summaryrefslogtreecommitdiff
path: root/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms')
-rw-r--r--lib/Transforms/Coroutines/CoroElide.cpp5
-rw-r--r--lib/Transforms/Coroutines/CoroFrame.cpp40
-rw-r--r--lib/Transforms/Coroutines/CoroInstr.h5
-rw-r--r--lib/Transforms/Coroutines/CoroSplit.cpp58
-rw-r--r--lib/Transforms/Coroutines/Coroutines.cpp6
-rw-r--r--lib/Transforms/IPO/ArgumentPromotion.cpp1309
-rw-r--r--lib/Transforms/IPO/ConstantMerge.cpp26
-rw-r--r--lib/Transforms/IPO/CrossDSOCFI.cpp7
-rw-r--r--lib/Transforms/IPO/DeadArgumentElimination.cpp152
-rw-r--r--lib/Transforms/IPO/FunctionAttrs.cpp150
-rw-r--r--lib/Transforms/IPO/FunctionImport.cpp110
-rw-r--r--lib/Transforms/IPO/GlobalDCE.cpp152
-rw-r--r--lib/Transforms/IPO/GlobalOpt.cpp10
-rw-r--r--lib/Transforms/IPO/GlobalSplit.cpp11
-rw-r--r--lib/Transforms/IPO/IPConstantPropagation.cpp8
-rw-r--r--lib/Transforms/IPO/InlineSimple.cpp13
-rw-r--r--lib/Transforms/IPO/Inliner.cpp205
-rw-r--r--lib/Transforms/IPO/LowerTypeTests.cpp308
-rw-r--r--lib/Transforms/IPO/MergeFunctions.cpp299
-rw-r--r--lib/Transforms/IPO/PartialInlining.cpp2
-rw-r--r--lib/Transforms/IPO/PassManagerBuilder.cpp87
-rw-r--r--lib/Transforms/IPO/SampleProfile.cpp271
-rw-r--r--lib/Transforms/IPO/StripSymbols.cpp24
-rw-r--r--lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp261
-rw-r--r--lib/Transforms/IPO/WholeProgramDevirt.cpp764
-rw-r--r--lib/Transforms/InstCombine/InstCombineAddSub.cpp135
-rw-r--r--lib/Transforms/InstCombine/InstCombineAndOrXor.cpp1192
-rw-r--r--lib/Transforms/InstCombine/InstCombineCalls.cpp1148
-rw-r--r--lib/Transforms/InstCombine/InstCombineCasts.cpp169
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp212
-rw-r--r--lib/Transforms/InstCombine/InstCombineInternal.h61
-rw-r--r--lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp145
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp58
-rw-r--r--lib/Transforms/InstCombine/InstCombinePHI.cpp6
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp86
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp703
-rw-r--r--lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp562
-rw-r--r--lib/Transforms/InstCombine/InstCombineVectorOps.cpp44
-rw-r--r--lib/Transforms/InstCombine/InstructionCombining.cpp285
-rw-r--r--lib/Transforms/Instrumentation/AddressSanitizer.cpp156
-rw-r--r--lib/Transforms/Instrumentation/DataFlowSanitizer.cpp78
-rw-r--r--lib/Transforms/Instrumentation/EfficiencySanitizer.cpp22
-rw-r--r--lib/Transforms/Instrumentation/IndirectCallPromotion.cpp506
-rw-r--r--lib/Transforms/Instrumentation/InstrProfiling.cpp170
-rw-r--r--lib/Transforms/Instrumentation/Instrumentation.cpp1
-rw-r--r--lib/Transforms/Instrumentation/MemorySanitizer.cpp103
-rw-r--r--lib/Transforms/Instrumentation/PGOInstrumentation.cpp353
-rw-r--r--lib/Transforms/Instrumentation/SanitizerCoverage.cpp100
-rw-r--r--lib/Transforms/Instrumentation/ThreadSanitizer.cpp54
-rw-r--r--lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h14
-rw-r--r--lib/Transforms/ObjCARC/ObjCARCContract.cpp1
-rw-r--r--lib/Transforms/ObjCARC/ObjCARCOpts.cpp92
-rw-r--r--lib/Transforms/Scalar/ADCE.cpp43
-rw-r--r--lib/Transforms/Scalar/AlignmentFromAssumptions.cpp12
-rw-r--r--lib/Transforms/Scalar/BDCE.cpp4
-rw-r--r--lib/Transforms/Scalar/CMakeLists.txt2
-rw-r--r--lib/Transforms/Scalar/ConstantHoisting.cpp21
-rw-r--r--lib/Transforms/Scalar/CorrelatedValuePropagation.cpp25
-rw-r--r--lib/Transforms/Scalar/DCE.cpp9
-rw-r--r--lib/Transforms/Scalar/DeadStoreElimination.cpp62
-rw-r--r--lib/Transforms/Scalar/EarlyCSE.cpp92
-rw-r--r--lib/Transforms/Scalar/Float2Int.cpp11
-rw-r--r--lib/Transforms/Scalar/GVN.cpp498
-rw-r--r--lib/Transforms/Scalar/GVNHoist.cpp89
-rw-r--r--lib/Transforms/Scalar/GuardWidening.cpp11
-rw-r--r--lib/Transforms/Scalar/IndVarSimplify.cpp19
-rw-r--r--lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp44
-rw-r--r--lib/Transforms/Scalar/InferAddressSpaces.cpp903
-rw-r--r--lib/Transforms/Scalar/JumpThreading.cpp293
-rw-r--r--lib/Transforms/Scalar/LICM.cpp170
-rw-r--r--lib/Transforms/Scalar/LoadCombine.cpp19
-rw-r--r--lib/Transforms/Scalar/LoopDeletion.cpp127
-rw-r--r--lib/Transforms/Scalar/LoopDistribute.cpp46
-rw-r--r--lib/Transforms/Scalar/LoopIdiomRecognize.cpp8
-rw-r--r--lib/Transforms/Scalar/LoopInstSimplify.cpp4
-rw-r--r--lib/Transforms/Scalar/LoopInterchange.cpp2
-rw-r--r--lib/Transforms/Scalar/LoopLoadElimination.cpp105
-rw-r--r--lib/Transforms/Scalar/LoopPassManager.cpp7
-rw-r--r--lib/Transforms/Scalar/LoopPredication.cpp282
-rw-r--r--lib/Transforms/Scalar/LoopRotation.cpp49
-rw-r--r--lib/Transforms/Scalar/LoopSimplifyCFG.cpp1
-rw-r--r--lib/Transforms/Scalar/LoopSink.cpp40
-rw-r--r--lib/Transforms/Scalar/LoopStrengthReduce.cpp432
-rw-r--r--lib/Transforms/Scalar/LoopUnrollPass.cpp172
-rw-r--r--lib/Transforms/Scalar/LoopUnswitch.cpp301
-rw-r--r--lib/Transforms/Scalar/LowerExpectIntrinsic.cpp4
-rw-r--r--lib/Transforms/Scalar/MemCpyOptimizer.cpp147
-rw-r--r--lib/Transforms/Scalar/MergedLoadStoreMotion.cpp179
-rw-r--r--lib/Transforms/Scalar/NaryReassociate.cpp12
-rw-r--r--lib/Transforms/Scalar/NewGVN.cpp2750
-rw-r--r--lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp12
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp13
-rw-r--r--lib/Transforms/Scalar/RewriteStatepointsForGC.cpp59
-rw-r--r--lib/Transforms/Scalar/SCCP.cpp171
-rw-r--r--lib/Transforms/Scalar/SROA.cpp24
-rw-r--r--lib/Transforms/Scalar/Scalar.cpp9
-rw-r--r--lib/Transforms/Scalar/Scalarizer.cpp19
-rw-r--r--lib/Transforms/Scalar/SimplifyCFGPass.cpp76
-rw-r--r--lib/Transforms/Scalar/Sink.cpp12
-rw-r--r--lib/Transforms/Utils/AddDiscriminators.cpp24
-rw-r--r--lib/Transforms/Utils/BasicBlockUtils.cpp9
-rw-r--r--lib/Transforms/Utils/BuildLibCalls.cpp448
-rw-r--r--lib/Transforms/Utils/BypassSlowDivision.cpp532
-rw-r--r--lib/Transforms/Utils/CMakeLists.txt4
-rw-r--r--lib/Transforms/Utils/CloneFunction.cpp65
-rw-r--r--lib/Transforms/Utils/CloneModule.cpp13
-rw-r--r--lib/Transforms/Utils/CodeExtractor.cpp19
-rw-r--r--lib/Transforms/Utils/DemoteRegToStack.cpp17
-rw-r--r--lib/Transforms/Utils/Evaluator.cpp3
-rw-r--r--lib/Transforms/Utils/FunctionComparator.cpp8
-rw-r--r--lib/Transforms/Utils/FunctionImportUtils.cpp14
-rw-r--r--lib/Transforms/Utils/GlobalStatus.cpp21
-rw-r--r--lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp2
-rw-r--r--lib/Transforms/Utils/InlineFunction.cpp152
-rw-r--r--lib/Transforms/Utils/LCSSA.cpp84
-rw-r--r--lib/Transforms/Utils/LibCallsShrinkWrap.cpp182
-rw-r--r--lib/Transforms/Utils/Local.cpp148
-rw-r--r--lib/Transforms/Utils/LoopSimplify.cpp40
-rw-r--r--lib/Transforms/Utils/LoopUnroll.cpp167
-rw-r--r--lib/Transforms/Utils/LoopUnrollPeel.cpp98
-rw-r--r--lib/Transforms/Utils/LoopUnrollRuntime.cpp52
-rw-r--r--lib/Transforms/Utils/LoopUtils.cpp25
-rw-r--r--lib/Transforms/Utils/LowerMemIntrinsics.cpp231
-rw-r--r--lib/Transforms/Utils/LowerSwitch.cpp8
-rw-r--r--lib/Transforms/Utils/Mem2Reg.cpp7
-rw-r--r--lib/Transforms/Utils/MemorySSA.cpp2305
-rw-r--r--lib/Transforms/Utils/MetaRenamer.cpp17
-rw-r--r--lib/Transforms/Utils/ModuleUtils.cpp23
-rw-r--r--lib/Transforms/Utils/PredicateInfo.cpp782
-rw-r--r--lib/Transforms/Utils/PromoteMemoryToRegister.cpp93
-rw-r--r--lib/Transforms/Utils/SSAUpdater.cpp27
-rw-r--r--lib/Transforms/Utils/SimplifyCFG.cpp171
-rw-r--r--lib/Transforms/Utils/SimplifyIndVar.cpp42
-rw-r--r--lib/Transforms/Utils/SimplifyInstructions.cpp22
-rw-r--r--lib/Transforms/Utils/SimplifyLibCalls.cpp399
-rw-r--r--lib/Transforms/Utils/Utils.cpp3
-rw-r--r--lib/Transforms/Utils/VNCoercion.cpp482
-rw-r--r--lib/Transforms/Utils/ValueMapper.cpp1
-rw-r--r--lib/Transforms/Vectorize/BBVectorize.cpp77
-rw-r--r--lib/Transforms/Vectorize/LoadStoreVectorizer.cpp13
-rw-r--r--lib/Transforms/Vectorize/LoopVectorize.cpp2965
-rw-r--r--lib/Transforms/Vectorize/SLPVectorizer.cpp846
142 files changed, 17793 insertions, 10977 deletions
diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp
index 99974d8da64c0..c6ac3f614ff7e 100644
--- a/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/lib/Transforms/Coroutines/CoroElide.cpp
@@ -92,7 +92,7 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.
static Type *getFrameType(Function *Resume) {
- auto *ArgType = Resume->getArgumentList().front().getType();
+ auto *ArgType = Resume->arg_begin()->getType();
return cast<PointerType>(ArgType)->getElementType();
}
@@ -127,7 +127,8 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {
// is spilled into the coroutine frame and recreate the alignment information
// here. Possibly we will need to do a mini SROA here and break the coroutine
// frame into individual AllocaInst recreating the original alignment.
- auto *Frame = new AllocaInst(FrameTy, "", InsertPt);
+ const DataLayout &DL = F->getParent()->getDataLayout();
+ auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
auto *FrameVoidPtr =
new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);
diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp
index bb28558a29e24..19e6789dfa74a 100644
--- a/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -133,6 +133,7 @@ struct SuspendCrossingInfo {
};
} // end anonymous namespace
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label,
BitVector const &BV) const {
dbgs() << Label << ":";
@@ -151,6 +152,7 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
}
dbgs() << "\n";
}
+#endif
SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
: Mapping(F) {
@@ -420,15 +422,31 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
report_fatal_error("Coroutines cannot handle non static allocas yet");
} else {
// Otherwise, create a store instruction storing the value into the
- // coroutine frame. For, argument, we will place the store instruction
- // right after the coroutine frame pointer instruction, i.e. bitcase of
- // coro.begin from i8* to %f.frame*. For all other values, the spill is
- // placed immediately after the definition.
- Builder.SetInsertPoint(
- isa<Argument>(CurrentValue)
- ? FramePtr->getNextNode()
- : dyn_cast<Instruction>(E.def())->getNextNode());
+ // coroutine frame.
+
+ Instruction *InsertPt = nullptr;
+ if (isa<Argument>(CurrentValue)) {
+ // For arguments, we will place the store instruction right after
+ // the coroutine frame pointer instruction, i.e. bitcast of
+ // coro.begin from i8* to %f.frame*.
+ InsertPt = FramePtr->getNextNode();
+ } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) {
+ // If we are spilling the result of the invoke instruction, split the
+ // normal edge and insert the spill in the new block.
+ auto NewBB = SplitEdge(II->getParent(), II->getNormalDest());
+ InsertPt = NewBB->getTerminator();
+ } else if (dyn_cast<PHINode>(CurrentValue)) {
+ // Skip the PHINodes and EH pads instructions.
+ InsertPt =
+ &*cast<Instruction>(E.def())->getParent()->getFirstInsertionPt();
+ } else {
+ // For all other values, the spill is placed immediately after
+ // the definition.
+ assert(!isa<TerminatorInst>(E.def()) && "unexpected terminator");
+ InsertPt = cast<Instruction>(E.def())->getNextNode();
+ }
+ Builder.SetInsertPoint(InsertPt);
auto *G = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, Index,
CurrentValue->getName() + Twine(".spill.addr"));
@@ -484,7 +502,7 @@ static void rewritePHIs(BasicBlock &BB) {
// loop:
// %n.val = phi i32[%n, %entry], [%inc, %loop]
//
- // It will create:
+ // It will create:
//
// loop.from.entry:
// %n.loop.pre = phi i32 [%n, %entry]
@@ -687,13 +705,12 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
Spills.emplace_back(&I, U);
// Rewrite materializable instructions to be materialized at the use point.
- std::sort(Spills.begin(), Spills.end());
DEBUG(dump("Materializations", Spills));
rewriteMaterializableInstructions(Builder, Spills);
// Collect the spills for arguments and other not-materializable values.
Spills.clear();
- for (Argument &A : F.getArgumentList())
+ for (Argument &A : F.args())
for (User *U : A.users())
if (Checker.isDefinitionAcrossSuspend(A, U))
Spills.emplace_back(&A, U);
@@ -719,7 +736,6 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
Spills.emplace_back(&I, U);
}
}
- std::sort(Spills.begin(), Spills.end());
DEBUG(dump("Spills", Spills));
moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin);
Shape.FrameTy = buildFrameType(F, Shape, Spills);
diff --git a/lib/Transforms/Coroutines/CoroInstr.h b/lib/Transforms/Coroutines/CoroInstr.h
index e03cef4bfc460..5c666bdfea1f6 100644
--- a/lib/Transforms/Coroutines/CoroInstr.h
+++ b/lib/Transforms/Coroutines/CoroInstr.h
@@ -23,6 +23,9 @@
// the Coroutine library.
//===----------------------------------------------------------------------===//
+#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H
+#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H
+
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -316,3 +319,5 @@ public:
};
} // End namespace llvm.
+
+#endif
diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp
index 7a3f4f60bae90..ab648f884c5b1 100644
--- a/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -22,6 +22,7 @@
#include "CoroInternal.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
@@ -144,6 +145,33 @@ static void replaceFallthroughCoroEnd(IntrinsicInst *End,
BB->getTerminator()->eraseFromParent();
}
+// In Resumers, we replace unwind coro.end with True to force the immediate
+// unwind to caller.
+static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
+ if (Shape.CoroEnds.empty())
+ return;
+
+ LLVMContext &Context = Shape.CoroEnds.front()->getContext();
+ auto *True = ConstantInt::getTrue(Context);
+ for (CoroEndInst *CE : Shape.CoroEnds) {
+ if (!CE->isUnwind())
+ continue;
+
+ auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
+
+ // If coro.end has an associated bundle, add cleanupret instruction.
+ if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
+ Value *FromPad = Bundle->Inputs[0];
+ auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
+ NewCE->getParent()->splitBasicBlock(NewCE);
+ CleanupRet->getParent()->getTerminator()->eraseFromParent();
+ }
+
+ NewCE->replaceAllUsesWith(True);
+ NewCE->eraseFromParent();
+ }
+}
+
// Rewrite final suspend point handling. We do not use suspend index to
// represent the final suspend point. Instead we zero-out ResumeFnAddr in the
// coroutine frame, since it is undefined behavior to resume a coroutine
@@ -157,9 +185,9 @@ static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
coro::Shape &Shape, SwitchInst *Switch,
bool IsDestroy) {
assert(Shape.HasFinalSuspend);
- auto FinalCase = --Switch->case_end();
- BasicBlock *ResumeBB = FinalCase.getCaseSuccessor();
- Switch->removeCase(FinalCase);
+ auto FinalCaseIt = std::prev(Switch->case_end());
+ BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
+ Switch->removeCase(FinalCaseIt);
if (IsDestroy) {
BasicBlock *OldSwitchBB = Switch->getParent();
auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
@@ -195,7 +223,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
// Replace all args with undefs. The buildCoroutineFrame algorithm already
// rewritten access to the args that occurs after suspend points with loads
// and stores to/from the coroutine frame.
- for (Argument &A : F.getArgumentList())
+ for (Argument &A : F.args())
VMap[&A] = UndefValue::get(A.getType());
SmallVector<ReturnInst *, 4> Returns;
@@ -216,9 +244,9 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
// Remove old return attributes.
NewF->removeAttributes(
- AttributeSet::ReturnIndex,
- AttributeSet::get(
- NewF->getContext(), AttributeSet::ReturnIndex,
+ AttributeList::ReturnIndex,
+ AttributeList::get(
+ NewF->getContext(), AttributeList::ReturnIndex,
AttributeFuncs::typeIncompatible(NewF->getReturnType())));
// Make AllocaSpillBlock the new entry block.
@@ -236,7 +264,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
IRBuilder<> Builder(&NewF->getEntryBlock().front());
// Remap frame pointer.
- Argument *NewFramePtr = &NewF->getArgumentList().front();
+ Argument *NewFramePtr = &*NewF->arg_begin();
Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
NewFramePtr->takeName(OldFramePtr);
OldFramePtr->replaceAllUsesWith(NewFramePtr);
@@ -270,9 +298,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
// Remove coro.end intrinsics.
replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
- // FIXME: coming in upcoming patches:
- // replaceUnwindCoroEnds(Shape.CoroEnds, VMap);
-
+ replaceUnwindCoroEnds(Shape, VMap);
// Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
// to suppress deallocation code.
coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
@@ -284,8 +310,16 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
}
static void removeCoroEnds(coro::Shape &Shape) {
- for (CoroEndInst *CE : Shape.CoroEnds)
+ if (Shape.CoroEnds.empty())
+ return;
+
+ LLVMContext &Context = Shape.CoroEnds.front()->getContext();
+ auto *False = ConstantInt::getFalse(Context);
+
+ for (CoroEndInst *CE : Shape.CoroEnds) {
+ CE->replaceAllUsesWith(False);
CE->eraseFromParent();
+ }
}
static void replaceFrameSize(coro::Shape &Shape) {
diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp
index 877ec34b4d3b2..ea48043f9381f 100644
--- a/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/lib/Transforms/Coroutines/Coroutines.cpp
@@ -245,9 +245,9 @@ void coro::Shape::buildFrom(Function &F) {
if (CoroBegin)
report_fatal_error(
"coroutine should have exactly one defining @llvm.coro.begin");
- CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
- CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias);
- CB->removeAttribute(AttributeSet::FunctionIndex,
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
+ CB->removeAttribute(AttributeList::FunctionIndex,
Attribute::NoDuplicate);
CoroBegin = CB;
}
diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp
index 65b7bad3b1ed3..a2c8a32dfe866 100644
--- a/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -29,8 +29,9 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/ArgumentPromotion.h"
#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AliasAnalysis.h"
@@ -38,6 +39,7 @@
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/CFG.h"
@@ -51,323 +53,400 @@
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO.h"
#include <set>
using namespace llvm;
#define DEBUG_TYPE "argpromotion"
-STATISTIC(NumArgumentsPromoted , "Number of pointer arguments promoted");
+STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted");
-STATISTIC(NumByValArgsPromoted , "Number of byval arguments promoted");
-STATISTIC(NumArgumentsDead , "Number of dead pointer args eliminated");
+STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted");
+STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
-namespace {
- /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
- ///
- struct ArgPromotion : public CallGraphSCCPass {
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- getAAResultsAnalysisUsage(AU);
- CallGraphSCCPass::getAnalysisUsage(AU);
- }
+/// A vector used to hold the indices of a single GEP instruction
+typedef std::vector<uint64_t> IndicesVector;
- bool runOnSCC(CallGraphSCC &SCC) override;
- static char ID; // Pass identification, replacement for typeid
- explicit ArgPromotion(unsigned maxElements = 3)
- : CallGraphSCCPass(ID), maxElements(maxElements) {
- initializeArgPromotionPass(*PassRegistry::getPassRegistry());
- }
+/// DoPromotion - This method actually performs the promotion of the specified
+/// arguments, and returns the new function. At this point, we know that it's
+/// safe to do so.
+static Function *
+doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote,
+ SmallPtrSetImpl<Argument *> &ByValArgsToTransform,
+ Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>>
+ ReplaceCallSite) {
- private:
+ // Start by computing a new prototype for the function, which is the same as
+ // the old function, but has modified arguments.
+ FunctionType *FTy = F->getFunctionType();
+ std::vector<Type *> Params;
- using llvm::Pass::doInitialization;
- bool doInitialization(CallGraph &CG) override;
- /// The maximum number of elements to expand, or 0 for unlimited.
- unsigned maxElements;
- };
-}
+ typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable;
-/// A vector used to hold the indices of a single GEP instruction
-typedef std::vector<uint64_t> IndicesVector;
+ // ScalarizedElements - If we are promoting a pointer that has elements
+ // accessed out of it, keep track of which elements are accessed so that we
+ // can add one argument for each.
+ //
+ // Arguments that are directly loaded will have a zero element value here, to
+ // handle cases where there are both a direct load and GEP accesses.
+ //
+ std::map<Argument *, ScalarizeTable> ScalarizedElements;
-static CallGraphNode *
-PromoteArguments(CallGraphNode *CGN, CallGraph &CG,
- function_ref<AAResults &(Function &F)> AARGetter,
- unsigned MaxElements);
-static bool isDenselyPacked(Type *type, const DataLayout &DL);
-static bool canPaddingBeAccessed(Argument *Arg);
-static bool isSafeToPromoteArgument(Argument *Arg, bool isByVal, AAResults &AAR,
- unsigned MaxElements);
-static CallGraphNode *
-DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote,
- SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG);
+ // OriginalLoads - Keep track of a representative load instruction from the
+ // original function so that we can tell the alias analysis implementation
+ // what the new GEP/Load instructions we are inserting look like.
+ // We need to keep the original loads for each argument and the elements
+ // of the argument that are accessed.
+ std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads;
-char ArgPromotion::ID = 0;
-INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
- "Promote 'by reference' arguments to scalars", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
- "Promote 'by reference' arguments to scalars", false, false)
+ // Attribute - Keep track of the parameter attributes for the arguments
+ // that we are *not* promoting. For the ones that we do promote, the parameter
+ // attributes are lost
+ SmallVector<AttributeSet, 8> ArgAttrVec;
+ AttributeList PAL = F->getAttributes();
-Pass *llvm::createArgumentPromotionPass(unsigned maxElements) {
- return new ArgPromotion(maxElements);
-}
+ // First, determine the new argument list
+ unsigned ArgIndex = 0;
+ for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
+ ++I, ++ArgIndex) {
+ if (ByValArgsToTransform.count(&*I)) {
+ // Simple byval argument? Just add all the struct element types.
+ Type *AgTy = cast<PointerType>(I->getType())->getElementType();
+ StructType *STy = cast<StructType>(AgTy);
+ Params.insert(Params.end(), STy->element_begin(), STy->element_end());
+ ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(),
+ AttributeSet());
+ ++NumByValArgsPromoted;
+ } else if (!ArgsToPromote.count(&*I)) {
+ // Unchanged argument
+ Params.push_back(I->getType());
+ ArgAttrVec.push_back(PAL.getParamAttributes(ArgIndex));
+ } else if (I->use_empty()) {
+ // Dead argument (which are always marked as promotable)
+ ++NumArgumentsDead;
+ } else {
+ // Okay, this is being promoted. This means that the only uses are loads
+ // or GEPs which are only used by loads
-static bool runImpl(CallGraphSCC &SCC, CallGraph &CG,
- function_ref<AAResults &(Function &F)> AARGetter,
- unsigned MaxElements) {
- bool Changed = false, LocalChange;
+ // In this table, we will track which indices are loaded from the argument
+ // (where direct loads are tracked as no indices).
+ ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
+ for (User *U : I->users()) {
+ Instruction *UI = cast<Instruction>(U);
+ Type *SrcTy;
+ if (LoadInst *L = dyn_cast<LoadInst>(UI))
+ SrcTy = L->getType();
+ else
+ SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType();
+ IndicesVector Indices;
+ Indices.reserve(UI->getNumOperands() - 1);
+ // Since loads will only have a single operand, and GEPs only a single
+ // non-index operand, this will record direct loads without any indices,
+ // and gep+loads with the GEP indices.
+ for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end();
+ II != IE; ++II)
+ Indices.push_back(cast<ConstantInt>(*II)->getSExtValue());
+ // GEPs with a single 0 index can be merged with direct loads
+ if (Indices.size() == 1 && Indices.front() == 0)
+ Indices.clear();
+ ArgIndices.insert(std::make_pair(SrcTy, Indices));
+ LoadInst *OrigLoad;
+ if (LoadInst *L = dyn_cast<LoadInst>(UI))
+ OrigLoad = L;
+ else
+ // Take any load, we will use it only to update Alias Analysis
+ OrigLoad = cast<LoadInst>(UI->user_back());
+ OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad;
+ }
- do { // Iterate until we stop promoting from this SCC.
- LocalChange = false;
- // Attempt to promote arguments from all functions in this SCC.
- for (CallGraphNode *OldNode : SCC) {
- if (CallGraphNode *NewNode =
- PromoteArguments(OldNode, CG, AARGetter, MaxElements)) {
- LocalChange = true;
- SCC.ReplaceNode(OldNode, NewNode);
+ // Add a parameter to the function for each element passed in.
+ for (const auto &ArgIndex : ArgIndices) {
+ // not allowed to dereference ->begin() if size() is 0
+ Params.push_back(GetElementPtrInst::getIndexedType(
+ cast<PointerType>(I->getType()->getScalarType())->getElementType(),
+ ArgIndex.second));
+ ArgAttrVec.push_back(AttributeSet());
+ assert(Params.back());
}
+
+ if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty())
+ ++NumArgumentsPromoted;
+ else
+ ++NumAggregatesPromoted;
}
- Changed |= LocalChange; // Remember that we changed something.
- } while (LocalChange);
-
- return Changed;
-}
+ }
-bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
- if (skipSCC(SCC))
- return false;
+ Type *RetTy = FTy->getReturnType();
- // Get the callgraph information that we need to update to reflect our
- // changes.
- CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
+ // Construct the new function type using the new arguments.
+ FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
- // We compute dedicated AA results for each function in the SCC as needed. We
- // use a lambda referencing external objects so that they live long enough to
- // be queried, but we re-use them each time.
- Optional<BasicAAResult> BAR;
- Optional<AAResults> AAR;
- auto AARGetter = [&](Function &F) -> AAResults & {
- BAR.emplace(createLegacyPMBasicAAResult(*this, F));
- AAR.emplace(createLegacyPMAAResults(*this, F, *BAR));
- return *AAR;
- };
-
- return runImpl(SCC, CG, AARGetter, maxElements);
-}
+ // Create the new function body and insert it into the module.
+ Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
+ NF->copyAttributesFrom(F);
-/// \brief Checks if a type could have padding bytes.
-static bool isDenselyPacked(Type *type, const DataLayout &DL) {
+ // Patch the pointer to LLVM function in debug info descriptor.
+ NF->setSubprogram(F->getSubprogram());
+ F->setSubprogram(nullptr);
- // There is no size information, so be conservative.
- if (!type->isSized())
- return false;
+ DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n"
+ << "From: " << *F);
- // If the alloc size is not equal to the storage size, then there are padding
- // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
- if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type))
- return false;
+ // Recompute the parameter attributes list based on the new arguments for
+ // the function.
+ NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(),
+ PAL.getRetAttributes(), ArgAttrVec));
+ ArgAttrVec.clear();
- if (!isa<CompositeType>(type))
- return true;
+ F->getParent()->getFunctionList().insert(F->getIterator(), NF);
+ NF->takeName(F);
- // For homogenous sequential types, check for padding within members.
- if (SequentialType *seqTy = dyn_cast<SequentialType>(type))
- return isDenselyPacked(seqTy->getElementType(), DL);
+ // Loop over all of the callers of the function, transforming the call sites
+ // to pass in the loaded pointers.
+ //
+ SmallVector<Value *, 16> Args;
+ while (!F->use_empty()) {
+ CallSite CS(F->user_back());
+ assert(CS.getCalledFunction() == F);
+ Instruction *Call = CS.getInstruction();
+ const AttributeList &CallPAL = CS.getAttributes();
- // Check for padding within and between elements of a struct.
- StructType *StructTy = cast<StructType>(type);
- const StructLayout *Layout = DL.getStructLayout(StructTy);
- uint64_t StartPos = 0;
- for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) {
- Type *ElTy = StructTy->getElementType(i);
- if (!isDenselyPacked(ElTy, DL))
- return false;
- if (StartPos != Layout->getElementOffsetInBits(i))
- return false;
- StartPos += DL.getTypeAllocSizeInBits(ElTy);
- }
+ // Loop over the operands, inserting GEP and loads in the caller as
+ // appropriate.
+ CallSite::arg_iterator AI = CS.arg_begin();
+ ArgIndex = 1;
+ for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
+ ++I, ++AI, ++ArgIndex)
+ if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
+ Args.push_back(*AI); // Unmodified argument
+ ArgAttrVec.push_back(CallPAL.getAttributes(ArgIndex));
+ } else if (ByValArgsToTransform.count(&*I)) {
+ // Emit a GEP and load for each element of the struct.
+ Type *AgTy = cast<PointerType>(I->getType())->getElementType();
+ StructType *STy = cast<StructType>(AgTy);
+ Value *Idxs[2] = {
+ ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr};
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
+ Value *Idx = GetElementPtrInst::Create(
+ STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call);
+ // TODO: Tell AA about the new values?
+ Args.push_back(new LoadInst(Idx, Idx->getName() + ".val", Call));
+ ArgAttrVec.push_back(AttributeSet());
+ }
+ } else if (!I->use_empty()) {
+ // Non-dead argument: insert GEPs and loads as appropriate.
+ ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
+ // Store the Value* version of the indices in here, but declare it now
+ // for reuse.
+ std::vector<Value *> Ops;
+ for (const auto &ArgIndex : ArgIndices) {
+ Value *V = *AI;
+ LoadInst *OrigLoad =
+ OriginalLoads[std::make_pair(&*I, ArgIndex.second)];
+ if (!ArgIndex.second.empty()) {
+ Ops.reserve(ArgIndex.second.size());
+ Type *ElTy = V->getType();
+ for (unsigned long II : ArgIndex.second) {
+ // Use i32 to index structs, and i64 for others (pointers/arrays).
+ // This satisfies GEP constraints.
+ Type *IdxTy =
+ (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext())
+ : Type::getInt64Ty(F->getContext()));
+ Ops.push_back(ConstantInt::get(IdxTy, II));
+ // Keep track of the type we're currently indexing.
+ if (auto *ElPTy = dyn_cast<PointerType>(ElTy))
+ ElTy = ElPTy->getElementType();
+ else
+ ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II);
+ }
+ // And create a GEP to extract those indices.
+ V = GetElementPtrInst::Create(ArgIndex.first, V, Ops,
+ V->getName() + ".idx", Call);
+ Ops.clear();
+ }
+ // Since we're replacing a load make sure we take the alignment
+ // of the previous load.
+ LoadInst *newLoad = new LoadInst(V, V->getName() + ".val", Call);
+ newLoad->setAlignment(OrigLoad->getAlignment());
+ // Transfer the AA info too.
+ AAMDNodes AAInfo;
+ OrigLoad->getAAMetadata(AAInfo);
+ newLoad->setAAMetadata(AAInfo);
- return true;
-}
+ Args.push_back(newLoad);
+ ArgAttrVec.push_back(AttributeSet());
+ }
+ }
-/// \brief Checks if the padding bytes of an argument could be accessed.
-static bool canPaddingBeAccessed(Argument *arg) {
+ // Push any varargs arguments on the list.
+ for (; AI != CS.arg_end(); ++AI, ++ArgIndex) {
+ Args.push_back(*AI);
+ ArgAttrVec.push_back(CallPAL.getAttributes(ArgIndex));
+ }
- assert(arg->hasByValAttr());
+ SmallVector<OperandBundleDef, 1> OpBundles;
+ CS.getOperandBundlesAsDefs(OpBundles);
- // Track all the pointers to the argument to make sure they are not captured.
- SmallPtrSet<Value *, 16> PtrValues;
- PtrValues.insert(arg);
+ CallSite NewCS;
+ if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
+ NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
+ Args, OpBundles, "", Call);
+ } else {
+ auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call);
+ NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
+ NewCS = NewCall;
+ }
+ NewCS.setCallingConv(CS.getCallingConv());
+ NewCS.setAttributes(
+ AttributeList::get(F->getContext(), CallPAL.getFnAttributes(),
+ CallPAL.getRetAttributes(), ArgAttrVec));
+ NewCS->setDebugLoc(Call->getDebugLoc());
+ uint64_t W;
+ if (Call->extractProfTotalWeight(W))
+ NewCS->setProfWeight(W);
+ Args.clear();
+ ArgAttrVec.clear();
- // Track all of the stores.
- SmallVector<StoreInst *, 16> Stores;
+ // Update the callgraph to know that the callsite has been transformed.
+ if (ReplaceCallSite)
+ (*ReplaceCallSite)(CS, NewCS);
- // Scan through the uses recursively to make sure the pointer is always used
- // sanely.
- SmallVector<Value *, 16> WorkList;
- WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end());
- while (!WorkList.empty()) {
- Value *V = WorkList.back();
- WorkList.pop_back();
- if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
- if (PtrValues.insert(V).second)
- WorkList.insert(WorkList.end(), V->user_begin(), V->user_end());
- } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
- Stores.push_back(Store);
- } else if (!isa<LoadInst>(V)) {
- return true;
+ if (!Call->use_empty()) {
+ Call->replaceAllUsesWith(NewCS.getInstruction());
+ NewCS->takeName(Call);
}
- }
-// Check to make sure the pointers aren't captured
- for (StoreInst *Store : Stores)
- if (PtrValues.count(Store->getValueOperand()))
- return true;
-
- return false;
-}
+ // Finally, remove the old call from the program, reducing the use-count of
+ // F.
+ Call->eraseFromParent();
+ }
-/// PromoteArguments - This method checks the specified function to see if there
-/// are any promotable arguments and if it is safe to promote the function (for
-/// example, all callers are direct). If safe to promote some arguments, it
-/// calls the DoPromotion method.
-///
-static CallGraphNode *
-PromoteArguments(CallGraphNode *CGN, CallGraph &CG,
- function_ref<AAResults &(Function &F)> AARGetter,
- unsigned MaxElements) {
- Function *F = CGN->getFunction();
+ const DataLayout &DL = F->getParent()->getDataLayout();
- // Make sure that it is local to this module.
- if (!F || !F->hasLocalLinkage()) return nullptr;
+ // Since we have now created the new function, splice the body of the old
+ // function right into the new function, leaving the old rotting hulk of the
+ // function empty.
+ NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
- // Don't promote arguments for variadic functions. Adding, removing, or
- // changing non-pack parameters can change the classification of pack
- // parameters. Frontends encode that classification at the call site in the
- // IR, while in the callee the classification is determined dynamically based
- // on the number of registers consumed so far.
- if (F->isVarArg()) return nullptr;
+ // Loop over the argument list, transferring uses of the old arguments over to
+ // the new arguments, also transferring over the names as well.
+ //
+ for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
+ I2 = NF->arg_begin();
+ I != E; ++I) {
+ if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
+ // If this is an unmodified argument, move the name and users over to the
+ // new version.
+ I->replaceAllUsesWith(&*I2);
+ I2->takeName(&*I);
+ ++I2;
+ continue;
+ }
- // First check: see if there are any pointer arguments! If not, quick exit.
- SmallVector<Argument*, 16> PointerArgs;
- for (Argument &I : F->args())
- if (I.getType()->isPointerTy())
- PointerArgs.push_back(&I);
- if (PointerArgs.empty()) return nullptr;
+ if (ByValArgsToTransform.count(&*I)) {
+ // In the callee, we create an alloca, and store each of the new incoming
+ // arguments into the alloca.
+ Instruction *InsertPt = &NF->begin()->front();
- // Second check: make sure that all callers are direct callers. We can't
- // transform functions that have indirect callers. Also see if the function
- // is self-recursive.
- bool isSelfRecursive = false;
- for (Use &U : F->uses()) {
- CallSite CS(U.getUser());
- // Must be a direct call.
- if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) return nullptr;
-
- if (CS.getInstruction()->getParent()->getParent() == F)
- isSelfRecursive = true;
- }
-
- const DataLayout &DL = F->getParent()->getDataLayout();
+ // Just add all the struct element types.
+ Type *AgTy = cast<PointerType>(I->getType())->getElementType();
+ Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr,
+ "", InsertPt);
+ StructType *STy = cast<StructType>(AgTy);
+ Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0),
+ nullptr};
- AAResults &AAR = AARGetter(*F);
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
+ Value *Idx = GetElementPtrInst::Create(
+ AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i),
+ InsertPt);
+ I2->setName(I->getName() + "." + Twine(i));
+ new StoreInst(&*I2++, Idx, InsertPt);
+ }
- // Check to see which arguments are promotable. If an argument is promotable,
- // add it to ArgsToPromote.
- SmallPtrSet<Argument*, 8> ArgsToPromote;
- SmallPtrSet<Argument*, 8> ByValArgsToTransform;
- for (Argument *PtrArg : PointerArgs) {
- Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType();
+ // Anything that used the arg should now use the alloca.
+ I->replaceAllUsesWith(TheAlloca);
+ TheAlloca->takeName(&*I);
- // Replace sret attribute with noalias. This reduces register pressure by
- // avoiding a register copy.
- if (PtrArg->hasStructRetAttr()) {
- unsigned ArgNo = PtrArg->getArgNo();
- F->setAttributes(
- F->getAttributes()
- .removeAttribute(F->getContext(), ArgNo + 1, Attribute::StructRet)
- .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias));
- for (Use &U : F->uses()) {
- CallSite CS(U.getUser());
- CS.setAttributes(
- CS.getAttributes()
- .removeAttribute(F->getContext(), ArgNo + 1,
- Attribute::StructRet)
- .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias));
+ // If the alloca is used in a call, we must clear the tail flag since
+ // the callee now uses an alloca from the caller.
+ for (User *U : TheAlloca->users()) {
+ CallInst *Call = dyn_cast<CallInst>(U);
+ if (!Call)
+ continue;
+ Call->setTailCall(false);
}
+ continue;
}
- // If this is a byval argument, and if the aggregate type is small, just
- // pass the elements, which is always safe, if the passed value is densely
- // packed or if we can prove the padding bytes are never accessed. This does
- // not apply to inalloca.
- bool isSafeToPromote =
- PtrArg->hasByValAttr() &&
- (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg));
- if (isSafeToPromote) {
- if (StructType *STy = dyn_cast<StructType>(AgTy)) {
- if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
- DEBUG(dbgs() << "argpromotion disable promoting argument '"
- << PtrArg->getName() << "' because it would require adding more"
- << " than " << MaxElements << " arguments to the function.\n");
- continue;
- }
-
- // If all the elements are single-value types, we can promote it.
- bool AllSimple = true;
- for (const auto *EltTy : STy->elements()) {
- if (!EltTy->isSingleValueType()) {
- AllSimple = false;
- break;
- }
+ if (I->use_empty())
+ continue;
+
+ // Otherwise, if we promoted this argument, then all users are load
+ // instructions (or GEPs with only load users), and all loads should be
+ // using the new argument that we added.
+ ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
+
+ while (!I->use_empty()) {
+ if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) {
+ assert(ArgIndices.begin()->second.empty() &&
+ "Load element should sort to front!");
+ I2->setName(I->getName() + ".val");
+ LI->replaceAllUsesWith(&*I2);
+ LI->eraseFromParent();
+ DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName()
+ << "' in function '" << F->getName() << "'\n");
+ } else {
+ GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back());
+ IndicesVector Operands;
+ Operands.reserve(GEP->getNumIndices());
+ for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
+ II != IE; ++II)
+ Operands.push_back(cast<ConstantInt>(*II)->getSExtValue());
+
+ // GEPs with a single 0 index can be merged with direct loads
+ if (Operands.size() == 1 && Operands.front() == 0)
+ Operands.clear();
+
+ Function::arg_iterator TheArg = I2;
+ for (ScalarizeTable::iterator It = ArgIndices.begin();
+ It->second != Operands; ++It, ++TheArg) {
+ assert(It != ArgIndices.end() && "GEP not handled??");
}
- // Safe to transform, don't even bother trying to "promote" it.
- // Passing the elements as a scalar will allow sroa to hack on
- // the new alloca we introduce.
- if (AllSimple) {
- ByValArgsToTransform.insert(PtrArg);
- continue;
+ std::string NewName = I->getName();
+ for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
+ NewName += "." + utostr(Operands[i]);
}
- }
- }
+ NewName += ".val";
+ TheArg->setName(NewName);
- // If the argument is a recursive type and we're in a recursive
- // function, we could end up infinitely peeling the function argument.
- if (isSelfRecursive) {
- if (StructType *STy = dyn_cast<StructType>(AgTy)) {
- bool RecursiveType = false;
- for (const auto *EltTy : STy->elements()) {
- if (EltTy == PtrArg->getType()) {
- RecursiveType = true;
- break;
- }
+ DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName()
+ << "' of function '" << NF->getName() << "'\n");
+
+ // All of the uses must be load instructions. Replace them all with
+ // the argument specified by ArgNo.
+ while (!GEP->use_empty()) {
+ LoadInst *L = cast<LoadInst>(GEP->user_back());
+ L->replaceAllUsesWith(&*TheArg);
+ L->eraseFromParent();
}
- if (RecursiveType)
- continue;
+ GEP->eraseFromParent();
}
}
-
- // Otherwise, see if we can promote the pointer to its value.
- if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR,
- MaxElements))
- ArgsToPromote.insert(PtrArg);
- }
- // No promotable pointer arguments.
- if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
- return nullptr;
+ // Increment I2 past all of the arguments added for this promoted pointer.
+ std::advance(I2, ArgIndices.size());
+ }
- return DoPromotion(F, ArgsToPromote, ByValArgsToTransform, CG);
+ return NF;
}
/// AllCallersPassInValidPointerForArgument - Return true if we can prove that
/// all callees pass in a valid pointer for the specified function argument.
-static bool AllCallersPassInValidPointerForArgument(Argument *Arg) {
+static bool allCallersPassInValidPointerForArgument(Argument *Arg) {
Function *Callee = Arg->getParent();
const DataLayout &DL = Callee->getParent()->getDataLayout();
@@ -390,26 +469,25 @@ static bool AllCallersPassInValidPointerForArgument(Argument *Arg) {
/// elements in Prefix is the same as the corresponding elements in Longer.
///
/// This means it also returns true when Prefix and Longer are equal!
-static bool IsPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) {
+static bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) {
if (Prefix.size() > Longer.size())
return false;
return std::equal(Prefix.begin(), Prefix.end(), Longer.begin());
}
-
/// Checks if Indices, or a prefix of Indices, is in Set.
-static bool PrefixIn(const IndicesVector &Indices,
+static bool prefixIn(const IndicesVector &Indices,
std::set<IndicesVector> &Set) {
- std::set<IndicesVector>::iterator Low;
- Low = Set.upper_bound(Indices);
- if (Low != Set.begin())
- Low--;
- // Low is now the last element smaller than or equal to Indices. This means
- // it points to a prefix of Indices (possibly Indices itself), if such
- // prefix exists.
- //
- // This load is safe if any prefix of its operands is safe to load.
- return Low != Set.end() && IsPrefix(*Low, Indices);
+ std::set<IndicesVector>::iterator Low;
+ Low = Set.upper_bound(Indices);
+ if (Low != Set.begin())
+ Low--;
+ // Low is now the last element smaller than or equal to Indices. This means
+ // it points to a prefix of Indices (possibly Indices itself), if such
+ // prefix exists.
+ //
+ // This load is safe if any prefix of its operands is safe to load.
+ return Low != Set.end() && isPrefix(*Low, Indices);
}
/// Mark the given indices (ToMark) as safe in the given set of indices
@@ -417,7 +495,7 @@ static bool PrefixIn(const IndicesVector &Indices,
/// is already a prefix of Indices in Safe, Indices are implicitely marked safe
/// already. Furthermore, any indices that Indices is itself a prefix of, are
/// removed from Safe (since they are implicitely safe because of Indices now).
-static void MarkIndicesSafe(const IndicesVector &ToMark,
+static void markIndicesSafe(const IndicesVector &ToMark,
std::set<IndicesVector> &Safe) {
std::set<IndicesVector>::iterator Low;
Low = Safe.upper_bound(ToMark);
@@ -428,7 +506,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark,
// means it points to a prefix of Indices (possibly Indices itself), if
// such prefix exists.
if (Low != Safe.end()) {
- if (IsPrefix(*Low, ToMark))
+ if (isPrefix(*Low, ToMark))
// If there is already a prefix of these indices (or exactly these
// indices) marked a safe, don't bother adding these indices
return;
@@ -441,7 +519,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark,
++Low;
// If there we're a prefix of longer index list(s), remove those
std::set<IndicesVector>::iterator End = Safe.end();
- while (Low != End && IsPrefix(ToMark, *Low)) {
+ while (Low != End && isPrefix(ToMark, *Low)) {
std::set<IndicesVector>::iterator Remove = Low;
++Low;
Safe.erase(Remove);
@@ -486,7 +564,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
GEPIndicesSet ToPromote;
// If the pointer is always valid, any load with first index 0 is valid.
- if (isByValOrInAlloca || AllCallersPassInValidPointerForArgument(Arg))
+ if (isByValOrInAlloca || allCallersPassInValidPointerForArgument(Arg))
SafeToUnconditionallyLoad.insert(IndicesVector(1, 0));
// First, iterate the entry block and mark loads of (geps of) arguments as
@@ -512,25 +590,26 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
return false;
// Indices checked out, mark them as safe
- MarkIndicesSafe(Indices, SafeToUnconditionallyLoad);
+ markIndicesSafe(Indices, SafeToUnconditionallyLoad);
Indices.clear();
}
} else if (V == Arg) {
// Direct loads are equivalent to a GEP with a single 0 index.
- MarkIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad);
+ markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad);
}
}
// Now, iterate all uses of the argument to see if there are any uses that are
// not (GEP+)loads, or any (GEP+)loads that are not safe to promote.
- SmallVector<LoadInst*, 16> Loads;
+ SmallVector<LoadInst *, 16> Loads;
IndicesVector Operands;
for (Use &U : Arg->uses()) {
User *UR = U.getUser();
Operands.clear();
if (LoadInst *LI = dyn_cast<LoadInst>(UR)) {
// Don't hack volatile/atomic loads
- if (!LI->isSimple()) return false;
+ if (!LI->isSimple())
+ return false;
Loads.push_back(LI);
// Direct loads are equivalent to a GEP with a zero index and then a load.
Operands.push_back(0);
@@ -547,30 +626,31 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
}
// Ensure that all of the indices are constants.
- for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end();
- i != e; ++i)
+ for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e;
+ ++i)
if (ConstantInt *C = dyn_cast<ConstantInt>(*i))
Operands.push_back(C->getSExtValue());
else
- return false; // Not a constant operand GEP!
+ return false; // Not a constant operand GEP!
// Ensure that the only users of the GEP are load instructions.
for (User *GEPU : GEP->users())
if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) {
// Don't hack volatile/atomic loads
- if (!LI->isSimple()) return false;
+ if (!LI->isSimple())
+ return false;
Loads.push_back(LI);
} else {
// Other uses than load?
return false;
}
} else {
- return false; // Not a load or a GEP.
+ return false; // Not a load or a GEP.
}
// Now, see if it is safe to promote this load / loads of this GEP. Loading
// is safe if Operands, or a prefix of Operands, is marked as safe.
- if (!PrefixIn(Operands, SafeToUnconditionallyLoad))
+ if (!prefixIn(Operands, SafeToUnconditionallyLoad))
return false;
// See if we are already promoting a load with these indices. If not, check
@@ -579,8 +659,10 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
if (ToPromote.find(Operands) == ToPromote.end()) {
if (MaxElements > 0 && ToPromote.size() == MaxElements) {
DEBUG(dbgs() << "argpromotion not promoting argument '"
- << Arg->getName() << "' because it would require adding more "
- << "than " << MaxElements << " arguments to the function.\n");
+ << Arg->getName()
+ << "' because it would require adding more "
+ << "than " << MaxElements
+ << " arguments to the function.\n");
// We limit aggregate promotion to only promoting up to a fixed number
// of elements of the aggregate.
return false;
@@ -589,7 +671,8 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
}
}
- if (Loads.empty()) return true; // No users, this is a dead argument.
+ if (Loads.empty())
+ return true; // No users, this is a dead argument.
// Okay, now we know that the argument is only used by load instructions and
// it is safe to unconditionally perform all of them. Use alias analysis to
@@ -598,7 +681,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
// Because there could be several/many load instructions, remember which
// blocks we know to be transparent to the load.
- df_iterator_default_set<BasicBlock*, 16> TranspBlocks;
+ df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
for (LoadInst *Load : Loads) {
// Check to see if the load is invalidated from the start of the block to
@@ -607,7 +690,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
MemoryLocation Loc = MemoryLocation::get(Load);
if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, MRI_Mod))
- return false; // Pointer is invalidated!
+ return false; // Pointer is invalidated!
// Now check every path from the entry block to the load for transparency.
// To do this, we perform a depth first search on the inverse CFG from the
@@ -625,416 +708,352 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca,
return true;
}
-/// DoPromotion - This method actually performs the promotion of the specified
-/// arguments, and returns the new function. At this point, we know that it's
-/// safe to do so.
-static CallGraphNode *
-DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote,
- SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG) {
+/// \brief Checks if a type could have padding bytes.
+static bool isDenselyPacked(Type *type, const DataLayout &DL) {
- // Start by computing a new prototype for the function, which is the same as
- // the old function, but has modified arguments.
- FunctionType *FTy = F->getFunctionType();
- std::vector<Type*> Params;
+ // There is no size information, so be conservative.
+ if (!type->isSized())
+ return false;
- typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable;
+ // If the alloc size is not equal to the storage size, then there are padding
+ // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
+ if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type))
+ return false;
- // ScalarizedElements - If we are promoting a pointer that has elements
- // accessed out of it, keep track of which elements are accessed so that we
- // can add one argument for each.
- //
- // Arguments that are directly loaded will have a zero element value here, to
- // handle cases where there are both a direct load and GEP accesses.
- //
- std::map<Argument*, ScalarizeTable> ScalarizedElements;
+ if (!isa<CompositeType>(type))
+ return true;
- // OriginalLoads - Keep track of a representative load instruction from the
- // original function so that we can tell the alias analysis implementation
- // what the new GEP/Load instructions we are inserting look like.
- // We need to keep the original loads for each argument and the elements
- // of the argument that are accessed.
- std::map<std::pair<Argument*, IndicesVector>, LoadInst*> OriginalLoads;
+ // For homogenous sequential types, check for padding within members.
+ if (SequentialType *seqTy = dyn_cast<SequentialType>(type))
+ return isDenselyPacked(seqTy->getElementType(), DL);
- // Attribute - Keep track of the parameter attributes for the arguments
- // that we are *not* promoting. For the ones that we do promote, the parameter
- // attributes are lost
- SmallVector<AttributeSet, 8> AttributesVec;
- const AttributeSet &PAL = F->getAttributes();
+ // Check for padding within and between elements of a struct.
+ StructType *StructTy = cast<StructType>(type);
+ const StructLayout *Layout = DL.getStructLayout(StructTy);
+ uint64_t StartPos = 0;
+ for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) {
+ Type *ElTy = StructTy->getElementType(i);
+ if (!isDenselyPacked(ElTy, DL))
+ return false;
+ if (StartPos != Layout->getElementOffsetInBits(i))
+ return false;
+ StartPos += DL.getTypeAllocSizeInBits(ElTy);
+ }
- // Add any return attributes.
- if (PAL.hasAttributes(AttributeSet::ReturnIndex))
- AttributesVec.push_back(AttributeSet::get(F->getContext(),
- PAL.getRetAttributes()));
+ return true;
+}
- // First, determine the new argument list
- unsigned ArgIndex = 1;
- for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
- ++I, ++ArgIndex) {
- if (ByValArgsToTransform.count(&*I)) {
- // Simple byval argument? Just add all the struct element types.
- Type *AgTy = cast<PointerType>(I->getType())->getElementType();
- StructType *STy = cast<StructType>(AgTy);
- Params.insert(Params.end(), STy->element_begin(), STy->element_end());
- ++NumByValArgsPromoted;
- } else if (!ArgsToPromote.count(&*I)) {
- // Unchanged argument
- Params.push_back(I->getType());
- AttributeSet attrs = PAL.getParamAttributes(ArgIndex);
- if (attrs.hasAttributes(ArgIndex)) {
- AttrBuilder B(attrs, ArgIndex);
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Params.size(), B));
- }
- } else if (I->use_empty()) {
- // Dead argument (which are always marked as promotable)
- ++NumArgumentsDead;
- } else {
- // Okay, this is being promoted. This means that the only uses are loads
- // or GEPs which are only used by loads
+/// \brief Checks if the padding bytes of an argument could be accessed.
+static bool canPaddingBeAccessed(Argument *arg) {
- // In this table, we will track which indices are loaded from the argument
- // (where direct loads are tracked as no indices).
- ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
- for (User *U : I->users()) {
- Instruction *UI = cast<Instruction>(U);
- Type *SrcTy;
- if (LoadInst *L = dyn_cast<LoadInst>(UI))
- SrcTy = L->getType();
- else
- SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType();
- IndicesVector Indices;
- Indices.reserve(UI->getNumOperands() - 1);
- // Since loads will only have a single operand, and GEPs only a single
- // non-index operand, this will record direct loads without any indices,
- // and gep+loads with the GEP indices.
- for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end();
- II != IE; ++II)
- Indices.push_back(cast<ConstantInt>(*II)->getSExtValue());
- // GEPs with a single 0 index can be merged with direct loads
- if (Indices.size() == 1 && Indices.front() == 0)
- Indices.clear();
- ArgIndices.insert(std::make_pair(SrcTy, Indices));
- LoadInst *OrigLoad;
- if (LoadInst *L = dyn_cast<LoadInst>(UI))
- OrigLoad = L;
- else
- // Take any load, we will use it only to update Alias Analysis
- OrigLoad = cast<LoadInst>(UI->user_back());
- OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad;
- }
+ assert(arg->hasByValAttr());
- // Add a parameter to the function for each element passed in.
- for (const auto &ArgIndex : ArgIndices) {
- // not allowed to dereference ->begin() if size() is 0
- Params.push_back(GetElementPtrInst::getIndexedType(
- cast<PointerType>(I->getType()->getScalarType())->getElementType(),
- ArgIndex.second));
- assert(Params.back());
- }
+ // Track all the pointers to the argument to make sure they are not captured.
+ SmallPtrSet<Value *, 16> PtrValues;
+ PtrValues.insert(arg);
- if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty())
- ++NumArgumentsPromoted;
- else
- ++NumAggregatesPromoted;
+ // Track all of the stores.
+ SmallVector<StoreInst *, 16> Stores;
+
+ // Scan through the uses recursively to make sure the pointer is always used
+ // sanely.
+ SmallVector<Value *, 16> WorkList;
+ WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end());
+ while (!WorkList.empty()) {
+ Value *V = WorkList.back();
+ WorkList.pop_back();
+ if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
+ if (PtrValues.insert(V).second)
+ WorkList.insert(WorkList.end(), V->user_begin(), V->user_end());
+ } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
+ Stores.push_back(Store);
+ } else if (!isa<LoadInst>(V)) {
+ return true;
}
}
- // Add any function attributes.
- if (PAL.hasAttributes(AttributeSet::FunctionIndex))
- AttributesVec.push_back(AttributeSet::get(FTy->getContext(),
- PAL.getFnAttributes()));
+ // Check to make sure the pointers aren't captured
+ for (StoreInst *Store : Stores)
+ if (PtrValues.count(Store->getValueOperand()))
+ return true;
- Type *RetTy = FTy->getReturnType();
+ return false;
+}
- // Construct the new function type using the new arguments.
- FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
+/// PromoteArguments - This method checks the specified function to see if there
+/// are any promotable arguments and if it is safe to promote the function (for
+/// example, all callers are direct). If safe to promote some arguments, it
+/// calls the DoPromotion method.
+///
+static Function *
+promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter,
+ unsigned MaxElements,
+ Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>>
+ ReplaceCallSite) {
+ // Make sure that it is local to this module.
+ if (!F->hasLocalLinkage())
+ return nullptr;
- // Create the new function body and insert it into the module.
- Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
- NF->copyAttributesFrom(F);
+ // Don't promote arguments for variadic functions. Adding, removing, or
+ // changing non-pack parameters can change the classification of pack
+ // parameters. Frontends encode that classification at the call site in the
+ // IR, while in the callee the classification is determined dynamically based
+ // on the number of registers consumed so far.
+ if (F->isVarArg())
+ return nullptr;
- // Patch the pointer to LLVM function in debug info descriptor.
- NF->setSubprogram(F->getSubprogram());
- F->setSubprogram(nullptr);
+ // First check: see if there are any pointer arguments! If not, quick exit.
+ SmallVector<Argument *, 16> PointerArgs;
+ for (Argument &I : F->args())
+ if (I.getType()->isPointerTy())
+ PointerArgs.push_back(&I);
+ if (PointerArgs.empty())
+ return nullptr;
- DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n"
- << "From: " << *F);
-
- // Recompute the parameter attributes list based on the new arguments for
- // the function.
- NF->setAttributes(AttributeSet::get(F->getContext(), AttributesVec));
- AttributesVec.clear();
+ // Second check: make sure that all callers are direct callers. We can't
+ // transform functions that have indirect callers. Also see if the function
+ // is self-recursive.
+ bool isSelfRecursive = false;
+ for (Use &U : F->uses()) {
+ CallSite CS(U.getUser());
+ // Must be a direct call.
+ if (CS.getInstruction() == nullptr || !CS.isCallee(&U))
+ return nullptr;
- F->getParent()->getFunctionList().insert(F->getIterator(), NF);
- NF->takeName(F);
+ if (CS.getInstruction()->getParent()->getParent() == F)
+ isSelfRecursive = true;
+ }
- // Get a new callgraph node for NF.
- CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
+ const DataLayout &DL = F->getParent()->getDataLayout();
- // Loop over all of the callers of the function, transforming the call sites
- // to pass in the loaded pointers.
- //
- SmallVector<Value*, 16> Args;
- while (!F->use_empty()) {
- CallSite CS(F->user_back());
- assert(CS.getCalledFunction() == F);
- Instruction *Call = CS.getInstruction();
- const AttributeSet &CallPAL = CS.getAttributes();
+ AAResults &AAR = AARGetter(*F);
- // Add any return attributes.
- if (CallPAL.hasAttributes(AttributeSet::ReturnIndex))
- AttributesVec.push_back(AttributeSet::get(F->getContext(),
- CallPAL.getRetAttributes()));
+ // Check to see which arguments are promotable. If an argument is promotable,
+ // add it to ArgsToPromote.
+ SmallPtrSet<Argument *, 8> ArgsToPromote;
+ SmallPtrSet<Argument *, 8> ByValArgsToTransform;
+ for (Argument *PtrArg : PointerArgs) {
+ Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType();
- // Loop over the operands, inserting GEP and loads in the caller as
- // appropriate.
- CallSite::arg_iterator AI = CS.arg_begin();
- ArgIndex = 1;
- for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
- I != E; ++I, ++AI, ++ArgIndex)
- if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
- Args.push_back(*AI); // Unmodified argument
+ // Replace sret attribute with noalias. This reduces register pressure by
+ // avoiding a register copy.
+ if (PtrArg->hasStructRetAttr()) {
+ unsigned ArgNo = PtrArg->getArgNo();
+ F->setAttributes(
+ F->getAttributes()
+ .removeAttribute(F->getContext(), ArgNo + 1, Attribute::StructRet)
+ .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias));
+ for (Use &U : F->uses()) {
+ CallSite CS(U.getUser());
+ CS.setAttributes(
+ CS.getAttributes()
+ .removeAttribute(F->getContext(), ArgNo + 1,
+ Attribute::StructRet)
+ .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias));
+ }
+ }
- if (CallPAL.hasAttributes(ArgIndex)) {
- AttrBuilder B(CallPAL, ArgIndex);
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Args.size(), B));
- }
- } else if (ByValArgsToTransform.count(&*I)) {
- // Emit a GEP and load for each element of the struct.
- Type *AgTy = cast<PointerType>(I->getType())->getElementType();
- StructType *STy = cast<StructType>(AgTy);
- Value *Idxs[2] = {
- ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr };
- for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
- Value *Idx = GetElementPtrInst::Create(
- STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call);
- // TODO: Tell AA about the new values?
- Args.push_back(new LoadInst(Idx, Idx->getName()+".val", Call));
+ // If this is a byval argument, and if the aggregate type is small, just
+ // pass the elements, which is always safe, if the passed value is densely
+ // packed or if we can prove the padding bytes are never accessed. This does
+ // not apply to inalloca.
+ bool isSafeToPromote =
+ PtrArg->hasByValAttr() &&
+ (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg));
+ if (isSafeToPromote) {
+ if (StructType *STy = dyn_cast<StructType>(AgTy)) {
+ if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
+ DEBUG(dbgs() << "argpromotion disable promoting argument '"
+ << PtrArg->getName()
+ << "' because it would require adding more"
+ << " than " << MaxElements
+ << " arguments to the function.\n");
+ continue;
}
- } else if (!I->use_empty()) {
- // Non-dead argument: insert GEPs and loads as appropriate.
- ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
- // Store the Value* version of the indices in here, but declare it now
- // for reuse.
- std::vector<Value*> Ops;
- for (const auto &ArgIndex : ArgIndices) {
- Value *V = *AI;
- LoadInst *OrigLoad =
- OriginalLoads[std::make_pair(&*I, ArgIndex.second)];
- if (!ArgIndex.second.empty()) {
- Ops.reserve(ArgIndex.second.size());
- Type *ElTy = V->getType();
- for (unsigned long II : ArgIndex.second) {
- // Use i32 to index structs, and i64 for others (pointers/arrays).
- // This satisfies GEP constraints.
- Type *IdxTy = (ElTy->isStructTy() ?
- Type::getInt32Ty(F->getContext()) :
- Type::getInt64Ty(F->getContext()));
- Ops.push_back(ConstantInt::get(IdxTy, II));
- // Keep track of the type we're currently indexing.
- if (auto *ElPTy = dyn_cast<PointerType>(ElTy))
- ElTy = ElPTy->getElementType();
- else
- ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II);
- }
- // And create a GEP to extract those indices.
- V = GetElementPtrInst::Create(ArgIndex.first, V, Ops,
- V->getName() + ".idx", Call);
- Ops.clear();
+
+ // If all the elements are single-value types, we can promote it.
+ bool AllSimple = true;
+ for (const auto *EltTy : STy->elements()) {
+ if (!EltTy->isSingleValueType()) {
+ AllSimple = false;
+ break;
}
- // Since we're replacing a load make sure we take the alignment
- // of the previous load.
- LoadInst *newLoad = new LoadInst(V, V->getName()+".val", Call);
- newLoad->setAlignment(OrigLoad->getAlignment());
- // Transfer the AA info too.
- AAMDNodes AAInfo;
- OrigLoad->getAAMetadata(AAInfo);
- newLoad->setAAMetadata(AAInfo);
+ }
- Args.push_back(newLoad);
+ // Safe to transform, don't even bother trying to "promote" it.
+ // Passing the elements as a scalar will allow sroa to hack on
+ // the new alloca we introduce.
+ if (AllSimple) {
+ ByValArgsToTransform.insert(PtrArg);
+ continue;
}
}
+ }
- // Push any varargs arguments on the list.
- for (; AI != CS.arg_end(); ++AI, ++ArgIndex) {
- Args.push_back(*AI);
- if (CallPAL.hasAttributes(ArgIndex)) {
- AttrBuilder B(CallPAL, ArgIndex);
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Args.size(), B));
+ // If the argument is a recursive type and we're in a recursive
+ // function, we could end up infinitely peeling the function argument.
+ if (isSelfRecursive) {
+ if (StructType *STy = dyn_cast<StructType>(AgTy)) {
+ bool RecursiveType = false;
+ for (const auto *EltTy : STy->elements()) {
+ if (EltTy == PtrArg->getType()) {
+ RecursiveType = true;
+ break;
+ }
+ }
+ if (RecursiveType)
+ continue;
}
}
- // Add any function attributes.
- if (CallPAL.hasAttributes(AttributeSet::FunctionIndex))
- AttributesVec.push_back(AttributeSet::get(Call->getContext(),
- CallPAL.getFnAttributes()));
+ // Otherwise, see if we can promote the pointer to its value.
+ if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR,
+ MaxElements))
+ ArgsToPromote.insert(PtrArg);
+ }
+
+ // No promotable pointer arguments.
+ if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
+ return nullptr;
- SmallVector<OperandBundleDef, 1> OpBundles;
- CS.getOperandBundlesAsDefs(OpBundles);
+ return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite);
+}
- Instruction *New;
- if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
- New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
- Args, OpBundles, "", Call);
- cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
- cast<InvokeInst>(New)->setAttributes(AttributeSet::get(II->getContext(),
- AttributesVec));
- } else {
- New = CallInst::Create(NF, Args, OpBundles, "", Call);
- cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
- cast<CallInst>(New)->setAttributes(AttributeSet::get(New->getContext(),
- AttributesVec));
- cast<CallInst>(New)->setTailCallKind(
- cast<CallInst>(Call)->getTailCallKind());
- }
- New->setDebugLoc(Call->getDebugLoc());
- Args.clear();
- AttributesVec.clear();
+PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
+ CGSCCAnalysisManager &AM,
+ LazyCallGraph &CG,
+ CGSCCUpdateResult &UR) {
+ bool Changed = false, LocalChange;
- // Update the callgraph to know that the callsite has been transformed.
- CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()];
- CalleeNode->replaceCallEdge(CS, CallSite(New), NF_CGN);
+ // Iterate until we stop promoting from this SCC.
+ do {
+ LocalChange = false;
- if (!Call->use_empty()) {
- Call->replaceAllUsesWith(New);
- New->takeName(Call);
+ for (LazyCallGraph::Node &N : C) {
+ Function &OldF = N.getFunction();
+
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
+ // FIXME: This lambda must only be used with this function. We should
+ // skip the lambda and just get the AA results directly.
+ auto AARGetter = [&](Function &F) -> AAResults & {
+ assert(&F == &OldF && "Called with an unexpected function!");
+ return FAM.getResult<AAManager>(F);
+ };
+
+ Function *NewF = promoteArguments(&OldF, AARGetter, 3u, None);
+ if (!NewF)
+ continue;
+ LocalChange = true;
+
+ // Directly substitute the functions in the call graph. Note that this
+ // requires the old function to be completely dead and completely
+ // replaced by the new function. It does no call graph updates, it merely
+ // swaps out the particular function mapped to a particular node in the
+ // graph.
+ C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
+ OldF.eraseFromParent();
}
- // Finally, remove the old call from the program, reducing the use-count of
- // F.
- Call->eraseFromParent();
- }
-
- // Since we have now created the new function, splice the body of the old
- // function right into the new function, leaving the old rotting hulk of the
- // function empty.
- NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
+ Changed |= LocalChange;
+ } while (LocalChange);
- // Loop over the argument list, transferring uses of the old arguments over to
- // the new arguments, also transferring over the names as well.
- //
- for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
- I2 = NF->arg_begin(); I != E; ++I) {
- if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
- // If this is an unmodified argument, move the name and users over to the
- // new version.
- I->replaceAllUsesWith(&*I2);
- I2->takeName(&*I);
- ++I2;
- continue;
- }
+ if (!Changed)
+ return PreservedAnalyses::all();
- if (ByValArgsToTransform.count(&*I)) {
- // In the callee, we create an alloca, and store each of the new incoming
- // arguments into the alloca.
- Instruction *InsertPt = &NF->begin()->front();
+ return PreservedAnalyses::none();
+}
- // Just add all the struct element types.
- Type *AgTy = cast<PointerType>(I->getType())->getElementType();
- Value *TheAlloca = new AllocaInst(AgTy, nullptr, "", InsertPt);
- StructType *STy = cast<StructType>(AgTy);
- Value *Idxs[2] = {
- ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr };
+namespace {
+/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
+///
+struct ArgPromotion : public CallGraphSCCPass {
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ getAAResultsAnalysisUsage(AU);
+ CallGraphSCCPass::getAnalysisUsage(AU);
+ }
- for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
- Value *Idx = GetElementPtrInst::Create(
- AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i),
- InsertPt);
- I2->setName(I->getName()+"."+Twine(i));
- new StoreInst(&*I2++, Idx, InsertPt);
- }
+ bool runOnSCC(CallGraphSCC &SCC) override;
+ static char ID; // Pass identification, replacement for typeid
+ explicit ArgPromotion(unsigned MaxElements = 3)
+ : CallGraphSCCPass(ID), MaxElements(MaxElements) {
+ initializeArgPromotionPass(*PassRegistry::getPassRegistry());
+ }
- // Anything that used the arg should now use the alloca.
- I->replaceAllUsesWith(TheAlloca);
- TheAlloca->takeName(&*I);
+private:
+ using llvm::Pass::doInitialization;
+ bool doInitialization(CallGraph &CG) override;
+ /// The maximum number of elements to expand, or 0 for unlimited.
+ unsigned MaxElements;
+};
+}
- // If the alloca is used in a call, we must clear the tail flag since
- // the callee now uses an alloca from the caller.
- for (User *U : TheAlloca->users()) {
- CallInst *Call = dyn_cast<CallInst>(U);
- if (!Call)
- continue;
- Call->setTailCall(false);
- }
- continue;
- }
+char ArgPromotion::ID = 0;
+INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
+ "Promote 'by reference' arguments to scalars", false,
+ false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
+ "Promote 'by reference' arguments to scalars", false, false)
- if (I->use_empty())
- continue;
+Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) {
+ return new ArgPromotion(MaxElements);
+}
- // Otherwise, if we promoted this argument, then all users are load
- // instructions (or GEPs with only load users), and all loads should be
- // using the new argument that we added.
- ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
+bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
+ if (skipSCC(SCC))
+ return false;
- while (!I->use_empty()) {
- if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) {
- assert(ArgIndices.begin()->second.empty() &&
- "Load element should sort to front!");
- I2->setName(I->getName()+".val");
- LI->replaceAllUsesWith(&*I2);
- LI->eraseFromParent();
- DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName()
- << "' in function '" << F->getName() << "'\n");
- } else {
- GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back());
- IndicesVector Operands;
- Operands.reserve(GEP->getNumIndices());
- for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
- II != IE; ++II)
- Operands.push_back(cast<ConstantInt>(*II)->getSExtValue());
+ // Get the callgraph information that we need to update to reflect our
+ // changes.
+ CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
- // GEPs with a single 0 index can be merged with direct loads
- if (Operands.size() == 1 && Operands.front() == 0)
- Operands.clear();
+ LegacyAARGetter AARGetter(*this);
- Function::arg_iterator TheArg = I2;
- for (ScalarizeTable::iterator It = ArgIndices.begin();
- It->second != Operands; ++It, ++TheArg) {
- assert(It != ArgIndices.end() && "GEP not handled??");
- }
+ bool Changed = false, LocalChange;
- std::string NewName = I->getName();
- for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
- NewName += "." + utostr(Operands[i]);
- }
- NewName += ".val";
- TheArg->setName(NewName);
+ // Iterate until we stop promoting from this SCC.
+ do {
+ LocalChange = false;
+ // Attempt to promote arguments from all functions in this SCC.
+ for (CallGraphNode *OldNode : SCC) {
+ Function *OldF = OldNode->getFunction();
+ if (!OldF)
+ continue;
+
+ auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) {
+ Function *Caller = OldCS.getInstruction()->getParent()->getParent();
+ CallGraphNode *NewCalleeNode =
+ CG.getOrInsertFunction(NewCS.getCalledFunction());
+ CallGraphNode *CallerNode = CG[Caller];
+ CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode);
+ };
+
+ if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements,
+ {ReplaceCallSite})) {
+ LocalChange = true;
- DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName()
- << "' of function '" << NF->getName() << "'\n");
+ // Update the call graph for the newly promoted function.
+ CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
+ NewNode->stealCalledFunctionsFrom(OldNode);
+ if (OldNode->getNumReferences() == 0)
+ delete CG.removeFunctionFromModule(OldNode);
+ else
+ OldF->setLinkage(Function::ExternalLinkage);
- // All of the uses must be load instructions. Replace them all with
- // the argument specified by ArgNo.
- while (!GEP->use_empty()) {
- LoadInst *L = cast<LoadInst>(GEP->user_back());
- L->replaceAllUsesWith(&*TheArg);
- L->eraseFromParent();
- }
- GEP->eraseFromParent();
+ // And updat ethe SCC we're iterating as well.
+ SCC.ReplaceNode(OldNode, NewNode);
}
}
+ // Remember that we changed something.
+ Changed |= LocalChange;
+ } while (LocalChange);
- // Increment I2 past all of the arguments added for this promoted pointer.
- std::advance(I2, ArgIndices.size());
- }
-
- NF_CGN->stealCalledFunctionsFrom(CG[F]);
-
- // Now that the old function is dead, delete it. If there is a dangling
- // reference to the CallgraphNode, just leave the dead function around for
- // someone else to nuke.
- CallGraphNode *CGN = CG[F];
- if (CGN->getNumReferences() == 0)
- delete CG.removeFunctionFromModule(CGN);
- else
- F->setLinkage(Function::ExternalLinkage);
-
- return NF_CGN;
+ return Changed;
}
bool ArgPromotion::doInitialization(CallGraph &CG) {
diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp
index d75ed206ad23c..62b5a9c9ba266 100644
--- a/lib/Transforms/IPO/ConstantMerge.cpp
+++ b/lib/Transforms/IPO/ConstantMerge.cpp
@@ -60,6 +60,23 @@ static bool IsBetterCanonical(const GlobalVariable &A,
return A.hasGlobalUnnamedAddr();
}
+static bool hasMetadataOtherThanDebugLoc(const GlobalVariable *GV) {
+ SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
+ GV->getAllMetadata(MDs);
+ for (const auto &V : MDs)
+ if (V.first != LLVMContext::MD_dbg)
+ return true;
+ return false;
+}
+
+static void copyDebugLocMetadata(const GlobalVariable *From,
+ GlobalVariable *To) {
+ SmallVector<DIGlobalVariableExpression *, 1> MDs;
+ From->getDebugInfo(MDs);
+ for (auto MD : MDs)
+ To->addDebugInfo(MD);
+}
+
static unsigned getAlignment(GlobalVariable *GV) {
unsigned Align = GV->getAlignment();
if (Align)
@@ -113,6 +130,10 @@ static bool mergeConstants(Module &M) {
if (GV->isWeakForLinker())
continue;
+ // Don't touch globals with metadata other then !dbg.
+ if (hasMetadataOtherThanDebugLoc(GV))
+ continue;
+
Constant *Init = GV->getInitializer();
// Check to see if the initializer is already known.
@@ -155,6 +176,9 @@ static bool mergeConstants(Module &M) {
if (!Slot->hasGlobalUnnamedAddr() && !GV->hasGlobalUnnamedAddr())
continue;
+ if (hasMetadataOtherThanDebugLoc(GV))
+ continue;
+
if (!GV->hasGlobalUnnamedAddr())
Slot->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
@@ -178,6 +202,8 @@ static bool mergeConstants(Module &M) {
getAlignment(Replacements[i].second)));
}
+ copyDebugLocMetadata(Replacements[i].first, Replacements[i].second);
+
// Eliminate any uses of the dead global.
Replacements[i].first->replaceAllUsesWith(Replacements[i].second);
diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp
index ba2e60dee3bcb..1b111de061576 100644
--- a/lib/Transforms/IPO/CrossDSOCFI.cpp
+++ b/lib/Transforms/IPO/CrossDSOCFI.cpp
@@ -98,8 +98,11 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
LLVMContext &Ctx = M.getContext();
Constant *C = M.getOrInsertFunction(
"__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),
- Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr);
+ Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx));
Function *F = dyn_cast<Function>(C);
+ // Take over the existing function. The frontend emits a weak stub so that the
+ // linker knows about the symbol; this pass replaces the function body.
+ F->deleteBody();
F->setAlignment(4096);
auto args = F->arg_begin();
Value &CallSiteTypeId = *(args++);
@@ -117,7 +120,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
IRBuilder<> IRBFail(TrapBB);
Constant *CFICheckFailFn = M.getOrInsertFunction(
"__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx),
- Type::getInt8PtrTy(Ctx), nullptr);
+ Type::getInt8PtrTy(Ctx));
IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});
IRBFail.CreateBr(ExitBB);
diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp
index 1a5ed46922118..375b74c494d92 100644
--- a/lib/Transforms/IPO/DeadArgumentElimination.cpp
+++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp
@@ -166,41 +166,43 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) {
Args.assign(CS.arg_begin(), CS.arg_begin() + NumArgs);
// Drop any attributes that were on the vararg arguments.
- AttributeSet PAL = CS.getAttributes();
+ AttributeList PAL = CS.getAttributes();
if (!PAL.isEmpty() && PAL.getSlotIndex(PAL.getNumSlots() - 1) > NumArgs) {
- SmallVector<AttributeSet, 8> AttributesVec;
+ SmallVector<AttributeList, 8> AttributesVec;
for (unsigned i = 0; PAL.getSlotIndex(i) <= NumArgs; ++i)
AttributesVec.push_back(PAL.getSlotAttributes(i));
- if (PAL.hasAttributes(AttributeSet::FunctionIndex))
- AttributesVec.push_back(AttributeSet::get(Fn.getContext(),
- PAL.getFnAttributes()));
- PAL = AttributeSet::get(Fn.getContext(), AttributesVec);
+ if (PAL.hasAttributes(AttributeList::FunctionIndex))
+ AttributesVec.push_back(AttributeList::get(Fn.getContext(),
+ AttributeList::FunctionIndex,
+ PAL.getFnAttributes()));
+ PAL = AttributeList::get(Fn.getContext(), AttributesVec);
}
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
- Instruction *New;
+ CallSite NewCS;
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
- New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
- Args, OpBundles, "", Call);
- cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
- cast<InvokeInst>(New)->setAttributes(PAL);
+ NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
+ Args, OpBundles, "", Call);
} else {
- New = CallInst::Create(NF, Args, OpBundles, "", Call);
- cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
- cast<CallInst>(New)->setAttributes(PAL);
- cast<CallInst>(New)->setTailCallKind(
- cast<CallInst>(Call)->getTailCallKind());
+ NewCS = CallInst::Create(NF, Args, OpBundles, "", Call);
+ cast<CallInst>(NewCS.getInstruction())
+ ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
}
- New->setDebugLoc(Call->getDebugLoc());
+ NewCS.setCallingConv(CS.getCallingConv());
+ NewCS.setAttributes(PAL);
+ NewCS->setDebugLoc(Call->getDebugLoc());
+ uint64_t W;
+ if (Call->extractProfTotalWeight(W))
+ NewCS->setProfWeight(W);
Args.clear();
if (!Call->use_empty())
- Call->replaceAllUsesWith(New);
+ Call->replaceAllUsesWith(NewCS.getInstruction());
- New->takeName(Call);
+ NewCS->takeName(Call);
// Finally, remove the old call from the program, reducing the use-count of
// F.
@@ -681,8 +683,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
bool HasLiveReturnedArg = false;
// Set up to build a new list of parameter attributes.
- SmallVector<AttributeSet, 8> AttributesVec;
- const AttributeSet &PAL = F->getAttributes();
+ SmallVector<AttributeSet, 8> ArgAttrVec;
+ const AttributeList &PAL = F->getAttributes();
// Remember which arguments are still alive.
SmallVector<bool, 10> ArgAlive(FTy->getNumParams(), false);
@@ -696,16 +698,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
if (LiveValues.erase(Arg)) {
Params.push_back(I->getType());
ArgAlive[i] = true;
-
- // Get the original parameter attributes (skipping the first one, that is
- // for the return value.
- if (PAL.hasAttributes(i + 1)) {
- AttrBuilder B(PAL, i + 1);
- if (B.contains(Attribute::Returned))
- HasLiveReturnedArg = true;
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Params.size(), B));
- }
+ ArgAttrVec.push_back(PAL.getParamAttributes(i));
+ HasLiveReturnedArg |= PAL.hasParamAttribute(i, Attribute::Returned);
} else {
++NumArgumentsEliminated;
DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " << i
@@ -779,30 +773,24 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
assert(NRetTy && "No new return type found?");
// The existing function return attributes.
- AttributeSet RAttrs = PAL.getRetAttributes();
+ AttrBuilder RAttrs(PAL.getRetAttributes());
// Remove any incompatible attributes, but only if we removed all return
// values. Otherwise, ensure that we don't have any conflicting attributes
// here. Currently, this should not be possible, but special handling might be
// required when new return value attributes are added.
if (NRetTy->isVoidTy())
- RAttrs = RAttrs.removeAttributes(NRetTy->getContext(),
- AttributeSet::ReturnIndex,
- AttributeFuncs::typeIncompatible(NRetTy));
+ RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy));
else
- assert(!AttrBuilder(RAttrs, AttributeSet::ReturnIndex).
- overlaps(AttributeFuncs::typeIncompatible(NRetTy)) &&
+ assert(!RAttrs.overlaps(AttributeFuncs::typeIncompatible(NRetTy)) &&
"Return attributes no longer compatible?");
- if (RAttrs.hasAttributes(AttributeSet::ReturnIndex))
- AttributesVec.push_back(AttributeSet::get(NRetTy->getContext(), RAttrs));
-
- if (PAL.hasAttributes(AttributeSet::FunctionIndex))
- AttributesVec.push_back(AttributeSet::get(F->getContext(),
- PAL.getFnAttributes()));
+ AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs);
// Reconstruct the AttributesList based on the vector we constructed.
- AttributeSet NewPAL = AttributeSet::get(F->getContext(), AttributesVec);
+ assert(ArgAttrVec.size() == Params.size());
+ AttributeList NewPAL = AttributeList::get(
+ F->getContext(), PAL.getFnAttributes(), RetAttrs, ArgAttrVec);
// Create the new function type based on the recomputed parameters.
FunctionType *NFTy = FunctionType::get(NRetTy, Params, FTy->isVarArg());
@@ -829,18 +817,14 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
CallSite CS(F->user_back());
Instruction *Call = CS.getInstruction();
- AttributesVec.clear();
- const AttributeSet &CallPAL = CS.getAttributes();
-
- // The call return attributes.
- AttributeSet RAttrs = CallPAL.getRetAttributes();
+ ArgAttrVec.clear();
+ const AttributeList &CallPAL = CS.getAttributes();
- // Adjust in case the function was changed to return void.
- RAttrs = RAttrs.removeAttributes(NRetTy->getContext(),
- AttributeSet::ReturnIndex,
- AttributeFuncs::typeIncompatible(NF->getReturnType()));
- if (RAttrs.hasAttributes(AttributeSet::ReturnIndex))
- AttributesVec.push_back(AttributeSet::get(NF->getContext(), RAttrs));
+ // Adjust the call return attributes in case the function was changed to
+ // return void.
+ AttrBuilder RAttrs(CallPAL.getRetAttributes());
+ RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy));
+ AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs);
// Declare these outside of the loops, so we can reuse them for the second
// loop, which loops the varargs.
@@ -852,57 +836,55 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) {
if (ArgAlive[i]) {
Args.push_back(*I);
// Get original parameter attributes, but skip return attributes.
- if (CallPAL.hasAttributes(i + 1)) {
- AttrBuilder B(CallPAL, i + 1);
+ AttributeSet Attrs = CallPAL.getParamAttributes(i);
+ if (NRetTy != RetTy && Attrs.hasAttribute(Attribute::Returned)) {
// If the return type has changed, then get rid of 'returned' on the
// call site. The alternative is to make all 'returned' attributes on
// call sites keep the return value alive just like 'returned'
- // attributes on function declaration but it's less clearly a win
- // and this is not an expected case anyway
- if (NRetTy != RetTy && B.contains(Attribute::Returned))
- B.removeAttribute(Attribute::Returned);
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Args.size(), B));
+ // attributes on function declaration but it's less clearly a win and
+ // this is not an expected case anyway
+ ArgAttrVec.push_back(AttributeSet::get(
+ F->getContext(),
+ AttrBuilder(Attrs).removeAttribute(Attribute::Returned)));
+ } else {
+ // Otherwise, use the original attributes.
+ ArgAttrVec.push_back(Attrs);
}
}
// Push any varargs arguments on the list. Don't forget their attributes.
for (CallSite::arg_iterator E = CS.arg_end(); I != E; ++I, ++i) {
Args.push_back(*I);
- if (CallPAL.hasAttributes(i + 1)) {
- AttrBuilder B(CallPAL, i + 1);
- AttributesVec.
- push_back(AttributeSet::get(F->getContext(), Args.size(), B));
- }
+ ArgAttrVec.push_back(CallPAL.getParamAttributes(i));
}
- if (CallPAL.hasAttributes(AttributeSet::FunctionIndex))
- AttributesVec.push_back(AttributeSet::get(Call->getContext(),
- CallPAL.getFnAttributes()));
-
// Reconstruct the AttributesList based on the vector we constructed.
- AttributeSet NewCallPAL = AttributeSet::get(F->getContext(), AttributesVec);
+ assert(ArgAttrVec.size() == Args.size());
+ AttributeList NewCallPAL = AttributeList::get(
+ F->getContext(), CallPAL.getFnAttributes(), RetAttrs, ArgAttrVec);
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
- Instruction *New;
+ CallSite NewCS;
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
- New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
- Args, OpBundles, "", Call->getParent());
- cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
- cast<InvokeInst>(New)->setAttributes(NewCallPAL);
+ NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
+ Args, OpBundles, "", Call->getParent());
} else {
- New = CallInst::Create(NF, Args, OpBundles, "", Call);
- cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
- cast<CallInst>(New)->setAttributes(NewCallPAL);
- cast<CallInst>(New)->setTailCallKind(
- cast<CallInst>(Call)->getTailCallKind());
+ NewCS = CallInst::Create(NF, Args, OpBundles, "", Call);
+ cast<CallInst>(NewCS.getInstruction())
+ ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
}
- New->setDebugLoc(Call->getDebugLoc());
-
+ NewCS.setCallingConv(CS.getCallingConv());
+ NewCS.setAttributes(NewCallPAL);
+ NewCS->setDebugLoc(Call->getDebugLoc());
+ uint64_t W;
+ if (Call->extractProfTotalWeight(W))
+ NewCS->setProfWeight(W);
Args.clear();
+ ArgAttrVec.clear();
+ Instruction *New = NewCS.getInstruction();
if (!Call->use_empty()) {
if (New->getType() == Call->getType()) {
// Return type not changed? Just replace users then.
diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp
index 402a66552c243..4d13b3f406887 100644
--- a/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -49,31 +49,35 @@ STATISTIC(NumNoAlias, "Number of function returns marked noalias");
STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull");
STATISTIC(NumNoRecurse, "Number of functions marked as norecurse");
-namespace {
-typedef SmallSetVector<Function *, 8> SCCNodeSet;
-}
+// FIXME: This is disabled by default to avoid exposing security vulnerabilities
+// in C/C++ code compiled by clang:
+// http://lists.llvm.org/pipermail/cfe-dev/2017-January/052066.html
+static cl::opt<bool> EnableNonnullArgPropagation(
+ "enable-nonnull-arg-prop", cl::Hidden,
+ cl::desc("Try to propagate nonnull argument attributes from callsites to "
+ "caller functions."));
namespace {
-/// The three kinds of memory access relevant to 'readonly' and
-/// 'readnone' attributes.
-enum MemoryAccessKind {
- MAK_ReadNone = 0,
- MAK_ReadOnly = 1,
- MAK_MayWrite = 2
-};
+typedef SmallSetVector<Function *, 8> SCCNodeSet;
}
-static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR,
+/// Returns the memory access attribute for function F using AAR for AA results,
+/// where SCCNodes is the current SCC.
+///
+/// If ThisBody is true, this function may examine the function body and will
+/// return a result pertaining to this copy of the function. If it is false, the
+/// result will be based only on AA results for the function declaration; it
+/// will be assumed that some other (perhaps less optimized) version of the
+/// function may be selected at link time.
+static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody,
+ AAResults &AAR,
const SCCNodeSet &SCCNodes) {
FunctionModRefBehavior MRB = AAR.getModRefBehavior(&F);
if (MRB == FMRB_DoesNotAccessMemory)
// Already perfect!
return MAK_ReadNone;
- // Non-exact function definitions may not be selected at link time, and an
- // alternative version that writes to memory may be selected. See the comment
- // on GlobalValue::isDefinitionExact for more details.
- if (!F.hasExactDefinition()) {
+ if (!ThisBody) {
if (AliasAnalysis::onlyReadsMemory(MRB))
return MAK_ReadOnly;
@@ -172,9 +176,14 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR,
return ReadsMemory ? MAK_ReadOnly : MAK_ReadNone;
}
+MemoryAccessKind llvm::computeFunctionBodyMemoryAccess(Function &F,
+ AAResults &AAR) {
+ return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {});
+}
+
/// Deduce readonly/readnone attributes for the SCC.
template <typename AARGetterT>
-static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) {
+static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) {
// Check if any of the functions in the SCC read or write memory. If they
// write memory then they can't be marked readnone or readonly.
bool ReadsMemory = false;
@@ -182,7 +191,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) {
// Call the callable parameter to look up AA results for this function.
AAResults &AAR = AARGetter(*F);
- switch (checkFunctionMemoryAccess(*F, AAR, SCCNodes)) {
+ // Non-exact function definitions may not be selected at link time, and an
+ // alternative version that writes to memory may be selected. See the
+ // comment on GlobalValue::isDefinitionExact for more details.
+ switch (checkFunctionMemoryAccess(*F, F->hasExactDefinition(),
+ AAR, SCCNodes)) {
case MAK_MayWrite:
return false;
case MAK_ReadOnly:
@@ -212,11 +225,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) {
AttrBuilder B;
B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone);
F->removeAttributes(
- AttributeSet::FunctionIndex,
- AttributeSet::get(F->getContext(), AttributeSet::FunctionIndex, B));
+ AttributeList::FunctionIndex,
+ AttributeList::get(F->getContext(), AttributeList::FunctionIndex, B));
// Add in the new attribute.
- F->addAttribute(AttributeSet::FunctionIndex,
+ F->addAttribute(AttributeList::FunctionIndex,
ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone);
if (ReadsMemory)
@@ -522,7 +535,7 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) {
if (Value *RetArg = FindRetArg()) {
auto *A = cast<Argument>(RetArg);
- A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B));
+ A->addAttr(AttributeList::get(F->getContext(), A->getArgNo() + 1, B));
++NumReturned;
Changed = true;
}
@@ -531,6 +544,49 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) {
return Changed;
}
+/// If a callsite has arguments that are also arguments to the parent function,
+/// try to propagate attributes from the callsite's arguments to the parent's
+/// arguments. This may be important because inlining can cause information loss
+/// when attribute knowledge disappears with the inlined call.
+static bool addArgumentAttrsFromCallsites(Function &F) {
+ if (!EnableNonnullArgPropagation)
+ return false;
+
+ bool Changed = false;
+
+ // For an argument attribute to transfer from a callsite to the parent, the
+ // call must be guaranteed to execute every time the parent is called.
+ // Conservatively, just check for calls in the entry block that are guaranteed
+ // to execute.
+ // TODO: This could be enhanced by testing if the callsite post-dominates the
+ // entry block or by doing simple forward walks or backward walks to the
+ // callsite.
+ BasicBlock &Entry = F.getEntryBlock();
+ for (Instruction &I : Entry) {
+ if (auto CS = CallSite(&I)) {
+ if (auto *CalledFunc = CS.getCalledFunction()) {
+ for (auto &CSArg : CalledFunc->args()) {
+ if (!CSArg.hasNonNullAttr())
+ continue;
+
+ // If the non-null callsite argument operand is an argument to 'F'
+ // (the caller) and the call is guaranteed to execute, then the value
+ // must be non-null throughout 'F'.
+ auto *FArg = dyn_cast<Argument>(CS.getArgOperand(CSArg.getArgNo()));
+ if (FArg && !FArg->hasNonNullAttr()) {
+ FArg->addAttr(Attribute::NonNull);
+ Changed = true;
+ }
+ }
+ }
+ }
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I))
+ break;
+ }
+
+ return Changed;
+}
+
/// Deduce nocapture attributes for the SCC.
static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
bool Changed = false;
@@ -549,6 +605,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
if (!F->hasExactDefinition())
continue;
+ Changed |= addArgumentAttrsFromCallsites(*F);
+
// Functions that are readonly (or readnone) and nounwind and don't return
// a value can't capture arguments. Don't analyze them.
if (F->onlyReadsMemory() && F->doesNotThrow() &&
@@ -556,7 +614,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
for (Function::arg_iterator A = F->arg_begin(), E = F->arg_end(); A != E;
++A) {
if (A->getType()->isPointerTy() && !A->hasNoCaptureAttr()) {
- A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B));
+ A->addAttr(AttributeList::get(F->getContext(), A->getArgNo() + 1, B));
++NumNoCapture;
Changed = true;
}
@@ -576,7 +634,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
if (Tracker.Uses.empty()) {
// If it's trivially not captured, mark it nocapture now.
A->addAttr(
- AttributeSet::get(F->getContext(), A->getArgNo() + 1, B));
+ AttributeList::get(F->getContext(), A->getArgNo() + 1, B));
++NumNoCapture;
Changed = true;
} else {
@@ -604,7 +662,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
if (R != Attribute::None) {
AttrBuilder B;
B.addAttribute(R);
- A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B));
+ A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));
Changed = true;
R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;
}
@@ -629,7 +687,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
if (ArgumentSCC[0]->Uses.size() == 1 &&
ArgumentSCC[0]->Uses[0] == ArgumentSCC[0]) {
Argument *A = ArgumentSCC[0]->Definition;
- A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B));
+ A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));
++NumNoCapture;
Changed = true;
}
@@ -671,7 +729,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) {
Argument *A = ArgumentSCC[i]->Definition;
- A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B));
+ A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));
++NumNoCapture;
Changed = true;
}
@@ -708,8 +766,9 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) {
Argument *A = ArgumentSCC[i]->Definition;
// Clear out existing readonly/readnone attributes
- A->removeAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, R));
- A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B));
+ A->removeAttr(
+ AttributeList::get(A->getContext(), A->getArgNo() + 1, R));
+ A->addAttr(AttributeList::get(A->getContext(), A->getArgNo() + 1, B));
ReadAttr == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;
Changed = true;
}
@@ -769,7 +828,7 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) {
case Instruction::Call:
case Instruction::Invoke: {
CallSite CS(RVI);
- if (CS.paramHasAttr(0, Attribute::NoAlias))
+ if (CS.hasRetAttr(Attribute::NoAlias))
break;
if (CS.getCalledFunction() && SCCNodes.count(CS.getCalledFunction()))
break;
@@ -905,7 +964,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {
// pointers.
for (Function *F : SCCNodes) {
// Already nonnull.
- if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex,
+ if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,
Attribute::NonNull))
continue;
@@ -926,7 +985,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {
// Mark the function eagerly since we may discover a function
// which prevents us from speculating about the entire SCC
DEBUG(dbgs() << "Eagerly marking " << F->getName() << " as nonnull\n");
- F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
+ F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
++NumNonNullReturn;
MadeChange = true;
}
@@ -939,13 +998,13 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) {
if (SCCReturnsNonNull) {
for (Function *F : SCCNodes) {
- if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex,
+ if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,
Attribute::NonNull) ||
!F->getReturnType()->isPointerTy())
continue;
DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n");
- F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
+ F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
++NumNonNullReturn;
MadeChange = true;
}
@@ -1163,19 +1222,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) {
bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) {
if (skipSCC(SCC))
return false;
-
- // We compute dedicated AA results for each function in the SCC as needed. We
- // use a lambda referencing external objects so that they live long enough to
- // be queried, but we re-use them each time.
- Optional<BasicAAResult> BAR;
- Optional<AAResults> AAR;
- auto AARGetter = [&](Function &F) -> AAResults & {
- BAR.emplace(createLegacyPMBasicAAResult(*this, F));
- AAR.emplace(createLegacyPMAAResults(*this, F, *BAR));
- return *AAR;
- };
-
- return runImpl(SCC, AARGetter);
+ return runImpl(SCC, LegacyAARGetter(*this));
}
namespace {
@@ -1275,16 +1322,9 @@ PreservedAnalyses
ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) {
auto &CG = AM.getResult<CallGraphAnalysis>(M);
- bool Changed = deduceFunctionAttributeInRPO(M, CG);
-
- // CallGraphAnalysis holds AssertingVH and must be invalidated eagerly so
- // that other passes don't delete stuff from under it.
- // FIXME: We need to invalidate this to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<CallGraphAnalysis>(M);
-
- if (!Changed)
+ if (!deduceFunctionAttributeInRPO(M, CG))
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
PA.preserve<CallGraphAnalysis>();
return PA;
diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp
index 6b32f6c31f722..d66411f04cc4a 100644
--- a/lib/Transforms/IPO/FunctionImport.cpp
+++ b/lib/Transforms/IPO/FunctionImport.cpp
@@ -75,12 +75,6 @@ static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden,
static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden,
cl::desc("Compute dead symbols"));
-// Temporary allows the function import pass to disable always linking
-// referenced discardable symbols.
-static cl::opt<bool>
- DontForceImportReferencedDiscardableSymbols("disable-force-link-odr",
- cl::init(false), cl::Hidden);
-
static cl::opt<bool> EnableImportMetadata(
"enable-import-metadata", cl::init(
#if !defined(NDEBUG)
@@ -124,7 +118,7 @@ namespace {
static const GlobalValueSummary *
selectCallee(const ModuleSummaryIndex &Index,
const GlobalValueSummaryList &CalleeSummaryList,
- unsigned Threshold) {
+ unsigned Threshold, StringRef CallerModulePath) {
auto It = llvm::find_if(
CalleeSummaryList,
[&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) {
@@ -145,6 +139,21 @@ selectCallee(const ModuleSummaryIndex &Index,
auto *Summary = cast<FunctionSummary>(GVSummary);
+ // If this is a local function, make sure we import the copy
+ // in the caller's module. The only time a local function can
+ // share an entry in the index is if there is a local with the same name
+ // in another module that had the same source file name (in a different
+ // directory), where each was compiled in their own directory so there
+ // was not distinguishing path.
+ // However, do the import from another module if there is only one
+ // entry in the list - in that case this must be a reference due
+ // to indirect call profile data, since a function pointer can point to
+ // a local in another module.
+ if (GlobalValue::isLocalLinkage(Summary->linkage()) &&
+ CalleeSummaryList.size() > 1 &&
+ Summary->modulePath() != CallerModulePath)
+ return false;
+
if (Summary->instCount() > Threshold)
return false;
@@ -163,11 +172,13 @@ selectCallee(const ModuleSummaryIndex &Index,
/// null if there's no match.
static const GlobalValueSummary *selectCallee(GlobalValue::GUID GUID,
unsigned Threshold,
- const ModuleSummaryIndex &Index) {
+ const ModuleSummaryIndex &Index,
+ StringRef CallerModulePath) {
auto CalleeSummaryList = Index.findGlobalValueSummaryList(GUID);
if (CalleeSummaryList == Index.end())
return nullptr; // This function does not have a summary
- return selectCallee(Index, CalleeSummaryList->second, Threshold);
+ return selectCallee(Index, CalleeSummaryList->second, Threshold,
+ CallerModulePath);
}
using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */,
@@ -186,6 +197,15 @@ static void computeImportForFunction(
auto GUID = Edge.first.getGUID();
DEBUG(dbgs() << " edge -> " << GUID << " Threshold:" << Threshold << "\n");
+ if (Index.findGlobalValueSummaryList(GUID) == Index.end()) {
+ // For SamplePGO, the indirect call targets for local functions will
+ // have its original name annotated in profile. We try to find the
+ // corresponding PGOFuncName as the GUID.
+ GUID = Index.getGUIDFromOriginalID(GUID);
+ if (GUID == 0)
+ continue;
+ }
+
if (DefinedGVSummaries.count(GUID)) {
DEBUG(dbgs() << "ignored! Target already in destination module.\n");
continue;
@@ -202,7 +222,8 @@ static void computeImportForFunction(
const auto NewThreshold =
Threshold * GetBonusMultiplier(Edge.second.Hotness);
- auto *CalleeSummary = selectCallee(GUID, NewThreshold, Index);
+ auto *CalleeSummary =
+ selectCallee(GUID, NewThreshold, Index, Summary.modulePath());
if (!CalleeSummary) {
DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n");
continue;
@@ -522,6 +543,23 @@ llvm::EmitImportsFiles(StringRef ModulePath, StringRef OutputFilename,
/// Fixup WeakForLinker linkages in \p TheModule based on summary analysis.
void llvm::thinLTOResolveWeakForLinkerModule(
Module &TheModule, const GVSummaryMapTy &DefinedGlobals) {
+ auto ConvertToDeclaration = [](GlobalValue &GV) {
+ DEBUG(dbgs() << "Converting to a declaration: `" << GV.getName() << "\n");
+ if (Function *F = dyn_cast<Function>(&GV)) {
+ F->deleteBody();
+ F->clearMetadata();
+ } else if (GlobalVariable *V = dyn_cast<GlobalVariable>(&GV)) {
+ V->setInitializer(nullptr);
+ V->setLinkage(GlobalValue::ExternalLinkage);
+ V->clearMetadata();
+ } else
+ // For now we don't resolve or drop aliases. Once we do we'll
+ // need to add support here for creating either a function or
+ // variable declaration, and return the new GlobalValue* for
+ // the caller to use.
+ llvm_unreachable("Expected function or variable");
+ };
+
auto updateLinkage = [&](GlobalValue &GV) {
if (!GlobalValue::isWeakForLinker(GV.getLinkage()))
return;
@@ -532,18 +570,25 @@ void llvm::thinLTOResolveWeakForLinkerModule(
auto NewLinkage = GS->second->linkage();
if (NewLinkage == GV.getLinkage())
return;
- DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from "
- << GV.getLinkage() << " to " << NewLinkage << "\n");
- GV.setLinkage(NewLinkage);
- // Remove functions converted to available_externally from comdats,
+ // Check for a non-prevailing def that has interposable linkage
+ // (e.g. non-odr weak or linkonce). In that case we can't simply
+ // convert to available_externally, since it would lose the
+ // interposable property and possibly get inlined. Simply drop
+ // the definition in that case.
+ if (GlobalValue::isAvailableExternallyLinkage(NewLinkage) &&
+ GlobalValue::isInterposableLinkage(GV.getLinkage()))
+ ConvertToDeclaration(GV);
+ else {
+ DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from "
+ << GV.getLinkage() << " to " << NewLinkage << "\n");
+ GV.setLinkage(NewLinkage);
+ }
+ // Remove declarations from comdats, including available_externally
// as this is a declaration for the linker, and will be dropped eventually.
// It is illegal for comdats to contain declarations.
auto *GO = dyn_cast_or_null<GlobalObject>(&GV);
- if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) {
- assert(GO->hasAvailableExternallyLinkage() &&
- "Expected comdat on definition (possibly available external)");
+ if (GO && GO->isDeclarationForLinker() && GO->hasComdat())
GO->setComdat(nullptr);
- }
};
// Process functions and global now
@@ -562,7 +607,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule,
// the current module.
StringSet<> AsmUndefinedRefs;
ModuleSymbolTable::CollectAsmSymbols(
- Triple(TheModule.getTargetTriple()), TheModule.getModuleInlineAsm(),
+ TheModule,
[&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) {
if (Flags & object::BasicSymbolRef::SF_Undefined)
AsmUndefinedRefs.insert(Name);
@@ -617,14 +662,12 @@ void llvm::thinLTOInternalizeModule(Module &TheModule,
// index.
//
Expected<bool> FunctionImporter::importFunctions(
- Module &DestModule, const FunctionImporter::ImportMapTy &ImportList,
- bool ForceImportReferencedDiscardableSymbols) {
+ Module &DestModule, const FunctionImporter::ImportMapTy &ImportList) {
DEBUG(dbgs() << "Starting import for Module "
<< DestModule.getModuleIdentifier() << "\n");
unsigned ImportedCount = 0;
- // Linker that will be used for importing function
- Linker TheLinker(DestModule);
+ IRMover Mover(DestModule);
// Do the actual import of functions now, one Module at a time
std::set<StringRef> ModuleNameOrderedList;
for (auto &FunctionsToImportPerModule : ImportList) {
@@ -648,7 +691,7 @@ Expected<bool> FunctionImporter::importFunctions(
auto &ImportGUIDs = FunctionsToImportPerModule->second;
// Find the globals to import
- DenseSet<const GlobalValue *> GlobalsToImport;
+ SetVector<GlobalValue *> GlobalsToImport;
for (Function &F : *SrcModule) {
if (!F.hasName())
continue;
@@ -687,6 +730,13 @@ Expected<bool> FunctionImporter::importFunctions(
}
}
for (GlobalAlias &GA : SrcModule->aliases()) {
+ // FIXME: This should eventually be controlled entirely by the summary.
+ if (FunctionImportGlobalProcessing::doImportAsDefinition(
+ &GA, &GlobalsToImport)) {
+ GlobalsToImport.insert(&GA);
+ continue;
+ }
+
if (!GA.hasName())
continue;
auto GUID = GA.getGUID();
@@ -731,12 +781,9 @@ Expected<bool> FunctionImporter::importFunctions(
<< " from " << SrcModule->getSourceFileName() << "\n";
}
- // Instruct the linker that the client will take care of linkonce resolution
- unsigned Flags = Linker::Flags::None;
- if (!ForceImportReferencedDiscardableSymbols)
- Flags |= Linker::Flags::DontForceLinkLinkonceODR;
-
- if (TheLinker.linkInModule(std::move(SrcModule), Flags, &GlobalsToImport))
+ if (Mover.move(std::move(SrcModule), GlobalsToImport.getArrayRef(),
+ [](GlobalValue &, IRMover::ValueAdder) {},
+ /*IsPerformingImport=*/true))
report_fatal_error("Function Import: link error");
ImportedCount += GlobalsToImport.size();
@@ -796,8 +843,7 @@ static bool doImportingForModule(Module &M) {
return loadFile(Identifier, M.getContext());
};
FunctionImporter Importer(*Index, ModuleLoader);
- Expected<bool> Result = Importer.importFunctions(
- M, ImportList, !DontForceImportReferencedDiscardableSymbols);
+ Expected<bool> Result = Importer.importFunctions(M, ImportList);
// FIXME: Probably need to propagate Errors through the pass manager.
if (!Result) {
diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp
index 7a04de3d12dbf..c91e8b454927f 100644
--- a/lib/Transforms/IPO/GlobalDCE.cpp
+++ b/lib/Transforms/IPO/GlobalDCE.cpp
@@ -25,7 +25,7 @@
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/CtorUtils.h"
#include "llvm/Transforms/Utils/GlobalStatus.h"
-#include <unordered_map>
+
using namespace llvm;
#define DEBUG_TYPE "globaldce"
@@ -50,7 +50,14 @@ namespace {
if (skipModule(M))
return false;
+ // We need a minimally functional dummy module analysis manager. It needs
+ // to at least know about the possibility of proxying a function analysis
+ // manager.
+ FunctionAnalysisManager DummyFAM;
ModuleAnalysisManager DummyMAM;
+ DummyMAM.registerPass(
+ [&] { return FunctionAnalysisManagerModuleProxy(DummyFAM); });
+
auto PA = Impl.run(M, DummyMAM);
return !PA.areAllPreserved();
}
@@ -78,9 +85,67 @@ static bool isEmptyFunction(Function *F) {
return RI.getReturnValue() == nullptr;
}
-PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {
+/// Compute the set of GlobalValue that depends from V.
+/// The recursion stops as soon as a GlobalValue is met.
+void GlobalDCEPass::ComputeDependencies(Value *V,
+ SmallPtrSetImpl<GlobalValue *> &Deps) {
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ Function *Parent = I->getParent()->getParent();
+ Deps.insert(Parent);
+ } else if (auto *GV = dyn_cast<GlobalValue>(V)) {
+ Deps.insert(GV);
+ } else if (auto *CE = dyn_cast<Constant>(V)) {
+ // Avoid walking the whole tree of a big ConstantExprs multiple times.
+ auto Where = ConstantDependenciesCache.find(CE);
+ if (Where != ConstantDependenciesCache.end()) {
+ auto const &K = Where->second;
+ Deps.insert(K.begin(), K.end());
+ } else {
+ SmallPtrSetImpl<GlobalValue *> &LocalDeps = ConstantDependenciesCache[CE];
+ for (User *CEUser : CE->users())
+ ComputeDependencies(CEUser, LocalDeps);
+ Deps.insert(LocalDeps.begin(), LocalDeps.end());
+ }
+ }
+}
+
+void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) {
+ SmallPtrSet<GlobalValue *, 8> Deps;
+ for (User *User : GV.users())
+ ComputeDependencies(User, Deps);
+ Deps.erase(&GV); // Remove self-reference.
+ for (GlobalValue *GVU : Deps) {
+ GVDependencies.insert(std::make_pair(GVU, &GV));
+ }
+}
+
+/// Mark Global value as Live
+void GlobalDCEPass::MarkLive(GlobalValue &GV,
+ SmallVectorImpl<GlobalValue *> *Updates) {
+ auto const Ret = AliveGlobals.insert(&GV);
+ if (!Ret.second)
+ return;
+
+ if (Updates)
+ Updates->push_back(&GV);
+ if (Comdat *C = GV.getComdat()) {
+ for (auto &&CM : make_range(ComdatMembers.equal_range(C)))
+ MarkLive(*CM.second, Updates); // Recursion depth is only two because only
+ // globals in the same comdat are visited.
+ }
+}
+
+PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
bool Changed = false;
+ // The algorithm first computes the set L of global variables that are
+ // trivially live. Then it walks the initialization of these variables to
+ // compute the globals used to initialize them, which effectively builds a
+ // directed graph where nodes are global variables, and an edge from A to B
+ // means B is used to initialize A. Finally, it propagates the liveness
+ // information through the graph starting from the nodes in L. Nodes note
+ // marked as alive are discarded.
+
// Remove empty functions from the global ctors list.
Changed |= optimizeGlobalCtorsList(M, isEmptyFunction);
@@ -103,21 +168,39 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {
// initializer.
if (!GO.isDeclaration() && !GO.hasAvailableExternallyLinkage())
if (!GO.isDiscardableIfUnused())
- GlobalIsNeeded(&GO);
+ MarkLive(GO);
+
+ UpdateGVDependencies(GO);
}
+ // Compute direct dependencies of aliases.
for (GlobalAlias &GA : M.aliases()) {
Changed |= RemoveUnusedGlobalValue(GA);
// Externally visible aliases are needed.
if (!GA.isDiscardableIfUnused())
- GlobalIsNeeded(&GA);
+ MarkLive(GA);
+
+ UpdateGVDependencies(GA);
}
+ // Compute direct dependencies of ifuncs.
for (GlobalIFunc &GIF : M.ifuncs()) {
Changed |= RemoveUnusedGlobalValue(GIF);
// Externally visible ifuncs are needed.
if (!GIF.isDiscardableIfUnused())
- GlobalIsNeeded(&GIF);
+ MarkLive(GIF);
+
+ UpdateGVDependencies(GIF);
+ }
+
+ // Propagate liveness from collected Global Values through the computed
+ // dependencies.
+ SmallVector<GlobalValue *, 8> NewLiveGVs{AliveGlobals.begin(),
+ AliveGlobals.end()};
+ while (!NewLiveGVs.empty()) {
+ GlobalValue *LGV = NewLiveGVs.pop_back_val();
+ for (auto &&GVD : make_range(GVDependencies.equal_range(LGV)))
+ MarkLive(*GVD.second, &NewLiveGVs);
}
// Now that all globals which are needed are in the AliveGlobals set, we loop
@@ -154,7 +237,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {
GA.setAliasee(nullptr);
}
- // The third pass drops targets of ifuncs which are dead...
+ // The fourth pass drops targets of ifuncs which are dead...
std::vector<GlobalIFunc*> DeadIFuncs;
for (GlobalIFunc &GIF : M.ifuncs())
if (!AliveGlobals.count(&GIF)) {
@@ -188,7 +271,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {
// Make sure that all memory is released
AliveGlobals.clear();
- SeenConstants.clear();
+ ConstantDependenciesCache.clear();
+ GVDependencies.clear();
ComdatMembers.clear();
if (Changed)
@@ -196,60 +280,6 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) {
return PreservedAnalyses::all();
}
-/// GlobalIsNeeded - the specific global value as needed, and
-/// recursively mark anything that it uses as also needed.
-void GlobalDCEPass::GlobalIsNeeded(GlobalValue *G) {
- // If the global is already in the set, no need to reprocess it.
- if (!AliveGlobals.insert(G).second)
- return;
-
- if (Comdat *C = G->getComdat()) {
- for (auto &&CM : make_range(ComdatMembers.equal_range(C)))
- GlobalIsNeeded(CM.second);
- }
-
- if (GlobalVariable *GV = dyn_cast<GlobalVariable>(G)) {
- // If this is a global variable, we must make sure to add any global values
- // referenced by the initializer to the alive set.
- if (GV->hasInitializer())
- MarkUsedGlobalsAsNeeded(GV->getInitializer());
- } else if (GlobalIndirectSymbol *GIS = dyn_cast<GlobalIndirectSymbol>(G)) {
- // The target of a global alias or ifunc is needed.
- MarkUsedGlobalsAsNeeded(GIS->getIndirectSymbol());
- } else {
- // Otherwise this must be a function object. We have to scan the body of
- // the function looking for constants and global values which are used as
- // operands. Any operands of these types must be processed to ensure that
- // any globals used will be marked as needed.
- Function *F = cast<Function>(G);
-
- for (Use &U : F->operands())
- MarkUsedGlobalsAsNeeded(cast<Constant>(U.get()));
-
- for (BasicBlock &BB : *F)
- for (Instruction &I : BB)
- for (Use &U : I.operands())
- if (GlobalValue *GV = dyn_cast<GlobalValue>(U))
- GlobalIsNeeded(GV);
- else if (Constant *C = dyn_cast<Constant>(U))
- MarkUsedGlobalsAsNeeded(C);
- }
-}
-
-void GlobalDCEPass::MarkUsedGlobalsAsNeeded(Constant *C) {
- if (GlobalValue *GV = dyn_cast<GlobalValue>(C))
- return GlobalIsNeeded(GV);
-
- // Loop over all of the operands of the constant, adding any globals they
- // use to the list of needed globals.
- for (Use &U : C->operands()) {
- // If we've already processed this constant there's no need to do it again.
- Constant *Op = dyn_cast<Constant>(U);
- if (Op && SeenConstants.insert(Op).second)
- MarkUsedGlobalsAsNeeded(Op);
- }
-}
-
// RemoveUnusedGlobalValue - Loop over all of the uses of the specified
// GlobalValue, looking for the constant pointer ref that may be pointing to it.
// If found, check to see if the constant pointer ref is safe to destroy, and if
diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp
index 5b0d5e3bc01e5..ade4f21ceb524 100644
--- a/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/lib/Transforms/IPO/GlobalOpt.cpp
@@ -1819,12 +1819,14 @@ static bool processInternalGlobal(
GS.AccessingFunction->doesNotRecurse() &&
isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV,
LookupDomTree)) {
+ const DataLayout &DL = GV->getParent()->getDataLayout();
+
DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n");
Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction
->getEntryBlock().begin());
Type *ElemTy = GV->getValueType();
// FIXME: Pass Global's alignment when globals have alignment
- AllocaInst *Alloca = new AllocaInst(ElemTy, nullptr,
+ AllocaInst *Alloca = new AllocaInst(ElemTy, DL.getAllocaAddrSpace(), nullptr,
GV->getName(), &FirstI);
if (!isa<UndefValue>(GV->getInitializer()))
new StoreInst(GV->getInitializer(), Alloca, &FirstI);
@@ -1977,7 +1979,7 @@ static void ChangeCalleesToFastCall(Function *F) {
}
}
-static AttributeSet StripNest(LLVMContext &C, const AttributeSet &Attrs) {
+static AttributeList StripNest(LLVMContext &C, const AttributeList &Attrs) {
for (unsigned i = 0, e = Attrs.getNumSlots(); i != e; ++i) {
unsigned Index = Attrs.getSlotIndex(i);
if (!Attrs.getSlotAttributes(i).hasAttribute(Index, Attribute::Nest))
@@ -2387,7 +2389,7 @@ OptimizeGlobalAliases(Module &M,
}
static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) {
- LibFunc::Func F = LibFunc::cxa_atexit;
+ LibFunc F = LibFunc_cxa_atexit;
if (!TLI->has(F))
return nullptr;
@@ -2396,7 +2398,7 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) {
return nullptr;
// Make sure that the function has the correct prototype.
- if (!TLI->getLibFunc(*Fn, F) || F != LibFunc::cxa_atexit)
+ if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit)
return nullptr;
return Fn;
diff --git a/lib/Transforms/IPO/GlobalSplit.cpp b/lib/Transforms/IPO/GlobalSplit.cpp
index bbbd096e89c0f..4705ebe265ae1 100644
--- a/lib/Transforms/IPO/GlobalSplit.cpp
+++ b/lib/Transforms/IPO/GlobalSplit.cpp
@@ -85,7 +85,16 @@ bool splitGlobal(GlobalVariable &GV) {
uint64_t ByteOffset = cast<ConstantInt>(
cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
->getZExtValue();
- if (ByteOffset < SplitBegin || ByteOffset >= SplitEnd)
+ // Type metadata may be attached one byte after the end of the vtable, for
+ // classes without virtual methods in Itanium ABI. AFAIK, it is never
+ // attached to the first byte of a vtable. Subtract one to get the right
+ // slice.
+ // This is making an assumption that vtable groups are the only kinds of
+ // global variables that !type metadata can be attached to, and that they
+ // are either Itanium ABI vtable groups or contain a single vtable (i.e.
+ // Microsoft ABI vtables).
+ uint64_t AttachedTo = (ByteOffset == 0) ? ByteOffset : ByteOffset - 1;
+ if (AttachedTo < SplitBegin || AttachedTo >= SplitEnd)
continue;
SplitGV->addMetadata(
LLVMContext::MD_type,
diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp
index 916135e33cd50..349807496dc2c 100644
--- a/lib/Transforms/IPO/IPConstantPropagation.cpp
+++ b/lib/Transforms/IPO/IPConstantPropagation.cpp
@@ -136,7 +136,13 @@ static bool PropagateConstantReturn(Function &F) {
// For more details, see GlobalValue::mayBeDerefined.
if (!F.isDefinitionExact())
return false;
-
+
+ // Don't touch naked functions. The may contain asm returning
+ // value we don't see, so we may end up interprocedurally propagating
+ // the return value incorrectly.
+ if (F.hasFnAttribute(Attribute::Naked))
+ return false;
+
// Check to see if this function returns a constant.
SmallVector<Value *,4> RetVals;
StructType *STy = dyn_cast<StructType>(F.getReturnType());
diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp
index 1770445b413f6..50e7cc89a3b32 100644
--- a/lib/Transforms/IPO/InlineSimple.cpp
+++ b/lib/Transforms/IPO/InlineSimple.cpp
@@ -48,7 +48,7 @@ public:
}
explicit SimpleInliner(InlineParams Params)
- : LegacyInlinerBase(ID), Params(Params) {
+ : LegacyInlinerBase(ID), Params(std::move(Params)) {
initializeSimpleInlinerPass(*PassRegistry::getPassRegistry());
}
@@ -61,7 +61,8 @@ public:
[&](Function &F) -> AssumptionCache & {
return ACT->getAssumptionCache(F);
};
- return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, PSI);
+ return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache,
+ /*GetBFI=*/None, PSI);
}
bool runOnSCC(CallGraphSCC &SCC) override;
@@ -92,8 +93,12 @@ Pass *llvm::createFunctionInliningPass(int Threshold) {
}
Pass *llvm::createFunctionInliningPass(unsigned OptLevel,
- unsigned SizeOptLevel) {
- return new SimpleInliner(llvm::getInlineParams(OptLevel, SizeOptLevel));
+ unsigned SizeOptLevel,
+ bool DisableInlineHotCallSite) {
+ auto Param = llvm::getInlineParams(OptLevel, SizeOptLevel);
+ if (DisableInlineHotCallSite)
+ Param.HotCallSiteThreshold = 0;
+ return new SimpleInliner(Param);
}
Pass *llvm::createFunctionInliningPass(InlineParams &Params) {
diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp
index 3f4731c937d17..6c83c99ae3be5 100644
--- a/lib/Transforms/IPO/Inliner.cpp
+++ b/lib/Transforms/IPO/Inliner.cpp
@@ -19,6 +19,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/OptimizationDiagnosticInfo.h"
@@ -260,8 +261,8 @@ static bool InlineCallIfPossible(
/// Return true if inlining of CS can block the caller from being
/// inlined which is proved to be more beneficial. \p IC is the
/// estimated inline cost associated with callsite \p CS.
-/// \p TotalAltCost will be set to the estimated cost of inlining the caller
-/// if \p CS is suppressed for inlining.
+/// \p TotalSecondaryCost will be set to the estimated cost of inlining the
+/// caller if \p CS is suppressed for inlining.
static bool
shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,
int &TotalSecondaryCost,
@@ -288,7 +289,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,
// treating them as truly abstract units etc.
TotalSecondaryCost = 0;
// The candidate cost to be imposed upon the current function.
- int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1);
+ int CandidateCost = IC.getCost() - 1;
// This bool tracks what happens if we do NOT inline C into B.
bool callerWillBeRemoved = Caller->hasLocalLinkage();
// This bool tracks what happens if we DO inline C into B.
@@ -325,7 +326,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC,
// one is set very low by getInlineCost, in anticipation that Caller will
// be removed entirely. We did not account for this above unless there
// is only one caller of Caller.
- if (callerWillBeRemoved && !Caller->use_empty())
+ if (callerWillBeRemoved && !Caller->hasOneUse())
TotalSecondaryCost -= InlineConstants::LastCallToStaticBonus;
if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost())
@@ -342,6 +343,7 @@ static bool shouldInline(CallSite CS,
InlineCost IC = GetInlineCost(CS);
Instruction *Call = CS.getInstruction();
Function *Callee = CS.getCalledFunction();
+ Function *Caller = CS.getCaller();
if (IC.isAlways()) {
DEBUG(dbgs() << " Inlining: cost=always"
@@ -355,19 +357,20 @@ static bool shouldInline(CallSite CS,
if (IC.isNever()) {
DEBUG(dbgs() << " NOT Inlining: cost=never"
<< ", Call: " << *CS.getInstruction() << "\n");
- ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "NeverInline", Call)
- << NV("Callee", Callee)
- << " should never be inlined (cost=never)");
+ ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call)
+ << NV("Callee", Callee) << " not inlined into "
+ << NV("Caller", Caller)
+ << " because it should never be inlined (cost=never)");
return false;
}
- Function *Caller = CS.getCaller();
if (!IC) {
DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost()
<< ", thres=" << (IC.getCostDelta() + IC.getCost())
<< ", Call: " << *CS.getInstruction() << "\n");
- ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call)
- << NV("Callee", Callee) << " too costly to inline (cost="
+ ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call)
+ << NV("Callee", Callee) << " not inlined into "
+ << NV("Caller", Caller) << " because too costly to inline (cost="
<< NV("Cost", IC.getCost()) << ", threshold="
<< NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")");
return false;
@@ -378,8 +381,8 @@ static bool shouldInline(CallSite CS,
DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction()
<< " Cost = " << IC.getCost()
<< ", outer Cost = " << TotalSecondaryCost << '\n');
- ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE,
- "IncreaseCostInOtherContexts", Call)
+ ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts",
+ Call)
<< "Not inlining. Cost of inlining " << NV("Callee", Callee)
<< " increases the cost of inlining " << NV("Caller", Caller)
<< " in other contexts");
@@ -552,16 +555,11 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG,
// If the policy determines that we should inline this function,
// try to do so.
- using namespace ore;
- if (!shouldInline(CS, GetInlineCost, ORE)) {
- ORE.emit(
- OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block)
- << NV("Callee", Callee) << " will not be inlined into "
- << NV("Caller", Caller));
+ if (!shouldInline(CS, GetInlineCost, ORE))
continue;
- }
// Attempt to inline the function.
+ using namespace ore;
if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas,
InlineHistoryID, InsertLifetime, AARGetter,
ImportedFunctionsStats)) {
@@ -638,22 +636,12 @@ bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) {
ACT = &getAnalysis<AssumptionCacheTracker>();
PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
- // We compute dedicated AA results for each function in the SCC as needed. We
- // use a lambda referencing external objects so that they live long enough to
- // be queried, but we re-use them each time.
- Optional<BasicAAResult> BAR;
- Optional<AAResults> AAR;
- auto AARGetter = [&](Function &F) -> AAResults & {
- BAR.emplace(createLegacyPMBasicAAResult(*this, F));
- AAR.emplace(createLegacyPMAAResults(*this, F, *BAR));
- return *AAR;
- };
auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
return ACT->getAssumptionCache(F);
};
return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime,
[this](CallSite CS) { return getInlineCost(CS); },
- AARGetter, ImportedFunctionsStats);
+ LegacyAARGetter(*this), ImportedFunctionsStats);
}
/// Remove now-dead linkonce functions at the end of
@@ -750,9 +738,6 @@ bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG,
PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
CGSCCAnalysisManager &AM, LazyCallGraph &CG,
CGSCCUpdateResult &UR) {
- FunctionAnalysisManager &FAM =
- AM.getResult<FunctionAnalysisManagerCGSCCProxy>(InitialC, CG)
- .getManager();
const ModuleAnalysisManager &MAM =
AM.getResult<ModuleAnalysisManagerCGSCCProxy>(InitialC, CG).getManager();
bool Changed = false;
@@ -761,35 +746,52 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
Module &M = *InitialC.begin()->getFunction().getParent();
ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M);
- std::function<AssumptionCache &(Function &)> GetAssumptionCache =
- [&](Function &F) -> AssumptionCache & {
- return FAM.getResult<AssumptionAnalysis>(F);
- };
-
- // Setup the data structure used to plumb customization into the
- // `InlineFunction` routine.
- InlineFunctionInfo IFI(/*cg=*/nullptr, &GetAssumptionCache);
+ // We use a single common worklist for calls across the entire SCC. We
+ // process these in-order and append new calls introduced during inlining to
+ // the end.
+ //
+ // Note that this particular order of processing is actually critical to
+ // avoid very bad behaviors. Consider *highly connected* call graphs where
+ // each function contains a small amonut of code and a couple of calls to
+ // other functions. Because the LLVM inliner is fundamentally a bottom-up
+ // inliner, it can handle gracefully the fact that these all appear to be
+ // reasonable inlining candidates as it will flatten things until they become
+ // too big to inline, and then move on and flatten another batch.
+ //
+ // However, when processing call edges *within* an SCC we cannot rely on this
+ // bottom-up behavior. As a consequence, with heavily connected *SCCs* of
+ // functions we can end up incrementally inlining N calls into each of
+ // N functions because each incremental inlining decision looks good and we
+ // don't have a topological ordering to prevent explosions.
+ //
+ // To compensate for this, we don't process transitive edges made immediate
+ // by inlining until we've done one pass of inlining across the entire SCC.
+ // Large, highly connected SCCs still lead to some amount of code bloat in
+ // this model, but it is uniformly spread across all the functions in the SCC
+ // and eventually they all become too large to inline, rather than
+ // incrementally maknig a single function grow in a super linear fashion.
+ SmallVector<std::pair<CallSite, int>, 16> Calls;
- auto GetInlineCost = [&](CallSite CS) {
- Function &Callee = *CS.getCalledFunction();
- auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee);
- return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, PSI);
- };
+ // Populate the initial list of calls in this SCC.
+ for (auto &N : InitialC) {
+ // We want to generally process call sites top-down in order for
+ // simplifications stemming from replacing the call with the returned value
+ // after inlining to be visible to subsequent inlining decisions.
+ // FIXME: Using instructions sequence is a really bad way to do this.
+ // Instead we should do an actual RPO walk of the function body.
+ for (Instruction &I : instructions(N.getFunction()))
+ if (auto CS = CallSite(&I))
+ if (Function *Callee = CS.getCalledFunction())
+ if (!Callee->isDeclaration())
+ Calls.push_back({CS, -1});
+ }
+ if (Calls.empty())
+ return PreservedAnalyses::all();
- // We use a worklist of nodes to process so that we can handle if the SCC
- // structure changes and some nodes are no longer part of the current SCC. We
- // also need to use an updatable pointer for the SCC as a consequence.
- SmallVector<LazyCallGraph::Node *, 16> Nodes;
- for (auto &N : InitialC)
- Nodes.push_back(&N);
+ // Capture updatable variables for the current SCC and RefSCC.
auto *C = &InitialC;
auto *RC = &C->getOuterRefSCC();
- // We also use a secondary worklist of call sites within a particular node to
- // allow quickly continuing to inline through newly inlined call sites where
- // possible.
- SmallVector<std::pair<CallSite, int>, 16> Calls;
-
// When inlining a callee produces new call sites, we want to keep track of
// the fact that they were inlined from the callee. This allows us to avoid
// infinite inlining in some obscure cases. To represent this, we use an
@@ -805,34 +807,58 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// defer deleting these to make it easier to handle the call graph updates.
SmallVector<Function *, 4> DeadFunctions;
- do {
- auto &N = *Nodes.pop_back_val();
+ // Loop forward over all of the calls. Note that we cannot cache the size as
+ // inlining can introduce new calls that need to be processed.
+ for (int i = 0; i < (int)Calls.size(); ++i) {
+ // We expect the calls to typically be batched with sequences of calls that
+ // have the same caller, so we first set up some shared infrastructure for
+ // this caller. We also do any pruning we can at this layer on the caller
+ // alone.
+ Function &F = *Calls[i].first.getCaller();
+ LazyCallGraph::Node &N = *CG.lookup(F);
if (CG.lookupSCC(N) != C)
continue;
- Function &F = N.getFunction();
if (F.hasFnAttribute(Attribute::OptimizeNone))
continue;
+ DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n");
+
+ // Get a FunctionAnalysisManager via a proxy for this particular node. We
+ // do this each time we visit a node as the SCC may have changed and as
+ // we're going to mutate this particular function we want to make sure the
+ // proxy is in place to forward any invalidation events. We can use the
+ // manager we get here for looking up results for functions other than this
+ // node however because those functions aren't going to be mutated by this
+ // pass.
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, CG)
+ .getManager();
+ std::function<AssumptionCache &(Function &)> GetAssumptionCache =
+ [&](Function &F) -> AssumptionCache & {
+ return FAM.getResult<AssumptionAnalysis>(F);
+ };
+ auto GetBFI = [&](Function &F) -> BlockFrequencyInfo & {
+ return FAM.getResult<BlockFrequencyAnalysis>(F);
+ };
+
+ auto GetInlineCost = [&](CallSite CS) {
+ Function &Callee = *CS.getCalledFunction();
+ auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee);
+ return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI},
+ PSI);
+ };
+
// Get the remarks emission analysis for the caller.
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- // We want to generally process call sites top-down in order for
- // simplifications stemming from replacing the call with the returned value
- // after inlining to be visible to subsequent inlining decisions. So we
- // walk the function backwards and then process the back of the vector.
- // FIXME: Using reverse is a really bad way to do this. Instead we should
- // do an actual PO walk of the function body.
- for (Instruction &I : reverse(instructions(F)))
- if (auto CS = CallSite(&I))
- if (Function *Callee = CS.getCalledFunction())
- if (!Callee->isDeclaration())
- Calls.push_back({CS, -1});
-
+ // Now process as many calls as we have within this caller in the sequnece.
+ // We bail out as soon as the caller has to change so we can update the
+ // call graph and prepare the context of that new caller.
bool DidInline = false;
- while (!Calls.empty()) {
+ for (; i < (int)Calls.size() && Calls[i].first.getCaller() == &F; ++i) {
int InlineHistoryID;
CallSite CS;
- std::tie(CS, InlineHistoryID) = Calls.pop_back_val();
+ std::tie(CS, InlineHistoryID) = Calls[i];
Function &Callee = *CS.getCalledFunction();
if (InlineHistoryID != -1 &&
@@ -843,6 +869,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
if (!shouldInline(CS, GetInlineCost, ORE))
continue;
+ // Setup the data structure used to plumb customization into the
+ // `InlineFunction` routine.
+ InlineFunctionInfo IFI(
+ /*cg=*/nullptr, &GetAssumptionCache,
+ &FAM.getResult<BlockFrequencyAnalysis>(*(CS.getCaller())),
+ &FAM.getResult<BlockFrequencyAnalysis>(Callee));
+
if (!InlineFunction(CS, IFI))
continue;
DidInline = true;
@@ -870,6 +903,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// made dead by this operation on other functions).
Callee.removeDeadConstantUsers();
if (Callee.use_empty()) {
+ Calls.erase(
+ std::remove_if(Calls.begin() + i + 1, Calls.end(),
+ [&Callee](const std::pair<CallSite, int> &Call) {
+ return Call.first.getCaller() == &Callee;
+ }),
+ Calls.end());
// Clear the body and queue the function itself for deletion when we
// finish inlining and call graph updates.
// Note that after this point, it is an error to do anything other
@@ -882,6 +921,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
}
}
+ // Back the call index up by one to put us in a good position to go around
+ // the outer loop.
+ --i;
+
if (!DidInline)
continue;
Changed = true;
@@ -896,8 +939,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// below.
for (Function *InlinedCallee : InlinedCallees) {
LazyCallGraph::Node &CalleeN = *CG.lookup(*InlinedCallee);
- for (LazyCallGraph::Edge &E : CalleeN)
- RC->insertTrivialRefEdge(N, *E.getNode());
+ for (LazyCallGraph::Edge &E : *CalleeN)
+ RC->insertTrivialRefEdge(N, E.getNode());
}
InlinedCallees.clear();
@@ -908,8 +951,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// re-use the exact same logic for updating the call graph to reflect the
// change..
C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR);
+ DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n");
RC = &C->getOuterRefSCC();
- } while (!Nodes.empty());
+ }
// Now that we've finished inlining all of the calls across this SCC, delete
// all of the trivially dead functions, updating the call graph and the CGSCC
@@ -920,8 +964,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// sets.
for (Function *DeadF : DeadFunctions) {
// Get the necessary information out of the call graph and nuke the
- // function there.
+ // function there. Also, cclear out any cached analyses.
auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF));
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(DeadC, CG)
+ .getManager();
+ FAM.clear(*DeadF);
+ AM.clear(DeadC);
auto &DeadRC = DeadC.getOuterRefSCC();
CG.removeDeadFunction(*DeadF);
diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp
index deb7e819480b4..785207efbe5c8 100644
--- a/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -42,8 +42,6 @@
using namespace llvm;
using namespace lowertypetests;
-using SummaryAction = LowerTypeTestsSummaryAction;
-
#define DEBUG_TYPE "lowertypetests"
STATISTIC(ByteArraySizeBits, "Byte array size in bits");
@@ -57,13 +55,13 @@ static cl::opt<bool> AvoidReuse(
cl::desc("Try to avoid reuse of byte array addresses using aliases"),
cl::Hidden, cl::init(true));
-static cl::opt<SummaryAction> ClSummaryAction(
+static cl::opt<PassSummaryAction> ClSummaryAction(
"lowertypetests-summary-action",
cl::desc("What to do with the summary when running this pass"),
- cl::values(clEnumValN(SummaryAction::None, "none", "Do nothing"),
- clEnumValN(SummaryAction::Import, "import",
+ cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
+ clEnumValN(PassSummaryAction::Import, "import",
"Import typeid resolutions from summary and globals"),
- clEnumValN(SummaryAction::Export, "export",
+ clEnumValN(PassSummaryAction::Export, "export",
"Export typeid resolutions to summary and globals")),
cl::Hidden);
@@ -234,8 +232,8 @@ public:
class LowerTypeTestsModule {
Module &M;
- SummaryAction Action;
- ModuleSummaryIndex *Summary;
+ ModuleSummaryIndex *ExportSummary;
+ const ModuleSummaryIndex *ImportSummary;
bool LinkerSubsectionsViaSymbols;
Triple::ArchType Arch;
@@ -253,15 +251,21 @@ class LowerTypeTestsModule {
// Indirect function call index assignment counter for WebAssembly
uint64_t IndirectIndex = 1;
- // Mapping from type identifiers to the call sites that test them.
- DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites;
+ // Mapping from type identifiers to the call sites that test them, as well as
+ // whether the type identifier needs to be exported to ThinLTO backends as
+ // part of the regular LTO phase of the ThinLTO pipeline (see exportTypeId).
+ struct TypeIdUserInfo {
+ std::vector<CallInst *> CallSites;
+ bool IsExported = false;
+ };
+ DenseMap<Metadata *, TypeIdUserInfo> TypeIdUsers;
/// This structure describes how to lower type tests for a particular type
/// identifier. It is either built directly from the global analysis (during
/// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type
/// identifier summaries and external symbol references (in ThinLTO backends).
struct TypeIdLowering {
- TypeTestResolution::Kind TheKind;
+ TypeTestResolution::Kind TheKind = TypeTestResolution::Unsat;
/// All except Unsat: the start address within the combined global.
Constant *OffsetedGlobal;
@@ -274,9 +278,6 @@ class LowerTypeTestsModule {
/// covering members of this type identifier as a multiple of 2^AlignLog2.
Constant *SizeM1;
- /// ByteArray, Inline, AllOnes: range of SizeM1 expressed as a bit width.
- unsigned SizeM1BitWidth;
-
/// ByteArray: the byte array to test the address against.
Constant *TheByteArray;
@@ -291,6 +292,10 @@ class LowerTypeTestsModule {
Function *WeakInitializerFn = nullptr;
+ void exportTypeId(StringRef TypeId, const TypeIdLowering &TIL);
+ TypeIdLowering importTypeId(StringRef TypeId);
+ void importTypeTest(CallInst *CI);
+
BitSetInfo
buildBitSet(Metadata *TypeId,
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
@@ -327,8 +332,8 @@ class LowerTypeTestsModule {
void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);
public:
- LowerTypeTestsModule(Module &M, SummaryAction Action,
- ModuleSummaryIndex *Summary);
+ LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary);
bool lower();
// Lower the module using the action and summary passed as command line
@@ -341,15 +346,17 @@ struct LowerTypeTests : public ModulePass {
bool UseCommandLine = false;
- SummaryAction Action;
- ModuleSummaryIndex *Summary;
+ ModuleSummaryIndex *ExportSummary;
+ const ModuleSummaryIndex *ImportSummary;
LowerTypeTests() : ModulePass(ID), UseCommandLine(true) {
initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
}
- LowerTypeTests(SummaryAction Action, ModuleSummaryIndex *Summary)
- : ModulePass(ID), Action(Action), Summary(Summary) {
+ LowerTypeTests(ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary)
+ : ModulePass(ID), ExportSummary(ExportSummary),
+ ImportSummary(ImportSummary) {
initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
}
@@ -358,7 +365,7 @@ struct LowerTypeTests : public ModulePass {
return false;
if (UseCommandLine)
return LowerTypeTestsModule::runForTesting(M);
- return LowerTypeTestsModule(M, Action, Summary).lower();
+ return LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower();
}
};
@@ -368,9 +375,10 @@ INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,
false)
char LowerTypeTests::ID = 0;
-ModulePass *llvm::createLowerTypeTestsPass(SummaryAction Action,
- ModuleSummaryIndex *Summary) {
- return new LowerTypeTests(Action, Summary);
+ModulePass *
+llvm::createLowerTypeTestsPass(ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary) {
+ return new LowerTypeTests(ExportSummary, ImportSummary);
}
/// Build a bit set for TypeId using the object layouts in
@@ -494,10 +502,11 @@ Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B,
return createMaskedBitTest(B, TIL.InlineBits, BitOffset);
} else {
Constant *ByteArray = TIL.TheByteArray;
- if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
+ if (!LinkerSubsectionsViaSymbols && AvoidReuse && !ImportSummary) {
// Each use of the byte array uses a different alias. This makes the
// backend less likely to reuse previously computed byte array addresses,
// improving the security of the CFI mechanism based on this pass.
+ // This won't work when importing because TheByteArray is external.
ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage,
"bits_use", ByteArray, &M);
}
@@ -593,8 +602,7 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
IntPtrTy));
Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
- Constant *BitSizeConst = ConstantExpr::getZExt(TIL.SizeM1, IntPtrTy);
- Value *OffsetInRange = B.CreateICmpULE(BitOffset, BitSizeConst);
+ Value *OffsetInRange = B.CreateICmpULE(BitOffset, TIL.SizeM1);
// If the bit set is all ones, testing against it is unnecessary.
if (TIL.TheKind == TypeTestResolution::AllOnes)
@@ -687,6 +695,123 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(
}
}
+/// Export the given type identifier so that ThinLTO backends may import it.
+/// Type identifiers are exported by adding coarse-grained information about how
+/// to test the type identifier to the summary, and creating symbols in the
+/// object file (aliases and absolute symbols) containing fine-grained
+/// information about the type identifier.
+void LowerTypeTestsModule::exportTypeId(StringRef TypeId,
+ const TypeIdLowering &TIL) {
+ TypeTestResolution &TTRes =
+ ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes;
+ TTRes.TheKind = TIL.TheKind;
+
+ auto ExportGlobal = [&](StringRef Name, Constant *C) {
+ GlobalAlias *GA =
+ GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
+ "__typeid_" + TypeId + "_" + Name, C, &M);
+ GA->setVisibility(GlobalValue::HiddenVisibility);
+ };
+
+ if (TIL.TheKind != TypeTestResolution::Unsat)
+ ExportGlobal("global_addr", TIL.OffsetedGlobal);
+
+ if (TIL.TheKind == TypeTestResolution::ByteArray ||
+ TIL.TheKind == TypeTestResolution::Inline ||
+ TIL.TheKind == TypeTestResolution::AllOnes) {
+ ExportGlobal("align", ConstantExpr::getIntToPtr(TIL.AlignLog2, Int8PtrTy));
+ ExportGlobal("size_m1", ConstantExpr::getIntToPtr(TIL.SizeM1, Int8PtrTy));
+
+ uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1;
+ if (TIL.TheKind == TypeTestResolution::Inline)
+ TTRes.SizeM1BitWidth = (BitSize <= 32) ? 5 : 6;
+ else
+ TTRes.SizeM1BitWidth = (BitSize <= 128) ? 7 : 32;
+ }
+
+ if (TIL.TheKind == TypeTestResolution::ByteArray) {
+ ExportGlobal("byte_array", TIL.TheByteArray);
+ ExportGlobal("bit_mask", TIL.BitMask);
+ }
+
+ if (TIL.TheKind == TypeTestResolution::Inline)
+ ExportGlobal("inline_bits",
+ ConstantExpr::getIntToPtr(TIL.InlineBits, Int8PtrTy));
+}
+
+LowerTypeTestsModule::TypeIdLowering
+LowerTypeTestsModule::importTypeId(StringRef TypeId) {
+ const TypeIdSummary *TidSummary = ImportSummary->getTypeIdSummary(TypeId);
+ if (!TidSummary)
+ return {}; // Unsat: no globals match this type id.
+ const TypeTestResolution &TTRes = TidSummary->TTRes;
+
+ TypeIdLowering TIL;
+ TIL.TheKind = TTRes.TheKind;
+
+ auto ImportGlobal = [&](StringRef Name, unsigned AbsWidth) {
+ Constant *C =
+ M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(), Int8Ty);
+ auto *GV = dyn_cast<GlobalVariable>(C);
+ // We only need to set metadata if the global is newly created, in which
+ // case it would not have hidden visibility.
+ if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility)
+ return C;
+
+ GV->setVisibility(GlobalValue::HiddenVisibility);
+ auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
+ auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
+ auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
+ GV->setMetadata(LLVMContext::MD_absolute_symbol,
+ MDNode::get(M.getContext(), {MinC, MaxC}));
+ };
+ if (AbsWidth == IntPtrTy->getBitWidth())
+ SetAbsRange(~0ull, ~0ull); // Full set.
+ else if (AbsWidth)
+ SetAbsRange(0, 1ull << AbsWidth);
+ return C;
+ };
+
+ if (TIL.TheKind != TypeTestResolution::Unsat)
+ TIL.OffsetedGlobal = ImportGlobal("global_addr", 0);
+
+ if (TIL.TheKind == TypeTestResolution::ByteArray ||
+ TIL.TheKind == TypeTestResolution::Inline ||
+ TIL.TheKind == TypeTestResolution::AllOnes) {
+ TIL.AlignLog2 = ConstantExpr::getPtrToInt(ImportGlobal("align", 8), Int8Ty);
+ TIL.SizeM1 = ConstantExpr::getPtrToInt(
+ ImportGlobal("size_m1", TTRes.SizeM1BitWidth), IntPtrTy);
+ }
+
+ if (TIL.TheKind == TypeTestResolution::ByteArray) {
+ TIL.TheByteArray = ImportGlobal("byte_array", 0);
+ TIL.BitMask = ImportGlobal("bit_mask", 8);
+ }
+
+ if (TIL.TheKind == TypeTestResolution::Inline)
+ TIL.InlineBits = ConstantExpr::getPtrToInt(
+ ImportGlobal("inline_bits", 1 << TTRes.SizeM1BitWidth),
+ TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty);
+
+ return TIL;
+}
+
+void LowerTypeTestsModule::importTypeTest(CallInst *CI) {
+ auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
+ if (!TypeIdMDVal)
+ report_fatal_error("Second argument of llvm.type.test must be metadata");
+
+ auto TypeIdStr = dyn_cast<MDString>(TypeIdMDVal->getMetadata());
+ if (!TypeIdStr)
+ report_fatal_error(
+ "Second argument of llvm.type.test must be a metadata string");
+
+ TypeIdLowering TIL = importTypeId(TypeIdStr->getString());
+ Value *Lowered = lowerTypeTestCall(TypeIdStr, CI, TIL);
+ CI->replaceAllUsesWith(Lowered);
+ CI->eraseFromParent();
+}
+
void LowerTypeTestsModule::lowerTypeTestCalls(
ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
@@ -708,16 +833,12 @@ void LowerTypeTestsModule::lowerTypeTestCalls(
TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr(
Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)),
TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2);
+ TIL.SizeM1 = ConstantInt::get(IntPtrTy, BSI.BitSize - 1);
if (BSI.isAllOnes()) {
TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single
: TypeTestResolution::AllOnes;
- TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32;
- TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty,
- BSI.BitSize - 1);
} else if (BSI.BitSize <= 64) {
TIL.TheKind = TypeTestResolution::Inline;
- TIL.SizeM1BitWidth = (BSI.BitSize <= 32) ? 5 : 6;
- TIL.SizeM1 = ConstantInt::get(Int8Ty, BSI.BitSize - 1);
uint64_t InlineBits = 0;
for (auto Bit : BSI.Bits)
InlineBits |= uint64_t(1) << Bit;
@@ -728,17 +849,19 @@ void LowerTypeTestsModule::lowerTypeTestCalls(
(BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits);
} else {
TIL.TheKind = TypeTestResolution::ByteArray;
- TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32;
- TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty,
- BSI.BitSize - 1);
++NumByteArraysCreated;
ByteArrayInfo *BAI = createByteArray(BSI);
TIL.TheByteArray = BAI->ByteArray;
TIL.BitMask = BAI->MaskGlobal;
}
+ TypeIdUserInfo &TIUI = TypeIdUsers[TypeId];
+
+ if (TIUI.IsExported)
+ exportTypeId(cast<MDString>(TypeId)->getString(), TIL);
+
// Lower each call to llvm.type.test for this type identifier.
- for (CallInst *CI : TypeTestCallSites[TypeId]) {
+ for (CallInst *CI : TIUI.CallSites) {
++NumTypeTestCallsLowered;
Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
CI->replaceAllUsesWith(Lowered);
@@ -757,9 +880,9 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
report_fatal_error(
"A member of a type identifier may not have an explicit section");
- if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker())
- report_fatal_error(
- "A global var member of a type identifier must be a definition");
+ // FIXME: We previously checked that global var member of a type identifier
+ // must be a definition, but the IR linker may leave type metadata on
+ // declarations. We should restore this check after fixing PR31759.
auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
if (!OffsetConstMD)
@@ -1173,13 +1296,11 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
}
/// Lower all type tests in this module.
-LowerTypeTestsModule::LowerTypeTestsModule(Module &M, SummaryAction Action,
- ModuleSummaryIndex *Summary)
- : M(M), Action(Action), Summary(Summary) {
- // FIXME: Use these fields.
- (void)this->Action;
- (void)this->Summary;
-
+LowerTypeTestsModule::LowerTypeTestsModule(
+ Module &M, ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary)
+ : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary) {
+ assert(!(ExportSummary && ImportSummary));
Triple TargetTriple(M.getTargetTriple());
LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();
Arch = TargetTriple.getArch();
@@ -1203,7 +1324,11 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {
ExitOnErr(errorCodeToError(In.error()));
}
- bool Changed = LowerTypeTestsModule(M, ClSummaryAction, &Summary).lower();
+ bool Changed =
+ LowerTypeTestsModule(
+ M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
+ ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
+ .lower();
if (!ClWriteSummary.empty()) {
ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary +
@@ -1222,9 +1347,18 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {
bool LowerTypeTestsModule::lower() {
Function *TypeTestFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
- if (!TypeTestFunc || TypeTestFunc->use_empty())
+ if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary)
return false;
+ if (ImportSummary) {
+ for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end();
+ UI != UE;) {
+ auto *CI = cast<CallInst>((*UI++).getUser());
+ importTypeTest(CI);
+ }
+ return true;
+ }
+
// Equivalence class set containing type identifiers and the globals that
// reference them. This is used to partition the set of type identifiers in
// the module into disjoint sets.
@@ -1248,6 +1382,9 @@ bool LowerTypeTestsModule::lower() {
unsigned I = 0;
SmallVector<MDNode *, 2> Types;
for (GlobalObject &GO : M.global_objects()) {
+ if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker())
+ continue;
+
Types.clear();
GO.getMetadata(LLVMContext::MD_type, Types);
if (Types.empty())
@@ -1262,33 +1399,57 @@ bool LowerTypeTestsModule::lower() {
}
}
- for (const Use &U : TypeTestFunc->uses()) {
- auto CI = cast<CallInst>(U.getUser());
+ auto AddTypeIdUse = [&](Metadata *TypeId) -> TypeIdUserInfo & {
+ // Add the call site to the list of call sites for this type identifier. We
+ // also use TypeIdUsers to keep track of whether we have seen this type
+ // identifier before. If we have, we don't need to re-add the referenced
+ // globals to the equivalence class.
+ auto Ins = TypeIdUsers.insert({TypeId, {}});
+ if (Ins.second) {
+ // Add the type identifier to the equivalence class.
+ GlobalClassesTy::iterator GCI = GlobalClasses.insert(TypeId);
+ GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
+
+ // Add the referenced globals to the type identifier's equivalence class.
+ for (GlobalTypeMember *GTM : TypeIdInfo[TypeId].RefGlobals)
+ CurSet = GlobalClasses.unionSets(
+ CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM)));
+ }
+
+ return Ins.first->second;
+ };
- auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
- if (!BitSetMDVal)
- report_fatal_error("Second argument of llvm.type.test must be metadata");
- auto BitSet = BitSetMDVal->getMetadata();
+ if (TypeTestFunc) {
+ for (const Use &U : TypeTestFunc->uses()) {
+ auto CI = cast<CallInst>(U.getUser());
- // Add the call site to the list of call sites for this type identifier. We
- // also use TypeTestCallSites to keep track of whether we have seen this
- // type identifier before. If we have, we don't need to re-add the
- // referenced globals to the equivalence class.
- std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool>
- Ins = TypeTestCallSites.insert(
- std::make_pair(BitSet, std::vector<CallInst *>()));
- Ins.first->second.push_back(CI);
- if (!Ins.second)
- continue;
+ auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
+ if (!TypeIdMDVal)
+ report_fatal_error("Second argument of llvm.type.test must be metadata");
+ auto TypeId = TypeIdMDVal->getMetadata();
+ AddTypeIdUse(TypeId).CallSites.push_back(CI);
+ }
+ }
- // Add the type identifier to the equivalence class.
- GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet);
- GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
+ if (ExportSummary) {
+ DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
+ for (auto &P : TypeIdInfo) {
+ if (auto *TypeId = dyn_cast<MDString>(P.first))
+ MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
+ TypeId);
+ }
- // Add the referenced globals to the type identifier's equivalence class.
- for (GlobalTypeMember *GTM : TypeIdInfo[BitSet].RefGlobals)
- CurSet = GlobalClasses.unionSets(
- CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM)));
+ for (auto &P : *ExportSummary) {
+ for (auto &S : P.second) {
+ auto *FS = dyn_cast<FunctionSummary>(S.get());
+ if (!FS)
+ continue;
+ // FIXME: Only add live functions.
+ for (GlobalValue::GUID G : FS->type_tests())
+ for (Metadata *MD : MetadataByGUID[G])
+ AddTypeIdUse(MD).IsExported = true;
+ }
+ }
}
if (GlobalClasses.empty())
@@ -1349,8 +1510,9 @@ bool LowerTypeTestsModule::lower() {
PreservedAnalyses LowerTypeTestsPass::run(Module &M,
ModuleAnalysisManager &AM) {
- bool Changed =
- LowerTypeTestsModule(M, SummaryAction::None, /*Summary=*/nullptr).lower();
+ bool Changed = LowerTypeTestsModule(M, /*ExportSummary=*/nullptr,
+ /*ImportSummary=*/nullptr)
+ .lower();
if (!Changed)
return PreservedAnalyses::all();
return PreservedAnalyses::none();
diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp
index e0bb0eb42b590..771770ddc0600 100644
--- a/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/lib/Transforms/IPO/MergeFunctions.cpp
@@ -96,8 +96,10 @@
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ValueHandle.h"
@@ -127,6 +129,26 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck(
"'0' disables this check. Works only with '-debug' key."),
cl::init(0), cl::Hidden);
+// Under option -mergefunc-preserve-debug-info we:
+// - Do not create a new function for a thunk.
+// - Retain the debug info for a thunk's parameters (and associated
+// instructions for the debug info) from the entry block.
+// Note: -debug will display the algorithm at work.
+// - Create debug-info for the call (to the shared implementation) made by
+// a thunk and its return value.
+// - Erase the rest of the function, retaining the (minimally sized) entry
+// block to create a thunk.
+// - Preserve a thunk's call site to point to the thunk even when both occur
+// within the same translation unit, to aid debugability. Note that this
+// behaviour differs from the underlying -mergefunc implementation which
+// modifies the thunk's call site to point to the shared implementation
+// when both occur within the same translation unit.
+static cl::opt<bool>
+ MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden,
+ cl::init(false),
+ cl::desc("Preserve debug info in thunk when mergefunc "
+ "transformations are made."));
+
namespace {
class FunctionNode {
@@ -215,8 +237,21 @@ private:
/// Replace G with a thunk or an alias to F. Deletes G.
void writeThunkOrAlias(Function *F, Function *G);
- /// Replace G with a simple tail call to bitcast(F). Also replace direct uses
- /// of G with bitcast(F). Deletes G.
+ /// Fill PDIUnrelatedWL with instructions from the entry block that are
+ /// unrelated to parameter related debug info.
+ void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock,
+ std::vector<Instruction *> &PDIUnrelatedWL);
+
+ /// Erase the rest of the CFG (i.e. barring the entry block).
+ void eraseTail(Function *G);
+
+ /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
+ /// parameter debug info, from the entry block.
+ void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL);
+
+ /// Replace G with a simple tail call to bitcast(F). Also (unless
+ /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
+ /// delete G.
void writeThunk(Function *F, Function *G);
/// Replace G with an alias to F. Deletes G.
@@ -269,8 +304,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
if (Res1 != -Res2) {
dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber
<< "\n";
- F1->dump();
- F2->dump();
+ dbgs() << *F1 << '\n' << *F2 << '\n';
Valid = false;
}
@@ -305,9 +339,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
<< TripleNumber << "\n";
dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", "
<< Res4 << "\n";
- F1->dump();
- F2->dump();
- F3->dump();
+ dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n';
Valid = false;
}
}
@@ -400,19 +432,15 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
// Transferring other attributes may help other optimizations, but that
// should be done uniformly and not in this ad-hoc way.
auto &Context = New->getContext();
- auto NewFuncAttrs = New->getAttributes();
- auto CallSiteAttrs = CS.getAttributes();
-
- CallSiteAttrs = CallSiteAttrs.addAttributes(
- Context, AttributeSet::ReturnIndex, NewFuncAttrs.getRetAttributes());
-
- for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) {
- AttributeSet Attrs = NewFuncAttrs.getParamAttributes(argIdx);
- if (Attrs.getNumSlots())
- CallSiteAttrs = CallSiteAttrs.addAttributes(Context, argIdx, Attrs);
- }
-
- CS.setAttributes(CallSiteAttrs);
+ auto NewPAL = New->getAttributes();
+ SmallVector<AttributeSet, 4> NewArgAttrs;
+ for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++)
+ NewArgAttrs.push_back(NewPAL.getParamAttributes(argIdx));
+ // Don't transfer attributes from the function to the callee. Function
+ // attributes typically aren't relevant to the calling convention or ABI.
+ CS.setAttributes(AttributeList::get(Context, /*FnAttrs=*/AttributeSet(),
+ NewPAL.getRetAttributes(),
+ NewArgAttrs));
remove(CS.getInstruction()->getParent()->getParent());
U->set(BitcastNew);
@@ -461,51 +489,242 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
return Builder.CreateBitCast(V, DestTy);
}
-// Replace G with a simple tail call to bitcast(F). Also replace direct uses
-// of G with bitcast(F). Deletes G.
+// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
+// parameter debug info, from the entry block.
+void MergeFunctions::eraseInstsUnrelatedToPDI(
+ std::vector<Instruction *> &PDIUnrelatedWL) {
+
+ DEBUG(dbgs() << " Erasing instructions (in reverse order of appearance in "
+ "entry block) unrelated to parameter debug info from entry "
+ "block: {\n");
+ while (!PDIUnrelatedWL.empty()) {
+ Instruction *I = PDIUnrelatedWL.back();
+ DEBUG(dbgs() << " Deleting Instruction: ");
+ DEBUG(I->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ I->eraseFromParent();
+ PDIUnrelatedWL.pop_back();
+ }
+ DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter "
+ "debug info from entry block. \n");
+}
+
+// Reduce G to its entry block.
+void MergeFunctions::eraseTail(Function *G) {
+
+ std::vector<BasicBlock *> WorklistBB;
+ for (Function::iterator BBI = std::next(G->begin()), BBE = G->end();
+ BBI != BBE; ++BBI) {
+ BBI->dropAllReferences();
+ WorklistBB.push_back(&*BBI);
+ }
+ while (!WorklistBB.empty()) {
+ BasicBlock *BB = WorklistBB.back();
+ BB->eraseFromParent();
+ WorklistBB.pop_back();
+ }
+}
+
+// We are interested in the following instructions from the entry block as being
+// related to parameter debug info:
+// - @llvm.dbg.declare
+// - stores from the incoming parameters to locations on the stack-frame
+// - allocas that create these locations on the stack-frame
+// - @llvm.dbg.value
+// - the entry block's terminator
+// The rest are unrelated to debug info for the parameters; fill up
+// PDIUnrelatedWL with such instructions.
+void MergeFunctions::filterInstsUnrelatedToPDI(
+ BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) {
+
+ std::set<Instruction *> PDIRelated;
+ for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end();
+ BI != BIE; ++BI) {
+ if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) {
+ DEBUG(dbgs() << " Deciding: ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ DILocalVariable *DILocVar = DVI->getVariable();
+ if (DILocVar->isParameter()) {
+ DEBUG(dbgs() << " Include (parameter): ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIRelated.insert(&*BI);
+ } else {
+ DEBUG(dbgs() << " Delete (!parameter): ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) {
+ DEBUG(dbgs() << " Deciding: ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ DILocalVariable *DILocVar = DDI->getVariable();
+ if (DILocVar->isParameter()) {
+ DEBUG(dbgs() << " Parameter: ");
+ DEBUG(DILocVar->print(dbgs()));
+ AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress());
+ if (AI) {
+ DEBUG(dbgs() << " Processing alloca users: ");
+ DEBUG(dbgs() << "\n");
+ for (User *U : AI->users()) {
+ if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
+ if (Value *Arg = SI->getValueOperand()) {
+ if (dyn_cast<Argument>(Arg)) {
+ DEBUG(dbgs() << " Include: ");
+ DEBUG(AI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIRelated.insert(AI);
+ DEBUG(dbgs() << " Include (parameter): ");
+ DEBUG(SI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIRelated.insert(SI);
+ DEBUG(dbgs() << " Include: ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIRelated.insert(&*BI);
+ } else {
+ DEBUG(dbgs() << " Delete (!parameter): ");
+ DEBUG(SI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ }
+ } else {
+ DEBUG(dbgs() << " Defer: ");
+ DEBUG(U->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ }
+ } else {
+ DEBUG(dbgs() << " Delete (alloca NULL): ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ } else {
+ DEBUG(dbgs() << " Delete (!parameter): ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ } else if (dyn_cast<TerminatorInst>(BI) == GEntryBlock->getTerminator()) {
+ DEBUG(dbgs() << " Will Include Terminator: ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIRelated.insert(&*BI);
+ } else {
+ DEBUG(dbgs() << " Defer: ");
+ DEBUG(BI->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ }
+ DEBUG(dbgs()
+ << " Report parameter debug info related/related instructions: {\n");
+ for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end();
+ BI != BE; ++BI) {
+
+ Instruction *I = &*BI;
+ if (PDIRelated.find(I) == PDIRelated.end()) {
+ DEBUG(dbgs() << " !PDIRelated: ");
+ DEBUG(I->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ PDIUnrelatedWL.push_back(I);
+ } else {
+ DEBUG(dbgs() << " PDIRelated: ");
+ DEBUG(I->print(dbgs()));
+ DEBUG(dbgs() << "\n");
+ }
+ }
+ DEBUG(dbgs() << " }\n");
+}
+
+// Replace G with a simple tail call to bitcast(F). Also (unless
+// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
+// delete G. Under MergeFunctionsPDI, we use G itself for creating
+// the thunk as we preserve the debug info (and associated instructions)
+// from G's entry block pertaining to G's incoming arguments which are
+// passed on as corresponding arguments in the call that G makes to F.
+// For better debugability, under MergeFunctionsPDI, we do not modify G's
+// call sites to point to F even when within the same translation unit.
void MergeFunctions::writeThunk(Function *F, Function *G) {
- if (!G->isInterposable()) {
- // Redirect direct callers of G to F.
+ if (!G->isInterposable() && !MergeFunctionsPDI) {
+ // Redirect direct callers of G to F. (See note on MergeFunctionsPDI
+ // above).
replaceDirectCallers(G, F);
}
// If G was internal then we may have replaced all uses of G with F. If so,
- // stop here and delete G. There's no need for a thunk.
- if (G->hasLocalLinkage() && G->use_empty()) {
+ // stop here and delete G. There's no need for a thunk. (See note on
+ // MergeFunctionsPDI above).
+ if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) {
G->eraseFromParent();
return;
}
- Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "",
- G->getParent());
- BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG);
- IRBuilder<> Builder(BB);
+ BasicBlock *GEntryBlock = nullptr;
+ std::vector<Instruction *> PDIUnrelatedWL;
+ BasicBlock *BB = nullptr;
+ Function *NewG = nullptr;
+ if (MergeFunctionsPDI) {
+ DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new "
+ "function as thunk; retain original: "
+ << G->getName() << "()\n");
+ GEntryBlock = &G->getEntryBlock();
+ DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related "
+ "debug info for "
+ << G->getName() << "() {\n");
+ filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL);
+ GEntryBlock->getTerminator()->eraseFromParent();
+ BB = GEntryBlock;
+ } else {
+ NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "",
+ G->getParent());
+ BB = BasicBlock::Create(F->getContext(), "", NewG);
+ }
+ IRBuilder<> Builder(BB);
+ Function *H = MergeFunctionsPDI ? G : NewG;
SmallVector<Value *, 16> Args;
unsigned i = 0;
FunctionType *FFTy = F->getFunctionType();
- for (Argument & AI : NewG->args()) {
+ for (Argument & AI : H->args()) {
Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
++i;
}
CallInst *CI = Builder.CreateCall(F, Args);
+ ReturnInst *RI = nullptr;
CI->setTailCall();
CI->setCallingConv(F->getCallingConv());
CI->setAttributes(F->getAttributes());
- if (NewG->getReturnType()->isVoidTy()) {
- Builder.CreateRetVoid();
+ if (H->getReturnType()->isVoidTy()) {
+ RI = Builder.CreateRetVoid();
} else {
- Builder.CreateRet(createCast(Builder, CI, NewG->getReturnType()));
+ RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
}
- NewG->copyAttributesFrom(G);
- NewG->takeName(G);
- removeUsers(G);
- G->replaceAllUsesWith(NewG);
- G->eraseFromParent();
+ if (MergeFunctionsPDI) {
+ DISubprogram *DIS = G->getSubprogram();
+ if (DIS) {
+ DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
+ DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS);
+ CI->setDebugLoc(CIDbgLoc);
+ RI->setDebugLoc(RIDbgLoc);
+ } else {
+ DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for "
+ << G->getName() << "()\n");
+ }
+ eraseTail(G);
+ eraseInstsUnrelatedToPDI(PDIUnrelatedWL);
+ DEBUG(dbgs() << "} // End of parameter related debug info filtering for: "
+ << G->getName() << "()\n");
+ } else {
+ NewG->copyAttributesFrom(G);
+ NewG->takeName(G);
+ removeUsers(G);
+ G->replaceAllUsesWith(NewG);
+ G->eraseFromParent();
+ }
- DEBUG(dbgs() << "writeThunk: " << NewG->getName() << '\n');
+ DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n');
++NumThunksWritten;
}
diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp
index 7ef3fc1fc2a7f..a2f6e5639d9d4 100644
--- a/lib/Transforms/IPO/PartialInlining.cpp
+++ b/lib/Transforms/IPO/PartialInlining.cpp
@@ -33,7 +33,7 @@ STATISTIC(NumPartialInlined, "Number of functions partially inlined");
namespace {
struct PartialInlinerImpl {
- PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {}
+ PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(std::move(IFI)) {}
bool run(Module &M);
Function *unswitchFunction(Function *F);
diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp
index 941efb210d1c8..f11b58d1adc48 100644
--- a/lib/Transforms/IPO/PassManagerBuilder.cpp
+++ b/lib/Transforms/IPO/PassManagerBuilder.cpp
@@ -93,10 +93,6 @@ static cl::opt<CFLAAType>
clEnumValN(CFLAAType::Both, "both",
"Enable both variants of CFL-AA")));
-static cl::opt<bool>
-EnableMLSM("mlsm", cl::init(true), cl::Hidden,
- cl::desc("Enable motion of merged load and store"));
-
static cl::opt<bool> EnableLoopInterchange(
"enable-loopinterchange", cl::init(false), cl::Hidden,
cl::desc("Enable the new, experimental LoopInterchange Pass"));
@@ -141,8 +137,8 @@ static cl::opt<int> PreInlineThreshold(
"(default = 75)"));
static cl::opt<bool> EnableGVNHoist(
- "enable-gvn-hoist", cl::init(false), cl::Hidden,
- cl::desc("Enable the GVN hoisting pass"));
+ "enable-gvn-hoist", cl::init(true), cl::Hidden,
+ cl::desc("Enable the GVN hoisting pass (default = on)"));
static cl::opt<bool>
DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false),
@@ -172,6 +168,7 @@ PassManagerBuilder::PassManagerBuilder() {
PGOInstrUse = RunPGOInstrUse;
PrepareForThinLTO = EnablePrepareForThinLTO;
PerformThinLTO = false;
+ DivergentTarget = false;
}
PassManagerBuilder::~PassManagerBuilder() {
@@ -248,8 +245,6 @@ void PassManagerBuilder::populateFunctionPassManager(
FPM.add(createCFGSimplificationPass());
FPM.add(createSROAPass());
FPM.add(createEarlyCSEPass());
- if(EnableGVNHoist)
- FPM.add(createGVNHoistPass());
FPM.add(createLowerExpectIntrinsicPass());
}
@@ -294,6 +289,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses(
// Break up aggregate allocas, using SSAUpdater.
MPM.add(createSROAPass());
MPM.add(createEarlyCSEPass()); // Catch trivial redundancies
+ if (EnableGVNHoist)
+ MPM.add(createGVNHoistPass());
// Speculative execution if the target has divergent branches; otherwise nop.
MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass());
MPM.add(createJumpThreadingPass()); // Thread jumps.
@@ -305,29 +302,34 @@ void PassManagerBuilder::addFunctionSimplificationPasses(
MPM.add(createLibCallsShrinkWrapPass());
addExtensionsToPM(EP_Peephole, MPM);
+ // Optimize memory intrinsic calls based on the profiled size information.
+ if (SizeLevel == 0)
+ MPM.add(createPGOMemOPSizeOptLegacyPass());
+
MPM.add(createTailCallEliminationPass()); // Eliminate tail calls
MPM.add(createCFGSimplificationPass()); // Merge & remove BBs
MPM.add(createReassociatePass()); // Reassociate expressions
// Rotate Loop - disable header duplication at -Oz
MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1));
MPM.add(createLICMPass()); // Hoist loop invariants
- MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3));
+ MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));
MPM.add(createCFGSimplificationPass());
addInstructionCombiningPass(MPM);
MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars
MPM.add(createLoopIdiomPass()); // Recognize idioms like memset.
+ addExtensionsToPM(EP_LateLoopOptimizations, MPM);
MPM.add(createLoopDeletionPass()); // Delete dead loops
+
if (EnableLoopInterchange) {
MPM.add(createLoopInterchangePass()); // Interchange loops
MPM.add(createCFGSimplificationPass());
}
if (!DisableUnrollLoops)
- MPM.add(createSimpleLoopUnrollPass()); // Unroll small loops
+ MPM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops
addExtensionsToPM(EP_LoopOptimizerEnd, MPM);
if (OptLevel > 1) {
- if (EnableMLSM)
- MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds
+ MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds
MPM.add(NewGVN ? createNewGVNPass()
: createGVNPass(DisableGVNLoadPRE)); // Remove redundancies
}
@@ -369,7 +371,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses(
// BBVectorize may have significantly shortened a loop body; unroll again.
if (!DisableUnrollLoops)
- MPM.add(createLoopUnrollPass());
+ MPM.add(createLoopUnrollPass(OptLevel));
}
}
@@ -434,7 +436,16 @@ void PassManagerBuilder::populateModulePassManager(
// earlier in the pass pipeline, here before globalopt. Otherwise imported
// available_externally functions look unreferenced and are removed.
if (PerformThinLTO)
- MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true));
+ MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true,
+ !PGOSampleUse.empty()));
+
+ // For SamplePGO in ThinLTO compile phase, we do not want to unroll loops
+ // as it will change the CFG too much to make the 2nd profile annotation
+ // in backend more difficult.
+ bool PrepareForThinLTOUsingPGOSampleProfile =
+ PrepareForThinLTO && !PGOSampleUse.empty();
+ if (PrepareForThinLTOUsingPGOSampleProfile)
+ DisableUnrollLoops = true;
if (!DisableUnitAtATime) {
// Infer attributes about declarations if possible.
@@ -454,14 +465,18 @@ void PassManagerBuilder::populateModulePassManager(
MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE
}
- if (!PerformThinLTO) {
+ // For SamplePGO in ThinLTO compile phase, we do not want to do indirect
+ // call promotion as it will change the CFG too much to make the 2nd
+ // profile annotation in backend more difficult.
+ if (!PerformThinLTO && !PrepareForThinLTOUsingPGOSampleProfile) {
/// PGO instrumentation is added during the compile phase for ThinLTO, do
/// not run it a second time
addPGOInstrPasses(MPM);
// Indirect call promotion that promotes intra-module targets only.
// For ThinLTO this is done earlier due to interactions with globalopt
// for imported functions.
- MPM.add(createPGOIndirectCallPromotionLegacyPass());
+ MPM.add(
+ createPGOIndirectCallPromotionLegacyPass(false, !PGOSampleUse.empty()));
}
if (EnableNonLTOGlobalsModRef)
@@ -589,7 +604,7 @@ void PassManagerBuilder::populateModulePassManager(
MPM.add(createCorrelatedValuePropagationPass());
addInstructionCombiningPass(MPM);
MPM.add(createLICMPass());
- MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3));
+ MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));
MPM.add(createCFGSimplificationPass());
addInstructionCombiningPass(MPM);
}
@@ -615,16 +630,16 @@ void PassManagerBuilder::populateModulePassManager(
// BBVectorize may have significantly shortened a loop body; unroll again.
if (!DisableUnrollLoops)
- MPM.add(createLoopUnrollPass());
+ MPM.add(createLoopUnrollPass(OptLevel));
}
}
addExtensionsToPM(EP_Peephole, MPM);
- MPM.add(createCFGSimplificationPass());
+ MPM.add(createLateCFGSimplificationPass()); // Switches to lookup tables
addInstructionCombiningPass(MPM);
if (!DisableUnrollLoops) {
- MPM.add(createLoopUnrollPass()); // Unroll small loops
+ MPM.add(createLoopUnrollPass(OptLevel)); // Unroll small loops
// LoopUnroll may generate some redundency to cleanup.
addInstructionCombiningPass(MPM);
@@ -684,7 +699,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {
// left by the earlier promotion pass that promotes intra-module targets.
// This two-step promotion is to save the compile time. For LTO, it should
// produce the same result as if we only do promotion here.
- PM.add(createPGOIndirectCallPromotionLegacyPass(true));
+ PM.add(
+ createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty()));
// Propagate constants at call sites into the functions they call. This
// opens opportunities for globalopt (and inlining) by substituting function
@@ -703,7 +719,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {
PM.add(createGlobalSplitPass());
// Apply whole-program devirtualization and virtual constant propagation.
- PM.add(createWholeProgramDevirtPass());
+ PM.add(createWholeProgramDevirtPass(ExportSummary, nullptr));
// That's all we need at opt level 1.
if (OptLevel == 1)
@@ -759,8 +775,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {
PM.add(createGlobalsAAWrapperPass()); // IP alias analysis.
PM.add(createLICMPass()); // Hoist loop invariants.
- if (EnableMLSM)
- PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds.
+ PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds.
PM.add(NewGVN ? createNewGVNPass()
: createGVNPass(DisableGVNLoadPRE)); // Remove redundancies.
PM.add(createMemCpyOptPass()); // Remove dead memcpys.
@@ -775,11 +790,11 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) {
PM.add(createLoopInterchangePass());
if (!DisableUnrollLoops)
- PM.add(createSimpleLoopUnrollPass()); // Unroll small loops
+ PM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops
PM.add(createLoopVectorizePass(true, LoopVectorize));
// The vectorizer may have significantly shortened a loop body; unroll again.
if (!DisableUnrollLoops)
- PM.add(createLoopUnrollPass());
+ PM.add(createLoopUnrollPass(OptLevel));
// Now that we've optimized loops (in particular loop induction variables),
// we may have exposed more scalar opportunities. Run parts of the scalar
@@ -833,6 +848,23 @@ void PassManagerBuilder::populateThinLTOPassManager(
if (VerifyInput)
PM.add(createVerifierPass());
+ if (ImportSummary) {
+ // These passes import type identifier resolutions for whole-program
+ // devirtualization and CFI. They must run early because other passes may
+ // disturb the specific instruction patterns that these passes look for,
+ // creating dependencies on resolutions that may not appear in the summary.
+ //
+ // For example, GVN may transform the pattern assume(type.test) appearing in
+ // two basic blocks into assume(phi(type.test, type.test)), which would
+ // transform a dependency on a WPD resolution into a dependency on a type
+ // identifier resolution for CFI.
+ //
+ // Also, WPD has access to more precise information than ICP and can
+ // devirtualize more effectively, so it should operate on the IR first.
+ PM.add(createWholeProgramDevirtPass(nullptr, ImportSummary));
+ PM.add(createLowerTypeTestsPass(nullptr, ImportSummary));
+ }
+
populateModulePassManager(PM);
if (VerifyOutput)
@@ -857,8 +889,7 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) {
// Lower type metadata and the type.test intrinsic. This pass supports Clang's
// control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at
// link time if CFI is enabled. The pass does nothing if CFI is disabled.
- PM.add(createLowerTypeTestsPass(LowerTypeTestsSummaryAction::None,
- /*Summary=*/nullptr));
+ PM.add(createLowerTypeTestsPass(ExportSummary, nullptr));
if (OptLevel != 0)
addLateLTOOptimizationPasses(PM);
diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp
index 6a43f8dbac48a..3371de6e3d147 100644
--- a/lib/Transforms/IPO/SampleProfile.cpp
+++ b/lib/Transforms/IPO/SampleProfile.cpp
@@ -35,6 +35,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -43,6 +44,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
+#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/SampleProfReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -50,6 +52,7 @@
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <cctype>
@@ -159,8 +162,11 @@ protected:
ErrorOr<uint64_t> getInstWeight(const Instruction &I);
ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB);
const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const;
+ std::vector<const FunctionSamples *>
+ findIndirectCallFunctionSamples(const Instruction &I) const;
const FunctionSamples *findFunctionSamples(const Instruction &I) const;
- bool inlineHotFunctions(Function &F);
+ bool inlineHotFunctions(Function &F,
+ DenseSet<GlobalValue::GUID> &ImportGUIDs);
void printEdgeWeight(raw_ostream &OS, Edge E);
void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const;
void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB);
@@ -173,7 +179,7 @@ protected:
void buildEdges(Function &F);
bool propagateThroughEdges(Function &F, bool UpdateBlockCount);
void computeDominanceAndLoopInfo(Function &F);
- unsigned getOffset(unsigned L, unsigned H) const;
+ unsigned getOffset(const DILocation *DIL) const;
void clearFunctionData();
/// \brief Map basic blocks to their computed weights.
@@ -326,11 +332,12 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const {
// If there are inlined callsites in this function, count the samples found
// in the respective bodies. However, do not bother counting callees with 0
// total samples, these are callees that were never invoked at runtime.
- for (const auto &I : FS->getCallsiteSamples()) {
- const FunctionSamples *CalleeSamples = &I.second;
- if (callsiteIsHot(FS, CalleeSamples))
- Count += countUsedRecords(CalleeSamples);
- }
+ for (const auto &I : FS->getCallsiteSamples())
+ for (const auto &J : I.second) {
+ const FunctionSamples *CalleeSamples = &J.second;
+ if (callsiteIsHot(FS, CalleeSamples))
+ Count += countUsedRecords(CalleeSamples);
+ }
return Count;
}
@@ -343,11 +350,12 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS) const {
unsigned Count = FS->getBodySamples().size();
// Only count records in hot callsites.
- for (const auto &I : FS->getCallsiteSamples()) {
- const FunctionSamples *CalleeSamples = &I.second;
- if (callsiteIsHot(FS, CalleeSamples))
- Count += countBodyRecords(CalleeSamples);
- }
+ for (const auto &I : FS->getCallsiteSamples())
+ for (const auto &J : I.second) {
+ const FunctionSamples *CalleeSamples = &J.second;
+ if (callsiteIsHot(FS, CalleeSamples))
+ Count += countBodyRecords(CalleeSamples);
+ }
return Count;
}
@@ -362,11 +370,12 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS) const {
Total += I.second.getSamples();
// Only count samples in hot callsites.
- for (const auto &I : FS->getCallsiteSamples()) {
- const FunctionSamples *CalleeSamples = &I.second;
- if (callsiteIsHot(FS, CalleeSamples))
- Total += countBodySamples(CalleeSamples);
- }
+ for (const auto &I : FS->getCallsiteSamples())
+ for (const auto &J : I.second) {
+ const FunctionSamples *CalleeSamples = &J.second;
+ if (callsiteIsHot(FS, CalleeSamples))
+ Total += countBodySamples(CalleeSamples);
+ }
return Total;
}
@@ -398,15 +407,11 @@ void SampleProfileLoader::clearFunctionData() {
CoverageTracker.clear();
}
-/// \brief Returns the offset of lineno \p L to head_lineno \p H
-///
-/// \param L Lineno
-/// \param H Header lineno of the function
-///
-/// \returns offset to the header lineno. 16 bits are used to represent offset.
+/// Returns the line offset to the start line of the subprogram.
/// We assume that a single function will not exceed 65535 LOC.
-unsigned SampleProfileLoader::getOffset(unsigned L, unsigned H) const {
- return (L - H) & 0xffff;
+unsigned SampleProfileLoader::getOffset(const DILocation *DIL) const {
+ return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
+ 0xffff;
}
/// \brief Print the weight of edge \p E on stream \p OS.
@@ -451,8 +456,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS,
/// \param Inst Instruction to query.
///
/// \returns the weight of \p Inst.
-ErrorOr<uint64_t>
-SampleProfileLoader::getInstWeight(const Instruction &Inst) {
+ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
const DebugLoc &DLoc = Inst.getDebugLoc();
if (!DLoc)
return std::error_code();
@@ -470,19 +474,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {
// If a call/invoke instruction is inlined in profile, but not inlined here,
// it means that the inlined callsite has no sample, thus the call
// instruction should have 0 count.
- bool IsCall = isa<CallInst>(Inst) || isa<InvokeInst>(Inst);
- if (IsCall && findCalleeFunctionSamples(Inst))
+ if ((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) &&
+ findCalleeFunctionSamples(Inst))
return 0;
const DILocation *DIL = DLoc;
- unsigned Lineno = DLoc.getLine();
- unsigned HeaderLineno = DIL->getScope()->getSubprogram()->getLine();
-
- uint32_t LineOffset = getOffset(Lineno, HeaderLineno);
- uint32_t Discriminator = DIL->getDiscriminator();
- ErrorOr<uint64_t> R = IsCall
- ? FS->findCallSamplesAt(LineOffset, Discriminator)
- : FS->findSamplesAt(LineOffset, Discriminator);
+ uint32_t LineOffset = getOffset(DIL);
+ uint32_t Discriminator = DIL->getBaseDiscriminator();
+ ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator);
if (R) {
bool FirstMark =
CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get());
@@ -491,13 +490,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {
LLVMContext &Ctx = F->getContext();
emitOptimizationRemark(
Ctx, DEBUG_TYPE, *F, DLoc,
- Twine("Applied ") + Twine(*R) + " samples from profile (offset: " +
- Twine(LineOffset) +
+ Twine("Applied ") + Twine(*R) +
+ " samples from profile (offset: " + Twine(LineOffset) +
((Discriminator) ? Twine(".") + Twine(Discriminator) : "") + ")");
}
- DEBUG(dbgs() << " " << Lineno << "." << DIL->getDiscriminator() << ":"
- << Inst << " (line offset: " << Lineno - HeaderLineno << "."
- << DIL->getDiscriminator() << " - weight: " << R.get()
+ DEBUG(dbgs() << " " << DLoc.getLine() << "."
+ << DIL->getBaseDiscriminator() << ":" << Inst
+ << " (line offset: " << LineOffset << "."
+ << DIL->getBaseDiscriminator() << " - weight: " << R.get()
<< ")\n");
}
return R;
@@ -511,8 +511,7 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) {
/// \param BB The basic block to query.
///
/// \returns the weight for \p BB.
-ErrorOr<uint64_t>
-SampleProfileLoader::getBlockWeight(const BasicBlock *BB) {
+ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) {
uint64_t Max = 0;
bool HasWeight = false;
for (auto &I : BB->getInstList()) {
@@ -565,16 +564,49 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const {
if (!DIL) {
return nullptr;
}
- DISubprogram *SP = DIL->getScope()->getSubprogram();
- if (!SP)
- return nullptr;
+
+ StringRef CalleeName;
+ if (const CallInst *CI = dyn_cast<CallInst>(&Inst))
+ if (Function *Callee = CI->getCalledFunction())
+ CalleeName = Callee->getName();
const FunctionSamples *FS = findFunctionSamples(Inst);
if (FS == nullptr)
return nullptr;
- return FS->findFunctionSamplesAt(LineLocation(
- getOffset(DIL->getLine(), SP->getLine()), DIL->getDiscriminator()));
+ return FS->findFunctionSamplesAt(
+ LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), CalleeName);
+}
+
+/// Returns a vector of FunctionSamples that are the indirect call targets
+/// of \p Inst. The vector is sorted by the total number of samples.
+std::vector<const FunctionSamples *>
+SampleProfileLoader::findIndirectCallFunctionSamples(
+ const Instruction &Inst) const {
+ const DILocation *DIL = Inst.getDebugLoc();
+ std::vector<const FunctionSamples *> R;
+
+ if (!DIL) {
+ return R;
+ }
+
+ const FunctionSamples *FS = findFunctionSamples(Inst);
+ if (FS == nullptr)
+ return R;
+
+ if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt(
+ LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()))) {
+ if (M->size() == 0)
+ return R;
+ for (const auto &NameFS : *M) {
+ R.push_back(&NameFS.second);
+ }
+ std::sort(R.begin(), R.end(),
+ [](const FunctionSamples *L, const FunctionSamples *R) {
+ return L->getTotalSamples() > R->getTotalSamples();
+ });
+ }
+ return R;
}
/// \brief Get the FunctionSamples for an instruction.
@@ -588,23 +620,23 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const {
/// \returns the FunctionSamples pointer to the inlined instance.
const FunctionSamples *
SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const {
- SmallVector<LineLocation, 10> S;
+ SmallVector<std::pair<LineLocation, StringRef>, 10> S;
const DILocation *DIL = Inst.getDebugLoc();
- if (!DIL) {
+ if (!DIL)
return Samples;
- }
+
+ const DILocation *PrevDIL = DIL;
for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) {
- DISubprogram *SP = DIL->getScope()->getSubprogram();
- if (!SP)
- return nullptr;
- S.push_back(LineLocation(getOffset(DIL->getLine(), SP->getLine()),
- DIL->getDiscriminator()));
+ S.push_back(std::make_pair(
+ LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()),
+ PrevDIL->getScope()->getSubprogram()->getLinkageName()));
+ PrevDIL = DIL;
}
if (S.size() == 0)
return Samples;
const FunctionSamples *FS = Samples;
for (int i = S.size() - 1; i >= 0 && FS != nullptr; i--) {
- FS = FS->findFunctionSamplesAt(S[i]);
+ FS = FS->findFunctionSamplesAt(S[i].first, S[i].second);
}
return FS;
}
@@ -614,14 +646,17 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const {
/// Iteratively traverse all callsites of the function \p F, and find if
/// the corresponding inlined instance exists and is hot in profile. If
/// it is hot enough, inline the callsites and adds new callsites of the
-/// callee into the caller.
-///
-/// TODO: investigate the possibility of not invoking InlineFunction directly.
+/// callee into the caller. If the call is an indirect call, first promote
+/// it to direct call. Each indirect call is limited with a single target.
///
/// \param F function to perform iterative inlining.
+/// \param ImportGUIDs a set to be updated to include all GUIDs that come
+/// from a different module but inlined in the profiled binary.
///
/// \returns True if there is any inline happened.
-bool SampleProfileLoader::inlineHotFunctions(Function &F) {
+bool SampleProfileLoader::inlineHotFunctions(
+ Function &F, DenseSet<GlobalValue::GUID> &ImportGUIDs) {
+ DenseSet<Instruction *> PromotedInsns;
bool Changed = false;
LLVMContext &Ctx = F.getContext();
std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&](
@@ -647,18 +682,42 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) {
}
for (auto I : CIS) {
InlineFunctionInfo IFI(nullptr, ACT ? &GetAssumptionCache : nullptr);
- CallSite CS(I);
- Function *CalledFunction = CS.getCalledFunction();
- if (!CalledFunction || !CalledFunction->getSubprogram())
+ Function *CalledFunction = CallSite(I).getCalledFunction();
+ Instruction *DI = I;
+ if (!CalledFunction && !PromotedInsns.count(I) &&
+ CallSite(I).isIndirectCall())
+ for (const auto *FS : findIndirectCallFunctionSamples(*I)) {
+ auto CalleeFunctionName = FS->getName();
+ const char *Reason = "Callee function not available";
+ CalledFunction = F.getParent()->getFunction(CalleeFunctionName);
+ if (CalledFunction && isLegalToPromote(I, CalledFunction, &Reason)) {
+ // The indirect target was promoted and inlined in the profile, as a
+ // result, we do not have profile info for the branch probability.
+ // We set the probability to 80% taken to indicate that the static
+ // call is likely taken.
+ DI = dyn_cast<Instruction>(
+ promoteIndirectCall(I, CalledFunction, 80, 100, false)
+ ->stripPointerCasts());
+ PromotedInsns.insert(I);
+ } else {
+ DEBUG(dbgs() << "\nFailed to promote indirect call to "
+ << CalleeFunctionName << " because " << Reason
+ << "\n");
+ continue;
+ }
+ }
+ if (!CalledFunction || !CalledFunction->getSubprogram()) {
+ findCalleeFunctionSamples(*I)->findImportedFunctions(
+ ImportGUIDs, F.getParent(),
+ Samples->getTotalSamples() * SampleProfileHotThreshold / 100);
continue;
+ }
DebugLoc DLoc = I->getDebugLoc();
- uint64_t NumSamples = findCalleeFunctionSamples(*I)->getTotalSamples();
- if (InlineFunction(CS, IFI)) {
+ if (InlineFunction(CallSite(DI), IFI)) {
LocalChanged = true;
emitOptimizationRemark(Ctx, DEBUG_TYPE, F, DLoc,
Twine("inlined hot callee '") +
- CalledFunction->getName() + "' with " +
- Twine(NumSamples) + " samples into '" +
+ CalledFunction->getName() + "' into '" +
F.getName() + "'");
}
}
@@ -994,6 +1053,26 @@ void SampleProfileLoader::buildEdges(Function &F) {
}
}
+/// Sorts the CallTargetMap \p M by count in descending order and stores the
+/// sorted result in \p Sorted. Returns the total counts.
+static uint64_t SortCallTargets(SmallVector<InstrProfValueData, 2> &Sorted,
+ const SampleRecord::CallTargetMap &M) {
+ Sorted.clear();
+ uint64_t Sum = 0;
+ for (auto I = M.begin(); I != M.end(); ++I) {
+ Sum += I->getValue();
+ Sorted.push_back({Function::getGUID(I->getKey()), I->getValue()});
+ }
+ std::sort(Sorted.begin(), Sorted.end(),
+ [](const InstrProfValueData &L, const InstrProfValueData &R) {
+ if (L.Count == R.Count)
+ return L.Value > R.Value;
+ else
+ return L.Count > R.Count;
+ });
+ return Sum;
+}
+
/// \brief Propagate weights into edges
///
/// The following rules are applied to every block BB in the CFG:
@@ -1015,10 +1094,6 @@ void SampleProfileLoader::propagateWeights(Function &F) {
bool Changed = true;
unsigned I = 0;
- // Add an entry count to the function using the samples gathered
- // at the function entry.
- F.setEntryCount(Samples->getHeadSamples() + 1);
-
// If BB weight is larger than its corresponding loop's header BB weight,
// use the BB weight to replace the loop header BB weight.
for (auto &BI : F) {
@@ -1071,13 +1146,32 @@ void SampleProfileLoader::propagateWeights(Function &F) {
if (BlockWeights[BB]) {
for (auto &I : BB->getInstList()) {
- if (CallInst *CI = dyn_cast<CallInst>(&I)) {
- if (!dyn_cast<IntrinsicInst>(&I)) {
- SmallVector<uint32_t, 1> Weights;
- Weights.push_back(BlockWeights[BB]);
- CI->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Weights));
- }
+ if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
+ continue;
+ CallSite CS(&I);
+ if (!CS.getCalledFunction()) {
+ const DebugLoc &DLoc = I.getDebugLoc();
+ if (!DLoc)
+ continue;
+ const DILocation *DIL = DLoc;
+ uint32_t LineOffset = getOffset(DIL);
+ uint32_t Discriminator = DIL->getBaseDiscriminator();
+
+ const FunctionSamples *FS = findFunctionSamples(I);
+ if (!FS)
+ continue;
+ auto T = FS->findCallTargetMapAt(LineOffset, Discriminator);
+ if (!T || T.get().size() == 0)
+ continue;
+ SmallVector<InstrProfValueData, 2> SortedCallTargets;
+ uint64_t Sum = SortCallTargets(SortedCallTargets, T.get());
+ annotateValueSite(*I.getParent()->getParent()->getParent(), I,
+ SortedCallTargets, Sum, IPVK_IndirectCallTarget,
+ SortedCallTargets.size());
+ } else if (!dyn_cast<IntrinsicInst>(&I)) {
+ SmallVector<uint32_t, 1> Weights;
+ Weights.push_back(BlockWeights[BB]);
+ I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
}
}
}
@@ -1115,9 +1209,13 @@ void SampleProfileLoader::propagateWeights(Function &F) {
}
}
+ uint64_t TempWeight;
// Only set weights if there is at least one non-zero weight.
// In any other case, let the analyzer set weights.
- if (MaxWeight > 0) {
+ // Do not set weights if the weights are present. In ThinLTO, the profile
+ // annotation is done twice. If the first annotation already set the
+ // weights, the second pass does not need to set it.
+ if (MaxWeight > 0 && !TI->extractProfTotalWeight(TempWeight)) {
DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
TI->setMetadata(llvm::LLVMContext::MD_prof,
MDB.createBranchWeights(Weights));
@@ -1228,12 +1326,19 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {
DEBUG(dbgs() << "Line number for the first instruction in " << F.getName()
<< ": " << getFunctionLoc(F) << "\n");
- Changed |= inlineHotFunctions(F);
+ DenseSet<GlobalValue::GUID> ImportGUIDs;
+ Changed |= inlineHotFunctions(F, ImportGUIDs);
// Compute basic block weights.
Changed |= computeBlockWeights(F);
if (Changed) {
+ // Add an entry count to the function using the samples gathered at the
+ // function entry. Also sets the GUIDs that comes from a different
+ // module but inlined in the profiled binary. This is aiming at making
+ // the IR match the profiled binary before annotation.
+ F.setEntryCount(Samples->getHeadSamples() + 1, &ImportGUIDs);
+
// Compute dominance and loop info needed for propagation.
computeDominanceAndLoopInfo(F);
@@ -1329,7 +1434,7 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) {
bool SampleProfileLoader::runOnFunction(Function &F) {
F.setEntryCount(0);
Samples = Reader->getSamplesFor(F);
- if (!Samples->empty())
+ if (Samples && !Samples->empty())
return emitAnnotations(F);
return false;
}
diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp
index 8f6f161428e89..fb64367eef917 100644
--- a/lib/Transforms/IPO/StripSymbols.cpp
+++ b/lib/Transforms/IPO/StripSymbols.cpp
@@ -323,6 +323,14 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {
LiveGVs.insert(GVE);
}
+ std::set<DICompileUnit *> LiveCUs;
+ // Any CU referenced from a subprogram is live.
+ for (DISubprogram *SP : F.subprograms()) {
+ if (SP->getUnit())
+ LiveCUs.insert(SP->getUnit());
+ }
+
+ bool HasDeadCUs = false;
for (DICompileUnit *DIC : F.compile_units()) {
// Create our live global variable list.
bool GlobalVariableChange = false;
@@ -341,6 +349,11 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {
GlobalVariableChange = true;
}
+ if (!LiveGlobalVariables.empty())
+ LiveCUs.insert(DIC);
+ else if (!LiveCUs.count(DIC))
+ HasDeadCUs = true;
+
// If we found dead global variables, replace the current global
// variable list with our new live global variable list.
if (GlobalVariableChange) {
@@ -352,5 +365,16 @@ bool StripDeadDebugInfo::runOnModule(Module &M) {
LiveGlobalVariables.clear();
}
+ if (HasDeadCUs) {
+ // Delete the old node and replace it with a new one
+ NamedMDNode *NMD = M.getOrInsertNamedMetadata("llvm.dbg.cu");
+ NMD->clearOperands();
+ if (!LiveCUs.empty()) {
+ for (DICompileUnit *CU : LiveCUs)
+ NMD->addOperand(CU);
+ }
+ Changed = true;
+ }
+
return Changed;
}
diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
index 3680cfc813a18..65deb82cd2a5f 100644
--- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
+++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
@@ -14,16 +14,21 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/IPO.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/ModuleSummaryAnalysis.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace llvm;
@@ -41,23 +46,14 @@ namespace {
std::string getModuleId(Module *M) {
MD5 Md5;
bool ExportsSymbols = false;
- auto AddGlobal = [&](GlobalValue &GV) {
+ for (auto &GV : M->global_values()) {
if (GV.isDeclaration() || GV.getName().startswith("llvm.") ||
!GV.hasExternalLinkage())
- return;
+ continue;
ExportsSymbols = true;
Md5.update(GV.getName());
Md5.update(ArrayRef<uint8_t>{0});
- };
-
- for (auto &F : *M)
- AddGlobal(F);
- for (auto &GV : M->globals())
- AddGlobal(GV);
- for (auto &GA : M->aliases())
- AddGlobal(GA);
- for (auto &IF : M->ifuncs())
- AddGlobal(IF);
+ }
if (!ExportsSymbols)
return "";
@@ -73,15 +69,21 @@ std::string getModuleId(Module *M) {
// Promote each local-linkage entity defined by ExportM and used by ImportM by
// changing visibility and appending the given ModuleId.
void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) {
- auto PromoteInternal = [&](GlobalValue &ExportGV) {
+ DenseMap<const Comdat *, Comdat *> RenamedComdats;
+ for (auto &ExportGV : ExportM.global_values()) {
if (!ExportGV.hasLocalLinkage())
- return;
+ continue;
- GlobalValue *ImportGV = ImportM.getNamedValue(ExportGV.getName());
+ auto Name = ExportGV.getName();
+ GlobalValue *ImportGV = ImportM.getNamedValue(Name);
if (!ImportGV || ImportGV->use_empty())
- return;
+ continue;
+
+ std::string NewName = (Name + ModuleId).str();
- std::string NewName = (ExportGV.getName() + ModuleId).str();
+ if (const auto *C = ExportGV.getComdat())
+ if (C->getName() == Name)
+ RenamedComdats.try_emplace(C, ExportM.getOrInsertComdat(NewName));
ExportGV.setName(NewName);
ExportGV.setLinkage(GlobalValue::ExternalLinkage);
@@ -89,16 +91,15 @@ void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) {
ImportGV->setName(NewName);
ImportGV->setVisibility(GlobalValue::HiddenVisibility);
- };
+ }
- for (auto &F : ExportM)
- PromoteInternal(F);
- for (auto &GV : ExportM.globals())
- PromoteInternal(GV);
- for (auto &GA : ExportM.aliases())
- PromoteInternal(GA);
- for (auto &IF : ExportM.ifuncs())
- PromoteInternal(IF);
+ if (!RenamedComdats.empty())
+ for (auto &GO : ExportM.global_objects())
+ if (auto *C = GO.getComdat()) {
+ auto Replacement = RenamedComdats.find(C);
+ if (Replacement != RenamedComdats.end())
+ GO.setComdat(Replacement->second);
+ }
}
// Promote all internal (i.e. distinct) type ids used by the module by replacing
@@ -194,24 +195,7 @@ void simplifyExternals(Module &M) {
}
void filterModule(
- Module *M, std::function<bool(const GlobalValue *)> ShouldKeepDefinition) {
- for (Function &F : *M) {
- if (ShouldKeepDefinition(&F))
- continue;
-
- F.deleteBody();
- F.clearMetadata();
- }
-
- for (GlobalVariable &GV : M->globals()) {
- if (ShouldKeepDefinition(&GV))
- continue;
-
- GV.setInitializer(nullptr);
- GV.setLinkage(GlobalValue::ExternalLinkage);
- GV.clearMetadata();
- }
-
+ Module *M, function_ref<bool(const GlobalValue *)> ShouldKeepDefinition) {
for (Module::alias_iterator I = M->alias_begin(), E = M->alias_end();
I != E;) {
GlobalAlias *GA = &*I++;
@@ -219,7 +203,7 @@ void filterModule(
continue;
GlobalObject *GO;
- if (I->getValueType()->isFunctionTy())
+ if (GA->getValueType()->isFunctionTy())
GO = Function::Create(cast<FunctionType>(GA->getValueType()),
GlobalValue::ExternalLinkage, "", M);
else
@@ -231,53 +215,168 @@ void filterModule(
GA->replaceAllUsesWith(GO);
GA->eraseFromParent();
}
+
+ for (Function &F : *M) {
+ if (ShouldKeepDefinition(&F))
+ continue;
+
+ F.deleteBody();
+ F.setComdat(nullptr);
+ F.clearMetadata();
+ }
+
+ for (GlobalVariable &GV : M->globals()) {
+ if (ShouldKeepDefinition(&GV))
+ continue;
+
+ GV.setInitializer(nullptr);
+ GV.setLinkage(GlobalValue::ExternalLinkage);
+ GV.setComdat(nullptr);
+ GV.clearMetadata();
+ }
+}
+
+void forEachVirtualFunction(Constant *C, function_ref<void(Function *)> Fn) {
+ if (auto *F = dyn_cast<Function>(C))
+ return Fn(F);
+ if (isa<GlobalValue>(C))
+ return;
+ for (Value *Op : C->operands())
+ forEachVirtualFunction(cast<Constant>(Op), Fn);
}
// If it's possible to split M into regular and thin LTO parts, do so and write
// a multi-module bitcode file with the two parts to OS. Otherwise, write only a
// regular LTO bitcode file to OS.
-void splitAndWriteThinLTOBitcode(raw_ostream &OS, Module &M) {
+void splitAndWriteThinLTOBitcode(
+ raw_ostream &OS, raw_ostream *ThinLinkOS,
+ function_ref<AAResults &(Function &)> AARGetter, Module &M) {
std::string ModuleId = getModuleId(&M);
if (ModuleId.empty()) {
// We couldn't generate a module ID for this module, just write it out as a
// regular LTO module.
WriteBitcodeToFile(&M, OS);
+ if (ThinLinkOS)
+ // We don't have a ThinLTO part, but still write the module to the
+ // ThinLinkOS if requested so that the expected output file is produced.
+ WriteBitcodeToFile(&M, *ThinLinkOS);
return;
}
promoteTypeIds(M, ModuleId);
- auto IsInMergedM = [&](const GlobalValue *GV) {
- auto *GVar = dyn_cast<GlobalVariable>(GV->getBaseObject());
- if (!GVar)
- return false;
-
+ // Returns whether a global has attached type metadata. Such globals may
+ // participate in CFI or whole-program devirtualization, so they need to
+ // appear in the merged module instead of the thin LTO module.
+ auto HasTypeMetadata = [&](const GlobalObject *GO) {
SmallVector<MDNode *, 1> MDs;
- GVar->getMetadata(LLVMContext::MD_type, MDs);
+ GO->getMetadata(LLVMContext::MD_type, MDs);
return !MDs.empty();
};
+ // Collect the set of virtual functions that are eligible for virtual constant
+ // propagation. Each eligible function must not access memory, must return
+ // an integer of width <=64 bits, must take at least one argument, must not
+ // use its first argument (assumed to be "this") and all arguments other than
+ // the first one must be of <=64 bit integer type.
+ //
+ // Note that we test whether this copy of the function is readnone, rather
+ // than testing function attributes, which must hold for any copy of the
+ // function, even a less optimized version substituted at link time. This is
+ // sound because the virtual constant propagation optimizations effectively
+ // inline all implementations of the virtual function into each call site,
+ // rather than using function attributes to perform local optimization.
+ std::set<const Function *> EligibleVirtualFns;
+ // If any member of a comdat lives in MergedM, put all members of that
+ // comdat in MergedM to keep the comdat together.
+ DenseSet<const Comdat *> MergedMComdats;
+ for (GlobalVariable &GV : M.globals())
+ if (HasTypeMetadata(&GV)) {
+ if (const auto *C = GV.getComdat())
+ MergedMComdats.insert(C);
+ forEachVirtualFunction(GV.getInitializer(), [&](Function *F) {
+ auto *RT = dyn_cast<IntegerType>(F->getReturnType());
+ if (!RT || RT->getBitWidth() > 64 || F->arg_empty() ||
+ !F->arg_begin()->use_empty())
+ return;
+ for (auto &Arg : make_range(std::next(F->arg_begin()), F->arg_end())) {
+ auto *ArgT = dyn_cast<IntegerType>(Arg.getType());
+ if (!ArgT || ArgT->getBitWidth() > 64)
+ return;
+ }
+ if (computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == MAK_ReadNone)
+ EligibleVirtualFns.insert(F);
+ });
+ }
+
ValueToValueMapTy VMap;
- std::unique_ptr<Module> MergedM(CloneModule(&M, VMap, IsInMergedM));
+ std::unique_ptr<Module> MergedM(
+ CloneModule(&M, VMap, [&](const GlobalValue *GV) -> bool {
+ if (const auto *C = GV->getComdat())
+ if (MergedMComdats.count(C))
+ return true;
+ if (auto *F = dyn_cast<Function>(GV))
+ return EligibleVirtualFns.count(F);
+ if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject()))
+ return HasTypeMetadata(GVar);
+ return false;
+ }));
+ StripDebugInfo(*MergedM);
+
+ for (Function &F : *MergedM)
+ if (!F.isDeclaration()) {
+ // Reset the linkage of all functions eligible for virtual constant
+ // propagation. The canonical definitions live in the thin LTO module so
+ // that they can be imported.
+ F.setLinkage(GlobalValue::AvailableExternallyLinkage);
+ F.setComdat(nullptr);
+ }
- filterModule(&M, [&](const GlobalValue *GV) { return !IsInMergedM(GV); });
+ // Remove all globals with type metadata, globals with comdats that live in
+ // MergedM, and aliases pointing to such globals from the thin LTO module.
+ filterModule(&M, [&](const GlobalValue *GV) {
+ if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject()))
+ if (HasTypeMetadata(GVar))
+ return false;
+ if (const auto *C = GV->getComdat())
+ if (MergedMComdats.count(C))
+ return false;
+ return true;
+ });
promoteInternals(*MergedM, M, ModuleId);
promoteInternals(M, *MergedM, ModuleId);
simplifyExternals(*MergedM);
- SmallVector<char, 0> Buffer;
- BitcodeWriter W(Buffer);
// FIXME: Try to re-use BSI and PFI from the original module here.
ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, nullptr);
- W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index,
- /*GenerateHash=*/true);
- W.writeModule(MergedM.get());
+ SmallVector<char, 0> Buffer;
+ BitcodeWriter W(Buffer);
+ // Save the module hash produced for the full bitcode, which will
+ // be used in the backends, and use that in the minimized bitcode
+ // produced for the full link.
+ ModuleHash ModHash = {{0}};
+ W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index,
+ /*GenerateHash=*/true, &ModHash);
+ W.writeModule(MergedM.get());
OS << Buffer;
+
+ // If a minimized bitcode module was requested for the thin link,
+ // strip the debug info (the merged module was already stripped above)
+ // and write it to the given OS.
+ if (ThinLinkOS) {
+ Buffer.clear();
+ BitcodeWriter W2(Buffer);
+ StripDebugInfo(M);
+ W2.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index,
+ /*GenerateHash=*/false, &ModHash);
+ W2.writeModule(MergedM.get());
+ *ThinLinkOS << Buffer;
+ }
}
// Returns whether this module needs to be split because it uses type metadata.
@@ -292,28 +391,45 @@ bool requiresSplit(Module &M) {
return false;
}
-void writeThinLTOBitcode(raw_ostream &OS, Module &M,
- const ModuleSummaryIndex *Index) {
+void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS,
+ function_ref<AAResults &(Function &)> AARGetter,
+ Module &M, const ModuleSummaryIndex *Index) {
// See if this module has any type metadata. If so, we need to split it.
if (requiresSplit(M))
- return splitAndWriteThinLTOBitcode(OS, M);
+ return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M);
// Otherwise we can just write it out as a regular module.
+
+ // Save the module hash produced for the full bitcode, which will
+ // be used in the backends, and use that in the minimized bitcode
+ // produced for the full link.
+ ModuleHash ModHash = {{0}};
WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index,
- /*GenerateHash=*/true);
+ /*GenerateHash=*/true, &ModHash);
+ // If a minimized bitcode module was requested for the thin link,
+ // strip the debug info and write it to the given OS.
+ if (ThinLinkOS) {
+ StripDebugInfo(M);
+ WriteBitcodeToFile(&M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false,
+ Index,
+ /*GenerateHash=*/false, &ModHash);
+ }
}
class WriteThinLTOBitcode : public ModulePass {
raw_ostream &OS; // raw_ostream to print on
+ // The output stream on which to emit a minimized module for use
+ // just in the thin link, if requested.
+ raw_ostream *ThinLinkOS;
public:
static char ID; // Pass identification, replacement for typeid
- WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) {
+ WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()), ThinLinkOS(nullptr) {
initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry());
}
- explicit WriteThinLTOBitcode(raw_ostream &o)
- : ModulePass(ID), OS(o) {
+ explicit WriteThinLTOBitcode(raw_ostream &o, raw_ostream *ThinLinkOS)
+ : ModulePass(ID), OS(o), ThinLinkOS(ThinLinkOS) {
initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry());
}
@@ -322,12 +438,14 @@ public:
bool runOnModule(Module &M) override {
const ModuleSummaryIndex *Index =
&(getAnalysis<ModuleSummaryIndexWrapperPass>().getIndex());
- writeThinLTOBitcode(OS, M, Index);
+ writeThinLTOBitcode(OS, ThinLinkOS, LegacyAARGetter(*this), M, Index);
return true;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
+ AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<ModuleSummaryIndexWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
} // anonymous namespace
@@ -335,10 +453,13 @@ public:
char WriteThinLTOBitcode::ID = 0;
INITIALIZE_PASS_BEGIN(WriteThinLTOBitcode, "write-thinlto-bitcode",
"Write ThinLTO Bitcode", false, true)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(ModuleSummaryIndexWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(WriteThinLTOBitcode, "write-thinlto-bitcode",
"Write ThinLTO Bitcode", false, true)
-ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str) {
- return new WriteThinLTOBitcode(Str);
+ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str,
+ raw_ostream *ThinLinkOS) {
+ return new WriteThinLTOBitcode(Str, ThinLinkOS);
}
diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 844cc0f70eed6..cb7d487b68b0b 100644
--- a/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -25,6 +25,20 @@
// returns 0, or a single vtable's function returns 1, replace each virtual
// call with a comparison of the vptr against that vtable's address.
//
+// This pass is intended to be used during the regular and thin LTO pipelines.
+// During regular LTO, the pass determines the best optimization for each
+// virtual call and applies the resolutions directly to virtual calls that are
+// eligible for virtual call optimization (i.e. calls that use either of the
+// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During
+// ThinLTO, the pass operates in two phases:
+// - Export phase: this is run during the thin link over a single merged module
+// that contains all vtables with !type metadata that participate in the link.
+// The pass computes a resolution for each virtual call and stores it in the
+// type identifier summary.
+// - Import phase: this is run during the thin backends over the individual
+// modules. The pass applies the resolutions previously computed during the
+// import phase to each eligible virtual call.
+//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
@@ -35,6 +49,8 @@
#include "llvm/ADT/iterator_range.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
@@ -54,12 +70,16 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/ModuleSummaryIndexYAML.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
#include "llvm/PassSupport.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Utils/Evaluator.h"
#include <algorithm>
#include <cstddef>
@@ -72,6 +92,26 @@ using namespace wholeprogramdevirt;
#define DEBUG_TYPE "wholeprogramdevirt"
+static cl::opt<PassSummaryAction> ClSummaryAction(
+ "wholeprogramdevirt-summary-action",
+ cl::desc("What to do with the summary when running this pass"),
+ cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
+ clEnumValN(PassSummaryAction::Import, "import",
+ "Import typeid resolutions from summary and globals"),
+ clEnumValN(PassSummaryAction::Export, "export",
+ "Export typeid resolutions to summary and globals")),
+ cl::Hidden);
+
+static cl::opt<std::string> ClReadSummary(
+ "wholeprogramdevirt-read-summary",
+ cl::desc("Read summary from given YAML file before running pass"),
+ cl::Hidden);
+
+static cl::opt<std::string> ClWriteSummary(
+ "wholeprogramdevirt-write-summary",
+ cl::desc("Write summary to given YAML file after running pass"),
+ cl::Hidden);
+
// Find the minimum offset that we may store a value of size Size bits at. If
// IsAfter is set, look for an offset before the object, otherwise look for an
// offset after the object.
@@ -259,15 +299,92 @@ struct VirtualCallSite {
}
};
+// Call site information collected for a specific VTableSlot and possibly a list
+// of constant integer arguments. The grouping by arguments is handled by the
+// VTableSlotInfo class.
+struct CallSiteInfo {
+ /// The set of call sites for this slot. Used during regular LTO and the
+ /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
+ /// call sites that appear in the merged module itself); in each of these
+ /// cases we are directly operating on the call sites at the IR level.
+ std::vector<VirtualCallSite> CallSites;
+
+ // These fields are used during the export phase of ThinLTO and reflect
+ // information collected from function summaries.
+
+ /// Whether any function summary contains an llvm.assume(llvm.type.test) for
+ /// this slot.
+ bool SummaryHasTypeTestAssumeUsers;
+
+ /// CFI-specific: a vector containing the list of function summaries that use
+ /// the llvm.type.checked.load intrinsic and therefore will require
+ /// resolutions for llvm.type.test in order to implement CFI checks if
+ /// devirtualization was unsuccessful. If devirtualization was successful, the
+ /// pass will clear this vector by calling markDevirt(). If at the end of the
+ /// pass the vector is non-empty, we will need to add a use of llvm.type.test
+ /// to each of the function summaries in the vector.
+ std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
+
+ bool isExported() const {
+ return SummaryHasTypeTestAssumeUsers ||
+ !SummaryTypeCheckedLoadUsers.empty();
+ }
+
+ /// As explained in the comment for SummaryTypeCheckedLoadUsers.
+ void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
+};
+
+// Call site information collected for a specific VTableSlot.
+struct VTableSlotInfo {
+ // The set of call sites which do not have all constant integer arguments
+ // (excluding "this").
+ CallSiteInfo CSInfo;
+
+ // The set of call sites with all constant integer arguments (excluding
+ // "this"), grouped by argument list.
+ std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
+
+ void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+
+private:
+ CallSiteInfo &findCallSiteInfo(CallSite CS);
+};
+
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+ std::vector<uint64_t> Args;
+ auto *CI = dyn_cast<IntegerType>(CS.getType());
+ if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
+ return CSInfo;
+ for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+ auto *CI = dyn_cast<ConstantInt>(Arg);
+ if (!CI || CI->getBitWidth() > 64)
+ return CSInfo;
+ Args.push_back(CI->getZExtValue());
+ }
+ return ConstCSInfo[Args];
+}
+
+void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+ unsigned *NumUnsafeUses) {
+ findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
+}
+
struct DevirtModule {
Module &M;
+ function_ref<AAResults &(Function &)> AARGetter;
+
+ ModuleSummaryIndex *ExportSummary;
+ const ModuleSummaryIndex *ImportSummary;
+
IntegerType *Int8Ty;
PointerType *Int8PtrTy;
IntegerType *Int32Ty;
+ IntegerType *Int64Ty;
+ IntegerType *IntPtrTy;
bool RemarksEnabled;
- MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
+ MapVector<VTableSlot, VTableSlotInfo> CallSlots;
// This map keeps track of the number of "unsafe" uses of a loaded function
// pointer. The key is the associated llvm.type.test intrinsic call generated
@@ -279,11 +396,18 @@ struct DevirtModule {
// true.
std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
- DevirtModule(Module &M)
- : M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
+ DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
+ ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary)
+ : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
+ ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())),
- RemarksEnabled(areRemarksEnabled()) {}
+ Int64Ty(Type::getInt64Ty(M.getContext())),
+ IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
+ RemarksEnabled(areRemarksEnabled()) {
+ assert(!(ExportSummary && ImportSummary));
+ }
bool areRemarksEnabled();
@@ -298,57 +422,169 @@ struct DevirtModule {
tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
const std::set<TypeMemberInfo> &TypeMemberInfos,
uint64_t ByteOffset);
+
+ void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
+ bool &IsExported);
bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites);
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res);
+
bool tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- ArrayRef<ConstantInt *> Args);
- bool tryUniformRetValOpt(IntegerType *RetType,
- MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites);
+ ArrayRef<uint64_t> Args);
+
+ void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
+ uint64_t TheRetVal);
+ bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+ CallSiteInfo &CSInfo,
+ WholeProgramDevirtResolution::ByArg *Res);
+
+ // Returns the global symbol name that is used to export information about the
+ // given vtable slot and list of arguments.
+ std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
+ StringRef Name);
+
+ // This function is called during the export phase to create a symbol
+ // definition containing information about the given vtable slot and list of
+ // arguments.
+ void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
+ Constant *C);
+
+ // This function is called during the import phase to create a reference to
+ // the symbol definition created during the export phase.
+ Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
+ StringRef Name, unsigned AbsWidth = 0);
+
+ void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
+ Constant *UniqueMemberAddr);
bool tryUniqueRetValOpt(unsigned BitWidth,
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites);
+ CallSiteInfo &CSInfo,
+ WholeProgramDevirtResolution::ByArg *Res,
+ VTableSlot Slot, ArrayRef<uint64_t> Args);
+
+ void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
+ Constant *Byte, Constant *Bit);
bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- ArrayRef<VirtualCallSite> CallSites);
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot);
void rebuildGlobal(VTableBits &B);
+ // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
+ void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
+
+ // If we were able to eliminate all unsafe uses for a type checked load,
+ // eliminate the associated type tests by replacing them with true.
+ void removeRedundantTypeTests();
+
bool run();
+
+ // Lower the module using the action and summary passed as command line
+ // arguments. For testing purposes only.
+ static bool runForTesting(Module &M,
+ function_ref<AAResults &(Function &)> AARGetter);
};
struct WholeProgramDevirt : public ModulePass {
static char ID;
- WholeProgramDevirt() : ModulePass(ID) {
+ bool UseCommandLine = false;
+
+ ModuleSummaryIndex *ExportSummary;
+ const ModuleSummaryIndex *ImportSummary;
+
+ WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
+ initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
+ }
+
+ WholeProgramDevirt(ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary)
+ : ModulePass(ID), ExportSummary(ExportSummary),
+ ImportSummary(ImportSummary) {
initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
}
bool runOnModule(Module &M) override {
if (skipModule(M))
return false;
+ if (UseCommandLine)
+ return DevirtModule::runForTesting(M, LegacyAARGetter(*this));
+ return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary)
+ .run();
+ }
- return DevirtModule(M).run();
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
} // end anonymous namespace
-INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt",
- "Whole program devirtualization", false, false)
+INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
+ "Whole program devirtualization", false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
+ "Whole program devirtualization", false, false)
char WholeProgramDevirt::ID = 0;
-ModulePass *llvm::createWholeProgramDevirtPass() {
- return new WholeProgramDevirt;
+ModulePass *
+llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
+ const ModuleSummaryIndex *ImportSummary) {
+ return new WholeProgramDevirt(ExportSummary, ImportSummary);
}
PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
- ModuleAnalysisManager &) {
- if (!DevirtModule(M).run())
+ ModuleAnalysisManager &AM) {
+ auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto AARGetter = [&](Function &F) -> AAResults & {
+ return FAM.getResult<AAManager>(F);
+ };
+ if (!DevirtModule(M, AARGetter, nullptr, nullptr).run())
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
+bool DevirtModule::runForTesting(
+ Module &M, function_ref<AAResults &(Function &)> AARGetter) {
+ ModuleSummaryIndex Summary;
+
+ // Handle the command-line summary arguments. This code is for testing
+ // purposes only, so we handle errors directly.
+ if (!ClReadSummary.empty()) {
+ ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
+ ": ");
+ auto ReadSummaryFile =
+ ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
+
+ yaml::Input In(ReadSummaryFile->getBuffer());
+ In >> Summary;
+ ExitOnErr(errorCodeToError(In.error()));
+ }
+
+ bool Changed =
+ DevirtModule(
+ M, AARGetter,
+ ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
+ ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
+ .run();
+
+ if (!ClWriteSummary.empty()) {
+ ExitOnError ExitOnErr(
+ "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
+ std::error_code EC;
+ raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
+ ExitOnErr(errorCodeToError(EC));
+
+ yaml::Output Out(OS);
+ Out << Summary;
+ }
+
+ return Changed;
+}
+
void DevirtModule::buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
@@ -443,9 +679,31 @@ bool DevirtModule::tryFindVirtualCallTargets(
return !TargetsForSlot.empty();
}
+void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
+ Constant *TheFn, bool &IsExported) {
+ auto Apply = [&](CallSiteInfo &CSInfo) {
+ for (auto &&VCallSite : CSInfo.CallSites) {
+ if (RemarksEnabled)
+ VCallSite.emitRemark("single-impl", TheFn->getName());
+ VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
+ TheFn, VCallSite.CS.getCalledValue()->getType()));
+ // This use is no longer unsafe.
+ if (VCallSite.NumUnsafeUses)
+ --*VCallSite.NumUnsafeUses;
+ }
+ if (CSInfo.isExported()) {
+ IsExported = true;
+ CSInfo.markDevirt();
+ }
+ };
+ Apply(SlotInfo.CSInfo);
+ for (auto &P : SlotInfo.ConstCSInfo)
+ Apply(P.second);
+}
+
bool DevirtModule::trySingleImplDevirt(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites) {
+ VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
// See if the program contains a single implementation of this virtual
// function.
Function *TheFn = TargetsForSlot[0].Fn;
@@ -453,39 +711,51 @@ bool DevirtModule::trySingleImplDevirt(
if (TheFn != Target.Fn)
return false;
+ // If so, update each call site to call that implementation directly.
if (RemarksEnabled)
TargetsForSlot[0].WasDevirt = true;
- // If so, update each call site to call that implementation directly.
- for (auto &&VCallSite : CallSites) {
- if (RemarksEnabled)
- VCallSite.emitRemark("single-impl", TheFn->getName());
- VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
- TheFn, VCallSite.CS.getCalledValue()->getType()));
- // This use is no longer unsafe.
- if (VCallSite.NumUnsafeUses)
- --*VCallSite.NumUnsafeUses;
+
+ bool IsExported = false;
+ applySingleImplDevirt(SlotInfo, TheFn, IsExported);
+ if (!IsExported)
+ return false;
+
+ // If the only implementation has local linkage, we must promote to external
+ // to make it visible to thin LTO objects. We can only get here during the
+ // ThinLTO export phase.
+ if (TheFn->hasLocalLinkage()) {
+ TheFn->setLinkage(GlobalValue::ExternalLinkage);
+ TheFn->setVisibility(GlobalValue::HiddenVisibility);
+ TheFn->setName(TheFn->getName() + "$merged");
}
+
+ Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
+ Res->SingleImplName = TheFn->getName();
+
return true;
}
bool DevirtModule::tryEvaluateFunctionsWithArgs(
MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- ArrayRef<ConstantInt *> Args) {
+ ArrayRef<uint64_t> Args) {
// Evaluate each function and store the result in each target's RetVal
// field.
for (VirtualCallTarget &Target : TargetsForSlot) {
if (Target.Fn->arg_size() != Args.size() + 1)
return false;
- for (unsigned I = 0; I != Args.size(); ++I)
- if (Target.Fn->getFunctionType()->getParamType(I + 1) !=
- Args[I]->getType())
- return false;
Evaluator Eval(M.getDataLayout(), nullptr);
SmallVector<Constant *, 2> EvalArgs;
EvalArgs.push_back(
Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
- EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end());
+ for (unsigned I = 0; I != Args.size(); ++I) {
+ auto *ArgTy = dyn_cast<IntegerType>(
+ Target.Fn->getFunctionType()->getParamType(I + 1));
+ if (!ArgTy)
+ return false;
+ EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
+ }
+
Constant *RetVal;
if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
!isa<ConstantInt>(RetVal))
@@ -495,9 +765,18 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs(
return true;
}
+void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
+ uint64_t TheRetVal) {
+ for (auto Call : CSInfo.CallSites)
+ Call.replaceAndErase(
+ "uniform-ret-val", FnName, RemarksEnabled,
+ ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
+ CSInfo.markDevirt();
+}
+
bool DevirtModule::tryUniformRetValOpt(
- IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites) {
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
+ WholeProgramDevirtResolution::ByArg *Res) {
// Uniform return value optimization. If all functions return the same
// constant, replace all calls with that constant.
uint64_t TheRetVal = TargetsForSlot[0].RetVal;
@@ -505,19 +784,77 @@ bool DevirtModule::tryUniformRetValOpt(
if (Target.RetVal != TheRetVal)
return false;
- auto TheRetValConst = ConstantInt::get(RetType, TheRetVal);
- for (auto Call : CallSites)
- Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(),
- RemarksEnabled, TheRetValConst);
+ if (CSInfo.isExported()) {
+ Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
+ Res->Info = TheRetVal;
+ }
+
+ applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
if (RemarksEnabled)
for (auto &&Target : TargetsForSlot)
Target.WasDevirt = true;
return true;
}
+std::string DevirtModule::getGlobalName(VTableSlot Slot,
+ ArrayRef<uint64_t> Args,
+ StringRef Name) {
+ std::string FullName = "__typeid_";
+ raw_string_ostream OS(FullName);
+ OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
+ for (uint64_t Arg : Args)
+ OS << '_' << Arg;
+ OS << '_' << Name;
+ return OS.str();
+}
+
+void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
+ StringRef Name, Constant *C) {
+ GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
+ getGlobalName(Slot, Args, Name), C, &M);
+ GA->setVisibility(GlobalValue::HiddenVisibility);
+}
+
+Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
+ StringRef Name, unsigned AbsWidth) {
+ Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
+ auto *GV = dyn_cast<GlobalVariable>(C);
+ // We only need to set metadata if the global is newly created, in which
+ // case it would not have hidden visibility.
+ if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility)
+ return C;
+
+ GV->setVisibility(GlobalValue::HiddenVisibility);
+ auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
+ auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
+ auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
+ GV->setMetadata(LLVMContext::MD_absolute_symbol,
+ MDNode::get(M.getContext(), {MinC, MaxC}));
+ };
+ if (AbsWidth == IntPtrTy->getBitWidth())
+ SetAbsRange(~0ull, ~0ull); // Full set.
+ else if (AbsWidth)
+ SetAbsRange(0, 1ull << AbsWidth);
+ return GV;
+}
+
+void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
+ bool IsOne,
+ Constant *UniqueMemberAddr) {
+ for (auto &&Call : CSInfo.CallSites) {
+ IRBuilder<> B(Call.CS.getInstruction());
+ Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
+ Call.VTable, UniqueMemberAddr);
+ Cmp = B.CreateZExt(Cmp, Call.CS->getType());
+ Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp);
+ }
+ CSInfo.markDevirt();
+}
+
bool DevirtModule::tryUniqueRetValOpt(
unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- MutableArrayRef<VirtualCallSite> CallSites) {
+ CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
+ VTableSlot Slot, ArrayRef<uint64_t> Args) {
// IsOne controls whether we look for a 0 or a 1.
auto tryUniqueRetValOptFor = [&](bool IsOne) {
const TypeMemberInfo *UniqueMember = nullptr;
@@ -533,16 +870,23 @@ bool DevirtModule::tryUniqueRetValOpt(
// checked for a uniform return value in tryUniformRetValOpt.
assert(UniqueMember);
- // Replace each call with the comparison.
- for (auto &&Call : CallSites) {
- IRBuilder<> B(Call.CS.getInstruction());
- Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy);
- OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset);
- Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
- Call.VTable, OneAddr);
- Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(),
- RemarksEnabled, Cmp);
+ Constant *UniqueMemberAddr =
+ ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
+ UniqueMemberAddr = ConstantExpr::getGetElementPtr(
+ Int8Ty, UniqueMemberAddr,
+ ConstantInt::get(Int64Ty, UniqueMember->Offset));
+
+ if (CSInfo.isExported()) {
+ Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
+ Res->Info = IsOne;
+
+ exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
}
+
+ // Replace each call with the comparison.
+ applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
+ UniqueMemberAddr);
+
// Update devirtualization statistics for targets.
if (RemarksEnabled)
for (auto &&Target : TargetsForSlot)
@@ -560,9 +904,30 @@ bool DevirtModule::tryUniqueRetValOpt(
return false;
}
+void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
+ Constant *Byte, Constant *Bit) {
+ for (auto Call : CSInfo.CallSites) {
+ auto *RetType = cast<IntegerType>(Call.CS.getType());
+ IRBuilder<> B(Call.CS.getInstruction());
+ Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
+ if (RetType->getBitWidth() == 1) {
+ Value *Bits = B.CreateLoad(Addr);
+ Value *BitsAndBit = B.CreateAnd(Bits, Bit);
+ auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
+ Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
+ IsBitSet);
+ } else {
+ Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
+ Value *Val = B.CreateLoad(RetType, ValAddr);
+ Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val);
+ }
+ }
+ CSInfo.markDevirt();
+}
+
bool DevirtModule::tryVirtualConstProp(
- MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- ArrayRef<VirtualCallSite> CallSites) {
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res, VTableSlot Slot) {
// This only works if the function returns an integer.
auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
if (!RetType)
@@ -571,55 +936,38 @@ bool DevirtModule::tryVirtualConstProp(
if (BitWidth > 64)
return false;
- // Make sure that each function does not access memory, takes at least one
- // argument, does not use its first argument (which we assume is 'this'),
- // and has the same return type.
+ // Make sure that each function is defined, does not access memory, takes at
+ // least one argument, does not use its first argument (which we assume is
+ // 'this'), and has the same return type.
+ //
+ // Note that we test whether this copy of the function is readnone, rather
+ // than testing function attributes, which must hold for any copy of the
+ // function, even a less optimized version substituted at link time. This is
+ // sound because the virtual constant propagation optimizations effectively
+ // inline all implementations of the virtual function into each call site,
+ // rather than using function attributes to perform local optimization.
for (VirtualCallTarget &Target : TargetsForSlot) {
- if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() ||
- !Target.Fn->arg_begin()->use_empty() ||
+ if (Target.Fn->isDeclaration() ||
+ computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
+ MAK_ReadNone ||
+ Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
Target.Fn->getReturnType() != RetType)
return false;
}
- // Group call sites by the list of constant arguments they pass.
- // The comparator ensures deterministic ordering.
- struct ByAPIntValue {
- bool operator()(const std::vector<ConstantInt *> &A,
- const std::vector<ConstantInt *> &B) const {
- return std::lexicographical_compare(
- A.begin(), A.end(), B.begin(), B.end(),
- [](ConstantInt *AI, ConstantInt *BI) {
- return AI->getValue().ult(BI->getValue());
- });
- }
- };
- std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>,
- ByAPIntValue>
- VCallSitesByConstantArg;
- for (auto &&VCallSite : CallSites) {
- std::vector<ConstantInt *> Args;
- if (VCallSite.CS.getType() != RetType)
- continue;
- for (auto &&Arg :
- make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) {
- if (!isa<ConstantInt>(Arg))
- break;
- Args.push_back(cast<ConstantInt>(&Arg));
- }
- if (Args.size() + 1 != VCallSite.CS.arg_size())
- continue;
-
- VCallSitesByConstantArg[Args].push_back(VCallSite);
- }
-
- for (auto &&CSByConstantArg : VCallSitesByConstantArg) {
+ for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
continue;
- if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second))
+ WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
+ if (Res)
+ ResByArg = &Res->ResByArg[CSByConstantArg.first];
+
+ if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
continue;
- if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
+ if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
+ ResByArg, Slot, CSByConstantArg.first))
continue;
// Find an allocation offset in bits in all vtables associated with the
@@ -659,26 +1007,20 @@ bool DevirtModule::tryVirtualConstProp(
for (auto &&Target : TargetsForSlot)
Target.WasDevirt = true;
- // Rewrite each call to a load from OffsetByte/OffsetBit.
- for (auto Call : CSByConstantArg.second) {
- IRBuilder<> B(Call.CS.getInstruction());
- Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte);
- if (BitWidth == 1) {
- Value *Bits = B.CreateLoad(Addr);
- Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
- Value *BitsAndBit = B.CreateAnd(Bits, Bit);
- auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
- Call.replaceAndErase("virtual-const-prop-1-bit",
- TargetsForSlot[0].Fn->getName(),
- RemarksEnabled, IsBitSet);
- } else {
- Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
- Value *Val = B.CreateLoad(RetType, ValAddr);
- Call.replaceAndErase("virtual-const-prop",
- TargetsForSlot[0].Fn->getName(),
- RemarksEnabled, Val);
- }
+ Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
+ Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
+
+ if (CSByConstantArg.second.isExported()) {
+ ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
+ exportGlobal(Slot, CSByConstantArg.first, "byte",
+ ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy));
+ exportGlobal(Slot, CSByConstantArg.first, "bit",
+ ConstantExpr::getIntToPtr(BitConst, Int8PtrTy));
}
+
+ // Rewrite each call to a load from OffsetByte/OffsetBit.
+ applyVirtualConstProp(CSByConstantArg.second,
+ TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
}
return true;
}
@@ -733,7 +1075,11 @@ bool DevirtModule::areRemarksEnabled() {
if (FL.empty())
return false;
const Function &Fn = FL.front();
- auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), "");
+
+ const auto &BBL = Fn.getBasicBlockList();
+ if (BBL.empty())
+ return false;
+ auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
return DI.isEnabled();
}
@@ -766,8 +1112,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
if (SeenPtrs.insert(Ptr).second) {
for (DevirtCallSite Call : DevirtCalls) {
- CallSlots[{TypeId, Call.Offset}].push_back(
- {CI->getArgOperand(0), Call.CS, nullptr});
+ CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
+ Call.CS, nullptr);
}
}
}
@@ -853,14 +1199,79 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (HasNonCallUses)
++NumUnsafeUses;
for (DevirtCallSite Call : DevirtCalls) {
- CallSlots[{TypeId, Call.Offset}].push_back(
- {Ptr, Call.CS, &NumUnsafeUses});
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+ &NumUnsafeUses);
}
CI->eraseFromParent();
}
}
+void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
+ const TypeIdSummary *TidSummary =
+ ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString());
+ if (!TidSummary)
+ return;
+ auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
+ if (ResI == TidSummary->WPDRes.end())
+ return;
+ const WholeProgramDevirtResolution &Res = ResI->second;
+
+ if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
+ // The type of the function in the declaration is irrelevant because every
+ // call site will cast it to the correct type.
+ auto *SingleImpl = M.getOrInsertFunction(
+ Res.SingleImplName, Type::getVoidTy(M.getContext()));
+
+ // This is the import phase so we should not be exporting anything.
+ bool IsExported = false;
+ applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
+ assert(!IsExported);
+ }
+
+ for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
+ auto I = Res.ResByArg.find(CSByConstantArg.first);
+ if (I == Res.ResByArg.end())
+ continue;
+ auto &ResByArg = I->second;
+ // FIXME: We should figure out what to do about the "function name" argument
+ // to the apply* functions, as the function names are unavailable during the
+ // importing phase. For now we just pass the empty string. This does not
+ // impact correctness because the function names are just used for remarks.
+ switch (ResByArg.TheKind) {
+ case WholeProgramDevirtResolution::ByArg::UniformRetVal:
+ applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
+ break;
+ case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
+ Constant *UniqueMemberAddr =
+ importGlobal(Slot, CSByConstantArg.first, "unique_member");
+ applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
+ UniqueMemberAddr);
+ break;
+ }
+ case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
+ Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32);
+ Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty);
+ Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8);
+ Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty);
+ applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
+ }
+ default:
+ break;
+ }
+ }
+}
+
+void DevirtModule::removeRedundantTypeTests() {
+ auto True = ConstantInt::getTrue(M.getContext());
+ for (auto &&U : NumUnsafeUsesForTypeTest) {
+ if (U.second == 0) {
+ U.first->replaceAllUsesWith(True);
+ U.first->eraseFromParent();
+ }
+ }
+}
+
bool DevirtModule::run() {
Function *TypeTestFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
@@ -868,7 +1279,11 @@ bool DevirtModule::run() {
M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
- if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
+ // Normally if there are no users of the devirtualization intrinsics in the
+ // module, this pass has nothing to do. But if we are exporting, we also need
+ // to handle any users that appear only in the function summaries.
+ if (!ExportSummary &&
+ (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
AssumeFunc->use_empty()) &&
(!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
return false;
@@ -879,6 +1294,17 @@ bool DevirtModule::run() {
if (TypeCheckedLoadFunc)
scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
+ if (ImportSummary) {
+ for (auto &S : CallSlots)
+ importResolution(S.first, S.second);
+
+ removeRedundantTypeTests();
+
+ // The rest of the code is only necessary when exporting or during regular
+ // LTO, so we are done.
+ return true;
+ }
+
// Rebuild type metadata into a map for easy lookup.
std::vector<VTableBits> Bits;
DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
@@ -886,6 +1312,53 @@ bool DevirtModule::run() {
if (TypeIdMap.empty())
return true;
+ // Collect information from summary about which calls to try to devirtualize.
+ if (ExportSummary) {
+ DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
+ for (auto &P : TypeIdMap) {
+ if (auto *TypeId = dyn_cast<MDString>(P.first))
+ MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
+ TypeId);
+ }
+
+ for (auto &P : *ExportSummary) {
+ for (auto &S : P.second) {
+ auto *FS = dyn_cast<FunctionSummary>(S.get());
+ if (!FS)
+ continue;
+ // FIXME: Only add live functions.
+ for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
+ for (Metadata *MD : MetadataByGUID[VF.GUID]) {
+ CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
+ true;
+ }
+ }
+ for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
+ for (Metadata *MD : MetadataByGUID[VF.GUID]) {
+ CallSlots[{MD, VF.Offset}]
+ .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
+ }
+ }
+ for (const FunctionSummary::ConstVCall &VC :
+ FS->type_test_assume_const_vcalls()) {
+ for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
+ CallSlots[{MD, VC.VFunc.Offset}]
+ .ConstCSInfo[VC.Args]
+ .SummaryHasTypeTestAssumeUsers = true;
+ }
+ }
+ for (const FunctionSummary::ConstVCall &VC :
+ FS->type_checked_load_const_vcalls()) {
+ for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
+ CallSlots[{MD, VC.VFunc.Offset}]
+ .ConstCSInfo[VC.Args]
+ .SummaryTypeCheckedLoadUsers.push_back(FS);
+ }
+ }
+ }
+ }
+ }
+
// For each (type, offset) pair:
bool DidVirtualConstProp = false;
std::map<std::string, Function*> DevirtTargets;
@@ -894,19 +1367,39 @@ bool DevirtModule::run() {
// function implementation at offset S.first.ByteOffset, and add to
// TargetsForSlot.
std::vector<VirtualCallTarget> TargetsForSlot;
- if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
- S.first.ByteOffset))
- continue;
-
- if (!trySingleImplDevirt(TargetsForSlot, S.second) &&
- tryVirtualConstProp(TargetsForSlot, S.second))
+ if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
+ S.first.ByteOffset)) {
+ WholeProgramDevirtResolution *Res = nullptr;
+ if (ExportSummary && isa<MDString>(S.first.TypeID))
+ Res = &ExportSummary
+ ->getOrInsertTypeIdSummary(
+ cast<MDString>(S.first.TypeID)->getString())
+ .WPDRes[S.first.ByteOffset];
+
+ if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
+ tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
DidVirtualConstProp = true;
- // Collect functions devirtualized at least for one call site for stats.
- if (RemarksEnabled)
- for (const auto &T : TargetsForSlot)
- if (T.WasDevirt)
- DevirtTargets[T.Fn->getName()] = T.Fn;
+ // Collect functions devirtualized at least for one call site for stats.
+ if (RemarksEnabled)
+ for (const auto &T : TargetsForSlot)
+ if (T.WasDevirt)
+ DevirtTargets[T.Fn->getName()] = T.Fn;
+ }
+
+ // CFI-specific: if we are exporting and any llvm.type.checked.load
+ // intrinsics were *not* devirtualized, we need to add the resulting
+ // llvm.type.test intrinsics to the function summaries so that the
+ // LowerTypeTests pass will export them.
+ if (ExportSummary && isa<MDString>(S.first.TypeID)) {
+ auto GUID =
+ GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
+ for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
+ FS->addTypeTest(GUID);
+ for (auto &CCS : S.second.ConstCSInfo)
+ for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers)
+ FS->addTypeTest(GUID);
+ }
}
if (RemarksEnabled) {
@@ -914,23 +1407,12 @@ bool DevirtModule::run() {
for (const auto &DT : DevirtTargets) {
Function *F = DT.second;
DISubprogram *SP = F->getSubprogram();
- DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc();
- emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL,
+ emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,
Twine("devirtualized ") + F->getName());
}
}
- // If we were able to eliminate all unsafe uses for a type checked load,
- // eliminate the type test by replacing it with true.
- if (TypeCheckedLoadFunc) {
- auto True = ConstantInt::getTrue(M.getContext());
- for (auto &&U : NumUnsafeUsesForTypeTest) {
- if (U.second == 0) {
- U.first->replaceAllUsesWith(True);
- U.first->eraseFromParent();
- }
- }
- }
+ removeRedundantTypeTests();
// Rebuild each global we touched as part of virtual constant propagation to
// include the before and after bytes.
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 2d34c1cc74bd8..174ec8036274e 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -902,7 +902,7 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
- // Addition of two 2's compliment numbers having opposite signs will never
+ // Addition of two 2's complement numbers having opposite signs will never
// overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
@@ -939,7 +939,7 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
- // Subtraction of two 2's compliment numbers having identical signs will
+ // Subtraction of two 2's complement numbers having identical signs will
// never overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
@@ -1042,43 +1042,42 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
- const APInt *Val;
- if (match(RHS, m_APInt(Val))) {
- // X + (signbit) --> X ^ signbit
- if (Val->isSignBit())
+ const APInt *RHSC;
+ if (match(RHS, m_APInt(RHSC))) {
+ if (RHSC->isSignBit()) {
+ // If wrapping is not allowed, then the addition must set the sign bit:
+ // X + (signbit) --> X | signbit
+ if (I.hasNoSignedWrap() || I.hasNoUnsignedWrap())
+ return BinaryOperator::CreateOr(LHS, RHS);
+
+ // If wrapping is allowed, then the addition flips the sign bit of LHS:
+ // X + (signbit) --> X ^ signbit
return BinaryOperator::CreateXor(LHS, RHS);
+ }
// Is this add the last step in a convoluted sext?
Value *X;
const APInt *C;
if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) &&
C->isMinSignedValue() &&
- C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) {
+ C->sext(LHS->getType()->getScalarSizeInBits()) == *RHSC) {
// add(zext(xor i16 X, -32768), -32768) --> sext X
return CastInst::Create(Instruction::SExt, X, LHS->getType());
}
- if (Val->isNegative() &&
+ if (RHSC->isNegative() &&
match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) &&
- Val->sge(-C->sext(Val->getBitWidth()))) {
+ RHSC->sge(-C->sext(RHSC->getBitWidth()))) {
// (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val))
- return CastInst::Create(
- Instruction::ZExt,
- Builder->CreateNUWAdd(
- X, Constant::getIntegerValue(X->getType(),
- *C + Val->trunc(C->getBitWidth()))),
- I.getType());
+ Constant *NewC =
+ ConstantInt::get(X->getType(), *C + RHSC->trunc(C->getBitWidth()));
+ return new ZExtInst(Builder->CreateNUWAdd(X, NewC), I.getType());
}
}
// FIXME: Use the match above instead of dyn_cast to allow these transforms
// for splat vectors.
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
- // See if SimplifyDemandedBits can simplify this. This handles stuff like
- // (X & 254)+1 -> (X&254)|1
- if (SimplifyDemandedInstructionBits(I))
- return &I;
-
// zext(bool) + C -> bool ? C + 1 : C
if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
if (ZI->getSrcTy()->isIntegerTy(1))
@@ -1129,8 +1128,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
}
}
- if (isa<Constant>(RHS) && isa<PHINode>(LHS))
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (isa<Constant>(RHS))
+ if (Instruction *NV = foldOpWithConstantIntoOperand(I))
return NV;
if (I.getType()->getScalarType()->isIntegerTy(1))
@@ -1201,11 +1200,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(NewAdd, C2);
}
}
-
- // Try to fold constant add into select arguments.
- if (SelectInst *SI = dyn_cast<SelectInst>(LHS))
- if (Instruction *R = FoldOpIntoSelect(I, SI))
- return R;
}
// add (select X 0 (sub n A)) A --> select X A n
@@ -1253,7 +1247,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// (add (sext x), (sext y)) --> (sext (add int x, y))
if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of sexts), and if the
// integer add will not overflow.
if (LHSConv->getOperand(0)->getType() ==
@@ -1290,7 +1284,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// (add (zext x), (zext y)) --> (zext (add int x, y))
if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of zexts), and if the
// integer add will not overflow.
if (LHSConv->getOperand(0)->getType() ==
@@ -1311,13 +1305,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
{
Value *A = nullptr, *B = nullptr;
if (match(RHS, m_Xor(m_Value(A), m_Value(B))) &&
- (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(LHS, m_And(m_Specific(B), m_Specific(A)))))
+ match(LHS, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
if (match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
- (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(RHS, m_And(m_Specific(B), m_Specific(A)))))
+ match(RHS, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
}
@@ -1325,8 +1317,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
{
Value *A = nullptr, *B = nullptr;
if (match(RHS, m_Or(m_Value(A), m_Value(B))) &&
- (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(LHS, m_And(m_Specific(B), m_Specific(A))))) {
+ match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) {
auto *New = BinaryOperator::CreateAdd(A, B);
New->setHasNoSignedWrap(I.hasNoSignedWrap());
New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
@@ -1334,8 +1325,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
}
if (match(LHS, m_Or(m_Value(A), m_Value(B))) &&
- (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(RHS, m_And(m_Specific(B), m_Specific(A))))) {
+ match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) {
auto *New = BinaryOperator::CreateAdd(A, B);
New->setHasNoSignedWrap(I.hasNoSignedWrap());
New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
@@ -1394,6 +1384,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
+ Value *LHSIntVal = LHSConv->getOperand(0);
+
// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
// ... if the constant fits in the integer value. This is useful for things
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
@@ -1401,12 +1393,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) {
Constant *CI =
- ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
+ ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
if (LHSConv->hasOneUse() &&
ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
- WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
+ WillNotOverflowSignedAdd(LHSIntVal, CI, I)) {
// Insert the new integer add.
- Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
+ Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal,
CI, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
@@ -1414,17 +1406,17 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ Value *RHSIntVal = RHSConv->getOperand(0);
+
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of int->fp conversions),
// and if the integer add will not overflow.
- if (LHSConv->getOperand(0)->getType() ==
- RHSConv->getOperand(0)->getType() &&
+ if (LHSIntVal->getType() == RHSIntVal->getType() &&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
- WillNotOverflowSignedAdd(LHSConv->getOperand(0),
- RHSConv->getOperand(0), I)) {
+ WillNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {
// Insert the new integer add.
- Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
- RHSConv->getOperand(0),"addconv");
+ Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal,
+ RHSIntVal, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}
@@ -1562,7 +1554,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return Res;
}
- if (I.getType()->isIntegerTy(1))
+ if (I.getType()->getScalarType()->isIntegerTy(1))
return BinaryOperator::CreateXor(Op0, Op1);
// Replace (-1 - A) with (~A).
@@ -1580,14 +1572,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
+ // Try to fold constant sub into PHI values.
+ if (PHINode *PN = dyn_cast<PHINode>(Op1))
+ if (Instruction *R = foldOpIntoPhi(I, PN))
+ return R;
+
// C-(X+C2) --> (C-C2)-X
Constant *C2;
if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
- if (SimplifyDemandedInstructionBits(I))
- return &I;
-
// Fold (sub 0, (zext bool to B)) --> (sext bool to B)
if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X))))
if (X->getType()->getScalarType()->isIntegerTy(1))
@@ -1622,11 +1616,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
// zero.
- if ((*Op0C + 1).isPowerOf2()) {
- APInt KnownZero(BitWidth, 0);
- APInt KnownOne(BitWidth, 0);
- computeKnownBits(&I, KnownZero, KnownOne, 0, &I);
- if ((*Op0C | KnownZero).isAllOnesValue())
+ if (Op0C->isMask()) {
+ APInt RHSKnownZero(BitWidth, 0);
+ APInt RHSKnownOne(BitWidth, 0);
+ computeKnownBits(Op1, RHSKnownZero, RHSKnownOne, 0, &I);
+ if ((*Op0C | RHSKnownZero).isAllOnesValue())
return BinaryOperator::CreateXor(Op1, Op0);
}
}
@@ -1634,8 +1628,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
{
Value *Y;
// X-(X+Y) == -Y X-(Y+X) == -Y
- if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) ||
- match(Op1, m_Add(m_Value(Y), m_Specific(Op0))))
+ if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y))))
return BinaryOperator::CreateNeg(Y);
// (X-Y)-X == -Y
@@ -1645,18 +1638,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// (sub (or A, B) (xor A, B)) --> (and A, B)
{
- Value *A = nullptr, *B = nullptr;
+ Value *A, *B;
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
- (match(Op0, m_Or(m_Specific(A), m_Specific(B))) ||
- match(Op0, m_Or(m_Specific(B), m_Specific(A)))))
+ match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateAnd(A, B);
}
- if (Op0->hasOneUse()) {
- Value *Y = nullptr;
+ {
+ Value *Y;
// ((X | Y) - X) --> (~X & Y)
- if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) ||
- match(Op0, m_Or(m_Specific(Op1), m_Value(Y))))
+ if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1)))))
return BinaryOperator::CreateAnd(
Y, Builder->CreateNot(Op1, Op1->getName() + ".not"));
}
@@ -1664,7 +1655,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
if (Op1->hasOneUse()) {
Value *X = nullptr, *Y = nullptr, *Z = nullptr;
Constant *C = nullptr;
- Constant *CI = nullptr;
// (X - (Y - Z)) --> (X + (Z - Y)).
if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
@@ -1673,8 +1663,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// (X - (X & Y)) --> (X & ~Y)
//
- if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) ||
- match(Op1, m_And(m_Specific(Op0), m_Value(Y))))
+ if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0))))
return BinaryOperator::CreateAnd(Op0,
Builder->CreateNot(Y, Y->getName() + ".not"));
@@ -1702,14 +1691,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// X - A*-B -> X + A*B
// X - -A*B -> X + A*B
Value *A, *B;
- if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) ||
- match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B))))
+ Constant *CI;
+ if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))
return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B));
// X - A*CI -> X + A*-CI
- // X - CI*A -> X + A*-CI
- if (match(Op1, m_Mul(m_Value(A), m_Constant(CI))) ||
- match(Op1, m_Mul(m_Constant(CI), m_Value(A)))) {
+ // No need to handle commuted multiply because multiply handling will
+ // ensure constant will be move to the right hand side.
+ if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) {
Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI));
return BinaryOperator::CreateAdd(Op0, NewMul);
}
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index da5384a86aac6..b2a41c699202a 100644
--- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -137,9 +137,8 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) {
}
/// This handles expressions of the form ((val OP C1) & C2). Where
-/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is
-/// guaranteed to be a binary operator.
-Instruction *InstCombiner::OptAndOp(Instruction *Op,
+/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'.
+Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,
ConstantInt *OpRHS,
ConstantInt *AndRHS,
BinaryOperator &TheAnd) {
@@ -149,6 +148,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,
Together = ConstantExpr::getAnd(AndRHS, OpRHS);
switch (Op->getOpcode()) {
+ default: break;
case Instruction::Xor:
if (Op->hasOneUse()) {
// (X ^ C1) & C2 --> (X & C2) ^ (C1&C2)
@@ -159,13 +159,6 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,
break;
case Instruction::Or:
if (Op->hasOneUse()){
- if (Together != OpRHS) {
- // (X | C1) & C2 --> (X | (C1&C2)) & C2
- Value *Or = Builder->CreateOr(X, Together);
- Or->takeName(Op);
- return BinaryOperator::CreateAnd(Or, AndRHS);
- }
-
ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together);
if (TogetherCI && !TogetherCI->isZero()){
// (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1
@@ -302,178 +295,91 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo);
}
-/// Returns true iff Val consists of one contiguous run of 1s with any number
-/// of 0s on either side. The 1s are allowed to wrap from LSB to MSB,
-/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is
-/// not, since all 1s are not contiguous.
-static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) {
- const APInt& V = Val->getValue();
- uint32_t BitWidth = Val->getType()->getBitWidth();
- if (!APIntOps::isShiftedMask(BitWidth, V)) return false;
-
- // look for the first zero bit after the run of ones
- MB = BitWidth - ((V - 1) ^ V).countLeadingZeros();
- // look for the first non-zero bit
- ME = V.getActiveBits();
- return true;
-}
-
-/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines
-/// whether the operator is a sub. If we can fold one of the following xforms:
+/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns
+/// that can be simplified.
+/// One of A and B is considered the mask. The other is the value. This is
+/// described as the "AMask" or "BMask" part of the enum. If the enum contains
+/// only "Mask", then both A and B can be considered masks. If A is the mask,
+/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0.
+/// If both A and C are constants, this proof is also easy.
+/// For the following explanations, we assume that A is the mask.
///
-/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask
-/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0
-/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0
+/// "AllOnes" declares that the comparison is true only if (A & B) == A or all
+/// bits of A are set in B.
+/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes
///
-/// return (A +/- B).
+/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all
+/// bits of A are cleared in B.
+/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes
+///
+/// "Mixed" declares that (A & B) == C and C might or might not contain any
+/// number of one bits and zero bits.
+/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed
+///
+/// "Not" means that in above descriptions "==" should be replaced by "!=".
+/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes
///
-Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS,
- ConstantInt *Mask, bool isSub,
- Instruction &I) {
- Instruction *LHSI = dyn_cast<Instruction>(LHS);
- if (!LHSI || LHSI->getNumOperands() != 2 ||
- !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr;
-
- ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1));
-
- switch (LHSI->getOpcode()) {
- default: return nullptr;
- case Instruction::And:
- if (ConstantExpr::getAnd(N, Mask) == Mask) {
- // If the AndRHS is a power of two minus one (0+1+), this is simple.
- if ((Mask->getValue().countLeadingZeros() +
- Mask->getValue().countPopulation()) ==
- Mask->getValue().getBitWidth())
- break;
-
- // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+
- // part, we don't need any explicit masks to take them out of A. If that
- // is all N is, ignore it.
- uint32_t MB = 0, ME = 0;
- if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive
- uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth();
- APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1));
- if (MaskedValueIsZero(RHS, Mask, 0, &I))
- break;
- }
- }
- return nullptr;
- case Instruction::Or:
- case Instruction::Xor:
- // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0
- if ((Mask->getValue().countLeadingZeros() +
- Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth()
- && ConstantExpr::getAnd(N, Mask)->isNullValue())
- break;
- return nullptr;
- }
-
- if (isSub)
- return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold");
- return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold");
-}
-
-/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C)
-/// One of A and B is considered the mask, the other the value. This is
-/// described as the "AMask" or "BMask" part of the enum. If the enum
-/// contains only "Mask", then both A and B can be considered masks.
-/// If A is the mask, then it was proven, that (A & C) == C. This
-/// is trivial if C == A, or C == 0. If both A and C are constants, this
-/// proof is also easy.
-/// For the following explanations we assume that A is the mask.
-/// The part "AllOnes" declares, that the comparison is true only
-/// if (A & B) == A, or all bits of A are set in B.
-/// Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes
-/// The part "AllZeroes" declares, that the comparison is true only
-/// if (A & B) == 0, or all bits of A are cleared in B.
-/// Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes
-/// The part "Mixed" declares, that (A & B) == C and C might or might not
-/// contain any number of one bits and zero bits.
-/// Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed
-/// The Part "Not" means, that in above descriptions "==" should be replaced
-/// by "!=".
-/// Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes
/// If the mask A contains a single bit, then the following is equivalent:
/// (icmp eq (A & B), A) equals (icmp ne (A & B), 0)
/// (icmp ne (A & B), A) equals (icmp eq (A & B), 0)
enum MaskedICmpType {
- FoldMskICmp_AMask_AllOnes = 1,
- FoldMskICmp_AMask_NotAllOnes = 2,
- FoldMskICmp_BMask_AllOnes = 4,
- FoldMskICmp_BMask_NotAllOnes = 8,
- FoldMskICmp_Mask_AllZeroes = 16,
- FoldMskICmp_Mask_NotAllZeroes = 32,
- FoldMskICmp_AMask_Mixed = 64,
- FoldMskICmp_AMask_NotMixed = 128,
- FoldMskICmp_BMask_Mixed = 256,
- FoldMskICmp_BMask_NotMixed = 512
+ AMask_AllOnes = 1,
+ AMask_NotAllOnes = 2,
+ BMask_AllOnes = 4,
+ BMask_NotAllOnes = 8,
+ Mask_AllZeros = 16,
+ Mask_NotAllZeros = 32,
+ AMask_Mixed = 64,
+ AMask_NotMixed = 128,
+ BMask_Mixed = 256,
+ BMask_NotMixed = 512
};
-/// Return the set of pattern classes (from MaskedICmpType)
-/// that (icmp SCC (A & B), C) satisfies.
-static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
- ICmpInst::Predicate SCC)
-{
+/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C)
+/// satisfies.
+static unsigned getMaskedICmpType(Value *A, Value *B, Value *C,
+ ICmpInst::Predicate Pred) {
ConstantInt *ACst = dyn_cast<ConstantInt>(A);
ConstantInt *BCst = dyn_cast<ConstantInt>(B);
ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- bool icmp_eq = (SCC == ICmpInst::ICMP_EQ);
- bool icmp_abit = (ACst && !ACst->isZero() &&
- ACst->getValue().isPowerOf2());
- bool icmp_bbit = (BCst && !BCst->isZero() &&
- BCst->getValue().isPowerOf2());
- unsigned result = 0;
+ bool IsEq = (Pred == ICmpInst::ICMP_EQ);
+ bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2());
+ bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2());
+ unsigned MaskVal = 0;
if (CCst && CCst->isZero()) {
// if C is zero, then both A and B qualify as mask
- result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_AMask_Mixed |
- FoldMskICmp_BMask_Mixed)
- : (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_AMask_NotMixed |
- FoldMskICmp_BMask_NotMixed));
- if (icmp_abit)
- result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes |
- FoldMskICmp_AMask_NotMixed)
- : (FoldMskICmp_AMask_AllOnes |
- FoldMskICmp_AMask_Mixed));
- if (icmp_bbit)
- result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_BMask_NotMixed)
- : (FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_BMask_Mixed));
- return result;
+ MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed)
+ : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed));
+ if (IsAPow2)
+ MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed)
+ : (AMask_AllOnes | AMask_Mixed));
+ if (IsBPow2)
+ MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed)
+ : (BMask_AllOnes | BMask_Mixed));
+ return MaskVal;
}
+
if (A == C) {
- result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes |
- FoldMskICmp_AMask_Mixed)
- : (FoldMskICmp_AMask_NotAllOnes |
- FoldMskICmp_AMask_NotMixed));
- if (icmp_abit)
- result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_AMask_NotMixed)
- : (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_AMask_Mixed));
- } else if (ACst && CCst &&
- ConstantExpr::getAnd(ACst, CCst) == CCst) {
- result |= (icmp_eq ? FoldMskICmp_AMask_Mixed
- : FoldMskICmp_AMask_NotMixed);
+ MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed)
+ : (AMask_NotAllOnes | AMask_NotMixed));
+ if (IsAPow2)
+ MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed)
+ : (Mask_AllZeros | AMask_Mixed));
+ } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) {
+ MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed);
}
+
if (B == C) {
- result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_BMask_Mixed)
- : (FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_BMask_NotMixed));
- if (icmp_bbit)
- result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_BMask_NotMixed)
- : (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_BMask_Mixed));
- } else if (BCst && CCst &&
- ConstantExpr::getAnd(BCst, CCst) == CCst) {
- result |= (icmp_eq ? FoldMskICmp_BMask_Mixed
- : FoldMskICmp_BMask_NotMixed);
- }
- return result;
+ MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed)
+ : (BMask_NotAllOnes | BMask_NotMixed));
+ if (IsBPow2)
+ MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed)
+ : (Mask_AllZeros | BMask_Mixed));
+ } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) {
+ MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed);
+ }
+
+ return MaskVal;
}
/// Convert an analysis of a masked ICmp into its equivalent if all boolean
@@ -482,32 +388,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
/// involves swapping those bits over.
static unsigned conjugateICmpMask(unsigned Mask) {
unsigned NewMask;
- NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed |
- FoldMskICmp_BMask_Mixed))
+ NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros |
+ AMask_Mixed | BMask_Mixed))
<< 1;
- NewMask |=
- (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed |
- FoldMskICmp_BMask_NotMixed))
- >> 1;
+ NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros |
+ AMask_NotMixed | BMask_NotMixed))
+ >> 1;
return NewMask;
}
-/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
-/// Return the set of pattern classes (from MaskedICmpType)
-/// that both LHS and RHS satisfy.
-static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
- Value*& B, Value*& C,
- Value*& D, Value*& E,
- ICmpInst *LHS, ICmpInst *RHS,
- ICmpInst::Predicate &LHSCC,
- ICmpInst::Predicate &RHSCC) {
- if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0;
+/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E).
+/// Return the set of pattern classes (from MaskedICmpType) that both LHS and
+/// RHS satisfy.
+static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
+ Value *&D, Value *&E, ICmpInst *LHS,
+ ICmpInst *RHS,
+ ICmpInst::Predicate &PredL,
+ ICmpInst::Predicate &PredR) {
+ if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType())
+ return 0;
// vectors are not (yet?) supported
- if (LHS->getOperand(0)->getType()->isVectorTy()) return 0;
+ if (LHS->getOperand(0)->getType()->isVectorTy())
+ return 0;
// Here comes the tricky part:
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -517,9 +421,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
// above.
Value *L1 = LHS->getOperand(0);
Value *L2 = LHS->getOperand(1);
- Value *L11,*L12,*L21,*L22;
+ Value *L11, *L12, *L21, *L22;
// Check whether the icmp can be decomposed into a bit test.
- if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) {
+ if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
L21 = L22 = L1 = nullptr;
} else {
// Look for ANDs in the LHS icmp.
@@ -543,22 +447,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
}
// Bail if LHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(LHSCC))
+ if (!ICmpInst::isEquality(PredL))
return 0;
Value *R1 = RHS->getOperand(0);
Value *R2 = RHS->getOperand(1);
- Value *R11,*R12;
- bool ok = false;
- if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) {
+ Value *R11, *R12;
+ bool Ok = false;
+ if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12;
+ A = R11;
+ D = R12;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11;
+ A = R12;
+ D = R11;
} else {
return 0;
}
- E = R2; R1 = nullptr; ok = true;
+ E = R2;
+ R1 = nullptr;
+ Ok = true;
} else if (R1->getType()->isIntegerTy()) {
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
// As before, model no mask as a trivial mask if it'll let us do an
@@ -568,46 +476,62 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
}
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12; E = R2; ok = true;
+ A = R11;
+ D = R12;
+ E = R2;
+ Ok = true;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11; E = R2; ok = true;
+ A = R12;
+ D = R11;
+ E = R2;
+ Ok = true;
}
}
// Bail if RHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(RHSCC))
+ if (!ICmpInst::isEquality(PredR))
return 0;
// Look for ANDs on the right side of the RHS icmp.
- if (!ok && R2->getType()->isIntegerTy()) {
+ if (!Ok && R2->getType()->isIntegerTy()) {
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
R11 = R2;
R12 = Constant::getAllOnesValue(R2->getType());
}
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12; E = R1; ok = true;
+ A = R11;
+ D = R12;
+ E = R1;
+ Ok = true;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11; E = R1; ok = true;
+ A = R12;
+ D = R11;
+ E = R1;
+ Ok = true;
} else {
return 0;
}
}
- if (!ok)
+ if (!Ok)
return 0;
if (L11 == A) {
- B = L12; C = L2;
+ B = L12;
+ C = L2;
} else if (L12 == A) {
- B = L11; C = L2;
+ B = L11;
+ C = L2;
} else if (L21 == A) {
- B = L22; C = L1;
+ B = L22;
+ C = L1;
} else if (L22 == A) {
- B = L21; C = L1;
+ B = L21;
+ C = L1;
}
- unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC);
- unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC);
+ unsigned LeftType = getMaskedICmpType(A, B, C, PredL);
+ unsigned RightType = getMaskedICmpType(A, D, E, PredR);
return LeftType & RightType;
}
@@ -616,12 +540,14 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
llvm::InstCombiner::BuilderTy *Builder) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
- unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS,
- LHSCC, RHSCC);
- if (Mask == 0) return nullptr;
- assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) &&
- "foldLogOpOfMaskedICmpsHelper must return an equality predicate.");
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ unsigned Mask =
+ getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
+ if (Mask == 0)
+ return nullptr;
+
+ assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
+ "Expected equality predicates for masked type of icmps.");
// In full generality:
// (icmp (A & B) Op C) | (icmp (A & D) Op E)
@@ -642,7 +568,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
Mask = conjugateICmpMask(Mask);
}
- if (Mask & FoldMskICmp_Mask_AllZeroes) {
+ if (Mask & Mask_AllZeros) {
// (icmp eq (A & B), 0) & (icmp eq (A & D), 0)
// -> (icmp eq (A & (B|D)), 0)
Value *NewOr = Builder->CreateOr(B, D);
@@ -653,14 +579,14 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
Value *Zero = Constant::getNullValue(A->getType());
return Builder->CreateICmp(NewCC, NewAnd, Zero);
}
- if (Mask & FoldMskICmp_BMask_AllOnes) {
+ if (Mask & BMask_AllOnes) {
// (icmp eq (A & B), B) & (icmp eq (A & D), D)
// -> (icmp eq (A & (B|D)), (B|D))
Value *NewOr = Builder->CreateOr(B, D);
Value *NewAnd = Builder->CreateAnd(A, NewOr);
return Builder->CreateICmp(NewCC, NewAnd, NewOr);
}
- if (Mask & FoldMskICmp_AMask_AllOnes) {
+ if (Mask & AMask_AllOnes) {
// (icmp eq (A & B), A) & (icmp eq (A & D), A)
// -> (icmp eq (A & (B&D)), A)
Value *NewAnd1 = Builder->CreateAnd(B, D);
@@ -672,11 +598,13 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// their actual values. This isn't strictly necessary, just a "handle the
// easy cases for now" decision.
ConstantInt *BCst = dyn_cast<ConstantInt>(B);
- if (!BCst) return nullptr;
+ if (!BCst)
+ return nullptr;
ConstantInt *DCst = dyn_cast<ConstantInt>(D);
- if (!DCst) return nullptr;
+ if (!DCst)
+ return nullptr;
- if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) {
+ if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
// (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and
// (icmp ne (A & B), B) & (icmp ne (A & D), D)
// -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0)
@@ -689,7 +617,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
else if (NewMask == DCst->getValue())
return RHS;
}
- if (Mask & FoldMskICmp_AMask_NotAllOnes) {
+
+ if (Mask & AMask_NotAllOnes) {
// (icmp ne (A & B), B) & (icmp ne (A & D), D)
// -> (icmp ne (A & B), A) or (icmp ne (A & D), A)
// Only valid if one of the masks is a superset of the other (check "B|D" is
@@ -701,7 +630,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
else if (NewMask == DCst->getValue())
return RHS;
}
- if (Mask & FoldMskICmp_BMask_Mixed) {
+
+ if (Mask & BMask_Mixed) {
// (icmp eq (A & B), C) & (icmp eq (A & D), E)
// We already know that B & C == C && D & E == E.
// If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of
@@ -713,23 +643,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// (icmp ne (A & B), B) & (icmp eq (A & D), D)
// with B and D, having a single bit set.
ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- if (!CCst) return nullptr;
+ if (!CCst)
+ return nullptr;
ConstantInt *ECst = dyn_cast<ConstantInt>(E);
- if (!ECst) return nullptr;
- if (LHSCC != NewCC)
+ if (!ECst)
+ return nullptr;
+ if (PredL != NewCC)
CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst));
- if (RHSCC != NewCC)
+ if (PredR != NewCC)
ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst));
+
// If there is a conflict, we should actually return a false for the
// whole construct.
if (((BCst->getValue() & DCst->getValue()) &
(CCst->getValue() ^ ECst->getValue())) != 0)
return ConstantInt::get(LHS->getType(), !IsAnd);
+
Value *NewOr1 = Builder->CreateOr(B, D);
Value *NewOr2 = ConstantExpr::getOr(CCst, ECst);
Value *NewAnd = Builder->CreateAnd(A, NewOr1);
return Builder->CreateICmp(NewCC, NewAnd, NewOr2);
}
+
return nullptr;
}
@@ -789,12 +724,67 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
return Builder->CreateICmp(NewPred, Input, RangeEnd);
}
+static Value *
+foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS,
+ bool JoinedByAnd,
+ InstCombiner::BuilderTy *Builder) {
+ Value *X = LHS->getOperand(0);
+ if (X != RHS->getOperand(0))
+ return nullptr;
+
+ const APInt *C1, *C2;
+ if (!match(LHS->getOperand(1), m_APInt(C1)) ||
+ !match(RHS->getOperand(1), m_APInt(C2)))
+ return nullptr;
+
+ // We only handle (X != C1 && X != C2) and (X == C1 || X == C2).
+ ICmpInst::Predicate Pred = LHS->getPredicate();
+ if (Pred != RHS->getPredicate())
+ return nullptr;
+ if (JoinedByAnd && Pred != ICmpInst::ICMP_NE)
+ return nullptr;
+ if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ)
+ return nullptr;
+
+ // The larger unsigned constant goes on the right.
+ if (C1->ugt(*C2))
+ std::swap(C1, C2);
+
+ APInt Xor = *C1 ^ *C2;
+ if (Xor.isPowerOf2()) {
+ // If LHSC and RHSC differ by only one bit, then set that bit in X and
+ // compare against the larger constant:
+ // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2
+ // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2
+ // We choose an 'or' with a Pow2 constant rather than the inverse mask with
+ // 'and' because that may lead to smaller codegen from a smaller constant.
+ Value *Or = Builder->CreateOr(X, ConstantInt::get(X->getType(), Xor));
+ return Builder->CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2));
+ }
+
+ // Special case: get the ordering right when the values wrap around zero.
+ // Ie, we assumed the constants were unsigned when swapping earlier.
+ if (*C1 == 0 && C2->isAllOnesValue())
+ std::swap(C1, C2);
+
+ if (*C1 == *C2 - 1) {
+ // (X == 13 || X == 14) --> X - 13 <=u 1
+ // (X != 13 && X != 14) --> X - 13 >u 1
+ // An 'add' is the canonical IR form, so favor that over a 'sub'.
+ Value *Add = Builder->CreateAdd(X, ConstantInt::get(X->getType(), -(*C1)));
+ auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE;
+ return Builder->CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1));
+ }
+
+ return nullptr;
+}
+
/// Fold (icmp)&(icmp) if possible.
Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
// (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
- if (PredicatesFoldable(LHSCC, RHSCC)) {
+ if (PredicatesFoldable(PredL, PredR)) {
if (LHS->getOperand(0) == RHS->getOperand(1) &&
LHS->getOperand(1) == RHS->getOperand(0))
LHS->swapOperands();
@@ -819,86 +809,90 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false))
return V;
+ if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder))
+ return V;
+
// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
- Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
- ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (!LHSCst || !RHSCst) return nullptr;
+ Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ if (!LHSC || !RHSC)
+ return nullptr;
- if (LHSCst == RHSCst && LHSCC == RHSCC) {
+ if (LHSC == RHSC && PredL == PredR) {
// (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)
// where C is a power of 2 or
// (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0)
- if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) ||
- (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) {
- Value *NewOr = Builder->CreateOr(Val, Val2);
- return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
+ if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) ||
+ (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) {
+ Value *NewOr = Builder->CreateOr(LHS0, RHS0);
+ return Builder->CreateICmp(PredL, NewOr, LHSC);
}
}
// (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2
// where CMAX is the all ones value for the truncated type,
// iff the lower bits of C2 and CA are zero.
- if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC &&
- LHS->hasOneUse() && RHS->hasOneUse()) {
+ if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() &&
+ RHS->hasOneUse()) {
Value *V;
- ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr;
+ ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr;
// (trunc x) == C1 & (and x, CA) == C2
// (and x, CA) == C2 & (trunc x) == C1
- if (match(Val2, m_Trunc(m_Value(V))) &&
- match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) {
- SmallCst = RHSCst;
- BigCst = LHSCst;
- } else if (match(Val, m_Trunc(m_Value(V))) &&
- match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) {
- SmallCst = LHSCst;
- BigCst = RHSCst;
+ if (match(RHS0, m_Trunc(m_Value(V))) &&
+ match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) {
+ SmallC = RHSC;
+ BigC = LHSC;
+ } else if (match(LHS0, m_Trunc(m_Value(V))) &&
+ match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) {
+ SmallC = LHSC;
+ BigC = RHSC;
}
- if (SmallCst && BigCst) {
- unsigned BigBitSize = BigCst->getType()->getBitWidth();
- unsigned SmallBitSize = SmallCst->getType()->getBitWidth();
+ if (SmallC && BigC) {
+ unsigned BigBitSize = BigC->getType()->getBitWidth();
+ unsigned SmallBitSize = SmallC->getType()->getBitWidth();
// Check that the low bits are zero.
APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize);
- if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) {
- Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue());
- APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue();
- Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N);
- return Builder->CreateICmp(LHSCC, NewAnd, NewVal);
+ if ((Low & AndC->getValue()) == 0 && (Low & BigC->getValue()) == 0) {
+ Value *NewAnd = Builder->CreateAnd(V, Low | AndC->getValue());
+ APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue();
+ Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N);
+ return Builder->CreateICmp(PredL, NewAnd, NewVal);
}
}
}
// From here on, we only handle:
// (icmp1 A, C1) & (icmp2 A, C2) --> something simpler.
- if (Val != Val2) return nullptr;
+ if (LHS0 != RHS0)
+ return nullptr;
- // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere.
- if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE ||
- RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE ||
- LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE ||
- RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE)
+ // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere.
+ if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE ||
+ PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE ||
+ PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE ||
+ PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)
return nullptr;
// We can't fold (ugt x, C) & (sgt x, C2).
- if (!PredicatesFoldable(LHSCC, RHSCC))
+ if (!PredicatesFoldable(PredL, PredR))
return nullptr;
// Ensure that the larger constant is on the RHS.
bool ShouldSwap;
- if (CmpInst::isSigned(LHSCC) ||
- (ICmpInst::isEquality(LHSCC) &&
- CmpInst::isSigned(RHSCC)))
- ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue());
+ if (CmpInst::isSigned(PredL) ||
+ (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR)))
+ ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());
else
- ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue());
+ ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());
if (ShouldSwap) {
std::swap(LHS, RHS);
- std::swap(LHSCst, RHSCst);
- std::swap(LHSCC, RHSCC);
+ std::swap(LHSC, RHSC);
+ std::swap(PredL, PredR);
}
// At this point, we know we have two icmp instructions
@@ -907,113 +901,95 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
// icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know
// (from the icmp folding check above), that the two constants
// are not equal and that the larger constant is on the RHS
- assert(LHSCst != RHSCst && "Compares not folded above?");
+ assert(LHSC != RHSC && "Compares not folded above?");
- switch (LHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredL) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13
- case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13
- case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13
+ case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13
+ case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13
return LHS;
}
case ICmpInst::ICMP_NE:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_ULT:
- if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13
- return Builder->CreateICmpULT(Val, LHSCst);
- if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
+ if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13
+ return Builder->CreateICmpULT(LHS0, LHSC);
+ if (LHSC->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
- break; // (X != 13 & X u< 15) -> no change
+ break; // (X != 13 & X u< 15) -> no change
case ICmpInst::ICMP_SLT:
- if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13
- return Builder->CreateICmpSLT(Val, LHSCst);
- break; // (X != 13 & X s< 15) -> no change
- case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15
- case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15
+ if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13
+ return Builder->CreateICmpSLT(LHS0, LHSC);
+ break; // (X != 13 & X s< 15) -> no change
+ case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15
+ case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15
return RHS;
case ICmpInst::ICMP_NE:
- // Special case to get the ordering right when the values wrap around
- // zero.
- if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue())
- std::swap(LHSCst, RHSCst);
- if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1
- Constant *AddCST = ConstantExpr::getNeg(LHSCst);
- Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off");
- return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1),
- Val->getName()+".cmp");
- }
- break; // (X != 13 & X != 15) -> no change
+ // Potential folds for this case should already be handled.
+ break;
}
break;
case ICmpInst::ICMP_ULT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false
- case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false
+ case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false
return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0);
- case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13
- case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13
+ case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13
+ case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13
return LHS;
- case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SLT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13
- case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13
+ case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13
return LHS;
- case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_UGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15
return RHS;
- case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change
- break;
case ICmpInst::ICMP_NE:
- if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14
- return Builder->CreateICmp(LHSCC, Val, RHSCst);
- break; // (X u> 13 & X != 15) -> no change
- case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
+ if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14
+ return Builder->CreateICmp(PredL, LHS0, RHSC);
+ break; // (X u> 13 & X != 15) -> no change
+ case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
- case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15
return RHS;
- case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change
- break;
case ICmpInst::ICMP_NE:
- if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14
- return Builder->CreateICmp(LHSCC, Val, RHSCst);
- break; // (X s> 13 & X != 15) -> no change
- case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
- true, true);
- case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change
- break;
+ if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14
+ return Builder->CreateICmp(PredL, LHS0, RHSC);
+ break; // (X s> 13 & X != 15) -> no change
+ case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true,
+ true);
}
break;
}
@@ -1314,39 +1290,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
break;
}
- case Instruction::Add:
- // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS.
- // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0
- // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0
- if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I))
- return BinaryOperator::CreateAnd(V, AndRHS);
- if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I))
- return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes
- break;
-
case Instruction::Sub:
- // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS.
- // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0
- // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0
- if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I))
- return BinaryOperator::CreateAnd(V, AndRHS);
-
// -x & 1 -> x & 1
if (AndRHSMask == 1 && match(Op0LHS, m_Zero()))
return BinaryOperator::CreateAnd(Op0RHS, AndRHS);
- // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS
- // has 1's for all bits that the subtraction with A might affect.
- if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) {
- uint32_t BitWidth = AndRHSMask.getBitWidth();
- uint32_t Zeros = AndRHSMask.countLeadingZeros();
- APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros);
-
- if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) {
- Value *NewNeg = Builder->CreateNeg(Op0RHS);
- return BinaryOperator::CreateAnd(NewNeg, AndRHS);
- }
- }
break;
case Instruction::Shl:
@@ -1361,6 +1309,33 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
break;
}
+ // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth
+ // of X and OP behaves well when given trunc(C1) and X.
+ switch (Op0I->getOpcode()) {
+ default:
+ break;
+ case Instruction::Xor:
+ case Instruction::Or:
+ case Instruction::Mul:
+ case Instruction::Add:
+ case Instruction::Sub:
+ Value *X;
+ ConstantInt *C1;
+ if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) {
+ if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) {
+ auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType());
+ Value *BinOp;
+ if (isa<ZExtInst>(Op0LHS))
+ BinOp = Builder->CreateBinOp(Op0I->getOpcode(), X, TruncC1);
+ else
+ BinOp = Builder->CreateBinOp(Op0I->getOpcode(), TruncC1, X);
+ auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType());
+ auto *And = Builder->CreateAnd(BinOp, TruncC2);
+ return new ZExtInst(And, I.getType());
+ }
+ }
+ }
+
if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1)))
if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I))
return Res;
@@ -1381,10 +1356,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(NewCast, C3);
}
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder))
return DeMorgan;
@@ -1630,15 +1606,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D,
/// Fold (icmp)|(icmp) if possible.
Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Instruction *CxtI) {
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// if K1 and K2 are a one-bit mask.
- ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() &&
- RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) {
+ if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() &&
+ RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0));
BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0));
@@ -1680,52 +1656,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
// This implies all values in the two ranges differ by exactly one bit.
- if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) &&
- LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() &&
- RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() &&
- LHSCst->getValue() == (RHSCst->getValue())) {
+ if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
+ PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
+ LHSC->getType() == RHSC->getType() &&
+ LHSC->getValue() == (RHSC->getValue())) {
Value *LAdd = LHS->getOperand(0);
Value *RAdd = RHS->getOperand(0);
Value *LAddOpnd, *RAddOpnd;
- ConstantInt *LAddCst, *RAddCst;
- if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) &&
- match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) &&
- LAddCst->getValue().ugt(LHSCst->getValue()) &&
- RAddCst->getValue().ugt(LHSCst->getValue())) {
-
- APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue();
- if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) {
- ConstantInt *MaxAddCst = nullptr;
- if (LAddCst->getValue().ult(RAddCst->getValue()))
- MaxAddCst = RAddCst;
+ ConstantInt *LAddC, *RAddC;
+ if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) &&
+ match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) &&
+ LAddC->getValue().ugt(LHSC->getValue()) &&
+ RAddC->getValue().ugt(LHSC->getValue())) {
+
+ APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
+ if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) {
+ ConstantInt *MaxAddC = nullptr;
+ if (LAddC->getValue().ult(RAddC->getValue()))
+ MaxAddC = RAddC;
else
- MaxAddCst = LAddCst;
+ MaxAddC = LAddC;
- APInt RRangeLow = -RAddCst->getValue();
- APInt RRangeHigh = RRangeLow + LHSCst->getValue();
- APInt LRangeLow = -LAddCst->getValue();
- APInt LRangeHigh = LRangeLow + LHSCst->getValue();
+ APInt RRangeLow = -RAddC->getValue();
+ APInt RRangeHigh = RRangeLow + LHSC->getValue();
+ APInt LRangeLow = -LAddC->getValue();
+ APInt LRangeHigh = LRangeLow + LHSC->getValue();
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
: RRangeLow - LRangeLow;
if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
- RangeDiff.ugt(LHSCst->getValue())) {
- Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst);
+ RangeDiff.ugt(LHSC->getValue())) {
+ Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
- Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst);
- Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst);
- return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst));
+ Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskC);
+ Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddC);
+ return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSC));
}
}
}
}
// (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
- if (PredicatesFoldable(LHSCC, RHSCC)) {
+ if (PredicatesFoldable(PredL, PredR)) {
if (LHS->getOperand(0) == RHS->getOperand(1) &&
LHS->getOperand(1) == RHS->getOperand(0))
LHS->swapOperands();
@@ -1743,25 +1719,25 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder))
return V;
- Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
+ Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
if (LHS->hasOneUse() || RHS->hasOneUse()) {
// (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1)
// (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1)
Value *A = nullptr, *B = nullptr;
- if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) {
- B = Val;
- if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1))
- A = Val2;
- else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2)
+ if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) {
+ B = LHS0;
+ if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1))
+ A = RHS0;
+ else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0)
A = RHS->getOperand(1);
}
// (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1)
// (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1)
- else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) {
- B = Val2;
- if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1))
- A = Val;
- else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val)
+ else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
+ B = RHS0;
+ if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1))
+ A = LHS0;
+ else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0)
A = LHS->getOperand(1);
}
if (A && B)
@@ -1778,54 +1754,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true))
return V;
+ if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder))
+ return V;
+
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
- if (!LHSCst || !RHSCst) return nullptr;
+ if (!LHSC || !RHSC)
+ return nullptr;
- if (LHSCst == RHSCst && LHSCC == RHSCC) {
+ if (LHSC == RHSC && PredL == PredR) {
// (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
- if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) {
- Value *NewOr = Builder->CreateOr(Val, Val2);
- return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
+ if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) {
+ Value *NewOr = Builder->CreateOr(LHS0, RHS0);
+ return Builder->CreateICmp(PredL, NewOr, LHSC);
}
}
// (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1)
// iff C2 + CA == C1.
- if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) {
- ConstantInt *AddCst;
- if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst))))
- if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue())
- return Builder->CreateICmpULE(Val, LHSCst);
+ if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) {
+ ConstantInt *AddC;
+ if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC))))
+ if (RHSC->getValue() + AddC->getValue() == LHSC->getValue())
+ return Builder->CreateICmpULE(LHS0, LHSC);
}
// From here on, we only handle:
// (icmp1 A, C1) | (icmp2 A, C2) --> something simpler.
- if (Val != Val2) return nullptr;
+ if (LHS0 != RHS0)
+ return nullptr;
- // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere.
- if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE ||
- RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE ||
- LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE ||
- RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE)
+ // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere.
+ if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE ||
+ PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE ||
+ PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE ||
+ PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)
return nullptr;
// We can't fold (ugt x, C) | (sgt x, C2).
- if (!PredicatesFoldable(LHSCC, RHSCC))
+ if (!PredicatesFoldable(PredL, PredR))
return nullptr;
// Ensure that the larger constant is on the RHS.
bool ShouldSwap;
- if (CmpInst::isSigned(LHSCC) ||
- (ICmpInst::isEquality(LHSCC) &&
- CmpInst::isSigned(RHSCC)))
- ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue());
+ if (CmpInst::isSigned(PredL) ||
+ (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR)))
+ ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());
else
- ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue());
+ ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());
if (ShouldSwap) {
std::swap(LHS, RHS);
- std::swap(LHSCst, RHSCst);
- std::swap(LHSCC, RHSCC);
+ std::swap(LHSC, RHSC);
+ std::swap(PredL, PredR);
}
// At this point, we know we have two icmp instructions
@@ -1834,127 +1814,98 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the
// icmp folding check above), that the two constants are not
// equal.
- assert(LHSCst != RHSCst && "Compares not folded above?");
+ assert(LHSC != RHSC && "Compares not folded above?");
- switch (LHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredL) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- if (LHS->getOperand(0) == RHS->getOperand(0)) {
- // if LHSCst and RHSCst differ only by one bit:
- // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2
- assert(LHSCst->getValue().ule(LHSCst->getValue()));
-
- APInt Xor = LHSCst->getValue() ^ RHSCst->getValue();
- if (Xor.isPowerOf2()) {
- Value *Cst = Builder->getInt(Xor);
- Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst);
- return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst);
- }
- }
-
- if (LHSCst == SubOne(RHSCst)) {
- // (X == 13 | X == 14) -> X-13 <u 2
- Constant *AddCST = ConstantExpr::getNeg(LHSCst);
- Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off");
- AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst);
- return Builder->CreateICmpULT(Add, AddCST);
- }
-
- break; // (X == 13 | X == 15) -> no change
- case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change
- case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change
+ // Potential folds for this case should already be handled.
+ break;
+ case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change
+ case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change
break;
- case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15
- case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15
+ case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15
+ case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15
return RHS;
}
break;
case ICmpInst::ICMP_NE:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13
- case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13
- case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13
+ case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13
+ case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13
return LHS;
- case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true
- case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true
- case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true
+ case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true
+ case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true
+ case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true
return Builder->getTrue();
}
case ICmpInst::ICMP_ULT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change
break;
- case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2
- // If RHSCst is [us]MAXINT, it is always false. Not handling
+ case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2
+ // If RHSC is [us]MAXINT, it is always false. Not handling
// this can cause overflow.
- if (RHSCst->isMaxValue(false))
+ if (RHSC->isMaxValue(false))
return LHS;
- return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1,
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1,
false, false);
- case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15
+ case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15
return RHS;
- case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SLT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change
break;
- case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2
- // If RHSCst is [us]MAXINT, it is always false. Not handling
+ case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2
+ // If RHSC is [us]MAXINT, it is always false. Not handling
// this can cause overflow.
- if (RHSCst->isMaxValue(true))
+ if (RHSC->isMaxValue(true))
return LHS;
- return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1,
- true, false);
- case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true,
+ false);
+ case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15
return RHS;
- case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_UGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13
- case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13
+ case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13
return LHS;
- case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true
- case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true
+ case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true
+ case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true
return Builder->getTrue();
- case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13
- case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13
+ case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13
return LHS;
- case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true
- case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true
+ case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true
+ case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true
return Builder->getTrue();
- case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change
- break;
}
break;
}
@@ -2100,17 +2051,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
ConstantInt *C1 = nullptr; Value *X = nullptr;
- // (X & C1) | C2 --> (X | C2) & (C1|C2)
- // iff (C1 & C2) == 0.
- if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) &&
- (RHS->getValue() & C1->getValue()) != 0 &&
- Op0->hasOneUse()) {
- Value *Or = Builder->CreateOr(X, RHS);
- Or->takeName(Op0);
- return BinaryOperator::CreateAnd(Or,
- Builder->getInt(RHS->getValue() | C1->getValue()));
- }
-
// (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2)
if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) &&
Op0->hasOneUse()) {
@@ -2119,45 +2059,51 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
return BinaryOperator::CreateXor(Or,
Builder->getInt(C1->getValue() & ~RHS->getValue()));
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
// Given an OR instruction, check to see if this is a bswap.
if (Instruction *BSwap = MatchBSwap(I))
return BSwap;
- Value *A = nullptr, *B = nullptr;
- ConstantInt *C1 = nullptr, *C2 = nullptr;
+ {
+ Value *A;
+ const APInt *C;
+ // (X^C)|Y -> (X|Y)^C iff Y&C == 0
+ if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) &&
+ MaskedValueIsZero(Op1, *C, 0, &I)) {
+ Value *NOr = Builder->CreateOr(A, Op1);
+ NOr->takeName(Op0);
+ return BinaryOperator::CreateXor(NOr,
+ cast<Instruction>(Op0)->getOperand(1));
+ }
- // (X^C)|Y -> (X|Y)^C iff Y&C == 0
- if (Op0->hasOneUse() &&
- match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) &&
- MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) {
- Value *NOr = Builder->CreateOr(A, Op1);
- NOr->takeName(Op0);
- return BinaryOperator::CreateXor(NOr, C1);
+ // Y|(X^C) -> (X|Y)^C iff Y&C == 0
+ if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) &&
+ MaskedValueIsZero(Op0, *C, 0, &I)) {
+ Value *NOr = Builder->CreateOr(A, Op0);
+ NOr->takeName(Op0);
+ return BinaryOperator::CreateXor(NOr,
+ cast<Instruction>(Op1)->getOperand(1));
+ }
}
- // Y|(X^C) -> (X|Y)^C iff Y&C == 0
- if (Op1->hasOneUse() &&
- match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) &&
- MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) {
- Value *NOr = Builder->CreateOr(A, Op0);
- NOr->takeName(Op0);
- return BinaryOperator::CreateXor(NOr, C1);
- }
+ Value *A, *B;
// ((~A & B) | A) -> (A | B)
- if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1, m_Specific(A)))
- return BinaryOperator::CreateOr(A, B);
+ if (match(Op0, m_c_And(m_Not(m_Specific(Op1)), m_Value(A))))
+ return BinaryOperator::CreateOr(A, Op1);
+ if (match(Op1, m_c_And(m_Not(m_Specific(Op0)), m_Value(A))))
+ return BinaryOperator::CreateOr(Op0, A);
// ((A & B) | ~A) -> (~A | B)
- if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
- match(Op1, m_Not(m_Specific(A))))
- return BinaryOperator::CreateOr(Builder->CreateNot(A), B);
+ // The NOT is guaranteed to be in the RHS by complexity ordering.
+ if (match(Op1, m_Not(m_Value(A))) &&
+ match(Op0, m_c_And(m_Specific(A), m_Value(B))))
+ return BinaryOperator::CreateOr(Op1, B);
// (A & ~B) | (A ^ B) -> (A ^ B)
// (~B & A) | (A ^ B) -> (A ^ B)
@@ -2177,8 +2123,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
match(Op1, m_And(m_Value(B), m_Value(D)))) {
Value *V1 = nullptr, *V2 = nullptr;
- C1 = dyn_cast<ConstantInt>(C);
- C2 = dyn_cast<ConstantInt>(D);
+ ConstantInt *C1 = dyn_cast<ConstantInt>(C);
+ ConstantInt *C2 = dyn_cast<ConstantInt>(D);
if (C1 && C2) { // (A & C1)|(B & C2)
if ((C1->getValue() & C2->getValue()) == 0) {
// ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2)
@@ -2403,6 +2349,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
// be simplified by a later pass either, so we try swapping the inner/outer
// ORs in the hopes that we'll be able to simplify it this way.
// (X|C) | V --> (X|V) | C
+ ConstantInt *C1;
if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) &&
match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) {
Value *Inner = Builder->CreateOr(A, Op1);
@@ -2493,23 +2440,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
- if (Constant *RHS = dyn_cast<Constant>(Op1)) {
- if (RHS->isAllOnesValue() && Op0->hasOneUse())
- // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B
- if (CmpInst *CI = dyn_cast<CmpInst>(Op0))
- return CmpInst::Create(CI->getOpcode(),
- CI->getInversePredicate(),
- CI->getOperand(0), CI->getOperand(1));
+ // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B
+ ICmpInst::Predicate Pred;
+ if (match(Op0, m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))) &&
+ match(Op1, m_AllOnes())) {
+ cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred));
+ return replaceInstUsesWith(I, Op0);
}
- if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
+ if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) {
// fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp).
if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {
if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) {
if (CI->hasOneUse() && Op0C->hasOneUse()) {
Instruction::CastOps Opcode = Op0C->getOpcode();
if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) &&
- (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(),
+ (RHSC == ConstantExpr::getCast(Opcode, Builder->getTrue(),
Op0C->getDestTy()))) {
CI->setPredicate(CI->getInversePredicate());
return CastInst::Create(Opcode, CI, Op0C->getType());
@@ -2520,26 +2466,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
// ~(c-X) == X-c-1 == X+(-c-1)
- if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue())
+ if (Op0I->getOpcode() == Instruction::Sub && RHSC->isAllOnesValue())
if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) {
Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C);
- Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C,
- ConstantInt::get(I.getType(), 1));
- return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS);
+ return BinaryOperator::CreateAdd(Op0I->getOperand(1),
+ SubOne(NegOp0I0C));
}
if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
if (Op0I->getOpcode() == Instruction::Add) {
// ~(X-c) --> (-c-1)-X
- if (RHS->isAllOnesValue()) {
+ if (RHSC->isAllOnesValue()) {
Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI);
- return BinaryOperator::CreateSub(
- ConstantExpr::getSub(NegOp0CI,
- ConstantInt::get(I.getType(), 1)),
- Op0I->getOperand(0));
- } else if (RHS->getValue().isSignBit()) {
+ return BinaryOperator::CreateSub(SubOne(NegOp0CI),
+ Op0I->getOperand(0));
+ } else if (RHSC->getValue().isSignBit()) {
// (X + C) ^ signbit -> (X + C + signbit)
- Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue());
+ Constant *C = Builder->getInt(RHSC->getValue() + Op0CI->getValue());
return BinaryOperator::CreateAdd(Op0I->getOperand(0), C);
}
@@ -2547,10 +2490,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0
if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(),
0, &I)) {
- Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS);
+ Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC);
// Anything in both C1 and C2 is known to be zero, remove it from
// NewRHS.
- Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS);
+ Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC);
NewRHS = ConstantExpr::getAnd(NewRHS,
ConstantExpr::getNot(CommonBits));
Worklist.Add(Op0I);
@@ -2568,7 +2511,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
E1->getOpcode() == Instruction::Xor &&
(C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) {
// fold (C1 >> C2) ^ C3
- ConstantInt *C2 = Op0CI, *C3 = RHS;
+ ConstantInt *C2 = Op0CI, *C3 = RHSC;
APInt FoldConst = C1->getValue().lshr(C2->getValue());
FoldConst ^= C3->getValue();
// Prepare the two operands.
@@ -2582,27 +2525,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
- BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1);
- if (Op1I) {
+ {
Value *A, *B;
- if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) {
- if (A == Op0) { // B^(B|A) == (A|B)^B
- Op1I->swapOperands();
- I.swapOperands();
- std::swap(Op0, Op1);
- } else if (B == Op0) { // B^(A|B) == (A|B)^B
+ if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) {
+ if (A == Op0) { // A^(A|B) == A^(B|A)
+ cast<BinaryOperator>(Op1)->swapOperands();
+ std::swap(A, B);
+ }
+ if (B == Op0) { // A^(B|A) == (B|A)^A
I.swapOperands(); // Simplified below.
std::swap(Op0, Op1);
}
- } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) &&
- Op1I->hasOneUse()){
+ } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) {
if (A == Op0) { // A^(A&B) -> A^(B&A)
- Op1I->swapOperands();
+ cast<BinaryOperator>(Op1)->swapOperands();
std::swap(A, B);
}
if (B == Op0) { // A^(B&A) -> (B&A)^A
@@ -2612,65 +2554,63 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
- BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0);
- if (Op0I) {
+ {
Value *A, *B;
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- Op0I->hasOneUse()) {
+ if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) {
if (A == Op1) // (B|A)^B == (A|B)^B
std::swap(A, B);
if (B == Op1) // (A|B)^B == A & ~B
return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1));
- } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- Op0I->hasOneUse()){
+ } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) {
if (A == Op1) // (A&B)^A -> (B&A)^A
std::swap(A, B);
+ const APInt *C;
if (B == Op1 && // (B&A)^A == ~B & A
- !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C
+ !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C
return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1);
}
}
}
- if (Op0I && Op1I) {
+ {
Value *A, *B, *C, *D;
// (A & B)^(A | B) -> A ^ B
- if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Or(m_Value(C), m_Value(D)))) {
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Or(m_Value(C), m_Value(D)))) {
if ((A == C && B == D) || (A == D && B == C))
return BinaryOperator::CreateXor(A, B);
}
// (A | B)^(A & B) -> A ^ B
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- match(Op1I, m_And(m_Value(C), m_Value(D)))) {
+ if (match(Op0, m_Or(m_Value(A), m_Value(B))) &&
+ match(Op1, m_And(m_Value(C), m_Value(D)))) {
if ((A == C && B == D) || (A == D && B == C))
return BinaryOperator::CreateXor(A, B);
}
// (A | ~B) ^ (~A | B) -> A ^ B
// (~B | A) ^ (~A | B) -> A ^ B
- if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B))))
+ if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) &&
+ match(Op1, m_Or(m_Not(m_Specific(A)), m_Specific(B))))
return BinaryOperator::CreateXor(A, B);
// (~A | B) ^ (A | ~B) -> A ^ B
- if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) {
+ if (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) &&
+ match(Op1, m_Or(m_Specific(A), m_Not(m_Specific(B))))) {
return BinaryOperator::CreateXor(A, B);
}
// (A & ~B) ^ (~A & B) -> A ^ B
// (~B & A) ^ (~A & B) -> A ^ B
- if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B))))
+ if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) &&
+ match(Op1, m_And(m_Not(m_Specific(A)), m_Specific(B))))
return BinaryOperator::CreateXor(A, B);
// (~A & B) ^ (A & ~B) -> A ^ B
- if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) {
+ if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) &&
+ match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) {
return BinaryOperator::CreateXor(A, B);
}
// (A ^ C)^(A | B) -> ((~A) & B) ^ C
- if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) &&
- match(Op1I, m_Or(m_Value(A), m_Value(B)))) {
+ if (match(Op0, m_Xor(m_Value(D), m_Value(C))) &&
+ match(Op1, m_Or(m_Value(A), m_Value(B)))) {
if (D == A)
return BinaryOperator::CreateXor(
Builder->CreateAnd(Builder->CreateNot(A), B), C);
@@ -2679,8 +2619,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
Builder->CreateAnd(Builder->CreateNot(B), A), C);
}
// (A | B)^(A ^ C) -> ((~A) & B) ^ C
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Xor(m_Value(D), m_Value(C)))) {
+ if (match(Op0, m_Or(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Xor(m_Value(D), m_Value(C)))) {
if (D == A)
return BinaryOperator::CreateXor(
Builder->CreateAnd(Builder->CreateNot(A), B), C);
@@ -2689,12 +2629,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
Builder->CreateAnd(Builder->CreateNot(B), A), C);
}
// (A & B) ^ (A ^ B) -> (A | B)
- if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Xor(m_Specific(A), m_Specific(B))))
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Xor(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
// (A ^ B) ^ (A & B) -> (A | B)
- if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) &&
- match(Op1I, m_And(m_Specific(A), m_Specific(B))))
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
}
diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 2ef82ba3ed8c4..69484f47223f7 100644
--- a/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -60,6 +60,12 @@ using namespace PatternMatch;
STATISTIC(NumSimplified, "Number of library calls simplified");
+static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements(
+ "unfold-element-atomic-memcpy-max-elements",
+ cl::init(16),
+ cl::desc("Maximum number of elements in atomic memcpy the optimizer is "
+ "allowed to unfold"));
+
/// Return the specified type promoted as it would be to pass though a va_arg
/// area.
static Type *getPromotedType(Type *Ty) {
@@ -70,27 +76,6 @@ static Type *getPromotedType(Type *Ty) {
return Ty;
}
-/// Given an aggregate type which ultimately holds a single scalar element,
-/// like {{{type}}} or [1 x type], return type.
-static Type *reduceToSingleValueType(Type *T) {
- while (!T->isSingleValueType()) {
- if (StructType *STy = dyn_cast<StructType>(T)) {
- if (STy->getNumElements() == 1)
- T = STy->getElementType(0);
- else
- break;
- } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) {
- if (ATy->getNumElements() == 1)
- T = ATy->getElementType();
- else
- break;
- } else
- break;
- }
-
- return T;
-}
-
/// Return a constant boolean vector that has true elements in all positions
/// where the input constant data vector has an element with the sign bit set.
static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
@@ -108,6 +93,78 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
return ConstantVector::get(BoolVec);
}
+Instruction *
+InstCombiner::SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI) {
+ // Try to unfold this intrinsic into sequence of explicit atomic loads and
+ // stores.
+ // First check that number of elements is compile time constant.
+ auto *NumElementsCI = dyn_cast<ConstantInt>(AMI->getNumElements());
+ if (!NumElementsCI)
+ return nullptr;
+
+ // Check that there are not too many elements.
+ uint64_t NumElements = NumElementsCI->getZExtValue();
+ if (NumElements >= UnfoldElementAtomicMemcpyMaxElements)
+ return nullptr;
+
+ // Don't unfold into illegal integers
+ uint64_t ElementSizeInBytes = AMI->getElementSizeInBytes() * 8;
+ if (!getDataLayout().isLegalInteger(ElementSizeInBytes))
+ return nullptr;
+
+ // Cast source and destination to the correct type. Intrinsic input arguments
+ // are usually represented as i8*.
+ // Often operands will be explicitly casted to i8* and we can just strip
+ // those casts instead of inserting new ones. However it's easier to rely on
+ // other InstCombine rules which will cover trivial cases anyway.
+ Value *Src = AMI->getRawSource();
+ Value *Dst = AMI->getRawDest();
+ Type *ElementPointerType = Type::getIntNPtrTy(
+ AMI->getContext(), ElementSizeInBytes, Src->getType()->getPointerAddressSpace());
+
+ Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType,
+ "memcpy_unfold.src_casted");
+ Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType,
+ "memcpy_unfold.dst_casted");
+
+ for (uint64_t i = 0; i < NumElements; ++i) {
+ // Get current element addresses
+ ConstantInt *ElementIdxCI =
+ ConstantInt::get(AMI->getContext(), APInt(64, i));
+ Value *SrcElementAddr =
+ Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr");
+ Value *DstElementAddr =
+ Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr");
+
+ // Load from the source. Transfer alignment information and mark load as
+ // unordered atomic.
+ LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val");
+ Load->setOrdering(AtomicOrdering::Unordered);
+ // We know alignment of the first element. It is also guaranteed by the
+ // verifier that element size is less or equal than first element alignment
+ // and both of this values are powers of two.
+ // This means that all subsequent accesses are at least element size
+ // aligned.
+ // TODO: We can infer better alignment but there is no evidence that this
+ // will matter.
+ Load->setAlignment(i == 0 ? AMI->getSrcAlignment()
+ : AMI->getElementSizeInBytes());
+ Load->setDebugLoc(AMI->getDebugLoc());
+
+ // Store loaded value via unordered atomic store.
+ StoreInst *Store = Builder->CreateStore(Load, DstElementAddr);
+ Store->setOrdering(AtomicOrdering::Unordered);
+ Store->setAlignment(i == 0 ? AMI->getDstAlignment()
+ : AMI->getElementSizeInBytes());
+ Store->setDebugLoc(AMI->getDebugLoc());
+ }
+
+ // Set the number of elements of the copy to 0, it will be deleted on the
+ // next iteration.
+ AMI->setNumElements(Constant::getNullValue(NumElementsCI->getType()));
+ return AMI;
+}
+
Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT);
unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT);
@@ -144,41 +201,19 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp);
Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp);
- // Memcpy forces the use of i8* for the source and destination. That means
- // that if you're using memcpy to move one double around, you'll get a cast
- // from double* to i8*. We'd much rather use a double load+store rather than
- // an i64 load+store, here because this improves the odds that the source or
- // dest address will be promotable. See if we can find a better type than the
- // integer datatype.
- Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts();
+ // If the memcpy has metadata describing the members, see if we can get the
+ // TBAA tag describing our copy.
MDNode *CopyMD = nullptr;
- if (StrippedDest != MI->getArgOperand(0)) {
- Type *SrcETy = cast<PointerType>(StrippedDest->getType())
- ->getElementType();
- if (SrcETy->isSized() && DL.getTypeStoreSize(SrcETy) == Size) {
- // The SrcETy might be something like {{{double}}} or [1 x double]. Rip
- // down through these levels if so.
- SrcETy = reduceToSingleValueType(SrcETy);
-
- if (SrcETy->isSingleValueType()) {
- NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp);
- NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp);
-
- // If the memcpy has metadata describing the members, see if we can
- // get the TBAA tag describing our copy.
- if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) {
- if (M->getNumOperands() == 3 && M->getOperand(0) &&
- mdconst::hasa<ConstantInt>(M->getOperand(0)) &&
- mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() &&
- M->getOperand(1) &&
- mdconst::hasa<ConstantInt>(M->getOperand(1)) &&
- mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() ==
- Size &&
- M->getOperand(2) && isa<MDNode>(M->getOperand(2)))
- CopyMD = cast<MDNode>(M->getOperand(2));
- }
- }
- }
+ if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) {
+ if (M->getNumOperands() == 3 && M->getOperand(0) &&
+ mdconst::hasa<ConstantInt>(M->getOperand(0)) &&
+ mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() &&
+ M->getOperand(1) &&
+ mdconst::hasa<ConstantInt>(M->getOperand(1)) &&
+ mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() ==
+ Size &&
+ M->getOperand(2) && isa<MDNode>(M->getOperand(2)))
+ CopyMD = cast<MDNode>(M->getOperand(2));
}
// If the memcpy/memmove provides better alignment info than we can
@@ -510,6 +545,131 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,
return Builder.CreateAShr(Vec, ShiftVec);
}
+static Value *simplifyX86muldq(const IntrinsicInst &II,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Arg0 = II.getArgOperand(0);
+ Value *Arg1 = II.getArgOperand(1);
+ Type *ResTy = II.getType();
+ assert(Arg0->getType()->getScalarSizeInBits() == 32 &&
+ Arg1->getType()->getScalarSizeInBits() == 32 &&
+ ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types");
+
+ // muldq/muludq(undef, undef) -> zero (matches generic mul behavior)
+ if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1))
+ return ConstantAggregateZero::get(ResTy);
+
+ // Constant folding.
+ // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)),
+ // vXi64 sext(shuffle<0,2,..>(Arg1))))
+ // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)),
+ // vXi64 zext(shuffle<0,2,..>(Arg1))))
+ if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+ return nullptr;
+
+ unsigned NumElts = ResTy->getVectorNumElements();
+ assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) &&
+ Arg1->getType()->getVectorNumElements() == (2 * NumElts) &&
+ "Unexpected muldq/muludq types");
+
+ unsigned IntrinsicID = II.getIntrinsicID();
+ bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID ||
+ Intrinsic::x86_avx2_pmul_dq == IntrinsicID ||
+ Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID);
+
+ SmallVector<unsigned, 16> ShuffleMask;
+ for (unsigned i = 0; i != NumElts; ++i)
+ ShuffleMask.push_back(i * 2);
+
+ auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask);
+ auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask);
+
+ if (IsSigned) {
+ LHS = Builder.CreateSExt(LHS, ResTy);
+ RHS = Builder.CreateSExt(RHS, ResTy);
+ } else {
+ LHS = Builder.CreateZExt(LHS, ResTy);
+ RHS = Builder.CreateZExt(RHS, ResTy);
+ }
+
+ return Builder.CreateMul(LHS, RHS);
+}
+
+static Value *simplifyX86pack(IntrinsicInst &II, InstCombiner &IC,
+ InstCombiner::BuilderTy &Builder, bool IsSigned) {
+ Value *Arg0 = II.getArgOperand(0);
+ Value *Arg1 = II.getArgOperand(1);
+ Type *ResTy = II.getType();
+
+ // Fast all undef handling.
+ if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
+ return UndefValue::get(ResTy);
+
+ Type *ArgTy = Arg0->getType();
+ unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128;
+ unsigned NumDstElts = ResTy->getVectorNumElements();
+ unsigned NumSrcElts = ArgTy->getVectorNumElements();
+ assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types");
+
+ unsigned NumDstEltsPerLane = NumDstElts / NumLanes;
+ unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes;
+ unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits();
+ assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) &&
+ "Unexpected packing types");
+
+ // Constant folding.
+ auto *Cst0 = dyn_cast<Constant>(Arg0);
+ auto *Cst1 = dyn_cast<Constant>(Arg1);
+ if (!Cst0 || !Cst1)
+ return nullptr;
+
+ SmallVector<Constant *, 32> Vals;
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) {
+ unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane;
+ auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0;
+ auto *COp = Cst->getAggregateElement(SrcIdx);
+ if (COp && isa<UndefValue>(COp)) {
+ Vals.push_back(UndefValue::get(ResTy->getScalarType()));
+ continue;
+ }
+
+ auto *CInt = dyn_cast_or_null<ConstantInt>(COp);
+ if (!CInt)
+ return nullptr;
+
+ APInt Val = CInt->getValue();
+ assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() &&
+ "Unexpected constant bitwidth");
+
+ if (IsSigned) {
+ // PACKSS: Truncate signed value with signed saturation.
+ // Source values less than dst minint are saturated to minint.
+ // Source values greater than dst maxint are saturated to maxint.
+ if (Val.isSignedIntN(DstScalarSizeInBits))
+ Val = Val.trunc(DstScalarSizeInBits);
+ else if (Val.isNegative())
+ Val = APInt::getSignedMinValue(DstScalarSizeInBits);
+ else
+ Val = APInt::getSignedMaxValue(DstScalarSizeInBits);
+ } else {
+ // PACKUS: Truncate signed value with unsigned saturation.
+ // Source values less than zero are saturated to zero.
+ // Source values greater than dst maxuint are saturated to maxuint.
+ if (Val.isIntN(DstScalarSizeInBits))
+ Val = Val.trunc(DstScalarSizeInBits);
+ else if (Val.isNegative())
+ Val = APInt::getNullValue(DstScalarSizeInBits);
+ else
+ Val = APInt::getAllOnesValue(DstScalarSizeInBits);
+ }
+
+ Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val));
+ }
+ }
+
+ return ConstantVector::get(Vals);
+}
+
static Value *simplifyX86movmsk(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder) {
Value *Arg = II.getArgOperand(0);
@@ -1330,6 +1490,27 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
return true;
}
+// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
+//
+// A single NaN input is folded to minnum, so we rely on that folding for
+// handling NaNs.
+static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
+ const APFloat &Src2) {
+ APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
+
+ APFloat::cmpResult Cmp0 = Max3.compare(Src0);
+ assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
+ if (Cmp0 == APFloat::cmpEqual)
+ return maxnum(Src1, Src2);
+
+ APFloat::cmpResult Cmp1 = Max3.compare(Src1);
+ assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
+ if (Cmp1 == APFloat::cmpEqual)
+ return maxnum(Src0, Src2);
+
+ return maxnum(Src0, Src1);
+}
+
// Returns true iff the 2 intrinsics have the same operands, limiting the
// comparison to the first NumOperands.
static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E,
@@ -1373,6 +1554,254 @@ static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID,
return false;
}
+// Convert NVVM intrinsics to target-generic LLVM code where possible.
+static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
+ // Each NVVM intrinsic we can simplify can be replaced with one of:
+ //
+ // * an LLVM intrinsic,
+ // * an LLVM cast operation,
+ // * an LLVM binary operation, or
+ // * ad-hoc LLVM IR for the particular operation.
+
+ // Some transformations are only valid when the module's
+ // flush-denormals-to-zero (ftz) setting is true/false, whereas other
+ // transformations are valid regardless of the module's ftz setting.
+ enum FtzRequirementTy {
+ FTZ_Any, // Any ftz setting is ok.
+ FTZ_MustBeOn, // Transformation is valid only if ftz is on.
+ FTZ_MustBeOff, // Transformation is valid only if ftz is off.
+ };
+ // Classes of NVVM intrinsics that can't be replaced one-to-one with a
+ // target-generic intrinsic, cast op, or binary op but that we can nonetheless
+ // simplify.
+ enum SpecialCase {
+ SPC_Reciprocal,
+ };
+
+ // SimplifyAction is a poor-man's variant (plus an additional flag) that
+ // represents how to replace an NVVM intrinsic with target-generic LLVM IR.
+ struct SimplifyAction {
+ // Invariant: At most one of these Optionals has a value.
+ Optional<Intrinsic::ID> IID;
+ Optional<Instruction::CastOps> CastOp;
+ Optional<Instruction::BinaryOps> BinaryOp;
+ Optional<SpecialCase> Special;
+
+ FtzRequirementTy FtzRequirement = FTZ_Any;
+
+ SimplifyAction() = default;
+
+ SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq)
+ : IID(IID), FtzRequirement(FtzReq) {}
+
+ // Cast operations don't have anything to do with FTZ, so we skip that
+ // argument.
+ SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {}
+
+ SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq)
+ : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {}
+
+ SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq)
+ : Special(Special), FtzRequirement(FtzReq) {}
+ };
+
+ // Try to generate a SimplifyAction describing how to replace our
+ // IntrinsicInstr with target-generic LLVM IR.
+ const SimplifyAction Action = [II]() -> SimplifyAction {
+ switch (II->getIntrinsicID()) {
+
+ // NVVM intrinsics that map directly to LLVM intrinsics.
+ case Intrinsic::nvvm_ceil_d:
+ return {Intrinsic::ceil, FTZ_Any};
+ case Intrinsic::nvvm_ceil_f:
+ return {Intrinsic::ceil, FTZ_MustBeOff};
+ case Intrinsic::nvvm_ceil_ftz_f:
+ return {Intrinsic::ceil, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fabs_d:
+ return {Intrinsic::fabs, FTZ_Any};
+ case Intrinsic::nvvm_fabs_f:
+ return {Intrinsic::fabs, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fabs_ftz_f:
+ return {Intrinsic::fabs, FTZ_MustBeOn};
+ case Intrinsic::nvvm_floor_d:
+ return {Intrinsic::floor, FTZ_Any};
+ case Intrinsic::nvvm_floor_f:
+ return {Intrinsic::floor, FTZ_MustBeOff};
+ case Intrinsic::nvvm_floor_ftz_f:
+ return {Intrinsic::floor, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fma_rn_d:
+ return {Intrinsic::fma, FTZ_Any};
+ case Intrinsic::nvvm_fma_rn_f:
+ return {Intrinsic::fma, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ return {Intrinsic::fma, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fmax_d:
+ return {Intrinsic::maxnum, FTZ_Any};
+ case Intrinsic::nvvm_fmax_f:
+ return {Intrinsic::maxnum, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fmax_ftz_f:
+ return {Intrinsic::maxnum, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fmin_d:
+ return {Intrinsic::minnum, FTZ_Any};
+ case Intrinsic::nvvm_fmin_f:
+ return {Intrinsic::minnum, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fmin_ftz_f:
+ return {Intrinsic::minnum, FTZ_MustBeOn};
+ case Intrinsic::nvvm_round_d:
+ return {Intrinsic::round, FTZ_Any};
+ case Intrinsic::nvvm_round_f:
+ return {Intrinsic::round, FTZ_MustBeOff};
+ case Intrinsic::nvvm_round_ftz_f:
+ return {Intrinsic::round, FTZ_MustBeOn};
+ case Intrinsic::nvvm_sqrt_rn_d:
+ return {Intrinsic::sqrt, FTZ_Any};
+ case Intrinsic::nvvm_sqrt_f:
+ // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the
+ // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts
+ // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are
+ // the versions with explicit ftz-ness.
+ return {Intrinsic::sqrt, FTZ_Any};
+ case Intrinsic::nvvm_sqrt_rn_f:
+ return {Intrinsic::sqrt, FTZ_MustBeOff};
+ case Intrinsic::nvvm_sqrt_rn_ftz_f:
+ return {Intrinsic::sqrt, FTZ_MustBeOn};
+ case Intrinsic::nvvm_trunc_d:
+ return {Intrinsic::trunc, FTZ_Any};
+ case Intrinsic::nvvm_trunc_f:
+ return {Intrinsic::trunc, FTZ_MustBeOff};
+ case Intrinsic::nvvm_trunc_ftz_f:
+ return {Intrinsic::trunc, FTZ_MustBeOn};
+
+ // NVVM intrinsics that map to LLVM cast operations.
+ //
+ // Note that llvm's target-generic conversion operators correspond to the rz
+ // (round to zero) versions of the nvvm conversion intrinsics, even though
+ // most everything else here uses the rn (round to nearest even) nvvm ops.
+ case Intrinsic::nvvm_d2i_rz:
+ case Intrinsic::nvvm_f2i_rz:
+ case Intrinsic::nvvm_d2ll_rz:
+ case Intrinsic::nvvm_f2ll_rz:
+ return {Instruction::FPToSI};
+ case Intrinsic::nvvm_d2ui_rz:
+ case Intrinsic::nvvm_f2ui_rz:
+ case Intrinsic::nvvm_d2ull_rz:
+ case Intrinsic::nvvm_f2ull_rz:
+ return {Instruction::FPToUI};
+ case Intrinsic::nvvm_i2d_rz:
+ case Intrinsic::nvvm_i2f_rz:
+ case Intrinsic::nvvm_ll2d_rz:
+ case Intrinsic::nvvm_ll2f_rz:
+ return {Instruction::SIToFP};
+ case Intrinsic::nvvm_ui2d_rz:
+ case Intrinsic::nvvm_ui2f_rz:
+ case Intrinsic::nvvm_ull2d_rz:
+ case Intrinsic::nvvm_ull2f_rz:
+ return {Instruction::UIToFP};
+
+ // NVVM intrinsics that map to LLVM binary ops.
+ case Intrinsic::nvvm_add_rn_d:
+ return {Instruction::FAdd, FTZ_Any};
+ case Intrinsic::nvvm_add_rn_f:
+ return {Instruction::FAdd, FTZ_MustBeOff};
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return {Instruction::FAdd, FTZ_MustBeOn};
+ case Intrinsic::nvvm_mul_rn_d:
+ return {Instruction::FMul, FTZ_Any};
+ case Intrinsic::nvvm_mul_rn_f:
+ return {Instruction::FMul, FTZ_MustBeOff};
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ return {Instruction::FMul, FTZ_MustBeOn};
+ case Intrinsic::nvvm_div_rn_d:
+ return {Instruction::FDiv, FTZ_Any};
+ case Intrinsic::nvvm_div_rn_f:
+ return {Instruction::FDiv, FTZ_MustBeOff};
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ return {Instruction::FDiv, FTZ_MustBeOn};
+
+ // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but
+ // need special handling.
+ //
+ // We seem to be mising intrinsics for rcp.approx.{ftz.}f32, which is just
+ // as well.
+ case Intrinsic::nvvm_rcp_rn_d:
+ return {SPC_Reciprocal, FTZ_Any};
+ case Intrinsic::nvvm_rcp_rn_f:
+ return {SPC_Reciprocal, FTZ_MustBeOff};
+ case Intrinsic::nvvm_rcp_rn_ftz_f:
+ return {SPC_Reciprocal, FTZ_MustBeOn};
+
+ // We do not currently simplify intrinsics that give an approximate answer.
+ // These include:
+ //
+ // - nvvm_cos_approx_{f,ftz_f}
+ // - nvvm_ex2_approx_{d,f,ftz_f}
+ // - nvvm_lg2_approx_{d,f,ftz_f}
+ // - nvvm_sin_approx_{f,ftz_f}
+ // - nvvm_sqrt_approx_{f,ftz_f}
+ // - nvvm_rsqrt_approx_{d,f,ftz_f}
+ // - nvvm_div_approx_{ftz_d,ftz_f,f}
+ // - nvvm_rcp_approx_ftz_d
+ //
+ // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast"
+ // means that fastmath is enabled in the intrinsic. Unfortunately only
+ // binary operators (currently) have a fastmath bit in SelectionDAG, so this
+ // information gets lost and we can't select on it.
+ //
+ // TODO: div and rcp are lowered to a binary op, so these we could in theory
+ // lower them to "fast fdiv".
+
+ default:
+ return {};
+ }
+ }();
+
+ // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we
+ // can bail out now. (Notice that in the case that IID is not an NVVM
+ // intrinsic, we don't have to look up any module metadata, as
+ // FtzRequirementTy will be FTZ_Any.)
+ if (Action.FtzRequirement != FTZ_Any) {
+ bool FtzEnabled =
+ II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() ==
+ "true";
+
+ if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn))
+ return nullptr;
+ }
+
+ // Simplify to target-generic intrinsic.
+ if (Action.IID) {
+ SmallVector<Value *, 4> Args(II->arg_operands());
+ // All the target-generic intrinsics currently of interest to us have one
+ // type argument, equal to that of the nvvm intrinsic's argument.
+ Type *Tys[] = {II->getArgOperand(0)->getType()};
+ return CallInst::Create(
+ Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args);
+ }
+
+ // Simplify to target-generic binary op.
+ if (Action.BinaryOp)
+ return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0),
+ II->getArgOperand(1), II->getName());
+
+ // Simplify to target-generic cast op.
+ if (Action.CastOp)
+ return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(),
+ II->getName());
+
+ // All that's left are the special cases.
+ if (!Action.Special)
+ return nullptr;
+
+ switch (*Action.Special) {
+ case SPC_Reciprocal:
+ // Simplify reciprocal.
+ return BinaryOperator::Create(
+ Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1),
+ II->getArgOperand(0), II->getName());
+ }
+ llvm_unreachable("All SpecialCase enumerators should be handled in switch.");
+}
+
Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) {
removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this);
return nullptr;
@@ -1462,6 +1891,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
if (Changed) return II;
}
+ if (auto *AMI = dyn_cast<ElementAtomicMemCpyInst>(II)) {
+ if (Constant *C = dyn_cast<Constant>(AMI->getNumElements()))
+ if (C->isNullValue())
+ return eraseInstFromFunction(*AMI);
+
+ if (Instruction *I = SimplifyElementAtomicMemCpy(AMI))
+ return I;
+ }
+
+ if (Instruction *I = SimplifyNVVMIntrinsic(II, *this))
+ return I;
+
auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width,
unsigned DemandedWidth) {
APInt UndefElts(Width, 0);
@@ -1581,8 +2022,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(*II, V);
break;
}
- case Intrinsic::fma:
case Intrinsic::fmuladd: {
+ // Canonicalize fast fmuladd to the separate fmul + fadd.
+ if (II->hasUnsafeAlgebra()) {
+ BuilderTy::FastMathFlagGuard Guard(*Builder);
+ Builder->setFastMathFlags(II->getFastMathFlags());
+ Value *Mul = Builder->CreateFMul(II->getArgOperand(0),
+ II->getArgOperand(1));
+ Value *Add = Builder->CreateFAdd(Mul, II->getArgOperand(2));
+ Add->takeName(II);
+ return replaceInstUsesWith(*II, Add);
+ }
+
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::fma: {
Value *Src0 = II->getArgOperand(0);
Value *Src1 = II->getArgOperand(1);
@@ -1631,6 +2085,26 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return SelectInst::Create(Cond, Call0, Call1);
}
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::ceil:
+ case Intrinsic::floor:
+ case Intrinsic::round:
+ case Intrinsic::nearbyint:
+ case Intrinsic::rint:
+ case Intrinsic::trunc: {
+ Value *ExtSrc;
+ if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) &&
+ II->getArgOperand(0)->hasOneUse()) {
+ // fabs (fpext x) -> fpext (fabs x)
+ Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(),
+ { ExtSrc->getType() });
+ CallInst *NewFabs = Builder->CreateCall(F, ExtSrc);
+ NewFabs->copyFastMathFlags(II);
+ NewFabs->takeName(II);
+ return new FPExtInst(NewFabs, II->getType());
+ }
+
break;
}
case Intrinsic::cos:
@@ -1863,6 +2337,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return II;
break;
}
+ case Intrinsic::x86_avx512_mask_cmp_pd_128:
+ case Intrinsic::x86_avx512_mask_cmp_pd_256:
+ case Intrinsic::x86_avx512_mask_cmp_pd_512:
+ case Intrinsic::x86_avx512_mask_cmp_ps_128:
+ case Intrinsic::x86_avx512_mask_cmp_ps_256:
+ case Intrinsic::x86_avx512_mask_cmp_ps_512: {
+ // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a)
+ Value *Arg0 = II->getArgOperand(0);
+ Value *Arg1 = II->getArgOperand(1);
+ bool Arg0IsZero = match(Arg0, m_Zero());
+ if (Arg0IsZero)
+ std::swap(Arg0, Arg1);
+ Value *A, *B;
+ // This fold requires only the NINF(not +/- inf) since inf minus
+ // inf is nan.
+ // NSZ(No Signed Zeros) is not needed because zeros of any sign are
+ // equal for both compares.
+ // NNAN is not needed because nans compare the same for both compares.
+ // The compare intrinsic uses the above assumptions and therefore
+ // doesn't require additional flags.
+ if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) &&
+ match(Arg1, m_Zero()) &&
+ cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) {
+ if (Arg0IsZero)
+ std::swap(A, B);
+ II->setArgOperand(0, A);
+ II->setArgOperand(1, B);
+ return II;
+ }
+ break;
+ }
case Intrinsic::x86_avx512_mask_add_ps_512:
case Intrinsic::x86_avx512_mask_div_ps_512:
@@ -2130,6 +2635,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_avx2_pmulu_dq:
case Intrinsic::x86_avx512_pmul_dq_512:
case Intrinsic::x86_avx512_pmulu_dq_512: {
+ if (Value *V = simplifyX86muldq(*II, *Builder))
+ return replaceInstUsesWith(*II, V);
+
unsigned VWidth = II->getType()->getVectorNumElements();
APInt UndefElts(VWidth, 0);
APInt DemandedElts = APInt::getAllOnesValue(VWidth);
@@ -2141,6 +2649,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::x86_sse2_packssdw_128:
+ case Intrinsic::x86_sse2_packsswb_128:
+ case Intrinsic::x86_avx2_packssdw:
+ case Intrinsic::x86_avx2_packsswb:
+ case Intrinsic::x86_avx512_packssdw_512:
+ case Intrinsic::x86_avx512_packsswb_512:
+ if (Value *V = simplifyX86pack(*II, *this, *Builder, true))
+ return replaceInstUsesWith(*II, V);
+ break;
+
+ case Intrinsic::x86_sse2_packuswb_128:
+ case Intrinsic::x86_sse41_packusdw:
+ case Intrinsic::x86_avx2_packusdw:
+ case Intrinsic::x86_avx2_packuswb:
+ case Intrinsic::x86_avx512_packusdw_512:
+ case Intrinsic::x86_avx512_packuswb_512:
+ if (Value *V = simplifyX86pack(*II, *this, *Builder, false))
+ return replaceInstUsesWith(*II, V);
+ break;
+
+ case Intrinsic::x86_pclmulqdq: {
+ if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
+ unsigned Imm = C->getZExtValue();
+
+ bool MadeChange = false;
+ Value *Arg0 = II->getArgOperand(0);
+ Value *Arg1 = II->getArgOperand(1);
+ unsigned VWidth = Arg0->getType()->getVectorNumElements();
+ APInt DemandedElts(VWidth, 0);
+
+ APInt UndefElts1(VWidth, 0);
+ DemandedElts = (Imm & 0x01) ? 2 : 1;
+ if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts,
+ UndefElts1)) {
+ II->setArgOperand(0, V);
+ MadeChange = true;
+ }
+
+ APInt UndefElts2(VWidth, 0);
+ DemandedElts = (Imm & 0x10) ? 2 : 1;
+ if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts,
+ UndefElts2)) {
+ II->setArgOperand(1, V);
+ MadeChange = true;
+ }
+
+ // If both input elements are undef, the result is undef.
+ if (UndefElts1[(Imm & 0x01) ? 1 : 0] ||
+ UndefElts2[(Imm & 0x10) ? 1 : 0])
+ return replaceInstUsesWith(*II,
+ ConstantAggregateZero::get(II->getType()));
+
+ if (MadeChange)
+ return II;
+ }
+ break;
+ }
+
case Intrinsic::x86_sse41_insertps:
if (Value *V = simplifyX86insertps(*II, *Builder))
return replaceInstUsesWith(*II, V);
@@ -2531,9 +3097,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
-
case Intrinsic::amdgcn_rcp: {
- if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) {
+ Value *Src = II->getArgOperand(0);
+
+ // TODO: Move to ConstantFolding/InstSimplify?
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(CI, Src);
+
+ if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
const APFloat &ArgVal = C->getValueAPF();
APFloat Val(ArgVal.getSemantics(), 1.0);
APFloat::opStatus Status = Val.divide(ArgVal,
@@ -2546,6 +3117,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::amdgcn_rsq: {
+ Value *Src = II->getArgOperand(0);
+
+ // TODO: Move to ConstantFolding/InstSimplify?
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(CI, Src);
+ break;
+ }
case Intrinsic::amdgcn_frexp_mant:
case Intrinsic::amdgcn_frexp_exp: {
Value *Src = II->getArgOperand(0);
@@ -2650,6 +3229,274 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));
}
+ case Intrinsic::amdgcn_cvt_pkrtz: {
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+ if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
+ if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
+ const fltSemantics &HalfSem
+ = II->getType()->getScalarType()->getFltSemantics();
+ bool LosesInfo;
+ APFloat Val0 = C0->getValueAPF();
+ APFloat Val1 = C1->getValueAPF();
+ Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
+ Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
+
+ Constant *Folded = ConstantVector::get({
+ ConstantFP::get(II->getContext(), Val0),
+ ConstantFP::get(II->getContext(), Val1) });
+ return replaceInstUsesWith(*II, Folded);
+ }
+ }
+
+ if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1))
+ return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
+
+ break;
+ }
+ case Intrinsic::amdgcn_ubfe:
+ case Intrinsic::amdgcn_sbfe: {
+ // Decompose simple cases into standard shifts.
+ Value *Src = II->getArgOperand(0);
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(*II, Src);
+
+ unsigned Width;
+ Type *Ty = II->getType();
+ unsigned IntSize = Ty->getIntegerBitWidth();
+
+ ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2));
+ if (CWidth) {
+ Width = CWidth->getZExtValue();
+ if ((Width & (IntSize - 1)) == 0)
+ return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty));
+
+ if (Width >= IntSize) {
+ // Hardware ignores high bits, so remove those.
+ II->setArgOperand(2, ConstantInt::get(CWidth->getType(),
+ Width & (IntSize - 1)));
+ return II;
+ }
+ }
+
+ unsigned Offset;
+ ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1));
+ if (COffset) {
+ Offset = COffset->getZExtValue();
+ if (Offset >= IntSize) {
+ II->setArgOperand(1, ConstantInt::get(COffset->getType(),
+ Offset & (IntSize - 1)));
+ return II;
+ }
+ }
+
+ bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe;
+
+ // TODO: Also emit sub if only width is constant.
+ if (!CWidth && COffset && Offset == 0) {
+ Constant *KSize = ConstantInt::get(COffset->getType(), IntSize);
+ Value *ShiftVal = Builder->CreateSub(KSize, II->getArgOperand(2));
+ ShiftVal = Builder->CreateZExt(ShiftVal, II->getType());
+
+ Value *Shl = Builder->CreateShl(Src, ShiftVal);
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Shl, ShiftVal) :
+ Builder->CreateLShr(Shl, ShiftVal);
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+
+ if (!CWidth || !COffset)
+ break;
+
+ // TODO: This allows folding to undef when the hardware has specific
+ // behavior?
+ if (Offset + Width < IntSize) {
+ Value *Shl = Builder->CreateShl(Src, IntSize - Offset - Width);
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Shl, IntSize - Width) :
+ Builder->CreateLShr(Shl, IntSize - Width);
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Src, Offset) :
+ Builder->CreateLShr(Src, Offset);
+
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+ case Intrinsic::amdgcn_exp:
+ case Intrinsic::amdgcn_exp_compr: {
+ ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1));
+ if (!En) // Illegal.
+ break;
+
+ unsigned EnBits = En->getZExtValue();
+ if (EnBits == 0xf)
+ break; // All inputs enabled.
+
+ bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr;
+ bool Changed = false;
+ for (int I = 0; I < (IsCompr ? 2 : 4); ++I) {
+ if ((!IsCompr && (EnBits & (1 << I)) == 0) ||
+ (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) {
+ Value *Src = II->getArgOperand(I + 2);
+ if (!isa<UndefValue>(Src)) {
+ II->setArgOperand(I + 2, UndefValue::get(Src->getType()));
+ Changed = true;
+ }
+ }
+ }
+
+ if (Changed)
+ return II;
+
+ break;
+
+ }
+ case Intrinsic::amdgcn_fmed3: {
+ // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled
+ // for the shader.
+
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+ Value *Src2 = II->getArgOperand(2);
+
+ bool Swap = false;
+ // Canonicalize constants to RHS operands.
+ //
+ // fmed3(c0, x, c1) -> fmed3(x, c0, c1)
+ if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
+ std::swap(Src0, Src1);
+ Swap = true;
+ }
+
+ if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
+ std::swap(Src1, Src2);
+ Swap = true;
+ }
+
+ if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
+ std::swap(Src0, Src1);
+ Swap = true;
+ }
+
+ if (Swap) {
+ II->setArgOperand(0, Src0);
+ II->setArgOperand(1, Src1);
+ II->setArgOperand(2, Src2);
+ return II;
+ }
+
+ if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
+ CallInst *NewCall = Builder->CreateMinNum(Src0, Src1);
+ NewCall->copyFastMathFlags(II);
+ NewCall->takeName(II);
+ return replaceInstUsesWith(*II, NewCall);
+ }
+
+ if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
+ if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
+ if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
+ APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
+ C2->getValueAPF());
+ return replaceInstUsesWith(*II,
+ ConstantFP::get(Builder->getContext(), Result));
+ }
+ }
+ }
+
+ break;
+ }
+ case Intrinsic::amdgcn_icmp:
+ case Intrinsic::amdgcn_fcmp: {
+ const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2));
+ if (!CC)
+ break;
+
+ // Guard against invalid arguments.
+ int64_t CCVal = CC->getZExtValue();
+ bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp;
+ if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE ||
+ CCVal > CmpInst::LAST_ICMP_PREDICATE)) ||
+ (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE ||
+ CCVal > CmpInst::LAST_FCMP_PREDICATE)))
+ break;
+
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+
+ if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
+ if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
+ Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
+ return replaceInstUsesWith(*II,
+ ConstantExpr::getSExt(CCmp, II->getType()));
+ }
+
+ // Canonicalize constants to RHS.
+ CmpInst::Predicate SwapPred
+ = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal));
+ II->setArgOperand(0, Src1);
+ II->setArgOperand(1, Src0);
+ II->setArgOperand(2, ConstantInt::get(CC->getType(),
+ static_cast<int>(SwapPred)));
+ return II;
+ }
+
+ if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE)
+ break;
+
+ // Canonicalize compare eq with true value to compare != 0
+ // llvm.amdgcn.icmp(zext (i1 x), 1, eq)
+ // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne)
+ // llvm.amdgcn.icmp(sext (i1 x), -1, eq)
+ // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne)
+ Value *ExtSrc;
+ if (CCVal == CmpInst::ICMP_EQ &&
+ ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) ||
+ (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) &&
+ ExtSrc->getType()->isIntegerTy(1)) {
+ II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType()));
+ II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE));
+ return II;
+ }
+
+ CmpInst::Predicate SrcPred;
+ Value *SrcLHS;
+ Value *SrcRHS;
+
+ // Fold compare eq/ne with 0 from a compare result as the predicate to the
+ // intrinsic. The typical use is a wave vote function in the library, which
+ // will be fed from a user code condition compared with 0. Fold in the
+ // redundant compare.
+
+ // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne)
+ // -> llvm.amdgcn.[if]cmp(a, b, pred)
+ //
+ // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq)
+ // -> llvm.amdgcn.[if]cmp(a, b, inv pred)
+ if (match(Src1, m_Zero()) &&
+ match(Src0,
+ m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) {
+ if (CCVal == CmpInst::ICMP_EQ)
+ SrcPred = CmpInst::getInversePredicate(SrcPred);
+
+ Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ?
+ Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp;
+
+ Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID,
+ SrcLHS->getType());
+ Value *Args[] = { SrcLHS, SrcRHS,
+ ConstantInt::get(CC->getType(), SrcPred) };
+ CallInst *NewCall = Builder->CreateCall(NewF, Args);
+ NewCall->takeName(II);
+ return replaceInstUsesWith(*II, NewCall);
+ }
+
+ break;
+ }
case Intrinsic::stackrestore: {
// If the save is right next to the restore, remove the restore. This can
// happen when variable allocas are DCE'd.
@@ -2790,7 +3637,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// isKnownNonNull -> nonnull attribute
if (isKnownNonNullAt(DerivedPtr, II, &DT))
- II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
+ II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
}
// TODO: bitcast(relocate(p)) -> relocate(bitcast(p))
@@ -2799,11 +3646,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...)
break;
}
- }
+ case Intrinsic::experimental_guard: {
+ // Is this guard followed by another guard?
+ Instruction *NextInst = II->getNextNode();
+ Value *NextCond = nullptr;
+ if (match(NextInst,
+ m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) {
+ Value *CurrCond = II->getArgOperand(0);
+
+ // Remove a guard that it is immediately preceded by an identical guard.
+ if (CurrCond == NextCond)
+ return eraseInstFromFunction(*NextInst);
+
+ // Otherwise canonicalize guard(a); guard(b) -> guard(a & b).
+ II->setArgOperand(0, Builder->CreateAnd(CurrCond, NextCond));
+ return eraseInstFromFunction(*NextInst);
+ }
+ break;
+ }
+ }
return visitCallSite(II);
}
+// Fence instruction simplification
+Instruction *InstCombiner::visitFenceInst(FenceInst &FI) {
+ // Remove identical consecutive fences.
+ if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode()))
+ if (FI.isIdenticalTo(NFI))
+ return eraseInstFromFunction(FI);
+ return nullptr;
+}
+
// InvokeInst simplification
//
Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) {
@@ -2950,7 +3824,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
for (Value *V : CS.args()) {
if (V->getType()->isPointerTy() &&
- !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) &&
+ !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&
isKnownNonNullAt(V, CS.getInstruction(), &DT))
Indices.push_back(ArgNo + 1);
ArgNo++;
@@ -2959,7 +3833,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
assert(ArgNo == CS.arg_size() && "sanity check");
if (!Indices.empty()) {
- AttributeSet AS = CS.getAttributes();
+ AttributeList AS = CS.getAttributes();
LLVMContext &Ctx = CS.getInstruction()->getContext();
AS = AS.addAttribute(Ctx, Indices,
Attribute::get(Ctx, Attribute::NonNull));
@@ -3081,7 +3955,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
return false;
Instruction *Caller = CS.getInstruction();
- const AttributeSet &CallerPAL = CS.getAttributes();
+ const AttributeList &CallerPAL = CS.getAttributes();
// Okay, this is a cast from a function to a different type. Unless doing so
// would cause a type conversion of one of our arguments, change this call to
@@ -3108,7 +3982,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
}
if (!CallerPAL.isEmpty() && !Caller->use_empty()) {
- AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex);
+ AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);
if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy)))
return false; // Attribute not compatible with transformed value.
}
@@ -3149,8 +4023,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL))
return false; // Cannot transform this parameter value.
- if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1).
- overlaps(AttributeFuncs::typeIncompatible(ParamTy)))
+ if (AttrBuilder(CallerPAL.getParamAttributes(i))
+ .overlaps(AttributeFuncs::typeIncompatible(ParamTy)))
return false; // Attribute not compatible with transformed value.
if (CS.isInAllocaArgument(i))
@@ -3158,9 +4032,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// If the parameter is passed as a byval argument, then we have to have a
// sized type and the sized type has to have the same size as the old type.
- if (ParamTy != ActTy &&
- CallerPAL.getParamAttributes(i + 1).hasAttribute(i + 1,
- Attribute::ByVal)) {
+ if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) {
PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy);
if (!ParamPTy || !ParamPTy->getElementType()->isSized())
return false;
@@ -3205,7 +4077,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
break;
// Check if it has an attribute that's incompatible with varargs.
- AttributeSet PAttrs = CallerPAL.getSlotAttributes(i - 1);
+ AttributeList PAttrs = CallerPAL.getSlotAttributes(i - 1);
if (PAttrs.hasAttribute(Index, Attribute::StructRet))
return false;
}
@@ -3213,44 +4085,37 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// Okay, we decided that this is a safe thing to do: go ahead and start
// inserting cast instructions as necessary.
- std::vector<Value*> Args;
+ SmallVector<Value *, 8> Args;
+ SmallVector<AttributeSet, 8> ArgAttrs;
Args.reserve(NumActualArgs);
- SmallVector<AttributeSet, 8> attrVec;
- attrVec.reserve(NumCommonArgs);
+ ArgAttrs.reserve(NumActualArgs);
// Get any return attributes.
- AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex);
+ AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);
// If the return value is not being used, the type may not be compatible
// with the existing attributes. Wipe out any problematic attributes.
RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy));
- // Add the new return attributes.
- if (RAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(Caller->getContext(),
- AttributeSet::ReturnIndex, RAttrs));
-
AI = CS.arg_begin();
for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) {
Type *ParamTy = FT->getParamType(i);
- if ((*AI)->getType() == ParamTy) {
- Args.push_back(*AI);
- } else {
- Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy));
- }
+ Value *NewArg = *AI;
+ if ((*AI)->getType() != ParamTy)
+ NewArg = Builder->CreateBitOrPointerCast(*AI, ParamTy);
+ Args.push_back(NewArg);
// Add any parameter attributes.
- AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1);
- if (PAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(Caller->getContext(), i + 1,
- PAttrs));
+ ArgAttrs.push_back(CallerPAL.getParamAttributes(i));
}
// If the function takes more arguments than the call was taking, add them
// now.
- for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i)
+ for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) {
Args.push_back(Constant::getNullValue(FT->getParamType(i)));
+ ArgAttrs.push_back(AttributeSet());
+ }
// If we are removing arguments to the function, emit an obnoxious warning.
if (FT->getNumParams() < NumActualArgs) {
@@ -3259,54 +4124,56 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// Add all of the arguments in their promoted form to the arg list.
for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) {
Type *PTy = getPromotedType((*AI)->getType());
+ Value *NewArg = *AI;
if (PTy != (*AI)->getType()) {
// Must promote to pass through va_arg area!
Instruction::CastOps opcode =
CastInst::getCastOpcode(*AI, false, PTy, false);
- Args.push_back(Builder->CreateCast(opcode, *AI, PTy));
- } else {
- Args.push_back(*AI);
+ NewArg = Builder->CreateCast(opcode, *AI, PTy);
}
+ Args.push_back(NewArg);
// Add any parameter attributes.
- AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1);
- if (PAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(FT->getContext(), i + 1,
- PAttrs));
+ ArgAttrs.push_back(CallerPAL.getParamAttributes(i));
}
}
}
AttributeSet FnAttrs = CallerPAL.getFnAttributes();
- if (CallerPAL.hasAttributes(AttributeSet::FunctionIndex))
- attrVec.push_back(AttributeSet::get(Callee->getContext(), FnAttrs));
if (NewRetTy->isVoidTy())
Caller->setName(""); // Void type should not have a name.
- const AttributeSet &NewCallerPAL = AttributeSet::get(Callee->getContext(),
- attrVec);
+ assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) &&
+ "missing argument attributes");
+ LLVMContext &Ctx = Callee->getContext();
+ AttributeList NewCallerPAL = AttributeList::get(
+ Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs);
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
- Instruction *NC;
+ CallSite NewCS;
if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) {
- NC = Builder->CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(),
- Args, OpBundles);
- NC->takeName(II);
- cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv());
- cast<InvokeInst>(NC)->setAttributes(NewCallerPAL);
+ NewCS = Builder->CreateInvoke(Callee, II->getNormalDest(),
+ II->getUnwindDest(), Args, OpBundles);
} else {
- CallInst *CI = cast<CallInst>(Caller);
- NC = Builder->CreateCall(Callee, Args, OpBundles);
- NC->takeName(CI);
- cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind());
- cast<CallInst>(NC)->setCallingConv(CI->getCallingConv());
- cast<CallInst>(NC)->setAttributes(NewCallerPAL);
+ NewCS = Builder->CreateCall(Callee, Args, OpBundles);
+ cast<CallInst>(NewCS.getInstruction())
+ ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind());
}
+ NewCS->takeName(Caller);
+ NewCS.setCallingConv(CS.getCallingConv());
+ NewCS.setAttributes(NewCallerPAL);
+
+ // Preserve the weight metadata for the new call instruction. The metadata
+ // is used by SamplePGO to check callsite's hotness.
+ uint64_t W;
+ if (Caller->extractProfTotalWeight(W))
+ NewCS->setProfWeight(W);
// Insert a cast of the return type as necessary.
+ Instruction *NC = NewCS.getInstruction();
Value *NV = NC;
if (OldRetTy != NV->getType() && !Caller->use_empty()) {
if (!NV->getType()->isVoidTy()) {
@@ -3351,7 +4218,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
Value *Callee = CS.getCalledValue();
PointerType *PTy = cast<PointerType>(Callee->getType());
FunctionType *FTy = cast<FunctionType>(PTy->getElementType());
- const AttributeSet &Attrs = CS.getAttributes();
+ AttributeList Attrs = CS.getAttributes();
// If the call already has the 'nest' attribute somewhere then give up -
// otherwise 'nest' would occur twice after splicing in the chain.
@@ -3364,50 +4231,46 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts());
FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType());
- const AttributeSet &NestAttrs = NestF->getAttributes();
+ AttributeList NestAttrs = NestF->getAttributes();
if (!NestAttrs.isEmpty()) {
- unsigned NestIdx = 1;
+ unsigned NestArgNo = 0;
Type *NestTy = nullptr;
AttributeSet NestAttr;
// Look for a parameter marked with the 'nest' attribute.
for (FunctionType::param_iterator I = NestFTy->param_begin(),
- E = NestFTy->param_end(); I != E; ++NestIdx, ++I)
- if (NestAttrs.hasAttribute(NestIdx, Attribute::Nest)) {
+ E = NestFTy->param_end();
+ I != E; ++NestArgNo, ++I) {
+ AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo);
+ if (AS.hasAttribute(Attribute::Nest)) {
// Record the parameter type and any other attributes.
NestTy = *I;
- NestAttr = NestAttrs.getParamAttributes(NestIdx);
+ NestAttr = AS;
break;
}
+ }
if (NestTy) {
Instruction *Caller = CS.getInstruction();
std::vector<Value*> NewArgs;
+ std::vector<AttributeSet> NewArgAttrs;
NewArgs.reserve(CS.arg_size() + 1);
-
- SmallVector<AttributeSet, 8> NewAttrs;
- NewAttrs.reserve(Attrs.getNumSlots() + 1);
+ NewArgAttrs.reserve(CS.arg_size());
// Insert the nest argument into the call argument list, which may
// mean appending it. Likewise for attributes.
- // Add any result attributes.
- if (Attrs.hasAttributes(AttributeSet::ReturnIndex))
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- Attrs.getRetAttributes()));
-
{
- unsigned Idx = 1;
+ unsigned ArgNo = 0;
CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end();
do {
- if (Idx == NestIdx) {
+ if (ArgNo == NestArgNo) {
// Add the chain argument and attributes.
Value *NestVal = Tramp->getArgOperand(2);
if (NestVal->getType() != NestTy)
NestVal = Builder->CreateBitCast(NestVal, NestTy, "nest");
NewArgs.push_back(NestVal);
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- NestAttr));
+ NewArgAttrs.push_back(NestAttr);
}
if (I == E)
@@ -3415,23 +4278,13 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Add the original argument and attributes.
NewArgs.push_back(*I);
- AttributeSet Attr = Attrs.getParamAttributes(Idx);
- if (Attr.hasAttributes(Idx)) {
- AttrBuilder B(Attr, Idx);
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- Idx + (Idx >= NestIdx), B));
- }
+ NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo));
- ++Idx;
+ ++ArgNo;
++I;
} while (true);
}
- // Add any function attributes.
- if (Attrs.hasAttributes(AttributeSet::FunctionIndex))
- NewAttrs.push_back(AttributeSet::get(FTy->getContext(),
- Attrs.getFnAttributes()));
-
// The trampoline may have been bitcast to a bogus type (FTy).
// Handle this by synthesizing a new function type, equal to FTy
// with the chain parameter inserted.
@@ -3442,12 +4295,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Insert the chain's type into the list of parameter types, which may
// mean appending it.
{
- unsigned Idx = 1;
+ unsigned ArgNo = 0;
FunctionType::param_iterator I = FTy->param_begin(),
E = FTy->param_end();
do {
- if (Idx == NestIdx)
+ if (ArgNo == NestArgNo)
// Add the chain's type.
NewTypes.push_back(NestTy);
@@ -3457,7 +4310,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Add the original type.
NewTypes.push_back(*I);
- ++Idx;
+ ++ArgNo;
++I;
} while (true);
}
@@ -3470,8 +4323,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
NestF->getType() == PointerType::getUnqual(NewFTy) ?
NestF : ConstantExpr::getBitCast(NestF,
PointerType::getUnqual(NewFTy));
- const AttributeSet &NewPAL =
- AttributeSet::get(FTy->getContext(), NewAttrs);
+ AttributeList NewPAL =
+ AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(),
+ Attrs.getRetAttributes(), NewArgAttrs);
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e74b590e2b7c9..25683132c7860 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -274,12 +274,12 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
return NV;
// If we are casting a PHI, then fold the cast into the PHI.
- if (isa<PHINode>(Src)) {
+ if (auto *PN = dyn_cast<PHINode>(Src)) {
// Don't do this if it would create a PHI node with an illegal type from a
// legal type.
if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() ||
- ShouldChangeType(CI.getType(), Src->getType()))
- if (Instruction *NV = FoldOpIntoPhi(CI))
+ shouldChangeType(CI.getType(), Src->getType()))
+ if (Instruction *NV = foldOpIntoPhi(CI, PN))
return NV;
}
@@ -447,7 +447,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC,
Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {
Type *SrcTy = Trunc.getSrcTy();
Type *DestTy = Trunc.getType();
- if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy))
+ if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
return nullptr;
BinaryOperator *LogicOp;
@@ -463,6 +463,56 @@ Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {
return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC);
}
+/// Try to narrow the width of a splat shuffle. This could be generalized to any
+/// shuffle with a constant operand, but we limit the transform to avoid
+/// creating a shuffle type that targets may not be able to lower effectively.
+static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
+ InstCombiner::BuilderTy &Builder) {
+ auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
+ if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) &&
+ Shuf->getMask()->getSplatValue() &&
+ Shuf->getType() == Shuf->getOperand(0)->getType()) {
+ // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask
+ Constant *NarrowUndef = UndefValue::get(Trunc.getType());
+ Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType());
+ return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask());
+ }
+
+ return nullptr;
+}
+
+/// Try to narrow the width of an insert element. This could be generalized for
+/// any vector constant, but we limit the transform to insertion into undef to
+/// avoid potential backend problems from unsupported insertion widths. This
+/// could also be extended to handle the case of inserting a scalar constant
+/// into a vector variable.
+static Instruction *shrinkInsertElt(CastInst &Trunc,
+ InstCombiner::BuilderTy &Builder) {
+ Instruction::CastOps Opcode = Trunc.getOpcode();
+ assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
+ "Unexpected instruction for shrinking");
+
+ auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0));
+ if (!InsElt || !InsElt->hasOneUse())
+ return nullptr;
+
+ Type *DestTy = Trunc.getType();
+ Type *DestScalarTy = DestTy->getScalarType();
+ Value *VecOp = InsElt->getOperand(0);
+ Value *ScalarOp = InsElt->getOperand(1);
+ Value *Index = InsElt->getOperand(2);
+
+ if (isa<UndefValue>(VecOp)) {
+ // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index
+ // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index
+ UndefValue *NarrowUndef = UndefValue::get(DestTy);
+ Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy);
+ return InsertElementInst::Create(NarrowUndef, NarrowOp, Index);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
if (Instruction *Result = commonCastTransforms(CI))
return Result;
@@ -488,7 +538,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
// type. Only do this if the dest type is a simple type, don't convert the
// expression tree to something weird like i93 unless the source is also
// strange.
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateTruncated(Src, DestTy, *this, &CI)) {
// If this cast is a truncate, evaluting in a different type always
@@ -554,8 +604,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
if (Instruction *I = shrinkBitwiseLogic(CI))
return I;
+ if (Instruction *I = shrinkSplatShuffle(CI, *Builder))
+ return I;
+
+ if (Instruction *I = shrinkInsertElt(CI, *Builder))
+ return I;
+
if (Src->hasOneUse() && isa<IntegerType>(SrcTy) &&
- ShouldChangeType(SrcTy, DestTy)) {
+ shouldChangeType(SrcTy, DestTy)) {
// Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the
// dest type is native and cst < dest size.
if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) &&
@@ -838,11 +894,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
if (Instruction *Result = commonCastTransforms(CI))
return Result;
- // See if we can simplify any instructions used by the input whose sole
- // purpose is to compute bits we don't care about.
- if (SimplifyDemandedInstructionBits(CI))
- return &CI;
-
Value *Src = CI.getOperand(0);
Type *SrcTy = Src->getType(), *DestTy = CI.getType();
@@ -851,10 +902,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
// expression tree to something weird like i93 unless the source is also
// strange.
unsigned BitsToClear;
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) {
- assert(BitsToClear < SrcTy->getScalarSizeInBits() &&
- "Unreasonable BitsToClear");
+ assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
+ "Can't clear more bits than in SrcTy");
// Okay, we can transform this! Insert the new expression now.
DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -1124,11 +1175,6 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
if (Instruction *I = commonCastTransforms(CI))
return I;
- // See if we can simplify any instructions used by the input whose sole
- // purpose is to compute bits we don't care about.
- if (SimplifyDemandedInstructionBits(CI))
- return &CI;
-
Value *Src = CI.getOperand(0);
Type *SrcTy = Src->getType(), *DestTy = CI.getType();
@@ -1145,7 +1191,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
// type. Only do this if the dest type is a simple type, don't convert the
// expression tree to something weird like i93 unless the source is also
// strange.
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateSExtd(Src, DestTy)) {
// Okay, we can transform this! Insert the new expression now.
DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -1167,18 +1213,16 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
ShAmt);
}
- // If this input is a trunc from our destination, then turn sext(trunc(x))
+ // If the input is a trunc from the destination type, then turn sext(trunc(x))
// into shifts.
- if (TruncInst *TI = dyn_cast<TruncInst>(Src))
- if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) {
- uint32_t SrcBitSize = SrcTy->getScalarSizeInBits();
- uint32_t DestBitSize = DestTy->getScalarSizeInBits();
-
- // We need to emit a shl + ashr to do the sign extend.
- Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
- Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext");
- return BinaryOperator::CreateAShr(Res, ShAmt);
- }
+ Value *X;
+ if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) {
+ // sext(trunc(X)) --> ashr(shl(X, C), C)
+ unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
+ unsigned DestBitSize = DestTy->getScalarSizeInBits();
+ Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
+ return BinaryOperator::CreateAShr(Builder->CreateShl(X, ShAmt), ShAmt);
+ }
if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
return transformSExtICmp(ICI, CI);
@@ -1225,17 +1269,15 @@ static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
return nullptr;
}
-/// If this is a floating-point extension instruction, look
-/// through it until we get the source value.
+/// Look through floating-point extensions until we get the source value.
static Value *lookThroughFPExtensions(Value *V) {
- if (Instruction *I = dyn_cast<Instruction>(V))
- if (I->getOpcode() == Instruction::FPExt)
- return lookThroughFPExtensions(I->getOperand(0));
+ while (auto *FPExt = dyn_cast<FPExtInst>(V))
+ V = FPExt->getOperand(0);
// If this value is a constant, return the constant in the smallest FP type
// that can accurately represent it. This allows us to turn
// (float)((double)X+2.0) into x+2.0f.
- if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+ if (auto *CFP = dyn_cast<ConstantFP>(V)) {
if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext()))
return V; // No constant folding of this.
// See if the value can be truncated to half and then reextended.
@@ -1392,24 +1434,49 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0));
if (II) {
switch (II->getIntrinsicID()) {
- default: break;
- case Intrinsic::fabs: {
- // (fptrunc (fabs x)) -> (fabs (fptrunc x))
- Value *InnerTrunc = Builder->CreateFPTrunc(II->getArgOperand(0),
- CI.getType());
- Type *IntrinsicType[] = { CI.getType() };
- Function *Overload = Intrinsic::getDeclaration(
- CI.getModule(), II->getIntrinsicID(), IntrinsicType);
-
- SmallVector<OperandBundleDef, 1> OpBundles;
- II->getOperandBundlesAsDefs(OpBundles);
-
- Value *Args[] = { InnerTrunc };
- return CallInst::Create(Overload, Args, OpBundles, II->getName());
+ default: break;
+ case Intrinsic::fabs:
+ case Intrinsic::ceil:
+ case Intrinsic::floor:
+ case Intrinsic::rint:
+ case Intrinsic::round:
+ case Intrinsic::nearbyint:
+ case Intrinsic::trunc: {
+ Value *Src = II->getArgOperand(0);
+ if (!Src->hasOneUse())
+ break;
+
+ // Except for fabs, this transformation requires the input of the unary FP
+ // operation to be itself an fpext from the type to which we're
+ // truncating.
+ if (II->getIntrinsicID() != Intrinsic::fabs) {
+ FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src);
+ if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType())
+ break;
}
+
+ // Do unary FP operation on smaller type.
+ // (fptrunc (fabs x)) -> (fabs (fptrunc x))
+ Value *InnerTrunc = Builder->CreateFPTrunc(Src, CI.getType());
+ Type *IntrinsicType[] = { CI.getType() };
+ Function *Overload = Intrinsic::getDeclaration(
+ CI.getModule(), II->getIntrinsicID(), IntrinsicType);
+
+ SmallVector<OperandBundleDef, 1> OpBundles;
+ II->getOperandBundlesAsDefs(OpBundles);
+
+ Value *Args[] = { InnerTrunc };
+ CallInst *NewCI = CallInst::Create(Overload, Args,
+ OpBundles, II->getName());
+ NewCI->copyFastMathFlags(II);
+ return NewCI;
+ }
}
}
+ if (Instruction *I = shrinkInsertElt(CI, *Builder))
+ return I;
+
return nullptr;
}
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 428f94bb5e932..bbafa9e9f4687 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -230,7 +230,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
return nullptr;
uint64_t ArrayElementCount = Init->getType()->getArrayNumElements();
- if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays.
+ // Don't blow up on huge arrays.
+ if (ArrayElementCount > MaxArraySizeForCombine)
+ return nullptr;
// There are many forms of this optimization we can handle, for now, just do
// the simple index into a single-dimensional array.
@@ -1663,7 +1665,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
(Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) {
// TODO: Is this a good transform for vectors? Wider types may reduce
// throughput. Should this transform be limited (even for scalars) by using
- // ShouldChangeType()?
+ // shouldChangeType()?
if (!Cmp.getType()->isVectorTy()) {
Type *WideType = W->getType();
unsigned WideScalarBits = WideType->getScalarSizeInBits();
@@ -1792,6 +1794,15 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
ConstantInt::get(V->getType(), 1));
}
+ // X | C == C --> X <=u C
+ // X | C != C --> X >u C
+ // iff C+1 is a power of 2 (C is a bitmask of the low bits)
+ if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) &&
+ (*C + 1).isPowerOf2()) {
+ Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
+ return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1));
+ }
+
if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse())
return nullptr;
@@ -1914,61 +1925,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Shl->getOperand(0);
- if (Cmp.isEquality()) {
- // If the shift is NUW, then it is just shifting out zeros, no need for an
- // AND.
- Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt));
- if (Shl->hasNoUnsignedWrap())
- return new ICmpInst(Pred, X, LShrC);
-
- // If the shift is NSW and we compare to 0, then it is just shifting out
- // sign bits, no need for an AND either.
- if (Shl->hasNoSignedWrap() && *C == 0)
- return new ICmpInst(Pred, X, LShrC);
-
- if (Shl->hasOneUse()) {
- // Otherwise, strength reduce the shift into an and.
- Constant *Mask = ConstantInt::get(Shl->getType(),
- APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
-
- Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
- return new ICmpInst(Pred, And, LShrC);
+ Type *ShType = Shl->getType();
+
+ // NSW guarantees that we are only shifting out sign bits from the high bits,
+ // so we can ASHR the compare constant without needing a mask and eliminate
+ // the shift.
+ if (Shl->hasNoSignedWrap()) {
+ if (Pred == ICmpInst::ICMP_SGT) {
+ // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt)
+ APInt ShiftedC = C->ashr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
+ // This is the same code as the SGT case, but assert the pre-condition
+ // that is needed for this to work with equality predicates.
+ assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ "Compare known true or false was not folded");
+ APInt ShiftedC = C->ashr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_SLT) {
+ // SLE is the same as above, but SLE is canonicalized to SLT, so convert:
+ // (X << S) <=s C is equiv to X <=s (C >> S) for all C
+ // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX
+ // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN
+ assert(!C->isMinSignedValue() && "Unexpected icmp slt");
+ APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1;
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ // If this is a signed comparison to 0 and the shift is sign preserving,
+ // use the shift LHS operand instead; isSignTest may change 'Pred', so only
+ // do that if we're sure to not continue on in this function.
+ if (isSignTest(Pred, *C))
+ return new ICmpInst(Pred, X, Constant::getNullValue(ShType));
+ }
+
+ // NUW guarantees that we are only shifting out zero bits from the high bits,
+ // so we can LSHR the compare constant without needing a mask and eliminate
+ // the shift.
+ if (Shl->hasNoUnsignedWrap()) {
+ if (Pred == ICmpInst::ICMP_UGT) {
+ // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt)
+ APInt ShiftedC = C->lshr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
+ // This is the same code as the UGT case, but assert the pre-condition
+ // that is needed for this to work with equality predicates.
+ assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ "Compare known true or false was not folded");
+ APInt ShiftedC = C->lshr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_ULT) {
+ // ULE is the same as above, but ULE is canonicalized to ULT, so convert:
+ // (X << S) <=u C is equiv to X <=u (C >> S) for all C
+ // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u
+ // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0
+ assert(C->ugt(0) && "ult 0 should have been eliminated");
+ APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1;
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
}
- // If this is a signed comparison to 0 and the shift is sign preserving,
- // use the shift LHS operand instead; isSignTest may change 'Pred', so only
- // do that if we're sure to not continue on in this function.
- if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C))
- return new ICmpInst(Pred, X, Constant::getNullValue(X->getType()));
+ if (Cmp.isEquality() && Shl->hasOneUse()) {
+ // Strength-reduce the shift into an 'and'.
+ Constant *Mask = ConstantInt::get(
+ ShType,
+ APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
+ Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
+ Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt));
+ return new ICmpInst(Pred, And, LShrC);
+ }
// Otherwise, if this is a comparison of the sign bit, simplify to and/test.
bool TrueIfSigned = false;
if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) {
// (X << 31) <s 0 --> (X & 1) != 0
Constant *Mask = ConstantInt::get(
- X->getType(),
+ ShType,
APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1));
Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
- And, Constant::getNullValue(And->getType()));
- }
-
- // When the shift is nuw and pred is >u or <=u, comparison only really happens
- // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the
- // <=u case can be further converted to match <u (see below).
- if (Shl->hasNoUnsignedWrap() &&
- (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) {
- // Derivation for the ult case:
- // (X << S) <=u C is equiv to X <=u (C >> S) for all C
- // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u
- // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0
- assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) &&
- "Encountered `ult 0` that should have been eliminated by "
- "InstSimplify.");
- APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1
- : C->lshr(*ShiftAmt);
- return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC));
+ And, Constant::getNullValue(ShType));
}
// Transform (icmp pred iM (shl iM %v, N), C)
@@ -1981,8 +2020,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
- if (X->getType()->isVectorTy())
- TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements());
+ if (ShType->isVectorTy())
+ TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());
Constant *NewC =
ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC);
@@ -2342,8 +2381,24 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
// Fold icmp pred (add X, C2), C.
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
- auto CR =
- ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2);
+ CmpInst::Predicate Pred = Cmp.getPredicate();
+
+ // If the add does not wrap, we can always adjust the compare by subtracting
+ // the constants. Equality comparisons are handled elsewhere. SGE/SLE are
+ // canonicalized to SGT/SLT.
+ if (Add->hasNoSignedWrap() &&
+ (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) {
+ bool Overflow;
+ APInt NewC = C->ssub_ov(*C2, Overflow);
+ // If there is overflow, the result must be true or false.
+ // TODO: Can we assert there is no overflow because InstSimplify always
+ // handles those cases?
+ if (!Overflow)
+ // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2)
+ return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
+ }
+
+ auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2);
const APInt &Upper = CR.getUpper();
const APInt &Lower = CR.getLower();
if (Cmp.isSigned()) {
@@ -2364,16 +2419,14 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
// X+C <u C2 -> (X & -C2) == C
// iff C & (C2-1) == 0
// C2 is a power of 2
- if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() &&
- (*C2 & (*C - 1)) == 0)
+ if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0)
return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)),
ConstantExpr::getNeg(cast<Constant>(Y)));
// X+C >u C2 -> (X & ~C2) != C
// iff C & C2 == 0
// C2+1 is a power of 2
- if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() &&
- (*C2 & *C) == 0)
+ if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0)
return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)),
ConstantExpr::getNeg(cast<Constant>(Y)));
@@ -2656,7 +2709,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) {
// block. If in the same block, we're encouraging jump threading. If
// not, we are just pessimizing the code by making an i1 phi.
if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::Select: {
@@ -2767,12 +2820,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
D = BO1->getOperand(1);
}
- // icmp (X+cst) < 0 --> X < -cst
- if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero()))
- if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B))
- if (!RHSC->isMinValue(/*isSigned=*/true))
- return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC));
-
// icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
return new ICmpInst(Pred, A == Op1 ? B : A,
@@ -2847,6 +2894,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
+ // TODO: The subtraction-related identities shown below also hold, but
+ // canonicalization from (X -nuw 1) to (X + -1) means that the combinations
+ // wouldn't happen even if they were implemented.
+ //
+ // icmp ult (X - 1), Y -> icmp ule X, Y
+ // icmp uge (X - 1), Y -> icmp ugt X, Y
+ // icmp ugt X, (Y - 1) -> icmp uge X, Y
+ // icmp ule X, (Y - 1) -> icmp ult X, Y
+
+ // icmp ule (X + 1), Y -> icmp ult X, Y
+ if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One()))
+ return new ICmpInst(CmpInst::ICMP_ULT, A, Op1);
+
+ // icmp ugt (X + 1), Y -> icmp uge X, Y
+ if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One()))
+ return new ICmpInst(CmpInst::ICMP_UGE, A, Op1);
+
+ // icmp uge X, (Y + 1) -> icmp ugt X, Y
+ if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One()))
+ return new ICmpInst(CmpInst::ICMP_UGT, Op0, C);
+
+ // icmp ult X, (Y + 1) -> icmp ule X, Y
+ if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One()))
+ return new ICmpInst(CmpInst::ICMP_ULE, Op0, C);
+
// if C1 has greater magnitude than C2:
// icmp (X + C1), (Y + C2) -> icmp (X + C3), Y
// s.t. C3 = C1 - C2
@@ -3738,16 +3810,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth,
// greater than the RHS must differ in a bit higher than these due to carry.
case ICmpInst::ICMP_UGT: {
unsigned trailingOnes = RHS.countTrailingOnes();
- APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes);
- return ~lowBitsSet;
+ return APInt::getBitsSetFrom(BitWidth, trailingOnes);
}
// Similarly, for a ULT comparison, we don't care about the trailing zeros.
// Any value less than the RHS must differ in a higher bit because of carries.
case ICmpInst::ICMP_ULT: {
unsigned trailingZeros = RHS.countTrailingZeros();
- APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros);
- return ~lowBitsSet;
+ return APInt::getBitsSetFrom(BitWidth, trailingZeros);
}
default:
@@ -3887,7 +3957,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");
if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {
BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1);
- // The check for the unique predecessor is not the best that can be
+ // The check for the single predecessor is not the best that can be
// done. But it protects efficiently against cases like when SI's
// home block has two successors, Succ and Succ1, and Succ1 predecessor
// of Succ. Then SI can't be replaced by SIOpd because the use that gets
@@ -3895,8 +3965,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
// guarantees that the path all uses of SI (outside SI's parent) are on
// is disjoint from all other paths out of SI. But that information
// is more expensive to compute, and the trade-off here is in favor
- // of compile-time.
- if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
+ // of compile-time. It should also be noticed that we check for a single
+ // predecessor and not only uniqueness. This to handle the situation when
+ // Succ and Succ1 points to the same basic block.
+ if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
NumSel++;
SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent());
return true;
@@ -3932,12 +4004,12 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
- if (SimplifyDemandedBits(I.getOperandUse(0),
+ if (SimplifyDemandedBits(&I, 0,
getDemandedBitsLHSMask(I, BitWidth, IsSignBit),
Op0KnownZero, Op0KnownOne, 0))
return &I;
- if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth),
+ if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth),
Op1KnownZero, Op1KnownOne, 0))
return &I;
@@ -4801,7 +4873,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
// block. If in the same block, we're encouraging jump threading. If
// not, we are just pessimizing the code by making an i1 phi.
if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::SIToFP:
diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index 2847ce858e791..71000063ab3cb 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -28,6 +28,9 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Support/Dwarf.h"
+#include "llvm/IR/DIBuilder.h"
#define DEBUG_TYPE "instcombine"
@@ -40,21 +43,29 @@ class DbgDeclareInst;
class MemIntrinsic;
class MemSetInst;
-/// \brief Assign a complexity or rank value to LLVM Values.
+/// Assign a complexity or rank value to LLVM Values. This is used to reduce
+/// the amount of pattern matching needed for compares and commutative
+/// instructions. For example, if we have:
+/// icmp ugt X, Constant
+/// or
+/// xor (add X, Constant), cast Z
+///
+/// We do not have to consider the commuted variants of these patterns because
+/// canonicalization based on complexity guarantees the above ordering.
///
/// This routine maps IR values to various complexity ranks:
/// 0 -> undef
/// 1 -> Constants
/// 2 -> Other non-instructions
/// 3 -> Arguments
-/// 3 -> Unary operations
-/// 4 -> Other instructions
+/// 4 -> Cast and (f)neg/not instructions
+/// 5 -> Other instructions
static inline unsigned getComplexity(Value *V) {
if (isa<Instruction>(V)) {
- if (BinaryOperator::isNeg(V) || BinaryOperator::isFNeg(V) ||
- BinaryOperator::isNot(V))
- return 3;
- return 4;
+ if (isa<CastInst>(V) || BinaryOperator::isNeg(V) ||
+ BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V))
+ return 4;
+ return 5;
}
if (isa<Argument>(V))
return 3;
@@ -289,6 +300,7 @@ public:
Instruction *visitLoadInst(LoadInst &LI);
Instruction *visitStoreInst(StoreInst &SI);
Instruction *visitBranchInst(BranchInst &BI);
+ Instruction *visitFenceInst(FenceInst &FI);
Instruction *visitSwitchInst(SwitchInst &SI);
Instruction *visitReturnInst(ReturnInst &RI);
Instruction *visitInsertValueInst(InsertValueInst &IV);
@@ -313,9 +325,14 @@ public:
bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp,
const unsigned SIOpd);
+ /// Try to replace instruction \p I with value \p V which are pointers
+ /// in different address space.
+ /// \return true if successful.
+ bool replacePointer(Instruction &I, Value *V);
+
private:
- bool ShouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
- bool ShouldChangeType(Type *From, Type *To) const;
+ bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
+ bool shouldChangeType(Type *From, Type *To) const;
Value *dyn_castNegVal(Value *V) const;
Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const;
Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset,
@@ -456,8 +473,9 @@ public:
/// methods should return the value returned by this function.
Instruction *eraseInstFromFunction(Instruction &I) {
DEBUG(dbgs() << "IC: ERASE " << I << '\n');
-
assert(I.use_empty() && "Cannot erase instruction that is used!");
+ salvageDebugInfo(I);
+
// Make sure that we reprocess all operands now that we reduced their
// use counts.
if (I.getNumOperands() < 8) {
@@ -499,6 +517,9 @@ public:
return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
}
+ /// Maximum size of array considered when transforming.
+ uint64_t MaxArraySizeForCombine;
+
private:
/// \brief Performs a few simplifications for operators which are associative
/// or commutative.
@@ -518,8 +539,16 @@ private:
Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero,
APInt &KnownOne, unsigned Depth,
Instruction *CxtI);
- bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero,
+ bool SimplifyDemandedBits(Instruction *I, unsigned Op,
+ const APInt &DemandedMask, APInt &KnownZero,
APInt &KnownOne, unsigned Depth = 0);
+ /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
+ /// bits. It also tries to handle simplifications that can be done based on
+ /// DemandedMask, but without modifying the Instruction.
+ Value *SimplifyMultipleUseDemandedBits(Instruction *I,
+ const APInt &DemandedMask,
+ APInt &KnownZero, APInt &KnownOne,
+ unsigned Depth, Instruction *CxtI);
/// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded
/// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence.
Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl,
@@ -540,7 +569,7 @@ private:
/// Given a binary operator, cast instruction, or select which has a PHI node
/// as operand #0, see if we can fold the instruction into the PHI (which is
/// only possible if all operands to the PHI are constants).
- Instruction *FoldOpIntoPhi(Instruction &I);
+ Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN);
/// Given an instruction with a select as one operand and a constant as the
/// other operand, try to fold the binary operator into the select arguments.
@@ -549,7 +578,7 @@ private:
Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI);
/// This is a convenience wrapper function for the above two functions.
- Instruction *foldOpWithConstantIntoOperand(Instruction &I);
+ Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I);
/// \brief Try to rotate an operation below a PHI node, using PHI nodes for
/// its operands.
@@ -628,16 +657,16 @@ private:
SelectPatternFlavor SPF2, Value *C);
Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
- Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS,
+ Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS,
ConstantInt *AndRHS, BinaryOperator &TheAnd);
- Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask,
- bool isSub, Instruction &I);
Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside);
Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);
Instruction *MatchBSwap(BinaryOperator &I);
bool SimplifyStoreAtEndOfBlock(StoreInst &SI);
+
+ Instruction *SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI);
Instruction *SimplifyMemTransfer(MemIntrinsic *MI);
Instruction *SimplifyMemSet(MemSetInst *MI);
diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 49e516e9c1761..6288e054f1bc5 100644
--- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -12,13 +12,15 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -223,6 +225,107 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) {
return nullptr;
}
+namespace {
+// If I and V are pointers in different address space, it is not allowed to
+// use replaceAllUsesWith since I and V have different types. A
+// non-target-specific transformation should not use addrspacecast on V since
+// the two address space may be disjoint depending on target.
+//
+// This class chases down uses of the old pointer until reaching the load
+// instructions, then replaces the old pointer in the load instructions with
+// the new pointer. If during the chasing it sees bitcast or GEP, it will
+// create new bitcast or GEP with the new pointer and use them in the load
+// instruction.
+class PointerReplacer {
+public:
+ PointerReplacer(InstCombiner &IC) : IC(IC) {}
+ void replacePointer(Instruction &I, Value *V);
+
+private:
+ void findLoadAndReplace(Instruction &I);
+ void replace(Instruction *I);
+ Value *getReplacement(Value *I);
+
+ SmallVector<Instruction *, 4> Path;
+ MapVector<Value *, Value *> WorkMap;
+ InstCombiner &IC;
+};
+} // end anonymous namespace
+
+void PointerReplacer::findLoadAndReplace(Instruction &I) {
+ for (auto U : I.users()) {
+ auto *Inst = dyn_cast<Instruction>(&*U);
+ if (!Inst)
+ return;
+ DEBUG(dbgs() << "Found pointer user: " << *U << '\n');
+ if (isa<LoadInst>(Inst)) {
+ for (auto P : Path)
+ replace(P);
+ replace(Inst);
+ } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) {
+ Path.push_back(Inst);
+ findLoadAndReplace(*Inst);
+ Path.pop_back();
+ } else {
+ return;
+ }
+ }
+}
+
+Value *PointerReplacer::getReplacement(Value *V) {
+ auto Loc = WorkMap.find(V);
+ if (Loc != WorkMap.end())
+ return Loc->second;
+ return nullptr;
+}
+
+void PointerReplacer::replace(Instruction *I) {
+ if (getReplacement(I))
+ return;
+
+ if (auto *LT = dyn_cast<LoadInst>(I)) {
+ auto *V = getReplacement(LT->getPointerOperand());
+ assert(V && "Operand not replaced");
+ auto *NewI = new LoadInst(V);
+ NewI->takeName(LT);
+ IC.InsertNewInstWith(NewI, *LT);
+ IC.replaceInstUsesWith(*LT, NewI);
+ WorkMap[LT] = NewI;
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
+ auto *V = getReplacement(GEP->getPointerOperand());
+ assert(V && "Operand not replaced");
+ SmallVector<Value *, 8> Indices;
+ Indices.append(GEP->idx_begin(), GEP->idx_end());
+ auto *NewI = GetElementPtrInst::Create(
+ V->getType()->getPointerElementType(), V, Indices);
+ IC.InsertNewInstWith(NewI, *GEP);
+ NewI->takeName(GEP);
+ WorkMap[GEP] = NewI;
+ } else if (auto *BC = dyn_cast<BitCastInst>(I)) {
+ auto *V = getReplacement(BC->getOperand(0));
+ assert(V && "Operand not replaced");
+ auto *NewT = PointerType::get(BC->getType()->getPointerElementType(),
+ V->getType()->getPointerAddressSpace());
+ auto *NewI = new BitCastInst(V, NewT);
+ IC.InsertNewInstWith(NewI, *BC);
+ NewI->takeName(BC);
+ WorkMap[BC] = NewI;
+ } else {
+ llvm_unreachable("should never reach here");
+ }
+}
+
+void PointerReplacer::replacePointer(Instruction &I, Value *V) {
+#ifndef NDEBUG
+ auto *PT = cast<PointerType>(I.getType());
+ auto *NT = cast<PointerType>(V->getType());
+ assert(PT != NT && PT->getElementType() == NT->getElementType() &&
+ "Invalid usage");
+#endif
+ WorkMap[&I] = V;
+ findLoadAndReplace(I);
+}
+
Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
if (auto *I = simplifyAllocaArraySize(*this, AI))
return I;
@@ -293,12 +396,22 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
for (unsigned i = 0, e = ToDelete.size(); i != e; ++i)
eraseInstFromFunction(*ToDelete[i]);
Constant *TheSrc = cast<Constant>(Copy->getSource());
- Constant *Cast
- = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType());
- Instruction *NewI = replaceInstUsesWith(AI, Cast);
- eraseInstFromFunction(*Copy);
- ++NumGlobalCopies;
- return NewI;
+ auto *SrcTy = TheSrc->getType();
+ auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(),
+ SrcTy->getPointerAddressSpace());
+ Constant *Cast =
+ ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy);
+ if (AI.getType()->getPointerAddressSpace() ==
+ SrcTy->getPointerAddressSpace()) {
+ Instruction *NewI = replaceInstUsesWith(AI, Cast);
+ eraseInstFromFunction(*Copy);
+ ++NumGlobalCopies;
+ return NewI;
+ } else {
+ PointerReplacer PtrReplacer(*this);
+ PtrReplacer.replacePointer(AI, Cast);
+ ++NumGlobalCopies;
+ }
}
}
}
@@ -608,7 +721,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
// arrays of arbitrary size but this has a terrible impact on compile time.
// The threshold here is chosen arbitrarily, maybe needs a little bit of
// tuning.
- if (NumElements > 1024)
+ if (NumElements > IC.MaxArraySizeForCombine)
return nullptr;
const DataLayout &DL = IC.getDataLayout();
@@ -1113,7 +1226,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) {
// arrays of arbitrary size but this has a terrible impact on compile time.
// The threshold here is chosen arbitrarily, maybe needs a little bit of
// tuning.
- if (NumElements > 1024)
+ if (NumElements > IC.MaxArraySizeForCombine)
return false;
const DataLayout &DL = IC.getDataLayout();
@@ -1268,8 +1381,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
break;
}
- // Don't skip over loads or things that can modify memory.
- if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory())
+ // Don't skip over loads, throws or things that can modify memory.
+ if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow())
break;
}
@@ -1392,8 +1505,8 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
}
// If we find something that may be using or overwriting the stored
// value, or if we run out of instructions, we can't do the xform.
- if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() ||
- BBI == OtherBB->begin())
+ if (BBI->mayReadFromMemory() || BBI->mayThrow() ||
+ BBI->mayWriteToMemory() || BBI == OtherBB->begin())
return false;
}
@@ -1402,7 +1515,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
// StoreBB.
for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) {
// FIXME: This should really be AA driven.
- if (I->mayReadFromMemory() || I->mayWriteToMemory())
+ if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory())
return false;
}
}
@@ -1425,7 +1538,9 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
SI.getOrdering(),
SI.getSynchScope());
InsertNewInstBefore(NewSI, *BBI);
- NewSI->setDebugLoc(OtherStore->getDebugLoc());
+ // The debug locations of the original instructions might differ; merge them.
+ NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(),
+ OtherStore->getDebugLoc()));
// If the two stores had AA tags, merge them.
AAMDNodes AATags;
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 45a19fb0f1f26..f1ac82057e6cf 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -298,39 +298,33 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
// (X / Y) * Y = X - (X % Y)
// (X / Y) * -Y = (X % Y) - X
{
- Value *Op1C = Op1;
- BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0);
- if (!BO ||
- (BO->getOpcode() != Instruction::UDiv &&
- BO->getOpcode() != Instruction::SDiv)) {
- Op1C = Op0;
- BO = dyn_cast<BinaryOperator>(Op1);
+ Value *Y = Op1;
+ BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0);
+ if (!Div || (Div->getOpcode() != Instruction::UDiv &&
+ Div->getOpcode() != Instruction::SDiv)) {
+ Y = Op0;
+ Div = dyn_cast<BinaryOperator>(Op1);
}
- Value *Neg = dyn_castNegVal(Op1C);
- if (BO && BO->hasOneUse() &&
- (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) &&
- (BO->getOpcode() == Instruction::UDiv ||
- BO->getOpcode() == Instruction::SDiv)) {
- Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1);
+ Value *Neg = dyn_castNegVal(Y);
+ if (Div && Div->hasOneUse() &&
+ (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) &&
+ (Div->getOpcode() == Instruction::UDiv ||
+ Div->getOpcode() == Instruction::SDiv)) {
+ Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1);
// If the division is exact, X % Y is zero, so we end up with X or -X.
- if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO))
- if (SDiv->isExact()) {
- if (Op1BO == Op1C)
- return replaceInstUsesWith(I, Op0BO);
- return BinaryOperator::CreateNeg(Op0BO);
- }
-
- Value *Rem;
- if (BO->getOpcode() == Instruction::UDiv)
- Rem = Builder->CreateURem(Op0BO, Op1BO);
- else
- Rem = Builder->CreateSRem(Op0BO, Op1BO);
- Rem->takeName(BO);
+ if (Div->isExact()) {
+ if (DivOp1 == Y)
+ return replaceInstUsesWith(I, X);
+ return BinaryOperator::CreateNeg(X);
+ }
- if (Op1BO == Op1C)
- return BinaryOperator::CreateSub(Op0BO, Rem);
- return BinaryOperator::CreateSub(Rem, Op0BO);
+ auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem
+ : Instruction::SRem;
+ Value *Rem = Builder->CreateBinOp(RemOpc, X, DivOp1);
+ if (DivOp1 == Y)
+ return BinaryOperator::CreateSub(X, Rem);
+ return BinaryOperator::CreateSub(Rem, X);
}
}
@@ -1461,16 +1455,16 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
- } else if (isa<PHINode>(Op0I)) {
+ } else if (auto *PN = dyn_cast<PHINode>(Op0I)) {
using namespace llvm::PatternMatch;
const APInt *Op1Int;
if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
(I.getOpcode() == Instruction::URem ||
!Op1Int->isMinSignedValue())) {
- // FoldOpIntoPhi will speculate instructions to the end of the PHI's
+ // foldOpIntoPhi will speculate instructions to the end of the PHI's
// predecessor blocks, so do this only if we know the srem or urem
// will not fault.
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, PN))
return NV;
}
}
diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 4cbffe9533b75..85e5b6ba2dc2f 100644
--- a/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -457,8 +457,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {
}
// The more common cases of a phi with no constant operands or just one
- // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi()
- // respectively. FoldOpIntoPhi() wants to do the opposite transform that is
+ // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi()
+ // respectively. foldOpIntoPhi() wants to do the opposite transform that is
// performed here. It tries to replicate a cast in the phi operand's basic
// block to expose other folding opportunities. Thus, InstCombine will
// infinite loop without this check.
@@ -507,7 +507,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
// Be careful about transforming integer PHIs. We don't want to pessimize
// the code by turning an i32 into an i1293.
if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) {
- if (!ShouldChangeType(PN.getType(), CastSrcTy))
+ if (!shouldChangeType(PN.getType(), CastSrcTy))
return nullptr;
}
} else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) {
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 36644845352ed..693b6c95c169c 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -120,6 +120,16 @@ static Constant *getSelectFoldableConstant(Instruction *I) {
/// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
Instruction *FI) {
+ // Don't break up min/max patterns. The hasOneUse checks below prevent that
+ // for most cases, but vector min/max with bitcasts can be transformed. If the
+ // one-use restrictions are eased for other patterns, we still don't want to
+ // obfuscate min/max.
+ if ((match(&SI, m_SMin(m_Value(), m_Value())) ||
+ match(&SI, m_SMax(m_Value(), m_Value())) ||
+ match(&SI, m_UMin(m_Value(), m_Value())) ||
+ match(&SI, m_UMax(m_Value(), m_Value()))))
+ return nullptr;
+
// If this is a cast from the same type, merge.
if (TI->getNumOperands() == 1 && TI->isCast()) {
Type *FIOpndTy = FI->getOperand(0)->getType();
@@ -364,7 +374,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
/// into:
/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
- InstCombiner::BuilderTy *Builder) {
+ InstCombiner::BuilderTy *Builder) {
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
@@ -395,13 +405,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) ||
match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) {
IntrinsicInst *II = cast<IntrinsicInst>(Count);
- IRBuilder<> Builder(II);
// Explicitly clear the 'undef_on_zero' flag.
IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone());
Type *Ty = NewI->getArgOperand(1)->getType();
NewI->setArgOperand(1, Constant::getNullValue(Ty));
- Builder.Insert(NewI);
- return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType());
+ Builder->Insert(NewI);
+ return Builder->CreateZExtOrTrunc(NewI, ValueOnZero->getType());
}
return nullptr;
@@ -500,18 +509,16 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
return true;
}
-/// If this is an integer min/max where the select's 'true' operand is a
-/// constant, canonicalize that constant to the 'false' operand:
-/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C
+/// If this is an integer min/max (icmp + select) with a constant operand,
+/// create the canonical icmp for the min/max operation and canonicalize the
+/// constant to the 'false' operand of the select:
+/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2
+/// Note: if C1 != C2, this will change the icmp constant to the existing
+/// constant operand of the select.
static Instruction *
canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
- // TODO: We should also canonicalize min/max when the select has a different
- // constant value than the cmp constant, but we need to fix the backend first.
- if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) ||
- !isa<Constant>(Sel.getTrueValue()) ||
- isa<Constant>(Sel.getFalseValue()) ||
- Cmp.getOperand(1) != Sel.getTrueValue())
+ if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
// Canonicalize the compare predicate based on whether we have min or max.
@@ -526,16 +533,25 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
default: return nullptr;
}
- // Canonicalize the constant to the right side.
- if (isa<Constant>(LHS))
- std::swap(LHS, RHS);
+ // Is this already canonical?
+ if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS &&
+ Cmp.getPredicate() == NewPred)
+ return nullptr;
+
+ // Create the canonical compare and plug it into the select.
+ Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS));
- Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS);
- SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel);
+ // If the select operands did not change, we're done.
+ if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS)
+ return &Sel;
- // We swapped the select operands, so swap the metadata too.
- NewSel->swapProfMetadata();
- return NewSel;
+ // If we are swapping the select operands, swap the metadata too.
+ assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS &&
+ "Unexpected results from matchSelectPattern");
+ Sel.setTrueValue(LHS);
+ Sel.setFalseValue(RHS);
+ Sel.swapProfMetadata();
+ return &Sel;
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
@@ -786,7 +802,9 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
// This transform is performance neutral if we can elide at least one xor from
// the set of three operands, since we'll be tacking on an xor at the very
// end.
- if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
+ if (SelectPatternResult::isMinOrMax(SPF1) &&
+ SelectPatternResult::isMinOrMax(SPF2) &&
+ IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
IsFreeOrProfitableToInvert(B, NotB, ElidesXor) &&
IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) {
if (!NotA)
@@ -1035,8 +1053,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
// If the select condition element is false, choose from the 2nd vector.
Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts));
} else if (isa<UndefValue>(Elt)) {
- // If the select condition element is undef, the shuffle mask is undef.
- Mask.push_back(UndefValue::get(Int32Ty));
+ // Undef in a select condition (choose one of the operands) does not mean
+ // the same thing as undef in a shuffle mask (any value is acceptable), so
+ // give up.
+ return nullptr;
} else {
// Bail out on a constant expression.
return nullptr;
@@ -1364,11 +1384,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
// See if we can fold the select into a phi node if the condition is a select.
- if (isa<PHINode>(SI.getCondition()))
+ if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
// The true/false values have to be live in the PHI predecessor's blocks.
if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) &&
canSelectOperandBeMappingIntoPredBlock(FalseVal, SI))
- if (Instruction *NV = FoldOpIntoPhi(SI))
+ if (Instruction *NV = foldOpIntoPhi(SI, PN))
return NV;
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
@@ -1450,6 +1470,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
}
+ // If we can compute the condition, there's no need for a select.
+ // Like the above fold, we are attempting to reduce compile-time cost by
+ // putting this fold here with limitations rather than in InstSimplify.
+ // The motivation for this call into value tracking is to take advantage of
+ // the assumption cache, so make sure that is populated.
+ if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) {
+ APInt KnownOne(1, 0), KnownZero(1, 0);
+ computeKnownBits(CondVal, KnownZero, KnownOne, 0, &SI);
+ if (KnownOne == 1)
+ return replaceInstUsesWith(SI, TrueVal);
+ if (KnownZero == 1)
+ return replaceInstUsesWith(SI, FalseVal);
+ }
+
if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder))
return BitCastSel;
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 4ff9b64ac57cf..9aa679c60e47b 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -22,8 +22,8 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
- assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ assert(Op0->getType() == Op1->getType());
// See if we can fold away this shift.
if (SimplifyDemandedInstructionBits(I))
@@ -65,63 +65,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
}
/// Return true if we can simplify two logical (either left or right) shifts
-/// that have constant shift amounts.
-static bool canEvaluateShiftedShift(unsigned FirstShiftAmt,
- bool IsFirstShiftLeft,
- Instruction *SecondShift, InstCombiner &IC,
+/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
+static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
+ Instruction *InnerShift, InstCombiner &IC,
Instruction *CxtI) {
- assert(SecondShift->isLogicalShift() && "Unexpected instruction type");
+ assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
- // We need constant shifts.
- auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1));
- if (!SecondShiftConst)
+ // We need constant scalar or constant splat shifts.
+ const APInt *InnerShiftConst;
+ if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
return false;
- unsigned SecondShiftAmt = SecondShiftConst->getZExtValue();
- bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl;
-
- // We can always fold shl(c1) + shl(c2) -> shl(c1+c2).
- // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2).
- if (IsFirstShiftLeft == IsSecondShiftLeft)
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ if (IsInnerShl == IsOuterShl)
return true;
- // We can always fold lshr(c) + shl(c) -> and(c2).
- // We can always fold shl(c) + lshr(c) -> and(c2).
- if (FirstShiftAmt == SecondShiftAmt)
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ unsigned InnerShAmt = InnerShiftConst->getZExtValue();
+ if (InnerShAmt == OuterShAmt)
return true;
- unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
-
// If the 2nd shift is bigger than the 1st, we can fold:
- // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or
- // shl(c1) + lshr(c2) -> lshr(c3) + and(c4),
+ // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
+ // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
// but it isn't profitable unless we know the and'd out bits are already zero.
- // Also check that the 2nd shift is valid (less than the type width) or we'll
- // crash trying to produce the bit mask for the 'and'.
- if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) {
- unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt
- : SecondShiftAmt - FirstShiftAmt;
- APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift;
- if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI))
+ // Also, check that the inner shift is valid (less than the type width) or
+ // we'll crash trying to produce the bit mask for the 'and'.
+ unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
+ if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
+ unsigned MaskShift =
+ IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
+ APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
+ if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
return true;
}
return false;
}
-/// See if we can compute the specified value, but shifted
-/// logically to the left or right by some number of bits. This should return
-/// true if the expression can be computed for the same cost as the current
-/// expression tree. This is used to eliminate extraneous shifting from things
-/// like:
+/// See if we can compute the specified value, but shifted logically to the left
+/// or right by some number of bits. This should return true if the expression
+/// can be computed for the same cost as the current expression tree. This is
+/// used to eliminate extraneous shifting from things like:
/// %C = shl i128 %A, 64
/// %D = shl i128 %B, 96
/// %E = or i128 %C, %D
/// %F = lshr i128 %E, 64
-/// where the client will ask if E can be computed shifted right by 64-bits. If
-/// this succeeds, the GetShiftedValue function will be called to produce the
-/// value.
-static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
+/// where the client will ask if E can be computed shifted right by 64-bits. If
+/// this succeeds, getShiftedValue() will be called to produce the value.
+static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombiner &IC, Instruction *CxtI) {
// We can always evaluate constants shifted.
if (isa<Constant>(V))
@@ -165,8 +162,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
case Instruction::Or:
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
- return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
- CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
+ return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
+ canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
case Instruction::Shl:
case Instruction::LShr:
@@ -176,8 +173,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
- return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
- CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
+ return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
+ canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -185,16 +182,79 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
+ if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
return false;
return true;
}
}
}
-/// When CanEvaluateShifted returned true for an expression,
-/// this value inserts the new computation that produces the shifted value.
-static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
+/// Fold OuterShift (InnerShift X, C1), C2.
+/// See canEvaluateShiftedShift() for the constraints on these instructions.
+static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
+ bool IsOuterShl,
+ InstCombiner::BuilderTy &Builder) {
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ Type *ShType = InnerShift->getType();
+ unsigned TypeWidth = ShType->getScalarSizeInBits();
+
+ // We only accept shifts-by-a-constant in canEvaluateShifted().
+ const APInt *C1;
+ match(InnerShift->getOperand(1), m_APInt(C1));
+ unsigned InnerShAmt = C1->getZExtValue();
+
+ // Change the shift amount and clear the appropriate IR flags.
+ auto NewInnerShift = [&](unsigned ShAmt) {
+ InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
+ if (IsInnerShl) {
+ InnerShift->setHasNoUnsignedWrap(false);
+ InnerShift->setHasNoSignedWrap(false);
+ } else {
+ InnerShift->setIsExact(false);
+ }
+ return InnerShift;
+ };
+
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ if (IsInnerShl == IsOuterShl) {
+ // If this is an oversized composite shift, then unsigned shifts get 0.
+ if (InnerShAmt + OuterShAmt >= TypeWidth)
+ return Constant::getNullValue(ShType);
+
+ return NewInnerShift(InnerShAmt + OuterShAmt);
+ }
+
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ if (InnerShAmt == OuterShAmt) {
+ APInt Mask = IsInnerShl
+ ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
+ : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
+ Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
+ ConstantInt::get(ShType, Mask));
+ if (auto *AndI = dyn_cast<Instruction>(And)) {
+ AndI->moveBefore(InnerShift);
+ AndI->takeName(InnerShift);
+ }
+ return And;
+ }
+
+ assert(InnerShAmt > OuterShAmt &&
+ "Unexpected opposite direction logical shift pair");
+
+ // In general, we would need an 'and' for this transform, but
+ // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
+ // lshr (shl X, C1), C2 --> shl X, C1 - C2
+ // shl (lshr X, C1), C2 --> lshr X, C1 - C2
+ return NewInnerShift(InnerShAmt - OuterShAmt);
+}
+
+/// When canEvaluateShifted() returns true for an expression, this function
+/// inserts the new computation that produces the shifted value.
+static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
InstCombiner &IC, const DataLayout &DL) {
// We can always evaluate constants shifted.
if (Constant *C = dyn_cast<Constant>(V)) {
@@ -220,100 +280,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
I->setOperand(
- 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
+ 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
return I;
- case Instruction::Shl: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
-
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
- if (isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(I->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return I;
- }
-
- // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(BO->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(BO);
- VI->takeName(BO);
- }
- return V;
- }
-
- // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return BO;
- }
- // FIXME: This is almost identical to the SHL case. Refactor both cases into
- // a helper function.
- case Instruction::LShr: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
- if (!isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(BO->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setIsExact(false);
- return I;
- }
-
- // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(I->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(I);
- VI->takeName(I);
- }
- return V;
- }
-
- // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setIsExact(false);
- return BO;
- }
+ case Instruction::Shl:
+ case Instruction::LShr:
+ return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
+ *(IC.Builder));
case Instruction::Select:
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
+ 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
return I;
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -321,215 +302,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
- PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits,
+ PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
isLeftShift, IC, DL));
return PN;
}
}
}
-/// Try to fold (X << C1) << C2, where the shifts are some combination of
-/// shl/ashr/lshr.
-static Instruction *
-foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1,
- InstCombiner::BuilderTy *Builder) {
- Value *Op0 = I.getOperand(0);
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
-
- // Find out if this is a shift of a shift by a constant.
- BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
- if (ShiftOp && !ShiftOp->isShift())
- ShiftOp = nullptr;
-
- if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) {
-
- // This is a constant shift of a constant shift. Be careful about hiding
- // shl instructions behind bit masks. They are used to represent multiplies
- // by a constant, and it is important that simple arithmetic expressions
- // are still recognizable by scalar evolution.
- //
- // The transforms applied to shl are very similar to the transforms applied
- // to mul by constant. We can be more aggressive about optimizing right
- // shifts.
- //
- // Combinations of right and left shifts will still be optimized in
- // DAGCombine where scalar evolution no longer applies.
-
- ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
- uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
- uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
- assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
- if (ShiftAmt1 == 0)
- return nullptr; // Will be simplified in the future.
- Value *X = ShiftOp->getOperand(0);
-
- IntegerType *Ty = cast<IntegerType>(I.getType());
-
- // Check for (X << c1) << c2 and (X >> c1) >> c2
- if (I.getOpcode() == ShiftOp->getOpcode()) {
- uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift.
- // If this is an oversized composite shift, then unsigned shifts become
- // zero (handled in InstSimplify) and ashr saturates.
- if (AmtSum >= TypeBits) {
- if (I.getOpcode() != Instruction::AShr)
- return nullptr;
- AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr.
- }
-
- return BinaryOperator::Create(I.getOpcode(), X,
- ConstantInt::get(Ty, AmtSum));
- }
-
- if (ShiftAmt1 == ShiftAmt2) {
- // If we have ((X << C) >>u C), turn this into X & (-1 >>u C).
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1));
- return BinaryOperator::CreateAnd(
- X, ConstantInt::get(I.getContext(), Mask));
- }
- } else if (ShiftAmt1 < ShiftAmt2) {
- uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1;
-
- // (X >>?,exact C1) << C2 --> X << (C2-C1)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- assert(ShiftOp->getOpcode() == Instruction::LShr ||
- ShiftOp->getOpcode() == Instruction::AShr);
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
- NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
- return NewShl;
- }
-
- // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- // (X <<nuw C1) >>u C2 --> X >>u (C2-C1)
- if (ShiftOp->hasNoUnsignedWrap()) {
- BinaryOperator *NewLShr =
- BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst);
- NewLShr->setIsExact(I.isExact());
- return NewLShr;
- }
- Value *Shift = Builder->CreateLShr(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X >>s (C2-C1)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewAShr =
- BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst);
- NewAShr->setIsExact(I.isExact());
- return NewAShr;
- }
- }
- } else {
- assert(ShiftAmt2 < ShiftAmt1);
- uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2;
-
- // (X >>?exact C1) << C2 --> X >>?exact (C1-C2)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShr =
- BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst);
- NewShr->setIsExact(true);
- return NewShr;
- }
-
- // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- if (ShiftOp->hasNoUnsignedWrap()) {
- // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2)
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(true);
- return NewShl;
- }
- Value *Shift = Builder->CreateShl(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoSignedWrap(true);
- return NewShl;
- }
- }
- }
- }
-
- return nullptr;
-}
-
Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
- ConstantInt *COp1 = nullptr;
- if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else
- COp1 = dyn_cast<ConstantInt>(Op1);
-
- if (!COp1)
+ const APInt *Op1C;
+ if (!match(Op1, m_APInt(Op1C)))
return nullptr;
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
- CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) {
+ canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
return replaceInstUsesWith(
- I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL));
+ I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
}
// See if we can simplify any instructions used by the instruction whose sole
// purpose is to compute bits we don't care about.
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
+ unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
- assert(!COp1->uge(TypeBits) &&
+ assert(!Op1C->uge(TypeBits) &&
"Shift over the type width should have been removed already");
- // ((X*C1) << C2) == (X * (C1 << C2))
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
- if (BO->getOpcode() == Instruction::Mul && isLeftShift)
- if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
- return BinaryOperator::CreateMul(BO->getOperand(0),
- ConstantExpr::getShl(BOOp, Op1));
-
if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))
return FoldedShift;
@@ -544,7 +349,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
isa<ConstantInt>(TrOp->getOperand(1))) {
// Okay, we'll do this xform. Make the shift of shift.
- Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
+ Constant *ShAmt =
+ ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
// (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
@@ -561,10 +367,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate.
if (I.getOpcode() == Instruction::Shl)
- MaskV <<= COp1->getZExtValue();
+ MaskV <<= Op1C->getZExtValue();
else {
assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
- MaskV = MaskV.lshr(COp1->getZExtValue());
+ MaskV = MaskV.lshr(Op1C->getZExtValue());
}
// shift1 & 0x00FF
@@ -598,7 +404,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -634,7 +440,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
}
- if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder))
- return Folded;
-
return nullptr;
}
@@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V =
- SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(),
- I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(),
+ I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *V = commonShiftTransforms(I))
return V;
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- // Turn:
- // %zext = zext i32 %V to i64
- // %res = shl i64 %V, 8
- //
- // Into:
- // %shl = shl i32 %V, 8
- // %res = zext i32 %shl to i64
- //
- // This is only valid if %V would have zeros shifted out.
- if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) {
- unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits();
- if (ShAmt < SrcBitWidth &&
- MaskedValueIsZero(ZI->getOperand(0),
- APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) {
- auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt);
- return new ZExtInst(Shl, I.getType());
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+
+ // shl (zext X), ShAmt --> zext (shl X, ShAmt)
+ // This is only valid if X would have zeros shifted out.
+ Value *X;
+ if (match(Op0, m_ZExt(m_Value(X)))) {
+ unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+ if (ShAmt < SrcWidth &&
+ MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
+ return new ZExtInst(Builder->CreateShl(X, ShAmt), Ty);
+ }
+
+ // (X >>u C) << C --> X & (-1 << C)
+ if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) {
+ APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ // Be careful about hiding shl instructions behind bit masks. They are used
+ // to represent multiplies by a constant, and it is important that simple
+ // arithmetic expressions are still recognizable by scalar evolution.
+ // The inexact versions are deferred to DAGCombine, so we don't hide shl
+ // behind a bit mask.
+ const APInt *ShOp1;
+ if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShOp1))),
+ m_Exact(m_AShr(m_Value(X), m_APInt(ShOp1)))))) {
+ unsigned ShrAmt = ShOp1->getZExtValue();
+ if (ShrAmt < ShAmt) {
+ // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+ return NewShl;
}
+ if (ShrAmt > ShAmt) {
+ // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
+ auto *NewShr = BinaryOperator::Create(
+ cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
+ NewShr->setIsExact(true);
+ return NewShr;
+ }
+ }
+
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X << C1) << C2 --> X << (C1 + C2)
+ return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is a NUW shift.
if (!I.hasNoUnsignedWrap() &&
- MaskedValueIsZero(I.getOperand(0),
- APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0,
- &I)) {
+ MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setHasNoUnsignedWrap();
return &I;
}
- // If the shifted out value is all signbits, this is a NSW shift.
- if (!I.hasNoSignedWrap() &&
- ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) {
+ // If the shifted-out value is all signbits, then this is a NSW shift.
+ if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
I.setHasNoSignedWrap();
return &I;
}
}
- // (C1 << A) << C2 -> (C1 << C2) << A
- Constant *C1, *C2;
- Value *A;
- if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) &&
- match(I.getOperand(1), m_Constant(C2)))
- return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
+ Constant *C1;
+ if (match(Op1, m_Constant(C1))) {
+ Constant *C2;
+ Value *X;
+ // (C2 << X) << C1 --> (C2 << C1) << X
+ if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
+ return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
+
+ // (X * C2) << C1 --> X * (C2 << C1)
+ if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
+ return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
+ }
return nullptr;
}
@@ -776,43 +617,83 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
- unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ auto *II = dyn_cast<IntrinsicInst>(Op0);
+ if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
+ (II->getIntrinsicID() == Intrinsic::ctlz ||
+ II->getIntrinsicID() == Intrinsic::cttz ||
+ II->getIntrinsicID() == Intrinsic::ctpop)) {
// ctlz.i32(x)>>5 --> zext(x == 0)
// cttz.i32(x)>>5 --> zext(x == 0)
// ctpop.i32(x)>>5 --> zext(x == -1)
- if ((II->getIntrinsicID() == Intrinsic::ctlz ||
- II->getIntrinsicID() == Intrinsic::cttz ||
- II->getIntrinsicID() == Intrinsic::ctpop) &&
- isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
- bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
- Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
- Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
- return new ZExtInst(Cmp, II->getType());
+ bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
+ Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
+ Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
+ return new ZExtInst(Cmp, Ty);
+ }
+
+ Value *X;
+ const APInt *ShOp1;
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
+ auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
+ NewLShr->setIsExact(I.isExact());
+ return NewLShr;
+ }
+ // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2)
+ Value *NewLShr = Builder->CreateLShr(X, ShiftDiff, "", I.isExact());
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
}
+ if (ShlAmt > ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(true);
+ return NewShl;
+ }
+ // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2)
+ Value *NewShl = Builder->CreateShl(X, ShiftDiff);
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
+ }
+ assert(ShlAmt == ShAmt);
+ // (X << C) >>u C --> X & (-1 >>u C)
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
+ return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)){
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
-
return nullptr;
}
@@ -820,48 +701,66 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
+ Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
- // If the input is a SHL by the same constant (ashr (shl X, C), C), then we
- // have a sign-extend idiom.
+ // If the shift amount equals the difference in width of the destination
+ // and source scalar types:
+ // ashr (shl (zext X), C), C --> sext X
Value *X;
- if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
- // If the input is an extension from the shifted amount value, e.g.
- // %x = zext i8 %A to i32
- // %y = shl i32 %x, 24
- // %z = ashr %y, 24
- // then turn this into "z = sext i8 A to i32".
- if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) {
- uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits();
- uint32_t DestBits = ZI->getType()->getScalarSizeInBits();
- if (Op1C->getZExtValue() == DestBits-SrcBits)
- return new SExtInst(ZI->getOperand(0), ZI->getType());
+ if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
+ ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
+ return new SExtInst(X, Ty);
+
+ // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
+ // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
+ const APInt *ShOp1;
+ if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
+ NewAShr->setIsExact(I.isExact());
+ return NewAShr;
}
+ if (ShlAmt > ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
+ NewShl->setHasNoSignedWrap(true);
+ return NewShl;
+ }
+ }
+
+ if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized arithmetic shifts replicate the sign bit.
+ AmtSum = std::min(AmtSum, BitWidth - 1);
+ // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
+ return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)) {
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
// See if we can turn a signed shr into an unsigned shr.
- if (MaskedValueIsZero(Op0,
- APInt::getSignBit(I.getType()->getScalarSizeInBits()),
- 0, &I))
+ if (MaskedValueIsZero(Op0, APInt::getSignBit(BitWidth), 0, &I))
return BinaryOperator::CreateLShr(Op0, Op1);
return nullptr;
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 8b930bd95dfe3..4e6f02058d839 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -30,18 +30,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
assert(I && "No instruction?");
assert(OpNo < I->getNumOperands() && "Operand index too large");
- // If the operand is not a constant integer, nothing to do.
- ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo));
- if (!OpC) return false;
+ // The operand must be a constant integer or splat integer.
+ Value *Op = I->getOperand(OpNo);
+ const APInt *C;
+ if (!match(Op, m_APInt(C)))
+ return false;
// If there are no bits set that aren't demanded, nothing to do.
- Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth());
- if ((~Demanded & OpC->getValue()) == 0)
+ Demanded = Demanded.zextOrTrunc(C->getBitWidth());
+ if ((~Demanded & *C) == 0)
return false;
// This instruction is producing bits that are not demanded. Shrink the RHS.
- Demanded &= OpC->getValue();
- I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded));
+ Demanded &= *C;
+ I->setOperand(OpNo, ConstantInt::get(Op->getType(), Demanded));
return true;
}
@@ -66,12 +68,13 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
-bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask,
+bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
+ const APInt &DemandedMask,
APInt &KnownZero, APInt &KnownOne,
unsigned Depth) {
- auto *UserI = dyn_cast<Instruction>(U.getUser());
+ Use &U = I->getOperandUse(OpNo);
Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero,
- KnownOne, Depth, UserI);
+ KnownOne, Depth, I);
if (!NewVal) return false;
U = NewVal;
return true;
@@ -114,9 +117,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne "
"must have same BitWidth");
- if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- // We know all of the bits for a constant!
- KnownOne = CI->getValue() & DemandedMask;
+ const APInt *C;
+ if (match(V, m_APInt(C))) {
+ // We know all of the bits for a scalar constant or a splat vector constant!
+ KnownOne = *C & DemandedMask;
KnownZero = ~KnownOne & DemandedMask;
return nullptr;
}
@@ -138,9 +142,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 6) // Limit search depth.
return nullptr;
- APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
- APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
-
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
@@ -151,107 +152,43 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// we can't do any simplifications of the operands, because DemandedMask
// only reflects the bits demanded by *one* of the users.
if (Depth != 0 && !I->hasOneUse()) {
- // Despite the fact that we can't simplify this instruction in all User's
- // context, we can at least compute the knownzero/knownone bits, and we can
- // do simplifications that apply to *just* the one user if we know that
- // this instruction has a simpler value in that context.
- if (I->getOpcode() == Instruction::And) {
- // If either the LHS or the RHS are Zero, the result is zero.
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known 1 on one side, return the other.
- // These bits cannot contribute to the result of the 'and' in this
- // context.
- if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
- (DemandedMask & ~LHSKnownZero))
- return I->getOperand(0);
- if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
- (DemandedMask & ~RHSKnownZero))
- return I->getOperand(1);
-
- // If all of the demanded bits in the inputs are known zeros, return zero.
- if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
- return Constant::getNullValue(VTy);
-
- } else if (I->getOpcode() == Instruction::Or) {
- // We can simplify (X|Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
- // If either the LHS or the RHS are One, the result is One.
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known zero on one side, return the
- // other. These bits cannot contribute to the result of the 'or' in this
- // context.
- if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
- (DemandedMask & ~LHSKnownOne))
- return I->getOperand(0);
- if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
- (DemandedMask & ~RHSKnownOne))
- return I->getOperand(1);
-
- // If all of the potentially set bits on one side are known to be set on
- // the other side, just use the 'other' side.
- if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
- (DemandedMask & (~RHSKnownZero)))
- return I->getOperand(0);
- if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
- (DemandedMask & (~LHSKnownZero)))
- return I->getOperand(1);
- } else if (I->getOpcode() == Instruction::Xor) {
- // We can simplify (X^Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known zero on one side, return the
- // other.
- if ((DemandedMask & RHSKnownZero) == DemandedMask)
- return I->getOperand(0);
- if ((DemandedMask & LHSKnownZero) == DemandedMask)
- return I->getOperand(1);
- }
-
- // Compute the KnownZero/KnownOne bits to simplify things downstream.
- computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
- return nullptr;
+ return SimplifyMultipleUseDemandedBits(I, DemandedMask, KnownZero, KnownOne,
+ Depth, CxtI);
}
+ APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
+ APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
+
// If this is the root being simplified, allow it to have multiple uses,
// just set the DemandedMask to all bits so that we can try to simplify the
// operands. This allows visitTruncInst (for example) to simplify the
// operand of a trunc without duplicating all the logic below.
if (Depth == 0 && !V->hasOneUse())
- DemandedMask = APInt::getAllOnesValue(BitWidth);
+ DemandedMask.setAllBits();
switch (I->getOpcode()) {
default:
computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
break;
- case Instruction::And:
+ case Instruction::And: {
// If either the LHS or the RHS are Zero, the result is zero.
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero,
- LHSKnownZero, LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownZero, LHSKnownZero,
+ LHSKnownOne, Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
+ // Output known-0 are known to be clear if zero in either the LHS | RHS.
+ APInt IKnownZero = RHSKnownZero | LHSKnownZero;
+ // Output known-1 bits are only known if set in both the LHS & RHS.
+ APInt IKnownOne = RHSKnownOne & LHSKnownOne;
+
// If the client is only demanding bits that we know, return the known
// constant.
- if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)|
- (RHSKnownOne & LHSKnownOne))) == DemandedMask)
- return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne);
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(VTy, IKnownOne);
// If all of the demanded bits are known 1 on one side, return the other.
// These bits cannot contribute to the result of the 'and'.
@@ -262,34 +199,33 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
(DemandedMask & ~RHSKnownZero))
return I->getOperand(1);
- // If all of the demanded bits in the inputs are known zeros, return zero.
- if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
- return Constant::getNullValue(VTy);
-
// If the RHS is a constant, see if we can simplify it.
if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero))
return I;
- // Output known-1 bits are only known if set in both the LHS & RHS.
- KnownOne = RHSKnownOne & LHSKnownOne;
- // Output known-0 are known to be clear if zero in either the LHS | RHS.
- KnownZero = RHSKnownZero | LHSKnownZero;
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
- case Instruction::Or:
+ }
+ case Instruction::Or: {
// If either the LHS or the RHS are One, the result is One.
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne,
- LHSKnownZero, LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownOne, LHSKnownZero,
+ LHSKnownOne, Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
+ // Output known-0 bits are only known if clear in both the LHS & RHS.
+ APInt IKnownZero = RHSKnownZero & LHSKnownZero;
+ // Output known-1 are known to be set if set in either the LHS | RHS.
+ APInt IKnownOne = RHSKnownOne | LHSKnownOne;
+
// If the client is only demanding bits that we know, return the known
// constant.
- if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)|
- (RHSKnownOne | LHSKnownOne))) == DemandedMask)
- return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne);
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(VTy, IKnownOne);
// If all of the demanded bits are known zero on one side, return the other.
// These bits cannot contribute to the result of the 'or'.
@@ -313,16 +249,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
- // Output known-0 bits are only known if clear in both the LHS & RHS.
- KnownZero = RHSKnownZero & LHSKnownZero;
- // Output known-1 are known to be set if set in either the LHS | RHS.
- KnownOne = RHSKnownOne | LHSKnownOne;
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
+ }
case Instruction::Xor: {
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
@@ -400,9 +335,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
// Output known-0 bits are known if clear or set in both the LHS & RHS.
- KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne);
+ KnownZero = std::move(IKnownZero);
// Output known-1 are known to be set if set in only one of the LHS, RHS.
- KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
}
case Instruction::Select:
@@ -412,10 +347,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN)
return nullptr;
- if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 1, DemandedMask, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
@@ -434,8 +369,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask = DemandedMask.zext(truncBf);
KnownZero = KnownZero.zext(truncBf);
KnownOne = KnownOne.zext(truncBf);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
DemandedMask = DemandedMask.trunc(BitWidth);
KnownZero = KnownZero.trunc(BitWidth);
@@ -460,8 +395,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Don't touch a vector-to-scalar bitcast.
return nullptr;
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
break;
@@ -472,15 +407,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask = DemandedMask.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
DemandedMask = DemandedMask.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
KnownOne = KnownOne.zext(BitWidth);
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
// The top bits are known to be zero.
- KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth);
+ KnownZero.setBitsFrom(SrcBitWidth);
break;
}
case Instruction::SExt: {
@@ -490,7 +425,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt InputDemandedBits = DemandedMask &
APInt::getLowBitsSet(BitWidth, SrcBitWidth);
- APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth));
+ APInt NewBits(APInt::getBitsSetFrom(BitWidth, SrcBitWidth));
// If any of the sign extended bits are demanded, we know that the sign
// bit is demanded.
if ((NewBits & DemandedMask) != 0)
@@ -499,8 +434,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, InputDemandedBits, KnownZero, KnownOne,
+ Depth + 1))
return I;
InputDemandedBits = InputDemandedBits.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
@@ -530,11 +465,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Right fill the mask of bits for this ADD/SUB to demand the most
// significant bit and all those below it.
APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
- LHSKnownZero, LHSKnownOne, Depth + 1) ||
+ if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnownZero, LHSKnownOne,
+ Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
- SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
- LHSKnownZero, LHSKnownOne, Depth + 1)) {
+ SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnownZero, RHSKnownOne,
+ Depth + 1)) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
@@ -543,6 +479,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
BinOP.setHasNoUnsignedWrap(false);
return I;
}
+
+ // If we are known to be adding/subtracting zeros to every bit below
+ // the highest demanded bit, we just return the other side.
+ if ((DemandedFromOps & RHSKnownZero) == DemandedFromOps)
+ return I->getOperand(0);
+ // We can't do this with the LHS for subtraction.
+ if (I->getOpcode() == Instruction::Add &&
+ (DemandedFromOps & LHSKnownZero) == DemandedFromOps)
+ return I->getOperand(1);
}
// Otherwise just hand the add/sub off to computeKnownBits to fill in
@@ -569,19 +514,19 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the shift is NUW/NSW, then it does demand the high bits.
ShlOperator *IOp = cast<ShlOperator>(I);
if (IOp->hasNoSignedWrap())
- DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
+ DemandedMaskIn.setHighBits(ShiftAmt+1);
else if (IOp->hasNoUnsignedWrap())
- DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setHighBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
KnownZero <<= ShiftAmt;
KnownOne <<= ShiftAmt;
// low bits known zero.
if (ShiftAmt)
- KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ KnownZero.setLowBits(ShiftAmt);
}
break;
case Instruction::LShr:
@@ -595,19 +540,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the shift is exact, then it does demand the low bits (and knows that
// they are zero).
if (cast<LShrOperator>(I)->isExact())
- DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setLowBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
- KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
- KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
- if (ShiftAmt) {
- // Compute the new bits that are at the top now.
- APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
- KnownZero |= HighBits; // high bits known zero.
- }
+ KnownZero = KnownZero.lshr(ShiftAmt);
+ KnownOne = KnownOne.lshr(ShiftAmt);
+ if (ShiftAmt)
+ KnownZero.setHighBits(ShiftAmt); // high bits known zero.
}
break;
case Instruction::AShr:
@@ -635,26 +577,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If any of the "high bits" are demanded, we should set the sign bit as
// demanded.
if (DemandedMask.countLeadingZeros() <= ShiftAmt)
- DemandedMaskIn.setBit(BitWidth-1);
+ DemandedMaskIn.setSignBit();
// If the shift is exact, then it does demand the low bits (and knows that
// they are zero).
if (cast<AShrOperator>(I)->isExact())
- DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setLowBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
// Compute the new bits that are at the top now.
APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
- KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
- KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
+ KnownZero = KnownZero.lshr(ShiftAmt);
+ KnownOne = KnownOne.lshr(ShiftAmt);
// Handle the sign bits.
APInt SignBit(APInt::getSignBit(BitWidth));
// Adjust to where it is now in the mask.
- SignBit = APIntOps::lshr(SignBit, ShiftAmt);
+ SignBit = SignBit.lshr(ShiftAmt);
// If the input sign bit is known to be zero, or if none of the top bits
// are demanded, turn this into an unsigned shift right.
@@ -683,8 +625,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt LowBits = RA - 1;
APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, Mask2, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
// The low bits of LHS are unchanged by the srem.
@@ -693,12 +635,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If LHS is non-negative or has all low bits zero, then the upper bits
// are all zero.
- if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
+ if (LHSKnownZero.isNegative() || ((LHSKnownZero & LowBits) == LowBits))
KnownZero |= ~LowBits;
// If LHS is negative and not all low bits are zero, then the upper bits
// are all one.
- if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0))
+ if (LHSKnownOne.isNegative() && ((LHSKnownOne & LowBits) != 0))
KnownOne |= ~LowBits;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
@@ -713,21 +655,17 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
CxtI);
// If it's known zero, our sign bit is also zero.
if (LHSKnownZero.isNegative())
- KnownZero.setBit(KnownZero.getBitWidth() - 1);
+ KnownZero.setSignBit();
}
break;
case Instruction::URem: {
APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
APInt AllOnes = APInt::getAllOnesValue(BitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2,
- KnownOne2, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2,
- KnownOne2, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, AllOnes, KnownZero2, KnownOne2, Depth + 1) ||
+ SimplifyDemandedBits(I, 1, AllOnes, KnownZero2, KnownOne2, Depth + 1))
return I;
unsigned Leaders = KnownZero2.countLeadingOnes();
- Leaders = std::max(Leaders,
- KnownZero2.countLeadingOnes());
KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
break;
}
@@ -792,11 +730,11 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return ConstantInt::getNullValue(VTy);
// We know that the upper bits are set to zero.
- KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth);
+ KnownZero.setBitsFrom(ArgWidth);
return nullptr;
}
case Intrinsic::x86_sse42_crc32_64_64:
- KnownZero = APInt::getHighBitsSet(64, 32);
+ KnownZero.setBitsFrom(32);
return nullptr;
}
}
@@ -811,6 +749,150 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return nullptr;
}
+/// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
+/// bits. It also tries to handle simplifications that can be done based on
+/// DemandedMask, but without modifying the Instruction.
+Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
+ const APInt &DemandedMask,
+ APInt &KnownZero,
+ APInt &KnownOne,
+ unsigned Depth,
+ Instruction *CxtI) {
+ unsigned BitWidth = DemandedMask.getBitWidth();
+ Type *ITy = I->getType();
+
+ APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
+ APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
+
+ // Despite the fact that we can't simplify this instruction in all User's
+ // context, we can at least compute the knownzero/knownone bits, and we can
+ // do simplifications that apply to *just* the one user if we know that
+ // this instruction has a simpler value in that context.
+ switch (I->getOpcode()) {
+ case Instruction::And: {
+ // If either the LHS or the RHS are Zero, the result is zero.
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 are known to be clear if zero in either the LHS | RHS.
+ APInt IKnownZero = RHSKnownZero | LHSKnownZero;
+ // Output known-1 bits are only known if set in both the LHS & RHS.
+ APInt IKnownOne = RHSKnownOne & LHSKnownOne;
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known 1 on one side, return the other.
+ // These bits cannot contribute to the result of the 'and' in this
+ // context.
+ if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
+ (DemandedMask & ~LHSKnownZero))
+ return I->getOperand(0);
+ if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
+ (DemandedMask & ~RHSKnownZero))
+ return I->getOperand(1);
+
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ case Instruction::Or: {
+ // We can simplify (X|Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+
+ // If either the LHS or the RHS are One, the result is One.
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 bits are only known if clear in both the LHS & RHS.
+ APInt IKnownZero = RHSKnownZero & LHSKnownZero;
+ // Output known-1 are known to be set if set in either the LHS | RHS.
+ APInt IKnownOne = RHSKnownOne | LHSKnownOne;
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known zero on one side, return the
+ // other. These bits cannot contribute to the result of the 'or' in this
+ // context.
+ if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
+ (DemandedMask & ~LHSKnownOne))
+ return I->getOperand(0);
+ if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
+ (DemandedMask & ~RHSKnownOne))
+ return I->getOperand(1);
+
+ // If all of the potentially set bits on one side are known to be set on
+ // the other side, just use the 'other' side.
+ if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
+ (DemandedMask & (~RHSKnownZero)))
+ return I->getOperand(0);
+ if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
+ (DemandedMask & (~LHSKnownZero)))
+ return I->getOperand(1);
+
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ case Instruction::Xor: {
+ // We can simplify (X^Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 bits are known if clear or set in both the LHS & RHS.
+ APInt IKnownZero = (RHSKnownZero & LHSKnownZero) |
+ (RHSKnownOne & LHSKnownOne);
+ // Output known-1 are known to be set if set in only one of the LHS, RHS.
+ APInt IKnownOne = (RHSKnownZero & LHSKnownOne) |
+ (RHSKnownOne & LHSKnownZero);
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known zero on one side, return the
+ // other.
+ if ((DemandedMask & RHSKnownZero) == DemandedMask)
+ return I->getOperand(0);
+ if ((DemandedMask & LHSKnownZero) == DemandedMask)
+ return I->getOperand(1);
+
+ // Output known-0 bits are known if clear or set in both the LHS & RHS.
+ KnownZero = std::move(IKnownZero);
+ // Output known-1 are known to be set if set in only one of the LHS, RHS.
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ default:
+ // Compute the KnownZero/KnownOne bits to simplify things downstream.
+ computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
+
+ // If this user is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, KnownOne);
+
+ break;
+ }
+
+ return nullptr;
+}
+
+
/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
@@ -849,7 +931,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr,
unsigned ShrAmt = ShrOp1.getZExtValue();
KnownOne.clearAllBits();
- KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1);
+ KnownZero.setLowBits(ShlAmt - 1);
KnownZero &= DemandedMask;
APInt BitMask1(APInt::getAllOnesValue(BitWidth));
@@ -1472,14 +1554,136 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
break;
}
+ case Intrinsic::x86_sse2_packssdw_128:
+ case Intrinsic::x86_sse2_packsswb_128:
+ case Intrinsic::x86_sse2_packuswb_128:
+ case Intrinsic::x86_sse41_packusdw:
+ case Intrinsic::x86_avx2_packssdw:
+ case Intrinsic::x86_avx2_packsswb:
+ case Intrinsic::x86_avx2_packusdw:
+ case Intrinsic::x86_avx2_packuswb:
+ case Intrinsic::x86_avx512_packssdw_512:
+ case Intrinsic::x86_avx512_packsswb_512:
+ case Intrinsic::x86_avx512_packusdw_512:
+ case Intrinsic::x86_avx512_packuswb_512: {
+ auto *Ty0 = II->getArgOperand(0)->getType();
+ unsigned InnerVWidth = Ty0->getVectorNumElements();
+ assert(VWidth == (InnerVWidth * 2) && "Unexpected input size");
+
+ unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128;
+ unsigned VWidthPerLane = VWidth / NumLanes;
+ unsigned InnerVWidthPerLane = InnerVWidth / NumLanes;
+
+ // Per lane, pack the elements of the first input and then the second.
+ // e.g.
+ // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3])
+ // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15])
+ for (int OpNum = 0; OpNum != 2; ++OpNum) {
+ APInt OpDemandedElts(InnerVWidth, 0);
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ unsigned LaneIdx = Lane * VWidthPerLane;
+ for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) {
+ unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum;
+ if (DemandedElts[Idx])
+ OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt);
+ }
+ }
+
+ // Demand elements from the operand.
+ auto *Op = II->getArgOperand(OpNum);
+ APInt OpUndefElts(InnerVWidth, 0);
+ TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts,
+ Depth + 1);
+ if (TmpV) {
+ II->setArgOperand(OpNum, TmpV);
+ MadeChange = true;
+ }
+
+ // Pack the operand's UNDEF elements, one lane at a time.
+ OpUndefElts = OpUndefElts.zext(VWidth);
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane);
+ LaneElts = LaneElts.getLoBits(InnerVWidthPerLane);
+ LaneElts = LaneElts.shl(InnerVWidthPerLane * (2 * Lane + OpNum));
+ UndefElts |= LaneElts;
+ }
+ }
+ break;
+ }
+
+ // PSHUFB
+ case Intrinsic::x86_ssse3_pshuf_b_128:
+ case Intrinsic::x86_avx2_pshuf_b:
+ case Intrinsic::x86_avx512_pshuf_b_512:
+ // PERMILVAR
+ case Intrinsic::x86_avx_vpermilvar_ps:
+ case Intrinsic::x86_avx_vpermilvar_ps_256:
+ case Intrinsic::x86_avx512_vpermilvar_ps_512:
+ case Intrinsic::x86_avx_vpermilvar_pd:
+ case Intrinsic::x86_avx_vpermilvar_pd_256:
+ case Intrinsic::x86_avx512_vpermilvar_pd_512:
+ // PERMV
+ case Intrinsic::x86_avx2_permd:
+ case Intrinsic::x86_avx2_permps: {
+ Value *Op1 = II->getArgOperand(1);
+ TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts,
+ Depth + 1);
+ if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; }
+ break;
+ }
+
// SSE4A instructions leave the upper 64-bits of the 128-bit result
// in an undefined state.
case Intrinsic::x86_sse4a_extrq:
case Intrinsic::x86_sse4a_extrqi:
case Intrinsic::x86_sse4a_insertq:
case Intrinsic::x86_sse4a_insertqi:
- UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2);
+ UndefElts.setHighBits(VWidth / 2);
break;
+ case Intrinsic::amdgcn_buffer_load:
+ case Intrinsic::amdgcn_buffer_load_format: {
+ if (VWidth == 1 || !DemandedElts.isMask())
+ return nullptr;
+
+ // TODO: Handle 3 vectors when supported in code gen.
+ unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes());
+ if (NewNumElts == VWidth)
+ return nullptr;
+
+ Module *M = II->getParent()->getParent()->getParent();
+ Type *EltTy = V->getType()->getVectorElementType();
+
+ Type *NewTy = (NewNumElts == 1) ? EltTy :
+ VectorType::get(EltTy, NewNumElts);
+
+ Function *NewIntrin = Intrinsic::getDeclaration(M, II->getIntrinsicID(),
+ NewTy);
+
+ SmallVector<Value *, 5> Args;
+ for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I)
+ Args.push_back(II->getArgOperand(I));
+
+ IRBuilderBase::InsertPointGuard Guard(*Builder);
+ Builder->SetInsertPoint(II);
+
+ CallInst *NewCall = Builder->CreateCall(NewIntrin, Args);
+ NewCall->takeName(II);
+ NewCall->copyMetadata(*II);
+ if (NewNumElts == 1) {
+ return Builder->CreateInsertElement(UndefValue::get(V->getType()),
+ NewCall, static_cast<uint64_t>(0));
+ }
+
+ SmallVector<uint32_t, 8> EltMask;
+ for (unsigned I = 0; I < VWidth; ++I)
+ EltMask.push_back(I);
+
+ Value *Shuffle = Builder->CreateShuffleVector(
+ NewCall, UndefValue::get(NewTy), EltMask);
+
+ MadeChange = true;
+ return Shuffle;
+ }
}
break;
}
diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index b2477f6c8633b..e89b400a4afc8 100644
--- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -645,6 +645,36 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) {
return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask);
}
+/// If we have an insertelement instruction feeding into another insertelement
+/// and the 2nd is inserting a constant into the vector, canonicalize that
+/// constant insertion before the insertion of a variable:
+///
+/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 -->
+/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1
+///
+/// This has the potential of eliminating the 2nd insertelement instruction
+/// via constant folding of the scalar constant into a vector constant.
+static Instruction *hoistInsEltConst(InsertElementInst &InsElt2,
+ InstCombiner::BuilderTy &Builder) {
+ auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0));
+ if (!InsElt1 || !InsElt1->hasOneUse())
+ return nullptr;
+
+ Value *X, *Y;
+ Constant *ScalarC;
+ ConstantInt *IdxC1, *IdxC2;
+ if (match(InsElt1->getOperand(0), m_Value(X)) &&
+ match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) &&
+ match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) &&
+ match(InsElt2.getOperand(1), m_Constant(ScalarC)) &&
+ match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) {
+ Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2);
+ return InsertElementInst::Create(NewInsElt1, Y, IdxC1);
+ }
+
+ return nullptr;
+}
+
/// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex
/// --> shufflevector X, CVec', Mask'
static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) {
@@ -806,6 +836,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) {
if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE))
return Shuf;
+ if (Instruction *NewInsElt = hoistInsEltConst(IE, *Builder))
+ return NewInsElt;
+
// Turn a sequence of inserts that broadcasts a scalar into a single
// insert + shufflevector.
if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE))
@@ -1107,12 +1140,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SmallVector<int, 16> Mask = SVI.getShuffleMask();
Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
- bool MadeChange = false;
-
- // Undefined shuffle mask -> undefined value.
- if (isa<UndefValue>(SVI.getOperand(2)))
- return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType()));
+ if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getMask(),
+ SVI.getType(), DL, &TLI, &DT, &AC))
+ return replaceInstUsesWith(SVI, V);
+ bool MadeChange = false;
unsigned VWidth = SVI.getType()->getVectorNumElements();
APInt UndefElts(VWidth, 0);
@@ -1209,7 +1241,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (isShuffleExtractingFromLHS(SVI, Mask)) {
Value *V = LHS;
unsigned MaskElems = Mask.size();
- unsigned BegIdx = Mask.front();
VectorType *SrcTy = cast<VectorType>(V->getType());
unsigned VecBitWidth = SrcTy->getBitWidth();
unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType());
@@ -1223,6 +1254,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
// Only visit bitcasts that weren't previously handled.
BCs.push_back(BC);
for (BitCastInst *BC : BCs) {
+ unsigned BegIdx = Mask.front();
Type *TgtTy = BC->getDestTy();
unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy);
if (!TgtElemBitWidth)
diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp
index 27fc34d231752..88ef17bbc8fa6 100644
--- a/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -82,18 +82,24 @@ static cl::opt<bool>
EnableExpensiveCombines("expensive-combines",
cl::desc("Enable expensive instruction combines"));
+static cl::opt<unsigned>
+MaxArraySize("instcombine-maxarray-size", cl::init(1024),
+ cl::desc("Maximum array size considered when doing a combine"));
+
Value *InstCombiner::EmitGEPOffset(User *GEP) {
return llvm::EmitGEPOffset(Builder, DL, GEP);
}
/// Return true if it is desirable to convert an integer computation from a
/// given bit width to a new bit width.
-/// We don't want to convert from a legal to an illegal type for example or from
-/// a smaller to a larger illegal type.
-bool InstCombiner::ShouldChangeType(unsigned FromWidth,
+/// We don't want to convert from a legal to an illegal type or from a smaller
+/// to a larger illegal type. A width of '1' is always treated as a legal type
+/// because i1 is a fundamental type in IR, and there are many specialized
+/// optimizations for i1 types.
+bool InstCombiner::shouldChangeType(unsigned FromWidth,
unsigned ToWidth) const {
- bool FromLegal = DL.isLegalInteger(FromWidth);
- bool ToLegal = DL.isLegalInteger(ToWidth);
+ bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth);
+ bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth);
// If this is a legal integer from type, and the result would be an illegal
// type, don't do the transformation.
@@ -109,14 +115,16 @@ bool InstCombiner::ShouldChangeType(unsigned FromWidth,
}
/// Return true if it is desirable to convert a computation from 'From' to 'To'.
-/// We don't want to convert from a legal to an illegal type for example or from
-/// a smaller to a larger illegal type.
-bool InstCombiner::ShouldChangeType(Type *From, Type *To) const {
+/// We don't want to convert from a legal to an illegal type or from a smaller
+/// to a larger illegal type. i1 is always treated as a legal type because it is
+/// a fundamental type in IR, and there are many specialized optimizations for
+/// i1 types.
+bool InstCombiner::shouldChangeType(Type *From, Type *To) const {
assert(From->isIntegerTy() && To->isIntegerTy());
unsigned FromWidth = From->getPrimitiveSizeInBits();
unsigned ToWidth = To->getPrimitiveSizeInBits();
- return ShouldChangeType(FromWidth, ToWidth);
+ return shouldChangeType(FromWidth, ToWidth);
}
// Return true, if No Signed Wrap should be maintained for I.
@@ -447,16 +455,11 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp,
/// This function returns identity value for given opcode, which can be used to
/// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1).
-static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
+static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) {
if (isa<Constant>(V))
return nullptr;
- if (OpCode == Instruction::Mul)
- return ConstantInt::get(V->getType(), 1);
-
- // TODO: We can handle other cases e.g. Instruction::And, Instruction::Or etc.
-
- return nullptr;
+ return ConstantExpr::getBinOpIdentity(Opcode, V->getType());
}
/// This function factors binary ops which can be combined using distributive
@@ -468,8 +471,7 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
static Instruction::BinaryOps
getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode,
BinaryOperator *Op, Value *&LHS, Value *&RHS) {
- if (!Op)
- return Instruction::BinaryOpsEnd;
+ assert(Op && "Expected a binary operator");
LHS = Op->getOperand(0);
RHS = Op->getOperand(1);
@@ -499,11 +501,7 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,
const DataLayout &DL, BinaryOperator &I,
Instruction::BinaryOps InnerOpcode, Value *A,
Value *B, Value *C, Value *D) {
-
- // If any of A, B, C, D are null, we can not factor I, return early.
- // Checking A and C should be enough.
- if (!A || !C || !B || !D)
- return nullptr;
+ assert(A && B && C && D && "All values must be provided");
Value *V = nullptr;
Value *SimplifiedInst = nullptr;
@@ -564,13 +562,11 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,
if (isa<OverflowingBinaryOperator>(&I))
HasNSW = I.hasNoSignedWrap();
- if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS))
- if (isa<OverflowingBinaryOperator>(Op0))
- HasNSW &= Op0->hasNoSignedWrap();
+ if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS))
+ HasNSW &= LOBO->hasNoSignedWrap();
- if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS))
- if (isa<OverflowingBinaryOperator>(Op1))
- HasNSW &= Op1->hasNoSignedWrap();
+ if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS))
+ HasNSW &= ROBO->hasNoSignedWrap();
// We can propagate 'nsw' if we know that
// %Y = mul nsw i16 %X, C
@@ -599,31 +595,39 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
+ Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
- // Factorization.
- Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
- auto TopLevelOpcode = I.getOpcode();
- auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
- auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
-
- // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
- // a common term.
- if (LHSOpcode == RHSOpcode) {
- if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D))
- return V;
- }
-
- // The instruction has the form "(A op' B) op (C)". Try to factorize common
- // term.
- if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS,
- getIdentityValue(LHSOpcode, RHS)))
- return V;
+ {
+ // Factorization.
+ Value *A, *B, *C, *D;
+ Instruction::BinaryOps LHSOpcode, RHSOpcode;
+ if (Op0)
+ LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+ if (Op1)
+ RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
+
+ // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
+ // a common term.
+ if (Op0 && Op1 && LHSOpcode == RHSOpcode)
+ if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D))
+ return V;
+
+ // The instruction has the form "(A op' B) op (C)". Try to factorize common
+ // term.
+ if (Op0)
+ if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
+ if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS,
+ Ident))
+ return V;
- // The instruction has the form "(B) op (C op' D)". Try to factorize common
- // term.
- if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS,
- getIdentityValue(RHSOpcode, LHS), C, D))
- return V;
+ // The instruction has the form "(B) op (C op' D)". Try to factorize common
+ // term.
+ if (Op1)
+ if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
+ if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, Ident,
+ C, D))
+ return V;
+ }
// Expansion.
if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
@@ -720,6 +724,21 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const {
if (C->getType()->getElementType()->isIntegerTy())
return ConstantExpr::getNeg(C);
+ if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
+ for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+ Constant *Elt = CV->getAggregateElement(i);
+ if (!Elt)
+ return nullptr;
+
+ if (isa<UndefValue>(Elt))
+ continue;
+
+ if (!isa<ConstantInt>(Elt))
+ return nullptr;
+ }
+ return ConstantExpr::getNeg(CV);
+ }
+
return nullptr;
}
@@ -820,8 +839,29 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) {
return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
}
-Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
- PHINode *PN = cast<PHINode>(I.getOperand(0));
+static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV,
+ InstCombiner *IC) {
+ bool ConstIsRHS = isa<Constant>(I->getOperand(1));
+ Constant *C = cast<Constant>(I->getOperand(ConstIsRHS));
+
+ if (auto *InC = dyn_cast<Constant>(InV)) {
+ if (ConstIsRHS)
+ return ConstantExpr::get(I->getOpcode(), InC, C);
+ return ConstantExpr::get(I->getOpcode(), C, InC);
+ }
+
+ Value *Op0 = InV, *Op1 = C;
+ if (!ConstIsRHS)
+ std::swap(Op0, Op1);
+
+ Value *RI = IC->Builder->CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp");
+ auto *FPInst = dyn_cast<Instruction>(RI);
+ if (FPInst && isa<FPMathOperator>(FPInst))
+ FPInst->copyFastMathFlags(I);
+ return RI;
+}
+
+Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
unsigned NumPHIValues = PN->getNumIncomingValues();
if (NumPHIValues == 0)
return nullptr;
@@ -902,7 +942,11 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
// Beware of ConstantExpr: it may eventually evaluate to getNullValue,
// even if currently isNullValue gives false.
Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i));
- if (InC && !isa<ConstantExpr>(InC))
+ // For vector constants, we cannot use isNullValue to fold into
+ // FalseVInPred versus TrueVInPred. When we have individual nonzero
+ // elements in the vector, we will incorrectly fold InC to
+ // `TrueVInPred`.
+ if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC))
InV = InC->isNullValue() ? FalseVInPred : TrueVInPred;
else
InV = Builder->CreateSelect(PN->getIncomingValue(i),
@@ -923,15 +967,9 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
C, "phitmp");
NewPN->addIncoming(InV, PN->getIncomingBlock(i));
}
- } else if (I.getNumOperands() == 2) {
- Constant *C = cast<Constant>(I.getOperand(1));
+ } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV = nullptr;
- if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)))
- InV = ConstantExpr::get(I.getOpcode(), InC, C);
- else
- InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
- PN->getIncomingValue(i), C, "phitmp");
+ Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), this);
NewPN->addIncoming(InV, PN->getIncomingBlock(i));
}
} else {
@@ -957,14 +995,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
return replaceInstUsesWith(I, NewPN);
}
-Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) {
+Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) {
assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type");
if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
return NewSel;
- } else if (isa<PHINode>(I.getOperand(0))) {
- if (Instruction *NewPhi = FoldOpIntoPhi(I))
+ } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
+ if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
return NewPhi;
}
return nullptr;
@@ -1315,22 +1353,19 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) {
assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth);
assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth);
- // If both arguments of binary operation are shuffles, which use the same
- // mask and shuffle within a single vector, it is worthwhile to move the
- // shuffle after binary operation:
+ // If both arguments of the binary operation are shuffles that use the same
+ // mask and shuffle within a single vector, move the shuffle after the binop:
// Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m)
- if (isa<ShuffleVectorInst>(LHS) && isa<ShuffleVectorInst>(RHS)) {
- ShuffleVectorInst *LShuf = cast<ShuffleVectorInst>(LHS);
- ShuffleVectorInst *RShuf = cast<ShuffleVectorInst>(RHS);
- if (isa<UndefValue>(LShuf->getOperand(1)) &&
- isa<UndefValue>(RShuf->getOperand(1)) &&
- LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType() &&
- LShuf->getMask() == RShuf->getMask()) {
- Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0),
- RShuf->getOperand(0), Builder);
- return Builder->CreateShuffleVector(NewBO,
- UndefValue::get(NewBO->getType()), LShuf->getMask());
- }
+ auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS);
+ auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS);
+ if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() &&
+ isa<UndefValue>(LShuf->getOperand(1)) &&
+ isa<UndefValue>(RShuf->getOperand(1)) &&
+ LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) {
+ Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0),
+ RShuf->getOperand(0), Builder);
+ return Builder->CreateShuffleVector(
+ NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask());
}
// If one argument is a shuffle within one vector, the other is a constant,
@@ -1559,27 +1594,21 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// Replace: gep (gep %P, long B), long A, ...
// With: T = long A+B; gep %P, T, ...
//
- Value *Sum;
Value *SO1 = Src->getOperand(Src->getNumOperands()-1);
Value *GO1 = GEP.getOperand(1);
- if (SO1 == Constant::getNullValue(SO1->getType())) {
- Sum = GO1;
- } else if (GO1 == Constant::getNullValue(GO1->getType())) {
- Sum = SO1;
- } else {
- // If they aren't the same type, then the input hasn't been processed
- // by the loop above yet (which canonicalizes sequential index types to
- // intptr_t). Just avoid transforming this until the input has been
- // normalized.
- if (SO1->getType() != GO1->getType())
- return nullptr;
- // Only do the combine when GO1 and SO1 are both constants. Only in
- // this case, we are sure the cost after the merge is never more than
- // that before the merge.
- if (!isa<Constant>(GO1) || !isa<Constant>(SO1))
- return nullptr;
- Sum = Builder->CreateAdd(SO1, GO1, PtrOp->getName()+".sum");
- }
+
+ // If they aren't the same type, then the input hasn't been processed
+ // by the loop above yet (which canonicalizes sequential index types to
+ // intptr_t). Just avoid transforming this until the input has been
+ // normalized.
+ if (SO1->getType() != GO1->getType())
+ return nullptr;
+
+ Value* Sum = SimplifyAddInst(GO1, SO1, false, false, DL, &TLI, &DT, &AC);
+ // Only do the combine when we are sure the cost after the
+ // merge is never more than that before the merge.
+ if (Sum == nullptr)
+ return nullptr;
// Update the GEP in place if possible.
if (Src->getNumOperands() == 2) {
@@ -1654,14 +1683,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
}
}
- // Handle gep(bitcast x) and gep(gep x, 0, 0, 0).
- Value *StrippedPtr = PtrOp->stripPointerCasts();
- PointerType *StrippedPtrTy = dyn_cast<PointerType>(StrippedPtr->getType());
-
// We do not handle pointer-vector geps here.
- if (!StrippedPtrTy)
+ if (GEP.getType()->isVectorTy())
return nullptr;
+ // Handle gep(bitcast x) and gep(gep x, 0, 0, 0).
+ Value *StrippedPtr = PtrOp->stripPointerCasts();
+ PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType());
+
if (StrippedPtr != PtrOp) {
bool HasZeroPointerIndex = false;
if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1)))
@@ -2239,11 +2268,11 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
ConstantInt *AddRHS;
if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) {
// Change 'switch (X+4) case 1:' into 'switch (X) case -3'.
- for (SwitchInst::CaseIt CaseIter : SI.cases()) {
- Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS);
+ for (auto Case : SI.cases()) {
+ Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS);
assert(isa<ConstantInt>(NewCase) &&
"Result of expression should be constant");
- CaseIter.setValue(cast<ConstantInt>(NewCase));
+ Case.setValue(cast<ConstantInt>(NewCase));
}
SI.setCondition(Op0);
return &SI;
@@ -2275,9 +2304,9 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc");
SI.setCondition(NewCond);
- for (SwitchInst::CaseIt CaseIter : SI.cases()) {
- APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth);
- CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
+ for (auto Case : SI.cases()) {
+ APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth);
+ Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
}
return &SI;
}
@@ -2934,8 +2963,8 @@ bool InstCombiner::run() {
Result->takeName(I);
// Push the new instruction and any users onto the worklist.
- Worklist.Add(Result);
Worklist.AddUsersToWorkList(*Result);
+ Worklist.Add(Result);
// Insert the new instruction into the basic block...
BasicBlock *InstParent = I->getParent();
@@ -2958,8 +2987,8 @@ bool InstCombiner::run() {
if (isInstructionTriviallyDead(I, &TLI)) {
eraseInstFromFunction(*I);
} else {
- Worklist.Add(I);
Worklist.AddUsersToWorkList(*I);
+ Worklist.Add(I);
}
}
MadeIRChange = true;
@@ -3022,12 +3051,11 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
}
// See if we can constant fold its operands.
- for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e;
- ++i) {
- if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i))
+ for (Use &U : Inst->operands()) {
+ if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U))
continue;
- auto *C = cast<Constant>(i);
+ auto *C = cast<Constant>(U);
Constant *&FoldRes = FoldedConstants[C];
if (!FoldRes)
FoldRes = ConstantFoldConstant(C, DL, TLI);
@@ -3035,7 +3063,10 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
FoldRes = C;
if (FoldRes != C) {
- *i = FoldRes;
+ DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst
+ << "\n Old = " << *C
+ << "\n New = " << *FoldRes << '\n');
+ U = FoldRes;
MadeIRChange = true;
}
}
@@ -3055,17 +3086,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) {
- // See if this is an explicit destination.
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
- i != e; ++i)
- if (i.getCaseValue() == Cond) {
- BasicBlock *ReachableBB = i.getCaseSuccessor();
- Worklist.push_back(ReachableBB);
- continue;
- }
-
- // Otherwise it is the default destination.
- Worklist.push_back(SI->getDefaultDest());
+ Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor());
continue;
}
}
@@ -3152,6 +3173,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist,
InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines,
AA, AC, TLI, DT, DL, LI);
+ IC.MaxArraySizeForCombine = MaxArraySize;
Changed |= IC.run();
if (!Changed)
@@ -3176,9 +3198,10 @@ PreservedAnalyses InstCombinePass::run(Function &F,
return PreservedAnalyses::all();
// Mark all the analyses that instcombine updates as preserved.
- // FIXME: This should also 'preserve the CFG'.
PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
+ PA.preserve<AAManager>();
+ PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index f5e9e7dd5a93e..94cfc69ed5551 100644
--- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -80,6 +80,7 @@ static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37;
static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36;
static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30;
static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46;
+static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40;
static const uint64_t kWindowsShadowOffset32 = 3ULL << 28;
// The shadow memory space is dynamically allocated.
static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel;
@@ -380,6 +381,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,
bool IsAndroid = TargetTriple.isAndroid();
bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS();
bool IsFreeBSD = TargetTriple.isOSFreeBSD();
+ bool IsPS4CPU = TargetTriple.isPS4CPU();
bool IsLinux = TargetTriple.isOSLinux();
bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 ||
TargetTriple.getArch() == llvm::Triple::ppc64le;
@@ -392,6 +394,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,
TargetTriple.getArch() == llvm::Triple::mips64el;
bool IsAArch64 = TargetTriple.getArch() == llvm::Triple::aarch64;
bool IsWindows = TargetTriple.isOSWindows();
+ bool IsFuchsia = TargetTriple.isOSFuchsia();
ShadowMapping Mapping;
@@ -412,12 +415,18 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,
else
Mapping.Offset = kDefaultShadowOffset32;
} else { // LongSize == 64
- if (IsPPC64)
+ // Fuchsia is always PIE, which means that the beginning of the address
+ // space is always available.
+ if (IsFuchsia)
+ Mapping.Offset = 0;
+ else if (IsPPC64)
Mapping.Offset = kPPC64_ShadowOffset64;
else if (IsSystemZ)
Mapping.Offset = kSystemZ_ShadowOffset64;
else if (IsFreeBSD)
Mapping.Offset = kFreeBSD_ShadowOffset64;
+ else if (IsPS4CPU)
+ Mapping.Offset = kPS4CPU_ShadowOffset64;
else if (IsLinux && IsX86_64) {
if (IsKasan)
Mapping.Offset = kLinuxKasan_ShadowOffset64;
@@ -456,9 +465,9 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize,
// offset is not necessary 1/8-th of the address space. On SystemZ,
// we could OR the constant in a single instruction, but it's more
// efficient to load it once and use indexed addressing.
- Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ
- && !(Mapping.Offset & (Mapping.Offset - 1))
- && Mapping.Offset != kDynamicShadowSentinel;
+ Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS4CPU &&
+ !(Mapping.Offset & (Mapping.Offset - 1)) &&
+ Mapping.Offset != kDynamicShadowSentinel;
return Mapping;
}
@@ -567,8 +576,6 @@ struct AddressSanitizer : public FunctionPass {
Type *IntptrTy;
ShadowMapping Mapping;
DominatorTree *DT;
- Function *AsanCtorFunction = nullptr;
- Function *AsanInitFunction = nullptr;
Function *AsanHandleNoReturnFunc;
Function *AsanPtrCmpFunction, *AsanPtrSubFunction;
// This array is indexed by AccessIsWrite, Experiment and log2(AccessSize).
@@ -1561,31 +1568,31 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) {
// Declare our poisoning and unpoisoning functions.
AsanPoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr));
+ kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy));
AsanPoisonGlobals->setLinkage(Function::ExternalLinkage);
AsanUnpoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanUnpoisonGlobalsName, IRB.getVoidTy(), nullptr));
+ kAsanUnpoisonGlobalsName, IRB.getVoidTy()));
AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage);
// Declare functions that register/unregister globals.
AsanRegisterGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy));
AsanRegisterGlobals->setLinkage(Function::ExternalLinkage);
AsanUnregisterGlobals = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(kAsanUnregisterGlobalsName, IRB.getVoidTy(),
- IntptrTy, IntptrTy, nullptr));
+ IntptrTy, IntptrTy));
AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage);
// Declare the functions that find globals in a shared object and then invoke
// the (un)register function on them.
AsanRegisterImageGlobals =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr));
+ kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy));
AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage);
AsanUnregisterImageGlobals =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr));
+ kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy));
AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage);
}
@@ -1618,11 +1625,12 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata(
GlobalVariable *
AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer,
StringRef OriginalName) {
- GlobalVariable *Metadata =
- new GlobalVariable(M, Initializer->getType(), false,
- GlobalVariable::InternalLinkage, Initializer,
- Twine("__asan_global_") +
- GlobalValue::getRealLinkageName(OriginalName));
+ auto Linkage = TargetTriple.isOSBinFormatMachO()
+ ? GlobalVariable::InternalLinkage
+ : GlobalVariable::PrivateLinkage;
+ GlobalVariable *Metadata = new GlobalVariable(
+ M, Initializer->getType(), false, Linkage, Initializer,
+ Twine("__asan_global_") + GlobalValue::getRealLinkageName(OriginalName));
Metadata->setSection(getGlobalMetadataSection());
return Metadata;
}
@@ -1862,7 +1870,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) {
GlobalValue *InstrumentedGlobal = NewGlobal;
bool CanUsePrivateAliases =
- TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO();
+ TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO() ||
+ TargetTriple.isOSBinFormatWasm();
if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) {
// Create local alias for NewGlobal to avoid crash on ODR between
// instrumented and non-instrumented libraries.
@@ -1926,13 +1935,19 @@ bool AddressSanitizerModule::runOnModule(Module &M) {
Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel);
initializeCallbacks(M);
- bool Changed = false;
+ if (CompileKernel)
+ return false;
+
+ Function *AsanCtorFunction;
+ std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions(
+ M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{},
+ /*InitArgs=*/{}, kAsanVersionCheckName);
+ appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority);
+ bool Changed = false;
// TODO(glider): temporarily disabled globals instrumentation for KASan.
- if (ClGlobals && !CompileKernel) {
- Function *CtorFunc = M.getFunction(kAsanModuleCtorName);
- assert(CtorFunc);
- IRBuilder<> IRB(CtorFunc->getEntryBlock().getTerminator());
+ if (ClGlobals) {
+ IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator());
Changed |= InstrumentGlobals(IRB, M);
}
@@ -1949,49 +1964,60 @@ void AddressSanitizer::initializeCallbacks(Module &M) {
const std::string ExpStr = Exp ? "exp_" : "";
const std::string SuffixStr = CompileKernel ? "N" : "_n";
const std::string EndingStr = Recover ? "_noabort" : "";
- Type *ExpType = Exp ? Type::getInt32Ty(*C) : nullptr;
- AsanErrorCallbackSized[AccessIsWrite][Exp] =
- checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + EndingStr,
- IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr));
- AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] =
- checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr,
- IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr));
- for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;
- AccessSizeIndex++) {
- const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex);
- AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] =
- checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr,
- IRB.getVoidTy(), IntptrTy, ExpType, nullptr));
- AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] =
- checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr,
- IRB.getVoidTy(), IntptrTy, ExpType, nullptr));
+
+ SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy};
+ SmallVector<Type *, 2> Args1{1, IntptrTy};
+ if (Exp) {
+ Type *ExpType = Type::getInt32Ty(*C);
+ Args2.push_back(ExpType);
+ Args1.push_back(ExpType);
}
- }
+ AsanErrorCallbackSized[AccessIsWrite][Exp] =
+ checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+ kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr +
+ EndingStr,
+ FunctionType::get(IRB.getVoidTy(), Args2, false)));
+
+ AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] =
+ checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+ ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr,
+ FunctionType::get(IRB.getVoidTy(), Args2, false)));
+
+ for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;
+ AccessSizeIndex++) {
+ const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex);
+ AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] =
+ checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+ kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr,
+ FunctionType::get(IRB.getVoidTy(), Args1, false)));
+
+ AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] =
+ checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+ ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr,
+ FunctionType::get(IRB.getVoidTy(), Args1, false)));
+ }
+ }
}
const std::string MemIntrinCallbackPrefix =
CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix;
AsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));
AsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));
AsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy));
AsanHandleNoReturnFunc = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy(), nullptr));
+ M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy()));
AsanPtrCmpFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy));
AsanPtrSubFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy));
// We insert an empty inline asm after __asan_report* to avoid callback merge.
EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false),
StringRef(""), StringRef(""),
@@ -2001,7 +2027,6 @@ void AddressSanitizer::initializeCallbacks(Module &M) {
// virtual
bool AddressSanitizer::doInitialization(Module &M) {
// Initialize the private fields. No one has accessed them before.
-
GlobalsMD.init(M);
C = &(M.getContext());
@@ -2009,13 +2034,6 @@ bool AddressSanitizer::doInitialization(Module &M) {
IntptrTy = Type::getIntNTy(*C, LongSize);
TargetTriple = Triple(M.getTargetTriple());
- if (!CompileKernel) {
- std::tie(AsanCtorFunction, AsanInitFunction) =
- createSanitizerCtorAndInitFunctions(
- M, kAsanModuleCtorName, kAsanInitName,
- /*InitArgTypes=*/{}, /*InitArgs=*/{}, kAsanVersionCheckName);
- appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority);
- }
Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel);
return true;
}
@@ -2034,6 +2052,8 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) {
// We cannot just ignore these methods, because they may call other
// instrumented functions.
if (F.getName().find(" load]") != std::string::npos) {
+ Function *AsanInitFunction =
+ declareSanitizerInitFunction(*F.getParent(), kAsanInitName, {});
IRBuilder<> IRB(&F.front(), F.front().begin());
IRB.CreateCall(AsanInitFunction, {});
return true;
@@ -2081,7 +2101,6 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) {
}
bool AddressSanitizer::runOnFunction(Function &F) {
- if (&F == AsanCtorFunction) return false;
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false;
if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false;
if (F.getName().startswith("__asan_")) return false;
@@ -2175,8 +2194,9 @@ bool AddressSanitizer::runOnFunction(Function &F) {
(ClInstrumentationWithCallsThreshold >= 0 &&
ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold);
const DataLayout &DL = F.getParent()->getDataLayout();
- ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(),
- /*RoundToAlign=*/true);
+ ObjectSizeOpts ObjSizeOpts;
+ ObjSizeOpts.RoundToAlign = true;
+ ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), ObjSizeOpts);
// Instrument.
int NumInstrumented = 0;
@@ -2234,18 +2254,18 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) {
std::string Suffix = itostr(i);
AsanStackMallocFunc[i] = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(kAsanStackMallocNameTemplate + Suffix, IntptrTy,
- IntptrTy, nullptr));
+ IntptrTy));
AsanStackFreeFunc[i] = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix,
- IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ IRB.getVoidTy(), IntptrTy, IntptrTy));
}
if (ASan.UseAfterScope) {
AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(),
- IntptrTy, IntptrTy, nullptr));
+ IntptrTy, IntptrTy));
AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(),
- IntptrTy, IntptrTy, nullptr));
+ IntptrTy, IntptrTy));
}
for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) {
@@ -2254,14 +2274,14 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) {
Name << std::setw(2) << std::setfill('0') << std::hex << Val;
AsanSetShadowFunc[Val] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy));
}
AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy));
AsanAllocasUnpoisonFunc =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr));
+ kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy));
}
void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask,
diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index b34d5b8c45a71..4e454f0c95b65 100644
--- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -254,7 +254,7 @@ class DataFlowSanitizer : public ModulePass {
MDNode *ColdCallWeights;
DFSanABIList ABIList;
DenseMap<Value *, Function *> UnwrappedFnMap;
- AttributeSet ReadOnlyNoneAttrs;
+ AttributeList ReadOnlyNoneAttrs;
bool DFSanRuntimeShadowMask;
Value *getShadowAddress(Value *Addr, Instruction *Pos);
@@ -331,6 +331,10 @@ class DFSanVisitor : public InstVisitor<DFSanVisitor> {
DFSanFunction &DFSF;
DFSanVisitor(DFSanFunction &DFSF) : DFSF(DFSF) {}
+ const DataLayout &getDataLayout() const {
+ return DFSF.F->getParent()->getDataLayout();
+ }
+
void visitOperandShadowInst(Instruction &I);
void visitBinaryOperator(BinaryOperator &BO);
@@ -539,16 +543,17 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName,
F->getParent());
NewF->copyAttributesFrom(F);
NewF->removeAttributes(
- AttributeSet::ReturnIndex,
- AttributeSet::get(F->getContext(), AttributeSet::ReturnIndex,
- AttributeFuncs::typeIncompatible(NewFT->getReturnType())));
+ AttributeList::ReturnIndex,
+ AttributeList::get(
+ F->getContext(), AttributeList::ReturnIndex,
+ AttributeFuncs::typeIncompatible(NewFT->getReturnType())));
BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", NewF);
if (F->isVarArg()) {
NewF->removeAttributes(
- AttributeSet::FunctionIndex,
- AttributeSet().addAttribute(*Ctx, AttributeSet::FunctionIndex,
- "split-stack"));
+ AttributeList::FunctionIndex,
+ AttributeList().addAttribute(*Ctx, AttributeList::FunctionIndex,
+ "split-stack"));
CallInst::Create(DFSanVarargWrapperFn,
IRBuilder<>(BB).CreateGlobalStringPtr(F->getName()), "",
BB);
@@ -580,8 +585,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT,
Function::arg_iterator AI = F->arg_begin(); ++AI;
for (unsigned N = FT->getNumParams(); N != 0; ++AI, --N)
Args.push_back(&*AI);
- CallInst *CI =
- CallInst::Create(&F->getArgumentList().front(), Args, "", BB);
+ CallInst *CI = CallInst::Create(&*F->arg_begin(), Args, "", BB);
ReturnInst *RI;
if (FT->getReturnType()->isVoidTy())
RI = ReturnInst::Create(*Ctx, BB);
@@ -595,7 +599,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT,
DFSanVisitor(DFSF).visitCallInst(*CI);
if (!FT->getReturnType()->isVoidTy())
new StoreInst(DFSF.getShadow(RI->getReturnValue()),
- &F->getArgumentList().back(), RI);
+ &*std::prev(F->arg_end()), RI);
}
return C;
@@ -622,26 +626,26 @@ bool DataFlowSanitizer::runOnModule(Module &M) {
DFSanUnionFn = Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy);
if (Function *F = dyn_cast<Function>(DFSanUnionFn)) {
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
- F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
+ F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
F->addAttribute(1, Attribute::ZExt);
F->addAttribute(2, Attribute::ZExt);
}
DFSanCheckedUnionFn = Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy);
if (Function *F = dyn_cast<Function>(DFSanCheckedUnionFn)) {
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
- F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
+ F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
F->addAttribute(1, Attribute::ZExt);
F->addAttribute(2, Attribute::ZExt);
}
DFSanUnionLoadFn =
Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy);
if (Function *F = dyn_cast<Function>(DFSanUnionLoadFn)) {
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
- F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly);
- F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind);
+ F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
+ F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
}
DFSanUnimplementedFn =
Mod->getOrInsertFunction("__dfsan_unimplemented", DFSanUnimplementedFnTy);
@@ -696,7 +700,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) {
AttrBuilder B;
B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone);
- ReadOnlyNoneAttrs = AttributeSet::get(*Ctx, AttributeSet::FunctionIndex, B);
+ ReadOnlyNoneAttrs = AttributeList::get(*Ctx, AttributeList::FunctionIndex, B);
// First, change the ABI of every function in the module. ABI-listed
// functions keep their original ABI and get a wrapper function.
@@ -717,9 +721,10 @@ bool DataFlowSanitizer::runOnModule(Module &M) {
Function *NewF = Function::Create(NewFT, F.getLinkage(), "", &M);
NewF->copyAttributesFrom(&F);
NewF->removeAttributes(
- AttributeSet::ReturnIndex,
- AttributeSet::get(NewF->getContext(), AttributeSet::ReturnIndex,
- AttributeFuncs::typeIncompatible(NewFT->getReturnType())));
+ AttributeList::ReturnIndex,
+ AttributeList::get(
+ NewF->getContext(), AttributeList::ReturnIndex,
+ AttributeFuncs::typeIncompatible(NewFT->getReturnType())));
for (Function::arg_iterator FArg = F.arg_begin(),
NewFArg = NewF->arg_begin(),
FArgEnd = F.arg_end();
@@ -758,7 +763,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) {
&F, std::string("dfsw$") + std::string(F.getName()),
GlobalValue::LinkOnceODRLinkage, NewFT);
if (getInstrumentedABI() == IA_TLS)
- NewF->removeAttributes(AttributeSet::FunctionIndex, ReadOnlyNoneAttrs);
+ NewF->removeAttributes(AttributeList::FunctionIndex, ReadOnlyNoneAttrs);
Value *WrappedFnCst =
ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT));
@@ -906,7 +911,7 @@ Value *DFSanFunction::getShadow(Value *V) {
break;
}
case DataFlowSanitizer::IA_Args: {
- unsigned ArgIdx = A->getArgNo() + F->getArgumentList().size() / 2;
+ unsigned ArgIdx = A->getArgNo() + F->arg_size() / 2;
Function::arg_iterator i = F->arg_begin();
while (ArgIdx--)
++i;
@@ -983,7 +988,7 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) {
IRBuilder<> IRB(Pos);
if (AvoidNewBlocks) {
CallInst *Call = IRB.CreateCall(DFS.DFSanCheckedUnionFn, {V1, V2});
- Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
Call->addAttribute(1, Attribute::ZExt);
Call->addAttribute(2, Attribute::ZExt);
@@ -996,7 +1001,7 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) {
Ne, Pos, /*Unreachable=*/false, DFS.ColdCallWeights, &DT));
IRBuilder<> ThenIRB(BI);
CallInst *Call = ThenIRB.CreateCall(DFS.DFSanUnionFn, {V1, V2});
- Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
Call->addAttribute(1, Attribute::ZExt);
Call->addAttribute(2, Attribute::ZExt);
@@ -1099,7 +1104,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align,
CallInst *FallbackCall = FallbackIRB.CreateCall(
DFS.DFSanUnionLoadFn,
{ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)});
- FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
// Compare each of the shadows stored in the loaded 64 bits to each other,
// by computing (WideShadow rotl ShadowWidth) == WideShadow.
@@ -1156,7 +1161,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align,
IRBuilder<> IRB(Pos);
CallInst *FallbackCall = IRB.CreateCall(
DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)});
- FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
+ FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt);
return FallbackCall;
}
@@ -1446,7 +1451,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) {
// Custom functions returning non-void will write to the return label.
if (!FT->getReturnType()->isVoidTy()) {
- CustomFn->removeAttributes(AttributeSet::FunctionIndex,
+ CustomFn->removeAttributes(AttributeList::FunctionIndex,
DFSF.DFS.ReadOnlyNoneAttrs);
}
}
@@ -1481,7 +1486,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) {
auto *LabelVATy = ArrayType::get(DFSF.DFS.ShadowTy,
CS.arg_size() - FT->getNumParams());
auto *LabelVAAlloca = new AllocaInst(
- LabelVATy, "labelva", &DFSF.F->getEntryBlock().front());
+ LabelVATy, getDataLayout().getAllocaAddrSpace(),
+ "labelva", &DFSF.F->getEntryBlock().front());
for (unsigned n = 0; i != CS.arg_end(); ++i, ++n) {
auto LabelVAPtr = IRB.CreateStructGEP(LabelVATy, LabelVAAlloca, n);
@@ -1494,8 +1500,9 @@ void DFSanVisitor::visitCallSite(CallSite CS) {
if (!FT->getReturnType()->isVoidTy()) {
if (!DFSF.LabelReturnAlloca) {
DFSF.LabelReturnAlloca =
- new AllocaInst(DFSF.DFS.ShadowTy, "labelreturn",
- &DFSF.F->getEntryBlock().front());
+ new AllocaInst(DFSF.DFS.ShadowTy,
+ getDataLayout().getAllocaAddrSpace(),
+ "labelreturn", &DFSF.F->getEntryBlock().front());
}
Args.push_back(DFSF.LabelReturnAlloca);
}
@@ -1574,7 +1581,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) {
unsigned VarArgSize = CS.arg_size() - FT->getNumParams();
ArrayType *VarArgArrayTy = ArrayType::get(DFSF.DFS.ShadowTy, VarArgSize);
AllocaInst *VarArgShadow =
- new AllocaInst(VarArgArrayTy, "", &DFSF.F->getEntryBlock().front());
+ new AllocaInst(VarArgArrayTy, getDataLayout().getAllocaAddrSpace(),
+ "", &DFSF.F->getEntryBlock().front());
Args.push_back(IRB.CreateConstGEP2_32(VarArgArrayTy, VarArgShadow, 0, 0));
for (unsigned n = 0; i != e; ++i, ++n) {
IRB.CreateStore(
@@ -1593,7 +1601,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) {
}
NewCS.setCallingConv(CS.getCallingConv());
NewCS.setAttributes(CS.getAttributes().removeAttributes(
- *DFSF.DFS.Ctx, AttributeSet::ReturnIndex,
+ *DFSF.DFS.Ctx, AttributeList::ReturnIndex,
AttributeFuncs::typeIncompatible(NewCS.getInstruction()->getType())));
if (Next) {
diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp
index 05eba6c4dc699..7dea1dee756ac 100644
--- a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp
+++ b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp
@@ -267,35 +267,35 @@ void EfficiencySanitizer::initializeCallbacks(Module &M) {
SmallString<32> AlignedLoadName("__esan_aligned_load" + ByteSizeStr);
EsanAlignedLoad[Idx] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<32> AlignedStoreName("__esan_aligned_store" + ByteSizeStr);
EsanAlignedStore[Idx] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<32> UnalignedLoadName("__esan_unaligned_load" + ByteSizeStr);
EsanUnalignedLoad[Idx] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<32> UnalignedStoreName("__esan_unaligned_store" + ByteSizeStr);
EsanUnalignedStore[Idx] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy()));
}
EsanUnalignedLoadN = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("__esan_unaligned_loadN", IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
EsanUnalignedStoreN = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("__esan_unaligned_storeN", IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
MemmoveFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
MemcpyFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
MemsetFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt32Ty(), IntptrTy, nullptr));
+ IRB.getInt32Ty(), IntptrTy));
}
bool EfficiencySanitizer::shouldIgnoreStructType(StructType *StructTy) {
@@ -533,7 +533,7 @@ void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) {
IRBuilder<> IRB_Dtor(EsanDtorFunction->getEntryBlock().getTerminator());
Function *EsanExit = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(EsanExitName, IRB_Dtor.getVoidTy(),
- Int8PtrTy, nullptr));
+ Int8PtrTy));
EsanExit->setLinkage(Function::ExternalLinkage);
IRB_Dtor.CreateCall(EsanExit, {ToolInfoArg});
appendToGlobalDtors(M, EsanDtorFunction, EsanCtorAndDtorPriority);
@@ -757,7 +757,7 @@ bool EfficiencySanitizer::instrumentGetElementPtr(Instruction *I, Module &M) {
return false;
}
Type *SourceTy = GepInst->getSourceElementType();
- StructType *StructTy;
+ StructType *StructTy = nullptr;
ConstantInt *Idx;
// Check if GEP calculates address from a struct array.
if (isa<StructType>(SourceTy)) {
diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 1ba13bdfe05ad..61d627673c907 100644
--- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -1,4 +1,4 @@
-//===-- IndirectCallPromotion.cpp - Promote indirect calls to direct calls ===//
+//===-- IndirectCallPromotion.cpp - Optimizations based on value profiling ===//
//
// The LLVM Compiler Infrastructure
//
@@ -17,6 +17,8 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/IndirectCallPromotionAnalysis.h"
#include "llvm/Analysis/IndirectCallSiteVisitor.h"
#include "llvm/IR/BasicBlock.h"
@@ -40,6 +42,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/PGOInstrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -53,6 +56,8 @@ using namespace llvm;
STATISTIC(NumOfPGOICallPromotion, "Number of indirect call promotions.");
STATISTIC(NumOfPGOICallsites, "Number of indirect call candidate sites.");
+STATISTIC(NumOfPGOMemOPOpt, "Number of memop intrinsics optimized.");
+STATISTIC(NumOfPGOMemOPAnnotate, "Number of memop intrinsics annotated.");
// Command line option to disable indirect-call promotion with the default as
// false. This is for debug purpose.
@@ -80,6 +85,12 @@ static cl::opt<bool> ICPLTOMode("icp-lto", cl::init(false), cl::Hidden,
cl::desc("Run indirect-call promotion in LTO "
"mode"));
+// Set if the pass is called in SamplePGO mode. The difference for SamplePGO
+// mode is it will add prof metadatato the created direct call.
+static cl::opt<bool>
+ ICPSamplePGOMode("icp-samplepgo", cl::init(false), cl::Hidden,
+ cl::desc("Run indirect-call promotion in SamplePGO mode"));
+
// If the option is set to true, only call instructions will be considered for
// transformation -- invoke instructions will be ignored.
static cl::opt<bool>
@@ -100,13 +111,51 @@ static cl::opt<bool>
ICPDUMPAFTER("icp-dumpafter", cl::init(false), cl::Hidden,
cl::desc("Dump IR after transformation happens"));
+// The minimum call count to optimize memory intrinsic calls.
+static cl::opt<unsigned>
+ MemOPCountThreshold("pgo-memop-count-threshold", cl::Hidden, cl::ZeroOrMore,
+ cl::init(1000),
+ cl::desc("The minimum count to optimize memory "
+ "intrinsic calls"));
+
+// Command line option to disable memory intrinsic optimization. The default is
+// false. This is for debug purpose.
+static cl::opt<bool> DisableMemOPOPT("disable-memop-opt", cl::init(false),
+ cl::Hidden, cl::desc("Disable optimize"));
+
+// The percent threshold to optimize memory intrinsic calls.
+static cl::opt<unsigned>
+ MemOPPercentThreshold("pgo-memop-percent-threshold", cl::init(40),
+ cl::Hidden, cl::ZeroOrMore,
+ cl::desc("The percentage threshold for the "
+ "memory intrinsic calls optimization"));
+
+// Maximum number of versions for optimizing memory intrinsic call.
+static cl::opt<unsigned>
+ MemOPMaxVersion("pgo-memop-max-version", cl::init(3), cl::Hidden,
+ cl::ZeroOrMore,
+ cl::desc("The max version for the optimized memory "
+ " intrinsic calls"));
+
+// Scale the counts from the annotation using the BB count value.
+static cl::opt<bool>
+ MemOPScaleCount("pgo-memop-scale-count", cl::init(true), cl::Hidden,
+ cl::desc("Scale the memop size counts using the basic "
+ " block count value"));
+
+// This option sets the rangge of precise profile memop sizes.
+extern cl::opt<std::string> MemOPSizeRange;
+
+// This option sets the value that groups large memop sizes
+extern cl::opt<unsigned> MemOPSizeLarge;
+
namespace {
class PGOIndirectCallPromotionLegacyPass : public ModulePass {
public:
static char ID;
- PGOIndirectCallPromotionLegacyPass(bool InLTO = false)
- : ModulePass(ID), InLTO(InLTO) {
+ PGOIndirectCallPromotionLegacyPass(bool InLTO = false, bool SamplePGO = false)
+ : ModulePass(ID), InLTO(InLTO), SamplePGO(SamplePGO) {
initializePGOIndirectCallPromotionLegacyPassPass(
*PassRegistry::getPassRegistry());
}
@@ -119,6 +168,28 @@ private:
// If this pass is called in LTO. We need to special handling the PGOFuncName
// for the static variables due to LTO's internalization.
bool InLTO;
+
+ // If this pass is called in SamplePGO. We need to add the prof metadata to
+ // the promoted direct call.
+ bool SamplePGO;
+};
+
+class PGOMemOPSizeOptLegacyPass : public FunctionPass {
+public:
+ static char ID;
+
+ PGOMemOPSizeOptLegacyPass() : FunctionPass(ID) {
+ initializePGOMemOPSizeOptLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ StringRef getPassName() const override { return "PGOMemOPSize"; }
+
+private:
+ bool runOnFunction(Function &F) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<BlockFrequencyInfoWrapperPass>();
+ AU.addPreserved<GlobalsAAWrapperPass>();
+ }
};
} // end anonymous namespace
@@ -128,8 +199,22 @@ INITIALIZE_PASS(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom",
"direct calls.",
false, false)
-ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO) {
- return new PGOIndirectCallPromotionLegacyPass(InLTO);
+ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO,
+ bool SamplePGO) {
+ return new PGOIndirectCallPromotionLegacyPass(InLTO, SamplePGO);
+}
+
+char PGOMemOPSizeOptLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt",
+ "Optimize memory intrinsic using its size value profile",
+ false, false)
+INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
+INITIALIZE_PASS_END(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt",
+ "Optimize memory intrinsic using its size value profile",
+ false, false)
+
+FunctionPass *llvm::createPGOMemOPSizeOptLegacyPass() {
+ return new PGOMemOPSizeOptLegacyPass();
}
namespace {
@@ -144,17 +229,11 @@ private:
// defines.
InstrProfSymtab *Symtab;
- enum TargetStatus {
- OK, // Should be able to promote.
- NotAvailableInModule, // Cannot find the target in current module.
- ReturnTypeMismatch, // Return type mismatch b/w target and indirect-call.
- NumArgsMismatch, // Number of arguments does not match.
- ArgTypeMismatch // Type mismatch in the arguments (cannot bitcast).
- };
+ bool SamplePGO;
// Test if we can legally promote this direct-call of Target.
- TargetStatus isPromotionLegal(Instruction *Inst, uint64_t Target,
- Function *&F);
+ bool isPromotionLegal(Instruction *Inst, uint64_t Target, Function *&F,
+ const char **Reason = nullptr);
// A struct that records the direct target and it's call count.
struct PromotionCandidate {
@@ -172,91 +251,77 @@ private:
Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef,
uint64_t TotalCount, uint32_t NumCandidates);
- // Main function that transforms Inst (either a indirect-call instruction, or
- // an invoke instruction , to a conditional call to F. This is like:
- // if (Inst.CalledValue == F)
- // F(...);
- // else
- // Inst(...);
- // end
- // TotalCount is the profile count value that the instruction executes.
- // Count is the profile count value that F is the target function.
- // These two values are being used to update the branch weight.
- void promote(Instruction *Inst, Function *F, uint64_t Count,
- uint64_t TotalCount);
-
// Promote a list of targets for one indirect-call callsite. Return
// the number of promotions.
uint32_t tryToPromote(Instruction *Inst,
const std::vector<PromotionCandidate> &Candidates,
uint64_t &TotalCount);
- static const char *StatusToString(const TargetStatus S) {
- switch (S) {
- case OK:
- return "OK to promote";
- case NotAvailableInModule:
- return "Cannot find the target";
- case ReturnTypeMismatch:
- return "Return type mismatch";
- case NumArgsMismatch:
- return "The number of arguments mismatch";
- case ArgTypeMismatch:
- return "Argument Type mismatch";
- }
- llvm_unreachable("Should not reach here");
- }
-
// Noncopyable
ICallPromotionFunc(const ICallPromotionFunc &other) = delete;
ICallPromotionFunc &operator=(const ICallPromotionFunc &other) = delete;
public:
- ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab)
- : F(Func), M(Modu), Symtab(Symtab) {
- }
+ ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab,
+ bool SamplePGO)
+ : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO) {}
bool processFunction();
};
} // end anonymous namespace
-ICallPromotionFunc::TargetStatus
-ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target,
- Function *&TargetFunction) {
- Function *DirectCallee = Symtab->getFunction(Target);
- if (DirectCallee == nullptr)
- return NotAvailableInModule;
+bool llvm::isLegalToPromote(Instruction *Inst, Function *F,
+ const char **Reason) {
// Check the return type.
Type *CallRetType = Inst->getType();
if (!CallRetType->isVoidTy()) {
- Type *FuncRetType = DirectCallee->getReturnType();
+ Type *FuncRetType = F->getReturnType();
if (FuncRetType != CallRetType &&
- !CastInst::isBitCastable(FuncRetType, CallRetType))
- return ReturnTypeMismatch;
+ !CastInst::isBitCastable(FuncRetType, CallRetType)) {
+ if (Reason)
+ *Reason = "Return type mismatch";
+ return false;
+ }
}
// Check if the arguments are compatible with the parameters
- FunctionType *DirectCalleeType = DirectCallee->getFunctionType();
+ FunctionType *DirectCalleeType = F->getFunctionType();
unsigned ParamNum = DirectCalleeType->getFunctionNumParams();
CallSite CS(Inst);
unsigned ArgNum = CS.arg_size();
- if (ParamNum != ArgNum && !DirectCalleeType->isVarArg())
- return NumArgsMismatch;
+ if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) {
+ if (Reason)
+ *Reason = "The number of arguments mismatch";
+ return false;
+ }
for (unsigned I = 0; I < ParamNum; ++I) {
Type *PTy = DirectCalleeType->getFunctionParamType(I);
Type *ATy = CS.getArgument(I)->getType();
if (PTy == ATy)
continue;
- if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy))
- return ArgTypeMismatch;
+ if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) {
+ if (Reason)
+ *Reason = "Argument type mismatch";
+ return false;
+ }
}
DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to "
- << Symtab->getFuncName(Target) << "\n");
- TargetFunction = DirectCallee;
- return OK;
+ << F->getName() << "\n");
+ return true;
+}
+
+bool ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target,
+ Function *&TargetFunction,
+ const char **Reason) {
+ TargetFunction = Symtab->getFunction(Target);
+ if (TargetFunction == nullptr) {
+ *Reason = "Cannot find the target";
+ return false;
+ }
+ return isLegalToPromote(Inst, TargetFunction, Reason);
}
// Indirect-call promotion heuristic. The direct targets are sorted based on
@@ -296,10 +361,9 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite(
break;
}
Function *TargetFunction = nullptr;
- TargetStatus Status = isPromotionLegal(Inst, Target, TargetFunction);
- if (Status != OK) {
+ const char *Reason = nullptr;
+ if (!isPromotionLegal(Inst, Target, TargetFunction, &Reason)) {
StringRef TargetFuncName = Symtab->getFuncName(Target);
- const char *Reason = StatusToString(Status);
DEBUG(dbgs() << " Not promote: " << Reason << "\n");
emitOptimizationRemarkMissed(
F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(),
@@ -532,8 +596,14 @@ static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult,
// Ret = phi(Ret1, Ret2);
// It adds type casts for the args do not match the parameters and the return
// value. Branch weights metadata also updated.
-void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee,
- uint64_t Count, uint64_t TotalCount) {
+// If \p AttachProfToDirectCall is true, a prof metadata is attached to the
+// new direct call to contain \p Count. This is used by SamplePGO inliner to
+// check callsite hotness.
+// Returns the promoted direct call instruction.
+Instruction *llvm::promoteIndirectCall(Instruction *Inst,
+ Function *DirectCallee, uint64_t Count,
+ uint64_t TotalCount,
+ bool AttachProfToDirectCall) {
assert(DirectCallee != nullptr);
BasicBlock *BB = Inst->getParent();
// Just to suppress the non-debug build warning.
@@ -548,6 +618,14 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee,
Instruction *NewInst =
createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB);
+ if (AttachProfToDirectCall) {
+ SmallVector<uint32_t, 1> Weights;
+ Weights.push_back(Count);
+ MDBuilder MDB(NewInst->getContext());
+ dyn_cast<Instruction>(NewInst->stripPointerCasts())
+ ->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ }
+
// Move Inst from MergeBB to IndirectCallBB.
Inst->removeFromParent();
IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(),
@@ -576,9 +654,10 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee,
DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n");
emitOptimizationRemark(
- F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(),
+ BB->getContext(), "pgo-icall-prom", *BB->getParent(), Inst->getDebugLoc(),
Twine("Promote indirect call to ") + DirectCallee->getName() +
" with count " + Twine(Count) + " out of " + Twine(TotalCount));
+ return NewInst;
}
// Promote indirect-call to conditional direct-call for one callsite.
@@ -589,7 +668,7 @@ uint32_t ICallPromotionFunc::tryToPromote(
for (auto &C : Candidates) {
uint64_t Count = C.Count;
- promote(Inst, C.TargetFunction, Count, TotalCount);
+ promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, SamplePGO);
assert(TotalCount >= Count);
TotalCount -= Count;
NumOfPGOICallPromotion++;
@@ -630,7 +709,7 @@ bool ICallPromotionFunc::processFunction() {
}
// A wrapper function that does the actual work.
-static bool promoteIndirectCalls(Module &M, bool InLTO) {
+static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) {
if (DisableICP)
return false;
InstrProfSymtab Symtab;
@@ -641,7 +720,7 @@ static bool promoteIndirectCalls(Module &M, bool InLTO) {
continue;
if (F.hasFnAttribute(Attribute::OptimizeNone))
continue;
- ICallPromotionFunc ICallPromotion(F, &M, &Symtab);
+ ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO);
bool FuncChanged = ICallPromotion.processFunction();
if (ICPDUMPAFTER && FuncChanged) {
DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs()));
@@ -658,12 +737,289 @@ static bool promoteIndirectCalls(Module &M, bool InLTO) {
bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) {
// Command-line option has the priority for InLTO.
- return promoteIndirectCalls(M, InLTO | ICPLTOMode);
+ return promoteIndirectCalls(M, InLTO | ICPLTOMode,
+ SamplePGO | ICPSamplePGOMode);
}
-PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) {
- if (!promoteIndirectCalls(M, InLTO | ICPLTOMode))
+PreservedAnalyses PGOIndirectCallPromotion::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ if (!promoteIndirectCalls(M, InLTO | ICPLTOMode,
+ SamplePGO | ICPSamplePGOMode))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
+
+namespace {
+class MemOPSizeOpt : public InstVisitor<MemOPSizeOpt> {
+public:
+ MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI)
+ : Func(Func), BFI(BFI), Changed(false) {
+ ValueDataArray =
+ llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2);
+ // Get the MemOPSize range information from option MemOPSizeRange,
+ getMemOPSizeRangeFromOption(MemOPSizeRange, PreciseRangeStart,
+ PreciseRangeLast);
+ }
+ bool isChanged() const { return Changed; }
+ void perform() {
+ WorkList.clear();
+ visit(Func);
+
+ for (auto &MI : WorkList) {
+ ++NumOfPGOMemOPAnnotate;
+ if (perform(MI)) {
+ Changed = true;
+ ++NumOfPGOMemOPOpt;
+ DEBUG(dbgs() << "MemOP calls: " << MI->getCalledFunction()->getName()
+ << "is Transformed.\n");
+ }
+ }
+ }
+
+ void visitMemIntrinsic(MemIntrinsic &MI) {
+ Value *Length = MI.getLength();
+ // Not perform on constant length calls.
+ if (dyn_cast<ConstantInt>(Length))
+ return;
+ WorkList.push_back(&MI);
+ }
+
+private:
+ Function &Func;
+ BlockFrequencyInfo &BFI;
+ bool Changed;
+ std::vector<MemIntrinsic *> WorkList;
+ // Start of the previse range.
+ int64_t PreciseRangeStart;
+ // Last value of the previse range.
+ int64_t PreciseRangeLast;
+ // The space to read the profile annotation.
+ std::unique_ptr<InstrProfValueData[]> ValueDataArray;
+ bool perform(MemIntrinsic *MI);
+
+ // This kind shows which group the value falls in. For PreciseValue, we have
+ // the profile count for that value. LargeGroup groups the values that are in
+ // range [LargeValue, +inf). NonLargeGroup groups the rest of values.
+ enum MemOPSizeKind { PreciseValue, NonLargeGroup, LargeGroup };
+
+ MemOPSizeKind getMemOPSizeKind(int64_t Value) const {
+ if (Value == MemOPSizeLarge && MemOPSizeLarge != 0)
+ return LargeGroup;
+ if (Value == PreciseRangeLast + 1)
+ return NonLargeGroup;
+ return PreciseValue;
+ }
+};
+
+static const char *getMIName(const MemIntrinsic *MI) {
+ switch (MI->getIntrinsicID()) {
+ case Intrinsic::memcpy:
+ return "memcpy";
+ case Intrinsic::memmove:
+ return "memmove";
+ case Intrinsic::memset:
+ return "memset";
+ default:
+ return "unknown";
+ }
+}
+
+static bool isProfitable(uint64_t Count, uint64_t TotalCount) {
+ assert(Count <= TotalCount);
+ if (Count < MemOPCountThreshold)
+ return false;
+ if (Count < TotalCount * MemOPPercentThreshold / 100)
+ return false;
+ return true;
+}
+
+static inline uint64_t getScaledCount(uint64_t Count, uint64_t Num,
+ uint64_t Denom) {
+ if (!MemOPScaleCount)
+ return Count;
+ bool Overflowed;
+ uint64_t ScaleCount = SaturatingMultiply(Count, Num, &Overflowed);
+ return ScaleCount / Denom;
+}
+
+bool MemOPSizeOpt::perform(MemIntrinsic *MI) {
+ assert(MI);
+ if (MI->getIntrinsicID() == Intrinsic::memmove)
+ return false;
+
+ uint32_t NumVals, MaxNumPromotions = MemOPMaxVersion + 2;
+ uint64_t TotalCount;
+ if (!getValueProfDataFromInst(*MI, IPVK_MemOPSize, MaxNumPromotions,
+ ValueDataArray.get(), NumVals, TotalCount))
+ return false;
+
+ uint64_t ActualCount = TotalCount;
+ uint64_t SavedTotalCount = TotalCount;
+ if (MemOPScaleCount) {
+ auto BBEdgeCount = BFI.getBlockProfileCount(MI->getParent());
+ if (!BBEdgeCount)
+ return false;
+ ActualCount = *BBEdgeCount;
+ }
+
+ if (ActualCount < MemOPCountThreshold)
+ return false;
+
+ ArrayRef<InstrProfValueData> VDs(ValueDataArray.get(), NumVals);
+ TotalCount = ActualCount;
+ if (MemOPScaleCount)
+ DEBUG(dbgs() << "Scale counts: numberator = " << ActualCount
+ << " denominator = " << SavedTotalCount << "\n");
+
+ // Keeping track of the count of the default case:
+ uint64_t RemainCount = TotalCount;
+ SmallVector<uint64_t, 16> SizeIds;
+ SmallVector<uint64_t, 16> CaseCounts;
+ uint64_t MaxCount = 0;
+ unsigned Version = 0;
+ // Default case is in the front -- save the slot here.
+ CaseCounts.push_back(0);
+ for (auto &VD : VDs) {
+ int64_t V = VD.Value;
+ uint64_t C = VD.Count;
+ if (MemOPScaleCount)
+ C = getScaledCount(C, ActualCount, SavedTotalCount);
+
+ // Only care precise value here.
+ if (getMemOPSizeKind(V) != PreciseValue)
+ continue;
+
+ // ValueCounts are sorted on the count. Break at the first un-profitable
+ // value.
+ if (!isProfitable(C, RemainCount))
+ break;
+
+ SizeIds.push_back(V);
+ CaseCounts.push_back(C);
+ if (C > MaxCount)
+ MaxCount = C;
+
+ assert(RemainCount >= C);
+ RemainCount -= C;
+
+ if (++Version > MemOPMaxVersion && MemOPMaxVersion != 0)
+ break;
+ }
+
+ if (Version == 0)
+ return false;
+
+ CaseCounts[0] = RemainCount;
+ if (RemainCount > MaxCount)
+ MaxCount = RemainCount;
+
+ uint64_t SumForOpt = TotalCount - RemainCount;
+ DEBUG(dbgs() << "Read one memory intrinsic profile: " << SumForOpt << " vs "
+ << TotalCount << "\n");
+ DEBUG(
+ for (auto &VD
+ : VDs) { dbgs() << " (" << VD.Value << "," << VD.Count << ")\n"; });
+
+ DEBUG(dbgs() << "Optimize one memory intrinsic call to " << Version
+ << " Versions\n");
+
+ // mem_op(..., size)
+ // ==>
+ // switch (size) {
+ // case s1:
+ // mem_op(..., s1);
+ // goto merge_bb;
+ // case s2:
+ // mem_op(..., s2);
+ // goto merge_bb;
+ // ...
+ // default:
+ // mem_op(..., size);
+ // goto merge_bb;
+ // }
+ // merge_bb:
+
+ BasicBlock *BB = MI->getParent();
+ DEBUG(dbgs() << "\n\n== Basic Block Before ==\n");
+ DEBUG(dbgs() << *BB << "\n");
+
+ BasicBlock *DefaultBB = SplitBlock(BB, MI);
+ BasicBlock::iterator It(*MI);
+ ++It;
+ assert(It != DefaultBB->end());
+ BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It));
+ DefaultBB->setName("MemOP.Default");
+ MergeBB->setName("MemOP.Merge");
+
+ auto &Ctx = Func.getContext();
+ IRBuilder<> IRB(BB);
+ BB->getTerminator()->eraseFromParent();
+ Value *SizeVar = MI->getLength();
+ SwitchInst *SI = IRB.CreateSwitch(SizeVar, DefaultBB, SizeIds.size());
+
+ // Clear the value profile data.
+ MI->setMetadata(LLVMContext::MD_prof, nullptr);
+
+ DEBUG(dbgs() << "\n\n== Basic Block After==\n");
+
+ for (uint64_t SizeId : SizeIds) {
+ ConstantInt *CaseSizeId = ConstantInt::get(Type::getInt64Ty(Ctx), SizeId);
+ BasicBlock *CaseBB = BasicBlock::Create(
+ Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB);
+ Instruction *NewInst = MI->clone();
+ // Fix the argument.
+ dyn_cast<MemIntrinsic>(NewInst)->setLength(CaseSizeId);
+ CaseBB->getInstList().push_back(NewInst);
+ IRBuilder<> IRBCase(CaseBB);
+ IRBCase.CreateBr(MergeBB);
+ SI->addCase(CaseSizeId, CaseBB);
+ DEBUG(dbgs() << *CaseBB << "\n");
+ }
+ setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount);
+
+ DEBUG(dbgs() << *BB << "\n");
+ DEBUG(dbgs() << *DefaultBB << "\n");
+ DEBUG(dbgs() << *MergeBB << "\n");
+
+ emitOptimizationRemark(Func.getContext(), "memop-opt", Func,
+ MI->getDebugLoc(),
+ Twine("optimize ") + getMIName(MI) + " with count " +
+ Twine(SumForOpt) + " out of " + Twine(TotalCount) +
+ " for " + Twine(Version) + " versions");
+
+ return true;
+}
+} // namespace
+
+static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI) {
+ if (DisableMemOPOPT)
+ return false;
+
+ if (F.hasFnAttribute(Attribute::OptimizeForSize))
+ return false;
+ MemOPSizeOpt MemOPSizeOpt(F, BFI);
+ MemOPSizeOpt.perform();
+ return MemOPSizeOpt.isChanged();
+}
+
+bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) {
+ BlockFrequencyInfo &BFI =
+ getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI();
+ return PGOMemOPSizeOptImpl(F, BFI);
+}
+
+namespace llvm {
+char &PGOMemOPSizeOptID = PGOMemOPSizeOptLegacyPass::ID;
+
+PreservedAnalyses PGOMemOPSizeOpt::run(Function &F,
+ FunctionAnalysisManager &FAM) {
+ auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
+ bool Changed = PGOMemOPSizeOptImpl(F, BFI);
+ if (!Changed)
+ return PreservedAnalyses::all();
+ auto PA = PreservedAnalyses();
+ PA.preserve<GlobalsAA>();
+ return PA;
+}
+} // namespace llvm
diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp
index adea7e7724476..d91ac6ac78830 100644
--- a/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -14,18 +14,58 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/InstrProfiling.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <string>
using namespace llvm;
#define DEBUG_TYPE "instrprof"
+// The start and end values of precise value profile range for memory
+// intrinsic sizes
+cl::opt<std::string> MemOPSizeRange(
+ "memop-size-range",
+ cl::desc("Set the range of size in memory intrinsic calls to be profiled "
+ "precisely, in a format of <start_val>:<end_val>"),
+ cl::init(""));
+
+// The value that considered to be large value in memory intrinsic.
+cl::opt<unsigned> MemOPSizeLarge(
+ "memop-size-large",
+ cl::desc("Set large value thresthold in memory intrinsic size profiling. "
+ "Value of 0 disables the large value profiling."),
+ cl::init(8192));
+
namespace {
cl::opt<bool> DoNameCompression("enable-name-compression",
@@ -41,6 +81,7 @@ cl::opt<bool> ValueProfileStaticAlloc(
"vp-static-alloc",
cl::desc("Do static counter allocation for value profiler"),
cl::init(true));
+
cl::opt<double> NumCountersPerValueSite(
"vp-counters-per-site",
cl::desc("The average number of profile counters allocated "
@@ -56,9 +97,11 @@ class InstrProfilingLegacyPass : public ModulePass {
public:
static char ID;
- InstrProfilingLegacyPass() : ModulePass(ID), InstrProf() {}
+
+ InstrProfilingLegacyPass() : ModulePass(ID) {}
InstrProfilingLegacyPass(const InstrProfOptions &Options)
: ModulePass(ID), InstrProf(Options) {}
+
StringRef getPassName() const override {
return "Frontend instrumentation-based coverage lowering";
}
@@ -73,7 +116,7 @@ public:
}
};
-} // anonymous namespace
+} // end anonymous namespace
PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) {
auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
@@ -97,30 +140,6 @@ llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) {
return new InstrProfilingLegacyPass(Options);
}
-bool InstrProfiling::isMachO() const {
- return Triple(M->getTargetTriple()).isOSBinFormatMachO();
-}
-
-/// Get the section name for the counter variables.
-StringRef InstrProfiling::getCountersSection() const {
- return getInstrProfCountersSectionName(isMachO());
-}
-
-/// Get the section name for the name variables.
-StringRef InstrProfiling::getNameSection() const {
- return getInstrProfNameSectionName(isMachO());
-}
-
-/// Get the section name for the profile data variables.
-StringRef InstrProfiling::getDataSection() const {
- return getInstrProfDataSectionName(isMachO());
-}
-
-/// Get the section name for the coverage mapping data.
-StringRef InstrProfiling::getCoverageSection() const {
- return getInstrProfCoverageSectionName(isMachO());
-}
-
static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) {
InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr);
if (Inc)
@@ -137,6 +156,9 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) {
NamesSize = 0;
ProfileDataMap.clear();
UsedVars.clear();
+ getMemOPSizeRangeFromOption(MemOPSizeRange, MemOPSizeRangeStart,
+ MemOPSizeRangeLast);
+ TT = Triple(M.getTargetTriple());
// We did not know how many value sites there would be inside
// the instrumented function. This is counting the number of instrumented
@@ -189,17 +211,34 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) {
}
static Constant *getOrInsertValueProfilingCall(Module &M,
- const TargetLibraryInfo &TLI) {
+ const TargetLibraryInfo &TLI,
+ bool IsRange = false) {
LLVMContext &Ctx = M.getContext();
auto *ReturnTy = Type::getVoidTy(M.getContext());
- Type *ParamTypes[] = {
+
+ Constant *Res;
+ if (!IsRange) {
+ Type *ParamTypes[] = {
#define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType
#include "llvm/ProfileData/InstrProfData.inc"
- };
- auto *ValueProfilingCallTy =
- FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false);
- Constant *Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(),
- ValueProfilingCallTy);
+ };
+ auto *ValueProfilingCallTy =
+ FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false);
+ Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(),
+ ValueProfilingCallTy);
+ } else {
+ Type *RangeParamTypes[] = {
+#define VALUE_RANGE_PROF 1
+#define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType
+#include "llvm/ProfileData/InstrProfData.inc"
+#undef VALUE_RANGE_PROF
+ };
+ auto *ValueRangeProfilingCallTy =
+ FunctionType::get(ReturnTy, makeArrayRef(RangeParamTypes), false);
+ Res = M.getOrInsertFunction(getInstrProfValueRangeProfFuncName(),
+ ValueRangeProfilingCallTy);
+ }
+
if (Function *FunRes = dyn_cast<Function>(Res)) {
if (auto AK = TLI.getExtAttrForI32Param(false))
FunRes->addAttribute(3, AK);
@@ -208,7 +247,6 @@ static Constant *getOrInsertValueProfilingCall(Module &M,
}
void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) {
-
GlobalVariable *Name = Ind->getName();
uint64_t ValueKind = Ind->getValueKind()->getZExtValue();
uint64_t Index = Ind->getIndex()->getZExtValue();
@@ -222,7 +260,6 @@ void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) {
}
void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {
-
GlobalVariable *Name = Ind->getName();
auto It = ProfileDataMap.find(Name);
assert(It != ProfileDataMap.end() && It->second.DataVar &&
@@ -235,11 +272,25 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {
Index += It->second.NumValueSites[Kind];
IRBuilder<> Builder(Ind);
- Value *Args[3] = {Ind->getTargetValue(),
- Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
- Builder.getInt32(Index)};
- CallInst *Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI),
- Args);
+ bool IsRange = (Ind->getValueKind()->getZExtValue() ==
+ llvm::InstrProfValueKind::IPVK_MemOPSize);
+ CallInst *Call = nullptr;
+ if (!IsRange) {
+ Value *Args[3] = {Ind->getTargetValue(),
+ Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
+ Builder.getInt32(Index)};
+ Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), Args);
+ } else {
+ Value *Args[6] = {
+ Ind->getTargetValue(),
+ Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
+ Builder.getInt32(Index),
+ Builder.getInt64(MemOPSizeRangeStart),
+ Builder.getInt64(MemOPSizeRangeLast),
+ Builder.getInt64(MemOPSizeLarge == 0 ? INT64_MIN : MemOPSizeLarge)};
+ Call =
+ Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI, true), Args);
+ }
if (auto AK = TLI->getExtAttrForI32Param(false))
Call->addAttribute(3, AK);
Ind->replaceAllUsesWith(Call);
@@ -259,7 +310,6 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) {
}
void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) {
-
ConstantArray *Names =
cast<ConstantArray>(CoverageNamesVar->getInitializer());
for (unsigned I = 0, E = Names->getNumOperands(); I < E; ++I) {
@@ -270,7 +320,9 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) {
Name->setLinkage(GlobalValue::PrivateLinkage);
ReferencedNames.push_back(Name);
+ NC->dropAllReferences();
}
+ CoverageNamesVar->eraseFromParent();
}
/// Get the name of a profiling variable for a particular function.
@@ -367,7 +419,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
Constant::getNullValue(CounterTy),
getVarName(Inc, getInstrProfCountersVarPrefix()));
CounterPtr->setVisibility(NamePtr->getVisibility());
- CounterPtr->setSection(getCountersSection());
+ CounterPtr->setSection(
+ getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat()));
CounterPtr->setAlignment(8);
CounterPtr->setComdat(ProfileVarsComdat);
@@ -376,7 +429,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
// the current function.
Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy);
if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(*M)) {
-
uint64_t NS = 0;
for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
NS += PD.NumValueSites[Kind];
@@ -388,11 +440,12 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
Constant::getNullValue(ValuesTy),
getVarName(Inc, getInstrProfValuesVarPrefix()));
ValuesVar->setVisibility(NamePtr->getVisibility());
- ValuesVar->setSection(getInstrProfValuesSectionName(isMachO()));
+ ValuesVar->setSection(
+ getInstrProfSectionName(IPSK_vals, TT.getObjectFormat()));
ValuesVar->setAlignment(8);
ValuesVar->setComdat(ProfileVarsComdat);
ValuesPtrExpr =
- ConstantExpr::getBitCast(ValuesVar, llvm::Type::getInt8PtrTy(Ctx));
+ ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx));
}
}
@@ -421,7 +474,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
ConstantStruct::get(DataTy, DataVals),
getVarName(Inc, getInstrProfDataVarPrefix()));
Data->setVisibility(NamePtr->getVisibility());
- Data->setSection(getDataSection());
+ Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat()));
Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT);
Data->setComdat(ProfileVarsComdat);
@@ -481,9 +534,10 @@ void InstrProfiling::emitVNodes() {
ArrayType *VNodesTy = ArrayType::get(VNodeTy, NumCounters);
auto *VNodesVar = new GlobalVariable(
- *M, VNodesTy, false, llvm::GlobalValue::PrivateLinkage,
+ *M, VNodesTy, false, GlobalValue::PrivateLinkage,
Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName());
- VNodesVar->setSection(getInstrProfVNodesSectionName(isMachO()));
+ VNodesVar->setSection(
+ getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat()));
UsedVars.push_back(VNodesVar);
}
@@ -496,18 +550,22 @@ void InstrProfiling::emitNameData() {
std::string CompressedNameStr;
if (Error E = collectPGOFuncNameStrings(ReferencedNames, CompressedNameStr,
DoNameCompression)) {
- llvm::report_fatal_error(toString(std::move(E)), false);
+ report_fatal_error(toString(std::move(E)), false);
}
auto &Ctx = M->getContext();
- auto *NamesVal = llvm::ConstantDataArray::getString(
+ auto *NamesVal = ConstantDataArray::getString(
Ctx, StringRef(CompressedNameStr), false);
- NamesVar = new llvm::GlobalVariable(*M, NamesVal->getType(), true,
- llvm::GlobalValue::PrivateLinkage,
- NamesVal, getInstrProfNamesVarName());
+ NamesVar = new GlobalVariable(*M, NamesVal->getType(), true,
+ GlobalValue::PrivateLinkage, NamesVal,
+ getInstrProfNamesVarName());
NamesSize = CompressedNameStr.size();
- NamesVar->setSection(getNameSection());
+ NamesVar->setSection(
+ getInstrProfSectionName(IPSK_name, TT.getObjectFormat()));
UsedVars.push_back(NamesVar);
+
+ for (auto *NamePtr : ReferencedNames)
+ NamePtr->eraseFromParent();
}
void InstrProfiling::emitRegistration() {
@@ -550,7 +608,6 @@ void InstrProfiling::emitRegistration() {
}
void InstrProfiling::emitRuntimeHook() {
-
// We expect the linker to be invoked with -u<hook_var> flag for linux,
// for which case there is no need to emit the user function.
if (Triple(M->getTargetTriple()).isOSLinux())
@@ -600,7 +657,6 @@ void InstrProfiling::emitInitialization() {
GlobalVariable *ProfileNameVar = new GlobalVariable(
*M, ProfileNameConst->getType(), true, GlobalValue::WeakAnyLinkage,
ProfileNameConst, INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR));
- Triple TT(M->getTargetTriple());
if (TT.supportsCOMDAT()) {
ProfileNameVar->setLinkage(GlobalValue::ExternalLinkage);
ProfileNameVar->setComdat(M->getOrInsertComdat(
diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp
index 2963d08752c46..7bb62d2c8455f 100644
--- a/lib/Transforms/Instrumentation/Instrumentation.cpp
+++ b/lib/Transforms/Instrumentation/Instrumentation.cpp
@@ -63,6 +63,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) {
initializePGOInstrumentationGenLegacyPassPass(Registry);
initializePGOInstrumentationUseLegacyPassPass(Registry);
initializePGOIndirectCallPromotionLegacyPassPass(Registry);
+ initializePGOMemOPSizeOptLegacyPassPass(Registry);
initializeInstrProfilingLegacyPassPass(Registry);
initializeMemorySanitizerPass(Registry);
initializeThreadSanitizerPass(Registry);
diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index fafb0fcbd0172..190f05db4b0c1 100644
--- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -425,7 +425,7 @@ void MemorySanitizer::initializeCallbacks(Module &M) {
// which is not yet implemented.
StringRef WarningFnName = Recover ? "__msan_warning"
: "__msan_warning_noreturn";
- WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), nullptr);
+ WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy());
for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;
AccessSizeIndex++) {
@@ -433,31 +433,31 @@ void MemorySanitizer::initializeCallbacks(Module &M) {
std::string FunctionName = "__msan_maybe_warning_" + itostr(AccessSize);
MaybeWarningFn[AccessSizeIndex] = M.getOrInsertFunction(
FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8),
- IRB.getInt32Ty(), nullptr);
+ IRB.getInt32Ty());
FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize);
MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction(
FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8),
- IRB.getInt8PtrTy(), IRB.getInt32Ty(), nullptr);
+ IRB.getInt8PtrTy(), IRB.getInt32Ty());
}
MsanSetAllocaOrigin4Fn = M.getOrInsertFunction(
"__msan_set_alloca_origin4", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy,
- IRB.getInt8PtrTy(), IntptrTy, nullptr);
+ IRB.getInt8PtrTy(), IntptrTy);
MsanPoisonStackFn =
M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr);
+ IRB.getInt8PtrTy(), IntptrTy);
MsanChainOriginFn = M.getOrInsertFunction(
- "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty(), nullptr);
+ "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty());
MemmoveFn = M.getOrInsertFunction(
"__msan_memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr);
+ IRB.getInt8PtrTy(), IntptrTy);
MemcpyFn = M.getOrInsertFunction(
"__msan_memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IntptrTy, nullptr);
+ IntptrTy);
MemsetFn = M.getOrInsertFunction(
"__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(),
- IntptrTy, nullptr);
+ IntptrTy);
// Create globals.
RetvalTLS = new GlobalVariable(
@@ -1037,15 +1037,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
OriginMap[V] = Origin;
}
+ Constant *getCleanShadow(Type *OrigTy) {
+ Type *ShadowTy = getShadowTy(OrigTy);
+ if (!ShadowTy)
+ return nullptr;
+ return Constant::getNullValue(ShadowTy);
+ }
+
/// \brief Create a clean shadow value for a given value.
///
/// Clean shadow (all zeroes) means all bits of the value are defined
/// (initialized).
Constant *getCleanShadow(Value *V) {
- Type *ShadowTy = getShadowTy(V);
- if (!ShadowTy)
- return nullptr;
- return Constant::getNullValue(ShadowTy);
+ return getCleanShadow(V->getType());
}
/// \brief Create a dirty shadow of a given shadow type.
@@ -1942,7 +1946,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (ClCheckAccessAddress)
insertShadowCheck(Addr, &I);
- // FIXME: use ClStoreCleanOrigin
// FIXME: factor out common code from materializeStores
if (MS.TrackOrigins)
IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB, 1));
@@ -2325,11 +2328,49 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}
+ void handleStmxcsr(IntrinsicInst &I) {
+ IRBuilder<> IRB(&I);
+ Value* Addr = I.getArgOperand(0);
+ Type *Ty = IRB.getInt32Ty();
+ Value *ShadowPtr = getShadowPtr(Addr, Ty, IRB);
+
+ IRB.CreateStore(getCleanShadow(Ty),
+ IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo()));
+
+ if (ClCheckAccessAddress)
+ insertShadowCheck(Addr, &I);
+ }
+
+ void handleLdmxcsr(IntrinsicInst &I) {
+ if (!InsertChecks) return;
+
+ IRBuilder<> IRB(&I);
+ Value *Addr = I.getArgOperand(0);
+ Type *Ty = IRB.getInt32Ty();
+ unsigned Alignment = 1;
+
+ if (ClCheckAccessAddress)
+ insertShadowCheck(Addr, &I);
+
+ Value *Shadow = IRB.CreateAlignedLoad(getShadowPtr(Addr, Ty, IRB),
+ Alignment, "_ldmxcsr");
+ Value *Origin = MS.TrackOrigins
+ ? IRB.CreateLoad(getOriginPtr(Addr, IRB, Alignment))
+ : getCleanOrigin();
+ insertShadowCheck(Shadow, Origin, &I);
+ }
+
void visitIntrinsicInst(IntrinsicInst &I) {
switch (I.getIntrinsicID()) {
case llvm::Intrinsic::bswap:
handleBswap(I);
break;
+ case llvm::Intrinsic::x86_sse_stmxcsr:
+ handleStmxcsr(I);
+ break;
+ case llvm::Intrinsic::x86_sse_ldmxcsr:
+ handleLdmxcsr(I);
+ break;
case llvm::Intrinsic::x86_avx512_vcvtsd2usi64:
case llvm::Intrinsic::x86_avx512_vcvtsd2usi32:
case llvm::Intrinsic::x86_avx512_vcvtss2usi64:
@@ -2566,10 +2607,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
AttrBuilder B;
B.addAttribute(Attribute::ReadOnly)
.addAttribute(Attribute::ReadNone);
- Func->removeAttributes(AttributeSet::FunctionIndex,
- AttributeSet::get(Func->getContext(),
- AttributeSet::FunctionIndex,
- B));
+ Func->removeAttributes(AttributeList::FunctionIndex,
+ AttributeList::get(Func->getContext(),
+ AttributeList::FunctionIndex,
+ B));
}
maybeMarkSanitizerLibraryCallNoBuiltin(Call, TLI);
@@ -2597,7 +2638,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
" Shadow: " << *ArgShadow << "\n");
bool ArgIsInitialized = false;
const DataLayout &DL = F.getParent()->getDataLayout();
- if (CS.paramHasAttr(i + 1, Attribute::ByVal)) {
+ if (CS.paramHasAttr(i, Attribute::ByVal)) {
assert(A->getType()->isPointerTy() &&
"ByVal argument is not a pointer!");
Size = DL.getTypeAllocSize(A->getType()->getPointerElementType());
@@ -2690,7 +2731,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
} else {
Value *Shadow = getShadow(RetVal);
IRB.CreateAlignedStore(Shadow, ShadowPtr, kShadowTLSAlignment);
- // FIXME: make it conditional if ClStoreCleanOrigin==0
if (MS.TrackOrigins)
IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval(IRB));
}
@@ -2717,15 +2757,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOrigin(&I, getCleanOrigin());
IRBuilder<> IRB(I.getNextNode());
const DataLayout &DL = F.getParent()->getDataLayout();
- uint64_t Size = DL.getTypeAllocSize(I.getAllocatedType());
+ uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType());
+ Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize);
+ if (I.isArrayAllocation())
+ Len = IRB.CreateMul(Len, I.getArraySize());
if (PoisonStack && ClPoisonStackWithCall) {
IRB.CreateCall(MS.MsanPoisonStackFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()),
- ConstantInt::get(MS.IntptrTy, Size)});
+ {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len});
} else {
Value *ShadowBase = getShadowPtr(&I, Type::getInt8PtrTy(*MS.C), IRB);
Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0);
- IRB.CreateMemSet(ShadowBase, PoisonValue, Size, I.getAlignment());
+ IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment());
}
if (PoisonStack && MS.TrackOrigins) {
@@ -2742,8 +2784,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
StackDescription.str());
IRB.CreateCall(MS.MsanSetAllocaOrigin4Fn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()),
- ConstantInt::get(MS.IntptrTy, Size),
+ {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len,
IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy()),
IRB.CreatePointerCast(&F, MS.IntptrTy)});
}
@@ -2935,7 +2976,7 @@ struct VarArgAMD64Helper : public VarArgHelper {
Value *A = *ArgIt;
unsigned ArgNo = CS.getArgumentNo(ArgIt);
bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams();
- bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal);
+ bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal);
if (IsByVal) {
// ByVal arguments always go to the overflow area.
// Fixed arguments passed through the overflow area will be stepped
@@ -3456,7 +3497,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
Value *A = *ArgIt;
unsigned ArgNo = CS.getArgumentNo(ArgIt);
bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams();
- bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal);
+ bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal);
if (IsByVal) {
assert(A->getType()->isPointerTy());
Type *RealTy = A->getType()->getPointerElementType();
@@ -3618,9 +3659,9 @@ bool MemorySanitizer::runOnFunction(Function &F) {
AttrBuilder B;
B.addAttribute(Attribute::ReadOnly)
.addAttribute(Attribute::ReadNone);
- F.removeAttributes(AttributeSet::FunctionIndex,
- AttributeSet::get(F.getContext(),
- AttributeSet::FunctionIndex, B));
+ F.removeAttributes(
+ AttributeList::FunctionIndex,
+ AttributeList::get(F.getContext(), AttributeList::FunctionIndex, B));
return Visitor.runOnFunction();
}
diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 04f9a64bef9fc..990bcec109de7 100644
--- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -58,8 +58,10 @@
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/IndirectCallSiteVisitor.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
@@ -71,7 +73,9 @@
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/ProfileData/ProfileCommon.h"
#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/DOTGraphTraits.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/JamCRC.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -87,6 +91,7 @@ using namespace llvm;
STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");
STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented.");
+STATISTIC(NumOfPGOMemIntrinsics, "Number of mem intrinsics instrumented.");
STATISTIC(NumOfPGOEdge, "Number of edges.");
STATISTIC(NumOfPGOBB, "Number of basic-blocks.");
STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
@@ -116,6 +121,13 @@ static cl::opt<unsigned> MaxNumAnnotations(
cl::desc("Max number of annotations for a single indirect "
"call callsite"));
+// Command line option to set the maximum number of value annotations
+// to write to the metadata for a single memop intrinsic.
+static cl::opt<unsigned> MaxNumMemOPAnnotations(
+ "memop-max-annotations", cl::init(4), cl::Hidden, cl::ZeroOrMore,
+ cl::desc("Max number of preicise value annotations for a single memop"
+ "intrinsic"));
+
// Command line option to control appending FunctionHash to the name of a COMDAT
// function. This is to avoid the hash mismatch caused by the preinliner.
static cl::opt<bool> DoComdatRenaming(
@@ -125,24 +137,59 @@ static cl::opt<bool> DoComdatRenaming(
// Command line option to enable/disable the warning about missing profile
// information.
-static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function",
- cl::init(false),
- cl::Hidden);
+static cl::opt<bool>
+ PGOWarnMissing("pgo-warn-missing-function", cl::init(false), cl::Hidden,
+ cl::desc("Use this option to turn on/off "
+ "warnings about missing profile data for "
+ "functions."));
// Command line option to enable/disable the warning about a hash mismatch in
// the profile data.
-static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false),
- cl::Hidden);
+static cl::opt<bool>
+ NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden,
+ cl::desc("Use this option to turn off/on "
+ "warnings about profile cfg mismatch."));
// Command line option to enable/disable the warning about a hash mismatch in
// the profile data for Comdat functions, which often turns out to be false
// positive due to the pre-instrumentation inline.
-static cl::opt<bool> NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat",
- cl::init(true), cl::Hidden);
+static cl::opt<bool>
+ NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", cl::init(true),
+ cl::Hidden,
+ cl::desc("The option is used to turn on/off "
+ "warnings about hash mismatch for comdat "
+ "functions."));
// Command line option to enable/disable select instruction instrumentation.
-static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true),
- cl::Hidden);
+static cl::opt<bool>
+ PGOInstrSelect("pgo-instr-select", cl::init(true), cl::Hidden,
+ cl::desc("Use this option to turn on/off SELECT "
+ "instruction instrumentation. "));
+
+// Command line option to turn on CFG dot dump of raw profile counts
+static cl::opt<bool>
+ PGOViewRawCounts("pgo-view-raw-counts", cl::init(false), cl::Hidden,
+ cl::desc("A boolean option to show CFG dag "
+ "with raw profile counts from "
+ "profile data. See also option "
+ "-pgo-view-counts. To limit graph "
+ "display to only one function, use "
+ "filtering option -view-bfi-func-name."));
+
+// Command line option to enable/disable memop intrinsic call.size profiling.
+static cl::opt<bool>
+ PGOInstrMemOP("pgo-instr-memop", cl::init(true), cl::Hidden,
+ cl::desc("Use this option to turn on/off "
+ "memory instrinsic size profiling."));
+
+// Command line option to turn on CFG dot dump after profile annotation.
+// Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts
+extern cl::opt<bool> PGOViewCounts;
+
+// Command line option to specify the name of the function for CFG dump
+// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
+extern cl::opt<std::string> ViewBlockFreqFuncName;
+
namespace {
/// The select instruction visitor plays three roles specified
@@ -167,6 +214,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
SelectInstVisitor(Function &Func) : F(Func) {}
void countSelects(Function &Func) {
+ NSIs = 0;
Mode = VM_counting;
visit(Func);
}
@@ -196,9 +244,54 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
void annotateOneSelectInst(SelectInst &SI);
// Visit \p SI instruction and perform tasks according to visit mode.
void visitSelectInst(SelectInst &SI);
+ // Return the number of select instructions. This needs be called after
+ // countSelects().
unsigned getNumOfSelectInsts() const { return NSIs; }
};
+/// Instruction Visitor class to visit memory intrinsic calls.
+struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> {
+ Function &F;
+ unsigned NMemIs = 0; // Number of memIntrinsics instrumented.
+ VisitMode Mode = VM_counting; // Visiting mode.
+ unsigned CurCtrId = 0; // Current counter index.
+ unsigned TotalNumCtrs = 0; // Total number of counters
+ GlobalVariable *FuncNameVar = nullptr;
+ uint64_t FuncHash = 0;
+ PGOUseFunc *UseFunc = nullptr;
+ std::vector<Instruction *> Candidates;
+
+ MemIntrinsicVisitor(Function &Func) : F(Func) {}
+
+ void countMemIntrinsics(Function &Func) {
+ NMemIs = 0;
+ Mode = VM_counting;
+ visit(Func);
+ }
+
+ void instrumentMemIntrinsics(Function &Func, unsigned TotalNC,
+ GlobalVariable *FNV, uint64_t FHash) {
+ Mode = VM_instrument;
+ TotalNumCtrs = TotalNC;
+ FuncHash = FHash;
+ FuncNameVar = FNV;
+ visit(Func);
+ }
+
+ std::vector<Instruction *> findMemIntrinsics(Function &Func) {
+ Candidates.clear();
+ Mode = VM_annotate;
+ visit(Func);
+ return Candidates;
+ }
+
+ // Visit the IR stream and annotate all mem intrinsic call instructions.
+ void instrumentOneMemIntrinsic(MemIntrinsic &MI);
+ // Visit \p MI instruction and perform tasks according to visit mode.
+ void visitMemIntrinsic(MemIntrinsic &SI);
+ unsigned getNumOfMemIntrinsics() const { return NMemIs; }
+};
+
class PGOInstrumentationGenLegacyPass : public ModulePass {
public:
static char ID;
@@ -316,8 +409,9 @@ private:
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
public:
- std::vector<Instruction *> IndirectCallSites;
+ std::vector<std::vector<Instruction *>> ValueSites;
SelectInstVisitor SIVisitor;
+ MemIntrinsicVisitor MIVisitor;
std::string FuncName;
GlobalVariable *FuncNameVar;
// CFG hash value for this function.
@@ -347,13 +441,16 @@ public:
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
BlockFrequencyInfo *BFI = nullptr)
- : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0),
- MST(F, BPI, BFI) {
+ : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1),
+ SIVisitor(Func), MIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) {
// This should be done before CFG hash computation.
SIVisitor.countSelects(Func);
+ MIVisitor.countMemIntrinsics(Func);
NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
- IndirectCallSites = findIndirectCallSites(Func);
+ NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
+ ValueSites[IPVK_IndirectCallTarget] = findIndirectCallSites(Func);
+ ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func);
FuncName = getPGOFuncName(F);
computeCFGHash();
@@ -405,7 +502,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
}
JC.update(Indexes);
FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 |
- (uint64_t)IndirectCallSites.size() << 48 |
+ (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 |
(uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
}
@@ -552,7 +649,7 @@ static void instrumentOneFunc(
return;
unsigned NumIndirectCallSites = 0;
- for (auto &I : FuncInfo.IndirectCallSites) {
+ for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) {
CallSite CS(I);
Value *Callee = CS.getCalledValue();
DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = "
@@ -565,10 +662,14 @@ static void instrumentOneFunc(
{llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
Builder.getInt64(FuncInfo.FunctionHash),
Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()),
- Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget),
+ Builder.getInt32(IPVK_IndirectCallTarget),
Builder.getInt32(NumIndirectCallSites++)});
}
NumOfPGOICall += NumIndirectCallSites;
+
+ // Now instrument memop intrinsic calls.
+ FuncInfo.MIVisitor.instrumentMemIntrinsics(
+ F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash);
}
// This class represents a CFG edge in profile use compilation.
@@ -653,8 +754,11 @@ public:
// Set the branch weights based on the count values.
void setBranchWeights();
- // Annotate the indirect call sites.
- void annotateIndirectCallSites();
+ // Annotate the value profile call sites all all value kind.
+ void annotateValueSites();
+
+ // Annotate the value profile call sites for one value kind.
+ void annotateValueSites(uint32_t Kind);
// The hotness of the function from the profile count.
enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot };
@@ -677,6 +781,8 @@ public:
return FuncInfo.findBBInfo(BB);
}
+ Function &getFunc() const { return F; }
+
private:
Function &F;
Module *M;
@@ -761,7 +867,7 @@ void PGOUseFunc::setInstrumentedCounts(
NewEdge1.InMST = true;
getBBInfo(InstrBB).setBBInfoCount(CountValue);
}
- ProfileCountSize = CountFromProfile.size();
+ ProfileCountSize = CountFromProfile.size();
CountPosition = I;
}
@@ -932,21 +1038,6 @@ void PGOUseFunc::populateCounters() {
DEBUG(FuncInfo.dumpInfo("after reading profile."));
}
-static void setProfMetadata(Module *M, Instruction *TI,
- ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
- MDBuilder MDB(M->getContext());
- assert(MaxCount > 0 && "Bad max count");
- uint64_t Scale = calculateCountScale(MaxCount);
- SmallVector<unsigned, 4> Weights;
- for (const auto &ECI : EdgeCounts)
- Weights.push_back(scaleBranchCount(ECI, Scale));
-
- DEBUG(dbgs() << "Weight is: ";
- for (const auto &W : Weights) { dbgs() << W << " "; }
- dbgs() << "\n";);
- TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
-}
-
// Assign the scaled count values to the BB with multiple out edges.
void PGOUseFunc::setBranchWeights() {
// Generate MD_prof metadata for every branch instruction.
@@ -990,8 +1081,8 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) {
Builder.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step),
{llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
- Builder.getInt64(FuncHash),
- Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step});
+ Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs),
+ Builder.getInt32(*CurCtrIdx), Step});
++(*CurCtrIdx);
}
@@ -1020,9 +1111,9 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
if (SI.getCondition()->getType()->isVectorTy())
return;
- NSIs++;
switch (Mode) {
case VM_counting:
+ NSIs++;
return;
case VM_instrument:
instrumentOneSelectInst(SI);
@@ -1035,35 +1126,79 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
llvm_unreachable("Unknown visiting mode");
}
-// Traverse all the indirect callsites and annotate the instructions.
-void PGOUseFunc::annotateIndirectCallSites() {
+void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) {
+ Module *M = F.getParent();
+ IRBuilder<> Builder(&MI);
+ Type *Int64Ty = Builder.getInt64Ty();
+ Type *I8PtrTy = Builder.getInt8PtrTy();
+ Value *Length = MI.getLength();
+ assert(!dyn_cast<ConstantInt>(Length));
+ Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
+ {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
+ Builder.getInt64(FuncHash), Builder.CreatePtrToInt(Length, Int64Ty),
+ Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)});
+ ++CurCtrId;
+}
+
+void MemIntrinsicVisitor::visitMemIntrinsic(MemIntrinsic &MI) {
+ if (!PGOInstrMemOP)
+ return;
+ Value *Length = MI.getLength();
+ // Not instrument constant length calls.
+ if (dyn_cast<ConstantInt>(Length))
+ return;
+
+ switch (Mode) {
+ case VM_counting:
+ NMemIs++;
+ return;
+ case VM_instrument:
+ instrumentOneMemIntrinsic(MI);
+ return;
+ case VM_annotate:
+ Candidates.push_back(&MI);
+ return;
+ }
+ llvm_unreachable("Unknown visiting mode");
+}
+
+// Traverse all valuesites and annotate the instructions for all value kind.
+void PGOUseFunc::annotateValueSites() {
if (DisableValueProfiling)
return;
// Create the PGOFuncName meta data.
createPGOFuncNameMetadata(F, FuncInfo.FuncName);
- unsigned IndirectCallSiteIndex = 0;
- auto &IndirectCallSites = FuncInfo.IndirectCallSites;
- unsigned NumValueSites =
- ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget);
- if (NumValueSites != IndirectCallSites.size()) {
- std::string Msg =
- std::string("Inconsistent number of indirect call sites: ") +
- F.getName().str();
+ for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
+ annotateValueSites(Kind);
+}
+
+// Annotate the instructions for a specific value kind.
+void PGOUseFunc::annotateValueSites(uint32_t Kind) {
+ unsigned ValueSiteIndex = 0;
+ auto &ValueSites = FuncInfo.ValueSites[Kind];
+ unsigned NumValueSites = ProfileRecord.getNumValueSites(Kind);
+ if (NumValueSites != ValueSites.size()) {
auto &Ctx = M->getContext();
- Ctx.diagnose(
- DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
+ Ctx.diagnose(DiagnosticInfoPGOProfile(
+ M->getName().data(),
+ Twine("Inconsistent number of value sites for kind = ") + Twine(Kind) +
+ " in " + F.getName().str(),
+ DS_Warning));
return;
}
- for (auto &I : IndirectCallSites) {
- DEBUG(dbgs() << "Read one indirect call instrumentation: Index="
- << IndirectCallSiteIndex << " out of " << NumValueSites
- << "\n");
- annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget,
- IndirectCallSiteIndex, MaxNumAnnotations);
- IndirectCallSiteIndex++;
+ for (auto &I : ValueSites) {
+ DEBUG(dbgs() << "Read one value site profile (kind = " << Kind
+ << "): Index = " << ValueSiteIndex << " out of "
+ << NumValueSites << "\n");
+ annotateValueSite(*M, *I, ProfileRecord,
+ static_cast<InstrProfValueKind>(Kind), ValueSiteIndex,
+ Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations
+ : MaxNumAnnotations);
+ ValueSiteIndex++;
}
}
} // end anonymous namespace
@@ -1196,12 +1331,29 @@ static bool annotateAllFunctions(
continue;
Func.populateCounters();
Func.setBranchWeights();
- Func.annotateIndirectCallSites();
+ Func.annotateValueSites();
PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr();
if (FreqAttr == PGOUseFunc::FFA_Cold)
ColdFunctions.push_back(&F);
else if (FreqAttr == PGOUseFunc::FFA_Hot)
HotFunctions.push_back(&F);
+ if (PGOViewCounts && (ViewBlockFreqFuncName.empty() ||
+ F.getName().equals(ViewBlockFreqFuncName))) {
+ LoopInfo LI{DominatorTree(F)};
+ std::unique_ptr<BranchProbabilityInfo> NewBPI =
+ llvm::make_unique<BranchProbabilityInfo>(F, LI);
+ std::unique_ptr<BlockFrequencyInfo> NewBFI =
+ llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI);
+
+ NewBFI->view();
+ }
+ if (PGOViewRawCounts && (ViewBlockFreqFuncName.empty() ||
+ F.getName().equals(ViewBlockFreqFuncName))) {
+ if (ViewBlockFreqFuncName.empty())
+ WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName());
+ else
+ ViewGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName());
+ }
}
M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext()));
// Set function hotness attribute from the profile.
@@ -1257,3 +1409,90 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) {
return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI);
}
+
+namespace llvm {
+void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
+ uint64_t MaxCount) {
+ MDBuilder MDB(M->getContext());
+ assert(MaxCount > 0 && "Bad max count");
+ uint64_t Scale = calculateCountScale(MaxCount);
+ SmallVector<unsigned, 4> Weights;
+ for (const auto &ECI : EdgeCounts)
+ Weights.push_back(scaleBranchCount(ECI, Scale));
+
+ DEBUG(dbgs() << "Weight is: ";
+ for (const auto &W : Weights) { dbgs() << W << " "; }
+ dbgs() << "\n";);
+ TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+}
+
+template <> struct GraphTraits<PGOUseFunc *> {
+ typedef const BasicBlock *NodeRef;
+ typedef succ_const_iterator ChildIteratorType;
+ typedef pointer_iterator<Function::const_iterator> nodes_iterator;
+
+ static NodeRef getEntryNode(const PGOUseFunc *G) {
+ return &G->getFunc().front();
+ }
+ static ChildIteratorType child_begin(const NodeRef N) {
+ return succ_begin(N);
+ }
+ static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); }
+ static nodes_iterator nodes_begin(const PGOUseFunc *G) {
+ return nodes_iterator(G->getFunc().begin());
+ }
+ static nodes_iterator nodes_end(const PGOUseFunc *G) {
+ return nodes_iterator(G->getFunc().end());
+ }
+};
+
+static std::string getSimpleNodeName(const BasicBlock *Node) {
+ if (!Node->getName().empty())
+ return Node->getName();
+
+ std::string SimpleNodeName;
+ raw_string_ostream OS(SimpleNodeName);
+ Node->printAsOperand(OS, false);
+ return OS.str();
+}
+
+template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
+ explicit DOTGraphTraits(bool isSimple = false)
+ : DefaultDOTGraphTraits(isSimple) {}
+
+ static std::string getGraphName(const PGOUseFunc *G) {
+ return G->getFunc().getName();
+ }
+
+ std::string getNodeLabel(const BasicBlock *Node, const PGOUseFunc *Graph) {
+ std::string Result;
+ raw_string_ostream OS(Result);
+
+ OS << getSimpleNodeName(Node) << ":\\l";
+ UseBBInfo *BI = Graph->findBBInfo(Node);
+ OS << "Count : ";
+ if (BI && BI->CountValid)
+ OS << BI->CountValue << "\\l";
+ else
+ OS << "Unknown\\l";
+
+ if (!PGOInstrSelect)
+ return Result;
+
+ for (auto BI = Node->begin(); BI != Node->end(); ++BI) {
+ auto *I = &*BI;
+ if (!isa<SelectInst>(I))
+ continue;
+ // Display scaled counts for SELECT instruction:
+ OS << "SELECT : { T = ";
+ uint64_t TC, FC;
+ bool HasProf = I->extractProfMetadata(TC, FC);
+ if (!HasProf)
+ OS << "Unknown, F = Unknown }\\l";
+ else
+ OS << TC << ", F = " << FC << " }\\l";
+ }
+ return Result;
+ }
+};
+} // namespace llvm
diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
index 5b4b1fb771347..fa0c7cc5a4c53 100644
--- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
+++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
@@ -78,7 +78,6 @@ static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch";
static const char *const SanCovModuleCtorName = "sancov.module_ctor";
static const uint64_t SanCtorAndDtorPriority = 2;
-static const char *const SanCovTracePCGuardSection = "__sancov_guards";
static const char *const SanCovTracePCGuardName =
"__sanitizer_cov_trace_pc_guard";
static const char *const SanCovTracePCGuardInitName =
@@ -95,7 +94,7 @@ static cl::opt<unsigned> ClCoverageBlockThreshold(
"sanitizer-coverage-block-threshold",
cl::desc("Use a callback with a guard check inside it if there are"
" more than this number of blocks."),
- cl::Hidden, cl::init(500));
+ cl::Hidden, cl::init(0));
static cl::opt<bool>
ClExperimentalTracing("sanitizer-coverage-experimental-tracing",
@@ -216,6 +215,9 @@ private:
SanCovWithCheckFunction->getNumUses() + SanCovTraceBB->getNumUses() +
SanCovTraceEnter->getNumUses();
}
+ StringRef getSanCovTracePCGuardSection() const;
+ StringRef getSanCovTracePCGuardSectionStart() const;
+ StringRef getSanCovTracePCGuardSectionEnd() const;
Function *SanCovFunction;
Function *SanCovWithCheckFunction;
Function *SanCovIndirCallFunction, *SanCovTracePCIndir;
@@ -227,6 +229,7 @@ private:
InlineAsm *EmptyAsm;
Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy;
Module *CurModule;
+ Triple TargetTriple;
LLVMContext *C;
const DataLayout *DL;
@@ -246,6 +249,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {
C = &(M.getContext());
DL = &M.getDataLayout();
CurModule = &M;
+ TargetTriple = Triple(M.getTargetTriple());
HasSancovGuardsSection = false;
IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits());
IntptrPtrTy = PointerType::getUnqual(IntptrTy);
@@ -258,39 +262,39 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {
Int32Ty = IRB.getInt32Ty();
SanCovFunction = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy, nullptr));
+ M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy));
SanCovWithCheckFunction = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy, nullptr));
+ M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy));
SanCovTracePCIndir = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy, nullptr));
+ M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy));
SanCovIndirCallFunction =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr));
+ SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy));
SanCovTraceCmpFunction[0] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty(), nullptr));
+ SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty()));
SanCovTraceCmpFunction[1] = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(SanCovTraceCmp2, VoidTy, IRB.getInt16Ty(),
- IRB.getInt16Ty(), nullptr));
+ IRB.getInt16Ty()));
SanCovTraceCmpFunction[2] = checkSanitizerInterfaceFunction(
M.getOrInsertFunction(SanCovTraceCmp4, VoidTy, IRB.getInt32Ty(),
- IRB.getInt32Ty(), nullptr));
+ IRB.getInt32Ty()));
SanCovTraceCmpFunction[3] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty, nullptr));
+ SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty));
SanCovTraceDivFunction[0] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceDiv4, VoidTy, IRB.getInt32Ty(), nullptr));
+ SanCovTraceDiv4, VoidTy, IRB.getInt32Ty()));
SanCovTraceDivFunction[1] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceDiv8, VoidTy, Int64Ty, nullptr));
+ SanCovTraceDiv8, VoidTy, Int64Ty));
SanCovTraceGepFunction =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceGep, VoidTy, IntptrTy, nullptr));
+ SanCovTraceGep, VoidTy, IntptrTy));
SanCovTraceSwitchFunction =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy, nullptr));
+ SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy));
// We insert an empty inline asm after cov callbacks to avoid callback merge.
EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false),
@@ -298,13 +302,13 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {
/*hasSideEffects=*/true);
SanCovTracePC = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovTracePCName, VoidTy, nullptr));
+ M.getOrInsertFunction(SanCovTracePCName, VoidTy));
SanCovTracePCGuard = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- SanCovTracePCGuardName, VoidTy, Int32PtrTy, nullptr));
+ SanCovTracePCGuardName, VoidTy, Int32PtrTy));
SanCovTraceEnter = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy, nullptr));
+ M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy));
SanCovTraceBB = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy, nullptr));
+ M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy));
// At this point we create a dummy array of guards because we don't
// know how many elements we will need.
@@ -363,22 +367,28 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {
if (Options.TracePCGuard) {
if (HasSancovGuardsSection) {
Function *CtorFunc;
- std::string SectionName(SanCovTracePCGuardSection);
- GlobalVariable *Bounds[2];
- const char *Prefix[2] = {"__start_", "__stop_"};
- for (int i = 0; i < 2; i++) {
- Bounds[i] = new GlobalVariable(M, Int32PtrTy, false,
- GlobalVariable::ExternalLinkage, nullptr,
- Prefix[i] + SectionName);
- Bounds[i]->setVisibility(GlobalValue::HiddenVisibility);
- }
+ GlobalVariable *SecStart = new GlobalVariable(
+ M, Int32PtrTy, false, GlobalVariable::ExternalLinkage, nullptr,
+ getSanCovTracePCGuardSectionStart());
+ SecStart->setVisibility(GlobalValue::HiddenVisibility);
+ GlobalVariable *SecEnd = new GlobalVariable(
+ M, Int32PtrTy, false, GlobalVariable::ExternalLinkage, nullptr,
+ getSanCovTracePCGuardSectionEnd());
+ SecEnd->setVisibility(GlobalValue::HiddenVisibility);
+
std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions(
M, SanCovModuleCtorName, SanCovTracePCGuardInitName,
{Int32PtrTy, Int32PtrTy},
- {IRB.CreatePointerCast(Bounds[0], Int32PtrTy),
- IRB.CreatePointerCast(Bounds[1], Int32PtrTy)});
-
- appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority);
+ {IRB.CreatePointerCast(SecStart, Int32PtrTy),
+ IRB.CreatePointerCast(SecEnd, Int32PtrTy)});
+
+ if (TargetTriple.supportsCOMDAT()) {
+ // Use comdat to dedup CtorFunc.
+ CtorFunc->setComdat(M.getOrInsertComdat(SanCovModuleCtorName));
+ appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority, CtorFunc);
+ } else {
+ appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority);
+ }
}
} else if (!Options.TracePC) {
Function *CtorFunc;
@@ -435,6 +445,11 @@ static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const
if (isa<UnreachableInst>(BB->getTerminator()))
return false;
+ // Don't insert coverage into blocks without a valid insertion point
+ // (catchswitch blocks).
+ if (BB->getFirstInsertionPt() == BB->end())
+ return false;
+
if (!ClPruneBlocks || &F.getEntryBlock() == BB)
return true;
@@ -517,7 +532,7 @@ void SanitizerCoverageModule::CreateFunctionGuardArray(size_t NumGuards,
Constant::getNullValue(ArrayOfInt32Ty), "__sancov_gen_");
if (auto Comdat = F.getComdat())
FunctionGuardArray->setComdat(Comdat);
- FunctionGuardArray->setSection(SanCovTracePCGuardSection);
+ FunctionGuardArray->setSection(getSanCovTracePCGuardSection());
}
bool SanitizerCoverageModule::InjectCoverage(Function &F,
@@ -755,6 +770,27 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
}
}
+StringRef SanitizerCoverageModule::getSanCovTracePCGuardSection() const {
+ if (TargetTriple.getObjectFormat() == Triple::COFF)
+ return ".SCOV$M";
+ if (TargetTriple.isOSBinFormatMachO())
+ return "__DATA,__sancov_guards";
+ return "__sancov_guards";
+}
+
+StringRef SanitizerCoverageModule::getSanCovTracePCGuardSectionStart() const {
+ if (TargetTriple.isOSBinFormatMachO())
+ return "\1section$start$__DATA$__sancov_guards";
+ return "__start___sancov_guards";
+}
+
+StringRef SanitizerCoverageModule::getSanCovTracePCGuardSectionEnd() const {
+ if (TargetTriple.isOSBinFormatMachO())
+ return "\1section$end$__DATA$__sancov_guards";
+ return "__stop___sancov_guards";
+}
+
+
char SanitizerCoverageModule::ID = 0;
INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov",
"SanitizerCoverage: TODO."
diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
index 52035c79a4a38..9260217bd5e62 100644
--- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
@@ -155,17 +155,18 @@ FunctionPass *llvm::createThreadSanitizerPass() {
void ThreadSanitizer::initializeCallbacks(Module &M) {
IRBuilder<> IRB(M.getContext());
- AttributeSet Attr;
- Attr = Attr.addAttribute(M.getContext(), AttributeSet::FunctionIndex, Attribute::NoUnwind);
+ AttributeList Attr;
+ Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex,
+ Attribute::NoUnwind);
// Initialize the callbacks.
TsanFuncEntry = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
TsanFuncExit = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy(), nullptr));
+ M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy()));
TsanIgnoreBegin = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy(), nullptr));
+ "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy()));
TsanIgnoreEnd = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_ignore_thread_end", Attr, IRB.getVoidTy(), nullptr));
+ "__tsan_ignore_thread_end", Attr, IRB.getVoidTy()));
OrdTy = IRB.getInt32Ty();
for (size_t i = 0; i < kNumberOfAccessSizes; ++i) {
const unsigned ByteSize = 1U << i;
@@ -174,31 +175,31 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {
std::string BitSizeStr = utostr(BitSize);
SmallString<32> ReadName("__tsan_read" + ByteSizeStr);
TsanRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<32> WriteName("__tsan_write" + ByteSizeStr);
TsanWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr);
TsanUnalignedRead[i] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr);
TsanUnalignedWrite[i] =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
Type *Ty = Type::getIntNTy(M.getContext(), BitSize);
Type *PtrTy = Ty->getPointerTo();
SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load");
TsanAtomicLoad[i] = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy, nullptr));
+ M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy));
SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store");
TsanAtomicStore[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy, nullptr));
+ AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy));
for (int op = AtomicRMWInst::FIRST_BINOP;
op <= AtomicRMWInst::LAST_BINOP; ++op) {
@@ -222,33 +223,33 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {
continue;
SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart);
TsanAtomicRMW[op][i] = checkSanitizerInterfaceFunction(
- M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy, nullptr));
+ M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy));
}
SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr +
"_compare_exchange_val");
TsanAtomicCAS[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr));
+ AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy));
}
TsanVptrUpdate = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), nullptr));
+ IRB.getInt8PtrTy(), IRB.getInt8PtrTy()));
TsanVptrLoad = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr));
+ "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()));
TsanAtomicThreadFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr));
+ "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy));
TsanAtomicSignalFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr));
+ "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy));
MemmoveFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
MemcpyFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy, nullptr));
+ IRB.getInt8PtrTy(), IntptrTy));
MemsetFn = checkSanitizerInterfaceFunction(
M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt32Ty(), IntptrTy, nullptr));
+ IRB.getInt32Ty(), IntptrTy));
}
bool ThreadSanitizer::doInitialization(Module &M) {
@@ -271,7 +272,7 @@ static bool isVtableAccess(Instruction *I) {
// Do not instrument known races/"benign races" that come from compiler
// instrumentatin. The user has no way of suppressing them.
-static bool shouldInstrumentReadWriteFromAddress(Value *Addr) {
+static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
// Peel off GEPs and BitCasts.
Addr = Addr->stripInBoundsOffsets();
@@ -279,8 +280,9 @@ static bool shouldInstrumentReadWriteFromAddress(Value *Addr) {
if (GV->hasSection()) {
StringRef SectionName = GV->getSection();
// Check if the global is in the PGO counters section.
- if (SectionName.endswith(getInstrProfCountersSectionName(
- /*AddSegment=*/false)))
+ auto OF = Triple(M->getTargetTriple()).getObjectFormat();
+ if (SectionName.endswith(
+ getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
return false;
}
@@ -342,13 +344,13 @@ void ThreadSanitizer::chooseInstructionsToInstrument(
for (Instruction *I : reverse(Local)) {
if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
Value *Addr = Store->getPointerOperand();
- if (!shouldInstrumentReadWriteFromAddress(Addr))
+ if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
continue;
WriteTargets.insert(Addr);
} else {
LoadInst *Load = cast<LoadInst>(I);
Value *Addr = Load->getPointerOperand();
- if (!shouldInstrumentReadWriteFromAddress(Addr))
+ if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
continue;
if (WriteTargets.count(Addr)) {
// We will write to this temp, so no reason to analyze the read.
diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h
index c74827210364e..c541fa4c8bee7 100644
--- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h
+++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h
@@ -127,9 +127,8 @@ private:
LLVMContext &C = TheModule->getContext();
Type *Params[] = { PointerType::getUnqual(Type::getInt8Ty(C)) };
- AttributeSet Attr =
- AttributeSet().addAttribute(C, AttributeSet::FunctionIndex,
- Attribute::NoUnwind);
+ AttributeList Attr = AttributeList().addAttribute(
+ C, AttributeList::FunctionIndex, Attribute::NoUnwind);
FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params,
/*isVarArg=*/false);
return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr);
@@ -144,10 +143,10 @@ private:
Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C));
Type *Params[] = { I8X };
FunctionType *Fty = FunctionType::get(I8X, Params, /*isVarArg=*/false);
- AttributeSet Attr = AttributeSet();
+ AttributeList Attr = AttributeList();
if (NoUnwind)
- Attr = Attr.addAttribute(C, AttributeSet::FunctionIndex,
+ Attr = Attr.addAttribute(C, AttributeList::FunctionIndex,
Attribute::NoUnwind);
return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr);
@@ -162,9 +161,8 @@ private:
Type *I8XX = PointerType::getUnqual(I8X);
Type *Params[] = { I8XX, I8X };
- AttributeSet Attr =
- AttributeSet().addAttribute(C, AttributeSet::FunctionIndex,
- Attribute::NoUnwind);
+ AttributeList Attr = AttributeList().addAttribute(
+ C, AttributeList::FunctionIndex, Attribute::NoUnwind);
Attr = Attr.addAttribute(C, 1, Attribute::NoCapture);
FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params,
diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp
index 23c1f5990ba52..a86eaaec76412 100644
--- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp
+++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp
@@ -394,6 +394,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release,
DEBUG(llvm::dbgs() << " New Store Strong: " << *StoreStrong << "\n");
+ if (&*Iter == Retain) ++Iter;
if (&*Iter == Store) ++Iter;
Store->eraseFromParent();
Release->eraseFromParent();
diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
index 136d54a6cb759..3c73376c99068 100644
--- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
+++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
@@ -85,41 +85,6 @@ static const Value *FindSingleUseIdentifiedObject(const Value *Arg) {
return nullptr;
}
-/// This is a wrapper around getUnderlyingObjCPtr along the lines of
-/// GetUnderlyingObjects except that it returns early when it sees the first
-/// alloca.
-static inline bool AreAnyUnderlyingObjectsAnAlloca(const Value *V,
- const DataLayout &DL) {
- SmallPtrSet<const Value *, 4> Visited;
- SmallVector<const Value *, 4> Worklist;
- Worklist.push_back(V);
- do {
- const Value *P = Worklist.pop_back_val();
- P = GetUnderlyingObjCPtr(P, DL);
-
- if (isa<AllocaInst>(P))
- return true;
-
- if (!Visited.insert(P).second)
- continue;
-
- if (const SelectInst *SI = dyn_cast<const SelectInst>(P)) {
- Worklist.push_back(SI->getTrueValue());
- Worklist.push_back(SI->getFalseValue());
- continue;
- }
-
- if (const PHINode *PN = dyn_cast<const PHINode>(P)) {
- for (Value *IncValue : PN->incoming_values())
- Worklist.push_back(IncValue);
- continue;
- }
- } while (!Worklist.empty());
-
- return false;
-}
-
-
/// @}
///
/// \defgroup ARCOpt ARC Optimization.
@@ -481,9 +446,6 @@ namespace {
/// MDKind identifiers.
ARCMDKindCache MDKindCache;
- // This is used to track if a pointer is stored into an alloca.
- DenseSet<const Value *> MultiOwnersSet;
-
/// A flag indicating whether this optimization pass should run.
bool Run;
@@ -524,8 +486,7 @@ namespace {
PairUpRetainsAndReleases(DenseMap<const BasicBlock *, BBState> &BBStates,
BlotMapVector<Value *, RRInfo> &Retains,
DenseMap<Value *, RRInfo> &Releases, Module *M,
- SmallVectorImpl<Instruction *> &NewRetains,
- SmallVectorImpl<Instruction *> &NewReleases,
+ Instruction * Retain,
SmallVectorImpl<Instruction *> &DeadInsts,
RRInfo &RetainsToMove, RRInfo &ReleasesToMove,
Value *Arg, bool KnownSafe,
@@ -1155,29 +1116,6 @@ bool ObjCARCOpt::VisitInstructionBottomUp(
case ARCInstKind::None:
// These are irrelevant.
return NestingDetected;
- case ARCInstKind::User:
- // If we have a store into an alloca of a pointer we are tracking, the
- // pointer has multiple owners implying that we must be more conservative.
- //
- // This comes up in the context of a pointer being ``KnownSafe''. In the
- // presence of a block being initialized, the frontend will emit the
- // objc_retain on the original pointer and the release on the pointer loaded
- // from the alloca. The optimizer will through the provenance analysis
- // realize that the two are related, but since we only require KnownSafe in
- // one direction, will match the inner retain on the original pointer with
- // the guard release on the original pointer. This is fixed by ensuring that
- // in the presence of allocas we only unconditionally remove pointers if
- // both our retain and our release are KnownSafe.
- if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
- const DataLayout &DL = BB->getModule()->getDataLayout();
- if (AreAnyUnderlyingObjectsAnAlloca(SI->getPointerOperand(), DL)) {
- auto I = MyStates.findPtrBottomUpState(
- GetRCIdentityRoot(SI->getValueOperand()));
- if (I != MyStates.bottom_up_ptr_end())
- MultiOwnersSet.insert(I->first);
- }
- }
- break;
default:
break;
}
@@ -1540,8 +1478,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(
DenseMap<const BasicBlock *, BBState> &BBStates,
BlotMapVector<Value *, RRInfo> &Retains,
DenseMap<Value *, RRInfo> &Releases, Module *M,
- SmallVectorImpl<Instruction *> &NewRetains,
- SmallVectorImpl<Instruction *> &NewReleases,
+ Instruction *Retain,
SmallVectorImpl<Instruction *> &DeadInsts, RRInfo &RetainsToMove,
RRInfo &ReleasesToMove, Value *Arg, bool KnownSafe,
bool &AnyPairsCompletelyEliminated) {
@@ -1549,7 +1486,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(
// is already incremented, we can similarly ignore possible decrements unless
// we are dealing with a retainable object with multiple provenance sources.
bool KnownSafeTD = true, KnownSafeBU = true;
- bool MultipleOwners = false;
bool CFGHazardAfflicted = false;
// Connect the dots between the top-down-collected RetainsToMove and
@@ -1561,14 +1497,13 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(
unsigned OldCount = 0;
unsigned NewCount = 0;
bool FirstRelease = true;
- for (;;) {
+ for (SmallVector<Instruction *, 4> NewRetains{Retain};;) {
+ SmallVector<Instruction *, 4> NewReleases;
for (Instruction *NewRetain : NewRetains) {
auto It = Retains.find(NewRetain);
assert(It != Retains.end());
const RRInfo &NewRetainRRI = It->second;
KnownSafeTD &= NewRetainRRI.KnownSafe;
- MultipleOwners =
- MultipleOwners || MultiOwnersSet.count(GetArgRCIdentityRoot(NewRetain));
for (Instruction *NewRetainRelease : NewRetainRRI.Calls) {
auto Jt = Releases.find(NewRetainRelease);
if (Jt == Releases.end())
@@ -1691,7 +1626,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases(
}
}
}
- NewReleases.clear();
if (NewRetains.empty()) break;
}
@@ -1745,10 +1679,6 @@ bool ObjCARCOpt::PerformCodePlacement(
DEBUG(dbgs() << "\n== ObjCARCOpt::PerformCodePlacement ==\n");
bool AnyPairsCompletelyEliminated = false;
- RRInfo RetainsToMove;
- RRInfo ReleasesToMove;
- SmallVector<Instruction *, 4> NewRetains;
- SmallVector<Instruction *, 4> NewReleases;
SmallVector<Instruction *, 8> DeadInsts;
// Visit each retain.
@@ -1780,9 +1710,10 @@ bool ObjCARCOpt::PerformCodePlacement(
// Connect the dots between the top-down-collected RetainsToMove and
// bottom-up-collected ReleasesToMove to form sets of related calls.
- NewRetains.push_back(Retain);
+ RRInfo RetainsToMove, ReleasesToMove;
+
bool PerformMoveCalls = PairUpRetainsAndReleases(
- BBStates, Retains, Releases, M, NewRetains, NewReleases, DeadInsts,
+ BBStates, Retains, Releases, M, Retain, DeadInsts,
RetainsToMove, ReleasesToMove, Arg, KnownSafe,
AnyPairsCompletelyEliminated);
@@ -1792,12 +1723,6 @@ bool ObjCARCOpt::PerformCodePlacement(
MoveCalls(Arg, RetainsToMove, ReleasesToMove,
Retains, Releases, DeadInsts, M);
}
-
- // Clean up state for next retain.
- NewReleases.clear();
- NewRetains.clear();
- RetainsToMove.clear();
- ReleasesToMove.clear();
}
// Now that we're done moving everything, we can delete the newly dead
@@ -1987,9 +1912,6 @@ bool ObjCARCOpt::OptimizeSequences(Function &F) {
Releases,
F.getParent());
- // Cleanup.
- MultiOwnersSet.clear();
-
return AnyPairsCompletelyEliminated && NestingDetected;
}
diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp
index adc903cab31b7..5b467dc9fe125 100644
--- a/lib/Transforms/Scalar/ADCE.cpp
+++ b/lib/Transforms/Scalar/ADCE.cpp
@@ -41,8 +41,8 @@ using namespace llvm;
STATISTIC(NumRemoved, "Number of instructions removed");
STATISTIC(NumBranchesRemoved, "Number of branch instructions removed");
-// This is a tempoary option until we change the interface
-// to this pass based on optimization level.
+// This is a temporary option until we change the interface to this pass based
+// on optimization level.
static cl::opt<bool> RemoveControlFlowFlag("adce-remove-control-flow",
cl::init(true), cl::Hidden);
@@ -110,7 +110,7 @@ class AggressiveDeadCodeElimination {
/// The set of blocks which we have determined whose control
/// dependence sources must be live and which have not had
- /// those dependences analyized.
+ /// those dependences analyzed.
SmallPtrSet<BasicBlock *, 16> NewLiveBlocks;
/// Set up auxiliary data structures for Instructions and BasicBlocks and
@@ -145,7 +145,7 @@ class AggressiveDeadCodeElimination {
/// was removed.
bool removeDeadInstructions();
- /// Identify connected sections of the control flow grap which have
+ /// Identify connected sections of the control flow graph which have
/// dead terminators and rewrite the control flow graph to remove them.
void updateDeadRegions();
@@ -234,7 +234,7 @@ void AggressiveDeadCodeElimination::initialize() {
return Iter != end() && Iter->second;
}
} State;
-
+
State.reserve(F.size());
// Iterate over blocks in depth-first pre-order and
// treat all edges to a block already seen as loop back edges
@@ -262,25 +262,6 @@ void AggressiveDeadCodeElimination::initialize() {
continue;
auto *BB = BBInfo.BB;
if (!PDT.getNode(BB)) {
- markLive(BBInfo.Terminator);
- continue;
- }
- for (auto *Succ : successors(BB))
- if (!PDT.getNode(Succ)) {
- markLive(BBInfo.Terminator);
- break;
- }
- }
-
- // Mark blocks live if there is no path from the block to the
- // return of the function or a successor for which this is true.
- // This protects IDFCalculator which cannot handle such blocks.
- for (auto &BBInfoPair : BlockInfo) {
- auto &BBInfo = BBInfoPair.second;
- if (BBInfo.terminatorIsLive())
- continue;
- auto *BB = BBInfo.BB;
- if (!PDT.getNode(BB)) {
DEBUG(dbgs() << "Not post-dominated by return: " << BB->getName()
<< '\n';);
markLive(BBInfo.Terminator);
@@ -579,7 +560,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {
PreferredSucc = Info;
}
assert((PreferredSucc && PreferredSucc->PostOrder > 0) &&
- "Failed to find safe successor for dead branc");
+ "Failed to find safe successor for dead branch");
bool First = true;
for (auto *Succ : successors(BB)) {
if (!First || Succ != PreferredSucc->BB)
@@ -594,13 +575,13 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {
// reverse top-sort order
void AggressiveDeadCodeElimination::computeReversePostOrder() {
-
- // This provides a post-order numbering of the reverse conrtol flow graph
+
+ // This provides a post-order numbering of the reverse control flow graph
// Note that it is incomplete in the presence of infinite loops but we don't
// need numbers blocks which don't reach the end of the functions since
// all branches in those blocks are forced live.
-
- // For each block without successors, extend the DFS from the bloack
+
+ // For each block without successors, extend the DFS from the block
// backward through the graph
SmallPtrSet<BasicBlock*, 16> Visited;
unsigned PostOrder = 0;
@@ -644,8 +625,8 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) {
if (!AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination())
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- auto PA = PreservedAnalyses();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index c1df3173c0fc5..fd931c521c8f1 100644
--- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -438,19 +438,13 @@ AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
- bool Changed = runImpl(F, AC, &SE, &DT);
-
- // FIXME: We need to invalidate this to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<ScalarEvolutionAnalysis>(F);
-
- if (!Changed)
+ if (!runImpl(F, AC, &SE, &DT))
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<AAManager>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<GlobalsAA>();
- PA.preserve<LoopAnalysis>();
- PA.preserve<DominatorTreeAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp
index 251b387077697..61e8700f1cd67 100644
--- a/lib/Transforms/Scalar/BDCE.cpp
+++ b/lib/Transforms/Scalar/BDCE.cpp
@@ -80,8 +80,8 @@ PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) {
if (!bitTrackingDCE(F, DB))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- auto PA = PreservedAnalyses();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt
index 06d3d6a73954e..b323ab3bd443c 100644
--- a/lib/Transforms/Scalar/CMakeLists.txt
+++ b/lib/Transforms/Scalar/CMakeLists.txt
@@ -16,6 +16,7 @@ add_llvm_library(LLVMScalarOpts
IVUsersPrinter.cpp
InductiveRangeCheckElimination.cpp
IndVarSimplify.cpp
+ InferAddressSpaces.cpp
JumpThreading.cpp
LICM.cpp
LoopAccessAnalysisPrinter.cpp
@@ -29,6 +30,7 @@ add_llvm_library(LLVMScalarOpts
LoopInterchange.cpp
LoopLoadElimination.cpp
LoopPassManager.cpp
+ LoopPredication.cpp
LoopRerollPass.cpp
LoopRotation.cpp
LoopSimplifyCFG.cpp
diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp
index 38262514c9ec4..ee6333e88716b 100644
--- a/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -136,8 +136,16 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst,
if (Idx != ~0U && isa<PHINode>(Inst))
return cast<PHINode>(Inst)->getIncomingBlock(Idx)->getTerminator();
- BasicBlock *IDom = DT->getNode(Inst->getParent())->getIDom()->getBlock();
- return IDom->getTerminator();
+ // This must be an EH pad. Iterate over immediate dominators until we find a
+ // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads
+ // and terminators.
+ auto IDom = DT->getNode(Inst->getParent())->getIDom();
+ while (IDom->getBlock()->isEHPad()) {
+ assert(Entry != IDom->getBlock() && "eh pad in entry block");
+ IDom = IDom->getIDom();
+ }
+
+ return IDom->getBlock()->getTerminator();
}
/// \brief Find an insertion point that dominates all uses.
@@ -289,8 +297,8 @@ void ConstantHoistingPass::collectConstantCandidates(Function &Fn) {
// bit widths (APInt Operator- does not like that). If the value cannot be
// represented in uint64 we return an "empty" APInt. This is then interpreted
// as the value is not in range.
-static llvm::Optional<APInt> calculateOffsetDiff(APInt V1, APInt V2)
-{
+static llvm::Optional<APInt> calculateOffsetDiff(const APInt &V1,
+ const APInt &V2) {
llvm::Optional<APInt> Res = None;
unsigned BW = V1.getBitWidth() > V2.getBitWidth() ?
V1.getBitWidth() : V2.getBitWidth();
@@ -623,6 +631,7 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F,
if (!runImpl(F, TTI, DT, F.getEntryBlock()))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- return PreservedAnalyses::none();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 84f9373ae9144..c843c61ea94ed 100644
--- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -235,9 +235,8 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {
// Analyse each switch case in turn. This is done in reverse order so that
// removing a case doesn't cause trouble for the iteration.
bool Changed = false;
- for (SwitchInst::CaseIt CI = SI->case_end(), CE = SI->case_begin(); CI-- != CE;
- ) {
- ConstantInt *Case = CI.getCaseValue();
+ for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) {
+ ConstantInt *Case = CI->getCaseValue();
// Check to see if the switch condition is equal to/not equal to the case
// value on every incoming edge, equal/not equal being the same each time.
@@ -270,8 +269,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {
if (State == LazyValueInfo::False) {
// This case never fires - remove it.
- CI.getCaseSuccessor()->removePredecessor(BB);
- SI->removeCase(CI); // Does not invalidate the iterator.
+ CI->getCaseSuccessor()->removePredecessor(BB);
+ CI = SI->removeCase(CI);
+ CE = SI->case_end();
// The condition can be modified by removePredecessor's PHI simplification
// logic.
@@ -279,7 +279,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {
++NumDeadCases;
Changed = true;
- } else if (State == LazyValueInfo::True) {
+ continue;
+ }
+ if (State == LazyValueInfo::True) {
// This case always fires. Arrange for the switch to be turned into an
// unconditional branch by replacing the switch condition with the case
// value.
@@ -288,6 +290,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) {
Changed = true;
break;
}
+
+ // Increment the case iterator sense we didn't delete it.
+ ++CI;
}
if (Changed)
@@ -308,7 +313,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
// Try to mark pointer typed parameters as non-null. We skip the
// relatively expensive analysis for constants which are obviously either
// null or non-null to start with.
- if (Type && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) &&
+ if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&
!isa<Constant>(V) &&
LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,
ConstantPointerNull::get(Type),
@@ -322,7 +327,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
if (Indices.empty())
return false;
- AttributeSet AS = CS.getAttributes();
+ AttributeList AS = CS.getAttributes();
LLVMContext &Ctx = CS.getInstruction()->getContext();
AS = AS.addAttribute(Ctx, Indices, Attribute::get(Ctx, Attribute::NonNull));
CS.setAttributes(AS);
@@ -570,10 +575,6 @@ CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {
LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
bool Changed = runImpl(F, LVI);
- // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<LazyValueAnalysis>(F);
-
if (!Changed)
return PreservedAnalyses::all();
PreservedAnalyses PA;
diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp
index cc2a3cfaf9d19..07a0ba9b12221 100644
--- a/lib/Transforms/Scalar/DCE.cpp
+++ b/lib/Transforms/Scalar/DCE.cpp
@@ -124,9 +124,12 @@ static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) {
}
PreservedAnalyses DCEPass::run(Function &F, FunctionAnalysisManager &AM) {
- if (eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F)))
- return PreservedAnalyses::none();
- return PreservedAnalyses::all();
+ if (!eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F)))
+ return PreservedAnalyses::all();
+
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
namespace {
diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 4d4c3baef3f5f..1ec38e56aa4cb 100644
--- a/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -135,13 +135,13 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) {
if (auto CS = CallSite(I)) {
if (Function *F = CS.getCalledFunction()) {
StringRef FnName = F->getName();
- if (TLI.has(LibFunc::strcpy) && FnName == TLI.getName(LibFunc::strcpy))
+ if (TLI.has(LibFunc_strcpy) && FnName == TLI.getName(LibFunc_strcpy))
return true;
- if (TLI.has(LibFunc::strncpy) && FnName == TLI.getName(LibFunc::strncpy))
+ if (TLI.has(LibFunc_strncpy) && FnName == TLI.getName(LibFunc_strncpy))
return true;
- if (TLI.has(LibFunc::strcat) && FnName == TLI.getName(LibFunc::strcat))
+ if (TLI.has(LibFunc_strcat) && FnName == TLI.getName(LibFunc_strcat))
return true;
- if (TLI.has(LibFunc::strncat) && FnName == TLI.getName(LibFunc::strncat))
+ if (TLI.has(LibFunc_strncat) && FnName == TLI.getName(LibFunc_strncat))
return true;
}
}
@@ -287,19 +287,14 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL,
}
namespace {
-enum OverwriteResult {
- OverwriteBegin,
- OverwriteComplete,
- OverwriteEnd,
- OverwriteUnknown
-};
+enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown };
}
-/// Return 'OverwriteComplete' if a store to the 'Later' location completely
-/// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of
-/// the 'Earlier' location is completely overwritten by 'Later',
-/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten
-/// by 'Later', or 'OverwriteUnknown' if nothing can be determined.
+/// Return 'OW_Complete' if a store to the 'Later' location completely
+/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the
+/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the
+/// beginning of the 'Earlier' location is overwritten by 'Later', or
+/// 'OW_Unknown' if nothing can be determined.
static OverwriteResult isOverwrite(const MemoryLocation &Later,
const MemoryLocation &Earlier,
const DataLayout &DL,
@@ -310,7 +305,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
// If we don't know the sizes of either access, then we can't do a comparison.
if (Later.Size == MemoryLocation::UnknownSize ||
Earlier.Size == MemoryLocation::UnknownSize)
- return OverwriteUnknown;
+ return OW_Unknown;
const Value *P1 = Earlier.Ptr->stripPointerCasts();
const Value *P2 = Later.Ptr->stripPointerCasts();
@@ -320,7 +315,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
if (P1 == P2) {
// Make sure that the Later size is >= the Earlier size.
if (Later.Size >= Earlier.Size)
- return OverwriteComplete;
+ return OW_Complete;
}
// Check to see if the later store is to the entire object (either a global,
@@ -332,13 +327,13 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
// If we can't resolve the same pointers to the same object, then we can't
// analyze them at all.
if (UO1 != UO2)
- return OverwriteUnknown;
+ return OW_Unknown;
// If the "Later" store is to a recognizable object, get its size.
uint64_t ObjectSize = getPointerSize(UO2, DL, TLI);
if (ObjectSize != MemoryLocation::UnknownSize)
if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size)
- return OverwriteComplete;
+ return OW_Complete;
// Okay, we have stores to two completely different pointers. Try to
// decompose the pointer into a "base + constant_offset" form. If the base
@@ -350,7 +345,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
// If the base pointers still differ, we have two completely different stores.
if (BP1 != BP2)
- return OverwriteUnknown;
+ return OW_Unknown;
// The later store completely overlaps the earlier store if:
//
@@ -370,7 +365,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
if (EarlierOff >= LaterOff &&
Later.Size >= Earlier.Size &&
uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size)
- return OverwriteComplete;
+ return OW_Complete;
// We may now overlap, although the overlap is not complete. There might also
// be other incomplete overlaps, and together, they might cover the complete
@@ -428,7 +423,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
") Composite Later [" <<
ILI->second << ", " << ILI->first << ")\n");
++NumCompletePartials;
- return OverwriteComplete;
+ return OW_Complete;
}
}
@@ -443,7 +438,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
if (!EnablePartialOverwriteTracking &&
(LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) &&
int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size)))
- return OverwriteEnd;
+ return OW_End;
// Finally, we also need to check if the later store overwrites the beginning
// of the earlier store.
@@ -458,11 +453,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
(LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) {
assert(int64_t(LaterOff + Later.Size) <
int64_t(EarlierOff + Earlier.Size) &&
- "Expect to be handled as OverwriteComplete");
- return OverwriteBegin;
+ "Expect to be handled as OW_Complete");
+ return OW_Begin;
}
// Otherwise, they don't completely overlap.
- return OverwriteUnknown;
+ return OW_Unknown;
}
/// If 'Inst' might be a self read (i.e. a noop copy of a
@@ -551,7 +546,7 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI,
Instruction *I = &*BI;
if (I->mayWriteToMemory() && I != SecondI) {
auto Res = AA->getModRefInfo(I, MemLoc);
- if (Res != MRI_NoModRef)
+ if (Res & MRI_Mod)
return false;
}
}
@@ -909,7 +904,7 @@ static bool tryToShortenBegin(Instruction *EarlierWrite,
if (LaterStart <= EarlierStart && LaterStart + LaterSize > EarlierStart) {
assert(LaterStart + LaterSize < EarlierStart + EarlierSize &&
- "Should have been handled as OverwriteComplete");
+ "Should have been handled as OW_Complete");
if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart,
LaterSize, false)) {
IntervalMap.erase(OII);
@@ -1105,7 +1100,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
OverwriteResult OR =
isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset,
DepWrite, IOL);
- if (OR == OverwriteComplete) {
+ if (OR == OW_Complete) {
DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: "
<< *DepWrite << "\n KILLER: " << *Inst << '\n');
@@ -1117,15 +1112,15 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
// We erased DepWrite; start over.
InstDep = MD->getDependency(Inst);
continue;
- } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) ||
- ((OR == OverwriteBegin &&
+ } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) ||
+ ((OR == OW_Begin &&
isShortenableAtTheBeginning(DepWrite)))) {
assert(!EnablePartialOverwriteTracking && "Do not expect to perform "
"when partial-overwrite "
"tracking is enabled");
int64_t EarlierSize = DepLoc.Size;
int64_t LaterSize = Loc.Size;
- bool IsOverwriteEnd = (OR == OverwriteEnd);
+ bool IsOverwriteEnd = (OR == OW_End);
MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize,
InstWriteOffset, LaterSize, IsOverwriteEnd);
}
@@ -1186,8 +1181,9 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
if (!eliminateDeadStores(F, AA, MD, DT, TLI))
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
PA.preserve<MemoryDependenceAnalysis>();
return PA;
diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp
index 16e08ee58fbeb..04479b6e49ac8 100644
--- a/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -19,6 +19,8 @@
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/DataLayout.h"
@@ -32,7 +34,6 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Utils/MemorySSA.h"
#include <deque>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -253,6 +254,7 @@ public:
DominatorTree &DT;
AssumptionCache &AC;
MemorySSA *MSSA;
+ std::unique_ptr<MemorySSAUpdater> MSSAUpdater;
typedef RecyclingAllocator<
BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy;
typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>,
@@ -315,7 +317,9 @@ public:
/// \brief Set up the EarlyCSE runner for a particular function.
EarlyCSE(const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI,
DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA)
- : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA), CurrentGeneration(0) {}
+ : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA),
+ MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), CurrentGeneration(0) {
+ }
bool run();
@@ -388,7 +392,7 @@ private:
ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI)
: IsTargetMemInst(false), Inst(Inst) {
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst))
- if (TTI.getTgtMemIntrinsic(II, Info) && Info.NumMemRefs == 1)
+ if (TTI.getTgtMemIntrinsic(II, Info))
IsTargetMemInst = true;
}
bool isLoad() const {
@@ -400,17 +404,14 @@ private:
return isa<StoreInst>(Inst);
}
bool isAtomic() const {
- if (IsTargetMemInst) {
- assert(Info.IsSimple && "need to refine IsSimple in TTI");
- return false;
- }
+ if (IsTargetMemInst)
+ return Info.Ordering != AtomicOrdering::NotAtomic;
return Inst->isAtomic();
}
bool isUnordered() const {
- if (IsTargetMemInst) {
- assert(Info.IsSimple && "need to refine IsSimple in TTI");
- return true;
- }
+ if (IsTargetMemInst)
+ return Info.isUnordered();
+
if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
return LI->isUnordered();
} else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
@@ -421,10 +422,9 @@ private:
}
bool isVolatile() const {
- if (IsTargetMemInst) {
- assert(Info.IsSimple && "need to refine IsSimple in TTI");
- return false;
- }
+ if (IsTargetMemInst)
+ return Info.IsVolatile;
+
if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
return LI->isVolatile();
} else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
@@ -517,7 +517,7 @@ private:
if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U))
PhisToCheck.push_back(MP);
- MSSA->removeMemoryAccess(WI);
+ MSSAUpdater->removeMemoryAccess(WI);
for (MemoryPhi *MP : PhisToCheck) {
MemoryAccess *FirstIn = MP->getIncomingValue(0);
@@ -587,27 +587,28 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// which reaches this block where the condition might hold a different
// value. Since we're adding this to the scoped hash table (like any other
// def), it will have been popped if we encounter a future merge block.
- if (BasicBlock *Pred = BB->getSinglePredecessor())
- if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()))
- if (BI->isConditional())
- if (auto *CondInst = dyn_cast<Instruction>(BI->getCondition()))
- if (SimpleValue::canHandle(CondInst)) {
- assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB);
- auto *ConditionalConstant = (BI->getSuccessor(0) == BB) ?
- ConstantInt::getTrue(BB->getContext()) :
- ConstantInt::getFalse(BB->getContext());
- AvailableValues.insert(CondInst, ConditionalConstant);
- DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '"
- << CondInst->getName() << "' as " << *ConditionalConstant
- << " in " << BB->getName() << "\n");
- // Replace all dominated uses with the known value.
- if (unsigned Count =
- replaceDominatedUsesWith(CondInst, ConditionalConstant, DT,
- BasicBlockEdge(Pred, BB))) {
- Changed = true;
- NumCSECVP = NumCSECVP + Count;
- }
- }
+ if (BasicBlock *Pred = BB->getSinglePredecessor()) {
+ auto *BI = dyn_cast<BranchInst>(Pred->getTerminator());
+ if (BI && BI->isConditional()) {
+ auto *CondInst = dyn_cast<Instruction>(BI->getCondition());
+ if (CondInst && SimpleValue::canHandle(CondInst)) {
+ assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB);
+ auto *TorF = (BI->getSuccessor(0) == BB)
+ ? ConstantInt::getTrue(BB->getContext())
+ : ConstantInt::getFalse(BB->getContext());
+ AvailableValues.insert(CondInst, TorF);
+ DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '"
+ << CondInst->getName() << "' as " << *TorF << " in "
+ << BB->getName() << "\n");
+ // Replace all dominated uses with the known value.
+ if (unsigned Count = replaceDominatedUsesWith(
+ CondInst, TorF, DT, BasicBlockEdge(Pred, BB))) {
+ Changed = true;
+ NumCSECVP = NumCSECVP + Count;
+ }
+ }
+ }
+ }
/// LastStore - Keep track of the last non-volatile store that we saw... for
/// as long as there in no instruction that reads memory. If we see a store
@@ -761,12 +762,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
continue;
}
- // If this instruction may read from memory, forget LastStore.
- // Load/store intrinsics will indicate both a read and a write to
- // memory. The target may override this (e.g. so that a store intrinsic
- // does not read from memory, and thus will be treated the same as a
- // regular store for commoning purposes).
- if (Inst->mayReadFromMemory() &&
+ // If this instruction may read from memory or throw (and potentially read
+ // from memory in the exception handler), forget LastStore. Load/store
+ // intrinsics will indicate both a read and a write to memory. The target
+ // may override this (e.g. so that a store intrinsic does not read from
+ // memory, and thus will be treated the same as a regular store for
+ // commoning purposes).
+ if ((Inst->mayReadFromMemory() || Inst->mayThrow()) &&
!(MemInst.isValid() && !MemInst.mayReadFromMemory()))
LastStore = nullptr;
@@ -967,10 +969,8 @@ PreservedAnalyses EarlyCSEPass::run(Function &F,
if (!CSE.run())
return PreservedAnalyses::all();
- // CSE preserves the dominator tree because it doesn't mutate the CFG.
- // FIXME: Bundle this with other CFG-preservation.
PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
if (UseMemorySSA)
PA.preserve<MemorySSAAnalysis>();
diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp
index 545036d724ef1..8a5af6195f1b5 100644
--- a/lib/Transforms/Scalar/Float2Int.cpp
+++ b/lib/Transforms/Scalar/Float2Int.cpp
@@ -516,11 +516,10 @@ FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) {
if (!runImpl(F))
return PreservedAnalyses::all();
- else {
- // FIXME: This should also 'preserve the CFG'.
- PreservedAnalyses PA;
- PA.preserve<GlobalsAA>();
- return PA;
- }
+
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ PA.preserve<GlobalsAA>();
+ return PA;
}
} // End namespace llvm
diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp
index 0137378b828b4..be696df548d52 100644
--- a/lib/Transforms/Scalar/GVN.cpp
+++ b/lib/Transforms/Scalar/GVN.cpp
@@ -36,7 +36,6 @@
#include "llvm/Analysis/OptimizationDiagnosticInfo.h"
#include "llvm/Analysis/PHITransAddr.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/GlobalVariable.h"
@@ -51,9 +50,12 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
+#include "llvm/Transforms/Utils/VNCoercion.h"
+
#include <vector>
using namespace llvm;
using namespace llvm::gvn;
+using namespace llvm::VNCoercion;
using namespace PatternMatch;
#define DEBUG_TYPE "gvn"
@@ -595,11 +597,12 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<GlobalsAA>();
+ PA.preserve<TargetLibraryAnalysis>();
return PA;
}
-LLVM_DUMP_METHOD
-void GVN::dump(DenseMap<uint32_t, Value*>& d) {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) {
errs() << "{\n";
for (DenseMap<uint32_t, Value*>::iterator I = d.begin(),
E = d.end(); I != E; ++I) {
@@ -608,6 +611,7 @@ void GVN::dump(DenseMap<uint32_t, Value*>& d) {
}
errs() << "}\n";
}
+#endif
/// Return true if we can prove that the value
/// we're analyzing is fully available in the specified block. As we go, keep
@@ -690,442 +694,6 @@ SpeculationFailure:
}
-/// Return true if CoerceAvailableValueToLoadType will succeed.
-static bool CanCoerceMustAliasedValueToLoad(Value *StoredVal,
- Type *LoadTy,
- const DataLayout &DL) {
- // If the loaded or stored value is an first class array or struct, don't try
- // to transform them. We need to be able to bitcast to integer.
- if (LoadTy->isStructTy() || LoadTy->isArrayTy() ||
- StoredVal->getType()->isStructTy() ||
- StoredVal->getType()->isArrayTy())
- return false;
-
- // The store has to be at least as big as the load.
- if (DL.getTypeSizeInBits(StoredVal->getType()) <
- DL.getTypeSizeInBits(LoadTy))
- return false;
-
- return true;
-}
-
-/// If we saw a store of a value to memory, and
-/// then a load from a must-aliased pointer of a different type, try to coerce
-/// the stored value. LoadedTy is the type of the load we want to replace.
-/// IRB is IRBuilder used to insert new instructions.
-///
-/// If we can't do it, return null.
-static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy,
- IRBuilder<> &IRB,
- const DataLayout &DL) {
- assert(CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) &&
- "precondition violation - materialization can't fail");
-
- if (auto *C = dyn_cast<Constant>(StoredVal))
- if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
- StoredVal = FoldedStoredVal;
-
- // If this is already the right type, just return it.
- Type *StoredValTy = StoredVal->getType();
-
- uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy);
- uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy);
-
- // If the store and reload are the same size, we can always reuse it.
- if (StoredValSize == LoadedValSize) {
- // Pointer to Pointer -> use bitcast.
- if (StoredValTy->getScalarType()->isPointerTy() &&
- LoadedTy->getScalarType()->isPointerTy()) {
- StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy);
- } else {
- // Convert source pointers to integers, which can be bitcast.
- if (StoredValTy->getScalarType()->isPointerTy()) {
- StoredValTy = DL.getIntPtrType(StoredValTy);
- StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy);
- }
-
- Type *TypeToCastTo = LoadedTy;
- if (TypeToCastTo->getScalarType()->isPointerTy())
- TypeToCastTo = DL.getIntPtrType(TypeToCastTo);
-
- if (StoredValTy != TypeToCastTo)
- StoredVal = IRB.CreateBitCast(StoredVal, TypeToCastTo);
-
- // Cast to pointer if the load needs a pointer type.
- if (LoadedTy->getScalarType()->isPointerTy())
- StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy);
- }
-
- if (auto *C = dyn_cast<ConstantExpr>(StoredVal))
- if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
- StoredVal = FoldedStoredVal;
-
- return StoredVal;
- }
-
- // If the loaded value is smaller than the available value, then we can
- // extract out a piece from it. If the available value is too small, then we
- // can't do anything.
- assert(StoredValSize >= LoadedValSize &&
- "CanCoerceMustAliasedValueToLoad fail");
-
- // Convert source pointers to integers, which can be manipulated.
- if (StoredValTy->getScalarType()->isPointerTy()) {
- StoredValTy = DL.getIntPtrType(StoredValTy);
- StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy);
- }
-
- // Convert vectors and fp to integer, which can be manipulated.
- if (!StoredValTy->isIntegerTy()) {
- StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize);
- StoredVal = IRB.CreateBitCast(StoredVal, StoredValTy);
- }
-
- // If this is a big-endian system, we need to shift the value down to the low
- // bits so that a truncate will work.
- if (DL.isBigEndian()) {
- uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) -
- DL.getTypeStoreSizeInBits(LoadedTy);
- StoredVal = IRB.CreateLShr(StoredVal, ShiftAmt, "tmp");
- }
-
- // Truncate the integer to the right size now.
- Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize);
- StoredVal = IRB.CreateTrunc(StoredVal, NewIntTy, "trunc");
-
- if (LoadedTy != NewIntTy) {
- // If the result is a pointer, inttoptr.
- if (LoadedTy->getScalarType()->isPointerTy())
- StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy, "inttoptr");
- else
- // Otherwise, bitcast.
- StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast");
- }
-
- if (auto *C = dyn_cast<Constant>(StoredVal))
- if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
- StoredVal = FoldedStoredVal;
-
- return StoredVal;
-}
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being a clobbering memory write (store,
-/// memset, memcpy, memmove). This means that the write *may* provide bits used
-/// by the load but we can't be sure because the pointers don't mustalias.
-///
-/// Check this case to see if there is anything more we can do before we give
-/// up. This returns -1 if we have to give up, or a byte number in the stored
-/// value of the piece that feeds the load.
-static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr,
- Value *WritePtr,
- uint64_t WriteSizeInBits,
- const DataLayout &DL) {
- // If the loaded or stored value is a first class array or struct, don't try
- // to transform them. We need to be able to bitcast to integer.
- if (LoadTy->isStructTy() || LoadTy->isArrayTy())
- return -1;
-
- int64_t StoreOffset = 0, LoadOffset = 0;
- Value *StoreBase =
- GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL);
- Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL);
- if (StoreBase != LoadBase)
- return -1;
-
- // If the load and store are to the exact same address, they should have been
- // a must alias. AA must have gotten confused.
- // FIXME: Study to see if/when this happens. One case is forwarding a memset
- // to a load from the base of the memset.
-
- // If the load and store don't overlap at all, the store doesn't provide
- // anything to the load. In this case, they really don't alias at all, AA
- // must have gotten confused.
- uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy);
-
- if ((WriteSizeInBits & 7) | (LoadSize & 7))
- return -1;
- uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes.
- LoadSize /= 8;
-
-
- bool isAAFailure = false;
- if (StoreOffset < LoadOffset)
- isAAFailure = StoreOffset+int64_t(StoreSize) <= LoadOffset;
- else
- isAAFailure = LoadOffset+int64_t(LoadSize) <= StoreOffset;
-
- if (isAAFailure)
- return -1;
-
- // If the Load isn't completely contained within the stored bits, we don't
- // have all the bits to feed it. We could do something crazy in the future
- // (issue a smaller load then merge the bits in) but this seems unlikely to be
- // valuable.
- if (StoreOffset > LoadOffset ||
- StoreOffset+StoreSize < LoadOffset+LoadSize)
- return -1;
-
- // Okay, we can do this transformation. Return the number of bytes into the
- // store that the load is.
- return LoadOffset-StoreOffset;
-}
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being a clobbering store.
-static int AnalyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr,
- StoreInst *DepSI) {
- // Cannot handle reading from store of first-class aggregate yet.
- if (DepSI->getValueOperand()->getType()->isStructTy() ||
- DepSI->getValueOperand()->getType()->isArrayTy())
- return -1;
-
- const DataLayout &DL = DepSI->getModule()->getDataLayout();
- Value *StorePtr = DepSI->getPointerOperand();
- uint64_t StoreSize =DL.getTypeSizeInBits(DepSI->getValueOperand()->getType());
- return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr,
- StorePtr, StoreSize, DL);
-}
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being clobbered by another load. See if
-/// the other load can feed into the second load.
-static int AnalyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr,
- LoadInst *DepLI, const DataLayout &DL){
- // Cannot handle reading from store of first-class aggregate yet.
- if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy())
- return -1;
-
- Value *DepPtr = DepLI->getPointerOperand();
- uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType());
- int R = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL);
- if (R != -1) return R;
-
- // If we have a load/load clobber an DepLI can be widened to cover this load,
- // then we should widen it!
- int64_t LoadOffs = 0;
- const Value *LoadBase =
- GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL);
- unsigned LoadSize = DL.getTypeStoreSize(LoadTy);
-
- unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize(
- LoadBase, LoadOffs, LoadSize, DepLI);
- if (Size == 0) return -1;
-
- // Check non-obvious conditions enforced by MDA which we rely on for being
- // able to materialize this potentially available value
- assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!");
- assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load");
-
- return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size*8, DL);
-}
-
-
-
-static int AnalyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr,
- MemIntrinsic *MI,
- const DataLayout &DL) {
- // If the mem operation is a non-constant size, we can't handle it.
- ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength());
- if (!SizeCst) return -1;
- uint64_t MemSizeInBits = SizeCst->getZExtValue()*8;
-
- // If this is memset, we just need to see if the offset is valid in the size
- // of the memset..
- if (MI->getIntrinsicID() == Intrinsic::memset)
- return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(),
- MemSizeInBits, DL);
-
- // If we have a memcpy/memmove, the only case we can handle is if this is a
- // copy from constant memory. In that case, we can read directly from the
- // constant memory.
- MemTransferInst *MTI = cast<MemTransferInst>(MI);
-
- Constant *Src = dyn_cast<Constant>(MTI->getSource());
- if (!Src) return -1;
-
- GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL));
- if (!GV || !GV->isConstant()) return -1;
-
- // See if the access is within the bounds of the transfer.
- int Offset = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr,
- MI->getDest(), MemSizeInBits, DL);
- if (Offset == -1)
- return Offset;
-
- unsigned AS = Src->getType()->getPointerAddressSpace();
- // Otherwise, see if we can constant fold a load from the constant with the
- // offset applied as appropriate.
- Src = ConstantExpr::getBitCast(Src,
- Type::getInt8PtrTy(Src->getContext(), AS));
- Constant *OffsetCst =
- ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset);
- Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src,
- OffsetCst);
- Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS));
- if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL))
- return Offset;
- return -1;
-}
-
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being a clobbering store. This means
-/// that the store provides bits used by the load but we the pointers don't
-/// mustalias. Check this case to see if there is anything more we can do
-/// before we give up.
-static Value *GetStoreValueForLoad(Value *SrcVal, unsigned Offset,
- Type *LoadTy,
- Instruction *InsertPt, const DataLayout &DL){
- LLVMContext &Ctx = SrcVal->getType()->getContext();
-
- uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8;
- uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8;
-
- IRBuilder<> Builder(InsertPt);
-
- // Compute which bits of the stored value are being used by the load. Convert
- // to an integer type to start with.
- if (SrcVal->getType()->getScalarType()->isPointerTy())
- SrcVal = Builder.CreatePtrToInt(SrcVal,
- DL.getIntPtrType(SrcVal->getType()));
- if (!SrcVal->getType()->isIntegerTy())
- SrcVal = Builder.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize*8));
-
- // Shift the bits to the least significant depending on endianness.
- unsigned ShiftAmt;
- if (DL.isLittleEndian())
- ShiftAmt = Offset*8;
- else
- ShiftAmt = (StoreSize-LoadSize-Offset)*8;
-
- if (ShiftAmt)
- SrcVal = Builder.CreateLShr(SrcVal, ShiftAmt);
-
- if (LoadSize != StoreSize)
- SrcVal = Builder.CreateTrunc(SrcVal, IntegerType::get(Ctx, LoadSize*8));
-
- return CoerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL);
-}
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being a clobbering load. This means
-/// that the load *may* provide bits used by the load but we can't be sure
-/// because the pointers don't mustalias. Check this case to see if there is
-/// anything more we can do before we give up.
-static Value *GetLoadValueForLoad(LoadInst *SrcVal, unsigned Offset,
- Type *LoadTy, Instruction *InsertPt,
- GVN &gvn) {
- const DataLayout &DL = SrcVal->getModule()->getDataLayout();
- // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to
- // widen SrcVal out to a larger load.
- unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType());
- unsigned LoadSize = DL.getTypeStoreSize(LoadTy);
- if (Offset+LoadSize > SrcValStoreSize) {
- assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!");
- assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load");
- // If we have a load/load clobber an DepLI can be widened to cover this
- // load, then we should widen it to the next power of 2 size big enough!
- unsigned NewLoadSize = Offset+LoadSize;
- if (!isPowerOf2_32(NewLoadSize))
- NewLoadSize = NextPowerOf2(NewLoadSize);
-
- Value *PtrVal = SrcVal->getPointerOperand();
-
- // Insert the new load after the old load. This ensures that subsequent
- // memdep queries will find the new load. We can't easily remove the old
- // load completely because it is already in the value numbering table.
- IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal));
- Type *DestPTy =
- IntegerType::get(LoadTy->getContext(), NewLoadSize*8);
- DestPTy = PointerType::get(DestPTy,
- PtrVal->getType()->getPointerAddressSpace());
- Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc());
- PtrVal = Builder.CreateBitCast(PtrVal, DestPTy);
- LoadInst *NewLoad = Builder.CreateLoad(PtrVal);
- NewLoad->takeName(SrcVal);
- NewLoad->setAlignment(SrcVal->getAlignment());
-
- DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n");
- DEBUG(dbgs() << "TO: " << *NewLoad << "\n");
-
- // Replace uses of the original load with the wider load. On a big endian
- // system, we need to shift down to get the relevant bits.
- Value *RV = NewLoad;
- if (DL.isBigEndian())
- RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8);
- RV = Builder.CreateTrunc(RV, SrcVal->getType());
- SrcVal->replaceAllUsesWith(RV);
-
- // We would like to use gvn.markInstructionForDeletion here, but we can't
- // because the load is already memoized into the leader map table that GVN
- // tracks. It is potentially possible to remove the load from the table,
- // but then there all of the operations based on it would need to be
- // rehashed. Just leave the dead load around.
- gvn.getMemDep().removeInstruction(SrcVal);
- SrcVal = NewLoad;
- }
-
- return GetStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL);
-}
-
-
-/// This function is called when we have a
-/// memdep query of a load that ends up being a clobbering mem intrinsic.
-static Value *GetMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset,
- Type *LoadTy, Instruction *InsertPt,
- const DataLayout &DL){
- LLVMContext &Ctx = LoadTy->getContext();
- uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy)/8;
-
- IRBuilder<> Builder(InsertPt);
-
- // We know that this method is only called when the mem transfer fully
- // provides the bits for the load.
- if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) {
- // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and
- // independently of what the offset is.
- Value *Val = MSI->getValue();
- if (LoadSize != 1)
- Val = Builder.CreateZExt(Val, IntegerType::get(Ctx, LoadSize*8));
-
- Value *OneElt = Val;
-
- // Splat the value out to the right number of bits.
- for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize; ) {
- // If we can double the number of bytes set, do it.
- if (NumBytesSet*2 <= LoadSize) {
- Value *ShVal = Builder.CreateShl(Val, NumBytesSet*8);
- Val = Builder.CreateOr(Val, ShVal);
- NumBytesSet <<= 1;
- continue;
- }
-
- // Otherwise insert one byte at a time.
- Value *ShVal = Builder.CreateShl(Val, 1*8);
- Val = Builder.CreateOr(OneElt, ShVal);
- ++NumBytesSet;
- }
-
- return CoerceAvailableValueToLoadType(Val, LoadTy, Builder, DL);
- }
-
- // Otherwise, this is a memcpy/memmove from a constant global.
- MemTransferInst *MTI = cast<MemTransferInst>(SrcInst);
- Constant *Src = cast<Constant>(MTI->getSource());
- unsigned AS = Src->getType()->getPointerAddressSpace();
-
- // Otherwise, see if we can constant fold a load from the constant with the
- // offset applied as appropriate.
- Src = ConstantExpr::getBitCast(Src,
- Type::getInt8PtrTy(Src->getContext(), AS));
- Constant *OffsetCst =
- ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset);
- Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src,
- OffsetCst);
- Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS));
- return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL);
-}
/// Given a set of loads specified by ValuesPerBlock,
@@ -1171,7 +739,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI,
if (isSimpleValue()) {
Res = getSimpleValue();
if (Res->getType() != LoadTy) {
- Res = GetStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL);
+ Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL);
DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " "
<< *getSimpleValue() << '\n'
@@ -1182,14 +750,20 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI,
if (Load->getType() == LoadTy && Offset == 0) {
Res = Load;
} else {
- Res = GetLoadValueForLoad(Load, Offset, LoadTy, InsertPt, gvn);
-
+ Res = getLoadValueForLoad(Load, Offset, LoadTy, InsertPt, DL);
+ // We would like to use gvn.markInstructionForDeletion here, but we can't
+ // because the load is already memoized into the leader map table that GVN
+ // tracks. It is potentially possible to remove the load from the table,
+ // but then there all of the operations based on it would need to be
+ // rehashed. Just leave the dead load around.
+ gvn.getMemDep().removeInstruction(Load);
DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " "
<< *getCoercedLoadValue() << '\n'
- << *Res << '\n' << "\n\n\n");
+ << *Res << '\n'
+ << "\n\n\n");
}
} else if (isMemIntrinValue()) {
- Res = GetMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy,
+ Res = getMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy,
InsertPt, DL);
DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset
<< " " << *getMemIntrinValue() << '\n'
@@ -1258,7 +832,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// Can't forward from non-atomic to atomic without violating memory model.
if (Address && LI->isAtomic() <= DepSI->isAtomic()) {
int Offset =
- AnalyzeLoadFromClobberingStore(LI->getType(), Address, DepSI);
+ analyzeLoadFromClobberingStore(LI->getType(), Address, DepSI, DL);
if (Offset != -1) {
Res = AvailableValue::get(DepSI->getValueOperand(), Offset);
return true;
@@ -1276,7 +850,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// Can't forward from non-atomic to atomic without violating memory model.
if (DepLI != LI && Address && LI->isAtomic() <= DepLI->isAtomic()) {
int Offset =
- AnalyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL);
+ analyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL);
if (Offset != -1) {
Res = AvailableValue::getLoad(DepLI, Offset);
@@ -1289,7 +863,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// forward a value on from it.
if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) {
if (Address && !LI->isAtomic()) {
- int Offset = AnalyzeLoadFromClobberingMemInst(LI->getType(), Address,
+ int Offset = analyzeLoadFromClobberingMemInst(LI->getType(), Address,
DepMI, DL);
if (Offset != -1) {
Res = AvailableValue::getMI(DepMI, Offset);
@@ -1334,7 +908,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// different types if we have to. If the stored value is larger or equal to
// the loaded value, we can reuse it.
if (S->getValueOperand()->getType() != LI->getType() &&
- !CanCoerceMustAliasedValueToLoad(S->getValueOperand(),
+ !canCoerceMustAliasedValueToLoad(S->getValueOperand(),
LI->getType(), DL))
return false;
@@ -1351,7 +925,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,
// If the stored value is larger or equal to the loaded value, we can reuse
// it.
if (LD->getType() != LI->getType() &&
- !CanCoerceMustAliasedValueToLoad(LD, LI->getType(), DL))
+ !canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL))
return false;
// Can't forward from non-atomic to atomic without violating memory model.
@@ -1713,7 +1287,7 @@ bool GVN::processNonLocalLoad(LoadInst *LI) {
// If instruction I has debug info, then we should not update it.
// Also, if I has a null DebugLoc, then it is still potentially incorrect
// to propagate LI's DebugLoc because LI may not post-dominate I.
- if (LI->getDebugLoc() && ValuesPerBlock.size() != 1)
+ if (LI->getDebugLoc() && LI->getParent() == I->getParent())
I->setDebugLoc(LI->getDebugLoc());
if (V->getType()->getScalarType()->isPointerTy())
MD->invalidateCachedPointerInfo(V);
@@ -1795,7 +1369,7 @@ static void patchReplacementInstruction(Instruction *I, Value *Repl) {
// Patch the replacement so that it is not more restrictive than the value
// being replaced.
- // Note that if 'I' is a load being replaced by some operation,
+ // Note that if 'I' is a load being replaced by some operation,
// for example, by an arithmetic operation, then andIRFlags()
// would just erase all math flags from the original arithmetic
// operation, which is clearly not wanted and not needed.
@@ -2187,11 +1761,11 @@ bool GVN::processInstruction(Instruction *I) {
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
i != e; ++i) {
- BasicBlock *Dst = i.getCaseSuccessor();
+ BasicBlock *Dst = i->getCaseSuccessor();
// If there is only a single edge, propagate the case value into it.
if (SwitchEdges.lookup(Dst) == 1) {
BasicBlockEdge E(Parent, Dst);
- Changed |= propagateEquality(SwitchCond, i.getCaseValue(), E, true);
+ Changed |= propagateEquality(SwitchCond, i->getCaseValue(), E, true);
}
}
return Changed;
@@ -2581,21 +2155,12 @@ bool GVN::iterateOnFunction(Function &F) {
// Top-down walk of the dominator tree
bool Changed = false;
- // Save the blocks this function have before transformation begins. GVN may
- // split critical edge, and hence may invalidate the RPO/DT iterator.
- //
- std::vector<BasicBlock *> BBVect;
- BBVect.reserve(256);
// Needed for value numbering with phi construction to work.
+ // RPOT walks the graph in its constructor and will not be invalidated during
+ // processBlock.
ReversePostOrderTraversal<Function *> RPOT(&F);
- for (ReversePostOrderTraversal<Function *>::rpo_iterator RI = RPOT.begin(),
- RE = RPOT.end();
- RI != RE; ++RI)
- BBVect.push_back(*RI);
-
- for (std::vector<BasicBlock *>::iterator I = BBVect.begin(), E = BBVect.end();
- I != E; I++)
- Changed |= processBlock(*I);
+ for (BasicBlock *BB : RPOT)
+ Changed |= processBlock(BB);
return Changed;
}
@@ -2783,6 +2348,7 @@ public:
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
+ AU.addPreserved<TargetLibraryInfoWrapperPass>();
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp
index f8e1d2e1a08ab..6adfe130d148b 100644
--- a/lib/Transforms/Scalar/GVNHoist.cpp
+++ b/lib/Transforms/Scalar/GVNHoist.cpp
@@ -17,16 +17,39 @@
// is disabled in the following cases.
// 1. Scalars across calls.
// 2. geps when corresponding load/store cannot be hoisted.
+//
+// TODO: Hoist from >2 successors. Currently GVNHoist will not hoist stores
+// in this case because it works on two instructions at a time.
+// entry:
+// switch i32 %c1, label %exit1 [
+// i32 0, label %sw0
+// i32 1, label %sw1
+// ]
+//
+// sw0:
+// store i32 1, i32* @G
+// br label %exit
+//
+// sw1:
+// store i32 1, i32* @G
+// br label %exit
+//
+// exit1:
+// store i32 1, i32* @G
+// ret void
+// exit:
+// ret void
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Utils/MemorySSA.h"
using namespace llvm;
@@ -60,7 +83,7 @@ static cl::opt<int>
cl::desc("Maximum length of dependent chains to hoist "
"(default = 10, unlimited = -1)"));
-namespace {
+namespace llvm {
// Provides a sorting function based on the execution order of two instructions.
struct SortByDFSIn {
@@ -72,13 +95,6 @@ public:
// Returns true when A executes before B.
bool operator()(const Instruction *A, const Instruction *B) const {
- // FIXME: libc++ has a std::sort() algorithm that will call the compare
- // function on the same element. Once PR20837 is fixed and some more years
- // pass by and all the buildbots have moved to a corrected std::sort(),
- // enable the following assert:
- //
- // assert(A != B);
-
const BasicBlock *BA = A->getParent();
const BasicBlock *BB = B->getParent();
unsigned ADFS, BDFS;
@@ -202,6 +218,7 @@ public:
GVNHoist(DominatorTree *DT, AliasAnalysis *AA, MemoryDependenceResults *MD,
MemorySSA *MSSA)
: DT(DT), AA(AA), MD(MD), MSSA(MSSA),
+ MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)),
HoistingGeps(false),
HoistedCtr(0)
{ }
@@ -249,9 +266,11 @@ private:
AliasAnalysis *AA;
MemoryDependenceResults *MD;
MemorySSA *MSSA;
+ std::unique_ptr<MemorySSAUpdater> MSSAUpdater;
const bool HoistingGeps;
DenseMap<const Value *, unsigned> DFSNumber;
BBSideEffectsSet BBSideEffects;
+ DenseSet<const BasicBlock*> HoistBarrier;
int HoistedCtr;
enum InsKind { Unknown, Scalar, Load, Store };
@@ -307,8 +326,8 @@ private:
continue;
}
- // Check for end of function, calls that do not return, etc.
- if (!isGuaranteedToTransferExecutionToSuccessor(BB->getTerminator()))
+ // We reached the leaf Basic Block => not all paths have this instruction.
+ if (!BB->getTerminator()->getNumSuccessors())
return false;
// When reaching the back-edge of a loop, there may be a path through the
@@ -360,7 +379,7 @@ private:
ReachedNewPt = true;
}
}
- if (defClobbersUseOrDef(Def, MU, *AA))
+ if (MemorySSAUtil::defClobbersUseOrDef(Def, MU, *AA))
return true;
}
@@ -387,7 +406,8 @@ private:
// executed between the execution of NewBB and OldBB. Hoisting an expression
// from OldBB into NewBB has to be safe on all execution paths.
for (auto I = idf_begin(OldBB), E = idf_end(OldBB); I != E;) {
- if (*I == NewBB) {
+ const BasicBlock *BB = *I;
+ if (BB == NewBB) {
// Stop traversal when reaching HoistPt.
I.skipChildren();
continue;
@@ -398,11 +418,17 @@ private:
return true;
// Impossible to hoist with exceptions on the path.
- if (hasEH(*I))
+ if (hasEH(BB))
+ return true;
+
+ // No such instruction after HoistBarrier in a basic block was
+ // selected for hoisting so instructions selected within basic block with
+ // a hoist barrier can be hoisted.
+ if ((BB != OldBB) && HoistBarrier.count(BB))
return true;
// Check that we do not move a store past loads.
- if (hasMemoryUse(NewPt, Def, *I))
+ if (hasMemoryUse(NewPt, Def, BB))
return true;
// -1 is unlimited number of blocks on all paths.
@@ -419,17 +445,18 @@ private:
// Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and
// return true when the counter NBBsOnAllPaths reaches 0, except when it is
// initialized to -1 which is unlimited.
- bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *BB,
+ bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *SrcBB,
int &NBBsOnAllPaths) {
- assert(DT->dominates(HoistPt, BB) && "Invalid path");
+ assert(DT->dominates(HoistPt, SrcBB) && "Invalid path");
// Walk all basic blocks reachable in depth-first iteration on
// the inverse CFG from BBInsn to NewHoistPt. These blocks are all the
// blocks that may be executed between the execution of NewHoistPt and
// BBInsn. Hoisting an expression from BBInsn into NewHoistPt has to be safe
// on all execution paths.
- for (auto I = idf_begin(BB), E = idf_end(BB); I != E;) {
- if (*I == HoistPt) {
+ for (auto I = idf_begin(SrcBB), E = idf_end(SrcBB); I != E;) {
+ const BasicBlock *BB = *I;
+ if (BB == HoistPt) {
// Stop traversal when reaching NewHoistPt.
I.skipChildren();
continue;
@@ -440,7 +467,13 @@ private:
return true;
// Impossible to hoist with exceptions on the path.
- if (hasEH(*I))
+ if (hasEH(BB))
+ return true;
+
+ // No such instruction after HoistBarrier in a basic block was
+ // selected for hoisting so instructions selected within basic block with
+ // a hoist barrier can be hoisted.
+ if ((BB != SrcBB) && HoistBarrier.count(BB))
return true;
// -1 is unlimited number of blocks on all paths.
@@ -626,6 +659,8 @@ private:
// Compute the insertion point and the list of expressions to be hoisted.
SmallVecInsn InstructionsToHoist;
for (auto I : V)
+ // We don't need to check for hoist-barriers here because if
+ // I->getParent() is a barrier then I precedes the barrier.
if (!hasEH(I->getParent()))
InstructionsToHoist.push_back(I);
@@ -809,9 +844,9 @@ private:
// legal when the ld/st is not moved past its current definition.
MemoryAccess *Def = OldMemAcc->getDefiningAccess();
NewMemAcc =
- MSSA->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End);
+ MSSAUpdater->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End);
OldMemAcc->replaceAllUsesWith(NewMemAcc);
- MSSA->removeMemoryAccess(OldMemAcc);
+ MSSAUpdater->removeMemoryAccess(OldMemAcc);
}
}
@@ -850,7 +885,7 @@ private:
// Update the uses of the old MSSA access with NewMemAcc.
MemoryAccess *OldMA = MSSA->getMemoryAccess(I);
OldMA->replaceAllUsesWith(NewMemAcc);
- MSSA->removeMemoryAccess(OldMA);
+ MSSAUpdater->removeMemoryAccess(OldMA);
}
Repl->andIRFlags(I);
@@ -872,7 +907,7 @@ private:
auto In = Phi->incoming_values();
if (all_of(In, [&](Use &U) { return U == NewMemAcc; })) {
Phi->replaceAllUsesWith(NewMemAcc);
- MSSA->removeMemoryAccess(Phi);
+ MSSAUpdater->removeMemoryAccess(Phi);
}
}
}
@@ -896,6 +931,12 @@ private:
for (BasicBlock *BB : depth_first(&F.getEntryBlock())) {
int InstructionNb = 0;
for (Instruction &I1 : *BB) {
+ // If I1 cannot guarantee progress, subsequent instructions
+ // in BB cannot be hoisted anyways.
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I1)) {
+ HoistBarrier.insert(BB);
+ break;
+ }
// Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting
// deeper may increase the register pressure and compilation time.
if (MaxDepthInBB != -1 && InstructionNb++ >= MaxDepthInBB)
diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp
index b05ef002a4561..7019287954a15 100644
--- a/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/lib/Transforms/Scalar/GuardWidening.cpp
@@ -568,8 +568,7 @@ bool GuardWideningImpl::combineRangeChecks(
return RC.getBase() == CurrentBase && RC.getLength() == CurrentLength;
};
- std::copy_if(Checks.begin(), Checks.end(),
- std::back_inserter(CurrentChecks), IsCurrentCheck);
+ copy_if(Checks, std::back_inserter(CurrentChecks), IsCurrentCheck);
Checks.erase(remove_if(Checks, IsCurrentCheck), Checks.end());
assert(CurrentChecks.size() != 0 && "We know we have at least one!");
@@ -658,8 +657,12 @@ PreservedAnalyses GuardWideningPass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &LI = AM.getResult<LoopAnalysis>(F);
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
- bool Changed = GuardWideningImpl(DT, PDT, LI).run();
- return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+ if (!GuardWideningImpl(DT, PDT, LI).run())
+ return PreservedAnalyses::all();
+
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) {
diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp
index 1752fb75eb1b1..dcb2a4a0c6e6b 100644
--- a/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -231,8 +231,9 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) {
bool isExact = false;
// See if we can convert this to an int64_t
uint64_t UIntVal;
- if (APF.convertToInteger(&UIntVal, 64, true, APFloat::rmTowardZero,
- &isExact) != APFloat::opOK || !isExact)
+ if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true,
+ APFloat::rmTowardZero, &isExact) != APFloat::opOK ||
+ !isExact)
return false;
IntVal = UIntVal;
return true;
@@ -906,7 +907,7 @@ class WidenIV {
SmallVector<NarrowIVDefUse, 8> NarrowIVUsers;
enum ExtendKind { ZeroExtended, SignExtended, Unknown };
- // A map tracking the kind of extension used to widen each narrow IV
+ // A map tracking the kind of extension used to widen each narrow IV
// and narrow IV user.
// Key: pointer to a narrow IV or IV user.
// Value: the kind of extension used to widen this Instruction.
@@ -1608,7 +1609,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,
return;
CmpInst::Predicate P =
- TrueDest ? Pred : CmpInst::getInversePredicate(Pred);
+ TrueDest ? Pred : CmpInst::getInversePredicate(Pred);
auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS));
auto CmpConstrainedLHSRange =
@@ -1634,7 +1635,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,
UpdateRangeFromGuards(NarrowUser);
BasicBlock *NarrowUserBB = NarrowUser->getParent();
- // If NarrowUserBB is statically unreachable asking dominator queries may
+ // If NarrowUserBB is statically unreachable asking dominator queries may
// yield surprising results. (e.g. the block may not have a dom tree node)
if (!DT->isReachableFromEntry(NarrowUserBB))
return;
@@ -2152,6 +2153,8 @@ linearFunctionTestReplace(Loop *L,
Value *CmpIndVar = IndVar;
const SCEV *IVCount = BackedgeTakenCount;
+ assert(L->getLoopLatch() && "Loop no longer in simplified form?");
+
// If the exiting block is the same as the backedge block, we prefer to
// compare against the post-incremented value, otherwise we must compare
// against the preincremented value.
@@ -2376,6 +2379,7 @@ bool IndVarSimplify::run(Loop *L) {
// Loop::getCanonicalInductionVariable only supports loops with preheaders,
// and we're in trouble if we can't find the induction variable even when
// we've manually inserted one.
+ // - LFTR relies on having a single backedge.
if (!L->isLoopSimplifyForm())
return false;
@@ -2492,8 +2496,9 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
if (!IVS.run(&L))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
namespace {
diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
index 8e81541c23378..85db6e5e11052 100644
--- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
+++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
@@ -446,6 +446,15 @@ struct LoopStructure {
BasicBlock *LatchExit;
unsigned LatchBrExitIdx;
+ // The loop represented by this instance of LoopStructure is semantically
+ // equivalent to:
+ //
+ // intN_ty inc = IndVarIncreasing ? 1 : -1;
+ // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT;
+ //
+ // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarNext)
+ // ... body ...
+
Value *IndVarNext;
Value *IndVarStart;
Value *LoopExitAt;
@@ -789,6 +798,10 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP
return None;
}
+ const SCEV *StartNext = IndVarNext->getStart();
+ const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE));
+ const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
+
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
// TODO: generalize the predicates here to also match their unsigned variants.
if (IsIncreasing) {
@@ -809,10 +822,22 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP
return None;
}
+ if (!SE.isLoopEntryGuardedByCond(
+ &L, CmpInst::ICMP_SLT, IndVarStart,
+ SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())))) {
+ FailureReason = "Induction variable start not bounded by upper limit";
+ return None;
+ }
+
IRBuilder<> B(Preheader->getTerminator());
RightValue = B.CreateAdd(RightValue, One);
+ } else {
+ if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SLT, IndVarStart,
+ RightSCEV)) {
+ FailureReason = "Induction variable start not bounded by upper limit";
+ return None;
+ }
}
-
} else {
bool FoundExpectedPred =
(Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) ||
@@ -831,15 +856,24 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP
return None;
}
+ if (!SE.isLoopEntryGuardedByCond(
+ &L, CmpInst::ICMP_SGT, IndVarStart,
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) {
+ FailureReason = "Induction variable start not bounded by lower limit";
+ return None;
+ }
+
IRBuilder<> B(Preheader->getTerminator());
RightValue = B.CreateSub(RightValue, One);
+ } else {
+ if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SGT, IndVarStart,
+ RightSCEV)) {
+ FailureReason = "Induction variable start not bounded by lower limit";
+ return None;
+ }
}
}
- const SCEV *StartNext = IndVarNext->getStart();
- const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE));
- const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
-
BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
assert(SE.getLoopDisposition(LatchCount, &L) ==
diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp
new file mode 100644
index 0000000000000..5d8701431a2ce
--- /dev/null
+++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -0,0 +1,903 @@
+//===-- NVPTXInferAddressSpace.cpp - ---------------------*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// CUDA C/C++ includes memory space designation as variable type qualifers (such
+// as __global__ and __shared__). Knowing the space of a memory access allows
+// CUDA compilers to emit faster PTX loads and stores. For example, a load from
+// shared memory can be translated to `ld.shared` which is roughly 10% faster
+// than a generic `ld` on an NVIDIA Tesla K40c.
+//
+// Unfortunately, type qualifiers only apply to variable declarations, so CUDA
+// compilers must infer the memory space of an address expression from
+// type-qualified variables.
+//
+// LLVM IR uses non-zero (so-called) specific address spaces to represent memory
+// spaces (e.g. addrspace(3) means shared memory). The Clang frontend
+// places only type-qualified variables in specific address spaces, and then
+// conservatively `addrspacecast`s each type-qualified variable to addrspace(0)
+// (so-called the generic address space) for other instructions to use.
+//
+// For example, the Clang translates the following CUDA code
+// __shared__ float a[10];
+// float v = a[i];
+// to
+// %0 = addrspacecast [10 x float] addrspace(3)* @a to [10 x float]*
+// %1 = gep [10 x float], [10 x float]* %0, i64 0, i64 %i
+// %v = load float, float* %1 ; emits ld.f32
+// @a is in addrspace(3) since it's type-qualified, but its use from %1 is
+// redirected to %0 (the generic version of @a).
+//
+// The optimization implemented in this file propagates specific address spaces
+// from type-qualified variable declarations to its users. For example, it
+// optimizes the above IR to
+// %1 = gep [10 x float] addrspace(3)* @a, i64 0, i64 %i
+// %v = load float addrspace(3)* %1 ; emits ld.shared.f32
+// propagating the addrspace(3) from @a to %1. As the result, the NVPTX
+// codegen is able to emit ld.shared.f32 for %v.
+//
+// Address space inference works in two steps. First, it uses a data-flow
+// analysis to infer as many generic pointers as possible to point to only one
+// specific address space. In the above example, it can prove that %1 only
+// points to addrspace(3). This algorithm was published in
+// CUDA: Compiling and optimizing for a GPU platform
+// Chakrabarti, Grover, Aarts, Kong, Kudlur, Lin, Marathe, Murphy, Wang
+// ICCS 2012
+//
+// Then, address space inference replaces all refinable generic pointers with
+// equivalent specific pointers.
+//
+// The major challenge of implementing this optimization is handling PHINodes,
+// which may create loops in the data flow graph. This brings two complications.
+//
+// First, the data flow analysis in Step 1 needs to be circular. For example,
+// %generic.input = addrspacecast float addrspace(3)* %input to float*
+// loop:
+// %y = phi [ %generic.input, %y2 ]
+// %y2 = getelementptr %y, 1
+// %v = load %y2
+// br ..., label %loop, ...
+// proving %y specific requires proving both %generic.input and %y2 specific,
+// but proving %y2 specific circles back to %y. To address this complication,
+// the data flow analysis operates on a lattice:
+// uninitialized > specific address spaces > generic.
+// All address expressions (our implementation only considers phi, bitcast,
+// addrspacecast, and getelementptr) start with the uninitialized address space.
+// The monotone transfer function moves the address space of a pointer down a
+// lattice path from uninitialized to specific and then to generic. A join
+// operation of two different specific address spaces pushes the expression down
+// to the generic address space. The analysis completes once it reaches a fixed
+// point.
+//
+// Second, IR rewriting in Step 2 also needs to be circular. For example,
+// converting %y to addrspace(3) requires the compiler to know the converted
+// %y2, but converting %y2 needs the converted %y. To address this complication,
+// we break these cycles using "undef" placeholders. When converting an
+// instruction `I` to a new address space, if its operand `Op` is not converted
+// yet, we let `I` temporarily use `undef` and fix all the uses of undef later.
+// For instance, our algorithm first converts %y to
+// %y' = phi float addrspace(3)* [ %input, undef ]
+// Then, it converts %y2 to
+// %y2' = getelementptr %y', 1
+// Finally, it fixes the undef in %y' so that
+// %y' = phi float addrspace(3)* [ %input, %y2' ]
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
+
+#define DEBUG_TYPE "infer-address-spaces"
+
+using namespace llvm;
+
+namespace {
+static const unsigned UninitializedAddressSpace = ~0u;
+
+using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>;
+
+/// \brief InferAddressSpaces
+class InferAddressSpaces : public FunctionPass {
+ /// Target specific address space which uses of should be replaced if
+ /// possible.
+ unsigned FlatAddrSpace;
+
+public:
+ static char ID;
+
+ InferAddressSpaces() : FunctionPass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ }
+
+ bool runOnFunction(Function &F) override;
+
+private:
+ // Returns the new address space of V if updated; otherwise, returns None.
+ Optional<unsigned>
+ updateAddressSpace(const Value &V,
+ const ValueToAddrSpaceMapTy &InferredAddrSpace) const;
+
+ // Tries to infer the specific address space of each address expression in
+ // Postorder.
+ void inferAddressSpaces(const std::vector<Value *> &Postorder,
+ ValueToAddrSpaceMapTy *InferredAddrSpace) const;
+
+ bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const;
+
+ // Changes the flat address expressions in function F to point to specific
+ // address spaces if InferredAddrSpace says so. Postorder is the postorder of
+ // all flat expressions in the use-def graph of function F.
+ bool
+ rewriteWithNewAddressSpaces(const std::vector<Value *> &Postorder,
+ const ValueToAddrSpaceMapTy &InferredAddrSpace,
+ Function *F) const;
+
+ void appendsFlatAddressExpressionToPostorderStack(
+ Value *V, std::vector<std::pair<Value *, bool>> *PostorderStack,
+ DenseSet<Value *> *Visited) const;
+
+ bool rewriteIntrinsicOperands(IntrinsicInst *II,
+ Value *OldV, Value *NewV) const;
+ void collectRewritableIntrinsicOperands(
+ IntrinsicInst *II,
+ std::vector<std::pair<Value *, bool>> *PostorderStack,
+ DenseSet<Value *> *Visited) const;
+
+ std::vector<Value *> collectFlatAddressExpressions(Function &F) const;
+
+ Value *cloneValueWithNewAddressSpace(
+ Value *V, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace,
+ SmallVectorImpl<const Use *> *UndefUsesToFix) const;
+ unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const;
+};
+} // end anonymous namespace
+
+char InferAddressSpaces::ID = 0;
+
+namespace llvm {
+void initializeInferAddressSpacesPass(PassRegistry &);
+}
+
+INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
+ false, false)
+
+// Returns true if V is an address expression.
+// TODO: Currently, we consider only phi, bitcast, addrspacecast, and
+// getelementptr operators.
+static bool isAddressExpression(const Value &V) {
+ if (!isa<Operator>(V))
+ return false;
+
+ switch (cast<Operator>(V).getOpcode()) {
+ case Instruction::PHI:
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
+ case Instruction::GetElementPtr:
+ case Instruction::Select:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Returns the pointer operands of V.
+//
+// Precondition: V is an address expression.
+static SmallVector<Value *, 2> getPointerOperands(const Value &V) {
+ assert(isAddressExpression(V));
+ const Operator &Op = cast<Operator>(V);
+ switch (Op.getOpcode()) {
+ case Instruction::PHI: {
+ auto IncomingValues = cast<PHINode>(Op).incoming_values();
+ return SmallVector<Value *, 2>(IncomingValues.begin(),
+ IncomingValues.end());
+ }
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
+ case Instruction::GetElementPtr:
+ return {Op.getOperand(0)};
+ case Instruction::Select:
+ return {Op.getOperand(1), Op.getOperand(2)};
+ default:
+ llvm_unreachable("Unexpected instruction type.");
+ }
+}
+
+// TODO: Move logic to TTI?
+bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
+ Value *OldV,
+ Value *NewV) const {
+ Module *M = II->getParent()->getParent()->getParent();
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::amdgcn_atomic_inc:
+ case Intrinsic::amdgcn_atomic_dec:{
+ const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4));
+ if (!IsVolatile || !IsVolatile->isNullValue())
+ return false;
+
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::objectsize: {
+ Type *DestTy = II->getType();
+ Type *SrcTy = NewV->getType();
+ Function *NewDecl =
+ Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy});
+ II->setArgOperand(0, NewV);
+ II->setCalledFunction(NewDecl);
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+// TODO: Move logic to TTI?
+void InferAddressSpaces::collectRewritableIntrinsicOperands(
+ IntrinsicInst *II, std::vector<std::pair<Value *, bool>> *PostorderStack,
+ DenseSet<Value *> *Visited) const {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::objectsize:
+ case Intrinsic::amdgcn_atomic_inc:
+ case Intrinsic::amdgcn_atomic_dec:
+ appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
+ PostorderStack, Visited);
+ break;
+ default:
+ break;
+ }
+}
+
+// Returns all flat address expressions in function F. The elements are
+// If V is an unvisited flat address expression, appends V to PostorderStack
+// and marks it as visited.
+void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack(
+ Value *V, std::vector<std::pair<Value *, bool>> *PostorderStack,
+ DenseSet<Value *> *Visited) const {
+ assert(V->getType()->isPointerTy());
+ if (isAddressExpression(*V) &&
+ V->getType()->getPointerAddressSpace() == FlatAddrSpace) {
+ if (Visited->insert(V).second)
+ PostorderStack->push_back(std::make_pair(V, false));
+ }
+}
+
+// Returns all flat address expressions in function F. The elements are ordered
+// ordered in postorder.
+std::vector<Value *>
+InferAddressSpaces::collectFlatAddressExpressions(Function &F) const {
+ // This function implements a non-recursive postorder traversal of a partial
+ // use-def graph of function F.
+ std::vector<std::pair<Value *, bool>> PostorderStack;
+ // The set of visited expressions.
+ DenseSet<Value *> Visited;
+
+ auto PushPtrOperand = [&](Value *Ptr) {
+ appendsFlatAddressExpressionToPostorderStack(Ptr, &PostorderStack,
+ &Visited);
+ };
+
+ // We only explore address expressions that are reachable from loads and
+ // stores for now because we aim at generating faster loads and stores.
+ for (Instruction &I : instructions(F)) {
+ if (auto *LI = dyn_cast<LoadInst>(&I))
+ PushPtrOperand(LI->getPointerOperand());
+ else if (auto *SI = dyn_cast<StoreInst>(&I))
+ PushPtrOperand(SI->getPointerOperand());
+ else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I))
+ PushPtrOperand(RMW->getPointerOperand());
+ else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I))
+ PushPtrOperand(CmpX->getPointerOperand());
+ else if (auto *MI = dyn_cast<MemIntrinsic>(&I)) {
+ // For memset/memcpy/memmove, any pointer operand can be replaced.
+ PushPtrOperand(MI->getRawDest());
+
+ // Handle 2nd operand for memcpy/memmove.
+ if (auto *MTI = dyn_cast<MemTransferInst>(MI))
+ PushPtrOperand(MTI->getRawSource());
+ } else if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ collectRewritableIntrinsicOperands(II, &PostorderStack, &Visited);
+ else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) {
+ // FIXME: Handle vectors of pointers
+ if (Cmp->getOperand(0)->getType()->isPointerTy()) {
+ PushPtrOperand(Cmp->getOperand(0));
+ PushPtrOperand(Cmp->getOperand(1));
+ }
+ }
+ }
+
+ std::vector<Value *> Postorder; // The resultant postorder.
+ while (!PostorderStack.empty()) {
+ // If the operands of the expression on the top are already explored,
+ // adds that expression to the resultant postorder.
+ if (PostorderStack.back().second) {
+ Postorder.push_back(PostorderStack.back().first);
+ PostorderStack.pop_back();
+ continue;
+ }
+ // Otherwise, adds its operands to the stack and explores them.
+ PostorderStack.back().second = true;
+ for (Value *PtrOperand : getPointerOperands(*PostorderStack.back().first)) {
+ appendsFlatAddressExpressionToPostorderStack(PtrOperand, &PostorderStack,
+ &Visited);
+ }
+ }
+ return Postorder;
+}
+
+// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone
+// of OperandUse.get() in the new address space. If the clone is not ready yet,
+// returns an undef in the new address space as a placeholder.
+static Value *operandWithNewAddressSpaceOrCreateUndef(
+ const Use &OperandUse, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace,
+ SmallVectorImpl<const Use *> *UndefUsesToFix) {
+ Value *Operand = OperandUse.get();
+
+ Type *NewPtrTy =
+ Operand->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
+
+ if (Constant *C = dyn_cast<Constant>(Operand))
+ return ConstantExpr::getAddrSpaceCast(C, NewPtrTy);
+
+ if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand))
+ return NewOperand;
+
+ UndefUsesToFix->push_back(&OperandUse);
+ return UndefValue::get(NewPtrTy);
+}
+
+// Returns a clone of `I` with its operands converted to those specified in
+// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an
+// operand whose address space needs to be modified might not exist in
+// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and
+// adds that operand use to UndefUsesToFix so that caller can fix them later.
+//
+// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
+// from a pointer whose type already matches. Therefore, this function returns a
+// Value* instead of an Instruction*.
+static Value *cloneInstructionWithNewAddressSpace(
+ Instruction *I, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace,
+ SmallVectorImpl<const Use *> *UndefUsesToFix) {
+ Type *NewPtrType =
+ I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
+
+ if (I->getOpcode() == Instruction::AddrSpaceCast) {
+ Value *Src = I->getOperand(0);
+ // Because `I` is flat, the source address space must be specific.
+ // Therefore, the inferred address space must be the source space, according
+ // to our algorithm.
+ assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+ if (Src->getType() != NewPtrType)
+ return new BitCastInst(Src, NewPtrType);
+ return Src;
+ }
+
+ // Computes the converted pointer operands.
+ SmallVector<Value *, 4> NewPointerOperands;
+ for (const Use &OperandUse : I->operands()) {
+ if (!OperandUse.get()->getType()->isPointerTy())
+ NewPointerOperands.push_back(nullptr);
+ else
+ NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef(
+ OperandUse, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix));
+ }
+
+ switch (I->getOpcode()) {
+ case Instruction::BitCast:
+ return new BitCastInst(NewPointerOperands[0], NewPtrType);
+ case Instruction::PHI: {
+ assert(I->getType()->isPointerTy());
+ PHINode *PHI = cast<PHINode>(I);
+ PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues());
+ for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) {
+ unsigned OperandNo = PHINode::getOperandNumForIncomingValue(Index);
+ NewPHI->addIncoming(NewPointerOperands[OperandNo],
+ PHI->getIncomingBlock(Index));
+ }
+ return NewPHI;
+ }
+ case Instruction::GetElementPtr: {
+ GetElementPtrInst *GEP = cast<GetElementPtrInst>(I);
+ GetElementPtrInst *NewGEP = GetElementPtrInst::Create(
+ GEP->getSourceElementType(), NewPointerOperands[0],
+ SmallVector<Value *, 4>(GEP->idx_begin(), GEP->idx_end()));
+ NewGEP->setIsInBounds(GEP->isInBounds());
+ return NewGEP;
+ }
+ case Instruction::Select: {
+ assert(I->getType()->isPointerTy());
+ return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
+ NewPointerOperands[2], "", nullptr, I);
+ }
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+}
+
+// Similar to cloneInstructionWithNewAddressSpace, returns a clone of the
+// constant expression `CE` with its operands replaced as specified in
+// ValueWithNewAddrSpace.
+static Value *cloneConstantExprWithNewAddressSpace(
+ ConstantExpr *CE, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace) {
+ Type *TargetType =
+ CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
+
+ if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+ // Because CE is flat, the source address space must be specific.
+ // Therefore, the inferred address space must be the source space according
+ // to our algorithm.
+ assert(CE->getOperand(0)->getType()->getPointerAddressSpace() ==
+ NewAddrSpace);
+ return ConstantExpr::getBitCast(CE->getOperand(0), TargetType);
+ }
+
+ if (CE->getOpcode() == Instruction::BitCast) {
+ if (Value *NewOperand = ValueWithNewAddrSpace.lookup(CE->getOperand(0)))
+ return ConstantExpr::getBitCast(cast<Constant>(NewOperand), TargetType);
+ return ConstantExpr::getAddrSpaceCast(CE, TargetType);
+ }
+
+ if (CE->getOpcode() == Instruction::Select) {
+ Constant *Src0 = CE->getOperand(1);
+ Constant *Src1 = CE->getOperand(2);
+ if (Src0->getType()->getPointerAddressSpace() ==
+ Src1->getType()->getPointerAddressSpace()) {
+
+ return ConstantExpr::getSelect(
+ CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType),
+ ConstantExpr::getAddrSpaceCast(Src1, TargetType));
+ }
+ }
+
+ // Computes the operands of the new constant expression.
+ SmallVector<Constant *, 4> NewOperands;
+ for (unsigned Index = 0; Index < CE->getNumOperands(); ++Index) {
+ Constant *Operand = CE->getOperand(Index);
+ // If the address space of `Operand` needs to be modified, the new operand
+ // with the new address space should already be in ValueWithNewAddrSpace
+ // because (1) the constant expressions we consider (i.e. addrspacecast,
+ // bitcast, and getelementptr) do not incur cycles in the data flow graph
+ // and (2) this function is called on constant expressions in postorder.
+ if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) {
+ NewOperands.push_back(cast<Constant>(NewOperand));
+ } else {
+ // Otherwise, reuses the old operand.
+ NewOperands.push_back(Operand);
+ }
+ }
+
+ if (CE->getOpcode() == Instruction::GetElementPtr) {
+ // Needs to specify the source type while constructing a getelementptr
+ // constant expression.
+ return CE->getWithOperands(
+ NewOperands, TargetType, /*OnlyIfReduced=*/false,
+ NewOperands[0]->getType()->getPointerElementType());
+ }
+
+ return CE->getWithOperands(NewOperands, TargetType);
+}
+
+// Returns a clone of the value `V`, with its operands replaced as specified in
+// ValueWithNewAddrSpace. This function is called on every flat address
+// expression whose address space needs to be modified, in postorder.
+//
+// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix.
+Value *InferAddressSpaces::cloneValueWithNewAddressSpace(
+ Value *V, unsigned NewAddrSpace,
+ const ValueToValueMapTy &ValueWithNewAddrSpace,
+ SmallVectorImpl<const Use *> *UndefUsesToFix) const {
+ // All values in Postorder are flat address expressions.
+ assert(isAddressExpression(*V) &&
+ V->getType()->getPointerAddressSpace() == FlatAddrSpace);
+
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ Value *NewV = cloneInstructionWithNewAddressSpace(
+ I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix);
+ if (Instruction *NewI = dyn_cast<Instruction>(NewV)) {
+ if (NewI->getParent() == nullptr) {
+ NewI->insertBefore(I);
+ NewI->takeName(I);
+ }
+ }
+ return NewV;
+ }
+
+ return cloneConstantExprWithNewAddressSpace(
+ cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace);
+}
+
+// Defines the join operation on the address space lattice (see the file header
+// comments).
+unsigned InferAddressSpaces::joinAddressSpaces(unsigned AS1,
+ unsigned AS2) const {
+ if (AS1 == FlatAddrSpace || AS2 == FlatAddrSpace)
+ return FlatAddrSpace;
+
+ if (AS1 == UninitializedAddressSpace)
+ return AS2;
+ if (AS2 == UninitializedAddressSpace)
+ return AS1;
+
+ // The join of two different specific address spaces is flat.
+ return (AS1 == AS2) ? AS1 : FlatAddrSpace;
+}
+
+bool InferAddressSpaces::runOnFunction(Function &F) {
+ if (skipFunction(F))
+ return false;
+
+ const TargetTransformInfo &TTI =
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ FlatAddrSpace = TTI.getFlatAddressSpace();
+ if (FlatAddrSpace == UninitializedAddressSpace)
+ return false;
+
+ // Collects all flat address expressions in postorder.
+ std::vector<Value *> Postorder = collectFlatAddressExpressions(F);
+
+ // Runs a data-flow analysis to refine the address spaces of every expression
+ // in Postorder.
+ ValueToAddrSpaceMapTy InferredAddrSpace;
+ inferAddressSpaces(Postorder, &InferredAddrSpace);
+
+ // Changes the address spaces of the flat address expressions who are inferred
+ // to point to a specific address space.
+ return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, &F);
+}
+
+void InferAddressSpaces::inferAddressSpaces(
+ const std::vector<Value *> &Postorder,
+ ValueToAddrSpaceMapTy *InferredAddrSpace) const {
+ SetVector<Value *> Worklist(Postorder.begin(), Postorder.end());
+ // Initially, all expressions are in the uninitialized address space.
+ for (Value *V : Postorder)
+ (*InferredAddrSpace)[V] = UninitializedAddressSpace;
+
+ while (!Worklist.empty()) {
+ Value *V = Worklist.pop_back_val();
+
+ // Tries to update the address space of the stack top according to the
+ // address spaces of its operands.
+ DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n');
+ Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace);
+ if (!NewAS.hasValue())
+ continue;
+ // If any updates are made, grabs its users to the worklist because
+ // their address spaces can also be possibly updated.
+ DEBUG(dbgs() << " to " << NewAS.getValue() << '\n');
+ (*InferredAddrSpace)[V] = NewAS.getValue();
+
+ for (Value *User : V->users()) {
+ // Skip if User is already in the worklist.
+ if (Worklist.count(User))
+ continue;
+
+ auto Pos = InferredAddrSpace->find(User);
+ // Our algorithm only updates the address spaces of flat address
+ // expressions, which are those in InferredAddrSpace.
+ if (Pos == InferredAddrSpace->end())
+ continue;
+
+ // Function updateAddressSpace moves the address space down a lattice
+ // path. Therefore, nothing to do if User is already inferred as flat (the
+ // bottom element in the lattice).
+ if (Pos->second == FlatAddrSpace)
+ continue;
+
+ Worklist.insert(User);
+ }
+ }
+}
+
+Optional<unsigned> InferAddressSpaces::updateAddressSpace(
+ const Value &V, const ValueToAddrSpaceMapTy &InferredAddrSpace) const {
+ assert(InferredAddrSpace.count(&V));
+
+ // The new inferred address space equals the join of the address spaces
+ // of all its pointer operands.
+ unsigned NewAS = UninitializedAddressSpace;
+
+ const Operator &Op = cast<Operator>(V);
+ if (Op.getOpcode() == Instruction::Select) {
+ Value *Src0 = Op.getOperand(1);
+ Value *Src1 = Op.getOperand(2);
+
+ auto I = InferredAddrSpace.find(Src0);
+ unsigned Src0AS = (I != InferredAddrSpace.end()) ?
+ I->second : Src0->getType()->getPointerAddressSpace();
+
+ auto J = InferredAddrSpace.find(Src1);
+ unsigned Src1AS = (J != InferredAddrSpace.end()) ?
+ J->second : Src1->getType()->getPointerAddressSpace();
+
+ auto *C0 = dyn_cast<Constant>(Src0);
+ auto *C1 = dyn_cast<Constant>(Src1);
+
+ // If one of the inputs is a constant, we may be able to do a constant
+ // addrspacecast of it. Defer inferring the address space until the input
+ // address space is known.
+ if ((C1 && Src0AS == UninitializedAddressSpace) ||
+ (C0 && Src1AS == UninitializedAddressSpace))
+ return None;
+
+ if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS))
+ NewAS = Src1AS;
+ else if (C1 && isSafeToCastConstAddrSpace(C1, Src0AS))
+ NewAS = Src0AS;
+ else
+ NewAS = joinAddressSpaces(Src0AS, Src1AS);
+ } else {
+ for (Value *PtrOperand : getPointerOperands(V)) {
+ auto I = InferredAddrSpace.find(PtrOperand);
+ unsigned OperandAS = I != InferredAddrSpace.end() ?
+ I->second : PtrOperand->getType()->getPointerAddressSpace();
+
+ // join(flat, *) = flat. So we can break if NewAS is already flat.
+ NewAS = joinAddressSpaces(NewAS, OperandAS);
+ if (NewAS == FlatAddrSpace)
+ break;
+ }
+ }
+
+ unsigned OldAS = InferredAddrSpace.lookup(&V);
+ assert(OldAS != FlatAddrSpace);
+ if (OldAS == NewAS)
+ return None;
+ return NewAS;
+}
+
+/// \p returns true if \p U is the pointer operand of a memory instruction with
+/// a single pointer operand that can have its address space changed by simply
+/// mutating the use to a new value.
+static bool isSimplePointerUseValidToReplace(Use &U) {
+ User *Inst = U.getUser();
+ unsigned OpNo = U.getOperandNo();
+
+ if (auto *LI = dyn_cast<LoadInst>(Inst))
+ return OpNo == LoadInst::getPointerOperandIndex() && !LI->isVolatile();
+
+ if (auto *SI = dyn_cast<StoreInst>(Inst))
+ return OpNo == StoreInst::getPointerOperandIndex() && !SI->isVolatile();
+
+ if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst))
+ return OpNo == AtomicRMWInst::getPointerOperandIndex() && !RMW->isVolatile();
+
+ if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) {
+ return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() &&
+ !CmpX->isVolatile();
+ }
+
+ return false;
+}
+
+/// Update memory intrinsic uses that require more complex processing than
+/// simple memory instructions. Thse require re-mangling and may have multiple
+/// pointer operands.
+static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV,
+ Value *NewV) {
+ IRBuilder<> B(MI);
+ MDNode *TBAA = MI->getMetadata(LLVMContext::MD_tbaa);
+ MDNode *ScopeMD = MI->getMetadata(LLVMContext::MD_alias_scope);
+ MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias);
+
+ if (auto *MSI = dyn_cast<MemSetInst>(MI)) {
+ B.CreateMemSet(NewV, MSI->getValue(),
+ MSI->getLength(), MSI->getAlignment(),
+ false, // isVolatile
+ TBAA, ScopeMD, NoAliasMD);
+ } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) {
+ Value *Src = MTI->getRawSource();
+ Value *Dest = MTI->getRawDest();
+
+ // Be careful in case this is a self-to-self copy.
+ if (Src == OldV)
+ Src = NewV;
+
+ if (Dest == OldV)
+ Dest = NewV;
+
+ if (isa<MemCpyInst>(MTI)) {
+ MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct);
+ B.CreateMemCpy(Dest, Src, MTI->getLength(),
+ MTI->getAlignment(),
+ false, // isVolatile
+ TBAA, TBAAStruct, ScopeMD, NoAliasMD);
+ } else {
+ assert(isa<MemMoveInst>(MTI));
+ B.CreateMemMove(Dest, Src, MTI->getLength(),
+ MTI->getAlignment(),
+ false, // isVolatile
+ TBAA, ScopeMD, NoAliasMD);
+ }
+ } else
+ llvm_unreachable("unhandled MemIntrinsic");
+
+ MI->eraseFromParent();
+ return true;
+}
+
+// \p returns true if it is OK to change the address space of constant \p C with
+// a ConstantExpr addrspacecast.
+bool InferAddressSpaces::isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const {
+ assert(NewAS != UninitializedAddressSpace);
+
+ unsigned SrcAS = C->getType()->getPointerAddressSpace();
+ if (SrcAS == NewAS || isa<UndefValue>(C))
+ return true;
+
+ // Prevent illegal casts between different non-flat address spaces.
+ if (SrcAS != FlatAddrSpace && NewAS != FlatAddrSpace)
+ return false;
+
+ if (isa<ConstantPointerNull>(C))
+ return true;
+
+ if (auto *Op = dyn_cast<Operator>(C)) {
+ // If we already have a constant addrspacecast, it should be safe to cast it
+ // off.
+ if (Op->getOpcode() == Instruction::AddrSpaceCast)
+ return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS);
+
+ if (Op->getOpcode() == Instruction::IntToPtr &&
+ Op->getType()->getPointerAddressSpace() == FlatAddrSpace)
+ return true;
+ }
+
+ return false;
+}
+
+static Value::use_iterator skipToNextUser(Value::use_iterator I,
+ Value::use_iterator End) {
+ User *CurUser = I->getUser();
+ ++I;
+
+ while (I != End && I->getUser() == CurUser)
+ ++I;
+
+ return I;
+}
+
+bool InferAddressSpaces::rewriteWithNewAddressSpaces(
+ const std::vector<Value *> &Postorder,
+ const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const {
+ // For each address expression to be modified, creates a clone of it with its
+ // pointer operands converted to the new address space. Since the pointer
+ // operands are converted, the clone is naturally in the new address space by
+ // construction.
+ ValueToValueMapTy ValueWithNewAddrSpace;
+ SmallVector<const Use *, 32> UndefUsesToFix;
+ for (Value* V : Postorder) {
+ unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
+ if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
+ ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace(
+ V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix);
+ }
+ }
+
+ if (ValueWithNewAddrSpace.empty())
+ return false;
+
+ // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace.
+ for (const Use *UndefUse : UndefUsesToFix) {
+ User *V = UndefUse->getUser();
+ User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V));
+ unsigned OperandNo = UndefUse->getOperandNo();
+ assert(isa<UndefValue>(NewV->getOperand(OperandNo)));
+ NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get()));
+ }
+
+ // Replaces the uses of the old address expressions with the new ones.
+ for (Value *V : Postorder) {
+ Value *NewV = ValueWithNewAddrSpace.lookup(V);
+ if (NewV == nullptr)
+ continue;
+
+ DEBUG(dbgs() << "Replacing the uses of " << *V
+ << "\n with\n " << *NewV << '\n');
+
+ Value::use_iterator I, E, Next;
+ for (I = V->use_begin(), E = V->use_end(); I != E; ) {
+ Use &U = *I;
+
+ // Some users may see the same pointer operand in multiple operands. Skip
+ // to the next instruction.
+ I = skipToNextUser(I, E);
+
+ if (isSimplePointerUseValidToReplace(U)) {
+ // If V is used as the pointer operand of a compatible memory operation,
+ // sets the pointer operand to NewV. This replacement does not change
+ // the element type, so the resultant load/store is still valid.
+ U.set(NewV);
+ continue;
+ }
+
+ User *CurUser = U.getUser();
+ // Handle more complex cases like intrinsic that need to be remangled.
+ if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
+ if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
+ continue;
+ }
+
+ if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
+ if (rewriteIntrinsicOperands(II, V, NewV))
+ continue;
+ }
+
+ if (isa<Instruction>(CurUser)) {
+ if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) {
+ // If we can infer that both pointers are in the same addrspace,
+ // transform e.g.
+ // %cmp = icmp eq float* %p, %q
+ // into
+ // %cmp = icmp eq float addrspace(3)* %new_p, %new_q
+
+ unsigned NewAS = NewV->getType()->getPointerAddressSpace();
+ int SrcIdx = U.getOperandNo();
+ int OtherIdx = (SrcIdx == 0) ? 1 : 0;
+ Value *OtherSrc = Cmp->getOperand(OtherIdx);
+
+ if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
+ if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
+ Cmp->setOperand(OtherIdx, OtherNewV);
+ Cmp->setOperand(SrcIdx, NewV);
+ continue;
+ }
+ }
+
+ // Even if the type mismatches, we can cast the constant.
+ if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
+ if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
+ Cmp->setOperand(SrcIdx, NewV);
+ Cmp->setOperand(OtherIdx,
+ ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType()));
+ continue;
+ }
+ }
+ }
+
+ // Otherwise, replaces the use with flat(NewV).
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ BasicBlock::iterator InsertPos = std::next(I->getIterator());
+ while (isa<PHINode>(InsertPos))
+ ++InsertPos;
+ U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos));
+ } else {
+ U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
+ V->getType()));
+ }
+ }
+ }
+
+ if (V->use_empty())
+ RecursivelyDeleteTriviallyDeadInstructions(V);
+ }
+
+ return true;
+}
+
+FunctionPass *llvm::createInferAddressSpacesPass() {
+ return new InferAddressSpaces();
+}
diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp
index 1870c3deb4f33..08eb95a1a3d3e 100644
--- a/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/lib/Transforms/Scalar/JumpThreading.cpp
@@ -17,6 +17,7 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
@@ -30,11 +31,13 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
@@ -89,6 +92,7 @@ namespace {
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<LazyValueInfoWrapperPass>();
AU.addPreserved<LazyValueInfoWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
@@ -104,6 +108,7 @@ INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading",
"Jump Threading", false, false)
INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_END(JumpThreading, "jump-threading",
"Jump Threading", false, false)
@@ -121,6 +126,7 @@ bool JumpThreading::runOnFunction(Function &F) {
return false;
auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
+ auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
std::unique_ptr<BlockFrequencyInfo> BFI;
std::unique_ptr<BranchProbabilityInfo> BPI;
bool HasProfileData = F.getEntryCount().hasValue();
@@ -129,7 +135,8 @@ bool JumpThreading::runOnFunction(Function &F) {
BPI.reset(new BranchProbabilityInfo(F, LI));
BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
}
- return Impl.runImpl(F, TLI, LVI, HasProfileData, std::move(BFI),
+
+ return Impl.runImpl(F, TLI, LVI, AA, HasProfileData, std::move(BFI),
std::move(BPI));
}
@@ -138,6 +145,8 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &LVI = AM.getResult<LazyValueAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+
std::unique_ptr<BlockFrequencyInfo> BFI;
std::unique_ptr<BranchProbabilityInfo> BPI;
bool HasProfileData = F.getEntryCount().hasValue();
@@ -146,12 +155,9 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,
BPI.reset(new BranchProbabilityInfo(F, LI));
BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
}
- bool Changed =
- runImpl(F, &TLI, &LVI, HasProfileData, std::move(BFI), std::move(BPI));
- // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<LazyValueAnalysis>(F);
+ bool Changed = runImpl(F, &TLI, &LVI, &AA, HasProfileData, std::move(BFI),
+ std::move(BPI));
if (!Changed)
return PreservedAnalyses::all();
@@ -161,18 +167,23 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,
}
bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
- LazyValueInfo *LVI_, bool HasProfileData_,
+ LazyValueInfo *LVI_, AliasAnalysis *AA_,
+ bool HasProfileData_,
std::unique_ptr<BlockFrequencyInfo> BFI_,
std::unique_ptr<BranchProbabilityInfo> BPI_) {
DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n");
TLI = TLI_;
LVI = LVI_;
+ AA = AA_;
BFI.reset();
BPI.reset();
// When profile data is available, we need to update edge weights after
// successful jump threading, which requires both BPI and BFI being available.
HasProfileData = HasProfileData_;
+ auto *GuardDecl = F.getParent()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_guard));
+ HasGuards = GuardDecl && !GuardDecl->use_empty();
if (HasProfileData) {
BPI = std::move(BPI_);
BFI = std::move(BFI_);
@@ -226,26 +237,13 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
BB != &BB->getParent()->getEntryBlock() &&
// If the terminator is the only non-phi instruction, try to nuke it.
BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB)) {
- // Since TryToSimplifyUncondBranchFromEmptyBlock may delete the
- // block, we have to make sure it isn't in the LoopHeaders set. We
- // reinsert afterward if needed.
- bool ErasedFromLoopHeaders = LoopHeaders.erase(BB);
- BasicBlock *Succ = BI->getSuccessor(0);
-
// FIXME: It is always conservatively correct to drop the info
// for a block even if it doesn't get erased. This isn't totally
// awesome, but it allows us to use AssertingVH to prevent nasty
// dangling pointer issues within LazyValueInfo.
LVI->eraseBlock(BB);
- if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) {
+ if (TryToSimplifyUncondBranchFromEmptyBlock(BB))
Changed = true;
- // If we deleted BB and BB was the header of a loop, then the
- // successor is now the header of the loop.
- BB = Succ;
- }
-
- if (ErasedFromLoopHeaders)
- LoopHeaders.insert(BB);
}
}
EverChanged |= Changed;
@@ -255,10 +253,13 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
return EverChanged;
}
-/// getJumpThreadDuplicationCost - Return the cost of duplicating this block to
-/// thread across it. Stop scanning the block when passing the threshold.
-static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB,
+/// Return the cost of duplicating a piece of this block from first non-phi
+/// and before StopAt instruction to thread across it. Stop scanning the block
+/// when exceeding the threshold. If duplication is impossible, returns ~0U.
+static unsigned getJumpThreadDuplicationCost(BasicBlock *BB,
+ Instruction *StopAt,
unsigned Threshold) {
+ assert(StopAt->getParent() == BB && "Not an instruction from proper BB?");
/// Ignore PHI nodes, these will be flattened when duplication happens.
BasicBlock::const_iterator I(BB->getFirstNonPHI());
@@ -266,15 +267,17 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB,
// branch, so they shouldn't count against the duplication cost.
unsigned Bonus = 0;
- const TerminatorInst *BBTerm = BB->getTerminator();
- // Threading through a switch statement is particularly profitable. If this
- // block ends in a switch, decrease its cost to make it more likely to happen.
- if (isa<SwitchInst>(BBTerm))
- Bonus = 6;
-
- // The same holds for indirect branches, but slightly more so.
- if (isa<IndirectBrInst>(BBTerm))
- Bonus = 8;
+ if (BB->getTerminator() == StopAt) {
+ // Threading through a switch statement is particularly profitable. If this
+ // block ends in a switch, decrease its cost to make it more likely to
+ // happen.
+ if (isa<SwitchInst>(StopAt))
+ Bonus = 6;
+
+ // The same holds for indirect branches, but slightly more so.
+ if (isa<IndirectBrInst>(StopAt))
+ Bonus = 8;
+ }
// Bump the threshold up so the early exit from the loop doesn't skip the
// terminator-based Size adjustment at the end.
@@ -283,7 +286,7 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB,
// Sum up the cost of each instruction until we get to the terminator. Don't
// include the terminator because the copy won't include it.
unsigned Size = 0;
- for (; !isa<TerminatorInst>(I); ++I) {
+ for (; &*I != StopAt; ++I) {
// Stop scanning the block if we've reached the threshold.
if (Size > Threshold)
@@ -729,6 +732,10 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
if (TryToUnfoldSelectInCurrBB(BB))
return true;
+ // Look if we can propagate guards to predecessors.
+ if (HasGuards && ProcessGuards(BB))
+ return true;
+
// What kind of constant we're looking for.
ConstantPreference Preference = WantInteger;
@@ -804,7 +811,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
return false;
}
-
if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) {
// If we're branching on a conditional, LVI might be able to determine
// it's value at the branch instruction. We only handle comparisons
@@ -812,7 +818,12 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
// TODO: This should be extended to handle switches as well.
BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());
Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1));
- if (CondBr && CondConst && CondBr->isConditional()) {
+ if (CondBr && CondConst) {
+ // We should have returned as soon as we turn a conditional branch to
+ // unconditional. Because its no longer interesting as far as jump
+ // threading is concerned.
+ assert(CondBr->isConditional() && "Threading on unconditional terminator");
+
LazyValueInfo::Tristate Ret =
LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0),
CondConst, CondBr);
@@ -835,10 +846,12 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
}
return true;
}
- }
- if (CondBr && CondConst && TryToUnfoldSelect(CondCmp, BB))
- return true;
+ // We did not manage to simplify this branch, try to see whether
+ // CondCmp depends on a known phi-select pattern.
+ if (TryToUnfoldSelect(CondCmp, BB))
+ return true;
+ }
}
// Check for some cases that are worth simplifying. Right now we want to look
@@ -857,7 +870,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
if (SimplifyPartiallyRedundantLoad(LI))
return true;
-
// Handle a variety of cases where we are branching on something derived from
// a PHI node in the current block. If we can prove that any predecessors
// compute a predictable value based on a PHI node, thread those predecessors.
@@ -871,7 +883,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator()))
return ProcessBranchOnPHI(PN);
-
// If this is an otherwise-unfoldable branch on a XOR, see if we can simplify.
if (CondInst->getOpcode() == Instruction::Xor &&
CondInst->getParent() == BB && isa<BranchInst>(BB->getTerminator()))
@@ -920,6 +931,14 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) {
return false;
}
+/// Return true if Op is an instruction defined in the given block.
+static bool isOpDefinedInBlock(Value *Op, BasicBlock *BB) {
+ if (Instruction *OpInst = dyn_cast<Instruction>(Op))
+ if (OpInst->getParent() == BB)
+ return true;
+ return false;
+}
+
/// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant
/// load instruction, eliminate it by replacing it with a PHI node. This is an
/// important optimization that encourages jump threading, and needs to be run
@@ -942,18 +961,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {
Value *LoadedPtr = LI->getOperand(0);
- // If the loaded operand is defined in the LoadBB, it can't be available.
- // TODO: Could do simple PHI translation, that would be fun :)
- if (Instruction *PtrOp = dyn_cast<Instruction>(LoadedPtr))
- if (PtrOp->getParent() == LoadBB)
- return false;
+ // If the loaded operand is defined in the LoadBB and its not a phi,
+ // it can't be available in predecessors.
+ if (isOpDefinedInBlock(LoadedPtr, LoadBB) && !isa<PHINode>(LoadedPtr))
+ return false;
// Scan a few instructions up from the load, to see if it is obviously live at
// the entry to its block.
BasicBlock::iterator BBIt(LI);
bool IsLoadCSE;
- if (Value *AvailableVal =
- FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan, nullptr, &IsLoadCSE)) {
+ if (Value *AvailableVal = FindAvailableLoadedValue(
+ LI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) {
// If the value of the load is locally available within the block, just use
// it. This frequently occurs for reg2mem'd allocas.
@@ -997,12 +1015,34 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {
if (!PredsScanned.insert(PredBB).second)
continue;
- // Scan the predecessor to see if the value is available in the pred.
BBIt = PredBB->end();
- Value *PredAvailable = FindAvailableLoadedValue(LI, PredBB, BBIt,
- DefMaxInstsToScan,
- nullptr,
- &IsLoadCSE);
+ unsigned NumScanedInst = 0;
+ Value *PredAvailable = nullptr;
+ // NOTE: We don't CSE load that is volatile or anything stronger than
+ // unordered, that should have been checked when we entered the function.
+ assert(LI->isUnordered() && "Attempting to CSE volatile or atomic loads");
+ // If this is a load on a phi pointer, phi-translate it and search
+ // for available load/store to the pointer in predecessors.
+ Value *Ptr = LoadedPtr->DoPHITranslation(LoadBB, PredBB);
+ PredAvailable = FindAvailablePtrLoadStore(
+ Ptr, LI->getType(), LI->isAtomic(), PredBB, BBIt, DefMaxInstsToScan,
+ AA, &IsLoadCSE, &NumScanedInst);
+
+ // If PredBB has a single predecessor, continue scanning through the
+ // single precessor.
+ BasicBlock *SinglePredBB = PredBB;
+ while (!PredAvailable && SinglePredBB && BBIt == SinglePredBB->begin() &&
+ NumScanedInst < DefMaxInstsToScan) {
+ SinglePredBB = SinglePredBB->getSinglePredecessor();
+ if (SinglePredBB) {
+ BBIt = SinglePredBB->end();
+ PredAvailable = FindAvailablePtrLoadStore(
+ Ptr, LI->getType(), LI->isAtomic(), SinglePredBB, BBIt,
+ (DefMaxInstsToScan - NumScanedInst), AA, &IsLoadCSE,
+ &NumScanedInst);
+ }
+ }
+
if (!PredAvailable) {
OneUnavailablePred = PredBB;
continue;
@@ -1062,10 +1102,10 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) {
if (UnavailablePred) {
assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 &&
"Can't handle critical edge here!");
- LoadInst *NewVal =
- new LoadInst(LoadedPtr, LI->getName() + ".pr", false,
- LI->getAlignment(), LI->getOrdering(), LI->getSynchScope(),
- UnavailablePred->getTerminator());
+ LoadInst *NewVal = new LoadInst(
+ LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred),
+ LI->getName() + ".pr", false, LI->getAlignment(), LI->getOrdering(),
+ LI->getSynchScope(), UnavailablePred->getTerminator());
NewVal->setDebugLoc(LI->getDebugLoc());
if (AATags)
NewVal->setAAMetadata(AATags);
@@ -1229,7 +1269,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,
else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()))
DestBB = BI->getSuccessor(cast<ConstantInt>(Val)->isZero());
else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
- DestBB = SI->findCaseValue(cast<ConstantInt>(Val)).getCaseSuccessor();
+ DestBB = SI->findCaseValue(cast<ConstantInt>(Val))->getCaseSuccessor();
} else {
assert(isa<IndirectBrInst>(BB->getTerminator())
&& "Unexpected terminator");
@@ -1468,7 +1508,8 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB,
return false;
}
- unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold);
+ unsigned JumpThreadCost =
+ getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold);
if (JumpThreadCost > BBDupThreshold) {
DEBUG(dbgs() << " Not threading BB '" << BB->getName()
<< "' - Cost is too high: " << JumpThreadCost << "\n");
@@ -1756,7 +1797,8 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred(
return false;
}
- unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, BBDupThreshold);
+ unsigned DuplicationCost =
+ getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold);
if (DuplicationCost > BBDupThreshold) {
DEBUG(dbgs() << " Not duplicating BB '" << BB->getName()
<< "' - Cost is too high: " << DuplicationCost << "\n");
@@ -1888,10 +1930,10 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred(
/// TryToUnfoldSelect - Look for blocks of the form
/// bb1:
/// %a = select
-/// br bb
+/// br bb2
///
/// bb2:
-/// %p = phi [%a, %bb] ...
+/// %p = phi [%a, %bb1] ...
/// %c = icmp %p
/// br i1 %c
///
@@ -2021,3 +2063,130 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) {
return false;
}
+
+/// Try to propagate a guard from the current BB into one of its predecessors
+/// in case if another branch of execution implies that the condition of this
+/// guard is always true. Currently we only process the simplest case that
+/// looks like:
+///
+/// Start:
+/// %cond = ...
+/// br i1 %cond, label %T1, label %F1
+/// T1:
+/// br label %Merge
+/// F1:
+/// br label %Merge
+/// Merge:
+/// %condGuard = ...
+/// call void(i1, ...) @llvm.experimental.guard( i1 %condGuard )[ "deopt"() ]
+///
+/// And cond either implies condGuard or !condGuard. In this case all the
+/// instructions before the guard can be duplicated in both branches, and the
+/// guard is then threaded to one of them.
+bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) {
+ using namespace PatternMatch;
+ // We only want to deal with two predecessors.
+ BasicBlock *Pred1, *Pred2;
+ auto PI = pred_begin(BB), PE = pred_end(BB);
+ if (PI == PE)
+ return false;
+ Pred1 = *PI++;
+ if (PI == PE)
+ return false;
+ Pred2 = *PI++;
+ if (PI != PE)
+ return false;
+ if (Pred1 == Pred2)
+ return false;
+
+ // Try to thread one of the guards of the block.
+ // TODO: Look up deeper than to immediate predecessor?
+ auto *Parent = Pred1->getSinglePredecessor();
+ if (!Parent || Parent != Pred2->getSinglePredecessor())
+ return false;
+
+ if (auto *BI = dyn_cast<BranchInst>(Parent->getTerminator()))
+ for (auto &I : *BB)
+ if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>()))
+ if (ThreadGuard(BB, cast<IntrinsicInst>(&I), BI))
+ return true;
+
+ return false;
+}
+
+/// Try to propagate the guard from BB which is the lower block of a diamond
+/// to one of its branches, in case if diamond's condition implies guard's
+/// condition.
+bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard,
+ BranchInst *BI) {
+ assert(BI->getNumSuccessors() == 2 && "Wrong number of successors?");
+ assert(BI->isConditional() && "Unconditional branch has 2 successors?");
+ Value *GuardCond = Guard->getArgOperand(0);
+ Value *BranchCond = BI->getCondition();
+ BasicBlock *TrueDest = BI->getSuccessor(0);
+ BasicBlock *FalseDest = BI->getSuccessor(1);
+
+ auto &DL = BB->getModule()->getDataLayout();
+ bool TrueDestIsSafe = false;
+ bool FalseDestIsSafe = false;
+
+ // True dest is safe if BranchCond => GuardCond.
+ auto Impl = isImpliedCondition(BranchCond, GuardCond, DL);
+ if (Impl && *Impl)
+ TrueDestIsSafe = true;
+ else {
+ // False dest is safe if !BranchCond => GuardCond.
+ Impl =
+ isImpliedCondition(BranchCond, GuardCond, DL, /* InvertAPred */ true);
+ if (Impl && *Impl)
+ FalseDestIsSafe = true;
+ }
+
+ if (!TrueDestIsSafe && !FalseDestIsSafe)
+ return false;
+
+ BasicBlock *UnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest;
+ BasicBlock *GuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest;
+
+ ValueToValueMapTy UnguardedMapping, GuardedMapping;
+ Instruction *AfterGuard = Guard->getNextNode();
+ unsigned Cost = getJumpThreadDuplicationCost(BB, AfterGuard, BBDupThreshold);
+ if (Cost > BBDupThreshold)
+ return false;
+ // Duplicate all instructions before the guard and the guard itself to the
+ // branch where implication is not proved.
+ GuardedBlock = DuplicateInstructionsInSplitBetween(
+ BB, GuardedBlock, AfterGuard, GuardedMapping);
+ assert(GuardedBlock && "Could not create the guarded block?");
+ // Duplicate all instructions before the guard in the unguarded branch.
+ // Since we have successfully duplicated the guarded block and this block
+ // has fewer instructions, we expect it to succeed.
+ UnguardedBlock = DuplicateInstructionsInSplitBetween(BB, UnguardedBlock,
+ Guard, UnguardedMapping);
+ assert(UnguardedBlock && "Could not create the unguarded block?");
+ DEBUG(dbgs() << "Moved guard " << *Guard << " to block "
+ << GuardedBlock->getName() << "\n");
+
+ // Some instructions before the guard may still have uses. For them, we need
+ // to create Phi nodes merging their copies in both guarded and unguarded
+ // branches. Those instructions that have no uses can be just removed.
+ SmallVector<Instruction *, 4> ToRemove;
+ for (auto BI = BB->begin(); &*BI != AfterGuard; ++BI)
+ if (!isa<PHINode>(&*BI))
+ ToRemove.push_back(&*BI);
+
+ Instruction *InsertionPoint = &*BB->getFirstInsertionPt();
+ assert(InsertionPoint && "Empty block?");
+ // Substitute with Phis & remove.
+ for (auto *Inst : reverse(ToRemove)) {
+ if (!Inst->use_empty()) {
+ PHINode *NewPN = PHINode::Create(Inst->getType(), 2);
+ NewPN->addIncoming(UnguardedMapping[Inst], UnguardedBlock);
+ NewPN->addIncoming(GuardedMapping[Inst], GuardedBlock);
+ NewPN->insertBefore(InsertionPoint);
+ Inst->replaceAllUsesWith(NewPN);
+ }
+ Inst->eraseFromParent();
+ }
+ return true;
+}
diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp
index f51d11c04cb23..340c81fed0fda 100644
--- a/lib/Transforms/Scalar/LICM.cpp
+++ b/lib/Transforms/Scalar/LICM.cpp
@@ -77,10 +77,16 @@ STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk");
STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk");
STATISTIC(NumPromoted, "Number of memory locations promoted to registers");
+/// Memory promotion is enabled by default.
static cl::opt<bool>
- DisablePromotion("disable-licm-promotion", cl::Hidden,
+ DisablePromotion("disable-licm-promotion", cl::Hidden, cl::init(false),
cl::desc("Disable memory promotion in LICM pass"));
+static cl::opt<uint32_t> MaxNumUsesTraversed(
+ "licm-max-num-uses-traversed", cl::Hidden, cl::init(8),
+ cl::desc("Max num uses visited for identifying load "
+ "invariance in loop using invariant start (default = 8)"));
+
static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI);
static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop,
const LoopSafetyInfo *SafetyInfo);
@@ -201,9 +207,9 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,
if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.SE, ORE, true))
return PreservedAnalyses::all();
- // FIXME: There is no setPreservesCFG in the new PM. When that becomes
- // available, it should be used here.
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
char LegacyLICMPass::ID = 0;
@@ -425,6 +431,29 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
continue;
}
+ // Attempt to remove floating point division out of the loop by converting
+ // it to a reciprocal multiplication.
+ if (I.getOpcode() == Instruction::FDiv &&
+ CurLoop->isLoopInvariant(I.getOperand(1)) &&
+ I.hasAllowReciprocal()) {
+ auto Divisor = I.getOperand(1);
+ auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0);
+ auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor);
+ ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags());
+ ReciprocalDivisor->insertBefore(&I);
+
+ auto Product = BinaryOperator::CreateFMul(I.getOperand(0),
+ ReciprocalDivisor);
+ Product->setFastMathFlags(I.getFastMathFlags());
+ Product->insertAfter(&I);
+ I.replaceAllUsesWith(Product);
+ I.eraseFromParent();
+
+ hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE);
+ Changed = true;
+ continue;
+ }
+
// Try hoisting the instruction out to the preheader. We can only do this
// if all of the operands of the instruction are loop invariant and if it
// is safe to hoist the instruction.
@@ -461,7 +490,10 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) {
SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow;
// Iterate over loop instructions and compute safety info.
- for (Loop::block_iterator BB = CurLoop->block_begin(),
+ // Skip header as it has been computed and stored in HeaderMayThrow.
+ // The first block in loopinfo.Blocks is guaranteed to be the header.
+ assert(Header == *CurLoop->getBlocks().begin() && "First block must be header");
+ for (Loop::block_iterator BB = std::next(CurLoop->block_begin()),
BBE = CurLoop->block_end();
(BB != BBE) && !SafetyInfo->MayThrow; ++BB)
for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end();
@@ -477,6 +509,59 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) {
SafetyInfo->BlockColors = colorEHFunclets(*Fn);
}
+// Return true if LI is invariant within scope of the loop. LI is invariant if
+// CurLoop is dominated by an invariant.start representing the same memory location
+// and size as the memory location LI loads from, and also the invariant.start
+// has no uses.
+static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT,
+ Loop *CurLoop) {
+ Value *Addr = LI->getOperand(0);
+ const DataLayout &DL = LI->getModule()->getDataLayout();
+ const uint32_t LocSizeInBits = DL.getTypeSizeInBits(
+ cast<PointerType>(Addr->getType())->getElementType());
+
+ // if the type is i8 addrspace(x)*, we know this is the type of
+ // llvm.invariant.start operand
+ auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()),
+ LI->getPointerAddressSpace());
+ unsigned BitcastsVisited = 0;
+ // Look through bitcasts until we reach the i8* type (this is invariant.start
+ // operand type).
+ while (Addr->getType() != PtrInt8Ty) {
+ auto *BC = dyn_cast<BitCastInst>(Addr);
+ // Avoid traversing high number of bitcast uses.
+ if (++BitcastsVisited > MaxNumUsesTraversed || !BC)
+ return false;
+ Addr = BC->getOperand(0);
+ }
+
+ unsigned UsesVisited = 0;
+ // Traverse all uses of the load operand value, to see if invariant.start is
+ // one of the uses, and whether it dominates the load instruction.
+ for (auto *U : Addr->users()) {
+ // Avoid traversing for Load operand with high number of users.
+ if (++UsesVisited > MaxNumUsesTraversed)
+ return false;
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
+ // If there are escaping uses of invariant.start instruction, the load maybe
+ // non-invariant.
+ if (!II || II->getIntrinsicID() != Intrinsic::invariant_start ||
+ II->hasNUsesOrMore(1))
+ continue;
+ unsigned InvariantSizeInBits =
+ cast<ConstantInt>(II->getArgOperand(0))->getSExtValue() * 8;
+ // Confirm the invariant.start location size contains the load operand size
+ // in bits. Also, the invariant.start should dominate the load, and we
+ // should not hoist the load out of a loop that contains this dominating
+ // invariant.start.
+ if (LocSizeInBits <= InvariantSizeInBits &&
+ DT->properlyDominates(II->getParent(), CurLoop->getHeader()))
+ return true;
+ }
+
+ return false;
+}
+
bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
Loop *CurLoop, AliasSetTracker *CurAST,
LoopSafetyInfo *SafetyInfo,
@@ -493,6 +578,10 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
if (LI->getMetadata(LLVMContext::MD_invariant_load))
return true;
+ // This checks for an invariant.start dominating the load.
+ if (isLoadInvariantInLoop(LI, DT, CurLoop))
+ return true;
+
// Don't hoist loads which have may-aliased stores in loop.
uint64_t Size = 0;
if (LI->getType()->isSized())
@@ -782,7 +871,7 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop,
DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I
<< "\n");
ORE->emit(OptimizationRemark(DEBUG_TYPE, "Hoisted", &I)
- << "hosting " << ore::NV("Inst", &I));
+ << "hoisting " << ore::NV("Inst", &I));
// Metadata can be dependent on conditions we are hoisting above.
// Conservatively strip all metadata on the instruction unless we were
@@ -852,6 +941,7 @@ class LoopPromoter : public LoadAndStorePromoter {
LoopInfo &LI;
DebugLoc DL;
int Alignment;
+ bool UnorderedAtomic;
AAMDNodes AATags;
Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const {
@@ -875,10 +965,11 @@ public:
SmallVectorImpl<BasicBlock *> &LEB,
SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC,
AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment,
- const AAMDNodes &AATags)
+ bool UnorderedAtomic, const AAMDNodes &AATags)
: LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA),
LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast),
- LI(li), DL(std::move(dl)), Alignment(alignment), AATags(AATags) {}
+ LI(li), DL(std::move(dl)), Alignment(alignment),
+ UnorderedAtomic(UnorderedAtomic),AATags(AATags) {}
bool isInstInList(Instruction *I,
const SmallVectorImpl<Instruction *> &) const override {
@@ -902,6 +993,8 @@ public:
Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock);
Instruction *InsertPos = LoopInsertPts[i];
StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos);
+ if (UnorderedAtomic)
+ NewSI->setOrdering(AtomicOrdering::Unordered);
NewSI->setAlignment(Alignment);
NewSI->setDebugLoc(DL);
if (AATags)
@@ -992,18 +1085,41 @@ bool llvm::promoteLoopAccessesToScalars(
// We start with an alignment of one and try to find instructions that allow
// us to prove better alignment.
unsigned Alignment = 1;
+ // Keep track of which types of access we see
+ bool SawUnorderedAtomic = false;
+ bool SawNotAtomic = false;
AAMDNodes AATags;
const DataLayout &MDL = Preheader->getModule()->getDataLayout();
+ // Do we know this object does not escape ?
+ bool IsKnownNonEscapingObject = false;
if (SafetyInfo->MayThrow) {
// If a loop can throw, we have to insert a store along each unwind edge.
// That said, we can't actually make the unwind edge explicit. Therefore,
// we have to prove that the store is dead along the unwind edge.
//
- // Currently, this code just special-cases alloca instructions.
- if (!isa<AllocaInst>(GetUnderlyingObject(SomePtr, MDL)))
- return false;
+ // If the underlying object is not an alloca, nor a pointer that does not
+ // escape, then we can not effectively prove that the store is dead along
+ // the unwind edge. i.e. the caller of this function could have ways to
+ // access the pointed object.
+ Value *Object = GetUnderlyingObject(SomePtr, MDL);
+ // If this is a base pointer we do not understand, simply bail.
+ // We only handle alloca and return value from alloc-like fn right now.
+ if (!isa<AllocaInst>(Object)) {
+ if (!isAllocLikeFn(Object, TLI))
+ return false;
+ // If this is an alloc like fn. There are more constraints we need to verify.
+ // More specifically, we must make sure that the pointer can not escape.
+ //
+ // NOTE: PointerMayBeCaptured is not enough as the pointer may have escaped
+ // even though its not captured by the enclosing function. Standard allocation
+ // functions like malloc, calloc, and operator new return values which can
+ // be assumed not to have previously escaped.
+ if (PointerMayBeCaptured(Object, true, true))
+ return false;
+ IsKnownNonEscapingObject = true;
+ }
}
// Check that all of the pointers in the alias set have the same type. We
@@ -1029,8 +1145,11 @@ bool llvm::promoteLoopAccessesToScalars(
// it.
if (LoadInst *Load = dyn_cast<LoadInst>(UI)) {
assert(!Load->isVolatile() && "AST broken");
- if (!Load->isSimple())
+ if (!Load->isUnordered())
return false;
+
+ SawUnorderedAtomic |= Load->isAtomic();
+ SawNotAtomic |= !Load->isAtomic();
if (!DereferenceableInPH)
DereferenceableInPH = isSafeToExecuteUnconditionally(
@@ -1041,9 +1160,12 @@ bool llvm::promoteLoopAccessesToScalars(
if (UI->getOperand(1) != ASIV)
continue;
assert(!Store->isVolatile() && "AST broken");
- if (!Store->isSimple())
+ if (!Store->isUnordered())
return false;
+ SawUnorderedAtomic |= Store->isAtomic();
+ SawNotAtomic |= !Store->isAtomic();
+
// If the store is guaranteed to execute, both properties are satisfied.
// We may want to check if a store is guaranteed to execute even if we
// already know that promotion is safe, since it may have higher
@@ -1096,6 +1218,12 @@ bool llvm::promoteLoopAccessesToScalars(
}
}
+ // If we found both an unordered atomic instruction and a non-atomic memory
+ // access, bail. We can't blindly promote non-atomic to atomic since we
+ // might not be able to lower the result. We can't downgrade since that
+ // would violate memory model. Also, align 0 is an error for atomics.
+ if (SawUnorderedAtomic && SawNotAtomic)
+ return false;
// If we couldn't prove we can hoist the load, bail.
if (!DereferenceableInPH)
@@ -1106,10 +1234,15 @@ bool llvm::promoteLoopAccessesToScalars(
// stores along paths which originally didn't have them without violating the
// memory model.
if (!SafeToInsertStore) {
- Value *Object = GetUnderlyingObject(SomePtr, MDL);
- SafeToInsertStore =
- (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) &&
+ // If this is a known non-escaping object, it is safe to insert the stores.
+ if (IsKnownNonEscapingObject)
+ SafeToInsertStore = true;
+ else {
+ Value *Object = GetUnderlyingObject(SomePtr, MDL);
+ SafeToInsertStore =
+ (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) &&
!PointerMayBeCaptured(Object, true, true);
+ }
}
// If we've still failed to prove we can sink the store, give up.
@@ -1134,12 +1267,15 @@ bool llvm::promoteLoopAccessesToScalars(
SmallVector<PHINode *, 16> NewPHIs;
SSAUpdater SSA(&NewPHIs);
LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks,
- InsertPts, PIC, *CurAST, *LI, DL, Alignment, AATags);
+ InsertPts, PIC, *CurAST, *LI, DL, Alignment,
+ SawUnorderedAtomic, AATags);
// Set up the preheader to have a definition of the value. It is the live-out
// value from the preheader that uses in the loop will use.
LoadInst *PreheaderLoad = new LoadInst(
SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator());
+ if (SawUnorderedAtomic)
+ PreheaderLoad->setOrdering(AtomicOrdering::Unordered);
PreheaderLoad->setAlignment(Alignment);
PreheaderLoad->setDebugLoc(DL);
if (AATags)
diff --git a/lib/Transforms/Scalar/LoadCombine.cpp b/lib/Transforms/Scalar/LoadCombine.cpp
index 389f1c595aa40..02215d3450c23 100644
--- a/lib/Transforms/Scalar/LoadCombine.cpp
+++ b/lib/Transforms/Scalar/LoadCombine.cpp
@@ -19,6 +19,7 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -53,18 +54,20 @@ struct LoadPOPPair {
class LoadCombine : public BasicBlockPass {
LLVMContext *C;
AliasAnalysis *AA;
+ DominatorTree *DT;
public:
LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
initializeLoadCombinePass(*PassRegistry::getPassRegistry());
}
-
+
using llvm::Pass::doInitialization;
bool doInitialization(Function &) override;
bool runOnBasicBlock(BasicBlock &BB) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
@@ -234,6 +237,14 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
return false;
AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
+ DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+
+ // Skip analysing dead blocks (not forward reachable from function entry).
+ if (!DT->isReachableFromEntry(&BB)) {
+ DEBUG(dbgs() << "LC: skipping unreachable " << BB.getName() <<
+ " in " << BB.getParent()->getName() << "\n");
+ return false;
+ }
IRBuilder<TargetFolder> TheBuilder(
BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
@@ -245,13 +256,17 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
bool Combined = false;
unsigned Index = 0;
for (auto &I : BB) {
- if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
+ if (I.mayThrow() || AST.containsUnknown(&I)) {
if (combineLoads(LoadMap))
Combined = true;
LoadMap.clear();
AST.clear();
continue;
}
+ if (I.mayWriteToMemory()) {
+ AST.add(&I);
+ continue;
+ }
LoadInst *LI = dyn_cast<LoadInst>(&I);
if (!LI)
continue;
diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp
index cca75a365024c..73e8ce0e1d93c 100644
--- a/lib/Transforms/Scalar/LoopDeletion.cpp
+++ b/lib/Transforms/Scalar/LoopDeletion.cpp
@@ -29,32 +29,31 @@ using namespace llvm;
STATISTIC(NumDeleted, "Number of loops deleted");
-/// isLoopDead - Determined if a loop is dead. This assumes that we've already
-/// checked for unique exit and exiting blocks, and that the code is in LCSSA
-/// form.
-bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE,
- SmallVectorImpl<BasicBlock *> &exitingBlocks,
- SmallVectorImpl<BasicBlock *> &exitBlocks,
- bool &Changed, BasicBlock *Preheader) {
- BasicBlock *exitBlock = exitBlocks[0];
-
+/// Determines if a loop is dead.
+///
+/// This assumes that we've already checked for unique exit and exiting blocks,
+/// and that the code is in LCSSA form.
+static bool isLoopDead(Loop *L, ScalarEvolution &SE,
+ SmallVectorImpl<BasicBlock *> &ExitingBlocks,
+ BasicBlock *ExitBlock, bool &Changed,
+ BasicBlock *Preheader) {
// Make sure that all PHI entries coming from the loop are loop invariant.
// Because the code is in LCSSA form, any values used outside of the loop
// must pass through a PHI in the exit block, meaning that this check is
// sufficient to guarantee that no loop-variant values are used outside
// of the loop.
- BasicBlock::iterator BI = exitBlock->begin();
+ BasicBlock::iterator BI = ExitBlock->begin();
bool AllEntriesInvariant = true;
bool AllOutgoingValuesSame = true;
while (PHINode *P = dyn_cast<PHINode>(BI)) {
- Value *incoming = P->getIncomingValueForBlock(exitingBlocks[0]);
+ Value *incoming = P->getIncomingValueForBlock(ExitingBlocks[0]);
// Make sure all exiting blocks produce the same incoming value for the exit
// block. If there are different incoming values for different exiting
// blocks, then it is impossible to statically determine which value should
// be used.
AllOutgoingValuesSame =
- all_of(makeArrayRef(exitingBlocks).slice(1), [&](BasicBlock *BB) {
+ all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) {
return incoming == P->getIncomingValueForBlock(BB);
});
@@ -78,33 +77,37 @@ bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE,
// Make sure that no instructions in the block have potential side-effects.
// This includes instructions that could write to memory, and loads that are
- // marked volatile. This could be made more aggressive by using aliasing
- // information to identify readonly and readnone calls.
- for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end();
- LI != LE; ++LI) {
- for (Instruction &I : **LI) {
- if (I.mayHaveSideEffects())
- return false;
- }
- }
-
+ // marked volatile.
+ for (auto &I : L->blocks())
+ if (any_of(*I, [](Instruction &I) { return I.mayHaveSideEffects(); }))
+ return false;
return true;
}
-/// Remove dead loops, by which we mean loops that do not impact the observable
-/// behavior of the program other than finite running time. Note we do ensure
-/// that this never remove a loop that might be infinite, as doing so could
-/// change the halting/non-halting nature of a program. NOTE: This entire
-/// process relies pretty heavily on LoopSimplify and LCSSA in order to make
-/// various safety checks work.
-bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
- LoopInfo &loopInfo) {
+/// Remove a loop if it is dead.
+///
+/// A loop is considered dead if it does not impact the observable behavior of
+/// the program other than finite running time. This never removes a loop that
+/// might be infinite, as doing so could change the halting/non-halting nature
+/// of a program.
+///
+/// This entire process relies pretty heavily on LoopSimplify form and LCSSA in
+/// order to make various safety checks work.
+///
+/// \returns true if any changes were made. This may mutate the loop even if it
+/// is unable to delete it due to hoisting trivially loop invariant
+/// instructions out of the loop.
+///
+/// This also updates the relevant analysis information in \p DT, \p SE, and \p
+/// LI. It also updates the loop PM if an updater struct is provided.
+static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
+ LoopInfo &LI, LPMUpdater *Updater = nullptr) {
assert(L->isLCSSAForm(DT) && "Expected LCSSA!");
// We can only remove the loop if there is a preheader that we can
// branch from after removing it.
- BasicBlock *preheader = L->getLoopPreheader();
- if (!preheader)
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader)
return false;
// If LoopSimplify form is not available, stay out of trouble.
@@ -116,22 +119,20 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
if (L->begin() != L->end())
return false;
- SmallVector<BasicBlock *, 4> exitingBlocks;
- L->getExitingBlocks(exitingBlocks);
-
- SmallVector<BasicBlock *, 4> exitBlocks;
- L->getUniqueExitBlocks(exitBlocks);
+ SmallVector<BasicBlock *, 4> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
// We require that the loop only have a single exit block. Otherwise, we'd
// be in the situation of needing to be able to solve statically which exit
// block will be branched to, or trying to preserve the branching logic in
// a loop invariant manner.
- if (exitBlocks.size() != 1)
+ BasicBlock *ExitBlock = L->getUniqueExitBlock();
+ if (!ExitBlock)
return false;
// Finally, we have to check that the loop really is dead.
bool Changed = false;
- if (!isLoopDead(L, SE, exitingBlocks, exitBlocks, Changed, preheader))
+ if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader))
return Changed;
// Don't remove loops for which we can't solve the trip count.
@@ -142,11 +143,13 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
// Now that we know the removal is safe, remove the loop by changing the
// branch from the preheader to go to the single exit block.
- BasicBlock *exitBlock = exitBlocks[0];
-
+ //
// Because we're deleting a large chunk of code at once, the sequence in which
- // we remove things is very important to avoid invalidation issues. Don't
- // mess with this unless you have good reason and know what you're doing.
+ // we remove things is very important to avoid invalidation issues.
+
+ // If we have an LPM updater, tell it about the loop being removed.
+ if (Updater)
+ Updater->markLoopAsDeleted(*L);
// Tell ScalarEvolution that the loop is deleted. Do this before
// deleting the loop so that ScalarEvolution can look at the loop
@@ -154,19 +157,19 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
SE.forgetLoop(L);
// Connect the preheader directly to the exit block.
- TerminatorInst *TI = preheader->getTerminator();
- TI->replaceUsesOfWith(L->getHeader(), exitBlock);
+ TerminatorInst *TI = Preheader->getTerminator();
+ TI->replaceUsesOfWith(L->getHeader(), ExitBlock);
// Rewrite phis in the exit block to get their inputs from
// the preheader instead of the exiting block.
- BasicBlock *exitingBlock = exitingBlocks[0];
- BasicBlock::iterator BI = exitBlock->begin();
+ BasicBlock *ExitingBlock = ExitingBlocks[0];
+ BasicBlock::iterator BI = ExitBlock->begin();
while (PHINode *P = dyn_cast<PHINode>(BI)) {
- int j = P->getBasicBlockIndex(exitingBlock);
+ int j = P->getBasicBlockIndex(ExitingBlock);
assert(j >= 0 && "Can't find exiting block in exit block's phi node!");
- P->setIncomingBlock(j, preheader);
- for (unsigned i = 1; i < exitingBlocks.size(); ++i)
- P->removeIncomingValue(exitingBlocks[i]);
+ P->setIncomingBlock(j, Preheader);
+ for (unsigned i = 1; i < ExitingBlocks.size(); ++i)
+ P->removeIncomingValue(ExitingBlocks[i]);
++BI;
}
@@ -175,11 +178,11 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
SmallVector<DomTreeNode*, 8> ChildNodes;
for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end();
LI != LE; ++LI) {
- // Move all of the block's children to be children of the preheader, which
+ // Move all of the block's children to be children of the Preheader, which
// allows us to remove the domtree entry for the block.
ChildNodes.insert(ChildNodes.begin(), DT[*LI]->begin(), DT[*LI]->end());
for (DomTreeNode *ChildNode : ChildNodes) {
- DT.changeImmediateDominator(ChildNode, DT[preheader]);
+ DT.changeImmediateDominator(ChildNode, DT[Preheader]);
}
ChildNodes.clear();
@@ -204,22 +207,19 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
SmallPtrSet<BasicBlock *, 8> blocks;
blocks.insert(L->block_begin(), L->block_end());
for (BasicBlock *BB : blocks)
- loopInfo.removeBlock(BB);
+ LI.removeBlock(BB);
// The last step is to update LoopInfo now that we've eliminated this loop.
- loopInfo.markAsRemoved(L);
- Changed = true;
-
+ LI.markAsRemoved(L);
++NumDeleted;
- return Changed;
+ return true;
}
PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
- LPMUpdater &) {
- bool Changed = runImpl(&L, AR.DT, AR.SE, AR.LI);
- if (!Changed)
+ LPMUpdater &Updater) {
+ if (!deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, &Updater))
return PreservedAnalyses::all();
return getLoopPassPreservedAnalyses();
@@ -257,8 +257,7 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) {
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopInfo &loopInfo = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LoopDeletionPass Impl;
- return Impl.runImpl(L, DT, SE, loopInfo);
+ return deleteLoopIfDead(L, DT, SE, LI);
}
diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp
index 19716b28ad66a..3624bba103450 100644
--- a/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -812,29 +812,29 @@ private:
const RuntimePointerChecking *RtPtrChecking) {
SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
- std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks),
- [&](const RuntimePointerChecking::PointerCheck &Check) {
- for (unsigned PtrIdx1 : Check.first->Members)
- for (unsigned PtrIdx2 : Check.second->Members)
- // Only include this check if there is a pair of pointers
- // that require checking and the pointers fall into
- // separate partitions.
- //
- // (Note that we already know at this point that the two
- // pointer groups need checking but it doesn't follow
- // that each pair of pointers within the two groups need
- // checking as well.
- //
- // In other words we don't want to include a check just
- // because there is a pair of pointers between the two
- // pointer groups that require checks and a different
- // pair whose pointers fall into different partitions.)
- if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) &&
- !RuntimePointerChecking::arePointersInSamePartition(
- PtrToPartition, PtrIdx1, PtrIdx2))
- return true;
- return false;
- });
+ copy_if(AllChecks, std::back_inserter(Checks),
+ [&](const RuntimePointerChecking::PointerCheck &Check) {
+ for (unsigned PtrIdx1 : Check.first->Members)
+ for (unsigned PtrIdx2 : Check.second->Members)
+ // Only include this check if there is a pair of pointers
+ // that require checking and the pointers fall into
+ // separate partitions.
+ //
+ // (Note that we already know at this point that the two
+ // pointer groups need checking but it doesn't follow
+ // that each pair of pointers within the two groups need
+ // checking as well.
+ //
+ // In other words we don't want to include a check just
+ // because there is a pair of pointers between the two
+ // pointer groups that require checks and a different
+ // pair whose pointers fall into different partitions.)
+ if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) &&
+ !RuntimePointerChecking::arePointersInSamePartition(
+ PtrToPartition, PtrIdx1, PtrIdx2))
+ return true;
+ return false;
+ });
return Checks;
}
diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 5fec51c095d09..946d85d7360fd 100644
--- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -236,9 +236,9 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
ApplyCodeSizeHeuristics =
L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs;
- HasMemset = TLI->has(LibFunc::memset);
- HasMemsetPattern = TLI->has(LibFunc::memset_pattern16);
- HasMemcpy = TLI->has(LibFunc::memcpy);
+ HasMemset = TLI->has(LibFunc_memset);
+ HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
+ HasMemcpy = TLI->has(LibFunc_memcpy);
if (HasMemset || HasMemsetPattern || HasMemcpy)
if (SE->hasLoopInvariantBackedgeTakenCount(L))
@@ -823,7 +823,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
Module *M = TheStore->getModule();
Value *MSP =
M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(),
- Int8PtrTy, Int8PtrTy, IntPtr, (void *)nullptr);
+ Int8PtrTy, Int8PtrTy, IntPtr);
inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI);
// Otherwise we should form a memset_pattern16. PatternValue is known to be
diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp
index 69102d10ff60f..28e71ca05436c 100644
--- a/lib/Transforms/Scalar/LoopInstSimplify.cpp
+++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp
@@ -189,7 +189,9 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI))
return PreservedAnalyses::all();
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
char LoopInstSimplifyLegacyPass::ID = 0;
diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp
index e9f84edd1cbff..9f3875a3027f4 100644
--- a/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -39,7 +39,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
-#include "llvm/Transforms/Utils/SSAUpdater.h"
+
using namespace llvm;
#define DEBUG_TYPE "loop-interchange"
diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 8fb580183e307..cf63cb660db8c 100644
--- a/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -20,13 +20,14 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/Scalar/LoopLoadElimination.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
@@ -45,9 +46,9 @@
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
-#include <forward_list>
-#include <cassert>
#include <algorithm>
+#include <cassert>
+#include <forward_list>
#include <set>
#include <tuple>
#include <utility>
@@ -373,15 +374,15 @@ public:
const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
- std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks),
- [&](const RuntimePointerChecking::PointerCheck &Check) {
- for (auto PtrIdx1 : Check.first->Members)
- for (auto PtrIdx2 : Check.second->Members)
- if (needsChecking(PtrIdx1, PtrIdx2,
- PtrsWrittenOnFwdingPath, CandLoadPtrs))
- return true;
- return false;
- });
+ copy_if(AllChecks, std::back_inserter(Checks),
+ [&](const RuntimePointerChecking::PointerCheck &Check) {
+ for (auto PtrIdx1 : Check.first->Members)
+ for (auto PtrIdx2 : Check.second->Members)
+ if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
+ CandLoadPtrs))
+ return true;
+ return false;
+ });
DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n");
DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
@@ -558,6 +559,32 @@ private:
PredicatedScalarEvolution PSE;
};
+static bool
+eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT,
+ function_ref<const LoopAccessInfo &(Loop &)> GetLAI) {
+ // Build up a worklist of inner-loops to transform to avoid iterator
+ // invalidation.
+ // FIXME: This logic comes from other passes that actually change the loop
+ // nest structure. It isn't clear this is necessary (or useful) for a pass
+ // which merely optimizes the use of loads in a loop.
+ SmallVector<Loop *, 8> Worklist;
+
+ for (Loop *TopLevelLoop : LI)
+ for (Loop *L : depth_first(TopLevelLoop))
+ // We only handle inner-most loops.
+ if (L->empty())
+ Worklist.push_back(L);
+
+ // Now walk the identified inner loops.
+ bool Changed = false;
+ for (Loop *L : Worklist) {
+ // The actual work is performed by LoadEliminationForLoop.
+ LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT);
+ Changed |= LEL.processLoop();
+ }
+ return Changed;
+}
+
/// \brief The pass. Most of the work is delegated to the per-loop
/// LoadEliminationForLoop class.
class LoopLoadElimination : public FunctionPass {
@@ -570,32 +597,14 @@ public:
if (skipFunction(F))
return false;
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-
- // Build up a worklist of inner-loops to vectorize. This is necessary as the
- // act of distributing a loop creates new loops and can invalidate iterators
- // across the loops.
- SmallVector<Loop *, 8> Worklist;
-
- for (Loop *TopLevelLoop : *LI)
- for (Loop *L : depth_first(TopLevelLoop))
- // We only handle inner-most loops.
- if (L->empty())
- Worklist.push_back(L);
-
- // Now walk the identified inner loops.
- bool Changed = false;
- for (Loop *L : Worklist) {
- const LoopAccessInfo &LAI = LAA->getInfo(L);
- // The actual work is performed by LoadEliminationForLoop.
- LoadEliminationForLoop LEL(L, LI, LAI, DT);
- Changed |= LEL.processLoop();
- }
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
// Process each loop nest in the function.
- return Changed;
+ return eliminateLoadsAcrossLoops(
+ F, LI, DT,
+ [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); });
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -631,4 +640,28 @@ FunctionPass *createLoopLoadEliminationPass() {
return new LoopLoadElimination();
}
+PreservedAnalyses LoopLoadEliminationPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
+ auto &LI = AM.getResult<LoopAnalysis>(F);
+ auto &TTI = AM.getResult<TargetIRAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+
+ auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
+ bool Changed = eliminateLoadsAcrossLoops(
+ F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & {
+ LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI};
+ return LAM.getResult<LoopAccessAnalysis>(L, AR);
+ });
+
+ if (!Changed)
+ return PreservedAnalyses::all();
+
+ PreservedAnalyses PA;
+ return PA;
+}
+
} // end namespace llvm
diff --git a/lib/Transforms/Scalar/LoopPassManager.cpp b/lib/Transforms/Scalar/LoopPassManager.cpp
index 028f4bba8b1d8..10f6fcdcfdb7c 100644
--- a/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -42,6 +42,13 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
break;
}
+#ifndef NDEBUG
+ // Verify the loop structure and LCSSA form before visiting the loop.
+ L.verifyLoop();
+ assert(L.isRecursivelyLCSSAForm(AR.DT, AR.LI) &&
+ "Loops must remain in LCSSA form!");
+#endif
+
// Update the analysis manager as each pass runs and potentially
// invalidates analyses.
AM.invalidate(L, PassPA);
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp
new file mode 100644
index 0000000000000..0ce6044293261
--- /dev/null
+++ b/lib/Transforms/Scalar/LoopPredication.cpp
@@ -0,0 +1,282 @@
+//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// The LoopPredication pass tries to convert loop variant range checks to loop
+// invariant by widening checks across loop iterations. For example, it will
+// convert
+//
+// for (i = 0; i < n; i++) {
+// guard(i < len);
+// ...
+// }
+//
+// to
+//
+// for (i = 0; i < n; i++) {
+// guard(n - 1 < len);
+// ...
+// }
+//
+// After this transformation the condition of the guard is loop invariant, so
+// loop-unswitch can later unswitch the loop by this condition which basically
+// predicates the loop by the widened condition:
+//
+// if (n - 1 < len)
+// for (i = 0; i < n; i++) {
+// ...
+// }
+// else
+// deoptimize
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LoopPredication.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+#define DEBUG_TYPE "loop-predication"
+
+using namespace llvm;
+
+namespace {
+class LoopPredication {
+ ScalarEvolution *SE;
+
+ Loop *L;
+ const DataLayout *DL;
+ BasicBlock *Preheader;
+
+ Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
+ IRBuilder<> &Builder);
+ bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
+
+public:
+ LoopPredication(ScalarEvolution *SE) : SE(SE){};
+ bool runOnLoop(Loop *L);
+};
+
+class LoopPredicationLegacyPass : public LoopPass {
+public:
+ static char ID;
+ LoopPredicationLegacyPass() : LoopPass(ID) {
+ initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ getLoopAnalysisUsage(AU);
+ }
+
+ bool runOnLoop(Loop *L, LPPassManager &LPM) override {
+ if (skipLoop(L))
+ return false;
+ auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+ LoopPredication LP(SE);
+ return LP.runOnLoop(L);
+ }
+};
+
+char LoopPredicationLegacyPass::ID = 0;
+} // end namespace llvm
+
+INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
+ "Loop predication", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopPass)
+INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
+ "Loop predication", false, false)
+
+Pass *llvm::createLoopPredicationPass() {
+ return new LoopPredicationLegacyPass();
+}
+
+PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &U) {
+ LoopPredication LP(&AR.SE);
+ if (!LP.runOnLoop(&L))
+ return PreservedAnalyses::all();
+
+ return getLoopPassPreservedAnalyses();
+}
+
+/// If ICI can be widened to a loop invariant condition emits the loop
+/// invariant condition in the loop preheader and return it, otherwise
+/// returns None.
+Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
+ SCEVExpander &Expander,
+ IRBuilder<> &Builder) {
+ DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
+ DEBUG(ICI->dump());
+
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+ const SCEV *LHSS = SE->getSCEV(LHS);
+ if (isa<SCEVCouldNotCompute>(LHSS))
+ return None;
+ const SCEV *RHSS = SE->getSCEV(RHS);
+ if (isa<SCEVCouldNotCompute>(RHSS))
+ return None;
+
+ // Canonicalize RHS to be loop invariant bound, LHS - a loop computable index
+ if (SE->isLoopInvariant(LHSS, L)) {
+ std::swap(LHS, RHS);
+ std::swap(LHSS, RHSS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+ if (!SE->isLoopInvariant(RHSS, L) || !isSafeToExpand(RHSS, *SE))
+ return None;
+
+ const SCEVAddRecExpr *IndexAR = dyn_cast<SCEVAddRecExpr>(LHSS);
+ if (!IndexAR || IndexAR->getLoop() != L)
+ return None;
+
+ DEBUG(dbgs() << "IndexAR: ");
+ DEBUG(IndexAR->dump());
+
+ bool IsIncreasing = false;
+ if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing))
+ return None;
+
+ // If the predicate is increasing the condition can change from false to true
+ // as the loop progresses, in this case take the value on the first iteration
+ // for the widened check. Otherwise the condition can change from true to
+ // false as the loop progresses, so take the value on the last iteration.
+ const SCEV *NewLHSS = IsIncreasing
+ ? IndexAR->getStart()
+ : SE->getSCEVAtScope(IndexAR, L->getParentLoop());
+ if (NewLHSS == IndexAR) {
+ DEBUG(dbgs() << "Can't compute NewLHSS!\n");
+ return None;
+ }
+
+ DEBUG(dbgs() << "NewLHSS: ");
+ DEBUG(NewLHSS->dump());
+
+ if (!SE->isLoopInvariant(NewLHSS, L) || !isSafeToExpand(NewLHSS, *SE))
+ return None;
+
+ DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n");
+
+ Type *Ty = LHS->getType();
+ Instruction *InsertAt = Preheader->getTerminator();
+ assert(Ty == RHS->getType() && "icmp operands have different types?");
+ Value *NewLHS = Expander.expandCodeFor(NewLHSS, Ty, InsertAt);
+ Value *NewRHS = Expander.expandCodeFor(RHSS, Ty, InsertAt);
+ return Builder.CreateICmp(Pred, NewLHS, NewRHS);
+}
+
+bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
+ SCEVExpander &Expander) {
+ DEBUG(dbgs() << "Processing guard:\n");
+ DEBUG(Guard->dump());
+
+ IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
+
+ // The guard condition is expected to be in form of:
+ // cond1 && cond2 && cond3 ...
+ // Iterate over subconditions looking for for icmp conditions which can be
+ // widened across loop iterations. Widening these conditions remember the
+ // resulting list of subconditions in Checks vector.
+ SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
+ SmallPtrSet<Value *, 4> Visited;
+
+ SmallVector<Value *, 4> Checks;
+
+ unsigned NumWidened = 0;
+ do {
+ Value *Condition = Worklist.pop_back_val();
+ if (!Visited.insert(Condition).second)
+ continue;
+
+ Value *LHS, *RHS;
+ using namespace llvm::PatternMatch;
+ if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
+ Worklist.push_back(LHS);
+ Worklist.push_back(RHS);
+ continue;
+ }
+
+ if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
+ if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
+ Checks.push_back(NewRangeCheck.getValue());
+ NumWidened++;
+ continue;
+ }
+ }
+
+ // Save the condition as is if we can't widen it
+ Checks.push_back(Condition);
+ } while (Worklist.size() != 0);
+
+ if (NumWidened == 0)
+ return false;
+
+ // Emit the new guard condition
+ Builder.SetInsertPoint(Guard);
+ Value *LastCheck = nullptr;
+ for (auto *Check : Checks)
+ if (!LastCheck)
+ LastCheck = Check;
+ else
+ LastCheck = Builder.CreateAnd(LastCheck, Check);
+ Guard->setOperand(0, LastCheck);
+
+ DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
+ return true;
+}
+
+bool LoopPredication::runOnLoop(Loop *Loop) {
+ L = Loop;
+
+ DEBUG(dbgs() << "Analyzing ");
+ DEBUG(L->dump());
+
+ Module *M = L->getHeader()->getModule();
+
+ // There is nothing to do if the module doesn't use guards
+ auto *GuardDecl =
+ M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
+ if (!GuardDecl || GuardDecl->use_empty())
+ return false;
+
+ DL = &M->getDataLayout();
+
+ Preheader = L->getLoopPreheader();
+ if (!Preheader)
+ return false;
+
+ // Collect all the guards into a vector and process later, so as not
+ // to invalidate the instruction iterator.
+ SmallVector<IntrinsicInst *, 4> Guards;
+ for (const auto BB : L->blocks())
+ for (auto &I : *BB)
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::experimental_guard)
+ Guards.push_back(II);
+
+ SCEVExpander Expander(*SE, *DL, "loop-predication");
+
+ bool Changed = false;
+ for (auto *Guard : Guards)
+ Changed |= widenGuardConditions(Guard, Expander);
+
+ return Changed;
+}
diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp
index cc83069d5f52d..e5689368de80d 100644
--- a/lib/Transforms/Scalar/LoopRotation.cpp
+++ b/lib/Transforms/Scalar/LoopRotation.cpp
@@ -79,7 +79,8 @@ private:
/// to merge the two values. Do this now.
static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,
BasicBlock *OrigPreheader,
- ValueToValueMapTy &ValueMap) {
+ ValueToValueMapTy &ValueMap,
+ SmallVectorImpl<PHINode*> *InsertedPHIs) {
// Remove PHI node entries that are no longer live.
BasicBlock::iterator I, E = OrigHeader->end();
for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I)
@@ -87,7 +88,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,
// Now fix up users of the instructions in OrigHeader, inserting PHI nodes
// as necessary.
- SSAUpdater SSA;
+ SSAUpdater SSA(InsertedPHIs);
for (I = OrigHeader->begin(); I != E; ++I) {
Value *OrigHeaderVal = &*I;
@@ -174,6 +175,38 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,
}
}
+/// Propagate dbg.value intrinsics through the newly inserted Phis.
+static void insertDebugValues(BasicBlock *OrigHeader,
+ SmallVectorImpl<PHINode*> &InsertedPHIs) {
+ ValueToValueMapTy DbgValueMap;
+
+ // Map existing PHI nodes to their dbg.values.
+ for (auto &I : *OrigHeader) {
+ if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) {
+ if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation()))
+ DbgValueMap.insert({Loc, DbgII});
+ }
+ }
+
+ // Then iterate through the new PHIs and look to see if they use one of the
+ // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will
+ // propagate the info through the new PHI.
+ LLVMContext &C = OrigHeader->getContext();
+ for (auto PHI : InsertedPHIs) {
+ for (auto VI : PHI->operand_values()) {
+ auto V = DbgValueMap.find(VI);
+ if (V != DbgValueMap.end()) {
+ auto *DbgII = cast<DbgInfoIntrinsic>(V->second);
+ Instruction *NewDbgII = DbgII->clone();
+ auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI));
+ NewDbgII->setOperand(0, PhiMAV);
+ BasicBlock *Parent = PHI->getParent();
+ NewDbgII->insertBefore(Parent->getFirstNonPHIOrDbgOrLifetime());
+ }
+ }
+ }
+}
+
/// Rotate loop LP. Return true if the loop is rotated.
///
/// \param SimplifiedLatch is true if the latch was just folded into the final
@@ -347,9 +380,18 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// remove the corresponding incoming values from the PHI nodes in OrigHeader.
LoopEntryBranch->eraseFromParent();
+
+ SmallVector<PHINode*, 2> InsertedPHIs;
// If there were any uses of instructions in the duplicated block outside the
// loop, update them, inserting PHI nodes as required
- RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap);
+ RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap,
+ &InsertedPHIs);
+
+ // Attach dbg.value intrinsics to the new phis if that phi uses a value that
+ // previously had debug metadata attached. This keeps the debug info
+ // up-to-date in the loop body.
+ if (!InsertedPHIs.empty())
+ insertDebugValues(OrigHeader, InsertedPHIs);
// NewHeader is now the header of the loop.
L->moveToHeader(NewHeader);
@@ -634,6 +676,7 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM,
bool Changed = LR.processLoop(&L);
if (!Changed)
return PreservedAnalyses::all();
+
return getLoopPassPreservedAnalyses();
}
diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
index 16061212ba384..a5a81c33a8ebd 100644
--- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
+++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
@@ -69,6 +69,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM,
LPMUpdater &) {
if (!simplifyLoopCFG(L, AR.DT, AR.LI))
return PreservedAnalyses::all();
+
return getLoopPassPreservedAnalyses();
}
diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp
index f3f415275c0e0..c9d55b4594fea 100644
--- a/lib/Transforms/Scalar/LoopSink.cpp
+++ b/lib/Transforms/Scalar/LoopSink.cpp
@@ -1,4 +1,4 @@
-//===-- LoopSink.cpp - Loop Sink Pass ------------------------===//
+//===-- LoopSink.cpp - Loop Sink Pass -------------------------------------===//
//
// The LLVM Compiler Infrastructure
//
@@ -28,8 +28,10 @@
// InsertBBs = UseBBs - DomBBs + BB
// For BB in InsertBBs:
// Insert I at BB's beginning
+//
//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/Scalar/LoopSink.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"
@@ -297,6 +299,42 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI,
return Changed;
}
+PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) {
+ LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
+ // Nothing to do if there are no loops.
+ if (LI.empty())
+ return PreservedAnalyses::all();
+
+ AAResults &AA = FAM.getResult<AAManager>(F);
+ DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+ BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
+
+ // We want to do a postorder walk over the loops. Since loops are a tree this
+ // is equivalent to a reversed preorder walk and preorder is easy to compute
+ // without recursion. Since we reverse the preorder, we will visit siblings
+ // in reverse program order. This isn't expected to matter at all but is more
+ // consistent with sinking algorithms which generally work bottom-up.
+ SmallVector<Loop *, 4> PreorderLoops = LI.getLoopsInPreorder();
+
+ bool Changed = false;
+ do {
+ Loop &L = *PreorderLoops.pop_back_val();
+
+ // Note that we don't pass SCEV here because it is only used to invalidate
+ // loops in SCEV and we don't preserve (or request) SCEV at all making that
+ // unnecessary.
+ Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI,
+ /*ScalarEvolution*/ nullptr);
+ } while (!PreorderLoops.empty());
+
+ if (!Changed)
+ return PreservedAnalyses::all();
+
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
+}
+
namespace {
struct LegacyLoopSinkPass : public LoopPass {
static char ID;
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 194587a85e7c1..af137f6faa634 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -129,6 +129,17 @@ static cl::opt<bool> EnablePhiElim(
"enable-lsr-phielim", cl::Hidden, cl::init(true),
cl::desc("Enable LSR phi elimination"));
+// The flag adds instruction count to solutions cost comparision.
+static cl::opt<bool> InsnsCost(
+ "lsr-insns-cost", cl::Hidden, cl::init(false),
+ cl::desc("Add instruction count to a LSR cost model"));
+
+// Flag to choose how to narrow complex lsr solution
+static cl::opt<bool> LSRExpNarrow(
+ "lsr-exp-narrow", cl::Hidden, cl::init(false),
+ cl::desc("Narrow LSR complex solution using"
+ " expectation of registers number"));
+
#ifndef NDEBUG
// Stress test IV chain generation.
static cl::opt<bool> StressIVChain(
@@ -181,10 +192,11 @@ void RegSortData::print(raw_ostream &OS) const {
OS << "[NumUses=" << UsedByIndices.count() << ']';
}
-LLVM_DUMP_METHOD
-void RegSortData::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void RegSortData::dump() const {
print(errs()); errs() << '\n';
}
+#endif
namespace {
@@ -295,9 +307,13 @@ struct Formula {
/// canonical representation of a formula is
/// 1. BaseRegs.size > 1 implies ScaledReg != NULL and
/// 2. ScaledReg != NULL implies Scale != 1 || !BaseRegs.empty().
+ /// 3. The reg containing recurrent expr related with currect loop in the
+ /// formula should be put in the ScaledReg.
/// #1 enforces that the scaled register is always used when at least two
/// registers are needed by the formula: e.g., reg1 + reg2 is reg1 + 1 * reg2.
/// #2 enforces that 1 * reg is reg.
+ /// #3 ensures invariant regs with respect to current loop can be combined
+ /// together in LSR codegen.
/// This invariant can be temporarly broken while building a formula.
/// However, every formula inserted into the LSRInstance must be in canonical
/// form.
@@ -318,12 +334,14 @@ struct Formula {
void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE);
- bool isCanonical() const;
+ bool isCanonical(const Loop &L) const;
- void canonicalize();
+ void canonicalize(const Loop &L);
bool unscale();
+ bool hasZeroEnd() const;
+
size_t getNumRegs() const;
Type *getType() const;
@@ -410,16 +428,35 @@ void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) {
BaseRegs.push_back(Sum);
HasBaseReg = true;
}
- canonicalize();
+ canonicalize(*L);
}
/// \brief Check whether or not this formula statisfies the canonical
/// representation.
/// \see Formula::BaseRegs.
-bool Formula::isCanonical() const {
- if (ScaledReg)
- return Scale != 1 || !BaseRegs.empty();
- return BaseRegs.size() <= 1;
+bool Formula::isCanonical(const Loop &L) const {
+ if (!ScaledReg)
+ return BaseRegs.size() <= 1;
+
+ if (Scale != 1)
+ return true;
+
+ if (Scale == 1 && BaseRegs.empty())
+ return false;
+
+ const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg);
+ if (SAR && SAR->getLoop() == &L)
+ return true;
+
+ // If ScaledReg is not a recurrent expr, or it is but its loop is not current
+ // loop, meanwhile BaseRegs contains a recurrent expr reg related with current
+ // loop, we want to swap the reg in BaseRegs with ScaledReg.
+ auto I =
+ find_if(make_range(BaseRegs.begin(), BaseRegs.end()), [&](const SCEV *S) {
+ return isa<const SCEVAddRecExpr>(S) &&
+ (cast<SCEVAddRecExpr>(S)->getLoop() == &L);
+ });
+ return I == BaseRegs.end();
}
/// \brief Helper method to morph a formula into its canonical representation.
@@ -428,21 +465,33 @@ bool Formula::isCanonical() const {
/// field. Otherwise, we would have to do special cases everywhere in LSR
/// to treat reg1 + reg2 + ... the same way as reg1 + 1*reg2 + ...
/// On the other hand, 1*reg should be canonicalized into reg.
-void Formula::canonicalize() {
- if (isCanonical())
+void Formula::canonicalize(const Loop &L) {
+ if (isCanonical(L))
return;
// So far we did not need this case. This is easy to implement but it is
// useless to maintain dead code. Beside it could hurt compile time.
assert(!BaseRegs.empty() && "1*reg => reg, should not be needed.");
+
// Keep the invariant sum in BaseRegs and one of the variant sum in ScaledReg.
- ScaledReg = BaseRegs.back();
- BaseRegs.pop_back();
- Scale = 1;
- size_t BaseRegsSize = BaseRegs.size();
- size_t Try = 0;
- // If ScaledReg is an invariant, try to find a variant expression.
- while (Try < BaseRegsSize && !isa<SCEVAddRecExpr>(ScaledReg))
- std::swap(ScaledReg, BaseRegs[Try++]);
+ if (!ScaledReg) {
+ ScaledReg = BaseRegs.back();
+ BaseRegs.pop_back();
+ Scale = 1;
+ }
+
+ // If ScaledReg is an invariant with respect to L, find the reg from
+ // BaseRegs containing the recurrent expr related with Loop L. Swap the
+ // reg with ScaledReg.
+ const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg);
+ if (!SAR || SAR->getLoop() != &L) {
+ auto I = find_if(make_range(BaseRegs.begin(), BaseRegs.end()),
+ [&](const SCEV *S) {
+ return isa<const SCEVAddRecExpr>(S) &&
+ (cast<SCEVAddRecExpr>(S)->getLoop() == &L);
+ });
+ if (I != BaseRegs.end())
+ std::swap(ScaledReg, *I);
+ }
}
/// \brief Get rid of the scale in the formula.
@@ -458,6 +507,14 @@ bool Formula::unscale() {
return true;
}
+bool Formula::hasZeroEnd() const {
+ if (UnfoldedOffset || BaseOffset)
+ return false;
+ if (BaseRegs.size() != 1 || ScaledReg)
+ return false;
+ return true;
+}
+
/// Return the total number of register operands used by this formula. This does
/// not include register uses implied by non-constant addrec strides.
size_t Formula::getNumRegs() const {
@@ -534,10 +591,11 @@ void Formula::print(raw_ostream &OS) const {
}
}
-LLVM_DUMP_METHOD
-void Formula::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void Formula::dump() const {
print(errs()); errs() << '\n';
}
+#endif
/// Return true if the given addrec can be sign-extended without changing its
/// value.
@@ -711,7 +769,7 @@ static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) {
static bool isAddressUse(Instruction *Inst, Value *OperandVal) {
bool isAddress = isa<LoadInst>(Inst);
if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
- if (SI->getOperand(1) == OperandVal)
+ if (SI->getPointerOperand() == OperandVal)
isAddress = true;
} else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
// Addressing modes can also be folded into prefetches and a variety
@@ -723,6 +781,12 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) {
isAddress = true;
break;
}
+ } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) {
+ if (RMW->getPointerOperand() == OperandVal)
+ isAddress = true;
+ } else if (AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) {
+ if (CmpX->getPointerOperand() == OperandVal)
+ isAddress = true;
}
return isAddress;
}
@@ -735,6 +799,10 @@ static MemAccessTy getAccessType(const Instruction *Inst) {
AccessTy.AddrSpace = SI->getPointerAddressSpace();
} else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
AccessTy.AddrSpace = LI->getPointerAddressSpace();
+ } else if (const AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) {
+ AccessTy.AddrSpace = RMW->getPointerAddressSpace();
+ } else if (const AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) {
+ AccessTy.AddrSpace = CmpX->getPointerAddressSpace();
}
// All pointers have the same requirements, so canonicalize them to an
@@ -875,7 +943,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
const LSRUse &LU, const Formula &F);
// Get the cost of the scaling factor used in F for LU.
static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
- const LSRUse &LU, const Formula &F);
+ const LSRUse &LU, const Formula &F,
+ const Loop &L);
namespace {
@@ -883,6 +952,7 @@ namespace {
class Cost {
/// TODO: Some of these could be merged. Also, a lexical ordering
/// isn't always optimal.
+ unsigned Insns;
unsigned NumRegs;
unsigned AddRecCost;
unsigned NumIVMuls;
@@ -893,8 +963,8 @@ class Cost {
public:
Cost()
- : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0),
- SetupCost(0), ScaleCost(0) {}
+ : Insns(0), NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0),
+ ImmCost(0), SetupCost(0), ScaleCost(0) {}
bool operator<(const Cost &Other) const;
@@ -903,9 +973,9 @@ public:
#ifndef NDEBUG
// Once any of the metrics loses, they must all remain losers.
bool isValid() {
- return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
+ return ((Insns | NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
| ImmCost | SetupCost | ScaleCost) != ~0u)
- || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
+ || ((Insns & NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
& ImmCost & SetupCost & ScaleCost) == ~0u);
}
#endif
@@ -1067,7 +1137,8 @@ public:
}
bool HasFormulaWithSameRegs(const Formula &F) const;
- bool InsertFormula(const Formula &F);
+ float getNotSelectedProbability(const SCEV *Reg) const;
+ bool InsertFormula(const Formula &F, const Loop &L);
void DeleteFormula(Formula &F);
void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses);
@@ -1083,17 +1154,23 @@ void Cost::RateRegister(const SCEV *Reg,
const Loop *L,
ScalarEvolution &SE, DominatorTree &DT) {
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) {
- // If this is an addrec for another loop, don't second-guess its addrec phi
- // nodes. LSR isn't currently smart enough to reason about more than one
- // loop at a time. LSR has already run on inner loops, will not run on outer
- // loops, and cannot be expected to change sibling loops.
+ // If this is an addrec for another loop, it should be an invariant
+ // with respect to L since L is the innermost loop (at least
+ // for now LSR only handles innermost loops).
if (AR->getLoop() != L) {
// If the AddRec exists, consider it's register free and leave it alone.
if (isExistingPhi(AR, SE))
return;
- // Otherwise, do not consider this formula at all.
- Lose();
+ // It is bad to allow LSR for current loop to add induction variables
+ // for its sibling loops.
+ if (!AR->getLoop()->contains(L)) {
+ Lose();
+ return;
+ }
+
+ // Otherwise, it will be an invariant with respect to Loop L.
+ ++NumRegs;
return;
}
AddRecCost += 1; /// TODO: This should be a function of the stride.
@@ -1150,8 +1227,11 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
ScalarEvolution &SE, DominatorTree &DT,
const LSRUse &LU,
SmallPtrSetImpl<const SCEV *> *LoserRegs) {
- assert(F.isCanonical() && "Cost is accurate only for canonical formula");
+ assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula");
// Tally up the registers.
+ unsigned PrevAddRecCost = AddRecCost;
+ unsigned PrevNumRegs = NumRegs;
+ unsigned PrevNumBaseAdds = NumBaseAdds;
if (const SCEV *ScaledReg = F.ScaledReg) {
if (VisitedRegs.count(ScaledReg)) {
Lose();
@@ -1171,6 +1251,18 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
return;
}
+ // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as
+ // additional instruction (at least fill).
+ unsigned TTIRegNum = TTI.getNumberOfRegisters(false) - 1;
+ if (NumRegs > TTIRegNum) {
+ // Cost already exceeded TTIRegNum, then only newly added register can add
+ // new instructions.
+ if (PrevNumRegs > TTIRegNum)
+ Insns += (NumRegs - PrevNumRegs);
+ else
+ Insns += (NumRegs - TTIRegNum);
+ }
+
// Determine how many (unfolded) adds we'll need inside the loop.
size_t NumBaseParts = F.getNumRegs();
if (NumBaseParts > 1)
@@ -1181,7 +1273,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
NumBaseAdds += (F.UnfoldedOffset != 0);
// Accumulate non-free scaling amounts.
- ScaleCost += getScalingFactorCost(TTI, LU, F);
+ ScaleCost += getScalingFactorCost(TTI, LU, F, *L);
// Tally up the non-zero immediates.
for (const LSRFixup &Fixup : LU.Fixups) {
@@ -1199,11 +1291,30 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
!TTI.isFoldableMemAccessOffset(Fixup.UserInst, Offset))
NumBaseAdds++;
}
+
+ // If ICmpZero formula ends with not 0, it could not be replaced by
+ // just add or sub. We'll need to compare final result of AddRec.
+ // That means we'll need an additional instruction.
+ // For -10 + {0, +, 1}:
+ // i = i + 1;
+ // cmp i, 10
+ //
+ // For {-10, +, 1}:
+ // i = i + 1;
+ if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd())
+ Insns++;
+ // Each new AddRec adds 1 instruction to calculation.
+ Insns += (AddRecCost - PrevAddRecCost);
+
+ // BaseAdds adds instructions for unfolded registers.
+ if (LU.Kind != LSRUse::ICmpZero)
+ Insns += NumBaseAdds - PrevNumBaseAdds;
assert(isValid() && "invalid cost");
}
/// Set this cost to a losing value.
void Cost::Lose() {
+ Insns = ~0u;
NumRegs = ~0u;
AddRecCost = ~0u;
NumIVMuls = ~0u;
@@ -1215,6 +1326,8 @@ void Cost::Lose() {
/// Choose the lower cost.
bool Cost::operator<(const Cost &Other) const {
+ if (InsnsCost && Insns != Other.Insns)
+ return Insns < Other.Insns;
return std::tie(NumRegs, AddRecCost, NumIVMuls, NumBaseAdds, ScaleCost,
ImmCost, SetupCost) <
std::tie(Other.NumRegs, Other.AddRecCost, Other.NumIVMuls,
@@ -1223,6 +1336,7 @@ bool Cost::operator<(const Cost &Other) const {
}
void Cost::print(raw_ostream &OS) const {
+ OS << Insns << " instruction" << (Insns == 1 ? " " : "s ");
OS << NumRegs << " reg" << (NumRegs == 1 ? "" : "s");
if (AddRecCost != 0)
OS << ", with addrec cost " << AddRecCost;
@@ -1239,10 +1353,11 @@ void Cost::print(raw_ostream &OS) const {
OS << ", plus " << SetupCost << " setup cost";
}
-LLVM_DUMP_METHOD
-void Cost::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void Cost::dump() const {
print(errs()); errs() << '\n';
}
+#endif
LSRFixup::LSRFixup()
: UserInst(nullptr), OperandValToReplace(nullptr),
@@ -1285,10 +1400,11 @@ void LSRFixup::print(raw_ostream &OS) const {
OS << ", Offset=" << Offset;
}
-LLVM_DUMP_METHOD
-void LSRFixup::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LSRFixup::dump() const {
print(errs()); errs() << '\n';
}
+#endif
/// Test whether this use as a formula which has the same registers as the given
/// formula.
@@ -1300,10 +1416,19 @@ bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const {
return Uniquifier.count(Key);
}
+/// The function returns a probability of selecting formula without Reg.
+float LSRUse::getNotSelectedProbability(const SCEV *Reg) const {
+ unsigned FNum = 0;
+ for (const Formula &F : Formulae)
+ if (F.referencesReg(Reg))
+ FNum++;
+ return ((float)(Formulae.size() - FNum)) / Formulae.size();
+}
+
/// If the given formula has not yet been inserted, add it to the list, and
/// return true. Return false otherwise. The formula must be in canonical form.
-bool LSRUse::InsertFormula(const Formula &F) {
- assert(F.isCanonical() && "Invalid canonical representation");
+bool LSRUse::InsertFormula(const Formula &F, const Loop &L) {
+ assert(F.isCanonical(L) && "Invalid canonical representation");
if (!Formulae.empty() && RigidFormula)
return false;
@@ -1391,10 +1516,11 @@ void LSRUse::print(raw_ostream &OS) const {
OS << ", widest fixup type: " << *WidestFixupType;
}
-LLVM_DUMP_METHOD
-void LSRUse::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LSRUse::dump() const {
print(errs()); errs() << '\n';
}
+#endif
static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
@@ -1472,7 +1598,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
int64_t MinOffset, int64_t MaxOffset,
LSRUse::KindType Kind, MemAccessTy AccessTy,
- const Formula &F) {
+ const Formula &F, const Loop &L) {
// For the purpose of isAMCompletelyFolded either having a canonical formula
// or a scale not equal to zero is correct.
// Problems may arise from non canonical formulae having a scale == 0.
@@ -1480,7 +1606,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
// However, when we generate the scaled formulae, we first check that the
// scaling factor is profitable before computing the actual ScaledReg for
// compile time sake.
- assert((F.isCanonical() || F.Scale != 0));
+ assert((F.isCanonical(L) || F.Scale != 0));
return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy,
F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale);
}
@@ -1515,14 +1641,15 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
}
static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
- const LSRUse &LU, const Formula &F) {
+ const LSRUse &LU, const Formula &F,
+ const Loop &L) {
if (!F.Scale)
return 0;
// If the use is not completely folded in that instruction, we will have to
// pay an extra cost only for scale != 1.
if (!isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind,
- LU.AccessTy, F))
+ LU.AccessTy, F, L))
return F.Scale != 1;
switch (LU.Kind) {
@@ -1772,6 +1899,7 @@ class LSRInstance {
void NarrowSearchSpaceByDetectingSupersets();
void NarrowSearchSpaceByCollapsingUnrolledCode();
void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();
+ void NarrowSearchSpaceByDeletingCostlyFormulas();
void NarrowSearchSpaceByPickingWinnerRegs();
void NarrowSearchSpaceUsingHeuristics();
@@ -2492,7 +2620,12 @@ static Value *getWideOperand(Value *Oper) {
static bool isCompatibleIVType(Value *LVal, Value *RVal) {
Type *LType = LVal->getType();
Type *RType = RVal->getType();
- return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy());
+ return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() &&
+ // Different address spaces means (possibly)
+ // different types of the pointer implementation,
+ // e.g. i16 vs i32 so disallow that.
+ (LType->getPointerAddressSpace() ==
+ RType->getPointerAddressSpace()));
}
/// Return an approximation of this SCEV expression's "base", or NULL for any
@@ -2989,8 +3122,10 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
User::op_iterator UseI =
find(UserInst->operands(), U.getOperandValToReplace());
assert(UseI != UserInst->op_end() && "cannot find IV operand");
- if (IVIncSet.count(UseI))
+ if (IVIncSet.count(UseI)) {
+ DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n');
continue;
+ }
LSRUse::KindType Kind = LSRUse::Basic;
MemAccessTy AccessTy;
@@ -3025,8 +3160,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
if (SE.isLoopInvariant(N, L) && isSafeToExpand(N, SE)) {
// S is normalized, so normalize N before folding it into S
// to keep the result normalized.
- N = TransformForPostIncUse(Normalize, N, CI, nullptr,
- TmpPostIncLoops, SE, DT);
+ N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);
Kind = LSRUse::ICmpZero;
S = SE.getMinusSCEV(N, S);
}
@@ -3108,7 +3242,8 @@ bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) {
// Do not insert formula that we will not be able to expand.
assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F) &&
"Formula is illegal");
- if (!LU.InsertFormula(F))
+
+ if (!LU.InsertFormula(F, *L))
return false;
CountRegisters(F, LUIdx);
@@ -3347,7 +3482,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx,
F.BaseRegs.push_back(*J);
// We may have changed the number of register in base regs, adjust the
// formula accordingly.
- F.canonicalize();
+ F.canonicalize(*L);
if (InsertFormula(LU, LUIdx, F))
// If that formula hadn't been seen before, recurse to find more like
@@ -3359,7 +3494,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx,
/// Split out subexpressions from adds and the bases of addrecs.
void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx,
Formula Base, unsigned Depth) {
- assert(Base.isCanonical() && "Input must be in the canonical form");
+ assert(Base.isCanonical(*L) && "Input must be in the canonical form");
// Arbitrarily cap recursion to protect compile time.
if (Depth >= 3)
return;
@@ -3400,7 +3535,7 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx,
// rather than proceed with zero in a register.
if (!Sum->isZero()) {
F.BaseRegs.push_back(Sum);
- F.canonicalize();
+ F.canonicalize(*L);
(void)InsertFormula(LU, LUIdx, F);
}
}
@@ -3457,7 +3592,7 @@ void LSRInstance::GenerateConstantOffsetsImpl(
F.ScaledReg = nullptr;
} else
F.deleteBaseReg(F.BaseRegs[Idx]);
- F.canonicalize();
+ F.canonicalize(*L);
} else if (IsScaledReg)
F.ScaledReg = NewG;
else
@@ -3620,10 +3755,10 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {
if (LU.Kind == LSRUse::ICmpZero &&
!Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV)
continue;
- // For each addrec base reg, apply the scale, if possible.
- for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i)
- if (const SCEVAddRecExpr *AR =
- dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i])) {
+ // For each addrec base reg, if its loop is current loop, apply the scale.
+ for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) {
+ const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i]);
+ if (AR && (AR->getLoop() == L || LU.AllFixupsOutsideLoop)) {
const SCEV *FactorS = SE.getConstant(IntTy, Factor);
if (FactorS->isZero())
continue;
@@ -3637,11 +3772,17 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {
// The canonical representation of 1*reg is reg, which is already in
// Base. In that case, do not try to insert the formula, it will be
// rejected anyway.
- if (F.Scale == 1 && F.BaseRegs.empty())
+ if (F.Scale == 1 && (F.BaseRegs.empty() ||
+ (AR->getLoop() != L && LU.AllFixupsOutsideLoop)))
continue;
+ // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate
+ // non canonical Formula with ScaledReg's loop not being L.
+ if (F.Scale == 1 && LU.AllFixupsOutsideLoop)
+ F.canonicalize(*L);
(void)InsertFormula(LU, LUIdx, F);
}
}
+ }
}
}
@@ -3697,10 +3838,11 @@ void WorkItem::print(raw_ostream &OS) const {
<< " , add offset " << Imm;
}
-LLVM_DUMP_METHOD
-void WorkItem::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void WorkItem::dump() const {
print(errs()); errs() << '\n';
}
+#endif
/// Look for registers which are a constant distance apart and try to form reuse
/// opportunities between them.
@@ -3821,7 +3963,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {
continue;
// OK, looks good.
- NewF.canonicalize();
+ NewF.canonicalize(*this->L);
(void)InsertFormula(LU, LUIdx, NewF);
} else {
// Use the immediate in a base register.
@@ -3853,7 +3995,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {
goto skip_formula;
// Ok, looks good.
- NewF.canonicalize();
+ NewF.canonicalize(*this->L);
(void)InsertFormula(LU, LUIdx, NewF);
break;
skip_formula:;
@@ -4165,6 +4307,144 @@ void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){
}
}
+/// The function delete formulas with high registers number expectation.
+/// Assuming we don't know the value of each formula (already delete
+/// all inefficient), generate probability of not selecting for each
+/// register.
+/// For example,
+/// Use1:
+/// reg(a) + reg({0,+,1})
+/// reg(a) + reg({-1,+,1}) + 1
+/// reg({a,+,1})
+/// Use2:
+/// reg(b) + reg({0,+,1})
+/// reg(b) + reg({-1,+,1}) + 1
+/// reg({b,+,1})
+/// Use3:
+/// reg(c) + reg(b) + reg({0,+,1})
+/// reg(c) + reg({b,+,1})
+///
+/// Probability of not selecting
+/// Use1 Use2 Use3
+/// reg(a) (1/3) * 1 * 1
+/// reg(b) 1 * (1/3) * (1/2)
+/// reg({0,+,1}) (2/3) * (2/3) * (1/2)
+/// reg({-1,+,1}) (2/3) * (2/3) * 1
+/// reg({a,+,1}) (2/3) * 1 * 1
+/// reg({b,+,1}) 1 * (2/3) * (2/3)
+/// reg(c) 1 * 1 * 0
+///
+/// Now count registers number mathematical expectation for each formula:
+/// Note that for each use we exclude probability if not selecting for the use.
+/// For example for Use1 probability for reg(a) would be just 1 * 1 (excluding
+/// probabilty 1/3 of not selecting for Use1).
+/// Use1:
+/// reg(a) + reg({0,+,1}) 1 + 1/3 -- to be deleted
+/// reg(a) + reg({-1,+,1}) + 1 1 + 4/9 -- to be deleted
+/// reg({a,+,1}) 1
+/// Use2:
+/// reg(b) + reg({0,+,1}) 1/2 + 1/3 -- to be deleted
+/// reg(b) + reg({-1,+,1}) + 1 1/2 + 2/3 -- to be deleted
+/// reg({b,+,1}) 2/3
+/// Use3:
+/// reg(c) + reg(b) + reg({0,+,1}) 1 + 1/3 + 4/9 -- to be deleted
+/// reg(c) + reg({b,+,1}) 1 + 2/3
+
+void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() {
+ if (EstimateSearchSpaceComplexity() < ComplexityLimit)
+ return;
+ // Ok, we have too many of formulae on our hands to conveniently handle.
+ // Use a rough heuristic to thin out the list.
+
+ // Set of Regs wich will be 100% used in final solution.
+ // Used in each formula of a solution (in example above this is reg(c)).
+ // We can skip them in calculations.
+ SmallPtrSet<const SCEV *, 4> UniqRegs;
+ DEBUG(dbgs() << "The search space is too complex.\n");
+
+ // Map each register to probability of not selecting
+ DenseMap <const SCEV *, float> RegNumMap;
+ for (const SCEV *Reg : RegUses) {
+ if (UniqRegs.count(Reg))
+ continue;
+ float PNotSel = 1;
+ for (const LSRUse &LU : Uses) {
+ if (!LU.Regs.count(Reg))
+ continue;
+ float P = LU.getNotSelectedProbability(Reg);
+ if (P != 0.0)
+ PNotSel *= P;
+ else
+ UniqRegs.insert(Reg);
+ }
+ RegNumMap.insert(std::make_pair(Reg, PNotSel));
+ }
+
+ DEBUG(dbgs() << "Narrowing the search space by deleting costly formulas\n");
+
+ // Delete formulas where registers number expectation is high.
+ for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) {
+ LSRUse &LU = Uses[LUIdx];
+ // If nothing to delete - continue.
+ if (LU.Formulae.size() < 2)
+ continue;
+ // This is temporary solution to test performance. Float should be
+ // replaced with round independent type (based on integers) to avoid
+ // different results for different target builds.
+ float FMinRegNum = LU.Formulae[0].getNumRegs();
+ float FMinARegNum = LU.Formulae[0].getNumRegs();
+ size_t MinIdx = 0;
+ for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) {
+ Formula &F = LU.Formulae[i];
+ float FRegNum = 0;
+ float FARegNum = 0;
+ for (const SCEV *BaseReg : F.BaseRegs) {
+ if (UniqRegs.count(BaseReg))
+ continue;
+ FRegNum += RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg);
+ if (isa<SCEVAddRecExpr>(BaseReg))
+ FARegNum +=
+ RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg);
+ }
+ if (const SCEV *ScaledReg = F.ScaledReg) {
+ if (!UniqRegs.count(ScaledReg)) {
+ FRegNum +=
+ RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg);
+ if (isa<SCEVAddRecExpr>(ScaledReg))
+ FARegNum +=
+ RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg);
+ }
+ }
+ if (FMinRegNum > FRegNum ||
+ (FMinRegNum == FRegNum && FMinARegNum > FARegNum)) {
+ FMinRegNum = FRegNum;
+ FMinARegNum = FARegNum;
+ MinIdx = i;
+ }
+ }
+ DEBUG(dbgs() << " The formula "; LU.Formulae[MinIdx].print(dbgs());
+ dbgs() << " with min reg num " << FMinRegNum << '\n');
+ if (MinIdx != 0)
+ std::swap(LU.Formulae[MinIdx], LU.Formulae[0]);
+ while (LU.Formulae.size() != 1) {
+ DEBUG(dbgs() << " Deleting "; LU.Formulae.back().print(dbgs());
+ dbgs() << '\n');
+ LU.Formulae.pop_back();
+ }
+ LU.RecomputeRegs(LUIdx, RegUses);
+ assert(LU.Formulae.size() == 1 && "Should be exactly 1 min regs formula");
+ Formula &F = LU.Formulae[0];
+ DEBUG(dbgs() << " Leaving only "; F.print(dbgs()); dbgs() << '\n');
+ // When we choose the formula, the regs become unique.
+ UniqRegs.insert(F.BaseRegs.begin(), F.BaseRegs.end());
+ if (F.ScaledReg)
+ UniqRegs.insert(F.ScaledReg);
+ }
+ DEBUG(dbgs() << "After pre-selection:\n";
+ print_uses(dbgs()));
+}
+
+
/// Pick a register which seems likely to be profitable, and then in any use
/// which has any reference to that register, delete all formulae which do not
/// reference that register.
@@ -4237,7 +4517,10 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() {
NarrowSearchSpaceByDetectingSupersets();
NarrowSearchSpaceByCollapsingUnrolledCode();
NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();
- NarrowSearchSpaceByPickingWinnerRegs();
+ if (LSRExpNarrow)
+ NarrowSearchSpaceByDeletingCostlyFormulas();
+ else
+ NarrowSearchSpaceByPickingWinnerRegs();
}
/// This is the recursive solver.
@@ -4515,11 +4798,7 @@ Value *LSRInstance::Expand(const LSRUse &LU,
assert(!Reg->isZero() && "Zero allocated in a base register!");
// If we're expanding for a post-inc user, make the post-inc adjustment.
- PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops);
- Reg = TransformForPostIncUse(Denormalize, Reg,
- LF.UserInst, LF.OperandValToReplace,
- Loops, SE, DT);
-
+ Reg = denormalizeForPostIncUse(Reg, LF.PostIncLoops, SE);
Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr)));
}
@@ -4530,9 +4809,7 @@ Value *LSRInstance::Expand(const LSRUse &LU,
// If we're expanding for a post-inc user, make the post-inc adjustment.
PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops);
- ScaledS = TransformForPostIncUse(Denormalize, ScaledS,
- LF.UserInst, LF.OperandValToReplace,
- Loops, SE, DT);
+ ScaledS = denormalizeForPostIncUse(ScaledS, Loops, SE);
if (LU.Kind == LSRUse::ICmpZero) {
// Expand ScaleReg as if it was part of the base regs.
@@ -4975,10 +5252,11 @@ void LSRInstance::print(raw_ostream &OS) const {
print_uses(OS);
}
-LLVM_DUMP_METHOD
-void LSRInstance::dump() const {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void LSRInstance::dump() const {
print(errs()); errs() << '\n';
}
+#endif
namespace {
diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp
index c7f91226d2229..62aa6ee48069d 100644
--- a/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -44,7 +44,11 @@ using namespace llvm;
static cl::opt<unsigned>
UnrollThreshold("unroll-threshold", cl::Hidden,
- cl::desc("The baseline cost threshold for loop unrolling"));
+ cl::desc("The cost threshold for loop unrolling"));
+
+static cl::opt<unsigned> UnrollPartialThreshold(
+ "unroll-partial-threshold", cl::Hidden,
+ cl::desc("The cost threshold for partial loop unrolling"));
static cl::opt<unsigned> UnrollMaxPercentThresholdBoost(
"unroll-max-percent-threshold-boost", cl::init(400), cl::Hidden,
@@ -106,10 +110,19 @@ static cl::opt<unsigned> FlatLoopTripCountThreshold(
"aggressively unrolled."));
static cl::opt<bool>
- UnrollAllowPeeling("unroll-allow-peeling", cl::Hidden,
+ UnrollAllowPeeling("unroll-allow-peeling", cl::init(true), cl::Hidden,
cl::desc("Allows loops to be peeled when the dynamic "
"trip count is known to be low."));
+// This option isn't ever intended to be enabled, it serves to allow
+// experiments to check the assumptions about when this kind of revisit is
+// necessary.
+static cl::opt<bool> UnrollRevisitChildLoops(
+ "unroll-revisit-child-loops", cl::Hidden,
+ cl::desc("Enqueue and re-visit child loops in the loop PM after unrolling. "
+ "This shouldn't typically be needed as child loops (or their "
+ "clones) were already visited."));
+
/// A magic value for use with the Threshold parameter to indicate
/// that the loop unroll should be performed regardless of how much
/// code expansion would result.
@@ -118,16 +131,17 @@ static const unsigned NoThreshold = UINT_MAX;
/// Gather the various unrolling parameters based on the defaults, compiler
/// flags, TTI overrides and user specified parameters.
static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(
- Loop *L, const TargetTransformInfo &TTI, Optional<unsigned> UserThreshold,
- Optional<unsigned> UserCount, Optional<bool> UserAllowPartial,
- Optional<bool> UserRuntime, Optional<bool> UserUpperBound) {
+ Loop *L, const TargetTransformInfo &TTI, int OptLevel,
+ Optional<unsigned> UserThreshold, Optional<unsigned> UserCount,
+ Optional<bool> UserAllowPartial, Optional<bool> UserRuntime,
+ Optional<bool> UserUpperBound) {
TargetTransformInfo::UnrollingPreferences UP;
// Set up the defaults
- UP.Threshold = 150;
+ UP.Threshold = OptLevel > 2 ? 300 : 150;
UP.MaxPercentThresholdBoost = 400;
UP.OptSizeThreshold = 0;
- UP.PartialThreshold = UP.Threshold;
+ UP.PartialThreshold = 150;
UP.PartialOptSizeThreshold = 0;
UP.Count = 0;
UP.PeelCount = 0;
@@ -141,7 +155,7 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(
UP.AllowExpensiveTripCount = false;
UP.Force = false;
UP.UpperBound = false;
- UP.AllowPeeling = false;
+ UP.AllowPeeling = true;
// Override with any target specific settings
TTI.getUnrollingPreferences(L, UP);
@@ -153,10 +167,10 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(
}
// Apply any user values specified by cl::opt
- if (UnrollThreshold.getNumOccurrences() > 0) {
+ if (UnrollThreshold.getNumOccurrences() > 0)
UP.Threshold = UnrollThreshold;
- UP.PartialThreshold = UnrollThreshold;
- }
+ if (UnrollPartialThreshold.getNumOccurrences() > 0)
+ UP.PartialThreshold = UnrollPartialThreshold;
if (UnrollMaxPercentThresholdBoost.getNumOccurrences() > 0)
UP.MaxPercentThresholdBoost = UnrollMaxPercentThresholdBoost;
if (UnrollMaxCount.getNumOccurrences() > 0)
@@ -495,7 +509,7 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT,
KnownSucc = SI->getSuccessor(0);
else if (ConstantInt *SimpleCondVal =
dyn_cast<ConstantInt>(SimpleCond))
- KnownSucc = SI->findCaseValue(SimpleCondVal).getCaseSuccessor();
+ KnownSucc = SI->findCaseValue(SimpleCondVal)->getCaseSuccessor();
}
}
if (KnownSucc) {
@@ -770,7 +784,15 @@ static bool computeUnrollCount(
}
}
- // 4rd priority is partial unrolling.
+ // 4th priority is loop peeling
+ computePeelCount(L, LoopSize, UP, TripCount);
+ if (UP.PeelCount) {
+ UP.Runtime = false;
+ UP.Count = 1;
+ return ExplicitUnroll;
+ }
+
+ // 5th priority is partial unrolling.
// Try partial unroll only when TripCount could be staticaly calculated.
if (TripCount) {
UP.Partial |= ExplicitUnroll;
@@ -833,14 +855,6 @@ static bool computeUnrollCount(
<< "Unable to fully unroll loop as directed by unroll(full) pragma "
"because loop has a runtime trip count.");
- // 5th priority is loop peeling
- computePeelCount(L, LoopSize, UP);
- if (UP.PeelCount) {
- UP.Runtime = false;
- UP.Count = 1;
- return ExplicitUnroll;
- }
-
// 6th priority is runtime unrolling.
// Don't unroll a runtime trip count loop when it is disabled.
if (HasRuntimeUnrollDisablePragma(L)) {
@@ -914,7 +928,7 @@ static bool computeUnrollCount(
static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
ScalarEvolution *SE, const TargetTransformInfo &TTI,
AssumptionCache &AC, OptimizationRemarkEmitter &ORE,
- bool PreserveLCSSA,
+ bool PreserveLCSSA, int OptLevel,
Optional<unsigned> ProvidedCount,
Optional<unsigned> ProvidedThreshold,
Optional<bool> ProvidedAllowPartial,
@@ -934,7 +948,7 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
bool NotDuplicatable;
bool Convergent;
TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
- L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial,
+ L, TTI, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial,
ProvidedRuntime, ProvidedUpperBound);
// Exit early if unrolling is disabled.
if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0))
@@ -1034,16 +1048,17 @@ namespace {
class LoopUnroll : public LoopPass {
public:
static char ID; // Pass ID, replacement for typeid
- LoopUnroll(Optional<unsigned> Threshold = None,
+ LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None,
Optional<unsigned> Count = None,
Optional<bool> AllowPartial = None, Optional<bool> Runtime = None,
Optional<bool> UpperBound = None)
- : LoopPass(ID), ProvidedCount(std::move(Count)),
+ : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)),
ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial),
ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound) {
initializeLoopUnrollPass(*PassRegistry::getPassRegistry());
}
+ int OptLevel;
Optional<unsigned> ProvidedCount;
Optional<unsigned> ProvidedThreshold;
Optional<bool> ProvidedAllowPartial;
@@ -1068,7 +1083,7 @@ public:
OptimizationRemarkEmitter ORE(&F);
bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
- return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA,
+ return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel,
ProvidedCount, ProvidedThreshold,
ProvidedAllowPartial, ProvidedRuntime,
ProvidedUpperBound);
@@ -1094,26 +1109,27 @@ INITIALIZE_PASS_DEPENDENCY(LoopPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false)
-Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial,
- int Runtime, int UpperBound) {
+Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count,
+ int AllowPartial, int Runtime,
+ int UpperBound) {
// TODO: It would make more sense for this function to take the optionals
// directly, but that's dangerous since it would silently break out of tree
// callers.
- return new LoopUnroll(Threshold == -1 ? None : Optional<unsigned>(Threshold),
- Count == -1 ? None : Optional<unsigned>(Count),
- AllowPartial == -1 ? None
- : Optional<bool>(AllowPartial),
- Runtime == -1 ? None : Optional<bool>(Runtime),
- UpperBound == -1 ? None : Optional<bool>(UpperBound));
+ return new LoopUnroll(
+ OptLevel, Threshold == -1 ? None : Optional<unsigned>(Threshold),
+ Count == -1 ? None : Optional<unsigned>(Count),
+ AllowPartial == -1 ? None : Optional<bool>(AllowPartial),
+ Runtime == -1 ? None : Optional<bool>(Runtime),
+ UpperBound == -1 ? None : Optional<bool>(UpperBound));
}
-Pass *llvm::createSimpleLoopUnrollPass() {
- return llvm::createLoopUnrollPass(-1, -1, 0, 0, 0);
+Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) {
+ return llvm::createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0);
}
PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
- LPMUpdater &) {
+ LPMUpdater &Updater) {
const auto &FAM =
AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
Function *F = L.getHeader()->getParent();
@@ -1124,12 +1140,84 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
report_fatal_error("LoopUnrollPass: OptimizationRemarkEmitterAnalysis not "
"cached at a higher level");
- bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE,
- /*PreserveLCSSA*/ true, ProvidedCount,
- ProvidedThreshold, ProvidedAllowPartial,
- ProvidedRuntime, ProvidedUpperBound);
-
+ // Keep track of the previous loop structure so we can identify new loops
+ // created by unrolling.
+ Loop *ParentL = L.getParentLoop();
+ SmallPtrSet<Loop *, 4> OldLoops;
+ if (ParentL)
+ OldLoops.insert(ParentL->begin(), ParentL->end());
+ else
+ OldLoops.insert(AR.LI.begin(), AR.LI.end());
+
+ // The API here is quite complex to call, but there are only two interesting
+ // states we support: partial and full (or "simple") unrolling. However, to
+ // enable these things we actually pass "None" in for the optional to avoid
+ // providing an explicit choice.
+ Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam;
+ if (!AllowPartialUnrolling)
+ AllowPartialParam = RuntimeParam = UpperBoundParam = false;
+ bool Changed = tryToUnrollLoop(
+ &L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE,
+ /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None,
+ /*Threshold*/ None, AllowPartialParam, RuntimeParam, UpperBoundParam);
if (!Changed)
return PreservedAnalyses::all();
+
+ // The parent must not be damaged by unrolling!
+#ifndef NDEBUG
+ if (ParentL)
+ ParentL->verifyLoop();
+#endif
+
+ // Unrolling can do several things to introduce new loops into a loop nest:
+ // - Partial unrolling clones child loops within the current loop. If it
+ // uses a remainder, then it can also create any number of sibling loops.
+ // - Full unrolling clones child loops within the current loop but then
+ // removes the current loop making all of the children appear to be new
+ // sibling loops.
+ // - Loop peeling can directly introduce new sibling loops by peeling one
+ // iteration.
+ //
+ // When a new loop appears as a sibling loop, either from peeling an
+ // iteration or fully unrolling, its nesting structure has fundamentally
+ // changed and we want to revisit it to reflect that.
+ //
+ // When unrolling has removed the current loop, we need to tell the
+ // infrastructure that it is gone.
+ //
+ // Finally, we support a debugging/testing mode where we revisit child loops
+ // as well. These are not expected to require further optimizations as either
+ // they or the loop they were cloned from have been directly visited already.
+ // But the debugging mode allows us to check this assumption.
+ bool IsCurrentLoopValid = false;
+ SmallVector<Loop *, 4> SibLoops;
+ if (ParentL)
+ SibLoops.append(ParentL->begin(), ParentL->end());
+ else
+ SibLoops.append(AR.LI.begin(), AR.LI.end());
+ erase_if(SibLoops, [&](Loop *SibLoop) {
+ if (SibLoop == &L) {
+ IsCurrentLoopValid = true;
+ return true;
+ }
+
+ // Otherwise erase the loop from the list if it was in the old loops.
+ return OldLoops.count(SibLoop) != 0;
+ });
+ Updater.addSiblingLoops(SibLoops);
+
+ if (!IsCurrentLoopValid) {
+ Updater.markLoopAsDeleted(L);
+ } else {
+ // We can only walk child loops if the current loop remained valid.
+ if (UnrollRevisitChildLoops) {
+ // Walk *all* of the child loops. This is a highly speculative mode
+ // anyways so look for any simplifications that arose from partial
+ // unrolling or peeling off of iterations.
+ SmallVector<Loop *, 4> ChildLoops(L.begin(), L.end());
+ Updater.addChildLoops(ChildLoops);
+ }
+ }
+
return getLoopPassPreservedAnalyses();
}
diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp
index 76fe91884c7b8..a99c9999c6191 100644
--- a/lib/Transforms/Scalar/LoopUnswitch.cpp
+++ b/lib/Transforms/Scalar/LoopUnswitch.cpp
@@ -33,6 +33,7 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
@@ -47,6 +48,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/Support/CommandLine.h"
@@ -77,19 +79,6 @@ static cl::opt<unsigned>
Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"),
cl::init(100), cl::Hidden);
-static cl::opt<bool>
-LoopUnswitchWithBlockFrequency("loop-unswitch-with-block-frequency",
- cl::init(false), cl::Hidden,
- cl::desc("Enable the use of the block frequency analysis to access PGO "
- "heuristics to minimize code growth in cold regions."));
-
-static cl::opt<unsigned>
-ColdnessThreshold("loop-unswitch-coldness-threshold", cl::init(1), cl::Hidden,
- cl::desc("Coldness threshold in percentage. The loop header frequency "
- "(relative to the entry frequency) is compared with this "
- "threshold to determine if non-trivial unswitching should be "
- "enabled."));
-
namespace {
class LUAnalysisCache {
@@ -174,13 +163,6 @@ namespace {
LUAnalysisCache BranchesInfo;
- bool EnabledPGO;
-
- // BFI and ColdEntryFreq are only used when PGO and
- // LoopUnswitchWithBlockFrequency are enabled.
- BlockFrequencyInfo BFI;
- BlockFrequency ColdEntryFreq;
-
bool OptimizeForSize;
bool redoLoop;
@@ -199,12 +181,14 @@ namespace {
// NewBlocks contained cloned copy of basic blocks from LoopBlocks.
std::vector<BasicBlock*> NewBlocks;
+ bool hasBranchDivergence;
+
public:
static char ID; // Pass ID, replacement for typeid
- explicit LoopUnswitch(bool Os = false) :
+ explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) :
LoopPass(ID), OptimizeForSize(Os), redoLoop(false),
currentLoop(nullptr), DT(nullptr), loopHeader(nullptr),
- loopPreheader(nullptr) {
+ loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) {
initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());
}
@@ -217,6 +201,8 @@ namespace {
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetTransformInfoWrapperPass>();
+ if (hasBranchDivergence)
+ AU.addRequired<DivergenceAnalysis>();
getLoopAnalysisUsage(AU);
}
@@ -255,6 +241,11 @@ namespace {
TerminatorInst *TI);
void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L);
+
+ /// Given that the Invariant is not equal to Val. Simplify instructions
+ /// in the loop.
+ Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant,
+ Constant *Val);
};
}
@@ -381,16 +372,35 @@ INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops",
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(LoopPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis)
INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",
false, false)
-Pass *llvm::createLoopUnswitchPass(bool Os) {
- return new LoopUnswitch(Os);
+Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) {
+ return new LoopUnswitch(Os, hasBranchDivergence);
}
+/// Operator chain lattice.
+enum OperatorChain {
+ OC_OpChainNone, ///< There is no operator.
+ OC_OpChainOr, ///< There are only ORs.
+ OC_OpChainAnd, ///< There are only ANDs.
+ OC_OpChainMixed ///< There are ANDs and ORs.
+};
+
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
/// an invariant piece, return the invariant. Otherwise, return null.
+//
+/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
+/// mixed operator chain, as we can not reliably find a value which will simplify
+/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
+/// to simplify the chain.
+///
+/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to
+/// simplify the condition itself to a loop variant condition, but at the
+/// cost of creating an entirely new loop.
static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
+ OperatorChain &ParentChain,
DenseMap<Value *, Value *> &Cache) {
auto CacheIt = Cache.find(Cond);
if (CacheIt != Cache.end())
@@ -414,21 +424,53 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
return Cond;
}
+ // Walk up the operator chain to find partial invariant conditions.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
if (BO->getOpcode() == Instruction::And ||
BO->getOpcode() == Instruction::Or) {
- // If either the left or right side is invariant, we can unswitch on this,
- // which will cause the branch to go away in one loop and the condition to
- // simplify in the other one.
- if (Value *LHS =
- FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) {
- Cache[Cond] = LHS;
- return LHS;
+ // Given the previous operator, compute the current operator chain status.
+ OperatorChain NewChain;
+ switch (ParentChain) {
+ case OC_OpChainNone:
+ NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
+ OC_OpChainOr;
+ break;
+ case OC_OpChainOr:
+ NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr :
+ OC_OpChainMixed;
+ break;
+ case OC_OpChainAnd:
+ NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
+ OC_OpChainMixed;
+ break;
+ case OC_OpChainMixed:
+ NewChain = OC_OpChainMixed;
+ break;
}
- if (Value *RHS =
- FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) {
- Cache[Cond] = RHS;
- return RHS;
+
+ // If we reach a Mixed state, we do not want to keep walking up as we can not
+ // reliably find a value that will simplify the chain. With this check, we
+ // will return null on the first sight of mixed chain and the caller will
+ // either backtrack to find partial LIV in other operand or return null.
+ if (NewChain != OC_OpChainMixed) {
+ // Update the current operator chain type before we search up the chain.
+ ParentChain = NewChain;
+ // If either the left or right side is invariant, we can unswitch on this,
+ // which will cause the branch to go away in one loop and the condition to
+ // simplify in the other one.
+ if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
+ ParentChain, Cache)) {
+ Cache[Cond] = LHS;
+ return LHS;
+ }
+ // We did not manage to find a partial LIV in operand(0). Backtrack and try
+ // operand(1).
+ ParentChain = NewChain;
+ if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
+ ParentChain, Cache)) {
+ Cache[Cond] = RHS;
+ return RHS;
+ }
}
}
@@ -436,9 +478,21 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
return nullptr;
}
-static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
+/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
+/// an invariant piece, return the invariant along with the operator chain type.
+/// Otherwise, return null.
+static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond,
+ Loop *L,
+ bool &Changed) {
DenseMap<Value *, Value *> Cache;
- return FindLIVLoopCondition(Cond, L, Changed, Cache);
+ OperatorChain OpChain = OC_OpChainNone;
+ Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache);
+
+ // In case we do find a LIV, it can not be obtained by walking up a mixed
+ // operator chain.
+ assert((!FCond || OpChain != OC_OpChainMixed) &&
+ "Do not expect a partial LIV with mixed operator chain");
+ return {FCond, OpChain};
}
bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
@@ -457,19 +511,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
if (SanitizeMemory)
computeLoopSafetyInfo(&SafetyInfo, L);
- EnabledPGO = F->getEntryCount().hasValue();
-
- if (LoopUnswitchWithBlockFrequency && EnabledPGO) {
- BranchProbabilityInfo BPI(*F, *LI);
- BFI.calculate(*L->getHeader()->getParent(), BPI, *LI);
-
- // Use BranchProbability to compute a minimum frequency based on
- // function entry baseline frequency. Loops with headers below this
- // frequency are considered as cold.
- const BranchProbability ColdProb(ColdnessThreshold, 100);
- ColdEntryFreq = BlockFrequency(BFI.getEntryFreq()) * ColdProb;
- }
-
bool Changed = false;
do {
assert(currentLoop->isLCSSAForm(*DT));
@@ -581,19 +622,9 @@ bool LoopUnswitch::processCurrentLoop() {
loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize))
return false;
- if (LoopUnswitchWithBlockFrequency && EnabledPGO) {
- // Compute the weighted frequency of the hottest block in the
- // loop (loopHeader in this case since inner loops should be
- // processed before outer loop). If it is less than ColdFrequency,
- // we should not unswitch.
- BlockFrequency LoopEntryFreq = BFI.getBlockFreq(loopHeader);
- if (LoopEntryFreq < ColdEntryFreq)
- return false;
- }
-
for (IntrinsicInst *Guard : Guards) {
Value *LoopCond =
- FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed);
+ FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;
if (LoopCond &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
// NB! Unswitching (if successful) could have erased some of the
@@ -634,7 +665,7 @@ bool LoopUnswitch::processCurrentLoop() {
// See if this, or some part of it, is loop invariant. If so, we can
// unswitch on it if we desire.
Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
- currentLoop, Changed);
+ currentLoop, Changed).first;
if (LoopCond &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
++NumBranches;
@@ -642,24 +673,48 @@ bool LoopUnswitch::processCurrentLoop() {
}
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
- Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
- currentLoop, Changed);
+ Value *SC = SI->getCondition();
+ Value *LoopCond;
+ OperatorChain OpChain;
+ std::tie(LoopCond, OpChain) =
+ FindLIVLoopCondition(SC, currentLoop, Changed);
+
unsigned NumCases = SI->getNumCases();
if (LoopCond && NumCases) {
// Find a value to unswitch on:
// FIXME: this should chose the most expensive case!
// FIXME: scan for a case with a non-critical edge?
Constant *UnswitchVal = nullptr;
-
- // Do not process same value again and again.
- // At this point we have some cases already unswitched and
- // some not yet unswitched. Let's find the first not yet unswitched one.
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
- i != e; ++i) {
- Constant *UnswitchValCandidate = i.getCaseValue();
- if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
- UnswitchVal = UnswitchValCandidate;
- break;
+ // Find a case value such that at least one case value is unswitched
+ // out.
+ if (OpChain == OC_OpChainAnd) {
+ // If the chain only has ANDs and the switch has a case value of 0.
+ // Dropping in a 0 to the chain will unswitch out the 0-casevalue.
+ auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType()));
+ if (BranchesInfo.isUnswitched(SI, AllZero))
+ continue;
+ // We are unswitching 0 out.
+ UnswitchVal = AllZero;
+ } else if (OpChain == OC_OpChainOr) {
+ // If the chain only has ORs and the switch has a case value of ~0.
+ // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue.
+ auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType()));
+ if (BranchesInfo.isUnswitched(SI, AllOne))
+ continue;
+ // We are unswitching ~0 out.
+ UnswitchVal = AllOne;
+ } else {
+ assert(OpChain == OC_OpChainNone &&
+ "Expect to unswitch on trivial chain");
+ // Do not process same value again and again.
+ // At this point we have some cases already unswitched and
+ // some not yet unswitched. Let's find the first not yet unswitched one.
+ for (auto Case : SI->cases()) {
+ Constant *UnswitchValCandidate = Case.getCaseValue();
+ if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
+ UnswitchVal = UnswitchValCandidate;
+ break;
+ }
}
}
@@ -668,6 +723,11 @@ bool LoopUnswitch::processCurrentLoop() {
if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {
++NumSwitches;
+ // In case of a full LIV, UnswitchVal is the value we unswitched out.
+ // In case of a partial LIV, we only unswitch when its an AND-chain
+ // or OR-chain. In both cases switch input value simplifies to
+ // UnswitchVal.
+ BranchesInfo.setUnswitched(SI, UnswitchVal);
return true;
}
}
@@ -678,7 +738,7 @@ bool LoopUnswitch::processCurrentLoop() {
BBI != E; ++BBI)
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
- currentLoop, Changed);
+ currentLoop, Changed).first;
if (LoopCond && UnswitchIfProfitable(LoopCond,
ConstantInt::getTrue(Context))) {
++NumSelects;
@@ -753,6 +813,15 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,
<< ". Cost too high.\n");
return false;
}
+ if (hasBranchDivergence &&
+ getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) {
+ DEBUG(dbgs() << "NOT unswitching loop %"
+ << currentLoop->getHeader()->getName()
+ << " at non-trivial condition '" << *Val
+ << "' == " << *LoopCond << "\n"
+ << ". Condition is divergent.\n");
+ return false;
+ }
UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI);
return true;
@@ -899,7 +968,6 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
if (I.mayHaveSideEffects())
return false;
- // FIXME: add check for constant foldable switch instructions.
if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) {
if (BI->isUnconditional()) {
CurrentBB = BI->getSuccessor(0);
@@ -911,7 +979,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
// Found a trivial condition candidate: non-foldable conditional branch.
break;
}
+ } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
+ // At this point, any constant-foldable instructions should have probably
+ // been folded.
+ ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition());
+ if (!Cond)
+ break;
+ // Find the target block we are definitely going to.
+ CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor();
} else {
+ // We do not understand these terminator instructions.
break;
}
@@ -929,7 +1006,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
return false;
Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
- currentLoop, Changed);
+ currentLoop, Changed).first;
// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
@@ -960,7 +1037,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
// If this isn't switching on an invariant condition, we can't unswitch it.
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
- currentLoop, Changed);
+ currentLoop, Changed).first;
// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
@@ -973,13 +1050,12 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
// this.
// Note that we can't trivially unswitch on the default case or
// on already unswitched cases.
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
- i != e; ++i) {
+ for (auto Case : SI->cases()) {
BasicBlock *LoopExitCandidate;
- if ((LoopExitCandidate = isTrivialLoopExitBlock(currentLoop,
- i.getCaseSuccessor()))) {
+ if ((LoopExitCandidate =
+ isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) {
// Okay, we found a trivial case, remember the value that is trivial.
- ConstantInt *CaseVal = i.getCaseValue();
+ ConstantInt *CaseVal = Case.getCaseValue();
// Check that it was not unswitched before, since already unswitched
// trivial vals are looks trivial too.
@@ -998,6 +1074,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
nullptr);
+
+ // We are only unswitching full LIV.
+ BranchesInfo.setUnswitched(SI, CondVal);
++NumSwitches;
return true;
}
@@ -1253,18 +1332,38 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
if (!UI || !L->contains(UI))
continue;
- Worklist.push_back(UI);
+ // At this point, we know LIC is definitely not Val. Try to use some simple
+ // logic to simplify the user w.r.t. to the context.
+ if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) {
+ if (LI->replacementPreservesLCSSAForm(UI, Replacement)) {
+ // This in-loop instruction has been simplified w.r.t. its context,
+ // i.e. LIC != Val, make sure we propagate its replacement value to
+ // all its users.
+ //
+ // We can not yet delete UI, the LIC user, yet, because that would invalidate
+ // the LIC->users() iterator !. However, we can make this instruction
+ // dead by replacing all its users and push it onto the worklist so that
+ // it can be properly deleted and its operands simplified.
+ UI->replaceAllUsesWith(Replacement);
+ }
+ }
- // TODO: We could do other simplifications, for example, turning
- // 'icmp eq LIC, Val' -> false.
+ // This is a LIC user, push it into the worklist so that SimplifyCode can
+ // attempt to simplify it.
+ Worklist.push_back(UI);
// If we know that LIC is not Val, use this info to simplify code.
SwitchInst *SI = dyn_cast<SwitchInst>(UI);
if (!SI || !isa<ConstantInt>(Val)) continue;
- SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val));
+ // NOTE: if a case value for the switch is unswitched out, we record it
+ // after the unswitch finishes. We can not record it here as the switch
+ // is not a direct user of the partial LIV.
+ SwitchInst::CaseHandle DeadCase =
+ *SI->findCaseValue(cast<ConstantInt>(Val));
// Default case is live for multiple values.
- if (DeadCase == SI->case_default()) continue;
+ if (DeadCase == *SI->case_default())
+ continue;
// Found a dead case value. Don't remove PHI nodes in the
// successor if they become single-entry, those PHI nodes may
@@ -1274,8 +1373,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
BasicBlock *SISucc = DeadCase.getCaseSuccessor();
BasicBlock *Latch = L->getLoopLatch();
- BranchesInfo.setUnswitched(SI, Val);
-
if (!SI->findCaseDest(SISucc)) continue; // Edge is critical.
// If the DeadCase successor dominates the loop latch, then the
// transformation isn't safe since it will delete the sole predecessor edge
@@ -1397,3 +1494,27 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
}
}
}
+
+/// Simple simplifications we can do given the information that Cond is
+/// definitely not equal to Val.
+Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst,
+ Value *Invariant,
+ Constant *Val) {
+ // icmp eq cond, val -> false
+ ICmpInst *CI = dyn_cast<ICmpInst>(Inst);
+ if (CI && CI->isEquality()) {
+ Value *Op0 = CI->getOperand(0);
+ Value *Op1 = CI->getOperand(1);
+ if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) {
+ LLVMContext &Ctx = Inst->getContext();
+ if (CI->getPredicate() == CmpInst::ICMP_EQ)
+ return ConstantInt::getFalse(Ctx);
+ else
+ return ConstantInt::getTrue(Ctx);
+ }
+ }
+
+ // FIXME: there may be other opportunities, e.g. comparison with floating
+ // point, or Invariant - Val != 0, etc.
+ return nullptr;
+}
diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 52975ef351531..a143b9a3c645f 100644
--- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -67,11 +67,11 @@ static bool handleSwitchExpect(SwitchInst &SI) {
if (!ExpectedValue)
return false;
- SwitchInst::CaseIt Case = SI.findCaseValue(ExpectedValue);
+ SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
unsigned n = SI.getNumCases(); // +1 for default case.
SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight);
- if (Case == SI.case_default())
+ if (Case == *SI.case_default())
Weights[0] = LikelyBranchWeight;
else
Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight;
diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 1b590140f70aa..a3f3f25c1e0f6 100644
--- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -12,20 +12,49 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
+#include <cassert>
+#include <cstdint>
+
using namespace llvm;
#define DEBUG_TYPE "memcpyopt"
@@ -119,6 +148,7 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
return true;
}
+namespace {
/// Represents a range of memset'd bytes with the ByteVal value.
/// This allows us to analyze stores like:
@@ -130,7 +160,6 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
/// the first store, we make a range [1, 2). The second store extends the range
/// to [0, 2). The third makes a new range [2, 3). The fourth store joins the
/// two ranges into [0, 3) which is memset'able.
-namespace {
struct MemsetRange {
// Start/End - A semi range that describes the span that this range covers.
// The range is closed at the start and open at the end: [Start, End).
@@ -148,7 +177,8 @@ struct MemsetRange {
bool isProfitableToUseMemset(const DataLayout &DL) const;
};
-} // end anon namespace
+
+} // end anonymous namespace
bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const {
// If we found more than 4 stores to merge or 16 bytes, use memset.
@@ -192,13 +222,14 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const {
return TheStores.size() > NumPointerStores+NumByteStores;
}
-
namespace {
+
class MemsetRanges {
/// A sorted list of the memset ranges.
SmallVector<MemsetRange, 8> Ranges;
typedef SmallVectorImpl<MemsetRange>::iterator range_iterator;
const DataLayout &DL;
+
public:
MemsetRanges(const DataLayout &DL) : DL(DL) {}
@@ -231,8 +262,7 @@ public:
};
-} // end anon namespace
-
+} // end anonymous namespace
/// Add a new store to the MemsetRanges data structure. This adds a
/// new range for the specified store at the specified offset, merging into
@@ -299,48 +329,36 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
//===----------------------------------------------------------------------===//
namespace {
- class MemCpyOptLegacyPass : public FunctionPass {
- MemCpyOptPass Impl;
- public:
- static char ID; // Pass identification, replacement for typeid
- MemCpyOptLegacyPass() : FunctionPass(ID) {
- initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override;
-
- private:
- // This transformation requires dominator postdominator info
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<MemoryDependenceWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<MemoryDependenceWrapperPass>();
- }
+class MemCpyOptLegacyPass : public FunctionPass {
+ MemCpyOptPass Impl;
- // Helper functions
- bool processStore(StoreInst *SI, BasicBlock::iterator &BBI);
- bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI);
- bool processMemCpy(MemCpyInst *M);
- bool processMemMove(MemMoveInst *M);
- bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc,
- uint64_t cpyLen, unsigned cpyAlign, CallInst *C);
- bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep);
- bool processMemSetMemCpyDependence(MemCpyInst *M, MemSetInst *MDep);
- bool performMemCpyToMemSetOptzn(MemCpyInst *M, MemSetInst *MDep);
- bool processByValArgument(CallSite CS, unsigned ArgNo);
- Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
- Value *ByteVal);
-
- bool iterateOnFunction(Function &F);
- };
+public:
+ static char ID; // Pass identification, replacement for typeid
- char MemCpyOptLegacyPass::ID = 0;
-}
+ MemCpyOptLegacyPass() : FunctionPass(ID) {
+ initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+private:
+ // This transformation requires dominator postdominator info
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<MemoryDependenceWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addPreserved<GlobalsAAWrapperPass>();
+ AU.addPreserved<MemoryDependenceWrapperPass>();
+ }
+};
+
+char MemCpyOptLegacyPass::ID = 0;
+
+} // end anonymous namespace
/// The public interface to this file...
FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); }
@@ -523,14 +541,15 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P,
if (Args.erase(C))
NeedLift = true;
else if (MayAlias) {
- NeedLift = any_of(MemLocs, [C, &AA](const MemoryLocation &ML) {
+ NeedLift = llvm::any_of(MemLocs, [C, &AA](const MemoryLocation &ML) {
return AA.getModRefInfo(C, ML);
});
if (!NeedLift)
- NeedLift = any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) {
- return AA.getModRefInfo(C, CS);
- });
+ NeedLift =
+ llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) {
+ return AA.getModRefInfo(C, CS);
+ });
}
if (!NeedLift)
@@ -567,7 +586,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P,
}
// We made it, we need to lift
- for (auto *I : reverse(ToLift)) {
+ for (auto *I : llvm::reverse(ToLift)) {
DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n");
I->moveBefore(P);
}
@@ -761,7 +780,6 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
return false;
}
-
/// Takes a memcpy and a call that it depends on,
/// and checks for the possibility of a call slot optimization by having
/// the call write its result directly into the destination of the memcpy.
@@ -914,6 +932,17 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
if (MR != MRI_NoModRef)
return false;
+ // We can't create address space casts here because we don't know if they're
+ // safe for the target.
+ if (cpySrc->getType()->getPointerAddressSpace() !=
+ cpyDest->getType()->getPointerAddressSpace())
+ return false;
+ for (unsigned i = 0; i < CS.arg_size(); ++i)
+ if (CS.getArgument(i)->stripPointerCasts() == cpySrc &&
+ cpySrc->getType()->getPointerAddressSpace() !=
+ CS.getArgument(i)->getType()->getPointerAddressSpace())
+ return false;
+
// All the checks have passed, so do the transformation.
bool changedArgument = false;
for (unsigned i = 0; i < CS.arg_size(); ++i)
@@ -1240,7 +1269,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) {
bool MemCpyOptPass::processMemMove(MemMoveInst *M) {
AliasAnalysis &AA = LookupAliasAnalysis();
- if (!TLI->has(LibFunc::memmove))
+ if (!TLI->has(LibFunc_memmove))
return false;
// See if the pointers alias.
@@ -1306,6 +1335,11 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {
CS.getInstruction(), &AC, &DT) < ByValAlign)
return false;
+ // The address space of the memcpy source must match the byval argument
+ if (MDep->getSource()->getType()->getPointerAddressSpace() !=
+ ByValArg->getType()->getPointerAddressSpace())
+ return false;
+
// Verify that the copied-from memory doesn't change in between the memcpy and
// the byval call.
// memcpy(a <- b)
@@ -1375,7 +1409,6 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {
}
PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) {
-
auto &MD = AM.getResult<MemoryDependenceAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
@@ -1393,7 +1426,9 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) {
LookupAssumptionCache, LookupDomTree);
if (!MadeChange)
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
PA.preserve<MemoryDependenceAnalysis>();
return PA;
@@ -1414,10 +1449,10 @@ bool MemCpyOptPass::runImpl(
// If we don't have at least memset and memcpy, there is little point of doing
// anything here. These are required by a freestanding implementation, so if
// even they are disabled, there is no point in trying hard.
- if (!TLI->has(LibFunc::memset) || !TLI->has(LibFunc::memcpy))
+ if (!TLI->has(LibFunc_memset) || !TLI->has(LibFunc_memcpy))
return false;
- while (1) {
+ while (true) {
if (!iterateOnFunction(F))
break;
MadeChange = true;
diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
index 6a64c6b3619c4..acd3ef6791bed 100644
--- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
+++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
@@ -19,6 +19,8 @@
// thinks it safe to do so. This optimization helps with eg. hiding load
// latencies, triggering if-conversion, and reducing static code size.
//
+// NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist.
+//
//===----------------------------------------------------------------------===//
//
//
@@ -87,7 +89,6 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include "llvm/Transforms/Utils/SSAUpdater.h"
using namespace llvm;
@@ -118,16 +119,6 @@ private:
void removeInstruction(Instruction *Inst);
BasicBlock *getDiamondTail(BasicBlock *BB);
bool isDiamondHead(BasicBlock *BB);
- // Routines for hoisting loads
- bool isLoadHoistBarrierInRange(const Instruction &Start,
- const Instruction &End, LoadInst *LI,
- bool SafeToLoadUnconditionally);
- LoadInst *canHoistFromBlock(BasicBlock *BB, LoadInst *LI);
- void hoistInstruction(BasicBlock *BB, Instruction *HoistCand,
- Instruction *ElseInst);
- bool isSafeToHoist(Instruction *I) const;
- bool hoistLoad(BasicBlock *BB, LoadInst *HoistCand, LoadInst *ElseInst);
- bool mergeLoads(BasicBlock *BB);
// Routines for sinking stores
StoreInst *canSinkFromBlock(BasicBlock *BB, StoreInst *SI);
PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1);
@@ -188,169 +179,6 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) {
return true;
}
-///
-/// \brief True when instruction is a hoist barrier for a load
-///
-/// Whenever an instruction could possibly modify the value
-/// being loaded or protect against the load from happening
-/// it is considered a hoist barrier.
-///
-bool MergedLoadStoreMotion::isLoadHoistBarrierInRange(
- const Instruction &Start, const Instruction &End, LoadInst *LI,
- bool SafeToLoadUnconditionally) {
- if (!SafeToLoadUnconditionally)
- for (const Instruction &Inst :
- make_range(Start.getIterator(), End.getIterator()))
- if (!isGuaranteedToTransferExecutionToSuccessor(&Inst))
- return true;
- MemoryLocation Loc = MemoryLocation::get(LI);
- return AA->canInstructionRangeModRef(Start, End, Loc, MRI_Mod);
-}
-
-///
-/// \brief Decide if a load can be hoisted
-///
-/// When there is a load in \p BB to the same address as \p LI
-/// and it can be hoisted from \p BB, return that load.
-/// Otherwise return Null.
-///
-LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB1,
- LoadInst *Load0) {
- BasicBlock *BB0 = Load0->getParent();
- BasicBlock *Head = BB0->getSinglePredecessor();
- bool SafeToLoadUnconditionally = isSafeToLoadUnconditionally(
- Load0->getPointerOperand(), Load0->getAlignment(),
- Load0->getModule()->getDataLayout(),
- /*ScanFrom=*/Head->getTerminator());
- for (BasicBlock::iterator BBI = BB1->begin(), BBE = BB1->end(); BBI != BBE;
- ++BBI) {
- Instruction *Inst = &*BBI;
-
- // Only merge and hoist loads when their result in used only in BB
- auto *Load1 = dyn_cast<LoadInst>(Inst);
- if (!Load1 || Inst->isUsedOutsideOfBlock(BB1))
- continue;
-
- MemoryLocation Loc0 = MemoryLocation::get(Load0);
- MemoryLocation Loc1 = MemoryLocation::get(Load1);
- if (Load0->isSameOperationAs(Load1) && AA->isMustAlias(Loc0, Loc1) &&
- !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1,
- SafeToLoadUnconditionally) &&
- !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0,
- SafeToLoadUnconditionally)) {
- return Load1;
- }
- }
- return nullptr;
-}
-
-///
-/// \brief Merge two equivalent instructions \p HoistCand and \p ElseInst into
-/// \p BB
-///
-/// BB is the head of a diamond
-///
-void MergedLoadStoreMotion::hoistInstruction(BasicBlock *BB,
- Instruction *HoistCand,
- Instruction *ElseInst) {
- DEBUG(dbgs() << " Hoist Instruction into BB \n"; BB->dump();
- dbgs() << "Instruction Left\n"; HoistCand->dump(); dbgs() << "\n";
- dbgs() << "Instruction Right\n"; ElseInst->dump(); dbgs() << "\n");
- // Hoist the instruction.
- assert(HoistCand->getParent() != BB);
-
- // Intersect optional metadata.
- HoistCand->andIRFlags(ElseInst);
- HoistCand->dropUnknownNonDebugMetadata();
-
- // Prepend point for instruction insert
- Instruction *HoistPt = BB->getTerminator();
-
- // Merged instruction
- Instruction *HoistedInst = HoistCand->clone();
-
- // Hoist instruction.
- HoistedInst->insertBefore(HoistPt);
-
- HoistCand->replaceAllUsesWith(HoistedInst);
- removeInstruction(HoistCand);
- // Replace the else block instruction.
- ElseInst->replaceAllUsesWith(HoistedInst);
- removeInstruction(ElseInst);
-}
-
-///
-/// \brief Return true if no operand of \p I is defined in I's parent block
-///
-bool MergedLoadStoreMotion::isSafeToHoist(Instruction *I) const {
- BasicBlock *Parent = I->getParent();
- for (Use &U : I->operands())
- if (auto *Instr = dyn_cast<Instruction>(&U))
- if (Instr->getParent() == Parent)
- return false;
- return true;
-}
-
-///
-/// \brief Merge two equivalent loads and GEPs and hoist into diamond head
-///
-bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0,
- LoadInst *L1) {
- // Only one definition?
- auto *A0 = dyn_cast<Instruction>(L0->getPointerOperand());
- auto *A1 = dyn_cast<Instruction>(L1->getPointerOperand());
- if (A0 && A1 && A0->isIdenticalTo(A1) && isSafeToHoist(A0) &&
- A0->hasOneUse() && (A0->getParent() == L0->getParent()) &&
- A1->hasOneUse() && (A1->getParent() == L1->getParent()) &&
- isa<GetElementPtrInst>(A0)) {
- DEBUG(dbgs() << "Hoist Instruction into BB \n"; BB->dump();
- dbgs() << "Instruction Left\n"; L0->dump(); dbgs() << "\n";
- dbgs() << "Instruction Right\n"; L1->dump(); dbgs() << "\n");
- hoistInstruction(BB, A0, A1);
- hoistInstruction(BB, L0, L1);
- return true;
- }
- return false;
-}
-
-///
-/// \brief Try to hoist two loads to same address into diamond header
-///
-/// Starting from a diamond head block, iterate over the instructions in one
-/// successor block and try to match a load in the second successor.
-///
-bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) {
- bool MergedLoads = false;
- assert(isDiamondHead(BB));
- BranchInst *BI = cast<BranchInst>(BB->getTerminator());
- BasicBlock *Succ0 = BI->getSuccessor(0);
- BasicBlock *Succ1 = BI->getSuccessor(1);
- // #Instructions in Succ1 for Compile Time Control
- int Size1 = Succ1->size();
- int NLoads = 0;
- for (BasicBlock::iterator BBI = Succ0->begin(), BBE = Succ0->end();
- BBI != BBE;) {
- Instruction *I = &*BBI;
- ++BBI;
-
- // Don't move non-simple (atomic, volatile) loads.
- auto *L0 = dyn_cast<LoadInst>(I);
- if (!L0 || !L0->isSimple() || L0->isUsedOutsideOfBlock(Succ0))
- continue;
-
- ++NLoads;
- if (NLoads * Size1 >= MagicCompileTimeControl)
- break;
- if (LoadInst *L1 = canHoistFromBlock(Succ1, L0)) {
- bool Res = hoistLoad(BB, L0, L1);
- MergedLoads |= Res;
- // Don't attempt to hoist above loads that had not been hoisted.
- if (!Res)
- break;
- }
- }
- return MergedLoads;
-}
///
/// \brief True when instruction is a sink barrier for a store
@@ -534,7 +362,6 @@ bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD,
// Hoist equivalent loads and sink stores
// outside diamonds when possible
if (isDiamondHead(BB)) {
- Changed |= mergeLoads(BB);
Changed |= mergeStores(getDiamondTail(BB));
}
}
@@ -596,8 +423,8 @@ MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) {
if (!Impl.run(F, MD, AA))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
PA.preserve<MemoryDependenceAnalysis>();
return PA;
diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp
index 0a3bf7b4c31b7..c5bf2f28d1852 100644
--- a/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -156,20 +156,12 @@ PreservedAnalyses NaryReassociatePass::run(Function &F,
auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
- bool Changed = runImpl(F, AC, DT, SE, TLI, TTI);
-
- // FIXME: We need to invalidate this to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<ScalarEvolutionAnalysis>(F);
-
- if (!Changed)
+ if (!runImpl(F, AC, DT, SE, TLI, TTI))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<ScalarEvolutionAnalysis>();
- PA.preserve<TargetLibraryAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp
index 57e6e3ddad949..3d8ce888867ea 100644
--- a/lib/Transforms/Scalar/NewGVN.cpp
+++ b/lib/Transforms/Scalar/NewGVN.cpp
@@ -17,6 +17,27 @@
/// "A Sparse Algorithm for Predicated Global Value Numbering" from
/// Karthik Gargi.
///
+/// A brief overview of the algorithm: The algorithm is essentially the same as
+/// the standard RPO value numbering algorithm (a good reference is the paper
+/// "SCC based value numbering" by L. Taylor Simpson) with one major difference:
+/// The RPO algorithm proceeds, on every iteration, to process every reachable
+/// block and every instruction in that block. This is because the standard RPO
+/// algorithm does not track what things have the same value number, it only
+/// tracks what the value number of a given operation is (the mapping is
+/// operation -> value number). Thus, when a value number of an operation
+/// changes, it must reprocess everything to ensure all uses of a value number
+/// get updated properly. In constrast, the sparse algorithm we use *also*
+/// tracks what operations have a given value number (IE it also tracks the
+/// reverse mapping from value number -> operations with that value number), so
+/// that it only needs to reprocess the instructions that are affected when
+/// something's value number changes. The rest of the algorithm is devoted to
+/// performing symbolic evaluation, forward propagation, and simplification of
+/// operations based on the value numbers deduced so far.
+///
+/// We also do not perform elimination by using any published algorithm. All
+/// published algorithms are O(Instructions). Instead, we use a technique that
+/// is O(number of operations with the same value number), enabling us to skip
+/// trying to eliminate things that have unique value numbers.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/NewGVN.h"
@@ -40,13 +61,10 @@
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
-#include "llvm/Analysis/PHITransAddr.h"
+#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/GlobalVariable.h"
@@ -55,24 +73,25 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/PredIteratorCache.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugCounter.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVNExpression.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Utils/MemorySSA.h"
-#include "llvm/Transforms/Utils/SSAUpdater.h"
+#include "llvm/Transforms/Utils/PredicateInfo.h"
+#include "llvm/Transforms/Utils/VNCoercion.h"
+#include <numeric>
#include <unordered_map>
#include <utility>
#include <vector>
using namespace llvm;
using namespace PatternMatch;
using namespace llvm::GVNExpression;
-
+using namespace llvm::VNCoercion;
#define DEBUG_TYPE "newgvn"
STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted");
@@ -87,6 +106,15 @@ STATISTIC(NumGVNAvoidedSortedLeaderChanges,
"Number of avoided sorted leader changes");
STATISTIC(NumGVNNotMostDominatingLeader,
"Number of times a member dominated it's new classes' leader");
+STATISTIC(NumGVNDeadStores, "Number of redundant/dead stores eliminated");
+DEBUG_COUNTER(VNCounter, "newgvn-vn",
+ "Controls which instructions are value numbered")
+
+// Currently store defining access refinement is too slow due to basicaa being
+// egregiously slow. This flag lets us keep it working while we work on this
+// issue.
+static cl::opt<bool> EnableStoreRefinement("enable-store-refinement",
+ cl::init(false), cl::Hidden);
//===----------------------------------------------------------------------===//
// GVN Pass
@@ -105,6 +133,77 @@ PHIExpression::~PHIExpression() = default;
}
}
+// Tarjan's SCC finding algorithm with Nuutila's improvements
+// SCCIterator is actually fairly complex for the simple thing we want.
+// It also wants to hand us SCC's that are unrelated to the phi node we ask
+// about, and have us process them there or risk redoing work.
+// Graph traits over a filter iterator also doesn't work that well here.
+// This SCC finder is specialized to walk use-def chains, and only follows instructions,
+// not generic values (arguments, etc).
+struct TarjanSCC {
+
+ TarjanSCC() : Components(1) {}
+
+ void Start(const Instruction *Start) {
+ if (Root.lookup(Start) == 0)
+ FindSCC(Start);
+ }
+
+ const SmallPtrSetImpl<const Value *> &getComponentFor(const Value *V) const {
+ unsigned ComponentID = ValueToComponent.lookup(V);
+
+ assert(ComponentID > 0 &&
+ "Asking for a component for a value we never processed");
+ return Components[ComponentID];
+ }
+
+private:
+ void FindSCC(const Instruction *I) {
+ Root[I] = ++DFSNum;
+ // Store the DFS Number we had before it possibly gets incremented.
+ unsigned int OurDFS = DFSNum;
+ for (auto &Op : I->operands()) {
+ if (auto *InstOp = dyn_cast<Instruction>(Op)) {
+ if (Root.lookup(Op) == 0)
+ FindSCC(InstOp);
+ if (!InComponent.count(Op))
+ Root[I] = std::min(Root.lookup(I), Root.lookup(Op));
+ }
+ }
+ // See if we really were the root of a component, by seeing if we still have our DFSNumber.
+ // If we do, we are the root of the component, and we have completed a component. If we do not,
+ // we are not the root of a component, and belong on the component stack.
+ if (Root.lookup(I) == OurDFS) {
+ unsigned ComponentID = Components.size();
+ Components.resize(Components.size() + 1);
+ auto &Component = Components.back();
+ Component.insert(I);
+ DEBUG(dbgs() << "Component root is " << *I << "\n");
+ InComponent.insert(I);
+ ValueToComponent[I] = ComponentID;
+ // Pop a component off the stack and label it.
+ while (!Stack.empty() && Root.lookup(Stack.back()) >= OurDFS) {
+ auto *Member = Stack.back();
+ DEBUG(dbgs() << "Component member is " << *Member << "\n");
+ Component.insert(Member);
+ InComponent.insert(Member);
+ ValueToComponent[Member] = ComponentID;
+ Stack.pop_back();
+ }
+ } else {
+ // Part of a component, push to stack
+ Stack.push_back(I);
+ }
+ }
+ unsigned int DFSNum = 1;
+ SmallPtrSet<const Value *, 8> InComponent;
+ DenseMap<const Value *, unsigned int> Root;
+ SmallVector<const Value *, 8> Stack;
+ // Store the components as vector of ptr sets, because we need the topo order
+ // of SCC's, but not individual member order
+ SmallVector<SmallPtrSet<const Value *, 8>, 8> Components;
+ DenseMap<const Value *, unsigned> ValueToComponent;
+};
// Congruence classes represent the set of expressions/instructions
// that are all the same *during some scope in the function*.
// That is, because of the way we perform equality propagation, and
@@ -115,43 +214,152 @@ PHIExpression::~PHIExpression() = default;
// For any Value in the Member set, it is valid to replace any dominated member
// with that Value.
//
-// Every congruence class has a leader, and the leader is used to
-// symbolize instructions in a canonical way (IE every operand of an
-// instruction that is a member of the same congruence class will
-// always be replaced with leader during symbolization).
-// To simplify symbolization, we keep the leader as a constant if class can be
-// proved to be a constant value.
-// Otherwise, the leader is a randomly chosen member of the value set, it does
-// not matter which one is chosen.
-// Each congruence class also has a defining expression,
-// though the expression may be null. If it exists, it can be used for forward
-// propagation and reassociation of values.
-//
-struct CongruenceClass {
- using MemberSet = SmallPtrSet<Value *, 4>;
+// Every congruence class has a leader, and the leader is used to symbolize
+// instructions in a canonical way (IE every operand of an instruction that is a
+// member of the same congruence class will always be replaced with leader
+// during symbolization). To simplify symbolization, we keep the leader as a
+// constant if class can be proved to be a constant value. Otherwise, the
+// leader is the member of the value set with the smallest DFS number. Each
+// congruence class also has a defining expression, though the expression may be
+// null. If it exists, it can be used for forward propagation and reassociation
+// of values.
+
+// For memory, we also track a representative MemoryAccess, and a set of memory
+// members for MemoryPhis (which have no real instructions). Note that for
+// memory, it seems tempting to try to split the memory members into a
+// MemoryCongruenceClass or something. Unfortunately, this does not work
+// easily. The value numbering of a given memory expression depends on the
+// leader of the memory congruence class, and the leader of memory congruence
+// class depends on the value numbering of a given memory expression. This
+// leads to wasted propagation, and in some cases, missed optimization. For
+// example: If we had value numbered two stores together before, but now do not,
+// we move them to a new value congruence class. This in turn will move at one
+// of the memorydefs to a new memory congruence class. Which in turn, affects
+// the value numbering of the stores we just value numbered (because the memory
+// congruence class is part of the value number). So while theoretically
+// possible to split them up, it turns out to be *incredibly* complicated to get
+// it to work right, because of the interdependency. While structurally
+// slightly messier, it is algorithmically much simpler and faster to do what we
+// do here, and track them both at once in the same class.
+// Note: The default iterators for this class iterate over values
+class CongruenceClass {
+public:
+ using MemberType = Value;
+ using MemberSet = SmallPtrSet<MemberType *, 4>;
+ using MemoryMemberType = MemoryPhi;
+ using MemoryMemberSet = SmallPtrSet<const MemoryMemberType *, 2>;
+
+ explicit CongruenceClass(unsigned ID) : ID(ID) {}
+ CongruenceClass(unsigned ID, Value *Leader, const Expression *E)
+ : ID(ID), RepLeader(Leader), DefiningExpr(E) {}
+ unsigned getID() const { return ID; }
+ // True if this class has no members left. This is mainly used for assertion
+ // purposes, and for skipping empty classes.
+ bool isDead() const {
+ // If it's both dead from a value perspective, and dead from a memory
+ // perspective, it's really dead.
+ return empty() && memory_empty();
+ }
+ // Leader functions
+ Value *getLeader() const { return RepLeader; }
+ void setLeader(Value *Leader) { RepLeader = Leader; }
+ const std::pair<Value *, unsigned int> &getNextLeader() const {
+ return NextLeader;
+ }
+ void resetNextLeader() { NextLeader = {nullptr, ~0}; }
+
+ void addPossibleNextLeader(std::pair<Value *, unsigned int> LeaderPair) {
+ if (LeaderPair.second < NextLeader.second)
+ NextLeader = LeaderPair;
+ }
+
+ Value *getStoredValue() const { return RepStoredValue; }
+ void setStoredValue(Value *Leader) { RepStoredValue = Leader; }
+ const MemoryAccess *getMemoryLeader() const { return RepMemoryAccess; }
+ void setMemoryLeader(const MemoryAccess *Leader) { RepMemoryAccess = Leader; }
+
+ // Forward propagation info
+ const Expression *getDefiningExpr() const { return DefiningExpr; }
+ void setDefiningExpr(const Expression *E) { DefiningExpr = E; }
+
+ // Value member set
+ bool empty() const { return Members.empty(); }
+ unsigned size() const { return Members.size(); }
+ MemberSet::const_iterator begin() const { return Members.begin(); }
+ MemberSet::const_iterator end() const { return Members.end(); }
+ void insert(MemberType *M) { Members.insert(M); }
+ void erase(MemberType *M) { Members.erase(M); }
+ void swap(MemberSet &Other) { Members.swap(Other); }
+
+ // Memory member set
+ bool memory_empty() const { return MemoryMembers.empty(); }
+ unsigned memory_size() const { return MemoryMembers.size(); }
+ MemoryMemberSet::const_iterator memory_begin() const {
+ return MemoryMembers.begin();
+ }
+ MemoryMemberSet::const_iterator memory_end() const {
+ return MemoryMembers.end();
+ }
+ iterator_range<MemoryMemberSet::const_iterator> memory() const {
+ return make_range(memory_begin(), memory_end());
+ }
+ void memory_insert(const MemoryMemberType *M) { MemoryMembers.insert(M); }
+ void memory_erase(const MemoryMemberType *M) { MemoryMembers.erase(M); }
+
+ // Store count
+ unsigned getStoreCount() const { return StoreCount; }
+ void incStoreCount() { ++StoreCount; }
+ void decStoreCount() {
+ assert(StoreCount != 0 && "Store count went negative");
+ --StoreCount;
+ }
+
+ // Return true if two congruence classes are equivalent to each other. This
+ // means
+ // that every field but the ID number and the dead field are equivalent.
+ bool isEquivalentTo(const CongruenceClass *Other) const {
+ if (!Other)
+ return false;
+ if (this == Other)
+ return true;
+
+ if (std::tie(StoreCount, RepLeader, RepStoredValue, RepMemoryAccess) !=
+ std::tie(Other->StoreCount, Other->RepLeader, Other->RepStoredValue,
+ Other->RepMemoryAccess))
+ return false;
+ if (DefiningExpr != Other->DefiningExpr)
+ if (!DefiningExpr || !Other->DefiningExpr ||
+ *DefiningExpr != *Other->DefiningExpr)
+ return false;
+ // We need some ordered set
+ std::set<Value *> AMembers(Members.begin(), Members.end());
+ std::set<Value *> BMembers(Members.begin(), Members.end());
+ return AMembers == BMembers;
+ }
+
+private:
unsigned ID;
// Representative leader.
Value *RepLeader = nullptr;
+ // The most dominating leader after our current leader, because the member set
+ // is not sorted and is expensive to keep sorted all the time.
+ std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U};
+ // If this is represented by a store, the value of the store.
+ Value *RepStoredValue = nullptr;
+ // If this class contains MemoryDefs or MemoryPhis, this is the leading memory
+ // access.
+ const MemoryAccess *RepMemoryAccess = nullptr;
// Defining Expression.
const Expression *DefiningExpr = nullptr;
// Actual members of this class.
MemberSet Members;
-
- // True if this class has no members left. This is mainly used for assertion
- // purposes, and for skipping empty classes.
- bool Dead = false;
-
+ // This is the set of MemoryPhis that exist in the class. MemoryDefs and
+ // MemoryUses have real instructions representing them, so we only need to
+ // track MemoryPhis here.
+ MemoryMemberSet MemoryMembers;
// Number of stores in this congruence class.
// This is used so we can detect store equivalence changes properly.
int StoreCount = 0;
-
- // The most dominating leader after our current leader, because the member set
- // is not sorted and is expensive to keep sorted all the time.
- std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U};
-
- explicit CongruenceClass(unsigned ID) : ID(ID) {}
- CongruenceClass(unsigned ID, Value *Leader, const Expression *E)
- : ID(ID), RepLeader(Leader), DefiningExpr(E) {}
};
namespace llvm {
@@ -180,19 +388,34 @@ template <> struct DenseMapInfo<const Expression *> {
};
} // end namespace llvm
-class NewGVN : public FunctionPass {
+namespace {
+class NewGVN {
+ Function &F;
DominatorTree *DT;
- const DataLayout *DL;
- const TargetLibraryInfo *TLI;
AssumptionCache *AC;
+ const TargetLibraryInfo *TLI;
AliasAnalysis *AA;
MemorySSA *MSSA;
MemorySSAWalker *MSSAWalker;
+ const DataLayout &DL;
+ std::unique_ptr<PredicateInfo> PredInfo;
BumpPtrAllocator ExpressionAllocator;
ArrayRecycler<Value *> ArgRecycler;
+ TarjanSCC SCCFinder;
+
+ // Number of function arguments, used by ranking
+ unsigned int NumFuncArgs;
+
+ // RPOOrdering of basic blocks
+ DenseMap<const DomTreeNode *, unsigned> RPOOrdering;
// Congruence class info.
- CongruenceClass *InitialClass;
+
+ // This class is called INITIAL in the paper. It is the class everything
+ // startsout in, and represents any value. Being an optimistic analysis,
+ // anything in the TOP class has the value TOP, which is indeterminate and
+ // equivalent to everything.
+ CongruenceClass *TOPClass;
std::vector<CongruenceClass *> CongruenceClasses;
unsigned NextCongruenceNum;
@@ -200,13 +423,38 @@ class NewGVN : public FunctionPass {
DenseMap<Value *, CongruenceClass *> ValueToClass;
DenseMap<Value *, const Expression *> ValueToExpression;
+ // Mapping from predicate info we used to the instructions we used it with.
+ // In order to correctly ensure propagation, we must keep track of what
+ // comparisons we used, so that when the values of the comparisons change, we
+ // propagate the information to the places we used the comparison.
+ DenseMap<const Value *, SmallPtrSet<Instruction *, 2>> PredicateToUsers;
+ // Mapping from MemoryAccess we used to the MemoryAccess we used it with. Has
+ // the same reasoning as PredicateToUsers. When we skip MemoryAccesses for
+ // stores, we no longer can rely solely on the def-use chains of MemorySSA.
+ DenseMap<const MemoryAccess *, SmallPtrSet<MemoryAccess *, 2>> MemoryToUsers;
+
// A table storing which memorydefs/phis represent a memory state provably
// equivalent to another memory state.
// We could use the congruence class machinery, but the MemoryAccess's are
// abstract memory states, so they can only ever be equivalent to each other,
// and not to constants, etc.
- DenseMap<const MemoryAccess *, MemoryAccess *> MemoryAccessEquiv;
-
+ DenseMap<const MemoryAccess *, CongruenceClass *> MemoryAccessToClass;
+
+ // We could, if we wanted, build MemoryPhiExpressions and
+ // MemoryVariableExpressions, etc, and value number them the same way we value
+ // number phi expressions. For the moment, this seems like overkill. They
+ // can only exist in one of three states: they can be TOP (equal to
+ // everything), Equivalent to something else, or unique. Because we do not
+ // create expressions for them, we need to simulate leader change not just
+ // when they change class, but when they change state. Note: We can do the
+ // same thing for phis, and avoid having phi expressions if we wanted, We
+ // should eventually unify in one direction or the other, so this is a little
+ // bit of an experiment in which turns out easier to maintain.
+ enum MemoryPhiState { MPS_Invalid, MPS_TOP, MPS_Equivalent, MPS_Unique };
+ DenseMap<const MemoryPhi *, MemoryPhiState> MemoryPhiState;
+
+ enum PhiCycleState { PCS_Unknown, PCS_CycleFree, PCS_Cycle };
+ DenseMap<const PHINode *, PhiCycleState> PhiCycleState;
// Expression to class mapping.
using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>;
ExpressionClassMap ExpressionToClass;
@@ -231,8 +479,6 @@ class NewGVN : public FunctionPass {
BitVector TouchedInstructions;
DenseMap<const BasicBlock *, std::pair<unsigned, unsigned>> BlockInstRange;
- DenseMap<const DomTreeNode *, std::pair<unsigned, unsigned>>
- DominatedInstRange;
#ifndef NDEBUG
// Debugging for how many times each block and instruction got processed.
@@ -240,56 +486,42 @@ class NewGVN : public FunctionPass {
#endif
// DFS info.
- DenseMap<const BasicBlock *, std::pair<int, int>> DFSDomMap;
+ // This contains a mapping from Instructions to DFS numbers.
+ // The numbering starts at 1. An instruction with DFS number zero
+ // means that the instruction is dead.
DenseMap<const Value *, unsigned> InstrDFS;
+
+ // This contains the mapping DFS numbers to instructions.
SmallVector<Value *, 32> DFSToInstr;
// Deletion info.
SmallPtrSet<Instruction *, 8> InstructionsToErase;
public:
- static char ID; // Pass identification, replacement for typeid.
- NewGVN() : FunctionPass(ID) {
- initializeNewGVNPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
- bool runGVN(Function &F, DominatorTree *DT, AssumptionCache *AC,
- TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA);
+ NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC,
+ TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA,
+ const DataLayout &DL)
+ : F(F), DT(DT), AC(AC), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL),
+ PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)) {}
+ bool runGVN();
private:
- // This transformation requires dominator postdominator info.
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
-
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-
// Expression handling.
- const Expression *createExpression(Instruction *, const BasicBlock *);
- const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *,
- const BasicBlock *);
- PHIExpression *createPHIExpression(Instruction *);
+ const Expression *createExpression(Instruction *);
+ const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *);
+ PHIExpression *createPHIExpression(Instruction *, bool &HasBackedge,
+ bool &AllConstant);
const VariableExpression *createVariableExpression(Value *);
const ConstantExpression *createConstantExpression(Constant *);
- const Expression *createVariableOrConstant(Value *V, const BasicBlock *B);
+ const Expression *createVariableOrConstant(Value *V);
const UnknownExpression *createUnknownExpression(Instruction *);
- const StoreExpression *createStoreExpression(StoreInst *, MemoryAccess *,
- const BasicBlock *);
+ const StoreExpression *createStoreExpression(StoreInst *,
+ const MemoryAccess *);
LoadExpression *createLoadExpression(Type *, Value *, LoadInst *,
- MemoryAccess *, const BasicBlock *);
-
- const CallExpression *createCallExpression(CallInst *, MemoryAccess *,
- const BasicBlock *);
- const AggregateValueExpression *
- createAggregateValueExpression(Instruction *, const BasicBlock *);
- bool setBasicExpressionInfo(Instruction *, BasicExpression *,
- const BasicBlock *);
+ const MemoryAccess *);
+ const CallExpression *createCallExpression(CallInst *, const MemoryAccess *);
+ const AggregateValueExpression *createAggregateValueExpression(Instruction *);
+ bool setBasicExpressionInfo(Instruction *, BasicExpression *);
// Congruence class handling.
CongruenceClass *createCongruenceClass(Value *Leader, const Expression *E) {
@@ -298,9 +530,21 @@ private:
return result;
}
+ CongruenceClass *createMemoryClass(MemoryAccess *MA) {
+ auto *CC = createCongruenceClass(nullptr, nullptr);
+ CC->setMemoryLeader(MA);
+ return CC;
+ }
+ CongruenceClass *ensureLeaderOfMemoryClass(MemoryAccess *MA) {
+ auto *CC = getMemoryClass(MA);
+ if (CC->getMemoryLeader() != MA)
+ CC = createMemoryClass(MA);
+ return CC;
+ }
+
CongruenceClass *createSingletonCongruenceClass(Value *Member) {
CongruenceClass *CClass = createCongruenceClass(Member, nullptr);
- CClass->Members.insert(Member);
+ CClass->insert(Member);
ValueToClass[Member] = CClass;
return CClass;
}
@@ -313,37 +557,49 @@ private:
// Symbolic evaluation.
const Expression *checkSimplificationResults(Expression *, Instruction *,
Value *);
- const Expression *performSymbolicEvaluation(Value *, const BasicBlock *);
- const Expression *performSymbolicLoadEvaluation(Instruction *,
- const BasicBlock *);
- const Expression *performSymbolicStoreEvaluation(Instruction *,
- const BasicBlock *);
- const Expression *performSymbolicCallEvaluation(Instruction *,
- const BasicBlock *);
- const Expression *performSymbolicPHIEvaluation(Instruction *,
- const BasicBlock *);
- bool setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To);
- const Expression *performSymbolicAggrValueEvaluation(Instruction *,
- const BasicBlock *);
+ const Expression *performSymbolicEvaluation(Value *);
+ const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *,
+ Instruction *, MemoryAccess *);
+ const Expression *performSymbolicLoadEvaluation(Instruction *);
+ const Expression *performSymbolicStoreEvaluation(Instruction *);
+ const Expression *performSymbolicCallEvaluation(Instruction *);
+ const Expression *performSymbolicPHIEvaluation(Instruction *);
+ const Expression *performSymbolicAggrValueEvaluation(Instruction *);
+ const Expression *performSymbolicCmpEvaluation(Instruction *);
+ const Expression *performSymbolicPredicateInfoEvaluation(Instruction *);
// Congruence finding.
- // Templated to allow them to work both on BB's and BB-edges.
- template <class T>
- Value *lookupOperandLeader(Value *, const User *, const T &) const;
+ bool someEquivalentDominates(const Instruction *, const Instruction *) const;
+ Value *lookupOperandLeader(Value *) const;
void performCongruenceFinding(Instruction *, const Expression *);
- void moveValueToNewCongruenceClass(Instruction *, CongruenceClass *,
- CongruenceClass *);
+ void moveValueToNewCongruenceClass(Instruction *, const Expression *,
+ CongruenceClass *, CongruenceClass *);
+ void moveMemoryToNewCongruenceClass(Instruction *, MemoryAccess *,
+ CongruenceClass *, CongruenceClass *);
+ Value *getNextValueLeader(CongruenceClass *) const;
+ const MemoryAccess *getNextMemoryLeader(CongruenceClass *) const;
+ bool setMemoryClass(const MemoryAccess *From, CongruenceClass *To);
+ CongruenceClass *getMemoryClass(const MemoryAccess *MA) const;
+ const MemoryAccess *lookupMemoryLeader(const MemoryAccess *) const;
+ bool isMemoryAccessTop(const MemoryAccess *) const;
+
+ // Ranking
+ unsigned int getRank(const Value *) const;
+ bool shouldSwapOperands(const Value *, const Value *) const;
+
// Reachability handling.
void updateReachableEdge(BasicBlock *, BasicBlock *);
void processOutgoingEdges(TerminatorInst *, BasicBlock *);
- bool isOnlyReachableViaThisEdge(const BasicBlockEdge &) const;
- Value *findConditionEquivalence(Value *, BasicBlock *) const;
- MemoryAccess *lookupMemoryAccessEquiv(MemoryAccess *) const;
+ Value *findConditionEquivalence(Value *) const;
// Elimination.
struct ValueDFS;
- void convertDenseToDFSOrdered(CongruenceClass::MemberSet &,
- SmallVectorImpl<ValueDFS> &);
+ void convertClassToDFSOrdered(const CongruenceClass &,
+ SmallVectorImpl<ValueDFS> &,
+ DenseMap<const Value *, unsigned int> &,
+ SmallPtrSetImpl<Instruction *> &) const;
+ void convertClassToLoadsAndStores(const CongruenceClass &,
+ SmallVectorImpl<ValueDFS> &) const;
bool eliminateInstructions(Function &);
void replaceInstruction(Instruction *, Value *);
@@ -355,35 +611,58 @@ private:
// Various instruction touch utilities
void markUsersTouched(Value *);
- void markMemoryUsersTouched(MemoryAccess *);
- void markLeaderChangeTouched(CongruenceClass *CC);
+ void markMemoryUsersTouched(const MemoryAccess *);
+ void markMemoryDefTouched(const MemoryAccess *);
+ void markPredicateUsersTouched(Instruction *);
+ void markValueLeaderChangeTouched(CongruenceClass *CC);
+ void markMemoryLeaderChangeTouched(CongruenceClass *CC);
+ void addPredicateUsers(const PredicateBase *, Instruction *);
+ void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U);
+
+ // Main loop of value numbering
+ void iterateTouchedInstructions();
// Utilities.
void cleanupTables();
std::pair<unsigned, unsigned> assignDFSNumbers(BasicBlock *, unsigned);
void updateProcessedCount(Value *V);
void verifyMemoryCongruency() const;
+ void verifyIterationSettled(Function &F);
bool singleReachablePHIPath(const MemoryAccess *, const MemoryAccess *) const;
-};
-
-char NewGVN::ID = 0;
+ BasicBlock *getBlockForValue(Value *V) const;
+ void deleteExpression(const Expression *E);
+ unsigned InstrToDFSNum(const Value *V) const {
+ assert(isa<Instruction>(V) && "This should not be used for MemoryAccesses");
+ return InstrDFS.lookup(V);
+ }
-// createGVNPass - The public interface to this file.
-FunctionPass *llvm::createNewGVNPass() { return new NewGVN(); }
+ unsigned InstrToDFSNum(const MemoryAccess *MA) const {
+ return MemoryToDFSNum(MA);
+ }
+ Value *InstrFromDFSNum(unsigned DFSNum) { return DFSToInstr[DFSNum]; }
+ // Given a MemoryAccess, return the relevant instruction DFS number. Note:
+ // This deliberately takes a value so it can be used with Use's, which will
+ // auto-convert to Value's but not to MemoryAccess's.
+ unsigned MemoryToDFSNum(const Value *MA) const {
+ assert(isa<MemoryAccess>(MA) &&
+ "This should not be used with instructions");
+ return isa<MemoryUseOrDef>(MA)
+ ? InstrToDFSNum(cast<MemoryUseOrDef>(MA)->getMemoryInst())
+ : InstrDFS.lookup(MA);
+ }
+ bool isCycleFree(const PHINode *PN);
+ template <class T, class Range> T *getMinDFSOfRange(const Range &) const;
+ // Debug counter info. When verifying, we have to reset the value numbering
+ // debug counter to the same state it started in to get the same results.
+ std::pair<int, int> StartingVNCounter;
+};
+} // end anonymous namespace
template <typename T>
static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) {
- if ((!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) ||
- !LHS.BasicExpression::equals(RHS)) {
+ if (!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS))
return false;
- } else if (const auto *L = dyn_cast<LoadExpression>(&RHS)) {
- if (LHS.getDefiningAccess() != L->getDefiningAccess())
- return false;
- } else if (const auto *S = dyn_cast<StoreExpression>(&RHS)) {
- if (LHS.getDefiningAccess() != S->getDefiningAccess())
- return false;
- }
- return true;
+ return LHS.MemoryExpression::equals(RHS);
}
bool LoadExpression::equals(const Expression &Other) const {
@@ -391,7 +670,13 @@ bool LoadExpression::equals(const Expression &Other) const {
}
bool StoreExpression::equals(const Expression &Other) const {
- return equalsLoadStoreHelper(*this, Other);
+ if (!equalsLoadStoreHelper(*this, Other))
+ return false;
+ // Make sure that store vs store includes the value operand.
+ if (const auto *S = dyn_cast<StoreExpression>(&Other))
+ if (getStoredValue() != S->getStoredValue())
+ return false;
+ return true;
}
#ifndef NDEBUG
@@ -400,16 +685,28 @@ static std::string getBlockName(const BasicBlock *B) {
}
#endif
-INITIALIZE_PASS_BEGIN(NewGVN, "newgvn", "Global Value Numbering", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_END(NewGVN, "newgvn", "Global Value Numbering", false, false)
+// Get the basic block from an instruction/memory value.
+BasicBlock *NewGVN::getBlockForValue(Value *V) const {
+ if (auto *I = dyn_cast<Instruction>(V))
+ return I->getParent();
+ else if (auto *MP = dyn_cast<MemoryPhi>(V))
+ return MP->getBlock();
+ llvm_unreachable("Should have been able to figure out a block for our value");
+ return nullptr;
+}
-PHIExpression *NewGVN::createPHIExpression(Instruction *I) {
+// Delete a definitely dead expression, so it can be reused by the expression
+// allocator. Some of these are not in creation functions, so we have to accept
+// const versions.
+void NewGVN::deleteExpression(const Expression *E) {
+ assert(isa<BasicExpression>(E));
+ auto *BE = cast<BasicExpression>(E);
+ const_cast<BasicExpression *>(BE)->deallocateOperands(ArgRecycler);
+ ExpressionAllocator.Deallocate(E);
+}
+
+PHIExpression *NewGVN::createPHIExpression(Instruction *I, bool &HasBackedge,
+ bool &AllConstant) {
BasicBlock *PHIBlock = I->getParent();
auto *PN = cast<PHINode>(I);
auto *E =
@@ -419,28 +716,32 @@ PHIExpression *NewGVN::createPHIExpression(Instruction *I) {
E->setType(I->getType());
E->setOpcode(I->getOpcode());
- auto ReachablePhiArg = [&](const Use &U) {
- return ReachableBlocks.count(PN->getIncomingBlock(U));
- };
+ unsigned PHIRPO = RPOOrdering.lookup(DT->getNode(PHIBlock));
- // Filter out unreachable operands
- auto Filtered = make_filter_range(PN->operands(), ReachablePhiArg);
+ // Filter out unreachable phi operands.
+ auto Filtered = make_filter_range(PN->operands(), [&](const Use &U) {
+ return ReachableEdges.count({PN->getIncomingBlock(U), PHIBlock});
+ });
std::transform(Filtered.begin(), Filtered.end(), op_inserter(E),
[&](const Use &U) -> Value * {
+ auto *BB = PN->getIncomingBlock(U);
+ auto *DTN = DT->getNode(BB);
+ if (RPOOrdering.lookup(DTN) >= PHIRPO)
+ HasBackedge = true;
+ AllConstant &= isa<UndefValue>(U) || isa<Constant>(U);
+
// Don't try to transform self-defined phis.
if (U == PN)
return PN;
- const BasicBlockEdge BBE(PN->getIncomingBlock(U), PHIBlock);
- return lookupOperandLeader(U, I, BBE);
+ return lookupOperandLeader(U);
});
return E;
}
// Set basic expression info (Arguments, type, opcode) for Expression
// E from Instruction I in block B.
-bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E,
- const BasicBlock *B) {
+bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E) {
bool AllConstant = true;
if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
E->setType(GEP->getSourceElementType());
@@ -452,7 +753,7 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E,
// Transform the operand array into an operand leader array, and keep track of
// whether all members are constant.
std::transform(I->op_begin(), I->op_end(), op_inserter(E), [&](Value *O) {
- auto Operand = lookupOperandLeader(O, I, B);
+ auto Operand = lookupOperandLeader(O);
AllConstant &= isa<Constant>(Operand);
return Operand;
});
@@ -461,8 +762,7 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E,
}
const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T,
- Value *Arg1, Value *Arg2,
- const BasicBlock *B) {
+ Value *Arg1, Value *Arg2) {
auto *E = new (ExpressionAllocator) BasicExpression(2);
E->setType(T);
@@ -473,13 +773,13 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T,
// of their operands get the same value number by sorting the operand value
// numbers. Since all commutative instructions have two operands it is more
// efficient to sort by hand rather than using, say, std::sort.
- if (Arg1 > Arg2)
+ if (shouldSwapOperands(Arg1, Arg2))
std::swap(Arg1, Arg2);
}
- E->op_push_back(lookupOperandLeader(Arg1, nullptr, B));
- E->op_push_back(lookupOperandLeader(Arg2, nullptr, B));
+ E->op_push_back(lookupOperandLeader(Arg1));
+ E->op_push_back(lookupOperandLeader(Arg2));
- Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), *DL, TLI,
+ Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), DL, TLI,
DT, AC);
if (const Expression *SimplifiedE = checkSimplificationResults(E, nullptr, V))
return SimplifiedE;
@@ -502,40 +802,32 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E,
NumGVNOpsSimplified++;
assert(isa<BasicExpression>(E) &&
"We should always have had a basic expression here");
-
- cast<BasicExpression>(E)->deallocateOperands(ArgRecycler);
- ExpressionAllocator.Deallocate(E);
+ deleteExpression(E);
return createConstantExpression(C);
} else if (isa<Argument>(V) || isa<GlobalVariable>(V)) {
if (I)
DEBUG(dbgs() << "Simplified " << *I << " to "
<< " variable " << *V << "\n");
- cast<BasicExpression>(E)->deallocateOperands(ArgRecycler);
- ExpressionAllocator.Deallocate(E);
+ deleteExpression(E);
return createVariableExpression(V);
}
CongruenceClass *CC = ValueToClass.lookup(V);
- if (CC && CC->DefiningExpr) {
+ if (CC && CC->getDefiningExpr()) {
if (I)
DEBUG(dbgs() << "Simplified " << *I << " to "
<< " expression " << *V << "\n");
NumGVNOpsSimplified++;
- assert(isa<BasicExpression>(E) &&
- "We should always have had a basic expression here");
- cast<BasicExpression>(E)->deallocateOperands(ArgRecycler);
- ExpressionAllocator.Deallocate(E);
- return CC->DefiningExpr;
+ deleteExpression(E);
+ return CC->getDefiningExpr();
}
return nullptr;
}
-const Expression *NewGVN::createExpression(Instruction *I,
- const BasicBlock *B) {
-
+const Expression *NewGVN::createExpression(Instruction *I) {
auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands());
- bool AllConstant = setBasicExpressionInfo(I, E, B);
+ bool AllConstant = setBasicExpressionInfo(I, E);
if (I->isCommutative()) {
// Ensure that commutative instructions that only differ by a permutation
@@ -543,7 +835,7 @@ const Expression *NewGVN::createExpression(Instruction *I,
// numbers. Since all commutative instructions have two operands it is more
// efficient to sort by hand rather than using, say, std::sort.
assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!");
- if (E->getOperand(0) > E->getOperand(1))
+ if (shouldSwapOperands(E->getOperand(0), E->getOperand(1)))
E->swapOperands(0, 1);
}
@@ -559,48 +851,43 @@ const Expression *NewGVN::createExpression(Instruction *I,
// Sort the operand value numbers so x<y and y>x get the same value
// number.
CmpInst::Predicate Predicate = CI->getPredicate();
- if (E->getOperand(0) > E->getOperand(1)) {
+ if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) {
E->swapOperands(0, 1);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
E->setOpcode((CI->getOpcode() << 8) | Predicate);
// TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands
- // TODO: Since we noop bitcasts, we may need to check types before
- // simplifying, so that we don't end up simplifying based on a wrong
- // type assumption. We should clean this up so we can use constants of the
- // wrong type
-
assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() &&
"Wrong types on cmp instruction");
- if ((E->getOperand(0)->getType() == I->getOperand(0)->getType() &&
- E->getOperand(1)->getType() == I->getOperand(1)->getType())) {
- Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1),
- *DL, TLI, DT, AC);
- if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
- return SimplifiedE;
- }
+ assert((E->getOperand(0)->getType() == I->getOperand(0)->getType() &&
+ E->getOperand(1)->getType() == I->getOperand(1)->getType()));
+ Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1),
+ DL, TLI, DT, AC);
+ if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
+ return SimplifiedE;
} else if (isa<SelectInst>(I)) {
if (isa<Constant>(E->getOperand(0)) ||
- (E->getOperand(1)->getType() == I->getOperand(1)->getType() &&
- E->getOperand(2)->getType() == I->getOperand(2)->getType())) {
+ E->getOperand(0) == E->getOperand(1)) {
+ assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() &&
+ E->getOperand(2)->getType() == I->getOperand(2)->getType());
Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1),
- E->getOperand(2), *DL, TLI, DT, AC);
+ E->getOperand(2), DL, TLI, DT, AC);
if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
return SimplifiedE;
}
} else if (I->isBinaryOp()) {
Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1),
- *DL, TLI, DT, AC);
+ DL, TLI, DT, AC);
if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
return SimplifiedE;
} else if (auto *BI = dyn_cast<BitCastInst>(I)) {
- Value *V = SimplifyInstruction(BI, *DL, TLI, DT, AC);
+ Value *V = SimplifyInstruction(BI, DL, TLI, DT, AC);
if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
return SimplifiedE;
} else if (isa<GetElementPtrInst>(I)) {
Value *V = SimplifyGEPInst(E->getType(),
ArrayRef<Value *>(E->op_begin(), E->op_end()),
- *DL, TLI, DT, AC);
+ DL, TLI, DT, AC);
if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
return SimplifiedE;
} else if (AllConstant) {
@@ -615,7 +902,7 @@ const Expression *NewGVN::createExpression(Instruction *I,
for (Value *Arg : E->operands())
C.emplace_back(cast<Constant>(Arg));
- if (Value *V = ConstantFoldInstOperands(I, C, *DL, TLI))
+ if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI))
if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V))
return SimplifiedE;
}
@@ -623,18 +910,18 @@ const Expression *NewGVN::createExpression(Instruction *I,
}
const AggregateValueExpression *
-NewGVN::createAggregateValueExpression(Instruction *I, const BasicBlock *B) {
+NewGVN::createAggregateValueExpression(Instruction *I) {
if (auto *II = dyn_cast<InsertValueInst>(I)) {
auto *E = new (ExpressionAllocator)
AggregateValueExpression(I->getNumOperands(), II->getNumIndices());
- setBasicExpressionInfo(I, E, B);
+ setBasicExpressionInfo(I, E);
E->allocateIntOperands(ExpressionAllocator);
std::copy(II->idx_begin(), II->idx_end(), int_op_inserter(E));
return E;
} else if (auto *EI = dyn_cast<ExtractValueInst>(I)) {
auto *E = new (ExpressionAllocator)
AggregateValueExpression(I->getNumOperands(), EI->getNumIndices());
- setBasicExpressionInfo(EI, E, B);
+ setBasicExpressionInfo(EI, E);
E->allocateIntOperands(ExpressionAllocator);
std::copy(EI->idx_begin(), EI->idx_end(), int_op_inserter(E));
return E;
@@ -648,12 +935,10 @@ const VariableExpression *NewGVN::createVariableExpression(Value *V) {
return E;
}
-const Expression *NewGVN::createVariableOrConstant(Value *V,
- const BasicBlock *B) {
- auto Leader = lookupOperandLeader(V, nullptr, B);
- if (auto *C = dyn_cast<Constant>(Leader))
+const Expression *NewGVN::createVariableOrConstant(Value *V) {
+ if (auto *C = dyn_cast<Constant>(V))
return createConstantExpression(C);
- return createVariableExpression(Leader);
+ return createVariableExpression(V);
}
const ConstantExpression *NewGVN::createConstantExpression(Constant *C) {
@@ -669,40 +954,90 @@ const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) {
}
const CallExpression *NewGVN::createCallExpression(CallInst *CI,
- MemoryAccess *HV,
- const BasicBlock *B) {
+ const MemoryAccess *MA) {
// FIXME: Add operand bundles for calls.
auto *E =
- new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, HV);
- setBasicExpressionInfo(CI, E, B);
+ new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA);
+ setBasicExpressionInfo(CI, E);
return E;
}
+// Return true if some equivalent of instruction Inst dominates instruction U.
+bool NewGVN::someEquivalentDominates(const Instruction *Inst,
+ const Instruction *U) const {
+ auto *CC = ValueToClass.lookup(Inst);
+ // This must be an instruction because we are only called from phi nodes
+ // in the case that the value it needs to check against is an instruction.
+
+ // The most likely candiates for dominance are the leader and the next leader.
+ // The leader or nextleader will dominate in all cases where there is an
+ // equivalent that is higher up in the dom tree.
+ // We can't *only* check them, however, because the
+ // dominator tree could have an infinite number of non-dominating siblings
+ // with instructions that are in the right congruence class.
+ // A
+ // B C D E F G
+ // |
+ // H
+ // Instruction U could be in H, with equivalents in every other sibling.
+ // Depending on the rpo order picked, the leader could be the equivalent in
+ // any of these siblings.
+ if (!CC)
+ return false;
+ if (DT->dominates(cast<Instruction>(CC->getLeader()), U))
+ return true;
+ if (CC->getNextLeader().first &&
+ DT->dominates(cast<Instruction>(CC->getNextLeader().first), U))
+ return true;
+ return llvm::any_of(*CC, [&](const Value *Member) {
+ return Member != CC->getLeader() &&
+ DT->dominates(cast<Instruction>(Member), U);
+ });
+}
+
// See if we have a congruence class and leader for this operand, and if so,
// return it. Otherwise, return the operand itself.
-template <class T>
-Value *NewGVN::lookupOperandLeader(Value *V, const User *U, const T &B) const {
+Value *NewGVN::lookupOperandLeader(Value *V) const {
CongruenceClass *CC = ValueToClass.lookup(V);
- if (CC && (CC != InitialClass))
- return CC->RepLeader;
+ if (CC) {
+ // Everything in TOP is represneted by undef, as it can be any value.
+ // We do have to make sure we get the type right though, so we can't set the
+ // RepLeader to undef.
+ if (CC == TOPClass)
+ return UndefValue::get(V->getType());
+ return CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader();
+ }
+
return V;
}
-MemoryAccess *NewGVN::lookupMemoryAccessEquiv(MemoryAccess *MA) const {
- MemoryAccess *Result = MemoryAccessEquiv.lookup(MA);
- return Result ? Result : MA;
+const MemoryAccess *NewGVN::lookupMemoryLeader(const MemoryAccess *MA) const {
+ auto *CC = getMemoryClass(MA);
+ assert(CC->getMemoryLeader() &&
+ "Every MemoryAccess should be mapped to a "
+ "congruence class with a represenative memory "
+ "access");
+ return CC->getMemoryLeader();
+}
+
+// Return true if the MemoryAccess is really equivalent to everything. This is
+// equivalent to the lattice value "TOP" in most lattices. This is the initial
+// state of all MemoryAccesses.
+bool NewGVN::isMemoryAccessTop(const MemoryAccess *MA) const {
+ return getMemoryClass(MA) == TOPClass;
}
LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,
- LoadInst *LI, MemoryAccess *DA,
- const BasicBlock *B) {
- auto *E = new (ExpressionAllocator) LoadExpression(1, LI, DA);
+ LoadInst *LI,
+ const MemoryAccess *MA) {
+ auto *E =
+ new (ExpressionAllocator) LoadExpression(1, LI, lookupMemoryLeader(MA));
E->allocateOperands(ArgRecycler, ExpressionAllocator);
E->setType(LoadType);
// Give store and loads same opcode so they value number together.
E->setOpcode(0);
- E->op_push_back(lookupOperandLeader(PointerOp, LI, B));
+ E->op_push_back(PointerOp);
if (LI)
E->setAlignment(LI->getAlignment());
@@ -713,16 +1048,16 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,
}
const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI,
- MemoryAccess *DA,
- const BasicBlock *B) {
- auto *E =
- new (ExpressionAllocator) StoreExpression(SI->getNumOperands(), SI, DA);
+ const MemoryAccess *MA) {
+ auto *StoredValueLeader = lookupOperandLeader(SI->getValueOperand());
+ auto *E = new (ExpressionAllocator)
+ StoreExpression(SI->getNumOperands(), SI, StoredValueLeader, MA);
E->allocateOperands(ArgRecycler, ExpressionAllocator);
E->setType(SI->getValueOperand()->getType());
// Give store and loads same opcode so they value number together.
E->setOpcode(0);
- E->op_push_back(lookupOperandLeader(SI->getPointerOperand(), SI, B));
+ E->op_push_back(lookupOperandLeader(SI->getPointerOperand()));
// TODO: Value number heap versions. We may be able to discover
// things alias analysis can't on it's own (IE that a store and a
@@ -730,44 +1065,140 @@ const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI,
return E;
}
-// Utility function to check whether the congruence class has a member other
-// than the given instruction.
-bool hasMemberOtherThanUs(const CongruenceClass *CC, Instruction *I) {
- // Either it has more than one store, in which case it must contain something
- // other than us (because it's indexed by value), or if it only has one store
- // right now, that member should not be us.
- return CC->StoreCount > 1 || CC->Members.count(I) == 0;
-}
-
-const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I,
- const BasicBlock *B) {
+const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I) {
// Unlike loads, we never try to eliminate stores, so we do not check if they
// are simple and avoid value numbering them.
auto *SI = cast<StoreInst>(I);
- MemoryAccess *StoreAccess = MSSA->getMemoryAccess(SI);
- // See if we are defined by a previous store expression, it already has a
- // value, and it's the same value as our current store. FIXME: Right now, we
- // only do this for simple stores, we should expand to cover memcpys, etc.
+ auto *StoreAccess = MSSA->getMemoryAccess(SI);
+ // Get the expression, if any, for the RHS of the MemoryDef.
+ const MemoryAccess *StoreRHS = StoreAccess->getDefiningAccess();
+ if (EnableStoreRefinement)
+ StoreRHS = MSSAWalker->getClobberingMemoryAccess(StoreAccess);
+ // If we bypassed the use-def chains, make sure we add a use.
+ if (StoreRHS != StoreAccess->getDefiningAccess())
+ addMemoryUsers(StoreRHS, StoreAccess);
+
+ StoreRHS = lookupMemoryLeader(StoreRHS);
+ // If we are defined by ourselves, use the live on entry def.
+ if (StoreRHS == StoreAccess)
+ StoreRHS = MSSA->getLiveOnEntryDef();
+
if (SI->isSimple()) {
- // Get the expression, if any, for the RHS of the MemoryDef.
- MemoryAccess *StoreRHS = lookupMemoryAccessEquiv(
- cast<MemoryDef>(StoreAccess)->getDefiningAccess());
- const Expression *OldStore = createStoreExpression(SI, StoreRHS, B);
- CongruenceClass *CC = ExpressionToClass.lookup(OldStore);
+ // See if we are defined by a previous store expression, it already has a
+ // value, and it's the same value as our current store. FIXME: Right now, we
+ // only do this for simple stores, we should expand to cover memcpys, etc.
+ const auto *LastStore = createStoreExpression(SI, StoreRHS);
+ const auto *LastCC = ExpressionToClass.lookup(LastStore);
// Basically, check if the congruence class the store is in is defined by a
// store that isn't us, and has the same value. MemorySSA takes care of
// ensuring the store has the same memory state as us already.
- if (CC && CC->DefiningExpr && isa<StoreExpression>(CC->DefiningExpr) &&
- CC->RepLeader == lookupOperandLeader(SI->getValueOperand(), SI, B) &&
- hasMemberOtherThanUs(CC, I))
- return createStoreExpression(SI, StoreRHS, B);
+ // The RepStoredValue gets nulled if all the stores disappear in a class, so
+ // we don't need to check if the class contains a store besides us.
+ if (LastCC &&
+ LastCC->getStoredValue() == lookupOperandLeader(SI->getValueOperand()))
+ return LastStore;
+ deleteExpression(LastStore);
+ // Also check if our value operand is defined by a load of the same memory
+ // location, and the memory state is the same as it was then (otherwise, it
+ // could have been overwritten later. See test32 in
+ // transforms/DeadStoreElimination/simple.ll).
+ if (auto *LI =
+ dyn_cast<LoadInst>(lookupOperandLeader(SI->getValueOperand()))) {
+ if ((lookupOperandLeader(LI->getPointerOperand()) ==
+ lookupOperandLeader(SI->getPointerOperand())) &&
+ (lookupMemoryLeader(MSSA->getMemoryAccess(LI)->getDefiningAccess()) ==
+ StoreRHS))
+ return createVariableExpression(LI);
+ }
}
- return createStoreExpression(SI, StoreAccess, B);
+ // If the store is not equivalent to anything, value number it as a store that
+ // produces a unique memory state (instead of using it's MemoryUse, we use
+ // it's MemoryDef).
+ return createStoreExpression(SI, StoreAccess);
}
-const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I,
- const BasicBlock *B) {
+// See if we can extract the value of a loaded pointer from a load, a store, or
+// a memory instruction.
+const Expression *
+NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr,
+ LoadInst *LI, Instruction *DepInst,
+ MemoryAccess *DefiningAccess) {
+ assert((!LI || LI->isSimple()) && "Not a simple load");
+ if (auto *DepSI = dyn_cast<StoreInst>(DepInst)) {
+ // Can't forward from non-atomic to atomic without violating memory model.
+ // Also don't need to coerce if they are the same type, we will just
+ // propogate..
+ if (LI->isAtomic() > DepSI->isAtomic() ||
+ LoadType == DepSI->getValueOperand()->getType())
+ return nullptr;
+ int Offset = analyzeLoadFromClobberingStore(LoadType, LoadPtr, DepSI, DL);
+ if (Offset >= 0) {
+ if (auto *C = dyn_cast<Constant>(
+ lookupOperandLeader(DepSI->getValueOperand()))) {
+ DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant "
+ << *C << "\n");
+ return createConstantExpression(
+ getConstantStoreValueForLoad(C, Offset, LoadType, DL));
+ }
+ }
+
+ } else if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) {
+ // Can't forward from non-atomic to atomic without violating memory model.
+ if (LI->isAtomic() > DepLI->isAtomic())
+ return nullptr;
+ int Offset = analyzeLoadFromClobberingLoad(LoadType, LoadPtr, DepLI, DL);
+ if (Offset >= 0) {
+ // We can coerce a constant load into a load
+ if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI)))
+ if (auto *PossibleConstant =
+ getConstantLoadValueForLoad(C, Offset, LoadType, DL)) {
+ DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant "
+ << *PossibleConstant << "\n");
+ return createConstantExpression(PossibleConstant);
+ }
+ }
+
+ } else if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInst)) {
+ int Offset = analyzeLoadFromClobberingMemInst(LoadType, LoadPtr, DepMI, DL);
+ if (Offset >= 0) {
+ if (auto *PossibleConstant =
+ getConstantMemInstValueForLoad(DepMI, Offset, LoadType, DL)) {
+ DEBUG(dbgs() << "Coercing load from meminst " << *DepMI
+ << " to constant " << *PossibleConstant << "\n");
+ return createConstantExpression(PossibleConstant);
+ }
+ }
+ }
+
+ // All of the below are only true if the loaded pointer is produced
+ // by the dependent instruction.
+ if (LoadPtr != lookupOperandLeader(DepInst) &&
+ !AA->isMustAlias(LoadPtr, DepInst))
+ return nullptr;
+ // If this load really doesn't depend on anything, then we must be loading an
+ // undef value. This can happen when loading for a fresh allocation with no
+ // intervening stores, for example. Note that this is only true in the case
+ // that the result of the allocation is pointer equal to the load ptr.
+ if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) {
+ return createConstantExpression(UndefValue::get(LoadType));
+ }
+ // If this load occurs either right after a lifetime begin,
+ // then the loaded value is undefined.
+ else if (auto *II = dyn_cast<IntrinsicInst>(DepInst)) {
+ if (II->getIntrinsicID() == Intrinsic::lifetime_start)
+ return createConstantExpression(UndefValue::get(LoadType));
+ }
+ // If this load follows a calloc (which zero initializes memory),
+ // then the loaded value is zero
+ else if (isCallocLikeFn(DepInst, TLI)) {
+ return createConstantExpression(Constant::getNullValue(LoadType));
+ }
+
+ return nullptr;
+}
+
+const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) {
auto *LI = cast<LoadInst>(I);
// We can eliminate in favor of non-simple loads, but we won't be able to
@@ -775,7 +1206,7 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I,
if (!LI->isSimple())
return nullptr;
- Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand(), I, B);
+ Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand());
// Load of undef is undef.
if (isa<UndefValue>(LoadAddressLeader))
return createConstantExpression(UndefValue::get(LI->getType()));
@@ -788,61 +1219,233 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I,
// If the defining instruction is not reachable, replace with undef.
if (!ReachableBlocks.count(DefiningInst->getParent()))
return createConstantExpression(UndefValue::get(LI->getType()));
+ // This will handle stores and memory insts. We only do if it the
+ // defining access has a different type, or it is a pointer produced by
+ // certain memory operations that cause the memory to have a fixed value
+ // (IE things like calloc).
+ if (const auto *CoercionResult =
+ performSymbolicLoadCoercion(LI->getType(), LoadAddressLeader, LI,
+ DefiningInst, DefiningAccess))
+ return CoercionResult;
}
}
- const Expression *E =
- createLoadExpression(LI->getType(), LI->getPointerOperand(), LI,
- lookupMemoryAccessEquiv(DefiningAccess), B);
+ const Expression *E = createLoadExpression(LI->getType(), LoadAddressLeader,
+ LI, DefiningAccess);
return E;
}
+const Expression *
+NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) {
+ auto *PI = PredInfo->getPredicateInfoFor(I);
+ if (!PI)
+ return nullptr;
+
+ DEBUG(dbgs() << "Found predicate info from instruction !\n");
+
+ auto *PWC = dyn_cast<PredicateWithCondition>(PI);
+ if (!PWC)
+ return nullptr;
+
+ auto *CopyOf = I->getOperand(0);
+ auto *Cond = PWC->Condition;
+
+ // If this a copy of the condition, it must be either true or false depending
+ // on the predicate info type and edge
+ if (CopyOf == Cond) {
+ // We should not need to add predicate users because the predicate info is
+ // already a use of this operand.
+ if (isa<PredicateAssume>(PI))
+ return createConstantExpression(ConstantInt::getTrue(Cond->getType()));
+ if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
+ if (PBranch->TrueEdge)
+ return createConstantExpression(ConstantInt::getTrue(Cond->getType()));
+ return createConstantExpression(ConstantInt::getFalse(Cond->getType()));
+ }
+ if (auto *PSwitch = dyn_cast<PredicateSwitch>(PI))
+ return createConstantExpression(cast<Constant>(PSwitch->CaseValue));
+ }
+
+ // Not a copy of the condition, so see what the predicates tell us about this
+ // value. First, though, we check to make sure the value is actually a copy
+ // of one of the condition operands. It's possible, in certain cases, for it
+ // to be a copy of a predicateinfo copy. In particular, if two branch
+ // operations use the same condition, and one branch dominates the other, we
+ // will end up with a copy of a copy. This is currently a small deficiency in
+ // predicateinfo. What will end up happening here is that we will value
+ // number both copies the same anyway.
+
+ // Everything below relies on the condition being a comparison.
+ auto *Cmp = dyn_cast<CmpInst>(Cond);
+ if (!Cmp)
+ return nullptr;
+
+ if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) {
+ DEBUG(dbgs() << "Copy is not of any condition operands!");
+ return nullptr;
+ }
+ Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0));
+ Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1));
+ bool SwappedOps = false;
+ // Sort the ops
+ if (shouldSwapOperands(FirstOp, SecondOp)) {
+ std::swap(FirstOp, SecondOp);
+ SwappedOps = true;
+ }
+ CmpInst::Predicate Predicate =
+ SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate();
+
+ if (isa<PredicateAssume>(PI)) {
+ // If the comparison is true when the operands are equal, then we know the
+ // operands are equal, because assumes must always be true.
+ if (CmpInst::isTrueWhenEqual(Predicate)) {
+ addPredicateUsers(PI, I);
+ return createVariableOrConstant(FirstOp);
+ }
+ }
+ if (const auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
+ // If we are *not* a copy of the comparison, we may equal to the other
+ // operand when the predicate implies something about equality of
+ // operations. In particular, if the comparison is true/false when the
+ // operands are equal, and we are on the right edge, we know this operation
+ // is equal to something.
+ if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) ||
+ (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) {
+ addPredicateUsers(PI, I);
+ return createVariableOrConstant(FirstOp);
+ }
+ // Handle the special case of floating point.
+ if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) ||
+ (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) &&
+ isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) {
+ addPredicateUsers(PI, I);
+ return createConstantExpression(cast<Constant>(FirstOp));
+ }
+ }
+ return nullptr;
+}
+
// Evaluate read only and pure calls, and create an expression result.
-const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I,
- const BasicBlock *B) {
+const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) {
auto *CI = cast<CallInst>(I);
- if (AA->doesNotAccessMemory(CI))
- return createCallExpression(CI, nullptr, B);
- if (AA->onlyReadsMemory(CI)) {
+ if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+ // Instrinsics with the returned attribute are copies of arguments.
+ if (auto *ReturnedValue = II->getReturnedArgOperand()) {
+ if (II->getIntrinsicID() == Intrinsic::ssa_copy)
+ if (const auto *Result = performSymbolicPredicateInfoEvaluation(I))
+ return Result;
+ return createVariableOrConstant(ReturnedValue);
+ }
+ }
+ if (AA->doesNotAccessMemory(CI)) {
+ return createCallExpression(CI, TOPClass->getMemoryLeader());
+ } else if (AA->onlyReadsMemory(CI)) {
MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI);
- return createCallExpression(CI, lookupMemoryAccessEquiv(DefiningAccess), B);
+ return createCallExpression(CI, DefiningAccess);
}
return nullptr;
}
-// Update the memory access equivalence table to say that From is equal to To,
+// Retrieve the memory class for a given MemoryAccess.
+CongruenceClass *NewGVN::getMemoryClass(const MemoryAccess *MA) const {
+
+ auto *Result = MemoryAccessToClass.lookup(MA);
+ assert(Result && "Should have found memory class");
+ return Result;
+}
+
+// Update the MemoryAccess equivalence table to say that From is equal to To,
// and return true if this is different from what already existed in the table.
-bool NewGVN::setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To) {
- DEBUG(dbgs() << "Setting " << *From << " equivalent to ");
- if (!To)
- DEBUG(dbgs() << "itself");
- else
- DEBUG(dbgs() << *To);
+bool NewGVN::setMemoryClass(const MemoryAccess *From,
+ CongruenceClass *NewClass) {
+ assert(NewClass &&
+ "Every MemoryAccess should be getting mapped to a non-null class");
+ DEBUG(dbgs() << "Setting " << *From);
+ DEBUG(dbgs() << " equivalent to congruence class ");
+ DEBUG(dbgs() << NewClass->getID() << " with current MemoryAccess leader ");
+ DEBUG(dbgs() << *NewClass->getMemoryLeader());
DEBUG(dbgs() << "\n");
- auto LookupResult = MemoryAccessEquiv.find(From);
+
+ auto LookupResult = MemoryAccessToClass.find(From);
bool Changed = false;
// If it's already in the table, see if the value changed.
- if (LookupResult != MemoryAccessEquiv.end()) {
- if (To && LookupResult->second != To) {
+ if (LookupResult != MemoryAccessToClass.end()) {
+ auto *OldClass = LookupResult->second;
+ if (OldClass != NewClass) {
+ // If this is a phi, we have to handle memory member updates.
+ if (auto *MP = dyn_cast<MemoryPhi>(From)) {
+ OldClass->memory_erase(MP);
+ NewClass->memory_insert(MP);
+ // This may have killed the class if it had no non-memory members
+ if (OldClass->getMemoryLeader() == From) {
+ if (OldClass->memory_empty()) {
+ OldClass->setMemoryLeader(nullptr);
+ } else {
+ OldClass->setMemoryLeader(getNextMemoryLeader(OldClass));
+ DEBUG(dbgs() << "Memory class leader change for class "
+ << OldClass->getID() << " to "
+ << *OldClass->getMemoryLeader()
+ << " due to removal of a memory member " << *From
+ << "\n");
+ markMemoryLeaderChangeTouched(OldClass);
+ }
+ }
+ }
// It wasn't equivalent before, and now it is.
- LookupResult->second = To;
- Changed = true;
- } else if (!To) {
- // It used to be equivalent to something, and now it's not.
- MemoryAccessEquiv.erase(LookupResult);
+ LookupResult->second = NewClass;
Changed = true;
}
- } else {
- assert(!To &&
- "Memory equivalence should never change from nothing to something");
}
return Changed;
}
+
+// Determine if a phi is cycle-free. That means the values in the phi don't
+// depend on any expressions that can change value as a result of the phi.
+// For example, a non-cycle free phi would be v = phi(0, v+1).
+bool NewGVN::isCycleFree(const PHINode *PN) {
+ // In order to compute cycle-freeness, we do SCC finding on the phi, and see
+ // what kind of SCC it ends up in. If it is a singleton, it is cycle-free.
+ // If it is not in a singleton, it is only cycle free if the other members are
+ // all phi nodes (as they do not compute anything, they are copies). TODO:
+ // There are likely a few other intrinsics or expressions that could be
+ // included here, but this happens so infrequently already that it is not
+ // likely to be worth it.
+ auto PCS = PhiCycleState.lookup(PN);
+ if (PCS == PCS_Unknown) {
+ SCCFinder.Start(PN);
+ auto &SCC = SCCFinder.getComponentFor(PN);
+ // It's cycle free if it's size 1 or or the SCC is *only* phi nodes.
+ if (SCC.size() == 1)
+ PhiCycleState.insert({PN, PCS_CycleFree});
+ else {
+ bool AllPhis =
+ llvm::all_of(SCC, [](const Value *V) { return isa<PHINode>(V); });
+ PCS = AllPhis ? PCS_CycleFree : PCS_Cycle;
+ for (auto *Member : SCC)
+ if (auto *MemberPhi = dyn_cast<PHINode>(Member))
+ PhiCycleState.insert({MemberPhi, PCS});
+ }
+ }
+ if (PCS == PCS_Cycle)
+ return false;
+ return true;
+}
+
// Evaluate PHI nodes symbolically, and create an expression result.
-const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I,
- const BasicBlock *B) {
- auto *E = cast<PHIExpression>(createPHIExpression(I));
+const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) {
+ // True if one of the incoming phi edges is a backedge.
+ bool HasBackedge = false;
+ // All constant tracks the state of whether all the *original* phi operands
+ // were constant.
+ // This is really shorthand for "this phi cannot cycle due to forward
+ // propagation", as any
+ // change in value of the phi is guaranteed not to later change the value of
+ // the phi.
+ // IE it can't be v = phi(undef, v+1)
+ bool AllConstant = true;
+ auto *E =
+ cast<PHIExpression>(createPHIExpression(I, HasBackedge, AllConstant));
// We match the semantics of SimplifyPhiNode from InstructionSimplify here.
// See if all arguaments are the same.
@@ -861,14 +1464,15 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I,
if (Filtered.begin() == Filtered.end()) {
DEBUG(dbgs() << "Simplified PHI node " << *I << " to undef"
<< "\n");
- E->deallocateOperands(ArgRecycler);
- ExpressionAllocator.Deallocate(E);
+ deleteExpression(E);
return createConstantExpression(UndefValue::get(I->getType()));
}
+ unsigned NumOps = 0;
Value *AllSameValue = *(Filtered.begin());
++Filtered.begin();
// Can't use std::equal here, sadly, because filter.begin moves.
- if (llvm::all_of(Filtered, [AllSameValue](const Value *V) {
+ if (llvm::all_of(Filtered, [AllSameValue, &NumOps](const Value *V) {
+ ++NumOps;
return V == AllSameValue;
})) {
// In LLVM's non-standard representation of phi nodes, it's possible to have
@@ -881,27 +1485,32 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I,
// We also special case undef, so that if we have an undef, we can't use the
// common value unless it dominates the phi block.
if (HasUndef) {
+ // If we have undef and at least one other value, this is really a
+ // multivalued phi, and we need to know if it's cycle free in order to
+ // evaluate whether we can ignore the undef. The other parts of this are
+ // just shortcuts. If there is no backedge, or all operands are
+ // constants, or all operands are ignored but the undef, it also must be
+ // cycle free.
+ if (!AllConstant && HasBackedge && NumOps > 0 &&
+ !isa<UndefValue>(AllSameValue) && !isCycleFree(cast<PHINode>(I)))
+ return E;
+
// Only have to check for instructions
if (auto *AllSameInst = dyn_cast<Instruction>(AllSameValue))
- if (!DT->dominates(AllSameInst, I))
+ if (!someEquivalentDominates(AllSameInst, I))
return E;
}
NumGVNPhisAllSame++;
DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue
<< "\n");
- E->deallocateOperands(ArgRecycler);
- ExpressionAllocator.Deallocate(E);
- if (auto *C = dyn_cast<Constant>(AllSameValue))
- return createConstantExpression(C);
- return createVariableExpression(AllSameValue);
+ deleteExpression(E);
+ return createVariableOrConstant(AllSameValue);
}
return E;
}
-const Expression *
-NewGVN::performSymbolicAggrValueEvaluation(Instruction *I,
- const BasicBlock *B) {
+const Expression *NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) {
if (auto *EI = dyn_cast<ExtractValueInst>(I)) {
auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand());
if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) {
@@ -931,19 +1540,130 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I,
// expression.
assert(II->getNumArgOperands() == 2 &&
"Expect two args for recognised intrinsics.");
- return createBinaryExpression(Opcode, EI->getType(),
- II->getArgOperand(0),
- II->getArgOperand(1), B);
+ return createBinaryExpression(
+ Opcode, EI->getType(), II->getArgOperand(0), II->getArgOperand(1));
}
}
}
- return createAggregateValueExpression(I, B);
+ return createAggregateValueExpression(I);
+}
+const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) {
+ auto *CI = dyn_cast<CmpInst>(I);
+ // See if our operands are equal to those of a previous predicate, and if so,
+ // if it implies true or false.
+ auto Op0 = lookupOperandLeader(CI->getOperand(0));
+ auto Op1 = lookupOperandLeader(CI->getOperand(1));
+ auto OurPredicate = CI->getPredicate();
+ if (shouldSwapOperands(Op0, Op1)) {
+ std::swap(Op0, Op1);
+ OurPredicate = CI->getSwappedPredicate();
+ }
+
+ // Avoid processing the same info twice
+ const PredicateBase *LastPredInfo = nullptr;
+ // See if we know something about the comparison itself, like it is the target
+ // of an assume.
+ auto *CmpPI = PredInfo->getPredicateInfoFor(I);
+ if (dyn_cast_or_null<PredicateAssume>(CmpPI))
+ return createConstantExpression(ConstantInt::getTrue(CI->getType()));
+
+ if (Op0 == Op1) {
+ // This condition does not depend on predicates, no need to add users
+ if (CI->isTrueWhenEqual())
+ return createConstantExpression(ConstantInt::getTrue(CI->getType()));
+ else if (CI->isFalseWhenEqual())
+ return createConstantExpression(ConstantInt::getFalse(CI->getType()));
+ }
+
+ // NOTE: Because we are comparing both operands here and below, and using
+ // previous comparisons, we rely on fact that predicateinfo knows to mark
+ // comparisons that use renamed operands as users of the earlier comparisons.
+ // It is *not* enough to just mark predicateinfo renamed operands as users of
+ // the earlier comparisons, because the *other* operand may have changed in a
+ // previous iteration.
+ // Example:
+ // icmp slt %a, %b
+ // %b.0 = ssa.copy(%b)
+ // false branch:
+ // icmp slt %c, %b.0
+
+ // %c and %a may start out equal, and thus, the code below will say the second
+ // %icmp is false. c may become equal to something else, and in that case the
+ // %second icmp *must* be reexamined, but would not if only the renamed
+ // %operands are considered users of the icmp.
+
+ // *Currently* we only check one level of comparisons back, and only mark one
+ // level back as touched when changes appen . If you modify this code to look
+ // back farther through comparisons, you *must* mark the appropriate
+ // comparisons as users in PredicateInfo.cpp, or you will cause bugs. See if
+ // we know something just from the operands themselves
+
+ // See if our operands have predicate info, so that we may be able to derive
+ // something from a previous comparison.
+ for (const auto &Op : CI->operands()) {
+ auto *PI = PredInfo->getPredicateInfoFor(Op);
+ if (const auto *PBranch = dyn_cast_or_null<PredicateBranch>(PI)) {
+ if (PI == LastPredInfo)
+ continue;
+ LastPredInfo = PI;
+
+ // TODO: Along the false edge, we may know more things too, like icmp of
+ // same operands is false.
+ // TODO: We only handle actual comparison conditions below, not and/or.
+ auto *BranchCond = dyn_cast<CmpInst>(PBranch->Condition);
+ if (!BranchCond)
+ continue;
+ auto *BranchOp0 = lookupOperandLeader(BranchCond->getOperand(0));
+ auto *BranchOp1 = lookupOperandLeader(BranchCond->getOperand(1));
+ auto BranchPredicate = BranchCond->getPredicate();
+ if (shouldSwapOperands(BranchOp0, BranchOp1)) {
+ std::swap(BranchOp0, BranchOp1);
+ BranchPredicate = BranchCond->getSwappedPredicate();
+ }
+ if (BranchOp0 == Op0 && BranchOp1 == Op1) {
+ if (PBranch->TrueEdge) {
+ // If we know the previous predicate is true and we are in the true
+ // edge then we may be implied true or false.
+ if (CmpInst::isImpliedTrueByMatchingCmp(OurPredicate,
+ BranchPredicate)) {
+ addPredicateUsers(PI, I);
+ return createConstantExpression(
+ ConstantInt::getTrue(CI->getType()));
+ }
+
+ if (CmpInst::isImpliedFalseByMatchingCmp(OurPredicate,
+ BranchPredicate)) {
+ addPredicateUsers(PI, I);
+ return createConstantExpression(
+ ConstantInt::getFalse(CI->getType()));
+ }
+
+ } else {
+ // Just handle the ne and eq cases, where if we have the same
+ // operands, we may know something.
+ if (BranchPredicate == OurPredicate) {
+ addPredicateUsers(PI, I);
+ // Same predicate, same ops,we know it was false, so this is false.
+ return createConstantExpression(
+ ConstantInt::getFalse(CI->getType()));
+ } else if (BranchPredicate ==
+ CmpInst::getInversePredicate(OurPredicate)) {
+ addPredicateUsers(PI, I);
+ // Inverse predicate, we know the other was false, so this is true.
+ return createConstantExpression(
+ ConstantInt::getTrue(CI->getType()));
+ }
+ }
+ }
+ }
+ }
+ // Create expression will take care of simplifyCmpInst
+ return createExpression(I);
}
// Substitute and symbolize the value before value numbering.
-const Expression *NewGVN::performSymbolicEvaluation(Value *V,
- const BasicBlock *B) {
+const Expression *NewGVN::performSymbolicEvaluation(Value *V) {
const Expression *E = nullptr;
if (auto *C = dyn_cast<Constant>(V))
E = createConstantExpression(C);
@@ -957,24 +1677,27 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,
switch (I->getOpcode()) {
case Instruction::ExtractValue:
case Instruction::InsertValue:
- E = performSymbolicAggrValueEvaluation(I, B);
+ E = performSymbolicAggrValueEvaluation(I);
break;
case Instruction::PHI:
- E = performSymbolicPHIEvaluation(I, B);
+ E = performSymbolicPHIEvaluation(I);
break;
case Instruction::Call:
- E = performSymbolicCallEvaluation(I, B);
+ E = performSymbolicCallEvaluation(I);
break;
case Instruction::Store:
- E = performSymbolicStoreEvaluation(I, B);
+ E = performSymbolicStoreEvaluation(I);
break;
case Instruction::Load:
- E = performSymbolicLoadEvaluation(I, B);
+ E = performSymbolicLoadEvaluation(I);
break;
case Instruction::BitCast: {
- E = createExpression(I, B);
+ E = createExpression(I);
+ } break;
+ case Instruction::ICmp:
+ case Instruction::FCmp: {
+ E = performSymbolicCmpEvaluation(I);
} break;
-
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
@@ -993,8 +1716,6 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
- case Instruction::ICmp:
- case Instruction::FCmp:
case Instruction::Trunc:
case Instruction::ZExt:
case Instruction::SExt:
@@ -1011,7 +1732,7 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,
case Instruction::InsertElement:
case Instruction::ShuffleVector:
case Instruction::GetElementPtr:
- E = createExpression(I, B);
+ E = createExpression(I);
break;
default:
return nullptr;
@@ -1020,129 +1741,297 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V,
return E;
}
-// There is an edge from 'Src' to 'Dst'. Return true if every path from
-// the entry block to 'Dst' passes via this edge. In particular 'Dst'
-// must not be reachable via another edge from 'Src'.
-bool NewGVN::isOnlyReachableViaThisEdge(const BasicBlockEdge &E) const {
-
- // While in theory it is interesting to consider the case in which Dst has
- // more than one predecessor, because Dst might be part of a loop which is
- // only reachable from Src, in practice it is pointless since at the time
- // GVN runs all such loops have preheaders, which means that Dst will have
- // been changed to have only one predecessor, namely Src.
- const BasicBlock *Pred = E.getEnd()->getSinglePredecessor();
- const BasicBlock *Src = E.getStart();
- assert((!Pred || Pred == Src) && "No edge between these basic blocks!");
- (void)Src;
- return Pred != nullptr;
-}
-
void NewGVN::markUsersTouched(Value *V) {
// Now mark the users as touched.
for (auto *User : V->users()) {
assert(isa<Instruction>(User) && "Use of value not within an instruction?");
- TouchedInstructions.set(InstrDFS[User]);
+ TouchedInstructions.set(InstrToDFSNum(User));
}
}
-void NewGVN::markMemoryUsersTouched(MemoryAccess *MA) {
- for (auto U : MA->users()) {
- if (auto *MUD = dyn_cast<MemoryUseOrDef>(U))
- TouchedInstructions.set(InstrDFS[MUD->getMemoryInst()]);
- else
- TouchedInstructions.set(InstrDFS[U]);
+void NewGVN::addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) {
+ DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n");
+ MemoryToUsers[To].insert(U);
+}
+
+void NewGVN::markMemoryDefTouched(const MemoryAccess *MA) {
+ TouchedInstructions.set(MemoryToDFSNum(MA));
+}
+
+void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) {
+ if (isa<MemoryUse>(MA))
+ return;
+ for (auto U : MA->users())
+ TouchedInstructions.set(MemoryToDFSNum(U));
+ const auto Result = MemoryToUsers.find(MA);
+ if (Result != MemoryToUsers.end()) {
+ for (auto *User : Result->second)
+ TouchedInstructions.set(MemoryToDFSNum(User));
+ MemoryToUsers.erase(Result);
+ }
+}
+
+// Add I to the set of users of a given predicate.
+void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) {
+ if (auto *PBranch = dyn_cast<PredicateBranch>(PB))
+ PredicateToUsers[PBranch->Condition].insert(I);
+ else if (auto *PAssume = dyn_cast<PredicateBranch>(PB))
+ PredicateToUsers[PAssume->Condition].insert(I);
+}
+
+// Touch all the predicates that depend on this instruction.
+void NewGVN::markPredicateUsersTouched(Instruction *I) {
+ const auto Result = PredicateToUsers.find(I);
+ if (Result != PredicateToUsers.end()) {
+ for (auto *User : Result->second)
+ TouchedInstructions.set(InstrToDFSNum(User));
+ PredicateToUsers.erase(Result);
}
}
+// Mark users affected by a memory leader change.
+void NewGVN::markMemoryLeaderChangeTouched(CongruenceClass *CC) {
+ for (auto M : CC->memory())
+ markMemoryDefTouched(M);
+}
+
// Touch the instructions that need to be updated after a congruence class has a
// leader change, and mark changed values.
-void NewGVN::markLeaderChangeTouched(CongruenceClass *CC) {
- for (auto M : CC->Members) {
+void NewGVN::markValueLeaderChangeTouched(CongruenceClass *CC) {
+ for (auto M : *CC) {
if (auto *I = dyn_cast<Instruction>(M))
- TouchedInstructions.set(InstrDFS[I]);
+ TouchedInstructions.set(InstrToDFSNum(I));
LeaderChanges.insert(M);
}
}
+// Give a range of things that have instruction DFS numbers, this will return
+// the member of the range with the smallest dfs number.
+template <class T, class Range>
+T *NewGVN::getMinDFSOfRange(const Range &R) const {
+ std::pair<T *, unsigned> MinDFS = {nullptr, ~0U};
+ for (const auto X : R) {
+ auto DFSNum = InstrToDFSNum(X);
+ if (DFSNum < MinDFS.second)
+ MinDFS = {X, DFSNum};
+ }
+ return MinDFS.first;
+}
+
+// This function returns the MemoryAccess that should be the next leader of
+// congruence class CC, under the assumption that the current leader is going to
+// disappear.
+const MemoryAccess *NewGVN::getNextMemoryLeader(CongruenceClass *CC) const {
+ // TODO: If this ends up to slow, we can maintain a next memory leader like we
+ // do for regular leaders.
+ // Make sure there will be a leader to find
+ assert((CC->getStoreCount() > 0 || !CC->memory_empty()) &&
+ "Can't get next leader if there is none");
+ if (CC->getStoreCount() > 0) {
+ if (auto *NL = dyn_cast_or_null<StoreInst>(CC->getNextLeader().first))
+ return MSSA->getMemoryAccess(NL);
+ // Find the store with the minimum DFS number.
+ auto *V = getMinDFSOfRange<Value>(make_filter_range(
+ *CC, [&](const Value *V) { return isa<StoreInst>(V); }));
+ return MSSA->getMemoryAccess(cast<StoreInst>(V));
+ }
+ assert(CC->getStoreCount() == 0);
+
+ // Given our assertion, hitting this part must mean
+ // !OldClass->memory_empty()
+ if (CC->memory_size() == 1)
+ return *CC->memory_begin();
+ return getMinDFSOfRange<const MemoryPhi>(CC->memory());
+}
+
+// This function returns the next value leader of a congruence class, under the
+// assumption that the current leader is going away. This should end up being
+// the next most dominating member.
+Value *NewGVN::getNextValueLeader(CongruenceClass *CC) const {
+ // We don't need to sort members if there is only 1, and we don't care about
+ // sorting the TOP class because everything either gets out of it or is
+ // unreachable.
+
+ if (CC->size() == 1 || CC == TOPClass) {
+ return *(CC->begin());
+ } else if (CC->getNextLeader().first) {
+ ++NumGVNAvoidedSortedLeaderChanges;
+ return CC->getNextLeader().first;
+ } else {
+ ++NumGVNSortedLeaderChanges;
+ // NOTE: If this ends up to slow, we can maintain a dual structure for
+ // member testing/insertion, or keep things mostly sorted, and sort only
+ // here, or use SparseBitVector or ....
+ return getMinDFSOfRange<Value>(*CC);
+ }
+}
+
+// Move a MemoryAccess, currently in OldClass, to NewClass, including updates to
+// the memory members, etc for the move.
+//
+// The invariants of this function are:
+//
+// I must be moving to NewClass from OldClass The StoreCount of OldClass and
+// NewClass is expected to have been updated for I already if it is is a store.
+// The OldClass memory leader has not been updated yet if I was the leader.
+void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I,
+ MemoryAccess *InstMA,
+ CongruenceClass *OldClass,
+ CongruenceClass *NewClass) {
+ // If the leader is I, and we had a represenative MemoryAccess, it should
+ // be the MemoryAccess of OldClass.
+ assert((!InstMA || !OldClass->getMemoryLeader() ||
+ OldClass->getLeader() != I ||
+ OldClass->getMemoryLeader() == InstMA) &&
+ "Representative MemoryAccess mismatch");
+ // First, see what happens to the new class
+ if (!NewClass->getMemoryLeader()) {
+ // Should be a new class, or a store becoming a leader of a new class.
+ assert(NewClass->size() == 1 ||
+ (isa<StoreInst>(I) && NewClass->getStoreCount() == 1));
+ NewClass->setMemoryLeader(InstMA);
+ // Mark it touched if we didn't just create a singleton
+ DEBUG(dbgs() << "Memory class leader change for class " << NewClass->getID()
+ << " due to new memory instruction becoming leader\n");
+ markMemoryLeaderChangeTouched(NewClass);
+ }
+ setMemoryClass(InstMA, NewClass);
+ // Now, fixup the old class if necessary
+ if (OldClass->getMemoryLeader() == InstMA) {
+ if (OldClass->getStoreCount() != 0 || !OldClass->memory_empty()) {
+ OldClass->setMemoryLeader(getNextMemoryLeader(OldClass));
+ DEBUG(dbgs() << "Memory class leader change for class "
+ << OldClass->getID() << " to "
+ << *OldClass->getMemoryLeader()
+ << " due to removal of old leader " << *InstMA << "\n");
+ markMemoryLeaderChangeTouched(OldClass);
+ } else
+ OldClass->setMemoryLeader(nullptr);
+ }
+}
+
// Move a value, currently in OldClass, to be part of NewClass
-// Update OldClass for the move (including changing leaders, etc)
-void NewGVN::moveValueToNewCongruenceClass(Instruction *I,
+// Update OldClass and NewClass for the move (including changing leaders, etc).
+void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E,
CongruenceClass *OldClass,
CongruenceClass *NewClass) {
- DEBUG(dbgs() << "New congruence class for " << I << " is " << NewClass->ID
- << "\n");
-
- if (I == OldClass->NextLeader.first)
- OldClass->NextLeader = {nullptr, ~0U};
+ if (I == OldClass->getNextLeader().first)
+ OldClass->resetNextLeader();
// It's possible, though unlikely, for us to discover equivalences such
// that the current leader does not dominate the old one.
// This statistic tracks how often this happens.
// We assert on phi nodes when this happens, currently, for debugging, because
// we want to make sure we name phi node cycles properly.
- if (isa<Instruction>(NewClass->RepLeader) && NewClass->RepLeader &&
- I != NewClass->RepLeader &&
- DT->properlyDominates(
- I->getParent(),
- cast<Instruction>(NewClass->RepLeader)->getParent())) {
- ++NumGVNNotMostDominatingLeader;
- assert(!isa<PHINode>(I) &&
- "New class for instruction should not be dominated by instruction");
- }
-
- if (NewClass->RepLeader != I) {
- auto DFSNum = InstrDFS.lookup(I);
- if (DFSNum < NewClass->NextLeader.second)
- NewClass->NextLeader = {I, DFSNum};
+ if (isa<Instruction>(NewClass->getLeader()) && NewClass->getLeader() &&
+ I != NewClass->getLeader()) {
+ auto *IBB = I->getParent();
+ auto *NCBB = cast<Instruction>(NewClass->getLeader())->getParent();
+ bool Dominated =
+ IBB == NCBB && InstrToDFSNum(I) < InstrToDFSNum(NewClass->getLeader());
+ Dominated = Dominated || DT->properlyDominates(IBB, NCBB);
+ if (Dominated) {
+ ++NumGVNNotMostDominatingLeader;
+ assert(
+ !isa<PHINode>(I) &&
+ "New class for instruction should not be dominated by instruction");
+ }
}
- OldClass->Members.erase(I);
- NewClass->Members.insert(I);
- if (isa<StoreInst>(I)) {
- --OldClass->StoreCount;
- assert(OldClass->StoreCount >= 0);
- ++NewClass->StoreCount;
- assert(NewClass->StoreCount > 0);
+ if (NewClass->getLeader() != I)
+ NewClass->addPossibleNextLeader({I, InstrToDFSNum(I)});
+
+ OldClass->erase(I);
+ NewClass->insert(I);
+ // Handle our special casing of stores.
+ if (auto *SI = dyn_cast<StoreInst>(I)) {
+ OldClass->decStoreCount();
+ // Okay, so when do we want to make a store a leader of a class?
+ // If we have a store defined by an earlier load, we want the earlier load
+ // to lead the class.
+ // If we have a store defined by something else, we want the store to lead
+ // the class so everything else gets the "something else" as a value.
+ // If we have a store as the single member of the class, we want the store
+ // as the leader
+ if (NewClass->getStoreCount() == 0 && !NewClass->getStoredValue()) {
+ // If it's a store expression we are using, it means we are not equivalent
+ // to something earlier.
+ if (isa<StoreExpression>(E)) {
+ assert(lookupOperandLeader(SI->getValueOperand()) !=
+ NewClass->getLeader());
+ NewClass->setStoredValue(lookupOperandLeader(SI->getValueOperand()));
+ markValueLeaderChangeTouched(NewClass);
+ // Shift the new class leader to be the store
+ DEBUG(dbgs() << "Changing leader of congruence class "
+ << NewClass->getID() << " from " << *NewClass->getLeader()
+ << " to " << *SI << " because store joined class\n");
+ // If we changed the leader, we have to mark it changed because we don't
+ // know what it will do to symbolic evlauation.
+ NewClass->setLeader(SI);
+ }
+ // We rely on the code below handling the MemoryAccess change.
+ }
+ NewClass->incStoreCount();
}
-
+ // True if there is no memory instructions left in a class that had memory
+ // instructions before.
+
+ // If it's not a memory use, set the MemoryAccess equivalence
+ auto *InstMA = dyn_cast_or_null<MemoryDef>(MSSA->getMemoryAccess(I));
+ bool InstWasMemoryLeader = InstMA && OldClass->getMemoryLeader() == InstMA;
+ if (InstMA)
+ moveMemoryToNewCongruenceClass(I, InstMA, OldClass, NewClass);
ValueToClass[I] = NewClass;
// See if we destroyed the class or need to swap leaders.
- if (OldClass->Members.empty() && OldClass != InitialClass) {
- if (OldClass->DefiningExpr) {
- OldClass->Dead = true;
- DEBUG(dbgs() << "Erasing expression " << OldClass->DefiningExpr
+ if (OldClass->empty() && OldClass != TOPClass) {
+ if (OldClass->getDefiningExpr()) {
+ DEBUG(dbgs() << "Erasing expression " << OldClass->getDefiningExpr()
<< " from table\n");
- ExpressionToClass.erase(OldClass->DefiningExpr);
+ ExpressionToClass.erase(OldClass->getDefiningExpr());
}
- } else if (OldClass->RepLeader == I) {
+ } else if (OldClass->getLeader() == I) {
// When the leader changes, the value numbering of
// everything may change due to symbolization changes, so we need to
// reprocess.
- DEBUG(dbgs() << "Leader change!\n");
+ DEBUG(dbgs() << "Value class leader change for class " << OldClass->getID()
+ << "\n");
++NumGVNLeaderChanges;
- // We don't need to sort members if there is only 1, and we don't care about
- // sorting the initial class because everything either gets out of it or is
- // unreachable.
- if (OldClass->Members.size() == 1 || OldClass == InitialClass) {
- OldClass->RepLeader = *(OldClass->Members.begin());
- } else if (OldClass->NextLeader.first) {
- ++NumGVNAvoidedSortedLeaderChanges;
- OldClass->RepLeader = OldClass->NextLeader.first;
- OldClass->NextLeader = {nullptr, ~0U};
- } else {
- ++NumGVNSortedLeaderChanges;
- // TODO: If this ends up to slow, we can maintain a dual structure for
- // member testing/insertion, or keep things mostly sorted, and sort only
- // here, or ....
- std::pair<Value *, unsigned> MinDFS = {nullptr, ~0U};
- for (const auto X : OldClass->Members) {
- auto DFSNum = InstrDFS.lookup(X);
- if (DFSNum < MinDFS.second)
- MinDFS = {X, DFSNum};
- }
- OldClass->RepLeader = MinDFS.first;
+ // Destroy the stored value if there are no more stores to represent it.
+ // Note that this is basically clean up for the expression removal that
+ // happens below. If we remove stores from a class, we may leave it as a
+ // class of equivalent memory phis.
+ if (OldClass->getStoreCount() == 0) {
+ if (OldClass->getStoredValue())
+ OldClass->setStoredValue(nullptr);
+ }
+ // If we destroy the old access leader and it's a store, we have to
+ // effectively destroy the congruence class. When it comes to scalars,
+ // anything with the same value is as good as any other. That means that
+ // one leader is as good as another, and as long as you have some leader for
+ // the value, you are good.. When it comes to *memory states*, only one
+ // particular thing really represents the definition of a given memory
+ // state. Once it goes away, we need to re-evaluate which pieces of memory
+ // are really still equivalent. The best way to do this is to re-value
+ // number things. The only way to really make that happen is to destroy the
+ // rest of the class. In order to effectively destroy the class, we reset
+ // ExpressionToClass for each by using the ValueToExpression mapping. The
+ // members later get marked as touched due to the leader change. We will
+ // create new congruence classes, and the pieces that are still equivalent
+ // will end back together in a new class. If this becomes too expensive, it
+ // is possible to use a versioning scheme for the congruence classes to
+ // avoid the expressions finding this old class. Note that the situation is
+ // different for memory phis, becuase they are evaluated anew each time, and
+ // they become equal not by hashing, but by seeing if all operands are the
+ // same (or only one is reachable).
+ if (OldClass->getStoreCount() > 0 && InstWasMemoryLeader) {
+ DEBUG(dbgs() << "Kicking everything out of class " << OldClass->getID()
+ << " because MemoryAccess leader changed");
+ for (auto Member : *OldClass)
+ ExpressionToClass.erase(ValueToExpression.lookup(Member));
}
- markLeaderChangeTouched(OldClass);
+ OldClass->setLeader(getNextValueLeader(OldClass));
+ OldClass->resetNextLeader();
+ markValueLeaderChangeTouched(OldClass);
}
}
@@ -1150,12 +2039,12 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I,
void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) {
ValueToExpression[I] = E;
// This is guaranteed to return something, since it will at least find
- // INITIAL.
+ // TOP.
CongruenceClass *IClass = ValueToClass[I];
assert(IClass && "Should have found a IClass");
// Dead classes should have been eliminated from the mapping.
- assert(!IClass->Dead && "Found a dead class");
+ assert(!IClass->isDead() && "Found a dead class");
CongruenceClass *EClass;
if (const auto *VE = dyn_cast<VariableExpression>(E)) {
@@ -1171,79 +2060,52 @@ void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) {
// Constants and variables should always be made the leader.
if (const auto *CE = dyn_cast<ConstantExpression>(E)) {
- NewClass->RepLeader = CE->getConstantValue();
+ NewClass->setLeader(CE->getConstantValue());
} else if (const auto *SE = dyn_cast<StoreExpression>(E)) {
StoreInst *SI = SE->getStoreInst();
- NewClass->RepLeader =
- lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent());
+ NewClass->setLeader(SI);
+ NewClass->setStoredValue(lookupOperandLeader(SI->getValueOperand()));
+ // The RepMemoryAccess field will be filled in properly by the
+ // moveValueToNewCongruenceClass call.
} else {
- NewClass->RepLeader = I;
+ NewClass->setLeader(I);
}
assert(!isa<VariableExpression>(E) &&
"VariableExpression should have been handled already");
EClass = NewClass;
DEBUG(dbgs() << "Created new congruence class for " << *I
- << " using expression " << *E << " at " << NewClass->ID
- << " and leader " << *(NewClass->RepLeader) << "\n");
- DEBUG(dbgs() << "Hash value was " << E->getHashValue() << "\n");
+ << " using expression " << *E << " at " << NewClass->getID()
+ << " and leader " << *(NewClass->getLeader()));
+ if (NewClass->getStoredValue())
+ DEBUG(dbgs() << " and stored value " << *(NewClass->getStoredValue()));
+ DEBUG(dbgs() << "\n");
} else {
EClass = lookupResult.first->second;
if (isa<ConstantExpression>(E))
- assert(isa<Constant>(EClass->RepLeader) &&
+ assert((isa<Constant>(EClass->getLeader()) ||
+ (EClass->getStoredValue() &&
+ isa<Constant>(EClass->getStoredValue()))) &&
"Any class with a constant expression should have a "
"constant leader");
assert(EClass && "Somehow don't have an eclass");
- assert(!EClass->Dead && "We accidentally looked up a dead class");
+ assert(!EClass->isDead() && "We accidentally looked up a dead class");
}
}
bool ClassChanged = IClass != EClass;
bool LeaderChanged = LeaderChanges.erase(I);
if (ClassChanged || LeaderChanged) {
- DEBUG(dbgs() << "Found class " << EClass->ID << " for expression " << E
+ DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " << *E
<< "\n");
-
if (ClassChanged)
- moveValueToNewCongruenceClass(I, IClass, EClass);
+ moveValueToNewCongruenceClass(I, E, IClass, EClass);
markUsersTouched(I);
- if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) {
- // If this is a MemoryDef, we need to update the equivalence table. If
- // we determined the expression is congruent to a different memory
- // state, use that different memory state. If we determined it didn't,
- // we update that as well. Right now, we only support store
- // expressions.
- if (!isa<MemoryUse>(MA) && isa<StoreExpression>(E) &&
- EClass->Members.size() != 1) {
- auto *DefAccess = cast<StoreExpression>(E)->getDefiningAccess();
- setMemoryAccessEquivTo(MA, DefAccess != MA ? DefAccess : nullptr);
- } else {
- setMemoryAccessEquivTo(MA, nullptr);
- }
+ if (MemoryAccess *MA = MSSA->getMemoryAccess(I))
markMemoryUsersTouched(MA);
- }
- } else if (auto *SI = dyn_cast<StoreInst>(I)) {
- // There is, sadly, one complicating thing for stores. Stores do not
- // produce values, only consume them. However, in order to make loads and
- // stores value number the same, we ignore the value operand of the store.
- // But the value operand will still be the leader of our class, and thus, it
- // may change. Because the store is a use, the store will get reprocessed,
- // but nothing will change about it, and so nothing above will catch it
- // (since the class will not change). In order to make sure everything ends
- // up okay, we need to recheck the leader of the class. Since stores of
- // different values value number differently due to different memorydefs, we
- // are guaranteed the leader is always the same between stores in the same
- // class.
- DEBUG(dbgs() << "Checking store leader\n");
- auto ProperLeader =
- lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent());
- if (EClass->RepLeader != ProperLeader) {
- DEBUG(dbgs() << "Store leader changed, fixing\n");
- EClass->RepLeader = ProperLeader;
- markLeaderChangeTouched(EClass);
- markMemoryUsersTouched(MSSA->getMemoryAccess(SI));
- }
+ if (auto *CI = dyn_cast<CmpInst>(I))
+ markPredicateUsersTouched(CI);
}
}
@@ -1267,11 +2129,11 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) {
// they are the only thing that depend on new edges. Anything using their
// values will get propagated to if necessary.
if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(To))
- TouchedInstructions.set(InstrDFS[MemPhi]);
+ TouchedInstructions.set(InstrToDFSNum(MemPhi));
auto BI = To->begin();
while (isa<PHINode>(BI)) {
- TouchedInstructions.set(InstrDFS[&*BI]);
+ TouchedInstructions.set(InstrToDFSNum(&*BI));
++BI;
}
}
@@ -1280,8 +2142,8 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) {
// Given a predicate condition (from a switch, cmp, or whatever) and a block,
// see if we know some constant value for it already.
-Value *NewGVN::findConditionEquivalence(Value *Cond, BasicBlock *B) const {
- auto Result = lookupOperandLeader(Cond, nullptr, B);
+Value *NewGVN::findConditionEquivalence(Value *Cond) const {
+ auto Result = lookupOperandLeader(Cond);
if (isa<Constant>(Result))
return Result;
return nullptr;
@@ -1293,10 +2155,10 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {
BranchInst *BR;
if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) {
Value *Cond = BR->getCondition();
- Value *CondEvaluated = findConditionEquivalence(Cond, B);
+ Value *CondEvaluated = findConditionEquivalence(Cond);
if (!CondEvaluated) {
if (auto *I = dyn_cast<Instruction>(Cond)) {
- const Expression *E = createExpression(I, B);
+ const Expression *E = createExpression(I);
if (const auto *CE = dyn_cast<ConstantExpression>(E)) {
CondEvaluated = CE->getConstantValue();
}
@@ -1329,13 +2191,13 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {
SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges;
Value *SwitchCond = SI->getCondition();
- Value *CondEvaluated = findConditionEquivalence(SwitchCond, B);
+ Value *CondEvaluated = findConditionEquivalence(SwitchCond);
// See if we were able to turn this switch statement into a constant.
if (CondEvaluated && isa<ConstantInt>(CondEvaluated)) {
auto *CondVal = cast<ConstantInt>(CondEvaluated);
// We should be able to get case value for this.
- auto CaseVal = SI->findCaseValue(CondVal);
- if (CaseVal.getCaseSuccessor() == SI->getDefaultDest()) {
+ auto Case = *SI->findCaseValue(CondVal);
+ if (Case.getCaseSuccessor() == SI->getDefaultDest()) {
// We proved the value is outside of the range of the case.
// We can't do anything other than mark the default dest as reachable,
// and go home.
@@ -1343,7 +2205,7 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {
return;
}
// Now get where it goes and mark it reachable.
- BasicBlock *TargetBlock = CaseVal.getCaseSuccessor();
+ BasicBlock *TargetBlock = Case.getCaseSuccessor();
updateReachableEdge(B, TargetBlock);
} else {
for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) {
@@ -1361,45 +2223,66 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) {
}
// This also may be a memory defining terminator, in which case, set it
- // equivalent to nothing.
- if (MemoryAccess *MA = MSSA->getMemoryAccess(TI))
- setMemoryAccessEquivTo(MA, nullptr);
+ // equivalent only to itself.
+ //
+ auto *MA = MSSA->getMemoryAccess(TI);
+ if (MA && !isa<MemoryUse>(MA)) {
+ auto *CC = ensureLeaderOfMemoryClass(MA);
+ if (setMemoryClass(MA, CC))
+ markMemoryUsersTouched(MA);
+ }
}
}
-// The algorithm initially places the values of the routine in the INITIAL
-// congruence
-// class. The leader of INITIAL is the undetermined value `TOP`.
-// When the algorithm has finished, values still in INITIAL are unreachable.
+// The algorithm initially places the values of the routine in the TOP
+// congruence class. The leader of TOP is the undetermined value `undef`.
+// When the algorithm has finished, values still in TOP are unreachable.
void NewGVN::initializeCongruenceClasses(Function &F) {
- // FIXME now i can't remember why this is 2
- NextCongruenceNum = 2;
- // Initialize all other instructions to be in INITIAL class.
- CongruenceClass::MemberSet InitialValues;
- InitialClass = createCongruenceClass(nullptr, nullptr);
+ NextCongruenceNum = 0;
+
+ // Note that even though we use the live on entry def as a representative
+ // MemoryAccess, it is *not* the same as the actual live on entry def. We
+ // have no real equivalemnt to undef for MemoryAccesses, and so we really
+ // should be checking whether the MemoryAccess is top if we want to know if it
+ // is equivalent to everything. Otherwise, what this really signifies is that
+ // the access "it reaches all the way back to the beginning of the function"
+
+ // Initialize all other instructions to be in TOP class.
+ TOPClass = createCongruenceClass(nullptr, nullptr);
+ TOPClass->setMemoryLeader(MSSA->getLiveOnEntryDef());
+ // The live on entry def gets put into it's own class
+ MemoryAccessToClass[MSSA->getLiveOnEntryDef()] =
+ createMemoryClass(MSSA->getLiveOnEntryDef());
+
for (auto &B : F) {
- if (auto *MP = MSSA->getMemoryAccess(&B))
- MemoryAccessEquiv.insert({MP, MSSA->getLiveOnEntryDef()});
+ // All MemoryAccesses are equivalent to live on entry to start. They must
+ // be initialized to something so that initial changes are noticed. For
+ // the maximal answer, we initialize them all to be the same as
+ // liveOnEntry.
+ auto *MemoryBlockDefs = MSSA->getBlockDefs(&B);
+ if (MemoryBlockDefs)
+ for (const auto &Def : *MemoryBlockDefs) {
+ MemoryAccessToClass[&Def] = TOPClass;
+ auto *MD = dyn_cast<MemoryDef>(&Def);
+ // Insert the memory phis into the member list.
+ if (!MD) {
+ const MemoryPhi *MP = cast<MemoryPhi>(&Def);
+ TOPClass->memory_insert(MP);
+ MemoryPhiState.insert({MP, MPS_TOP});
+ }
- for (auto &I : B) {
- InitialValues.insert(&I);
- ValueToClass[&I] = InitialClass;
- // All memory accesses are equivalent to live on entry to start. They must
- // be initialized to something so that initial changes are noticed. For
- // the maximal answer, we initialize them all to be the same as
- // liveOnEntry. Note that to save time, we only initialize the
- // MemoryDef's for stores and all MemoryPhis to be equal. Right now, no
- // other expression can generate a memory equivalence. If we start
- // handling memcpy/etc, we can expand this.
- if (isa<StoreInst>(&I)) {
- MemoryAccessEquiv.insert(
- {MSSA->getMemoryAccess(&I), MSSA->getLiveOnEntryDef()});
- ++InitialClass->StoreCount;
- assert(InitialClass->StoreCount > 0);
+ if (MD && isa<StoreInst>(MD->getMemoryInst()))
+ TOPClass->incStoreCount();
}
+ for (auto &I : B) {
+ // Don't insert void terminators into the class. We don't value number
+ // them, and they just end up sitting in TOP.
+ if (isa<TerminatorInst>(I) && I.getType()->isVoidTy())
+ continue;
+ TOPClass->insert(&I);
+ ValueToClass[&I] = TOPClass;
}
}
- InitialClass->Members.swap(InitialValues);
// Initialize arguments to be in their own unique congruence classes
for (auto &FA : F.args())
@@ -1408,8 +2291,8 @@ void NewGVN::initializeCongruenceClasses(Function &F) {
void NewGVN::cleanupTables() {
for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) {
- DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->ID << " has "
- << CongruenceClasses[i]->Members.size() << " members\n");
+ DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID()
+ << " has " << CongruenceClasses[i]->size() << " members\n");
// Make sure we delete the congruence class (probably worth switching to
// a unique_ptr at some point.
delete CongruenceClasses[i];
@@ -1427,15 +2310,14 @@ void NewGVN::cleanupTables() {
#ifndef NDEBUG
ProcessedCount.clear();
#endif
- DFSDomMap.clear();
InstrDFS.clear();
InstructionsToErase.clear();
-
DFSToInstr.clear();
BlockInstRange.clear();
TouchedInstructions.clear();
- DominatedInstRange.clear();
- MemoryAccessEquiv.clear();
+ MemoryAccessToClass.clear();
+ PredicateToUsers.clear();
+ MemoryToUsers.clear();
}
std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B,
@@ -1447,6 +2329,16 @@ std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B,
}
for (auto &I : *B) {
+ // There's no need to call isInstructionTriviallyDead more than once on
+ // an instruction. Therefore, once we know that an instruction is dead
+ // we change its DFS number so that it doesn't get value numbered.
+ if (isInstructionTriviallyDead(&I, TLI)) {
+ InstrDFS[&I] = 0;
+ DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n");
+ markInstructionForDeletion(&I);
+ continue;
+ }
+
InstrDFS[&I] = End++;
DFSToInstr.emplace_back(&I);
}
@@ -1462,7 +2354,7 @@ void NewGVN::updateProcessedCount(Value *V) {
if (ProcessedCount.count(V) == 0) {
ProcessedCount.insert({V, 1});
} else {
- ProcessedCount[V] += 1;
+ ++ProcessedCount[V];
assert(ProcessedCount[V] < 100 &&
"Seem to have processed the same Value a lot");
}
@@ -1472,26 +2364,33 @@ void NewGVN::updateProcessedCount(Value *V) {
void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {
// If all the arguments are the same, the MemoryPhi has the same value as the
// argument.
- // Filter out unreachable blocks from our operands.
+ // Filter out unreachable blocks and self phis from our operands.
+ const BasicBlock *PHIBlock = MP->getBlock();
auto Filtered = make_filter_range(MP->operands(), [&](const Use &U) {
- return ReachableBlocks.count(MP->getIncomingBlock(U));
+ return lookupMemoryLeader(cast<MemoryAccess>(U)) != MP &&
+ !isMemoryAccessTop(cast<MemoryAccess>(U)) &&
+ ReachableEdges.count({MP->getIncomingBlock(U), PHIBlock});
});
-
- assert(Filtered.begin() != Filtered.end() &&
- "We should not be processing a MemoryPhi in a completely "
- "unreachable block");
+ // If all that is left is nothing, our memoryphi is undef. We keep it as
+ // InitialClass. Note: The only case this should happen is if we have at
+ // least one self-argument.
+ if (Filtered.begin() == Filtered.end()) {
+ if (setMemoryClass(MP, TOPClass))
+ markMemoryUsersTouched(MP);
+ return;
+ }
// Transform the remaining operands into operand leaders.
// FIXME: mapped_iterator should have a range version.
auto LookupFunc = [&](const Use &U) {
- return lookupMemoryAccessEquiv(cast<MemoryAccess>(U));
+ return lookupMemoryLeader(cast<MemoryAccess>(U));
};
auto MappedBegin = map_iterator(Filtered.begin(), LookupFunc);
auto MappedEnd = map_iterator(Filtered.end(), LookupFunc);
// and now check if all the elements are equal.
// Sadly, we can't use std::equals since these are random access iterators.
- MemoryAccess *AllSameValue = *MappedBegin;
+ const auto *AllSameValue = *MappedBegin;
++MappedBegin;
bool AllEqual = std::all_of(
MappedBegin, MappedEnd,
@@ -1501,8 +2400,18 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {
DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n");
else
DEBUG(dbgs() << "Memory Phi value numbered to itself\n");
-
- if (setMemoryAccessEquivTo(MP, AllEqual ? AllSameValue : nullptr))
+ // If it's equal to something, it's in that class. Otherwise, it has to be in
+ // a class where it is the leader (other things may be equivalent to it, but
+ // it needs to start off in its own class, which means it must have been the
+ // leader, and it can't have stopped being the leader because it was never
+ // removed).
+ CongruenceClass *CC =
+ AllEqual ? getMemoryClass(AllSameValue) : ensureLeaderOfMemoryClass(MP);
+ auto OldState = MemoryPhiState.lookup(MP);
+ assert(OldState != MPS_Invalid && "Invalid memory phi state");
+ auto NewState = AllEqual ? MPS_Equivalent : MPS_Unique;
+ MemoryPhiState[MP] = NewState;
+ if (setMemoryClass(MP, CC) || OldState != NewState)
markMemoryUsersTouched(MP);
}
@@ -1510,21 +2419,25 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) {
// congruence finding, and updating mappings.
void NewGVN::valueNumberInstruction(Instruction *I) {
DEBUG(dbgs() << "Processing instruction " << *I << "\n");
- if (isInstructionTriviallyDead(I, TLI)) {
- DEBUG(dbgs() << "Skipping unused instruction\n");
- markInstructionForDeletion(I);
- return;
- }
if (!I->isTerminator()) {
- const auto *Symbolized = performSymbolicEvaluation(I, I->getParent());
+ const Expression *Symbolized = nullptr;
+ if (DebugCounter::shouldExecute(VNCounter)) {
+ Symbolized = performSymbolicEvaluation(I);
+ } else {
+ // Mark the instruction as unused so we don't value number it again.
+ InstrDFS[I] = 0;
+ }
// If we couldn't come up with a symbolic expression, use the unknown
// expression
- if (Symbolized == nullptr)
+ if (Symbolized == nullptr) {
Symbolized = createUnknownExpression(I);
+ }
+
performCongruenceFinding(I, Symbolized);
} else {
// Handle terminators that return values. All of them produce values we
- // don't currently understand.
+ // don't currently understand. We don't place non-value producing
+ // terminators in a class.
if (!I->getType()->isVoidTy()) {
auto *Symbolized = createUnknownExpression(I);
performCongruenceFinding(I, Symbolized);
@@ -1539,72 +2452,102 @@ bool NewGVN::singleReachablePHIPath(const MemoryAccess *First,
const MemoryAccess *Second) const {
if (First == Second)
return true;
-
- if (auto *FirstDef = dyn_cast<MemoryUseOrDef>(First)) {
- auto *DefAccess = FirstDef->getDefiningAccess();
- return singleReachablePHIPath(DefAccess, Second);
- } else {
- auto *MP = cast<MemoryPhi>(First);
- auto ReachableOperandPred = [&](const Use &U) {
- return ReachableBlocks.count(MP->getIncomingBlock(U));
- };
- auto FilteredPhiArgs =
- make_filter_range(MP->operands(), ReachableOperandPred);
- SmallVector<const Value *, 32> OperandList;
- std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(),
- std::back_inserter(OperandList));
- bool Okay = OperandList.size() == 1;
- if (!Okay)
- Okay = std::equal(OperandList.begin(), OperandList.end(),
- OperandList.begin());
- if (Okay)
- return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second);
+ if (MSSA->isLiveOnEntryDef(First))
return false;
+
+ const auto *EndDef = First;
+ for (auto *ChainDef : optimized_def_chain(First)) {
+ if (ChainDef == Second)
+ return true;
+ if (MSSA->isLiveOnEntryDef(ChainDef))
+ return false;
+ EndDef = ChainDef;
}
+ auto *MP = cast<MemoryPhi>(EndDef);
+ auto ReachableOperandPred = [&](const Use &U) {
+ return ReachableEdges.count({MP->getIncomingBlock(U), MP->getBlock()});
+ };
+ auto FilteredPhiArgs =
+ make_filter_range(MP->operands(), ReachableOperandPred);
+ SmallVector<const Value *, 32> OperandList;
+ std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(),
+ std::back_inserter(OperandList));
+ bool Okay = OperandList.size() == 1;
+ if (!Okay)
+ Okay =
+ std::equal(OperandList.begin(), OperandList.end(), OperandList.begin());
+ if (Okay)
+ return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second);
+ return false;
}
// Verify the that the memory equivalence table makes sense relative to the
// congruence classes. Note that this checking is not perfect, and is currently
-// subject to very rare false negatives. It is only useful for testing/debugging.
+// subject to very rare false negatives. It is only useful for
+// testing/debugging.
void NewGVN::verifyMemoryCongruency() const {
- // Anything equivalent in the memory access table should be in the same
+#ifndef NDEBUG
+ // Verify that the memory table equivalence and memory member set match
+ for (const auto *CC : CongruenceClasses) {
+ if (CC == TOPClass || CC->isDead())
+ continue;
+ if (CC->getStoreCount() != 0) {
+ assert((CC->getStoredValue() || !isa<StoreInst>(CC->getLeader())) &&
+ "Any class with a store as a "
+ "leader should have a "
+ "representative stored value\n");
+ assert(CC->getMemoryLeader() &&
+ "Any congruence class with a store should "
+ "have a representative access\n");
+ }
+
+ if (CC->getMemoryLeader())
+ assert(MemoryAccessToClass.lookup(CC->getMemoryLeader()) == CC &&
+ "Representative MemoryAccess does not appear to be reverse "
+ "mapped properly");
+ for (auto M : CC->memory())
+ assert(MemoryAccessToClass.lookup(M) == CC &&
+ "Memory member does not appear to be reverse mapped properly");
+ }
+
+ // Anything equivalent in the MemoryAccess table should be in the same
// congruence class.
// Filter out the unreachable and trivially dead entries, because they may
// never have been updated if the instructions were not processed.
auto ReachableAccessPred =
- [&](const std::pair<const MemoryAccess *, MemoryAccess *> Pair) {
+ [&](const std::pair<const MemoryAccess *, CongruenceClass *> Pair) {
bool Result = ReachableBlocks.count(Pair.first->getBlock());
if (!Result)
return false;
+ if (MSSA->isLiveOnEntryDef(Pair.first))
+ return true;
if (auto *MemDef = dyn_cast<MemoryDef>(Pair.first))
return !isInstructionTriviallyDead(MemDef->getMemoryInst());
+ if (MemoryToDFSNum(Pair.first) == 0)
+ return false;
return true;
};
- auto Filtered = make_filter_range(MemoryAccessEquiv, ReachableAccessPred);
+ auto Filtered = make_filter_range(MemoryAccessToClass, ReachableAccessPred);
for (auto KV : Filtered) {
- assert(KV.first != KV.second &&
- "We added a useless equivalence to the memory equivalence table");
- // Unreachable instructions may not have changed because we never process
- // them.
- if (!ReachableBlocks.count(KV.first->getBlock()))
- continue;
+ assert(KV.second != TOPClass &&
+ "Memory not unreachable but ended up in TOP");
if (auto *FirstMUD = dyn_cast<MemoryUseOrDef>(KV.first)) {
- auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second);
+ auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second->getMemoryLeader());
if (FirstMUD && SecondMUD)
assert((singleReachablePHIPath(FirstMUD, SecondMUD) ||
- ValueToClass.lookup(FirstMUD->getMemoryInst()) ==
- ValueToClass.lookup(SecondMUD->getMemoryInst())) &&
- "The instructions for these memory operations should have "
- "been in the same congruence class or reachable through"
- "a single argument phi");
+ ValueToClass.lookup(FirstMUD->getMemoryInst()) ==
+ ValueToClass.lookup(SecondMUD->getMemoryInst())) &&
+ "The instructions for these memory operations should have "
+ "been in the same congruence class or reachable through"
+ "a single argument phi");
} else if (auto *FirstMP = dyn_cast<MemoryPhi>(KV.first)) {
-
// We can only sanely verify that MemoryDefs in the operand list all have
// the same class.
auto ReachableOperandPred = [&](const Use &U) {
- return ReachableBlocks.count(FirstMP->getIncomingBlock(U)) &&
+ return ReachableEdges.count(
+ {FirstMP->getIncomingBlock(U), FirstMP->getBlock()}) &&
isa<MemoryDef>(U);
};
@@ -1622,19 +2565,127 @@ void NewGVN::verifyMemoryCongruency() const {
"All MemoryPhi arguments should be in the same class");
}
}
+#endif
+}
+
+// Verify that the sparse propagation we did actually found the maximal fixpoint
+// We do this by storing the value to class mapping, touching all instructions,
+// and redoing the iteration to see if anything changed.
+void NewGVN::verifyIterationSettled(Function &F) {
+#ifndef NDEBUG
+ DEBUG(dbgs() << "Beginning iteration verification\n");
+ if (DebugCounter::isCounterSet(VNCounter))
+ DebugCounter::setCounterValue(VNCounter, StartingVNCounter);
+
+ // Note that we have to store the actual classes, as we may change existing
+ // classes during iteration. This is because our memory iteration propagation
+ // is not perfect, and so may waste a little work. But it should generate
+ // exactly the same congruence classes we have now, with different IDs.
+ std::map<const Value *, CongruenceClass> BeforeIteration;
+
+ for (auto &KV : ValueToClass) {
+ if (auto *I = dyn_cast<Instruction>(KV.first))
+ // Skip unused/dead instructions.
+ if (InstrToDFSNum(I) == 0)
+ continue;
+ BeforeIteration.insert({KV.first, *KV.second});
+ }
+
+ TouchedInstructions.set();
+ TouchedInstructions.reset(0);
+ iterateTouchedInstructions();
+ DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>>
+ EqualClasses;
+ for (const auto &KV : ValueToClass) {
+ if (auto *I = dyn_cast<Instruction>(KV.first))
+ // Skip unused/dead instructions.
+ if (InstrToDFSNum(I) == 0)
+ continue;
+ // We could sink these uses, but i think this adds a bit of clarity here as
+ // to what we are comparing.
+ auto *BeforeCC = &BeforeIteration.find(KV.first)->second;
+ auto *AfterCC = KV.second;
+ // Note that the classes can't change at this point, so we memoize the set
+ // that are equal.
+ if (!EqualClasses.count({BeforeCC, AfterCC})) {
+ assert(BeforeCC->isEquivalentTo(AfterCC) &&
+ "Value number changed after main loop completed!");
+ EqualClasses.insert({BeforeCC, AfterCC});
+ }
+ }
+#endif
+}
+
+// This is the main value numbering loop, it iterates over the initial touched
+// instruction set, propagating value numbers, marking things touched, etc,
+// until the set of touched instructions is completely empty.
+void NewGVN::iterateTouchedInstructions() {
+ unsigned int Iterations = 0;
+ // Figure out where touchedinstructions starts
+ int FirstInstr = TouchedInstructions.find_first();
+ // Nothing set, nothing to iterate, just return.
+ if (FirstInstr == -1)
+ return;
+ BasicBlock *LastBlock = getBlockForValue(InstrFromDFSNum(FirstInstr));
+ while (TouchedInstructions.any()) {
+ ++Iterations;
+ // Walk through all the instructions in all the blocks in RPO.
+ // TODO: As we hit a new block, we should push and pop equalities into a
+ // table lookupOperandLeader can use, to catch things PredicateInfo
+ // might miss, like edge-only equivalences.
+ for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1;
+ InstrNum = TouchedInstructions.find_next(InstrNum)) {
+
+ // This instruction was found to be dead. We don't bother looking
+ // at it again.
+ if (InstrNum == 0) {
+ TouchedInstructions.reset(InstrNum);
+ continue;
+ }
+
+ Value *V = InstrFromDFSNum(InstrNum);
+ BasicBlock *CurrBlock = getBlockForValue(V);
+
+ // If we hit a new block, do reachability processing.
+ if (CurrBlock != LastBlock) {
+ LastBlock = CurrBlock;
+ bool BlockReachable = ReachableBlocks.count(CurrBlock);
+ const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock);
+
+ // If it's not reachable, erase any touched instructions and move on.
+ if (!BlockReachable) {
+ TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second);
+ DEBUG(dbgs() << "Skipping instructions in block "
+ << getBlockName(CurrBlock)
+ << " because it is unreachable\n");
+ continue;
+ }
+ updateProcessedCount(CurrBlock);
+ }
+
+ if (auto *MP = dyn_cast<MemoryPhi>(V)) {
+ DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n");
+ valueNumberMemoryPhi(MP);
+ } else if (auto *I = dyn_cast<Instruction>(V)) {
+ valueNumberInstruction(I);
+ } else {
+ llvm_unreachable("Should have been a MemoryPhi or Instruction");
+ }
+ updateProcessedCount(V);
+ // Reset after processing (because we may mark ourselves as touched when
+ // we propagate equalities).
+ TouchedInstructions.reset(InstrNum);
+ }
+ }
+ NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations);
}
// This is the main transformation entry point.
-bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
- TargetLibraryInfo *_TLI, AliasAnalysis *_AA,
- MemorySSA *_MSSA) {
+bool NewGVN::runGVN() {
+ if (DebugCounter::isCounterSet(VNCounter))
+ StartingVNCounter = DebugCounter::getCounterValue(VNCounter);
bool Changed = false;
- DT = _DT;
- AC = _AC;
- TLI = _TLI;
- AA = _AA;
- MSSA = _MSSA;
- DL = &F.getParent()->getDataLayout();
+ NumFuncArgs = F.arg_size();
MSSAWalker = MSSA->getWalker();
// Count number of instructions for sizing of hash tables, and come
@@ -1642,15 +2693,14 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
unsigned ICount = 1;
// Add an empty instruction to account for the fact that we start at 1
DFSToInstr.emplace_back(nullptr);
- // Note: We want RPO traversal of the blocks, which is not quite the same as
- // dominator tree order, particularly with regard whether backedges get
- // visited first or second, given a block with multiple successors.
+ // Note: We want ideal RPO traversal of the blocks, which is not quite the
+ // same as dominator tree order, particularly with regard whether backedges
+ // get visited first or second, given a block with multiple successors.
// If we visit in the wrong order, we will end up performing N times as many
// iterations.
// The dominator tree does guarantee that, for a given dom tree node, it's
// parent must occur before it in the RPO ordering. Thus, we only need to sort
// the siblings.
- DenseMap<const DomTreeNode *, unsigned> RPOOrdering;
ReversePostOrderTraversal<Function *> RPOT(&F);
unsigned Counter = 0;
for (auto &B : RPOT) {
@@ -1663,7 +2713,7 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
auto *Node = DT->getNode(B);
if (Node->getChildren().size() > 1)
std::sort(Node->begin(), Node->end(),
- [&RPOOrdering](const DomTreeNode *A, const DomTreeNode *B) {
+ [&](const DomTreeNode *A, const DomTreeNode *B) {
return RPOOrdering[A] < RPOOrdering[B];
});
}
@@ -1689,7 +2739,6 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
}
TouchedInstructions.resize(ICount);
- DominatedInstRange.reserve(F.size());
// Ensure we don't end up resizing the expressionToClass map, as
// that can be quite expensive. At most, we have one expression per
// instruction.
@@ -1701,62 +2750,10 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
ReachableBlocks.insert(&F.getEntryBlock());
initializeCongruenceClasses(F);
-
- unsigned int Iterations = 0;
- // We start out in the entry block.
- BasicBlock *LastBlock = &F.getEntryBlock();
- while (TouchedInstructions.any()) {
- ++Iterations;
- // Walk through all the instructions in all the blocks in RPO.
- for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1;
- InstrNum = TouchedInstructions.find_next(InstrNum)) {
- assert(InstrNum != 0 && "Bit 0 should never be set, something touched an "
- "instruction not in the lookup table");
- Value *V = DFSToInstr[InstrNum];
- BasicBlock *CurrBlock = nullptr;
-
- if (auto *I = dyn_cast<Instruction>(V))
- CurrBlock = I->getParent();
- else if (auto *MP = dyn_cast<MemoryPhi>(V))
- CurrBlock = MP->getBlock();
- else
- llvm_unreachable("DFSToInstr gave us an unknown type of instruction");
-
- // If we hit a new block, do reachability processing.
- if (CurrBlock != LastBlock) {
- LastBlock = CurrBlock;
- bool BlockReachable = ReachableBlocks.count(CurrBlock);
- const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock);
-
- // If it's not reachable, erase any touched instructions and move on.
- if (!BlockReachable) {
- TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second);
- DEBUG(dbgs() << "Skipping instructions in block "
- << getBlockName(CurrBlock)
- << " because it is unreachable\n");
- continue;
- }
- updateProcessedCount(CurrBlock);
- }
-
- if (auto *MP = dyn_cast<MemoryPhi>(V)) {
- DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n");
- valueNumberMemoryPhi(MP);
- } else if (auto *I = dyn_cast<Instruction>(V)) {
- valueNumberInstruction(I);
- } else {
- llvm_unreachable("Should have been a MemoryPhi or Instruction");
- }
- updateProcessedCount(V);
- // Reset after processing (because we may mark ourselves as touched when
- // we propagate equalities).
- TouchedInstructions.reset(InstrNum);
- }
- }
- NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations);
-#ifndef NDEBUG
+ iterateTouchedInstructions();
verifyMemoryCongruency();
-#endif
+ verifyIterationSettled(F);
+
Changed |= eliminateInstructions(F);
// Delete all instructions marked for deletion.
@@ -1783,36 +2780,6 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC,
return Changed;
}
-bool NewGVN::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
- return runGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
- &getAnalysis<AAResultsWrapperPass>().getAAResults(),
- &getAnalysis<MemorySSAWrapperPass>().getMSSA());
-}
-
-PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) {
- NewGVN Impl;
-
- // Apparently the order in which we get these results matter for
- // the old GVN (see Chandler's comment in GVN.cpp). I'll keep
- // the same order here, just in case.
- auto &AC = AM.getResult<AssumptionAnalysis>(F);
- auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- auto &AA = AM.getResult<AAManager>(F);
- auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
- bool Changed = Impl.runGVN(F, &DT, &AC, &TLI, &AA, &MSSA);
- if (!Changed)
- return PreservedAnalyses::all();
- PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<GlobalsAA>();
- return PA;
-}
-
// Return true if V is a value that will always be available (IE can
// be placed anywhere) in the function. We don't do globals here
// because they are often worse to put in place.
@@ -1821,21 +2788,15 @@ static bool alwaysAvailable(Value *V) {
return isa<Constant>(V) || isa<Argument>(V);
}
-// Get the basic block from an instruction/value.
-static BasicBlock *getBlockForValue(Value *V) {
- if (auto *I = dyn_cast<Instruction>(V))
- return I->getParent();
- return nullptr;
-}
-
struct NewGVN::ValueDFS {
int DFSIn = 0;
int DFSOut = 0;
int LocalNum = 0;
- // Only one of these will be set.
- Value *Val = nullptr;
+ // Only one of Def and U will be set.
+ // The bool in the Def tells us whether the Def is the stored value of a
+ // store.
+ PointerIntPair<Value *, 1, bool> Def;
Use *U = nullptr;
-
bool operator<(const ValueDFS &Other) const {
// It's not enough that any given field be less than - we have sets
// of fields that need to be evaluated together to give a proper ordering.
@@ -1875,89 +2836,151 @@ struct NewGVN::ValueDFS {
// but .val and .u.
// It does not matter what order we replace these operands in.
// You will always end up with the same IR, and this is guaranteed.
- return std::tie(DFSIn, DFSOut, LocalNum, Val, U) <
- std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Val,
+ return std::tie(DFSIn, DFSOut, LocalNum, Def, U) <
+ std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Def,
Other.U);
}
};
-void NewGVN::convertDenseToDFSOrdered(
- CongruenceClass::MemberSet &Dense,
- SmallVectorImpl<ValueDFS> &DFSOrderedSet) {
+// This function converts the set of members for a congruence class from values,
+// to sets of defs and uses with associated DFS info. The total number of
+// reachable uses for each value is stored in UseCount, and instructions that
+// seem
+// dead (have no non-dead uses) are stored in ProbablyDead.
+void NewGVN::convertClassToDFSOrdered(
+ const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &DFSOrderedSet,
+ DenseMap<const Value *, unsigned int> &UseCounts,
+ SmallPtrSetImpl<Instruction *> &ProbablyDead) const {
for (auto D : Dense) {
// First add the value.
BasicBlock *BB = getBlockForValue(D);
// Constants are handled prior to ever calling this function, so
// we should only be left with instructions as members.
assert(BB && "Should have figured out a basic block for value");
- ValueDFS VD;
-
- std::pair<int, int> DFSPair = DFSDomMap[BB];
- assert(DFSPair.first != -1 && DFSPair.second != -1 && "Invalid DFS Pair");
- VD.DFSIn = DFSPair.first;
- VD.DFSOut = DFSPair.second;
- VD.Val = D;
- // If it's an instruction, use the real local dfs number.
- if (auto *I = dyn_cast<Instruction>(D))
- VD.LocalNum = InstrDFS[I];
- else
- llvm_unreachable("Should have been an instruction");
-
- DFSOrderedSet.emplace_back(VD);
-
- // Now add the users.
- for (auto &U : D->uses()) {
+ ValueDFS VDDef;
+ DomTreeNode *DomNode = DT->getNode(BB);
+ VDDef.DFSIn = DomNode->getDFSNumIn();
+ VDDef.DFSOut = DomNode->getDFSNumOut();
+ // If it's a store, use the leader of the value operand, if it's always
+ // available, or the value operand. TODO: We could do dominance checks to
+ // find a dominating leader, but not worth it ATM.
+ if (auto *SI = dyn_cast<StoreInst>(D)) {
+ auto Leader = lookupOperandLeader(SI->getValueOperand());
+ if (alwaysAvailable(Leader)) {
+ VDDef.Def.setPointer(Leader);
+ } else {
+ VDDef.Def.setPointer(SI->getValueOperand());
+ VDDef.Def.setInt(true);
+ }
+ } else {
+ VDDef.Def.setPointer(D);
+ }
+ assert(isa<Instruction>(D) &&
+ "The dense set member should always be an instruction");
+ VDDef.LocalNum = InstrToDFSNum(D);
+ DFSOrderedSet.emplace_back(VDDef);
+ Instruction *Def = cast<Instruction>(D);
+ unsigned int UseCount = 0;
+ // Now add the uses.
+ for (auto &U : Def->uses()) {
if (auto *I = dyn_cast<Instruction>(U.getUser())) {
- ValueDFS VD;
+ // Don't try to replace into dead uses
+ if (InstructionsToErase.count(I))
+ continue;
+ ValueDFS VDUse;
// Put the phi node uses in the incoming block.
BasicBlock *IBlock;
if (auto *P = dyn_cast<PHINode>(I)) {
IBlock = P->getIncomingBlock(U);
// Make phi node users appear last in the incoming block
// they are from.
- VD.LocalNum = InstrDFS.size() + 1;
+ VDUse.LocalNum = InstrDFS.size() + 1;
} else {
IBlock = I->getParent();
- VD.LocalNum = InstrDFS[I];
+ VDUse.LocalNum = InstrToDFSNum(I);
}
- std::pair<int, int> DFSPair = DFSDomMap[IBlock];
- VD.DFSIn = DFSPair.first;
- VD.DFSOut = DFSPair.second;
- VD.U = &U;
- DFSOrderedSet.emplace_back(VD);
+
+ // Skip uses in unreachable blocks, as we're going
+ // to delete them.
+ if (ReachableBlocks.count(IBlock) == 0)
+ continue;
+
+ DomTreeNode *DomNode = DT->getNode(IBlock);
+ VDUse.DFSIn = DomNode->getDFSNumIn();
+ VDUse.DFSOut = DomNode->getDFSNumOut();
+ VDUse.U = &U;
+ ++UseCount;
+ DFSOrderedSet.emplace_back(VDUse);
}
}
+
+ // If there are no uses, it's probably dead (but it may have side-effects,
+ // so not definitely dead. Otherwise, store the number of uses so we can
+ // track if it becomes dead later).
+ if (UseCount == 0)
+ ProbablyDead.insert(Def);
+ else
+ UseCounts[Def] = UseCount;
}
}
-static void patchReplacementInstruction(Instruction *I, Value *Repl) {
- // Patch the replacement so that it is not more restrictive than the value
- // being replaced.
- auto *Op = dyn_cast<BinaryOperator>(I);
- auto *ReplOp = dyn_cast<BinaryOperator>(Repl);
+// This function converts the set of members for a congruence class from values,
+// to the set of defs for loads and stores, with associated DFS info.
+void NewGVN::convertClassToLoadsAndStores(
+ const CongruenceClass &Dense,
+ SmallVectorImpl<ValueDFS> &LoadsAndStores) const {
+ for (auto D : Dense) {
+ if (!isa<LoadInst>(D) && !isa<StoreInst>(D))
+ continue;
- if (Op && ReplOp)
- ReplOp->andIRFlags(Op);
+ BasicBlock *BB = getBlockForValue(D);
+ ValueDFS VD;
+ DomTreeNode *DomNode = DT->getNode(BB);
+ VD.DFSIn = DomNode->getDFSNumIn();
+ VD.DFSOut = DomNode->getDFSNumOut();
+ VD.Def.setPointer(D);
- if (auto *ReplInst = dyn_cast<Instruction>(Repl)) {
- // FIXME: If both the original and replacement value are part of the
- // same control-flow region (meaning that the execution of one
- // guarentees the executation of the other), then we can combine the
- // noalias scopes here and do better than the general conservative
- // answer used in combineMetadata().
+ // If it's an instruction, use the real local dfs number.
+ if (auto *I = dyn_cast<Instruction>(D))
+ VD.LocalNum = InstrToDFSNum(I);
+ else
+ llvm_unreachable("Should have been an instruction");
- // In general, GVN unifies expressions over different control-flow
- // regions, and so we need a conservative combination of the noalias
- // scopes.
- unsigned KnownIDs[] = {
- LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias, LLVMContext::MD_range,
- LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load,
- LLVMContext::MD_invariant_group};
- combineMetadata(ReplInst, I, KnownIDs);
+ LoadsAndStores.emplace_back(VD);
}
}
+static void patchReplacementInstruction(Instruction *I, Value *Repl) {
+ auto *ReplInst = dyn_cast<Instruction>(Repl);
+ if (!ReplInst)
+ return;
+
+ // Patch the replacement so that it is not more restrictive than the value
+ // being replaced.
+ // Note that if 'I' is a load being replaced by some operation,
+ // for example, by an arithmetic operation, then andIRFlags()
+ // would just erase all math flags from the original arithmetic
+ // operation, which is clearly not wanted and not needed.
+ if (!isa<LoadInst>(I))
+ ReplInst->andIRFlags(I);
+
+ // FIXME: If both the original and replacement value are part of the
+ // same control-flow region (meaning that the execution of one
+ // guarantees the execution of the other), then we can combine the
+ // noalias scopes here and do better than the general conservative
+ // answer used in combineMetadata().
+
+ // In general, GVN unifies expressions over different control-flow
+ // regions, and so we need a conservative combination of the noalias
+ // scopes.
+ static const unsigned KnownIDs[] = {
+ LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias, LLVMContext::MD_range,
+ LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load,
+ LLVMContext::MD_invariant_group};
+ combineMetadata(ReplInst, I, KnownIDs);
+}
+
static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) {
patchReplacementInstruction(I, Repl);
I->replaceAllUsesWith(Repl);
@@ -1967,10 +2990,6 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {
DEBUG(dbgs() << " BasicBlock Dead:" << *BB);
++NumGVNBlocksDeleted;
- // Check to see if there are non-terminating instructions to delete.
- if (isa<TerminatorInst>(BB->begin()))
- return;
-
// Delete the instructions backwards, as it has a reduced likelihood of having
// to update as many def-use and use-def chains. Start after the terminator.
auto StartPoint = BB->rbegin();
@@ -1987,6 +3006,11 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {
Inst.eraseFromParent();
++NumGVNInstrDeleted;
}
+ // Now insert something that simplifycfg will turn into an unreachable.
+ Type *Int8Ty = Type::getInt8Ty(BB->getContext());
+ new StoreInst(UndefValue::get(Int8Ty),
+ Constant::getNullValue(Int8Ty->getPointerTo()),
+ BB->getTerminator());
}
void NewGVN::markInstructionForDeletion(Instruction *I) {
@@ -2086,59 +3110,59 @@ bool NewGVN::eliminateInstructions(Function &F) {
}
}
}
- DomTreeNode *Node = DT->getNode(&B);
- if (Node)
- DFSDomMap[&B] = {Node->getDFSNumIn(), Node->getDFSNumOut()};
}
- for (CongruenceClass *CC : CongruenceClasses) {
- // FIXME: We should eventually be able to replace everything still
- // in the initial class with undef, as they should be unreachable.
- // Right now, initial still contains some things we skip value
- // numbering of (UNREACHABLE's, for example).
- if (CC == InitialClass || CC->Dead)
+ // Map to store the use counts
+ DenseMap<const Value *, unsigned int> UseCounts;
+ for (CongruenceClass *CC : reverse(CongruenceClasses)) {
+ // Track the equivalent store info so we can decide whether to try
+ // dead store elimination.
+ SmallVector<ValueDFS, 8> PossibleDeadStores;
+ SmallPtrSet<Instruction *, 8> ProbablyDead;
+ if (CC->isDead() || CC->empty())
continue;
- assert(CC->RepLeader && "We should have had a leader");
+ // Everything still in the TOP class is unreachable or dead.
+ if (CC == TOPClass) {
+#ifndef NDEBUG
+ for (auto M : *CC)
+ assert((!ReachableBlocks.count(cast<Instruction>(M)->getParent()) ||
+ InstructionsToErase.count(cast<Instruction>(M))) &&
+ "Everything in TOP should be unreachable or dead at this "
+ "point");
+#endif
+ continue;
+ }
+ assert(CC->getLeader() && "We should have had a leader");
// If this is a leader that is always available, and it's a
// constant or has no equivalences, just replace everything with
// it. We then update the congruence class with whatever members
// are left.
- if (alwaysAvailable(CC->RepLeader)) {
- SmallPtrSet<Value *, 4> MembersLeft;
- for (auto M : CC->Members) {
-
+ Value *Leader =
+ CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader();
+ if (alwaysAvailable(Leader)) {
+ CongruenceClass::MemberSet MembersLeft;
+ for (auto M : *CC) {
Value *Member = M;
-
// Void things have no uses we can replace.
- if (Member == CC->RepLeader || Member->getType()->isVoidTy()) {
+ if (Member == Leader || !isa<Instruction>(Member) ||
+ Member->getType()->isVoidTy()) {
MembersLeft.insert(Member);
continue;
}
-
- DEBUG(dbgs() << "Found replacement " << *(CC->RepLeader) << " for "
- << *Member << "\n");
- // Due to equality propagation, these may not always be
- // instructions, they may be real values. We don't really
- // care about trying to replace the non-instructions.
- if (auto *I = dyn_cast<Instruction>(Member)) {
- assert(CC->RepLeader != I &&
- "About to accidentally remove our leader");
- replaceInstruction(I, CC->RepLeader);
- AnythingReplaced = true;
-
- continue;
- } else {
- MembersLeft.insert(I);
- }
+ DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member
+ << "\n");
+ auto *I = cast<Instruction>(Member);
+ assert(Leader != I && "About to accidentally remove our leader");
+ replaceInstruction(I, Leader);
+ AnythingReplaced = true;
}
- CC->Members.swap(MembersLeft);
-
+ CC->swap(MembersLeft);
} else {
- DEBUG(dbgs() << "Eliminating in congruence class " << CC->ID << "\n");
+ DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID()
+ << "\n");
// If this is a singleton, we can skip it.
- if (CC->Members.size() != 1) {
-
+ if (CC->size() != 1) {
// This is a stack because equality replacement/etc may place
// constants in the middle of the member list, and we want to use
// those constant values in preference to the current leader, over
@@ -2147,24 +3171,19 @@ bool NewGVN::eliminateInstructions(Function &F) {
// Convert the members to DFS ordered sets and then merge them.
SmallVector<ValueDFS, 8> DFSOrderedSet;
- convertDenseToDFSOrdered(CC->Members, DFSOrderedSet);
+ convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead);
// Sort the whole thing.
std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end());
-
for (auto &VD : DFSOrderedSet) {
int MemberDFSIn = VD.DFSIn;
int MemberDFSOut = VD.DFSOut;
- Value *Member = VD.Val;
- Use *MemberUse = VD.U;
-
- if (Member) {
- // We ignore void things because we can't get a value from them.
- // FIXME: We could actually use this to kill dead stores that are
- // dominated by equivalent earlier stores.
- if (Member->getType()->isVoidTy())
- continue;
- }
+ Value *Def = VD.Def.getPointer();
+ bool FromStore = VD.Def.getInt();
+ Use *U = VD.U;
+ // We ignore void things because we can't get a value from them.
+ if (Def && Def->getType()->isVoidTy())
+ continue;
if (EliminationStack.empty()) {
DEBUG(dbgs() << "Elimination Stack is empty\n");
@@ -2189,69 +3208,240 @@ bool NewGVN::eliminateInstructions(Function &F) {
// start using, we also push.
// Otherwise, we walk along, processing members who are
// dominated by this scope, and eliminate them.
- bool ShouldPush =
- Member && (EliminationStack.empty() || isa<Constant>(Member));
+ bool ShouldPush = Def && EliminationStack.empty();
bool OutOfScope =
!EliminationStack.isInScope(MemberDFSIn, MemberDFSOut);
if (OutOfScope || ShouldPush) {
// Sync to our current scope.
EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut);
- ShouldPush |= Member && EliminationStack.empty();
+ bool ShouldPush = Def && EliminationStack.empty();
if (ShouldPush) {
- EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut);
+ EliminationStack.push_back(Def, MemberDFSIn, MemberDFSOut);
+ }
+ }
+
+ // Skip the Def's, we only want to eliminate on their uses. But mark
+ // dominated defs as dead.
+ if (Def) {
+ // For anything in this case, what and how we value number
+ // guarantees that any side-effets that would have occurred (ie
+ // throwing, etc) can be proven to either still occur (because it's
+ // dominated by something that has the same side-effects), or never
+ // occur. Otherwise, we would not have been able to prove it value
+ // equivalent to something else. For these things, we can just mark
+ // it all dead. Note that this is different from the "ProbablyDead"
+ // set, which may not be dominated by anything, and thus, are only
+ // easy to prove dead if they are also side-effect free. Note that
+ // because stores are put in terms of the stored value, we skip
+ // stored values here. If the stored value is really dead, it will
+ // still be marked for deletion when we process it in its own class.
+ if (!EliminationStack.empty() && Def != EliminationStack.back() &&
+ isa<Instruction>(Def) && !FromStore)
+ markInstructionForDeletion(cast<Instruction>(Def));
+ continue;
+ }
+ // At this point, we know it is a Use we are trying to possibly
+ // replace.
+
+ assert(isa<Instruction>(U->get()) &&
+ "Current def should have been an instruction");
+ assert(isa<Instruction>(U->getUser()) &&
+ "Current user should have been an instruction");
+
+ // If the thing we are replacing into is already marked to be dead,
+ // this use is dead. Note that this is true regardless of whether
+ // we have anything dominating the use or not. We do this here
+ // because we are already walking all the uses anyway.
+ Instruction *InstUse = cast<Instruction>(U->getUser());
+ if (InstructionsToErase.count(InstUse)) {
+ auto &UseCount = UseCounts[U->get()];
+ if (--UseCount == 0) {
+ ProbablyDead.insert(cast<Instruction>(U->get()));
}
}
// If we get to this point, and the stack is empty we must have a use
- // with nothing we can use to eliminate it, just skip it.
+ // with nothing we can use to eliminate this use, so just skip it.
if (EliminationStack.empty())
continue;
- // Skip the Value's, we only want to eliminate on their uses.
- if (Member)
- continue;
- Value *Result = EliminationStack.back();
+ Value *DominatingLeader = EliminationStack.back();
// Don't replace our existing users with ourselves.
- if (MemberUse->get() == Result)
+ if (U->get() == DominatingLeader)
continue;
-
- DEBUG(dbgs() << "Found replacement " << *Result << " for "
- << *MemberUse->get() << " in " << *(MemberUse->getUser())
- << "\n");
+ DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for "
+ << *U->get() << " in " << *(U->getUser()) << "\n");
// If we replaced something in an instruction, handle the patching of
- // metadata.
- if (auto *ReplacedInst = dyn_cast<Instruction>(MemberUse->get()))
- patchReplacementInstruction(ReplacedInst, Result);
-
- assert(isa<Instruction>(MemberUse->getUser()));
- MemberUse->set(Result);
+ // metadata. Skip this if we are replacing predicateinfo with its
+ // original operand, as we already know we can just drop it.
+ auto *ReplacedInst = cast<Instruction>(U->get());
+ auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst);
+ if (!PI || DominatingLeader != PI->OriginalOp)
+ patchReplacementInstruction(ReplacedInst, DominatingLeader);
+ U->set(DominatingLeader);
+ // This is now a use of the dominating leader, which means if the
+ // dominating leader was dead, it's now live!
+ auto &LeaderUseCount = UseCounts[DominatingLeader];
+ // It's about to be alive again.
+ if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader))
+ ProbablyDead.erase(cast<Instruction>(DominatingLeader));
+ ++LeaderUseCount;
AnythingReplaced = true;
}
}
}
+ // At this point, anything still in the ProbablyDead set is actually dead if
+ // would be trivially dead.
+ for (auto *I : ProbablyDead)
+ if (wouldInstructionBeTriviallyDead(I))
+ markInstructionForDeletion(I);
+
// Cleanup the congruence class.
- SmallPtrSet<Value *, 4> MembersLeft;
- for (Value *Member : CC->Members) {
- if (Member->getType()->isVoidTy()) {
+ CongruenceClass::MemberSet MembersLeft;
+ for (auto *Member : *CC)
+ if (!isa<Instruction>(Member) ||
+ !InstructionsToErase.count(cast<Instruction>(Member)))
MembersLeft.insert(Member);
- continue;
- }
-
- if (auto *MemberInst = dyn_cast<Instruction>(Member)) {
- if (isInstructionTriviallyDead(MemberInst)) {
- // TODO: Don't mark loads of undefs.
- markInstructionForDeletion(MemberInst);
- continue;
+ CC->swap(MembersLeft);
+
+ // If we have possible dead stores to look at, try to eliminate them.
+ if (CC->getStoreCount() > 0) {
+ convertClassToLoadsAndStores(*CC, PossibleDeadStores);
+ std::sort(PossibleDeadStores.begin(), PossibleDeadStores.end());
+ ValueDFSStack EliminationStack;
+ for (auto &VD : PossibleDeadStores) {
+ int MemberDFSIn = VD.DFSIn;
+ int MemberDFSOut = VD.DFSOut;
+ Instruction *Member = cast<Instruction>(VD.Def.getPointer());
+ if (EliminationStack.empty() ||
+ !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut)) {
+ // Sync to our current scope.
+ EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut);
+ if (EliminationStack.empty()) {
+ EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut);
+ continue;
+ }
}
+ // We already did load elimination, so nothing to do here.
+ if (isa<LoadInst>(Member))
+ continue;
+ assert(!EliminationStack.empty());
+ Instruction *Leader = cast<Instruction>(EliminationStack.back());
+ (void)Leader;
+ assert(DT->dominates(Leader->getParent(), Member->getParent()));
+ // Member is dominater by Leader, and thus dead
+ DEBUG(dbgs() << "Marking dead store " << *Member
+ << " that is dominated by " << *Leader << "\n");
+ markInstructionForDeletion(Member);
+ CC->erase(Member);
+ ++NumGVNDeadStores;
}
- MembersLeft.insert(Member);
}
- CC->Members.swap(MembersLeft);
}
return AnythingReplaced;
}
+
+// This function provides global ranking of operations so that we can place them
+// in a canonical order. Note that rank alone is not necessarily enough for a
+// complete ordering, as constants all have the same rank. However, generally,
+// we will simplify an operation with all constants so that it doesn't matter
+// what order they appear in.
+unsigned int NewGVN::getRank(const Value *V) const {
+ // Prefer undef to anything else
+ if (isa<UndefValue>(V))
+ return 0;
+ if (isa<Constant>(V))
+ return 1;
+ else if (auto *A = dyn_cast<Argument>(V))
+ return 2 + A->getArgNo();
+
+ // Need to shift the instruction DFS by number of arguments + 3 to account for
+ // the constant and argument ranking above.
+ unsigned Result = InstrToDFSNum(V);
+ if (Result > 0)
+ return 3 + NumFuncArgs + Result;
+ // Unreachable or something else, just return a really large number.
+ return ~0;
+}
+
+// This is a function that says whether two commutative operations should
+// have their order swapped when canonicalizing.
+bool NewGVN::shouldSwapOperands(const Value *A, const Value *B) const {
+ // Because we only care about a total ordering, and don't rewrite expressions
+ // in this order, we order by rank, which will give a strict weak ordering to
+ // everything but constants, and then we order by pointer address.
+ return std::make_pair(getRank(A), A) > std::make_pair(getRank(B), B);
+}
+
+class NewGVNLegacyPass : public FunctionPass {
+public:
+ static char ID; // Pass identification, replacement for typeid.
+ NewGVNLegacyPass() : FunctionPass(ID) {
+ initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+ bool runOnFunction(Function &F) override;
+
+private:
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<MemorySSAWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addPreserved<GlobalsAAWrapperPass>();
+ }
+};
+
+bool NewGVNLegacyPass::runOnFunction(Function &F) {
+ if (skipFunction(F))
+ return false;
+ return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
+ &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
+ &getAnalysis<AAResultsWrapperPass>().getAAResults(),
+ &getAnalysis<MemorySSAWrapperPass>().getMSSA(),
+ F.getParent()->getDataLayout())
+ .runGVN();
+}
+
+INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering",
+ false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
+INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false,
+ false)
+
+char NewGVNLegacyPass::ID = 0;
+
+// createGVNPass - The public interface to this file.
+FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); }
+
+PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) {
+ // Apparently the order in which we get these results matter for
+ // the old GVN (see Chandler's comment in GVN.cpp). I'll keep
+ // the same order here, just in case.
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+ auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
+ bool Changed =
+ NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getParent()->getDataLayout())
+ .runGVN();
+ if (!Changed)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<GlobalsAA>();
+ return PA;
+}
diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
index 1a7ddc9585baf..1bfecea2f61e6 100644
--- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
+++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
@@ -66,7 +66,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc,
// Add attribute "readnone" so that backend can use a native sqrt instruction
// for this call. Insert a FP compare instruction and a conditional branch
// at the end of CurrBB.
- Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
+ Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
CurrBB.getTerminator()->eraseFromParent();
Builder.SetInsertPoint(&CurrBB);
Value *FCmp = Builder.CreateFCmpOEQ(Call, Call);
@@ -98,14 +98,14 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI,
// Skip if function either has local linkage or is not a known library
// function.
- LibFunc::Func LibFunc;
+ LibFunc LF;
if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() ||
- !TLI->getLibFunc(CalledFunc->getName(), LibFunc))
+ !TLI->getLibFunc(CalledFunc->getName(), LF))
continue;
- switch (LibFunc) {
- case LibFunc::sqrtf:
- case LibFunc::sqrt:
+ switch (LF) {
+ case LibFunc_sqrtf:
+ case LibFunc_sqrt:
if (TTI->haveFastSqrt(Call->getType()) &&
optimizeSQRT(Call, CalledFunc, *CurrBB, BB))
break;
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index 65c814d7a63b8..3dcab60907896 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -1069,8 +1069,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
///
/// Ops is the top-level list of add operands we're trying to factor.
static void FindSingleUseMultiplyFactors(Value *V,
- SmallVectorImpl<Value*> &Factors,
- const SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<Value*> &Factors) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
if (!BO) {
Factors.push_back(V);
@@ -1078,8 +1077,8 @@ static void FindSingleUseMultiplyFactors(Value *V,
}
// Otherwise, add the LHS and RHS to the list of factors.
- FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, Ops);
- FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, Ops);
+ FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
+ FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
}
/// Optimize a series of operands to an 'and', 'or', or 'xor' instruction.
@@ -1499,7 +1498,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Compute all of the factors of this added value.
SmallVector<Value*, 8> Factors;
- FindSingleUseMultiplyFactors(BOp, Factors, Ops);
+ FindSingleUseMultiplyFactors(BOp, Factors);
assert(Factors.size() > 1 && "Bad linearize!");
// Add one to FactorOccurrences for each unique factor in this op.
@@ -2236,8 +2235,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
ValueRankMap.clear();
if (MadeChange) {
- // FIXME: This should also 'preserve the CFG'.
- auto PA = PreservedAnalyses();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index 1de742050cb33..f344eb151464a 100644
--- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -365,6 +365,11 @@ findBaseDefiningValueOfVector(Value *I) {
// for particular sufflevector patterns.
return BaseDefiningValueResult(I, false);
+ // The behavior of getelementptr instructions is the same for vector and
+ // non-vector data types.
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
+ return findBaseDefiningValue(GEP->getPointerOperand());
+
// A PHI or Select is a base defining value. The outer findBasePointer
// algorithm is responsible for constructing a base value for this BDV.
assert((isa<SelectInst>(I) || isa<PHINode>(I)) &&
@@ -634,7 +639,7 @@ static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) {
// Values of type BDVState form a lattice, and this function implements the meet
// operation.
-static BDVState meetBDVState(BDVState LHS, BDVState RHS) {
+static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {
BDVState Result = meetBDVStateImpl(LHS, RHS);
assert(Result == meetBDVStateImpl(RHS, LHS) &&
"Math is wrong: meet does not commute!");
@@ -1123,14 +1128,14 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent,
// Create new attribute set containing only attributes which can be transferred
// from original call to the safepoint.
-static AttributeSet legalizeCallAttributes(AttributeSet AS) {
- AttributeSet Ret;
+static AttributeList legalizeCallAttributes(AttributeList AS) {
+ AttributeList Ret;
for (unsigned Slot = 0; Slot < AS.getNumSlots(); Slot++) {
unsigned Index = AS.getSlotIndex(Slot);
- if (Index == AttributeSet::ReturnIndex ||
- Index == AttributeSet::FunctionIndex) {
+ if (Index == AttributeList::ReturnIndex ||
+ Index == AttributeList::FunctionIndex) {
for (Attribute Attr : make_range(AS.begin(Slot), AS.end(Slot))) {
@@ -1148,7 +1153,7 @@ static AttributeSet legalizeCallAttributes(AttributeSet AS) {
Ret = Ret.addAttributes(
AS.getContext(), Index,
- AttributeSet::get(AS.getContext(), Index, AttrBuilder(Attr)));
+ AttributeList::get(AS.getContext(), Index, AttrBuilder(Attr)));
}
}
@@ -1299,12 +1304,11 @@ static StringRef getDeoptLowering(CallSite CS) {
const char *DeoptLowering = "deopt-lowering";
if (CS.hasFnAttr(DeoptLowering)) {
// FIXME: CallSite has a *really* confusing interface around attributes
- // with values.
- const AttributeSet &CSAS = CS.getAttributes();
- if (CSAS.hasAttribute(AttributeSet::FunctionIndex,
- DeoptLowering))
- return CSAS.getAttribute(AttributeSet::FunctionIndex,
- DeoptLowering).getValueAsString();
+ // with values.
+ const AttributeList &CSAS = CS.getAttributes();
+ if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering))
+ return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering)
+ .getValueAsString();
Function *F = CS.getCalledFunction();
assert(F && F->hasFnAttribute(DeoptLowering));
return F->getFnAttribute(DeoptLowering).getValueAsString();
@@ -1388,7 +1392,6 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
// Create the statepoint given all the arguments
Instruction *Token = nullptr;
- AttributeSet ReturnAttrs;
if (CS.isCall()) {
CallInst *ToReplace = cast<CallInst>(CS.getInstruction());
CallInst *Call = Builder.CreateGCStatepointCall(
@@ -1400,11 +1403,12 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
// Currently we will fail on parameter attributes and on certain
// function attributes.
- AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());
+ AttributeList NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());
// In case if we can handle this set of attributes - set up function attrs
// directly on statepoint and return attrs later for gc_result intrinsic.
- Call->setAttributes(NewAttrs.getFnAttributes());
- ReturnAttrs = NewAttrs.getRetAttributes();
+ Call->setAttributes(AttributeList::get(Call->getContext(),
+ AttributeList::FunctionIndex,
+ NewAttrs.getFnAttributes()));
Token = Call;
@@ -1428,11 +1432,12 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
// Currently we will fail on parameter attributes and on certain
// function attributes.
- AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());
+ AttributeList NewAttrs = legalizeCallAttributes(ToReplace->getAttributes());
// In case if we can handle this set of attributes - set up function attrs
// directly on statepoint and return attrs later for gc_result intrinsic.
- Invoke->setAttributes(NewAttrs.getFnAttributes());
- ReturnAttrs = NewAttrs.getRetAttributes();
+ Invoke->setAttributes(AttributeList::get(Invoke->getContext(),
+ AttributeList::FunctionIndex,
+ NewAttrs.getFnAttributes()));
Token = Invoke;
@@ -1478,7 +1483,9 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
StringRef Name =
CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : "";
CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name);
- GCResult->setAttributes(CS.getAttributes().getRetAttributes());
+ GCResult->setAttributes(
+ AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex,
+ CS.getAttributes().getRetAttributes()));
// We cannot RAUW or delete CS.getInstruction() because it could be in the
// live set of some other safepoint, in which case that safepoint's
@@ -1615,8 +1622,10 @@ static void relocationViaAlloca(
// Emit alloca for "LiveValue" and record it in "allocaMap" and
// "PromotableAllocas"
+ const DataLayout &DL = F.getParent()->getDataLayout();
auto emitAllocaFor = [&](Value *LiveValue) {
- AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), "",
+ AllocaInst *Alloca = new AllocaInst(LiveValue->getType(),
+ DL.getAllocaAddrSpace(), "",
F.getEntryBlock().getFirstNonPHI());
AllocaMap[LiveValue] = Alloca;
PromotableAllocas.push_back(Alloca);
@@ -1873,7 +1882,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,
"non noop cast is found during rematerialization");
Type *SrcTy = CI->getOperand(0)->getType();
- Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy);
+ Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI);
} else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {
// Cost of the address calculation
@@ -2304,7 +2313,7 @@ static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH,
if (!R.empty())
AH.setAttributes(AH.getAttributes().removeAttributes(
- Ctx, Index, AttributeSet::get(Ctx, Index, R)));
+ Ctx, Index, AttributeList::get(Ctx, Index, R)));
}
void
@@ -2316,7 +2325,7 @@ RewriteStatepointsForGC::stripNonValidAttributesFromPrototype(Function &F) {
RemoveNonValidAttrAtIndex(Ctx, F, A.getArgNo() + 1);
if (isa<PointerType>(F.getReturnType()))
- RemoveNonValidAttrAtIndex(Ctx, F, AttributeSet::ReturnIndex);
+ RemoveNonValidAttrAtIndex(Ctx, F, AttributeList::ReturnIndex);
}
void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) {
@@ -2351,7 +2360,7 @@ void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) {
if (isa<PointerType>(CS.getArgument(i)->getType()))
RemoveNonValidAttrAtIndex(Ctx, CS, i + 1);
if (isa<PointerType>(CS.getType()))
- RemoveNonValidAttrAtIndex(Ctx, CS, AttributeSet::ReturnIndex);
+ RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex);
}
}
}
diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp
index ede381c4c2430..8908dae2f5459 100644
--- a/lib/Transforms/Scalar/SCCP.cpp
+++ b/lib/Transforms/Scalar/SCCP.cpp
@@ -140,6 +140,14 @@ public:
return nullptr;
}
+ /// getBlockAddress - If this is a constant with a BlockAddress value, return
+ /// it, otherwise return null.
+ BlockAddress *getBlockAddress() const {
+ if (isConstant())
+ return dyn_cast<BlockAddress>(getConstant());
+ return nullptr;
+ }
+
void markForcedConstant(Constant *V) {
assert(isUnknown() && "Can't force a defined value!");
Val.setInt(forcedconstant);
@@ -306,20 +314,14 @@ public:
return MRVFunctionsTracked;
}
- void markOverdefined(Value *V) {
- assert(!V->getType()->isStructTy() &&
- "structs should use markAnythingOverdefined");
- markOverdefined(ValueState[V], V);
- }
-
- /// markAnythingOverdefined - Mark the specified value overdefined. This
+ /// markOverdefined - Mark the specified value overdefined. This
/// works with both scalars and structs.
- void markAnythingOverdefined(Value *V) {
+ void markOverdefined(Value *V) {
if (auto *STy = dyn_cast<StructType>(V->getType()))
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i)
markOverdefined(getStructValueState(V, i), V);
else
- markOverdefined(V);
+ markOverdefined(ValueState[V], V);
}
// isStructLatticeConstant - Return true if all the lattice values
@@ -513,12 +515,12 @@ private:
void visitCmpInst(CmpInst &I);
void visitExtractValueInst(ExtractValueInst &EVI);
void visitInsertValueInst(InsertValueInst &IVI);
- void visitLandingPadInst(LandingPadInst &I) { markAnythingOverdefined(&I); }
+ void visitLandingPadInst(LandingPadInst &I) { markOverdefined(&I); }
void visitFuncletPadInst(FuncletPadInst &FPI) {
- markAnythingOverdefined(&FPI);
+ markOverdefined(&FPI);
}
void visitCatchSwitchInst(CatchSwitchInst &CPI) {
- markAnythingOverdefined(&CPI);
+ markOverdefined(&CPI);
visitTerminatorInst(CPI);
}
@@ -538,16 +540,16 @@ private:
void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ }
void visitFenceInst (FenceInst &I) { /*returns void*/ }
void visitAtomicCmpXchgInst(AtomicCmpXchgInst &I) {
- markAnythingOverdefined(&I);
+ markOverdefined(&I);
}
void visitAtomicRMWInst (AtomicRMWInst &I) { markOverdefined(&I); }
void visitAllocaInst (Instruction &I) { markOverdefined(&I); }
- void visitVAArgInst (Instruction &I) { markAnythingOverdefined(&I); }
+ void visitVAArgInst (Instruction &I) { markOverdefined(&I); }
void visitInstruction(Instruction &I) {
// If a new instruction is added to LLVM that we don't handle.
DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n');
- markAnythingOverdefined(&I); // Just in case
+ markOverdefined(&I); // Just in case
}
};
@@ -602,14 +604,36 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI,
return;
}
- Succs[SI->findCaseValue(CI).getSuccessorIndex()] = true;
+ Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true;
return;
}
- // TODO: This could be improved if the operand is a [cast of a] BlockAddress.
- if (isa<IndirectBrInst>(&TI)) {
- // Just mark all destinations executable!
- Succs.assign(TI.getNumSuccessors(), true);
+ // In case of indirect branch and its address is a blockaddress, we mark
+ // the target as executable.
+ if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {
+ // Casts are folded by visitCastInst.
+ LatticeVal IBRValue = getValueState(IBR->getAddress());
+ BlockAddress *Addr = IBRValue.getBlockAddress();
+ if (!Addr) { // Overdefined or unknown condition?
+ // All destinations are executable!
+ if (!IBRValue.isUnknown())
+ Succs.assign(TI.getNumSuccessors(), true);
+ return;
+ }
+
+ BasicBlock* T = Addr->getBasicBlock();
+ assert(Addr->getFunction() == T->getParent() &&
+ "Block address of a different function ?");
+ for (unsigned i = 0; i < IBR->getNumSuccessors(); ++i) {
+ // This is the target.
+ if (IBR->getDestination(i) == T) {
+ Succs[i] = true;
+ return;
+ }
+ }
+
+ // If we didn't find our destination in the IBR successor list, then we
+ // have undefined behavior. Its ok to assume no successor is executable.
return;
}
@@ -659,13 +683,21 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) {
if (!CI)
return !SCValue.isUnknown();
- return SI->findCaseValue(CI).getCaseSuccessor() == To;
+ return SI->findCaseValue(CI)->getCaseSuccessor() == To;
}
- // Just mark all destinations executable!
- // TODO: This could be improved if the operand is a [cast of a] BlockAddress.
- if (isa<IndirectBrInst>(TI))
- return true;
+ // In case of indirect branch and its address is a blockaddress, we mark
+ // the target as executable.
+ if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) {
+ LatticeVal IBRValue = getValueState(IBR->getAddress());
+ BlockAddress *Addr = IBRValue.getBlockAddress();
+
+ if (!Addr)
+ return !IBRValue.isUnknown();
+
+ // At this point, the indirectbr is branching on a blockaddress.
+ return Addr->getBasicBlock() == To;
+ }
DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n');
llvm_unreachable("SCCP: Don't know how to handle this terminator!");
@@ -693,7 +725,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) {
// If this PN returns a struct, just mark the result overdefined.
// TODO: We could do a lot better than this if code actually uses this.
if (PN.getType()->isStructTy())
- return markAnythingOverdefined(&PN);
+ return markOverdefined(&PN);
if (getValueState(&PN).isOverdefined())
return; // Quick exit
@@ -803,7 +835,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {
// If this returns a struct, mark all elements over defined, we don't track
// structs in structs.
if (EVI.getType()->isStructTy())
- return markAnythingOverdefined(&EVI);
+ return markOverdefined(&EVI);
// If this is extracting from more than one level of struct, we don't know.
if (EVI.getNumIndices() != 1)
@@ -828,7 +860,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {
// If this has more than one index, we can't handle it, drive all results to
// undef.
if (IVI.getNumIndices() != 1)
- return markAnythingOverdefined(&IVI);
+ return markOverdefined(&IVI);
Value *Aggr = IVI.getAggregateOperand();
unsigned Idx = *IVI.idx_begin();
@@ -857,7 +889,7 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {
// If this select returns a struct, just mark the result overdefined.
// TODO: We could do a lot better than this if code actually uses this.
if (I.getType()->isStructTy())
- return markAnythingOverdefined(&I);
+ return markOverdefined(&I);
LatticeVal CondValue = getValueState(I.getCondition());
if (CondValue.isUnknown())
@@ -910,9 +942,16 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) {
// Otherwise, one of our operands is overdefined. Try to produce something
// better than overdefined with some tricks.
-
- // If this is an AND or OR with 0 or -1, it doesn't matter that the other
- // operand is overdefined.
+ // If this is 0 / Y, it doesn't matter that the second operand is
+ // overdefined, and we can replace it with zero.
+ if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv)
+ if (V1State.isConstant() && V1State.getConstant()->isNullValue())
+ return markConstant(IV, &I, V1State.getConstant());
+
+ // If this is:
+ // -> AND/MUL with 0
+ // -> OR with -1
+ // it doesn't matter that the other operand is overdefined.
if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul ||
I.getOpcode() == Instruction::Or) {
LatticeVal *NonOverdefVal = nullptr;
@@ -1021,7 +1060,7 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) {
void SCCPSolver::visitLoadInst(LoadInst &I) {
// If this load is of a struct, just mark the result overdefined.
if (I.getType()->isStructTy())
- return markAnythingOverdefined(&I);
+ return markOverdefined(&I);
LatticeVal PtrVal = getValueState(I.getOperand(0));
if (PtrVal.isUnknown()) return; // The pointer is not resolved yet!
@@ -1107,7 +1146,7 @@ CallOverdefined:
}
// Otherwise, we don't know anything about this call, mark it overdefined.
- return markAnythingOverdefined(I);
+ return markOverdefined(I);
}
// If this is a local function that doesn't have its address taken, mark its
@@ -1483,6 +1522,31 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
return true;
}
+ if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) {
+ // Indirect branch with no successor ?. Its ok to assume it branches
+ // to no target.
+ if (IBR->getNumSuccessors() < 1)
+ continue;
+
+ if (!getValueState(IBR->getAddress()).isUnknown())
+ continue;
+
+ // If the input to SCCP is actually branch on undef, fix the undef to
+ // the first successor of the indirect branch.
+ if (isa<UndefValue>(IBR->getAddress())) {
+ IBR->setAddress(BlockAddress::get(IBR->getSuccessor(0)));
+ markEdgeExecutable(&BB, IBR->getSuccessor(0));
+ return true;
+ }
+
+ // Otherwise, it is a branch on a symbolic value which is currently
+ // considered to be undef. Handle this by forcing the input value to the
+ // branch to the first successor.
+ markForcedConstant(IBR->getAddress(),
+ BlockAddress::get(IBR->getSuccessor(0)));
+ return true;
+ }
+
if (auto *SI = dyn_cast<SwitchInst>(TI)) {
if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown())
continue;
@@ -1490,12 +1554,12 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
// If the input to SCCP is actually switch on undef, fix the undef to
// the first constant.
if (isa<UndefValue>(SI->getCondition())) {
- SI->setCondition(SI->case_begin().getCaseValue());
- markEdgeExecutable(&BB, SI->case_begin().getCaseSuccessor());
+ SI->setCondition(SI->case_begin()->getCaseValue());
+ markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor());
return true;
}
- markForcedConstant(SI->getCondition(), SI->case_begin().getCaseValue());
+ markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue());
return true;
}
}
@@ -1545,7 +1609,7 @@ static bool runSCCP(Function &F, const DataLayout &DL,
// Mark all arguments to the function as being overdefined.
for (Argument &AI : F.args())
- Solver.markAnythingOverdefined(&AI);
+ Solver.markOverdefined(&AI);
// Solve for constants.
bool ResolvedUndefs = true;
@@ -1728,7 +1792,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL,
// Assume nothing about the incoming arguments.
for (Argument &AI : F.args())
- Solver.markAnythingOverdefined(&AI);
+ Solver.markOverdefined(&AI);
}
// Loop over global variables. We inform the solver about any internal global
@@ -1817,32 +1881,9 @@ static bool runIPSCCP(Module &M, const DataLayout &DL,
if (!I) continue;
bool Folded = ConstantFoldTerminator(I->getParent());
- if (!Folded) {
- // The constant folder may not have been able to fold the terminator
- // if this is a branch or switch on undef. Fold it manually as a
- // branch to the first successor.
-#ifndef NDEBUG
- if (auto *BI = dyn_cast<BranchInst>(I)) {
- assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) &&
- "Branch should be foldable!");
- } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
- assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold");
- } else {
- llvm_unreachable("Didn't fold away reference to block!");
- }
-#endif
-
- // Make this an uncond branch to the first successor.
- TerminatorInst *TI = I->getParent()->getTerminator();
- BranchInst::Create(TI->getSuccessor(0), TI);
-
- // Remove entries in successor phi nodes to remove edges.
- for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i)
- TI->getSuccessor(i)->removePredecessor(TI->getParent());
-
- // Remove the old terminator.
- TI->eraseFromParent();
- }
+ assert(Folded &&
+ "Expect TermInst on constantint or blockaddress to be folded");
+ (void) Folded;
}
// Finally, delete the basic block.
diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp
index bfcb15530ef53..d01e91a7f2356 100644
--- a/lib/Transforms/Scalar/SROA.cpp
+++ b/lib/Transforms/Scalar/SROA.cpp
@@ -1825,6 +1825,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Rank the remaining candidate vector types. This is easy because we know
// they're all integer vectors. We sort by ascending number of elements.
auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+ (void)DL;
assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) &&
"Cannot have vector types of different sizes!");
assert(RHSTy->getElementType()->isIntegerTy() &&
@@ -2294,7 +2295,8 @@ private:
#endif
return getAdjustedPtr(IRB, DL, &NewAI,
- APInt(DL.getPointerSizeInBits(), Offset), PointerTy,
+ APInt(DL.getPointerTypeSizeInBits(PointerTy), Offset),
+ PointerTy,
#ifndef NDEBUG
Twine(OldName) + "."
#else
@@ -2369,6 +2371,8 @@ private:
Value *OldOp = LI.getOperand(0);
assert(OldOp == OldPtr);
+ unsigned AS = LI.getPointerAddressSpace();
+
Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)
: LI.getType();
const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize;
@@ -2387,6 +2391,10 @@ private:
LI.isVolatile(), LI.getName());
if (LI.isVolatile())
NewLI->setAtomic(LI.getOrdering(), LI.getSynchScope());
+
+ // Try to preserve nonnull metadata
+ if (TargetTy->isPointerTy())
+ NewLI->copyMetadata(LI, LLVMContext::MD_nonnull);
V = NewLI;
// If this is an integer load past the end of the slice (which means the
@@ -2401,7 +2409,7 @@ private:
"endian_shift");
}
} else {
- Type *LTy = TargetTy->getPointerTo();
+ Type *LTy = TargetTy->getPointerTo(AS);
LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy),
getSliceAlign(TargetTy),
LI.isVolatile(), LI.getName());
@@ -2429,7 +2437,7 @@ private:
// the computed value, and then replace the placeholder with LI, leaving
// LI only used for this computation.
Value *Placeholder =
- new LoadInst(UndefValue::get(LI.getType()->getPointerTo()));
+ new LoadInst(UndefValue::get(LI.getType()->getPointerTo(AS)));
V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,
"insert");
LI.replaceAllUsesWith(V);
@@ -2542,7 +2550,8 @@ private:
NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(),
SI.isVolatile());
} else {
- Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo());
+ unsigned AS = SI.getPointerAddressSpace();
+ Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS));
NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()),
SI.isVolatile());
}
@@ -3857,7 +3866,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
if (Alignment <= DL.getABITypeAlignment(SliceTy))
Alignment = 0;
NewAI = new AllocaInst(
- SliceTy, nullptr, Alignment,
+ SliceTy, AI.getType()->getAddressSpace(), nullptr, Alignment,
AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI);
++NumNewAllocas;
}
@@ -4184,7 +4193,7 @@ bool SROA::promoteAllocas(Function &F) {
NumPromoted += PromotableAllocas.size();
DEBUG(dbgs() << "Promoting allocas with mem2reg...\n");
- PromoteMemToReg(PromotableAllocas, *DT, nullptr, AC);
+ PromoteMemToReg(PromotableAllocas, *DT, AC);
PromotableAllocas.clear();
return true;
}
@@ -4234,9 +4243,8 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT,
if (!Changed)
return PreservedAnalyses::all();
- // FIXME: Even when promoting allocas we should preserve some abstract set of
- // CFG-specific analyses.
PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp
index afe7483006ae6..00e3c95f6f06d 100644
--- a/lib/Transforms/Scalar/Scalar.cpp
+++ b/lib/Transforms/Scalar/Scalar.cpp
@@ -43,13 +43,14 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeDSELegacyPassPass(Registry);
initializeGuardWideningLegacyPassPass(Registry);
initializeGVNLegacyPassPass(Registry);
- initializeNewGVNPass(Registry);
+ initializeNewGVNLegacyPassPass(Registry);
initializeEarlyCSELegacyPassPass(Registry);
initializeEarlyCSEMemSSALegacyPassPass(Registry);
initializeGVNHoistLegacyPassPass(Registry);
initializeFlattenCFGPassPass(Registry);
initializeInductiveRangeCheckEliminationPass(Registry);
initializeIndVarSimplifyLegacyPassPass(Registry);
+ initializeInferAddressSpacesPass(Registry);
initializeJumpThreadingPass(Registry);
initializeLegacyLICMPassPass(Registry);
initializeLegacyLoopSinkPassPass(Registry);
@@ -58,6 +59,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeLoopAccessLegacyAnalysisPass(Registry);
initializeLoopInstSimplifyLegacyPassPass(Registry);
initializeLoopInterchangePass(Registry);
+ initializeLoopPredicationLegacyPassPass(Registry);
initializeLoopRotateLegacyPassPass(Registry);
initializeLoopStrengthReducePass(Registry);
initializeLoopRerollPass(Registry);
@@ -79,6 +81,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeIPSCCPLegacyPassPass(Registry);
initializeSROALegacyPassPass(Registry);
initializeCFGSimplifyPassPass(Registry);
+ initializeLateCFGSimplifyPassPass(Registry);
initializeStructurizeCFGPass(Registry);
initializeSinkingLegacyPassPass(Registry);
initializeTailCallElimPass(Registry);
@@ -115,6 +118,10 @@ void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createCFGSimplificationPass());
}
+void LLVMAddLateCFGSimplificationPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createLateCFGSimplificationPass());
+}
+
void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createDeadStoreEliminationPass());
}
diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp
index 39969e27367f6..c0c09a7e43fe9 100644
--- a/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/lib/Transforms/Scalar/Scalarizer.cpp
@@ -520,12 +520,25 @@ bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
unsigned NumElems = VT->getNumElements();
unsigned NumIndices = GEPI.getNumIndices();
- Scatterer Base = scatter(&GEPI, GEPI.getOperand(0));
+ // The base pointer might be scalar even if it's a vector GEP. In those cases,
+ // splat the pointer into a vector value, and scatter that vector.
+ Value *Op0 = GEPI.getOperand(0);
+ if (!Op0->getType()->isVectorTy())
+ Op0 = Builder.CreateVectorSplat(NumElems, Op0);
+ Scatterer Base = scatter(&GEPI, Op0);
SmallVector<Scatterer, 8> Ops;
Ops.resize(NumIndices);
- for (unsigned I = 0; I < NumIndices; ++I)
- Ops[I] = scatter(&GEPI, GEPI.getOperand(I + 1));
+ for (unsigned I = 0; I < NumIndices; ++I) {
+ Value *Op = GEPI.getOperand(I + 1);
+
+ // The indices might be scalars even if it's a vector GEP. In those cases,
+ // splat the scalar into a vector value, and scatter that vector.
+ if (!Op->getType()->isVectorTy())
+ Op = Builder.CreateVectorSplat(NumElems, Op);
+
+ Ops[I] = scatter(&GEPI, Op);
+ }
ValueVector Res;
Res.resize(NumElems);
diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp
index f2723bd7af82e..8754c714c5b28 100644
--- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp
+++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp
@@ -130,7 +130,8 @@ static bool mergeEmptyReturnBlocks(Function &F) {
/// iterating until no more changes are made.
static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
AssumptionCache *AC,
- unsigned BonusInstThreshold) {
+ unsigned BonusInstThreshold,
+ bool LateSimplifyCFG) {
bool Changed = false;
bool LocalChange = true;
@@ -145,7 +146,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
// Loop over all of the basic blocks and remove them if they are unneeded.
for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) {
- if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders)) {
+ if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders, LateSimplifyCFG)) {
LocalChange = true;
++NumSimpl;
}
@@ -156,10 +157,12 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
}
static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,
- AssumptionCache *AC, int BonusInstThreshold) {
+ AssumptionCache *AC, int BonusInstThreshold,
+ bool LateSimplifyCFG) {
bool EverChanged = removeUnreachableBlocks(F);
EverChanged |= mergeEmptyReturnBlocks(F);
- EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold);
+ EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold,
+ LateSimplifyCFG);
// If neither pass changed anything, we're done.
if (!EverChanged) return false;
@@ -173,7 +176,8 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,
return true;
do {
- EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold);
+ EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold,
+ LateSimplifyCFG);
EverChanged |= removeUnreachableBlocks(F);
} while (EverChanged);
@@ -181,17 +185,19 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,
}
SimplifyCFGPass::SimplifyCFGPass()
- : BonusInstThreshold(UserBonusInstThreshold) {}
+ : BonusInstThreshold(UserBonusInstThreshold),
+ LateSimplifyCFG(true) {}
-SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold)
- : BonusInstThreshold(BonusInstThreshold) {}
+SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold, bool LateSimplifyCFG)
+ : BonusInstThreshold(BonusInstThreshold),
+ LateSimplifyCFG(LateSimplifyCFG) {}
PreservedAnalyses SimplifyCFGPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
- if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold))
+ if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold, LateSimplifyCFG))
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<GlobalsAA>();
@@ -199,16 +205,17 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F,
}
namespace {
-struct CFGSimplifyPass : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
+struct BaseCFGSimplifyPass : public FunctionPass {
unsigned BonusInstThreshold;
std::function<bool(const Function &)> PredicateFtor;
+ bool LateSimplifyCFG;
- CFGSimplifyPass(int T = -1,
- std::function<bool(const Function &)> Ftor = nullptr)
- : FunctionPass(ID), PredicateFtor(std::move(Ftor)) {
+ BaseCFGSimplifyPass(int T, bool LateSimplifyCFG,
+ std::function<bool(const Function &)> Ftor,
+ char &ID)
+ : FunctionPass(ID), PredicateFtor(std::move(Ftor)),
+ LateSimplifyCFG(LateSimplifyCFG) {
BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T);
- initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F)))
@@ -218,7 +225,7 @@ struct CFGSimplifyPass : public FunctionPass {
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
const TargetTransformInfo &TTI =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold);
+ return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold, LateSimplifyCFG);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -227,6 +234,26 @@ struct CFGSimplifyPass : public FunctionPass {
AU.addPreserved<GlobalsAAWrapperPass>();
}
};
+
+struct CFGSimplifyPass : public BaseCFGSimplifyPass {
+ static char ID; // Pass identification, replacement for typeid
+
+ CFGSimplifyPass(int T = -1,
+ std::function<bool(const Function &)> Ftor = nullptr)
+ : BaseCFGSimplifyPass(T, false, Ftor, ID) {
+ initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry());
+ }
+};
+
+struct LateCFGSimplifyPass : public BaseCFGSimplifyPass {
+ static char ID; // Pass identification, replacement for typeid
+
+ LateCFGSimplifyPass(int T = -1,
+ std::function<bool(const Function &)> Ftor = nullptr)
+ : BaseCFGSimplifyPass(T, true, Ftor, ID) {
+ initializeLateCFGSimplifyPassPass(*PassRegistry::getPassRegistry());
+ }
+};
}
char CFGSimplifyPass::ID = 0;
@@ -237,9 +264,24 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
false)
+char LateCFGSimplifyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LateCFGSimplifyPass, "latesimplifycfg",
+ "Simplify the CFG more aggressively", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_END(LateCFGSimplifyPass, "latesimplifycfg",
+ "Simplify the CFG more aggressively", false, false)
+
// Public interface to the CFGSimplification pass
FunctionPass *
llvm::createCFGSimplificationPass(int Threshold,
- std::function<bool(const Function &)> Ftor) {
+ std::function<bool(const Function &)> Ftor) {
return new CFGSimplifyPass(Threshold, std::move(Ftor));
}
+
+// Public interface to the LateCFGSimplification pass
+FunctionPass *
+llvm::createLateCFGSimplificationPass(int Threshold,
+ std::function<bool(const Function &)> Ftor) {
+ return new LateCFGSimplifyPass(Threshold, std::move(Ftor));
+}
diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp
index c3f14a0f4b1e5..102e9eaeab772 100644
--- a/lib/Transforms/Scalar/Sink.cpp
+++ b/lib/Transforms/Scalar/Sink.cpp
@@ -164,13 +164,14 @@ static bool SinkInstruction(Instruction *Inst,
// Instructions can only be sunk if all their uses are in blocks
// dominated by one of the successors.
- // Look at all the postdominators and see if we can sink it in one.
+ // Look at all the dominated blocks and see if we can sink it in one.
DomTreeNode *DTN = DT.getNode(Inst->getParent());
for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end();
I != E && SuccToSinkTo == nullptr; ++I) {
BasicBlock *Candidate = (*I)->getBlock();
- if ((*I)->getIDom()->getBlock() == Inst->getParent() &&
- IsAcceptableTarget(Inst, Candidate, DT, LI))
+ // A node always immediate-dominates its children on the dominator
+ // tree.
+ if (IsAcceptableTarget(Inst, Candidate, DT, LI))
SuccToSinkTo = Candidate;
}
@@ -262,9 +263,8 @@ PreservedAnalyses SinkingPass::run(Function &F, FunctionAnalysisManager &AM) {
if (!iterativelySinkInstructions(F, DT, LI, AA))
return PreservedAnalyses::all();
- auto PA = PreservedAnalyses();
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<LoopAnalysis>();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
return PA;
}
diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp
index 2e95926c0b3f5..4c9746b8c691e 100644
--- a/lib/Transforms/Utils/AddDiscriminators.cpp
+++ b/lib/Transforms/Utils/AddDiscriminators.cpp
@@ -102,6 +102,10 @@ FunctionPass *llvm::createAddDiscriminatorsPass() {
return new AddDiscriminatorsLegacyPass();
}
+static bool shouldHaveDiscriminator(const Instruction *I) {
+ return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I);
+}
+
/// \brief Assign DWARF discriminators.
///
/// To assign discriminators, we examine the boundaries of every
@@ -176,7 +180,13 @@ static bool addDiscriminators(Function &F) {
// discriminator for this instruction.
for (BasicBlock &B : F) {
for (auto &I : B.getInstList()) {
- if (isa<IntrinsicInst>(&I))
+ // Not all intrinsic calls should have a discriminator.
+ // We want to avoid a non-deterministic assignment of discriminators at
+ // different debug levels. We still allow discriminators on memory
+ // intrinsic calls because those can be early expanded by SROA into
+ // pairs of loads and stores, and the expanded load/store instructions
+ // should have a valid discriminator.
+ if (!shouldHaveDiscriminator(&I))
continue;
const DILocation *DIL = I.getDebugLoc();
if (!DIL)
@@ -190,8 +200,8 @@ static bool addDiscriminators(Function &F) {
// discriminator is needed to distinguish both instructions.
// Only the lowest 7 bits are used to represent a discriminator to fit
// it in 1 byte ULEB128 representation.
- unsigned Discriminator = (R.second ? ++LDM[L] : LDM[L]) & 0x7f;
- I.setDebugLoc(DIL->cloneWithDiscriminator(Discriminator));
+ unsigned Discriminator = R.second ? ++LDM[L] : LDM[L];
+ I.setDebugLoc(DIL->setBaseDiscriminator(Discriminator));
DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
<< DIL->getColumn() << ":" << Discriminator << " " << I
<< "\n");
@@ -207,6 +217,10 @@ static bool addDiscriminators(Function &F) {
LocationSet CallLocations;
for (auto &I : B.getInstList()) {
CallInst *Current = dyn_cast<CallInst>(&I);
+ // We bypass intrinsic calls for the following two reasons:
+ // 1) We want to avoid a non-deterministic assigment of
+ // discriminators.
+ // 2) We want to minimize the number of base discriminators used.
if (!Current || isa<IntrinsicInst>(&I))
continue;
@@ -216,8 +230,8 @@ static bool addDiscriminators(Function &F) {
Location L =
std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine());
if (!CallLocations.insert(L).second) {
- Current->setDebugLoc(
- CurrentDIL->cloneWithDiscriminator((++LDM[L]) & 0x7f));
+ unsigned Discriminator = ++LDM[L];
+ Current->setDebugLoc(CurrentDIL->setBaseDiscriminator(Discriminator));
Changed = true;
}
}
diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp
index b90349d3cdad1..22af21d55c019 100644
--- a/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -438,7 +438,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB,
// The new block unconditionally branches to the old block.
BranchInst *BI = BranchInst::Create(BB, NewBB);
- BI->setDebugLoc(BB->getFirstNonPHI()->getDebugLoc());
+ BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc());
// Move the edges from Preds to point to NewBB instead of BB.
for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
@@ -646,9 +646,10 @@ llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore,
}
if (LI) {
- Loop *L = LI->getLoopFor(Head);
- L->addBasicBlockToLoop(ThenBlock, *LI);
- L->addBasicBlockToLoop(Tail, *LI);
+ if (Loop *L = LI->getLoopFor(Head)) {
+ L->addBasicBlockToLoop(ThenBlock, *LI);
+ L->addBasicBlockToLoop(Tail, *LI);
+ }
}
return CheckTerm;
diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp
index e61b04fbdd574..6cd9f1614991a 100644
--- a/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -96,9 +96,9 @@ static bool setDoesNotAlias(Function &F, unsigned n) {
}
static bool setNonNull(Function &F, unsigned n) {
- assert((n != AttributeSet::ReturnIndex ||
- F.getReturnType()->isPointerTy()) &&
- "nonnull applies only to pointers");
+ assert(
+ (n != AttributeList::ReturnIndex || F.getReturnType()->isPointerTy()) &&
+ "nonnull applies only to pointers");
if (F.getAttributes().hasAttribute(n, Attribute::NonNull))
return false;
F.addAttribute(n, Attribute::NonNull);
@@ -107,255 +107,255 @@ static bool setNonNull(Function &F, unsigned n) {
}
bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
- LibFunc::Func TheLibFunc;
+ LibFunc TheLibFunc;
if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc)))
return false;
bool Changed = false;
switch (TheLibFunc) {
- case LibFunc::strlen:
+ case LibFunc_strlen:
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::strchr:
- case LibFunc::strrchr:
+ case LibFunc_strchr:
+ case LibFunc_strrchr:
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
return Changed;
- case LibFunc::strtol:
- case LibFunc::strtod:
- case LibFunc::strtof:
- case LibFunc::strtoul:
- case LibFunc::strtoll:
- case LibFunc::strtold:
- case LibFunc::strtoull:
+ case LibFunc_strtol:
+ case LibFunc_strtod:
+ case LibFunc_strtof:
+ case LibFunc_strtoul:
+ case LibFunc_strtoll:
+ case LibFunc_strtold:
+ case LibFunc_strtoull:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::strcpy:
- case LibFunc::stpcpy:
- case LibFunc::strcat:
- case LibFunc::strncat:
- case LibFunc::strncpy:
- case LibFunc::stpncpy:
+ case LibFunc_strcpy:
+ case LibFunc_stpcpy:
+ case LibFunc_strcat:
+ case LibFunc_strncat:
+ case LibFunc_strncpy:
+ case LibFunc_stpncpy:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::strxfrm:
+ case LibFunc_strxfrm:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::strcmp: // 0,1
- case LibFunc::strspn: // 0,1
- case LibFunc::strncmp: // 0,1
- case LibFunc::strcspn: // 0,1
- case LibFunc::strcoll: // 0,1
- case LibFunc::strcasecmp: // 0,1
- case LibFunc::strncasecmp: //
+ case LibFunc_strcmp: // 0,1
+ case LibFunc_strspn: // 0,1
+ case LibFunc_strncmp: // 0,1
+ case LibFunc_strcspn: // 0,1
+ case LibFunc_strcoll: // 0,1
+ case LibFunc_strcasecmp: // 0,1
+ case LibFunc_strncasecmp: //
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::strstr:
- case LibFunc::strpbrk:
+ case LibFunc_strstr:
+ case LibFunc_strpbrk:
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::strtok:
- case LibFunc::strtok_r:
+ case LibFunc_strtok:
+ case LibFunc_strtok_r:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::scanf:
+ case LibFunc_scanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::setbuf:
- case LibFunc::setvbuf:
+ case LibFunc_setbuf:
+ case LibFunc_setvbuf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::strdup:
- case LibFunc::strndup:
+ case LibFunc_strdup:
+ case LibFunc_strndup:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::stat:
- case LibFunc::statvfs:
+ case LibFunc_stat:
+ case LibFunc_statvfs:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::sscanf:
+ case LibFunc_sscanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::sprintf:
+ case LibFunc_sprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::snprintf:
+ case LibFunc_snprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 3);
Changed |= setOnlyReadsMemory(F, 3);
return Changed;
- case LibFunc::setitimer:
+ case LibFunc_setitimer:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setDoesNotCapture(F, 3);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::system:
+ case LibFunc_system:
// May throw; "system" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::malloc:
+ case LibFunc_malloc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::memcmp:
+ case LibFunc_memcmp:
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::memchr:
- case LibFunc::memrchr:
+ case LibFunc_memchr:
+ case LibFunc_memrchr:
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotThrow(F);
return Changed;
- case LibFunc::modf:
- case LibFunc::modff:
- case LibFunc::modfl:
+ case LibFunc_modf:
+ case LibFunc_modff:
+ case LibFunc_modfl:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::memcpy:
- case LibFunc::mempcpy:
- case LibFunc::memccpy:
- case LibFunc::memmove:
+ case LibFunc_memcpy:
+ case LibFunc_mempcpy:
+ case LibFunc_memccpy:
+ case LibFunc_memmove:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::memcpy_chk:
+ case LibFunc_memcpy_chk:
Changed |= setDoesNotThrow(F);
return Changed;
- case LibFunc::memalign:
+ case LibFunc_memalign:
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::mkdir:
+ case LibFunc_mkdir:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::mktime:
+ case LibFunc_mktime:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::realloc:
+ case LibFunc_realloc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::read:
+ case LibFunc_read:
// May throw; "read" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::rewind:
+ case LibFunc_rewind:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::rmdir:
- case LibFunc::remove:
- case LibFunc::realpath:
+ case LibFunc_rmdir:
+ case LibFunc_remove:
+ case LibFunc_realpath:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::rename:
+ case LibFunc_rename:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::readlink:
+ case LibFunc_readlink:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::write:
+ case LibFunc_write:
// May throw; "write" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::bcopy:
+ case LibFunc_bcopy:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::bcmp:
+ case LibFunc_bcmp:
Changed |= setDoesNotThrow(F);
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::bzero:
+ case LibFunc_bzero:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::calloc:
+ case LibFunc_calloc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::chmod:
- case LibFunc::chown:
+ case LibFunc_chmod:
+ case LibFunc_chown:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::ctermid:
- case LibFunc::clearerr:
- case LibFunc::closedir:
+ case LibFunc_ctermid:
+ case LibFunc_clearerr:
+ case LibFunc_closedir:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::atoi:
- case LibFunc::atol:
- case LibFunc::atof:
- case LibFunc::atoll:
+ case LibFunc_atoi:
+ case LibFunc_atol:
+ case LibFunc_atof:
+ case LibFunc_atoll:
Changed |= setDoesNotThrow(F);
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::access:
+ case LibFunc_access:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::fopen:
+ case LibFunc_fopen:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
@@ -363,150 +363,150 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::fdopen:
+ case LibFunc_fdopen:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::feof:
- case LibFunc::free:
- case LibFunc::fseek:
- case LibFunc::ftell:
- case LibFunc::fgetc:
- case LibFunc::fseeko:
- case LibFunc::ftello:
- case LibFunc::fileno:
- case LibFunc::fflush:
- case LibFunc::fclose:
- case LibFunc::fsetpos:
- case LibFunc::flockfile:
- case LibFunc::funlockfile:
- case LibFunc::ftrylockfile:
+ case LibFunc_feof:
+ case LibFunc_free:
+ case LibFunc_fseek:
+ case LibFunc_ftell:
+ case LibFunc_fgetc:
+ case LibFunc_fseeko:
+ case LibFunc_ftello:
+ case LibFunc_fileno:
+ case LibFunc_fflush:
+ case LibFunc_fclose:
+ case LibFunc_fsetpos:
+ case LibFunc_flockfile:
+ case LibFunc_funlockfile:
+ case LibFunc_ftrylockfile:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::ferror:
+ case LibFunc_ferror:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F);
return Changed;
- case LibFunc::fputc:
- case LibFunc::fstat:
- case LibFunc::frexp:
- case LibFunc::frexpf:
- case LibFunc::frexpl:
- case LibFunc::fstatvfs:
+ case LibFunc_fputc:
+ case LibFunc_fstat:
+ case LibFunc_frexp:
+ case LibFunc_frexpf:
+ case LibFunc_frexpl:
+ case LibFunc_fstatvfs:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::fgets:
+ case LibFunc_fgets:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 3);
return Changed;
- case LibFunc::fread:
+ case LibFunc_fread:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 4);
return Changed;
- case LibFunc::fwrite:
+ case LibFunc_fwrite:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 4);
// FIXME: readonly #1?
return Changed;
- case LibFunc::fputs:
+ case LibFunc_fputs:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::fscanf:
- case LibFunc::fprintf:
+ case LibFunc_fscanf:
+ case LibFunc_fprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::fgetpos:
+ case LibFunc_fgetpos:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::getc:
- case LibFunc::getlogin_r:
- case LibFunc::getc_unlocked:
+ case LibFunc_getc:
+ case LibFunc_getlogin_r:
+ case LibFunc_getc_unlocked:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::getenv:
+ case LibFunc_getenv:
Changed |= setDoesNotThrow(F);
Changed |= setOnlyReadsMemory(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::gets:
- case LibFunc::getchar:
+ case LibFunc_gets:
+ case LibFunc_getchar:
Changed |= setDoesNotThrow(F);
return Changed;
- case LibFunc::getitimer:
+ case LibFunc_getitimer:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::getpwnam:
+ case LibFunc_getpwnam:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::ungetc:
+ case LibFunc_ungetc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::uname:
+ case LibFunc_uname:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::unlink:
+ case LibFunc_unlink:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::unsetenv:
+ case LibFunc_unsetenv:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::utime:
- case LibFunc::utimes:
+ case LibFunc_utime:
+ case LibFunc_utimes:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::putc:
+ case LibFunc_putc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::puts:
- case LibFunc::printf:
- case LibFunc::perror:
+ case LibFunc_puts:
+ case LibFunc_printf:
+ case LibFunc_perror:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::pread:
+ case LibFunc_pread:
// May throw; "pread" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::pwrite:
+ case LibFunc_pwrite:
// May throw; "pwrite" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::putchar:
+ case LibFunc_putchar:
Changed |= setDoesNotThrow(F);
return Changed;
- case LibFunc::popen:
+ case LibFunc_popen:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
@@ -514,132 +514,132 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::pclose:
+ case LibFunc_pclose:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::vscanf:
+ case LibFunc_vscanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::vsscanf:
+ case LibFunc_vsscanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::vfscanf:
+ case LibFunc_vfscanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::valloc:
+ case LibFunc_valloc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::vprintf:
+ case LibFunc_vprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::vfprintf:
- case LibFunc::vsprintf:
+ case LibFunc_vfprintf:
+ case LibFunc_vsprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::vsnprintf:
+ case LibFunc_vsnprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 3);
Changed |= setOnlyReadsMemory(F, 3);
return Changed;
- case LibFunc::open:
+ case LibFunc_open:
// May throw; "open" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::opendir:
+ case LibFunc_opendir:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::tmpfile:
+ case LibFunc_tmpfile:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::times:
+ case LibFunc_times:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::htonl:
- case LibFunc::htons:
- case LibFunc::ntohl:
- case LibFunc::ntohs:
+ case LibFunc_htonl:
+ case LibFunc_htons:
+ case LibFunc_ntohl:
+ case LibFunc_ntohs:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAccessMemory(F);
return Changed;
- case LibFunc::lstat:
+ case LibFunc_lstat:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::lchown:
+ case LibFunc_lchown:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::qsort:
+ case LibFunc_qsort:
// May throw; places call through function pointer.
Changed |= setDoesNotCapture(F, 4);
return Changed;
- case LibFunc::dunder_strdup:
- case LibFunc::dunder_strndup:
+ case LibFunc_dunder_strdup:
+ case LibFunc_dunder_strndup:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::dunder_strtok_r:
+ case LibFunc_dunder_strtok_r:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::under_IO_getc:
+ case LibFunc_under_IO_getc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::under_IO_putc:
+ case LibFunc_under_IO_putc:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::dunder_isoc99_scanf:
+ case LibFunc_dunder_isoc99_scanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::stat64:
- case LibFunc::lstat64:
- case LibFunc::statvfs64:
+ case LibFunc_stat64:
+ case LibFunc_lstat64:
+ case LibFunc_statvfs64:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::dunder_isoc99_sscanf:
+ case LibFunc_dunder_isoc99_sscanf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::fopen64:
+ case LibFunc_fopen64:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
@@ -647,26 +647,26 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
Changed |= setOnlyReadsMemory(F, 1);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
- case LibFunc::fseeko64:
- case LibFunc::ftello64:
+ case LibFunc_fseeko64:
+ case LibFunc_ftello64:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
return Changed;
- case LibFunc::tmpfile64:
+ case LibFunc_tmpfile64:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotAlias(F, 0);
return Changed;
- case LibFunc::fstat64:
- case LibFunc::fstatvfs64:
+ case LibFunc_fstat64:
+ case LibFunc_fstatvfs64:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::open64:
+ case LibFunc_open64:
// May throw; "open" is a valid pthread cancellation point.
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
- case LibFunc::gettimeofday:
+ case LibFunc_gettimeofday:
// Currently some platforms have the restrict keyword on the arguments to
// gettimeofday. To be conservative, do not add noalias to gettimeofday's
// arguments.
@@ -674,29 +674,29 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
return Changed;
- case LibFunc::Znwj: // new(unsigned int)
- case LibFunc::Znwm: // new(unsigned long)
- case LibFunc::Znaj: // new[](unsigned int)
- case LibFunc::Znam: // new[](unsigned long)
- case LibFunc::msvc_new_int: // new(unsigned int)
- case LibFunc::msvc_new_longlong: // new(unsigned long long)
- case LibFunc::msvc_new_array_int: // new[](unsigned int)
- case LibFunc::msvc_new_array_longlong: // new[](unsigned long long)
+ case LibFunc_Znwj: // new(unsigned int)
+ case LibFunc_Znwm: // new(unsigned long)
+ case LibFunc_Znaj: // new[](unsigned int)
+ case LibFunc_Znam: // new[](unsigned long)
+ case LibFunc_msvc_new_int: // new(unsigned int)
+ case LibFunc_msvc_new_longlong: // new(unsigned long long)
+ case LibFunc_msvc_new_array_int: // new[](unsigned int)
+ case LibFunc_msvc_new_array_longlong: // new[](unsigned long long)
// Operator new always returns a nonnull noalias pointer
- Changed |= setNonNull(F, AttributeSet::ReturnIndex);
- Changed |= setDoesNotAlias(F, AttributeSet::ReturnIndex);
+ Changed |= setNonNull(F, AttributeList::ReturnIndex);
+ Changed |= setDoesNotAlias(F, AttributeList::ReturnIndex);
return Changed;
//TODO: add LibFunc entries for:
- //case LibFunc::memset_pattern4:
- //case LibFunc::memset_pattern8:
- case LibFunc::memset_pattern16:
+ //case LibFunc_memset_pattern4:
+ //case LibFunc_memset_pattern8:
+ case LibFunc_memset_pattern16:
Changed |= setOnlyAccessesArgMemory(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
// int __nvvm_reflect(const char *)
- case LibFunc::nvvm_reflect:
+ case LibFunc_nvvm_reflect:
Changed |= setDoesNotAccessMemory(F);
Changed |= setDoesNotThrow(F);
return Changed;
@@ -717,13 +717,13 @@ Value *llvm::castToCStr(Value *V, IRBuilder<> &B) {
Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::strlen))
+ if (!TLI->has(LibFunc_strlen))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
LLVMContext &Context = B.GetInsertBlock()->getContext();
Constant *StrLen = M->getOrInsertFunction("strlen", DL.getIntPtrType(Context),
- B.getInt8PtrTy(), nullptr);
+ B.getInt8PtrTy());
inferLibFuncAttributes(*M->getFunction("strlen"), *TLI);
CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), "strlen");
if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts()))
@@ -734,14 +734,14 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL,
Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::strchr))
+ if (!TLI->has(LibFunc_strchr))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
Type *I8Ptr = B.getInt8PtrTy();
Type *I32Ty = B.getInt32Ty();
Constant *StrChr =
- M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty, nullptr);
+ M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty);
inferLibFuncAttributes(*M->getFunction("strchr"), *TLI);
CallInst *CI = B.CreateCall(
StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr");
@@ -752,14 +752,14 @@ Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B,
Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::strncmp))
+ if (!TLI->has(LibFunc_strncmp))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
LLVMContext &Context = B.GetInsertBlock()->getContext();
Value *StrNCmp = M->getOrInsertFunction("strncmp", B.getInt32Ty(),
B.getInt8PtrTy(), B.getInt8PtrTy(),
- DL.getIntPtrType(Context), nullptr);
+ DL.getIntPtrType(Context));
inferLibFuncAttributes(*M->getFunction("strncmp"), *TLI);
CallInst *CI = B.CreateCall(
StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "strncmp");
@@ -772,12 +772,12 @@ Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,
Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B,
const TargetLibraryInfo *TLI, StringRef Name) {
- if (!TLI->has(LibFunc::strcpy))
+ if (!TLI->has(LibFunc_strcpy))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
Type *I8Ptr = B.getInt8PtrTy();
- Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, nullptr);
+ Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr);
inferLibFuncAttributes(*M->getFunction(Name), *TLI);
CallInst *CI =
B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name);
@@ -788,13 +788,13 @@ Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B,
Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B,
const TargetLibraryInfo *TLI, StringRef Name) {
- if (!TLI->has(LibFunc::strncpy))
+ if (!TLI->has(LibFunc_strncpy))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
Type *I8Ptr = B.getInt8PtrTy();
Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr,
- Len->getType(), nullptr);
+ Len->getType());
inferLibFuncAttributes(*M->getFunction(Name), *TLI);
CallInst *CI = B.CreateCall(
StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncpy");
@@ -806,18 +806,18 @@ Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B,
Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
IRBuilder<> &B, const DataLayout &DL,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::memcpy_chk))
+ if (!TLI->has(LibFunc_memcpy_chk))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
- AttributeSet AS;
- AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex,
- Attribute::NoUnwind);
+ AttributeList AS;
+ AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex,
+ Attribute::NoUnwind);
LLVMContext &Context = B.GetInsertBlock()->getContext();
Value *MemCpy = M->getOrInsertFunction(
- "__memcpy_chk", AttributeSet::get(M->getContext(), AS), B.getInt8PtrTy(),
+ "__memcpy_chk", AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(),
B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context),
- DL.getIntPtrType(Context), nullptr);
+ DL.getIntPtrType(Context));
Dst = castToCStr(Dst, B);
Src = castToCStr(Src, B);
CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize});
@@ -828,14 +828,14 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::memchr))
+ if (!TLI->has(LibFunc_memchr))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
LLVMContext &Context = B.GetInsertBlock()->getContext();
Value *MemChr = M->getOrInsertFunction("memchr", B.getInt8PtrTy(),
B.getInt8PtrTy(), B.getInt32Ty(),
- DL.getIntPtrType(Context), nullptr);
+ DL.getIntPtrType(Context));
inferLibFuncAttributes(*M->getFunction("memchr"), *TLI);
CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, "memchr");
@@ -847,14 +847,14 @@ Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B,
Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::memcmp))
+ if (!TLI->has(LibFunc_memcmp))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
LLVMContext &Context = B.GetInsertBlock()->getContext();
Value *MemCmp = M->getOrInsertFunction("memcmp", B.getInt32Ty(),
B.getInt8PtrTy(), B.getInt8PtrTy(),
- DL.getIntPtrType(Context), nullptr);
+ DL.getIntPtrType(Context));
inferLibFuncAttributes(*M->getFunction("memcmp"), *TLI);
CallInst *CI = B.CreateCall(
MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "memcmp");
@@ -881,13 +881,13 @@ static void appendTypeSuffix(Value *Op, StringRef &Name,
}
Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B,
- const AttributeSet &Attrs) {
+ const AttributeList &Attrs) {
SmallString<20> NameBuffer;
appendTypeSuffix(Op, Name, NameBuffer);
Module *M = B.GetInsertBlock()->getModule();
Value *Callee = M->getOrInsertFunction(Name, Op->getType(),
- Op->getType(), nullptr);
+ Op->getType());
CallInst *CI = B.CreateCall(Callee, Op, Name);
CI->setAttributes(Attrs);
if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts()))
@@ -897,13 +897,13 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B,
}
Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,
- IRBuilder<> &B, const AttributeSet &Attrs) {
+ IRBuilder<> &B, const AttributeList &Attrs) {
SmallString<20> NameBuffer;
appendTypeSuffix(Op1, Name, NameBuffer);
Module *M = B.GetInsertBlock()->getModule();
Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), Op1->getType(),
- Op2->getType(), nullptr);
+ Op2->getType());
CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name);
CI->setAttributes(Attrs);
if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts()))
@@ -914,12 +914,12 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,
Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::putchar))
+ if (!TLI->has(LibFunc_putchar))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
- Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(),
- B.getInt32Ty(), nullptr);
+ Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), B.getInt32Ty());
+ inferLibFuncAttributes(*M->getFunction("putchar"), *TLI);
CallInst *CI = B.CreateCall(PutChar,
B.CreateIntCast(Char,
B.getInt32Ty(),
@@ -934,12 +934,12 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B,
Value *llvm::emitPutS(Value *Str, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::puts))
+ if (!TLI->has(LibFunc_puts))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
Value *PutS =
- M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy(), nullptr);
+ M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy());
inferLibFuncAttributes(*M->getFunction("puts"), *TLI);
CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), "puts");
if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts()))
@@ -949,12 +949,12 @@ Value *llvm::emitPutS(Value *Str, IRBuilder<> &B,
Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::fputc))
+ if (!TLI->has(LibFunc_fputc))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
Constant *F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(),
- File->getType(), nullptr);
+ File->getType());
if (File->getType()->isPointerTy())
inferLibFuncAttributes(*M->getFunction("fputc"), *TLI);
Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true,
@@ -968,13 +968,13 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B,
Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::fputs))
+ if (!TLI->has(LibFunc_fputs))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
- StringRef FPutsName = TLI->getName(LibFunc::fputs);
+ StringRef FPutsName = TLI->getName(LibFunc_fputs);
Constant *F = M->getOrInsertFunction(
- FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType(), nullptr);
+ FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType());
if (File->getType()->isPointerTy())
inferLibFuncAttributes(*M->getFunction(FPutsName), *TLI);
CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs");
@@ -986,16 +986,16 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B,
Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- if (!TLI->has(LibFunc::fwrite))
+ if (!TLI->has(LibFunc_fwrite))
return nullptr;
Module *M = B.GetInsertBlock()->getModule();
LLVMContext &Context = B.GetInsertBlock()->getContext();
- StringRef FWriteName = TLI->getName(LibFunc::fwrite);
+ StringRef FWriteName = TLI->getName(LibFunc_fwrite);
Constant *F = M->getOrInsertFunction(
FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(),
- DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType(),
- nullptr);
+ DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType());
+
if (File->getType()->isPointerTy())
inferLibFuncAttributes(*M->getFunction(FWriteName), *TLI);
CallInst *CI =
diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp
index bc2cef26edcbc..1cfe3bd536482 100644
--- a/lib/Transforms/Utils/BypassSlowDivision.cpp
+++ b/lib/Transforms/Utils/BypassSlowDivision.cpp
@@ -17,6 +17,8 @@
#include "llvm/Transforms/Utils/BypassSlowDivision.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -36,12 +38,21 @@ namespace {
: SignedOp(InSignedOp), Dividend(InDividend), Divisor(InDivisor) {}
};
- struct DivPhiNodes {
- PHINode *Quotient;
- PHINode *Remainder;
+ struct QuotRemPair {
+ Value *Quotient;
+ Value *Remainder;
- DivPhiNodes(PHINode *InQuotient, PHINode *InRemainder)
- : Quotient(InQuotient), Remainder(InRemainder) {}
+ QuotRemPair(Value *InQuotient, Value *InRemainder)
+ : Quotient(InQuotient), Remainder(InRemainder) {}
+ };
+
+ /// A quotient and remainder, plus a BB from which they logically "originate".
+ /// If you use Quotient or Remainder in a Phi node, you should use BB as its
+ /// corresponding predecessor.
+ struct QuotRemWithBB {
+ BasicBlock *BB = nullptr;
+ Value *Quotient = nullptr;
+ Value *Remainder = nullptr;
};
}
@@ -69,159 +80,376 @@ namespace llvm {
}
};
- typedef DenseMap<DivOpInfo, DivPhiNodes> DivCacheTy;
+ typedef DenseMap<DivOpInfo, QuotRemPair> DivCacheTy;
+ typedef DenseMap<unsigned, unsigned> BypassWidthsTy;
+ typedef SmallPtrSet<Instruction *, 4> VisitedSetTy;
}
-// insertFastDiv - Substitutes the div/rem instruction with code that checks the
-// value of the operands and uses a shorter-faster div/rem instruction when
-// possible and the longer-slower div/rem instruction otherwise.
-static bool insertFastDiv(Instruction *I, IntegerType *BypassType,
- bool UseDivOp, bool UseSignedOp,
- DivCacheTy &PerBBDivCache) {
- Function *F = I->getParent()->getParent();
- // Get instruction operands
- Value *Dividend = I->getOperand(0);
- Value *Divisor = I->getOperand(1);
+namespace {
+enum ValueRange {
+ /// Operand definitely fits into BypassType. No runtime checks are needed.
+ VALRNG_KNOWN_SHORT,
+ /// A runtime check is required, as value range is unknown.
+ VALRNG_UNKNOWN,
+ /// Operand is unlikely to fit into BypassType. The bypassing should be
+ /// disabled.
+ VALRNG_LIKELY_LONG
+};
+
+class FastDivInsertionTask {
+ bool IsValidTask = false;
+ Instruction *SlowDivOrRem = nullptr;
+ IntegerType *BypassType = nullptr;
+ BasicBlock *MainBB = nullptr;
+
+ bool isHashLikeValue(Value *V, VisitedSetTy &Visited);
+ ValueRange getValueRange(Value *Op, VisitedSetTy &Visited);
+ QuotRemWithBB createSlowBB(BasicBlock *Successor);
+ QuotRemWithBB createFastBB(BasicBlock *Successor);
+ QuotRemPair createDivRemPhiNodes(QuotRemWithBB &LHS, QuotRemWithBB &RHS,
+ BasicBlock *PhiBB);
+ Value *insertOperandRuntimeCheck(Value *Op1, Value *Op2);
+ Optional<QuotRemPair> insertFastDivAndRem();
+
+ bool isSignedOp() {
+ return SlowDivOrRem->getOpcode() == Instruction::SDiv ||
+ SlowDivOrRem->getOpcode() == Instruction::SRem;
+ }
+ bool isDivisionOp() {
+ return SlowDivOrRem->getOpcode() == Instruction::SDiv ||
+ SlowDivOrRem->getOpcode() == Instruction::UDiv;
+ }
+ Type *getSlowType() { return SlowDivOrRem->getType(); }
+
+public:
+ FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths);
+ Value *getReplacement(DivCacheTy &Cache);
+};
+} // anonymous namespace
+
+FastDivInsertionTask::FastDivInsertionTask(Instruction *I,
+ const BypassWidthsTy &BypassWidths) {
+ switch (I->getOpcode()) {
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ SlowDivOrRem = I;
+ break;
+ default:
+ // I is not a div/rem operation.
+ return;
+ }
- if (isa<ConstantInt>(Divisor)) {
- // Division by a constant should have been been solved and replaced earlier
- // in the pipeline.
- return false;
+ // Skip division on vector types. Only optimize integer instructions.
+ IntegerType *SlowType = dyn_cast<IntegerType>(SlowDivOrRem->getType());
+ if (!SlowType)
+ return;
+
+ // Skip if this bitwidth is not bypassed.
+ auto BI = BypassWidths.find(SlowType->getBitWidth());
+ if (BI == BypassWidths.end())
+ return;
+
+ // Get type for div/rem instruction with bypass bitwidth.
+ IntegerType *BT = IntegerType::get(I->getContext(), BI->second);
+ BypassType = BT;
+
+ // The original basic block.
+ MainBB = I->getParent();
+
+ // The instruction is indeed a slow div or rem operation.
+ IsValidTask = true;
+}
+
+/// Reuses previously-computed dividend or remainder from the current BB if
+/// operands and operation are identical. Otherwise calls insertFastDivAndRem to
+/// perform the optimization and caches the resulting dividend and remainder.
+/// If no replacement can be generated, nullptr is returned.
+Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) {
+ // First, make sure that the task is valid.
+ if (!IsValidTask)
+ return nullptr;
+
+ // Then, look for a value in Cache.
+ Value *Dividend = SlowDivOrRem->getOperand(0);
+ Value *Divisor = SlowDivOrRem->getOperand(1);
+ DivOpInfo Key(isSignedOp(), Dividend, Divisor);
+ auto CacheI = Cache.find(Key);
+
+ if (CacheI == Cache.end()) {
+ // If previous instance does not exist, try to insert fast div.
+ Optional<QuotRemPair> OptResult = insertFastDivAndRem();
+ // Bail out if insertFastDivAndRem has failed.
+ if (!OptResult)
+ return nullptr;
+ CacheI = Cache.insert({Key, *OptResult}).first;
}
- // If the numerator is a constant, bail if it doesn't fit into BypassType.
- if (ConstantInt *ConstDividend = dyn_cast<ConstantInt>(Dividend))
- if (ConstDividend->getValue().getActiveBits() > BypassType->getBitWidth())
+ QuotRemPair &Value = CacheI->second;
+ return isDivisionOp() ? Value.Quotient : Value.Remainder;
+}
+
+/// \brief Check if a value looks like a hash.
+///
+/// The routine is expected to detect values computed using the most common hash
+/// algorithms. Typically, hash computations end with one of the following
+/// instructions:
+///
+/// 1) MUL with a constant wider than BypassType
+/// 2) XOR instruction
+///
+/// And even if we are wrong and the value is not a hash, it is still quite
+/// unlikely that such values will fit into BypassType.
+///
+/// To detect string hash algorithms like FNV we have to look through PHI-nodes.
+/// It is implemented as a depth-first search for values that look neither long
+/// nor hash-like.
+bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return false;
+
+ switch (I->getOpcode()) {
+ case Instruction::Xor:
+ return true;
+ case Instruction::Mul: {
+ // After Constant Hoisting pass, long constants may be represented as
+ // bitcast instructions. As a result, some constants may look like an
+ // instruction at first, and an additional check is necessary to find out if
+ // an operand is actually a constant.
+ Value *Op1 = I->getOperand(1);
+ ConstantInt *C = dyn_cast<ConstantInt>(Op1);
+ if (!C && isa<BitCastInst>(Op1))
+ C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0));
+ return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth();
+ }
+ case Instruction::PHI: {
+ // Stop IR traversal in case of a crazy input code. This limits recursion
+ // depth.
+ if (Visited.size() >= 16)
return false;
+ // Do not visit nodes that have been visited already. We return true because
+ // it means that we couldn't find any value that doesn't look hash-like.
+ if (Visited.find(I) != Visited.end())
+ return true;
+ Visited.insert(I);
+ return llvm::all_of(cast<PHINode>(I)->incoming_values(), [&](Value *V) {
+ // Ignore undef values as they probably don't affect the division
+ // operands.
+ return getValueRange(V, Visited) == VALRNG_LIKELY_LONG ||
+ isa<UndefValue>(V);
+ });
+ }
+ default:
+ return false;
+ }
+}
+
+/// Check if an integer value fits into our bypass type.
+ValueRange FastDivInsertionTask::getValueRange(Value *V,
+ VisitedSetTy &Visited) {
+ unsigned ShortLen = BypassType->getBitWidth();
+ unsigned LongLen = V->getType()->getIntegerBitWidth();
+
+ assert(LongLen > ShortLen && "Value type must be wider than BypassType");
+ unsigned HiBits = LongLen - ShortLen;
+
+ const DataLayout &DL = SlowDivOrRem->getModule()->getDataLayout();
+ APInt Zeros(LongLen, 0), Ones(LongLen, 0);
- // Basic Block is split before divide
- BasicBlock *MainBB = &*I->getParent();
- BasicBlock *SuccessorBB = MainBB->splitBasicBlock(I);
-
- // Add new basic block for slow divide operation
- BasicBlock *SlowBB =
- BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB);
- SlowBB->moveBefore(SuccessorBB);
- IRBuilder<> SlowBuilder(SlowBB, SlowBB->begin());
- Value *SlowQuotientV;
- Value *SlowRemainderV;
- if (UseSignedOp) {
- SlowQuotientV = SlowBuilder.CreateSDiv(Dividend, Divisor);
- SlowRemainderV = SlowBuilder.CreateSRem(Dividend, Divisor);
+ computeKnownBits(V, Zeros, Ones, DL);
+
+ if (Zeros.countLeadingOnes() >= HiBits)
+ return VALRNG_KNOWN_SHORT;
+
+ if (Ones.countLeadingZeros() < HiBits)
+ return VALRNG_LIKELY_LONG;
+
+ // Long integer divisions are often used in hashtable implementations. It's
+ // not worth bypassing such divisions because hash values are extremely
+ // unlikely to have enough leading zeros. The call below tries to detect
+ // values that are unlikely to fit BypassType (including hashes).
+ if (isHashLikeValue(V, Visited))
+ return VALRNG_LIKELY_LONG;
+
+ return VALRNG_UNKNOWN;
+}
+
+/// Add new basic block for slow div and rem operations and put it before
+/// SuccessorBB.
+QuotRemWithBB FastDivInsertionTask::createSlowBB(BasicBlock *SuccessorBB) {
+ QuotRemWithBB DivRemPair;
+ DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "",
+ MainBB->getParent(), SuccessorBB);
+ IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin());
+
+ Value *Dividend = SlowDivOrRem->getOperand(0);
+ Value *Divisor = SlowDivOrRem->getOperand(1);
+
+ if (isSignedOp()) {
+ DivRemPair.Quotient = Builder.CreateSDiv(Dividend, Divisor);
+ DivRemPair.Remainder = Builder.CreateSRem(Dividend, Divisor);
} else {
- SlowQuotientV = SlowBuilder.CreateUDiv(Dividend, Divisor);
- SlowRemainderV = SlowBuilder.CreateURem(Dividend, Divisor);
+ DivRemPair.Quotient = Builder.CreateUDiv(Dividend, Divisor);
+ DivRemPair.Remainder = Builder.CreateURem(Dividend, Divisor);
}
- SlowBuilder.CreateBr(SuccessorBB);
-
- // Add new basic block for fast divide operation
- BasicBlock *FastBB =
- BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB);
- FastBB->moveBefore(SlowBB);
- IRBuilder<> FastBuilder(FastBB, FastBB->begin());
- Value *ShortDivisorV = FastBuilder.CreateCast(Instruction::Trunc, Divisor,
- BypassType);
- Value *ShortDividendV = FastBuilder.CreateCast(Instruction::Trunc, Dividend,
- BypassType);
-
- // udiv/urem because optimization only handles positive numbers
- Value *ShortQuotientV = FastBuilder.CreateUDiv(ShortDividendV, ShortDivisorV);
- Value *ShortRemainderV = FastBuilder.CreateURem(ShortDividendV,
- ShortDivisorV);
- Value *FastQuotientV = FastBuilder.CreateCast(Instruction::ZExt,
- ShortQuotientV,
- Dividend->getType());
- Value *FastRemainderV = FastBuilder.CreateCast(Instruction::ZExt,
- ShortRemainderV,
- Dividend->getType());
- FastBuilder.CreateBr(SuccessorBB);
-
- // Phi nodes for result of div and rem
- IRBuilder<> SuccessorBuilder(SuccessorBB, SuccessorBB->begin());
- PHINode *QuoPhi = SuccessorBuilder.CreatePHI(I->getType(), 2);
- QuoPhi->addIncoming(SlowQuotientV, SlowBB);
- QuoPhi->addIncoming(FastQuotientV, FastBB);
- PHINode *RemPhi = SuccessorBuilder.CreatePHI(I->getType(), 2);
- RemPhi->addIncoming(SlowRemainderV, SlowBB);
- RemPhi->addIncoming(FastRemainderV, FastBB);
-
- // Replace I with appropriate phi node
- if (UseDivOp)
- I->replaceAllUsesWith(QuoPhi);
- else
- I->replaceAllUsesWith(RemPhi);
- I->eraseFromParent();
- // Combine operands into a single value with OR for value testing below
- MainBB->getInstList().back().eraseFromParent();
- IRBuilder<> MainBuilder(MainBB, MainBB->end());
+ Builder.CreateBr(SuccessorBB);
+ return DivRemPair;
+}
+
+/// Add new basic block for fast div and rem operations and put it before
+/// SuccessorBB.
+QuotRemWithBB FastDivInsertionTask::createFastBB(BasicBlock *SuccessorBB) {
+ QuotRemWithBB DivRemPair;
+ DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "",
+ MainBB->getParent(), SuccessorBB);
+ IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin());
+
+ Value *Dividend = SlowDivOrRem->getOperand(0);
+ Value *Divisor = SlowDivOrRem->getOperand(1);
+ Value *ShortDivisorV =
+ Builder.CreateCast(Instruction::Trunc, Divisor, BypassType);
+ Value *ShortDividendV =
+ Builder.CreateCast(Instruction::Trunc, Dividend, BypassType);
+
+ // udiv/urem because this optimization only handles positive numbers.
+ Value *ShortQV = Builder.CreateUDiv(ShortDividendV, ShortDivisorV);
+ Value *ShortRV = Builder.CreateURem(ShortDividendV, ShortDivisorV);
+ DivRemPair.Quotient =
+ Builder.CreateCast(Instruction::ZExt, ShortQV, getSlowType());
+ DivRemPair.Remainder =
+ Builder.CreateCast(Instruction::ZExt, ShortRV, getSlowType());
+ Builder.CreateBr(SuccessorBB);
+
+ return DivRemPair;
+}
- // We should have bailed out above if the divisor is a constant, but the
- // dividend may still be a constant. Set OrV to our non-constant operands
- // OR'ed together.
- assert(!isa<ConstantInt>(Divisor));
+/// Creates Phi nodes for result of Div and Rem.
+QuotRemPair FastDivInsertionTask::createDivRemPhiNodes(QuotRemWithBB &LHS,
+ QuotRemWithBB &RHS,
+ BasicBlock *PhiBB) {
+ IRBuilder<> Builder(PhiBB, PhiBB->begin());
+ PHINode *QuoPhi = Builder.CreatePHI(getSlowType(), 2);
+ QuoPhi->addIncoming(LHS.Quotient, LHS.BB);
+ QuoPhi->addIncoming(RHS.Quotient, RHS.BB);
+ PHINode *RemPhi = Builder.CreatePHI(getSlowType(), 2);
+ RemPhi->addIncoming(LHS.Remainder, LHS.BB);
+ RemPhi->addIncoming(RHS.Remainder, RHS.BB);
+ return QuotRemPair(QuoPhi, RemPhi);
+}
+
+/// Creates a runtime check to test whether both the divisor and dividend fit
+/// into BypassType. The check is inserted at the end of MainBB. True return
+/// value means that the operands fit. Either of the operands may be NULL if it
+/// doesn't need a runtime check.
+Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) {
+ assert((Op1 || Op2) && "Nothing to check");
+ IRBuilder<> Builder(MainBB, MainBB->end());
Value *OrV;
- if (!isa<ConstantInt>(Dividend))
- OrV = MainBuilder.CreateOr(Dividend, Divisor);
+ if (Op1 && Op2)
+ OrV = Builder.CreateOr(Op1, Op2);
else
- OrV = Divisor;
+ OrV = Op1 ? Op1 : Op2;
// BitMask is inverted to check if the operands are
// larger than the bypass type
uint64_t BitMask = ~BypassType->getBitMask();
- Value *AndV = MainBuilder.CreateAnd(OrV, BitMask);
-
- // Compare operand values and branch
- Value *ZeroV = ConstantInt::getSigned(Dividend->getType(), 0);
- Value *CmpV = MainBuilder.CreateICmpEQ(AndV, ZeroV);
- MainBuilder.CreateCondBr(CmpV, FastBB, SlowBB);
-
- // Cache phi nodes to be used later in place of other instances
- // of div or rem with the same sign, dividend, and divisor
- DivOpInfo Key(UseSignedOp, Dividend, Divisor);
- DivPhiNodes Value(QuoPhi, RemPhi);
- PerBBDivCache.insert(std::pair<DivOpInfo, DivPhiNodes>(Key, Value));
- return true;
+ Value *AndV = Builder.CreateAnd(OrV, BitMask);
+
+ // Compare operand values
+ Value *ZeroV = ConstantInt::getSigned(getSlowType(), 0);
+ return Builder.CreateICmpEQ(AndV, ZeroV);
}
-// reuseOrInsertFastDiv - Reuses previously computed dividend or remainder from
-// the current BB if operands and operation are identical. Otherwise calls
-// insertFastDiv to perform the optimization and caches the resulting dividend
-// and remainder.
-static bool reuseOrInsertFastDiv(Instruction *I, IntegerType *BypassType,
- bool UseDivOp, bool UseSignedOp,
- DivCacheTy &PerBBDivCache) {
- // Get instruction operands
- DivOpInfo Key(UseSignedOp, I->getOperand(0), I->getOperand(1));
- DivCacheTy::iterator CacheI = PerBBDivCache.find(Key);
-
- if (CacheI == PerBBDivCache.end()) {
- // If previous instance does not exist, insert fast div
- return insertFastDiv(I, BypassType, UseDivOp, UseSignedOp, PerBBDivCache);
+/// Substitutes the div/rem instruction with code that checks the value of the
+/// operands and uses a shorter-faster div/rem instruction when possible.
+Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() {
+ Value *Dividend = SlowDivOrRem->getOperand(0);
+ Value *Divisor = SlowDivOrRem->getOperand(1);
+
+ if (isa<ConstantInt>(Divisor)) {
+ // Keep division by a constant for DAGCombiner.
+ return None;
}
- // Replace operation value with previously generated phi node
- DivPhiNodes &Value = CacheI->second;
- if (UseDivOp) {
- // Replace all uses of div instruction with quotient phi node
- I->replaceAllUsesWith(Value.Quotient);
+ VisitedSetTy SetL;
+ ValueRange DividendRange = getValueRange(Dividend, SetL);
+ if (DividendRange == VALRNG_LIKELY_LONG)
+ return None;
+
+ VisitedSetTy SetR;
+ ValueRange DivisorRange = getValueRange(Divisor, SetR);
+ if (DivisorRange == VALRNG_LIKELY_LONG)
+ return None;
+
+ bool DividendShort = (DividendRange == VALRNG_KNOWN_SHORT);
+ bool DivisorShort = (DivisorRange == VALRNG_KNOWN_SHORT);
+
+ if (DividendShort && DivisorShort) {
+ // If both operands are known to be short then just replace the long
+ // division with a short one in-place.
+
+ IRBuilder<> Builder(SlowDivOrRem);
+ Value *TruncDividend = Builder.CreateTrunc(Dividend, BypassType);
+ Value *TruncDivisor = Builder.CreateTrunc(Divisor, BypassType);
+ Value *TruncDiv = Builder.CreateUDiv(TruncDividend, TruncDivisor);
+ Value *TruncRem = Builder.CreateURem(TruncDividend, TruncDivisor);
+ Value *ExtDiv = Builder.CreateZExt(TruncDiv, getSlowType());
+ Value *ExtRem = Builder.CreateZExt(TruncRem, getSlowType());
+ return QuotRemPair(ExtDiv, ExtRem);
+ } else if (DividendShort && !isSignedOp()) {
+ // If the division is unsigned and Dividend is known to be short, then
+ // either
+ // 1) Divisor is less or equal to Dividend, and the result can be computed
+ // with a short division.
+ // 2) Divisor is greater than Dividend. In this case, no division is needed
+ // at all: The quotient is 0 and the remainder is equal to Dividend.
+ //
+ // So instead of checking at runtime whether Divisor fits into BypassType,
+ // we emit a runtime check to differentiate between these two cases. This
+ // lets us entirely avoid a long div.
+
+ // Split the basic block before the div/rem.
+ BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem);
+ // Remove the unconditional branch from MainBB to SuccessorBB.
+ MainBB->getInstList().back().eraseFromParent();
+ QuotRemWithBB Long;
+ Long.BB = MainBB;
+ Long.Quotient = ConstantInt::get(getSlowType(), 0);
+ Long.Remainder = Dividend;
+ QuotRemWithBB Fast = createFastBB(SuccessorBB);
+ QuotRemPair Result = createDivRemPhiNodes(Fast, Long, SuccessorBB);
+ IRBuilder<> Builder(MainBB, MainBB->end());
+ Value *CmpV = Builder.CreateICmpUGE(Dividend, Divisor);
+ Builder.CreateCondBr(CmpV, Fast.BB, SuccessorBB);
+ return Result;
} else {
- // Replace all uses of rem instruction with remainder phi node
- I->replaceAllUsesWith(Value.Remainder);
+ // General case. Create both slow and fast div/rem pairs and choose one of
+ // them at runtime.
+
+ // Split the basic block before the div/rem.
+ BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem);
+ // Remove the unconditional branch from MainBB to SuccessorBB.
+ MainBB->getInstList().back().eraseFromParent();
+ QuotRemWithBB Fast = createFastBB(SuccessorBB);
+ QuotRemWithBB Slow = createSlowBB(SuccessorBB);
+ QuotRemPair Result = createDivRemPhiNodes(Fast, Slow, SuccessorBB);
+ Value *CmpV = insertOperandRuntimeCheck(DividendShort ? nullptr : Dividend,
+ DivisorShort ? nullptr : Divisor);
+ IRBuilder<> Builder(MainBB, MainBB->end());
+ Builder.CreateCondBr(CmpV, Fast.BB, Slow.BB);
+ return Result;
}
-
- // Remove redundant operation
- I->eraseFromParent();
- return true;
}
-// bypassSlowDivision - This optimization identifies DIV instructions in a BB
-// that can be profitably bypassed and carried out with a shorter, faster
-// divide.
-bool llvm::bypassSlowDivision(
- BasicBlock *BB, const DenseMap<unsigned int, unsigned int> &BypassWidths) {
- DivCacheTy DivCache;
+/// This optimization identifies DIV/REM instructions in a BB that can be
+/// profitably bypassed and carried out with a shorter, faster divide.
+bool llvm::bypassSlowDivision(BasicBlock *BB,
+ const BypassWidthsTy &BypassWidths) {
+ DivCacheTy PerBBDivCache;
bool MadeChange = false;
Instruction* Next = &*BB->begin();
@@ -231,42 +459,20 @@ bool llvm::bypassSlowDivision(
Instruction* I = Next;
Next = Next->getNextNode();
- // Get instruction details
- unsigned Opcode = I->getOpcode();
- bool UseDivOp = Opcode == Instruction::SDiv || Opcode == Instruction::UDiv;
- bool UseRemOp = Opcode == Instruction::SRem || Opcode == Instruction::URem;
- bool UseSignedOp = Opcode == Instruction::SDiv ||
- Opcode == Instruction::SRem;
-
- // Only optimize div or rem ops
- if (!UseDivOp && !UseRemOp)
- continue;
-
- // Skip division on vector types, only optimize integer instructions
- if (!I->getType()->isIntegerTy())
- continue;
-
- // Get bitwidth of div/rem instruction
- IntegerType *T = cast<IntegerType>(I->getType());
- unsigned int bitwidth = T->getBitWidth();
-
- // Continue if bitwidth is not bypassed
- DenseMap<unsigned int, unsigned int>::const_iterator BI = BypassWidths.find(bitwidth);
- if (BI == BypassWidths.end())
- continue;
-
- // Get type for div/rem instruction with bypass bitwidth
- IntegerType *BT = IntegerType::get(I->getContext(), BI->second);
-
- MadeChange |= reuseOrInsertFastDiv(I, BT, UseDivOp, UseSignedOp, DivCache);
+ FastDivInsertionTask Task(I, BypassWidths);
+ if (Value *Replacement = Task.getReplacement(PerBBDivCache)) {
+ I->replaceAllUsesWith(Replacement);
+ I->eraseFromParent();
+ MadeChange = true;
+ }
}
// Above we eagerly create divs and rems, as pairs, so that we can efficiently
// create divrem machine instructions. Now erase any unused divs / rems so we
// don't leave extra instructions sitting around.
- for (auto &KV : DivCache)
- for (Instruction *Phi : {KV.second.Quotient, KV.second.Remainder})
- RecursivelyDeleteTriviallyDeadInstructions(Phi);
+ for (auto &KV : PerBBDivCache)
+ for (Value *V : {KV.second.Quotient, KV.second.Remainder})
+ RecursivelyDeleteTriviallyDeadInstructions(V);
return MadeChange;
}
diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt
index 69889ec72f908..7a21c03da221a 100644
--- a/lib/Transforms/Utils/CMakeLists.txt
+++ b/lib/Transforms/Utils/CMakeLists.txt
@@ -31,12 +31,13 @@ add_llvm_library(LLVMTransformUtils
LoopUtils.cpp
LoopVersioning.cpp
LowerInvoke.cpp
+ LowerMemIntrinsics.cpp
LowerSwitch.cpp
Mem2Reg.cpp
- MemorySSA.cpp
MetaRenamer.cpp
ModuleUtils.cpp
NameAnonGlobals.cpp
+ PredicateInfo.cpp
PromoteMemoryToRegister.cpp
StripGCRelocates.cpp
SSAUpdater.cpp
@@ -51,6 +52,7 @@ add_llvm_library(LLVMTransformUtils
UnifyFunctionExitNodes.cpp
Utils.cpp
ValueMapper.cpp
+ VNCoercion.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms
diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp
index 4d33e22fecfbd..385c12302e040 100644
--- a/lib/Transforms/Utils/CloneFunction.cpp
+++ b/lib/Transforms/Utils/CloneFunction.cpp
@@ -90,9 +90,9 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,
assert(VMap.count(&I) && "No mapping from source argument specified!");
#endif
- // Copy all attributes other than those stored in the AttributeSet. We need
- // to remap the parameter indices of the AttributeSet.
- AttributeSet NewAttrs = NewFunc->getAttributes();
+ // Copy all attributes other than those stored in the AttributeList. We need
+ // to remap the parameter indices of the AttributeList.
+ AttributeList NewAttrs = NewFunc->getAttributes();
NewFunc->copyAttributesFrom(OldFunc);
NewFunc->setAttributes(NewAttrs);
@@ -103,22 +103,20 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,
ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges,
TypeMapper, Materializer));
- AttributeSet OldAttrs = OldFunc->getAttributes();
+ SmallVector<AttributeSet, 4> NewArgAttrs(NewFunc->arg_size());
+ AttributeList OldAttrs = OldFunc->getAttributes();
+
// Clone any argument attributes that are present in the VMap.
- for (const Argument &OldArg : OldFunc->args())
+ for (const Argument &OldArg : OldFunc->args()) {
if (Argument *NewArg = dyn_cast<Argument>(VMap[&OldArg])) {
- AttributeSet attrs =
- OldAttrs.getParamAttributes(OldArg.getArgNo() + 1);
- if (attrs.getNumSlots() > 0)
- NewArg->addAttr(attrs);
+ NewArgAttrs[NewArg->getArgNo()] =
+ OldAttrs.getParamAttributes(OldArg.getArgNo());
}
+ }
NewFunc->setAttributes(
- NewFunc->getAttributes()
- .addAttributes(NewFunc->getContext(), AttributeSet::ReturnIndex,
- OldAttrs.getRetAttributes())
- .addAttributes(NewFunc->getContext(), AttributeSet::FunctionIndex,
- OldAttrs.getFnAttributes()));
+ AttributeList::get(NewFunc->getContext(), OldAttrs.getFnAttributes(),
+ OldAttrs.getRetAttributes(), NewArgAttrs));
SmallVector<std::pair<unsigned, MDNode *>, 1> MDs;
OldFunc->getAllMetadata(MDs);
@@ -353,7 +351,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB,
Cond = dyn_cast_or_null<ConstantInt>(V);
}
if (Cond) { // Constant fold to uncond branch!
- SwitchInst::ConstCaseIt Case = SI->findCaseValue(Cond);
+ SwitchInst::ConstCaseHandle Case = *SI->findCaseValue(Cond);
BasicBlock *Dest = const_cast<BasicBlock*>(Case.getCaseSuccessor());
VMap[OldTI] = BranchInst::Create(Dest, NewBB);
ToClone.push_back(Dest);
@@ -747,3 +745,40 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB,
return NewLoop;
}
+
+/// \brief Duplicate non-Phi instructions from the beginning of block up to
+/// StopAt instruction into a split block between BB and its predecessor.
+BasicBlock *
+llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB,
+ Instruction *StopAt,
+ ValueToValueMapTy &ValueMapping) {
+ // We are going to have to map operands from the original BB block to the new
+ // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to
+ // account for entry from PredBB.
+ BasicBlock::iterator BI = BB->begin();
+ for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI)
+ ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB);
+
+ BasicBlock *NewBB = SplitEdge(PredBB, BB);
+ NewBB->setName(PredBB->getName() + ".split");
+ Instruction *NewTerm = NewBB->getTerminator();
+
+ // Clone the non-phi instructions of BB into NewBB, keeping track of the
+ // mapping and using it to remap operands in the cloned instructions.
+ for (; StopAt != &*BI; ++BI) {
+ Instruction *New = BI->clone();
+ New->setName(BI->getName());
+ New->insertBefore(NewTerm);
+ ValueMapping[&*BI] = New;
+
+ // Remap operands to patch up intra-block references.
+ for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i)
+ if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) {
+ auto I = ValueMapping.find(Inst);
+ if (I != ValueMapping.end())
+ New->setOperand(i, I->second);
+ }
+ }
+
+ return NewBB;
+}
diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp
index 7ebeb615d2481..4e9d67252d6c5 100644
--- a/lib/Transforms/Utils/CloneModule.cpp
+++ b/lib/Transforms/Utils/CloneModule.cpp
@@ -20,6 +20,15 @@
#include "llvm-c/Core.h"
using namespace llvm;
+static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) {
+ const Comdat *SC = Src->getComdat();
+ if (!SC)
+ return;
+ Comdat *DC = Dst->getParent()->getOrInsertComdat(SC->getName());
+ DC->setSelectionKind(SC->getSelectionKind());
+ Dst->setComdat(DC);
+}
+
/// This is not as easy as it might seem because we have to worry about making
/// copies of global variables and functions, and making their (initializers and
/// references, respectively) refer to the right globals.
@@ -124,6 +133,8 @@ std::unique_ptr<Module> llvm::CloneModule(
I->getAllMetadata(MDs);
for (auto MD : MDs)
GV->addMetadata(MD.first, *MapMetadata(MD.second, VMap));
+
+ copyComdat(GV, &*I);
}
// Similarly, copy over function bodies now...
@@ -153,6 +164,8 @@ std::unique_ptr<Module> llvm::CloneModule(
if (I.hasPersonalityFn())
F->setPersonalityFn(MapValue(I.getPersonalityFn(), VMap));
+
+ copyComdat(F, &I);
}
// And aliases
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index c514c9c9cd4a6..644d93b727b3d 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -362,9 +362,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
// "target-features" attribute allowing it to be lowered.
// FIXME: This should be changed to check to see if a specific
// attribute can not be inherited.
- AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes();
- AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex);
- for (auto Attr : AB.td_attrs())
+ AttrBuilder AB(oldFunction->getAttributes().getFnAttributes());
+ for (const auto &Attr : AB.td_attrs())
newFunction->addFnAttr(Attr.first, Attr.second);
newFunction->getBasicBlockList().push_back(newRootNode);
@@ -440,8 +439,10 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
// Emit a call to the new function, passing in: *pointer to struct (if
// aggregating parameters), or plan inputs and allocated memory for outputs
std::vector<Value*> params, StructValues, ReloadOutputs, Reloads;
-
- LLVMContext &Context = newFunction->getContext();
+
+ Module *M = newFunction->getParent();
+ LLVMContext &Context = M->getContext();
+ const DataLayout &DL = M->getDataLayout();
// Add inputs as params, or to be filled into the struct
for (Value *input : inputs)
@@ -456,8 +457,9 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
StructValues.push_back(output);
} else {
AllocaInst *alloca =
- new AllocaInst(output->getType(), nullptr, output->getName() + ".loc",
- &codeReplacer->getParent()->front().front());
+ new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+ nullptr, output->getName() + ".loc",
+ &codeReplacer->getParent()->front().front());
ReloadOutputs.push_back(alloca);
params.push_back(alloca);
}
@@ -473,7 +475,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
// Allocate a struct at the beginning of this function
StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
- Struct = new AllocaInst(StructArgTy, nullptr, "structArg",
+ Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+ "structArg",
&codeReplacer->getParent()->front().front());
params.push_back(Struct);
diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp
index 75a1dde57c4cc..0eee6e19efac6 100644
--- a/lib/Transforms/Utils/DemoteRegToStack.cpp
+++ b/lib/Transforms/Utils/DemoteRegToStack.cpp
@@ -28,15 +28,17 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads,
return nullptr;
}
+ Function *F = I.getParent()->getParent();
+ const DataLayout &DL = F->getParent()->getDataLayout();
+
// Create a stack slot to hold the value.
AllocaInst *Slot;
if (AllocaPoint) {
- Slot = new AllocaInst(I.getType(), nullptr,
+ Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr,
I.getName()+".reg2mem", AllocaPoint);
} else {
- Function *F = I.getParent()->getParent();
- Slot = new AllocaInst(I.getType(), nullptr, I.getName() + ".reg2mem",
- &F->getEntryBlock().front());
+ Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr,
+ I.getName() + ".reg2mem", &F->getEntryBlock().front());
}
// We cannot demote invoke instructions to the stack if their normal edge
@@ -110,14 +112,17 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) {
return nullptr;
}
+ const DataLayout &DL = P->getModule()->getDataLayout();
+
// Create a stack slot to hold the value.
AllocaInst *Slot;
if (AllocaPoint) {
- Slot = new AllocaInst(P->getType(), nullptr,
+ Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr,
P->getName()+".reg2mem", AllocaPoint);
} else {
Function *F = P->getParent()->getParent();
- Slot = new AllocaInst(P->getType(), nullptr, P->getName() + ".reg2mem",
+ Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr,
+ P->getName() + ".reg2mem",
&F->getEntryBlock().front());
}
diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp
index 4adf1754253d0..59f176e2f231d 100644
--- a/lib/Transforms/Utils/Evaluator.cpp
+++ b/lib/Transforms/Utils/Evaluator.cpp
@@ -16,6 +16,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/GlobalVariable.h"
@@ -486,7 +487,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst,
ConstantInt *Val =
dyn_cast<ConstantInt>(getVal(SI->getCondition()));
if (!Val) return false; // Cannot determine.
- NextBB = SI->findCaseValue(Val).getCaseSuccessor();
+ NextBB = SI->findCaseValue(Val)->getCaseSuccessor();
} else if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(CurInst)) {
Value *Val = getVal(IBI->getAddress())->stripPointerCasts();
if (BlockAddress *BA = dyn_cast<BlockAddress>(Val))
diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp
index 81a7c4ceffab8..73a0b2737e957 100644
--- a/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/lib/Transforms/Utils/FunctionComparator.cpp
@@ -74,14 +74,14 @@ int FunctionComparator::cmpMem(StringRef L, StringRef R) const {
return L.compare(R);
}
-int FunctionComparator::cmpAttrs(const AttributeSet L,
- const AttributeSet R) const {
+int FunctionComparator::cmpAttrs(const AttributeList L,
+ const AttributeList R) const {
if (int Res = cmpNumbers(L.getNumSlots(), R.getNumSlots()))
return Res;
for (unsigned i = 0, e = L.getNumSlots(); i != e; ++i) {
- AttributeSet::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i),
- RE = R.end(i);
+ AttributeList::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i),
+ RE = R.end(i);
for (; LI != LE && RI != RE; ++LI, ++RI) {
Attribute LA = *LI;
Attribute RA = *RI;
diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp
index 9844190ef84a2..b00f4b14068a2 100644
--- a/lib/Transforms/Utils/FunctionImportUtils.cpp
+++ b/lib/Transforms/Utils/FunctionImportUtils.cpp
@@ -21,11 +21,11 @@ using namespace llvm;
/// Checks if we should import SGV as a definition, otherwise import as a
/// declaration.
bool FunctionImportGlobalProcessing::doImportAsDefinition(
- const GlobalValue *SGV, DenseSet<const GlobalValue *> *GlobalsToImport) {
+ const GlobalValue *SGV, SetVector<GlobalValue *> *GlobalsToImport) {
// For alias, we tie the definition to the base object. Extract it and recurse
if (auto *GA = dyn_cast<GlobalAlias>(SGV)) {
- if (GA->hasWeakAnyLinkage())
+ if (GA->isInterposable())
return false;
const GlobalObject *GO = GA->getBaseObject();
if (!GO->hasLinkOnceODRLinkage())
@@ -34,7 +34,7 @@ bool FunctionImportGlobalProcessing::doImportAsDefinition(
GO, GlobalsToImport);
}
// Only import the globals requested for importing.
- if (GlobalsToImport->count(SGV))
+ if (GlobalsToImport->count(const_cast<GlobalValue *>(SGV)))
return true;
// Otherwise no.
return false;
@@ -57,7 +57,8 @@ bool FunctionImportGlobalProcessing::shouldPromoteLocalToGlobal(
return false;
if (isPerformingImport()) {
- assert((!GlobalsToImport->count(SGV) || !isNonRenamableLocal(*SGV)) &&
+ assert((!GlobalsToImport->count(const_cast<GlobalValue *>(SGV)) ||
+ !isNonRenamableLocal(*SGV)) &&
"Attempting to promote non-renamable local");
// We don't know for sure yet if we are importing this value (as either
// a reference or a def), since we are simply walking all values in the
@@ -254,9 +255,8 @@ bool FunctionImportGlobalProcessing::run() {
return false;
}
-bool llvm::renameModuleForThinLTO(
- Module &M, const ModuleSummaryIndex &Index,
- DenseSet<const GlobalValue *> *GlobalsToImport) {
+bool llvm::renameModuleForThinLTO(Module &M, const ModuleSummaryIndex &Index,
+ SetVector<GlobalValue *> *GlobalsToImport) {
FunctionImportGlobalProcessing ThinLTOProcessing(M, Index, GlobalsToImport);
return ThinLTOProcessing.run();
}
diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp
index 74ebcda8355c1..ba4b78ac758a6 100644
--- a/lib/Transforms/Utils/GlobalStatus.cpp
+++ b/lib/Transforms/Utils/GlobalStatus.cpp
@@ -10,9 +10,22 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Transforms/Utils/GlobalStatus.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <cassert>
using namespace llvm;
@@ -175,13 +188,9 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS,
return false;
}
+GlobalStatus::GlobalStatus() = default;
+
bool GlobalStatus::analyzeGlobal(const Value *V, GlobalStatus &GS) {
SmallPtrSet<const PHINode *, 16> PhiUsers;
return analyzeGlobalAux(V, GS, PhiUsers);
}
-
-GlobalStatus::GlobalStatus()
- : IsCompared(false), IsLoaded(false), StoredType(NotStored),
- StoredOnceValue(nullptr), AccessingFunction(nullptr),
- HasMultipleAccessingFunctions(false), HasNonInstructionUser(false),
- Ordering(AtomicOrdering::NotAtomic) {}
diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
index ed018bb731074..b8c12ad5ea846 100644
--- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
+++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
@@ -62,6 +62,8 @@ void ImportedFunctionsInliningStatistics::recordInline(const Function &Caller,
void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) {
ModuleName = M.getName();
for (const auto &F : M.functions()) {
+ if (F.isDeclaration())
+ continue;
AllFunctions++;
ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr);
}
diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp
index a40079ca8e76e..5d6fbc3325fff 100644
--- a/lib/Transforms/Utils/InlineFunction.cpp
+++ b/lib/Transforms/Utils/InlineFunction.cpp
@@ -20,10 +20,12 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CallSite.h"
@@ -40,8 +42,8 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
using namespace llvm;
@@ -1107,26 +1109,23 @@ static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) {
bool DTCalculated = false;
Function *CalledFunc = CS.getCalledFunction();
- for (Function::arg_iterator I = CalledFunc->arg_begin(),
- E = CalledFunc->arg_end();
- I != E; ++I) {
- unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0;
- if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) {
+ for (Argument &Arg : CalledFunc->args()) {
+ unsigned Align = Arg.getType()->isPointerTy() ? Arg.getParamAlignment() : 0;
+ if (Align && !Arg.hasByValOrInAllocaAttr() && !Arg.hasNUses(0)) {
if (!DTCalculated) {
- DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent()
- ->getParent()));
+ DT.recalculate(*CS.getCaller());
DTCalculated = true;
}
// If we can already prove the asserted alignment in the context of the
// caller, then don't bother inserting the assumption.
- Value *Arg = CS.getArgument(I->getArgNo());
- if (getKnownAlignment(Arg, DL, CS.getInstruction(), AC, &DT) >= Align)
+ Value *ArgVal = CS.getArgument(Arg.getArgNo());
+ if (getKnownAlignment(ArgVal, DL, CS.getInstruction(), AC, &DT) >= Align)
continue;
- CallInst *NewAssumption = IRBuilder<>(CS.getInstruction())
- .CreateAlignmentAssumption(DL, Arg, Align);
- AC->registerAssumption(NewAssumption);
+ CallInst *NewAsmp = IRBuilder<>(CS.getInstruction())
+ .CreateAlignmentAssumption(DL, ArgVal, Align);
+ AC->registerAssumption(NewAsmp);
}
}
}
@@ -1140,7 +1139,7 @@ static void UpdateCallGraphAfterInlining(CallSite CS,
ValueToValueMapTy &VMap,
InlineFunctionInfo &IFI) {
CallGraph &CG = *IFI.CG;
- const Function *Caller = CS.getInstruction()->getParent()->getParent();
+ const Function *Caller = CS.getCaller();
const Function *Callee = CS.getCalledFunction();
CallGraphNode *CalleeNode = CG[Callee];
CallGraphNode *CallerNode = CG[Caller];
@@ -1225,7 +1224,8 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall,
PointerType *ArgTy = cast<PointerType>(Arg->getType());
Type *AggTy = ArgTy->getElementType();
- Function *Caller = TheCall->getParent()->getParent();
+ Function *Caller = TheCall->getFunction();
+ const DataLayout &DL = Caller->getParent()->getDataLayout();
// If the called function is readonly, then it could not mutate the caller's
// copy of the byval'd memory. In this case, it is safe to elide the copy and
@@ -1239,31 +1239,30 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall,
AssumptionCache *AC =
IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr;
- const DataLayout &DL = Caller->getParent()->getDataLayout();
// If the pointer is already known to be sufficiently aligned, or if we can
// round it up to a larger alignment, then we don't need a temporary.
if (getOrEnforceKnownAlignment(Arg, ByValAlignment, DL, TheCall, AC) >=
ByValAlignment)
return Arg;
-
+
// Otherwise, we have to make a memcpy to get a safe alignment. This is bad
// for code quality, but rarely happens and is required for correctness.
}
// Create the alloca. If we have DataLayout, use nice alignment.
- unsigned Align =
- Caller->getParent()->getDataLayout().getPrefTypeAlignment(AggTy);
+ unsigned Align = DL.getPrefTypeAlignment(AggTy);
// If the byval had an alignment specified, we *must* use at least that
// alignment, as it is required by the byval argument (and uses of the
// pointer inside the callee).
Align = std::max(Align, ByValAlignment);
-
- Value *NewAlloca = new AllocaInst(AggTy, nullptr, Align, Arg->getName(),
+
+ Value *NewAlloca = new AllocaInst(AggTy, DL.getAllocaAddrSpace(),
+ nullptr, Align, Arg->getName(),
&*Caller->begin()->begin());
IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca));
-
+
// Uses of the argument in the function should use our new alloca
// instead.
return NewAlloca;
@@ -1393,6 +1392,89 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,
}
}
}
+/// Update the block frequencies of the caller after a callee has been inlined.
+///
+/// Each block cloned into the caller has its block frequency scaled by the
+/// ratio of CallSiteFreq/CalleeEntryFreq. This ensures that the cloned copy of
+/// callee's entry block gets the same frequency as the callsite block and the
+/// relative frequencies of all cloned blocks remain the same after cloning.
+static void updateCallerBFI(BasicBlock *CallSiteBlock,
+ const ValueToValueMapTy &VMap,
+ BlockFrequencyInfo *CallerBFI,
+ BlockFrequencyInfo *CalleeBFI,
+ const BasicBlock &CalleeEntryBlock) {
+ SmallPtrSet<BasicBlock *, 16> ClonedBBs;
+ for (auto const &Entry : VMap) {
+ if (!isa<BasicBlock>(Entry.first) || !Entry.second)
+ continue;
+ auto *OrigBB = cast<BasicBlock>(Entry.first);
+ auto *ClonedBB = cast<BasicBlock>(Entry.second);
+ uint64_t Freq = CalleeBFI->getBlockFreq(OrigBB).getFrequency();
+ if (!ClonedBBs.insert(ClonedBB).second) {
+ // Multiple blocks in the callee might get mapped to one cloned block in
+ // the caller since we prune the callee as we clone it. When that happens,
+ // we want to use the maximum among the original blocks' frequencies.
+ uint64_t NewFreq = CallerBFI->getBlockFreq(ClonedBB).getFrequency();
+ if (NewFreq > Freq)
+ Freq = NewFreq;
+ }
+ CallerBFI->setBlockFreq(ClonedBB, Freq);
+ }
+ BasicBlock *EntryClone = cast<BasicBlock>(VMap.lookup(&CalleeEntryBlock));
+ CallerBFI->setBlockFreqAndScale(
+ EntryClone, CallerBFI->getBlockFreq(CallSiteBlock).getFrequency(),
+ ClonedBBs);
+}
+
+/// Update the branch metadata for cloned call instructions.
+static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap,
+ const Optional<uint64_t> &CalleeEntryCount,
+ const Instruction *TheCall) {
+ if (!CalleeEntryCount.hasValue() || CalleeEntryCount.getValue() < 1)
+ return;
+ Optional<uint64_t> CallSiteCount =
+ ProfileSummaryInfo::getProfileCount(TheCall, nullptr);
+ uint64_t CallCount =
+ std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0,
+ CalleeEntryCount.getValue());
+
+ for (auto const &Entry : VMap)
+ if (isa<CallInst>(Entry.first))
+ if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second))
+ CI->updateProfWeight(CallCount, CalleeEntryCount.getValue());
+ for (BasicBlock &BB : *Callee)
+ // No need to update the callsite if it is pruned during inlining.
+ if (VMap.count(&BB))
+ for (Instruction &I : BB)
+ if (CallInst *CI = dyn_cast<CallInst>(&I))
+ CI->updateProfWeight(CalleeEntryCount.getValue() - CallCount,
+ CalleeEntryCount.getValue());
+}
+
+/// Update the entry count of callee after inlining.
+///
+/// The callsite's block count is subtracted from the callee's function entry
+/// count.
+static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB,
+ Instruction *CallInst, Function *Callee) {
+ // If the callee has a original count of N, and the estimated count of
+ // callsite is M, the new callee count is set to N - M. M is estimated from
+ // the caller's entry count, its entry block frequency and the block frequency
+ // of the callsite.
+ Optional<uint64_t> CalleeCount = Callee->getEntryCount();
+ if (!CalleeCount.hasValue())
+ return;
+ Optional<uint64_t> CallCount =
+ ProfileSummaryInfo::getProfileCount(CallInst, CallerBFI);
+ if (!CallCount.hasValue())
+ return;
+ // Since CallSiteCount is an estimate, it could exceed the original callee
+ // count and has to be set to 0.
+ if (CallCount.getValue() > CalleeCount.getValue())
+ Callee->setEntryCount(0);
+ else
+ Callee->setEntryCount(CalleeCount.getValue() - CallCount.getValue());
+}
/// This function inlines the called function into the basic block of the
/// caller. This returns false if it is not possible to inline this call.
@@ -1405,13 +1487,13 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,
bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
AAResults *CalleeAAR, bool InsertLifetime) {
Instruction *TheCall = CS.getInstruction();
- assert(TheCall->getParent() && TheCall->getParent()->getParent() &&
- "Instruction not in function!");
+ assert(TheCall->getParent() && TheCall->getFunction()
+ && "Instruction not in function!");
// If IFI has any state in it, zap it before we fill it in.
IFI.reset();
-
- const Function *CalledFunc = CS.getCalledFunction();
+
+ Function *CalledFunc = CS.getCalledFunction();
if (!CalledFunc || // Can't inline external function or indirect
CalledFunc->isDeclaration() || // call, or call to a vararg function!
CalledFunc->getFunctionType()->isVarArg()) return false;
@@ -1548,7 +1630,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
// matches up the formal to the actual argument values.
CallSite::arg_iterator AI = CS.arg_begin();
unsigned ArgNo = 0;
- for (Function::const_arg_iterator I = CalledFunc->arg_begin(),
+ for (Function::arg_iterator I = CalledFunc->arg_begin(),
E = CalledFunc->arg_end(); I != E; ++I, ++AI, ++ArgNo) {
Value *ActualArg = *AI;
@@ -1578,10 +1660,18 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
CloneAndPruneFunctionInto(Caller, CalledFunc, VMap,
/*ModuleLevelChanges=*/false, Returns, ".i",
&InlinedFunctionInfo, TheCall);
-
// Remember the first block that is newly cloned over.
FirstNewBlock = LastBlock; ++FirstNewBlock;
+ if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr)
+ // Update the BFI of blocks cloned into the caller.
+ updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI,
+ CalledFunc->front());
+
+ updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall);
+ // Update the profile count of callee.
+ updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc);
+
// Inject byval arguments initialization.
for (std::pair<Value*, Value*> &Init : ByValInit)
HandleByValArgumentInit(Init.first, Init.second, Caller->getParent(),
@@ -2087,6 +2177,12 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
CalledFunc->getName() + ".exit");
}
+ if (IFI.CallerBFI) {
+ // Copy original BB's block frequency to AfterCallBB
+ IFI.CallerBFI->setBlockFreq(
+ AfterCallBB, IFI.CallerBFI->getBlockFreq(OrigBB).getFrequency());
+ }
+
// Change the branch that used to go to AfterCallBB to branch to the first
// basic block of the inlined function.
//
diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp
index 68c6b74d5e5b7..49b4bd92faf4b 100644
--- a/lib/Transforms/Utils/LCSSA.cpp
+++ b/lib/Transforms/Utils/LCSSA.cpp
@@ -87,7 +87,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
Instruction *I = Worklist.pop_back_val();
BasicBlock *InstBB = I->getParent();
Loop *L = LI.getLoopFor(InstBB);
- if (!LoopExitBlocks.count(L))
+ assert(L && "Instruction belongs to a BB that's not part of a loop");
+ if (!LoopExitBlocks.count(L))
L->getExitBlocks(LoopExitBlocks[L]);
assert(LoopExitBlocks.count(L));
const SmallVectorImpl<BasicBlock *> &ExitBlocks = LoopExitBlocks[L];
@@ -105,7 +106,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
for (Use &U : I->uses()) {
Instruction *User = cast<Instruction>(U.getUser());
BasicBlock *UserBB = User->getParent();
- if (PHINode *PN = dyn_cast<PHINode>(User))
+ if (auto *PN = dyn_cast<PHINode>(User))
UserBB = PN->getIncomingBlock(U);
if (InstBB != UserBB && !L->contains(UserBB))
@@ -123,7 +124,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
// DomBB dominates the value, so adjust DomBB to the normal destination
// block, which is effectively where the value is first usable.
BasicBlock *DomBB = InstBB;
- if (InvokeInst *Inv = dyn_cast<InvokeInst>(I))
+ if (auto *Inv = dyn_cast<InvokeInst>(I))
DomBB = Inv->getNormalDest();
DomTreeNode *DomNode = DT.getNode(DomBB);
@@ -188,7 +189,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
// block.
Instruction *User = cast<Instruction>(UseToRewrite->getUser());
BasicBlock *UserBB = User->getParent();
- if (PHINode *PN = dyn_cast<PHINode>(User))
+ if (auto *PN = dyn_cast<PHINode>(User))
UserBB = PN->getIncomingBlock(*UseToRewrite);
if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) {
@@ -237,40 +238,75 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
return Changed;
}
-/// Return true if the specified block dominates at least
-/// one of the blocks in the specified list.
-static bool
-blockDominatesAnExit(BasicBlock *BB,
- DominatorTree &DT,
- const SmallVectorImpl<BasicBlock *> &ExitBlocks) {
- DomTreeNode *DomNode = DT.getNode(BB);
- return any_of(ExitBlocks, [&](BasicBlock *EB) {
- return DT.dominates(DomNode, DT.getNode(EB));
- });
+// Compute the set of BasicBlocks in the loop `L` dominating at least one exit.
+static void computeBlocksDominatingExits(
+ Loop &L, DominatorTree &DT, SmallVector<BasicBlock *, 8> &ExitBlocks,
+ SmallPtrSet<BasicBlock *, 8> &BlocksDominatingExits) {
+ SmallVector<BasicBlock *, 8> BBWorklist;
+
+ // We start from the exit blocks, as every block trivially dominates itself
+ // (not strictly).
+ for (BasicBlock *BB : ExitBlocks)
+ BBWorklist.push_back(BB);
+
+ while (!BBWorklist.empty()) {
+ BasicBlock *BB = BBWorklist.pop_back_val();
+
+ // Check if this is a loop header. If this is the case, we're done.
+ if (L.getHeader() == BB)
+ continue;
+
+ // Otherwise, add its immediate predecessor in the dominator tree to the
+ // worklist, unless we visited it already.
+ BasicBlock *IDomBB = DT.getNode(BB)->getIDom()->getBlock();
+
+ // Exit blocks can have an immediate dominator not beloinging to the
+ // loop. For an exit block to be immediately dominated by another block
+ // outside the loop, it implies not all paths from that dominator, to the
+ // exit block, go through the loop.
+ // Example:
+ //
+ // |---- A
+ // | |
+ // | B<--
+ // | | |
+ // |---> C --
+ // |
+ // D
+ //
+ // C is the exit block of the loop and it's immediately dominated by A,
+ // which doesn't belong to the loop.
+ if (!L.contains(IDomBB))
+ continue;
+
+ if (BlocksDominatingExits.insert(IDomBB).second)
+ BBWorklist.push_back(IDomBB);
+ }
}
bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI,
ScalarEvolution *SE) {
bool Changed = false;
- // Get the set of exiting blocks.
SmallVector<BasicBlock *, 8> ExitBlocks;
L.getExitBlocks(ExitBlocks);
-
if (ExitBlocks.empty())
return false;
+ SmallPtrSet<BasicBlock *, 8> BlocksDominatingExits;
+
+ // We want to avoid use-scanning leveraging dominance informations.
+ // If a block doesn't dominate any of the loop exits, the none of the values
+ // defined in the loop can be used outside.
+ // We compute the set of blocks fullfilling the conditions in advance
+ // walking the dominator tree upwards until we hit a loop header.
+ computeBlocksDominatingExits(L, DT, ExitBlocks, BlocksDominatingExits);
+
SmallVector<Instruction *, 8> Worklist;
// Look at all the instructions in the loop, checking to see if they have uses
// outside the loop. If so, put them into the worklist to rewrite those uses.
- for (BasicBlock *BB : L.blocks()) {
- // For large loops, avoid use-scanning by using dominance information: In
- // particular, if a block does not dominate any of the loop exits, then none
- // of the values defined in the block could be used outside the loop.
- if (!blockDominatesAnExit(BB, DT, ExitBlocks))
- continue;
-
+ for (BasicBlock *BB : BlocksDominatingExits) {
for (Instruction &I : *BB) {
// Reject two common cases fast: instructions with no uses (like stores)
// and instructions with one use that is in the same block as this.
@@ -395,8 +431,8 @@ PreservedAnalyses LCSSAPass::run(Function &F, FunctionAnalysisManager &AM) {
if (!formLCSSAOnAllLoops(&LI, DT, SE))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<BasicAA>();
PA.preserve<GlobalsAA>();
PA.preserve<SCEVAA>();
diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
index d97cd7582eaae..fe93d6927c638 100644
--- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
+++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
@@ -100,12 +100,12 @@ private:
bool perform(CallInst *CI);
void checkCandidate(CallInst &CI);
void shrinkWrapCI(CallInst *CI, Value *Cond);
- bool performCallDomainErrorOnly(CallInst *CI, const LibFunc::Func &Func);
- bool performCallErrors(CallInst *CI, const LibFunc::Func &Func);
- bool performCallRangeErrorOnly(CallInst *CI, const LibFunc::Func &Func);
- Value *generateOneRangeCond(CallInst *CI, const LibFunc::Func &Func);
- Value *generateTwoRangeCond(CallInst *CI, const LibFunc::Func &Func);
- Value *generateCondForPow(CallInst *CI, const LibFunc::Func &Func);
+ bool performCallDomainErrorOnly(CallInst *CI, const LibFunc &Func);
+ bool performCallErrors(CallInst *CI, const LibFunc &Func);
+ bool performCallRangeErrorOnly(CallInst *CI, const LibFunc &Func);
+ Value *generateOneRangeCond(CallInst *CI, const LibFunc &Func);
+ Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func);
+ Value *generateCondForPow(CallInst *CI, const LibFunc &Func);
// Create an OR of two conditions.
Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val,
@@ -141,44 +141,44 @@ private:
// Perform the transformation to calls with errno set by domain error.
bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI,
- const LibFunc::Func &Func) {
+ const LibFunc &Func) {
Value *Cond = nullptr;
switch (Func) {
- case LibFunc::acos: // DomainError: (x < -1 || x > 1)
- case LibFunc::acosf: // Same as acos
- case LibFunc::acosl: // Same as acos
- case LibFunc::asin: // DomainError: (x < -1 || x > 1)
- case LibFunc::asinf: // Same as asin
- case LibFunc::asinl: // Same as asin
+ case LibFunc_acos: // DomainError: (x < -1 || x > 1)
+ case LibFunc_acosf: // Same as acos
+ case LibFunc_acosl: // Same as acos
+ case LibFunc_asin: // DomainError: (x < -1 || x > 1)
+ case LibFunc_asinf: // Same as asin
+ case LibFunc_asinl: // Same as asin
{
++NumWrappedTwoCond;
Cond = createOrCond(CI, CmpInst::FCMP_OLT, -1.0f, CmpInst::FCMP_OGT, 1.0f);
break;
}
- case LibFunc::cos: // DomainError: (x == +inf || x == -inf)
- case LibFunc::cosf: // Same as cos
- case LibFunc::cosl: // Same as cos
- case LibFunc::sin: // DomainError: (x == +inf || x == -inf)
- case LibFunc::sinf: // Same as sin
- case LibFunc::sinl: // Same as sin
+ case LibFunc_cos: // DomainError: (x == +inf || x == -inf)
+ case LibFunc_cosf: // Same as cos
+ case LibFunc_cosl: // Same as cos
+ case LibFunc_sin: // DomainError: (x == +inf || x == -inf)
+ case LibFunc_sinf: // Same as sin
+ case LibFunc_sinl: // Same as sin
{
++NumWrappedTwoCond;
Cond = createOrCond(CI, CmpInst::FCMP_OEQ, INFINITY, CmpInst::FCMP_OEQ,
-INFINITY);
break;
}
- case LibFunc::acosh: // DomainError: (x < 1)
- case LibFunc::acoshf: // Same as acosh
- case LibFunc::acoshl: // Same as acosh
+ case LibFunc_acosh: // DomainError: (x < 1)
+ case LibFunc_acoshf: // Same as acosh
+ case LibFunc_acoshl: // Same as acosh
{
++NumWrappedOneCond;
Cond = createCond(CI, CmpInst::FCMP_OLT, 1.0f);
break;
}
- case LibFunc::sqrt: // DomainError: (x < 0)
- case LibFunc::sqrtf: // Same as sqrt
- case LibFunc::sqrtl: // Same as sqrt
+ case LibFunc_sqrt: // DomainError: (x < 0)
+ case LibFunc_sqrtf: // Same as sqrt
+ case LibFunc_sqrtl: // Same as sqrt
{
++NumWrappedOneCond;
Cond = createCond(CI, CmpInst::FCMP_OLT, 0.0f);
@@ -193,31 +193,31 @@ bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI,
// Perform the transformation to calls with errno set by range error.
bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI,
- const LibFunc::Func &Func) {
+ const LibFunc &Func) {
Value *Cond = nullptr;
switch (Func) {
- case LibFunc::cosh:
- case LibFunc::coshf:
- case LibFunc::coshl:
- case LibFunc::exp:
- case LibFunc::expf:
- case LibFunc::expl:
- case LibFunc::exp10:
- case LibFunc::exp10f:
- case LibFunc::exp10l:
- case LibFunc::exp2:
- case LibFunc::exp2f:
- case LibFunc::exp2l:
- case LibFunc::sinh:
- case LibFunc::sinhf:
- case LibFunc::sinhl: {
+ case LibFunc_cosh:
+ case LibFunc_coshf:
+ case LibFunc_coshl:
+ case LibFunc_exp:
+ case LibFunc_expf:
+ case LibFunc_expl:
+ case LibFunc_exp10:
+ case LibFunc_exp10f:
+ case LibFunc_exp10l:
+ case LibFunc_exp2:
+ case LibFunc_exp2f:
+ case LibFunc_exp2l:
+ case LibFunc_sinh:
+ case LibFunc_sinhf:
+ case LibFunc_sinhl: {
Cond = generateTwoRangeCond(CI, Func);
break;
}
- case LibFunc::expm1: // RangeError: (709, inf)
- case LibFunc::expm1f: // RangeError: (88, inf)
- case LibFunc::expm1l: // RangeError: (11356, inf)
+ case LibFunc_expm1: // RangeError: (709, inf)
+ case LibFunc_expm1f: // RangeError: (88, inf)
+ case LibFunc_expm1l: // RangeError: (11356, inf)
{
Cond = generateOneRangeCond(CI, Func);
break;
@@ -231,15 +231,15 @@ bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI,
// Perform the transformation to calls with errno set by combination of errors.
bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,
- const LibFunc::Func &Func) {
+ const LibFunc &Func) {
Value *Cond = nullptr;
switch (Func) {
- case LibFunc::atanh: // DomainError: (x < -1 || x > 1)
+ case LibFunc_atanh: // DomainError: (x < -1 || x > 1)
// PoleError: (x == -1 || x == 1)
// Overall Cond: (x <= -1 || x >= 1)
- case LibFunc::atanhf: // Same as atanh
- case LibFunc::atanhl: // Same as atanh
+ case LibFunc_atanhf: // Same as atanh
+ case LibFunc_atanhl: // Same as atanh
{
if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)
return false;
@@ -247,20 +247,20 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,
Cond = createOrCond(CI, CmpInst::FCMP_OLE, -1.0f, CmpInst::FCMP_OGE, 1.0f);
break;
}
- case LibFunc::log: // DomainError: (x < 0)
+ case LibFunc_log: // DomainError: (x < 0)
// PoleError: (x == 0)
// Overall Cond: (x <= 0)
- case LibFunc::logf: // Same as log
- case LibFunc::logl: // Same as log
- case LibFunc::log10: // Same as log
- case LibFunc::log10f: // Same as log
- case LibFunc::log10l: // Same as log
- case LibFunc::log2: // Same as log
- case LibFunc::log2f: // Same as log
- case LibFunc::log2l: // Same as log
- case LibFunc::logb: // Same as log
- case LibFunc::logbf: // Same as log
- case LibFunc::logbl: // Same as log
+ case LibFunc_logf: // Same as log
+ case LibFunc_logl: // Same as log
+ case LibFunc_log10: // Same as log
+ case LibFunc_log10f: // Same as log
+ case LibFunc_log10l: // Same as log
+ case LibFunc_log2: // Same as log
+ case LibFunc_log2f: // Same as log
+ case LibFunc_log2l: // Same as log
+ case LibFunc_logb: // Same as log
+ case LibFunc_logbf: // Same as log
+ case LibFunc_logbl: // Same as log
{
if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)
return false;
@@ -268,11 +268,11 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,
Cond = createCond(CI, CmpInst::FCMP_OLE, 0.0f);
break;
}
- case LibFunc::log1p: // DomainError: (x < -1)
+ case LibFunc_log1p: // DomainError: (x < -1)
// PoleError: (x == -1)
// Overall Cond: (x <= -1)
- case LibFunc::log1pf: // Same as log1p
- case LibFunc::log1pl: // Same as log1p
+ case LibFunc_log1pf: // Same as log1p
+ case LibFunc_log1pl: // Same as log1p
{
if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError)
return false;
@@ -280,11 +280,11 @@ bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,
Cond = createCond(CI, CmpInst::FCMP_OLE, -1.0f);
break;
}
- case LibFunc::pow: // DomainError: x < 0 and y is noninteger
+ case LibFunc_pow: // DomainError: x < 0 and y is noninteger
// PoleError: x == 0 and y < 0
// RangeError: overflow or underflow
- case LibFunc::powf:
- case LibFunc::powl: {
+ case LibFunc_powf:
+ case LibFunc_powl: {
if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError ||
!LibCallsShrinkWrapDoRangeError)
return false;
@@ -313,7 +313,7 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) {
if (!CI.use_empty())
return;
- LibFunc::Func Func;
+ LibFunc Func;
Function *Callee = CI.getCalledFunction();
if (!Callee)
return;
@@ -333,16 +333,16 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) {
// Generate the upper bound condition for RangeError.
Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI,
- const LibFunc::Func &Func) {
+ const LibFunc &Func) {
float UpperBound;
switch (Func) {
- case LibFunc::expm1: // RangeError: (709, inf)
+ case LibFunc_expm1: // RangeError: (709, inf)
UpperBound = 709.0f;
break;
- case LibFunc::expm1f: // RangeError: (88, inf)
+ case LibFunc_expm1f: // RangeError: (88, inf)
UpperBound = 88.0f;
break;
- case LibFunc::expm1l: // RangeError: (11356, inf)
+ case LibFunc_expm1l: // RangeError: (11356, inf)
UpperBound = 11356.0f;
break;
default:
@@ -355,57 +355,57 @@ Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI,
// Generate the lower and upper bound condition for RangeError.
Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI,
- const LibFunc::Func &Func) {
+ const LibFunc &Func) {
float UpperBound, LowerBound;
switch (Func) {
- case LibFunc::cosh: // RangeError: (x < -710 || x > 710)
- case LibFunc::sinh: // Same as cosh
+ case LibFunc_cosh: // RangeError: (x < -710 || x > 710)
+ case LibFunc_sinh: // Same as cosh
LowerBound = -710.0f;
UpperBound = 710.0f;
break;
- case LibFunc::coshf: // RangeError: (x < -89 || x > 89)
- case LibFunc::sinhf: // Same as coshf
+ case LibFunc_coshf: // RangeError: (x < -89 || x > 89)
+ case LibFunc_sinhf: // Same as coshf
LowerBound = -89.0f;
UpperBound = 89.0f;
break;
- case LibFunc::coshl: // RangeError: (x < -11357 || x > 11357)
- case LibFunc::sinhl: // Same as coshl
+ case LibFunc_coshl: // RangeError: (x < -11357 || x > 11357)
+ case LibFunc_sinhl: // Same as coshl
LowerBound = -11357.0f;
UpperBound = 11357.0f;
break;
- case LibFunc::exp: // RangeError: (x < -745 || x > 709)
+ case LibFunc_exp: // RangeError: (x < -745 || x > 709)
LowerBound = -745.0f;
UpperBound = 709.0f;
break;
- case LibFunc::expf: // RangeError: (x < -103 || x > 88)
+ case LibFunc_expf: // RangeError: (x < -103 || x > 88)
LowerBound = -103.0f;
UpperBound = 88.0f;
break;
- case LibFunc::expl: // RangeError: (x < -11399 || x > 11356)
+ case LibFunc_expl: // RangeError: (x < -11399 || x > 11356)
LowerBound = -11399.0f;
UpperBound = 11356.0f;
break;
- case LibFunc::exp10: // RangeError: (x < -323 || x > 308)
+ case LibFunc_exp10: // RangeError: (x < -323 || x > 308)
LowerBound = -323.0f;
UpperBound = 308.0f;
break;
- case LibFunc::exp10f: // RangeError: (x < -45 || x > 38)
+ case LibFunc_exp10f: // RangeError: (x < -45 || x > 38)
LowerBound = -45.0f;
UpperBound = 38.0f;
break;
- case LibFunc::exp10l: // RangeError: (x < -4950 || x > 4932)
+ case LibFunc_exp10l: // RangeError: (x < -4950 || x > 4932)
LowerBound = -4950.0f;
UpperBound = 4932.0f;
break;
- case LibFunc::exp2: // RangeError: (x < -1074 || x > 1023)
+ case LibFunc_exp2: // RangeError: (x < -1074 || x > 1023)
LowerBound = -1074.0f;
UpperBound = 1023.0f;
break;
- case LibFunc::exp2f: // RangeError: (x < -149 || x > 127)
+ case LibFunc_exp2f: // RangeError: (x < -149 || x > 127)
LowerBound = -149.0f;
UpperBound = 127.0f;
break;
- case LibFunc::exp2l: // RangeError: (x < -16445 || x > 11383)
+ case LibFunc_exp2l: // RangeError: (x < -16445 || x > 11383)
LowerBound = -16445.0f;
UpperBound = 11383.0f;
break;
@@ -434,9 +434,9 @@ Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI,
// (i.e. we might invoke the calls that will not set the errno.).
//
Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI,
- const LibFunc::Func &Func) {
- // FIXME: LibFunc::powf and powl TBD.
- if (Func != LibFunc::pow) {
+ const LibFunc &Func) {
+ // FIXME: LibFunc_powf and powl TBD.
+ if (Func != LibFunc_pow) {
DEBUG(dbgs() << "Not handled powf() and powl()\n");
return nullptr;
}
@@ -516,7 +516,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) {
// Perform the transformation to a single candidate.
bool LibCallsShrinkWrap::perform(CallInst *CI) {
- LibFunc::Func Func;
+ LibFunc Func;
Function *Callee = CI->getCalledFunction();
assert(Callee && "perform() should apply to a non-empty callee");
TLI.getLibFunc(*Callee, Func);
diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp
index 6e4174aa0cda1..18b29226c2ef5 100644
--- a/lib/Transforms/Utils/Local.cpp
+++ b/lib/Transforms/Utils/Local.cpp
@@ -126,21 +126,20 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// If the default is unreachable, ignore it when searching for TheOnlyDest.
if (isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()) &&
SI->getNumCases() > 0) {
- TheOnlyDest = SI->case_begin().getCaseSuccessor();
+ TheOnlyDest = SI->case_begin()->getCaseSuccessor();
}
// Figure out which case it goes to.
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
- i != e; ++i) {
+ for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) {
// Found case matching a constant operand?
- if (i.getCaseValue() == CI) {
- TheOnlyDest = i.getCaseSuccessor();
+ if (i->getCaseValue() == CI) {
+ TheOnlyDest = i->getCaseSuccessor();
break;
}
// Check to see if this branch is going to the same place as the default
// dest. If so, eliminate it as an explicit compare.
- if (i.getCaseSuccessor() == DefaultDest) {
+ if (i->getCaseSuccessor() == DefaultDest) {
MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
unsigned NCases = SI->getNumCases();
// Fold the case metadata into the default if there will be any branches
@@ -154,7 +153,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
Weights.push_back(CI->getValue().getZExtValue());
}
// Merge weight of this case to the default weight.
- unsigned idx = i.getCaseIndex();
+ unsigned idx = i->getCaseIndex();
Weights[0] += Weights[idx+1];
// Remove weight for this case.
std::swap(Weights[idx+1], Weights.back());
@@ -165,15 +164,19 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
}
// Remove this entry.
DefaultDest->removePredecessor(SI->getParent());
- SI->removeCase(i);
- --i; --e;
+ i = SI->removeCase(i);
+ e = SI->case_end();
continue;
}
// Otherwise, check to see if the switch only branches to one destination.
// We do this by reseting "TheOnlyDest" to null when we find two non-equal
// destinations.
- if (i.getCaseSuccessor() != TheOnlyDest) TheOnlyDest = nullptr;
+ if (i->getCaseSuccessor() != TheOnlyDest)
+ TheOnlyDest = nullptr;
+
+ // Increment this iterator as we haven't removed the case.
+ ++i;
}
if (CI && !TheOnlyDest) {
@@ -209,7 +212,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
if (SI->getNumCases() == 1) {
// Otherwise, we can fold this switch into a conditional branch
// instruction if it has only one non-default destination.
- SwitchInst::CaseIt FirstCase = SI->case_begin();
+ auto FirstCase = *SI->case_begin();
Value *Cond = Builder.CreateICmpEQ(SI->getCondition(),
FirstCase.getCaseValue(), "cond");
@@ -287,7 +290,15 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
///
bool llvm::isInstructionTriviallyDead(Instruction *I,
const TargetLibraryInfo *TLI) {
- if (!I->use_empty() || isa<TerminatorInst>(I)) return false;
+ if (!I->use_empty())
+ return false;
+ return wouldInstructionBeTriviallyDead(I, TLI);
+}
+
+bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
+ const TargetLibraryInfo *TLI) {
+ if (isa<TerminatorInst>(I))
+ return false;
// We don't want the landingpad-like instructions removed by anything this
// general.
@@ -307,7 +318,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I,
return true;
}
- if (!I->mayHaveSideEffects()) return true;
+ if (!I->mayHaveSideEffects())
+ return true;
// Special case intrinsics that "may have side effects" but can be deleted
// when dead.
@@ -334,7 +346,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I,
}
}
- if (isAllocLikeFn(I, TLI)) return true;
+ if (isAllocLikeFn(I, TLI))
+ return true;
if (CallInst *CI = isFreeCall(I, TLI))
if (Constant *C = dyn_cast<Constant>(CI->getArgOperand(0)))
@@ -1075,11 +1088,11 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar,
// Since we can't guarantee that the original dbg.declare instrinsic
// is removed by LowerDbgDeclare(), we need to make sure that we are
// not inserting the same dbg.value intrinsic over and over.
- DbgValueList DbgValues;
- FindAllocaDbgValues(DbgValues, APN);
- for (auto DVI : DbgValues) {
- assert (DVI->getValue() == APN);
- assert (DVI->getOffset() == 0);
+ SmallVector<DbgValueInst *, 1> DbgValues;
+ findDbgValues(DbgValues, APN);
+ for (auto *DVI : DbgValues) {
+ assert(DVI->getValue() == APN);
+ assert(DVI->getOffset() == 0);
if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr))
return true;
}
@@ -1241,9 +1254,7 @@ DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) {
return nullptr;
}
-/// FindAllocaDbgValues - Finds the llvm.dbg.value intrinsics describing the
-/// alloca 'V', if any.
-void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) {
+void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) {
if (auto *L = LocalAsMetadata::getIfExists(V))
if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L))
for (User *U : MDV->users())
@@ -1251,36 +1262,32 @@ void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) {
DbgValues.push_back(DVI);
}
-static void DIExprAddDeref(SmallVectorImpl<uint64_t> &Expr) {
- Expr.push_back(dwarf::DW_OP_deref);
-}
-
-static void DIExprAddOffset(SmallVectorImpl<uint64_t> &Expr, int Offset) {
+static void appendOffset(SmallVectorImpl<uint64_t> &Ops, int64_t Offset) {
if (Offset > 0) {
- Expr.push_back(dwarf::DW_OP_plus);
- Expr.push_back(Offset);
+ Ops.push_back(dwarf::DW_OP_plus);
+ Ops.push_back(Offset);
} else if (Offset < 0) {
- Expr.push_back(dwarf::DW_OP_minus);
- Expr.push_back(-Offset);
+ Ops.push_back(dwarf::DW_OP_minus);
+ Ops.push_back(-Offset);
}
}
-static DIExpression *BuildReplacementDIExpr(DIBuilder &Builder,
- DIExpression *DIExpr, bool Deref,
- int Offset) {
+/// Prepend \p DIExpr with a deref and offset operation.
+static DIExpression *prependDIExpr(DIBuilder &Builder, DIExpression *DIExpr,
+ bool Deref, int64_t Offset) {
if (!Deref && !Offset)
return DIExpr;
// Create a copy of the original DIDescriptor for user variable, prepending
// "deref" operation to a list of address elements, as new llvm.dbg.declare
// will take a value storing address of the memory for variable, not
// alloca itself.
- SmallVector<uint64_t, 4> NewDIExpr;
+ SmallVector<uint64_t, 4> Ops;
if (Deref)
- DIExprAddDeref(NewDIExpr);
- DIExprAddOffset(NewDIExpr, Offset);
+ Ops.push_back(dwarf::DW_OP_deref);
+ appendOffset(Ops, Offset);
if (DIExpr)
- NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end());
- return Builder.createExpression(NewDIExpr);
+ Ops.append(DIExpr->elements_begin(), DIExpr->elements_end());
+ return Builder.createExpression(Ops);
}
bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,
@@ -1294,7 +1301,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,
auto *DIExpr = DDI->getExpression();
assert(DIVar && "Missing variable");
- DIExpr = BuildReplacementDIExpr(Builder, DIExpr, Deref, Offset);
+ DIExpr = prependDIExpr(Builder, DIExpr, Deref, Offset);
// Insert llvm.dbg.declare immediately after the original alloca, and remove
// old llvm.dbg.declare.
@@ -1326,11 +1333,11 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress,
// Insert the offset immediately after the first deref.
// We could just change the offset argument of dbg.value, but it's unsigned...
if (Offset) {
- SmallVector<uint64_t, 4> NewDIExpr;
- DIExprAddDeref(NewDIExpr);
- DIExprAddOffset(NewDIExpr, Offset);
- NewDIExpr.append(DIExpr->elements_begin() + 1, DIExpr->elements_end());
- DIExpr = Builder.createExpression(NewDIExpr);
+ SmallVector<uint64_t, 4> Ops;
+ Ops.push_back(dwarf::DW_OP_deref);
+ appendOffset(Ops, Offset);
+ Ops.append(DIExpr->elements_begin() + 1, DIExpr->elements_end());
+ DIExpr = Builder.createExpression(Ops);
}
Builder.insertDbgValueIntrinsic(NewAddress, DVI->getOffset(), DIVar, DIExpr,
@@ -1349,6 +1356,53 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress,
}
}
+void llvm::salvageDebugInfo(Instruction &I) {
+ SmallVector<DbgValueInst *, 1> DbgValues;
+ auto &M = *I.getModule();
+
+ auto MDWrap = [&](Value *V) {
+ return MetadataAsValue::get(I.getContext(), ValueAsMetadata::get(V));
+ };
+
+ if (isa<BitCastInst>(&I)) {
+ findDbgValues(DbgValues, &I);
+ for (auto *DVI : DbgValues) {
+ // Bitcasts are entirely irrelevant for debug info. Rewrite the dbg.value
+ // to use the cast's source.
+ DVI->setOperand(0, MDWrap(I.getOperand(0)));
+ DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n');
+ }
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
+ findDbgValues(DbgValues, &I);
+ for (auto *DVI : DbgValues) {
+ unsigned BitWidth =
+ M.getDataLayout().getPointerSizeInBits(GEP->getPointerAddressSpace());
+ APInt Offset(BitWidth, 0);
+ // Rewrite a constant GEP into a DIExpression.
+ if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) {
+ auto *DIExpr = DVI->getExpression();
+ DIBuilder DIB(M, /*AllowUnresolved*/ false);
+ // GEP offsets are i32 and thus alwaus fit into an int64_t.
+ DIExpr = prependDIExpr(DIB, DIExpr, NoDeref, Offset.getSExtValue());
+ DVI->setOperand(0, MDWrap(I.getOperand(0)));
+ DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr));
+ DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n');
+ }
+ }
+ } else if (isa<LoadInst>(&I)) {
+ findDbgValues(DbgValues, &I);
+ for (auto *DVI : DbgValues) {
+ // Rewrite the load into DW_OP_deref.
+ auto *DIExpr = DVI->getExpression();
+ DIBuilder DIB(M, /*AllowUnresolved*/ false);
+ DIExpr = prependDIExpr(DIB, DIExpr, WithDeref, 0);
+ DVI->setOperand(0, MDWrap(I.getOperand(0)));
+ DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr));
+ DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n');
+ }
+ }
+}
+
unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) {
unsigned NumDeadInst = 0;
// Delete the instructions backwards, as it has a reduced likelihood of
@@ -2068,9 +2122,9 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(
void llvm::maybeMarkSanitizerLibraryCallNoBuiltin(
CallInst *CI, const TargetLibraryInfo *TLI) {
Function *F = CI->getCalledFunction();
- LibFunc::Func Func;
+ LibFunc Func;
if (F && !F->hasLocalLinkage() && F->hasName() &&
TLI->getLibFunc(F->getName(), Func) && TLI->hasOptimizedCodeGen(Func) &&
!F->doesNotAccessMemory())
- CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin);
+ CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoBuiltin);
}
diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp
index 00cda2af00c62..e7ba19665d591 100644
--- a/lib/Transforms/Utils/LoopSimplify.cpp
+++ b/lib/Transforms/Utils/LoopSimplify.cpp
@@ -645,14 +645,7 @@ ReprocessLoop:
// loop-invariant instructions out of the way to open up more
// opportunities, and the disadvantage of having the responsibility
// to preserve dominator information.
- bool UniqueExit = true;
- if (!ExitBlocks.empty())
- for (unsigned i = 1, e = ExitBlocks.size(); i != e; ++i)
- if (ExitBlocks[i] != ExitBlocks[0]) {
- UniqueExit = false;
- break;
- }
- if (UniqueExit) {
+ if (ExitBlockSet.size() == 1) {
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
BasicBlock *ExitingBlock = ExitingBlocks[i];
if (!ExitingBlock->getSinglePredecessor()) continue;
@@ -735,6 +728,17 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI,
bool PreserveLCSSA) {
bool Changed = false;
+#ifndef NDEBUG
+ // If we're asked to preserve LCSSA, the loop nest needs to start in LCSSA
+ // form.
+ if (PreserveLCSSA) {
+ assert(DT && "DT not available.");
+ assert(LI && "LI not available.");
+ assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
+ "Requested to preserve LCSSA, but it's already broken.");
+ }
+#endif
+
// Worklist maintains our depth-first queue of loops in this nest to process.
SmallVector<Loop *, 4> Worklist;
Worklist.push_back(L);
@@ -814,15 +818,6 @@ bool LoopSimplify::runOnFunction(Function &F) {
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
-#ifndef NDEBUG
- if (PreserveLCSSA) {
- assert(DT && "DT not available.");
- assert(LI && "LI not available.");
- bool InLCSSA = all_of(
- *LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT, *LI); });
- assert(InLCSSA && "Requested to preserve LCSSA, but it's already broken.");
- }
-#endif
// Simplify each loop nest in the function.
for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
@@ -846,17 +841,14 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F,
ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F);
AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
- // FIXME: This pass should verify that the loops on which it's operating
- // are in canonical SSA form, and that the pass itself preserves this form.
+ // Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA
+ // after simplifying the loops.
for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
- Changed |= simplifyLoop(*I, DT, LI, SE, AC, true /* PreserveLCSSA */);
-
- // FIXME: We need to invalidate this to avoid PR28400. Is there a better
- // solution?
- AM.invalidate<ScalarEvolutionAnalysis>(F);
+ Changed |= simplifyLoop(*I, DT, LI, SE, AC, /*PreserveLCSSA*/ false);
if (!Changed)
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<LoopAnalysis>();
diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp
index e346ebd6a000a..3c669ce644e20 100644
--- a/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/lib/Transforms/Utils/LoopUnroll.cpp
@@ -27,6 +27,7 @@
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
@@ -51,6 +52,16 @@ UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden,
cl::desc("Allow runtime unrolled loops to be unrolled "
"with epilog instead of prolog."));
+static cl::opt<bool>
+UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden,
+ cl::desc("Verify domtree after unrolling"),
+#ifdef NDEBUG
+ cl::init(false)
+#else
+ cl::init(true)
+#endif
+ );
+
/// Convert the instruction operands from referencing the current values into
/// those specified by VMap.
static inline void remapInstruction(Instruction *I,
@@ -205,6 +216,45 @@ const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB,
}
}
+/// The function chooses which type of unroll (epilog or prolog) is more
+/// profitabale.
+/// Epilog unroll is more profitable when there is PHI that starts from
+/// constant. In this case epilog will leave PHI start from constant,
+/// but prolog will convert it to non-constant.
+///
+/// loop:
+/// PN = PHI [I, Latch], [CI, PreHeader]
+/// I = foo(PN)
+/// ...
+///
+/// Epilog unroll case.
+/// loop:
+/// PN = PHI [I2, Latch], [CI, PreHeader]
+/// I1 = foo(PN)
+/// I2 = foo(I1)
+/// ...
+/// Prolog unroll case.
+/// NewPN = PHI [PrologI, Prolog], [CI, PreHeader]
+/// loop:
+/// PN = PHI [I2, Latch], [NewPN, PreHeader]
+/// I1 = foo(PN)
+/// I2 = foo(I1)
+/// ...
+///
+static bool isEpilogProfitable(Loop *L) {
+ BasicBlock *PreHeader = L->getLoopPreheader();
+ BasicBlock *Header = L->getHeader();
+ assert(PreHeader && Header);
+ for (Instruction &BBI : *Header) {
+ PHINode *PN = dyn_cast<PHINode>(&BBI);
+ if (!PN)
+ break;
+ if (isa<ConstantInt>(PN->getIncomingValueForBlock(PreHeader)))
+ return true;
+ }
+ return false;
+}
+
/// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true
/// if unrolling was successful, or false if the loop was unmodified. Unrolling
/// can only fail when the loop's latch block is not terminated by a conditional
@@ -296,8 +346,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
Count = TripCount;
// Don't enter the unroll code if there is nothing to do.
- if (TripCount == 0 && Count < 2 && PeelCount == 0)
+ if (TripCount == 0 && Count < 2 && PeelCount == 0) {
+ DEBUG(dbgs() << "Won't unroll; almost nothing to do\n");
return false;
+ }
assert(Count > 0);
assert(TripMultiple > 0);
@@ -330,7 +382,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
"and peeling for the same loop");
if (PeelCount)
- peelLoop(L, PeelCount, LI, SE, DT, PreserveLCSSA);
+ peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA);
// Loops containing convergent instructions must have a count that divides
// their TripMultiple.
@@ -346,14 +398,22 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
"convergent operation.");
});
+ bool EpilogProfitability =
+ UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog
+ : isEpilogProfitable(L);
+
if (RuntimeTripCount && TripMultiple % Count != 0 &&
!UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount,
- UnrollRuntimeEpilog, LI, SE, DT,
+ EpilogProfitability, LI, SE, DT,
PreserveLCSSA)) {
if (Force)
RuntimeTripCount = false;
- else
+ else {
+ DEBUG(
+ dbgs() << "Wont unroll; remainder loop could not be generated"
+ "when assuming runtime trip count\n");
return false;
+ }
}
// Notify ScalarEvolution that the loop will be substantially changed,
@@ -446,6 +506,12 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
for (Loop *SubLoop : *L)
LoopsToSimplify.insert(SubLoop);
+ if (Header->getParent()->isDebugInfoForProfiling())
+ for (BasicBlock *BB : L->getBlocks())
+ for (Instruction &I : *BB)
+ if (const DILocation *DIL = I.getDebugLoc())
+ I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count));
+
for (unsigned It = 1; It != Count; ++It) {
std::vector<BasicBlock*> NewBlocks;
SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
@@ -456,19 +522,16 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
Header->getParent()->getBasicBlockList().push_back(New);
+ assert((*BB != Header || LI->getLoopFor(*BB) == L) &&
+ "Header should not be in a sub-loop");
// Tell LI about New.
- if (*BB == Header) {
- assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop");
- L->addBasicBlockToLoop(New, *LI);
- } else {
- const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
- if (OldLoop) {
- LoopsToSimplify.insert(NewLoops[OldLoop]);
+ const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
+ if (OldLoop) {
+ LoopsToSimplify.insert(NewLoops[OldLoop]);
- // Forget the old loop, since its inputs may have changed.
- if (SE)
- SE->forgetLoop(OldLoop);
- }
+ // Forget the old loop, since its inputs may have changed.
+ if (SE)
+ SE->forgetLoop(OldLoop);
}
if (*BB == Header)
@@ -615,14 +678,11 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
Term->eraseFromParent();
}
}
+
// Update dominators of blocks we might reach through exits.
// Immediate dominator of such block might change, because we add more
// routes which can lead to the exit: we can now reach it from the copied
- // iterations too. Thus, the new idom of the block will be the nearest
- // common dominator of the previous idom and common dominator of all copies of
- // the previous idom. This is equivalent to the nearest common dominator of
- // the previous idom and the first latch, which dominates all copies of the
- // previous idom.
+ // iterations too.
if (DT && Count > 1) {
for (auto *BB : OriginalLoopBlocks) {
auto *BBDomNode = DT->getNode(BB);
@@ -632,12 +692,38 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
if (!L->contains(ChildBB))
ChildrenToUpdate.push_back(ChildBB);
}
- BasicBlock *NewIDom = DT->findNearestCommonDominator(BB, Latches[0]);
+ BasicBlock *NewIDom;
+ if (BB == LatchBlock) {
+ // The latch is special because we emit unconditional branches in
+ // some cases where the original loop contained a conditional branch.
+ // Since the latch is always at the bottom of the loop, if the latch
+ // dominated an exit before unrolling, the new dominator of that exit
+ // must also be a latch. Specifically, the dominator is the first
+ // latch which ends in a conditional branch, or the last latch if
+ // there is no such latch.
+ NewIDom = Latches.back();
+ for (BasicBlock *IterLatch : Latches) {
+ TerminatorInst *Term = IterLatch->getTerminator();
+ if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) {
+ NewIDom = IterLatch;
+ break;
+ }
+ }
+ } else {
+ // The new idom of the block will be the nearest common dominator
+ // of all copies of the previous idom. This is equivalent to the
+ // nearest common dominator of the previous idom and the first latch,
+ // which dominates all copies of the previous idom.
+ NewIDom = DT->findNearestCommonDominator(BB, LatchBlock);
+ }
for (auto *ChildBB : ChildrenToUpdate)
DT->changeImmediateDominator(ChildBB, NewIDom);
}
}
+ if (DT && UnrollVerifyDomtree)
+ DT->verifyDomTree();
+
// Merge adjacent basic blocks, if possible.
SmallPtrSet<Loop *, 4> ForgottenLoops;
for (BasicBlock *Latch : Latches) {
@@ -655,13 +741,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
}
}
- // FIXME: We only preserve DT info for complete unrolling now. Incrementally
- // updating domtree after partial loop unrolling should also be easy.
- if (DT && !CompletelyUnroll)
- DT->recalculate(*L->getHeader()->getParent());
- else if (DT)
- DEBUG(DT->verifyDomTree());
-
// Simplify any new induction variables in the partially unrolled loop.
if (SE && !CompletelyUnroll && Count > 1) {
SmallVector<WeakVH, 16> DeadInsts;
@@ -721,29 +800,29 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
// at least one layer outside of the loop that was unrolled so that any
// changes to the parent loop exposed by the unrolling are considered.
if (DT) {
- if (!OuterL && !CompletelyUnroll)
- OuterL = L;
if (OuterL) {
// OuterL includes all loops for which we can break loop-simplify, so
// it's sufficient to simplify only it (it'll recursively simplify inner
// loops too).
+ if (NeedToFixLCSSA) {
+ // LCSSA must be performed on the outermost affected loop. The unrolled
+ // loop's last loop latch is guaranteed to be in the outermost loop
+ // after LoopInfo's been updated by markAsRemoved.
+ Loop *LatchLoop = LI->getLoopFor(Latches.back());
+ Loop *FixLCSSALoop = OuterL;
+ if (!FixLCSSALoop->contains(LatchLoop))
+ while (FixLCSSALoop->getParentLoop() != LatchLoop)
+ FixLCSSALoop = FixLCSSALoop->getParentLoop();
+
+ formLCSSARecursively(*FixLCSSALoop, *DT, LI, SE);
+ } else if (PreserveLCSSA) {
+ assert(OuterL->isLCSSAForm(*DT) &&
+ "Loops should be in LCSSA form after loop-unroll.");
+ }
+
// TODO: That potentially might be compile-time expensive. We should try
// to fix the loop-simplified form incrementally.
simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA);
-
- // LCSSA must be performed on the outermost affected loop. The unrolled
- // loop's last loop latch is guaranteed to be in the outermost loop after
- // LoopInfo's been updated by markAsRemoved.
- Loop *LatchLoop = LI->getLoopFor(Latches.back());
- if (!OuterL->contains(LatchLoop))
- while (OuterL->getParentLoop() != LatchLoop)
- OuterL = OuterL->getParentLoop();
-
- if (NeedToFixLCSSA)
- formLCSSARecursively(*OuterL, *DT, LI, SE);
- else
- assert(OuterL->isLCSSAForm(*DT) &&
- "Loops should be in LCSSA form after loop-unroll.");
} else {
// Simplify loops for which we might've broken loop-simplify form.
for (Loop *SubLoop : LoopsToSimplify)
diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp
index 842cf31f2e3d4..73c14f5606b73 100644
--- a/lib/Transforms/Utils/LoopUnrollPeel.cpp
+++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp
@@ -28,6 +28,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include <algorithm>
@@ -55,12 +56,20 @@ static bool canPeel(Loop *L) {
if (!L->getExitingBlock() || !L->getUniqueExitBlock())
return false;
+ // Don't try to peel loops where the latch is not the exiting block.
+ // This can be an indication of two different things:
+ // 1) The loop is not rotated.
+ // 2) The loop contains irreducible control flow that involves the latch.
+ if (L->getLoopLatch() != L->getExitingBlock())
+ return false;
+
return true;
}
// Return the number of iterations we want to peel off.
void llvm::computePeelCount(Loop *L, unsigned LoopSize,
- TargetTransformInfo::UnrollingPreferences &UP) {
+ TargetTransformInfo::UnrollingPreferences &UP,
+ unsigned &TripCount) {
UP.PeelCount = 0;
if (!canPeel(L))
return;
@@ -69,6 +78,39 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
if (!L->empty())
return;
+ // Try to find a Phi node that has the same loop invariant as an input from
+ // its only back edge. If there is such Phi, peeling 1 iteration from the
+ // loop is profitable, because starting from 2nd iteration we will have an
+ // invariant instead of this Phi.
+ if (LoopSize <= UP.Threshold) {
+ BasicBlock *BackEdge = L->getLoopLatch();
+ assert(BackEdge && "Loop is not in simplified form?");
+ BasicBlock *Header = L->getHeader();
+ // Iterate over Phis to find one with invariant input on back edge.
+ bool FoundCandidate = false;
+ PHINode *Phi;
+ for (auto BI = Header->begin(); isa<PHINode>(&*BI); ++BI) {
+ Phi = cast<PHINode>(&*BI);
+ Value *Input = Phi->getIncomingValueForBlock(BackEdge);
+ if (L->isLoopInvariant(Input)) {
+ FoundCandidate = true;
+ break;
+ }
+ }
+ if (FoundCandidate) {
+ DEBUG(dbgs() << "Peel one iteration to get rid of " << *Phi
+ << " because starting from 2nd iteration it is always"
+ << " an invariant\n");
+ UP.PeelCount = 1;
+ return;
+ }
+ }
+
+ // Bail if we know the statically calculated trip count.
+ // In this case we rather prefer partial unrolling.
+ if (TripCount)
+ return;
+
// If the user provided a peel count, use that.
bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0;
if (UserPeelCount) {
@@ -164,7 +206,8 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,
BasicBlock *InsertBot, BasicBlock *Exit,
SmallVectorImpl<BasicBlock *> &NewBlocks,
LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
- ValueToValueMapTy &LVMap, LoopInfo *LI) {
+ ValueToValueMapTy &LVMap, DominatorTree *DT,
+ LoopInfo *LI) {
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
@@ -185,6 +228,17 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,
ParentLoop->addBasicBlockToLoop(NewBB, *LI);
VMap[*BB] = NewBB;
+
+ // If dominator tree is available, insert nodes to represent cloned blocks.
+ if (DT) {
+ if (Header == *BB)
+ DT->addNewBlock(NewBB, InsertTop);
+ else {
+ DomTreeNode *IDom = DT->getNode(*BB)->getIDom();
+ // VMap must contain entry for IDom, as the iteration order is RPO.
+ DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDom->getBlock()]));
+ }
+ }
}
// Hook-up the control flow for the newly inserted blocks.
@@ -198,11 +252,13 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,
// The backedge now goes to the "bottom", which is either the loop's real
// header (for the last peeled iteration) or the copied header of the next
// iteration (for every other iteration)
- BranchInst *LatchBR =
- cast<BranchInst>(cast<BasicBlock>(VMap[Latch])->getTerminator());
+ BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]);
+ BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator());
unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1);
LatchBR->setSuccessor(HeaderIdx, InsertBot);
LatchBR->setSuccessor(1 - HeaderIdx, Exit);
+ if (DT)
+ DT->changeImmediateDominator(InsertBot, NewLatch);
// The new copy of the loop body starts with a bunch of PHI nodes
// that pick an incoming value from either the preheader, or the previous
@@ -257,7 +313,7 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop,
/// optimizations.
bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
ScalarEvolution *SE, DominatorTree *DT,
- bool PreserveLCSSA) {
+ AssumptionCache *AC, bool PreserveLCSSA) {
if (!canPeel(L))
return false;
@@ -358,7 +414,24 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
CurHeaderWeight = 1;
cloneLoopBlocks(L, Iter, InsertTop, InsertBot, Exit,
- NewBlocks, LoopBlocks, VMap, LVMap, LI);
+ NewBlocks, LoopBlocks, VMap, LVMap, DT, LI);
+
+ // Remap to use values from the current iteration instead of the
+ // previous one.
+ remapInstructionsInBlocks(NewBlocks, VMap);
+
+ if (DT) {
+ // Latches of the cloned loops dominate over the loop exit, so idom of the
+ // latter is the first cloned loop body, as original PreHeader dominates
+ // the original loop body.
+ if (Iter == 0)
+ DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch]));
+#ifndef NDEBUG
+ if (VerifyDomInfo)
+ DT->verifyDomTree();
+#endif
+ }
+
updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter,
PeelCount, ExitWeight);
@@ -369,10 +442,6 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
F->getBasicBlockList().splice(InsertTop->getIterator(),
F->getBasicBlockList(),
NewBlocks[0]->getIterator(), F->end());
-
- // Remap to use values from the current iteration instead of the
- // previous one.
- remapInstructionsInBlocks(NewBlocks, VMap);
}
// Now adjust the phi nodes in the loop header to get their initial values
@@ -405,9 +474,16 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
}
// If the loop is nested, we changed the parent loop, update SE.
- if (Loop *ParentLoop = L->getParentLoop())
+ if (Loop *ParentLoop = L->getParentLoop()) {
SE->forgetLoop(ParentLoop);
+ // FIXME: Incrementally update loop-simplify
+ simplifyLoop(ParentLoop, DT, LI, SE, AC, PreserveLCSSA);
+ } else {
+ // FIXME: Incrementally update loop-simplify
+ simplifyLoop(L, DT, LI, SE, AC, PreserveLCSSA);
+ }
+
NumPeeled++;
return true;
diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index d3ea1564115bb..85db734fb1827 100644
--- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -146,6 +146,8 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
// Add the branch to the exit block (around the unrolled loop)
B.CreateCondBr(BrLoopExit, Exit, NewPreHeader);
InsertPt->eraseFromParent();
+ if (DT)
+ DT->changeImmediateDominator(Exit, PrologExit);
}
/// Connect the unrolling epilog code to the original loop.
@@ -260,13 +262,20 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
IRBuilder<> B(InsertPt);
Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod");
assert(Exit && "Loop must have a single exit block only");
- // Split the exit to maintain loop canonicalization guarantees
+ // Split the epilogue exit to maintain loop canonicalization guarantees
SmallVector<BasicBlock*, 4> Preds(predecessors(Exit));
SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI,
PreserveLCSSA);
// Add the branch to the exit block (around the unrolling loop)
B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit);
InsertPt->eraseFromParent();
+ if (DT)
+ DT->changeImmediateDominator(Exit, NewExit);
+
+ // Split the main loop exit to maintain canonicalization guarantees.
+ SmallVector<BasicBlock*, 4> NewExitPreds{Latch};
+ SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI,
+ PreserveLCSSA);
}
/// Create a clone of the blocks in a loop and connect them together.
@@ -284,27 +293,17 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,
BasicBlock *Preheader,
std::vector<BasicBlock *> &NewBlocks,
LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
- LoopInfo *LI) {
+ DominatorTree *DT, LoopInfo *LI) {
StringRef suffix = UseEpilogRemainder ? "epil" : "prol";
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
Function *F = Header->getParent();
LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO();
LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO();
- Loop *NewLoop = nullptr;
Loop *ParentLoop = L->getParentLoop();
- if (CreateRemainderLoop) {
- NewLoop = new Loop();
- if (ParentLoop)
- ParentLoop->addChildLoop(NewLoop);
- else
- LI->addTopLevelLoop(NewLoop);
- }
-
NewLoopsMap NewLoops;
- if (NewLoop)
- NewLoops[L] = NewLoop;
- else if (ParentLoop)
+ NewLoops[ParentLoop] = ParentLoop;
+ if (!CreateRemainderLoop)
NewLoops[L] = ParentLoop;
// For each block in the original loop, create a new copy,
@@ -312,7 +311,7 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,
for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F);
NewBlocks.push_back(NewBB);
-
+
// If we're unrolling the outermost loop, there's no remainder loop,
// and this block isn't in a nested loop, then the new block is not
// in any loop. Otherwise, add it to loopinfo.
@@ -326,6 +325,17 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,
InsertTop->getTerminator()->setSuccessor(0, NewBB);
}
+ if (DT) {
+ if (Header == *BB) {
+ // The header is dominated by the preheader.
+ DT->addNewBlock(NewBB, InsertTop);
+ } else {
+ // Copy information from original loop to unrolled loop.
+ BasicBlock *IDomBB = DT->getNode(*BB)->getIDom()->getBlock();
+ DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDomBB]));
+ }
+ }
+
if (Latch == *BB) {
// For the last block, if CreateRemainderLoop is false, create a direct
// jump to InsertBot. If not, create a loop back to cloned head.
@@ -376,7 +386,9 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter,
NewPHI->setIncomingValue(idx, V);
}
}
- if (NewLoop) {
+ if (CreateRemainderLoop) {
+ Loop *NewLoop = NewLoops[L];
+ assert(NewLoop && "L should have been cloned");
// Add unroll disable metadata to disable future unrolling for this loop.
SmallVector<Metadata *, 4> MDs;
// Reserve first location for self reference to the LoopID metadata node.
@@ -599,6 +611,12 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,
// Branch to either remainder (extra iterations) loop or unrolling loop.
B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop);
PreHeaderBR->eraseFromParent();
+ if (DT) {
+ if (UseEpilogRemainder)
+ DT->changeImmediateDominator(NewExit, PreHeader);
+ else
+ DT->changeImmediateDominator(PrologExit, PreHeader);
+ }
Function *F = Header->getParent();
// Get an ordered list of blocks in the loop to help with the ordering of the
// cloned blocks in the prolog/epilog code
@@ -623,7 +641,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,
BasicBlock *InsertBot = UseEpilogRemainder ? Exit : PrologExit;
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
CloneLoopBlocks(L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop,
- InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, LI);
+ InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);
// Insert the cloned blocks into the function.
F->getBasicBlockList().splice(InsertBot->getIterator(),
diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp
index c8efa9efc7f34..175d013a011d0 100644
--- a/lib/Transforms/Utils/LoopUtils.cpp
+++ b/lib/Transforms/Utils/LoopUtils.cpp
@@ -230,8 +230,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
// - PHI:
// - All uses of the PHI must be the reduction (safe).
// - Otherwise, not safe.
- // - By one instruction outside of the loop (safe).
- // - By further instructions outside of the loop (not safe).
+ // - By instructions outside of the loop (safe).
+ // * One value may have several outside users, but all outside
+ // uses must be of the same value.
// - By an instruction that is not part of the reduction (not safe).
// This is either:
// * An instruction type other than PHI or the reduction operation.
@@ -297,10 +298,15 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
// Check if we found the exit user.
BasicBlock *Parent = UI->getParent();
if (!TheLoop->contains(Parent)) {
- // Exit if you find multiple outside users or if the header phi node is
- // being used. In this case the user uses the value of the previous
- // iteration, in which case we would loose "VF-1" iterations of the
- // reduction operation if we vectorize.
+ // If we already know this instruction is used externally, move on to
+ // the next user.
+ if (ExitInstruction == Cur)
+ continue;
+
+ // Exit if you find multiple values used outside or if the header phi
+ // node is being used. In this case the user uses the value of the
+ // previous iteration, in which case we would loose "VF-1" iterations of
+ // the reduction operation if we vectorize.
if (ExitInstruction != nullptr || Cur == Phi)
return false;
@@ -547,13 +553,14 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence(PHINode *Phi, Loop *TheLoop,
if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous))
return false;
- // Ensure every user of the phi node is dominated by the previous value. The
- // dominance requirement ensures the loop vectorizer will not need to
+ // Ensure every user of the phi node is dominated by the previous value.
+ // The dominance requirement ensures the loop vectorizer will not need to
// vectorize the initial value prior to the first iteration of the loop.
for (User *U : Phi->users())
- if (auto *I = dyn_cast<Instruction>(U))
+ if (auto *I = dyn_cast<Instruction>(U)) {
if (!DT->dominates(Previous, I))
return false;
+ }
return true;
}
diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp
new file mode 100644
index 0000000000000..c7cb561b5e21d
--- /dev/null
+++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp
@@ -0,0 +1,231 @@
+//===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
+
+using namespace llvm;
+
+void llvm::createMemCpyLoop(Instruction *InsertBefore,
+ Value *SrcAddr, Value *DstAddr, Value *CopyLen,
+ unsigned SrcAlign, unsigned DestAlign,
+ bool SrcIsVolatile, bool DstIsVolatile) {
+ Type *TypeOfCopyLen = CopyLen->getType();
+
+ BasicBlock *OrigBB = InsertBefore->getParent();
+ Function *F = OrigBB->getParent();
+ BasicBlock *NewBB =
+ InsertBefore->getParent()->splitBasicBlock(InsertBefore, "split");
+ BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "loadstoreloop",
+ F, NewBB);
+
+ OrigBB->getTerminator()->setSuccessor(0, LoopBB);
+ IRBuilder<> Builder(OrigBB->getTerminator());
+
+ // SrcAddr and DstAddr are expected to be pointer types,
+ // so no check is made here.
+ unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
+ unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+
+ // Cast pointers to (char *)
+ SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS));
+ DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS));
+
+ IRBuilder<> LoopBuilder(LoopBB);
+ PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
+ LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB);
+
+ // load from SrcAddr+LoopIndex
+ // TODO: we can leverage the align parameter of llvm.memcpy for more efficient
+ // word-sized loads and stores.
+ Value *Element =
+ LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP(
+ LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex),
+ SrcIsVolatile);
+ // store at DstAddr+LoopIndex
+ LoopBuilder.CreateStore(Element,
+ LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(),
+ DstAddr, LoopIndex),
+ DstIsVolatile);
+
+ // The value for LoopIndex coming from backedge is (LoopIndex + 1)
+ Value *NewIndex =
+ LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1));
+ LoopIndex->addIncoming(NewIndex, LoopBB);
+
+ LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
+ NewBB);
+}
+
+// Lower memmove to IR. memmove is required to correctly copy overlapping memory
+// regions; therefore, it has to check the relative positions of the source and
+// destination pointers and choose the copy direction accordingly.
+//
+// The code below is an IR rendition of this C function:
+//
+// void* memmove(void* dst, const void* src, size_t n) {
+// unsigned char* d = dst;
+// const unsigned char* s = src;
+// if (s < d) {
+// // copy backwards
+// while (n--) {
+// d[n] = s[n];
+// }
+// } else {
+// // copy forward
+// for (size_t i = 0; i < n; ++i) {
+// d[i] = s[i];
+// }
+// }
+// return dst;
+// }
+static void createMemMoveLoop(Instruction *InsertBefore,
+ Value *SrcAddr, Value *DstAddr, Value *CopyLen,
+ unsigned SrcAlign, unsigned DestAlign,
+ bool SrcIsVolatile, bool DstIsVolatile) {
+ Type *TypeOfCopyLen = CopyLen->getType();
+ BasicBlock *OrigBB = InsertBefore->getParent();
+ Function *F = OrigBB->getParent();
+
+ // Create the a comparison of src and dst, based on which we jump to either
+ // the forward-copy part of the function (if src >= dst) or the backwards-copy
+ // part (if src < dst).
+ // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
+ // structure. Its block terminators (unconditional branches) are replaced by
+ // the appropriate conditional branches when the loop is built.
+ ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT,
+ SrcAddr, DstAddr, "compare_src_dst");
+ TerminatorInst *ThenTerm, *ElseTerm;
+ SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm,
+ &ElseTerm);
+
+ // Each part of the function consists of two blocks:
+ // copy_backwards: used to skip the loop when n == 0
+ // copy_backwards_loop: the actual backwards loop BB
+ // copy_forward: used to skip the loop when n == 0
+ // copy_forward_loop: the actual forward loop BB
+ BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
+ CopyBackwardsBB->setName("copy_backwards");
+ BasicBlock *CopyForwardBB = ElseTerm->getParent();
+ CopyForwardBB->setName("copy_forward");
+ BasicBlock *ExitBB = InsertBefore->getParent();
+ ExitBB->setName("memmove_done");
+
+ // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
+ // between both backwards and forward copy clauses.
+ ICmpInst *CompareN =
+ new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen,
+ ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
+
+ // Copying backwards.
+ BasicBlock *LoopBB =
+ BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
+ IRBuilder<> LoopBuilder(LoopBB);
+ PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
+ Value *IndexPtr = LoopBuilder.CreateSub(
+ LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
+ Value *Element = LoopBuilder.CreateLoad(
+ LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element");
+ LoopBuilder.CreateStore(Element,
+ LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr));
+ LoopBuilder.CreateCondBr(
+ LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
+ ExitBB, LoopBB);
+ LoopPhi->addIncoming(IndexPtr, LoopBB);
+ LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
+ BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
+ ThenTerm->eraseFromParent();
+
+ // Copying forward.
+ BasicBlock *FwdLoopBB =
+ BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB);
+ IRBuilder<> FwdLoopBuilder(FwdLoopBB);
+ PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
+ Value *FwdElement = FwdLoopBuilder.CreateLoad(
+ FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element");
+ FwdLoopBuilder.CreateStore(
+ FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi));
+ Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
+ FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
+ FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
+ ExitBB, FwdLoopBB);
+ FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
+ FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
+
+ BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
+ ElseTerm->eraseFromParent();
+}
+
+static void createMemSetLoop(Instruction *InsertBefore,
+ Value *DstAddr, Value *CopyLen, Value *SetValue,
+ unsigned Align, bool IsVolatile) {
+ BasicBlock *OrigBB = InsertBefore->getParent();
+ Function *F = OrigBB->getParent();
+ BasicBlock *NewBB =
+ OrigBB->splitBasicBlock(InsertBefore, "split");
+ BasicBlock *LoopBB
+ = BasicBlock::Create(F->getContext(), "loadstoreloop", F, NewBB);
+
+ OrigBB->getTerminator()->setSuccessor(0, LoopBB);
+ IRBuilder<> Builder(OrigBB->getTerminator());
+
+ // Cast pointer to the type of value getting stored
+ unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+ DstAddr = Builder.CreateBitCast(DstAddr,
+ PointerType::get(SetValue->getType(), dstAS));
+
+ IRBuilder<> LoopBuilder(LoopBB);
+ PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0);
+ LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB);
+
+ LoopBuilder.CreateStore(
+ SetValue,
+ LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex),
+ IsVolatile);
+
+ Value *NewIndex =
+ LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1));
+ LoopIndex->addIncoming(NewIndex, LoopBB);
+
+ LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
+ NewBB);
+}
+
+void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy) {
+ createMemCpyLoop(/* InsertBefore */ Memcpy,
+ /* SrcAddr */ Memcpy->getRawSource(),
+ /* DstAddr */ Memcpy->getRawDest(),
+ /* CopyLen */ Memcpy->getLength(),
+ /* SrcAlign */ Memcpy->getAlignment(),
+ /* DestAlign */ Memcpy->getAlignment(),
+ /* SrcIsVolatile */ Memcpy->isVolatile(),
+ /* DstIsVolatile */ Memcpy->isVolatile());
+}
+
+void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) {
+ createMemMoveLoop(/* InsertBefore */ Memmove,
+ /* SrcAddr */ Memmove->getRawSource(),
+ /* DstAddr */ Memmove->getRawDest(),
+ /* CopyLen */ Memmove->getLength(),
+ /* SrcAlign */ Memmove->getAlignment(),
+ /* DestAlign */ Memmove->getAlignment(),
+ /* SrcIsVolatile */ Memmove->isVolatile(),
+ /* DstIsVolatile */ Memmove->isVolatile());
+}
+
+void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
+ createMemSetLoop(/* InsertBefore */ Memset,
+ /* DstAddr */ Memset->getRawDest(),
+ /* CopyLen */ Memset->getLength(),
+ /* SetValue */ Memset->getValue(),
+ /* Alignment */ Memset->getAlignment(),
+ Memset->isVolatile());
+}
diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp
index 75cd3bc8b2bfb..b375d51005d57 100644
--- a/lib/Transforms/Utils/LowerSwitch.cpp
+++ b/lib/Transforms/Utils/LowerSwitch.cpp
@@ -356,10 +356,10 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
unsigned numCmps = 0;
// Start with "simple" cases
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i)
- Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(),
- i.getCaseSuccessor()));
-
+ for (auto Case : SI->cases())
+ Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
+ Case.getCaseSuccessor()));
+
std::sort(Cases.begin(), Cases.end(), CaseCmp());
// Merge case into clusters
diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp
index 24b3b12930ac5..b659a2e4463ff 100644
--- a/lib/Transforms/Utils/Mem2Reg.cpp
+++ b/lib/Transforms/Utils/Mem2Reg.cpp
@@ -46,7 +46,7 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT,
if (Allocas.empty())
break;
- PromoteMemToReg(Allocas, DT, nullptr, &AC);
+ PromoteMemToReg(Allocas, DT, &AC);
NumPromoted += Allocas.size();
Changed = true;
}
@@ -59,8 +59,9 @@ PreservedAnalyses PromotePass::run(Function &F, FunctionAnalysisManager &AM) {
if (!promoteMemoryToRegister(F, DT, AC))
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- return PreservedAnalyses::none();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
namespace {
diff --git a/lib/Transforms/Utils/MemorySSA.cpp b/lib/Transforms/Utils/MemorySSA.cpp
deleted file mode 100644
index 1ce4225f09cc0..0000000000000
--- a/lib/Transforms/Utils/MemorySSA.cpp
+++ /dev/null
@@ -1,2305 +0,0 @@
-//===-- MemorySSA.cpp - Memory SSA Builder---------------------------===//
-//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
-//
-//===----------------------------------------------------------------===//
-//
-// This file implements the MemorySSA class.
-//
-//===----------------------------------------------------------------===//
-#include "llvm/Transforms/Utils/MemorySSA.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/DepthFirstIterator.h"
-#include "llvm/ADT/GraphTraits.h"
-#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/ADT/SmallSet.h"
-#include "llvm/ADT/Statistic.h"
-#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/CFG.h"
-#include "llvm/Analysis/GlobalsModRef.h"
-#include "llvm/Analysis/IteratedDominanceFrontier.h"
-#include "llvm/Analysis/MemoryLocation.h"
-#include "llvm/Analysis/PHITransAddr.h"
-#include "llvm/IR/AssemblyAnnotationWriter.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/Dominators.h"
-#include "llvm/IR/GlobalVariable.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/Metadata.h"
-#include "llvm/IR/Module.h"
-#include "llvm/IR/PatternMatch.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/FormattedStream.h"
-#include "llvm/Transforms/Scalar.h"
-#include <algorithm>
-
-#define DEBUG_TYPE "memoryssa"
-using namespace llvm;
-STATISTIC(NumClobberCacheLookups, "Number of Memory SSA version cache lookups");
-STATISTIC(NumClobberCacheHits, "Number of Memory SSA version cache hits");
-STATISTIC(NumClobberCacheInserts, "Number of MemorySSA version cache inserts");
-
-INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false,
- true)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false,
- true)
-
-INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa",
- "Memory SSA Printer", false, false)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa",
- "Memory SSA Printer", false, false)
-
-static cl::opt<unsigned> MaxCheckLimit(
- "memssa-check-limit", cl::Hidden, cl::init(100),
- cl::desc("The maximum number of stores/phis MemorySSA"
- "will consider trying to walk past (default = 100)"));
-
-static cl::opt<bool>
- VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden,
- cl::desc("Verify MemorySSA in legacy printer pass."));
-
-namespace llvm {
-/// \brief An assembly annotator class to print Memory SSA information in
-/// comments.
-class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter {
- friend class MemorySSA;
- const MemorySSA *MSSA;
-
-public:
- MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {}
-
- virtual void emitBasicBlockStartAnnot(const BasicBlock *BB,
- formatted_raw_ostream &OS) {
- if (MemoryAccess *MA = MSSA->getMemoryAccess(BB))
- OS << "; " << *MA << "\n";
- }
-
- virtual void emitInstructionAnnot(const Instruction *I,
- formatted_raw_ostream &OS) {
- if (MemoryAccess *MA = MSSA->getMemoryAccess(I))
- OS << "; " << *MA << "\n";
- }
-};
-}
-
-namespace {
-/// Our current alias analysis API differentiates heavily between calls and
-/// non-calls, and functions called on one usually assert on the other.
-/// This class encapsulates the distinction to simplify other code that wants
-/// "Memory affecting instructions and related data" to use as a key.
-/// For example, this class is used as a densemap key in the use optimizer.
-class MemoryLocOrCall {
-public:
- MemoryLocOrCall() : IsCall(false) {}
- MemoryLocOrCall(MemoryUseOrDef *MUD)
- : MemoryLocOrCall(MUD->getMemoryInst()) {}
- MemoryLocOrCall(const MemoryUseOrDef *MUD)
- : MemoryLocOrCall(MUD->getMemoryInst()) {}
-
- MemoryLocOrCall(Instruction *Inst) {
- if (ImmutableCallSite(Inst)) {
- IsCall = true;
- CS = ImmutableCallSite(Inst);
- } else {
- IsCall = false;
- // There is no such thing as a memorylocation for a fence inst, and it is
- // unique in that regard.
- if (!isa<FenceInst>(Inst))
- Loc = MemoryLocation::get(Inst);
- }
- }
-
- explicit MemoryLocOrCall(const MemoryLocation &Loc)
- : IsCall(false), Loc(Loc) {}
-
- bool IsCall;
- ImmutableCallSite getCS() const {
- assert(IsCall);
- return CS;
- }
- MemoryLocation getLoc() const {
- assert(!IsCall);
- return Loc;
- }
-
- bool operator==(const MemoryLocOrCall &Other) const {
- if (IsCall != Other.IsCall)
- return false;
-
- if (IsCall)
- return CS.getCalledValue() == Other.CS.getCalledValue();
- return Loc == Other.Loc;
- }
-
-private:
- union {
- ImmutableCallSite CS;
- MemoryLocation Loc;
- };
-};
-}
-
-namespace llvm {
-template <> struct DenseMapInfo<MemoryLocOrCall> {
- static inline MemoryLocOrCall getEmptyKey() {
- return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey());
- }
- static inline MemoryLocOrCall getTombstoneKey() {
- return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey());
- }
- static unsigned getHashValue(const MemoryLocOrCall &MLOC) {
- if (MLOC.IsCall)
- return hash_combine(MLOC.IsCall,
- DenseMapInfo<const Value *>::getHashValue(
- MLOC.getCS().getCalledValue()));
- return hash_combine(
- MLOC.IsCall, DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc()));
- }
- static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) {
- return LHS == RHS;
- }
-};
-
-enum class Reorderability { Always, IfNoAlias, Never };
-
-/// This does one-way checks to see if Use could theoretically be hoisted above
-/// MayClobber. This will not check the other way around.
-///
-/// This assumes that, for the purposes of MemorySSA, Use comes directly after
-/// MayClobber, with no potentially clobbering operations in between them.
-/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.)
-static Reorderability getLoadReorderability(const LoadInst *Use,
- const LoadInst *MayClobber) {
- bool VolatileUse = Use->isVolatile();
- bool VolatileClobber = MayClobber->isVolatile();
- // Volatile operations may never be reordered with other volatile operations.
- if (VolatileUse && VolatileClobber)
- return Reorderability::Never;
-
- // The lang ref allows reordering of volatile and non-volatile operations.
- // Whether an aliasing nonvolatile load and volatile load can be reordered,
- // though, is ambiguous. Because it may not be best to exploit this ambiguity,
- // we only allow volatile/non-volatile reordering if the volatile and
- // non-volatile operations don't alias.
- Reorderability Result = VolatileUse || VolatileClobber
- ? Reorderability::IfNoAlias
- : Reorderability::Always;
-
- // If a load is seq_cst, it cannot be moved above other loads. If its ordering
- // is weaker, it can be moved above other loads. We just need to be sure that
- // MayClobber isn't an acquire load, because loads can't be moved above
- // acquire loads.
- //
- // Note that this explicitly *does* allow the free reordering of monotonic (or
- // weaker) loads of the same address.
- bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent;
- bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(),
- AtomicOrdering::Acquire);
- if (SeqCstUse || MayClobberIsAcquire)
- return Reorderability::Never;
- return Result;
-}
-
-static bool instructionClobbersQuery(MemoryDef *MD,
- const MemoryLocation &UseLoc,
- const Instruction *UseInst,
- AliasAnalysis &AA) {
- Instruction *DefInst = MD->getMemoryInst();
- assert(DefInst && "Defining instruction not actually an instruction");
-
- if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) {
- // These intrinsics will show up as affecting memory, but they are just
- // markers.
- switch (II->getIntrinsicID()) {
- case Intrinsic::lifetime_start:
- case Intrinsic::lifetime_end:
- case Intrinsic::invariant_start:
- case Intrinsic::invariant_end:
- case Intrinsic::assume:
- return false;
- default:
- break;
- }
- }
-
- ImmutableCallSite UseCS(UseInst);
- if (UseCS) {
- ModRefInfo I = AA.getModRefInfo(DefInst, UseCS);
- return I != MRI_NoModRef;
- }
-
- if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) {
- if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) {
- switch (getLoadReorderability(UseLoad, DefLoad)) {
- case Reorderability::Always:
- return false;
- case Reorderability::Never:
- return true;
- case Reorderability::IfNoAlias:
- return !AA.isNoAlias(UseLoc, MemoryLocation::get(DefLoad));
- }
- }
- }
-
- return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod;
-}
-
-static bool instructionClobbersQuery(MemoryDef *MD, const MemoryUseOrDef *MU,
- const MemoryLocOrCall &UseMLOC,
- AliasAnalysis &AA) {
- // FIXME: This is a temporary hack to allow a single instructionClobbersQuery
- // to exist while MemoryLocOrCall is pushed through places.
- if (UseMLOC.IsCall)
- return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(),
- AA);
- return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(),
- AA);
-}
-
-// Return true when MD may alias MU, return false otherwise.
-bool defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU,
- AliasAnalysis &AA) {
- return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA);
-}
-}
-
-namespace {
-struct UpwardsMemoryQuery {
- // True if our original query started off as a call
- bool IsCall;
- // The pointer location we started the query with. This will be empty if
- // IsCall is true.
- MemoryLocation StartingLoc;
- // This is the instruction we were querying about.
- const Instruction *Inst;
- // The MemoryAccess we actually got called with, used to test local domination
- const MemoryAccess *OriginalAccess;
-
- UpwardsMemoryQuery()
- : IsCall(false), Inst(nullptr), OriginalAccess(nullptr) {}
-
- UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access)
- : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) {
- if (!IsCall)
- StartingLoc = MemoryLocation::get(Inst);
- }
-};
-
-static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc,
- AliasAnalysis &AA) {
- Instruction *Inst = MD->getMemoryInst();
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
- switch (II->getIntrinsicID()) {
- case Intrinsic::lifetime_start:
- case Intrinsic::lifetime_end:
- return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc);
- default:
- return false;
- }
- }
- return false;
-}
-
-static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA,
- const Instruction *I) {
- // If the memory can't be changed, then loads of the memory can't be
- // clobbered.
- //
- // FIXME: We should handle invariant groups, as well. It's a bit harder,
- // because we need to pay close attention to invariant group barriers.
- return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) ||
- AA.pointsToConstantMemory(I));
-}
-
-/// Cache for our caching MemorySSA walker.
-class WalkerCache {
- DenseMap<ConstMemoryAccessPair, MemoryAccess *> Accesses;
- DenseMap<const MemoryAccess *, MemoryAccess *> Calls;
-
-public:
- MemoryAccess *lookup(const MemoryAccess *MA, const MemoryLocation &Loc,
- bool IsCall) const {
- ++NumClobberCacheLookups;
- MemoryAccess *R = IsCall ? Calls.lookup(MA) : Accesses.lookup({MA, Loc});
- if (R)
- ++NumClobberCacheHits;
- return R;
- }
-
- bool insert(const MemoryAccess *MA, MemoryAccess *To,
- const MemoryLocation &Loc, bool IsCall) {
- // This is fine for Phis, since there are times where we can't optimize
- // them. Making a def its own clobber is never correct, though.
- assert((MA != To || isa<MemoryPhi>(MA)) &&
- "Something can't clobber itself!");
-
- ++NumClobberCacheInserts;
- bool Inserted;
- if (IsCall)
- Inserted = Calls.insert({MA, To}).second;
- else
- Inserted = Accesses.insert({{MA, Loc}, To}).second;
-
- return Inserted;
- }
-
- bool remove(const MemoryAccess *MA, const MemoryLocation &Loc, bool IsCall) {
- return IsCall ? Calls.erase(MA) : Accesses.erase({MA, Loc});
- }
-
- void clear() {
- Accesses.clear();
- Calls.clear();
- }
-
- bool contains(const MemoryAccess *MA) const {
- for (auto &P : Accesses)
- if (P.first.first == MA || P.second == MA)
- return true;
- for (auto &P : Calls)
- if (P.first == MA || P.second == MA)
- return true;
- return false;
- }
-};
-
-/// Walks the defining uses of MemoryDefs. Stops after we hit something that has
-/// no defining use (e.g. a MemoryPhi or liveOnEntry). Note that, when comparing
-/// against a null def_chain_iterator, this will compare equal only after
-/// walking said Phi/liveOnEntry.
-struct def_chain_iterator
- : public iterator_facade_base<def_chain_iterator, std::forward_iterator_tag,
- MemoryAccess *> {
- def_chain_iterator() : MA(nullptr) {}
- def_chain_iterator(MemoryAccess *MA) : MA(MA) {}
-
- MemoryAccess *operator*() const { return MA; }
-
- def_chain_iterator &operator++() {
- // N.B. liveOnEntry has a null defining access.
- if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA))
- MA = MUD->getDefiningAccess();
- else
- MA = nullptr;
- return *this;
- }
-
- bool operator==(const def_chain_iterator &O) const { return MA == O.MA; }
-
-private:
- MemoryAccess *MA;
-};
-
-static iterator_range<def_chain_iterator>
-def_chain(MemoryAccess *MA, MemoryAccess *UpTo = nullptr) {
-#ifdef EXPENSIVE_CHECKS
- assert((!UpTo || find(def_chain(MA), UpTo) != def_chain_iterator()) &&
- "UpTo isn't in the def chain!");
-#endif
- return make_range(def_chain_iterator(MA), def_chain_iterator(UpTo));
-}
-
-/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing
-/// inbetween `Start` and `ClobberAt` can clobbers `Start`.
-///
-/// This is meant to be as simple and self-contained as possible. Because it
-/// uses no cache, etc., it can be relatively expensive.
-///
-/// \param Start The MemoryAccess that we want to walk from.
-/// \param ClobberAt A clobber for Start.
-/// \param StartLoc The MemoryLocation for Start.
-/// \param MSSA The MemorySSA isntance that Start and ClobberAt belong to.
-/// \param Query The UpwardsMemoryQuery we used for our search.
-/// \param AA The AliasAnalysis we used for our search.
-static void LLVM_ATTRIBUTE_UNUSED
-checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt,
- const MemoryLocation &StartLoc, const MemorySSA &MSSA,
- const UpwardsMemoryQuery &Query, AliasAnalysis &AA) {
- assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?");
-
- if (MSSA.isLiveOnEntryDef(Start)) {
- assert(MSSA.isLiveOnEntryDef(ClobberAt) &&
- "liveOnEntry must clobber itself");
- return;
- }
-
- bool FoundClobber = false;
- DenseSet<MemoryAccessPair> VisitedPhis;
- SmallVector<MemoryAccessPair, 8> Worklist;
- Worklist.emplace_back(Start, StartLoc);
- // Walk all paths from Start to ClobberAt, while looking for clobbers. If one
- // is found, complain.
- while (!Worklist.empty()) {
- MemoryAccessPair MAP = Worklist.pop_back_val();
- // All we care about is that nothing from Start to ClobberAt clobbers Start.
- // We learn nothing from revisiting nodes.
- if (!VisitedPhis.insert(MAP).second)
- continue;
-
- for (MemoryAccess *MA : def_chain(MAP.first)) {
- if (MA == ClobberAt) {
- if (auto *MD = dyn_cast<MemoryDef>(MA)) {
- // instructionClobbersQuery isn't essentially free, so don't use `|=`,
- // since it won't let us short-circuit.
- //
- // Also, note that this can't be hoisted out of the `Worklist` loop,
- // since MD may only act as a clobber for 1 of N MemoryLocations.
- FoundClobber =
- FoundClobber || MSSA.isLiveOnEntryDef(MD) ||
- instructionClobbersQuery(MD, MAP.second, Query.Inst, AA);
- }
- break;
- }
-
- // We should never hit liveOnEntry, unless it's the clobber.
- assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?");
-
- if (auto *MD = dyn_cast<MemoryDef>(MA)) {
- (void)MD;
- assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) &&
- "Found clobber before reaching ClobberAt!");
- continue;
- }
-
- assert(isa<MemoryPhi>(MA));
- Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end());
- }
- }
-
- // If ClobberAt is a MemoryPhi, we can assume something above it acted as a
- // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point.
- assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) &&
- "ClobberAt never acted as a clobber");
-}
-
-/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up
-/// in one class.
-class ClobberWalker {
- /// Save a few bytes by using unsigned instead of size_t.
- using ListIndex = unsigned;
-
- /// Represents a span of contiguous MemoryDefs, potentially ending in a
- /// MemoryPhi.
- struct DefPath {
- MemoryLocation Loc;
- // Note that, because we always walk in reverse, Last will always dominate
- // First. Also note that First and Last are inclusive.
- MemoryAccess *First;
- MemoryAccess *Last;
- Optional<ListIndex> Previous;
-
- DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last,
- Optional<ListIndex> Previous)
- : Loc(Loc), First(First), Last(Last), Previous(Previous) {}
-
- DefPath(const MemoryLocation &Loc, MemoryAccess *Init,
- Optional<ListIndex> Previous)
- : DefPath(Loc, Init, Init, Previous) {}
- };
-
- const MemorySSA &MSSA;
- AliasAnalysis &AA;
- DominatorTree &DT;
- WalkerCache &WC;
- UpwardsMemoryQuery *Query;
- bool UseCache;
-
- // Phi optimization bookkeeping
- SmallVector<DefPath, 32> Paths;
- DenseSet<ConstMemoryAccessPair> VisitedPhis;
- DenseMap<const BasicBlock *, MemoryAccess *> WalkTargetCache;
-
- void setUseCache(bool Use) { UseCache = Use; }
- bool shouldIgnoreCache() const {
- // UseCache will only be false when we're debugging, or when expensive
- // checks are enabled. In either case, we don't care deeply about speed.
- return LLVM_UNLIKELY(!UseCache);
- }
-
- void addCacheEntry(const MemoryAccess *What, MemoryAccess *To,
- const MemoryLocation &Loc) const {
-// EXPENSIVE_CHECKS because most of these queries are redundant.
-#ifdef EXPENSIVE_CHECKS
- assert(MSSA.dominates(To, What));
-#endif
- if (shouldIgnoreCache())
- return;
- WC.insert(What, To, Loc, Query->IsCall);
- }
-
- MemoryAccess *lookupCache(const MemoryAccess *MA, const MemoryLocation &Loc) {
- return shouldIgnoreCache() ? nullptr : WC.lookup(MA, Loc, Query->IsCall);
- }
-
- void cacheDefPath(const DefPath &DN, MemoryAccess *Target) const {
- if (shouldIgnoreCache())
- return;
-
- for (MemoryAccess *MA : def_chain(DN.First, DN.Last))
- addCacheEntry(MA, Target, DN.Loc);
-
- // DefPaths only express the path we walked. So, DN.Last could either be a
- // thing we want to cache, or not.
- if (DN.Last != Target)
- addCacheEntry(DN.Last, Target, DN.Loc);
- }
-
- /// Find the nearest def or phi that `From` can legally be optimized to.
- ///
- /// FIXME: Deduplicate this with MSSA::findDominatingDef. Ideally, MSSA should
- /// keep track of this information for us, and allow us O(1) lookups of this
- /// info.
- MemoryAccess *getWalkTarget(const MemoryPhi *From) {
- assert(From->getNumOperands() && "Phi with no operands?");
-
- BasicBlock *BB = From->getBlock();
- auto At = WalkTargetCache.find(BB);
- if (At != WalkTargetCache.end())
- return At->second;
-
- SmallVector<const BasicBlock *, 8> ToCache;
- ToCache.push_back(BB);
-
- MemoryAccess *Result = MSSA.getLiveOnEntryDef();
- DomTreeNode *Node = DT.getNode(BB);
- while ((Node = Node->getIDom())) {
- auto At = WalkTargetCache.find(BB);
- if (At != WalkTargetCache.end()) {
- Result = At->second;
- break;
- }
-
- auto *Accesses = MSSA.getBlockAccesses(Node->getBlock());
- if (Accesses) {
- auto Iter = find_if(reverse(*Accesses), [](const MemoryAccess &MA) {
- return !isa<MemoryUse>(MA);
- });
- if (Iter != Accesses->rend()) {
- Result = const_cast<MemoryAccess *>(&*Iter);
- break;
- }
- }
-
- ToCache.push_back(Node->getBlock());
- }
-
- for (const BasicBlock *BB : ToCache)
- WalkTargetCache.insert({BB, Result});
- return Result;
- }
-
- /// Result of calling walkToPhiOrClobber.
- struct UpwardsWalkResult {
- /// The "Result" of the walk. Either a clobber, the last thing we walked, or
- /// both.
- MemoryAccess *Result;
- bool IsKnownClobber;
- bool FromCache;
- };
-
- /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last.
- /// This will update Desc.Last as it walks. It will (optionally) also stop at
- /// StopAt.
- ///
- /// This does not test for whether StopAt is a clobber
- UpwardsWalkResult walkToPhiOrClobber(DefPath &Desc,
- MemoryAccess *StopAt = nullptr) {
- assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world");
-
- for (MemoryAccess *Current : def_chain(Desc.Last)) {
- Desc.Last = Current;
- if (Current == StopAt)
- return {Current, false, false};
-
- if (auto *MD = dyn_cast<MemoryDef>(Current))
- if (MSSA.isLiveOnEntryDef(MD) ||
- instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA))
- return {MD, true, false};
-
- // Cache checks must be done last, because if Current is a clobber, the
- // cache will contain the clobber for Current.
- if (MemoryAccess *MA = lookupCache(Current, Desc.Loc))
- return {MA, true, true};
- }
-
- assert(isa<MemoryPhi>(Desc.Last) &&
- "Ended at a non-clobber that's not a phi?");
- return {Desc.Last, false, false};
- }
-
- void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches,
- ListIndex PriorNode) {
- auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}),
- upward_defs_end());
- for (const MemoryAccessPair &P : UpwardDefs) {
- PausedSearches.push_back(Paths.size());
- Paths.emplace_back(P.second, P.first, PriorNode);
- }
- }
-
- /// Represents a search that terminated after finding a clobber. This clobber
- /// may or may not be present in the path of defs from LastNode..SearchStart,
- /// since it may have been retrieved from cache.
- struct TerminatedPath {
- MemoryAccess *Clobber;
- ListIndex LastNode;
- };
-
- /// Get an access that keeps us from optimizing to the given phi.
- ///
- /// PausedSearches is an array of indices into the Paths array. Its incoming
- /// value is the indices of searches that stopped at the last phi optimization
- /// target. It's left in an unspecified state.
- ///
- /// If this returns None, NewPaused is a vector of searches that terminated
- /// at StopWhere. Otherwise, NewPaused is left in an unspecified state.
- Optional<TerminatedPath>
- getBlockingAccess(MemoryAccess *StopWhere,
- SmallVectorImpl<ListIndex> &PausedSearches,
- SmallVectorImpl<ListIndex> &NewPaused,
- SmallVectorImpl<TerminatedPath> &Terminated) {
- assert(!PausedSearches.empty() && "No searches to continue?");
-
- // BFS vs DFS really doesn't make a difference here, so just do a DFS with
- // PausedSearches as our stack.
- while (!PausedSearches.empty()) {
- ListIndex PathIndex = PausedSearches.pop_back_val();
- DefPath &Node = Paths[PathIndex];
-
- // If we've already visited this path with this MemoryLocation, we don't
- // need to do so again.
- //
- // NOTE: That we just drop these paths on the ground makes caching
- // behavior sporadic. e.g. given a diamond:
- // A
- // B C
- // D
- //
- // ...If we walk D, B, A, C, we'll only cache the result of phi
- // optimization for A, B, and D; C will be skipped because it dies here.
- // This arguably isn't the worst thing ever, since:
- // - We generally query things in a top-down order, so if we got below D
- // without needing cache entries for {C, MemLoc}, then chances are
- // that those cache entries would end up ultimately unused.
- // - We still cache things for A, so C only needs to walk up a bit.
- // If this behavior becomes problematic, we can fix without a ton of extra
- // work.
- if (!VisitedPhis.insert({Node.Last, Node.Loc}).second)
- continue;
-
- UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere);
- if (Res.IsKnownClobber) {
- assert(Res.Result != StopWhere || Res.FromCache);
- // If this wasn't a cache hit, we hit a clobber when walking. That's a
- // failure.
- TerminatedPath Term{Res.Result, PathIndex};
- if (!Res.FromCache || !MSSA.dominates(Res.Result, StopWhere))
- return Term;
-
- // Otherwise, it's a valid thing to potentially optimize to.
- Terminated.push_back(Term);
- continue;
- }
-
- if (Res.Result == StopWhere) {
- // We've hit our target. Save this path off for if we want to continue
- // walking.
- NewPaused.push_back(PathIndex);
- continue;
- }
-
- assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber");
- addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex);
- }
-
- return None;
- }
-
- template <typename T, typename Walker>
- struct generic_def_path_iterator
- : public iterator_facade_base<generic_def_path_iterator<T, Walker>,
- std::forward_iterator_tag, T *> {
- generic_def_path_iterator() : W(nullptr), N(None) {}
- generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {}
-
- T &operator*() const { return curNode(); }
-
- generic_def_path_iterator &operator++() {
- N = curNode().Previous;
- return *this;
- }
-
- bool operator==(const generic_def_path_iterator &O) const {
- if (N.hasValue() != O.N.hasValue())
- return false;
- return !N.hasValue() || *N == *O.N;
- }
-
- private:
- T &curNode() const { return W->Paths[*N]; }
-
- Walker *W;
- Optional<ListIndex> N;
- };
-
- using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>;
- using const_def_path_iterator =
- generic_def_path_iterator<const DefPath, const ClobberWalker>;
-
- iterator_range<def_path_iterator> def_path(ListIndex From) {
- return make_range(def_path_iterator(this, From), def_path_iterator());
- }
-
- iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const {
- return make_range(const_def_path_iterator(this, From),
- const_def_path_iterator());
- }
-
- struct OptznResult {
- /// The path that contains our result.
- TerminatedPath PrimaryClobber;
- /// The paths that we can legally cache back from, but that aren't
- /// necessarily the result of the Phi optimization.
- SmallVector<TerminatedPath, 4> OtherClobbers;
- };
-
- ListIndex defPathIndex(const DefPath &N) const {
- // The assert looks nicer if we don't need to do &N
- const DefPath *NP = &N;
- assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() &&
- "Out of bounds DefPath!");
- return NP - &Paths.front();
- }
-
- /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths
- /// that act as legal clobbers. Note that this won't return *all* clobbers.
- ///
- /// Phi optimization algorithm tl;dr:
- /// - Find the earliest def/phi, A, we can optimize to
- /// - Find if all paths from the starting memory access ultimately reach A
- /// - If not, optimization isn't possible.
- /// - Otherwise, walk from A to another clobber or phi, A'.
- /// - If A' is a def, we're done.
- /// - If A' is a phi, try to optimize it.
- ///
- /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path
- /// terminates when a MemoryAccess that clobbers said MemoryLocation is found.
- OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start,
- const MemoryLocation &Loc) {
- assert(Paths.empty() && VisitedPhis.empty() &&
- "Reset the optimization state.");
-
- Paths.emplace_back(Loc, Start, Phi, None);
- // Stores how many "valid" optimization nodes we had prior to calling
- // addSearches/getBlockingAccess. Necessary for caching if we had a blocker.
- auto PriorPathsSize = Paths.size();
-
- SmallVector<ListIndex, 16> PausedSearches;
- SmallVector<ListIndex, 8> NewPaused;
- SmallVector<TerminatedPath, 4> TerminatedPaths;
-
- addSearches(Phi, PausedSearches, 0);
-
- // Moves the TerminatedPath with the "most dominated" Clobber to the end of
- // Paths.
- auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) {
- assert(!Paths.empty() && "Need a path to move");
- auto Dom = Paths.begin();
- for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I)
- if (!MSSA.dominates(I->Clobber, Dom->Clobber))
- Dom = I;
- auto Last = Paths.end() - 1;
- if (Last != Dom)
- std::iter_swap(Last, Dom);
- };
-
- MemoryPhi *Current = Phi;
- while (1) {
- assert(!MSSA.isLiveOnEntryDef(Current) &&
- "liveOnEntry wasn't treated as a clobber?");
-
- MemoryAccess *Target = getWalkTarget(Current);
- // If a TerminatedPath doesn't dominate Target, then it wasn't a legal
- // optimization for the prior phi.
- assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) {
- return MSSA.dominates(P.Clobber, Target);
- }));
-
- // FIXME: This is broken, because the Blocker may be reported to be
- // liveOnEntry, and we'll happily wait for that to disappear (read: never)
- // For the moment, this is fine, since we do nothing with blocker info.
- if (Optional<TerminatedPath> Blocker = getBlockingAccess(
- Target, PausedSearches, NewPaused, TerminatedPaths)) {
- // Cache our work on the blocking node, since we know that's correct.
- cacheDefPath(Paths[Blocker->LastNode], Blocker->Clobber);
-
- // Find the node we started at. We can't search based on N->Last, since
- // we may have gone around a loop with a different MemoryLocation.
- auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) {
- return defPathIndex(N) < PriorPathsSize;
- });
- assert(Iter != def_path_iterator());
-
- DefPath &CurNode = *Iter;
- assert(CurNode.Last == Current);
-
- // Two things:
- // A. We can't reliably cache all of NewPaused back. Consider a case
- // where we have two paths in NewPaused; one of which can't optimize
- // above this phi, whereas the other can. If we cache the second path
- // back, we'll end up with suboptimal cache entries. We can handle
- // cases like this a bit better when we either try to find all
- // clobbers that block phi optimization, or when our cache starts
- // supporting unfinished searches.
- // B. We can't reliably cache TerminatedPaths back here without doing
- // extra checks; consider a case like:
- // T
- // / \
- // D C
- // \ /
- // S
- // Where T is our target, C is a node with a clobber on it, D is a
- // diamond (with a clobber *only* on the left or right node, N), and
- // S is our start. Say we walk to D, through the node opposite N
- // (read: ignoring the clobber), and see a cache entry in the top
- // node of D. That cache entry gets put into TerminatedPaths. We then
- // walk up to C (N is later in our worklist), find the clobber, and
- // quit. If we append TerminatedPaths to OtherClobbers, we'll cache
- // the bottom part of D to the cached clobber, ignoring the clobber
- // in N. Again, this problem goes away if we start tracking all
- // blockers for a given phi optimization.
- TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)};
- return {Result, {}};
- }
-
- // If there's nothing left to search, then all paths led to valid clobbers
- // that we got from our cache; pick the nearest to the start, and allow
- // the rest to be cached back.
- if (NewPaused.empty()) {
- MoveDominatedPathToEnd(TerminatedPaths);
- TerminatedPath Result = TerminatedPaths.pop_back_val();
- return {Result, std::move(TerminatedPaths)};
- }
-
- MemoryAccess *DefChainEnd = nullptr;
- SmallVector<TerminatedPath, 4> Clobbers;
- for (ListIndex Paused : NewPaused) {
- UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]);
- if (WR.IsKnownClobber)
- Clobbers.push_back({WR.Result, Paused});
- else
- // Micro-opt: If we hit the end of the chain, save it.
- DefChainEnd = WR.Result;
- }
-
- if (!TerminatedPaths.empty()) {
- // If we couldn't find the dominating phi/liveOnEntry in the above loop,
- // do it now.
- if (!DefChainEnd)
- for (MemoryAccess *MA : def_chain(Target))
- DefChainEnd = MA;
-
- // If any of the terminated paths don't dominate the phi we'll try to
- // optimize, we need to figure out what they are and quit.
- const BasicBlock *ChainBB = DefChainEnd->getBlock();
- for (const TerminatedPath &TP : TerminatedPaths) {
- // Because we know that DefChainEnd is as "high" as we can go, we
- // don't need local dominance checks; BB dominance is sufficient.
- if (DT.dominates(ChainBB, TP.Clobber->getBlock()))
- Clobbers.push_back(TP);
- }
- }
-
- // If we have clobbers in the def chain, find the one closest to Current
- // and quit.
- if (!Clobbers.empty()) {
- MoveDominatedPathToEnd(Clobbers);
- TerminatedPath Result = Clobbers.pop_back_val();
- return {Result, std::move(Clobbers)};
- }
-
- assert(all_of(NewPaused,
- [&](ListIndex I) { return Paths[I].Last == DefChainEnd; }));
-
- // Because liveOnEntry is a clobber, this must be a phi.
- auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd);
-
- PriorPathsSize = Paths.size();
- PausedSearches.clear();
- for (ListIndex I : NewPaused)
- addSearches(DefChainPhi, PausedSearches, I);
- NewPaused.clear();
-
- Current = DefChainPhi;
- }
- }
-
- /// Caches everything in an OptznResult.
- void cacheOptResult(const OptznResult &R) {
- if (R.OtherClobbers.empty()) {
- // If we're not going to be caching OtherClobbers, don't bother with
- // marking visited/etc.
- for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode))
- cacheDefPath(N, R.PrimaryClobber.Clobber);
- return;
- }
-
- // PrimaryClobber is our answer. If we can cache anything back, we need to
- // stop caching when we visit PrimaryClobber.
- SmallBitVector Visited(Paths.size());
- for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) {
- Visited[defPathIndex(N)] = true;
- cacheDefPath(N, R.PrimaryClobber.Clobber);
- }
-
- for (const TerminatedPath &P : R.OtherClobbers) {
- for (const DefPath &N : const_def_path(P.LastNode)) {
- ListIndex NIndex = defPathIndex(N);
- if (Visited[NIndex])
- break;
- Visited[NIndex] = true;
- cacheDefPath(N, P.Clobber);
- }
- }
- }
-
- void verifyOptResult(const OptznResult &R) const {
- assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) {
- return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber);
- }));
- }
-
- void resetPhiOptznState() {
- Paths.clear();
- VisitedPhis.clear();
- }
-
-public:
- ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT,
- WalkerCache &WC)
- : MSSA(MSSA), AA(AA), DT(DT), WC(WC), UseCache(true) {}
-
- void reset() { WalkTargetCache.clear(); }
-
- /// Finds the nearest clobber for the given query, optimizing phis if
- /// possible.
- MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q,
- bool UseWalkerCache = true) {
- setUseCache(UseWalkerCache);
- Query = &Q;
-
- MemoryAccess *Current = Start;
- // This walker pretends uses don't exist. If we're handed one, silently grab
- // its def. (This has the nice side-effect of ensuring we never cache uses)
- if (auto *MU = dyn_cast<MemoryUse>(Start))
- Current = MU->getDefiningAccess();
-
- DefPath FirstDesc(Q.StartingLoc, Current, Current, None);
- // Fast path for the overly-common case (no crazy phi optimization
- // necessary)
- UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc);
- MemoryAccess *Result;
- if (WalkResult.IsKnownClobber) {
- cacheDefPath(FirstDesc, WalkResult.Result);
- Result = WalkResult.Result;
- } else {
- OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last),
- Current, Q.StartingLoc);
- verifyOptResult(OptRes);
- cacheOptResult(OptRes);
- resetPhiOptznState();
- Result = OptRes.PrimaryClobber.Clobber;
- }
-
-#ifdef EXPENSIVE_CHECKS
- checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA);
-#endif
- return Result;
- }
-
- void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); }
-};
-
-struct RenamePassData {
- DomTreeNode *DTN;
- DomTreeNode::const_iterator ChildIt;
- MemoryAccess *IncomingVal;
-
- RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It,
- MemoryAccess *M)
- : DTN(D), ChildIt(It), IncomingVal(M) {}
- void swap(RenamePassData &RHS) {
- std::swap(DTN, RHS.DTN);
- std::swap(ChildIt, RHS.ChildIt);
- std::swap(IncomingVal, RHS.IncomingVal);
- }
-};
-} // anonymous namespace
-
-namespace llvm {
-/// \brief A MemorySSAWalker that does AA walks and caching of lookups to
-/// disambiguate accesses.
-///
-/// FIXME: The current implementation of this can take quadratic space in rare
-/// cases. This can be fixed, but it is something to note until it is fixed.
-///
-/// In order to trigger this behavior, you need to store to N distinct locations
-/// (that AA can prove don't alias), perform M stores to other memory
-/// locations that AA can prove don't alias any of the initial N locations, and
-/// then load from all of the N locations. In this case, we insert M cache
-/// entries for each of the N loads.
-///
-/// For example:
-/// define i32 @foo() {
-/// %a = alloca i32, align 4
-/// %b = alloca i32, align 4
-/// store i32 0, i32* %a, align 4
-/// store i32 0, i32* %b, align 4
-///
-/// ; Insert M stores to other memory that doesn't alias %a or %b here
-///
-/// %c = load i32, i32* %a, align 4 ; Caches M entries in
-/// ; CachedUpwardsClobberingAccess for the
-/// ; MemoryLocation %a
-/// %d = load i32, i32* %b, align 4 ; Caches M entries in
-/// ; CachedUpwardsClobberingAccess for the
-/// ; MemoryLocation %b
-///
-/// ; For completeness' sake, loading %a or %b again would not cache *another*
-/// ; M entries.
-/// %r = add i32 %c, %d
-/// ret i32 %r
-/// }
-class MemorySSA::CachingWalker final : public MemorySSAWalker {
- WalkerCache Cache;
- ClobberWalker Walker;
- bool AutoResetWalker;
-
- MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &);
- void verifyRemoved(MemoryAccess *);
-
-public:
- CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *);
- ~CachingWalker() override;
-
- using MemorySSAWalker::getClobberingMemoryAccess;
- MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override;
- MemoryAccess *getClobberingMemoryAccess(MemoryAccess *,
- const MemoryLocation &) override;
- void invalidateInfo(MemoryAccess *) override;
-
- /// Whether we call resetClobberWalker() after each time we *actually* walk to
- /// answer a clobber query.
- void setAutoResetWalker(bool AutoReset) { AutoResetWalker = AutoReset; }
-
- /// Drop the walker's persistent data structures. At the moment, this means
- /// "drop the walker's cache of BasicBlocks ->
- /// earliest-MemoryAccess-we-can-optimize-to". This is necessary if we're
- /// going to have DT updates, if we remove MemoryAccesses, etc.
- void resetClobberWalker() { Walker.reset(); }
-
- void verify(const MemorySSA *MSSA) override {
- MemorySSAWalker::verify(MSSA);
- Walker.verify(MSSA);
- }
-};
-
-/// \brief Rename a single basic block into MemorySSA form.
-/// Uses the standard SSA renaming algorithm.
-/// \returns The new incoming value.
-MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB,
- MemoryAccess *IncomingVal) {
- auto It = PerBlockAccesses.find(BB);
- // Skip most processing if the list is empty.
- if (It != PerBlockAccesses.end()) {
- AccessList *Accesses = It->second.get();
- for (MemoryAccess &L : *Accesses) {
- if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) {
- if (MUD->getDefiningAccess() == nullptr)
- MUD->setDefiningAccess(IncomingVal);
- if (isa<MemoryDef>(&L))
- IncomingVal = &L;
- } else {
- IncomingVal = &L;
- }
- }
- }
-
- // Pass through values to our successors
- for (const BasicBlock *S : successors(BB)) {
- auto It = PerBlockAccesses.find(S);
- // Rename the phi nodes in our successor block
- if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front()))
- continue;
- AccessList *Accesses = It->second.get();
- auto *Phi = cast<MemoryPhi>(&Accesses->front());
- Phi->addIncoming(IncomingVal, BB);
- }
-
- return IncomingVal;
-}
-
-/// \brief This is the standard SSA renaming algorithm.
-///
-/// We walk the dominator tree in preorder, renaming accesses, and then filling
-/// in phi nodes in our successors.
-void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal,
- SmallPtrSet<BasicBlock *, 16> &Visited) {
- SmallVector<RenamePassData, 32> WorkStack;
- IncomingVal = renameBlock(Root->getBlock(), IncomingVal);
- WorkStack.push_back({Root, Root->begin(), IncomingVal});
- Visited.insert(Root->getBlock());
-
- while (!WorkStack.empty()) {
- DomTreeNode *Node = WorkStack.back().DTN;
- DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt;
- IncomingVal = WorkStack.back().IncomingVal;
-
- if (ChildIt == Node->end()) {
- WorkStack.pop_back();
- } else {
- DomTreeNode *Child = *ChildIt;
- ++WorkStack.back().ChildIt;
- BasicBlock *BB = Child->getBlock();
- Visited.insert(BB);
- IncomingVal = renameBlock(BB, IncomingVal);
- WorkStack.push_back({Child, Child->begin(), IncomingVal});
- }
- }
-}
-
-/// \brief Compute dominator levels, used by the phi insertion algorithm above.
-void MemorySSA::computeDomLevels(DenseMap<DomTreeNode *, unsigned> &DomLevels) {
- for (auto DFI = df_begin(DT->getRootNode()), DFE = df_end(DT->getRootNode());
- DFI != DFE; ++DFI)
- DomLevels[*DFI] = DFI.getPathLength() - 1;
-}
-
-/// \brief This handles unreachable block accesses by deleting phi nodes in
-/// unreachable blocks, and marking all other unreachable MemoryAccess's as
-/// being uses of the live on entry definition.
-void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) {
- assert(!DT->isReachableFromEntry(BB) &&
- "Reachable block found while handling unreachable blocks");
-
- // Make sure phi nodes in our reachable successors end up with a
- // LiveOnEntryDef for our incoming edge, even though our block is forward
- // unreachable. We could just disconnect these blocks from the CFG fully,
- // but we do not right now.
- for (const BasicBlock *S : successors(BB)) {
- if (!DT->isReachableFromEntry(S))
- continue;
- auto It = PerBlockAccesses.find(S);
- // Rename the phi nodes in our successor block
- if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front()))
- continue;
- AccessList *Accesses = It->second.get();
- auto *Phi = cast<MemoryPhi>(&Accesses->front());
- Phi->addIncoming(LiveOnEntryDef.get(), BB);
- }
-
- auto It = PerBlockAccesses.find(BB);
- if (It == PerBlockAccesses.end())
- return;
-
- auto &Accesses = It->second;
- for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) {
- auto Next = std::next(AI);
- // If we have a phi, just remove it. We are going to replace all
- // users with live on entry.
- if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI))
- UseOrDef->setDefiningAccess(LiveOnEntryDef.get());
- else
- Accesses->erase(AI);
- AI = Next;
- }
-}
-
-MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT)
- : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr),
- NextID(INVALID_MEMORYACCESS_ID) {
- buildMemorySSA();
-}
-
-MemorySSA::~MemorySSA() {
- // Drop all our references
- for (const auto &Pair : PerBlockAccesses)
- for (MemoryAccess &MA : *Pair.second)
- MA.dropAllReferences();
-}
-
-MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) {
- auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr));
-
- if (Res.second)
- Res.first->second = make_unique<AccessList>();
- return Res.first->second.get();
-}
-
-/// This class is a batch walker of all MemoryUse's in the program, and points
-/// their defining access at the thing that actually clobbers them. Because it
-/// is a batch walker that touches everything, it does not operate like the
-/// other walkers. This walker is basically performing a top-down SSA renaming
-/// pass, where the version stack is used as the cache. This enables it to be
-/// significantly more time and memory efficient than using the regular walker,
-/// which is walking bottom-up.
-class MemorySSA::OptimizeUses {
-public:
- OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA,
- DominatorTree *DT)
- : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) {
- Walker = MSSA->getWalker();
- }
-
- void optimizeUses();
-
-private:
- /// This represents where a given memorylocation is in the stack.
- struct MemlocStackInfo {
- // This essentially is keeping track of versions of the stack. Whenever
- // the stack changes due to pushes or pops, these versions increase.
- unsigned long StackEpoch;
- unsigned long PopEpoch;
- // This is the lower bound of places on the stack to check. It is equal to
- // the place the last stack walk ended.
- // Note: Correctness depends on this being initialized to 0, which densemap
- // does
- unsigned long LowerBound;
- const BasicBlock *LowerBoundBlock;
- // This is where the last walk for this memory location ended.
- unsigned long LastKill;
- bool LastKillValid;
- };
- void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &,
- SmallVectorImpl<MemoryAccess *> &,
- DenseMap<MemoryLocOrCall, MemlocStackInfo> &);
- MemorySSA *MSSA;
- MemorySSAWalker *Walker;
- AliasAnalysis *AA;
- DominatorTree *DT;
-};
-
-/// Optimize the uses in a given block This is basically the SSA renaming
-/// algorithm, with one caveat: We are able to use a single stack for all
-/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is
-/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just
-/// going to be some position in that stack of possible ones.
-///
-/// We track the stack positions that each MemoryLocation needs
-/// to check, and last ended at. This is because we only want to check the
-/// things that changed since last time. The same MemoryLocation should
-/// get clobbered by the same store (getModRefInfo does not use invariantness or
-/// things like this, and if they start, we can modify MemoryLocOrCall to
-/// include relevant data)
-void MemorySSA::OptimizeUses::optimizeUsesInBlock(
- const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch,
- SmallVectorImpl<MemoryAccess *> &VersionStack,
- DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) {
-
- /// If no accesses, nothing to do.
- MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB);
- if (Accesses == nullptr)
- return;
-
- // Pop everything that doesn't dominate the current block off the stack,
- // increment the PopEpoch to account for this.
- while (!VersionStack.empty()) {
- BasicBlock *BackBlock = VersionStack.back()->getBlock();
- if (DT->dominates(BackBlock, BB))
- break;
- while (VersionStack.back()->getBlock() == BackBlock)
- VersionStack.pop_back();
- ++PopEpoch;
- }
- for (MemoryAccess &MA : *Accesses) {
- auto *MU = dyn_cast<MemoryUse>(&MA);
- if (!MU) {
- VersionStack.push_back(&MA);
- ++StackEpoch;
- continue;
- }
-
- if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) {
- MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true);
- continue;
- }
-
- MemoryLocOrCall UseMLOC(MU);
- auto &LocInfo = LocStackInfo[UseMLOC];
- // If the pop epoch changed, it means we've removed stuff from top of
- // stack due to changing blocks. We may have to reset the lower bound or
- // last kill info.
- if (LocInfo.PopEpoch != PopEpoch) {
- LocInfo.PopEpoch = PopEpoch;
- LocInfo.StackEpoch = StackEpoch;
- // If the lower bound was in something that no longer dominates us, we
- // have to reset it.
- // We can't simply track stack size, because the stack may have had
- // pushes/pops in the meantime.
- // XXX: This is non-optimal, but only is slower cases with heavily
- // branching dominator trees. To get the optimal number of queries would
- // be to make lowerbound and lastkill a per-loc stack, and pop it until
- // the top of that stack dominates us. This does not seem worth it ATM.
- // A much cheaper optimization would be to always explore the deepest
- // branch of the dominator tree first. This will guarantee this resets on
- // the smallest set of blocks.
- if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB &&
- !DT->dominates(LocInfo.LowerBoundBlock, BB)) {
- // Reset the lower bound of things to check.
- // TODO: Some day we should be able to reset to last kill, rather than
- // 0.
- LocInfo.LowerBound = 0;
- LocInfo.LowerBoundBlock = VersionStack[0]->getBlock();
- LocInfo.LastKillValid = false;
- }
- } else if (LocInfo.StackEpoch != StackEpoch) {
- // If all that has changed is the StackEpoch, we only have to check the
- // new things on the stack, because we've checked everything before. In
- // this case, the lower bound of things to check remains the same.
- LocInfo.PopEpoch = PopEpoch;
- LocInfo.StackEpoch = StackEpoch;
- }
- if (!LocInfo.LastKillValid) {
- LocInfo.LastKill = VersionStack.size() - 1;
- LocInfo.LastKillValid = true;
- }
-
- // At this point, we should have corrected last kill and LowerBound to be
- // in bounds.
- assert(LocInfo.LowerBound < VersionStack.size() &&
- "Lower bound out of range");
- assert(LocInfo.LastKill < VersionStack.size() &&
- "Last kill info out of range");
- // In any case, the new upper bound is the top of the stack.
- unsigned long UpperBound = VersionStack.size() - 1;
-
- if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) {
- DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " ("
- << *(MU->getMemoryInst()) << ")"
- << " because there are " << UpperBound - LocInfo.LowerBound
- << " stores to disambiguate\n");
- // Because we did not walk, LastKill is no longer valid, as this may
- // have been a kill.
- LocInfo.LastKillValid = false;
- continue;
- }
- bool FoundClobberResult = false;
- while (UpperBound > LocInfo.LowerBound) {
- if (isa<MemoryPhi>(VersionStack[UpperBound])) {
- // For phis, use the walker, see where we ended up, go there
- Instruction *UseInst = MU->getMemoryInst();
- MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst);
- // We are guaranteed to find it or something is wrong
- while (VersionStack[UpperBound] != Result) {
- assert(UpperBound != 0);
- --UpperBound;
- }
- FoundClobberResult = true;
- break;
- }
-
- MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]);
- // If the lifetime of the pointer ends at this instruction, it's live on
- // entry.
- if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) {
- // Reset UpperBound to liveOnEntryDef's place in the stack
- UpperBound = 0;
- FoundClobberResult = true;
- break;
- }
- if (instructionClobbersQuery(MD, MU, UseMLOC, *AA)) {
- FoundClobberResult = true;
- break;
- }
- --UpperBound;
- }
- // At the end of this loop, UpperBound is either a clobber, or lower bound
- // PHI walking may cause it to be < LowerBound, and in fact, < LastKill.
- if (FoundClobberResult || UpperBound < LocInfo.LastKill) {
- MU->setDefiningAccess(VersionStack[UpperBound], true);
- // We were last killed now by where we got to
- LocInfo.LastKill = UpperBound;
- } else {
- // Otherwise, we checked all the new ones, and now we know we can get to
- // LastKill.
- MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true);
- }
- LocInfo.LowerBound = VersionStack.size() - 1;
- LocInfo.LowerBoundBlock = BB;
- }
-}
-
-/// Optimize uses to point to their actual clobbering definitions.
-void MemorySSA::OptimizeUses::optimizeUses() {
-
- // We perform a non-recursive top-down dominator tree walk
- struct StackInfo {
- const DomTreeNode *Node;
- DomTreeNode::const_iterator Iter;
- };
-
- SmallVector<MemoryAccess *, 16> VersionStack;
- SmallVector<StackInfo, 16> DomTreeWorklist;
- DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo;
- VersionStack.push_back(MSSA->getLiveOnEntryDef());
-
- unsigned long StackEpoch = 1;
- unsigned long PopEpoch = 1;
- for (const auto *DomNode : depth_first(DT->getRootNode()))
- optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack,
- LocStackInfo);
-}
-
-void MemorySSA::placePHINodes(
- const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks,
- const DenseMap<const BasicBlock *, unsigned int> &BBNumbers) {
- // Determine where our MemoryPhi's should go
- ForwardIDFCalculator IDFs(*DT);
- IDFs.setDefiningBlocks(DefiningBlocks);
- SmallVector<BasicBlock *, 32> IDFBlocks;
- IDFs.calculate(IDFBlocks);
-
- std::sort(IDFBlocks.begin(), IDFBlocks.end(),
- [&BBNumbers](const BasicBlock *A, const BasicBlock *B) {
- return BBNumbers.lookup(A) < BBNumbers.lookup(B);
- });
-
- // Now place MemoryPhi nodes.
- for (auto &BB : IDFBlocks) {
- // Insert phi node
- AccessList *Accesses = getOrCreateAccessList(BB);
- MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++);
- ValueToMemoryAccess[BB] = Phi;
- // Phi's always are placed at the front of the block.
- Accesses->push_front(Phi);
- }
-}
-
-void MemorySSA::buildMemorySSA() {
- // We create an access to represent "live on entry", for things like
- // arguments or users of globals, where the memory they use is defined before
- // the beginning of the function. We do not actually insert it into the IR.
- // We do not define a live on exit for the immediate uses, and thus our
- // semantics do *not* imply that something with no immediate uses can simply
- // be removed.
- BasicBlock &StartingPoint = F.getEntryBlock();
- LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr,
- &StartingPoint, NextID++);
- DenseMap<const BasicBlock *, unsigned int> BBNumbers;
- unsigned NextBBNum = 0;
-
- // We maintain lists of memory accesses per-block, trading memory for time. We
- // could just look up the memory access for every possible instruction in the
- // stream.
- SmallPtrSet<BasicBlock *, 32> DefiningBlocks;
- SmallPtrSet<BasicBlock *, 32> DefUseBlocks;
- // Go through each block, figure out where defs occur, and chain together all
- // the accesses.
- for (BasicBlock &B : F) {
- BBNumbers[&B] = NextBBNum++;
- bool InsertIntoDef = false;
- AccessList *Accesses = nullptr;
- for (Instruction &I : B) {
- MemoryUseOrDef *MUD = createNewAccess(&I);
- if (!MUD)
- continue;
- InsertIntoDef |= isa<MemoryDef>(MUD);
-
- if (!Accesses)
- Accesses = getOrCreateAccessList(&B);
- Accesses->push_back(MUD);
- }
- if (InsertIntoDef)
- DefiningBlocks.insert(&B);
- if (Accesses)
- DefUseBlocks.insert(&B);
- }
- placePHINodes(DefiningBlocks, BBNumbers);
-
- // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get
- // filled in with all blocks.
- SmallPtrSet<BasicBlock *, 16> Visited;
- renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited);
-
- CachingWalker *Walker = getWalkerImpl();
-
- // We're doing a batch of updates; don't drop useful caches between them.
- Walker->setAutoResetWalker(false);
- OptimizeUses(this, Walker, AA, DT).optimizeUses();
- Walker->setAutoResetWalker(true);
- Walker->resetClobberWalker();
-
- // Mark the uses in unreachable blocks as live on entry, so that they go
- // somewhere.
- for (auto &BB : F)
- if (!Visited.count(&BB))
- markUnreachableAsLiveOnEntry(&BB);
-}
-
-MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); }
-
-MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() {
- if (Walker)
- return Walker.get();
-
- Walker = make_unique<CachingWalker>(this, AA, DT);
- return Walker.get();
-}
-
-MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) {
- assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB");
- AccessList *Accesses = getOrCreateAccessList(BB);
- MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++);
- ValueToMemoryAccess[BB] = Phi;
- // Phi's always are placed at the front of the block.
- Accesses->push_front(Phi);
- BlockNumberingValid.erase(BB);
- return Phi;
-}
-
-MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I,
- MemoryAccess *Definition) {
- assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI");
- MemoryUseOrDef *NewAccess = createNewAccess(I);
- assert(
- NewAccess != nullptr &&
- "Tried to create a memory access for a non-memory touching instruction");
- NewAccess->setDefiningAccess(Definition);
- return NewAccess;
-}
-
-MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I,
- MemoryAccess *Definition,
- const BasicBlock *BB,
- InsertionPlace Point) {
- MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition);
- auto *Accesses = getOrCreateAccessList(BB);
- if (Point == Beginning) {
- // It goes after any phi nodes
- auto AI = find_if(
- *Accesses, [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); });
-
- Accesses->insert(AI, NewAccess);
- } else {
- Accesses->push_back(NewAccess);
- }
- BlockNumberingValid.erase(BB);
- return NewAccess;
-}
-
-MemoryUseOrDef *MemorySSA::createMemoryAccessBefore(Instruction *I,
- MemoryAccess *Definition,
- MemoryUseOrDef *InsertPt) {
- assert(I->getParent() == InsertPt->getBlock() &&
- "New and old access must be in the same block");
- MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition);
- auto *Accesses = getOrCreateAccessList(InsertPt->getBlock());
- Accesses->insert(AccessList::iterator(InsertPt), NewAccess);
- BlockNumberingValid.erase(InsertPt->getBlock());
- return NewAccess;
-}
-
-MemoryUseOrDef *MemorySSA::createMemoryAccessAfter(Instruction *I,
- MemoryAccess *Definition,
- MemoryAccess *InsertPt) {
- assert(I->getParent() == InsertPt->getBlock() &&
- "New and old access must be in the same block");
- MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition);
- auto *Accesses = getOrCreateAccessList(InsertPt->getBlock());
- Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess);
- BlockNumberingValid.erase(InsertPt->getBlock());
- return NewAccess;
-}
-
-void MemorySSA::spliceMemoryAccessAbove(MemoryDef *Where,
- MemoryUseOrDef *What) {
- assert(What != getLiveOnEntryDef() &&
- Where != getLiveOnEntryDef() && "Can't splice (above) LOE.");
- assert(dominates(Where, What) && "Only upwards splices are permitted.");
-
- if (Where == What)
- return;
- if (isa<MemoryDef>(What)) {
- // TODO: possibly use removeMemoryAccess' more efficient RAUW
- What->replaceAllUsesWith(What->getDefiningAccess());
- What->setDefiningAccess(Where->getDefiningAccess());
- Where->setDefiningAccess(What);
- }
- AccessList *Src = getWritableBlockAccesses(What->getBlock());
- AccessList *Dest = getWritableBlockAccesses(Where->getBlock());
- Dest->splice(AccessList::iterator(Where), *Src, What);
-
- BlockNumberingValid.erase(What->getBlock());
- if (What->getBlock() != Where->getBlock())
- BlockNumberingValid.erase(Where->getBlock());
-}
-
-/// \brief Helper function to create new memory accesses
-MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) {
- // The assume intrinsic has a control dependency which we model by claiming
- // that it writes arbitrarily. Ignore that fake memory dependency here.
- // FIXME: Replace this special casing with a more accurate modelling of
- // assume's control dependency.
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
- if (II->getIntrinsicID() == Intrinsic::assume)
- return nullptr;
-
- // Find out what affect this instruction has on memory.
- ModRefInfo ModRef = AA->getModRefInfo(I);
- bool Def = bool(ModRef & MRI_Mod);
- bool Use = bool(ModRef & MRI_Ref);
-
- // It's possible for an instruction to not modify memory at all. During
- // construction, we ignore them.
- if (!Def && !Use)
- return nullptr;
-
- assert((Def || Use) &&
- "Trying to create a memory access with a non-memory instruction");
-
- MemoryUseOrDef *MUD;
- if (Def)
- MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++);
- else
- MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent());
- ValueToMemoryAccess[I] = MUD;
- return MUD;
-}
-
-MemoryAccess *MemorySSA::findDominatingDef(BasicBlock *UseBlock,
- enum InsertionPlace Where) {
- // Handle the initial case
- if (Where == Beginning)
- // The only thing that could define us at the beginning is a phi node
- if (MemoryPhi *Phi = getMemoryAccess(UseBlock))
- return Phi;
-
- DomTreeNode *CurrNode = DT->getNode(UseBlock);
- // Need to be defined by our dominator
- if (Where == Beginning)
- CurrNode = CurrNode->getIDom();
- Where = End;
- while (CurrNode) {
- auto It = PerBlockAccesses.find(CurrNode->getBlock());
- if (It != PerBlockAccesses.end()) {
- auto &Accesses = It->second;
- for (MemoryAccess &RA : reverse(*Accesses)) {
- if (isa<MemoryDef>(RA) || isa<MemoryPhi>(RA))
- return &RA;
- }
- }
- CurrNode = CurrNode->getIDom();
- }
- return LiveOnEntryDef.get();
-}
-
-/// \brief Returns true if \p Replacer dominates \p Replacee .
-bool MemorySSA::dominatesUse(const MemoryAccess *Replacer,
- const MemoryAccess *Replacee) const {
- if (isa<MemoryUseOrDef>(Replacee))
- return DT->dominates(Replacer->getBlock(), Replacee->getBlock());
- const auto *MP = cast<MemoryPhi>(Replacee);
- // For a phi node, the use occurs in the predecessor block of the phi node.
- // Since we may occur multiple times in the phi node, we have to check each
- // operand to ensure Replacer dominates each operand where Replacee occurs.
- for (const Use &Arg : MP->operands()) {
- if (Arg.get() != Replacee &&
- !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg)))
- return false;
- }
- return true;
-}
-
-/// \brief If all arguments of a MemoryPHI are defined by the same incoming
-/// argument, return that argument.
-static MemoryAccess *onlySingleValue(MemoryPhi *MP) {
- MemoryAccess *MA = nullptr;
-
- for (auto &Arg : MP->operands()) {
- if (!MA)
- MA = cast<MemoryAccess>(Arg);
- else if (MA != Arg)
- return nullptr;
- }
- return MA;
-}
-
-/// \brief Properly remove \p MA from all of MemorySSA's lookup tables.
-///
-/// Because of the way the intrusive list and use lists work, it is important to
-/// do removal in the right order.
-void MemorySSA::removeFromLookups(MemoryAccess *MA) {
- assert(MA->use_empty() &&
- "Trying to remove memory access that still has uses");
- BlockNumbering.erase(MA);
- if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA))
- MUD->setDefiningAccess(nullptr);
- // Invalidate our walker's cache if necessary
- if (!isa<MemoryUse>(MA))
- Walker->invalidateInfo(MA);
- // The call below to erase will destroy MA, so we can't change the order we
- // are doing things here
- Value *MemoryInst;
- if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) {
- MemoryInst = MUD->getMemoryInst();
- } else {
- MemoryInst = MA->getBlock();
- }
- auto VMA = ValueToMemoryAccess.find(MemoryInst);
- if (VMA->second == MA)
- ValueToMemoryAccess.erase(VMA);
-
- auto AccessIt = PerBlockAccesses.find(MA->getBlock());
- std::unique_ptr<AccessList> &Accesses = AccessIt->second;
- Accesses->erase(MA);
- if (Accesses->empty())
- PerBlockAccesses.erase(AccessIt);
-}
-
-void MemorySSA::removeMemoryAccess(MemoryAccess *MA) {
- assert(!isLiveOnEntryDef(MA) && "Trying to remove the live on entry def");
- // We can only delete phi nodes if they have no uses, or we can replace all
- // uses with a single definition.
- MemoryAccess *NewDefTarget = nullptr;
- if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) {
- // Note that it is sufficient to know that all edges of the phi node have
- // the same argument. If they do, by the definition of dominance frontiers
- // (which we used to place this phi), that argument must dominate this phi,
- // and thus, must dominate the phi's uses, and so we will not hit the assert
- // below.
- NewDefTarget = onlySingleValue(MP);
- assert((NewDefTarget || MP->use_empty()) &&
- "We can't delete this memory phi");
- } else {
- NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess();
- }
-
- // Re-point the uses at our defining access
- if (!MA->use_empty()) {
- // Reset optimized on users of this store, and reset the uses.
- // A few notes:
- // 1. This is a slightly modified version of RAUW to avoid walking the
- // uses twice here.
- // 2. If we wanted to be complete, we would have to reset the optimized
- // flags on users of phi nodes if doing the below makes a phi node have all
- // the same arguments. Instead, we prefer users to removeMemoryAccess those
- // phi nodes, because doing it here would be N^3.
- if (MA->hasValueHandle())
- ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget);
- // Note: We assume MemorySSA is not used in metadata since it's not really
- // part of the IR.
-
- while (!MA->use_empty()) {
- Use &U = *MA->use_begin();
- if (MemoryUse *MU = dyn_cast<MemoryUse>(U.getUser()))
- MU->resetOptimized();
- U.set(NewDefTarget);
- }
- }
-
- // The call below to erase will destroy MA, so we can't change the order we
- // are doing things here
- removeFromLookups(MA);
-}
-
-void MemorySSA::print(raw_ostream &OS) const {
- MemorySSAAnnotatedWriter Writer(this);
- F.print(OS, &Writer);
-}
-
-void MemorySSA::dump() const {
- MemorySSAAnnotatedWriter Writer(this);
- F.print(dbgs(), &Writer);
-}
-
-void MemorySSA::verifyMemorySSA() const {
- verifyDefUses(F);
- verifyDomination(F);
- verifyOrdering(F);
- Walker->verify(this);
-}
-
-/// \brief Verify that the order and existence of MemoryAccesses matches the
-/// order and existence of memory affecting instructions.
-void MemorySSA::verifyOrdering(Function &F) const {
- // Walk all the blocks, comparing what the lookups think and what the access
- // lists think, as well as the order in the blocks vs the order in the access
- // lists.
- SmallVector<MemoryAccess *, 32> ActualAccesses;
- for (BasicBlock &B : F) {
- const AccessList *AL = getBlockAccesses(&B);
- MemoryAccess *Phi = getMemoryAccess(&B);
- if (Phi)
- ActualAccesses.push_back(Phi);
- for (Instruction &I : B) {
- MemoryAccess *MA = getMemoryAccess(&I);
- assert((!MA || AL) && "We have memory affecting instructions "
- "in this block but they are not in the "
- "access list");
- if (MA)
- ActualAccesses.push_back(MA);
- }
- // Either we hit the assert, really have no accesses, or we have both
- // accesses and an access list
- if (!AL)
- continue;
- assert(AL->size() == ActualAccesses.size() &&
- "We don't have the same number of accesses in the block as on the "
- "access list");
- auto ALI = AL->begin();
- auto AAI = ActualAccesses.begin();
- while (ALI != AL->end() && AAI != ActualAccesses.end()) {
- assert(&*ALI == *AAI && "Not the same accesses in the same order");
- ++ALI;
- ++AAI;
- }
- ActualAccesses.clear();
- }
-}
-
-/// \brief Verify the domination properties of MemorySSA by checking that each
-/// definition dominates all of its uses.
-void MemorySSA::verifyDomination(Function &F) const {
-#ifndef NDEBUG
- for (BasicBlock &B : F) {
- // Phi nodes are attached to basic blocks
- if (MemoryPhi *MP = getMemoryAccess(&B))
- for (const Use &U : MP->uses())
- assert(dominates(MP, U) && "Memory PHI does not dominate it's uses");
-
- for (Instruction &I : B) {
- MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I));
- if (!MD)
- continue;
-
- for (const Use &U : MD->uses())
- assert(dominates(MD, U) && "Memory Def does not dominate it's uses");
- }
- }
-#endif
-}
-
-/// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use
-/// appears in the use list of \p Def.
-
-void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const {
-#ifndef NDEBUG
- // The live on entry use may cause us to get a NULL def here
- if (!Def)
- assert(isLiveOnEntryDef(Use) &&
- "Null def but use not point to live on entry def");
- else
- assert(is_contained(Def->users(), Use) &&
- "Did not find use in def's use list");
-#endif
-}
-
-/// \brief Verify the immediate use information, by walking all the memory
-/// accesses and verifying that, for each use, it appears in the
-/// appropriate def's use list
-void MemorySSA::verifyDefUses(Function &F) const {
- for (BasicBlock &B : F) {
- // Phi nodes are attached to basic blocks
- if (MemoryPhi *Phi = getMemoryAccess(&B)) {
- assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance(
- pred_begin(&B), pred_end(&B))) &&
- "Incomplete MemoryPhi Node");
- for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
- verifyUseInDefs(Phi->getIncomingValue(I), Phi);
- }
-
- for (Instruction &I : B) {
- if (MemoryUseOrDef *MA = getMemoryAccess(&I)) {
- verifyUseInDefs(MA->getDefiningAccess(), MA);
- }
- }
- }
-}
-
-MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const {
- return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I));
-}
-
-MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const {
- return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB)));
-}
-
-/// Perform a local numbering on blocks so that instruction ordering can be
-/// determined in constant time.
-/// TODO: We currently just number in order. If we numbered by N, we could
-/// allow at least N-1 sequences of insertBefore or insertAfter (and at least
-/// log2(N) sequences of mixed before and after) without needing to invalidate
-/// the numbering.
-void MemorySSA::renumberBlock(const BasicBlock *B) const {
- // The pre-increment ensures the numbers really start at 1.
- unsigned long CurrentNumber = 0;
- const AccessList *AL = getBlockAccesses(B);
- assert(AL != nullptr && "Asking to renumber an empty block");
- for (const auto &I : *AL)
- BlockNumbering[&I] = ++CurrentNumber;
- BlockNumberingValid.insert(B);
-}
-
-/// \brief Determine, for two memory accesses in the same block,
-/// whether \p Dominator dominates \p Dominatee.
-/// \returns True if \p Dominator dominates \p Dominatee.
-bool MemorySSA::locallyDominates(const MemoryAccess *Dominator,
- const MemoryAccess *Dominatee) const {
-
- const BasicBlock *DominatorBlock = Dominator->getBlock();
-
- assert((DominatorBlock == Dominatee->getBlock()) &&
- "Asking for local domination when accesses are in different blocks!");
- // A node dominates itself.
- if (Dominatee == Dominator)
- return true;
-
- // When Dominatee is defined on function entry, it is not dominated by another
- // memory access.
- if (isLiveOnEntryDef(Dominatee))
- return false;
-
- // When Dominator is defined on function entry, it dominates the other memory
- // access.
- if (isLiveOnEntryDef(Dominator))
- return true;
-
- if (!BlockNumberingValid.count(DominatorBlock))
- renumberBlock(DominatorBlock);
-
- unsigned long DominatorNum = BlockNumbering.lookup(Dominator);
- // All numbers start with 1
- assert(DominatorNum != 0 && "Block was not numbered properly");
- unsigned long DominateeNum = BlockNumbering.lookup(Dominatee);
- assert(DominateeNum != 0 && "Block was not numbered properly");
- return DominatorNum < DominateeNum;
-}
-
-bool MemorySSA::dominates(const MemoryAccess *Dominator,
- const MemoryAccess *Dominatee) const {
- if (Dominator == Dominatee)
- return true;
-
- if (isLiveOnEntryDef(Dominatee))
- return false;
-
- if (Dominator->getBlock() != Dominatee->getBlock())
- return DT->dominates(Dominator->getBlock(), Dominatee->getBlock());
- return locallyDominates(Dominator, Dominatee);
-}
-
-bool MemorySSA::dominates(const MemoryAccess *Dominator,
- const Use &Dominatee) const {
- if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) {
- BasicBlock *UseBB = MP->getIncomingBlock(Dominatee);
- // The def must dominate the incoming block of the phi.
- if (UseBB != Dominator->getBlock())
- return DT->dominates(Dominator->getBlock(), UseBB);
- // If the UseBB and the DefBB are the same, compare locally.
- return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee));
- }
- // If it's not a PHI node use, the normal dominates can already handle it.
- return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser()));
-}
-
-const static char LiveOnEntryStr[] = "liveOnEntry";
-
-void MemoryDef::print(raw_ostream &OS) const {
- MemoryAccess *UO = getDefiningAccess();
-
- OS << getID() << " = MemoryDef(";
- if (UO && UO->getID())
- OS << UO->getID();
- else
- OS << LiveOnEntryStr;
- OS << ')';
-}
-
-void MemoryPhi::print(raw_ostream &OS) const {
- bool First = true;
- OS << getID() << " = MemoryPhi(";
- for (const auto &Op : operands()) {
- BasicBlock *BB = getIncomingBlock(Op);
- MemoryAccess *MA = cast<MemoryAccess>(Op);
- if (!First)
- OS << ',';
- else
- First = false;
-
- OS << '{';
- if (BB->hasName())
- OS << BB->getName();
- else
- BB->printAsOperand(OS, false);
- OS << ',';
- if (unsigned ID = MA->getID())
- OS << ID;
- else
- OS << LiveOnEntryStr;
- OS << '}';
- }
- OS << ')';
-}
-
-MemoryAccess::~MemoryAccess() {}
-
-void MemoryUse::print(raw_ostream &OS) const {
- MemoryAccess *UO = getDefiningAccess();
- OS << "MemoryUse(";
- if (UO && UO->getID())
- OS << UO->getID();
- else
- OS << LiveOnEntryStr;
- OS << ')';
-}
-
-void MemoryAccess::dump() const {
- print(dbgs());
- dbgs() << "\n";
-}
-
-char MemorySSAPrinterLegacyPass::ID = 0;
-
-MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) {
- initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
-}
-
-void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.setPreservesAll();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
-}
-
-bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) {
- auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
- MSSA.print(dbgs());
- if (VerifyMemorySSA)
- MSSA.verifyMemorySSA();
- return false;
-}
-
-AnalysisKey MemorySSAAnalysis::Key;
-
-MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F,
- FunctionAnalysisManager &AM) {
- auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- auto &AA = AM.getResult<AAManager>(F);
- return MemorySSAAnalysis::Result(make_unique<MemorySSA>(F, &AA, &DT));
-}
-
-PreservedAnalyses MemorySSAPrinterPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- OS << "MemorySSA for function: " << F.getName() << "\n";
- AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS);
-
- return PreservedAnalyses::all();
-}
-
-PreservedAnalyses MemorySSAVerifierPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA();
-
- return PreservedAnalyses::all();
-}
-
-char MemorySSAWrapperPass::ID = 0;
-
-MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) {
- initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry());
-}
-
-void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); }
-
-void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.setPreservesAll();
- AU.addRequiredTransitive<DominatorTreeWrapperPass>();
- AU.addRequiredTransitive<AAResultsWrapperPass>();
-}
-
-bool MemorySSAWrapperPass::runOnFunction(Function &F) {
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- MSSA.reset(new MemorySSA(F, &AA, &DT));
- return false;
-}
-
-void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); }
-
-void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const {
- MSSA->print(OS);
-}
-
-MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {}
-
-MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A,
- DominatorTree *D)
- : MemorySSAWalker(M), Walker(*M, *A, *D, Cache), AutoResetWalker(true) {}
-
-MemorySSA::CachingWalker::~CachingWalker() {}
-
-void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) {
- // TODO: We can do much better cache invalidation with differently stored
- // caches. For now, for MemoryUses, we simply remove them
- // from the cache, and kill the entire call/non-call cache for everything
- // else. The problem is for phis or defs, currently we'd need to follow use
- // chains down and invalidate anything below us in the chain that currently
- // terminates at this access.
-
- // See if this is a MemoryUse, if so, just remove the cached info. MemoryUse
- // is by definition never a barrier, so nothing in the cache could point to
- // this use. In that case, we only need invalidate the info for the use
- // itself.
-
- if (MemoryUse *MU = dyn_cast<MemoryUse>(MA)) {
- UpwardsMemoryQuery Q(MU->getMemoryInst(), MU);
- Cache.remove(MU, Q.StartingLoc, Q.IsCall);
- MU->resetOptimized();
- } else {
- // If it is not a use, the best we can do right now is destroy the cache.
- Cache.clear();
- }
-
-#ifdef EXPENSIVE_CHECKS
- verifyRemoved(MA);
-#endif
-}
-
-/// \brief Walk the use-def chains starting at \p MA and find
-/// the MemoryAccess that actually clobbers Loc.
-///
-/// \returns our clobbering memory access
-MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess(
- MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) {
- MemoryAccess *New = Walker.findClobber(StartingAccess, Q);
-#ifdef EXPENSIVE_CHECKS
- MemoryAccess *NewNoCache =
- Walker.findClobber(StartingAccess, Q, /*UseWalkerCache=*/false);
- assert(NewNoCache == New && "Cache made us hand back a different result?");
-#endif
- if (AutoResetWalker)
- resetClobberWalker();
- return New;
-}
-
-MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess(
- MemoryAccess *StartingAccess, const MemoryLocation &Loc) {
- if (isa<MemoryPhi>(StartingAccess))
- return StartingAccess;
-
- auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess);
- if (MSSA->isLiveOnEntryDef(StartingUseOrDef))
- return StartingUseOrDef;
-
- Instruction *I = StartingUseOrDef->getMemoryInst();
-
- // Conservatively, fences are always clobbers, so don't perform the walk if we
- // hit a fence.
- if (!ImmutableCallSite(I) && I->isFenceLike())
- return StartingUseOrDef;
-
- UpwardsMemoryQuery Q;
- Q.OriginalAccess = StartingUseOrDef;
- Q.StartingLoc = Loc;
- Q.Inst = I;
- Q.IsCall = false;
-
- if (auto *CacheResult = Cache.lookup(StartingUseOrDef, Loc, Q.IsCall))
- return CacheResult;
-
- // Unlike the other function, do not walk to the def of a def, because we are
- // handed something we already believe is the clobbering access.
- MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef)
- ? StartingUseOrDef->getDefiningAccess()
- : StartingUseOrDef;
-
- MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q);
- DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is ");
- DEBUG(dbgs() << *StartingUseOrDef << "\n");
- DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is ");
- DEBUG(dbgs() << *Clobber << "\n");
- return Clobber;
-}
-
-MemoryAccess *
-MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) {
- auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA);
- // If this is a MemoryPhi, we can't do anything.
- if (!StartingAccess)
- return MA;
-
- // If this is an already optimized use or def, return the optimized result.
- // Note: Currently, we do not store the optimized def result because we'd need
- // a separate field, since we can't use it as the defining access.
- if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess))
- if (MU->isOptimized())
- return MU->getDefiningAccess();
-
- const Instruction *I = StartingAccess->getMemoryInst();
- UpwardsMemoryQuery Q(I, StartingAccess);
- // We can't sanely do anything with a fences, they conservatively
- // clobber all memory, and have no locations to get pointers from to
- // try to disambiguate.
- if (!Q.IsCall && I->isFenceLike())
- return StartingAccess;
-
- if (auto *CacheResult = Cache.lookup(StartingAccess, Q.StartingLoc, Q.IsCall))
- return CacheResult;
-
- if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) {
- MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef();
- Cache.insert(StartingAccess, LiveOnEntry, Q.StartingLoc, Q.IsCall);
- if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess))
- MU->setDefiningAccess(LiveOnEntry, true);
- return LiveOnEntry;
- }
-
- // Start with the thing we already think clobbers this location
- MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess();
-
- // At this point, DefiningAccess may be the live on entry def.
- // If it is, we will not get a better result.
- if (MSSA->isLiveOnEntryDef(DefiningAccess))
- return DefiningAccess;
-
- MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q);
- DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is ");
- DEBUG(dbgs() << *DefiningAccess << "\n");
- DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is ");
- DEBUG(dbgs() << *Result << "\n");
- if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess))
- MU->setDefiningAccess(Result, true);
-
- return Result;
-}
-
-// Verify that MA doesn't exist in any of the caches.
-void MemorySSA::CachingWalker::verifyRemoved(MemoryAccess *MA) {
- assert(!Cache.contains(MA) && "Found removed MemoryAccess in cache.");
-}
-
-MemoryAccess *
-DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) {
- if (auto *Use = dyn_cast<MemoryUseOrDef>(MA))
- return Use->getDefiningAccess();
- return MA;
-}
-
-MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess(
- MemoryAccess *StartingAccess, const MemoryLocation &) {
- if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess))
- return Use->getDefiningAccess();
- return StartingAccess;
-}
-} // namespace llvm
diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp
index c999bd008fefd..481c6aa29c3a1 100644
--- a/lib/Transforms/Utils/MetaRenamer.cpp
+++ b/lib/Transforms/Utils/MetaRenamer.cpp
@@ -16,6 +16,7 @@
#include "llvm/Transforms/IPO.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
@@ -67,6 +68,7 @@ namespace {
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.setPreservesAll();
}
@@ -110,9 +112,15 @@ namespace {
}
// Rename all functions
+ const TargetLibraryInfo &TLI =
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
for (auto &F : M) {
StringRef Name = F.getName();
- if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1))
+ LibFunc Tmp;
+ // Leave library functions alone because their presence or absence could
+ // affect the behavior of other passes.
+ if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
+ TLI.getLibFunc(F, Tmp))
continue;
F.setName(renamer.newName());
@@ -139,8 +147,11 @@ namespace {
}
char MetaRenamer::ID = 0;
-INITIALIZE_PASS(MetaRenamer, "metarenamer",
- "Assign new names to everything", false, false)
+INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer",
+ "Assign new names to everything", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_END(MetaRenamer, "metarenamer",
+ "Assign new names to everything", false, false)
//===----------------------------------------------------------------------===//
//
// MetaRenamer - Rename everything with metasyntactic names.
diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp
index 0d623df77a67e..dbe42c201dd4f 100644
--- a/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/lib/Transforms/Utils/ModuleUtils.cpp
@@ -130,13 +130,25 @@ void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) {
Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) {
if (isa<Function>(FuncOrBitcast))
return cast<Function>(FuncOrBitcast);
- FuncOrBitcast->dump();
+ FuncOrBitcast->print(errs());
+ errs() << '\n';
std::string Err;
raw_string_ostream Stream(Err);
Stream << "Sanitizer interface function redefined: " << *FuncOrBitcast;
report_fatal_error(Err);
}
+Function *llvm::declareSanitizerInitFunction(Module &M, StringRef InitName,
+ ArrayRef<Type *> InitArgTypes) {
+ assert(!InitName.empty() && "Expected init function name");
+ Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+ InitName,
+ FunctionType::get(Type::getVoidTy(M.getContext()), InitArgTypes, false),
+ AttributeList()));
+ F->setLinkage(Function::ExternalLinkage);
+ return F;
+}
+
std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions(
Module &M, StringRef CtorName, StringRef InitName,
ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs,
@@ -144,22 +156,19 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions(
assert(!InitName.empty() && "Expected init function name");
assert(InitArgs.size() == InitArgTypes.size() &&
"Sanitizer's init function expects different number of arguments");
+ Function *InitFunction =
+ declareSanitizerInitFunction(M, InitName, InitArgTypes);
Function *Ctor = Function::Create(
FunctionType::get(Type::getVoidTy(M.getContext()), false),
GlobalValue::InternalLinkage, CtorName, &M);
BasicBlock *CtorBB = BasicBlock::Create(M.getContext(), "", Ctor);
IRBuilder<> IRB(ReturnInst::Create(M.getContext(), CtorBB));
- Function *InitFunction =
- checkSanitizerInterfaceFunction(M.getOrInsertFunction(
- InitName, FunctionType::get(IRB.getVoidTy(), InitArgTypes, false),
- AttributeSet()));
- InitFunction->setLinkage(Function::ExternalLinkage);
IRB.CreateCall(InitFunction, InitArgs);
if (!VersionCheckName.empty()) {
Function *VersionCheckFunction =
checkSanitizerInterfaceFunction(M.getOrInsertFunction(
VersionCheckName, FunctionType::get(IRB.getVoidTy(), {}, false),
- AttributeSet()));
+ AttributeList()));
IRB.CreateCall(VersionCheckFunction, {});
}
return std::make_pair(Ctor, InitFunction);
diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp
new file mode 100644
index 0000000000000..8877aeafecdec
--- /dev/null
+++ b/lib/Transforms/Utils/PredicateInfo.cpp
@@ -0,0 +1,782 @@
+//===-- PredicateInfo.cpp - PredicateInfo Builder--------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------===//
+//
+// This file implements the PredicateInfo class.
+//
+//===----------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/PredicateInfo.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/OrderedBasicBlock.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugCounter.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Transforms/Scalar.h"
+#include <algorithm>
+#define DEBUG_TYPE "predicateinfo"
+using namespace llvm;
+using namespace PatternMatch;
+using namespace llvm::PredicateInfoClasses;
+
+INITIALIZE_PASS_BEGIN(PredicateInfoPrinterLegacyPass, "print-predicateinfo",
+ "PredicateInfo Printer", false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo",
+ "PredicateInfo Printer", false, false)
+static cl::opt<bool> VerifyPredicateInfo(
+ "verify-predicateinfo", cl::init(false), cl::Hidden,
+ cl::desc("Verify PredicateInfo in legacy printer pass."));
+namespace {
+DEBUG_COUNTER(RenameCounter, "predicateinfo-rename",
+ "Controls which variables are renamed with predicateinfo")
+// Given a predicate info that is a type of branching terminator, get the
+// branching block.
+const BasicBlock *getBranchBlock(const PredicateBase *PB) {
+ assert(isa<PredicateWithEdge>(PB) &&
+ "Only branches and switches should have PHIOnly defs that "
+ "require branch blocks.");
+ return cast<PredicateWithEdge>(PB)->From;
+}
+
+// Given a predicate info that is a type of branching terminator, get the
+// branching terminator.
+static Instruction *getBranchTerminator(const PredicateBase *PB) {
+ assert(isa<PredicateWithEdge>(PB) &&
+ "Not a predicate info type we know how to get a terminator from.");
+ return cast<PredicateWithEdge>(PB)->From->getTerminator();
+}
+
+// Given a predicate info that is a type of branching terminator, get the
+// edge this predicate info represents
+const std::pair<BasicBlock *, BasicBlock *>
+getBlockEdge(const PredicateBase *PB) {
+ assert(isa<PredicateWithEdge>(PB) &&
+ "Not a predicate info type we know how to get an edge from.");
+ const auto *PEdge = cast<PredicateWithEdge>(PB);
+ return std::make_pair(PEdge->From, PEdge->To);
+}
+}
+
+namespace llvm {
+namespace PredicateInfoClasses {
+enum LocalNum {
+ // Operations that must appear first in the block.
+ LN_First,
+ // Operations that are somewhere in the middle of the block, and are sorted on
+ // demand.
+ LN_Middle,
+ // Operations that must appear last in a block, like successor phi node uses.
+ LN_Last
+};
+
+// Associate global and local DFS info with defs and uses, so we can sort them
+// into a global domination ordering.
+struct ValueDFS {
+ int DFSIn = 0;
+ int DFSOut = 0;
+ unsigned int LocalNum = LN_Middle;
+ // Only one of Def or Use will be set.
+ Value *Def = nullptr;
+ Use *U = nullptr;
+ // Neither PInfo nor EdgeOnly participate in the ordering
+ PredicateBase *PInfo = nullptr;
+ bool EdgeOnly = false;
+};
+
+// This compares ValueDFS structures, creating OrderedBasicBlocks where
+// necessary to compare uses/defs in the same block. Doing so allows us to walk
+// the minimum number of instructions necessary to compute our def/use ordering.
+struct ValueDFS_Compare {
+ DenseMap<const BasicBlock *, std::unique_ptr<OrderedBasicBlock>> &OBBMap;
+ ValueDFS_Compare(
+ DenseMap<const BasicBlock *, std::unique_ptr<OrderedBasicBlock>> &OBBMap)
+ : OBBMap(OBBMap) {}
+ bool operator()(const ValueDFS &A, const ValueDFS &B) const {
+ if (&A == &B)
+ return false;
+ // The only case we can't directly compare them is when they in the same
+ // block, and both have localnum == middle. In that case, we have to use
+ // comesbefore to see what the real ordering is, because they are in the
+ // same basic block.
+
+ bool SameBlock = std::tie(A.DFSIn, A.DFSOut) == std::tie(B.DFSIn, B.DFSOut);
+
+ // We want to put the def that will get used for a given set of phi uses,
+ // before those phi uses.
+ // So we sort by edge, then by def.
+ // Note that only phi nodes uses and defs can come last.
+ if (SameBlock && A.LocalNum == LN_Last && B.LocalNum == LN_Last)
+ return comparePHIRelated(A, B);
+
+ if (!SameBlock || A.LocalNum != LN_Middle || B.LocalNum != LN_Middle)
+ return std::tie(A.DFSIn, A.DFSOut, A.LocalNum, A.Def, A.U) <
+ std::tie(B.DFSIn, B.DFSOut, B.LocalNum, B.Def, B.U);
+ return localComesBefore(A, B);
+ }
+
+ // For a phi use, or a non-materialized def, return the edge it represents.
+ const std::pair<BasicBlock *, BasicBlock *>
+ getBlockEdge(const ValueDFS &VD) const {
+ if (!VD.Def && VD.U) {
+ auto *PHI = cast<PHINode>(VD.U->getUser());
+ return std::make_pair(PHI->getIncomingBlock(*VD.U), PHI->getParent());
+ }
+ // This is really a non-materialized def.
+ return ::getBlockEdge(VD.PInfo);
+ }
+
+ // For two phi related values, return the ordering.
+ bool comparePHIRelated(const ValueDFS &A, const ValueDFS &B) const {
+ auto &ABlockEdge = getBlockEdge(A);
+ auto &BBlockEdge = getBlockEdge(B);
+ // Now sort by block edge and then defs before uses.
+ return std::tie(ABlockEdge, A.Def, A.U) < std::tie(BBlockEdge, B.Def, B.U);
+ }
+
+ // Get the definition of an instruction that occurs in the middle of a block.
+ Value *getMiddleDef(const ValueDFS &VD) const {
+ if (VD.Def)
+ return VD.Def;
+ // It's possible for the defs and uses to be null. For branches, the local
+ // numbering will say the placed predicaeinfos should go first (IE
+ // LN_beginning), so we won't be in this function. For assumes, we will end
+ // up here, beause we need to order the def we will place relative to the
+ // assume. So for the purpose of ordering, we pretend the def is the assume
+ // because that is where we will insert the info.
+ if (!VD.U) {
+ assert(VD.PInfo &&
+ "No def, no use, and no predicateinfo should not occur");
+ assert(isa<PredicateAssume>(VD.PInfo) &&
+ "Middle of block should only occur for assumes");
+ return cast<PredicateAssume>(VD.PInfo)->AssumeInst;
+ }
+ return nullptr;
+ }
+
+ // Return either the Def, if it's not null, or the user of the Use, if the def
+ // is null.
+ const Instruction *getDefOrUser(const Value *Def, const Use *U) const {
+ if (Def)
+ return cast<Instruction>(Def);
+ return cast<Instruction>(U->getUser());
+ }
+
+ // This performs the necessary local basic block ordering checks to tell
+ // whether A comes before B, where both are in the same basic block.
+ bool localComesBefore(const ValueDFS &A, const ValueDFS &B) const {
+ auto *ADef = getMiddleDef(A);
+ auto *BDef = getMiddleDef(B);
+
+ // See if we have real values or uses. If we have real values, we are
+ // guaranteed they are instructions or arguments. No matter what, we are
+ // guaranteed they are in the same block if they are instructions.
+ auto *ArgA = dyn_cast_or_null<Argument>(ADef);
+ auto *ArgB = dyn_cast_or_null<Argument>(BDef);
+
+ if (ArgA && !ArgB)
+ return true;
+ if (ArgB && !ArgA)
+ return false;
+ if (ArgA && ArgB)
+ return ArgA->getArgNo() < ArgB->getArgNo();
+
+ auto *AInst = getDefOrUser(ADef, A.U);
+ auto *BInst = getDefOrUser(BDef, B.U);
+
+ auto *BB = AInst->getParent();
+ auto LookupResult = OBBMap.find(BB);
+ if (LookupResult != OBBMap.end())
+ return LookupResult->second->dominates(AInst, BInst);
+
+ auto Result = OBBMap.insert({BB, make_unique<OrderedBasicBlock>(BB)});
+ return Result.first->second->dominates(AInst, BInst);
+ }
+};
+
+} // namespace PredicateInfoClasses
+
+bool PredicateInfo::stackIsInScope(const ValueDFSStack &Stack,
+ const ValueDFS &VDUse) const {
+ if (Stack.empty())
+ return false;
+ // If it's a phi only use, make sure it's for this phi node edge, and that the
+ // use is in a phi node. If it's anything else, and the top of the stack is
+ // EdgeOnly, we need to pop the stack. We deliberately sort phi uses next to
+ // the defs they must go with so that we can know it's time to pop the stack
+ // when we hit the end of the phi uses for a given def.
+ if (Stack.back().EdgeOnly) {
+ if (!VDUse.U)
+ return false;
+ auto *PHI = dyn_cast<PHINode>(VDUse.U->getUser());
+ if (!PHI)
+ return false;
+ // Check edge
+ BasicBlock *EdgePred = PHI->getIncomingBlock(*VDUse.U);
+ if (EdgePred != getBranchBlock(Stack.back().PInfo))
+ return false;
+
+ // Use dominates, which knows how to handle edge dominance.
+ return DT.dominates(getBlockEdge(Stack.back().PInfo), *VDUse.U);
+ }
+
+ return (VDUse.DFSIn >= Stack.back().DFSIn &&
+ VDUse.DFSOut <= Stack.back().DFSOut);
+}
+
+void PredicateInfo::popStackUntilDFSScope(ValueDFSStack &Stack,
+ const ValueDFS &VD) {
+ while (!Stack.empty() && !stackIsInScope(Stack, VD))
+ Stack.pop_back();
+}
+
+// Convert the uses of Op into a vector of uses, associating global and local
+// DFS info with each one.
+void PredicateInfo::convertUsesToDFSOrdered(
+ Value *Op, SmallVectorImpl<ValueDFS> &DFSOrderedSet) {
+ for (auto &U : Op->uses()) {
+ if (auto *I = dyn_cast<Instruction>(U.getUser())) {
+ ValueDFS VD;
+ // Put the phi node uses in the incoming block.
+ BasicBlock *IBlock;
+ if (auto *PN = dyn_cast<PHINode>(I)) {
+ IBlock = PN->getIncomingBlock(U);
+ // Make phi node users appear last in the incoming block
+ // they are from.
+ VD.LocalNum = LN_Last;
+ } else {
+ // If it's not a phi node use, it is somewhere in the middle of the
+ // block.
+ IBlock = I->getParent();
+ VD.LocalNum = LN_Middle;
+ }
+ DomTreeNode *DomNode = DT.getNode(IBlock);
+ // It's possible our use is in an unreachable block. Skip it if so.
+ if (!DomNode)
+ continue;
+ VD.DFSIn = DomNode->getDFSNumIn();
+ VD.DFSOut = DomNode->getDFSNumOut();
+ VD.U = &U;
+ DFSOrderedSet.push_back(VD);
+ }
+ }
+}
+
+// Collect relevant operations from Comparison that we may want to insert copies
+// for.
+void collectCmpOps(CmpInst *Comparison, SmallVectorImpl<Value *> &CmpOperands) {
+ auto *Op0 = Comparison->getOperand(0);
+ auto *Op1 = Comparison->getOperand(1);
+ if (Op0 == Op1)
+ return;
+ CmpOperands.push_back(Comparison);
+ // Only want real values, not constants. Additionally, operands with one use
+ // are only being used in the comparison, which means they will not be useful
+ // for us to consider for predicateinfo.
+ //
+ if ((isa<Instruction>(Op0) || isa<Argument>(Op0)) && !Op0->hasOneUse())
+ CmpOperands.push_back(Op0);
+ if ((isa<Instruction>(Op1) || isa<Argument>(Op1)) && !Op1->hasOneUse())
+ CmpOperands.push_back(Op1);
+}
+
+// Add Op, PB to the list of value infos for Op, and mark Op to be renamed.
+void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op,
+ PredicateBase *PB) {
+ OpsToRename.insert(Op);
+ auto &OperandInfo = getOrCreateValueInfo(Op);
+ AllInfos.push_back(PB);
+ OperandInfo.Infos.push_back(PB);
+}
+
+// Process an assume instruction and place relevant operations we want to rename
+// into OpsToRename.
+void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB,
+ SmallPtrSetImpl<Value *> &OpsToRename) {
+ // See if we have a comparison we support
+ SmallVector<Value *, 8> CmpOperands;
+ SmallVector<Value *, 2> ConditionsToProcess;
+ CmpInst::Predicate Pred;
+ Value *Operand = II->getOperand(0);
+ if (m_c_And(m_Cmp(Pred, m_Value(), m_Value()),
+ m_Cmp(Pred, m_Value(), m_Value()))
+ .match(II->getOperand(0))) {
+ ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(0));
+ ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(1));
+ ConditionsToProcess.push_back(Operand);
+ } else if (isa<CmpInst>(Operand)) {
+
+ ConditionsToProcess.push_back(Operand);
+ }
+ for (auto Cond : ConditionsToProcess) {
+ if (auto *Cmp = dyn_cast<CmpInst>(Cond)) {
+ collectCmpOps(Cmp, CmpOperands);
+ // Now add our copy infos for our operands
+ for (auto *Op : CmpOperands) {
+ auto *PA = new PredicateAssume(Op, II, Cmp);
+ addInfoFor(OpsToRename, Op, PA);
+ }
+ CmpOperands.clear();
+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) {
+ // Otherwise, it should be an AND.
+ assert(BinOp->getOpcode() == Instruction::And &&
+ "Should have been an AND");
+ auto *PA = new PredicateAssume(BinOp, II, BinOp);
+ addInfoFor(OpsToRename, BinOp, PA);
+ } else {
+ llvm_unreachable("Unknown type of condition");
+ }
+ }
+}
+
+// Process a block terminating branch, and place relevant operations to be
+// renamed into OpsToRename.
+void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB,
+ SmallPtrSetImpl<Value *> &OpsToRename) {
+ BasicBlock *FirstBB = BI->getSuccessor(0);
+ BasicBlock *SecondBB = BI->getSuccessor(1);
+ SmallVector<BasicBlock *, 2> SuccsToProcess;
+ SuccsToProcess.push_back(FirstBB);
+ SuccsToProcess.push_back(SecondBB);
+ SmallVector<Value *, 2> ConditionsToProcess;
+
+ auto InsertHelper = [&](Value *Op, bool isAnd, bool isOr, Value *Cond) {
+ for (auto *Succ : SuccsToProcess) {
+ // Don't try to insert on a self-edge. This is mainly because we will
+ // eliminate during renaming anyway.
+ if (Succ == BranchBB)
+ continue;
+ bool TakenEdge = (Succ == FirstBB);
+ // For and, only insert on the true edge
+ // For or, only insert on the false edge
+ if ((isAnd && !TakenEdge) || (isOr && TakenEdge))
+ continue;
+ PredicateBase *PB =
+ new PredicateBranch(Op, BranchBB, Succ, Cond, TakenEdge);
+ addInfoFor(OpsToRename, Op, PB);
+ if (!Succ->getSinglePredecessor())
+ EdgeUsesOnly.insert({BranchBB, Succ});
+ }
+ };
+
+ // Match combinations of conditions.
+ CmpInst::Predicate Pred;
+ bool isAnd = false;
+ bool isOr = false;
+ SmallVector<Value *, 8> CmpOperands;
+ if (match(BI->getCondition(), m_And(m_Cmp(Pred, m_Value(), m_Value()),
+ m_Cmp(Pred, m_Value(), m_Value()))) ||
+ match(BI->getCondition(), m_Or(m_Cmp(Pred, m_Value(), m_Value()),
+ m_Cmp(Pred, m_Value(), m_Value())))) {
+ auto *BinOp = cast<BinaryOperator>(BI->getCondition());
+ if (BinOp->getOpcode() == Instruction::And)
+ isAnd = true;
+ else if (BinOp->getOpcode() == Instruction::Or)
+ isOr = true;
+ ConditionsToProcess.push_back(BinOp->getOperand(0));
+ ConditionsToProcess.push_back(BinOp->getOperand(1));
+ ConditionsToProcess.push_back(BI->getCondition());
+ } else if (isa<CmpInst>(BI->getCondition())) {
+ ConditionsToProcess.push_back(BI->getCondition());
+ }
+ for (auto Cond : ConditionsToProcess) {
+ if (auto *Cmp = dyn_cast<CmpInst>(Cond)) {
+ collectCmpOps(Cmp, CmpOperands);
+ // Now add our copy infos for our operands
+ for (auto *Op : CmpOperands)
+ InsertHelper(Op, isAnd, isOr, Cmp);
+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) {
+ // This must be an AND or an OR.
+ assert((BinOp->getOpcode() == Instruction::And ||
+ BinOp->getOpcode() == Instruction::Or) &&
+ "Should have been an AND or an OR");
+ // The actual value of the binop is not subject to the same restrictions
+ // as the comparison. It's either true or false on the true/false branch.
+ InsertHelper(BinOp, false, false, BinOp);
+ } else {
+ llvm_unreachable("Unknown type of condition");
+ }
+ CmpOperands.clear();
+ }
+}
+// Process a block terminating switch, and place relevant operations to be
+// renamed into OpsToRename.
+void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB,
+ SmallPtrSetImpl<Value *> &OpsToRename) {
+ Value *Op = SI->getCondition();
+ if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse())
+ return;
+
+ // Remember how many outgoing edges there are to every successor.
+ SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges;
+ for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) {
+ BasicBlock *TargetBlock = SI->getSuccessor(i);
+ ++SwitchEdges[TargetBlock];
+ }
+
+ // Now propagate info for each case value
+ for (auto C : SI->cases()) {
+ BasicBlock *TargetBlock = C.getCaseSuccessor();
+ if (SwitchEdges.lookup(TargetBlock) == 1) {
+ PredicateSwitch *PS = new PredicateSwitch(
+ Op, SI->getParent(), TargetBlock, C.getCaseValue(), SI);
+ addInfoFor(OpsToRename, Op, PS);
+ if (!TargetBlock->getSinglePredecessor())
+ EdgeUsesOnly.insert({BranchBB, TargetBlock});
+ }
+ }
+}
+
+// Build predicate info for our function
+void PredicateInfo::buildPredicateInfo() {
+ DT.updateDFSNumbers();
+ // Collect operands to rename from all conditional branch terminators, as well
+ // as assume statements.
+ SmallPtrSet<Value *, 8> OpsToRename;
+ for (auto DTN : depth_first(DT.getRootNode())) {
+ BasicBlock *BranchBB = DTN->getBlock();
+ if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) {
+ if (!BI->isConditional())
+ continue;
+ processBranch(BI, BranchBB, OpsToRename);
+ } else if (auto *SI = dyn_cast<SwitchInst>(BranchBB->getTerminator())) {
+ processSwitch(SI, BranchBB, OpsToRename);
+ }
+ }
+ for (auto &Assume : AC.assumptions()) {
+ if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume))
+ processAssume(II, II->getParent(), OpsToRename);
+ }
+ // Now rename all our operations.
+ renameUses(OpsToRename);
+}
+
+// Given the renaming stack, make all the operands currently on the stack real
+// by inserting them into the IR. Return the last operation's value.
+Value *PredicateInfo::materializeStack(unsigned int &Counter,
+ ValueDFSStack &RenameStack,
+ Value *OrigOp) {
+ // Find the first thing we have to materialize
+ auto RevIter = RenameStack.rbegin();
+ for (; RevIter != RenameStack.rend(); ++RevIter)
+ if (RevIter->Def)
+ break;
+
+ size_t Start = RevIter - RenameStack.rbegin();
+ // The maximum number of things we should be trying to materialize at once
+ // right now is 4, depending on if we had an assume, a branch, and both used
+ // and of conditions.
+ for (auto RenameIter = RenameStack.end() - Start;
+ RenameIter != RenameStack.end(); ++RenameIter) {
+ auto *Op =
+ RenameIter == RenameStack.begin() ? OrigOp : (RenameIter - 1)->Def;
+ ValueDFS &Result = *RenameIter;
+ auto *ValInfo = Result.PInfo;
+ // For edge predicates, we can just place the operand in the block before
+ // the terminator. For assume, we have to place it right before the assume
+ // to ensure we dominate all of our uses. Always insert right before the
+ // relevant instruction (terminator, assume), so that we insert in proper
+ // order in the case of multiple predicateinfo in the same block.
+ if (isa<PredicateWithEdge>(ValInfo)) {
+ IRBuilder<> B(getBranchTerminator(ValInfo));
+ Function *IF = Intrinsic::getDeclaration(
+ F.getParent(), Intrinsic::ssa_copy, Op->getType());
+ CallInst *PIC =
+ B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++));
+ PredicateMap.insert({PIC, ValInfo});
+ Result.Def = PIC;
+ } else {
+ auto *PAssume = dyn_cast<PredicateAssume>(ValInfo);
+ assert(PAssume &&
+ "Should not have gotten here without it being an assume");
+ IRBuilder<> B(PAssume->AssumeInst);
+ Function *IF = Intrinsic::getDeclaration(
+ F.getParent(), Intrinsic::ssa_copy, Op->getType());
+ CallInst *PIC = B.CreateCall(IF, Op);
+ PredicateMap.insert({PIC, ValInfo});
+ Result.Def = PIC;
+ }
+ }
+ return RenameStack.back().Def;
+}
+
+// Instead of the standard SSA renaming algorithm, which is O(Number of
+// instructions), and walks the entire dominator tree, we walk only the defs +
+// uses. The standard SSA renaming algorithm does not really rely on the
+// dominator tree except to order the stack push/pops of the renaming stacks, so
+// that defs end up getting pushed before hitting the correct uses. This does
+// not require the dominator tree, only the *order* of the dominator tree. The
+// complete and correct ordering of the defs and uses, in dominator tree is
+// contained in the DFS numbering of the dominator tree. So we sort the defs and
+// uses into the DFS ordering, and then just use the renaming stack as per
+// normal, pushing when we hit a def (which is a predicateinfo instruction),
+// popping when we are out of the dfs scope for that def, and replacing any uses
+// with top of stack if it exists. In order to handle liveness without
+// propagating liveness info, we don't actually insert the predicateinfo
+// instruction def until we see a use that it would dominate. Once we see such
+// a use, we materialize the predicateinfo instruction in the right place and
+// use it.
+//
+// TODO: Use this algorithm to perform fast single-variable renaming in
+// promotememtoreg and memoryssa.
+void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpsToRename) {
+ ValueDFS_Compare Compare(OBBMap);
+ // Compute liveness, and rename in O(uses) per Op.
+ for (auto *Op : OpsToRename) {
+ unsigned Counter = 0;
+ SmallVector<ValueDFS, 16> OrderedUses;
+ const auto &ValueInfo = getValueInfo(Op);
+ // Insert the possible copies into the def/use list.
+ // They will become real copies if we find a real use for them, and never
+ // created otherwise.
+ for (auto &PossibleCopy : ValueInfo.Infos) {
+ ValueDFS VD;
+ // Determine where we are going to place the copy by the copy type.
+ // The predicate info for branches always come first, they will get
+ // materialized in the split block at the top of the block.
+ // The predicate info for assumes will be somewhere in the middle,
+ // it will get materialized in front of the assume.
+ if (const auto *PAssume = dyn_cast<PredicateAssume>(PossibleCopy)) {
+ VD.LocalNum = LN_Middle;
+ DomTreeNode *DomNode = DT.getNode(PAssume->AssumeInst->getParent());
+ if (!DomNode)
+ continue;
+ VD.DFSIn = DomNode->getDFSNumIn();
+ VD.DFSOut = DomNode->getDFSNumOut();
+ VD.PInfo = PossibleCopy;
+ OrderedUses.push_back(VD);
+ } else if (isa<PredicateWithEdge>(PossibleCopy)) {
+ // If we can only do phi uses, we treat it like it's in the branch
+ // block, and handle it specially. We know that it goes last, and only
+ // dominate phi uses.
+ auto BlockEdge = getBlockEdge(PossibleCopy);
+ if (EdgeUsesOnly.count(BlockEdge)) {
+ VD.LocalNum = LN_Last;
+ auto *DomNode = DT.getNode(BlockEdge.first);
+ if (DomNode) {
+ VD.DFSIn = DomNode->getDFSNumIn();
+ VD.DFSOut = DomNode->getDFSNumOut();
+ VD.PInfo = PossibleCopy;
+ VD.EdgeOnly = true;
+ OrderedUses.push_back(VD);
+ }
+ } else {
+ // Otherwise, we are in the split block (even though we perform
+ // insertion in the branch block).
+ // Insert a possible copy at the split block and before the branch.
+ VD.LocalNum = LN_First;
+ auto *DomNode = DT.getNode(BlockEdge.second);
+ if (DomNode) {
+ VD.DFSIn = DomNode->getDFSNumIn();
+ VD.DFSOut = DomNode->getDFSNumOut();
+ VD.PInfo = PossibleCopy;
+ OrderedUses.push_back(VD);
+ }
+ }
+ }
+ }
+
+ convertUsesToDFSOrdered(Op, OrderedUses);
+ std::sort(OrderedUses.begin(), OrderedUses.end(), Compare);
+ SmallVector<ValueDFS, 8> RenameStack;
+ // For each use, sorted into dfs order, push values and replaces uses with
+ // top of stack, which will represent the reaching def.
+ for (auto &VD : OrderedUses) {
+ // We currently do not materialize copy over copy, but we should decide if
+ // we want to.
+ bool PossibleCopy = VD.PInfo != nullptr;
+ if (RenameStack.empty()) {
+ DEBUG(dbgs() << "Rename Stack is empty\n");
+ } else {
+ DEBUG(dbgs() << "Rename Stack Top DFS numbers are ("
+ << RenameStack.back().DFSIn << ","
+ << RenameStack.back().DFSOut << ")\n");
+ }
+
+ DEBUG(dbgs() << "Current DFS numbers are (" << VD.DFSIn << ","
+ << VD.DFSOut << ")\n");
+
+ bool ShouldPush = (VD.Def || PossibleCopy);
+ bool OutOfScope = !stackIsInScope(RenameStack, VD);
+ if (OutOfScope || ShouldPush) {
+ // Sync to our current scope.
+ popStackUntilDFSScope(RenameStack, VD);
+ if (ShouldPush) {
+ RenameStack.push_back(VD);
+ }
+ }
+ // If we get to this point, and the stack is empty we must have a use
+ // with no renaming needed, just skip it.
+ if (RenameStack.empty())
+ continue;
+ // Skip values, only want to rename the uses
+ if (VD.Def || PossibleCopy)
+ continue;
+ if (!DebugCounter::shouldExecute(RenameCounter)) {
+ DEBUG(dbgs() << "Skipping execution due to debug counter\n");
+ continue;
+ }
+ ValueDFS &Result = RenameStack.back();
+
+ // If the possible copy dominates something, materialize our stack up to
+ // this point. This ensures every comparison that affects our operation
+ // ends up with predicateinfo.
+ if (!Result.Def)
+ Result.Def = materializeStack(Counter, RenameStack, Op);
+
+ DEBUG(dbgs() << "Found replacement " << *Result.Def << " for "
+ << *VD.U->get() << " in " << *(VD.U->getUser()) << "\n");
+ assert(DT.dominates(cast<Instruction>(Result.Def), *VD.U) &&
+ "Predicateinfo def should have dominated this use");
+ VD.U->set(Result.Def);
+ }
+ }
+}
+
+PredicateInfo::ValueInfo &PredicateInfo::getOrCreateValueInfo(Value *Operand) {
+ auto OIN = ValueInfoNums.find(Operand);
+ if (OIN == ValueInfoNums.end()) {
+ // This will grow it
+ ValueInfos.resize(ValueInfos.size() + 1);
+ // This will use the new size and give us a 0 based number of the info
+ auto InsertResult = ValueInfoNums.insert({Operand, ValueInfos.size() - 1});
+ assert(InsertResult.second && "Value info number already existed?");
+ return ValueInfos[InsertResult.first->second];
+ }
+ return ValueInfos[OIN->second];
+}
+
+const PredicateInfo::ValueInfo &
+PredicateInfo::getValueInfo(Value *Operand) const {
+ auto OINI = ValueInfoNums.lookup(Operand);
+ assert(OINI != 0 && "Operand was not really in the Value Info Numbers");
+ assert(OINI < ValueInfos.size() &&
+ "Value Info Number greater than size of Value Info Table");
+ return ValueInfos[OINI];
+}
+
+PredicateInfo::PredicateInfo(Function &F, DominatorTree &DT,
+ AssumptionCache &AC)
+ : F(F), DT(DT), AC(AC) {
+ // Push an empty operand info so that we can detect 0 as not finding one
+ ValueInfos.resize(1);
+ buildPredicateInfo();
+}
+
+PredicateInfo::~PredicateInfo() {}
+
+void PredicateInfo::verifyPredicateInfo() const {}
+
+char PredicateInfoPrinterLegacyPass::ID = 0;
+
+PredicateInfoPrinterLegacyPass::PredicateInfoPrinterLegacyPass()
+ : FunctionPass(ID) {
+ initializePredicateInfoPrinterLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+}
+
+void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<DominatorTreeWrapperPass>();
+ AU.addRequired<AssumptionCacheTracker>();
+}
+
+bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) {
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ auto PredInfo = make_unique<PredicateInfo>(F, DT, AC);
+ PredInfo->print(dbgs());
+ if (VerifyPredicateInfo)
+ PredInfo->verifyPredicateInfo();
+ return false;
+}
+
+PreservedAnalyses PredicateInfoPrinterPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+ OS << "PredicateInfo for function: " << F.getName() << "\n";
+ make_unique<PredicateInfo>(F, DT, AC)->print(OS);
+
+ return PreservedAnalyses::all();
+}
+
+/// \brief An assembly annotator class to print PredicateInfo information in
+/// comments.
+class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
+ friend class PredicateInfo;
+ const PredicateInfo *PredInfo;
+
+public:
+ PredicateInfoAnnotatedWriter(const PredicateInfo *M) : PredInfo(M) {}
+
+ virtual void emitBasicBlockStartAnnot(const BasicBlock *BB,
+ formatted_raw_ostream &OS) {}
+
+ virtual void emitInstructionAnnot(const Instruction *I,
+ formatted_raw_ostream &OS) {
+ if (const auto *PI = PredInfo->getPredicateInfoFor(I)) {
+ OS << "; Has predicate info\n";
+ if (const auto *PB = dyn_cast<PredicateBranch>(PI)) {
+ OS << "; branch predicate info { TrueEdge: " << PB->TrueEdge
+ << " Comparison:" << *PB->Condition << " Edge: [";
+ PB->From->printAsOperand(OS);
+ OS << ",";
+ PB->To->printAsOperand(OS);
+ OS << "] }\n";
+ } else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) {
+ OS << "; switch predicate info { CaseValue: " << *PS->CaseValue
+ << " Switch:" << *PS->Switch << " Edge: [";
+ PS->From->printAsOperand(OS);
+ OS << ",";
+ PS->To->printAsOperand(OS);
+ OS << "] }\n";
+ } else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) {
+ OS << "; assume predicate info {"
+ << " Comparison:" << *PA->Condition << " }\n";
+ }
+ }
+ }
+};
+
+void PredicateInfo::print(raw_ostream &OS) const {
+ PredicateInfoAnnotatedWriter Writer(this);
+ F.print(OS, &Writer);
+}
+
+void PredicateInfo::dump() const {
+ PredicateInfoAnnotatedWriter Writer(this);
+ F.print(dbgs(), &Writer);
+}
+
+PreservedAnalyses PredicateInfoVerifierPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
+ make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo();
+
+ return PreservedAnalyses::all();
+}
+}
diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
index 35faa6f65efda..a33b85c4ee69a 100644
--- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
+++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
@@ -15,7 +15,6 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
@@ -23,6 +22,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasSetTracker.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/IteratedDominanceFrontier.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -38,6 +38,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include <algorithm>
using namespace llvm;
@@ -225,9 +226,6 @@ struct PromoteMem2Reg {
DominatorTree &DT;
DIBuilder DIB;
- /// An AliasSetTracker object to update. If null, don't update it.
- AliasSetTracker *AST;
-
/// A cache of @llvm.assume intrinsics used by SimplifyInstruction.
AssumptionCache *AC;
@@ -269,10 +267,10 @@ struct PromoteMem2Reg {
public:
PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT,
- AliasSetTracker *AST, AssumptionCache *AC)
+ AssumptionCache *AC)
: Allocas(Allocas.begin(), Allocas.end()), DT(DT),
DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false),
- AST(AST), AC(AC) {}
+ AC(AC) {}
void run();
@@ -301,6 +299,18 @@ private:
} // end of anonymous namespace
+/// Given a LoadInst LI this adds assume(LI != null) after it.
+static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) {
+ Function *AssumeIntrinsic =
+ Intrinsic::getDeclaration(LI->getModule(), Intrinsic::assume);
+ ICmpInst *LoadNotNull = new ICmpInst(ICmpInst::ICMP_NE, LI,
+ Constant::getNullValue(LI->getType()));
+ LoadNotNull->insertAfter(LI);
+ CallInst *CI = CallInst::Create(AssumeIntrinsic, {LoadNotNull});
+ CI->insertAfter(LoadNotNull);
+ AC->registerAssumption(CI);
+}
+
static void removeLifetimeIntrinsicUsers(AllocaInst *AI) {
// Knowing that this alloca is promotable, we know that it's safe to kill all
// instructions except for load and store.
@@ -334,9 +344,8 @@ static void removeLifetimeIntrinsicUsers(AllocaInst *AI) {
/// and thus must be phi-ed with undef. We fall back to the standard alloca
/// promotion algorithm in that case.
static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
- LargeBlockInfo &LBI,
- DominatorTree &DT,
- AliasSetTracker *AST) {
+ LargeBlockInfo &LBI, DominatorTree &DT,
+ AssumptionCache *AC) {
StoreInst *OnlyStore = Info.OnlyStore;
bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0));
BasicBlock *StoreBB = OnlyStore->getParent();
@@ -387,9 +396,15 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
// code.
if (ReplVal == LI)
ReplVal = UndefValue::get(LI->getType());
+
+ // If the load was marked as nonnull we don't want to lose
+ // that information when we erase this Load. So we preserve
+ // it with an assume.
+ if (AC && LI->getMetadata(LLVMContext::MD_nonnull) &&
+ !llvm::isKnownNonNullAt(ReplVal, LI, &DT))
+ addAssumeNonNull(AC, LI);
+
LI->replaceAllUsesWith(ReplVal);
- if (AST && LI->getType()->isPointerTy())
- AST->deleteValue(LI);
LI->eraseFromParent();
LBI.deleteValue(LI);
}
@@ -410,8 +425,6 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
Info.OnlyStore->eraseFromParent();
LBI.deleteValue(Info.OnlyStore);
- if (AST)
- AST->deleteValue(AI);
AI->eraseFromParent();
LBI.deleteValue(AI);
return true;
@@ -435,7 +448,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
/// }
static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
LargeBlockInfo &LBI,
- AliasSetTracker *AST) {
+ DominatorTree &DT,
+ AssumptionCache *AC) {
// The trickiest case to handle is when we have large blocks. Because of this,
// this code is optimized assuming that large blocks happen. This does not
// significantly pessimize the small block case. This uses LargeBlockInfo to
@@ -476,13 +490,18 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
// There is no store before this load, bail out (load may be affected
// by the following stores - see main comment).
return false;
- }
- else
+ } else {
// Otherwise, there was a store before this load, the load takes its value.
- LI->replaceAllUsesWith(std::prev(I)->second->getOperand(0));
+ // Note, if the load was marked as nonnull we don't want to lose that
+ // information when we erase it. So we preserve it with an assume.
+ Value *ReplVal = std::prev(I)->second->getOperand(0);
+ if (AC && LI->getMetadata(LLVMContext::MD_nonnull) &&
+ !llvm::isKnownNonNullAt(ReplVal, LI, &DT))
+ addAssumeNonNull(AC, LI);
+
+ LI->replaceAllUsesWith(ReplVal);
+ }
- if (AST && LI->getType()->isPointerTy())
- AST->deleteValue(LI);
LI->eraseFromParent();
LBI.deleteValue(LI);
}
@@ -499,8 +518,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
LBI.deleteValue(SI);
}
- if (AST)
- AST->deleteValue(AI);
AI->eraseFromParent();
LBI.deleteValue(AI);
@@ -517,8 +534,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
void PromoteMem2Reg::run() {
Function &F = *DT.getRoot()->getParent();
- if (AST)
- PointerAllocaValues.resize(Allocas.size());
AllocaDbgDeclares.resize(Allocas.size());
AllocaInfo Info;
@@ -536,8 +551,6 @@ void PromoteMem2Reg::run() {
if (AI->use_empty()) {
// If there are no uses of the alloca, just delete it now.
- if (AST)
- AST->deleteValue(AI);
AI->eraseFromParent();
// Remove the alloca from the Allocas list, since it has been processed
@@ -553,7 +566,7 @@ void PromoteMem2Reg::run() {
// If there is only a single store to this value, replace any loads of
// it that are directly dominated by the definition with the value stored.
if (Info.DefiningBlocks.size() == 1) {
- if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AST)) {
+ if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AC)) {
// The alloca has been processed, move on.
RemoveFromAllocasList(AllocaNum);
++NumSingleStore;
@@ -564,7 +577,7 @@ void PromoteMem2Reg::run() {
// If the alloca is only read and written in one basic block, just perform a
// linear sweep over the block to eliminate it.
if (Info.OnlyUsedInOneBlock &&
- promoteSingleBlockAlloca(AI, Info, LBI, AST)) {
+ promoteSingleBlockAlloca(AI, Info, LBI, DT, AC)) {
// The alloca has been processed, move on.
RemoveFromAllocasList(AllocaNum);
continue;
@@ -578,11 +591,6 @@ void PromoteMem2Reg::run() {
BBNumbers[&BB] = ID++;
}
- // If we have an AST to keep updated, remember some pointer value that is
- // stored into the alloca.
- if (AST)
- PointerAllocaValues[AllocaNum] = Info.AllocaPointerVal;
-
// Remember the dbg.declare intrinsic describing this alloca, if any.
if (Info.DbgDeclare)
AllocaDbgDeclares[AllocaNum] = Info.DbgDeclare;
@@ -662,8 +670,6 @@ void PromoteMem2Reg::run() {
// tree. Just delete the users now.
if (!A->use_empty())
A->replaceAllUsesWith(UndefValue::get(A->getType()));
- if (AST)
- AST->deleteValue(A);
A->eraseFromParent();
}
@@ -694,8 +700,6 @@ void PromoteMem2Reg::run() {
// If this PHI node merges one value and/or undefs, get the value.
if (Value *V = SimplifyInstruction(PN, DL, nullptr, &DT, AC)) {
- if (AST && PN->getType()->isPointerTy())
- AST->deleteValue(PN);
PN->replaceAllUsesWith(V);
PN->eraseFromParent();
NewPhiNodes.erase(I++);
@@ -863,10 +867,6 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo,
&BB->front());
++NumPHIInsert;
PhiToAllocaMap[PN] = AllocaNo;
-
- if (AST && PN->getType()->isPointerTy())
- AST->copyValue(PointerAllocaValues[AllocaNo], PN);
-
return true;
}
@@ -940,10 +940,15 @@ NextIteration:
Value *V = IncomingVals[AI->second];
+ // If the load was marked as nonnull we don't want to lose
+ // that information when we erase this Load. So we preserve
+ // it with an assume.
+ if (AC && LI->getMetadata(LLVMContext::MD_nonnull) &&
+ !llvm::isKnownNonNullAt(V, LI, &DT))
+ addAssumeNonNull(AC, LI);
+
// Anything using the load now uses the current value.
LI->replaceAllUsesWith(V);
- if (AST && LI->getType()->isPointerTy())
- AST->deleteValue(LI);
BB->getInstList().erase(LI);
} else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
// Delete this instruction and mark the name as the current holder of the
@@ -987,10 +992,10 @@ NextIteration:
}
void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT,
- AliasSetTracker *AST, AssumptionCache *AC) {
+ AssumptionCache *AC) {
// If there is nothing to do, bail out...
if (Allocas.empty())
return;
- PromoteMem2Reg(Allocas, DT, AST, AC).run();
+ PromoteMem2Reg(Allocas, DT, AC).run();
}
diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp
index 8e93ee757a159..8b6a2c3766d26 100644
--- a/lib/Transforms/Utils/SSAUpdater.cpp
+++ b/lib/Transforms/Utils/SSAUpdater.cpp
@@ -11,20 +11,29 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Utils/SSAUpdater.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugLoc.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
-#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/Value.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/SSAUpdater.h"
#include "llvm/Transforms/Utils/SSAUpdaterImpl.h"
+#include <cassert>
+#include <utility>
using namespace llvm;
@@ -36,7 +45,7 @@ static AvailableValsTy &getAvailableVals(void *AV) {
}
SSAUpdater::SSAUpdater(SmallVectorImpl<PHINode*> *NewPHI)
- : AV(nullptr), ProtoType(nullptr), ProtoName(), InsertedPHIs(NewPHI) {}
+ : InsertedPHIs(NewPHI) {}
SSAUpdater::~SSAUpdater() {
delete static_cast<AvailableValsTy*>(AV);
@@ -205,6 +214,7 @@ void SSAUpdater::RewriteUseAfterInsertions(Use &U) {
}
namespace llvm {
+
template<>
class SSAUpdaterTraits<SSAUpdater> {
public:
@@ -230,6 +240,7 @@ public:
PHI_iterator &operator++() { ++idx; return *this; }
bool operator==(const PHI_iterator& x) const { return idx == x.idx; }
bool operator!=(const PHI_iterator& x) const { return !operator==(x); }
+
Value *getIncomingValue() { return PHI->getIncomingValue(idx); }
BasicBlock *getIncomingBlock() { return PHI->getIncomingBlock(idx); }
};
@@ -303,7 +314,7 @@ public:
}
};
-} // End llvm namespace
+} // end namespace llvm
/// Check to see if AvailableVals has an entry for the specified BB and if so,
/// return it. If not, construct SSA form by first calculating the required
@@ -337,14 +348,12 @@ LoadAndStorePromoter(ArrayRef<const Instruction*> Insts,
SSA.Initialize(SomeVal->getType(), BaseName);
}
-
void LoadAndStorePromoter::
run(const SmallVectorImpl<Instruction*> &Insts) const {
-
// First step: bucket up uses of the alloca by the block they occur in.
// This is important because we have to handle multiple defs/uses in a block
// ourselves: SSAUpdater is purely for cross-block references.
- DenseMap<BasicBlock*, TinyPtrVector<Instruction*> > UsesByBlock;
+ DenseMap<BasicBlock*, TinyPtrVector<Instruction*>> UsesByBlock;
for (Instruction *User : Insts)
UsesByBlock[User->getParent()].push_back(User);
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 7b0bddbbb8314..127a44df5344f 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -169,6 +170,8 @@ class SimplifyCFGOpt {
unsigned BonusInstThreshold;
AssumptionCache *AC;
SmallPtrSetImpl<BasicBlock *> *LoopHeaders;
+ // See comments in SimplifyCFGOpt::SimplifySwitch.
+ bool LateSimplifyCFG;
Value *isValueEqualityComparison(TerminatorInst *TI);
BasicBlock *GetValueEqualityComparisonCases(
TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases);
@@ -192,9 +195,10 @@ class SimplifyCFGOpt {
public:
SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL,
unsigned BonusInstThreshold, AssumptionCache *AC,
- SmallPtrSetImpl<BasicBlock *> *LoopHeaders)
+ SmallPtrSetImpl<BasicBlock *> *LoopHeaders,
+ bool LateSimplifyCFG)
: TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC),
- LoopHeaders(LoopHeaders) {}
+ LoopHeaders(LoopHeaders), LateSimplifyCFG(LateSimplifyCFG) {}
bool run(BasicBlock *BB);
};
@@ -710,10 +714,9 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases(
TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases) {
if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
Cases.reserve(SI->getNumCases());
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
- ++i)
- Cases.push_back(
- ValueEqualityComparisonCase(i.getCaseValue(), i.getCaseSuccessor()));
+ for (auto Case : SI->cases())
+ Cases.push_back(ValueEqualityComparisonCase(Case.getCaseValue(),
+ Case.getCaseSuccessor()));
return SI->getDefaultDest();
}
@@ -846,12 +849,12 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor(
}
for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) {
--i;
- if (DeadCases.count(i.getCaseValue())) {
+ if (DeadCases.count(i->getCaseValue())) {
if (HasWeight) {
- std::swap(Weights[i.getCaseIndex() + 1], Weights.back());
+ std::swap(Weights[i->getCaseIndex() + 1], Weights.back());
Weights.pop_back();
}
- i.getCaseSuccessor()->removePredecessor(TI->getParent());
+ i->getCaseSuccessor()->removePredecessor(TI->getParent());
SI->removeCase(i);
}
}
@@ -996,8 +999,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
SmallSetVector<BasicBlock*, 4> FailBlocks;
if (!SafeToMergeTerminators(TI, PTI, &FailBlocks)) {
for (auto *Succ : FailBlocks) {
- std::vector<BasicBlock*> Blocks = { TI->getParent() };
- if (!SplitBlockPredecessors(Succ, Blocks, ".fold.split"))
+ if (!SplitBlockPredecessors(Succ, TI->getParent(), ".fold.split"))
return false;
}
}
@@ -1280,7 +1282,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI,
if (!isa<CallInst>(I1))
I1->setDebugLoc(
DILocation::getMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()));
-
+
I2->eraseFromParent();
Changed = true;
@@ -1472,29 +1474,28 @@ static bool canSinkInstructions(
return false;
}
+ // Because SROA can't handle speculating stores of selects, try not
+ // to sink loads or stores of allocas when we'd have to create a PHI for
+ // the address operand. Also, because it is likely that loads or stores
+ // of allocas will disappear when Mem2Reg/SROA is run, don't sink them.
+ // This can cause code churn which can have unintended consequences down
+ // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244.
+ // FIXME: This is a workaround for a deficiency in SROA - see
+ // https://llvm.org/bugs/show_bug.cgi?id=30188
+ if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) {
+ return isa<AllocaInst>(I->getOperand(1));
+ }))
+ return false;
+ if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) {
+ return isa<AllocaInst>(I->getOperand(0));
+ }))
+ return false;
+
for (unsigned OI = 0, OE = I0->getNumOperands(); OI != OE; ++OI) {
if (I0->getOperand(OI)->getType()->isTokenTy())
// Don't touch any operand of token type.
return false;
- // Because SROA can't handle speculating stores of selects, try not
- // to sink loads or stores of allocas when we'd have to create a PHI for
- // the address operand. Also, because it is likely that loads or stores
- // of allocas will disappear when Mem2Reg/SROA is run, don't sink them.
- // This can cause code churn which can have unintended consequences down
- // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244.
- // FIXME: This is a workaround for a deficiency in SROA - see
- // https://llvm.org/bugs/show_bug.cgi?id=30188
- if (OI == 1 && isa<StoreInst>(I0) &&
- any_of(Insts, [](const Instruction *I) {
- return isa<AllocaInst>(I->getOperand(1));
- }))
- return false;
- if (OI == 0 && isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) {
- return isa<AllocaInst>(I->getOperand(0));
- }))
- return false;
-
auto SameAsI0 = [&I0, OI](const Instruction *I) {
assert(I->getNumOperands() == I0->getNumOperands());
return I->getOperand(OI) == I0->getOperand(OI);
@@ -1546,7 +1547,7 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) {
}))
return false;
}
-
+
// We don't need to do any more checking here; canSinkLastInstruction should
// have done it all for us.
SmallVector<Value*, 4> NewOperands;
@@ -1653,7 +1654,7 @@ namespace {
bool isValid() const {
return !Fail;
}
-
+
void operator -- () {
if (Fail)
return;
@@ -1699,7 +1700,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
// / \
// [f(1)] [if]
// | | \
- // | | \
+ // | | |
// | [f(2)]|
// \ | /
// [ end ]
@@ -1737,7 +1738,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
}
if (UnconditionalPreds.size() < 2)
return false;
-
+
bool Changed = false;
// We take a two-step approach to tail sinking. First we scan from the end of
// each block upwards in lockstep. If the n'th instruction from the end of each
@@ -1767,7 +1768,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size();
if ((NumPHIdValues % UnconditionalPreds.size()) != 0)
NumPHIInsts++;
-
+
return NumPHIInsts <= 1;
};
@@ -1790,7 +1791,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
}
if (!Profitable)
return false;
-
+
DEBUG(dbgs() << "SINK: Splitting edge\n");
// We have a conditional edge and we're going to sink some instructions.
// Insert a new block postdominating all blocks we're going to sink from.
@@ -1800,7 +1801,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
return false;
Changed = true;
}
-
+
// Now that we've analyzed all potential sinking candidates, perform the
// actual sink. We iteratively sink the last non-terminator of the source
// blocks into their common successor unless doing so would require too
@@ -1826,7 +1827,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) {
DEBUG(dbgs() << "SINK: stopping here, too many PHIs would be created!\n");
break;
}
-
+
if (!sinkLastInstruction(UnconditionalPreds))
return Changed;
NumSinkCommons++;
@@ -2078,6 +2079,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
Value *S = Builder.CreateSelect(
BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI);
SpeculatedStore->setOperand(0, S);
+ SpeculatedStore->setDebugLoc(
+ DILocation::getMergedLocation(
+ BI->getDebugLoc(), SpeculatedStore->getDebugLoc()));
}
// Metadata can be dependent on the condition we are hoisting above.
@@ -2147,7 +2151,8 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) {
/// If we have a conditional branch on a PHI node value that is defined in the
/// same block as the branch and if any PHI entries are constants, thread edges
/// corresponding to that entry to be branches to their ultimate destination.
-static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) {
+static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL,
+ AssumptionCache *AC) {
BasicBlock *BB = BI->getParent();
PHINode *PN = dyn_cast<PHINode>(BI->getCondition());
// NOTE: we currently cannot transform this case if the PHI node is used
@@ -2239,6 +2244,11 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) {
// Insert the new instruction into its new home.
if (N)
EdgeBB->getInstList().insert(InsertPt, N);
+
+ // Register the new instruction with the assumption cache if necessary.
+ if (auto *II = dyn_cast_or_null<IntrinsicInst>(N))
+ if (II->getIntrinsicID() == Intrinsic::assume)
+ AC->registerAssumption(II);
}
// Loop over all of the edges from PredBB to BB, changing them to branch
@@ -2251,7 +2261,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) {
}
// Recurse, simplifying any other constants.
- return FoldCondBranchOnPHI(BI, DL) | true;
+ return FoldCondBranchOnPHI(BI, DL, AC) | true;
}
return false;
@@ -3433,8 +3443,8 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) {
// Find the relevant condition and destinations.
Value *Condition = Select->getCondition();
- BasicBlock *TrueBB = SI->findCaseValue(TrueVal).getCaseSuccessor();
- BasicBlock *FalseBB = SI->findCaseValue(FalseVal).getCaseSuccessor();
+ BasicBlock *TrueBB = SI->findCaseValue(TrueVal)->getCaseSuccessor();
+ BasicBlock *FalseBB = SI->findCaseValue(FalseVal)->getCaseSuccessor();
// Get weight for TrueBB and FalseBB.
uint32_t TrueWeight = 0, FalseWeight = 0;
@@ -3444,9 +3454,9 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) {
GetBranchWeights(SI, Weights);
if (Weights.size() == 1 + SI->getNumCases()) {
TrueWeight =
- (uint32_t)Weights[SI->findCaseValue(TrueVal).getSuccessorIndex()];
+ (uint32_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()];
FalseWeight =
- (uint32_t)Weights[SI->findCaseValue(FalseVal).getSuccessorIndex()];
+ (uint32_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()];
}
}
@@ -4148,15 +4158,16 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) {
}
}
} else if (auto *SI = dyn_cast<SwitchInst>(TI)) {
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
- ++i)
- if (i.getCaseSuccessor() == BB) {
- BB->removePredecessor(SI->getParent());
- SI->removeCase(i);
- --i;
- --e;
- Changed = true;
+ for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) {
+ if (i->getCaseSuccessor() != BB) {
+ ++i;
+ continue;
}
+ BB->removePredecessor(SI->getParent());
+ i = SI->removeCase(i);
+ e = SI->case_end();
+ Changed = true;
+ }
} else if (auto *II = dyn_cast<InvokeInst>(TI)) {
if (II->getUnwindDest() == BB) {
removeUnwindEdge(TI->getParent());
@@ -4239,18 +4250,18 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) {
SmallVector<ConstantInt *, 16> CasesA;
SmallVector<ConstantInt *, 16> CasesB;
- for (SwitchInst::CaseIt I : SI->cases()) {
- BasicBlock *Dest = I.getCaseSuccessor();
+ for (auto Case : SI->cases()) {
+ BasicBlock *Dest = Case.getCaseSuccessor();
if (!DestA)
DestA = Dest;
if (Dest == DestA) {
- CasesA.push_back(I.getCaseValue());
+ CasesA.push_back(Case.getCaseValue());
continue;
}
if (!DestB)
DestB = Dest;
if (Dest == DestB) {
- CasesB.push_back(I.getCaseValue());
+ CasesB.push_back(Case.getCaseValue());
continue;
}
return false; // More than two destinations.
@@ -4375,7 +4386,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC,
bool HasDefault =
!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
const unsigned NumUnknownBits =
- Bits - (KnownZero.Or(KnownOne)).countPopulation();
+ Bits - (KnownZero | KnownOne).countPopulation();
assert(NumUnknownBits <= Bits);
if (HasDefault && DeadCases.empty() &&
NumUnknownBits < 64 /* avoid overflow */ &&
@@ -4400,17 +4411,17 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC,
// Remove dead cases from the switch.
for (ConstantInt *DeadCase : DeadCases) {
- SwitchInst::CaseIt Case = SI->findCaseValue(DeadCase);
- assert(Case != SI->case_default() &&
+ SwitchInst::CaseIt CaseI = SI->findCaseValue(DeadCase);
+ assert(CaseI != SI->case_default() &&
"Case was not found. Probably mistake in DeadCases forming.");
if (HasWeight) {
- std::swap(Weights[Case.getCaseIndex() + 1], Weights.back());
+ std::swap(Weights[CaseI->getCaseIndex() + 1], Weights.back());
Weights.pop_back();
}
// Prune unused values from PHI nodes.
- Case.getCaseSuccessor()->removePredecessor(SI->getParent());
- SI->removeCase(Case);
+ CaseI->getCaseSuccessor()->removePredecessor(SI->getParent());
+ SI->removeCase(CaseI);
}
if (HasWeight && Weights.size() >= 2) {
SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
@@ -4464,10 +4475,9 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) {
typedef DenseMap<PHINode *, SmallVector<int, 4>> ForwardingNodesMap;
ForwardingNodesMap ForwardingNodes;
- for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E;
- ++I) {
- ConstantInt *CaseValue = I.getCaseValue();
- BasicBlock *CaseDest = I.getCaseSuccessor();
+ for (auto Case : SI->cases()) {
+ ConstantInt *CaseValue = Case.getCaseValue();
+ BasicBlock *CaseDest = Case.getCaseSuccessor();
int PhiIndex;
PHINode *PHI =
@@ -5202,8 +5212,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// common destination, as well as the min and max case values.
assert(SI->case_begin() != SI->case_end());
SwitchInst::CaseIt CI = SI->case_begin();
- ConstantInt *MinCaseVal = CI.getCaseValue();
- ConstantInt *MaxCaseVal = CI.getCaseValue();
+ ConstantInt *MinCaseVal = CI->getCaseValue();
+ ConstantInt *MaxCaseVal = CI->getCaseValue();
BasicBlock *CommonDest = nullptr;
typedef SmallVector<std::pair<ConstantInt *, Constant *>, 4> ResultListTy;
@@ -5213,7 +5223,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
SmallVector<PHINode *, 4> PHIs;
for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) {
- ConstantInt *CaseVal = CI.getCaseValue();
+ ConstantInt *CaseVal = CI->getCaseValue();
if (CaseVal->getValue().slt(MinCaseVal->getValue()))
MinCaseVal = CaseVal;
if (CaseVal->getValue().sgt(MaxCaseVal->getValue()))
@@ -5222,7 +5232,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// Resulting value at phi nodes for this case value.
typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy;
ResultsTy Results;
- if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest,
+ if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest,
Results, DL, TTI))
return false;
@@ -5503,11 +5513,10 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
auto *Rot = Builder.CreateOr(LShr, Shl);
SI->replaceUsesOfWith(SI->getCondition(), Rot);
- for (SwitchInst::CaseIt C = SI->case_begin(), E = SI->case_end(); C != E;
- ++C) {
- auto *Orig = C.getCaseValue();
+ for (auto Case : SI->cases()) {
+ auto *Orig = Case.getCaseValue();
auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base);
- C.setValue(
+ Case.setValue(
cast<ConstantInt>(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue()))));
}
return true;
@@ -5553,7 +5562,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (ForwardSwitchConditionToPHI(SI))
return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true;
- if (SwitchToLookupTable(SI, Builder, DL, TTI))
+ // The conversion from switch to lookup tables results in difficult
+ // to analyze code and makes pruning branches much harder.
+ // This is a problem of the switch expression itself can still be
+ // restricted as a result of inlining or CVP. There only apply this
+ // transformation during late steps of the optimisation chain.
+ if (LateSimplifyCFG && SwitchToLookupTable(SI, Builder, DL, TTI))
return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true;
if (ReduceSwitchRange(SI, Builder, DL, TTI))
@@ -5833,7 +5847,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// through this block if any PHI node entries are constants.
if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition()))
if (PN->getParent() == BI->getParent())
- if (FoldCondBranchOnPHI(BI, DL))
+ if (FoldCondBranchOnPHI(BI, DL, AC))
return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true;
// Scan predecessor blocks for conditional branches.
@@ -6012,8 +6026,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) {
///
bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,
unsigned BonusInstThreshold, AssumptionCache *AC,
- SmallPtrSetImpl<BasicBlock *> *LoopHeaders) {
+ SmallPtrSetImpl<BasicBlock *> *LoopHeaders,
+ bool LateSimplifyCFG) {
return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(),
- BonusInstThreshold, AC, LoopHeaders)
+ BonusInstThreshold, AC, LoopHeaders, LateSimplifyCFG)
.run(BB);
}
diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp
index 6b1d3dc413305..a4cc6a031ad4c 100644
--- a/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -35,6 +35,9 @@ using namespace llvm;
STATISTIC(NumElimIdentity, "Number of IV identities eliminated");
STATISTIC(NumElimOperand, "Number of IV operands folded into a use");
STATISTIC(NumElimRem , "Number of IV remainder operations eliminated");
+STATISTIC(
+ NumSimplifiedSDiv,
+ "Number of IV signed division operations converted to unsigned division");
STATISTIC(NumElimCmp , "Number of IV comparisons eliminated");
namespace {
@@ -75,6 +78,7 @@ namespace {
void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand);
void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand,
bool IsSigned);
+ bool eliminateSDiv(BinaryOperator *SDiv);
bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand);
};
}
@@ -265,6 +269,33 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) {
Changed = true;
}
+bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) {
+ // Get the SCEVs for the ICmp operands.
+ auto *N = SE->getSCEV(SDiv->getOperand(0));
+ auto *D = SE->getSCEV(SDiv->getOperand(1));
+
+ // Simplify unnecessary loops away.
+ const Loop *L = LI->getLoopFor(SDiv->getParent());
+ N = SE->getSCEVAtScope(N, L);
+ D = SE->getSCEVAtScope(D, L);
+
+ // Replace sdiv by udiv if both of the operands are non-negative
+ if (SE->isKnownNonNegative(N) && SE->isKnownNonNegative(D)) {
+ auto *UDiv = BinaryOperator::Create(
+ BinaryOperator::UDiv, SDiv->getOperand(0), SDiv->getOperand(1),
+ SDiv->getName() + ".udiv", SDiv);
+ UDiv->setIsExact(SDiv->isExact());
+ SDiv->replaceAllUsesWith(UDiv);
+ DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n');
+ ++NumSimplifiedSDiv;
+ Changed = true;
+ DeadInsts.push_back(SDiv);
+ return true;
+ }
+
+ return false;
+}
+
/// SimplifyIVUsers helper for eliminating useless
/// remainder operations operating on an induction variable.
void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem,
@@ -426,12 +457,15 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst,
eliminateIVComparison(ICmp, IVOperand);
return true;
}
- if (BinaryOperator *Rem = dyn_cast<BinaryOperator>(UseInst)) {
- bool IsSigned = Rem->getOpcode() == Instruction::SRem;
- if (IsSigned || Rem->getOpcode() == Instruction::URem) {
- eliminateIVRemainder(Rem, IVOperand, IsSigned);
+ if (BinaryOperator *Bin = dyn_cast<BinaryOperator>(UseInst)) {
+ bool IsSRem = Bin->getOpcode() == Instruction::SRem;
+ if (IsSRem || Bin->getOpcode() == Instruction::URem) {
+ eliminateIVRemainder(Bin, IVOperand, IsSRem);
return true;
}
+
+ if (Bin->getOpcode() == Instruction::SDiv)
+ return eliminateSDiv(Bin);
}
if (auto *CI = dyn_cast<CallInst>(UseInst))
diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp
index 1220490123ce5..f6070868de44e 100644
--- a/lib/Transforms/Utils/SimplifyInstructions.cpp
+++ b/lib/Transforms/Utils/SimplifyInstructions.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/OptimizationDiagnosticInfo.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
@@ -35,7 +36,8 @@ using namespace llvm;
STATISTIC(NumSimplified, "Number of redundant instructions removed");
static bool runImpl(Function &F, const DominatorTree *DT,
- const TargetLibraryInfo *TLI, AssumptionCache *AC) {
+ const TargetLibraryInfo *TLI, AssumptionCache *AC,
+ OptimizationRemarkEmitter *ORE) {
const DataLayout &DL = F.getParent()->getDataLayout();
SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2;
bool Changed = false;
@@ -54,7 +56,7 @@ static bool runImpl(Function &F, const DominatorTree *DT,
// Don't waste time simplifying unused instructions.
if (!I->use_empty()) {
- if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) {
+ if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC, ORE)) {
// Mark all uses for resimplification next time round the loop.
for (User *U : I->users())
Next->insert(cast<Instruction>(U));
@@ -95,6 +97,7 @@ namespace {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
/// runOnFunction - Remove instructions that simplify.
@@ -108,7 +111,10 @@ namespace {
&getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
AssumptionCache *AC =
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- return runImpl(F, DT, TLI, AC);
+ OptimizationRemarkEmitter *ORE =
+ &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
+
+ return runImpl(F, DT, TLI, AC, ORE);
}
};
}
@@ -119,6 +125,7 @@ INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify",
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
INITIALIZE_PASS_END(InstSimplifier, "instsimplify",
"Remove redundant instructions", false, false)
char &llvm::InstructionSimplifierID = InstSimplifier::ID;
@@ -133,9 +140,12 @@ PreservedAnalyses InstSimplifierPass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
- bool Changed = runImpl(F, &DT, &TLI, &AC);
+ auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ bool Changed = runImpl(F, &DT, &TLI, &AC, &ORE);
if (!Changed)
return PreservedAnalyses::all();
- // FIXME: This should also 'preserve the CFG'.
- return PreservedAnalyses::none();
+
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 8eaeb1073a761..aa71e3669ea27 100644
--- a/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -51,9 +51,9 @@ static cl::opt<bool>
// Helper Functions
//===----------------------------------------------------------------------===//
-static bool ignoreCallingConv(LibFunc::Func Func) {
- return Func == LibFunc::abs || Func == LibFunc::labs ||
- Func == LibFunc::llabs || Func == LibFunc::strlen;
+static bool ignoreCallingConv(LibFunc Func) {
+ return Func == LibFunc_abs || Func == LibFunc_labs ||
+ Func == LibFunc_llabs || Func == LibFunc_strlen;
}
static bool isCallingConvCCompatible(CallInst *CI) {
@@ -123,8 +123,8 @@ static bool callHasFloatingPointArgument(const CallInst *CI) {
/// \brief Check whether the overloaded unary floating point function
/// corresponding to \a Ty is available.
static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty,
- LibFunc::Func DoubleFn, LibFunc::Func FloatFn,
- LibFunc::Func LongDoubleFn) {
+ LibFunc DoubleFn, LibFunc FloatFn,
+ LibFunc LongDoubleFn) {
switch (Ty->getTypeID()) {
case Type::FloatTyID:
return TLI->has(FloatFn);
@@ -809,9 +809,9 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) {
// TODO: Does this belong in BuildLibCalls or should all of those similar
// functions be moved here?
-static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs,
+static Value *emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs,
IRBuilder<> &B, const TargetLibraryInfo &TLI) {
- LibFunc::Func Func;
+ LibFunc Func;
if (!TLI.getLibFunc("calloc", Func) || !TLI.has(Func))
return nullptr;
@@ -819,7 +819,7 @@ static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs,
const DataLayout &DL = M->getDataLayout();
IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext()));
Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(),
- PtrType, PtrType, nullptr);
+ PtrType, PtrType);
CallInst *CI = B.CreateCall(Calloc, { Num, Size }, "calloc");
if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts()))
@@ -846,9 +846,9 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B,
// Is the inner call really malloc()?
Function *InnerCallee = Malloc->getCalledFunction();
- LibFunc::Func Func;
+ LibFunc Func;
if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) ||
- Func != LibFunc::malloc)
+ Func != LibFunc_malloc)
return nullptr;
// The memset must cover the same number of bytes that are malloc'd.
@@ -948,6 +948,20 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
return B.CreateFPExt(V, B.getDoubleTy());
}
+// Replace a libcall \p CI with a call to intrinsic \p IID
+static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) {
+ // Propagate fast-math flags from the existing call to the new call.
+ IRBuilder<>::FastMathFlagGuard Guard(B);
+ B.setFastMathFlags(CI->getFastMathFlags());
+
+ Module *M = CI->getModule();
+ Value *V = CI->getArgOperand(0);
+ Function *F = Intrinsic::getDeclaration(M, IID, CI->getType());
+ CallInst *NewCall = B.CreateCall(F, V);
+ NewCall->takeName(CI);
+ return NewCall;
+}
+
/// Shrink double -> float for binary functions like 'fmin/fmax'.
static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
@@ -1041,9 +1055,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) {
// pow(10.0, x) -> exp10(x)
if (Op1C->isExactlyValue(10.0) &&
- hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f,
- LibFunc::exp10l))
- return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B,
+ hasUnaryFloatFn(TLI, Op1->getType(), LibFunc_exp10, LibFunc_exp10f,
+ LibFunc_exp10l))
+ return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B,
Callee->getAttributes());
}
@@ -1055,10 +1069,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
// pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).
auto *OpC = dyn_cast<CallInst>(Op1);
if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) {
- LibFunc::Func Func;
+ LibFunc Func;
Function *OpCCallee = OpC->getCalledFunction();
if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) &&
- TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2)) {
+ TLI->has(Func) && (Func == LibFunc_exp || Func == LibFunc_exp2)) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());
Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul");
@@ -1075,17 +1089,20 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
return ConstantFP::get(CI->getType(), 1.0);
if (Op2C->isExactlyValue(-0.5) &&
- hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf,
- LibFunc::sqrtl)) {
+ hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf,
+ LibFunc_sqrtl)) {
// If -ffast-math:
// pow(x, -0.5) -> 1.0 / sqrt(x)
if (CI->hasUnsafeAlgebra()) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());
- // Here we cannot lower to an intrinsic because C99 sqrt() and llvm.sqrt
- // are not guaranteed to have the same semantics.
- Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B,
+ // TODO: If the pow call is an intrinsic, we should lower to the sqrt
+ // intrinsic, so we match errno semantics. We also should check that the
+ // target can in fact lower the sqrt intrinsic -- we currently have no way
+ // to ask this question other than asking whether the target has a sqrt
+ // libcall, which is a sufficient but not necessary condition.
+ Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,
Callee->getAttributes());
return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip");
@@ -1093,19 +1110,17 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
}
if (Op2C->isExactlyValue(0.5) &&
- hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf,
- LibFunc::sqrtl) &&
- hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf,
- LibFunc::fabsl)) {
+ hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf,
+ LibFunc_sqrtl)) {
// In -ffast-math, pow(x, 0.5) -> sqrt(x).
if (CI->hasUnsafeAlgebra()) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());
- // Unlike other math intrinsics, sqrt has differerent semantics
- // from the libc function. See LangRef for details.
- return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B,
+ // TODO: As above, we should lower to the sqrt intrinsic if the pow is an
+ // intrinsic, to match errno semantics.
+ return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,
Callee->getAttributes());
}
@@ -1115,9 +1130,16 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
// TODO: In finite-only mode, this could be just fabs(sqrt(x)).
Value *Inf = ConstantFP::getInfinity(CI->getType());
Value *NegInf = ConstantFP::getInfinity(CI->getType(), true);
+
+ // TODO: As above, we should lower to the sqrt intrinsic if the pow is an
+ // intrinsic, to match errno semantics.
Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes());
- Value *FAbs =
- emitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes());
+
+ Module *M = Callee->getParent();
+ Function *FabsF = Intrinsic::getDeclaration(M, Intrinsic::fabs,
+ CI->getType());
+ Value *FAbs = B.CreateCall(FabsF, Sqrt);
+
Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf);
Value *Sel = B.CreateSelect(FCmp, Inf, FAbs);
return Sel;
@@ -1173,11 +1195,11 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
Value *Op = CI->getArgOperand(0);
// Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32
// Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32
- LibFunc::Func LdExp = LibFunc::ldexpl;
+ LibFunc LdExp = LibFunc_ldexpl;
if (Op->getType()->isFloatTy())
- LdExp = LibFunc::ldexpf;
+ LdExp = LibFunc_ldexpf;
else if (Op->getType()->isDoubleTy())
- LdExp = LibFunc::ldexp;
+ LdExp = LibFunc_ldexp;
if (TLI->has(LdExp)) {
Value *LdExpArg = nullptr;
@@ -1197,7 +1219,7 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
Module *M = CI->getModule();
Value *NewCallee =
M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(),
- Op->getType(), B.getInt32Ty(), nullptr);
+ Op->getType(), B.getInt32Ty());
CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg});
if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts()))
CI->setCallingConv(F->getCallingConv());
@@ -1208,15 +1230,6 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
return Ret;
}
-Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
- Function *Callee = CI->getCalledFunction();
- StringRef Name = Callee->getName();
- if (Name == "fabs" && hasFloatVersion(Name))
- return optimizeUnaryDoubleFP(CI, B, false);
-
- return nullptr;
-}
-
Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
// If we can shrink the call to a float function rather than a double
@@ -1280,17 +1293,17 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
FMF.setUnsafeAlgebra();
B.setFastMathFlags(FMF);
- LibFunc::Func Func;
+ LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) &&
- Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow))
+ Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow))
return B.CreateFMul(OpC->getArgOperand(1),
emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
Callee->getAttributes()), "mul");
// log(exp2(y)) -> y*log(2)
if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) &&
- TLI->has(Func) && Func == LibFunc::exp2)
+ TLI->has(Func) && Func == LibFunc_exp2)
return B.CreateFMul(
OpC->getArgOperand(0),
emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0),
@@ -1302,8 +1315,11 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
- if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" ||
- Callee->getIntrinsicID() == Intrinsic::sqrt))
+ // TODO: Once we have a way (other than checking for the existince of the
+ // libcall) to tell whether our target can lower @llvm.sqrt, relax the
+ // condition below.
+ if (TLI->has(LibFunc_sqrtf) && (Callee->getName() == "sqrt" ||
+ Callee->getIntrinsicID() == Intrinsic::sqrt))
Ret = optimizeUnaryDoubleFP(CI, B, true);
if (!CI->hasUnsafeAlgebra())
@@ -1385,12 +1401,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) {
// tan(atan(x)) -> x
// tanf(atanf(x)) -> x
// tanl(atanl(x)) -> x
- LibFunc::Func Func;
+ LibFunc Func;
Function *F = OpC->getCalledFunction();
if (F && TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) &&
- ((Func == LibFunc::atan && Callee->getName() == "tan") ||
- (Func == LibFunc::atanf && Callee->getName() == "tanf") ||
- (Func == LibFunc::atanl && Callee->getName() == "tanl")))
+ ((Func == LibFunc_atan && Callee->getName() == "tan") ||
+ (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
+ (Func == LibFunc_atanl && Callee->getName() == "tanl")))
Ret = OpC->getArgOperand(0);
return Ret;
}
@@ -1427,7 +1443,7 @@ static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
Module *M = OrigCallee->getParent();
Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(),
- ResTy, ArgTy, nullptr);
+ ResTy, ArgTy);
if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {
// If the argument is an instruction, it must dominate all uses so put our
@@ -1508,24 +1524,24 @@ void LibCallSimplifier::classifyArgUse(
return;
Function *Callee = CI->getCalledFunction();
- LibFunc::Func Func;
+ LibFunc Func;
if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) ||
!isTrigLibCall(CI))
return;
if (IsFloat) {
- if (Func == LibFunc::sinpif)
+ if (Func == LibFunc_sinpif)
SinCalls.push_back(CI);
- else if (Func == LibFunc::cospif)
+ else if (Func == LibFunc_cospif)
CosCalls.push_back(CI);
- else if (Func == LibFunc::sincospif_stret)
+ else if (Func == LibFunc_sincospif_stret)
SinCosCalls.push_back(CI);
} else {
- if (Func == LibFunc::sinpi)
+ if (Func == LibFunc_sinpi)
SinCalls.push_back(CI);
- else if (Func == LibFunc::cospi)
+ else if (Func == LibFunc_cospi)
CosCalls.push_back(CI);
- else if (Func == LibFunc::sincospi_stret)
+ else if (Func == LibFunc_sincospi_stret)
SinCosCalls.push_back(CI);
}
}
@@ -1609,7 +1625,7 @@ Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B,
// Proceedings of PACT'98, Oct. 1998, IEEE
if (!CI->hasFnAttr(Attribute::Cold) &&
isReportingError(Callee, CI, StreamArg)) {
- CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold);
+ CI->addAttribute(AttributeList::FunctionIndex, Attribute::Cold);
}
return nullptr;
@@ -1699,7 +1715,7 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) {
// printf(format, ...) -> iprintf(format, ...) if no floating point
// arguments.
- if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) {
+ if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) {
Module *M = B.GetInsertBlock()->getParent()->getParent();
Constant *IPrintFFn =
M->getOrInsertFunction("iprintf", FT, Callee->getAttributes());
@@ -1780,7 +1796,7 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) {
// sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating
// point arguments.
- if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) {
+ if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) {
Module *M = B.GetInsertBlock()->getParent()->getParent();
Constant *SIPrintFFn =
M->getOrInsertFunction("siprintf", FT, Callee->getAttributes());
@@ -1850,7 +1866,7 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) {
// fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no
// floating point arguments.
- if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) {
+ if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) {
Module *M = B.GetInsertBlock()->getParent()->getParent();
Constant *FIPrintFFn =
M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes());
@@ -1929,7 +1945,7 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) {
}
bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) {
- LibFunc::Func Func;
+ LibFunc Func;
SmallString<20> FloatFuncName = FuncName;
FloatFuncName += 'f';
if (TLI->getLibFunc(FloatFuncName, Func))
@@ -1939,7 +1955,7 @@ bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) {
Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
IRBuilder<> &Builder) {
- LibFunc::Func Func;
+ LibFunc Func;
Function *Callee = CI->getCalledFunction();
// Check for string/memory library functions.
if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) {
@@ -1948,51 +1964,51 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
isCallingConvCCompatible(CI)) &&
"Optimizing string/memory libcall would change the calling convention");
switch (Func) {
- case LibFunc::strcat:
+ case LibFunc_strcat:
return optimizeStrCat(CI, Builder);
- case LibFunc::strncat:
+ case LibFunc_strncat:
return optimizeStrNCat(CI, Builder);
- case LibFunc::strchr:
+ case LibFunc_strchr:
return optimizeStrChr(CI, Builder);
- case LibFunc::strrchr:
+ case LibFunc_strrchr:
return optimizeStrRChr(CI, Builder);
- case LibFunc::strcmp:
+ case LibFunc_strcmp:
return optimizeStrCmp(CI, Builder);
- case LibFunc::strncmp:
+ case LibFunc_strncmp:
return optimizeStrNCmp(CI, Builder);
- case LibFunc::strcpy:
+ case LibFunc_strcpy:
return optimizeStrCpy(CI, Builder);
- case LibFunc::stpcpy:
+ case LibFunc_stpcpy:
return optimizeStpCpy(CI, Builder);
- case LibFunc::strncpy:
+ case LibFunc_strncpy:
return optimizeStrNCpy(CI, Builder);
- case LibFunc::strlen:
+ case LibFunc_strlen:
return optimizeStrLen(CI, Builder);
- case LibFunc::strpbrk:
+ case LibFunc_strpbrk:
return optimizeStrPBrk(CI, Builder);
- case LibFunc::strtol:
- case LibFunc::strtod:
- case LibFunc::strtof:
- case LibFunc::strtoul:
- case LibFunc::strtoll:
- case LibFunc::strtold:
- case LibFunc::strtoull:
+ case LibFunc_strtol:
+ case LibFunc_strtod:
+ case LibFunc_strtof:
+ case LibFunc_strtoul:
+ case LibFunc_strtoll:
+ case LibFunc_strtold:
+ case LibFunc_strtoull:
return optimizeStrTo(CI, Builder);
- case LibFunc::strspn:
+ case LibFunc_strspn:
return optimizeStrSpn(CI, Builder);
- case LibFunc::strcspn:
+ case LibFunc_strcspn:
return optimizeStrCSpn(CI, Builder);
- case LibFunc::strstr:
+ case LibFunc_strstr:
return optimizeStrStr(CI, Builder);
- case LibFunc::memchr:
+ case LibFunc_memchr:
return optimizeMemChr(CI, Builder);
- case LibFunc::memcmp:
+ case LibFunc_memcmp:
return optimizeMemCmp(CI, Builder);
- case LibFunc::memcpy:
+ case LibFunc_memcpy:
return optimizeMemCpy(CI, Builder);
- case LibFunc::memmove:
+ case LibFunc_memmove:
return optimizeMemMove(CI, Builder);
- case LibFunc::memset:
+ case LibFunc_memset:
return optimizeMemSet(CI, Builder);
default:
break;
@@ -2005,7 +2021,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
if (CI->isNoBuiltin())
return nullptr;
- LibFunc::Func Func;
+ LibFunc Func;
Function *Callee = CI->getCalledFunction();
StringRef FuncName = Callee->getName();
@@ -2029,8 +2045,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
return optimizePow(CI, Builder);
case Intrinsic::exp2:
return optimizeExp2(CI, Builder);
- case Intrinsic::fabs:
- return optimizeFabs(CI, Builder);
case Intrinsic::log:
return optimizeLog(CI, Builder);
case Intrinsic::sqrt:
@@ -2067,114 +2081,117 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
if (Value *V = optimizeStringMemoryLibCall(CI, Builder))
return V;
switch (Func) {
- case LibFunc::cosf:
- case LibFunc::cos:
- case LibFunc::cosl:
+ case LibFunc_cosf:
+ case LibFunc_cos:
+ case LibFunc_cosl:
return optimizeCos(CI, Builder);
- case LibFunc::sinpif:
- case LibFunc::sinpi:
- case LibFunc::cospif:
- case LibFunc::cospi:
+ case LibFunc_sinpif:
+ case LibFunc_sinpi:
+ case LibFunc_cospif:
+ case LibFunc_cospi:
return optimizeSinCosPi(CI, Builder);
- case LibFunc::powf:
- case LibFunc::pow:
- case LibFunc::powl:
+ case LibFunc_powf:
+ case LibFunc_pow:
+ case LibFunc_powl:
return optimizePow(CI, Builder);
- case LibFunc::exp2l:
- case LibFunc::exp2:
- case LibFunc::exp2f:
+ case LibFunc_exp2l:
+ case LibFunc_exp2:
+ case LibFunc_exp2f:
return optimizeExp2(CI, Builder);
- case LibFunc::fabsf:
- case LibFunc::fabs:
- case LibFunc::fabsl:
- return optimizeFabs(CI, Builder);
- case LibFunc::sqrtf:
- case LibFunc::sqrt:
- case LibFunc::sqrtl:
+ case LibFunc_fabsf:
+ case LibFunc_fabs:
+ case LibFunc_fabsl:
+ return replaceUnaryCall(CI, Builder, Intrinsic::fabs);
+ case LibFunc_sqrtf:
+ case LibFunc_sqrt:
+ case LibFunc_sqrtl:
return optimizeSqrt(CI, Builder);
- case LibFunc::ffs:
- case LibFunc::ffsl:
- case LibFunc::ffsll:
+ case LibFunc_ffs:
+ case LibFunc_ffsl:
+ case LibFunc_ffsll:
return optimizeFFS(CI, Builder);
- case LibFunc::fls:
- case LibFunc::flsl:
- case LibFunc::flsll:
+ case LibFunc_fls:
+ case LibFunc_flsl:
+ case LibFunc_flsll:
return optimizeFls(CI, Builder);
- case LibFunc::abs:
- case LibFunc::labs:
- case LibFunc::llabs:
+ case LibFunc_abs:
+ case LibFunc_labs:
+ case LibFunc_llabs:
return optimizeAbs(CI, Builder);
- case LibFunc::isdigit:
+ case LibFunc_isdigit:
return optimizeIsDigit(CI, Builder);
- case LibFunc::isascii:
+ case LibFunc_isascii:
return optimizeIsAscii(CI, Builder);
- case LibFunc::toascii:
+ case LibFunc_toascii:
return optimizeToAscii(CI, Builder);
- case LibFunc::printf:
+ case LibFunc_printf:
return optimizePrintF(CI, Builder);
- case LibFunc::sprintf:
+ case LibFunc_sprintf:
return optimizeSPrintF(CI, Builder);
- case LibFunc::fprintf:
+ case LibFunc_fprintf:
return optimizeFPrintF(CI, Builder);
- case LibFunc::fwrite:
+ case LibFunc_fwrite:
return optimizeFWrite(CI, Builder);
- case LibFunc::fputs:
+ case LibFunc_fputs:
return optimizeFPuts(CI, Builder);
- case LibFunc::log:
- case LibFunc::log10:
- case LibFunc::log1p:
- case LibFunc::log2:
- case LibFunc::logb:
+ case LibFunc_log:
+ case LibFunc_log10:
+ case LibFunc_log1p:
+ case LibFunc_log2:
+ case LibFunc_logb:
return optimizeLog(CI, Builder);
- case LibFunc::puts:
+ case LibFunc_puts:
return optimizePuts(CI, Builder);
- case LibFunc::tan:
- case LibFunc::tanf:
- case LibFunc::tanl:
+ case LibFunc_tan:
+ case LibFunc_tanf:
+ case LibFunc_tanl:
return optimizeTan(CI, Builder);
- case LibFunc::perror:
+ case LibFunc_perror:
return optimizeErrorReporting(CI, Builder);
- case LibFunc::vfprintf:
- case LibFunc::fiprintf:
+ case LibFunc_vfprintf:
+ case LibFunc_fiprintf:
return optimizeErrorReporting(CI, Builder, 0);
- case LibFunc::fputc:
+ case LibFunc_fputc:
return optimizeErrorReporting(CI, Builder, 1);
- case LibFunc::ceil:
- case LibFunc::floor:
- case LibFunc::rint:
- case LibFunc::round:
- case LibFunc::nearbyint:
- case LibFunc::trunc:
- if (hasFloatVersion(FuncName))
- return optimizeUnaryDoubleFP(CI, Builder, false);
- return nullptr;
- case LibFunc::acos:
- case LibFunc::acosh:
- case LibFunc::asin:
- case LibFunc::asinh:
- case LibFunc::atan:
- case LibFunc::atanh:
- case LibFunc::cbrt:
- case LibFunc::cosh:
- case LibFunc::exp:
- case LibFunc::exp10:
- case LibFunc::expm1:
- case LibFunc::sin:
- case LibFunc::sinh:
- case LibFunc::tanh:
+ case LibFunc_ceil:
+ return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
+ case LibFunc_floor:
+ return replaceUnaryCall(CI, Builder, Intrinsic::floor);
+ case LibFunc_round:
+ return replaceUnaryCall(CI, Builder, Intrinsic::round);
+ case LibFunc_nearbyint:
+ return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint);
+ case LibFunc_rint:
+ return replaceUnaryCall(CI, Builder, Intrinsic::rint);
+ case LibFunc_trunc:
+ return replaceUnaryCall(CI, Builder, Intrinsic::trunc);
+ case LibFunc_acos:
+ case LibFunc_acosh:
+ case LibFunc_asin:
+ case LibFunc_asinh:
+ case LibFunc_atan:
+ case LibFunc_atanh:
+ case LibFunc_cbrt:
+ case LibFunc_cosh:
+ case LibFunc_exp:
+ case LibFunc_exp10:
+ case LibFunc_expm1:
+ case LibFunc_sin:
+ case LibFunc_sinh:
+ case LibFunc_tanh:
if (UnsafeFPShrink && hasFloatVersion(FuncName))
return optimizeUnaryDoubleFP(CI, Builder, true);
return nullptr;
- case LibFunc::copysign:
+ case LibFunc_copysign:
if (hasFloatVersion(FuncName))
return optimizeBinaryDoubleFP(CI, Builder);
return nullptr;
- case LibFunc::fminf:
- case LibFunc::fmin:
- case LibFunc::fminl:
- case LibFunc::fmaxf:
- case LibFunc::fmax:
- case LibFunc::fmaxl:
+ case LibFunc_fminf:
+ case LibFunc_fmin:
+ case LibFunc_fminl:
+ case LibFunc_fmaxf:
+ case LibFunc_fmax:
+ case LibFunc_fmaxl:
return optimizeFMinFMax(CI, Builder);
default:
return nullptr;
@@ -2211,16 +2228,10 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
// * log(exp10(y)) -> y*log(10)
// * log(sqrt(x)) -> 0.5*log(x)
//
-// lround, lroundf, lroundl:
-// * lround(cnst) -> cnst'
-//
// pow, powf, powl:
// * pow(sqrt(x),y) -> pow(x,y*0.5)
// * pow(pow(x,y),z)-> pow(x,y*z)
//
-// round, roundf, roundl:
-// * round(cnst) -> cnst'
-//
// signbit:
// * signbit(cnst) -> cnst'
// * signbit(nncst) -> 0 (if pstv is a non-negative constant)
@@ -2230,10 +2241,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
// * sqrt(Nroot(x)) -> pow(x,1/(2*N))
// * sqrt(pow(x,y)) -> pow(|x|,y*0.5)
//
-// trunc, truncf, truncl:
-// * trunc(cnst) -> cnst'
-//
-//
//===----------------------------------------------------------------------===//
// Fortified Library Call Optimizations
@@ -2300,7 +2307,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI,
Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
IRBuilder<> &B,
- LibFunc::Func Func) {
+ LibFunc Func) {
Function *Callee = CI->getCalledFunction();
StringRef Name = Callee->getName();
const DataLayout &DL = CI->getModule()->getDataLayout();
@@ -2308,7 +2315,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
*ObjSize = CI->getArgOperand(2);
// __stpcpy_chk(x,x,...) -> x+strlen(x)
- if (Func == LibFunc::stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) {
+ if (Func == LibFunc_stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) {
Value *StrLen = emitStrLen(Src, B, DL, TLI);
return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr;
}
@@ -2334,14 +2341,14 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI);
// If the function was an __stpcpy_chk, and we were able to fold it into
// a __memcpy_chk, we still need to return the correct end pointer.
- if (Ret && Func == LibFunc::stpcpy_chk)
+ if (Ret && Func == LibFunc_stpcpy_chk)
return B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(SizeTTy, Len - 1));
return Ret;
}
Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI,
IRBuilder<> &B,
- LibFunc::Func Func) {
+ LibFunc Func) {
Function *Callee = CI->getCalledFunction();
StringRef Name = Callee->getName();
if (isFortifiedCallFoldable(CI, 3, 2, false)) {
@@ -2366,7 +2373,7 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) {
//
// PR23093.
- LibFunc::Func Func;
+ LibFunc Func;
Function *Callee = CI->getCalledFunction();
SmallVector<OperandBundleDef, 2> OpBundles;
@@ -2384,17 +2391,17 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) {
return nullptr;
switch (Func) {
- case LibFunc::memcpy_chk:
+ case LibFunc_memcpy_chk:
return optimizeMemCpyChk(CI, Builder);
- case LibFunc::memmove_chk:
+ case LibFunc_memmove_chk:
return optimizeMemMoveChk(CI, Builder);
- case LibFunc::memset_chk:
+ case LibFunc_memset_chk:
return optimizeMemSetChk(CI, Builder);
- case LibFunc::stpcpy_chk:
- case LibFunc::strcpy_chk:
+ case LibFunc_stpcpy_chk:
+ case LibFunc_strcpy_chk:
return optimizeStrpCpyChk(CI, Builder, Func);
- case LibFunc::stpncpy_chk:
- case LibFunc::strncpy_chk:
+ case LibFunc_stpncpy_chk:
+ case LibFunc_strncpy_chk:
return optimizeStrpNCpyChk(CI, Builder, Func);
default:
break;
diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp
index 7b9de2eadc611..7106483c3bd2a 100644
--- a/lib/Transforms/Utils/Utils.cpp
+++ b/lib/Transforms/Utils/Utils.cpp
@@ -35,9 +35,8 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) {
initializeUnifyFunctionExitNodesPass(Registry);
initializeInstSimplifierPass(Registry);
initializeMetaRenamerPass(Registry);
- initializeMemorySSAWrapperPassPass(Registry);
- initializeMemorySSAPrinterLegacyPassPass(Registry);
initializeStripGCRelocatesPass(Registry);
+ initializePredicateInfoPrinterLegacyPassPass(Registry);
}
/// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses.
diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp
new file mode 100644
index 0000000000000..4aeea02b1b1bf
--- /dev/null
+++ b/lib/Transforms/Utils/VNCoercion.cpp
@@ -0,0 +1,482 @@
+#include "llvm/Transforms/Utils/VNCoercion.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "vncoerce"
+namespace llvm {
+namespace VNCoercion {
+
+/// Return true if coerceAvailableValueToLoadType will succeed.
+bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy,
+ const DataLayout &DL) {
+ // If the loaded or stored value is an first class array or struct, don't try
+ // to transform them. We need to be able to bitcast to integer.
+ if (LoadTy->isStructTy() || LoadTy->isArrayTy() ||
+ StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy())
+ return false;
+
+ // The store has to be at least as big as the load.
+ if (DL.getTypeSizeInBits(StoredVal->getType()) < DL.getTypeSizeInBits(LoadTy))
+ return false;
+
+ return true;
+}
+
+template <class T, class HelperClass>
+static T *coerceAvailableValueToLoadTypeHelper(T *StoredVal, Type *LoadedTy,
+ HelperClass &Helper,
+ const DataLayout &DL) {
+ assert(canCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) &&
+ "precondition violation - materialization can't fail");
+ if (auto *C = dyn_cast<Constant>(StoredVal))
+ if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
+ StoredVal = FoldedStoredVal;
+
+ // If this is already the right type, just return it.
+ Type *StoredValTy = StoredVal->getType();
+
+ uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy);
+ uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy);
+
+ // If the store and reload are the same size, we can always reuse it.
+ if (StoredValSize == LoadedValSize) {
+ // Pointer to Pointer -> use bitcast.
+ if (StoredValTy->getScalarType()->isPointerTy() &&
+ LoadedTy->getScalarType()->isPointerTy()) {
+ StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy);
+ } else {
+ // Convert source pointers to integers, which can be bitcast.
+ if (StoredValTy->getScalarType()->isPointerTy()) {
+ StoredValTy = DL.getIntPtrType(StoredValTy);
+ StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy);
+ }
+
+ Type *TypeToCastTo = LoadedTy;
+ if (TypeToCastTo->getScalarType()->isPointerTy())
+ TypeToCastTo = DL.getIntPtrType(TypeToCastTo);
+
+ if (StoredValTy != TypeToCastTo)
+ StoredVal = Helper.CreateBitCast(StoredVal, TypeToCastTo);
+
+ // Cast to pointer if the load needs a pointer type.
+ if (LoadedTy->getScalarType()->isPointerTy())
+ StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy);
+ }
+
+ if (auto *C = dyn_cast<ConstantExpr>(StoredVal))
+ if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
+ StoredVal = FoldedStoredVal;
+
+ return StoredVal;
+ }
+ // If the loaded value is smaller than the available value, then we can
+ // extract out a piece from it. If the available value is too small, then we
+ // can't do anything.
+ assert(StoredValSize >= LoadedValSize &&
+ "canCoerceMustAliasedValueToLoad fail");
+
+ // Convert source pointers to integers, which can be manipulated.
+ if (StoredValTy->getScalarType()->isPointerTy()) {
+ StoredValTy = DL.getIntPtrType(StoredValTy);
+ StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy);
+ }
+
+ // Convert vectors and fp to integer, which can be manipulated.
+ if (!StoredValTy->isIntegerTy()) {
+ StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize);
+ StoredVal = Helper.CreateBitCast(StoredVal, StoredValTy);
+ }
+
+ // If this is a big-endian system, we need to shift the value down to the low
+ // bits so that a truncate will work.
+ if (DL.isBigEndian()) {
+ uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) -
+ DL.getTypeStoreSizeInBits(LoadedTy);
+ StoredVal = Helper.CreateLShr(
+ StoredVal, ConstantInt::get(StoredVal->getType(), ShiftAmt));
+ }
+
+ // Truncate the integer to the right size now.
+ Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize);
+ StoredVal = Helper.CreateTruncOrBitCast(StoredVal, NewIntTy);
+
+ if (LoadedTy != NewIntTy) {
+ // If the result is a pointer, inttoptr.
+ if (LoadedTy->getScalarType()->isPointerTy())
+ StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy);
+ else
+ // Otherwise, bitcast.
+ StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy);
+ }
+
+ if (auto *C = dyn_cast<Constant>(StoredVal))
+ if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL))
+ StoredVal = FoldedStoredVal;
+
+ return StoredVal;
+}
+
+/// If we saw a store of a value to memory, and
+/// then a load from a must-aliased pointer of a different type, try to coerce
+/// the stored value. LoadedTy is the type of the load we want to replace.
+/// IRB is IRBuilder used to insert new instructions.
+///
+/// If we can't do it, return null.
+Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy,
+ IRBuilder<> &IRB, const DataLayout &DL) {
+ return coerceAvailableValueToLoadTypeHelper(StoredVal, LoadedTy, IRB, DL);
+}
+
+/// This function is called when we have a memdep query of a load that ends up
+/// being a clobbering memory write (store, memset, memcpy, memmove). This
+/// means that the write *may* provide bits used by the load but we can't be
+/// sure because the pointers don't must-alias.
+///
+/// Check this case to see if there is anything more we can do before we give
+/// up. This returns -1 if we have to give up, or a byte number in the stored
+/// value of the piece that feeds the load.
+static int analyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr,
+ Value *WritePtr,
+ uint64_t WriteSizeInBits,
+ const DataLayout &DL) {
+ // If the loaded or stored value is a first class array or struct, don't try
+ // to transform them. We need to be able to bitcast to integer.
+ if (LoadTy->isStructTy() || LoadTy->isArrayTy())
+ return -1;
+
+ int64_t StoreOffset = 0, LoadOffset = 0;
+ Value *StoreBase =
+ GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL);
+ Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL);
+ if (StoreBase != LoadBase)
+ return -1;
+
+ // If the load and store are to the exact same address, they should have been
+ // a must alias. AA must have gotten confused.
+ // FIXME: Study to see if/when this happens. One case is forwarding a memset
+ // to a load from the base of the memset.
+
+ // If the load and store don't overlap at all, the store doesn't provide
+ // anything to the load. In this case, they really don't alias at all, AA
+ // must have gotten confused.
+ uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy);
+
+ if ((WriteSizeInBits & 7) | (LoadSize & 7))
+ return -1;
+ uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes.
+ LoadSize /= 8;
+
+ bool isAAFailure = false;
+ if (StoreOffset < LoadOffset)
+ isAAFailure = StoreOffset + int64_t(StoreSize) <= LoadOffset;
+ else
+ isAAFailure = LoadOffset + int64_t(LoadSize) <= StoreOffset;
+
+ if (isAAFailure)
+ return -1;
+
+ // If the Load isn't completely contained within the stored bits, we don't
+ // have all the bits to feed it. We could do something crazy in the future
+ // (issue a smaller load then merge the bits in) but this seems unlikely to be
+ // valuable.
+ if (StoreOffset > LoadOffset ||
+ StoreOffset + StoreSize < LoadOffset + LoadSize)
+ return -1;
+
+ // Okay, we can do this transformation. Return the number of bytes into the
+ // store that the load is.
+ return LoadOffset - StoreOffset;
+}
+
+/// This function is called when we have a
+/// memdep query of a load that ends up being a clobbering store.
+int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr,
+ StoreInst *DepSI, const DataLayout &DL) {
+ // Cannot handle reading from store of first-class aggregate yet.
+ if (DepSI->getValueOperand()->getType()->isStructTy() ||
+ DepSI->getValueOperand()->getType()->isArrayTy())
+ return -1;
+
+ Value *StorePtr = DepSI->getPointerOperand();
+ uint64_t StoreSize =
+ DL.getTypeSizeInBits(DepSI->getValueOperand()->getType());
+ return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, StorePtr, StoreSize,
+ DL);
+}
+
+/// This function is called when we have a
+/// memdep query of a load that ends up being clobbered by another load. See if
+/// the other load can feed into the second load.
+int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI,
+ const DataLayout &DL) {
+ // Cannot handle reading from store of first-class aggregate yet.
+ if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy())
+ return -1;
+
+ Value *DepPtr = DepLI->getPointerOperand();
+ uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType());
+ int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL);
+ if (R != -1)
+ return R;
+
+ // If we have a load/load clobber an DepLI can be widened to cover this load,
+ // then we should widen it!
+ int64_t LoadOffs = 0;
+ const Value *LoadBase =
+ GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL);
+ unsigned LoadSize = DL.getTypeStoreSize(LoadTy);
+
+ unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize(
+ LoadBase, LoadOffs, LoadSize, DepLI);
+ if (Size == 0)
+ return -1;
+
+ // Check non-obvious conditions enforced by MDA which we rely on for being
+ // able to materialize this potentially available value
+ assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!");
+ assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load");
+
+ return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size * 8, DL);
+}
+
+int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr,
+ MemIntrinsic *MI, const DataLayout &DL) {
+ // If the mem operation is a non-constant size, we can't handle it.
+ ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength());
+ if (!SizeCst)
+ return -1;
+ uint64_t MemSizeInBits = SizeCst->getZExtValue() * 8;
+
+ // If this is memset, we just need to see if the offset is valid in the size
+ // of the memset..
+ if (MI->getIntrinsicID() == Intrinsic::memset)
+ return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(),
+ MemSizeInBits, DL);
+
+ // If we have a memcpy/memmove, the only case we can handle is if this is a
+ // copy from constant memory. In that case, we can read directly from the
+ // constant memory.
+ MemTransferInst *MTI = cast<MemTransferInst>(MI);
+
+ Constant *Src = dyn_cast<Constant>(MTI->getSource());
+ if (!Src)
+ return -1;
+
+ GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL));
+ if (!GV || !GV->isConstant())
+ return -1;
+
+ // See if the access is within the bounds of the transfer.
+ int Offset = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(),
+ MemSizeInBits, DL);
+ if (Offset == -1)
+ return Offset;
+
+ unsigned AS = Src->getType()->getPointerAddressSpace();
+ // Otherwise, see if we can constant fold a load from the constant with the
+ // offset applied as appropriate.
+ Src =
+ ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS));
+ Constant *OffsetCst =
+ ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset);
+ Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src,
+ OffsetCst);
+ Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS));
+ if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL))
+ return Offset;
+ return -1;
+}
+
+template <class T, class HelperClass>
+static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy,
+ HelperClass &Helper,
+ const DataLayout &DL) {
+ LLVMContext &Ctx = SrcVal->getType()->getContext();
+
+ uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8;
+ uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8;
+ // Compute which bits of the stored value are being used by the load. Convert
+ // to an integer type to start with.
+ if (SrcVal->getType()->getScalarType()->isPointerTy())
+ SrcVal = Helper.CreatePtrToInt(SrcVal, DL.getIntPtrType(SrcVal->getType()));
+ if (!SrcVal->getType()->isIntegerTy())
+ SrcVal = Helper.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize * 8));
+
+ // Shift the bits to the least significant depending on endianness.
+ unsigned ShiftAmt;
+ if (DL.isLittleEndian())
+ ShiftAmt = Offset * 8;
+ else
+ ShiftAmt = (StoreSize - LoadSize - Offset) * 8;
+ if (ShiftAmt)
+ SrcVal = Helper.CreateLShr(SrcVal,
+ ConstantInt::get(SrcVal->getType(), ShiftAmt));
+
+ if (LoadSize != StoreSize)
+ SrcVal = Helper.CreateTruncOrBitCast(SrcVal,
+ IntegerType::get(Ctx, LoadSize * 8));
+ return SrcVal;
+}
+
+/// This function is called when we have a memdep query of a load that ends up
+/// being a clobbering store. This means that the store provides bits used by
+/// the load but the pointers don't must-alias. Check this case to see if
+/// there is anything more we can do before we give up.
+Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy,
+ Instruction *InsertPt, const DataLayout &DL) {
+
+ IRBuilder<> Builder(InsertPt);
+ SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL);
+ return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, Builder, DL);
+}
+
+Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset,
+ Type *LoadTy, const DataLayout &DL) {
+ ConstantFolder F;
+ SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, F, DL);
+ return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, F, DL);
+}
+
+/// This function is called when we have a memdep query of a load that ends up
+/// being a clobbering load. This means that the load *may* provide bits used
+/// by the load but we can't be sure because the pointers don't must-alias.
+/// Check this case to see if there is anything more we can do before we give
+/// up.
+Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy,
+ Instruction *InsertPt, const DataLayout &DL) {
+ // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to
+ // widen SrcVal out to a larger load.
+ unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType());
+ unsigned LoadSize = DL.getTypeStoreSize(LoadTy);
+ if (Offset + LoadSize > SrcValStoreSize) {
+ assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!");
+ assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load");
+ // If we have a load/load clobber an DepLI can be widened to cover this
+ // load, then we should widen it to the next power of 2 size big enough!
+ unsigned NewLoadSize = Offset + LoadSize;
+ if (!isPowerOf2_32(NewLoadSize))
+ NewLoadSize = NextPowerOf2(NewLoadSize);
+
+ Value *PtrVal = SrcVal->getPointerOperand();
+ // Insert the new load after the old load. This ensures that subsequent
+ // memdep queries will find the new load. We can't easily remove the old
+ // load completely because it is already in the value numbering table.
+ IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal));
+ Type *DestPTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8);
+ DestPTy =
+ PointerType::get(DestPTy, PtrVal->getType()->getPointerAddressSpace());
+ Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc());
+ PtrVal = Builder.CreateBitCast(PtrVal, DestPTy);
+ LoadInst *NewLoad = Builder.CreateLoad(PtrVal);
+ NewLoad->takeName(SrcVal);
+ NewLoad->setAlignment(SrcVal->getAlignment());
+
+ DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n");
+ DEBUG(dbgs() << "TO: " << *NewLoad << "\n");
+
+ // Replace uses of the original load with the wider load. On a big endian
+ // system, we need to shift down to get the relevant bits.
+ Value *RV = NewLoad;
+ if (DL.isBigEndian())
+ RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8);
+ RV = Builder.CreateTrunc(RV, SrcVal->getType());
+ SrcVal->replaceAllUsesWith(RV);
+
+ SrcVal = NewLoad;
+ }
+
+ return getStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL);
+}
+
+Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset,
+ Type *LoadTy, const DataLayout &DL) {
+ unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType());
+ unsigned LoadSize = DL.getTypeStoreSize(LoadTy);
+ if (Offset + LoadSize > SrcValStoreSize)
+ return nullptr;
+ return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL);
+}
+
+template <class T, class HelperClass>
+T *getMemInstValueForLoadHelper(MemIntrinsic *SrcInst, unsigned Offset,
+ Type *LoadTy, HelperClass &Helper,
+ const DataLayout &DL) {
+ LLVMContext &Ctx = LoadTy->getContext();
+ uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy) / 8;
+
+ // We know that this method is only called when the mem transfer fully
+ // provides the bits for the load.
+ if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) {
+ // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and
+ // independently of what the offset is.
+ T *Val = cast<T>(MSI->getValue());
+ if (LoadSize != 1)
+ Val =
+ Helper.CreateZExtOrBitCast(Val, IntegerType::get(Ctx, LoadSize * 8));
+ T *OneElt = Val;
+
+ // Splat the value out to the right number of bits.
+ for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize;) {
+ // If we can double the number of bytes set, do it.
+ if (NumBytesSet * 2 <= LoadSize) {
+ T *ShVal = Helper.CreateShl(
+ Val, ConstantInt::get(Val->getType(), NumBytesSet * 8));
+ Val = Helper.CreateOr(Val, ShVal);
+ NumBytesSet <<= 1;
+ continue;
+ }
+
+ // Otherwise insert one byte at a time.
+ T *ShVal = Helper.CreateShl(Val, ConstantInt::get(Val->getType(), 1 * 8));
+ Val = Helper.CreateOr(OneElt, ShVal);
+ ++NumBytesSet;
+ }
+
+ return coerceAvailableValueToLoadTypeHelper(Val, LoadTy, Helper, DL);
+ }
+
+ // Otherwise, this is a memcpy/memmove from a constant global.
+ MemTransferInst *MTI = cast<MemTransferInst>(SrcInst);
+ Constant *Src = cast<Constant>(MTI->getSource());
+ unsigned AS = Src->getType()->getPointerAddressSpace();
+
+ // Otherwise, see if we can constant fold a load from the constant with the
+ // offset applied as appropriate.
+ Src =
+ ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS));
+ Constant *OffsetCst =
+ ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset);
+ Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src,
+ OffsetCst);
+ Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS));
+ return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL);
+}
+
+/// This function is called when we have a
+/// memdep query of a load that ends up being a clobbering mem intrinsic.
+Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset,
+ Type *LoadTy, Instruction *InsertPt,
+ const DataLayout &DL) {
+ IRBuilder<> Builder(InsertPt);
+ return getMemInstValueForLoadHelper<Value, IRBuilder<>>(SrcInst, Offset,
+ LoadTy, Builder, DL);
+}
+
+Constant *getConstantMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset,
+ Type *LoadTy, const DataLayout &DL) {
+ // The only case analyzeLoadFromClobberingMemInst cannot be converted to a
+ // constant is when it's a memset of a non-constant.
+ if (auto *MSI = dyn_cast<MemSetInst>(SrcInst))
+ if (!isa<Constant>(MSI->getValue()))
+ return nullptr;
+ ConstantFolder F;
+ return getMemInstValueForLoadHelper<Constant, ConstantFolder>(SrcInst, Offset,
+ LoadTy, F, DL);
+}
+} // namespace VNCoercion
+} // namespace llvm
diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp
index 0e9baaf8649d0..f77c10b6dd473 100644
--- a/lib/Transforms/Utils/ValueMapper.cpp
+++ b/lib/Transforms/Utils/ValueMapper.cpp
@@ -681,6 +681,7 @@ void MDNodeMapper::mapNodesInPOT(UniquedGraph &G) {
remapOperands(*ClonedN, [this, &D, &G](Metadata *Old) {
if (Optional<Metadata *> MappedOp = getMappedOp(Old))
return *MappedOp;
+ (void)D;
assert(G.Info[Old].ID > D.ID && "Expected a forward reference");
return &G.getFwdReference(*cast<MDNode>(Old));
});
diff --git a/lib/Transforms/Vectorize/BBVectorize.cpp b/lib/Transforms/Vectorize/BBVectorize.cpp
index c01740b27d59b..c83b3f7b225bc 100644
--- a/lib/Transforms/Vectorize/BBVectorize.cpp
+++ b/lib/Transforms/Vectorize/BBVectorize.cpp
@@ -494,13 +494,13 @@ namespace {
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
// For stores, it is the value type, not the pointer type that matters
// because the value is what will come from a vector register.
-
+
Value *IVal = SI->getValueOperand();
T1 = IVal->getType();
} else {
T1 = I->getType();
}
-
+
if (CastInst *CI = dyn_cast<CastInst>(I))
T2 = CI->getSrcTy();
else
@@ -547,10 +547,11 @@ namespace {
// Returns the cost of the provided instruction using TTI.
// This does not handle loads and stores.
unsigned getInstrCost(unsigned Opcode, Type *T1, Type *T2,
- TargetTransformInfo::OperandValueKind Op1VK =
+ TargetTransformInfo::OperandValueKind Op1VK =
TargetTransformInfo::OK_AnyValue,
TargetTransformInfo::OperandValueKind Op2VK =
- TargetTransformInfo::OK_AnyValue) {
+ TargetTransformInfo::OK_AnyValue,
+ const Instruction *I = nullptr) {
switch (Opcode) {
default: break;
case Instruction::GetElementPtr:
@@ -584,7 +585,7 @@ namespace {
case Instruction::Select:
case Instruction::ICmp:
case Instruction::FCmp:
- return TTI->getCmpSelInstrCost(Opcode, T1, T2);
+ return TTI->getCmpSelInstrCost(Opcode, T1, T2, I);
case Instruction::ZExt:
case Instruction::SExt:
case Instruction::FPToUI:
@@ -598,7 +599,7 @@ namespace {
case Instruction::FPTrunc:
case Instruction::BitCast:
case Instruction::ShuffleVector:
- return TTI->getCastInstrCost(Opcode, T1, T2);
+ return TTI->getCastInstrCost(Opcode, T1, T2, I);
}
return 1;
@@ -894,7 +895,7 @@ namespace {
// vectors that has a scalar condition results in a malformed select.
// FIXME: We could probably be smarter about this by rewriting the select
// with different types instead.
- return (SI->getCondition()->getType()->isVectorTy() ==
+ return (SI->getCondition()->getType()->isVectorTy() ==
SI->getTrueValue()->getType()->isVectorTy());
} else if (isa<CmpInst>(I)) {
if (!Config.VectorizeCmp)
@@ -1044,14 +1045,14 @@ namespace {
return false;
}
} else if (TTI) {
- unsigned ICost = getInstrCost(I->getOpcode(), IT1, IT2);
- unsigned JCost = getInstrCost(J->getOpcode(), JT1, JT2);
- Type *VT1 = getVecTypeForPair(IT1, JT1),
- *VT2 = getVecTypeForPair(IT2, JT2);
TargetTransformInfo::OperandValueKind Op1VK =
TargetTransformInfo::OK_AnyValue;
TargetTransformInfo::OperandValueKind Op2VK =
TargetTransformInfo::OK_AnyValue;
+ unsigned ICost = getInstrCost(I->getOpcode(), IT1, IT2, Op1VK, Op2VK, I);
+ unsigned JCost = getInstrCost(J->getOpcode(), JT1, JT2, Op1VK, Op2VK, J);
+ Type *VT1 = getVecTypeForPair(IT1, JT1),
+ *VT2 = getVecTypeForPair(IT2, JT2);
// On some targets (example X86) the cost of a vector shift may vary
// depending on whether the second operand is a Uniform or
@@ -1090,7 +1091,7 @@ namespace {
// but this cost is ignored (because insert and extract element
// instructions are assigned a zero depth factor and are not really
// fused in general).
- unsigned VCost = getInstrCost(I->getOpcode(), VT1, VT2, Op1VK, Op2VK);
+ unsigned VCost = getInstrCost(I->getOpcode(), VT1, VT2, Op1VK, Op2VK, I);
if (VCost > ICost + JCost)
return false;
@@ -1127,39 +1128,51 @@ namespace {
FastMathFlags FMFCI;
if (auto *FPMOCI = dyn_cast<FPMathOperator>(CI))
FMFCI = FPMOCI->getFastMathFlags();
+ SmallVector<Value *, 4> IArgs(CI->arg_operands());
+ unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, IArgs, FMFCI);
- SmallVector<Type*, 4> Tys;
- for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i)
- Tys.push_back(CI->getArgOperand(i)->getType());
- unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, Tys, FMFCI);
-
- Tys.clear();
CallInst *CJ = cast<CallInst>(J);
FastMathFlags FMFCJ;
if (auto *FPMOCJ = dyn_cast<FPMathOperator>(CJ))
FMFCJ = FPMOCJ->getFastMathFlags();
- for (unsigned i = 0, ie = CJ->getNumArgOperands(); i != ie; ++i)
- Tys.push_back(CJ->getArgOperand(i)->getType());
- unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, Tys, FMFCJ);
+ SmallVector<Value *, 4> JArgs(CJ->arg_operands());
+ unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, JArgs, FMFCJ);
- Tys.clear();
assert(CI->getNumArgOperands() == CJ->getNumArgOperands() &&
"Intrinsic argument counts differ");
+ SmallVector<Type*, 4> Tys;
+ SmallVector<Value *, 4> VecArgs;
for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) {
if ((IID == Intrinsic::powi || IID == Intrinsic::ctlz ||
- IID == Intrinsic::cttz) && i == 1)
+ IID == Intrinsic::cttz) && i == 1) {
Tys.push_back(CI->getArgOperand(i)->getType());
- else
+ VecArgs.push_back(CI->getArgOperand(i));
+ }
+ else {
Tys.push_back(getVecTypeForPair(CI->getArgOperand(i)->getType(),
CJ->getArgOperand(i)->getType()));
+ // Add both operands, and then count their scalarization overhead
+ // with VF 1.
+ VecArgs.push_back(CI->getArgOperand(i));
+ VecArgs.push_back(CJ->getArgOperand(i));
+ }
}
+ // Compute the scalarization cost here with the original operands (to
+ // check for uniqueness etc), and then call getIntrinsicInstrCost()
+ // with the constructed vector types.
+ Type *RetTy = getVecTypeForPair(IT1, JT1);
+ unsigned ScalarizationCost = 0;
+ if (!RetTy->isVoidTy())
+ ScalarizationCost += TTI->getScalarizationOverhead(RetTy, true, false);
+ ScalarizationCost += TTI->getOperandsScalarizationOverhead(VecArgs, 1);
+
FastMathFlags FMFV = FMFCI;
FMFV &= FMFCJ;
- Type *RetTy = getVecTypeForPair(IT1, JT1);
- unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV);
+ unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV,
+ ScalarizationCost);
if (VCost > ICost + JCost)
return false;
@@ -2502,7 +2515,7 @@ namespace {
if (I2 == I1 || isa<UndefValue>(I2))
I2 = nullptr;
}
-
+
if (HEE) {
Value *I3 = HEE->getOperand(0);
if (!I2 && I3 != I1)
@@ -2693,14 +2706,14 @@ namespace {
// so extend the smaller vector to be the same length as the larger one.
Instruction *NLOp;
if (numElemL > 1) {
-
+
std::vector<Constant *> Mask(numElemH);
unsigned v = 0;
for (; v < numElemL; ++v)
Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
for (; v < numElemH; ++v)
Mask[v] = UndefValue::get(Type::getInt32Ty(Context));
-
+
NLOp = new ShuffleVectorInst(LOp, UndefValue::get(ArgTypeL),
ConstantVector::get(Mask),
getReplacementName(IBeforeJ ? I : J,
@@ -2710,7 +2723,7 @@ namespace {
getReplacementName(IBeforeJ ? I : J,
true, o, 1));
}
-
+
NLOp->insertBefore(IBeforeJ ? J : I);
LOp = NLOp;
}
@@ -2720,7 +2733,7 @@ namespace {
if (numElemH == 1 && expandIEChain(Context, I, J, o, LOp, numElemL,
ArgTypeH, VArgType, IBeforeJ)) {
Instruction *S =
- InsertElementInst::Create(LOp, HOp,
+ InsertElementInst::Create(LOp, HOp,
ConstantInt::get(Type::getInt32Ty(Context),
numElemL),
getReplacementName(IBeforeJ ? I : J,
@@ -2737,7 +2750,7 @@ namespace {
Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
for (; v < numElemL; ++v)
Mask[v] = UndefValue::get(Type::getInt32Ty(Context));
-
+
NHOp = new ShuffleVectorInst(HOp, UndefValue::get(ArgTypeH),
ConstantVector::get(Mask),
getReplacementName(IBeforeJ ? I : J,
diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index c44a393cf846d..4409d7a404f8b 100644
--- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -432,9 +432,12 @@ Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
unsigned ElementSizeBytes = ElementSizeBits / 8;
unsigned SizeBytes = ElementSizeBytes * Chain.size();
unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
- if (NumLeft == Chain.size())
- --NumLeft;
- else if (NumLeft == 0)
+ if (NumLeft == Chain.size()) {
+ if ((NumLeft & 1) == 0)
+ NumLeft /= 2; // Split even in half
+ else
+ --NumLeft; // Split off last element
+ } else if (NumLeft == 0)
NumLeft = 1;
return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
}
@@ -588,7 +591,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) {
continue;
// Make sure all the users of a vector are constant-index extracts.
- if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) {
+ if (isa<VectorType>(Ty) && !all_of(LI->users(), [](const User *U) {
const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
return EEI && isa<ConstantInt>(EEI->getOperand(1));
}))
@@ -622,7 +625,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) {
if (TySize > VecRegSize / 2)
continue;
- if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) {
+ if (isa<VectorType>(Ty) && !all_of(SI->users(), [](const User *U) {
const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
return EEI && isa<ConstantInt>(EEI->getOperand(1));
}))
diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp
index dac7032fa08f6..595b2ec889435 100644
--- a/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -50,6 +50,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -92,6 +93,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
#include "llvm/Transforms/Vectorize.h"
@@ -266,21 +268,6 @@ static bool hasCyclesInLoopBody(const Loop &L) {
return false;
}
-/// \brief This modifies LoopAccessReport to initialize message with
-/// loop-vectorizer-specific part.
-class VectorizationReport : public LoopAccessReport {
-public:
- VectorizationReport(Instruction *I = nullptr)
- : LoopAccessReport("loop not vectorized: ", I) {}
-
- /// \brief This allows promotion of the loop-access analysis report into the
- /// loop-vectorizer report. It modifies the message to add the
- /// loop-vectorizer-specific part of the message.
- explicit VectorizationReport(const LoopAccessReport &R)
- : LoopAccessReport(Twine("loop not vectorized: ") + R.str(),
- R.getInstr()) {}
-};
-
/// A helper function for converting Scalar types to vector types.
/// If the incoming type is void, we return void. If the VF is 1, we return
/// the scalar type.
@@ -290,31 +277,9 @@ static Type *ToVectorTy(Type *Scalar, unsigned VF) {
return VectorType::get(Scalar, VF);
}
-/// A helper function that returns GEP instruction and knows to skip a
-/// 'bitcast'. The 'bitcast' may be skipped if the source and the destination
-/// pointee types of the 'bitcast' have the same size.
-/// For example:
-/// bitcast double** %var to i64* - can be skipped
-/// bitcast double** %var to i8* - can not
-static GetElementPtrInst *getGEPInstruction(Value *Ptr) {
-
- if (isa<GetElementPtrInst>(Ptr))
- return cast<GetElementPtrInst>(Ptr);
-
- if (isa<BitCastInst>(Ptr) &&
- isa<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0))) {
- Type *BitcastTy = Ptr->getType();
- Type *GEPTy = cast<BitCastInst>(Ptr)->getSrcTy();
- if (!isa<PointerType>(BitcastTy) || !isa<PointerType>(GEPTy))
- return nullptr;
- Type *Pointee1Ty = cast<PointerType>(BitcastTy)->getPointerElementType();
- Type *Pointee2Ty = cast<PointerType>(GEPTy)->getPointerElementType();
- const DataLayout &DL = cast<BitCastInst>(Ptr)->getModule()->getDataLayout();
- if (DL.getTypeSizeInBits(Pointee1Ty) == DL.getTypeSizeInBits(Pointee2Ty))
- return cast<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0));
- }
- return nullptr;
-}
+// FIXME: The following helper functions have multiple implementations
+// in the project. They can be effectively organized in a common Load/Store
+// utilities unit.
/// A helper function that returns the pointer operand of a load or store
/// instruction.
@@ -326,6 +291,34 @@ static Value *getPointerOperand(Value *I) {
return nullptr;
}
+/// A helper function that returns the type of loaded or stored value.
+static Type *getMemInstValueType(Value *I) {
+ assert((isa<LoadInst>(I) || isa<StoreInst>(I)) &&
+ "Expected Load or Store instruction");
+ if (auto *LI = dyn_cast<LoadInst>(I))
+ return LI->getType();
+ return cast<StoreInst>(I)->getValueOperand()->getType();
+}
+
+/// A helper function that returns the alignment of load or store instruction.
+static unsigned getMemInstAlignment(Value *I) {
+ assert((isa<LoadInst>(I) || isa<StoreInst>(I)) &&
+ "Expected Load or Store instruction");
+ if (auto *LI = dyn_cast<LoadInst>(I))
+ return LI->getAlignment();
+ return cast<StoreInst>(I)->getAlignment();
+}
+
+/// A helper function that returns the address space of the pointer operand of
+/// load or store instruction.
+static unsigned getMemInstAddressSpace(Value *I) {
+ assert((isa<LoadInst>(I) || isa<StoreInst>(I)) &&
+ "Expected Load or Store instruction");
+ if (auto *LI = dyn_cast<LoadInst>(I))
+ return LI->getPointerAddressSpace();
+ return cast<StoreInst>(I)->getPointerAddressSpace();
+}
+
/// A helper function that returns true if the given type is irregular. The
/// type is irregular if its allocated size doesn't equal the store size of an
/// element of the corresponding vector type at the given vectorization factor.
@@ -351,6 +344,23 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) {
/// we always assume predicated blocks have a 50% chance of executing.
static unsigned getReciprocalPredBlockProb() { return 2; }
+/// A helper function that adds a 'fast' flag to floating-point operations.
+static Value *addFastMathFlag(Value *V) {
+ if (isa<FPMathOperator>(V)) {
+ FastMathFlags Flags;
+ Flags.setUnsafeAlgebra();
+ cast<Instruction>(V)->setFastMathFlags(Flags);
+ }
+ return V;
+}
+
+/// A helper function that returns an integer or floating-point constant with
+/// value C.
+static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
+ return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C)
+ : ConstantFP::get(Ty, C);
+}
+
/// InnerLoopVectorizer vectorizes loops which contain only one basic
/// block to a specified vectorization factor (VF).
/// This class performs the widening of scalars into vectors, or multiple
@@ -428,10 +438,17 @@ protected:
/// Copy and widen the instructions from the old loop.
virtual void vectorizeLoop();
+ /// Handle all cross-iteration phis in the header.
+ void fixCrossIterationPHIs();
+
/// Fix a first-order recurrence. This is the second phase of vectorizing
/// this phi node.
void fixFirstOrderRecurrence(PHINode *Phi);
+ /// Fix a reduction cross-iteration phi. This is the second phase of
+ /// vectorizing this phi node.
+ void fixReduction(PHINode *Phi);
+
/// \brief The Loop exit block may have single value PHI nodes where the
/// incoming value is 'Undef'. While vectorizing we only handled real values
/// that were defined inside the loop. Here we fix the 'undef case'.
@@ -448,7 +465,8 @@ protected:
/// Collect the instructions from the original loop that would be trivially
/// dead in the vectorized loop if generated.
- void collectTriviallyDeadInstructions();
+ void collectTriviallyDeadInstructions(
+ SmallPtrSetImpl<Instruction *> &DeadInstructions);
/// Shrinks vector element sizes to the smallest bitwidth they can be legally
/// represented as.
@@ -462,14 +480,14 @@ protected:
/// and DST.
VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst);
- /// A helper function to vectorize a single BB within the innermost loop.
- void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV);
+ /// A helper function to vectorize a single instruction within the innermost
+ /// loop.
+ void vectorizeInstruction(Instruction &I);
/// Vectorize a single PHINode in a block. This method handles the induction
/// variable canonicalization. It supports both VF = 1 for unrolled loops and
/// arbitrary length vectors.
- void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF,
- PhiVector *PV);
+ void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF);
/// Insert the new loop to the loop hierarchy and pass manager
/// and update the analysis passes.
@@ -504,20 +522,21 @@ protected:
/// \p EntryVal is the value from the original loop that maps to the steps.
/// Note that \p EntryVal doesn't have to be an induction variable (e.g., it
/// can be a truncate instruction).
- void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal);
-
- /// Create a vector induction phi node based on an existing scalar one. This
- /// currently only works for integer induction variables with a constant
- /// step. \p EntryVal is the value from the original loop that maps to the
- /// vector phi node. If \p EntryVal is a truncate instruction, instead of
- /// widening the original IV, we widen a version of the IV truncated to \p
- /// EntryVal's type.
- void createVectorIntInductionPHI(const InductionDescriptor &II,
- Instruction *EntryVal);
-
- /// Widen an integer induction variable \p IV. If \p Trunc is provided, the
- /// induction variable will first be truncated to the corresponding type.
- void widenIntInduction(PHINode *IV, TruncInst *Trunc = nullptr);
+ void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal,
+ const InductionDescriptor &ID);
+
+ /// Create a vector induction phi node based on an existing scalar one. \p
+ /// EntryVal is the value from the original loop that maps to the vector phi
+ /// node, and \p Step is the loop-invariant step. If \p EntryVal is a
+ /// truncate instruction, instead of widening the original IV, we widen a
+ /// version of the IV truncated to \p EntryVal's type.
+ void createVectorIntOrFpInductionPHI(const InductionDescriptor &II,
+ Value *Step, Instruction *EntryVal);
+
+ /// Widen an integer or floating-point induction variable \p IV. If \p Trunc
+ /// is provided, the integer induction variable will first be truncated to
+ /// the corresponding type.
+ void widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc = nullptr);
/// Returns true if an instruction \p I should be scalarized instead of
/// vectorized for the chosen vectorization factor.
@@ -583,6 +602,10 @@ protected:
/// vector of instructions.
void addMetadata(ArrayRef<Value *> To, Instruction *From);
+ /// \brief Set the debug location in the builder using the debug location in
+ /// the instruction.
+ void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr);
+
/// This is a helper class for maintaining vectorization state. It's used for
/// mapping values from the original loop to their corresponding values in
/// the new loop. Two mappings are maintained: one for vectorized values and
@@ -777,14 +800,6 @@ protected:
// Record whether runtime checks are added.
bool AddedSafetyChecks;
- // Holds instructions from the original loop whose counterparts in the
- // vectorized loop would be trivially dead if generated. For example,
- // original induction update instructions can become dead because we
- // separately emit induction "steps" when generating code for the new loop.
- // Similarly, we create a new latch condition when setting up the structure
- // of the new loop, so the old one can become dead.
- SmallPtrSet<Instruction *, 4> DeadInstructions;
-
// Holds the end values for each induction variable. We save the end values
// so we can later fix-up the external users of the induction variables.
DenseMap<PHINode *, Value *> IVEndValues;
@@ -803,8 +818,6 @@ public:
UnrollFactor, LVL, CM) {}
private:
- void scalarizeInstruction(Instruction *Instr,
- bool IfPredicateInstr = false) override;
void vectorizeMemoryInstruction(Instruction *Instr) override;
Value *getBroadcastInstrs(Value *V) override;
Value *getStepVector(Value *Val, int StartIdx, Value *Step,
@@ -832,12 +845,14 @@ static Instruction *getDebugLocFromInstOrOperands(Instruction *I) {
return I;
}
-/// \brief Set the debug location in the builder using the debug location in the
-/// instruction.
-static void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) {
- if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr))
- B.SetCurrentDebugLocation(Inst->getDebugLoc());
- else
+void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) {
+ if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) {
+ const DILocation *DIL = Inst->getDebugLoc();
+ if (DIL && Inst->getFunction()->isDebugInfoForProfiling())
+ B.SetCurrentDebugLocation(DIL->cloneWithDuplicationFactor(UF * VF));
+ else
+ B.SetCurrentDebugLocation(DIL);
+ } else
B.SetCurrentDebugLocation(DebugLoc());
}
@@ -1497,14 +1512,6 @@ private:
OptimizationRemarkEmitter &ORE;
};
-static void emitAnalysisDiag(const Loop *TheLoop,
- const LoopVectorizeHints &Hints,
- OptimizationRemarkEmitter &ORE,
- const LoopAccessReport &Message) {
- const char *Name = Hints.vectorizeAnalysisPassName();
- LoopAccessReport::emitAnalysis(Message, TheLoop, Name, ORE);
-}
-
static void emitMissedWarning(Function *F, Loop *L,
const LoopVectorizeHints &LH,
OptimizationRemarkEmitter *ORE) {
@@ -1512,13 +1519,17 @@ static void emitMissedWarning(Function *F, Loop *L,
if (LH.getForce() == LoopVectorizeHints::FK_Enabled) {
if (LH.getWidth() != 1)
- emitLoopVectorizeWarning(
- F->getContext(), *F, L->getStartLoc(),
- "failed explicitly specified loop vectorization");
+ ORE->emit(DiagnosticInfoOptimizationFailure(
+ DEBUG_TYPE, "FailedRequestedVectorization",
+ L->getStartLoc(), L->getHeader())
+ << "loop not vectorized: "
+ << "failed explicitly specified loop vectorization");
else if (LH.getInterleave() != 1)
- emitLoopInterleaveWarning(
- F->getContext(), *F, L->getStartLoc(),
- "failed explicitly specified loop interleaving");
+ ORE->emit(DiagnosticInfoOptimizationFailure(
+ DEBUG_TYPE, "FailedRequestedInterleaving", L->getStartLoc(),
+ L->getHeader())
+ << "loop not interleaved: "
+ << "failed explicitly specified loop interleaving");
}
}
@@ -1546,7 +1557,7 @@ public:
LoopVectorizeHints *H)
: NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT),
GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI),
- Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
+ PrimaryInduction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
Requirements(R), Hints(H) {}
/// ReductionList contains the reduction descriptors for all
@@ -1566,8 +1577,8 @@ public:
/// loop, only that it is legal to do so.
bool canVectorize();
- /// Returns the Induction variable.
- PHINode *getInduction() { return Induction; }
+ /// Returns the primary induction variable.
+ PHINode *getPrimaryInduction() { return PrimaryInduction; }
/// Returns the reduction variables found in the loop.
ReductionList *getReductionVars() { return &Reductions; }
@@ -1607,12 +1618,6 @@ public:
/// Returns true if the value V is uniform within the loop.
bool isUniform(Value *V);
- /// Returns true if \p I is known to be uniform after vectorization.
- bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); }
-
- /// Returns true if \p I is known to be scalar after vectorization.
- bool isScalarAfterVectorization(Instruction *I) { return Scalars.count(I); }
-
/// Returns the information that we collected about runtime memory check.
const RuntimePointerChecking *getRuntimePointerChecking() const {
return LAI->getRuntimePointerChecking();
@@ -1689,15 +1694,9 @@ public:
/// instructions that may divide by zero.
bool isScalarWithPredication(Instruction *I);
- /// Returns true if \p I is a memory instruction that has a consecutive or
- /// consecutive-like pointer operand. Consecutive-like pointers are pointers
- /// that are treated like consecutive pointers during vectorization. The
- /// pointer operands of interleaved accesses are an example.
- bool hasConsecutiveLikePtrOperand(Instruction *I);
-
- /// Returns true if \p I is a memory instruction that must be scalarized
- /// during vectorization.
- bool memoryInstructionMustBeScalarized(Instruction *I, unsigned VF = 1);
+ /// Returns true if \p I is a memory instruction with consecutive memory
+ /// access that can be widened.
+ bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1);
private:
/// Check if a single basic block loop is vectorizable.
@@ -1715,24 +1714,6 @@ private:
/// transformation.
bool canVectorizeWithIfConvert();
- /// Collect the instructions that are uniform after vectorization. An
- /// instruction is uniform if we represent it with a single scalar value in
- /// the vectorized loop corresponding to each vector iteration. Examples of
- /// uniform instructions include pointer operands of consecutive or
- /// interleaved memory accesses. Note that although uniformity implies an
- /// instruction will be scalar, the reverse is not true. In general, a
- /// scalarized instruction will be represented by VF scalar values in the
- /// vectorized loop, each corresponding to an iteration of the original
- /// scalar loop.
- void collectLoopUniforms();
-
- /// Collect the instructions that are scalar after vectorization. An
- /// instruction is scalar if it is known to be uniform or will be scalarized
- /// during vectorization. Non-uniform scalarized instructions will be
- /// represented by VF values in the vectorized loop, each corresponding to an
- /// iteration of the original scalar loop.
- void collectLoopScalars();
-
/// Return true if all of the instructions in the block can be speculatively
/// executed. \p SafePtrs is a list of addresses that are known to be legal
/// and we know that we can read from them without segfault.
@@ -1744,14 +1725,6 @@ private:
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
SmallPtrSetImpl<Value *> &AllowedExit);
- /// Report an analysis message to assist the user in diagnosing loops that are
- /// not vectorized. These are handled as LoopAccessReport rather than
- /// VectorizationReport because the << operator of VectorizationReport returns
- /// LoopAccessReport.
- void emitAnalysis(const LoopAccessReport &Message) const {
- emitAnalysisDiag(TheLoop, *Hints, *ORE, Message);
- }
-
/// Create an analysis remark that explains why vectorization failed
///
/// \p RemarkName is the identifier for the remark. If \p I is passed it is
@@ -1804,9 +1777,9 @@ private:
// --- vectorization state --- //
- /// Holds the integer induction variable. This is the counter of the
+ /// Holds the primary induction variable. This is the counter of the
/// loop.
- PHINode *Induction;
+ PHINode *PrimaryInduction;
/// Holds the reduction variables.
ReductionList Reductions;
/// Holds all of the induction variables that we found in the loop.
@@ -1822,12 +1795,6 @@ private:
/// vars which can be accessed from outside the loop.
SmallPtrSet<Value *, 4> AllowedExit;
- /// Holds the instructions known to be uniform after vectorization.
- SmallPtrSet<Instruction *, 4> Uniforms;
-
- /// Holds the instructions known to be scalar after vectorization.
- SmallPtrSet<Instruction *, 4> Scalars;
-
/// Can we assume the absence of NaNs.
bool HasFunNoNaNAttr;
@@ -1861,16 +1828,26 @@ public:
: TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {}
+ /// \return An upper bound for the vectorization factor, or None if
+ /// vectorization should be avoided up front.
+ Optional<unsigned> computeMaxVF(bool OptForSize);
+
/// Information about vectorization costs
struct VectorizationFactor {
unsigned Width; // Vector width with best cost
unsigned Cost; // Cost of the loop with that width
};
/// \return The most profitable vectorization factor and the cost of that VF.
- /// This method checks every power of two up to VF. If UserVF is not ZERO
+ /// This method checks every power of two up to MaxVF. If UserVF is not ZERO
/// then this vectorization factor will be selected if vectorization is
/// possible.
- VectorizationFactor selectVectorizationFactor(bool OptForSize);
+ VectorizationFactor selectVectorizationFactor(unsigned MaxVF);
+
+ /// Setup cost-based decisions for user vectorization factor.
+ void selectUserVectorizationFactor(unsigned UserVF) {
+ collectUniformsAndScalars(UserVF);
+ collectInstsToScalarize(UserVF);
+ }
/// \return The size (in bits) of the smallest and widest types in the code
/// that needs to be vectorized. We ignore values that remain scalar such as
@@ -1884,6 +1861,15 @@ public:
unsigned selectInterleaveCount(bool OptForSize, unsigned VF,
unsigned LoopCost);
+ /// Memory access instruction may be vectorized in more than one way.
+ /// Form of instruction after vectorization depends on cost.
+ /// This function takes cost-based decisions for Load/Store instructions
+ /// and collects them in a map. This decisions map is used for building
+ /// the lists of loop-uniform and loop-scalar instructions.
+ /// The calculated cost is saved with widening decision in order to
+ /// avoid redundant calculations.
+ void setCostBasedWideningDecision(unsigned VF);
+
/// \brief A struct that represents some properties of the register usage
/// of a loop.
struct RegisterUsage {
@@ -1918,14 +1904,118 @@ public:
return Scalars->second.count(I);
}
+ /// Returns true if \p I is known to be uniform after vectorization.
+ bool isUniformAfterVectorization(Instruction *I, unsigned VF) const {
+ if (VF == 1)
+ return true;
+ assert(Uniforms.count(VF) && "VF not yet analyzed for uniformity");
+ auto UniformsPerVF = Uniforms.find(VF);
+ return UniformsPerVF->second.count(I);
+ }
+
+ /// Returns true if \p I is known to be scalar after vectorization.
+ bool isScalarAfterVectorization(Instruction *I, unsigned VF) const {
+ if (VF == 1)
+ return true;
+ assert(Scalars.count(VF) && "Scalar values are not calculated for VF");
+ auto ScalarsPerVF = Scalars.find(VF);
+ return ScalarsPerVF->second.count(I);
+ }
+
/// \returns True if instruction \p I can be truncated to a smaller bitwidth
/// for vectorization factor \p VF.
bool canTruncateToMinimalBitwidth(Instruction *I, unsigned VF) const {
return VF > 1 && MinBWs.count(I) && !isProfitableToScalarize(I, VF) &&
- !Legal->isScalarAfterVectorization(I);
+ !isScalarAfterVectorization(I, VF);
+ }
+
+ /// Decision that was taken during cost calculation for memory instruction.
+ enum InstWidening {
+ CM_Unknown,
+ CM_Widen,
+ CM_Interleave,
+ CM_GatherScatter,
+ CM_Scalarize
+ };
+
+ /// Save vectorization decision \p W and \p Cost taken by the cost model for
+ /// instruction \p I and vector width \p VF.
+ void setWideningDecision(Instruction *I, unsigned VF, InstWidening W,
+ unsigned Cost) {
+ assert(VF >= 2 && "Expected VF >=2");
+ WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost);
+ }
+
+ /// Save vectorization decision \p W and \p Cost taken by the cost model for
+ /// interleaving group \p Grp and vector width \p VF.
+ void setWideningDecision(const InterleaveGroup *Grp, unsigned VF,
+ InstWidening W, unsigned Cost) {
+ assert(VF >= 2 && "Expected VF >=2");
+ /// Broadcast this decicion to all instructions inside the group.
+ /// But the cost will be assigned to one instruction only.
+ for (unsigned i = 0; i < Grp->getFactor(); ++i) {
+ if (auto *I = Grp->getMember(i)) {
+ if (Grp->getInsertPos() == I)
+ WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost);
+ else
+ WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, 0);
+ }
+ }
+ }
+
+ /// Return the cost model decision for the given instruction \p I and vector
+ /// width \p VF. Return CM_Unknown if this instruction did not pass
+ /// through the cost modeling.
+ InstWidening getWideningDecision(Instruction *I, unsigned VF) {
+ assert(VF >= 2 && "Expected VF >=2");
+ std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF);
+ auto Itr = WideningDecisions.find(InstOnVF);
+ if (Itr == WideningDecisions.end())
+ return CM_Unknown;
+ return Itr->second.first;
+ }
+
+ /// Return the vectorization cost for the given instruction \p I and vector
+ /// width \p VF.
+ unsigned getWideningCost(Instruction *I, unsigned VF) {
+ assert(VF >= 2 && "Expected VF >=2");
+ std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF);
+ assert(WideningDecisions.count(InstOnVF) && "The cost is not calculated");
+ return WideningDecisions[InstOnVF].second;
+ }
+
+ /// Return True if instruction \p I is an optimizable truncate whose operand
+ /// is an induction variable. Such a truncate will be removed by adding a new
+ /// induction variable with the destination type.
+ bool isOptimizableIVTruncate(Instruction *I, unsigned VF) {
+
+ // If the instruction is not a truncate, return false.
+ auto *Trunc = dyn_cast<TruncInst>(I);
+ if (!Trunc)
+ return false;
+
+ // Get the source and destination types of the truncate.
+ Type *SrcTy = ToVectorTy(cast<CastInst>(I)->getSrcTy(), VF);
+ Type *DestTy = ToVectorTy(cast<CastInst>(I)->getDestTy(), VF);
+
+ // If the truncate is free for the given types, return false. Replacing a
+ // free truncate with an induction variable would add an induction variable
+ // update instruction to each iteration of the loop. We exclude from this
+ // check the primary induction variable since it will need an update
+ // instruction regardless.
+ Value *Op = Trunc->getOperand(0);
+ if (Op != Legal->getPrimaryInduction() && TTI.isTruncateFree(SrcTy, DestTy))
+ return false;
+
+ // If the truncated value is not an induction variable, return false.
+ return Legal->isInductionVariable(Op);
}
private:
+ /// \return An upper bound for the vectorization factor, larger than zero.
+ /// One is returned if vectorization should best be avoided due to cost.
+ unsigned computeFeasibleMaxVF(bool OptForSize);
+
/// The vectorization cost is a combination of the cost itself and a boolean
/// indicating whether any of the contributing operations will actually
/// operate on
@@ -1949,6 +2039,26 @@ private:
/// the vector type as an output parameter.
unsigned getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy);
+ /// Calculate vectorization cost of memory instruction \p I.
+ unsigned getMemoryInstructionCost(Instruction *I, unsigned VF);
+
+ /// The cost computation for scalarized memory instruction.
+ unsigned getMemInstScalarizationCost(Instruction *I, unsigned VF);
+
+ /// The cost computation for interleaving group of memory instructions.
+ unsigned getInterleaveGroupCost(Instruction *I, unsigned VF);
+
+ /// The cost computation for Gather/Scatter instruction.
+ unsigned getGatherScatterCost(Instruction *I, unsigned VF);
+
+ /// The cost computation for widening instruction \p I with consecutive
+ /// memory access.
+ unsigned getConsecutiveMemOpCost(Instruction *I, unsigned VF);
+
+ /// The cost calculation for Load instruction \p I with uniform pointer -
+ /// scalar load + broadcast.
+ unsigned getUniformMemOpCost(Instruction *I, unsigned VF);
+
/// Returns whether the instruction is a load or store and will be a emitted
/// as a vector operation.
bool isConsecutiveLoadOrStore(Instruction *I);
@@ -1972,12 +2082,24 @@ private:
/// pairs.
typedef DenseMap<Instruction *, unsigned> ScalarCostsTy;
+ /// A set containing all BasicBlocks that are known to present after
+ /// vectorization as a predicated block.
+ SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization;
+
/// A map holding scalar costs for different vectorization factors. The
/// presence of a cost for an instruction in the mapping indicates that the
/// instruction will be scalarized when vectorizing with the associated
/// vectorization factor. The entries are VF-ScalarCostTy pairs.
DenseMap<unsigned, ScalarCostsTy> InstsToScalarize;
+ /// Holds the instructions known to be uniform after vectorization.
+ /// The data is collected per VF.
+ DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Uniforms;
+
+ /// Holds the instructions known to be scalar after vectorization.
+ /// The data is collected per VF.
+ DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Scalars;
+
/// Returns the expected difference in cost from scalarizing the expression
/// feeding a predicated instruction \p PredInst. The instructions to
/// scalarize and their scalar costs are collected in \p ScalarCosts. A
@@ -1990,6 +2112,44 @@ private:
/// the loop.
void collectInstsToScalarize(unsigned VF);
+ /// Collect the instructions that are uniform after vectorization. An
+ /// instruction is uniform if we represent it with a single scalar value in
+ /// the vectorized loop corresponding to each vector iteration. Examples of
+ /// uniform instructions include pointer operands of consecutive or
+ /// interleaved memory accesses. Note that although uniformity implies an
+ /// instruction will be scalar, the reverse is not true. In general, a
+ /// scalarized instruction will be represented by VF scalar values in the
+ /// vectorized loop, each corresponding to an iteration of the original
+ /// scalar loop.
+ void collectLoopUniforms(unsigned VF);
+
+ /// Collect the instructions that are scalar after vectorization. An
+ /// instruction is scalar if it is known to be uniform or will be scalarized
+ /// during vectorization. Non-uniform scalarized instructions will be
+ /// represented by VF values in the vectorized loop, each corresponding to an
+ /// iteration of the original scalar loop.
+ void collectLoopScalars(unsigned VF);
+
+ /// Collect Uniform and Scalar values for the given \p VF.
+ /// The sets depend on CM decision for Load/Store instructions
+ /// that may be vectorized as interleave, gather-scatter or scalarized.
+ void collectUniformsAndScalars(unsigned VF) {
+ // Do the analysis once.
+ if (VF == 1 || Uniforms.count(VF))
+ return;
+ setCostBasedWideningDecision(VF);
+ collectLoopUniforms(VF);
+ collectLoopScalars(VF);
+ }
+
+ /// Keeps cost model vectorization decision and cost for instructions.
+ /// Right now it is used for memory instructions only.
+ typedef DenseMap<std::pair<Instruction *, unsigned>,
+ std::pair<InstWidening, unsigned>>
+ DecisionList;
+
+ DecisionList WideningDecisions;
+
public:
/// The loop that we evaluate.
Loop *TheLoop;
@@ -2019,6 +2179,23 @@ public:
SmallPtrSet<const Value *, 16> VecValuesToIgnore;
};
+/// LoopVectorizationPlanner - drives the vectorization process after having
+/// passed Legality checks.
+class LoopVectorizationPlanner {
+public:
+ LoopVectorizationPlanner(LoopVectorizationCostModel &CM) : CM(CM) {}
+
+ ~LoopVectorizationPlanner() {}
+
+ /// Plan how to best vectorize, return the best VF and its cost.
+ LoopVectorizationCostModel::VectorizationFactor plan(bool OptForSize,
+ unsigned UserVF);
+
+private:
+ /// The profitablity analysis.
+ LoopVectorizationCostModel &CM;
+};
+
/// \brief This holds vectorization requirements that must be verified late in
/// the process. The requirements are set by legalize and costmodel. Once
/// vectorization has been determined to be possible and profitable the
@@ -2134,8 +2311,6 @@ struct LoopVectorize : public FunctionPass {
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
- AU.addRequiredID(LoopSimplifyID);
- AU.addRequiredID(LCSSAID);
AU.addRequired<BlockFrequencyInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
@@ -2156,7 +2331,7 @@ struct LoopVectorize : public FunctionPass {
//===----------------------------------------------------------------------===//
// Implementation of LoopVectorizationLegality, InnerLoopVectorizer and
-// LoopVectorizationCostModel.
+// LoopVectorizationCostModel and LoopVectorizationPlanner.
//===----------------------------------------------------------------------===//
Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) {
@@ -2176,27 +2351,51 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) {
return Shuf;
}
-void InnerLoopVectorizer::createVectorIntInductionPHI(
- const InductionDescriptor &II, Instruction *EntryVal) {
+void InnerLoopVectorizer::createVectorIntOrFpInductionPHI(
+ const InductionDescriptor &II, Value *Step, Instruction *EntryVal) {
Value *Start = II.getStartValue();
- ConstantInt *Step = II.getConstIntStepValue();
- assert(Step && "Can not widen an IV with a non-constant step");
// Construct the initial value of the vector IV in the vector loop preheader
auto CurrIP = Builder.saveIP();
Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());
if (isa<TruncInst>(EntryVal)) {
+ assert(Start->getType()->isIntegerTy() &&
+ "Truncation requires an integer type");
auto *TruncType = cast<IntegerType>(EntryVal->getType());
- Step = ConstantInt::getSigned(TruncType, Step->getSExtValue());
+ Step = Builder.CreateTrunc(Step, TruncType);
Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType);
}
Value *SplatStart = Builder.CreateVectorSplat(VF, Start);
- Value *SteppedStart = getStepVector(SplatStart, 0, Step);
+ Value *SteppedStart =
+ getStepVector(SplatStart, 0, Step, II.getInductionOpcode());
+
+ // We create vector phi nodes for both integer and floating-point induction
+ // variables. Here, we determine the kind of arithmetic we will perform.
+ Instruction::BinaryOps AddOp;
+ Instruction::BinaryOps MulOp;
+ if (Step->getType()->isIntegerTy()) {
+ AddOp = Instruction::Add;
+ MulOp = Instruction::Mul;
+ } else {
+ AddOp = II.getInductionOpcode();
+ MulOp = Instruction::FMul;
+ }
+
+ // Multiply the vectorization factor by the step using integer or
+ // floating-point arithmetic as appropriate.
+ Value *ConstVF = getSignedIntOrFpConstant(Step->getType(), VF);
+ Value *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, Step, ConstVF));
+
+ // Create a vector splat to use in the induction update.
+ //
+ // FIXME: If the step is non-constant, we create the vector splat with
+ // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't
+ // handle a constant vector splat.
+ Value *SplatVF = isa<Constant>(Mul)
+ ? ConstantVector::getSplat(VF, cast<Constant>(Mul))
+ : Builder.CreateVectorSplat(VF, Mul);
Builder.restoreIP(CurrIP);
- Value *SplatVF =
- ConstantVector::getSplat(VF, ConstantInt::getSigned(Start->getType(),
- VF * Step->getSExtValue()));
// We may need to add the step a number of times, depending on the unroll
// factor. The last of those goes into the PHI.
PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind",
@@ -2205,8 +2404,8 @@ void InnerLoopVectorizer::createVectorIntInductionPHI(
VectorParts Entry(UF);
for (unsigned Part = 0; Part < UF; ++Part) {
Entry[Part] = LastInduction;
- LastInduction = cast<Instruction>(
- Builder.CreateAdd(LastInduction, SplatVF, "step.add"));
+ LastInduction = cast<Instruction>(addFastMathFlag(
+ Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")));
}
VectorLoopValueMap.initVector(EntryVal, Entry);
if (isa<TruncInst>(EntryVal))
@@ -2225,7 +2424,7 @@ void InnerLoopVectorizer::createVectorIntInductionPHI(
}
bool InnerLoopVectorizer::shouldScalarizeInstruction(Instruction *I) const {
- return Legal->isScalarAfterVectorization(I) ||
+ return Cost->isScalarAfterVectorization(I, VF) ||
Cost->isProfitableToScalarize(I, VF);
}
@@ -2239,7 +2438,10 @@ bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const {
return any_of(IV->users(), isScalarInst);
}
-void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {
+void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) {
+
+ assert((IV->getType()->isIntegerTy() || IV != OldInduction) &&
+ "Primary induction variable must have an integer type");
auto II = Legal->getInductionVars()->find(IV);
assert(II != Legal->getInductionVars()->end() && "IV is not an induction");
@@ -2251,9 +2453,6 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {
// induction variable.
Value *ScalarIV = nullptr;
- // The step of the induction.
- Value *Step = nullptr;
-
// The value from the original loop to which we are mapping the new induction
// variable.
Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV;
@@ -2266,45 +2465,49 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {
// least one user in the loop that is not widened.
auto NeedsScalarIV = VF > 1 && needsScalarInduction(EntryVal);
- // If the induction variable has a constant integer step value, go ahead and
- // get it now.
- if (ID.getConstIntStepValue())
- Step = ID.getConstIntStepValue();
+ // Generate code for the induction step. Note that induction steps are
+ // required to be loop-invariant
+ assert(PSE.getSE()->isLoopInvariant(ID.getStep(), OrigLoop) &&
+ "Induction step should be loop invariant");
+ auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout();
+ Value *Step = nullptr;
+ if (PSE.getSE()->isSCEVable(IV->getType())) {
+ SCEVExpander Exp(*PSE.getSE(), DL, "induction");
+ Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(),
+ LoopVectorPreHeader->getTerminator());
+ } else {
+ Step = cast<SCEVUnknown>(ID.getStep())->getValue();
+ }
// Try to create a new independent vector induction variable. If we can't
// create the phi node, we will splat the scalar induction variable in each
// loop iteration.
- if (VF > 1 && IV->getType() == Induction->getType() && Step &&
- !shouldScalarizeInstruction(EntryVal)) {
- createVectorIntInductionPHI(ID, EntryVal);
+ if (VF > 1 && !shouldScalarizeInstruction(EntryVal)) {
+ createVectorIntOrFpInductionPHI(ID, Step, EntryVal);
VectorizedIV = true;
}
// If we haven't yet vectorized the induction variable, or if we will create
// a scalar one, we need to define the scalar induction variable and step
// values. If we were given a truncation type, truncate the canonical
- // induction variable and constant step. Otherwise, derive these values from
- // the induction descriptor.
+ // induction variable and step. Otherwise, derive these values from the
+ // induction descriptor.
if (!VectorizedIV || NeedsScalarIV) {
+ ScalarIV = Induction;
+ if (IV != OldInduction) {
+ ScalarIV = IV->getType()->isIntegerTy()
+ ? Builder.CreateSExtOrTrunc(Induction, IV->getType())
+ : Builder.CreateCast(Instruction::SIToFP, Induction,
+ IV->getType());
+ ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL);
+ ScalarIV->setName("offset.idx");
+ }
if (Trunc) {
auto *TruncType = cast<IntegerType>(Trunc->getType());
- assert(Step && "Truncation requires constant integer step");
- auto StepInt = cast<ConstantInt>(Step)->getSExtValue();
- ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType);
- Step = ConstantInt::getSigned(TruncType, StepInt);
- } else {
- ScalarIV = Induction;
- auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout();
- if (IV != OldInduction) {
- ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType());
- ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL);
- ScalarIV->setName("offset.idx");
- }
- if (!Step) {
- SCEVExpander Exp(*PSE.getSE(), DL, "induction");
- Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(),
- &*Builder.GetInsertPoint());
- }
+ assert(Step->getType()->isIntegerTy() &&
+ "Truncation requires an integer step");
+ ScalarIV = Builder.CreateTrunc(ScalarIV, TruncType);
+ Step = Builder.CreateTrunc(Step, TruncType);
}
}
@@ -2314,7 +2517,8 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {
Value *Broadcasted = getBroadcastInstrs(ScalarIV);
VectorParts Entry(UF);
for (unsigned Part = 0; Part < UF; ++Part)
- Entry[Part] = getStepVector(Broadcasted, VF * Part, Step);
+ Entry[Part] =
+ getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode());
VectorLoopValueMap.initVector(EntryVal, Entry);
if (Trunc)
addMetadata(Entry, Trunc);
@@ -2327,7 +2531,7 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) {
// in the loop in the common case prior to InstCombine. We will be trading
// one vector extract for each scalar step.
if (NeedsScalarIV)
- buildScalarSteps(ScalarIV, Step, EntryVal);
+ buildScalarSteps(ScalarIV, Step, EntryVal, ID);
}
Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step,
@@ -2387,30 +2591,43 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step,
}
void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step,
- Value *EntryVal) {
+ Value *EntryVal,
+ const InductionDescriptor &ID) {
// We shouldn't have to build scalar steps if we aren't vectorizing.
assert(VF > 1 && "VF should be greater than one");
// Get the value type and ensure it and the step have the same integer type.
Type *ScalarIVTy = ScalarIV->getType()->getScalarType();
- assert(ScalarIVTy->isIntegerTy() && ScalarIVTy == Step->getType() &&
- "Val and Step should have the same integer type");
+ assert(ScalarIVTy == Step->getType() &&
+ "Val and Step should have the same type");
+
+ // We build scalar steps for both integer and floating-point induction
+ // variables. Here, we determine the kind of arithmetic we will perform.
+ Instruction::BinaryOps AddOp;
+ Instruction::BinaryOps MulOp;
+ if (ScalarIVTy->isIntegerTy()) {
+ AddOp = Instruction::Add;
+ MulOp = Instruction::Mul;
+ } else {
+ AddOp = ID.getInductionOpcode();
+ MulOp = Instruction::FMul;
+ }
// Determine the number of scalars we need to generate for each unroll
// iteration. If EntryVal is uniform, we only need to generate the first
// lane. Otherwise, we generate all VF values.
unsigned Lanes =
- Legal->isUniformAfterVectorization(cast<Instruction>(EntryVal)) ? 1 : VF;
+ Cost->isUniformAfterVectorization(cast<Instruction>(EntryVal), VF) ? 1 : VF;
// Compute the scalar steps and save the results in VectorLoopValueMap.
ScalarParts Entry(UF);
for (unsigned Part = 0; Part < UF; ++Part) {
Entry[Part].resize(VF);
for (unsigned Lane = 0; Lane < Lanes; ++Lane) {
- auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + Lane);
- auto *Mul = Builder.CreateMul(StartIdx, Step);
- auto *Add = Builder.CreateAdd(ScalarIV, Mul);
+ auto *StartIdx = getSignedIntOrFpConstant(ScalarIVTy, VF * Part + Lane);
+ auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step));
+ auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul));
Entry[Part][Lane] = Add;
}
}
@@ -2469,7 +2686,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) {
// known to be uniform after vectorization, this corresponds to lane zero
// of the last unroll iteration. Otherwise, the last instruction is the one
// we created for the last vector lane of the last unroll iteration.
- unsigned LastLane = Legal->isUniformAfterVectorization(I) ? 0 : VF - 1;
+ unsigned LastLane = Cost->isUniformAfterVectorization(I, VF) ? 0 : VF - 1;
auto *LastInst = cast<Instruction>(getScalarValue(V, UF - 1, LastLane));
// Set the insert point after the last scalarized instruction. This ensures
@@ -2486,7 +2703,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) {
// VectorLoopValueMap, we will only generate the insertelements once.
for (unsigned Part = 0; Part < UF; ++Part) {
Value *VectorValue = nullptr;
- if (Legal->isUniformAfterVectorization(I)) {
+ if (Cost->isUniformAfterVectorization(I, VF)) {
VectorValue = getBroadcastInstrs(getScalarValue(V, Part, 0));
} else {
VectorValue = UndefValue::get(VectorType::get(V->getType(), VF));
@@ -2515,8 +2732,9 @@ Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part,
if (OrigLoop->isLoopInvariant(V))
return V;
- assert(Lane > 0 ? !Legal->isUniformAfterVectorization(cast<Instruction>(V))
- : true && "Uniform values only have lane zero");
+ assert(Lane > 0 ?
+ !Cost->isUniformAfterVectorization(cast<Instruction>(V), VF)
+ : true && "Uniform values only have lane zero");
// If the value from the original loop has not been vectorized, it is
// represented by UF x VF scalar values in the new loop. Return the requested
@@ -2551,102 +2769,6 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) {
"reverse");
}
-// Get a mask to interleave \p NumVec vectors into a wide vector.
-// I.e. <0, VF, VF*2, ..., VF*(NumVec-1), 1, VF+1, VF*2+1, ...>
-// E.g. For 2 interleaved vectors, if VF is 4, the mask is:
-// <0, 4, 1, 5, 2, 6, 3, 7>
-static Constant *getInterleavedMask(IRBuilder<> &Builder, unsigned VF,
- unsigned NumVec) {
- SmallVector<Constant *, 16> Mask;
- for (unsigned i = 0; i < VF; i++)
- for (unsigned j = 0; j < NumVec; j++)
- Mask.push_back(Builder.getInt32(j * VF + i));
-
- return ConstantVector::get(Mask);
-}
-
-// Get the strided mask starting from index \p Start.
-// I.e. <Start, Start + Stride, ..., Start + Stride*(VF-1)>
-static Constant *getStridedMask(IRBuilder<> &Builder, unsigned Start,
- unsigned Stride, unsigned VF) {
- SmallVector<Constant *, 16> Mask;
- for (unsigned i = 0; i < VF; i++)
- Mask.push_back(Builder.getInt32(Start + i * Stride));
-
- return ConstantVector::get(Mask);
-}
-
-// Get a mask of two parts: The first part consists of sequential integers
-// starting from 0, The second part consists of UNDEFs.
-// I.e. <0, 1, 2, ..., NumInt - 1, undef, ..., undef>
-static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned NumInt,
- unsigned NumUndef) {
- SmallVector<Constant *, 16> Mask;
- for (unsigned i = 0; i < NumInt; i++)
- Mask.push_back(Builder.getInt32(i));
-
- Constant *Undef = UndefValue::get(Builder.getInt32Ty());
- for (unsigned i = 0; i < NumUndef; i++)
- Mask.push_back(Undef);
-
- return ConstantVector::get(Mask);
-}
-
-// Concatenate two vectors with the same element type. The 2nd vector should
-// not have more elements than the 1st vector. If the 2nd vector has less
-// elements, extend it with UNDEFs.
-static Value *ConcatenateTwoVectors(IRBuilder<> &Builder, Value *V1,
- Value *V2) {
- VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType());
- VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType());
- assert(VecTy1 && VecTy2 &&
- VecTy1->getScalarType() == VecTy2->getScalarType() &&
- "Expect two vectors with the same element type");
-
- unsigned NumElts1 = VecTy1->getNumElements();
- unsigned NumElts2 = VecTy2->getNumElements();
- assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements");
-
- if (NumElts1 > NumElts2) {
- // Extend with UNDEFs.
- Constant *ExtMask =
- getSequentialMask(Builder, NumElts2, NumElts1 - NumElts2);
- V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask);
- }
-
- Constant *Mask = getSequentialMask(Builder, NumElts1 + NumElts2, 0);
- return Builder.CreateShuffleVector(V1, V2, Mask);
-}
-
-// Concatenate vectors in the given list. All vectors have the same type.
-static Value *ConcatenateVectors(IRBuilder<> &Builder,
- ArrayRef<Value *> InputList) {
- unsigned NumVec = InputList.size();
- assert(NumVec > 1 && "Should be at least two vectors");
-
- SmallVector<Value *, 8> ResList;
- ResList.append(InputList.begin(), InputList.end());
- do {
- SmallVector<Value *, 8> TmpList;
- for (unsigned i = 0; i < NumVec - 1; i += 2) {
- Value *V0 = ResList[i], *V1 = ResList[i + 1];
- assert((V0->getType() == V1->getType() || i == NumVec - 2) &&
- "Only the last vector may have a different type");
-
- TmpList.push_back(ConcatenateTwoVectors(Builder, V0, V1));
- }
-
- // Push the last vector if the total number of vectors is odd.
- if (NumVec % 2 != 0)
- TmpList.push_back(ResList[NumVec - 1]);
-
- ResList = TmpList;
- NumVec = ResList.size();
- } while (NumVec > 1);
-
- return ResList[0];
-}
-
// Try to vectorize the interleave group that \p Instr belongs to.
//
// E.g. Translate following interleaved load group (factor = 3):
@@ -2683,15 +2805,13 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {
if (Instr != Group->getInsertPos())
return;
- LoadInst *LI = dyn_cast<LoadInst>(Instr);
- StoreInst *SI = dyn_cast<StoreInst>(Instr);
Value *Ptr = getPointerOperand(Instr);
// Prepare for the vector type of the interleaved load/store.
- Type *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType();
+ Type *ScalarTy = getMemInstValueType(Instr);
unsigned InterleaveFactor = Group->getFactor();
Type *VecTy = VectorType::get(ScalarTy, InterleaveFactor * VF);
- Type *PtrTy = VecTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
+ Type *PtrTy = VecTy->getPointerTo(getMemInstAddressSpace(Instr));
// Prepare for the new pointers.
setDebugLocFromInst(Builder, Ptr);
@@ -2731,7 +2851,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {
Value *UndefVec = UndefValue::get(VecTy);
// Vectorize the interleaved load group.
- if (LI) {
+ if (isa<LoadInst>(Instr)) {
// For each unroll part, create a wide load for the group.
SmallVector<Value *, 2> NewLoads;
@@ -2752,7 +2872,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {
continue;
VectorParts Entry(UF);
- Constant *StrideMask = getStridedMask(Builder, I, InterleaveFactor, VF);
+ Constant *StrideMask = createStrideMask(Builder, I, InterleaveFactor, VF);
for (unsigned Part = 0; Part < UF; Part++) {
Value *StridedVec = Builder.CreateShuffleVector(
NewLoads[Part], UndefVec, StrideMask, "strided.vec");
@@ -2796,10 +2916,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) {
}
// Concatenate all vectors into a wide vector.
- Value *WideVec = ConcatenateVectors(Builder, StoredVecs);
+ Value *WideVec = concatenateVectors(Builder, StoredVecs);
// Interleave the elements in the wide vector.
- Constant *IMask = getInterleavedMask(Builder, VF, InterleaveFactor);
+ Constant *IMask = createInterleaveMask(Builder, VF, InterleaveFactor);
Value *IVec = Builder.CreateShuffleVector(WideVec, UndefVec, IMask,
"interleaved.vec");
@@ -2816,103 +2936,44 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {
assert((LI || SI) && "Invalid Load/Store instruction");
- // Try to vectorize the interleave group if this access is interleaved.
- if (Legal->isAccessInterleaved(Instr))
+ LoopVectorizationCostModel::InstWidening Decision =
+ Cost->getWideningDecision(Instr, VF);
+ assert(Decision != LoopVectorizationCostModel::CM_Unknown &&
+ "CM decision should be taken at this point");
+ if (Decision == LoopVectorizationCostModel::CM_Interleave)
return vectorizeInterleaveGroup(Instr);
- Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType();
+ Type *ScalarDataTy = getMemInstValueType(Instr);
Type *DataTy = VectorType::get(ScalarDataTy, VF);
Value *Ptr = getPointerOperand(Instr);
- unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment();
+ unsigned Alignment = getMemInstAlignment(Instr);
// An alignment of 0 means target abi alignment. We need to use the scalar's
// target abi alignment in such a case.
const DataLayout &DL = Instr->getModule()->getDataLayout();
if (!Alignment)
Alignment = DL.getABITypeAlignment(ScalarDataTy);
- unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace();
+ unsigned AddressSpace = getMemInstAddressSpace(Instr);
// Scalarize the memory instruction if necessary.
- if (Legal->memoryInstructionMustBeScalarized(Instr, VF))
+ if (Decision == LoopVectorizationCostModel::CM_Scalarize)
return scalarizeInstruction(Instr, Legal->isScalarWithPredication(Instr));
// Determine if the pointer operand of the access is either consecutive or
// reverse consecutive.
int ConsecutiveStride = Legal->isConsecutivePtr(Ptr);
bool Reverse = ConsecutiveStride < 0;
-
- // Determine if either a gather or scatter operation is legal.
bool CreateGatherScatter =
- !ConsecutiveStride && Legal->isLegalGatherOrScatter(Instr);
+ (Decision == LoopVectorizationCostModel::CM_GatherScatter);
VectorParts VectorGep;
// Handle consecutive loads/stores.
- GetElementPtrInst *Gep = getGEPInstruction(Ptr);
if (ConsecutiveStride) {
- if (Gep) {
- unsigned NumOperands = Gep->getNumOperands();
-#ifndef NDEBUG
- // The original GEP that identified as a consecutive memory access
- // should have only one loop-variant operand.
- unsigned NumOfLoopVariantOps = 0;
- for (unsigned i = 0; i < NumOperands; ++i)
- if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)),
- OrigLoop))
- NumOfLoopVariantOps++;
- assert(NumOfLoopVariantOps == 1 &&
- "Consecutive GEP should have only one loop-variant operand");
-#endif
- GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone());
- Gep2->setName("gep.indvar");
-
- // A new GEP is created for a 0-lane value of the first unroll iteration.
- // The GEPs for the rest of the unroll iterations are computed below as an
- // offset from this GEP.
- for (unsigned i = 0; i < NumOperands; ++i)
- // We can apply getScalarValue() for all GEP indices. It returns an
- // original value for loop-invariant operand and 0-lane for consecutive
- // operand.
- Gep2->setOperand(i, getScalarValue(Gep->getOperand(i),
- 0, /* First unroll iteration */
- 0 /* 0-lane of the vector */ ));
- setDebugLocFromInst(Builder, Gep);
- Ptr = Builder.Insert(Gep2);
-
- } else { // No GEP
- setDebugLocFromInst(Builder, Ptr);
- Ptr = getScalarValue(Ptr, 0, 0);
- }
+ Ptr = getScalarValue(Ptr, 0, 0);
} else {
// At this point we should vector version of GEP for Gather or Scatter
assert(CreateGatherScatter && "The instruction should be scalarized");
- if (Gep) {
- // Vectorizing GEP, across UF parts. We want to get a vector value for base
- // and each index that's defined inside the loop, even if it is
- // loop-invariant but wasn't hoisted out. Otherwise we want to keep them
- // scalar.
- SmallVector<VectorParts, 4> OpsV;
- for (Value *Op : Gep->operands()) {
- Instruction *SrcInst = dyn_cast<Instruction>(Op);
- if (SrcInst && OrigLoop->contains(SrcInst))
- OpsV.push_back(getVectorValue(Op));
- else
- OpsV.push_back(VectorParts(UF, Op));
- }
- for (unsigned Part = 0; Part < UF; ++Part) {
- SmallVector<Value *, 4> Ops;
- Value *GEPBasePtr = OpsV[0][Part];
- for (unsigned i = 1; i < Gep->getNumOperands(); i++)
- Ops.push_back(OpsV[i][Part]);
- Value *NewGep = Builder.CreateGEP(GEPBasePtr, Ops, "VectorGep");
- cast<GetElementPtrInst>(NewGep)->setIsInBounds(Gep->isInBounds());
- assert(NewGep->getType()->isVectorTy() && "Expected vector GEP");
-
- NewGep =
- Builder.CreateBitCast(NewGep, VectorType::get(Ptr->getType(), VF));
- VectorGep.push_back(NewGep);
- }
- } else
- VectorGep = getVectorValue(Ptr);
+ VectorGep = getVectorValue(Ptr);
}
VectorParts Mask = createBlockInMask(Instr->getParent());
@@ -3027,7 +3088,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr,
// Determine the number of scalars we need to generate for each unroll
// iteration. If the instruction is uniform, we only need to generate the
// first lane. Otherwise, we generate all VF values.
- unsigned Lanes = Legal->isUniformAfterVectorization(Instr) ? 1 : VF;
+ unsigned Lanes = Cost->isUniformAfterVectorization(Instr, VF) ? 1 : VF;
// For each vector unroll 'part':
for (unsigned Part = 0; Part < UF; ++Part) {
@@ -3038,7 +3099,9 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr,
// Start if-block.
Value *Cmp = nullptr;
if (IfPredicateInstr) {
- Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Lane));
+ Cmp = Cond[Part];
+ if (Cmp->getType()->isVectorTy())
+ Cmp = Builder.CreateExtractElement(Cmp, Builder.getInt32(Lane));
Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp,
ConstantInt::get(Cmp->getType(), 1));
}
@@ -3346,7 +3409,7 @@ void InnerLoopVectorizer::createEmptyLoop() {
// - counts from zero, stepping by one
// - is the size of the widest induction variable type
// then we create a new one.
- OldInduction = Legal->getInduction();
+ OldInduction = Legal->getPrimaryInduction();
Type *IdxTy = Legal->getWidestInductionType();
// Split the single block loop into the two loop structure described above.
@@ -3543,7 +3606,7 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
namespace {
struct CSEDenseMapInfo {
- static bool canHandle(Instruction *I) {
+ static bool canHandle(const Instruction *I) {
return isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) ||
isa<ShuffleVectorInst>(I) || isa<GetElementPtrInst>(I);
}
@@ -3553,12 +3616,12 @@ struct CSEDenseMapInfo {
static inline Instruction *getTombstoneKey() {
return DenseMapInfo<Instruction *>::getTombstoneKey();
}
- static unsigned getHashValue(Instruction *I) {
+ static unsigned getHashValue(const Instruction *I) {
assert(canHandle(I) && "Unknown instruction!");
return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(),
I->value_op_end()));
}
- static bool isEqual(Instruction *LHS, Instruction *RHS) {
+ static bool isEqual(const Instruction *LHS, const Instruction *RHS) {
if (LHS == getEmptyKey() || RHS == getEmptyKey() ||
LHS == getTombstoneKey() || RHS == getTombstoneKey())
return LHS == RHS;
@@ -3589,51 +3652,6 @@ static void cse(BasicBlock *BB) {
}
}
-/// \brief Adds a 'fast' flag to floating point operations.
-static Value *addFastMathFlag(Value *V) {
- if (isa<FPMathOperator>(V)) {
- FastMathFlags Flags;
- Flags.setUnsafeAlgebra();
- cast<Instruction>(V)->setFastMathFlags(Flags);
- }
- return V;
-}
-
-/// \brief Estimate the overhead of scalarizing a value based on its type.
-/// Insert and Extract are set if the result needs to be inserted and/or
-/// extracted from vectors.
-static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract,
- const TargetTransformInfo &TTI) {
- if (Ty->isVoidTy())
- return 0;
-
- assert(Ty->isVectorTy() && "Can only scalarize vectors");
- unsigned Cost = 0;
-
- for (unsigned I = 0, E = Ty->getVectorNumElements(); I < E; ++I) {
- if (Extract)
- Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, I);
- if (Insert)
- Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I);
- }
-
- return Cost;
-}
-
-/// \brief Estimate the overhead of scalarizing an Instruction based on the
-/// types of its operands and return value.
-static unsigned getScalarizationOverhead(SmallVectorImpl<Type *> &OpTys,
- Type *RetTy,
- const TargetTransformInfo &TTI) {
- unsigned ScalarizationCost =
- getScalarizationOverhead(RetTy, true, false, TTI);
-
- for (Type *Ty : OpTys)
- ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI);
-
- return ScalarizationCost;
-}
-
/// \brief Estimate the overhead of scalarizing an instruction. This is a
/// convenience wrapper for the type-based getScalarizationOverhead API.
static unsigned getScalarizationOverhead(Instruction *I, unsigned VF,
@@ -3641,14 +3659,24 @@ static unsigned getScalarizationOverhead(Instruction *I, unsigned VF,
if (VF == 1)
return 0;
+ unsigned Cost = 0;
Type *RetTy = ToVectorTy(I->getType(), VF);
+ if (!RetTy->isVoidTy() &&
+ (!isa<LoadInst>(I) ||
+ !TTI.supportsEfficientVectorElementLoadStore()))
+ Cost += TTI.getScalarizationOverhead(RetTy, true, false);
- SmallVector<Type *, 4> OpTys;
- unsigned OperandsNum = I->getNumOperands();
- for (unsigned OpInd = 0; OpInd < OperandsNum; ++OpInd)
- OpTys.push_back(ToVectorTy(I->getOperand(OpInd)->getType(), VF));
+ if (CallInst *CI = dyn_cast<CallInst>(I)) {
+ SmallVector<const Value *, 4> Operands(CI->arg_operands());
+ Cost += TTI.getOperandsScalarizationOverhead(Operands, VF);
+ }
+ else if (!isa<StoreInst>(I) ||
+ !TTI.supportsEfficientVectorElementLoadStore()) {
+ SmallVector<const Value *, 4> Operands(I->operand_values());
+ Cost += TTI.getOperandsScalarizationOverhead(Operands, VF);
+ }
- return getScalarizationOverhead(OpTys, RetTy, TTI);
+ return Cost;
}
// Estimate cost of a call instruction CI if it were vectorized with factor VF.
@@ -3681,7 +3709,7 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF,
// Compute costs of unpacking argument values for the scalar calls and
// packing the return values to a vector.
- unsigned ScalarizationCost = getScalarizationOverhead(Tys, RetTy, TTI);
+ unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI);
unsigned Cost = ScalarCallCost * VF + ScalarizationCost;
@@ -3709,16 +3737,12 @@ static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF,
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
assert(ID && "Expected intrinsic call!");
- Type *RetTy = ToVectorTy(CI->getType(), VF);
- SmallVector<Type *, 4> Tys;
- for (Value *ArgOperand : CI->arg_operands())
- Tys.push_back(ToVectorTy(ArgOperand->getType(), VF));
-
FastMathFlags FMF;
if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
FMF = FPMO->getFastMathFlags();
- return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF);
+ SmallVector<Value *, 4> Operands(CI->arg_operands());
+ return TTI.getIntrinsicInstrCost(ID, CI->getType(), Operands, FMF, VF);
}
static Type *smallestIntegerVectorType(Type *T1, Type *T2) {
@@ -3861,30 +3885,27 @@ void InnerLoopVectorizer::vectorizeLoop() {
// the cost-model.
//
//===------------------------------------------------===//
- Constant *Zero = Builder.getInt32(0);
- // In order to support recurrences we need to be able to vectorize Phi nodes.
- // Phi nodes have cycles, so we need to vectorize them in two stages. First,
- // we create a new vector PHI node with no incoming edges. We use this value
- // when we vectorize all of the instructions that use the PHI. Next, after
- // all of the instructions in the block are complete we add the new incoming
- // edges to the PHI. At this point all of the instructions in the basic block
- // are vectorized, so we can use them to construct the PHI.
- PhiVector PHIsToFix;
-
- // Collect instructions from the original loop that will become trivially
- // dead in the vectorized loop. We don't need to vectorize these
- // instructions.
- collectTriviallyDeadInstructions();
+ // Collect instructions from the original loop that will become trivially dead
+ // in the vectorized loop. We don't need to vectorize these instructions. For
+ // example, original induction update instructions can become dead because we
+ // separately emit induction "steps" when generating code for the new loop.
+ // Similarly, we create a new latch condition when setting up the structure
+ // of the new loop, so the old one can become dead.
+ SmallPtrSet<Instruction *, 4> DeadInstructions;
+ collectTriviallyDeadInstructions(DeadInstructions);
// Scan the loop in a topological order to ensure that defs are vectorized
// before users.
LoopBlocksDFS DFS(OrigLoop);
DFS.perform(LI);
- // Vectorize all of the blocks in the original loop.
+ // Vectorize all instructions in the original loop that will not become
+ // trivially dead when vectorized.
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO()))
- vectorizeBlockInLoop(BB, &PHIsToFix);
+ for (Instruction &I : *BB)
+ if (!DeadInstructions.count(&I))
+ vectorizeInstruction(I);
// Insert truncates and extends for any truncated instructions as hints to
// InstCombine.
@@ -3892,224 +3913,10 @@ void InnerLoopVectorizer::vectorizeLoop() {
truncateToMinimalBitwidths();
// At this point every instruction in the original loop is widened to a
- // vector form. Now we need to fix the recurrences in PHIsToFix. These PHI
+ // vector form. Now we need to fix the recurrences in the loop. These PHI
// nodes are currently empty because we did not want to introduce cycles.
// This is the second stage of vectorizing recurrences.
- for (PHINode *Phi : PHIsToFix) {
- assert(Phi && "Unable to recover vectorized PHI");
-
- // Handle first-order recurrences that need to be fixed.
- if (Legal->isFirstOrderRecurrence(Phi)) {
- fixFirstOrderRecurrence(Phi);
- continue;
- }
-
- // If the phi node is not a first-order recurrence, it must be a reduction.
- // Get it's reduction variable descriptor.
- assert(Legal->isReductionVariable(Phi) &&
- "Unable to find the reduction variable");
- RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi];
-
- RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind();
- TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue();
- Instruction *LoopExitInst = RdxDesc.getLoopExitInstr();
- RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind =
- RdxDesc.getMinMaxRecurrenceKind();
- setDebugLocFromInst(Builder, ReductionStartValue);
-
- // We need to generate a reduction vector from the incoming scalar.
- // To do so, we need to generate the 'identity' vector and override
- // one of the elements with the incoming scalar reduction. We need
- // to do it in the vector-loop preheader.
- Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator());
-
- // This is the vector-clone of the value that leaves the loop.
- const VectorParts &VectorExit = getVectorValue(LoopExitInst);
- Type *VecTy = VectorExit[0]->getType();
-
- // Find the reduction identity variable. Zero for addition, or, xor,
- // one for multiplication, -1 for And.
- Value *Identity;
- Value *VectorStart;
- if (RK == RecurrenceDescriptor::RK_IntegerMinMax ||
- RK == RecurrenceDescriptor::RK_FloatMinMax) {
- // MinMax reduction have the start value as their identify.
- if (VF == 1) {
- VectorStart = Identity = ReductionStartValue;
- } else {
- VectorStart = Identity =
- Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident");
- }
- } else {
- // Handle other reduction kinds:
- Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity(
- RK, VecTy->getScalarType());
- if (VF == 1) {
- Identity = Iden;
- // This vector is the Identity vector where the first element is the
- // incoming scalar reduction.
- VectorStart = ReductionStartValue;
- } else {
- Identity = ConstantVector::getSplat(VF, Iden);
-
- // This vector is the Identity vector where the first element is the
- // incoming scalar reduction.
- VectorStart =
- Builder.CreateInsertElement(Identity, ReductionStartValue, Zero);
- }
- }
-
- // Fix the vector-loop phi.
-
- // Reductions do not have to start at zero. They can start with
- // any loop invariant values.
- const VectorParts &VecRdxPhi = getVectorValue(Phi);
- BasicBlock *Latch = OrigLoop->getLoopLatch();
- Value *LoopVal = Phi->getIncomingValueForBlock(Latch);
- const VectorParts &Val = getVectorValue(LoopVal);
- for (unsigned part = 0; part < UF; ++part) {
- // Make sure to add the reduction stat value only to the
- // first unroll part.
- Value *StartVal = (part == 0) ? VectorStart : Identity;
- cast<PHINode>(VecRdxPhi[part])
- ->addIncoming(StartVal, LoopVectorPreHeader);
- cast<PHINode>(VecRdxPhi[part])
- ->addIncoming(Val[part], LoopVectorBody);
- }
-
- // Before each round, move the insertion point right between
- // the PHIs and the values we are going to write.
- // This allows us to write both PHINodes and the extractelement
- // instructions.
- Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
-
- VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst);
- setDebugLocFromInst(Builder, LoopExitInst);
-
- // If the vector reduction can be performed in a smaller type, we truncate
- // then extend the loop exit value to enable InstCombine to evaluate the
- // entire expression in the smaller type.
- if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) {
- Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
- Builder.SetInsertPoint(LoopVectorBody->getTerminator());
- for (unsigned part = 0; part < UF; ++part) {
- Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy);
- Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy)
- : Builder.CreateZExt(Trunc, VecTy);
- for (Value::user_iterator UI = RdxParts[part]->user_begin();
- UI != RdxParts[part]->user_end();)
- if (*UI != Trunc) {
- (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd);
- RdxParts[part] = Extnd;
- } else {
- ++UI;
- }
- }
- Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
- for (unsigned part = 0; part < UF; ++part)
- RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy);
- }
-
- // Reduce all of the unrolled parts into a single vector.
- Value *ReducedPartRdx = RdxParts[0];
- unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK);
- setDebugLocFromInst(Builder, ReducedPartRdx);
- for (unsigned part = 1; part < UF; ++part) {
- if (Op != Instruction::ICmp && Op != Instruction::FCmp)
- // Floating point operations had to be 'fast' to enable the reduction.
- ReducedPartRdx = addFastMathFlag(
- Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part],
- ReducedPartRdx, "bin.rdx"));
- else
- ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp(
- Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]);
- }
-
- if (VF > 1) {
- // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
- // and vector ops, reducing the set of values being computed by half each
- // round.
- assert(isPowerOf2_32(VF) &&
- "Reduction emission only supported for pow2 vectors!");
- Value *TmpVec = ReducedPartRdx;
- SmallVector<Constant *, 32> ShuffleMask(VF, nullptr);
- for (unsigned i = VF; i != 1; i >>= 1) {
- // Move the upper half of the vector to the lower half.
- for (unsigned j = 0; j != i / 2; ++j)
- ShuffleMask[j] = Builder.getInt32(i / 2 + j);
-
- // Fill the rest of the mask with undef.
- std::fill(&ShuffleMask[i / 2], ShuffleMask.end(),
- UndefValue::get(Builder.getInt32Ty()));
-
- Value *Shuf = Builder.CreateShuffleVector(
- TmpVec, UndefValue::get(TmpVec->getType()),
- ConstantVector::get(ShuffleMask), "rdx.shuf");
-
- if (Op != Instruction::ICmp && Op != Instruction::FCmp)
- // Floating point operations had to be 'fast' to enable the reduction.
- TmpVec = addFastMathFlag(Builder.CreateBinOp(
- (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"));
- else
- TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind,
- TmpVec, Shuf);
- }
-
- // The result is in the first element of the vector.
- ReducedPartRdx =
- Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
-
- // If the reduction can be performed in a smaller type, we need to extend
- // the reduction to the wider type before we branch to the original loop.
- if (Phi->getType() != RdxDesc.getRecurrenceType())
- ReducedPartRdx =
- RdxDesc.isSigned()
- ? Builder.CreateSExt(ReducedPartRdx, Phi->getType())
- : Builder.CreateZExt(ReducedPartRdx, Phi->getType());
- }
-
- // Create a phi node that merges control-flow from the backedge-taken check
- // block and the middle block.
- PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx",
- LoopScalarPreHeader->getTerminator());
- for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I)
- BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]);
- BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock);
-
- // Now, we need to fix the users of the reduction variable
- // inside and outside of the scalar remainder loop.
- // We know that the loop is in LCSSA form. We need to update the
- // PHI nodes in the exit blocks.
- for (BasicBlock::iterator LEI = LoopExitBlock->begin(),
- LEE = LoopExitBlock->end();
- LEI != LEE; ++LEI) {
- PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI);
- if (!LCSSAPhi)
- break;
-
- // All PHINodes need to have a single entry edge, or two if
- // we already fixed them.
- assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI");
-
- // We found our reduction value exit-PHI. Update it with the
- // incoming bypass edge.
- if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) {
- // Add an edge coming from the bypass.
- LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock);
- break;
- }
- } // end of the LCSSA phi scan.
-
- // Fix the scalar loop reduction variable with the incoming reduction sum
- // from the vector body and from the backedge value.
- int IncomingEdgeBlockIdx =
- Phi->getBasicBlockIndex(OrigLoop->getLoopLatch());
- assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index");
- // Pick the other block.
- int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1);
- Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi);
- Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst);
- } // end of for each Phi in PHIsToFix.
+ fixCrossIterationPHIs();
// Update the dominator tree.
//
@@ -4134,6 +3941,25 @@ void InnerLoopVectorizer::vectorizeLoop() {
cse(LoopVectorBody);
}
+void InnerLoopVectorizer::fixCrossIterationPHIs() {
+ // In order to support recurrences we need to be able to vectorize Phi nodes.
+ // Phi nodes have cycles, so we need to vectorize them in two stages. This is
+ // stage #2: We now need to fix the recurrences by adding incoming edges to
+ // the currently empty PHI nodes. At this point every instruction in the
+ // original loop is widened to a vector form so we can use them to construct
+ // the incoming edges.
+ for (Instruction &I : *OrigLoop->getHeader()) {
+ PHINode *Phi = dyn_cast<PHINode>(&I);
+ if (!Phi)
+ break;
+ // Handle first-order recurrences and reductions that need to be fixed.
+ if (Legal->isFirstOrderRecurrence(Phi))
+ fixFirstOrderRecurrence(Phi);
+ else if (Legal->isReductionVariable(Phi))
+ fixReduction(Phi);
+ }
+}
+
void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {
// This is the second phase of vectorizing first-order recurrences. An
@@ -4212,15 +4038,17 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {
auto *VecPhi = Builder.CreatePHI(VectorInit->getType(), 2, "vector.recur");
VecPhi->addIncoming(VectorInit, LoopVectorPreHeader);
- // Get the vectorized previous value. We ensured the previous values was an
- // instruction when detecting the recurrence.
+ // Get the vectorized previous value.
auto &PreviousParts = getVectorValue(Previous);
- // Set the insertion point to be after this instruction. We ensured the
- // previous value dominated all uses of the phi when detecting the
- // recurrence.
- Builder.SetInsertPoint(
- &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1])));
+ // Set the insertion point after the previous value if it is an instruction.
+ // Note that the previous value may have been constant-folded so it is not
+ // guaranteed to be an instruction in the vector loop.
+ if (LI->getLoopFor(LoopVectorBody)->isLoopInvariant(PreviousParts[UF - 1]))
+ Builder.SetInsertPoint(&*LoopVectorBody->getFirstInsertionPt());
+ else
+ Builder.SetInsertPoint(
+ &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1])));
// We will construct a vector for the recurrence by combining the values for
// the current and previous iterations. This is the required shuffle mask.
@@ -4251,18 +4079,33 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {
// Extract the last vector element in the middle block. This will be the
// initial value for the recurrence when jumping to the scalar loop.
- auto *Extract = Incoming;
+ auto *ExtractForScalar = Incoming;
if (VF > 1) {
Builder.SetInsertPoint(LoopMiddleBlock->getTerminator());
- Extract = Builder.CreateExtractElement(Extract, Builder.getInt32(VF - 1),
- "vector.recur.extract");
- }
+ ExtractForScalar = Builder.CreateExtractElement(
+ ExtractForScalar, Builder.getInt32(VF - 1), "vector.recur.extract");
+ }
+ // Extract the second last element in the middle block if the
+ // Phi is used outside the loop. We need to extract the phi itself
+ // and not the last element (the phi update in the current iteration). This
+ // will be the value when jumping to the exit block from the LoopMiddleBlock,
+ // when the scalar loop is not run at all.
+ Value *ExtractForPhiUsedOutsideLoop = nullptr;
+ if (VF > 1)
+ ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement(
+ Incoming, Builder.getInt32(VF - 2), "vector.recur.extract.for.phi");
+ // When loop is unrolled without vectorizing, initialize
+ // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled value of
+ // `Incoming`. This is analogous to the vectorized case above: extracting the
+ // second last element when VF > 1.
+ else if (UF > 1)
+ ExtractForPhiUsedOutsideLoop = PreviousParts[UF - 2];
// Fix the initial value of the original recurrence in the scalar loop.
Builder.SetInsertPoint(&*LoopScalarPreHeader->begin());
auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init");
for (auto *BB : predecessors(LoopScalarPreHeader)) {
- auto *Incoming = BB == LoopMiddleBlock ? Extract : ScalarInit;
+ auto *Incoming = BB == LoopMiddleBlock ? ExtractForScalar : ScalarInit;
Start->addIncoming(Incoming, BB);
}
@@ -4279,12 +4122,218 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) {
if (!LCSSAPhi)
break;
if (LCSSAPhi->getIncomingValue(0) == Phi) {
- LCSSAPhi->addIncoming(Extract, LoopMiddleBlock);
+ LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock);
break;
}
}
}
+void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
+ Constant *Zero = Builder.getInt32(0);
+
+ // Get it's reduction variable descriptor.
+ assert(Legal->isReductionVariable(Phi) &&
+ "Unable to find the reduction variable");
+ RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi];
+
+ RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind();
+ TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue();
+ Instruction *LoopExitInst = RdxDesc.getLoopExitInstr();
+ RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind =
+ RdxDesc.getMinMaxRecurrenceKind();
+ setDebugLocFromInst(Builder, ReductionStartValue);
+
+ // We need to generate a reduction vector from the incoming scalar.
+ // To do so, we need to generate the 'identity' vector and override
+ // one of the elements with the incoming scalar reduction. We need
+ // to do it in the vector-loop preheader.
+ Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator());
+
+ // This is the vector-clone of the value that leaves the loop.
+ const VectorParts &VectorExit = getVectorValue(LoopExitInst);
+ Type *VecTy = VectorExit[0]->getType();
+
+ // Find the reduction identity variable. Zero for addition, or, xor,
+ // one for multiplication, -1 for And.
+ Value *Identity;
+ Value *VectorStart;
+ if (RK == RecurrenceDescriptor::RK_IntegerMinMax ||
+ RK == RecurrenceDescriptor::RK_FloatMinMax) {
+ // MinMax reduction have the start value as their identify.
+ if (VF == 1) {
+ VectorStart = Identity = ReductionStartValue;
+ } else {
+ VectorStart = Identity =
+ Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident");
+ }
+ } else {
+ // Handle other reduction kinds:
+ Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity(
+ RK, VecTy->getScalarType());
+ if (VF == 1) {
+ Identity = Iden;
+ // This vector is the Identity vector where the first element is the
+ // incoming scalar reduction.
+ VectorStart = ReductionStartValue;
+ } else {
+ Identity = ConstantVector::getSplat(VF, Iden);
+
+ // This vector is the Identity vector where the first element is the
+ // incoming scalar reduction.
+ VectorStart =
+ Builder.CreateInsertElement(Identity, ReductionStartValue, Zero);
+ }
+ }
+
+ // Fix the vector-loop phi.
+
+ // Reductions do not have to start at zero. They can start with
+ // any loop invariant values.
+ const VectorParts &VecRdxPhi = getVectorValue(Phi);
+ BasicBlock *Latch = OrigLoop->getLoopLatch();
+ Value *LoopVal = Phi->getIncomingValueForBlock(Latch);
+ const VectorParts &Val = getVectorValue(LoopVal);
+ for (unsigned part = 0; part < UF; ++part) {
+ // Make sure to add the reduction stat value only to the
+ // first unroll part.
+ Value *StartVal = (part == 0) ? VectorStart : Identity;
+ cast<PHINode>(VecRdxPhi[part])
+ ->addIncoming(StartVal, LoopVectorPreHeader);
+ cast<PHINode>(VecRdxPhi[part])
+ ->addIncoming(Val[part], LI->getLoopFor(LoopVectorBody)->getLoopLatch());
+ }
+
+ // Before each round, move the insertion point right between
+ // the PHIs and the values we are going to write.
+ // This allows us to write both PHINodes and the extractelement
+ // instructions.
+ Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
+
+ VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst);
+ setDebugLocFromInst(Builder, LoopExitInst);
+
+ // If the vector reduction can be performed in a smaller type, we truncate
+ // then extend the loop exit value to enable InstCombine to evaluate the
+ // entire expression in the smaller type.
+ if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) {
+ Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
+ Builder.SetInsertPoint(LoopVectorBody->getTerminator());
+ for (unsigned part = 0; part < UF; ++part) {
+ Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy);
+ Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy)
+ : Builder.CreateZExt(Trunc, VecTy);
+ for (Value::user_iterator UI = RdxParts[part]->user_begin();
+ UI != RdxParts[part]->user_end();)
+ if (*UI != Trunc) {
+ (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd);
+ RdxParts[part] = Extnd;
+ } else {
+ ++UI;
+ }
+ }
+ Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
+ for (unsigned part = 0; part < UF; ++part)
+ RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy);
+ }
+
+ // Reduce all of the unrolled parts into a single vector.
+ Value *ReducedPartRdx = RdxParts[0];
+ unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK);
+ setDebugLocFromInst(Builder, ReducedPartRdx);
+ for (unsigned part = 1; part < UF; ++part) {
+ if (Op != Instruction::ICmp && Op != Instruction::FCmp)
+ // Floating point operations had to be 'fast' to enable the reduction.
+ ReducedPartRdx = addFastMathFlag(
+ Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part],
+ ReducedPartRdx, "bin.rdx"));
+ else
+ ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp(
+ Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]);
+ }
+
+ if (VF > 1) {
+ // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
+ // and vector ops, reducing the set of values being computed by half each
+ // round.
+ assert(isPowerOf2_32(VF) &&
+ "Reduction emission only supported for pow2 vectors!");
+ Value *TmpVec = ReducedPartRdx;
+ SmallVector<Constant *, 32> ShuffleMask(VF, nullptr);
+ for (unsigned i = VF; i != 1; i >>= 1) {
+ // Move the upper half of the vector to the lower half.
+ for (unsigned j = 0; j != i / 2; ++j)
+ ShuffleMask[j] = Builder.getInt32(i / 2 + j);
+
+ // Fill the rest of the mask with undef.
+ std::fill(&ShuffleMask[i / 2], ShuffleMask.end(),
+ UndefValue::get(Builder.getInt32Ty()));
+
+ Value *Shuf = Builder.CreateShuffleVector(
+ TmpVec, UndefValue::get(TmpVec->getType()),
+ ConstantVector::get(ShuffleMask), "rdx.shuf");
+
+ if (Op != Instruction::ICmp && Op != Instruction::FCmp)
+ // Floating point operations had to be 'fast' to enable the reduction.
+ TmpVec = addFastMathFlag(Builder.CreateBinOp(
+ (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"));
+ else
+ TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind,
+ TmpVec, Shuf);
+ }
+
+ // The result is in the first element of the vector.
+ ReducedPartRdx =
+ Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
+
+ // If the reduction can be performed in a smaller type, we need to extend
+ // the reduction to the wider type before we branch to the original loop.
+ if (Phi->getType() != RdxDesc.getRecurrenceType())
+ ReducedPartRdx =
+ RdxDesc.isSigned()
+ ? Builder.CreateSExt(ReducedPartRdx, Phi->getType())
+ : Builder.CreateZExt(ReducedPartRdx, Phi->getType());
+ }
+
+ // Create a phi node that merges control-flow from the backedge-taken check
+ // block and the middle block.
+ PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx",
+ LoopScalarPreHeader->getTerminator());
+ for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I)
+ BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]);
+ BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock);
+
+ // Now, we need to fix the users of the reduction variable
+ // inside and outside of the scalar remainder loop.
+ // We know that the loop is in LCSSA form. We need to update the
+ // PHI nodes in the exit blocks.
+ for (BasicBlock::iterator LEI = LoopExitBlock->begin(),
+ LEE = LoopExitBlock->end();
+ LEI != LEE; ++LEI) {
+ PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI);
+ if (!LCSSAPhi)
+ break;
+
+ // All PHINodes need to have a single entry edge, or two if
+ // we already fixed them.
+ assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI");
+
+ // We found a reduction value exit-PHI. Update it with the
+ // incoming bypass edge.
+ if (LCSSAPhi->getIncomingValue(0) == LoopExitInst)
+ LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock);
+ } // end of the LCSSA phi scan.
+
+ // Fix the scalar loop reduction variable with the incoming reduction sum
+ // from the vector body and from the backedge value.
+ int IncomingEdgeBlockIdx =
+ Phi->getBasicBlockIndex(OrigLoop->getLoopLatch());
+ assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index");
+ // Pick the other block.
+ int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1);
+ Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi);
+ Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst);
+}
+
void InnerLoopVectorizer::fixLCSSAPHIs() {
for (Instruction &LEI : *LoopExitBlock) {
auto *LCSSAPhi = dyn_cast<PHINode>(&LEI);
@@ -4296,7 +4345,8 @@ void InnerLoopVectorizer::fixLCSSAPHIs() {
}
}
-void InnerLoopVectorizer::collectTriviallyDeadInstructions() {
+void InnerLoopVectorizer::collectTriviallyDeadInstructions(
+ SmallPtrSetImpl<Instruction *> &DeadInstructions) {
BasicBlock *Latch = OrigLoop->getLoopLatch();
// We create new control-flow for the vectorized loop, so the original
@@ -4563,9 +4613,12 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) {
}
void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,
- unsigned VF, PhiVector *PV) {
+ unsigned VF) {
PHINode *P = cast<PHINode>(PN);
- // Handle recurrences.
+ // In order to support recurrences we need to be able to vectorize Phi nodes.
+ // Phi nodes have cycles, so we need to vectorize them in two stages. This is
+ // stage #1: We create a new vector PHI node with no incoming edges. We'll use
+ // this value when we vectorize all of the instructions that use the PHI.
if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) {
VectorParts Entry(UF);
for (unsigned part = 0; part < UF; ++part) {
@@ -4576,7 +4629,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,
VecTy, 2, "vec.phi", &*LoopVectorBody->getFirstInsertionPt());
}
VectorLoopValueMap.initVector(P, Entry);
- PV->push_back(P);
return;
}
@@ -4631,7 +4683,8 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,
case InductionDescriptor::IK_NoInduction:
llvm_unreachable("Unknown induction");
case InductionDescriptor::IK_IntInduction:
- return widenIntInduction(P);
+ case InductionDescriptor::IK_FpInduction:
+ return widenIntOrFpInduction(P);
case InductionDescriptor::IK_PtrInduction: {
// Handle the pointer induction variable case.
assert(P->getType()->isPointerTy() && "Unexpected type.");
@@ -4641,7 +4694,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,
// Determine the number of scalars we need to generate for each unroll
// iteration. If the instruction is uniform, we only need to generate the
// first lane. Otherwise, we generate all VF values.
- unsigned Lanes = Legal->isUniformAfterVectorization(P) ? 1 : VF;
+ unsigned Lanes = Cost->isUniformAfterVectorization(P, VF) ? 1 : VF;
// These are the scalar results. Notice that we don't generate vector GEPs
// because scalar GEPs result in better code.
ScalarParts Entry(UF);
@@ -4658,30 +4711,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF,
VectorLoopValueMap.initScalar(P, Entry);
return;
}
- case InductionDescriptor::IK_FpInduction: {
- assert(P->getType() == II.getStartValue()->getType() &&
- "Types must match");
- // Handle other induction variables that are now based on the
- // canonical one.
- assert(P != OldInduction && "Primary induction can be integer only");
-
- Value *V = Builder.CreateCast(Instruction::SIToFP, Induction, P->getType());
- V = II.transform(Builder, V, PSE.getSE(), DL);
- V->setName("fp.offset.idx");
-
- // Now we have scalar op: %fp.offset.idx = StartVal +/- Induction*StepVal
-
- Value *Broadcasted = getBroadcastInstrs(V);
- // After broadcasting the induction variable we need to make the vector
- // consecutive by adding StepVal*0, StepVal*1, StepVal*2, etc.
- Value *StepVal = cast<SCEVUnknown>(II.getStep())->getValue();
- VectorParts Entry(UF);
- for (unsigned part = 0; part < UF; ++part)
- Entry[part] = getStepVector(Broadcasted, VF * part, StepVal,
- II.getInductionOpcode());
- VectorLoopValueMap.initVector(P, Entry);
- return;
- }
}
}
@@ -4703,269 +4732,323 @@ static bool mayDivideByZero(Instruction &I) {
return !CInt || CInt->isZero();
}
-void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
- // For each instruction in the old loop.
- for (Instruction &I : *BB) {
-
- // If the instruction will become trivially dead when vectorized, we don't
- // need to generate it.
- if (DeadInstructions.count(&I))
- continue;
+void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) {
+ // Scalarize instructions that should remain scalar after vectorization.
+ if (VF > 1 &&
+ !(isa<BranchInst>(&I) || isa<PHINode>(&I) || isa<DbgInfoIntrinsic>(&I)) &&
+ shouldScalarizeInstruction(&I)) {
+ scalarizeInstruction(&I, Legal->isScalarWithPredication(&I));
+ return;
+ }
- // Scalarize instructions that should remain scalar after vectorization.
- if (VF > 1 &&
- !(isa<BranchInst>(&I) || isa<PHINode>(&I) ||
- isa<DbgInfoIntrinsic>(&I)) &&
- shouldScalarizeInstruction(&I)) {
- scalarizeInstruction(&I, Legal->isScalarWithPredication(&I));
- continue;
- }
+ switch (I.getOpcode()) {
+ case Instruction::Br:
+ // Nothing to do for PHIs and BR, since we already took care of the
+ // loop control flow instructions.
+ break;
+ case Instruction::PHI: {
+ // Vectorize PHINodes.
+ widenPHIInstruction(&I, UF, VF);
+ break;
+ } // End of PHI.
+ case Instruction::GetElementPtr: {
+ // Construct a vector GEP by widening the operands of the scalar GEP as
+ // necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP
+ // results in a vector of pointers when at least one operand of the GEP
+ // is vector-typed. Thus, to keep the representation compact, we only use
+ // vector-typed operands for loop-varying values.
+ auto *GEP = cast<GetElementPtrInst>(&I);
+ VectorParts Entry(UF);
- switch (I.getOpcode()) {
- case Instruction::Br:
- // Nothing to do for PHIs and BR, since we already took care of the
- // loop control flow instructions.
- continue;
- case Instruction::PHI: {
- // Vectorize PHINodes.
- widenPHIInstruction(&I, UF, VF, PV);
- continue;
- } // End of PHI.
-
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::SRem:
- case Instruction::URem:
- // Scalarize with predication if this instruction may divide by zero and
- // block execution is conditional, otherwise fallthrough.
- if (Legal->isScalarWithPredication(&I)) {
- scalarizeInstruction(&I, true);
- continue;
- }
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::FDiv:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor: {
- // Just widen binops.
- auto *BinOp = cast<BinaryOperator>(&I);
- setDebugLocFromInst(Builder, BinOp);
- const VectorParts &A = getVectorValue(BinOp->getOperand(0));
- const VectorParts &B = getVectorValue(BinOp->getOperand(1));
-
- // Use this vector value for all users of the original instruction.
- VectorParts Entry(UF);
+ if (VF > 1 && OrigLoop->hasLoopInvariantOperands(GEP)) {
+ // If we are vectorizing, but the GEP has only loop-invariant operands,
+ // the GEP we build (by only using vector-typed operands for
+ // loop-varying values) would be a scalar pointer. Thus, to ensure we
+ // produce a vector of pointers, we need to either arbitrarily pick an
+ // operand to broadcast, or broadcast a clone of the original GEP.
+ // Here, we broadcast a clone of the original.
+ //
+ // TODO: If at some point we decide to scalarize instructions having
+ // loop-invariant operands, this special case will no longer be
+ // required. We would add the scalarization decision to
+ // collectLoopScalars() and teach getVectorValue() to broadcast
+ // the lane-zero scalar value.
+ auto *Clone = Builder.Insert(GEP->clone());
+ for (unsigned Part = 0; Part < UF; ++Part)
+ Entry[Part] = Builder.CreateVectorSplat(VF, Clone);
+ } else {
+ // If the GEP has at least one loop-varying operand, we are sure to
+ // produce a vector of pointers. But if we are only unrolling, we want
+ // to produce a scalar GEP for each unroll part. Thus, the GEP we
+ // produce with the code below will be scalar (if VF == 1) or vector
+ // (otherwise). Note that for the unroll-only case, we still maintain
+ // values in the vector mapping with initVector, as we do for other
+ // instructions.
for (unsigned Part = 0; Part < UF; ++Part) {
- Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]);
- if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V))
- VecOp->copyIRFlags(BinOp);
+ // The pointer operand of the new GEP. If it's loop-invariant, we
+ // won't broadcast it.
+ auto *Ptr = OrigLoop->isLoopInvariant(GEP->getPointerOperand())
+ ? GEP->getPointerOperand()
+ : getVectorValue(GEP->getPointerOperand())[Part];
+
+ // Collect all the indices for the new GEP. If any index is
+ // loop-invariant, we won't broadcast it.
+ SmallVector<Value *, 4> Indices;
+ for (auto &U : make_range(GEP->idx_begin(), GEP->idx_end())) {
+ if (OrigLoop->isLoopInvariant(U.get()))
+ Indices.push_back(U.get());
+ else
+ Indices.push_back(getVectorValue(U.get())[Part]);
+ }
- Entry[Part] = V;
+ // Create the new GEP. Note that this GEP may be a scalar if VF == 1,
+ // but it should be a vector, otherwise.
+ auto *NewGEP = GEP->isInBounds()
+ ? Builder.CreateInBoundsGEP(Ptr, Indices)
+ : Builder.CreateGEP(Ptr, Indices);
+ assert((VF == 1 || NewGEP->getType()->isVectorTy()) &&
+ "NewGEP is not a pointer vector");
+ Entry[Part] = NewGEP;
}
+ }
- VectorLoopValueMap.initVector(&I, Entry);
- addMetadata(Entry, BinOp);
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, GEP);
+ break;
+ }
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::SRem:
+ case Instruction::URem:
+ // Scalarize with predication if this instruction may divide by zero and
+ // block execution is conditional, otherwise fallthrough.
+ if (Legal->isScalarWithPredication(&I)) {
+ scalarizeInstruction(&I, true);
break;
}
- case Instruction::Select: {
- // Widen selects.
- // If the selector is loop invariant we can create a select
- // instruction with a scalar condition. Otherwise, use vector-select.
- auto *SE = PSE.getSE();
- bool InvariantCond =
- SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop);
- setDebugLocFromInst(Builder, &I);
-
- // The condition can be loop invariant but still defined inside the
- // loop. This means that we can't just use the original 'cond' value.
- // We have to take the 'vectorized' value and pick the first lane.
- // Instcombine will make this a no-op.
- const VectorParts &Cond = getVectorValue(I.getOperand(0));
- const VectorParts &Op0 = getVectorValue(I.getOperand(1));
- const VectorParts &Op1 = getVectorValue(I.getOperand(2));
-
- auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0);
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::FDiv:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor: {
+ // Just widen binops.
+ auto *BinOp = cast<BinaryOperator>(&I);
+ setDebugLocFromInst(Builder, BinOp);
+ const VectorParts &A = getVectorValue(BinOp->getOperand(0));
+ const VectorParts &B = getVectorValue(BinOp->getOperand(1));
+
+ // Use this vector value for all users of the original instruction.
+ VectorParts Entry(UF);
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]);
+
+ if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V))
+ VecOp->copyIRFlags(BinOp);
+
+ Entry[Part] = V;
+ }
- VectorParts Entry(UF);
- for (unsigned Part = 0; Part < UF; ++Part) {
- Entry[Part] = Builder.CreateSelect(
- InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]);
- }
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, BinOp);
+ break;
+ }
+ case Instruction::Select: {
+ // Widen selects.
+ // If the selector is loop invariant we can create a select
+ // instruction with a scalar condition. Otherwise, use vector-select.
+ auto *SE = PSE.getSE();
+ bool InvariantCond =
+ SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop);
+ setDebugLocFromInst(Builder, &I);
+
+ // The condition can be loop invariant but still defined inside the
+ // loop. This means that we can't just use the original 'cond' value.
+ // We have to take the 'vectorized' value and pick the first lane.
+ // Instcombine will make this a no-op.
+ const VectorParts &Cond = getVectorValue(I.getOperand(0));
+ const VectorParts &Op0 = getVectorValue(I.getOperand(1));
+ const VectorParts &Op1 = getVectorValue(I.getOperand(2));
+
+ auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0);
- VectorLoopValueMap.initVector(&I, Entry);
- addMetadata(Entry, &I);
- break;
+ VectorParts Entry(UF);
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ Entry[Part] = Builder.CreateSelect(
+ InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]);
}
- case Instruction::ICmp:
- case Instruction::FCmp: {
- // Widen compares. Generate vector compares.
- bool FCmp = (I.getOpcode() == Instruction::FCmp);
- auto *Cmp = dyn_cast<CmpInst>(&I);
- setDebugLocFromInst(Builder, Cmp);
- const VectorParts &A = getVectorValue(Cmp->getOperand(0));
- const VectorParts &B = getVectorValue(Cmp->getOperand(1));
- VectorParts Entry(UF);
- for (unsigned Part = 0; Part < UF; ++Part) {
- Value *C = nullptr;
- if (FCmp) {
- C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]);
- cast<FCmpInst>(C)->copyFastMathFlags(Cmp);
- } else {
- C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]);
- }
- Entry[Part] = C;
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, &I);
+ break;
+ }
+
+ case Instruction::ICmp:
+ case Instruction::FCmp: {
+ // Widen compares. Generate vector compares.
+ bool FCmp = (I.getOpcode() == Instruction::FCmp);
+ auto *Cmp = dyn_cast<CmpInst>(&I);
+ setDebugLocFromInst(Builder, Cmp);
+ const VectorParts &A = getVectorValue(Cmp->getOperand(0));
+ const VectorParts &B = getVectorValue(Cmp->getOperand(1));
+ VectorParts Entry(UF);
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ Value *C = nullptr;
+ if (FCmp) {
+ C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]);
+ cast<FCmpInst>(C)->copyFastMathFlags(Cmp);
+ } else {
+ C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]);
}
+ Entry[Part] = C;
+ }
+
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, &I);
+ break;
+ }
- VectorLoopValueMap.initVector(&I, Entry);
- addMetadata(Entry, &I);
+ case Instruction::Store:
+ case Instruction::Load:
+ vectorizeMemoryInstruction(&I);
+ break;
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::FPExt:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ case Instruction::Trunc:
+ case Instruction::FPTrunc:
+ case Instruction::BitCast: {
+ auto *CI = dyn_cast<CastInst>(&I);
+ setDebugLocFromInst(Builder, CI);
+
+ // Optimize the special case where the source is a constant integer
+ // induction variable. Notice that we can only optimize the 'trunc' case
+ // because (a) FP conversions lose precision, (b) sext/zext may wrap, and
+ // (c) other casts depend on pointer size.
+ if (Cost->isOptimizableIVTruncate(CI, VF)) {
+ widenIntOrFpInduction(cast<PHINode>(CI->getOperand(0)),
+ cast<TruncInst>(CI));
break;
}
- case Instruction::Store:
- case Instruction::Load:
- vectorizeMemoryInstruction(&I);
+ /// Vectorize casts.
+ Type *DestTy =
+ (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF);
+
+ const VectorParts &A = getVectorValue(CI->getOperand(0));
+ VectorParts Entry(UF);
+ for (unsigned Part = 0; Part < UF; ++Part)
+ Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy);
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, &I);
+ break;
+ }
+
+ case Instruction::Call: {
+ // Ignore dbg intrinsics.
+ if (isa<DbgInfoIntrinsic>(I))
break;
- case Instruction::ZExt:
- case Instruction::SExt:
- case Instruction::FPToUI:
- case Instruction::FPToSI:
- case Instruction::FPExt:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
- case Instruction::SIToFP:
- case Instruction::UIToFP:
- case Instruction::Trunc:
- case Instruction::FPTrunc:
- case Instruction::BitCast: {
- auto *CI = dyn_cast<CastInst>(&I);
- setDebugLocFromInst(Builder, CI);
-
- // Optimize the special case where the source is a constant integer
- // induction variable. Notice that we can only optimize the 'trunc' case
- // because (a) FP conversions lose precision, (b) sext/zext may wrap, and
- // (c) other casts depend on pointer size.
- auto ID = Legal->getInductionVars()->lookup(OldInduction);
- if (isa<TruncInst>(CI) && CI->getOperand(0) == OldInduction &&
- ID.getConstIntStepValue()) {
- widenIntInduction(OldInduction, cast<TruncInst>(CI));
- break;
- }
+ setDebugLocFromInst(Builder, &I);
- /// Vectorize casts.
- Type *DestTy =
- (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF);
+ Module *M = I.getParent()->getParent()->getParent();
+ auto *CI = cast<CallInst>(&I);
- const VectorParts &A = getVectorValue(CI->getOperand(0));
- VectorParts Entry(UF);
- for (unsigned Part = 0; Part < UF; ++Part)
- Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy);
- VectorLoopValueMap.initVector(&I, Entry);
- addMetadata(Entry, &I);
+ StringRef FnName = CI->getCalledFunction()->getName();
+ Function *F = CI->getCalledFunction();
+ Type *RetTy = ToVectorTy(CI->getType(), VF);
+ SmallVector<Type *, 4> Tys;
+ for (Value *ArgOperand : CI->arg_operands())
+ Tys.push_back(ToVectorTy(ArgOperand->getType(), VF));
+
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+ if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end ||
+ ID == Intrinsic::lifetime_start)) {
+ scalarizeInstruction(&I);
+ break;
+ }
+ // The flag shows whether we use Intrinsic or a usual Call for vectorized
+ // version of the instruction.
+ // Is it beneficial to perform intrinsic call compared to lib call?
+ bool NeedToScalarize;
+ unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize);
+ bool UseVectorIntrinsic =
+ ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost;
+ if (!UseVectorIntrinsic && NeedToScalarize) {
+ scalarizeInstruction(&I);
break;
}
- case Instruction::Call: {
- // Ignore dbg intrinsics.
- if (isa<DbgInfoIntrinsic>(I))
- break;
- setDebugLocFromInst(Builder, &I);
-
- Module *M = BB->getParent()->getParent();
- auto *CI = cast<CallInst>(&I);
-
- StringRef FnName = CI->getCalledFunction()->getName();
- Function *F = CI->getCalledFunction();
- Type *RetTy = ToVectorTy(CI->getType(), VF);
- SmallVector<Type *, 4> Tys;
- for (Value *ArgOperand : CI->arg_operands())
- Tys.push_back(ToVectorTy(ArgOperand->getType(), VF));
-
- Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
- if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end ||
- ID == Intrinsic::lifetime_start)) {
- scalarizeInstruction(&I);
- break;
- }
- // The flag shows whether we use Intrinsic or a usual Call for vectorized
- // version of the instruction.
- // Is it beneficial to perform intrinsic call compared to lib call?
- bool NeedToScalarize;
- unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize);
- bool UseVectorIntrinsic =
- ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost;
- if (!UseVectorIntrinsic && NeedToScalarize) {
- scalarizeInstruction(&I);
- break;
- }
-
- VectorParts Entry(UF);
- for (unsigned Part = 0; Part < UF; ++Part) {
- SmallVector<Value *, 4> Args;
- for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) {
- Value *Arg = CI->getArgOperand(i);
- // Some intrinsics have a scalar argument - don't replace it with a
- // vector.
- if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) {
- const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i));
- Arg = VectorArg[Part];
- }
- Args.push_back(Arg);
+ VectorParts Entry(UF);
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ SmallVector<Value *, 4> Args;
+ for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) {
+ Value *Arg = CI->getArgOperand(i);
+ // Some intrinsics have a scalar argument - don't replace it with a
+ // vector.
+ if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) {
+ const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i));
+ Arg = VectorArg[Part];
}
+ Args.push_back(Arg);
+ }
- Function *VectorF;
- if (UseVectorIntrinsic) {
- // Use vector version of the intrinsic.
- Type *TysForDecl[] = {CI->getType()};
- if (VF > 1)
- TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF);
- VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl);
- } else {
- // Use vector version of the library call.
- StringRef VFnName = TLI->getVectorizedFunction(FnName, VF);
- assert(!VFnName.empty() && "Vector function name is empty.");
- VectorF = M->getFunction(VFnName);
- if (!VectorF) {
- // Generate a declaration
- FunctionType *FTy = FunctionType::get(RetTy, Tys, false);
- VectorF =
- Function::Create(FTy, Function::ExternalLinkage, VFnName, M);
- VectorF->copyAttributesFrom(F);
- }
+ Function *VectorF;
+ if (UseVectorIntrinsic) {
+ // Use vector version of the intrinsic.
+ Type *TysForDecl[] = {CI->getType()};
+ if (VF > 1)
+ TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF);
+ VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl);
+ } else {
+ // Use vector version of the library call.
+ StringRef VFnName = TLI->getVectorizedFunction(FnName, VF);
+ assert(!VFnName.empty() && "Vector function name is empty.");
+ VectorF = M->getFunction(VFnName);
+ if (!VectorF) {
+ // Generate a declaration
+ FunctionType *FTy = FunctionType::get(RetTy, Tys, false);
+ VectorF =
+ Function::Create(FTy, Function::ExternalLinkage, VFnName, M);
+ VectorF->copyAttributesFrom(F);
}
- assert(VectorF && "Can't create vector function.");
-
- SmallVector<OperandBundleDef, 1> OpBundles;
- CI->getOperandBundlesAsDefs(OpBundles);
- CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles);
+ }
+ assert(VectorF && "Can't create vector function.");
- if (isa<FPMathOperator>(V))
- V->copyFastMathFlags(CI);
+ SmallVector<OperandBundleDef, 1> OpBundles;
+ CI->getOperandBundlesAsDefs(OpBundles);
+ CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles);
- Entry[Part] = V;
- }
+ if (isa<FPMathOperator>(V))
+ V->copyFastMathFlags(CI);
- VectorLoopValueMap.initVector(&I, Entry);
- addMetadata(Entry, &I);
- break;
+ Entry[Part] = V;
}
- default:
- // All other instructions are unsupported. Scalarize them.
- scalarizeInstruction(&I);
- break;
- } // end of switch.
- } // end of for_each instr.
+ VectorLoopValueMap.initVector(&I, Entry);
+ addMetadata(Entry, &I);
+ break;
+ }
+
+ default:
+ // All other instructions are unsupported. Scalarize them.
+ scalarizeInstruction(&I);
+ break;
+ } // end of switch.
}
void InnerLoopVectorizer::updateAnalysis() {
@@ -4976,11 +5059,10 @@ void InnerLoopVectorizer::updateAnalysis() {
assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) &&
"Entry does not dominate exit.");
- // We don't predicate stores by this point, so the vector body should be a
- // single loop.
- DT->addNewBlock(LoopVectorBody, LoopVectorPreHeader);
-
- DT->addNewBlock(LoopMiddleBlock, LoopVectorBody);
+ DT->addNewBlock(LI->getLoopFor(LoopVectorBody)->getHeader(),
+ LoopVectorPreHeader);
+ DT->addNewBlock(LoopMiddleBlock,
+ LI->getLoopFor(LoopVectorBody)->getLoopLatch());
DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]);
DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader);
DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]);
@@ -5145,12 +5227,6 @@ bool LoopVectorizationLegality::canVectorize() {
if (UseInterleaved)
InterleaveInfo.analyzeInterleaving(*getSymbolicStrides());
- // Collect all instructions that are known to be uniform after vectorization.
- collectLoopUniforms();
-
- // Collect all instructions that are known to be scalar after vectorization.
- collectLoopScalars();
-
unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
@@ -5234,8 +5310,8 @@ void LoopVectorizationLegality::addInductionPhi(
// one if there are multiple (no good reason for doing this other
// than it is expedient). We've checked that it begins at zero and
// steps by one, so this is a canonical induction variable.
- if (!Induction || PhiTy == WidestIndTy)
- Induction = Phi;
+ if (!PrimaryInduction || PhiTy == WidestIndTy)
+ PrimaryInduction = Phi;
}
// Both the PHI node itself, and the "post-increment" value feeding
@@ -5398,7 +5474,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
} // next instr.
}
- if (!Induction) {
+ if (!PrimaryInduction) {
DEBUG(dbgs() << "LV: Did not find one integer induction var.\n");
if (Inductions.empty()) {
ORE->emit(createMissedAnalysis("NoInductionVariable")
@@ -5410,46 +5486,166 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// Now we know the widest induction type, check if our found induction
// is the same size. If it's not, unset it here and InnerLoopVectorizer
// will create another.
- if (Induction && WidestIndTy != Induction->getType())
- Induction = nullptr;
+ if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType())
+ PrimaryInduction = nullptr;
return true;
}
-void LoopVectorizationLegality::collectLoopScalars() {
+void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) {
+
+ // We should not collect Scalars more than once per VF. Right now, this
+ // function is called from collectUniformsAndScalars(), which already does
+ // this check. Collecting Scalars for VF=1 does not make any sense.
+ assert(VF >= 2 && !Scalars.count(VF) &&
+ "This function should not be visited twice for the same VF");
+
+ SmallSetVector<Instruction *, 8> Worklist;
+
+ // These sets are used to seed the analysis with pointers used by memory
+ // accesses that will remain scalar.
+ SmallSetVector<Instruction *, 8> ScalarPtrs;
+ SmallPtrSet<Instruction *, 8> PossibleNonScalarPtrs;
+
+ // A helper that returns true if the use of Ptr by MemAccess will be scalar.
+ // The pointer operands of loads and stores will be scalar as long as the
+ // memory access is not a gather or scatter operation. The value operand of a
+ // store will remain scalar if the store is scalarized.
+ auto isScalarUse = [&](Instruction *MemAccess, Value *Ptr) {
+ InstWidening WideningDecision = getWideningDecision(MemAccess, VF);
+ assert(WideningDecision != CM_Unknown &&
+ "Widening decision should be ready at this moment");
+ if (auto *Store = dyn_cast<StoreInst>(MemAccess))
+ if (Ptr == Store->getValueOperand())
+ return WideningDecision == CM_Scalarize;
+ assert(Ptr == getPointerOperand(MemAccess) &&
+ "Ptr is neither a value or pointer operand");
+ return WideningDecision != CM_GatherScatter;
+ };
+
+ // A helper that returns true if the given value is a bitcast or
+ // getelementptr instruction contained in the loop.
+ auto isLoopVaryingBitCastOrGEP = [&](Value *V) {
+ return ((isa<BitCastInst>(V) && V->getType()->isPointerTy()) ||
+ isa<GetElementPtrInst>(V)) &&
+ !TheLoop->isLoopInvariant(V);
+ };
- // If an instruction is uniform after vectorization, it will remain scalar.
- Scalars.insert(Uniforms.begin(), Uniforms.end());
+ // A helper that evaluates a memory access's use of a pointer. If the use
+ // will be a scalar use, and the pointer is only used by memory accesses, we
+ // place the pointer in ScalarPtrs. Otherwise, the pointer is placed in
+ // PossibleNonScalarPtrs.
+ auto evaluatePtrUse = [&](Instruction *MemAccess, Value *Ptr) {
- // Collect the getelementptr instructions that will not be vectorized. A
- // getelementptr instruction is only vectorized if it is used for a legal
- // gather or scatter operation.
+ // We only care about bitcast and getelementptr instructions contained in
+ // the loop.
+ if (!isLoopVaryingBitCastOrGEP(Ptr))
+ return;
+
+ // If the pointer has already been identified as scalar (e.g., if it was
+ // also identified as uniform), there's nothing to do.
+ auto *I = cast<Instruction>(Ptr);
+ if (Worklist.count(I))
+ return;
+
+ // If the use of the pointer will be a scalar use, and all users of the
+ // pointer are memory accesses, place the pointer in ScalarPtrs. Otherwise,
+ // place the pointer in PossibleNonScalarPtrs.
+ if (isScalarUse(MemAccess, Ptr) && all_of(I->users(), [&](User *U) {
+ return isa<LoadInst>(U) || isa<StoreInst>(U);
+ }))
+ ScalarPtrs.insert(I);
+ else
+ PossibleNonScalarPtrs.insert(I);
+ };
+
+ // We seed the scalars analysis with three classes of instructions: (1)
+ // instructions marked uniform-after-vectorization, (2) bitcast and
+ // getelementptr instructions used by memory accesses requiring a scalar use,
+ // and (3) pointer induction variables and their update instructions (we
+ // currently only scalarize these).
+ //
+ // (1) Add to the worklist all instructions that have been identified as
+ // uniform-after-vectorization.
+ Worklist.insert(Uniforms[VF].begin(), Uniforms[VF].end());
+
+ // (2) Add to the worklist all bitcast and getelementptr instructions used by
+ // memory accesses requiring a scalar use. The pointer operands of loads and
+ // stores will be scalar as long as the memory accesses is not a gather or
+ // scatter operation. The value operand of a store will remain scalar if the
+ // store is scalarized.
for (auto *BB : TheLoop->blocks())
for (auto &I : *BB) {
- if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
- Scalars.insert(GEP);
- continue;
+ if (auto *Load = dyn_cast<LoadInst>(&I)) {
+ evaluatePtrUse(Load, Load->getPointerOperand());
+ } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
+ evaluatePtrUse(Store, Store->getPointerOperand());
+ evaluatePtrUse(Store, Store->getValueOperand());
}
- auto *Ptr = getPointerOperand(&I);
- if (!Ptr)
- continue;
- auto *GEP = getGEPInstruction(Ptr);
- if (GEP && isLegalGatherOrScatter(&I))
- Scalars.erase(GEP);
+ }
+ for (auto *I : ScalarPtrs)
+ if (!PossibleNonScalarPtrs.count(I)) {
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n");
+ Worklist.insert(I);
}
+ // (3) Add to the worklist all pointer induction variables and their update
+ // instructions.
+ //
+ // TODO: Once we are able to vectorize pointer induction variables we should
+ // no longer insert them into the worklist here.
+ auto *Latch = TheLoop->getLoopLatch();
+ for (auto &Induction : *Legal->getInductionVars()) {
+ auto *Ind = Induction.first;
+ auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch));
+ if (Induction.second.getKind() != InductionDescriptor::IK_PtrInduction)
+ continue;
+ Worklist.insert(Ind);
+ Worklist.insert(IndUpdate);
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n");
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n");
+ }
+
+ // Expand the worklist by looking through any bitcasts and getelementptr
+ // instructions we've already identified as scalar. This is similar to the
+ // expansion step in collectLoopUniforms(); however, here we're only
+ // expanding to include additional bitcasts and getelementptr instructions.
+ unsigned Idx = 0;
+ while (Idx != Worklist.size()) {
+ Instruction *Dst = Worklist[Idx++];
+ if (!isLoopVaryingBitCastOrGEP(Dst->getOperand(0)))
+ continue;
+ auto *Src = cast<Instruction>(Dst->getOperand(0));
+ if (all_of(Src->users(), [&](User *U) -> bool {
+ auto *J = cast<Instruction>(U);
+ return !TheLoop->contains(J) || Worklist.count(J) ||
+ ((isa<LoadInst>(J) || isa<StoreInst>(J)) &&
+ isScalarUse(J, Src));
+ })) {
+ Worklist.insert(Src);
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *Src << "\n");
+ }
+ }
+
// An induction variable will remain scalar if all users of the induction
// variable and induction variable update remain scalar.
- auto *Latch = TheLoop->getLoopLatch();
- for (auto &Induction : *getInductionVars()) {
+ for (auto &Induction : *Legal->getInductionVars()) {
auto *Ind = Induction.first;
auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch));
+ // We already considered pointer induction variables, so there's no reason
+ // to look at their users again.
+ //
+ // TODO: Once we are able to vectorize pointer induction variables we
+ // should no longer skip over them here.
+ if (Induction.second.getKind() == InductionDescriptor::IK_PtrInduction)
+ continue;
+
// Determine if all users of the induction variable are scalar after
// vectorization.
auto ScalarInd = all_of(Ind->users(), [&](User *U) -> bool {
auto *I = cast<Instruction>(U);
- return I == IndUpdate || !TheLoop->contains(I) || Scalars.count(I);
+ return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I);
});
if (!ScalarInd)
continue;
@@ -5458,23 +5654,19 @@ void LoopVectorizationLegality::collectLoopScalars() {
// scalar after vectorization.
auto ScalarIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool {
auto *I = cast<Instruction>(U);
- return I == Ind || !TheLoop->contains(I) || Scalars.count(I);
+ return I == Ind || !TheLoop->contains(I) || Worklist.count(I);
});
if (!ScalarIndUpdate)
continue;
// The induction variable and its update instruction will remain scalar.
- Scalars.insert(Ind);
- Scalars.insert(IndUpdate);
+ Worklist.insert(Ind);
+ Worklist.insert(IndUpdate);
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n");
+ DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n");
}
-}
-bool LoopVectorizationLegality::hasConsecutiveLikePtrOperand(Instruction *I) {
- if (isAccessInterleaved(I))
- return true;
- if (auto *Ptr = getPointerOperand(I))
- return isConsecutivePtr(Ptr);
- return false;
+ Scalars[VF].insert(Worklist.begin(), Worklist.end());
}
bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) {
@@ -5494,48 +5686,48 @@ bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) {
return false;
}
-bool LoopVectorizationLegality::memoryInstructionMustBeScalarized(
- Instruction *I, unsigned VF) {
-
- // If the memory instruction is in an interleaved group, it will be
- // vectorized and its pointer will remain uniform.
- if (isAccessInterleaved(I))
- return false;
-
+bool LoopVectorizationLegality::memoryInstructionCanBeWidened(Instruction *I,
+ unsigned VF) {
// Get and ensure we have a valid memory instruction.
LoadInst *LI = dyn_cast<LoadInst>(I);
StoreInst *SI = dyn_cast<StoreInst>(I);
assert((LI || SI) && "Invalid memory instruction");
- // If the pointer operand is uniform (loop invariant), the memory instruction
- // will be scalarized.
auto *Ptr = getPointerOperand(I);
- if (LI && isUniform(Ptr))
- return true;
- // If the pointer operand is non-consecutive and neither a gather nor a
- // scatter operation is legal, the memory instruction will be scalarized.
- if (!isConsecutivePtr(Ptr) && !isLegalGatherOrScatter(I))
- return true;
+ // In order to be widened, the pointer should be consecutive, first of all.
+ if (!isConsecutivePtr(Ptr))
+ return false;
// If the instruction is a store located in a predicated block, it will be
// scalarized.
if (isScalarWithPredication(I))
- return true;
+ return false;
// If the instruction's allocated size doesn't equal it's type size, it
// requires padding and will be scalarized.
auto &DL = I->getModule()->getDataLayout();
auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType();
if (hasIrregularType(ScalarTy, DL, VF))
- return true;
+ return false;
- // Otherwise, the memory instruction should be vectorized if the rest of the
- // loop is.
- return false;
+ return true;
}
-void LoopVectorizationLegality::collectLoopUniforms() {
+void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) {
+
+ // We should not collect Uniforms more than once per VF. Right now,
+ // this function is called from collectUniformsAndScalars(), which
+ // already does this check. Collecting Uniforms for VF=1 does not make any
+ // sense.
+
+ assert(VF >= 2 && !Uniforms.count(VF) &&
+ "This function should not be visited twice for the same VF");
+
+ // Visit the list of Uniforms. If we'll not find any uniform value, we'll
+ // not analyze again. Uniforms.count(VF) will return 1.
+ Uniforms[VF].clear();
+
// We now know that the loop is vectorizable!
// Collect instructions inside the loop that will remain uniform after
// vectorization.
@@ -5568,6 +5760,14 @@ void LoopVectorizationLegality::collectLoopUniforms() {
// Holds pointer operands of instructions that are possibly non-uniform.
SmallPtrSet<Instruction *, 8> PossibleNonUniformPtrs;
+ auto isUniformDecision = [&](Instruction *I, unsigned VF) {
+ InstWidening WideningDecision = getWideningDecision(I, VF);
+ assert(WideningDecision != CM_Unknown &&
+ "Widening decision should be ready at this moment");
+
+ return (WideningDecision == CM_Widen ||
+ WideningDecision == CM_Interleave);
+ };
// Iterate over the instructions in the loop, and collect all
// consecutive-like pointer operands in ConsecutiveLikePtrs. If it's possible
// that a consecutive-like pointer operand will be scalarized, we collect it
@@ -5590,25 +5790,18 @@ void LoopVectorizationLegality::collectLoopUniforms() {
return getPointerOperand(U) == Ptr;
});
- // Ensure the memory instruction will not be scalarized, making its
- // pointer operand non-uniform. If the pointer operand is used by some
- // instruction other than a memory access, we're not going to check if
- // that other instruction may be scalarized here. Thus, conservatively
- // assume the pointer operand may be non-uniform.
- if (!UsersAreMemAccesses || memoryInstructionMustBeScalarized(&I))
+ // Ensure the memory instruction will not be scalarized or used by
+ // gather/scatter, making its pointer operand non-uniform. If the pointer
+ // operand is used by any instruction other than a memory access, we
+ // conservatively assume the pointer operand may be non-uniform.
+ if (!UsersAreMemAccesses || !isUniformDecision(&I, VF))
PossibleNonUniformPtrs.insert(Ptr);
// If the memory instruction will be vectorized and its pointer operand
- // is consecutive-like, the pointer operand should remain uniform.
- else if (hasConsecutiveLikePtrOperand(&I))
- ConsecutiveLikePtrs.insert(Ptr);
-
- // Otherwise, if the memory instruction will be vectorized and its
- // pointer operand is non-consecutive-like, the memory instruction should
- // be a gather or scatter operation. Its pointer operand will be
- // non-uniform.
+ // is consecutive-like, or interleaving - the pointer operand should
+ // remain uniform.
else
- PossibleNonUniformPtrs.insert(Ptr);
+ ConsecutiveLikePtrs.insert(Ptr);
}
// Add to the Worklist all consecutive and consecutive-like pointers that
@@ -5632,7 +5825,9 @@ void LoopVectorizationLegality::collectLoopUniforms() {
continue;
auto *OI = cast<Instruction>(OV);
if (all_of(OI->users(), [&](User *U) -> bool {
- return isOutOfScope(U) || Worklist.count(cast<Instruction>(U));
+ auto *J = cast<Instruction>(U);
+ return !TheLoop->contains(J) || Worklist.count(J) ||
+ (OI == getPointerOperand(J) && isUniformDecision(J, VF));
})) {
Worklist.insert(OI);
DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n");
@@ -5643,7 +5838,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {
// Returns true if Ptr is the pointer operand of a memory access instruction
// I, and I is known to not require scalarization.
auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool {
- return getPointerOperand(I) == Ptr && !memoryInstructionMustBeScalarized(I);
+ return getPointerOperand(I) == Ptr && isUniformDecision(I, VF);
};
// For an instruction to be added into Worklist above, all its users inside
@@ -5652,7 +5847,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {
// nodes separately. An induction variable will remain uniform if all users
// of the induction variable and induction variable update remain uniform.
// The code below handles both pointer and non-pointer induction variables.
- for (auto &Induction : Inductions) {
+ for (auto &Induction : *Legal->getInductionVars()) {
auto *Ind = Induction.first;
auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch));
@@ -5683,7 +5878,7 @@ void LoopVectorizationLegality::collectLoopUniforms() {
DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate << "\n");
}
- Uniforms.insert(Worklist.begin(), Worklist.end());
+ Uniforms[VF].insert(Worklist.begin(), Worklist.end());
}
bool LoopVectorizationLegality::canVectorizeMemory() {
@@ -5823,7 +6018,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses(
uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType());
// An alignment of 0 means target ABI alignment.
- unsigned Align = LI ? LI->getAlignment() : SI->getAlignment();
+ unsigned Align = getMemInstAlignment(&I);
if (!Align)
Align = DL.getABITypeAlignment(PtrTy->getElementType());
@@ -5978,6 +6173,11 @@ void InterleavedAccessInfo::analyzeInterleaving(
if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size)
continue;
+ // Ignore A if the memory object of A and B don't belong to the same
+ // address space
+ if (getMemInstAddressSpace(A) != getMemInstAddressSpace(B))
+ continue;
+
// Calculate the distance from A to B.
const SCEVConstant *DistToB = dyn_cast<SCEVConstant>(
PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev));
@@ -6020,36 +6220,36 @@ void InterleavedAccessInfo::analyzeInterleaving(
if (Group->getNumMembers() != Group->getFactor())
releaseGroup(Group);
- // Remove interleaved groups with gaps (currently only loads) whose memory
- // accesses may wrap around. We have to revisit the getPtrStride analysis,
- // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does
+ // Remove interleaved groups with gaps (currently only loads) whose memory
+ // accesses may wrap around. We have to revisit the getPtrStride analysis,
+ // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does
// not check wrapping (see documentation there).
- // FORNOW we use Assume=false;
- // TODO: Change to Assume=true but making sure we don't exceed the threshold
+ // FORNOW we use Assume=false;
+ // TODO: Change to Assume=true but making sure we don't exceed the threshold
// of runtime SCEV assumptions checks (thereby potentially failing to
- // vectorize altogether).
+ // vectorize altogether).
// Additional optional optimizations:
- // TODO: If we are peeling the loop and we know that the first pointer doesn't
+ // TODO: If we are peeling the loop and we know that the first pointer doesn't
// wrap then we can deduce that all pointers in the group don't wrap.
- // This means that we can forcefully peel the loop in order to only have to
- // check the first pointer for no-wrap. When we'll change to use Assume=true
+ // This means that we can forcefully peel the loop in order to only have to
+ // check the first pointer for no-wrap. When we'll change to use Assume=true
// we'll only need at most one runtime check per interleaved group.
//
for (InterleaveGroup *Group : LoadGroups) {
// Case 1: A full group. Can Skip the checks; For full groups, if the wide
- // load would wrap around the address space we would do a memory access at
- // nullptr even without the transformation.
- if (Group->getNumMembers() == Group->getFactor())
+ // load would wrap around the address space we would do a memory access at
+ // nullptr even without the transformation.
+ if (Group->getNumMembers() == Group->getFactor())
continue;
- // Case 2: If first and last members of the group don't wrap this implies
+ // Case 2: If first and last members of the group don't wrap this implies
// that all the pointers in the group don't wrap.
// So we check only group member 0 (which is always guaranteed to exist),
- // and group member Factor - 1; If the latter doesn't exist we rely on
+ // and group member Factor - 1; If the latter doesn't exist we rely on
// peeling (if it is a non-reveresed accsess -- see Case 3).
Value *FirstMemberPtr = getPointerOperand(Group->getMember(0));
- if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,
+ if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,
/*ShouldCheckWrap=*/true)) {
DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "
"first group member potentially pointer-wrapping.\n");
@@ -6065,8 +6265,7 @@ void InterleavedAccessInfo::analyzeInterleaving(
"last group member potentially pointer-wrapping.\n");
releaseGroup(Group);
}
- }
- else {
+ } else {
// Case 3: A non-reversed interleaved load group with gaps: We need
// to execute at least one scalar epilogue iteration. This will ensure
// we don't speculatively access memory out-of-bounds. We only need
@@ -6082,27 +6281,62 @@ void InterleavedAccessInfo::analyzeInterleaving(
}
}
-LoopVectorizationCostModel::VectorizationFactor
-LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
- // Width 1 means no vectorize
- VectorizationFactor Factor = {1U, 0U};
- if (OptForSize && Legal->getRuntimePointerChecking()->Need) {
+Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) {
+ if (!EnableCondStoresVectorization && Legal->getNumPredStores()) {
+ ORE->emit(createMissedAnalysis("ConditionalStore")
+ << "store that is conditionally executed prevents vectorization");
+ DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n");
+ return None;
+ }
+
+ if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize.
+ return computeFeasibleMaxVF(OptForSize);
+
+ if (Legal->getRuntimePointerChecking()->Need) {
ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize")
<< "runtime pointer checks needed. Enable vectorization of this "
"loop with '#pragma clang loop vectorize(enable)' when "
"compiling with -Os/-Oz");
DEBUG(dbgs()
<< "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n");
- return Factor;
+ return None;
}
- if (!EnableCondStoresVectorization && Legal->getNumPredStores()) {
- ORE->emit(createMissedAnalysis("ConditionalStore")
- << "store that is conditionally executed prevents vectorization");
- DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n");
- return Factor;
+ // If we optimize the program for size, avoid creating the tail loop.
+ unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
+ DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
+
+ // If we don't know the precise trip count, don't try to vectorize.
+ if (TC < 2) {
+ ORE->emit(
+ createMissedAnalysis("UnknownLoopCountComplexCFG")
+ << "unable to calculate the loop count due to complex control flow");
+ DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n");
+ return None;
}
+ unsigned MaxVF = computeFeasibleMaxVF(OptForSize);
+
+ if (TC % MaxVF != 0) {
+ // If the trip count that we found modulo the vectorization factor is not
+ // zero then we require a tail.
+ // FIXME: look for a smaller MaxVF that does divide TC rather than give up.
+ // FIXME: return None if loop requiresScalarEpilog(<MaxVF>), or look for a
+ // smaller MaxVF that does not require a scalar epilog.
+
+ ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize")
+ << "cannot optimize for size and vectorize at the "
+ "same time. Enable vectorization of this loop "
+ "with '#pragma clang loop vectorize(enable)' "
+ "when compiling with -Os/-Oz");
+ DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n");
+ return None;
+ }
+
+ return MaxVF;
+}
+
+unsigned LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize) {
MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
unsigned SmallestType, WidestType;
std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
@@ -6136,7 +6370,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements"
" into one vector!");
- unsigned VF = MaxVectorSize;
+ unsigned MaxVF = MaxVectorSize;
if (MaximizeBandwidth && !OptForSize) {
// Collect all viable vectorization factors.
SmallVector<unsigned, 8> VFs;
@@ -6152,54 +6386,16 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true);
for (int i = RUs.size() - 1; i >= 0; --i) {
if (RUs[i].MaxLocalUsers <= TargetNumRegisters) {
- VF = VFs[i];
+ MaxVF = VFs[i];
break;
}
}
}
+ return MaxVF;
+}
- // If we optimize the program for size, avoid creating the tail loop.
- if (OptForSize) {
- unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
- DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
-
- // If we don't know the precise trip count, don't try to vectorize.
- if (TC < 2) {
- ORE->emit(
- createMissedAnalysis("UnknownLoopCountComplexCFG")
- << "unable to calculate the loop count due to complex control flow");
- DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n");
- return Factor;
- }
-
- // Find the maximum SIMD width that can fit within the trip count.
- VF = TC % MaxVectorSize;
-
- if (VF == 0)
- VF = MaxVectorSize;
- else {
- // If the trip count that we found modulo the vectorization factor is not
- // zero then we require a tail.
- ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize")
- << "cannot optimize for size and vectorize at the "
- "same time. Enable vectorization of this loop "
- "with '#pragma clang loop vectorize(enable)' "
- "when compiling with -Os/-Oz");
- DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n");
- return Factor;
- }
- }
-
- int UserVF = Hints->getWidth();
- if (UserVF != 0) {
- assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two");
- DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n");
-
- Factor.Width = UserVF;
- collectInstsToScalarize(UserVF);
- return Factor;
- }
-
+LoopVectorizationCostModel::VectorizationFactor
+LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) {
float Cost = expectedCost(1).first;
#ifndef NDEBUG
const float ScalarCost = Cost;
@@ -6209,12 +6405,12 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled;
// Ignore scalar width, because the user explicitly wants vectorization.
- if (ForceVectorization && VF > 1) {
+ if (ForceVectorization && MaxVF > 1) {
Width = 2;
Cost = expectedCost(Width).first / (float)Width;
}
- for (unsigned i = 2; i <= VF; i *= 2) {
+ for (unsigned i = 2; i <= MaxVF; i *= 2) {
// Notice that the vector loop needs to be executed less times, so
// we need to divide the cost of the vector loops by the width of
// the vector elements.
@@ -6238,8 +6434,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
<< "LV: Vectorization seems to be not beneficial, "
<< "but was forced by a user.\n");
DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n");
- Factor.Width = Width;
- Factor.Cost = Width * Cost;
+ VectorizationFactor Factor = {Width, (unsigned)(Width * Cost)};
return Factor;
}
@@ -6277,9 +6472,16 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {
T = ST->getValueOperand()->getType();
// Ignore loaded pointer types and stored pointer types that are not
- // consecutive. However, we do want to take consecutive stores/loads of
- // pointer vectors into account.
- if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I))
+ // vectorizable.
+ //
+ // FIXME: The check here attempts to predict whether a load or store will
+ // be vectorized. We only know this for certain after a VF has
+ // been selected. Here, we assume that if an access can be
+ // vectorized, it will be. We should also look at extending this
+ // optimization to non-pointer types.
+ //
+ if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I) &&
+ !Legal->isAccessInterleaved(&I) && !Legal->isLegalGatherOrScatter(&I))
continue;
MinWidth = std::min(MinWidth,
@@ -6562,12 +6764,13 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) {
MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size());
continue;
}
-
+ collectUniformsAndScalars(VFs[j]);
// Count the number of live intervals.
unsigned RegUsage = 0;
for (auto Inst : OpenIntervals) {
// Skip ignored values for VF > 1.
- if (VecValuesToIgnore.count(Inst))
+ if (VecValuesToIgnore.count(Inst) ||
+ isScalarAfterVectorization(Inst, VFs[j]))
continue;
RegUsage += GetRegUsage(Inst->getType(), VFs[j]);
}
@@ -6628,6 +6831,9 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) {
ScalarCostsTy ScalarCosts;
if (computePredInstDiscount(&I, ScalarCosts, VF) >= 0)
ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end());
+
+ // Remember that BB will remain after vectorization.
+ PredicatedBBsAfterVectorization.insert(BB);
}
}
}
@@ -6636,7 +6842,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(
Instruction *PredInst, DenseMap<Instruction *, unsigned> &ScalarCosts,
unsigned VF) {
- assert(!Legal->isUniformAfterVectorization(PredInst) &&
+ assert(!isUniformAfterVectorization(PredInst, VF) &&
"Instruction marked uniform-after-vectorization will be predicated");
// Initialize the discount to zero, meaning that the scalar version and the
@@ -6657,7 +6863,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// already be scalar to avoid traversing chains that are unlikely to be
// beneficial.
if (!I->hasOneUse() || PredInst->getParent() != I->getParent() ||
- Legal->isScalarAfterVectorization(I))
+ isScalarAfterVectorization(I, VF))
return false;
// If the instruction is scalar with predication, it will be analyzed
@@ -6677,7 +6883,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// the lane zero values for uniforms rather than asserting.
for (Use &U : I->operands())
if (auto *J = dyn_cast<Instruction>(U.get()))
- if (Legal->isUniformAfterVectorization(J))
+ if (isUniformAfterVectorization(J, VF))
return false;
// Otherwise, we can scalarize the instruction.
@@ -6690,7 +6896,7 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// and their return values are inserted into vectors. Thus, an extract would
// still be required.
auto needsExtract = [&](Instruction *I) -> bool {
- return TheLoop->contains(I) && !Legal->isScalarAfterVectorization(I);
+ return TheLoop->contains(I) && !isScalarAfterVectorization(I, VF);
};
// Compute the expected cost discount from scalarizing the entire expression
@@ -6717,8 +6923,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// Compute the scalarization overhead of needed insertelement instructions
// and phi nodes.
if (Legal->isScalarWithPredication(I) && !I->getType()->isVoidTy()) {
- ScalarCost += getScalarizationOverhead(ToVectorTy(I->getType(), VF), true,
- false, TTI);
+ ScalarCost += TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF),
+ true, false);
ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI);
}
@@ -6733,8 +6939,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(
if (canBeScalarized(J))
Worklist.push_back(J);
else if (needsExtract(J))
- ScalarCost += getScalarizationOverhead(ToVectorTy(J->getType(), VF),
- false, true, TTI);
+ ScalarCost += TTI.getScalarizationOverhead(
+ ToVectorTy(J->getType(),VF), false, true);
}
// Scale the total scalar cost by block probability.
@@ -6753,6 +6959,9 @@ LoopVectorizationCostModel::VectorizationCostTy
LoopVectorizationCostModel::expectedCost(unsigned VF) {
VectorizationCostTy Cost;
+ // Collect Uniform and Scalar instructions after vectorization with VF.
+ collectUniformsAndScalars(VF);
+
// Collect the instructions (and their associated costs) that will be more
// profitable to scalarize.
collectInstsToScalarize(VF);
@@ -6832,11 +7041,141 @@ static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) {
Legal->hasStride(I->getOperand(1));
}
+unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I,
+ unsigned VF) {
+ Type *ValTy = getMemInstValueType(I);
+ auto SE = PSE.getSE();
+
+ unsigned Alignment = getMemInstAlignment(I);
+ unsigned AS = getMemInstAddressSpace(I);
+ Value *Ptr = getPointerOperand(I);
+ Type *PtrTy = ToVectorTy(Ptr->getType(), VF);
+
+ // Figure out whether the access is strided and get the stride value
+ // if it's known in compile time
+ const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop);
+
+ // Get the cost of the scalar memory instruction and address computation.
+ unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV);
+
+ Cost += VF *
+ TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment,
+ AS, I);
+
+ // Get the overhead of the extractelement and insertelement instructions
+ // we might create due to scalarization.
+ Cost += getScalarizationOverhead(I, VF, TTI);
+
+ // If we have a predicated store, it may not be executed for each vector
+ // lane. Scale the cost by the probability of executing the predicated
+ // block.
+ if (Legal->isScalarWithPredication(I))
+ Cost /= getReciprocalPredBlockProb();
+
+ return Cost;
+}
+
+unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I,
+ unsigned VF) {
+ Type *ValTy = getMemInstValueType(I);
+ Type *VectorTy = ToVectorTy(ValTy, VF);
+ unsigned Alignment = getMemInstAlignment(I);
+ Value *Ptr = getPointerOperand(I);
+ unsigned AS = getMemInstAddressSpace(I);
+ int ConsecutiveStride = Legal->isConsecutivePtr(Ptr);
+
+ assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) &&
+ "Stride should be 1 or -1 for consecutive memory access");
+ unsigned Cost = 0;
+ if (Legal->isMaskRequired(I))
+ Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS);
+ else
+ Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, I);
+
+ bool Reverse = ConsecutiveStride < 0;
+ if (Reverse)
+ Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0);
+ return Cost;
+}
+
+unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I,
+ unsigned VF) {
+ LoadInst *LI = cast<LoadInst>(I);
+ Type *ValTy = LI->getType();
+ Type *VectorTy = ToVectorTy(ValTy, VF);
+ unsigned Alignment = LI->getAlignment();
+ unsigned AS = LI->getPointerAddressSpace();
+
+ return TTI.getAddressComputationCost(ValTy) +
+ TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy);
+}
+
+unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I,
+ unsigned VF) {
+ Type *ValTy = getMemInstValueType(I);
+ Type *VectorTy = ToVectorTy(ValTy, VF);
+ unsigned Alignment = getMemInstAlignment(I);
+ Value *Ptr = getPointerOperand(I);
+
+ return TTI.getAddressComputationCost(VectorTy) +
+ TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr,
+ Legal->isMaskRequired(I), Alignment);
+}
+
+unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
+ unsigned VF) {
+ Type *ValTy = getMemInstValueType(I);
+ Type *VectorTy = ToVectorTy(ValTy, VF);
+ unsigned AS = getMemInstAddressSpace(I);
+
+ auto Group = Legal->getInterleavedAccessGroup(I);
+ assert(Group && "Fail to get an interleaved access group.");
+
+ unsigned InterleaveFactor = Group->getFactor();
+ Type *WideVecTy = VectorType::get(ValTy, VF * InterleaveFactor);
+
+ // Holds the indices of existing members in an interleaved load group.
+ // An interleaved store group doesn't need this as it doesn't allow gaps.
+ SmallVector<unsigned, 4> Indices;
+ if (isa<LoadInst>(I)) {
+ for (unsigned i = 0; i < InterleaveFactor; i++)
+ if (Group->getMember(i))
+ Indices.push_back(i);
+ }
+
+ // Calculate the cost of the whole interleaved group.
+ unsigned Cost = TTI.getInterleavedMemoryOpCost(I->getOpcode(), WideVecTy,
+ Group->getFactor(), Indices,
+ Group->getAlignment(), AS);
+
+ if (Group->isReverse())
+ Cost += Group->getNumMembers() *
+ TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0);
+ return Cost;
+}
+
+unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I,
+ unsigned VF) {
+
+ // Calculate scalar cost only. Vectorization cost should be ready at this
+ // moment.
+ if (VF == 1) {
+ Type *ValTy = getMemInstValueType(I);
+ unsigned Alignment = getMemInstAlignment(I);
+ unsigned AS = getMemInstAlignment(I);
+
+ return TTI.getAddressComputationCost(ValTy) +
+ TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I);
+ }
+ return getWideningCost(I, VF);
+}
+
LoopVectorizationCostModel::VectorizationCostTy
LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
// If we know that this instruction will remain uniform, check the cost of
// the scalar version.
- if (Legal->isUniformAfterVectorization(I))
+ if (isUniformAfterVectorization(I, VF))
VF = 1;
if (VF > 1 && isProfitableToScalarize(I, VF))
@@ -6850,6 +7189,79 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
return VectorizationCostTy(C, TypeNotScalarized);
}
+void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) {
+ if (VF == 1)
+ return;
+ for (BasicBlock *BB : TheLoop->blocks()) {
+ // For each instruction in the old loop.
+ for (Instruction &I : *BB) {
+ Value *Ptr = getPointerOperand(&I);
+ if (!Ptr)
+ continue;
+
+ if (isa<LoadInst>(&I) && Legal->isUniform(Ptr)) {
+ // Scalar load + broadcast
+ unsigned Cost = getUniformMemOpCost(&I, VF);
+ setWideningDecision(&I, VF, CM_Scalarize, Cost);
+ continue;
+ }
+
+ // We assume that widening is the best solution when possible.
+ if (Legal->memoryInstructionCanBeWidened(&I, VF)) {
+ unsigned Cost = getConsecutiveMemOpCost(&I, VF);
+ setWideningDecision(&I, VF, CM_Widen, Cost);
+ continue;
+ }
+
+ // Choose between Interleaving, Gather/Scatter or Scalarization.
+ unsigned InterleaveCost = UINT_MAX;
+ unsigned NumAccesses = 1;
+ if (Legal->isAccessInterleaved(&I)) {
+ auto Group = Legal->getInterleavedAccessGroup(&I);
+ assert(Group && "Fail to get an interleaved access group.");
+
+ // Make one decision for the whole group.
+ if (getWideningDecision(&I, VF) != CM_Unknown)
+ continue;
+
+ NumAccesses = Group->getNumMembers();
+ InterleaveCost = getInterleaveGroupCost(&I, VF);
+ }
+
+ unsigned GatherScatterCost =
+ Legal->isLegalGatherOrScatter(&I)
+ ? getGatherScatterCost(&I, VF) * NumAccesses
+ : UINT_MAX;
+
+ unsigned ScalarizationCost =
+ getMemInstScalarizationCost(&I, VF) * NumAccesses;
+
+ // Choose better solution for the current VF,
+ // write down this decision and use it during vectorization.
+ unsigned Cost;
+ InstWidening Decision;
+ if (InterleaveCost <= GatherScatterCost &&
+ InterleaveCost < ScalarizationCost) {
+ Decision = CM_Interleave;
+ Cost = InterleaveCost;
+ } else if (GatherScatterCost < ScalarizationCost) {
+ Decision = CM_GatherScatter;
+ Cost = GatherScatterCost;
+ } else {
+ Decision = CM_Scalarize;
+ Cost = ScalarizationCost;
+ }
+ // If the instructions belongs to an interleave group, the whole group
+ // receives the same decision. The whole group receives the cost, but
+ // the cost will actually be assigned to one instruction.
+ if (auto Group = Legal->getInterleavedAccessGroup(&I))
+ setWideningDecision(Group, VF, Decision, Cost);
+ else
+ setWideningDecision(&I, VF, Decision, Cost);
+ }
+ }
+}
+
unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
unsigned VF,
Type *&VectorTy) {
@@ -6868,7 +7280,31 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
// instruction cost.
return 0;
case Instruction::Br: {
- return TTI.getCFInstrCost(I->getOpcode());
+ // In cases of scalarized and predicated instructions, there will be VF
+ // predicated blocks in the vectorized loop. Each branch around these
+ // blocks requires also an extract of its vector compare i1 element.
+ bool ScalarPredicatedBB = false;
+ BranchInst *BI = cast<BranchInst>(I);
+ if (VF > 1 && BI->isConditional() &&
+ (PredicatedBBsAfterVectorization.count(BI->getSuccessor(0)) ||
+ PredicatedBBsAfterVectorization.count(BI->getSuccessor(1))))
+ ScalarPredicatedBB = true;
+
+ if (ScalarPredicatedBB) {
+ // Return cost for branches around scalarized and predicated blocks.
+ Type *Vec_i1Ty =
+ VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF);
+ return (TTI.getScalarizationOverhead(Vec_i1Ty, false, true) +
+ (TTI.getCFInstrCost(Instruction::Br) * VF));
+ } else if (I->getParent() == TheLoop->getLoopLatch() || VF == 1)
+ // The back-edge branch will remain, as will all scalar branches.
+ return TTI.getCFInstrCost(Instruction::Br);
+ else
+ // This branch will be eliminated by if-conversion.
+ return 0;
+ // Note: We currently assume zero cost for an unconditional branch inside
+ // a predicated block since it will become a fall-through, although we
+ // may decide in the future to call TTI for all branches.
}
case Instruction::PHI: {
auto *Phi = cast<PHINode>(I);
@@ -6969,7 +7405,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
if (!ScalarCond)
CondTy = VectorType::get(CondTy, VF);
- return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy);
+ return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy, I);
}
case Instruction::ICmp:
case Instruction::FCmp: {
@@ -6978,130 +7414,12 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF))
ValTy = IntegerType::get(ValTy->getContext(), MinBWs[Op0AsInstruction]);
VectorTy = ToVectorTy(ValTy, VF);
- return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy);
+ return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr, I);
}
case Instruction::Store:
case Instruction::Load: {
- StoreInst *SI = dyn_cast<StoreInst>(I);
- LoadInst *LI = dyn_cast<LoadInst>(I);
- Type *ValTy = (SI ? SI->getValueOperand()->getType() : LI->getType());
- VectorTy = ToVectorTy(ValTy, VF);
-
- unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment();
- unsigned AS =
- SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace();
- Value *Ptr = getPointerOperand(I);
- // We add the cost of address computation here instead of with the gep
- // instruction because only here we know whether the operation is
- // scalarized.
- if (VF == 1)
- return TTI.getAddressComputationCost(VectorTy) +
- TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS);
-
- if (LI && Legal->isUniform(Ptr)) {
- // Scalar load + broadcast
- unsigned Cost = TTI.getAddressComputationCost(ValTy->getScalarType());
- Cost += TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(),
- Alignment, AS);
- return Cost +
- TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, ValTy);
- }
-
- // For an interleaved access, calculate the total cost of the whole
- // interleave group.
- if (Legal->isAccessInterleaved(I)) {
- auto Group = Legal->getInterleavedAccessGroup(I);
- assert(Group && "Fail to get an interleaved access group.");
-
- // Only calculate the cost once at the insert position.
- if (Group->getInsertPos() != I)
- return 0;
-
- unsigned InterleaveFactor = Group->getFactor();
- Type *WideVecTy =
- VectorType::get(VectorTy->getVectorElementType(),
- VectorTy->getVectorNumElements() * InterleaveFactor);
-
- // Holds the indices of existing members in an interleaved load group.
- // An interleaved store group doesn't need this as it doesn't allow gaps.
- SmallVector<unsigned, 4> Indices;
- if (LI) {
- for (unsigned i = 0; i < InterleaveFactor; i++)
- if (Group->getMember(i))
- Indices.push_back(i);
- }
-
- // Calculate the cost of the whole interleaved group.
- unsigned Cost = TTI.getInterleavedMemoryOpCost(
- I->getOpcode(), WideVecTy, Group->getFactor(), Indices,
- Group->getAlignment(), AS);
-
- if (Group->isReverse())
- Cost +=
- Group->getNumMembers() *
- TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0);
-
- // FIXME: The interleaved load group with a huge gap could be even more
- // expensive than scalar operations. Then we could ignore such group and
- // use scalar operations instead.
- return Cost;
- }
-
- // Check if the memory instruction will be scalarized.
- if (Legal->memoryInstructionMustBeScalarized(I, VF)) {
- unsigned Cost = 0;
- Type *PtrTy = ToVectorTy(Ptr->getType(), VF);
-
- // Figure out whether the access is strided and get the stride value
- // if it's known in compile time
- const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop);
-
- // Get the cost of the scalar memory instruction and address computation.
- Cost += VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV);
- Cost += VF *
- TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(),
- Alignment, AS);
-
- // Get the overhead of the extractelement and insertelement instructions
- // we might create due to scalarization.
- Cost += getScalarizationOverhead(I, VF, TTI);
-
- // If we have a predicated store, it may not be executed for each vector
- // lane. Scale the cost by the probability of executing the predicated
- // block.
- if (Legal->isScalarWithPredication(I))
- Cost /= getReciprocalPredBlockProb();
-
- return Cost;
- }
-
- // Determine if the pointer operand of the access is either consecutive or
- // reverse consecutive.
- int ConsecutiveStride = Legal->isConsecutivePtr(Ptr);
- bool Reverse = ConsecutiveStride < 0;
-
- // Determine if either a gather or scatter operation is legal.
- bool UseGatherOrScatter =
- !ConsecutiveStride && Legal->isLegalGatherOrScatter(I);
-
- unsigned Cost = TTI.getAddressComputationCost(VectorTy);
- if (UseGatherOrScatter) {
- assert(ConsecutiveStride == 0 &&
- "Gather/Scatter are not used for consecutive stride");
- return Cost +
- TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr,
- Legal->isMaskRequired(I), Alignment);
- }
- // Wide load/stores.
- if (Legal->isMaskRequired(I))
- Cost +=
- TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS);
- else
- Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS);
-
- if (Reverse)
- Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0);
- return Cost;
+ VectorTy = ToVectorTy(getMemInstValueType(I), VF);
+ return getMemoryInstructionCost(I, VF);
}
case Instruction::ZExt:
case Instruction::SExt:
@@ -7115,12 +7433,14 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
- // We optimize the truncation of induction variable.
- // The cost of these is the same as the scalar operation.
- if (I->getOpcode() == Instruction::Trunc &&
- Legal->isInductionVariable(I->getOperand(0)))
- return TTI.getCastInstrCost(I->getOpcode(), I->getType(),
- I->getOperand(0)->getType());
+ // We optimize the truncation of induction variables having constant
+ // integer steps. The cost of these truncations is the same as the scalar
+ // operation.
+ if (isOptimizableIVTruncate(I, VF)) {
+ auto *Trunc = cast<TruncInst>(I);
+ return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(),
+ Trunc->getSrcTy(), Trunc);
+ }
Type *SrcScalarTy = I->getOperand(0)->getType();
Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF);
@@ -7143,7 +7463,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
}
}
- return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy);
+ return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, I);
}
case Instruction::Call: {
bool NeedToScalarize;
@@ -7172,9 +7492,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)
INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
@@ -7206,81 +7524,34 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts();
VecValuesToIgnore.insert(Casts.begin(), Casts.end());
}
-
- // Insert values known to be scalar into VecValuesToIgnore. This is a
- // conservative estimation of the values that will later be scalarized.
- //
- // FIXME: Even though an instruction is not scalar-after-vectoriztion, it may
- // still be scalarized. For example, we may find an instruction to be
- // more profitable for a given vectorization factor if it were to be
- // scalarized. But at this point, we haven't yet computed the
- // vectorization factor.
- for (auto *BB : TheLoop->getBlocks())
- for (auto &I : *BB)
- if (Legal->isScalarAfterVectorization(&I))
- VecValuesToIgnore.insert(&I);
}
-void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr,
- bool IfPredicateInstr) {
- assert(!Instr->getType()->isAggregateType() && "Can't handle vectors");
- // Holds vector parameters or scalars, in case of uniform vals.
- SmallVector<VectorParts, 4> Params;
-
- setDebugLocFromInst(Builder, Instr);
-
- // Does this instruction return a value ?
- bool IsVoidRetTy = Instr->getType()->isVoidTy();
-
- // Initialize a new scalar map entry.
- ScalarParts Entry(UF);
-
- VectorParts Cond;
- if (IfPredicateInstr)
- Cond = createBlockInMask(Instr->getParent());
-
- // For each vector unroll 'part':
- for (unsigned Part = 0; Part < UF; ++Part) {
- Entry[Part].resize(1);
- // For each scalar that we create:
-
- // Start an "if (pred) a[i] = ..." block.
- Value *Cmp = nullptr;
- if (IfPredicateInstr) {
- if (Cond[Part]->getType()->isVectorTy())
- Cond[Part] =
- Builder.CreateExtractElement(Cond[Part], Builder.getInt32(0));
- Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cond[Part],
- ConstantInt::get(Cond[Part]->getType(), 1));
- }
-
- Instruction *Cloned = Instr->clone();
- if (!IsVoidRetTy)
- Cloned->setName(Instr->getName() + ".cloned");
-
- // Replace the operands of the cloned instructions with their scalar
- // equivalents in the new loop.
- for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) {
- auto *NewOp = getScalarValue(Instr->getOperand(op), Part, 0);
- Cloned->setOperand(op, NewOp);
- }
+LoopVectorizationCostModel::VectorizationFactor
+LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) {
- // Place the cloned scalar in the new loop.
- Builder.Insert(Cloned);
+ // Width 1 means no vectorize, cost 0 means uncomputed cost.
+ const LoopVectorizationCostModel::VectorizationFactor NoVectorization = {1U,
+ 0U};
+ Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize);
+ if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize.
+ return NoVectorization;
- // Add the cloned scalar to the scalar map entry.
- Entry[Part][0] = Cloned;
+ if (UserVF) {
+ DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n");
+ assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two");
+ // Collect the instructions (and their associated costs) that will be more
+ // profitable to scalarize.
+ CM.selectUserVectorizationFactor(UserVF);
+ return {UserVF, 0};
+ }
- // If we just cloned a new assumption, add it the assumption cache.
- if (auto *II = dyn_cast<IntrinsicInst>(Cloned))
- if (II->getIntrinsicID() == Intrinsic::assume)
- AC->registerAssumption(II);
+ unsigned MaxVF = MaybeMaxVF.getValue();
+ assert(MaxVF != 0 && "MaxVF is zero.");
+ if (MaxVF == 1)
+ return NoVectorization;
- // End if-block.
- if (IfPredicateInstr)
- PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp));
- }
- VectorLoopValueMap.initScalar(Instr, Entry);
+ // Select the optimal vectorization factor.
+ return CM.selectVectorizationFactor(MaxVF);
}
void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) {
@@ -7414,11 +7685,6 @@ bool LoopVectorizePass::processLoop(Loop *L) {
return false;
}
- // Use the cost model.
- LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F,
- &Hints);
- CM.collectValuesToIgnore();
-
// Check the function attributes to find out if this function should be
// optimized for size.
bool OptForSize =
@@ -7464,9 +7730,20 @@ bool LoopVectorizePass::processLoop(Loop *L) {
return false;
}
- // Select the optimal vectorization factor.
- const LoopVectorizationCostModel::VectorizationFactor VF =
- CM.selectVectorizationFactor(OptForSize);
+ // Use the cost model.
+ LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F,
+ &Hints);
+ CM.collectValuesToIgnore();
+
+ // Use the planner for vectorization.
+ LoopVectorizationPlanner LVP(CM);
+
+ // Get user vectorization factor.
+ unsigned UserVF = Hints.getWidth();
+
+ // Plan how to best vectorize, return the best VF and its cost.
+ LoopVectorizationCostModel::VectorizationFactor VF =
+ LVP.plan(OptForSize, UserVF);
// Select the interleave count.
unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost);
@@ -7522,10 +7799,10 @@ bool LoopVectorizePass::processLoop(Loop *L) {
const char *VAPassName = Hints.vectorizeAnalysisPassName();
if (!VectorizeLoop && !InterleaveLoop) {
// Do not vectorize or interleaving the loop.
- ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first,
+ ORE->emit(OptimizationRemarkMissed(VAPassName, VecDiagMsg.first,
L->getStartLoc(), L->getHeader())
<< VecDiagMsg.second);
- ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first,
+ ORE->emit(OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first,
L->getStartLoc(), L->getHeader())
<< IntDiagMsg.second);
return false;
@@ -7621,6 +7898,16 @@ bool LoopVectorizePass::runImpl(
if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2)
return false;
+ bool Changed = false;
+
+ // The vectorizer requires loops to be in simplified form.
+ // Since simplification may add new inner loops, it has to run before the
+ // legality and profitability checks. This means running the loop vectorizer
+ // will simplify all loops, regardless of whether anything end up being
+ // vectorized.
+ for (auto &L : *LI)
+ Changed |= simplifyLoop(L, DT, LI, SE, AC, false /* PreserveLCSSA */);
+
// Build up a worklist of inner-loops to vectorize. This is necessary as
// the act of vectorizing or partially unrolling a loop creates new loops
// and can invalidate iterators across the loops.
@@ -7632,9 +7919,15 @@ bool LoopVectorizePass::runImpl(
LoopsAnalyzed += Worklist.size();
// Now walk the identified inner loops.
- bool Changed = false;
- while (!Worklist.empty())
- Changed |= processLoop(Worklist.pop_back_val());
+ while (!Worklist.empty()) {
+ Loop *L = Worklist.pop_back_val();
+
+ // For the inner loops we actually process, form LCSSA to simplify the
+ // transform.
+ Changed |= formLCSSARecursively(*L, *DT, LI, SE);
+
+ Changed |= processLoop(L);
+ }
// Process each loop nest in the function.
return Changed;
diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 328f270029604..da3ac06ab464e 100644
--- a/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -39,6 +39,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
@@ -90,6 +91,10 @@ static cl::opt<unsigned> MinTreeSize(
"slp-min-tree-size", cl::init(3), cl::Hidden,
cl::desc("Only vectorize small trees if they are fully vectorizable"));
+static cl::opt<bool>
+ ViewSLPTree("view-slp-tree", cl::Hidden,
+ cl::desc("Display the SLP trees with Graphviz"));
+
// Limit the number of alias checks. The limit is chosen so that
// it has no negative effect on the llvm benchmarks.
static const unsigned AliasedCheckLimit = 10;
@@ -212,14 +217,14 @@ static unsigned getSameOpcode(ArrayRef<Value *> VL) {
/// Flag set: NSW, NUW, exact, and all of fast-math.
static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) {
if (auto *VecOp = dyn_cast<Instruction>(I)) {
- if (auto *Intersection = dyn_cast<Instruction>(VL[0])) {
- // Intersection is initialized to the 0th scalar,
- // so start counting from index '1'.
+ if (auto *I0 = dyn_cast<Instruction>(VL[0])) {
+ // VecOVp is initialized to the 0th scalar, so start counting from index
+ // '1'.
+ VecOp->copyIRFlags(I0);
for (int i = 1, e = VL.size(); i < e; ++i) {
if (auto *Scalar = dyn_cast<Instruction>(VL[i]))
- Intersection->andIRFlags(Scalar);
+ VecOp->andIRFlags(Scalar);
}
- VecOp->copyIRFlags(Intersection);
}
}
}
@@ -304,6 +309,8 @@ public:
typedef SmallVector<Instruction *, 16> InstrList;
typedef SmallPtrSet<Value *, 16> ValueSet;
typedef SmallVector<StoreInst *, 8> StoreList;
+ typedef MapVector<Value *, SmallVector<Instruction *, 2>>
+ ExtraValueToDebugLocsMap;
BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti,
TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li,
@@ -330,6 +337,10 @@ public:
/// \brief Vectorize the tree that starts with the elements in \p VL.
/// Returns the vectorized root.
Value *vectorizeTree();
+ /// Vectorize the tree but with the list of externally used values \p
+ /// ExternallyUsedValues. Values in this MapVector can be replaced but the
+ /// generated extractvalue instructions.
+ Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues);
/// \returns the cost incurred by unwanted spills and fills, caused by
/// holding live values over call sites.
@@ -343,6 +354,13 @@ public:
/// the purpose of scheduling and extraction in the \p UserIgnoreLst.
void buildTree(ArrayRef<Value *> Roots,
ArrayRef<Value *> UserIgnoreLst = None);
+ /// Construct a vectorizable tree that starts at \p Roots, ignoring users for
+ /// the purpose of scheduling and extraction in the \p UserIgnoreLst taking
+ /// into account (anf updating it, if required) list of externally used
+ /// values stored in \p ExternallyUsedValues.
+ void buildTree(ArrayRef<Value *> Roots,
+ ExtraValueToDebugLocsMap &ExternallyUsedValues,
+ ArrayRef<Value *> UserIgnoreLst = None);
/// Clear the internal data structures that are created by 'buildTree'.
void deleteTree() {
@@ -404,7 +422,7 @@ private:
int getEntryCost(TreeEntry *E);
/// This is the recursive part of buildTree.
- void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth);
+ void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int);
/// \returns True if the ExtractElement/ExtractValue instructions in VL can
/// be vectorized to use the original vector (or aggregate "bitcast" to a vector).
@@ -451,8 +469,9 @@ private:
SmallVectorImpl<Value *> &Left,
SmallVectorImpl<Value *> &Right);
struct TreeEntry {
- TreeEntry() : Scalars(), VectorizedValue(nullptr),
- NeedToGather(0) {}
+ TreeEntry(std::vector<TreeEntry> &Container)
+ : Scalars(), VectorizedValue(nullptr), NeedToGather(0),
+ Container(Container) {}
/// \returns true if the scalars in VL are equal to this entry.
bool isSame(ArrayRef<Value *> VL) const {
@@ -468,11 +487,24 @@ private:
/// Do we need to gather this sequence ?
bool NeedToGather;
+
+ /// Points back to the VectorizableTree.
+ ///
+ /// Only used for Graphviz right now. Unfortunately GraphTrait::NodeRef has
+ /// to be a pointer and needs to be able to initialize the child iterator.
+ /// Thus we need a reference back to the container to translate the indices
+ /// to entries.
+ std::vector<TreeEntry> &Container;
+
+ /// The TreeEntry index containing the user of this entry. We can actually
+ /// have multiple users so the data structure is not truly a tree.
+ SmallVector<int, 1> UserTreeIndices;
};
/// Create a new VectorizableTree entry.
- TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized) {
- VectorizableTree.emplace_back();
+ TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized,
+ int &UserTreeIdx) {
+ VectorizableTree.emplace_back(VectorizableTree);
int idx = VectorizableTree.size() - 1;
TreeEntry *Last = &VectorizableTree[idx];
Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end());
@@ -485,6 +517,10 @@ private:
} else {
MustGather.insert(VL.begin(), VL.end());
}
+
+ if (UserTreeIdx >= 0)
+ Last->UserTreeIndices.push_back(UserTreeIdx);
+ UserTreeIdx = idx;
return Last;
}
@@ -558,7 +594,9 @@ private:
SmallVector<std::unique_ptr<Instruction>, 8> DeletedInstructions;
/// A list of values that need to extracted out of the tree.
- /// This list holds pairs of (Internal Scalar : External User).
+ /// This list holds pairs of (Internal Scalar : External User). External User
+ /// can be nullptr, it means that this Internal Scalar will be used later,
+ /// after vectorization.
UserList ExternalUses;
/// Values used only by @llvm.assume calls.
@@ -706,6 +744,8 @@ private:
return os;
}
#endif
+ friend struct GraphTraits<BoUpSLP *>;
+ friend struct DOTGraphTraits<BoUpSLP *>;
/// Contains all scheduling data for a basic block.
///
@@ -916,17 +956,98 @@ private:
/// original width.
MapVector<Value *, std::pair<uint64_t, bool>> MinBWs;
};
+} // end namespace slpvectorizer
+
+template <> struct GraphTraits<BoUpSLP *> {
+ typedef BoUpSLP::TreeEntry TreeEntry;
+
+ /// NodeRef has to be a pointer per the GraphWriter.
+ typedef TreeEntry *NodeRef;
+
+ /// \brief Add the VectorizableTree to the index iterator to be able to return
+ /// TreeEntry pointers.
+ struct ChildIteratorType
+ : public iterator_adaptor_base<ChildIteratorType,
+ SmallVector<int, 1>::iterator> {
+
+ std::vector<TreeEntry> &VectorizableTree;
+
+ ChildIteratorType(SmallVector<int, 1>::iterator W,
+ std::vector<TreeEntry> &VT)
+ : ChildIteratorType::iterator_adaptor_base(W), VectorizableTree(VT) {}
+
+ NodeRef operator*() { return &VectorizableTree[*I]; }
+ };
+
+ static NodeRef getEntryNode(BoUpSLP &R) { return &R.VectorizableTree[0]; }
+
+ static ChildIteratorType child_begin(NodeRef N) {
+ return {N->UserTreeIndices.begin(), N->Container};
+ }
+ static ChildIteratorType child_end(NodeRef N) {
+ return {N->UserTreeIndices.end(), N->Container};
+ }
+
+ /// For the node iterator we just need to turn the TreeEntry iterator into a
+ /// TreeEntry* iterator so that it dereferences to NodeRef.
+ typedef pointer_iterator<std::vector<TreeEntry>::iterator> nodes_iterator;
+
+ static nodes_iterator nodes_begin(BoUpSLP *R) {
+ return nodes_iterator(R->VectorizableTree.begin());
+ }
+ static nodes_iterator nodes_end(BoUpSLP *R) {
+ return nodes_iterator(R->VectorizableTree.end());
+ }
+
+ static unsigned size(BoUpSLP *R) { return R->VectorizableTree.size(); }
+};
+
+template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
+ typedef BoUpSLP::TreeEntry TreeEntry;
+
+ DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
+
+ std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) {
+ std::string Str;
+ raw_string_ostream OS(Str);
+ if (isSplat(Entry->Scalars)) {
+ OS << "<splat> " << *Entry->Scalars[0];
+ return Str;
+ }
+ for (auto V : Entry->Scalars) {
+ OS << *V;
+ if (std::any_of(
+ R->ExternalUses.begin(), R->ExternalUses.end(),
+ [&](const BoUpSLP::ExternalUser &EU) { return EU.Scalar == V; }))
+ OS << " <extract>";
+ OS << "\n";
+ }
+ return Str;
+ }
+
+ static std::string getNodeAttributes(const TreeEntry *Entry,
+ const BoUpSLP *) {
+ if (Entry->NeedToGather)
+ return "color=red";
+ return "";
+ }
+};
} // end namespace llvm
-} // end namespace slpvectorizer
void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
ArrayRef<Value *> UserIgnoreLst) {
+ ExtraValueToDebugLocsMap ExternallyUsedValues;
+ buildTree(Roots, ExternallyUsedValues, UserIgnoreLst);
+}
+void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
+ ExtraValueToDebugLocsMap &ExternallyUsedValues,
+ ArrayRef<Value *> UserIgnoreLst) {
deleteTree();
UserIgnoreList = UserIgnoreLst;
if (!allSameType(Roots))
return;
- buildTree_rec(Roots, 0);
+ buildTree_rec(Roots, 0, -1);
// Collect the values that we need to extract from the tree.
for (TreeEntry &EIdx : VectorizableTree) {
@@ -940,6 +1061,14 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
if (Entry->NeedToGather)
continue;
+ // Check if the scalar is externally used as an extra arg.
+ auto ExtI = ExternallyUsedValues.find(Scalar);
+ if (ExtI != ExternallyUsedValues.end()) {
+ DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " <<
+ Lane << " from " << *Scalar << ".\n");
+ ExternalUses.emplace_back(Scalar, nullptr, Lane);
+ continue;
+ }
for (User *U : Scalar->users()) {
DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n");
@@ -976,28 +1105,28 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
}
}
-
-void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
+void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
+ int UserTreeIdx) {
bool isAltShuffle = false;
assert((allConstant(VL) || allSameType(VL)) && "Invalid types!");
if (Depth == RecursionMaxDepth) {
DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
// Don't handle vectors.
if (VL[0]->getType()->isVectorTy()) {
DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
if (SI->getValueOperand()->getType()->isVectorTy()) {
DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
unsigned Opcode = getSameOpcode(VL);
@@ -1014,7 +1143,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
// If all of the operands are identical or constant we have a simple solution.
if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !Opcode) {
DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
@@ -1026,7 +1155,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (EphValues.count(VL[i])) {
DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] <<
") is ephemeral.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
@@ -1039,10 +1168,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
DEBUG(dbgs() << "SLP: \tChecking bundle: " << *VL[i] << ".\n");
if (E->Scalars[i] != VL[i]) {
DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
+ // Record the reuse of the tree node. FIXME, currently this is only used to
+ // properly draw the graph rather than for the actual vectorization.
+ E->UserTreeIndices.push_back(UserTreeIdx);
DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *VL[0] << ".\n");
return;
}
@@ -1052,7 +1184,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (ScalarToTreeEntry.count(VL[i])) {
DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] <<
") is already in tree.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
@@ -1062,7 +1194,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (unsigned i = 0, e = VL.size(); i != e; ++i) {
if (MustGather.count(VL[i])) {
DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
@@ -1076,7 +1208,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
// Don't go into unreachable blocks. They may contain instructions with
// dependency cycles which confuse the final scheduling.
DEBUG(dbgs() << "SLP: bundle in unreachable block.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
@@ -1085,7 +1217,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (unsigned j = i+1; j < e; ++j)
if (VL[i] == VL[j]) {
DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
@@ -1100,7 +1232,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
assert((!BS.getScheduleData(VL[0]) ||
!BS.getScheduleData(VL[0])->isPartOfBundle()) &&
"tryScheduleBundle should cancelScheduling on failure");
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n");
@@ -1117,12 +1249,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (Term) {
DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n");
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n");
for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
@@ -1132,7 +1264,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock(
PH->getIncomingBlock(i)));
- buildTree_rec(Operands, Depth + 1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1144,7 +1276,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
} else {
BS.cancelScheduling(VL);
}
- newTreeEntry(VL, Reuse);
+ newTreeEntry(VL, Reuse, UserTreeIdx);
return;
}
case Instruction::Load: {
@@ -1160,7 +1292,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (DL->getTypeSizeInBits(ScalarTy) !=
DL->getTypeAllocSizeInBits(ScalarTy)) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n");
return;
}
@@ -1171,7 +1303,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
LoadInst *L = cast<LoadInst>(VL[i]);
if (!L->isSimple()) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n");
return;
}
@@ -1193,7 +1325,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (Consecutive) {
++NumLoadsWantToKeepOrder;
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of loads.\n");
return;
}
@@ -1208,7 +1340,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
}
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
if (ReverseConsecutive) {
++NumLoadsWantToChangeOrder;
@@ -1235,12 +1367,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType();
if (Ty != SrcTy || !isValidElementType(Ty)) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n");
return;
}
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of casts.\n");
for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
@@ -1249,7 +1381,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(i));
- buildTree_rec(Operands, Depth+1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1263,13 +1395,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (Cmp->getPredicate() != P0 ||
Cmp->getOperand(0)->getType() != ComparedTy) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n");
return;
}
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of compares.\n");
for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
@@ -1278,7 +1410,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(i));
- buildTree_rec(Operands, Depth+1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1301,7 +1433,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of bin op.\n");
// Sort operands of the instructions so that each side is more likely to
@@ -1309,8 +1441,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) {
ValueList Left, Right;
reorderInputsAccordingToOpcode(VL, Left, Right);
- buildTree_rec(Left, Depth + 1);
- buildTree_rec(Right, Depth + 1);
+ buildTree_rec(Left, Depth + 1, UserTreeIdx);
+ buildTree_rec(Right, Depth + 1, UserTreeIdx);
return;
}
@@ -1320,7 +1452,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(i));
- buildTree_rec(Operands, Depth+1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1330,7 +1462,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (cast<Instruction>(VL[j])->getNumOperands() != 2) {
DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n");
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
@@ -1343,7 +1475,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
if (Ty0 != CurTy) {
DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n");
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
@@ -1355,12 +1487,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
DEBUG(
dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n");
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
return;
}
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of GEPs.\n");
for (unsigned i = 0, e = 2; i < e; ++i) {
ValueList Operands;
@@ -1368,7 +1500,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(i));
- buildTree_rec(Operands, Depth + 1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1377,19 +1509,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Non-consecutive store.\n");
return;
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a vector of stores.\n");
ValueList Operands;
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(0));
- buildTree_rec(Operands, Depth + 1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
return;
}
case Instruction::Call: {
@@ -1400,7 +1532,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
if (!isTriviallyVectorizable(ID)) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
return;
}
@@ -1414,7 +1546,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
getVectorIntrinsicIDForCall(CI2, TLI) != ID ||
!CI->hasIdenticalOperandBundleSchema(*CI2)) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i]
<< "\n");
return;
@@ -1425,7 +1557,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
Value *A1J = CI2->getArgOperand(1);
if (A1I != A1J) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI
<< " argument "<< A1I<<"!=" << A1J
<< "\n");
@@ -1438,14 +1570,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
CI->op_begin() + CI->getBundleOperandsEndIndex(),
CI2->op_begin() + CI2->getBundleOperandsStartIndex())) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!="
<< *VL[i] << '\n');
return;
}
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) {
ValueList Operands;
// Prepare the operand vector.
@@ -1453,7 +1585,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
CallInst *CI2 = dyn_cast<CallInst>(j);
Operands.push_back(CI2->getArgOperand(i));
}
- buildTree_rec(Operands, Depth + 1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
@@ -1462,19 +1594,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
// then do not vectorize this instruction.
if (!isAltShuffle) {
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");
return;
}
- newTreeEntry(VL, true);
+ newTreeEntry(VL, true, UserTreeIdx);
DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");
// Reorder operands if reordering would enable vectorization.
if (isa<BinaryOperator>(VL0)) {
ValueList Left, Right;
reorderAltShuffleOperands(VL, Left, Right);
- buildTree_rec(Left, Depth + 1);
- buildTree_rec(Right, Depth + 1);
+ buildTree_rec(Left, Depth + 1, UserTreeIdx);
+ buildTree_rec(Right, Depth + 1, UserTreeIdx);
return;
}
@@ -1484,13 +1616,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
for (Value *j : VL)
Operands.push_back(cast<Instruction>(j)->getOperand(i));
- buildTree_rec(Operands, Depth + 1);
+ buildTree_rec(Operands, Depth + 1, UserTreeIdx);
}
return;
}
default:
BS.cancelScheduling(VL);
- newTreeEntry(VL, false);
+ newTreeEntry(VL, false, UserTreeIdx);
DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");
return;
}
@@ -1570,6 +1702,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
Type *ScalarTy = VL[0]->getType();
if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
ScalarTy = SI->getValueOperand()->getType();
+ else if (CmpInst *CI = dyn_cast<CmpInst>(VL[0]))
+ ScalarTy = CI->getOperand(0)->getType();
VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
// If we have computed a smaller type for the expression, update VecTy so
@@ -1599,7 +1733,13 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
int DeadCost = 0;
for (unsigned i = 0, e = VL.size(); i < e; ++i) {
Instruction *E = cast<Instruction>(VL[i]);
- if (E->hasOneUse())
+ // If all users are going to be vectorized, instruction can be
+ // considered as dead.
+ // The same, if have only one user, it will be vectorized for sure.
+ if (E->hasOneUse() ||
+ std::all_of(E->user_begin(), E->user_end(), [this](User *U) {
+ return ScalarToTreeEntry.count(U) > 0;
+ }))
// Take credit for instruction that will become dead.
DeadCost +=
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i);
@@ -1624,10 +1764,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
// Calculate the cost of this instruction.
int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
- VL0->getType(), SrcTy);
+ VL0->getType(), SrcTy, VL0);
VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
- int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
+ int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy, VL0);
return VecCost - ScalarCost;
}
case Instruction::FCmp:
@@ -1636,8 +1776,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
// Calculate the cost of this instruction.
VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
int ScalarCost = VecTy->getNumElements() *
- TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
- int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
+ TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty(), VL0);
+ int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy, VL0);
return VecCost - ScalarCost;
}
case Instruction::Add:
@@ -1720,18 +1860,18 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
// Cost of wide load - cost of scalar loads.
unsigned alignment = dyn_cast<LoadInst>(VL0)->getAlignment();
int ScalarLdCost = VecTy->getNumElements() *
- TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0);
+ TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0, VL0);
int VecLdCost = TTI->getMemoryOpCost(Instruction::Load,
- VecTy, alignment, 0);
+ VecTy, alignment, 0, VL0);
return VecLdCost - ScalarLdCost;
}
case Instruction::Store: {
// We know that we can merge the stores. Calculate the cost.
unsigned alignment = dyn_cast<StoreInst>(VL0)->getAlignment();
int ScalarStCost = VecTy->getNumElements() *
- TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0);
+ TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0, VL0);
int VecStCost = TTI->getMemoryOpCost(Instruction::Store,
- VecTy, alignment, 0);
+ VecTy, alignment, 0, VL0);
return VecStCost - ScalarStCost;
}
case Instruction::Call: {
@@ -1739,12 +1879,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
// Calculate the cost of the scalar and vector calls.
- SmallVector<Type*, 4> ScalarTys, VecTys;
- for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) {
+ SmallVector<Type*, 4> ScalarTys;
+ for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op)
ScalarTys.push_back(CI->getArgOperand(op)->getType());
- VecTys.push_back(VectorType::get(CI->getArgOperand(op)->getType(),
- VecTy->getNumElements()));
- }
FastMathFlags FMF;
if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
@@ -1753,7 +1890,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
int ScalarCallCost = VecTy->getNumElements() *
TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys, FMF);
- int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys, FMF);
+ SmallVector<Value *, 4> Args(CI->arg_operands());
+ int VecCallCost = TTI->getIntrinsicInstrCost(ID, CI->getType(), Args, FMF,
+ VecTy->getNumElements());
DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost
<< " (" << VecCallCost << "-" << ScalarCallCost << ")"
@@ -1947,9 +2086,18 @@ int BoUpSLP::getTreeCost() {
int SpillCost = getSpillCost();
Cost += SpillCost + ExtractCost;
- DEBUG(dbgs() << "SLP: Spill Cost = " << SpillCost << ".\n"
- << "SLP: Extract Cost = " << ExtractCost << ".\n"
- << "SLP: Total Cost = " << Cost << ".\n");
+ std::string Str;
+ {
+ raw_string_ostream OS(Str);
+ OS << "SLP: Spill Cost = " << SpillCost << ".\n"
+ << "SLP: Extract Cost = " << ExtractCost << ".\n"
+ << "SLP: Total Cost = " << Cost << ".\n";
+ }
+ DEBUG(dbgs() << Str);
+
+ if (ViewSLPTree)
+ ViewGraph(this, "SLP" + F->getName(), false, Str);
+
return Cost;
}
@@ -2702,6 +2850,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
Value *BoUpSLP::vectorizeTree() {
+ ExtraValueToDebugLocsMap ExternallyUsedValues;
+ return vectorizeTree(ExternallyUsedValues);
+}
+
+Value *
+BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) {
// All blocks must be scheduled before any instructions are inserted.
for (auto &BSIter : BlocksSchedules) {
@@ -2744,7 +2898,7 @@ Value *BoUpSLP::vectorizeTree() {
// Skip users that we already RAUW. This happens when one instruction
// has multiple uses of the same value.
- if (!is_contained(Scalar->users(), User))
+ if (User && !is_contained(Scalar->users(), User))
continue;
assert(ScalarToTreeEntry.count(Scalar) && "Invalid scalar");
@@ -2756,6 +2910,28 @@ Value *BoUpSLP::vectorizeTree() {
assert(Vec && "Can't find vectorizable value");
Value *Lane = Builder.getInt32(ExternalUse.Lane);
+ // If User == nullptr, the Scalar is used as extra arg. Generate
+ // ExtractElement instruction and update the record for this scalar in
+ // ExternallyUsedValues.
+ if (!User) {
+ assert(ExternallyUsedValues.count(Scalar) &&
+ "Scalar with nullptr as an external user must be registered in "
+ "ExternallyUsedValues map");
+ if (auto *VecI = dyn_cast<Instruction>(Vec)) {
+ Builder.SetInsertPoint(VecI->getParent(),
+ std::next(VecI->getIterator()));
+ } else {
+ Builder.SetInsertPoint(&F->getEntryBlock().front());
+ }
+ Value *Ex = Builder.CreateExtractElement(Vec, Lane);
+ Ex = extend(ScalarRoot, Ex, Scalar->getType());
+ CSEBlocks.insert(cast<Instruction>(Scalar)->getParent());
+ auto &Locs = ExternallyUsedValues[Scalar];
+ ExternallyUsedValues.insert({Ex, Locs});
+ ExternallyUsedValues.erase(Scalar);
+ continue;
+ }
+
// Generate extracts for out-of-tree users.
// Find the insertion point for the extractelement lane.
if (auto *VecI = dyn_cast<Instruction>(Vec)) {
@@ -3264,7 +3440,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) {
// sorted by the original instruction location. This lets the final schedule
// be as close as possible to the original instruction order.
struct ScheduleDataCompare {
- bool operator()(ScheduleData *SD1, ScheduleData *SD2) {
+ bool operator()(ScheduleData *SD1, ScheduleData *SD2) const {
return SD2->SchedulingPriority < SD1->SchedulingPriority;
}
};
@@ -3645,9 +3821,9 @@ PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &A
bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB);
if (!Changed)
return PreservedAnalyses::all();
+
PreservedAnalyses PA;
- PA.preserve<LoopAnalysis>();
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<AAManager>();
PA.preserve<GlobalsAA>();
return PA;
@@ -4026,36 +4202,40 @@ bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) {
if (!V)
return false;
+ Value *P = V->getParent();
+
+ // Vectorize in current basic block only.
+ auto *Op0 = dyn_cast<Instruction>(V->getOperand(0));
+ auto *Op1 = dyn_cast<Instruction>(V->getOperand(1));
+ if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P)
+ return false;
+
// Try to vectorize V.
- if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
+ if (tryToVectorizePair(Op0, Op1, R))
return true;
- BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
- BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
+ auto *A = dyn_cast<BinaryOperator>(Op0);
+ auto *B = dyn_cast<BinaryOperator>(Op1);
// Try to skip B.
if (B && B->hasOneUse()) {
- BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
- BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
- if (tryToVectorizePair(A, B0, R)) {
+ auto *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
+ auto *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
+ if (B0 && B0->getParent() == P && tryToVectorizePair(A, B0, R))
return true;
- }
- if (tryToVectorizePair(A, B1, R)) {
+ if (B1 && B1->getParent() == P && tryToVectorizePair(A, B1, R))
return true;
- }
}
// Try to skip A.
if (A && A->hasOneUse()) {
- BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
- BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
- if (tryToVectorizePair(A0, B, R)) {
+ auto *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
+ auto *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
+ if (A0 && A0->getParent() == P && tryToVectorizePair(A0, B, R))
return true;
- }
- if (tryToVectorizePair(A1, B, R)) {
+ if (A1 && A1->getParent() == P && tryToVectorizePair(A1, B, R))
return true;
- }
}
- return 0;
+ return false;
}
/// \brief Generate a shuffle mask to be used in a reduction tree.
@@ -4119,37 +4299,41 @@ namespace {
class HorizontalReduction {
SmallVector<Value *, 16> ReductionOps;
SmallVector<Value *, 32> ReducedVals;
+ // Use map vector to make stable output.
+ MapVector<Instruction *, Value *> ExtraArgs;
- BinaryOperator *ReductionRoot;
- // After successfull horizontal reduction vectorization attempt for PHI node
- // vectorizer tries to update root binary op by combining vectorized tree and
- // the ReductionPHI node. But during vectorization this ReductionPHI can be
- // vectorized itself and replaced by the undef value, while the instruction
- // itself is marked for deletion. This 'marked for deletion' PHI node then can
- // be used in new binary operation, causing "Use still stuck around after Def
- // is destroyed" crash upon PHI node deletion.
- WeakVH ReductionPHI;
+ BinaryOperator *ReductionRoot = nullptr;
/// The opcode of the reduction.
- unsigned ReductionOpcode;
+ Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd;
/// The opcode of the values we perform a reduction on.
- unsigned ReducedValueOpcode;
+ unsigned ReducedValueOpcode = 0;
/// Should we model this reduction as a pairwise reduction tree or a tree that
/// splits the vector in halves and adds those halves.
- bool IsPairwiseReduction;
+ bool IsPairwiseReduction = false;
+
+ /// Checks if the ParentStackElem.first should be marked as a reduction
+ /// operation with an extra argument or as extra argument itself.
+ void markExtraArg(std::pair<Instruction *, unsigned> &ParentStackElem,
+ Value *ExtraArg) {
+ if (ExtraArgs.count(ParentStackElem.first)) {
+ ExtraArgs[ParentStackElem.first] = nullptr;
+ // We ran into something like:
+ // ParentStackElem.first = ExtraArgs[ParentStackElem.first] + ExtraArg.
+ // The whole ParentStackElem.first should be considered as an extra value
+ // in this case.
+ // Do not perform analysis of remaining operands of ParentStackElem.first
+ // instruction, this whole instruction is an extra argument.
+ ParentStackElem.second = ParentStackElem.first->getNumOperands();
+ } else {
+ // We ran into something like:
+ // ParentStackElem.first += ... + ExtraArg + ...
+ ExtraArgs[ParentStackElem.first] = ExtraArg;
+ }
+ }
public:
- /// The width of one full horizontal reduction operation.
- unsigned ReduxWidth;
-
- /// Minimal width of available vector registers. It's used to determine
- /// ReduxWidth.
- unsigned MinVecRegSize;
-
- HorizontalReduction(unsigned MinVecRegSize)
- : ReductionRoot(nullptr), ReductionOpcode(0), ReducedValueOpcode(0),
- IsPairwiseReduction(false), ReduxWidth(0),
- MinVecRegSize(MinVecRegSize) {}
+ HorizontalReduction() = default;
/// \brief Try to find a reduction tree.
bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) {
@@ -4176,21 +4360,14 @@ public:
if (!isValidElementType(Ty))
return false;
- const DataLayout &DL = B->getModule()->getDataLayout();
ReductionOpcode = B->getOpcode();
ReducedValueOpcode = 0;
- // FIXME: Register size should be a parameter to this function, so we can
- // try different vectorization factors.
- ReduxWidth = MinVecRegSize / DL.getTypeSizeInBits(Ty);
ReductionRoot = B;
- ReductionPHI = Phi;
-
- if (ReduxWidth < 4)
- return false;
// We currently only support adds.
- if (ReductionOpcode != Instruction::Add &&
- ReductionOpcode != Instruction::FAdd)
+ if ((ReductionOpcode != Instruction::Add &&
+ ReductionOpcode != Instruction::FAdd) ||
+ !B->isAssociative())
return false;
// Post order traverse the reduction tree starting at B. We only handle true
@@ -4202,30 +4379,26 @@ public:
unsigned EdgeToVist = Stack.back().second++;
bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode;
- // Only handle trees in the current basic block.
- if (TreeN->getParent() != B->getParent())
- return false;
-
- // Each tree node needs to have one user except for the ultimate
- // reduction.
- if (!TreeN->hasOneUse() && TreeN != B)
- return false;
-
// Postorder vist.
if (EdgeToVist == 2 || IsReducedValue) {
- if (IsReducedValue) {
- // Make sure that the opcodes of the operations that we are going to
- // reduce match.
- if (!ReducedValueOpcode)
- ReducedValueOpcode = TreeN->getOpcode();
- else if (ReducedValueOpcode != TreeN->getOpcode())
- return false;
+ if (IsReducedValue)
ReducedVals.push_back(TreeN);
- } else {
- // We need to be able to reassociate the adds.
- if (!TreeN->isAssociative())
- return false;
- ReductionOps.push_back(TreeN);
+ else {
+ auto I = ExtraArgs.find(TreeN);
+ if (I != ExtraArgs.end() && !I->second) {
+ // Check if TreeN is an extra argument of its parent operation.
+ if (Stack.size() <= 1) {
+ // TreeN can't be an extra argument as it is a root reduction
+ // operation.
+ return false;
+ }
+ // Yes, TreeN is an extra argument, do not add it to a list of
+ // reduction operations.
+ // Stack[Stack.size() - 2] always points to the parent operation.
+ markExtraArg(Stack[Stack.size() - 2], TreeN);
+ ExtraArgs.erase(TreeN);
+ } else
+ ReductionOps.push_back(TreeN);
}
// Retract.
Stack.pop_back();
@@ -4242,13 +4415,44 @@ public:
// reduced value class.
if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode ||
I->getOpcode() == ReductionOpcode)) {
- if (!ReducedValueOpcode && I->getOpcode() != ReductionOpcode)
+ // Only handle trees in the current basic block.
+ if (I->getParent() != B->getParent()) {
+ // I is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), I);
+ continue;
+ }
+
+ // Each tree node needs to have one user except for the ultimate
+ // reduction.
+ if (!I->hasOneUse() && I != B) {
+ // I is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), I);
+ continue;
+ }
+
+ if (I->getOpcode() == ReductionOpcode) {
+ // We need to be able to reassociate the reduction operations.
+ if (!I->isAssociative()) {
+ // I is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), I);
+ continue;
+ }
+ } else if (ReducedValueOpcode &&
+ ReducedValueOpcode != I->getOpcode()) {
+ // Make sure that the opcodes of the operations that we are going to
+ // reduce match.
+ // I is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), I);
+ continue;
+ } else if (!ReducedValueOpcode)
ReducedValueOpcode = I->getOpcode();
+
Stack.push_back(std::make_pair(I, 0));
continue;
}
- return false;
}
+ // NextV is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), NextV);
}
return true;
}
@@ -4259,10 +4463,15 @@ public:
if (ReducedVals.empty())
return false;
+ // If there is a sufficient number of reduction values, reduce
+ // to a nearby power-of-2. Can safely generate oversized
+ // vectors and rely on the backend to split them to legal sizes.
unsigned NumReducedVals = ReducedVals.size();
- if (NumReducedVals < ReduxWidth)
+ if (NumReducedVals < 4)
return false;
+ unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
+
Value *VectorizedTree = nullptr;
IRBuilder<> Builder(ReductionRoot);
FastMathFlags Unsafe;
@@ -4270,20 +4479,26 @@ public:
Builder.setFastMathFlags(Unsafe);
unsigned i = 0;
- for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) {
+ BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
+ // The same extra argument may be used several time, so log each attempt
+ // to use it.
+ for (auto &Pair : ExtraArgs)
+ ExternallyUsedValues[Pair.second].push_back(Pair.first);
+ while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
- V.buildTree(VL, ReductionOps);
+ V.buildTree(VL, ExternallyUsedValues, ReductionOps);
if (V.shouldReorder()) {
SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend());
- V.buildTree(Reversed, ReductionOps);
+ V.buildTree(Reversed, ExternallyUsedValues, ReductionOps);
}
if (V.isTreeTinyAndNotFullyVectorizable())
- continue;
+ break;
V.computeMinimumValueSizes();
// Estimate cost.
- int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]);
+ int Cost =
+ V.getTreeCost() + getReductionCost(TTI, ReducedVals[i], ReduxWidth);
if (Cost >= -SLPCostThreshold)
break;
@@ -4292,33 +4507,44 @@ public:
// Vectorize a tree.
DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc();
- Value *VectorizedRoot = V.vectorizeTree();
+ Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues);
// Emit a reduction.
- Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder);
+ Value *ReducedSubTree =
+ emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps);
if (VectorizedTree) {
Builder.SetCurrentDebugLocation(Loc);
- VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree,
- ReducedSubTree, "bin.rdx");
+ VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree,
+ ReducedSubTree, "bin.rdx");
+ propagateIRFlags(VectorizedTree, ReductionOps);
} else
VectorizedTree = ReducedSubTree;
+ i += ReduxWidth;
+ ReduxWidth = PowerOf2Floor(NumReducedVals - i);
}
if (VectorizedTree) {
// Finish the reduction.
for (; i < NumReducedVals; ++i) {
- Builder.SetCurrentDebugLocation(
- cast<Instruction>(ReducedVals[i])->getDebugLoc());
- VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree,
- ReducedVals[i]);
+ auto *I = cast<Instruction>(ReducedVals[i]);
+ Builder.SetCurrentDebugLocation(I->getDebugLoc());
+ VectorizedTree =
+ Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I);
+ propagateIRFlags(VectorizedTree, ReductionOps);
+ }
+ for (auto &Pair : ExternallyUsedValues) {
+ assert(!Pair.second.empty() &&
+ "At least one DebugLoc must be inserted");
+ // Add each externally used value to the final reduction.
+ for (auto *I : Pair.second) {
+ Builder.SetCurrentDebugLocation(I->getDebugLoc());
+ VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree,
+ Pair.first, "bin.extra");
+ propagateIRFlags(VectorizedTree, I);
+ }
}
// Update users.
- if (ReductionPHI && !isa<UndefValue>(ReductionPHI)) {
- assert(ReductionRoot && "Need a reduction operation");
- ReductionRoot->setOperand(0, VectorizedTree);
- ReductionRoot->setOperand(1, ReductionPHI);
- } else
- ReductionRoot->replaceAllUsesWith(VectorizedTree);
+ ReductionRoot->replaceAllUsesWith(VectorizedTree);
}
return VectorizedTree != nullptr;
}
@@ -4329,7 +4555,8 @@ public:
private:
/// \brief Calculate the cost of a reduction.
- int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) {
+ int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal,
+ unsigned ReduxWidth) {
Type *ScalarTy = FirstReducedVal->getType();
Type *VecTy = VectorType::get(ScalarTy, ReduxWidth);
@@ -4352,15 +4579,9 @@ private:
return VecReduxCost - ScalarReduxCost;
}
- static Value *createBinOp(IRBuilder<> &Builder, unsigned Opcode, Value *L,
- Value *R, const Twine &Name = "") {
- if (Opcode == Instruction::FAdd)
- return Builder.CreateFAdd(L, R, Name);
- return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, L, R, Name);
- }
-
/// \brief Emit a horizontal reduction of the vectorized value.
- Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder) {
+ Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder,
+ unsigned ReduxWidth, ArrayRef<Value *> RedOps) {
assert(VectorizedValue && "Need to have a vectorized tree node");
assert(isPowerOf2_32(ReduxWidth) &&
"We only handle power-of-two reductions for now");
@@ -4378,15 +4599,16 @@ private:
Value *RightShuf = Builder.CreateShuffleVector(
TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),
"rdx.shuf.r");
- TmpVec = createBinOp(Builder, ReductionOpcode, LeftShuf, RightShuf,
- "bin.rdx");
+ TmpVec = Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf,
+ "bin.rdx");
} else {
Value *UpperHalf =
createRdxShuffleMask(ReduxWidth, i, false, false, Builder);
Value *Shuf = Builder.CreateShuffleVector(
TmpVec, UndefValue::get(TmpVec->getType()), UpperHalf, "rdx.shuf");
- TmpVec = createBinOp(Builder, ReductionOpcode, TmpVec, Shuf, "bin.rdx");
+ TmpVec = Builder.CreateBinOp(ReductionOpcode, TmpVec, Shuf, "bin.rdx");
}
+ propagateIRFlags(TmpVec, RedOps);
}
// The result is in the first element of the vector.
@@ -4438,16 +4660,19 @@ static bool findBuildVector(InsertElementInst *FirstInsertElem,
static bool findBuildAggregate(InsertValueInst *IV,
SmallVectorImpl<Value *> &BuildVector,
SmallVectorImpl<Value *> &BuildVectorOpds) {
- if (!IV->hasOneUse())
- return false;
- Value *V = IV->getAggregateOperand();
- if (!isa<UndefValue>(V)) {
- InsertValueInst *I = dyn_cast<InsertValueInst>(V);
- if (!I || !findBuildAggregate(I, BuildVector, BuildVectorOpds))
+ Value *V;
+ do {
+ BuildVector.push_back(IV);
+ BuildVectorOpds.push_back(IV->getInsertedValueOperand());
+ V = IV->getAggregateOperand();
+ if (isa<UndefValue>(V))
+ break;
+ IV = dyn_cast<InsertValueInst>(V);
+ if (!IV || !IV->hasOneUse())
return false;
- }
- BuildVector.push_back(IV);
- BuildVectorOpds.push_back(IV->getInsertedValueOperand());
+ } while (true);
+ std::reverse(BuildVector.begin(), BuildVector.end());
+ std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end());
return true;
}
@@ -4507,29 +4732,137 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,
return nullptr;
}
+namespace {
+/// Tracks instructons and its children.
+class WeakVHWithLevel final : public CallbackVH {
+ /// Operand index of the instruction currently beeing analized.
+ unsigned Level = 0;
+ /// Is this the instruction that should be vectorized, or are we now
+ /// processing children (i.e. operands of this instruction) for potential
+ /// vectorization?
+ bool IsInitial = true;
+
+public:
+ explicit WeakVHWithLevel() = default;
+ WeakVHWithLevel(Value *V) : CallbackVH(V){};
+ /// Restart children analysis each time it is repaced by the new instruction.
+ void allUsesReplacedWith(Value *New) override {
+ setValPtr(New);
+ Level = 0;
+ IsInitial = true;
+ }
+ /// Check if the instruction was not deleted during vectorization.
+ bool isValid() const { return !getValPtr(); }
+ /// Is the istruction itself must be vectorized?
+ bool isInitial() const { return IsInitial; }
+ /// Try to vectorize children.
+ void clearInitial() { IsInitial = false; }
+ /// Are all children processed already?
+ bool isFinal() const {
+ assert(getValPtr() &&
+ (isa<Instruction>(getValPtr()) &&
+ cast<Instruction>(getValPtr())->getNumOperands() >= Level));
+ return getValPtr() &&
+ cast<Instruction>(getValPtr())->getNumOperands() == Level;
+ }
+ /// Get next child operation.
+ Value *nextOperand() {
+ assert(getValPtr() && isa<Instruction>(getValPtr()) &&
+ cast<Instruction>(getValPtr())->getNumOperands() > Level);
+ return cast<Instruction>(getValPtr())->getOperand(Level++);
+ }
+ virtual ~WeakVHWithLevel() = default;
+};
+} // namespace
+
/// \brief Attempt to reduce a horizontal reduction.
/// If it is legal to match a horizontal reduction feeding
-/// the phi node P with reduction operators BI, then check if it
-/// can be done.
+/// the phi node P with reduction operators Root in a basic block BB, then check
+/// if it can be done.
/// \returns true if a horizontal reduction was matched and reduced.
/// \returns false if a horizontal reduction was not matched.
-static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI,
- BoUpSLP &R, TargetTransformInfo *TTI,
- unsigned MinRegSize) {
+static bool canBeVectorized(
+ PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R,
+ TargetTransformInfo *TTI,
+ const function_ref<bool(BinaryOperator *, BoUpSLP &)> Vectorize) {
if (!ShouldVectorizeHor)
return false;
- HorizontalReduction HorRdx(MinRegSize);
- if (!HorRdx.matchAssociativeReduction(P, BI))
+ if (!Root)
return false;
- // If there is a sufficient number of reduction values, reduce
- // to a nearby power-of-2. Can safely generate oversized
- // vectors and rely on the backend to split them to legal sizes.
- HorRdx.ReduxWidth =
- std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues()));
+ if (Root->getParent() != BB)
+ return false;
+ SmallVector<WeakVHWithLevel, 8> Stack(1, Root);
+ SmallSet<Value *, 8> VisitedInstrs;
+ bool Res = false;
+ while (!Stack.empty()) {
+ Value *V = Stack.back();
+ if (!V) {
+ Stack.pop_back();
+ continue;
+ }
+ auto *Inst = dyn_cast<Instruction>(V);
+ if (!Inst || isa<PHINode>(Inst)) {
+ Stack.pop_back();
+ continue;
+ }
+ if (Stack.back().isInitial()) {
+ Stack.back().clearInitial();
+ if (auto *BI = dyn_cast<BinaryOperator>(Inst)) {
+ HorizontalReduction HorRdx;
+ if (HorRdx.matchAssociativeReduction(P, BI)) {
+ if (HorRdx.tryToReduce(R, TTI)) {
+ Res = true;
+ P = nullptr;
+ continue;
+ }
+ }
+ if (P) {
+ Inst = dyn_cast<Instruction>(BI->getOperand(0));
+ if (Inst == P)
+ Inst = dyn_cast<Instruction>(BI->getOperand(1));
+ if (!Inst) {
+ P = nullptr;
+ continue;
+ }
+ }
+ }
+ P = nullptr;
+ if (Vectorize(dyn_cast<BinaryOperator>(Inst), R)) {
+ Res = true;
+ continue;
+ }
+ }
+ if (Stack.back().isFinal()) {
+ Stack.pop_back();
+ continue;
+ }
- return HorRdx.tryToReduce(R, TTI);
+ if (auto *NextV = dyn_cast<Instruction>(Stack.back().nextOperand()))
+ if (NextV->getParent() == BB && VisitedInstrs.insert(NextV).second &&
+ Stack.size() < RecursionMaxDepth)
+ Stack.push_back(NextV);
+ }
+ return Res;
+}
+
+bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V,
+ BasicBlock *BB, BoUpSLP &R,
+ TargetTransformInfo *TTI) {
+ if (!V)
+ return false;
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return false;
+
+ if (!isa<BinaryOperator>(I))
+ P = nullptr;
+ // Try to match and vectorize a horizontal reduction.
+ return canBeVectorized(P, I, BB, R, TTI,
+ [this](BinaryOperator *BI, BoUpSLP &R) -> bool {
+ return tryToVectorize(BI, R);
+ });
}
bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
@@ -4599,67 +4932,42 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
if (P->getNumIncomingValues() != 2)
return Changed;
- Value *Rdx = getReductionValue(DT, P, BB, LI);
-
- // Check if this is a Binary Operator.
- BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
- if (!BI)
- continue;
-
// Try to match and vectorize a horizontal reduction.
- if (canMatchHorizontalReduction(P, BI, R, TTI, R.getMinVecRegSize())) {
+ if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R,
+ TTI)) {
Changed = true;
it = BB->begin();
e = BB->end();
continue;
}
-
- Value *Inst = BI->getOperand(0);
- if (Inst == P)
- Inst = BI->getOperand(1);
-
- if (tryToVectorize(dyn_cast<BinaryOperator>(Inst), R)) {
- // We would like to start over since some instructions are deleted
- // and the iterator may become invalid value.
- Changed = true;
- it = BB->begin();
- e = BB->end();
- continue;
- }
-
continue;
}
- if (ShouldStartVectorizeHorAtStore)
- if (StoreInst *SI = dyn_cast<StoreInst>(it))
- if (BinaryOperator *BinOp =
- dyn_cast<BinaryOperator>(SI->getValueOperand())) {
- if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI,
- R.getMinVecRegSize()) ||
- tryToVectorize(BinOp, R)) {
- Changed = true;
- it = BB->begin();
- e = BB->end();
- continue;
- }
+ if (ShouldStartVectorizeHorAtStore) {
+ if (StoreInst *SI = dyn_cast<StoreInst>(it)) {
+ // Try to match and vectorize a horizontal reduction.
+ if (vectorizeRootInstruction(nullptr, SI->getValueOperand(), BB, R,
+ TTI)) {
+ Changed = true;
+ it = BB->begin();
+ e = BB->end();
+ continue;
}
+ }
+ }
// Try to vectorize horizontal reductions feeding into a return.
- if (ReturnInst *RI = dyn_cast<ReturnInst>(it))
- if (RI->getNumOperands() != 0)
- if (BinaryOperator *BinOp =
- dyn_cast<BinaryOperator>(RI->getOperand(0))) {
- DEBUG(dbgs() << "SLP: Found a return to vectorize.\n");
- if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI,
- R.getMinVecRegSize()) ||
- tryToVectorizePair(BinOp->getOperand(0), BinOp->getOperand(1),
- R)) {
- Changed = true;
- it = BB->begin();
- e = BB->end();
- continue;
- }
+ if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) {
+ if (RI->getNumOperands() != 0) {
+ // Try to match and vectorize a horizontal reduction.
+ if (vectorizeRootInstruction(nullptr, RI->getOperand(0), BB, R, TTI)) {
+ Changed = true;
+ it = BB->begin();
+ e = BB->end();
+ continue;
}
+ }
+ }
// Try to vectorize trees that start at compare instructions.
if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
@@ -4672,16 +4980,14 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
}
- for (int i = 0; i < 2; ++i) {
- if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) {
- if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) {
- Changed = true;
- // We would like to start over since some instructions are deleted
- // and the iterator may become invalid value.
- it = BB->begin();
- e = BB->end();
- break;
- }
+ for (int I = 0; I < 2; ++I) {
+ if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) {
+ Changed = true;
+ // We would like to start over since some instructions are deleted
+ // and the iterator may become invalid value.
+ it = BB->begin();
+ e = BB->end();
+ break;
}
}
continue;