aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms')
-rw-r--r--lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp78
-rw-r--r--lib/Transforms/Coroutines/CoroCleanup.cpp7
-rw-r--r--lib/Transforms/Coroutines/CoroEarly.cpp26
-rw-r--r--lib/Transforms/Coroutines/CoroElide.cpp2
-rw-r--r--lib/Transforms/Coroutines/CoroFrame.cpp652
-rw-r--r--lib/Transforms/Coroutines/CoroInstr.h205
-rw-r--r--lib/Transforms/Coroutines/CoroInternal.h162
-rw-r--r--lib/Transforms/Coroutines/CoroSplit.cpp1166
-rw-r--r--lib/Transforms/Coroutines/Coroutines.cpp342
-rw-r--r--lib/Transforms/IPO/ArgumentPromotion.cpp2
-rw-r--r--lib/Transforms/IPO/Attributor.cpp4959
-rw-r--r--lib/Transforms/IPO/BlockExtractor.cpp5
-rw-r--r--lib/Transforms/IPO/ConstantMerge.cpp4
-rw-r--r--lib/Transforms/IPO/CrossDSOCFI.cpp10
-rw-r--r--lib/Transforms/IPO/FunctionAttrs.cpp38
-rw-r--r--lib/Transforms/IPO/FunctionImport.cpp43
-rw-r--r--lib/Transforms/IPO/GlobalDCE.cpp156
-rw-r--r--lib/Transforms/IPO/GlobalOpt.cpp176
-rw-r--r--lib/Transforms/IPO/HotColdSplitting.cpp61
-rw-r--r--lib/Transforms/IPO/IPO.cpp13
-rw-r--r--lib/Transforms/IPO/InferFunctionAttrs.cpp20
-rw-r--r--lib/Transforms/IPO/Inliner.cpp21
-rw-r--r--lib/Transforms/IPO/LoopExtractor.cpp6
-rw-r--r--lib/Transforms/IPO/LowerTypeTests.cpp305
-rw-r--r--lib/Transforms/IPO/MergeFunctions.cpp4
-rw-r--r--lib/Transforms/IPO/PartialInlining.cpp20
-rw-r--r--lib/Transforms/IPO/PassManagerBuilder.cpp1
-rw-r--r--lib/Transforms/IPO/SCCP.cpp18
-rw-r--r--lib/Transforms/IPO/SampleProfile.cpp238
-rw-r--r--lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp21
-rw-r--r--lib/Transforms/IPO/WholeProgramDevirt.cpp389
-rw-r--r--lib/Transforms/InstCombine/InstCombineAddSub.cpp268
-rw-r--r--lib/Transforms/InstCombine/InstCombineAndOrXor.cpp278
-rw-r--r--lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp4
-rw-r--r--lib/Transforms/InstCombine/InstCombineCalls.cpp121
-rw-r--r--lib/Transforms/InstCombine/InstCombineCasts.cpp102
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp870
-rw-r--r--lib/Transforms/InstCombine/InstCombineInternal.h116
-rw-r--r--lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp93
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp77
-rw-r--r--lib/Transforms/InstCombine/InstCombinePHI.cpp6
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp455
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp370
-rw-r--r--lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp48
-rw-r--r--lib/Transforms/InstCombine/InstCombineVectorOps.cpp171
-rw-r--r--lib/Transforms/InstCombine/InstructionCombining.cpp67
-rw-r--r--lib/Transforms/Instrumentation/AddressSanitizer.cpp98
-rw-r--r--lib/Transforms/Instrumentation/BoundsChecking.cpp2
-rw-r--r--lib/Transforms/Instrumentation/CFGMST.h4
-rw-r--r--lib/Transforms/Instrumentation/ControlHeightReduction.cpp26
-rw-r--r--lib/Transforms/Instrumentation/DataFlowSanitizer.cpp2
-rw-r--r--lib/Transforms/Instrumentation/GCOVProfiling.cpp49
-rw-r--r--lib/Transforms/Instrumentation/HWAddressSanitizer.cpp376
-rw-r--r--lib/Transforms/Instrumentation/IndirectCallPromotion.cpp2
-rw-r--r--lib/Transforms/Instrumentation/InstrOrderFile.cpp3
-rw-r--r--lib/Transforms/Instrumentation/InstrProfiling.cpp65
-rw-r--r--lib/Transforms/Instrumentation/Instrumentation.cpp5
-rw-r--r--lib/Transforms/Instrumentation/MemorySanitizer.cpp89
-rw-r--r--lib/Transforms/Instrumentation/PGOInstrumentation.cpp220
-rw-r--r--lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp6
-rw-r--r--lib/Transforms/Instrumentation/SanitizerCoverage.cpp164
-rw-r--r--lib/Transforms/Instrumentation/ThreadSanitizer.cpp54
-rw-r--r--lib/Transforms/Instrumentation/ValueProfileCollector.cpp78
-rw-r--r--lib/Transforms/Instrumentation/ValueProfileCollector.h79
-rw-r--r--lib/Transforms/Instrumentation/ValueProfilePlugins.inc75
-rw-r--r--lib/Transforms/ObjCARC/PtrState.cpp4
-rw-r--r--lib/Transforms/Scalar/AlignmentFromAssumptions.cpp8
-rw-r--r--lib/Transforms/Scalar/CallSiteSplitting.cpp2
-rw-r--r--lib/Transforms/Scalar/ConstantHoisting.cpp24
-rw-r--r--lib/Transforms/Scalar/ConstantProp.cpp2
-rw-r--r--lib/Transforms/Scalar/CorrelatedValuePropagation.cpp180
-rw-r--r--lib/Transforms/Scalar/DCE.cpp31
-rw-r--r--lib/Transforms/Scalar/DeadStoreElimination.cpp7
-rw-r--r--lib/Transforms/Scalar/DivRemPairs.cpp219
-rw-r--r--lib/Transforms/Scalar/EarlyCSE.cpp22
-rw-r--r--lib/Transforms/Scalar/FlattenCFGPass.cpp24
-rw-r--r--lib/Transforms/Scalar/Float2Int.cpp47
-rw-r--r--lib/Transforms/Scalar/GVN.cpp201
-rw-r--r--lib/Transforms/Scalar/GVNHoist.cpp17
-rw-r--r--lib/Transforms/Scalar/GuardWidening.cpp2
-rw-r--r--lib/Transforms/Scalar/IndVarSimplify.cpp389
-rw-r--r--lib/Transforms/Scalar/InferAddressSpaces.cpp38
-rw-r--r--lib/Transforms/Scalar/InstSimplifyPass.cpp48
-rw-r--r--lib/Transforms/Scalar/JumpThreading.cpp18
-rw-r--r--lib/Transforms/Scalar/LICM.cpp55
-rw-r--r--lib/Transforms/Scalar/LoopDataPrefetch.cpp4
-rw-r--r--lib/Transforms/Scalar/LoopDeletion.cpp2
-rw-r--r--lib/Transforms/Scalar/LoopFuse.cpp640
-rw-r--r--lib/Transforms/Scalar/LoopIdiomRecognize.cpp890
-rw-r--r--lib/Transforms/Scalar/LoopInstSimplify.cpp5
-rw-r--r--lib/Transforms/Scalar/LoopInterchange.cpp62
-rw-r--r--lib/Transforms/Scalar/LoopLoadElimination.cpp3
-rw-r--r--lib/Transforms/Scalar/LoopPredication.cpp2
-rw-r--r--lib/Transforms/Scalar/LoopRerollPass.cpp3
-rw-r--r--lib/Transforms/Scalar/LoopRotation.cpp10
-rw-r--r--lib/Transforms/Scalar/LoopSimplifyCFG.cpp4
-rw-r--r--lib/Transforms/Scalar/LoopSink.cpp9
-rw-r--r--lib/Transforms/Scalar/LoopStrengthReduce.cpp20
-rw-r--r--lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp8
-rw-r--r--lib/Transforms/Scalar/LoopUnrollPass.cpp128
-rw-r--r--lib/Transforms/Scalar/LoopUnswitch.cpp87
-rw-r--r--lib/Transforms/Scalar/LoopVersioningLICM.cpp31
-rw-r--r--lib/Transforms/Scalar/LowerConstantIntrinsics.cpp170
-rw-r--r--lib/Transforms/Scalar/LowerExpectIntrinsic.cpp33
-rw-r--r--lib/Transforms/Scalar/MemCpyOptimizer.cpp110
-rw-r--r--lib/Transforms/Scalar/MergeICmps.cpp2
-rw-r--r--lib/Transforms/Scalar/MergedLoadStoreMotion.cpp167
-rw-r--r--lib/Transforms/Scalar/NaryReassociate.cpp2
-rw-r--r--lib/Transforms/Scalar/NewGVN.cpp25
-rw-r--r--lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp2
-rw-r--r--lib/Transforms/Scalar/PlaceSafepoints.cpp6
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp190
-rw-r--r--lib/Transforms/Scalar/RewriteStatepointsForGC.cpp6
-rw-r--r--lib/Transforms/Scalar/SCCP.cpp75
-rw-r--r--lib/Transforms/Scalar/SROA.cpp40
-rw-r--r--lib/Transforms/Scalar/Scalar.cpp9
-rw-r--r--lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp2
-rw-r--r--lib/Transforms/Scalar/SimpleLoopUnswitch.cpp25
-rw-r--r--lib/Transforms/Scalar/SpeculateAroundPHIs.cpp6
-rw-r--r--lib/Transforms/Scalar/StructurizeCFG.cpp2
-rw-r--r--lib/Transforms/Scalar/TailRecursionElimination.cpp2
-rw-r--r--lib/Transforms/Utils/BasicBlockUtils.cpp64
-rw-r--r--lib/Transforms/Utils/BuildLibCalls.cpp94
-rw-r--r--lib/Transforms/Utils/BypassSlowDivision.cpp8
-rw-r--r--lib/Transforms/Utils/CanonicalizeAliases.cpp1
-rw-r--r--lib/Transforms/Utils/CloneFunction.cpp15
-rw-r--r--lib/Transforms/Utils/CloneModule.cpp18
-rw-r--r--lib/Transforms/Utils/CodeExtractor.cpp309
-rw-r--r--lib/Transforms/Utils/EntryExitInstrumenter.cpp2
-rw-r--r--lib/Transforms/Utils/Evaluator.cpp2
-rw-r--r--lib/Transforms/Utils/FlattenCFG.cpp20
-rw-r--r--lib/Transforms/Utils/FunctionImportUtils.cpp2
-rw-r--r--lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp6
-rw-r--r--lib/Transforms/Utils/LibCallsShrinkWrap.cpp2
-rw-r--r--lib/Transforms/Utils/Local.cpp209
-rw-r--r--lib/Transforms/Utils/LoopRotationUtils.cpp27
-rw-r--r--lib/Transforms/Utils/LoopSimplify.cpp15
-rw-r--r--lib/Transforms/Utils/LoopUnroll.cpp12
-rw-r--r--lib/Transforms/Utils/LoopUnrollAndJam.cpp6
-rw-r--r--lib/Transforms/Utils/LoopUnrollPeel.cpp161
-rw-r--r--lib/Transforms/Utils/LoopUtils.cpp56
-rw-r--r--lib/Transforms/Utils/LoopVersioning.cpp4
-rw-r--r--lib/Transforms/Utils/MetaRenamer.cpp5
-rw-r--r--lib/Transforms/Utils/MisExpect.cpp177
-rw-r--r--lib/Transforms/Utils/ModuleUtils.cpp2
-rw-r--r--lib/Transforms/Utils/PredicateInfo.cpp80
-rw-r--r--lib/Transforms/Utils/SimplifyCFG.cpp250
-rw-r--r--lib/Transforms/Utils/SimplifyLibCalls.cpp688
-rw-r--r--lib/Transforms/Utils/SymbolRewriter.cpp12
-rw-r--r--lib/Transforms/Utils/VNCoercion.cpp2
-rw-r--r--lib/Transforms/Utils/ValueMapper.cpp60
-rw-r--r--lib/Transforms/Vectorize/LoadStoreVectorizer.cpp26
-rw-r--r--lib/Transforms/Vectorize/LoopVectorizationLegality.cpp186
-rw-r--r--lib/Transforms/Vectorize/LoopVectorizationPlanner.h4
-rw-r--r--lib/Transforms/Vectorize/LoopVectorize.cpp738
-rw-r--r--lib/Transforms/Vectorize/SLPVectorizer.cpp820
-rw-r--r--lib/Transforms/Vectorize/VPlan.cpp19
-rw-r--r--lib/Transforms/Vectorize/VPlan.h4
-rw-r--r--lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp2
-rw-r--r--lib/Transforms/Vectorize/VPlanSLP.cpp13
160 files changed, 16942 insertions, 4892 deletions
diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 06222d7e7e44..a24de3ca213f 100644
--- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -121,14 +121,13 @@ static bool foldGuardedRotateToFunnelShift(Instruction &I) {
BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1);
BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1);
Instruction *TermI = GuardBB->getTerminator();
- BasicBlock *TrueBB, *FalseBB;
ICmpInst::Predicate Pred;
- if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), TrueBB,
- FalseBB)))
+ BasicBlock *PhiBB = Phi.getParent();
+ if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()),
+ m_SpecificBB(PhiBB), m_SpecificBB(RotBB))))
return false;
- BasicBlock *PhiBB = Phi.getParent();
- if (Pred != CmpInst::ICMP_EQ || TrueBB != PhiBB || FalseBB != RotBB)
+ if (Pred != CmpInst::ICMP_EQ)
return false;
// We matched a variation of this IR pattern:
@@ -251,6 +250,72 @@ static bool foldAnyOrAllBitsSet(Instruction &I) {
return true;
}
+// Try to recognize below function as popcount intrinsic.
+// This is the "best" algorithm from
+// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
+// Also used in TargetLowering::expandCTPOP().
+//
+// int popcount(unsigned int i) {
+// i = i - ((i >> 1) & 0x55555555);
+// i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
+// i = ((i + (i >> 4)) & 0x0F0F0F0F);
+// return (i * 0x01010101) >> 24;
+// }
+static bool tryToRecognizePopCount(Instruction &I) {
+ if (I.getOpcode() != Instruction::LShr)
+ return false;
+
+ Type *Ty = I.getType();
+ if (!Ty->isIntOrIntVectorTy())
+ return false;
+
+ unsigned Len = Ty->getScalarSizeInBits();
+ // FIXME: fix Len == 8 and other irregular type lengths.
+ if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
+ return false;
+
+ APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
+ APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
+ APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
+ APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
+ APInt MaskShift = APInt(Len, Len - 8);
+
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ Value *MulOp0;
+ // Matching "(i * 0x01010101...) >> 24".
+ if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
+ match(Op1, m_SpecificInt(MaskShift))) {
+ Value *ShiftOp0;
+ // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
+ if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
+ m_Deferred(ShiftOp0)),
+ m_SpecificInt(Mask0F)))) {
+ Value *AndOp0;
+ // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
+ if (match(ShiftOp0,
+ m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
+ m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
+ m_SpecificInt(Mask33))))) {
+ Value *Root, *SubOp1;
+ // Matching "i - ((i >> 1) & 0x55555555...)".
+ if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
+ match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
+ m_SpecificInt(Mask55)))) {
+ LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
+ IRBuilder<> Builder(&I);
+ Function *Func = Intrinsic::getDeclaration(
+ I.getModule(), Intrinsic::ctpop, I.getType());
+ I.replaceAllUsesWith(Builder.CreateCall(Func, {Root}));
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
@@ -269,6 +334,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) {
for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
MadeChange |= foldAnyOrAllBitsSet(I);
MadeChange |= foldGuardedRotateToFunnelShift(I);
+ MadeChange |= tryToRecognizePopCount(I);
}
}
@@ -303,7 +369,7 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
}
bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
return runImpl(F, TLI, DT);
}
diff --git a/lib/Transforms/Coroutines/CoroCleanup.cpp b/lib/Transforms/Coroutines/CoroCleanup.cpp
index 1fb0a114d0c7..c3e05577f044 100644
--- a/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -73,6 +73,8 @@ bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) {
II->replaceAllUsesWith(ConstantInt::getTrue(Context));
break;
case Intrinsic::coro_id:
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once:
II->replaceAllUsesWith(ConstantTokenNone::get(Context));
break;
case Intrinsic::coro_subfn_addr:
@@ -111,8 +113,9 @@ struct CoroCleanup : FunctionPass {
bool doInitialization(Module &M) override {
if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin",
"llvm.coro.subfn.addr", "llvm.coro.free",
- "llvm.coro.id"}))
- L = llvm::make_unique<Lowerer>(M);
+ "llvm.coro.id", "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once"}))
+ L = std::make_unique<Lowerer>(M);
return false;
}
diff --git a/lib/Transforms/Coroutines/CoroEarly.cpp b/lib/Transforms/Coroutines/CoroEarly.cpp
index 692697d6f32e..55993d33ee4e 100644
--- a/lib/Transforms/Coroutines/CoroEarly.cpp
+++ b/lib/Transforms/Coroutines/CoroEarly.cpp
@@ -91,13 +91,14 @@ void Lowerer::lowerCoroDone(IntrinsicInst *II) {
Value *Operand = II->getArgOperand(0);
// ResumeFnAddr is the first pointer sized element of the coroutine frame.
+ static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
+ "resume function not at offset zero");
auto *FrameTy = Int8Ptr;
PointerType *FramePtrTy = FrameTy->getPointerTo();
Builder.SetInsertPoint(II);
auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
- auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0);
- auto *Load = Builder.CreateLoad(FrameTy, Gep);
+ auto *Load = Builder.CreateLoad(BCI);
auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
II->replaceAllUsesWith(Cond);
@@ -189,6 +190,10 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) {
}
}
break;
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once:
+ F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
+ break;
case Intrinsic::coro_resume:
lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
break;
@@ -231,11 +236,18 @@ struct CoroEarly : public FunctionPass {
// This pass has work to do only if we find intrinsics we are going to lower
// in the module.
bool doInitialization(Module &M) override {
- if (coro::declaresIntrinsics(
- M, {"llvm.coro.id", "llvm.coro.destroy", "llvm.coro.done",
- "llvm.coro.end", "llvm.coro.noop", "llvm.coro.free",
- "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"}))
- L = llvm::make_unique<Lowerer>(M);
+ if (coro::declaresIntrinsics(M, {"llvm.coro.id",
+ "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once",
+ "llvm.coro.destroy",
+ "llvm.coro.done",
+ "llvm.coro.end",
+ "llvm.coro.noop",
+ "llvm.coro.free",
+ "llvm.coro.promise",
+ "llvm.coro.resume",
+ "llvm.coro.suspend"}))
+ L = std::make_unique<Lowerer>(M);
return false;
}
diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp
index 6707aa1c827d..aca77119023b 100644
--- a/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/lib/Transforms/Coroutines/CoroElide.cpp
@@ -286,7 +286,7 @@ struct CoroElide : FunctionPass {
bool doInitialization(Module &M) override {
if (coro::declaresIntrinsics(M, {"llvm.coro.id"}))
- L = llvm::make_unique<Lowerer>(M);
+ L = std::make_unique<Lowerer>(M);
return false;
}
diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp
index 58bf22bee29b..2c42cf8a6d25 100644
--- a/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -18,6 +18,7 @@
#include "CoroInternal.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/Analysis/PtrUseVisitor.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/CFG.h"
@@ -28,6 +29,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/circular_raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/PromoteMemToReg.h"
using namespace llvm;
@@ -120,6 +122,15 @@ struct SuspendCrossingInfo {
return false;
BasicBlock *UseBB = I->getParent();
+
+ // As a special case, treat uses by an llvm.coro.suspend.retcon
+ // as if they were uses in the suspend's single predecessor: the
+ // uses conceptually occur before the suspend.
+ if (isa<CoroSuspendRetconInst>(I)) {
+ UseBB = UseBB->getSinglePredecessor();
+ assert(UseBB && "should have split coro.suspend into its own block");
+ }
+
return hasPathCrossingSuspendPoint(DefBB, UseBB);
}
@@ -128,7 +139,17 @@ struct SuspendCrossingInfo {
}
bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
- return isDefinitionAcrossSuspend(I.getParent(), U);
+ auto *DefBB = I.getParent();
+
+ // As a special case, treat values produced by an llvm.coro.suspend.*
+ // as if they were defined in the single successor: the uses
+ // conceptually occur after the suspend.
+ if (isa<AnyCoroSuspendInst>(I)) {
+ DefBB = DefBB->getSingleSuccessor();
+ assert(DefBB && "should have split coro.suspend into its own block");
+ }
+
+ return isDefinitionAcrossSuspend(DefBB, U);
}
};
} // end anonymous namespace
@@ -183,9 +204,10 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
B.Suspend = true;
B.Kills |= B.Consumes;
};
- for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
+ for (auto *CSI : Shape.CoroSuspends) {
markSuspendBlock(CSI);
- markSuspendBlock(CSI->getCoroSave());
+ if (auto *Save = CSI->getCoroSave())
+ markSuspendBlock(Save);
}
// Iterate propagating consumes and kills until they stop changing.
@@ -261,11 +283,13 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
// We build up the list of spills for every case where a use is separated
// from the definition by a suspend point.
+static const unsigned InvalidFieldIndex = ~0U;
+
namespace {
class Spill {
Value *Def = nullptr;
Instruction *User = nullptr;
- unsigned FieldNo = 0;
+ unsigned FieldNo = InvalidFieldIndex;
public:
Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {}
@@ -280,11 +304,11 @@ public:
// the definition the first time they encounter it. Consider refactoring
// SpillInfo into two arrays to normalize the spill representation.
unsigned fieldIndex() const {
- assert(FieldNo && "Accessing unassigned field");
+ assert(FieldNo != InvalidFieldIndex && "Accessing unassigned field");
return FieldNo;
}
void setFieldIndex(unsigned FieldNumber) {
- assert(!FieldNo && "Reassigning field number");
+ assert(FieldNo == InvalidFieldIndex && "Reassigning field number");
FieldNo = FieldNumber;
}
};
@@ -376,18 +400,30 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
SmallString<32> Name(F.getName());
Name.append(".Frame");
StructType *FrameTy = StructType::create(C, Name);
- auto *FramePtrTy = FrameTy->getPointerTo();
- auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
- /*isVarArg=*/false);
- auto *FnPtrTy = FnTy->getPointerTo();
-
- // Figure out how wide should be an integer type storing the suspend index.
- unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size()));
- Type *PromiseType = Shape.PromiseAlloca
- ? Shape.PromiseAlloca->getType()->getElementType()
- : Type::getInt1Ty(C);
- SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, PromiseType,
- Type::getIntNTy(C, IndexBits)};
+ SmallVector<Type *, 8> Types;
+
+ AllocaInst *PromiseAlloca = Shape.getPromiseAlloca();
+
+ if (Shape.ABI == coro::ABI::Switch) {
+ auto *FramePtrTy = FrameTy->getPointerTo();
+ auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
+ /*IsVarArg=*/false);
+ auto *FnPtrTy = FnTy->getPointerTo();
+
+ // Figure out how wide should be an integer type storing the suspend index.
+ unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size()));
+ Type *PromiseType = PromiseAlloca
+ ? PromiseAlloca->getType()->getElementType()
+ : Type::getInt1Ty(C);
+ Type *IndexType = Type::getIntNTy(C, IndexBits);
+ Types.push_back(FnPtrTy);
+ Types.push_back(FnPtrTy);
+ Types.push_back(PromiseType);
+ Types.push_back(IndexType);
+ } else {
+ assert(PromiseAlloca == nullptr && "lowering doesn't support promises");
+ }
+
Value *CurrentDef = nullptr;
Padder.addTypes(Types);
@@ -399,7 +435,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
CurrentDef = S.def();
// PromiseAlloca was already added to Types array earlier.
- if (CurrentDef == Shape.PromiseAlloca)
+ if (CurrentDef == PromiseAlloca)
continue;
uint64_t Count = 1;
@@ -430,9 +466,80 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
}
FrameTy->setBody(Types);
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ break;
+
+ // Remember whether the frame is inline in the storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto &Layout = F.getParent()->getDataLayout();
+ auto Id = Shape.getRetconCoroId();
+ Shape.RetconLowering.IsFrameInlineInStorage
+ = (Layout.getTypeAllocSize(FrameTy) <= Id->getStorageSize() &&
+ Layout.getABITypeAlignment(FrameTy) <= Id->getStorageAlignment());
+ break;
+ }
+ }
+
return FrameTy;
}
+// We use a pointer use visitor to discover if there are any writes into an
+// alloca that dominates CoroBegin. If that is the case, insertSpills will copy
+// the value from the alloca into the coroutine frame spill slot corresponding
+// to that alloca.
+namespace {
+struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> {
+ using Base = PtrUseVisitor<AllocaUseVisitor>;
+ AllocaUseVisitor(const DataLayout &DL, const DominatorTree &DT,
+ const CoroBeginInst &CB)
+ : PtrUseVisitor(DL), DT(DT), CoroBegin(CB) {}
+
+ // We are only interested in uses that dominate coro.begin.
+ void visit(Instruction &I) {
+ if (DT.dominates(&I, &CoroBegin))
+ Base::visit(I);
+ }
+ // We need to provide this overload as PtrUseVisitor uses a pointer based
+ // visiting function.
+ void visit(Instruction *I) { return visit(*I); }
+
+ void visitLoadInst(LoadInst &) {} // Good. Nothing to do.
+
+ // If the use is an operand, the pointer escaped and anything can write into
+ // that memory. If the use is the pointer, we are definitely writing into the
+ // alloca and therefore we need to copy.
+ void visitStoreInst(StoreInst &SI) { PI.setAborted(&SI); }
+
+ // Any other instruction that is not filtered out by PtrUseVisitor, will
+ // result in the copy.
+ void visitInstruction(Instruction &I) { PI.setAborted(&I); }
+
+private:
+ const DominatorTree &DT;
+ const CoroBeginInst &CoroBegin;
+};
+} // namespace
+static bool mightWriteIntoAllocaPtr(AllocaInst &A, const DominatorTree &DT,
+ const CoroBeginInst &CB) {
+ const DataLayout &DL = A.getModule()->getDataLayout();
+ AllocaUseVisitor Visitor(DL, DT, CB);
+ auto PtrI = Visitor.visitPtr(A);
+ if (PtrI.isEscaped() || PtrI.isAborted()) {
+ auto *PointerEscapingInstr = PtrI.getEscapingInst()
+ ? PtrI.getEscapingInst()
+ : PtrI.getAbortingInst();
+ if (PointerEscapingInstr) {
+ LLVM_DEBUG(
+ dbgs() << "AllocaInst copy was triggered by instruction: "
+ << *PointerEscapingInstr << "\n");
+ }
+ return true;
+ }
+ return false;
+}
+
// We need to make room to insert a spill after initial PHIs, but before
// catchswitch instruction. Placing it before violates the requirement that
// catchswitch, like all other EHPads must be the first nonPHI in a block.
@@ -476,7 +583,7 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
// whatever
//
//
-static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
+static Instruction *insertSpills(const SpillInfo &Spills, coro::Shape &Shape) {
auto *CB = Shape.CoroBegin;
LLVMContext &C = CB->getContext();
IRBuilder<> Builder(CB->getNextNode());
@@ -484,11 +591,14 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
PointerType *FramePtrTy = FrameTy->getPointerTo();
auto *FramePtr =
cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr"));
+ DominatorTree DT(*CB->getFunction());
Value *CurrentValue = nullptr;
BasicBlock *CurrentBlock = nullptr;
Value *CurrentReload = nullptr;
- unsigned Index = 0; // Proper field number will be read from field definition.
+
+ // Proper field number will be read from field definition.
+ unsigned Index = InvalidFieldIndex;
// We need to keep track of any allocas that need "spilling"
// since they will live in the coroutine frame now, all access to them
@@ -496,9 +606,11 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
// we remember allocas and their indices to be handled once we processed
// all the spills.
SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas;
- // Promise alloca (if present) has a fixed field number (Shape::PromiseField)
- if (Shape.PromiseAlloca)
- Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField);
+ // Promise alloca (if present) has a fixed field number.
+ if (auto *PromiseAlloca = Shape.getPromiseAlloca()) {
+ assert(Shape.ABI == coro::ABI::Switch);
+ Allocas.emplace_back(PromiseAlloca, coro::Shape::SwitchFieldIndex::Promise);
+ }
// Create a GEP with the given index into the coroutine frame for the original
// value Orig. Appends an extra 0 index for array-allocas, preserving the
@@ -526,7 +638,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
// Create a load instruction to reload the spilled value from the coroutine
// frame.
auto CreateReload = [&](Instruction *InsertBefore) {
- assert(Index && "accessing unassigned field number");
+ assert(Index != InvalidFieldIndex && "accessing unassigned field number");
Builder.SetInsertPoint(InsertBefore);
auto *G = GetFramePointer(Index, CurrentValue);
@@ -558,29 +670,45 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
// coroutine frame.
Instruction *InsertPt = nullptr;
- if (isa<Argument>(CurrentValue)) {
+ if (auto Arg = dyn_cast<Argument>(CurrentValue)) {
// For arguments, we will place the store instruction right after
// the coroutine frame pointer instruction, i.e. bitcast of
// coro.begin from i8* to %f.frame*.
InsertPt = FramePtr->getNextNode();
+
+ // If we're spilling an Argument, make sure we clear 'nocapture'
+ // from the coroutine function.
+ Arg->getParent()->removeParamAttr(Arg->getArgNo(),
+ Attribute::NoCapture);
+
} else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) {
// If we are spilling the result of the invoke instruction, split the
// normal edge and insert the spill in the new block.
auto NewBB = SplitEdge(II->getParent(), II->getNormalDest());
InsertPt = NewBB->getTerminator();
- } else if (dyn_cast<PHINode>(CurrentValue)) {
+ } else if (isa<PHINode>(CurrentValue)) {
// Skip the PHINodes and EH pads instructions.
BasicBlock *DefBlock = cast<Instruction>(E.def())->getParent();
if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
InsertPt = splitBeforeCatchSwitch(CSI);
else
InsertPt = &*DefBlock->getFirstInsertionPt();
+ } else if (auto CSI = dyn_cast<AnyCoroSuspendInst>(CurrentValue)) {
+ // Don't spill immediately after a suspend; splitting assumes
+ // that the suspend will be followed by a branch.
+ InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHI();
} else {
+ auto *I = cast<Instruction>(E.def());
+ assert(!I->isTerminator() && "unexpected terminator");
// For all other values, the spill is placed immediately after
// the definition.
- assert(!cast<Instruction>(E.def())->isTerminator() &&
- "unexpected terminator");
- InsertPt = cast<Instruction>(E.def())->getNextNode();
+ if (DT.dominates(CB, I)) {
+ InsertPt = I->getNextNode();
+ } else {
+ // Unless, it is not dominated by CoroBegin, then it will be
+ // inserted immediately after CoroFrame is computed.
+ InsertPt = FramePtr->getNextNode();
+ }
}
Builder.SetInsertPoint(InsertPt);
@@ -613,21 +741,53 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
}
BasicBlock *FramePtrBB = FramePtr->getParent();
- Shape.AllocaSpillBlock =
- FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB");
- Shape.AllocaSpillBlock->splitBasicBlock(&Shape.AllocaSpillBlock->front(),
- "PostSpill");
- Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front());
+ auto SpillBlock =
+ FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB");
+ SpillBlock->splitBasicBlock(&SpillBlock->front(), "PostSpill");
+ Shape.AllocaSpillBlock = SpillBlock;
// If we found any allocas, replace all of their remaining uses with Geps.
+ // Note: we cannot do it indiscriminately as some of the uses may not be
+ // dominated by CoroBegin.
+ bool MightNeedToCopy = false;
+ Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front());
+ SmallVector<Instruction *, 4> UsersToUpdate;
for (auto &P : Allocas) {
- auto *G = GetFramePointer(P.second, P.first);
+ AllocaInst *const A = P.first;
+ UsersToUpdate.clear();
+ for (User *U : A->users()) {
+ auto *I = cast<Instruction>(U);
+ if (DT.dominates(CB, I))
+ UsersToUpdate.push_back(I);
+ else
+ MightNeedToCopy = true;
+ }
+ if (!UsersToUpdate.empty()) {
+ auto *G = GetFramePointer(P.second, A);
+ G->takeName(A);
+ for (Instruction *I : UsersToUpdate)
+ I->replaceUsesOfWith(A, G);
+ }
+ }
+ // If we discovered such uses not dominated by CoroBegin, see if any of them
+ // preceed coro begin and have instructions that can modify the
+ // value of the alloca and therefore would require a copying the value into
+ // the spill slot in the coroutine frame.
+ if (MightNeedToCopy) {
+ Builder.SetInsertPoint(FramePtr->getNextNode());
+
+ for (auto &P : Allocas) {
+ AllocaInst *const A = P.first;
+ if (mightWriteIntoAllocaPtr(*A, DT, *CB)) {
+ if (A->isArrayAllocation())
+ report_fatal_error(
+ "Coroutines cannot handle copying of array allocas yet");
- // We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here,
- // as we are changing location of the instruction.
- G->takeName(P.first);
- P.first->replaceAllUsesWith(G);
- P.first->eraseFromParent();
+ auto *G = GetFramePointer(P.second, A);
+ auto *Value = Builder.CreateLoad(A);
+ Builder.CreateStore(Value, G);
+ }
+ }
}
return FramePtr;
}
@@ -829,52 +989,6 @@ static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
}
}
-// Move early uses of spilled variable after CoroBegin.
-// For example, if a parameter had address taken, we may end up with the code
-// like:
-// define @f(i32 %n) {
-// %n.addr = alloca i32
-// store %n, %n.addr
-// ...
-// call @coro.begin
-// we need to move the store after coro.begin
-static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills,
- CoroBeginInst *CoroBegin) {
- DominatorTree DT(F);
- SmallVector<Instruction *, 8> NeedsMoving;
-
- Value *CurrentValue = nullptr;
-
- for (auto const &E : Spills) {
- if (CurrentValue == E.def())
- continue;
-
- CurrentValue = E.def();
-
- for (User *U : CurrentValue->users()) {
- Instruction *I = cast<Instruction>(U);
- if (!DT.dominates(CoroBegin, I)) {
- LLVM_DEBUG(dbgs() << "will move: " << *I << "\n");
-
- // TODO: Make this more robust. Currently if we run into a situation
- // where simple instruction move won't work we panic and
- // report_fatal_error.
- for (User *UI : I->users()) {
- if (!DT.dominates(CoroBegin, cast<Instruction>(UI)))
- report_fatal_error("cannot move instruction since its users are not"
- " dominated by CoroBegin");
- }
-
- NeedsMoving.push_back(I);
- }
- }
- }
-
- Instruction *InsertPt = CoroBegin->getNextNode();
- for (Instruction *I : NeedsMoving)
- I->moveBefore(InsertPt);
-}
-
// Splits the block at a particular instruction unless it is the first
// instruction in the block with a single predecessor.
static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) {
@@ -895,21 +1009,337 @@ static void splitAround(Instruction *I, const Twine &Name) {
splitBlockIfNotFirst(I->getNextNode(), "After" + Name);
}
+static bool isSuspendBlock(BasicBlock *BB) {
+ return isa<AnyCoroSuspendInst>(BB->front());
+}
+
+typedef SmallPtrSet<BasicBlock*, 8> VisitedBlocksSet;
+
+/// Does control flow starting at the given block ever reach a suspend
+/// instruction before reaching a block in VisitedOrFreeBBs?
+static bool isSuspendReachableFrom(BasicBlock *From,
+ VisitedBlocksSet &VisitedOrFreeBBs) {
+ // Eagerly try to add this block to the visited set. If it's already
+ // there, stop recursing; this path doesn't reach a suspend before
+ // either looping or reaching a freeing block.
+ if (!VisitedOrFreeBBs.insert(From).second)
+ return false;
+
+ // We assume that we'll already have split suspends into their own blocks.
+ if (isSuspendBlock(From))
+ return true;
+
+ // Recurse on the successors.
+ for (auto Succ : successors(From)) {
+ if (isSuspendReachableFrom(Succ, VisitedOrFreeBBs))
+ return true;
+ }
+
+ return false;
+}
+
+/// Is the given alloca "local", i.e. bounded in lifetime to not cross a
+/// suspend point?
+static bool isLocalAlloca(CoroAllocaAllocInst *AI) {
+ // Seed the visited set with all the basic blocks containing a free
+ // so that we won't pass them up.
+ VisitedBlocksSet VisitedOrFreeBBs;
+ for (auto User : AI->users()) {
+ if (auto FI = dyn_cast<CoroAllocaFreeInst>(User))
+ VisitedOrFreeBBs.insert(FI->getParent());
+ }
+
+ return !isSuspendReachableFrom(AI->getParent(), VisitedOrFreeBBs);
+}
+
+/// After we split the coroutine, will the given basic block be along
+/// an obvious exit path for the resumption function?
+static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB,
+ unsigned depth = 3) {
+ // If we've bottomed out our depth count, stop searching and assume
+ // that the path might loop back.
+ if (depth == 0) return false;
+
+ // If this is a suspend block, we're about to exit the resumption function.
+ if (isSuspendBlock(BB)) return true;
+
+ // Recurse into the successors.
+ for (auto Succ : successors(BB)) {
+ if (!willLeaveFunctionImmediatelyAfter(Succ, depth - 1))
+ return false;
+ }
+
+ // If none of the successors leads back in a loop, we're on an exit/abort.
+ return true;
+}
+
+static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) {
+ // Look for a free that isn't sufficiently obviously followed by
+ // either a suspend or a termination, i.e. something that will leave
+ // the coro resumption frame.
+ for (auto U : AI->users()) {
+ auto FI = dyn_cast<CoroAllocaFreeInst>(U);
+ if (!FI) continue;
+
+ if (!willLeaveFunctionImmediatelyAfter(FI->getParent()))
+ return true;
+ }
+
+ // If we never found one, we don't need a stack save.
+ return false;
+}
+
+/// Turn each of the given local allocas into a normal (dynamic) alloca
+/// instruction.
+static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas,
+ SmallVectorImpl<Instruction*> &DeadInsts) {
+ for (auto AI : LocalAllocas) {
+ auto M = AI->getModule();
+ IRBuilder<> Builder(AI);
+
+ // Save the stack depth. Try to avoid doing this if the stackrestore
+ // is going to immediately precede a return or something.
+ Value *StackSave = nullptr;
+ if (localAllocaNeedsStackSave(AI))
+ StackSave = Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::stacksave));
+
+ // Allocate memory.
+ auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize());
+ Alloca->setAlignment(MaybeAlign(AI->getAlignment()));
+
+ for (auto U : AI->users()) {
+ // Replace gets with the allocation.
+ if (isa<CoroAllocaGetInst>(U)) {
+ U->replaceAllUsesWith(Alloca);
+
+ // Replace frees with stackrestores. This is safe because
+ // alloca.alloc is required to obey a stack discipline, although we
+ // don't enforce that structurally.
+ } else {
+ auto FI = cast<CoroAllocaFreeInst>(U);
+ if (StackSave) {
+ Builder.SetInsertPoint(FI);
+ Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::stackrestore),
+ StackSave);
+ }
+ }
+ DeadInsts.push_back(cast<Instruction>(U));
+ }
+
+ DeadInsts.push_back(AI);
+ }
+}
+
+/// Turn the given coro.alloca.alloc call into a dynamic allocation.
+/// This happens during the all-instructions iteration, so it must not
+/// delete the call.
+static Instruction *lowerNonLocalAlloca(CoroAllocaAllocInst *AI,
+ coro::Shape &Shape,
+ SmallVectorImpl<Instruction*> &DeadInsts) {
+ IRBuilder<> Builder(AI);
+ auto Alloc = Shape.emitAlloc(Builder, AI->getSize(), nullptr);
+
+ for (User *U : AI->users()) {
+ if (isa<CoroAllocaGetInst>(U)) {
+ U->replaceAllUsesWith(Alloc);
+ } else {
+ auto FI = cast<CoroAllocaFreeInst>(U);
+ Builder.SetInsertPoint(FI);
+ Shape.emitDealloc(Builder, Alloc, nullptr);
+ }
+ DeadInsts.push_back(cast<Instruction>(U));
+ }
+
+ // Push this on last so that it gets deleted after all the others.
+ DeadInsts.push_back(AI);
+
+ // Return the new allocation value so that we can check for needed spills.
+ return cast<Instruction>(Alloc);
+}
+
+/// Get the current swifterror value.
+static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy,
+ coro::Shape &Shape) {
+ // Make a fake function pointer as a sort of intrinsic.
+ auto FnTy = FunctionType::get(ValueTy, {}, false);
+ auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+
+ auto Call = Builder.CreateCall(Fn, {});
+ Shape.SwiftErrorOps.push_back(Call);
+
+ return Call;
+}
+
+/// Set the given value as the current swifterror value.
+///
+/// Returns a slot that can be used as a swifterror slot.
+static Value *emitSetSwiftErrorValue(IRBuilder<> &Builder, Value *V,
+ coro::Shape &Shape) {
+ // Make a fake function pointer as a sort of intrinsic.
+ auto FnTy = FunctionType::get(V->getType()->getPointerTo(),
+ {V->getType()}, false);
+ auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+
+ auto Call = Builder.CreateCall(Fn, { V });
+ Shape.SwiftErrorOps.push_back(Call);
+
+ return Call;
+}
+
+/// Set the swifterror value from the given alloca before a call,
+/// then put in back in the alloca afterwards.
+///
+/// Returns an address that will stand in for the swifterror slot
+/// until splitting.
+static Value *emitSetAndGetSwiftErrorValueAround(Instruction *Call,
+ AllocaInst *Alloca,
+ coro::Shape &Shape) {
+ auto ValueTy = Alloca->getAllocatedType();
+ IRBuilder<> Builder(Call);
+
+ // Load the current value from the alloca and set it as the
+ // swifterror value.
+ auto ValueBeforeCall = Builder.CreateLoad(ValueTy, Alloca);
+ auto Addr = emitSetSwiftErrorValue(Builder, ValueBeforeCall, Shape);
+
+ // Move to after the call. Since swifterror only has a guaranteed
+ // value on normal exits, we can ignore implicit and explicit unwind
+ // edges.
+ if (isa<CallInst>(Call)) {
+ Builder.SetInsertPoint(Call->getNextNode());
+ } else {
+ auto Invoke = cast<InvokeInst>(Call);
+ Builder.SetInsertPoint(Invoke->getNormalDest()->getFirstNonPHIOrDbg());
+ }
+
+ // Get the current swifterror value and store it to the alloca.
+ auto ValueAfterCall = emitGetSwiftErrorValue(Builder, ValueTy, Shape);
+ Builder.CreateStore(ValueAfterCall, Alloca);
+
+ return Addr;
+}
+
+/// Eliminate a formerly-swifterror alloca by inserting the get/set
+/// intrinsics and attempting to MemToReg the alloca away.
+static void eliminateSwiftErrorAlloca(Function &F, AllocaInst *Alloca,
+ coro::Shape &Shape) {
+ for (auto UI = Alloca->use_begin(), UE = Alloca->use_end(); UI != UE; ) {
+ // We're likely changing the use list, so use a mutation-safe
+ // iteration pattern.
+ auto &Use = *UI;
+ ++UI;
+
+ // swifterror values can only be used in very specific ways.
+ // We take advantage of that here.
+ auto User = Use.getUser();
+ if (isa<LoadInst>(User) || isa<StoreInst>(User))
+ continue;
+
+ assert(isa<CallInst>(User) || isa<InvokeInst>(User));
+ auto Call = cast<Instruction>(User);
+
+ auto Addr = emitSetAndGetSwiftErrorValueAround(Call, Alloca, Shape);
+
+ // Use the returned slot address as the call argument.
+ Use.set(Addr);
+ }
+
+ // All the uses should be loads and stores now.
+ assert(isAllocaPromotable(Alloca));
+}
+
+/// "Eliminate" a swifterror argument by reducing it to the alloca case
+/// and then loading and storing in the prologue and epilog.
+///
+/// The argument keeps the swifterror flag.
+static void eliminateSwiftErrorArgument(Function &F, Argument &Arg,
+ coro::Shape &Shape,
+ SmallVectorImpl<AllocaInst*> &AllocasToPromote) {
+ IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
+
+ auto ArgTy = cast<PointerType>(Arg.getType());
+ auto ValueTy = ArgTy->getElementType();
+
+ // Reduce to the alloca case:
+
+ // Create an alloca and replace all uses of the arg with it.
+ auto Alloca = Builder.CreateAlloca(ValueTy, ArgTy->getAddressSpace());
+ Arg.replaceAllUsesWith(Alloca);
+
+ // Set an initial value in the alloca. swifterror is always null on entry.
+ auto InitialValue = Constant::getNullValue(ValueTy);
+ Builder.CreateStore(InitialValue, Alloca);
+
+ // Find all the suspends in the function and save and restore around them.
+ for (auto Suspend : Shape.CoroSuspends) {
+ (void) emitSetAndGetSwiftErrorValueAround(Suspend, Alloca, Shape);
+ }
+
+ // Find all the coro.ends in the function and restore the error value.
+ for (auto End : Shape.CoroEnds) {
+ Builder.SetInsertPoint(End);
+ auto FinalValue = Builder.CreateLoad(ValueTy, Alloca);
+ (void) emitSetSwiftErrorValue(Builder, FinalValue, Shape);
+ }
+
+ // Now we can use the alloca logic.
+ AllocasToPromote.push_back(Alloca);
+ eliminateSwiftErrorAlloca(F, Alloca, Shape);
+}
+
+/// Eliminate all problematic uses of swifterror arguments and allocas
+/// from the function. We'll fix them up later when splitting the function.
+static void eliminateSwiftError(Function &F, coro::Shape &Shape) {
+ SmallVector<AllocaInst*, 4> AllocasToPromote;
+
+ // Look for a swifterror argument.
+ for (auto &Arg : F.args()) {
+ if (!Arg.hasSwiftErrorAttr()) continue;
+
+ eliminateSwiftErrorArgument(F, Arg, Shape, AllocasToPromote);
+ break;
+ }
+
+ // Look for swifterror allocas.
+ for (auto &Inst : F.getEntryBlock()) {
+ auto Alloca = dyn_cast<AllocaInst>(&Inst);
+ if (!Alloca || !Alloca->isSwiftError()) continue;
+
+ // Clear the swifterror flag.
+ Alloca->setSwiftError(false);
+
+ AllocasToPromote.push_back(Alloca);
+ eliminateSwiftErrorAlloca(F, Alloca, Shape);
+ }
+
+ // If we have any allocas to promote, compute a dominator tree and
+ // promote them en masse.
+ if (!AllocasToPromote.empty()) {
+ DominatorTree DT(F);
+ PromoteMemToReg(AllocasToPromote, DT);
+ }
+}
+
void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
// Lower coro.dbg.declare to coro.dbg.value, since we are going to rewrite
// access to local variables.
LowerDbgDeclare(F);
- Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise();
- if (Shape.PromiseAlloca) {
- Shape.CoroBegin->getId()->clearPromise();
+ eliminateSwiftError(F, Shape);
+
+ if (Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.PromiseAlloca) {
+ Shape.getSwitchCoroId()->clearPromise();
}
// Make sure that all coro.save, coro.suspend and the fallthrough coro.end
// intrinsics are in their own blocks to simplify the logic of building up
// SuspendCrossing data.
- for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
- splitAround(CSI->getCoroSave(), "CoroSave");
+ for (auto *CSI : Shape.CoroSuspends) {
+ if (auto *Save = CSI->getCoroSave())
+ splitAround(Save, "CoroSave");
splitAround(CSI, "CoroSuspend");
}
@@ -926,6 +1356,8 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
IRBuilder<> Builder(F.getContext());
SpillInfo Spills;
+ SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
+ SmallVector<Instruction*, 4> DeadInstructions;
for (int Repeat = 0; Repeat < 4; ++Repeat) {
// See if there are materializable instructions across suspend points.
@@ -955,11 +1387,40 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
// of the Coroutine Frame.
if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin)
continue;
+
// The Coroutine Promise always included into coroutine frame, no need to
// check for suspend crossing.
- if (Shape.PromiseAlloca == &I)
+ if (Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.PromiseAlloca == &I)
continue;
+ // Handle alloca.alloc specially here.
+ if (auto AI = dyn_cast<CoroAllocaAllocInst>(&I)) {
+ // Check whether the alloca's lifetime is bounded by suspend points.
+ if (isLocalAlloca(AI)) {
+ LocalAllocas.push_back(AI);
+ continue;
+ }
+
+ // If not, do a quick rewrite of the alloca and then add spills of
+ // the rewritten value. The rewrite doesn't invalidate anything in
+ // Spills because the other alloca intrinsics have no other operands
+ // besides AI, and it doesn't invalidate the iteration because we delay
+ // erasing AI.
+ auto Alloc = lowerNonLocalAlloca(AI, Shape, DeadInstructions);
+
+ for (User *U : Alloc->users()) {
+ if (Checker.isDefinitionAcrossSuspend(*Alloc, U))
+ Spills.emplace_back(Alloc, U);
+ }
+ continue;
+ }
+
+ // Ignore alloca.get; we process this as part of coro.alloca.alloc.
+ if (isa<CoroAllocaGetInst>(I)) {
+ continue;
+ }
+
for (User *U : I.users())
if (Checker.isDefinitionAcrossSuspend(I, U)) {
// We cannot spill a token.
@@ -970,7 +1431,10 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
}
}
LLVM_DEBUG(dump("Spills", Spills));
- moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin);
Shape.FrameTy = buildFrameType(F, Shape, Spills);
Shape.FramePtr = insertSpills(Spills, Shape);
+ lowerLocalAllocas(LocalAllocas, DeadInstructions);
+
+ for (auto I : DeadInstructions)
+ I->eraseFromParent();
}
diff --git a/lib/Transforms/Coroutines/CoroInstr.h b/lib/Transforms/Coroutines/CoroInstr.h
index 5e19d7642e38..de2d2920cb15 100644
--- a/lib/Transforms/Coroutines/CoroInstr.h
+++ b/lib/Transforms/Coroutines/CoroInstr.h
@@ -27,6 +27,7 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/raw_ostream.h"
namespace llvm {
@@ -77,10 +78,8 @@ public:
}
};
-/// This represents the llvm.coro.alloc instruction.
-class LLVM_LIBRARY_VISIBILITY CoroIdInst : public IntrinsicInst {
- enum { AlignArg, PromiseArg, CoroutineArg, InfoArg };
-
+/// This represents a common base class for llvm.coro.id instructions.
+class LLVM_LIBRARY_VISIBILITY AnyCoroIdInst : public IntrinsicInst {
public:
CoroAllocInst *getCoroAlloc() {
for (User *U : users())
@@ -97,6 +96,24 @@ public:
llvm_unreachable("no coro.begin associated with coro.id");
}
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ auto ID = I->getIntrinsicID();
+ return ID == Intrinsic::coro_id ||
+ ID == Intrinsic::coro_id_retcon ||
+ ID == Intrinsic::coro_id_retcon_once;
+ }
+
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdInst : public AnyCoroIdInst {
+ enum { AlignArg, PromiseArg, CoroutineArg, InfoArg };
+
+public:
AllocaInst *getPromise() const {
Value *Arg = getArgOperand(PromiseArg);
return isa<ConstantPointerNull>(Arg)
@@ -182,6 +199,80 @@ public:
}
};
+/// This represents either the llvm.coro.id.retcon or
+/// llvm.coro.id.retcon.once instruction.
+class LLVM_LIBRARY_VISIBILITY AnyCoroIdRetconInst : public AnyCoroIdInst {
+ enum { SizeArg, AlignArg, StorageArg, PrototypeArg, AllocArg, DeallocArg };
+
+public:
+ void checkWellFormed() const;
+
+ uint64_t getStorageSize() const {
+ return cast<ConstantInt>(getArgOperand(SizeArg))->getZExtValue();
+ }
+
+ uint64_t getStorageAlignment() const {
+ return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
+ }
+
+ Value *getStorage() const {
+ return getArgOperand(StorageArg);
+ }
+
+ /// Return the prototype for the continuation function. The type,
+ /// attributes, and calling convention of the continuation function(s)
+ /// are taken from this declaration.
+ Function *getPrototype() const {
+ return cast<Function>(getArgOperand(PrototypeArg)->stripPointerCasts());
+ }
+
+ /// Return the function to use for allocating memory.
+ Function *getAllocFunction() const {
+ return cast<Function>(getArgOperand(AllocArg)->stripPointerCasts());
+ }
+
+ /// Return the function to use for deallocating memory.
+ Function *getDeallocFunction() const {
+ return cast<Function>(getArgOperand(DeallocArg)->stripPointerCasts());
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ auto ID = I->getIntrinsicID();
+ return ID == Intrinsic::coro_id_retcon
+ || ID == Intrinsic::coro_id_retcon_once;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id.retcon instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdRetconInst
+ : public AnyCoroIdRetconInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_id_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id.retcon.once instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdRetconOnceInst
+ : public AnyCoroIdRetconInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_id_retcon_once;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
/// This represents the llvm.coro.frame instruction.
class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst {
public:
@@ -215,7 +306,9 @@ class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst {
enum { IdArg, MemArg };
public:
- CoroIdInst *getId() const { return cast<CoroIdInst>(getArgOperand(IdArg)); }
+ AnyCoroIdInst *getId() const {
+ return cast<AnyCoroIdInst>(getArgOperand(IdArg));
+ }
Value *getMem() const { return getArgOperand(MemArg); }
@@ -261,8 +354,22 @@ public:
}
};
+class LLVM_LIBRARY_VISIBILITY AnyCoroSuspendInst : public IntrinsicInst {
+public:
+ CoroSaveInst *getCoroSave() const;
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_suspend ||
+ I->getIntrinsicID() == Intrinsic::coro_suspend_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
/// This represents the llvm.coro.suspend instruction.
-class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst {
+class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public AnyCoroSuspendInst {
enum { SaveArg, FinalArg };
public:
@@ -273,6 +380,7 @@ public:
assert(isa<ConstantTokenNone>(Arg));
return nullptr;
}
+
bool isFinal() const {
return cast<Constant>(getArgOperand(FinalArg))->isOneValue();
}
@@ -286,6 +394,37 @@ public:
}
};
+inline CoroSaveInst *AnyCoroSuspendInst::getCoroSave() const {
+ if (auto Suspend = dyn_cast<CoroSuspendInst>(this))
+ return Suspend->getCoroSave();
+ return nullptr;
+}
+
+/// This represents the llvm.coro.suspend.retcon instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSuspendRetconInst : public AnyCoroSuspendInst {
+public:
+ op_iterator value_begin() { return arg_begin(); }
+ const_op_iterator value_begin() const { return arg_begin(); }
+
+ op_iterator value_end() { return arg_end(); }
+ const_op_iterator value_end() const { return arg_end(); }
+
+ iterator_range<op_iterator> value_operands() {
+ return make_range(value_begin(), value_end());
+ }
+ iterator_range<const_op_iterator> value_operands() const {
+ return make_range(value_begin(), value_end());
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_suspend_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
/// This represents the llvm.coro.size instruction.
class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst {
public:
@@ -317,6 +456,60 @@ public:
}
};
+/// This represents the llvm.coro.alloca.alloc instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaAllocInst : public IntrinsicInst {
+ enum { SizeArg, AlignArg };
+public:
+ Value *getSize() const {
+ return getArgOperand(SizeArg);
+ }
+ unsigned getAlignment() const {
+ return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_alloc;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloca.get instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaGetInst : public IntrinsicInst {
+ enum { AllocArg };
+public:
+ CoroAllocaAllocInst *getAlloc() const {
+ return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_get;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloca.free instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaFreeInst : public IntrinsicInst {
+ enum { AllocArg };
+public:
+ CoroAllocaAllocInst *getAlloc() const {
+ return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_free;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
} // End namespace llvm.
#endif
diff --git a/lib/Transforms/Coroutines/CoroInternal.h b/lib/Transforms/Coroutines/CoroInternal.h
index 441c8a20f1f3..c151474316f9 100644
--- a/lib/Transforms/Coroutines/CoroInternal.h
+++ b/lib/Transforms/Coroutines/CoroInternal.h
@@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
#include "CoroInstr.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Coroutines.h"
namespace llvm {
@@ -61,37 +62,174 @@ struct LowererBase {
Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt);
};
+enum class ABI {
+ /// The "resume-switch" lowering, where there are separate resume and
+ /// destroy functions that are shared between all suspend points. The
+ /// coroutine frame implicitly stores the resume and destroy functions,
+ /// the current index, and any promise value.
+ Switch,
+
+ /// The "returned-continuation" lowering, where each suspend point creates a
+ /// single continuation function that is used for both resuming and
+ /// destroying. Does not support promises.
+ Retcon,
+
+ /// The "unique returned-continuation" lowering, where each suspend point
+ /// creates a single continuation function that is used for both resuming
+ /// and destroying. Does not support promises. The function is known to
+ /// suspend at most once during its execution, and the return value of
+ /// the continuation is void.
+ RetconOnce,
+};
+
// Holds structural Coroutine Intrinsics for a particular function and other
// values used during CoroSplit pass.
struct LLVM_LIBRARY_VISIBILITY Shape {
CoroBeginInst *CoroBegin;
SmallVector<CoroEndInst *, 4> CoroEnds;
SmallVector<CoroSizeInst *, 2> CoroSizes;
- SmallVector<CoroSuspendInst *, 4> CoroSuspends;
-
- // Field Indexes for known coroutine frame fields.
- enum {
- ResumeField,
- DestroyField,
- PromiseField,
- IndexField,
+ SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends;
+ SmallVector<CallInst*, 2> SwiftErrorOps;
+
+ // Field indexes for special fields in the switch lowering.
+ struct SwitchFieldIndex {
+ enum {
+ Resume,
+ Destroy,
+ Promise,
+ Index,
+ /// The index of the first spill field.
+ FirstSpill
+ };
};
+ coro::ABI ABI;
+
StructType *FrameTy;
Instruction *FramePtr;
BasicBlock *AllocaSpillBlock;
- SwitchInst *ResumeSwitch;
- AllocaInst *PromiseAlloca;
- bool HasFinalSuspend;
+
+ struct SwitchLoweringStorage {
+ SwitchInst *ResumeSwitch;
+ AllocaInst *PromiseAlloca;
+ BasicBlock *ResumeEntryBlock;
+ bool HasFinalSuspend;
+ };
+
+ struct RetconLoweringStorage {
+ Function *ResumePrototype;
+ Function *Alloc;
+ Function *Dealloc;
+ BasicBlock *ReturnBlock;
+ bool IsFrameInlineInStorage;
+ };
+
+ union {
+ SwitchLoweringStorage SwitchLowering;
+ RetconLoweringStorage RetconLowering;
+ };
+
+ CoroIdInst *getSwitchCoroId() const {
+ assert(ABI == coro::ABI::Switch);
+ return cast<CoroIdInst>(CoroBegin->getId());
+ }
+
+ AnyCoroIdRetconInst *getRetconCoroId() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+ return cast<AnyCoroIdRetconInst>(CoroBegin->getId());
+ }
IntegerType *getIndexType() const {
+ assert(ABI == coro::ABI::Switch);
assert(FrameTy && "frame type not assigned");
- return cast<IntegerType>(FrameTy->getElementType(IndexField));
+ return cast<IntegerType>(FrameTy->getElementType(SwitchFieldIndex::Index));
}
ConstantInt *getIndex(uint64_t Value) const {
return ConstantInt::get(getIndexType(), Value);
}
+ PointerType *getSwitchResumePointerType() const {
+ assert(ABI == coro::ABI::Switch);
+ assert(FrameTy && "frame type not assigned");
+ return cast<PointerType>(FrameTy->getElementType(SwitchFieldIndex::Resume));
+ }
+
+ FunctionType *getResumeFunctionType() const {
+ switch (ABI) {
+ case coro::ABI::Switch: {
+ auto *FnPtrTy = getSwitchResumePointerType();
+ return cast<FunctionType>(FnPtrTy->getPointerElementType());
+ }
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return RetconLowering.ResumePrototype->getFunctionType();
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ ArrayRef<Type*> getRetconResultTypes() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+ auto FTy = CoroBegin->getFunction()->getFunctionType();
+
+ // The safety of all this is checked by checkWFRetconPrototype.
+ if (auto STy = dyn_cast<StructType>(FTy->getReturnType())) {
+ return STy->elements().slice(1);
+ } else {
+ return ArrayRef<Type*>();
+ }
+ }
+
+ ArrayRef<Type*> getRetconResumeTypes() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+
+ // The safety of all this is checked by checkWFRetconPrototype.
+ auto FTy = RetconLowering.ResumePrototype->getFunctionType();
+ return FTy->params().slice(1);
+ }
+
+ CallingConv::ID getResumeFunctionCC() const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ return CallingConv::Fast;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return RetconLowering.ResumePrototype->getCallingConv();
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ unsigned getFirstSpillFieldIndex() const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ return SwitchFieldIndex::FirstSpill;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return 0;
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ AllocaInst *getPromiseAlloca() const {
+ if (ABI == coro::ABI::Switch)
+ return SwitchLowering.PromiseAlloca;
+ return nullptr;
+ }
+
+ /// Allocate memory according to the rules of the active lowering.
+ ///
+ /// \param CG - if non-null, will be updated for the new call
+ Value *emitAlloc(IRBuilder<> &Builder, Value *Size, CallGraph *CG) const;
+
+ /// Deallocate memory according to the rules of the active lowering.
+ ///
+ /// \param CG - if non-null, will be updated for the new call
+ void emitDealloc(IRBuilder<> &Builder, Value *Ptr, CallGraph *CG) const;
+
Shape() = default;
explicit Shape(Function &F) { buildFrom(F); }
void buildFrom(Function &F);
diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp
index 5458e70ff16a..04723cbde417 100644
--- a/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -55,6 +55,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -70,9 +71,197 @@ using namespace llvm;
#define DEBUG_TYPE "coro-split"
+namespace {
+
+/// A little helper class for building
+class CoroCloner {
+public:
+ enum class Kind {
+ /// The shared resume function for a switch lowering.
+ SwitchResume,
+
+ /// The shared unwind function for a switch lowering.
+ SwitchUnwind,
+
+ /// The shared cleanup function for a switch lowering.
+ SwitchCleanup,
+
+ /// An individual continuation function.
+ Continuation,
+ };
+private:
+ Function &OrigF;
+ Function *NewF;
+ const Twine &Suffix;
+ coro::Shape &Shape;
+ Kind FKind;
+ ValueToValueMapTy VMap;
+ IRBuilder<> Builder;
+ Value *NewFramePtr = nullptr;
+ Value *SwiftErrorSlot = nullptr;
+
+ /// The active suspend instruction; meaningful only for continuation ABIs.
+ AnyCoroSuspendInst *ActiveSuspend = nullptr;
+
+public:
+ /// Create a cloner for a switch lowering.
+ CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
+ Kind FKind)
+ : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape),
+ FKind(FKind), Builder(OrigF.getContext()) {
+ assert(Shape.ABI == coro::ABI::Switch);
+ }
+
+ /// Create a cloner for a continuation lowering.
+ CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
+ Function *NewF, AnyCoroSuspendInst *ActiveSuspend)
+ : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape),
+ FKind(Kind::Continuation), Builder(OrigF.getContext()),
+ ActiveSuspend(ActiveSuspend) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ assert(NewF && "need existing function for continuation");
+ assert(ActiveSuspend && "need active suspend point for continuation");
+ }
+
+ Function *getFunction() const {
+ assert(NewF != nullptr && "declaration not yet set");
+ return NewF;
+ }
+
+ void create();
+
+private:
+ bool isSwitchDestroyFunction() {
+ switch (FKind) {
+ case Kind::Continuation:
+ case Kind::SwitchResume:
+ return false;
+ case Kind::SwitchUnwind:
+ case Kind::SwitchCleanup:
+ return true;
+ }
+ llvm_unreachable("Unknown CoroCloner::Kind enum");
+ }
+
+ void createDeclaration();
+ void replaceEntryBlock();
+ Value *deriveNewFramePointer();
+ void replaceRetconSuspendUses();
+ void replaceCoroSuspends();
+ void replaceCoroEnds();
+ void replaceSwiftErrorOps();
+ void handleFinalSuspend();
+ void maybeFreeContinuationStorage();
+};
+
+} // end anonymous namespace
+
+static void maybeFreeRetconStorage(IRBuilder<> &Builder, coro::Shape &Shape,
+ Value *FramePtr, CallGraph *CG) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ if (Shape.RetconLowering.IsFrameInlineInStorage)
+ return;
+
+ Shape.emitDealloc(Builder, FramePtr, CG);
+}
+
+/// Replace a non-unwind call to llvm.coro.end.
+static void replaceFallthroughCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume,
+ CallGraph *CG) {
+ // Start inserting right before the coro.end.
+ IRBuilder<> Builder(End);
+
+ // Create the return instruction.
+ switch (Shape.ABI) {
+ // The cloned functions in switch-lowering always return void.
+ case coro::ABI::Switch:
+ // coro.end doesn't immediately end the coroutine in the main function
+ // in this lowering, because we need to deallocate the coroutine.
+ if (!InResume)
+ return;
+ Builder.CreateRetVoid();
+ break;
+
+ // In unique continuation lowering, the continuations always return void.
+ // But we may have implicitly allocated storage.
+ case coro::ABI::RetconOnce:
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ Builder.CreateRetVoid();
+ break;
+
+ // In non-unique continuation lowering, we signal completion by returning
+ // a null continuation.
+ case coro::ABI::Retcon: {
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ auto RetTy = Shape.getResumeFunctionType()->getReturnType();
+ auto RetStructTy = dyn_cast<StructType>(RetTy);
+ PointerType *ContinuationTy =
+ cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy);
+
+ Value *ReturnValue = ConstantPointerNull::get(ContinuationTy);
+ if (RetStructTy) {
+ ReturnValue = Builder.CreateInsertValue(UndefValue::get(RetStructTy),
+ ReturnValue, 0);
+ }
+ Builder.CreateRet(ReturnValue);
+ break;
+ }
+ }
+
+ // Remove the rest of the block, by splitting it into an unreachable block.
+ auto *BB = End->getParent();
+ BB->splitBasicBlock(End);
+ BB->getTerminator()->eraseFromParent();
+}
+
+/// Replace an unwind call to llvm.coro.end.
+static void replaceUnwindCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume, CallGraph *CG){
+ IRBuilder<> Builder(End);
+
+ switch (Shape.ABI) {
+ // In switch-lowering, this does nothing in the main function.
+ case coro::ABI::Switch:
+ if (!InResume)
+ return;
+ break;
+
+ // In continuation-lowering, this frees the continuation storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ break;
+ }
+
+ // If coro.end has an associated bundle, add cleanupret instruction.
+ if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) {
+ auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]);
+ auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr);
+ End->getParent()->splitBasicBlock(End);
+ CleanupRet->getParent()->getTerminator()->eraseFromParent();
+ }
+}
+
+static void replaceCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume, CallGraph *CG) {
+ if (End->isUnwind())
+ replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
+ else
+ replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG);
+
+ auto &Context = End->getContext();
+ End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context)
+ : ConstantInt::getFalse(Context));
+ End->eraseFromParent();
+}
+
// Create an entry block for a resume function with a switch that will jump to
// suspend points.
-static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
+static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
+ assert(Shape.ABI == coro::ABI::Switch);
LLVMContext &C = F.getContext();
// resume.entry:
@@ -91,15 +280,16 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
IRBuilder<> Builder(NewEntry);
auto *FramePtr = Shape.FramePtr;
auto *FrameTy = Shape.FrameTy;
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
- FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
auto *Switch =
Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
- Shape.ResumeSwitch = Switch;
+ Shape.SwitchLowering.ResumeSwitch = Switch;
size_t SuspendIndex = 0;
- for (CoroSuspendInst *S : Shape.CoroSuspends) {
+ for (auto *AnyS : Shape.CoroSuspends) {
+ auto *S = cast<CoroSuspendInst>(AnyS);
ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
// Replace CoroSave with a store to Index:
@@ -109,14 +299,15 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
Builder.SetInsertPoint(Save);
if (S->isFinal()) {
// Final suspend point is represented by storing zero in ResumeFnAddr.
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
- 0, "ResumeFn.addr");
+ auto *GepIndex = Builder.CreateStructGEP(FrameTy, FramePtr,
+ coro::Shape::SwitchFieldIndex::Resume,
+ "ResumeFn.addr");
auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
cast<PointerType>(GepIndex->getType())->getElementType()));
Builder.CreateStore(NullPtr, GepIndex);
} else {
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
- FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
Builder.CreateStore(IndexVal, GepIndex);
}
Save->replaceAllUsesWith(ConstantTokenNone::get(C));
@@ -164,48 +355,9 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
Builder.SetInsertPoint(UnreachBB);
Builder.CreateUnreachable();
- return NewEntry;
+ Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
}
-// In Resumers, we replace fallthrough coro.end with ret void and delete the
-// rest of the block.
-static void replaceFallthroughCoroEnd(IntrinsicInst *End,
- ValueToValueMapTy &VMap) {
- auto *NewE = cast<IntrinsicInst>(VMap[End]);
- ReturnInst::Create(NewE->getContext(), nullptr, NewE);
-
- // Remove the rest of the block, by splitting it into an unreachable block.
- auto *BB = NewE->getParent();
- BB->splitBasicBlock(NewE);
- BB->getTerminator()->eraseFromParent();
-}
-
-// In Resumers, we replace unwind coro.end with True to force the immediate
-// unwind to caller.
-static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
- if (Shape.CoroEnds.empty())
- return;
-
- LLVMContext &Context = Shape.CoroEnds.front()->getContext();
- auto *True = ConstantInt::getTrue(Context);
- for (CoroEndInst *CE : Shape.CoroEnds) {
- if (!CE->isUnwind())
- continue;
-
- auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
-
- // If coro.end has an associated bundle, add cleanupret instruction.
- if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
- Value *FromPad = Bundle->Inputs[0];
- auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
- NewCE->getParent()->splitBasicBlock(NewCE);
- CleanupRet->getParent()->getTerminator()->eraseFromParent();
- }
-
- NewCE->replaceAllUsesWith(True);
- NewCE->eraseFromParent();
- }
-}
// Rewrite final suspend point handling. We do not use suspend index to
// represent the final suspend point. Instead we zero-out ResumeFnAddr in the
@@ -216,83 +368,364 @@ static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
// In the destroy function, we add a code sequence to check if ResumeFnAddress
// is Null, and if so, jump to the appropriate label to handle cleanup from the
// final suspend point.
-static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
- coro::Shape &Shape, SwitchInst *Switch,
- bool IsDestroy) {
- assert(Shape.HasFinalSuspend);
+void CoroCloner::handleFinalSuspend() {
+ assert(Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.HasFinalSuspend);
+ auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]);
auto FinalCaseIt = std::prev(Switch->case_end());
BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
Switch->removeCase(FinalCaseIt);
- if (IsDestroy) {
+ if (isSwitchDestroyFunction()) {
BasicBlock *OldSwitchBB = Switch->getParent();
auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
Builder.SetInsertPoint(OldSwitchBB->getTerminator());
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
- 0, 0, "ResumeFn.addr");
- auto *Load = Builder.CreateLoad(
- Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex);
- auto *NullPtr =
- ConstantPointerNull::get(cast<PointerType>(Load->getType()));
- auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
+ auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr,
+ coro::Shape::SwitchFieldIndex::Resume,
+ "ResumeFn.addr");
+ auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(),
+ GepIndex);
+ auto *Cond = Builder.CreateIsNull(Load);
Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
OldSwitchBB->getTerminator()->eraseFromParent();
}
}
-// Create a resume clone by cloning the body of the original function, setting
-// new entry block and replacing coro.suspend an appropriate value to force
-// resume or cleanup pass for every suspend point.
-static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
- BasicBlock *ResumeEntry, int8_t FnIndex) {
- Module *M = F.getParent();
- auto *FrameTy = Shape.FrameTy;
- auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
- auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
+static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape,
+ const Twine &Suffix,
+ Module::iterator InsertBefore) {
+ Module *M = OrigF.getParent();
+ auto *FnTy = Shape.getResumeFunctionType();
Function *NewF =
- Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage,
- F.getName() + Suffix, M);
+ Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
+ OrigF.getName() + Suffix);
NewF->addParamAttr(0, Attribute::NonNull);
NewF->addParamAttr(0, Attribute::NoAlias);
- ValueToValueMapTy VMap;
+ M->getFunctionList().insert(InsertBefore, NewF);
+
+ return NewF;
+}
+
+/// Replace uses of the active llvm.coro.suspend.retcon call with the
+/// arguments to the continuation function.
+///
+/// This assumes that the builder has a meaningful insertion point.
+void CoroCloner::replaceRetconSuspendUses() {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+
+ auto NewS = VMap[ActiveSuspend];
+ if (NewS->use_empty()) return;
+
+ // Copy out all the continuation arguments after the buffer pointer into
+ // an easily-indexed data structure for convenience.
+ SmallVector<Value*, 8> Args;
+ for (auto I = std::next(NewF->arg_begin()), E = NewF->arg_end(); I != E; ++I)
+ Args.push_back(&*I);
+
+ // If the suspend returns a single scalar value, we can just do a simple
+ // replacement.
+ if (!isa<StructType>(NewS->getType())) {
+ assert(Args.size() == 1);
+ NewS->replaceAllUsesWith(Args.front());
+ return;
+ }
+
+ // Try to peephole extracts of an aggregate return.
+ for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) {
+ auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser());
+ if (!EVI || EVI->getNumIndices() != 1)
+ continue;
+
+ EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]);
+ EVI->eraseFromParent();
+ }
+
+ // If we have no remaining uses, we're done.
+ if (NewS->use_empty()) return;
+
+ // Otherwise, we need to create an aggregate.
+ Value *Agg = UndefValue::get(NewS->getType());
+ for (size_t I = 0, E = Args.size(); I != E; ++I)
+ Agg = Builder.CreateInsertValue(Agg, Args[I], I);
+
+ NewS->replaceAllUsesWith(Agg);
+}
+
+void CoroCloner::replaceCoroSuspends() {
+ Value *SuspendResult;
+
+ switch (Shape.ABI) {
+ // In switch lowering, replace coro.suspend with the appropriate value
+ // for the type of function we're extracting.
+ // Replacing coro.suspend with (0) will result in control flow proceeding to
+ // a resume label associated with a suspend point, replacing it with (1) will
+ // result in control flow proceeding to a cleanup label associated with this
+ // suspend point.
+ case coro::ABI::Switch:
+ SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0);
+ break;
+
+ // In returned-continuation lowering, the arguments from earlier
+ // continuations are theoretically arbitrary, and they should have been
+ // spilled.
+ case coro::ABI::RetconOnce:
+ case coro::ABI::Retcon:
+ return;
+ }
+
+ for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) {
+ // The active suspend was handled earlier.
+ if (CS == ActiveSuspend) continue;
+
+ auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]);
+ MappedCS->replaceAllUsesWith(SuspendResult);
+ MappedCS->eraseFromParent();
+ }
+}
+
+void CoroCloner::replaceCoroEnds() {
+ for (CoroEndInst *CE : Shape.CoroEnds) {
+ // We use a null call graph because there's no call graph node for
+ // the cloned function yet. We'll just be rebuilding that later.
+ auto NewCE = cast<CoroEndInst>(VMap[CE]);
+ replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
+ }
+}
+
+static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
+ ValueToValueMapTy *VMap) {
+ Value *CachedSlot = nullptr;
+ auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
+ if (CachedSlot) {
+ assert(CachedSlot->getType()->getPointerElementType() == ValueTy &&
+ "multiple swifterror slots in function with different types");
+ return CachedSlot;
+ }
+
+ // Check if the function has a swifterror argument.
+ for (auto &Arg : F.args()) {
+ if (Arg.isSwiftError()) {
+ CachedSlot = &Arg;
+ assert(Arg.getType()->getPointerElementType() == ValueTy &&
+ "swifterror argument does not have expected type");
+ return &Arg;
+ }
+ }
+
+ // Create a swifterror alloca.
+ IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
+ auto Alloca = Builder.CreateAlloca(ValueTy);
+ Alloca->setSwiftError(true);
+
+ CachedSlot = Alloca;
+ return Alloca;
+ };
+
+ for (CallInst *Op : Shape.SwiftErrorOps) {
+ auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op;
+ IRBuilder<> Builder(MappedOp);
+
+ // If there are no arguments, this is a 'get' operation.
+ Value *MappedResult;
+ if (Op->getNumArgOperands() == 0) {
+ auto ValueTy = Op->getType();
+ auto Slot = getSwiftErrorSlot(ValueTy);
+ MappedResult = Builder.CreateLoad(ValueTy, Slot);
+ } else {
+ assert(Op->getNumArgOperands() == 1);
+ auto Value = MappedOp->getArgOperand(0);
+ auto ValueTy = Value->getType();
+ auto Slot = getSwiftErrorSlot(ValueTy);
+ Builder.CreateStore(Value, Slot);
+ MappedResult = Slot;
+ }
+
+ MappedOp->replaceAllUsesWith(MappedResult);
+ MappedOp->eraseFromParent();
+ }
+
+ // If we're updating the original function, we've invalidated SwiftErrorOps.
+ if (VMap == nullptr) {
+ Shape.SwiftErrorOps.clear();
+ }
+}
+
+void CoroCloner::replaceSwiftErrorOps() {
+ ::replaceSwiftErrorOps(*NewF, Shape, &VMap);
+}
+
+void CoroCloner::replaceEntryBlock() {
+ // In the original function, the AllocaSpillBlock is a block immediately
+ // following the allocation of the frame object which defines GEPs for
+ // all the allocas that have been moved into the frame, and it ends by
+ // branching to the original beginning of the coroutine. Make this
+ // the entry block of the cloned function.
+ auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
+ Entry->setName("entry" + Suffix);
+ Entry->moveBefore(&NewF->getEntryBlock());
+ Entry->getTerminator()->eraseFromParent();
+
+ // Clear all predecessors of the new entry block. There should be
+ // exactly one predecessor, which we created when splitting out
+ // AllocaSpillBlock to begin with.
+ assert(Entry->hasOneUse());
+ auto BranchToEntry = cast<BranchInst>(Entry->user_back());
+ assert(BranchToEntry->isUnconditional());
+ Builder.SetInsertPoint(BranchToEntry);
+ Builder.CreateUnreachable();
+ BranchToEntry->eraseFromParent();
+
+ // TODO: move any allocas into Entry that weren't moved into the frame.
+ // (Currently we move all allocas into the frame.)
+
+ // Branch from the entry to the appropriate place.
+ Builder.SetInsertPoint(Entry);
+ switch (Shape.ABI) {
+ case coro::ABI::Switch: {
+ // In switch-lowering, we built a resume-entry block in the original
+ // function. Make the entry block branch to this.
+ auto *SwitchBB =
+ cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]);
+ Builder.CreateBr(SwitchBB);
+ break;
+ }
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ // In continuation ABIs, we want to branch to immediately after the
+ // active suspend point. Earlier phases will have put the suspend in its
+ // own basic block, so just thread our jump directly to its successor.
+ auto MappedCS = cast<CoroSuspendRetconInst>(VMap[ActiveSuspend]);
+ auto Branch = cast<BranchInst>(MappedCS->getNextNode());
+ assert(Branch->isUnconditional());
+ Builder.CreateBr(Branch->getSuccessor(0));
+ break;
+ }
+ }
+}
+
+/// Derive the value of the new frame pointer.
+Value *CoroCloner::deriveNewFramePointer() {
+ // Builder should be inserting to the front of the new entry block.
+
+ switch (Shape.ABI) {
+ // In switch-lowering, the argument is the frame pointer.
+ case coro::ABI::Switch:
+ return &*NewF->arg_begin();
+
+ // In continuation-lowering, the argument is the opaque storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ Argument *NewStorage = &*NewF->arg_begin();
+ auto FramePtrTy = Shape.FrameTy->getPointerTo();
+
+ // If the storage is inline, just bitcast to the storage to the frame type.
+ if (Shape.RetconLowering.IsFrameInlineInStorage)
+ return Builder.CreateBitCast(NewStorage, FramePtrTy);
+
+ // Otherwise, load the real frame from the opaque storage.
+ auto FramePtrPtr =
+ Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo());
+ return Builder.CreateLoad(FramePtrPtr);
+ }
+ }
+ llvm_unreachable("bad ABI");
+}
+
+/// Clone the body of the original function into a resume function of
+/// some sort.
+void CoroCloner::create() {
+ // Create the new function if we don't already have one.
+ if (!NewF) {
+ NewF = createCloneDeclaration(OrigF, Shape, Suffix,
+ OrigF.getParent()->end());
+ }
+
// Replace all args with undefs. The buildCoroutineFrame algorithm already
// rewritten access to the args that occurs after suspend points with loads
// and stores to/from the coroutine frame.
- for (Argument &A : F.args())
+ for (Argument &A : OrigF.args())
VMap[&A] = UndefValue::get(A.getType());
SmallVector<ReturnInst *, 4> Returns;
- CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
- NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
+ // Ignore attempts to change certain attributes of the function.
+ // TODO: maybe there should be a way to suppress this during cloning?
+ auto savedVisibility = NewF->getVisibility();
+ auto savedUnnamedAddr = NewF->getUnnamedAddr();
+ auto savedDLLStorageClass = NewF->getDLLStorageClass();
+
+ // NewF's linkage (which CloneFunctionInto does *not* change) might not
+ // be compatible with the visibility of OrigF (which it *does* change),
+ // so protect against that.
+ auto savedLinkage = NewF->getLinkage();
+ NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
+
+ CloneFunctionInto(NewF, &OrigF, VMap, /*ModuleLevelChanges=*/true, Returns);
+
+ NewF->setLinkage(savedLinkage);
+ NewF->setVisibility(savedVisibility);
+ NewF->setUnnamedAddr(savedUnnamedAddr);
+ NewF->setDLLStorageClass(savedDLLStorageClass);
+
+ auto &Context = NewF->getContext();
+
+ // Replace the attributes of the new function:
+ auto OrigAttrs = NewF->getAttributes();
+ auto NewAttrs = AttributeList();
+
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ // Bootstrap attributes by copying function attributes from the
+ // original function. This should include optimization settings and so on.
+ NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex,
+ OrigAttrs.getFnAttributes());
+ break;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ // If we have a continuation prototype, just use its attributes,
+ // full-stop.
+ NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes();
+ break;
+ }
- // Remove old returns.
- for (ReturnInst *Return : Returns)
- changeToUnreachable(Return, /*UseLLVMTrap=*/false);
+ // Make the frame parameter nonnull and noalias.
+ NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NonNull);
+ NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NoAlias);
+
+ switch (Shape.ABI) {
+ // In these ABIs, the cloned functions always return 'void', and the
+ // existing return sites are meaningless. Note that for unique
+ // continuations, this includes the returns associated with suspends;
+ // this is fine because we can't suspend twice.
+ case coro::ABI::Switch:
+ case coro::ABI::RetconOnce:
+ // Remove old returns.
+ for (ReturnInst *Return : Returns)
+ changeToUnreachable(Return, /*UseLLVMTrap=*/false);
+ break;
+
+ // With multi-suspend continuations, we'll already have eliminated the
+ // original returns and inserted returns before all the suspend points,
+ // so we want to leave any returns in place.
+ case coro::ABI::Retcon:
+ break;
+ }
- // Remove old return attributes.
- NewF->removeAttributes(
- AttributeList::ReturnIndex,
- AttributeFuncs::typeIncompatible(NewF->getReturnType()));
+ NewF->setAttributes(NewAttrs);
+ NewF->setCallingConv(Shape.getResumeFunctionCC());
- // Make AllocaSpillBlock the new entry block.
- auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
- auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
- Entry->moveBefore(&NewF->getEntryBlock());
- Entry->getTerminator()->eraseFromParent();
- BranchInst::Create(SwitchBB, Entry);
- Entry->setName("entry" + Suffix);
+ // Set up the new entry block.
+ replaceEntryBlock();
- // Clear all predecessors of the new entry block.
- auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
- Entry->replaceAllUsesWith(Switch->getDefaultDest());
-
- IRBuilder<> Builder(&NewF->getEntryBlock().front());
+ Builder.SetInsertPoint(&NewF->getEntryBlock().front());
+ NewFramePtr = deriveNewFramePointer();
// Remap frame pointer.
- Argument *NewFramePtr = &*NewF->arg_begin();
- Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
+ Value *OldFramePtr = VMap[Shape.FramePtr];
NewFramePtr->takeName(OldFramePtr);
OldFramePtr->replaceAllUsesWith(NewFramePtr);
@@ -302,50 +735,55 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
OldVFrame->replaceAllUsesWith(NewVFrame);
- // Rewrite final suspend handling as it is not done via switch (allows to
- // remove final case from the switch, since it is undefined behavior to resume
- // the coroutine suspended at the final suspend point.
- if (Shape.HasFinalSuspend) {
- auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
- bool IsDestroy = FnIndex != 0;
- handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ // Rewrite final suspend handling as it is not done via switch (allows to
+ // remove final case from the switch, since it is undefined behavior to
+ // resume the coroutine suspended at the final suspend point.
+ if (Shape.SwitchLowering.HasFinalSuspend)
+ handleFinalSuspend();
+ break;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ // Replace uses of the active suspend with the corresponding
+ // continuation-function arguments.
+ assert(ActiveSuspend != nullptr &&
+ "no active suspend when lowering a continuation-style coroutine");
+ replaceRetconSuspendUses();
+ break;
}
- // Replace coro suspend with the appropriate resume index.
- // Replacing coro.suspend with (0) will result in control flow proceeding to
- // a resume label associated with a suspend point, replacing it with (1) will
- // result in control flow proceeding to a cleanup label associated with this
- // suspend point.
- auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
- for (CoroSuspendInst *CS : Shape.CoroSuspends) {
- auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
- MappedCS->replaceAllUsesWith(NewValue);
- MappedCS->eraseFromParent();
- }
+ // Handle suspends.
+ replaceCoroSuspends();
+
+ // Handle swifterror.
+ replaceSwiftErrorOps();
// Remove coro.end intrinsics.
- replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
- replaceUnwindCoroEnds(Shape, VMap);
+ replaceCoroEnds();
+
// Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
// to suppress deallocation code.
- coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
- /*Elide=*/FnIndex == 2);
-
- NewF->setCallingConv(CallingConv::Fast);
-
- return NewF;
+ if (Shape.ABI == coro::ABI::Switch)
+ coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
+ /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
}
-static void removeCoroEnds(coro::Shape &Shape) {
- if (Shape.CoroEnds.empty())
- return;
-
- LLVMContext &Context = Shape.CoroEnds.front()->getContext();
- auto *False = ConstantInt::getFalse(Context);
+// Create a resume clone by cloning the body of the original function, setting
+// new entry block and replacing coro.suspend an appropriate value to force
+// resume or cleanup pass for every suspend point.
+static Function *createClone(Function &F, const Twine &Suffix,
+ coro::Shape &Shape, CoroCloner::Kind FKind) {
+ CoroCloner Cloner(F, Suffix, Shape, FKind);
+ Cloner.create();
+ return Cloner.getFunction();
+}
- for (CoroEndInst *CE : Shape.CoroEnds) {
- CE->replaceAllUsesWith(False);
- CE->eraseFromParent();
+/// Remove calls to llvm.coro.end in the original function.
+static void removeCoroEnds(coro::Shape &Shape, CallGraph *CG) {
+ for (auto End : Shape.CoroEnds) {
+ replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG);
}
}
@@ -377,8 +815,12 @@ static void replaceFrameSize(coro::Shape &Shape) {
// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
//
// Assumes that all the functions have the same signature.
-static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
- std::initializer_list<Function *> Fns) {
+static void setCoroInfo(Function &F, coro::Shape &Shape,
+ ArrayRef<Function *> Fns) {
+ // This only works under the switch-lowering ABI because coro elision
+ // only works on the switch-lowering ABI.
+ assert(Shape.ABI == coro::ABI::Switch);
+
SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
assert(!Args.empty());
Function *Part = *Fns.begin();
@@ -393,38 +835,45 @@ static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
// Update coro.begin instruction to refer to this constant.
LLVMContext &C = F.getContext();
auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
- CoroBegin->getId()->setInfo(BC);
+ Shape.getSwitchCoroId()->setInfo(BC);
}
// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
Function *DestroyFn, Function *CleanupFn) {
+ assert(Shape.ABI == coro::ABI::Switch);
+
IRBuilder<> Builder(Shape.FramePtr->getNextNode());
- auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
- Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
+ auto *ResumeAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
"resume.addr");
Builder.CreateStore(ResumeFn, ResumeAddr);
Value *DestroyOrCleanupFn = DestroyFn;
- CoroIdInst *CoroId = Shape.CoroBegin->getId();
+ CoroIdInst *CoroId = Shape.getSwitchCoroId();
if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
// If there is a CoroAlloc and it returns false (meaning we elide the
// allocation, use CleanupFn instead of DestroyFn).
DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
}
- auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
- Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
+ auto *DestroyAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
"destroy.addr");
Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
}
static void postSplitCleanup(Function &F) {
removeUnreachableBlocks(F);
+
+ // For now, we do a mandatory verification step because we don't
+ // entirely trust this pass. Note that we don't want to add a verifier
+ // pass to FPM below because it will also verify all the global data.
+ verifyFunction(F);
+
legacy::FunctionPassManager FPM(F.getParent());
- FPM.add(createVerifierPass());
FPM.add(createSCCPPass());
FPM.add(createCFGSimplificationPass());
FPM.add(createEarlyCSEPass());
@@ -520,21 +969,34 @@ static void addMustTailToCoroResumes(Function &F) {
// Coroutine has no suspend points. Remove heap allocation for the coroutine
// frame if possible.
-static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
+static void handleNoSuspendCoroutine(coro::Shape &Shape) {
+ auto *CoroBegin = Shape.CoroBegin;
auto *CoroId = CoroBegin->getId();
auto *AllocInst = CoroId->getCoroAlloc();
- coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
- if (AllocInst) {
- IRBuilder<> Builder(AllocInst);
- // FIXME: Need to handle overaligned members.
- auto *Frame = Builder.CreateAlloca(FrameTy);
- auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
- AllocInst->replaceAllUsesWith(Builder.getFalse());
- AllocInst->eraseFromParent();
- CoroBegin->replaceAllUsesWith(VFrame);
- } else {
- CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
+ switch (Shape.ABI) {
+ case coro::ABI::Switch: {
+ auto SwitchId = cast<CoroIdInst>(CoroId);
+ coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
+ if (AllocInst) {
+ IRBuilder<> Builder(AllocInst);
+ // FIXME: Need to handle overaligned members.
+ auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
+ auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
+ AllocInst->replaceAllUsesWith(Builder.getFalse());
+ AllocInst->eraseFromParent();
+ CoroBegin->replaceAllUsesWith(VFrame);
+ } else {
+ CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
+ }
+ break;
+ }
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ CoroBegin->replaceAllUsesWith(UndefValue::get(CoroBegin->getType()));
+ break;
}
+
CoroBegin->eraseFromParent();
}
@@ -670,12 +1132,16 @@ static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
// Remove suspend points that are simplified.
static void simplifySuspendPoints(coro::Shape &Shape) {
+ // Currently, the only simplification we do is switch-lowering-specific.
+ if (Shape.ABI != coro::ABI::Switch)
+ return;
+
auto &S = Shape.CoroSuspends;
size_t I = 0, N = S.size();
if (N == 0)
return;
while (true) {
- if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
+ if (simplifySuspendPoint(cast<CoroSuspendInst>(S[I]), Shape.CoroBegin)) {
if (--N == I)
break;
std::swap(S[I], S[N]);
@@ -687,142 +1153,227 @@ static void simplifySuspendPoints(coro::Shape &Shape) {
S.resize(N);
}
-static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
- // Collect all blocks that we need to look for instructions to relocate.
- SmallPtrSet<BasicBlock *, 4> RelocBlocks;
- SmallVector<BasicBlock *, 4> Work;
- Work.push_back(CB->getParent());
+static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ assert(Shape.ABI == coro::ABI::Switch);
- do {
- BasicBlock *Current = Work.pop_back_val();
- for (BasicBlock *BB : predecessors(Current))
- if (RelocBlocks.count(BB) == 0) {
- RelocBlocks.insert(BB);
- Work.push_back(BB);
- }
- } while (!Work.empty());
- return RelocBlocks;
-}
-
-static SmallPtrSet<Instruction *, 8>
-getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
- SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
- SmallPtrSet<Instruction *, 8> DoNotRelocate;
- // Collect all instructions that we should not relocate
- SmallVector<Instruction *, 8> Work;
-
- // Start with CoroBegin and terminators of all preceding blocks.
- Work.push_back(CoroBegin);
- BasicBlock *CoroBeginBB = CoroBegin->getParent();
- for (BasicBlock *BB : RelocBlocks)
- if (BB != CoroBeginBB)
- Work.push_back(BB->getTerminator());
-
- // For every instruction in the Work list, place its operands in DoNotRelocate
- // set.
- do {
- Instruction *Current = Work.pop_back_val();
- LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n");
- DoNotRelocate.insert(Current);
- for (Value *U : Current->operands()) {
- auto *I = dyn_cast<Instruction>(U);
- if (!I)
- continue;
+ createResumeEntryBlock(F, Shape);
+ auto ResumeClone = createClone(F, ".resume", Shape,
+ CoroCloner::Kind::SwitchResume);
+ auto DestroyClone = createClone(F, ".destroy", Shape,
+ CoroCloner::Kind::SwitchUnwind);
+ auto CleanupClone = createClone(F, ".cleanup", Shape,
+ CoroCloner::Kind::SwitchCleanup);
- if (auto *A = dyn_cast<AllocaInst>(I)) {
- // Stores to alloca instructions that occur before the coroutine frame
- // is allocated should not be moved; the stored values may be used by
- // the coroutine frame allocator. The operands to those stores must also
- // remain in place.
- for (const auto &User : A->users())
- if (auto *SI = dyn_cast<llvm::StoreInst>(User))
- if (RelocBlocks.count(SI->getParent()) != 0 &&
- DoNotRelocate.count(SI) == 0) {
- Work.push_back(SI);
- DoNotRelocate.insert(SI);
- }
- continue;
- }
+ postSplitCleanup(*ResumeClone);
+ postSplitCleanup(*DestroyClone);
+ postSplitCleanup(*CleanupClone);
+
+ addMustTailToCoroResumes(*ResumeClone);
+
+ // Store addresses resume/destroy/cleanup functions in the coroutine frame.
+ updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+
+ assert(Clones.empty());
+ Clones.push_back(ResumeClone);
+ Clones.push_back(DestroyClone);
+ Clones.push_back(CleanupClone);
+
+ // Create a constant array referring to resume/destroy/clone functions pointed
+ // by the last argument of @llvm.coro.info, so that CoroElide pass can
+ // determined correct function to call.
+ setCoroInfo(F, Shape, Clones);
+}
- if (DoNotRelocate.count(I) == 0) {
- Work.push_back(I);
- DoNotRelocate.insert(I);
+static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ assert(Clones.empty());
+
+ // Reset various things that the optimizer might have decided it
+ // "knows" about the coroutine function due to not seeing a return.
+ F.removeFnAttr(Attribute::NoReturn);
+ F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
+ F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+
+ // Allocate the frame.
+ auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId());
+ Value *RawFramePtr;
+ if (Shape.RetconLowering.IsFrameInlineInStorage) {
+ RawFramePtr = Id->getStorage();
+ } else {
+ IRBuilder<> Builder(Id);
+
+ // Determine the size of the frame.
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ auto Size = DL.getTypeAllocSize(Shape.FrameTy);
+
+ // Allocate. We don't need to update the call graph node because we're
+ // going to recompute it from scratch after splitting.
+ RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr);
+ RawFramePtr =
+ Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
+
+ // Stash the allocated frame pointer in the continuation storage.
+ auto Dest = Builder.CreateBitCast(Id->getStorage(),
+ RawFramePtr->getType()->getPointerTo());
+ Builder.CreateStore(RawFramePtr, Dest);
+ }
+
+ // Map all uses of llvm.coro.begin to the allocated frame pointer.
+ {
+ // Make sure we don't invalidate Shape.FramePtr.
+ TrackingVH<Instruction> Handle(Shape.FramePtr);
+ Shape.CoroBegin->replaceAllUsesWith(RawFramePtr);
+ Shape.FramePtr = Handle.getValPtr();
+ }
+
+ // Create a unique return block.
+ BasicBlock *ReturnBB = nullptr;
+ SmallVector<PHINode *, 4> ReturnPHIs;
+
+ // Create all the functions in order after the main function.
+ auto NextF = std::next(F.getIterator());
+
+ // Create a continuation function for each of the suspend points.
+ Clones.reserve(Shape.CoroSuspends.size());
+ for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
+ auto Suspend = cast<CoroSuspendRetconInst>(Shape.CoroSuspends[i]);
+
+ // Create the clone declaration.
+ auto Continuation =
+ createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF);
+ Clones.push_back(Continuation);
+
+ // Insert a branch to the unified return block immediately before
+ // the suspend point.
+ auto SuspendBB = Suspend->getParent();
+ auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
+ auto Branch = cast<BranchInst>(SuspendBB->getTerminator());
+
+ // Create the unified return block.
+ if (!ReturnBB) {
+ // Place it before the first suspend.
+ ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F,
+ NewSuspendBB);
+ Shape.RetconLowering.ReturnBlock = ReturnBB;
+
+ IRBuilder<> Builder(ReturnBB);
+
+ // Create PHIs for all the return values.
+ assert(ReturnPHIs.empty());
+
+ // First, the continuation.
+ ReturnPHIs.push_back(Builder.CreatePHI(Continuation->getType(),
+ Shape.CoroSuspends.size()));
+
+ // Next, all the directly-yielded values.
+ for (auto ResultTy : Shape.getRetconResultTypes())
+ ReturnPHIs.push_back(Builder.CreatePHI(ResultTy,
+ Shape.CoroSuspends.size()));
+
+ // Build the return value.
+ auto RetTy = F.getReturnType();
+
+ // Cast the continuation value if necessary.
+ // We can't rely on the types matching up because that type would
+ // have to be infinite.
+ auto CastedContinuationTy =
+ (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0));
+ auto *CastedContinuation =
+ Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy);
+
+ Value *RetV;
+ if (ReturnPHIs.size() == 1) {
+ RetV = CastedContinuation;
+ } else {
+ RetV = UndefValue::get(RetTy);
+ RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0);
+ for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I)
+ RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I);
}
+
+ Builder.CreateRet(RetV);
}
- } while (!Work.empty());
- return DoNotRelocate;
-}
-static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
- // Analyze which non-alloca instructions are needed for allocation and
- // relocate the rest to after coro.begin. We need to do it, since some of the
- // targets of those instructions may be placed into coroutine frame memory
- // for which becomes available after coro.begin intrinsic.
+ // Branch to the return block.
+ Branch->setSuccessor(0, ReturnBB);
+ ReturnPHIs[0]->addIncoming(Continuation, SuspendBB);
+ size_t NextPHIIndex = 1;
+ for (auto &VUse : Suspend->value_operands())
+ ReturnPHIs[NextPHIIndex++]->addIncoming(&*VUse, SuspendBB);
+ assert(NextPHIIndex == ReturnPHIs.size());
+ }
- auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
- auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
+ assert(Clones.size() == Shape.CoroSuspends.size());
+ for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
+ auto Suspend = Shape.CoroSuspends[i];
+ auto Clone = Clones[i];
- Instruction *InsertPt = CoroBegin->getNextNode();
- BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
- for (auto B = BB.begin(), E = BB.end(); B != E;) {
- Instruction &I = *B++;
- if (isa<AllocaInst>(&I))
- continue;
- if (&I == CoroBegin)
- break;
- if (DoNotRelocateSet.count(&I))
- continue;
- I.moveBefore(InsertPt);
+ CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create();
+ }
+}
+
+namespace {
+ class PrettyStackTraceFunction : public PrettyStackTraceEntry {
+ Function &F;
+ public:
+ PrettyStackTraceFunction(Function &F) : F(F) {}
+ void print(raw_ostream &OS) const override {
+ OS << "While splitting coroutine ";
+ F.printAsOperand(OS, /*print type*/ false, F.getParent());
+ OS << "\n";
+ }
+ };
+}
+
+static void splitCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ return splitSwitchCoroutine(F, Shape, Clones);
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return splitRetconCoroutine(F, Shape, Clones);
}
+ llvm_unreachable("bad ABI kind");
}
static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
- EliminateUnreachableBlocks(F);
+ PrettyStackTraceFunction prettyStackTrace(F);
+
+ // The suspend-crossing algorithm in buildCoroutineFrame get tripped
+ // up by uses in unreachable blocks, so remove them as a first pass.
+ removeUnreachableBlocks(F);
coro::Shape Shape(F);
if (!Shape.CoroBegin)
return;
simplifySuspendPoints(Shape);
- relocateInstructionBefore(Shape.CoroBegin, F);
buildCoroutineFrame(F, Shape);
replaceFrameSize(Shape);
+ SmallVector<Function*, 4> Clones;
+
// If there are no suspend points, no split required, just remove
// the allocation and deallocation blocks, they are not needed.
if (Shape.CoroSuspends.empty()) {
- handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
- removeCoroEnds(Shape);
- postSplitCleanup(F);
- coro::updateCallGraph(F, {}, CG, SCC);
- return;
+ handleNoSuspendCoroutine(Shape);
+ } else {
+ splitCoroutine(F, Shape, Clones);
}
- auto *ResumeEntry = createResumeEntryBlock(F, Shape);
- auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
- auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
- auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
-
- // We no longer need coro.end in F.
- removeCoroEnds(Shape);
+ // Replace all the swifterror operations in the original function.
+ // This invalidates SwiftErrorOps in the Shape.
+ replaceSwiftErrorOps(F, Shape, nullptr);
+ removeCoroEnds(Shape, &CG);
postSplitCleanup(F);
- postSplitCleanup(*ResumeClone);
- postSplitCleanup(*DestroyClone);
- postSplitCleanup(*CleanupClone);
-
- addMustTailToCoroResumes(*ResumeClone);
-
- // Store addresses resume/destroy/cleanup functions in the coroutine frame.
- updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
-
- // Create a constant array referring to resume/destroy/clone functions pointed
- // by the last argument of @llvm.coro.info, so that CoroElide pass can
- // determined correct function to call.
- setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
// Update call graph and add the functions we created to the SCC.
- coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
+ coro::updateCallGraph(F, Clones, CG, SCC);
}
// When we see the coroutine the first time, we insert an indirect call to a
@@ -881,6 +1432,80 @@ static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
SCC.initialize(Nodes);
}
+/// Replace a call to llvm.coro.prepare.retcon.
+static void replacePrepare(CallInst *Prepare, CallGraph &CG) {
+ auto CastFn = Prepare->getArgOperand(0); // as an i8*
+ auto Fn = CastFn->stripPointerCasts(); // as its original type
+
+ // Find call graph nodes for the preparation.
+ CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr;
+ if (auto ConcreteFn = dyn_cast<Function>(Fn)) {
+ PrepareUserNode = CG[Prepare->getFunction()];
+ FnNode = CG[ConcreteFn];
+ }
+
+ // Attempt to peephole this pattern:
+ // %0 = bitcast [[TYPE]] @some_function to i8*
+ // %1 = call @llvm.coro.prepare.retcon(i8* %0)
+ // %2 = bitcast %1 to [[TYPE]]
+ // ==>
+ // %2 = @some_function
+ for (auto UI = Prepare->use_begin(), UE = Prepare->use_end();
+ UI != UE; ) {
+ // Look for bitcasts back to the original function type.
+ auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser());
+ if (!Cast || Cast->getType() != Fn->getType()) continue;
+
+ // Check whether the replacement will introduce new direct calls.
+ // If so, we'll need to update the call graph.
+ if (PrepareUserNode) {
+ for (auto &Use : Cast->uses()) {
+ if (auto *CB = dyn_cast<CallBase>(Use.getUser())) {
+ if (!CB->isCallee(&Use))
+ continue;
+ PrepareUserNode->removeCallEdgeFor(*CB);
+ PrepareUserNode->addCalledFunction(CB, FnNode);
+ }
+ }
+ }
+
+ // Replace and remove the cast.
+ Cast->replaceAllUsesWith(Fn);
+ Cast->eraseFromParent();
+ }
+
+ // Replace any remaining uses with the function as an i8*.
+ // This can never directly be a callee, so we don't need to update CG.
+ Prepare->replaceAllUsesWith(CastFn);
+ Prepare->eraseFromParent();
+
+ // Kill dead bitcasts.
+ while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) {
+ if (!Cast->use_empty()) break;
+ CastFn = Cast->getOperand(0);
+ Cast->eraseFromParent();
+ }
+}
+
+/// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent
+/// IPO from operating on calls to a retcon coroutine before it's been
+/// split. This is only safe to do after we've split all retcon
+/// coroutines in the module. We can do that this in this pass because
+/// this pass does promise to split all retcon coroutines (as opposed to
+/// switch coroutines, which are lowered in multiple stages).
+static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) {
+ bool Changed = false;
+ for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end();
+ PI != PE; ) {
+ // Intrinsics can only be used in calls.
+ auto *Prepare = cast<CallInst>((PI++)->getUser());
+ replacePrepare(Prepare, CG);
+ Changed = true;
+ }
+
+ return Changed;
+}
+
//===----------------------------------------------------------------------===//
// Top Level Driver
//===----------------------------------------------------------------------===//
@@ -899,7 +1524,9 @@ struct CoroSplit : public CallGraphSCCPass {
// A coroutine is identified by the presence of coro.begin intrinsic, if
// we don't have any, this pass has nothing to do.
bool doInitialization(CallGraph &CG) override {
- Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
+ Run = coro::declaresIntrinsics(CG.getModule(),
+ {"llvm.coro.begin",
+ "llvm.coro.prepare.retcon"});
return CallGraphSCCPass::doInitialization(CG);
}
@@ -907,6 +1534,12 @@ struct CoroSplit : public CallGraphSCCPass {
if (!Run)
return false;
+ // Check for uses of llvm.coro.prepare.retcon.
+ auto PrepareFn =
+ SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon");
+ if (PrepareFn && PrepareFn->use_empty())
+ PrepareFn = nullptr;
+
// Find coroutines for processing.
SmallVector<Function *, 4> Coroutines;
for (CallGraphNode *CGN : SCC)
@@ -914,12 +1547,17 @@ struct CoroSplit : public CallGraphSCCPass {
if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
Coroutines.push_back(F);
- if (Coroutines.empty())
+ if (Coroutines.empty() && !PrepareFn)
return false;
CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
+
+ if (Coroutines.empty())
+ return replaceAllPrepares(PrepareFn, CG);
+
createDevirtTriggerFunc(CG, SCC);
+ // Split all the coroutines.
for (Function *F : Coroutines) {
Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
StringRef Value = Attr.getValueAsString();
@@ -932,6 +1570,10 @@ struct CoroSplit : public CallGraphSCCPass {
F->removeFnAttr(CORO_PRESPLIT_ATTR);
splitCoroutine(*F, CG, SCC);
}
+
+ if (PrepareFn)
+ replaceAllPrepares(PrepareFn, CG);
+
return true;
}
diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp
index a581d1d21169..f39483b27518 100644
--- a/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/lib/Transforms/Coroutines/Coroutines.cpp
@@ -123,12 +123,26 @@ Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
static bool isCoroutineIntrinsicName(StringRef Name) {
// NOTE: Must be sorted!
static const char *const CoroIntrinsics[] = {
- "llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.destroy",
- "llvm.coro.done", "llvm.coro.end", "llvm.coro.frame",
- "llvm.coro.free", "llvm.coro.id", "llvm.coro.noop",
- "llvm.coro.param", "llvm.coro.promise", "llvm.coro.resume",
- "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr",
+ "llvm.coro.alloc",
+ "llvm.coro.begin",
+ "llvm.coro.destroy",
+ "llvm.coro.done",
+ "llvm.coro.end",
+ "llvm.coro.frame",
+ "llvm.coro.free",
+ "llvm.coro.id",
+ "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once",
+ "llvm.coro.noop",
+ "llvm.coro.param",
+ "llvm.coro.prepare.retcon",
+ "llvm.coro.promise",
+ "llvm.coro.resume",
+ "llvm.coro.save",
+ "llvm.coro.size",
+ "llvm.coro.subfn.addr",
"llvm.coro.suspend",
+ "llvm.coro.suspend.retcon",
};
return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
}
@@ -217,9 +231,6 @@ static void clear(coro::Shape &Shape) {
Shape.FrameTy = nullptr;
Shape.FramePtr = nullptr;
Shape.AllocaSpillBlock = nullptr;
- Shape.ResumeSwitch = nullptr;
- Shape.PromiseAlloca = nullptr;
- Shape.HasFinalSuspend = false;
}
static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
@@ -235,6 +246,7 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
// Collect "interesting" coroutine intrinsics.
void coro::Shape::buildFrom(Function &F) {
+ bool HasFinalSuspend = false;
size_t FinalSuspendIndex = 0;
clear(*this);
SmallVector<CoroFrameInst *, 8> CoroFrames;
@@ -257,9 +269,15 @@ void coro::Shape::buildFrom(Function &F) {
if (II->use_empty())
UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
break;
- case Intrinsic::coro_suspend:
- CoroSuspends.push_back(cast<CoroSuspendInst>(II));
- if (CoroSuspends.back()->isFinal()) {
+ case Intrinsic::coro_suspend_retcon: {
+ auto Suspend = cast<CoroSuspendRetconInst>(II);
+ CoroSuspends.push_back(Suspend);
+ break;
+ }
+ case Intrinsic::coro_suspend: {
+ auto Suspend = cast<CoroSuspendInst>(II);
+ CoroSuspends.push_back(Suspend);
+ if (Suspend->isFinal()) {
if (HasFinalSuspend)
report_fatal_error(
"Only one suspend point can be marked as final");
@@ -267,18 +285,23 @@ void coro::Shape::buildFrom(Function &F) {
FinalSuspendIndex = CoroSuspends.size() - 1;
}
break;
+ }
case Intrinsic::coro_begin: {
auto CB = cast<CoroBeginInst>(II);
- if (CB->getId()->getInfo().isPreSplit()) {
- if (CoroBegin)
- report_fatal_error(
+
+ // Ignore coro id's that aren't pre-split.
+ auto Id = dyn_cast<CoroIdInst>(CB->getId());
+ if (Id && !Id->getInfo().isPreSplit())
+ break;
+
+ if (CoroBegin)
+ report_fatal_error(
"coroutine should have exactly one defining @llvm.coro.begin");
- CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
- CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
- CB->removeAttribute(AttributeList::FunctionIndex,
- Attribute::NoDuplicate);
- CoroBegin = CB;
- }
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
+ CB->removeAttribute(AttributeList::FunctionIndex,
+ Attribute::NoDuplicate);
+ CoroBegin = CB;
break;
}
case Intrinsic::coro_end:
@@ -310,7 +333,7 @@ void coro::Shape::buildFrom(Function &F) {
// Replace all coro.suspend with undef and remove related coro.saves if
// present.
- for (CoroSuspendInst *CS : CoroSuspends) {
+ for (AnyCoroSuspendInst *CS : CoroSuspends) {
CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
CS->eraseFromParent();
if (auto *CoroSave = CS->getCoroSave())
@@ -324,19 +347,136 @@ void coro::Shape::buildFrom(Function &F) {
return;
}
+ auto Id = CoroBegin->getId();
+ switch (auto IdIntrinsic = Id->getIntrinsicID()) {
+ case Intrinsic::coro_id: {
+ auto SwitchId = cast<CoroIdInst>(Id);
+ this->ABI = coro::ABI::Switch;
+ this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
+ this->SwitchLowering.ResumeSwitch = nullptr;
+ this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
+ this->SwitchLowering.ResumeEntryBlock = nullptr;
+
+ for (auto AnySuspend : CoroSuspends) {
+ auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
+ if (!Suspend) {
+#ifndef NDEBUG
+ AnySuspend->dump();
+#endif
+ report_fatal_error("coro.id must be paired with coro.suspend");
+ }
+
+ if (!Suspend->getCoroSave())
+ createCoroSave(CoroBegin, Suspend);
+ }
+ break;
+ }
+
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once: {
+ auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
+ ContinuationId->checkWellFormed();
+ this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
+ ? coro::ABI::Retcon
+ : coro::ABI::RetconOnce);
+ auto Prototype = ContinuationId->getPrototype();
+ this->RetconLowering.ResumePrototype = Prototype;
+ this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
+ this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
+ this->RetconLowering.ReturnBlock = nullptr;
+ this->RetconLowering.IsFrameInlineInStorage = false;
+
+ // Determine the result value types, and make sure they match up with
+ // the values passed to the suspends.
+ auto ResultTys = getRetconResultTypes();
+ auto ResumeTys = getRetconResumeTypes();
+
+ for (auto AnySuspend : CoroSuspends) {
+ auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
+ if (!Suspend) {
+#ifndef NDEBUG
+ AnySuspend->dump();
+#endif
+ report_fatal_error("coro.id.retcon.* must be paired with "
+ "coro.suspend.retcon");
+ }
+
+ // Check that the argument types of the suspend match the results.
+ auto SI = Suspend->value_begin(), SE = Suspend->value_end();
+ auto RI = ResultTys.begin(), RE = ResultTys.end();
+ for (; SI != SE && RI != RE; ++SI, ++RI) {
+ auto SrcTy = (*SI)->getType();
+ if (SrcTy != *RI) {
+ // The optimizer likes to eliminate bitcasts leading into variadic
+ // calls, but that messes with our invariants. Re-insert the
+ // bitcast and ignore this type mismatch.
+ if (CastInst::isBitCastable(SrcTy, *RI)) {
+ auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
+ SI->set(BCI);
+ continue;
+ }
+
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("argument to coro.suspend.retcon does not "
+ "match corresponding prototype function result");
+ }
+ }
+ if (SI != SE || RI != RE) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("wrong number of arguments to coro.suspend.retcon");
+ }
+
+ // Check that the result type of the suspend matches the resume types.
+ Type *SResultTy = Suspend->getType();
+ ArrayRef<Type*> SuspendResultTys;
+ if (SResultTy->isVoidTy()) {
+ // leave as empty array
+ } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
+ SuspendResultTys = SResultStructTy->elements();
+ } else {
+ // forms an ArrayRef using SResultTy, be careful
+ SuspendResultTys = SResultTy;
+ }
+ if (SuspendResultTys.size() != ResumeTys.size()) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("wrong number of results from coro.suspend.retcon");
+ }
+ for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
+ if (SuspendResultTys[I] != ResumeTys[I]) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("result from coro.suspend.retcon does not "
+ "match corresponding prototype function param");
+ }
+ }
+ }
+ break;
+ }
+
+ default:
+ llvm_unreachable("coro.begin is not dependent on a coro.id call");
+ }
+
// The coro.free intrinsic is always lowered to the result of coro.begin.
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(CoroBegin);
CF->eraseFromParent();
}
- // Canonicalize coro.suspend by inserting a coro.save if needed.
- for (CoroSuspendInst *CS : CoroSuspends)
- if (!CS->getCoroSave())
- createCoroSave(CoroBegin, CS);
-
// Move final suspend to be the last element in the CoroSuspends vector.
- if (HasFinalSuspend &&
+ if (ABI == coro::ABI::Switch &&
+ SwitchLowering.HasFinalSuspend &&
FinalSuspendIndex != CoroSuspends.size() - 1)
std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
@@ -345,6 +485,154 @@ void coro::Shape::buildFrom(Function &F) {
CoroSave->eraseFromParent();
}
+static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
+ Call->setCallingConv(Callee->getCallingConv());
+ // TODO: attributes?
+}
+
+static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
+ if (CG)
+ (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
+}
+
+Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
+ CallGraph *CG) const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ llvm_unreachable("can't allocate memory in coro switch-lowering");
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto Alloc = RetconLowering.Alloc;
+ Size = Builder.CreateIntCast(Size,
+ Alloc->getFunctionType()->getParamType(0),
+ /*is signed*/ false);
+ auto *Call = Builder.CreateCall(Alloc, Size);
+ propagateCallAttrsFromCallee(Call, Alloc);
+ addCallToCallGraph(CG, Call, Alloc);
+ return Call;
+ }
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+}
+
+void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
+ CallGraph *CG) const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ llvm_unreachable("can't allocate memory in coro switch-lowering");
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto Dealloc = RetconLowering.Dealloc;
+ Ptr = Builder.CreateBitCast(Ptr,
+ Dealloc->getFunctionType()->getParamType(0));
+ auto *Call = Builder.CreateCall(Dealloc, Ptr);
+ propagateCallAttrsFromCallee(Call, Dealloc);
+ addCallToCallGraph(CG, Call, Dealloc);
+ return;
+ }
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+}
+
+LLVM_ATTRIBUTE_NORETURN
+static void fail(const Instruction *I, const char *Reason, Value *V) {
+#ifndef NDEBUG
+ I->dump();
+ if (V) {
+ errs() << " Value: ";
+ V->printAsOperand(llvm::errs());
+ errs() << '\n';
+ }
+#endif
+ report_fatal_error(Reason);
+}
+
+/// Check that the given value is a well-formed prototype for the
+/// llvm.coro.id.retcon.* intrinsics.
+static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
+
+ auto FT = F->getFunctionType();
+
+ if (isa<CoroIdRetconInst>(I)) {
+ bool ResultOkay;
+ if (FT->getReturnType()->isPointerTy()) {
+ ResultOkay = true;
+ } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
+ ResultOkay = (!SRetTy->isOpaque() &&
+ SRetTy->getNumElements() > 0 &&
+ SRetTy->getElementType(0)->isPointerTy());
+ } else {
+ ResultOkay = false;
+ }
+ if (!ResultOkay)
+ fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
+ "result", F);
+
+ if (FT->getReturnType() !=
+ I->getFunction()->getFunctionType()->getReturnType())
+ fail(I, "llvm.coro.id.retcon prototype return type must be same as"
+ "current function return type", F);
+ } else {
+ // No meaningful validation to do here for llvm.coro.id.unique.once.
+ }
+
+ if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
+ fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
+ "its first parameter", F);
+}
+
+/// Check that the given value is a well-formed allocator.
+static void checkWFAlloc(const Instruction *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.* allocator not a Function", V);
+
+ auto FT = F->getFunctionType();
+ if (!FT->getReturnType()->isPointerTy())
+ fail(I, "llvm.coro.* allocator must return a pointer", F);
+
+ if (FT->getNumParams() != 1 ||
+ !FT->getParamType(0)->isIntegerTy())
+ fail(I, "llvm.coro.* allocator must take integer as only param", F);
+}
+
+/// Check that the given value is a well-formed deallocator.
+static void checkWFDealloc(const Instruction *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.* deallocator not a Function", V);
+
+ auto FT = F->getFunctionType();
+ if (!FT->getReturnType()->isVoidTy())
+ fail(I, "llvm.coro.* deallocator must return void", F);
+
+ if (FT->getNumParams() != 1 ||
+ !FT->getParamType(0)->isPointerTy())
+ fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
+}
+
+static void checkConstantInt(const Instruction *I, Value *V,
+ const char *Reason) {
+ if (!isa<ConstantInt>(V)) {
+ fail(I, Reason, V);
+ }
+}
+
+void AnyCoroIdRetconInst::checkWellFormed() const {
+ checkConstantInt(this, getArgOperand(SizeArg),
+ "size argument to coro.id.retcon.* must be constant");
+ checkConstantInt(this, getArgOperand(AlignArg),
+ "alignment argument to coro.id.retcon.* must be constant");
+ checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
+ checkWFAlloc(this, getArgOperand(AllocArg));
+ checkWFDealloc(this, getArgOperand(DeallocArg));
+}
+
void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createCoroEarlyPass());
}
diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp
index 95a9f31cced3..dd9f74a881ee 100644
--- a/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -304,7 +304,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote,
// of the previous load.
LoadInst *newLoad =
IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val");
- newLoad->setAlignment(OrigLoad->getAlignment());
+ newLoad->setAlignment(MaybeAlign(OrigLoad->getAlignment()));
// Transfer the AA info too.
AAMDNodes AAInfo;
OrigLoad->getAAMetadata(AAInfo);
diff --git a/lib/Transforms/IPO/Attributor.cpp b/lib/Transforms/IPO/Attributor.cpp
index 2a52c6b9b4ad..95f47345d8fd 100644
--- a/lib/Transforms/IPO/Attributor.cpp
+++ b/lib/Transforms/IPO/Attributor.cpp
@@ -16,11 +16,15 @@
#include "llvm/Transforms/IPO/Attributor.h"
#include "llvm/ADT/DepthFirstIterator.h"
-#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
@@ -30,6 +34,9 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+
#include <cassert>
using namespace llvm;
@@ -46,19 +53,50 @@ STATISTIC(NumAttributesValidFixpoint,
"Number of abstract attributes in a valid fixpoint state");
STATISTIC(NumAttributesManifested,
"Number of abstract attributes manifested in IR");
-STATISTIC(NumFnNoUnwind, "Number of functions marked nounwind");
-
-STATISTIC(NumFnUniqueReturned, "Number of function with unique return");
-STATISTIC(NumFnKnownReturns, "Number of function with known return values");
-STATISTIC(NumFnArgumentReturned,
- "Number of function arguments marked returned");
-STATISTIC(NumFnNoSync, "Number of functions marked nosync");
-STATISTIC(NumFnNoFree, "Number of functions marked nofree");
-STATISTIC(NumFnReturnedNonNull,
- "Number of function return values marked nonnull");
-STATISTIC(NumFnArgumentNonNull, "Number of function arguments marked nonnull");
-STATISTIC(NumCSArgumentNonNull, "Number of call site arguments marked nonnull");
-STATISTIC(NumFnWillReturn, "Number of functions marked willreturn");
+
+// Some helper macros to deal with statistics tracking.
+//
+// Usage:
+// For simple IR attribute tracking overload trackStatistics in the abstract
+// attribute and choose the right STATS_DECLTRACK_********* macro,
+// e.g.,:
+// void trackStatistics() const override {
+// STATS_DECLTRACK_ARG_ATTR(returned)
+// }
+// If there is a single "increment" side one can use the macro
+// STATS_DECLTRACK with a custom message. If there are multiple increment
+// sides, STATS_DECL and STATS_TRACK can also be used separatly.
+//
+#define BUILD_STAT_MSG_IR_ATTR(TYPE, NAME) \
+ ("Number of " #TYPE " marked '" #NAME "'")
+#define BUILD_STAT_NAME(NAME, TYPE) NumIR##TYPE##_##NAME
+#define STATS_DECL_(NAME, MSG) STATISTIC(NAME, MSG);
+#define STATS_DECL(NAME, TYPE, MSG) \
+ STATS_DECL_(BUILD_STAT_NAME(NAME, TYPE), MSG);
+#define STATS_TRACK(NAME, TYPE) ++(BUILD_STAT_NAME(NAME, TYPE));
+#define STATS_DECLTRACK(NAME, TYPE, MSG) \
+ { \
+ STATS_DECL(NAME, TYPE, MSG) \
+ STATS_TRACK(NAME, TYPE) \
+ }
+#define STATS_DECLTRACK_ARG_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, Arguments, BUILD_STAT_MSG_IR_ATTR(arguments, NAME))
+#define STATS_DECLTRACK_CSARG_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, CSArguments, \
+ BUILD_STAT_MSG_IR_ATTR(call site arguments, NAME))
+#define STATS_DECLTRACK_FN_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, Function, BUILD_STAT_MSG_IR_ATTR(functions, NAME))
+#define STATS_DECLTRACK_CS_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, CS, BUILD_STAT_MSG_IR_ATTR(call site, NAME))
+#define STATS_DECLTRACK_FNRET_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, FunctionReturn, \
+ BUILD_STAT_MSG_IR_ATTR(function returns, NAME))
+#define STATS_DECLTRACK_CSRET_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, CSReturn, \
+ BUILD_STAT_MSG_IR_ATTR(call site returns, NAME))
+#define STATS_DECLTRACK_FLOATING_ATTR(NAME) \
+ STATS_DECLTRACK(NAME, Floating, \
+ ("Number of floating values known to be '" #NAME "'"))
// TODO: Determine a good default value.
//
@@ -72,18 +110,32 @@ static cl::opt<unsigned>
MaxFixpointIterations("attributor-max-iterations", cl::Hidden,
cl::desc("Maximal number of fixpoint iterations."),
cl::init(32));
+static cl::opt<bool> VerifyMaxFixpointIterations(
+ "attributor-max-iterations-verify", cl::Hidden,
+ cl::desc("Verify that max-iterations is a tight bound for a fixpoint"),
+ cl::init(false));
static cl::opt<bool> DisableAttributor(
"attributor-disable", cl::Hidden,
cl::desc("Disable the attributor inter-procedural deduction pass."),
cl::init(true));
-static cl::opt<bool> VerifyAttributor(
- "attributor-verify", cl::Hidden,
- cl::desc("Verify the Attributor deduction and "
- "manifestation of attributes -- may issue false-positive errors"),
+static cl::opt<bool> ManifestInternal(
+ "attributor-manifest-internal", cl::Hidden,
+ cl::desc("Manifest Attributor internal string attributes."),
cl::init(false));
+static cl::opt<unsigned> DepRecInterval(
+ "attributor-dependence-recompute-interval", cl::Hidden,
+ cl::desc("Number of iterations until dependences are recomputed."),
+ cl::init(4));
+
+static cl::opt<bool> EnableHeapToStack("enable-heap-to-stack-conversion",
+ cl::init(true), cl::Hidden);
+
+static cl::opt<int> MaxHeapToStackSize("max-heap-to-stack-size", cl::init(128),
+ cl::Hidden);
+
/// Logic operators for the change status enum class.
///
///{
@@ -95,78 +147,30 @@ ChangeStatus llvm::operator&(ChangeStatus l, ChangeStatus r) {
}
///}
-/// Helper to adjust the statistics.
-static void bookkeeping(AbstractAttribute::ManifestPosition MP,
- const Attribute &Attr) {
- if (!AreStatisticsEnabled())
- return;
-
- if (!Attr.isEnumAttribute())
- return;
- switch (Attr.getKindAsEnum()) {
- case Attribute::NoUnwind:
- NumFnNoUnwind++;
- return;
- case Attribute::Returned:
- NumFnArgumentReturned++;
- return;
- case Attribute::NoSync:
- NumFnNoSync++;
- break;
- case Attribute::NoFree:
- NumFnNoFree++;
- break;
- case Attribute::NonNull:
- switch (MP) {
- case AbstractAttribute::MP_RETURNED:
- NumFnReturnedNonNull++;
- break;
- case AbstractAttribute::MP_ARGUMENT:
- NumFnArgumentNonNull++;
- break;
- case AbstractAttribute::MP_CALL_SITE_ARGUMENT:
- NumCSArgumentNonNull++;
- break;
- default:
- break;
- }
- break;
- case Attribute::WillReturn:
- NumFnWillReturn++;
- break;
- default:
- return;
- }
-}
-
-template <typename StateTy>
-using followValueCB_t = std::function<bool(Value *, StateTy &State)>;
-template <typename StateTy>
-using visitValueCB_t = std::function<void(Value *, StateTy &State)>;
-
-/// Recursively visit all values that might become \p InitV at some point. This
+/// Recursively visit all values that might become \p IRP at some point. This
/// will be done by looking through cast instructions, selects, phis, and calls
-/// with the "returned" attribute. The callback \p FollowValueCB is asked before
-/// a potential origin value is looked at. If no \p FollowValueCB is passed, a
-/// default one is used that will make sure we visit every value only once. Once
-/// we cannot look through the value any further, the callback \p VisitValueCB
-/// is invoked and passed the current value and the \p State. To limit how much
-/// effort is invested, we will never visit more than \p MaxValues values.
-template <typename StateTy>
+/// with the "returned" attribute. Once we cannot look through the value any
+/// further, the callback \p VisitValueCB is invoked and passed the current
+/// value, the \p State, and a flag to indicate if we stripped anything. To
+/// limit how much effort is invested, we will never visit more values than
+/// specified by \p MaxValues.
+template <typename AAType, typename StateTy>
static bool genericValueTraversal(
- Value *InitV, StateTy &State, visitValueCB_t<StateTy> &VisitValueCB,
- followValueCB_t<StateTy> *FollowValueCB = nullptr, int MaxValues = 8) {
-
+ Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State,
+ const function_ref<bool(Value &, StateTy &, bool)> &VisitValueCB,
+ int MaxValues = 8) {
+
+ const AAIsDead *LivenessAA = nullptr;
+ if (IRP.getAnchorScope())
+ LivenessAA = &A.getAAFor<AAIsDead>(
+ QueryingAA, IRPosition::function(*IRP.getAnchorScope()),
+ /* TrackDependence */ false);
+ bool AnyDead = false;
+
+ // TODO: Use Positions here to allow context sensitivity in VisitValueCB
SmallPtrSet<Value *, 16> Visited;
- followValueCB_t<bool> DefaultFollowValueCB = [&](Value *Val, bool &) {
- return Visited.insert(Val).second;
- };
-
- if (!FollowValueCB)
- FollowValueCB = &DefaultFollowValueCB;
-
SmallVector<Value *, 16> Worklist;
- Worklist.push_back(InitV);
+ Worklist.push_back(&IRP.getAssociatedValue());
int Iteration = 0;
do {
@@ -174,7 +178,7 @@ static bool genericValueTraversal(
// Check if we should process the current value. To prevent endless
// recursion keep a record of the values we followed!
- if (!(*FollowValueCB)(V, State))
+ if (!Visited.insert(V).second)
continue;
// Make sure we limit the compile time for complex expressions.
@@ -183,23 +187,23 @@ static bool genericValueTraversal(
// Explicitly look through calls with a "returned" attribute if we do
// not have a pointer as stripPointerCasts only works on them.
+ Value *NewV = nullptr;
if (V->getType()->isPointerTy()) {
- V = V->stripPointerCasts();
+ NewV = V->stripPointerCasts();
} else {
CallSite CS(V);
if (CS && CS.getCalledFunction()) {
- Value *NewV = nullptr;
for (Argument &Arg : CS.getCalledFunction()->args())
if (Arg.hasReturnedAttr()) {
NewV = CS.getArgOperand(Arg.getArgNo());
break;
}
- if (NewV) {
- Worklist.push_back(NewV);
- continue;
- }
}
}
+ if (NewV && NewV != V) {
+ Worklist.push_back(NewV);
+ continue;
+ }
// Look through select instructions, visit both potential values.
if (auto *SI = dyn_cast<SelectInst>(V)) {
@@ -208,35 +212,34 @@ static bool genericValueTraversal(
continue;
}
- // Look through phi nodes, visit all operands.
+ // Look through phi nodes, visit all live operands.
if (auto *PHI = dyn_cast<PHINode>(V)) {
- Worklist.append(PHI->op_begin(), PHI->op_end());
+ assert(LivenessAA &&
+ "Expected liveness in the presence of instructions!");
+ for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
+ const BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
+ if (LivenessAA->isAssumedDead(IncomingBB->getTerminator())) {
+ AnyDead = true;
+ continue;
+ }
+ Worklist.push_back(PHI->getIncomingValue(u));
+ }
continue;
}
// Once a leaf is reached we inform the user through the callback.
- VisitValueCB(V, State);
+ if (!VisitValueCB(*V, State, Iteration > 1))
+ return false;
} while (!Worklist.empty());
+ // If we actually used liveness information so we have to record a dependence.
+ if (AnyDead)
+ A.recordDependence(*LivenessAA, QueryingAA);
+
// All values have been visited.
return true;
}
-/// Helper to identify the correct offset into an attribute list.
-static unsigned getAttrIndex(AbstractAttribute::ManifestPosition MP,
- unsigned ArgNo = 0) {
- switch (MP) {
- case AbstractAttribute::MP_ARGUMENT:
- case AbstractAttribute::MP_CALL_SITE_ARGUMENT:
- return ArgNo + AttributeList::FirstArgIndex;
- case AbstractAttribute::MP_FUNCTION:
- return AttributeList::FunctionIndex;
- case AbstractAttribute::MP_RETURNED:
- return AttributeList::ReturnIndex;
- }
- llvm_unreachable("Unknown manifest position!");
-}
-
/// Return true if \p New is equal or worse than \p Old.
static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) {
if (!Old.isIntAttribute())
@@ -247,12 +250,9 @@ static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) {
/// Return true if the information provided by \p Attr was added to the
/// attribute list \p Attrs. This is only the case if it was not already present
-/// in \p Attrs at the position describe by \p MP and \p ArgNo.
+/// in \p Attrs at the position describe by \p PK and \p AttrIdx.
static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr,
- AttributeList &Attrs,
- AbstractAttribute::ManifestPosition MP,
- unsigned ArgNo = 0) {
- unsigned AttrIdx = getAttrIndex(MP, ArgNo);
+ AttributeList &Attrs, int AttrIdx) {
if (Attr.isEnumAttribute()) {
Attribute::AttrKind Kind = Attr.getKindAsEnum();
@@ -270,9 +270,47 @@ static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr,
Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr);
return true;
}
+ if (Attr.isIntAttribute()) {
+ Attribute::AttrKind Kind = Attr.getKindAsEnum();
+ if (Attrs.hasAttribute(AttrIdx, Kind))
+ if (isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind)))
+ return false;
+ Attrs = Attrs.removeAttribute(Ctx, AttrIdx, Kind);
+ Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr);
+ return true;
+ }
llvm_unreachable("Expected enum or string attribute!");
}
+static const Value *getPointerOperand(const Instruction *I) {
+ if (auto *LI = dyn_cast<LoadInst>(I))
+ if (!LI->isVolatile())
+ return LI->getPointerOperand();
+
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ if (!SI->isVolatile())
+ return SI->getPointerOperand();
+
+ if (auto *CXI = dyn_cast<AtomicCmpXchgInst>(I))
+ if (!CXI->isVolatile())
+ return CXI->getPointerOperand();
+
+ if (auto *RMWI = dyn_cast<AtomicRMWInst>(I))
+ if (!RMWI->isVolatile())
+ return RMWI->getPointerOperand();
+
+ return nullptr;
+}
+static const Value *getBasePointerOfAccessPointerOperand(const Instruction *I,
+ int64_t &BytesOffset,
+ const DataLayout &DL) {
+ const Value *Ptr = getPointerOperand(I);
+ if (!Ptr)
+ return nullptr;
+
+ return GetPointerBaseWithConstantOffset(Ptr, BytesOffset, DL,
+ /*AllowNonInbounds*/ false);
+}
ChangeStatus AbstractAttribute::update(Attributor &A) {
ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
@@ -289,143 +327,527 @@ ChangeStatus AbstractAttribute::update(Attributor &A) {
return HasChanged;
}
-ChangeStatus AbstractAttribute::manifest(Attributor &A) {
- assert(getState().isValidState() &&
- "Attempted to manifest an invalid state!");
- assert(getAssociatedValue() &&
- "Attempted to manifest an attribute without associated value!");
-
- ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
- SmallVector<Attribute, 4> DeducedAttrs;
- getDeducedAttributes(DeducedAttrs);
-
- Function &ScopeFn = getAnchorScope();
- LLVMContext &Ctx = ScopeFn.getContext();
- ManifestPosition MP = getManifestPosition();
-
- AttributeList Attrs;
- SmallVector<unsigned, 4> ArgNos;
+ChangeStatus
+IRAttributeManifest::manifestAttrs(Attributor &A, IRPosition &IRP,
+ const ArrayRef<Attribute> &DeducedAttrs) {
+ Function *ScopeFn = IRP.getAssociatedFunction();
+ IRPosition::Kind PK = IRP.getPositionKind();
// In the following some generic code that will manifest attributes in
// DeducedAttrs if they improve the current IR. Due to the different
// annotation positions we use the underlying AttributeList interface.
- // Note that MP_CALL_SITE_ARGUMENT can annotate multiple locations.
- switch (MP) {
- case MP_ARGUMENT:
- ArgNos.push_back(cast<Argument>(getAssociatedValue())->getArgNo());
- Attrs = ScopeFn.getAttributes();
+ AttributeList Attrs;
+ switch (PK) {
+ case IRPosition::IRP_INVALID:
+ case IRPosition::IRP_FLOAT:
+ return ChangeStatus::UNCHANGED;
+ case IRPosition::IRP_ARGUMENT:
+ case IRPosition::IRP_FUNCTION:
+ case IRPosition::IRP_RETURNED:
+ Attrs = ScopeFn->getAttributes();
break;
- case MP_FUNCTION:
- case MP_RETURNED:
- ArgNos.push_back(0);
- Attrs = ScopeFn.getAttributes();
+ case IRPosition::IRP_CALL_SITE:
+ case IRPosition::IRP_CALL_SITE_RETURNED:
+ case IRPosition::IRP_CALL_SITE_ARGUMENT:
+ Attrs = ImmutableCallSite(&IRP.getAnchorValue()).getAttributes();
break;
- case MP_CALL_SITE_ARGUMENT: {
- CallSite CS(&getAnchoredValue());
- for (unsigned u = 0, e = CS.getNumArgOperands(); u != e; u++)
- if (CS.getArgOperand(u) == getAssociatedValue())
- ArgNos.push_back(u);
- Attrs = CS.getAttributes();
- }
}
+ ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
+ LLVMContext &Ctx = IRP.getAnchorValue().getContext();
for (const Attribute &Attr : DeducedAttrs) {
- for (unsigned ArgNo : ArgNos) {
- if (!addIfNotExistent(Ctx, Attr, Attrs, MP, ArgNo))
- continue;
+ if (!addIfNotExistent(Ctx, Attr, Attrs, IRP.getAttrIdx()))
+ continue;
- HasChanged = ChangeStatus::CHANGED;
- bookkeeping(MP, Attr);
- }
+ HasChanged = ChangeStatus::CHANGED;
}
if (HasChanged == ChangeStatus::UNCHANGED)
return HasChanged;
- switch (MP) {
- case MP_ARGUMENT:
- case MP_FUNCTION:
- case MP_RETURNED:
- ScopeFn.setAttributes(Attrs);
+ switch (PK) {
+ case IRPosition::IRP_ARGUMENT:
+ case IRPosition::IRP_FUNCTION:
+ case IRPosition::IRP_RETURNED:
+ ScopeFn->setAttributes(Attrs);
+ break;
+ case IRPosition::IRP_CALL_SITE:
+ case IRPosition::IRP_CALL_SITE_RETURNED:
+ case IRPosition::IRP_CALL_SITE_ARGUMENT:
+ CallSite(&IRP.getAnchorValue()).setAttributes(Attrs);
+ break;
+ case IRPosition::IRP_INVALID:
+ case IRPosition::IRP_FLOAT:
break;
- case MP_CALL_SITE_ARGUMENT:
- CallSite(&getAnchoredValue()).setAttributes(Attrs);
}
return HasChanged;
}
-Function &AbstractAttribute::getAnchorScope() {
- Value &V = getAnchoredValue();
- if (isa<Function>(V))
- return cast<Function>(V);
- if (isa<Argument>(V))
- return *cast<Argument>(V).getParent();
- if (isa<Instruction>(V))
- return *cast<Instruction>(V).getFunction();
- llvm_unreachable("No scope for anchored value found!");
+const IRPosition IRPosition::EmptyKey(255);
+const IRPosition IRPosition::TombstoneKey(256);
+
+SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
+ IRPositions.emplace_back(IRP);
+
+ ImmutableCallSite ICS(&IRP.getAnchorValue());
+ switch (IRP.getPositionKind()) {
+ case IRPosition::IRP_INVALID:
+ case IRPosition::IRP_FLOAT:
+ case IRPosition::IRP_FUNCTION:
+ return;
+ case IRPosition::IRP_ARGUMENT:
+ case IRPosition::IRP_RETURNED:
+ IRPositions.emplace_back(
+ IRPosition::function(*IRP.getAssociatedFunction()));
+ return;
+ case IRPosition::IRP_CALL_SITE:
+ assert(ICS && "Expected call site!");
+ // TODO: We need to look at the operand bundles similar to the redirection
+ // in CallBase.
+ if (!ICS.hasOperandBundles())
+ if (const Function *Callee = ICS.getCalledFunction())
+ IRPositions.emplace_back(IRPosition::function(*Callee));
+ return;
+ case IRPosition::IRP_CALL_SITE_RETURNED:
+ assert(ICS && "Expected call site!");
+ // TODO: We need to look at the operand bundles similar to the redirection
+ // in CallBase.
+ if (!ICS.hasOperandBundles()) {
+ if (const Function *Callee = ICS.getCalledFunction()) {
+ IRPositions.emplace_back(IRPosition::returned(*Callee));
+ IRPositions.emplace_back(IRPosition::function(*Callee));
+ }
+ }
+ IRPositions.emplace_back(
+ IRPosition::callsite_function(cast<CallBase>(*ICS.getInstruction())));
+ return;
+ case IRPosition::IRP_CALL_SITE_ARGUMENT: {
+ int ArgNo = IRP.getArgNo();
+ assert(ICS && ArgNo >= 0 && "Expected call site!");
+ // TODO: We need to look at the operand bundles similar to the redirection
+ // in CallBase.
+ if (!ICS.hasOperandBundles()) {
+ const Function *Callee = ICS.getCalledFunction();
+ if (Callee && Callee->arg_size() > unsigned(ArgNo))
+ IRPositions.emplace_back(IRPosition::argument(*Callee->getArg(ArgNo)));
+ if (Callee)
+ IRPositions.emplace_back(IRPosition::function(*Callee));
+ }
+ IRPositions.emplace_back(IRPosition::value(IRP.getAssociatedValue()));
+ return;
+ }
+ }
+}
+
+bool IRPosition::hasAttr(ArrayRef<Attribute::AttrKind> AKs,
+ bool IgnoreSubsumingPositions) const {
+ for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) {
+ for (Attribute::AttrKind AK : AKs)
+ if (EquivIRP.getAttr(AK).getKindAsEnum() == AK)
+ return true;
+ // The first position returned by the SubsumingPositionIterator is
+ // always the position itself. If we ignore subsuming positions we
+ // are done after the first iteration.
+ if (IgnoreSubsumingPositions)
+ break;
+ }
+ return false;
}
-const Function &AbstractAttribute::getAnchorScope() const {
- return const_cast<AbstractAttribute *>(this)->getAnchorScope();
+void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs,
+ SmallVectorImpl<Attribute> &Attrs) const {
+ for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this))
+ for (Attribute::AttrKind AK : AKs) {
+ const Attribute &Attr = EquivIRP.getAttr(AK);
+ if (Attr.getKindAsEnum() == AK)
+ Attrs.push_back(Attr);
+ }
}
-/// -----------------------NoUnwind Function Attribute--------------------------
+void IRPosition::verify() {
+ switch (KindOrArgNo) {
+ default:
+ assert(KindOrArgNo >= 0 && "Expected argument or call site argument!");
+ assert((isa<CallBase>(AnchorVal) || isa<Argument>(AnchorVal)) &&
+ "Expected call base or argument for positive attribute index!");
+ if (isa<Argument>(AnchorVal)) {
+ assert(cast<Argument>(AnchorVal)->getArgNo() == unsigned(getArgNo()) &&
+ "Argument number mismatch!");
+ assert(cast<Argument>(AnchorVal) == &getAssociatedValue() &&
+ "Associated value mismatch!");
+ } else {
+ assert(cast<CallBase>(*AnchorVal).arg_size() > unsigned(getArgNo()) &&
+ "Call site argument number mismatch!");
+ assert(cast<CallBase>(*AnchorVal).getArgOperand(getArgNo()) ==
+ &getAssociatedValue() &&
+ "Associated value mismatch!");
+ }
+ break;
+ case IRP_INVALID:
+ assert(!AnchorVal && "Expected no value for an invalid position!");
+ break;
+ case IRP_FLOAT:
+ assert((!isa<CallBase>(&getAssociatedValue()) &&
+ !isa<Argument>(&getAssociatedValue())) &&
+ "Expected specialized kind for call base and argument values!");
+ break;
+ case IRP_RETURNED:
+ assert(isa<Function>(AnchorVal) &&
+ "Expected function for a 'returned' position!");
+ assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!");
+ break;
+ case IRP_CALL_SITE_RETURNED:
+ assert((isa<CallBase>(AnchorVal)) &&
+ "Expected call base for 'call site returned' position!");
+ assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!");
+ break;
+ case IRP_CALL_SITE:
+ assert((isa<CallBase>(AnchorVal)) &&
+ "Expected call base for 'call site function' position!");
+ assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!");
+ break;
+ case IRP_FUNCTION:
+ assert(isa<Function>(AnchorVal) &&
+ "Expected function for a 'function' position!");
+ assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!");
+ break;
+ }
+}
+
+namespace {
+/// Helper functions to clamp a state \p S of type \p StateType with the
+/// information in \p R and indicate/return if \p S did change (as-in update is
+/// required to be run again).
+///
+///{
+template <typename StateType>
+ChangeStatus clampStateAndIndicateChange(StateType &S, const StateType &R);
+
+template <>
+ChangeStatus clampStateAndIndicateChange<IntegerState>(IntegerState &S,
+ const IntegerState &R) {
+ auto Assumed = S.getAssumed();
+ S ^= R;
+ return Assumed == S.getAssumed() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+}
-struct AANoUnwindFunction : AANoUnwind, BooleanState {
+template <>
+ChangeStatus clampStateAndIndicateChange<BooleanState>(BooleanState &S,
+ const BooleanState &R) {
+ return clampStateAndIndicateChange<IntegerState>(S, R);
+}
+///}
- AANoUnwindFunction(Function &F, InformationCache &InfoCache)
- : AANoUnwind(F, InfoCache) {}
+/// Clamp the information known for all returned values of a function
+/// (identified by \p QueryingAA) into \p S.
+template <typename AAType, typename StateType = typename AAType::StateType>
+static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA,
+ StateType &S) {
+ LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for "
+ << static_cast<const AbstractAttribute &>(QueryingAA)
+ << " into " << S << "\n");
+
+ assert((QueryingAA.getIRPosition().getPositionKind() ==
+ IRPosition::IRP_RETURNED ||
+ QueryingAA.getIRPosition().getPositionKind() ==
+ IRPosition::IRP_CALL_SITE_RETURNED) &&
+ "Can only clamp returned value states for a function returned or call "
+ "site returned position!");
+
+ // Use an optional state as there might not be any return values and we want
+ // to join (IntegerState::operator&) the state of all there are.
+ Optional<StateType> T;
+
+ // Callback for each possibly returned value.
+ auto CheckReturnValue = [&](Value &RV) -> bool {
+ const IRPosition &RVPos = IRPosition::value(RV);
+ const AAType &AA = A.getAAFor<AAType>(QueryingAA, RVPos);
+ LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr()
+ << " @ " << RVPos << "\n");
+ const StateType &AAS = static_cast<const StateType &>(AA.getState());
+ if (T.hasValue())
+ *T &= AAS;
+ else
+ T = AAS;
+ LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " RV State: " << T
+ << "\n");
+ return T->isValidState();
+ };
- /// See AbstractAttribute::getState()
- /// {
- AbstractState &getState() override { return *this; }
- const AbstractState &getState() const override { return *this; }
- /// }
+ if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA))
+ S.indicatePessimisticFixpoint();
+ else if (T.hasValue())
+ S ^= *T;
+}
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_FUNCTION; }
+/// Helper class to compose two generic deduction
+template <typename AAType, typename Base, typename StateType,
+ template <typename...> class F, template <typename...> class G>
+struct AAComposeTwoGenericDeduction
+ : public F<AAType, G<AAType, Base, StateType>, StateType> {
+ AAComposeTwoGenericDeduction(const IRPosition &IRP)
+ : F<AAType, G<AAType, Base, StateType>, StateType>(IRP) {}
- const std::string getAsStr() const override {
- return getAssumed() ? "nounwind" : "may-unwind";
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus ChangedF = F<AAType, G<AAType, Base, StateType>, StateType>::updateImpl(A);
+ ChangeStatus ChangedG = G<AAType, Base, StateType>::updateImpl(A);
+ return ChangedF | ChangedG;
}
+};
+
+/// Helper class for generic deduction: return value -> returned position.
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+struct AAReturnedFromReturnedValues : public Base {
+ AAReturnedFromReturnedValues(const IRPosition &IRP) : Base(IRP) {}
/// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override;
+ ChangeStatus updateImpl(Attributor &A) override {
+ StateType S;
+ clampReturnedValueStates<AAType, StateType>(A, *this, S);
+ // TODO: If we know we visited all returned values, thus no are assumed
+ // dead, we can take the known information from the state T.
+ return clampStateAndIndicateChange<StateType>(this->getState(), S);
+ }
+};
- /// See AANoUnwind::isAssumedNoUnwind().
- bool isAssumedNoUnwind() const override { return getAssumed(); }
+/// Clamp the information known at all call sites for a given argument
+/// (identified by \p QueryingAA) into \p S.
+template <typename AAType, typename StateType = typename AAType::StateType>
+static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
+ StateType &S) {
+ LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for "
+ << static_cast<const AbstractAttribute &>(QueryingAA)
+ << " into " << S << "\n");
+
+ assert(QueryingAA.getIRPosition().getPositionKind() ==
+ IRPosition::IRP_ARGUMENT &&
+ "Can only clamp call site argument states for an argument position!");
+
+ // Use an optional state as there might not be any return values and we want
+ // to join (IntegerState::operator&) the state of all there are.
+ Optional<StateType> T;
+
+ // The argument number which is also the call site argument number.
+ unsigned ArgNo = QueryingAA.getIRPosition().getArgNo();
+
+ auto CallSiteCheck = [&](AbstractCallSite ACS) {
+ const IRPosition &ACSArgPos = IRPosition::callsite_argument(ACS, ArgNo);
+ // Check if a coresponding argument was found or if it is on not associated
+ // (which can happen for callback calls).
+ if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID)
+ return false;
- /// See AANoUnwind::isKnownNoUnwind().
- bool isKnownNoUnwind() const override { return getKnown(); }
+ const AAType &AA = A.getAAFor<AAType>(QueryingAA, ACSArgPos);
+ LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction()
+ << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n");
+ const StateType &AAS = static_cast<const StateType &>(AA.getState());
+ if (T.hasValue())
+ *T &= AAS;
+ else
+ T = AAS;
+ LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " CSA State: " << T
+ << "\n");
+ return T->isValidState();
+ };
+
+ if (!A.checkForAllCallSites(CallSiteCheck, QueryingAA, true))
+ S.indicatePessimisticFixpoint();
+ else if (T.hasValue())
+ S ^= *T;
+}
+
+/// Helper class for generic deduction: call site argument -> argument position.
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+struct AAArgumentFromCallSiteArguments : public Base {
+ AAArgumentFromCallSiteArguments(const IRPosition &IRP) : Base(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ StateType S;
+ clampCallSiteArgumentStates<AAType, StateType>(A, *this, S);
+ // TODO: If we know we visited all incoming values, thus no are assumed
+ // dead, we can take the known information from the state T.
+ return clampStateAndIndicateChange<StateType>(this->getState(), S);
+ }
};
-ChangeStatus AANoUnwindFunction::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
+/// Helper class for generic replication: function returned -> cs returned.
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+struct AACallSiteReturnedFromReturned : public Base {
+ AACallSiteReturnedFromReturned(const IRPosition &IRP) : Base(IRP) {}
- // The map from instruction opcodes to those instructions in the function.
- auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
- auto Opcodes = {
- (unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
- (unsigned)Instruction::Call, (unsigned)Instruction::CleanupRet,
- (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume};
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ assert(this->getIRPosition().getPositionKind() ==
+ IRPosition::IRP_CALL_SITE_RETURNED &&
+ "Can only wrap function returned positions for call site returned "
+ "positions!");
+ auto &S = this->getState();
+
+ const Function *AssociatedFunction =
+ this->getIRPosition().getAssociatedFunction();
+ if (!AssociatedFunction)
+ return S.indicatePessimisticFixpoint();
+
+ IRPosition FnPos = IRPosition::returned(*AssociatedFunction);
+ const AAType &AA = A.getAAFor<AAType>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ S, static_cast<const typename AAType::StateType &>(AA.getState()));
+ }
+};
- for (unsigned Opcode : Opcodes) {
- for (Instruction *I : OpcodeInstMap[Opcode]) {
- if (!I->mayThrow())
- continue;
+/// Helper class for generic deduction using must-be-executed-context
+/// Base class is required to have `followUse` method.
- auto *NoUnwindAA = A.getAAFor<AANoUnwind>(*this, *I);
+/// bool followUse(Attributor &A, const Use *U, const Instruction *I)
+/// U - Underlying use.
+/// I - The user of the \p U.
+/// `followUse` returns true if the value should be tracked transitively.
- if (!NoUnwindAA || !NoUnwindAA->isAssumedNoUnwind()) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+struct AAFromMustBeExecutedContext : public Base {
+ AAFromMustBeExecutedContext(const IRPosition &IRP) : Base(IRP) {}
+
+ void initialize(Attributor &A) override {
+ Base::initialize(A);
+ IRPosition &IRP = this->getIRPosition();
+ Instruction *CtxI = IRP.getCtxI();
+
+ if (!CtxI)
+ return;
+
+ for (const Use &U : IRP.getAssociatedValue().uses())
+ Uses.insert(&U);
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ auto BeforeState = this->getState();
+ auto &S = this->getState();
+ Instruction *CtxI = this->getIRPosition().getCtxI();
+ if (!CtxI)
+ return ChangeStatus::UNCHANGED;
+
+ MustBeExecutedContextExplorer &Explorer =
+ A.getInfoCache().getMustBeExecutedContextExplorer();
+
+ SetVector<const Use *> NextUses;
+
+ for (const Use *U : Uses) {
+ if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
+ auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
+ bool Found = EIt.count(UserI);
+ while (!Found && ++EIt != EEnd)
+ Found = EIt.getCurrentInst() == UserI;
+ if (Found && Base::followUse(A, U, UserI))
+ for (const Use &Us : UserI->uses())
+ NextUses.insert(&Us);
}
}
+ for (const Use *U : NextUses)
+ Uses.insert(U);
+
+ return BeforeState == S ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED;
}
- return ChangeStatus::UNCHANGED;
-}
+
+private:
+ /// Container for (transitive) uses of the associated value.
+ SetVector<const Use *> Uses;
+};
+
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+using AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext =
+ AAComposeTwoGenericDeduction<AAType, Base, StateType,
+ AAFromMustBeExecutedContext,
+ AAArgumentFromCallSiteArguments>;
+
+template <typename AAType, typename Base,
+ typename StateType = typename AAType::StateType>
+using AACallSiteReturnedFromReturnedAndMustBeExecutedContext =
+ AAComposeTwoGenericDeduction<AAType, Base, StateType,
+ AAFromMustBeExecutedContext,
+ AACallSiteReturnedFromReturned>;
+
+/// -----------------------NoUnwind Function Attribute--------------------------
+
+struct AANoUnwindImpl : AANoUnwind {
+ AANoUnwindImpl(const IRPosition &IRP) : AANoUnwind(IRP) {}
+
+ const std::string getAsStr() const override {
+ return getAssumed() ? "nounwind" : "may-unwind";
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ auto Opcodes = {
+ (unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
+ (unsigned)Instruction::Call, (unsigned)Instruction::CleanupRet,
+ (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume};
+
+ auto CheckForNoUnwind = [&](Instruction &I) {
+ if (!I.mayThrow())
+ return true;
+
+ if (ImmutableCallSite ICS = ImmutableCallSite(&I)) {
+ const auto &NoUnwindAA =
+ A.getAAFor<AANoUnwind>(*this, IRPosition::callsite_function(ICS));
+ return NoUnwindAA.isAssumedNoUnwind();
+ }
+ return false;
+ };
+
+ if (!A.checkForAllInstructions(CheckForNoUnwind, *this, Opcodes))
+ return indicatePessimisticFixpoint();
+
+ return ChangeStatus::UNCHANGED;
+ }
+};
+
+struct AANoUnwindFunction final : public AANoUnwindImpl {
+ AANoUnwindFunction(const IRPosition &IRP) : AANoUnwindImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nounwind) }
+};
+
+/// NoUnwind attribute deduction for a call sites.
+struct AANoUnwindCallSite final : AANoUnwindImpl {
+ AANoUnwindCallSite(const IRPosition &IRP) : AANoUnwindImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoUnwindImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AANoUnwind>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AANoUnwind::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); }
+};
/// --------------------- Function Return Values -------------------------------
@@ -434,68 +856,48 @@ ChangeStatus AANoUnwindFunction::updateImpl(Attributor &A) {
///
/// If there is a unique returned value R, the manifest method will:
/// - mark R with the "returned" attribute, if R is an argument.
-class AAReturnedValuesImpl final : public AAReturnedValues, AbstractState {
+class AAReturnedValuesImpl : public AAReturnedValues, public AbstractState {
/// Mapping of values potentially returned by the associated function to the
/// return instructions that might return them.
- DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> ReturnedValues;
+ MapVector<Value *, SmallSetVector<ReturnInst *, 4>> ReturnedValues;
+
+ /// Mapping to remember the number of returned values for a call site such
+ /// that we can avoid updates if nothing changed.
+ DenseMap<const CallBase *, unsigned> NumReturnedValuesPerKnownAA;
+
+ /// Set of unresolved calls returned by the associated function.
+ SmallSetVector<CallBase *, 4> UnresolvedCalls;
/// State flags
///
///{
- bool IsFixed;
- bool IsValidState;
- bool HasOverdefinedReturnedCalls;
+ bool IsFixed = false;
+ bool IsValidState = true;
///}
- /// Collect values that could become \p V in the set \p Values, each mapped to
- /// \p ReturnInsts.
- void collectValuesRecursively(
- Attributor &A, Value *V, SmallPtrSetImpl<ReturnInst *> &ReturnInsts,
- DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> &Values) {
-
- visitValueCB_t<bool> VisitValueCB = [&](Value *Val, bool &) {
- assert(!isa<Instruction>(Val) ||
- &getAnchorScope() == cast<Instruction>(Val)->getFunction());
- Values[Val].insert(ReturnInsts.begin(), ReturnInsts.end());
- };
-
- bool UnusedBool;
- bool Success = genericValueTraversal(V, UnusedBool, VisitValueCB);
-
- // If we did abort the above traversal we haven't see all the values.
- // Consequently, we cannot know if the information we would derive is
- // accurate so we give up early.
- if (!Success)
- indicatePessimisticFixpoint();
- }
-
public:
- /// See AbstractAttribute::AbstractAttribute(...).
- AAReturnedValuesImpl(Function &F, InformationCache &InfoCache)
- : AAReturnedValues(F, InfoCache) {
- // We do not have an associated argument yet.
- AssociatedVal = nullptr;
- }
+ AAReturnedValuesImpl(const IRPosition &IRP) : AAReturnedValues(IRP) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
// Reset the state.
- AssociatedVal = nullptr;
IsFixed = false;
IsValidState = true;
- HasOverdefinedReturnedCalls = false;
ReturnedValues.clear();
- Function &F = cast<Function>(getAnchoredValue());
+ Function *F = getAssociatedFunction();
+ if (!F) {
+ indicatePessimisticFixpoint();
+ return;
+ }
// The map from instruction opcodes to those instructions in the function.
- auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
+ auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F);
// Look through all arguments, if one is marked as returned we are done.
- for (Argument &Arg : F.args()) {
+ for (Argument &Arg : F->args()) {
if (Arg.hasReturnedAttr()) {
-
auto &ReturnInstSet = ReturnedValues[&Arg];
for (Instruction *RI : OpcodeInstMap[Instruction::Ret])
ReturnInstSet.insert(cast<ReturnInst>(RI));
@@ -505,13 +907,8 @@ public:
}
}
- // If no argument was marked as returned we look at all return instructions
- // and collect potentially returned values.
- for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) {
- SmallPtrSet<ReturnInst *, 1> RISet({cast<ReturnInst>(RI)});
- collectValuesRecursively(A, cast<ReturnInst>(RI)->getReturnValue(), RISet,
- ReturnedValues);
- }
+ if (!F->hasExactDefinition())
+ indicatePessimisticFixpoint();
}
/// See AbstractAttribute::manifest(...).
@@ -523,25 +920,35 @@ public:
/// See AbstractAttribute::getState(...).
const AbstractState &getState() const override { return *this; }
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; }
-
/// See AbstractAttribute::updateImpl(Attributor &A).
ChangeStatus updateImpl(Attributor &A) override;
+ llvm::iterator_range<iterator> returned_values() override {
+ return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
+ }
+
+ llvm::iterator_range<const_iterator> returned_values() const override {
+ return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
+ }
+
+ const SmallSetVector<CallBase *, 4> &getUnresolvedCalls() const override {
+ return UnresolvedCalls;
+ }
+
/// Return the number of potential return values, -1 if unknown.
- size_t getNumReturnValues() const {
+ size_t getNumReturnValues() const override {
return isValidState() ? ReturnedValues.size() : -1;
}
/// Return an assumed unique return value if a single candidate is found. If
/// there cannot be one, return a nullptr. If it is not clear yet, return the
/// Optional::NoneType.
- Optional<Value *> getAssumedUniqueReturnValue() const;
+ Optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const;
- /// See AbstractState::checkForallReturnedValues(...).
- bool
- checkForallReturnedValues(std::function<bool(Value &)> &Pred) const override;
+ /// See AbstractState::checkForAllReturnedValues(...).
+ bool checkForAllReturnedValuesAndReturnInsts(
+ const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)>
+ &Pred) const override;
/// Pretty print the attribute similar to the IR representation.
const std::string getAsStr() const override;
@@ -553,13 +960,15 @@ public:
bool isValidState() const override { return IsValidState; }
/// See AbstractState::indicateOptimisticFixpoint(...).
- void indicateOptimisticFixpoint() override {
+ ChangeStatus indicateOptimisticFixpoint() override {
IsFixed = true;
- IsValidState &= true;
+ return ChangeStatus::UNCHANGED;
}
- void indicatePessimisticFixpoint() override {
+
+ ChangeStatus indicatePessimisticFixpoint() override {
IsFixed = true;
IsValidState = false;
+ return ChangeStatus::CHANGED;
}
};
@@ -568,21 +977,52 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) {
// Bookkeeping.
assert(isValidState());
- NumFnKnownReturns++;
+ STATS_DECLTRACK(KnownReturnValues, FunctionReturn,
+ "Number of function with known return values");
// Check if we have an assumed unique return value that we could manifest.
- Optional<Value *> UniqueRV = getAssumedUniqueReturnValue();
+ Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A);
if (!UniqueRV.hasValue() || !UniqueRV.getValue())
return Changed;
// Bookkeeping.
- NumFnUniqueReturned++;
+ STATS_DECLTRACK(UniqueReturnValue, FunctionReturn,
+ "Number of function with unique return");
+
+ // Callback to replace the uses of CB with the constant C.
+ auto ReplaceCallSiteUsersWith = [](CallBase &CB, Constant &C) {
+ if (CB.getNumUses() == 0 || CB.isMustTailCall())
+ return ChangeStatus::UNCHANGED;
+ CB.replaceAllUsesWith(&C);
+ return ChangeStatus::CHANGED;
+ };
// If the assumed unique return value is an argument, annotate it.
if (auto *UniqueRVArg = dyn_cast<Argument>(UniqueRV.getValue())) {
- AssociatedVal = UniqueRVArg;
- Changed = AbstractAttribute::manifest(A) | Changed;
+ getIRPosition() = IRPosition::argument(*UniqueRVArg);
+ Changed = IRAttribute::manifest(A);
+ } else if (auto *RVC = dyn_cast<Constant>(UniqueRV.getValue())) {
+ // We can replace the returned value with the unique returned constant.
+ Value &AnchorValue = getAnchorValue();
+ if (Function *F = dyn_cast<Function>(&AnchorValue)) {
+ for (const Use &U : F->uses())
+ if (CallBase *CB = dyn_cast<CallBase>(U.getUser()))
+ if (CB->isCallee(&U)) {
+ Constant *RVCCast =
+ ConstantExpr::getTruncOrBitCast(RVC, CB->getType());
+ Changed = ReplaceCallSiteUsersWith(*CB, *RVCCast) | Changed;
+ }
+ } else {
+ assert(isa<CallBase>(AnchorValue) &&
+ "Expcected a function or call base anchor!");
+ Constant *RVCCast =
+ ConstantExpr::getTruncOrBitCast(RVC, AnchorValue.getType());
+ Changed = ReplaceCallSiteUsersWith(cast<CallBase>(AnchorValue), *RVCCast);
+ }
+ if (Changed == ChangeStatus::CHANGED)
+ STATS_DECLTRACK(UniqueConstantReturnValue, FunctionReturn,
+ "Number of function returns replaced by constant return");
}
return Changed;
@@ -590,18 +1030,20 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) {
const std::string AAReturnedValuesImpl::getAsStr() const {
return (isAtFixpoint() ? "returns(#" : "may-return(#") +
- (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")";
+ (isValidState() ? std::to_string(getNumReturnValues()) : "?") +
+ ")[#UC: " + std::to_string(UnresolvedCalls.size()) + "]";
}
-Optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue() const {
- // If checkForallReturnedValues provides a unique value, ignoring potential
+Optional<Value *>
+AAReturnedValuesImpl::getAssumedUniqueReturnValue(Attributor &A) const {
+ // If checkForAllReturnedValues provides a unique value, ignoring potential
// undef values that can also be present, it is assumed to be the actual
// return value and forwarded to the caller of this method. If there are
// multiple, a nullptr is returned indicating there cannot be a unique
// returned value.
Optional<Value *> UniqueRV;
- std::function<bool(Value &)> Pred = [&](Value &RV) -> bool {
+ auto Pred = [&](Value &RV) -> bool {
// If we found a second returned value and neither the current nor the saved
// one is an undef, there is no unique returned value. Undefs are special
// since we can pretend they have any value.
@@ -618,14 +1060,15 @@ Optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue() const {
return true;
};
- if (!checkForallReturnedValues(Pred))
+ if (!A.checkForAllReturnedValues(Pred, *this))
UniqueRV = nullptr;
return UniqueRV;
}
-bool AAReturnedValuesImpl::checkForallReturnedValues(
- std::function<bool(Value &)> &Pred) const {
+bool AAReturnedValuesImpl::checkForAllReturnedValuesAndReturnInsts(
+ const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)>
+ &Pred) const {
if (!isValidState())
return false;
@@ -634,11 +1077,11 @@ bool AAReturnedValuesImpl::checkForallReturnedValues(
for (auto &It : ReturnedValues) {
Value *RV = It.first;
- ImmutableCallSite ICS(RV);
- if (ICS && !HasOverdefinedReturnedCalls)
+ CallBase *CB = dyn_cast<CallBase>(RV);
+ if (CB && !UnresolvedCalls.count(CB))
continue;
- if (!Pred(*RV))
+ if (!Pred(*RV, It.second))
return false;
}
@@ -646,125 +1089,196 @@ bool AAReturnedValuesImpl::checkForallReturnedValues(
}
ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
+ size_t NumUnresolvedCalls = UnresolvedCalls.size();
+ bool Changed = false;
+
+ // State used in the value traversals starting in returned values.
+ struct RVState {
+ // The map in which we collect return values -> return instrs.
+ decltype(ReturnedValues) &RetValsMap;
+ // The flag to indicate a change.
+ bool &Changed;
+ // The return instrs we come from.
+ SmallSetVector<ReturnInst *, 4> RetInsts;
+ };
- // Check if we know of any values returned by the associated function,
- // if not, we are done.
- if (getNumReturnValues() == 0) {
- indicateOptimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
+ // Callback for a leaf value returned by the associated function.
+ auto VisitValueCB = [](Value &Val, RVState &RVS, bool) -> bool {
+ auto Size = RVS.RetValsMap[&Val].size();
+ RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end());
+ bool Inserted = RVS.RetValsMap[&Val].size() != Size;
+ RVS.Changed |= Inserted;
+ LLVM_DEBUG({
+ if (Inserted)
+ dbgs() << "[AAReturnedValues] 1 Add new returned value " << Val
+ << " => " << RVS.RetInsts.size() << "\n";
+ });
+ return true;
+ };
- // Check if any of the returned values is a call site we can refine.
- decltype(ReturnedValues) AddRVs;
- bool HasCallSite = false;
+ // Helper method to invoke the generic value traversal.
+ auto VisitReturnedValue = [&](Value &RV, RVState &RVS) {
+ IRPosition RetValPos = IRPosition::value(RV);
+ return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this,
+ RVS, VisitValueCB);
+ };
- // Look at all returned call sites.
- for (auto &It : ReturnedValues) {
- SmallPtrSet<ReturnInst *, 2> &ReturnInsts = It.second;
- Value *RV = It.first;
- LLVM_DEBUG(dbgs() << "[AAReturnedValues] Potentially returned value " << *RV
- << "\n");
+ // Callback for all "return intructions" live in the associated function.
+ auto CheckReturnInst = [this, &VisitReturnedValue, &Changed](Instruction &I) {
+ ReturnInst &Ret = cast<ReturnInst>(I);
+ RVState RVS({ReturnedValues, Changed, {}});
+ RVS.RetInsts.insert(&Ret);
+ return VisitReturnedValue(*Ret.getReturnValue(), RVS);
+ };
- // Only call sites can change during an update, ignore the rest.
- CallSite RetCS(RV);
- if (!RetCS)
+ // Start by discovering returned values from all live returned instructions in
+ // the associated function.
+ if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}))
+ return indicatePessimisticFixpoint();
+
+ // Once returned values "directly" present in the code are handled we try to
+ // resolve returned calls.
+ decltype(ReturnedValues) NewRVsMap;
+ for (auto &It : ReturnedValues) {
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned value: " << *It.first
+ << " by #" << It.second.size() << " RIs\n");
+ CallBase *CB = dyn_cast<CallBase>(It.first);
+ if (!CB || UnresolvedCalls.count(CB))
continue;
- // For now, any call site we see will prevent us from directly fixing the
- // state. However, if the information on the callees is fixed, the call
- // sites will be removed and we will fix the information for this state.
- HasCallSite = true;
-
- // Try to find a assumed unique return value for the called function.
- auto *RetCSAA = A.getAAFor<AAReturnedValuesImpl>(*this, *RV);
- if (!RetCSAA) {
- HasOverdefinedReturnedCalls = true;
- LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site (" << *RV
- << ") with " << (RetCSAA ? "invalid" : "no")
- << " associated state\n");
+ if (!CB->getCalledFunction()) {
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] Unresolved call: " << *CB
+ << "\n");
+ UnresolvedCalls.insert(CB);
continue;
}
- // Try to find a assumed unique return value for the called function.
- Optional<Value *> AssumedUniqueRV = RetCSAA->getAssumedUniqueReturnValue();
+ // TODO: use the function scope once we have call site AAReturnedValues.
+ const auto &RetValAA = A.getAAFor<AAReturnedValues>(
+ *this, IRPosition::function(*CB->getCalledFunction()));
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] Found another AAReturnedValues: "
+ << static_cast<const AbstractAttribute &>(RetValAA)
+ << "\n");
- // If no assumed unique return value was found due to the lack of
- // candidates, we may need to resolve more calls (through more update
- // iterations) or the called function will not return. Either way, we simply
- // stick with the call sites as return values. Because there were not
- // multiple possibilities, we do not treat it as overdefined.
- if (!AssumedUniqueRV.hasValue())
+ // Skip dead ends, thus if we do not know anything about the returned
+ // call we mark it as unresolved and it will stay that way.
+ if (!RetValAA.getState().isValidState()) {
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] Unresolved call: " << *CB
+ << "\n");
+ UnresolvedCalls.insert(CB);
continue;
+ }
- // If multiple, non-refinable values were found, there cannot be a unique
- // return value for the called function. The returned call is overdefined!
- if (!AssumedUniqueRV.getValue()) {
- HasOverdefinedReturnedCalls = true;
- LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site has multiple "
- "potentially returned values\n");
+ // Do not try to learn partial information. If the callee has unresolved
+ // return values we will treat the call as unresolved/opaque.
+ auto &RetValAAUnresolvedCalls = RetValAA.getUnresolvedCalls();
+ if (!RetValAAUnresolvedCalls.empty()) {
+ UnresolvedCalls.insert(CB);
continue;
}
- LLVM_DEBUG({
- bool UniqueRVIsKnown = RetCSAA->isAtFixpoint();
- dbgs() << "[AAReturnedValues] Returned call site "
- << (UniqueRVIsKnown ? "known" : "assumed")
- << " unique return value: " << *AssumedUniqueRV << "\n";
- });
+ // Now check if we can track transitively returned values. If possible, thus
+ // if all return value can be represented in the current scope, do so.
+ bool Unresolved = false;
+ for (auto &RetValAAIt : RetValAA.returned_values()) {
+ Value *RetVal = RetValAAIt.first;
+ if (isa<Argument>(RetVal) || isa<CallBase>(RetVal) ||
+ isa<Constant>(RetVal))
+ continue;
+ // Anything that did not fit in the above categories cannot be resolved,
+ // mark the call as unresolved.
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] transitively returned value "
+ "cannot be translated: "
+ << *RetVal << "\n");
+ UnresolvedCalls.insert(CB);
+ Unresolved = true;
+ break;
+ }
- // The assumed unique return value.
- Value *AssumedRetVal = AssumedUniqueRV.getValue();
-
- // If the assumed unique return value is an argument, lookup the matching
- // call site operand and recursively collect new returned values.
- // If it is not an argument, it is just put into the set of returned values
- // as we would have already looked through casts, phis, and similar values.
- if (Argument *AssumedRetArg = dyn_cast<Argument>(AssumedRetVal))
- collectValuesRecursively(A,
- RetCS.getArgOperand(AssumedRetArg->getArgNo()),
- ReturnInsts, AddRVs);
- else
- AddRVs[AssumedRetVal].insert(ReturnInsts.begin(), ReturnInsts.end());
- }
+ if (Unresolved)
+ continue;
- // Keep track of any change to trigger updates on dependent attributes.
- ChangeStatus Changed = ChangeStatus::UNCHANGED;
+ // Now track transitively returned values.
+ unsigned &NumRetAA = NumReturnedValuesPerKnownAA[CB];
+ if (NumRetAA == RetValAA.getNumReturnValues()) {
+ LLVM_DEBUG(dbgs() << "[AAReturnedValues] Skip call as it has not "
+ "changed since it was seen last\n");
+ continue;
+ }
+ NumRetAA = RetValAA.getNumReturnValues();
+
+ for (auto &RetValAAIt : RetValAA.returned_values()) {
+ Value *RetVal = RetValAAIt.first;
+ if (Argument *Arg = dyn_cast<Argument>(RetVal)) {
+ // Arguments are mapped to call site operands and we begin the traversal
+ // again.
+ bool Unused = false;
+ RVState RVS({NewRVsMap, Unused, RetValAAIt.second});
+ VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS);
+ continue;
+ } else if (isa<CallBase>(RetVal)) {
+ // Call sites are resolved by the callee attribute over time, no need to
+ // do anything for us.
+ continue;
+ } else if (isa<Constant>(RetVal)) {
+ // Constants are valid everywhere, we can simply take them.
+ NewRVsMap[RetVal].insert(It.second.begin(), It.second.end());
+ continue;
+ }
+ }
+ }
- for (auto &It : AddRVs) {
+ // To avoid modifications to the ReturnedValues map while we iterate over it
+ // we kept record of potential new entries in a copy map, NewRVsMap.
+ for (auto &It : NewRVsMap) {
assert(!It.second.empty() && "Entry does not add anything.");
auto &ReturnInsts = ReturnedValues[It.first];
for (ReturnInst *RI : It.second)
- if (ReturnInsts.insert(RI).second) {
+ if (ReturnInsts.insert(RI)) {
LLVM_DEBUG(dbgs() << "[AAReturnedValues] Add new returned value "
<< *It.first << " => " << *RI << "\n");
- Changed = ChangeStatus::CHANGED;
+ Changed = true;
}
}
- // If there is no call site in the returned values we are done.
- if (!HasCallSite) {
- indicateOptimisticFixpoint();
- return ChangeStatus::CHANGED;
- }
-
- return Changed;
+ Changed |= (NumUnresolvedCalls != UnresolvedCalls.size());
+ return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
}
-/// ------------------------ NoSync Function Attribute -------------------------
+struct AAReturnedValuesFunction final : public AAReturnedValuesImpl {
+ AAReturnedValuesFunction(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {}
-struct AANoSyncFunction : AANoSync, BooleanState {
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(returned) }
+};
- AANoSyncFunction(Function &F, InformationCache &InfoCache)
- : AANoSync(F, InfoCache) {}
+/// Returned values information for a call sites.
+struct AAReturnedValuesCallSite final : AAReturnedValuesImpl {
+ AAReturnedValuesCallSite(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {}
- /// See AbstractAttribute::getState()
- /// {
- AbstractState &getState() override { return *this; }
- const AbstractState &getState() const override { return *this; }
- /// }
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites instead of
+ // redirecting requests to the callee.
+ llvm_unreachable("Abstract attributes for returned values are not "
+ "supported for call sites yet!");
+ }
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_FUNCTION; }
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ return indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {}
+};
+
+/// ------------------------ NoSync Function Attribute -------------------------
+
+struct AANoSyncImpl : AANoSync {
+ AANoSyncImpl(const IRPosition &IRP) : AANoSync(IRP) {}
const std::string getAsStr() const override {
return getAssumed() ? "nosync" : "may-sync";
@@ -773,12 +1287,6 @@ struct AANoSyncFunction : AANoSync, BooleanState {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override;
- /// See AANoSync::isAssumedNoSync()
- bool isAssumedNoSync() const override { return getAssumed(); }
-
- /// See AANoSync::isKnownNoSync()
- bool isKnownNoSync() const override { return getKnown(); }
-
/// Helper function used to determine whether an instruction is non-relaxed
/// atomic. In other words, if an atomic instruction does not have unordered
/// or monotonic ordering
@@ -792,7 +1300,7 @@ struct AANoSyncFunction : AANoSync, BooleanState {
static bool isNoSyncIntrinsic(Instruction *I);
};
-bool AANoSyncFunction::isNonRelaxedAtomic(Instruction *I) {
+bool AANoSyncImpl::isNonRelaxedAtomic(Instruction *I) {
if (!I->isAtomic())
return false;
@@ -841,7 +1349,7 @@ bool AANoSyncFunction::isNonRelaxedAtomic(Instruction *I) {
/// Checks if an intrinsic is nosync. Currently only checks mem* intrinsics.
/// FIXME: We should ipmrove the handling of intrinsics.
-bool AANoSyncFunction::isNoSyncIntrinsic(Instruction *I) {
+bool AANoSyncImpl::isNoSyncIntrinsic(Instruction *I) {
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
/// Element wise atomic memory intrinsics are can only be unordered,
@@ -863,7 +1371,7 @@ bool AANoSyncFunction::isNoSyncIntrinsic(Instruction *I) {
return false;
}
-bool AANoSyncFunction::isVolatile(Instruction *I) {
+bool AANoSyncImpl::isVolatile(Instruction *I) {
assert(!ImmutableCallSite(I) && !isa<CallBase>(I) &&
"Calls should not be checked here");
@@ -881,482 +1389,3074 @@ bool AANoSyncFunction::isVolatile(Instruction *I) {
}
}
-ChangeStatus AANoSyncFunction::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
+ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) {
- /// We are looking for volatile instructions or Non-Relaxed atomics.
- /// FIXME: We should ipmrove the handling of intrinsics.
- for (Instruction *I : InfoCache.getReadOrWriteInstsForFunction(F)) {
- ImmutableCallSite ICS(I);
- auto *NoSyncAA = A.getAAFor<AANoSyncFunction>(*this, *I);
+ auto CheckRWInstForNoSync = [&](Instruction &I) {
+ /// We are looking for volatile instructions or Non-Relaxed atomics.
+ /// FIXME: We should ipmrove the handling of intrinsics.
- if (isa<IntrinsicInst>(I) && isNoSyncIntrinsic(I))
- continue;
+ if (isa<IntrinsicInst>(&I) && isNoSyncIntrinsic(&I))
+ return true;
+
+ if (ImmutableCallSite ICS = ImmutableCallSite(&I)) {
+ if (ICS.hasFnAttr(Attribute::NoSync))
+ return true;
+
+ const auto &NoSyncAA =
+ A.getAAFor<AANoSync>(*this, IRPosition::callsite_function(ICS));
+ if (NoSyncAA.isAssumedNoSync())
+ return true;
+ return false;
+ }
+
+ if (!isVolatile(&I) && !isNonRelaxedAtomic(&I))
+ return true;
+
+ return false;
+ };
- if (ICS && (!NoSyncAA || !NoSyncAA->isAssumedNoSync()) &&
- !ICS.hasFnAttr(Attribute::NoSync)) {
+ auto CheckForNoSync = [&](Instruction &I) {
+ // At this point we handled all read/write effects and they are all
+ // nosync, so they can be skipped.
+ if (I.mayReadOrWriteMemory())
+ return true;
+
+ // non-convergent and readnone imply nosync.
+ return !ImmutableCallSite(&I).isConvergent();
+ };
+
+ if (!A.checkForAllReadWriteInstructions(CheckRWInstForNoSync, *this) ||
+ !A.checkForAllCallLikeInstructions(CheckForNoSync, *this))
+ return indicatePessimisticFixpoint();
+
+ return ChangeStatus::UNCHANGED;
+}
+
+struct AANoSyncFunction final : public AANoSyncImpl {
+ AANoSyncFunction(const IRPosition &IRP) : AANoSyncImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nosync) }
+};
+
+/// NoSync attribute deduction for a call sites.
+struct AANoSyncCallSite final : AANoSyncImpl {
+ AANoSyncCallSite(const IRPosition &IRP) : AANoSyncImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoSyncImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AANoSync>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(), static_cast<const AANoSync::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); }
+};
+
+/// ------------------------ No-Free Attributes ----------------------------
+
+struct AANoFreeImpl : public AANoFree {
+ AANoFreeImpl(const IRPosition &IRP) : AANoFree(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ auto CheckForNoFree = [&](Instruction &I) {
+ ImmutableCallSite ICS(&I);
+ if (ICS.hasFnAttr(Attribute::NoFree))
+ return true;
+
+ const auto &NoFreeAA =
+ A.getAAFor<AANoFree>(*this, IRPosition::callsite_function(ICS));
+ return NoFreeAA.isAssumedNoFree();
+ };
+
+ if (!A.checkForAllCallLikeInstructions(CheckForNoFree, *this))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr() const override {
+ return getAssumed() ? "nofree" : "may-free";
+ }
+};
+
+struct AANoFreeFunction final : public AANoFreeImpl {
+ AANoFreeFunction(const IRPosition &IRP) : AANoFreeImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nofree) }
+};
+
+/// NoFree attribute deduction for a call sites.
+struct AANoFreeCallSite final : AANoFreeImpl {
+ AANoFreeCallSite(const IRPosition &IRP) : AANoFreeImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoFreeImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AANoFree>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(), static_cast<const AANoFree::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nofree); }
+};
+
+/// ------------------------ NonNull Argument Attribute ------------------------
+static int64_t getKnownNonNullAndDerefBytesForUse(
+ Attributor &A, AbstractAttribute &QueryingAA, Value &AssociatedValue,
+ const Use *U, const Instruction *I, bool &IsNonNull, bool &TrackUse) {
+ TrackUse = false;
+
+ const Value *UseV = U->get();
+ if (!UseV->getType()->isPointerTy())
+ return 0;
+
+ Type *PtrTy = UseV->getType();
+ const Function *F = I->getFunction();
+ bool NullPointerIsDefined =
+ F ? llvm::NullPointerIsDefined(F, PtrTy->getPointerAddressSpace()) : true;
+ const DataLayout &DL = A.getInfoCache().getDL();
+ if (ImmutableCallSite ICS = ImmutableCallSite(I)) {
+ if (ICS.isBundleOperand(U))
+ return 0;
+
+ if (ICS.isCallee(U)) {
+ IsNonNull |= !NullPointerIsDefined;
+ return 0;
}
- if (ICS)
- continue;
+ unsigned ArgNo = ICS.getArgumentNo(U);
+ IRPosition IRP = IRPosition::callsite_argument(ICS, ArgNo);
+ auto &DerefAA = A.getAAFor<AADereferenceable>(QueryingAA, IRP);
+ IsNonNull |= DerefAA.isKnownNonNull();
+ return DerefAA.getKnownDereferenceableBytes();
+ }
- if (!isVolatile(I) && !isNonRelaxedAtomic(I))
- continue;
+ int64_t Offset;
+ if (const Value *Base = getBasePointerOfAccessPointerOperand(I, Offset, DL)) {
+ if (Base == &AssociatedValue && getPointerOperand(I) == UseV) {
+ int64_t DerefBytes =
+ Offset + (int64_t)DL.getTypeStoreSize(PtrTy->getPointerElementType());
+
+ IsNonNull |= !NullPointerIsDefined;
+ return DerefBytes;
+ }
+ }
+ if (const Value *Base =
+ GetPointerBaseWithConstantOffset(UseV, Offset, DL,
+ /*AllowNonInbounds*/ false)) {
+ auto &DerefAA =
+ A.getAAFor<AADereferenceable>(QueryingAA, IRPosition::value(*Base));
+ IsNonNull |= (!NullPointerIsDefined && DerefAA.isKnownNonNull());
+ IsNonNull |= (!NullPointerIsDefined && (Offset != 0));
+ int64_t DerefBytes = DerefAA.getKnownDereferenceableBytes();
+ return std::max(int64_t(0), DerefBytes - Offset);
+ }
+
+ return 0;
+}
+
+struct AANonNullImpl : AANonNull {
+ AANonNullImpl(const IRPosition &IRP)
+ : AANonNull(IRP),
+ NullIsDefined(NullPointerIsDefined(
+ getAnchorScope(),
+ getAssociatedValue().getType()->getPointerAddressSpace())) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ if (!NullIsDefined &&
+ hasAttr({Attribute::NonNull, Attribute::Dereferenceable}))
+ indicateOptimisticFixpoint();
+ else
+ AANonNull::initialize(A);
+ }
+
+ /// See AAFromMustBeExecutedContext
+ bool followUse(Attributor &A, const Use *U, const Instruction *I) {
+ bool IsNonNull = false;
+ bool TrackUse = false;
+ getKnownNonNullAndDerefBytesForUse(A, *this, getAssociatedValue(), U, I,
+ IsNonNull, TrackUse);
+ takeKnownMaximum(IsNonNull);
+ return TrackUse;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr() const override {
+ return getAssumed() ? "nonnull" : "may-null";
+ }
+
+ /// Flag to determine if the underlying value can be null and still allow
+ /// valid accesses.
+ const bool NullIsDefined;
+};
+
+/// NonNull attribute for a floating value.
+struct AANonNullFloating
+ : AAFromMustBeExecutedContext<AANonNull, AANonNullImpl> {
+ using Base = AAFromMustBeExecutedContext<AANonNull, AANonNullImpl>;
+ AANonNullFloating(const IRPosition &IRP) : Base(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ Base::initialize(A);
+
+ if (isAtFixpoint())
+ return;
+
+ const IRPosition &IRP = getIRPosition();
+ const Value &V = IRP.getAssociatedValue();
+ const DataLayout &DL = A.getDataLayout();
+
+ // TODO: This context sensitive query should be removed once we can do
+ // context sensitive queries in the genericValueTraversal below.
+ if (isKnownNonZero(&V, DL, 0, /* TODO: AC */ nullptr, IRP.getCtxI(),
+ /* TODO: DT */ nullptr))
+ indicateOptimisticFixpoint();
+ }
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus Change = Base::updateImpl(A);
+ if (isKnownNonNull())
+ return Change;
+
+ if (!NullIsDefined) {
+ const auto &DerefAA = A.getAAFor<AADereferenceable>(*this, getIRPosition());
+ if (DerefAA.getAssumedDereferenceableBytes())
+ return Change;
+ }
+
+ const DataLayout &DL = A.getDataLayout();
+
+ auto VisitValueCB = [&](Value &V, AAAlign::StateType &T,
+ bool Stripped) -> bool {
+ const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V));
+ if (!Stripped && this == &AA) {
+ if (!isKnownNonZero(&V, DL, 0, /* TODO: AC */ nullptr,
+ /* CtxI */ getCtxI(),
+ /* TODO: DT */ nullptr))
+ T.indicatePessimisticFixpoint();
+ } else {
+ // Use abstract attribute information.
+ const AANonNull::StateType &NS =
+ static_cast<const AANonNull::StateType &>(AA.getState());
+ T ^= NS;
+ }
+ return T.isValidState();
+ };
+
+ StateType T;
+ if (!genericValueTraversal<AANonNull, StateType>(A, getIRPosition(), *this,
+ T, VisitValueCB))
+ return indicatePessimisticFixpoint();
+
+ return clampStateAndIndicateChange(getState(), T);
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(nonnull) }
+};
+
+/// NonNull attribute for function return value.
+struct AANonNullReturned final
+ : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl> {
+ AANonNullReturned(const IRPosition &IRP)
+ : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl>(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(nonnull) }
+};
+
+/// NonNull attribute for function argument.
+struct AANonNullArgument final
+ : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AANonNull,
+ AANonNullImpl> {
+ AANonNullArgument(const IRPosition &IRP)
+ : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AANonNull,
+ AANonNullImpl>(
+ IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) }
+};
+
+struct AANonNullCallSiteArgument final : AANonNullFloating {
+ AANonNullCallSiteArgument(const IRPosition &IRP) : AANonNullFloating(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(nonnull) }
+};
+
+/// NonNull attribute for a call site return position.
+struct AANonNullCallSiteReturned final
+ : AACallSiteReturnedFromReturnedAndMustBeExecutedContext<AANonNull,
+ AANonNullImpl> {
+ AANonNullCallSiteReturned(const IRPosition &IRP)
+ : AACallSiteReturnedFromReturnedAndMustBeExecutedContext<AANonNull,
+ AANonNullImpl>(
+ IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) }
+};
+
+/// ------------------------ No-Recurse Attributes ----------------------------
+
+struct AANoRecurseImpl : public AANoRecurse {
+ AANoRecurseImpl(const IRPosition &IRP) : AANoRecurse(IRP) {}
+
+ /// See AbstractAttribute::getAsStr()
+ const std::string getAsStr() const override {
+ return getAssumed() ? "norecurse" : "may-recurse";
+ }
+};
+
+struct AANoRecurseFunction final : AANoRecurseImpl {
+ AANoRecurseFunction(const IRPosition &IRP) : AANoRecurseImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoRecurseImpl::initialize(A);
+ if (const Function *F = getAnchorScope())
+ if (A.getInfoCache().getSccSize(*F) == 1)
+ return;
indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
}
- auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
- auto Opcodes = {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
- (unsigned)Instruction::Call};
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
- for (unsigned Opcode : Opcodes) {
- for (Instruction *I : OpcodeInstMap[Opcode]) {
- // At this point we handled all read/write effects and they are all
- // nosync, so they can be skipped.
- if (I->mayReadOrWriteMemory())
- continue;
+ auto CheckForNoRecurse = [&](Instruction &I) {
+ ImmutableCallSite ICS(&I);
+ if (ICS.hasFnAttr(Attribute::NoRecurse))
+ return true;
- ImmutableCallSite ICS(I);
+ const auto &NoRecurseAA =
+ A.getAAFor<AANoRecurse>(*this, IRPosition::callsite_function(ICS));
+ if (!NoRecurseAA.isAssumedNoRecurse())
+ return false;
- // non-convergent and readnone imply nosync.
- if (!ICS.isConvergent())
- continue;
+ // Recursion to the same function
+ if (ICS.getCalledFunction() == getAnchorScope())
+ return false;
+
+ return true;
+ };
+
+ if (!A.checkForAllCallLikeInstructions(CheckForNoRecurse, *this))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(norecurse) }
+};
+
+/// NoRecurse attribute deduction for a call sites.
+struct AANoRecurseCallSite final : AANoRecurseImpl {
+ AANoRecurseCallSite(const IRPosition &IRP) : AANoRecurseImpl(IRP) {}
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoRecurseImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AANoRecurse>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AANoRecurse::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); }
+};
+
+/// ------------------------ Will-Return Attributes ----------------------------
+
+// Helper function that checks whether a function has any cycle.
+// TODO: Replace with more efficent code
+static bool containsCycle(Function &F) {
+ SmallPtrSet<BasicBlock *, 32> Visited;
+
+ // Traverse BB by dfs and check whether successor is already visited.
+ for (BasicBlock *BB : depth_first(&F)) {
+ Visited.insert(BB);
+ for (auto *SuccBB : successors(BB)) {
+ if (Visited.count(SuccBB))
+ return true;
}
}
+ return false;
+}
- return ChangeStatus::UNCHANGED;
+// Helper function that checks the function have a loop which might become an
+// endless loop
+// FIXME: Any cycle is regarded as endless loop for now.
+// We have to allow some patterns.
+static bool containsPossiblyEndlessLoop(Function *F) {
+ return !F || !F->hasExactDefinition() || containsCycle(*F);
}
-/// ------------------------ No-Free Attributes ----------------------------
+struct AAWillReturnImpl : public AAWillReturn {
+ AAWillReturnImpl(const IRPosition &IRP) : AAWillReturn(IRP) {}
-struct AANoFreeFunction : AbstractAttribute, BooleanState {
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AAWillReturn::initialize(A);
- /// See AbstractAttribute::AbstractAttribute(...).
- AANoFreeFunction(Function &F, InformationCache &InfoCache)
- : AbstractAttribute(F, InfoCache) {}
+ Function *F = getAssociatedFunction();
+ if (containsPossiblyEndlessLoop(F))
+ indicatePessimisticFixpoint();
+ }
- /// See AbstractAttribute::getState()
- ///{
- AbstractState &getState() override { return *this; }
- const AbstractState &getState() const override { return *this; }
- ///}
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ auto CheckForWillReturn = [&](Instruction &I) {
+ IRPosition IPos = IRPosition::callsite_function(ImmutableCallSite(&I));
+ const auto &WillReturnAA = A.getAAFor<AAWillReturn>(*this, IPos);
+ if (WillReturnAA.isKnownWillReturn())
+ return true;
+ if (!WillReturnAA.isAssumedWillReturn())
+ return false;
+ const auto &NoRecurseAA = A.getAAFor<AANoRecurse>(*this, IPos);
+ return NoRecurseAA.isAssumedNoRecurse();
+ };
+
+ if (!A.checkForAllCallLikeInstructions(CheckForWillReturn, *this))
+ return indicatePessimisticFixpoint();
+
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::getAsStr()
+ const std::string getAsStr() const override {
+ return getAssumed() ? "willreturn" : "may-noreturn";
+ }
+};
+
+struct AAWillReturnFunction final : AAWillReturnImpl {
+ AAWillReturnFunction(const IRPosition &IRP) : AAWillReturnImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(willreturn) }
+};
+
+/// WillReturn attribute deduction for a call sites.
+struct AAWillReturnCallSite final : AAWillReturnImpl {
+ AAWillReturnCallSite(const IRPosition &IRP) : AAWillReturnImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AAWillReturnImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AAWillReturn>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AAWillReturn::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(willreturn); }
+};
+
+/// ------------------------ NoAlias Argument Attribute ------------------------
+
+struct AANoAliasImpl : AANoAlias {
+ AANoAliasImpl(const IRPosition &IRP) : AANoAlias(IRP) {}
+
+ const std::string getAsStr() const override {
+ return getAssumed() ? "noalias" : "may-alias";
+ }
+};
+
+/// NoAlias attribute for a floating value.
+struct AANoAliasFloating final : AANoAliasImpl {
+ AANoAliasFloating(const IRPosition &IRP) : AANoAliasImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoAliasImpl::initialize(A);
+ Value &Val = getAssociatedValue();
+ if (isa<AllocaInst>(Val))
+ indicateOptimisticFixpoint();
+ if (isa<ConstantPointerNull>(Val) &&
+ Val.getType()->getPointerAddressSpace() == 0)
+ indicateOptimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Implement this.
+ return indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(noalias)
+ }
+};
+
+/// NoAlias attribute for an argument.
+struct AANoAliasArgument final
+ : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> {
+ AANoAliasArgument(const IRPosition &IRP)
+ : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noalias) }
+};
+
+struct AANoAliasCallSiteArgument final : AANoAliasImpl {
+ AANoAliasCallSiteArgument(const IRPosition &IRP) : AANoAliasImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // See callsite argument attribute and callee argument attribute.
+ ImmutableCallSite ICS(&getAnchorValue());
+ if (ICS.paramHasAttr(getArgNo(), Attribute::NoAlias))
+ indicateOptimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // We can deduce "noalias" if the following conditions hold.
+ // (i) Associated value is assumed to be noalias in the definition.
+ // (ii) Associated value is assumed to be no-capture in all the uses
+ // possibly executed before this callsite.
+ // (iii) There is no other pointer argument which could alias with the
+ // value.
+
+ const Value &V = getAssociatedValue();
+ const IRPosition IRP = IRPosition::value(V);
+
+ // (i) Check whether noalias holds in the definition.
+
+ auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, IRP);
+
+ if (!NoAliasAA.isAssumedNoAlias())
+ return indicatePessimisticFixpoint();
+
+ LLVM_DEBUG(dbgs() << "[Attributor][AANoAliasCSArg] " << V
+ << " is assumed NoAlias in the definition\n");
+
+ // (ii) Check whether the value is captured in the scope using AANoCapture.
+ // FIXME: This is conservative though, it is better to look at CFG and
+ // check only uses possibly executed before this callsite.
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_FUNCTION; }
+ auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, IRP);
+ if (!NoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ LLVM_DEBUG(
+ dbgs() << "[Attributor][AANoAliasCSArg] " << V
+ << " cannot be noalias as it is potentially captured\n");
+ return indicatePessimisticFixpoint();
+ }
+
+ // (iii) Check there is no other pointer argument which could alias with the
+ // value.
+ ImmutableCallSite ICS(&getAnchorValue());
+ for (unsigned i = 0; i < ICS.getNumArgOperands(); i++) {
+ if (getArgNo() == (int)i)
+ continue;
+ const Value *ArgOp = ICS.getArgOperand(i);
+ if (!ArgOp->getType()->isPointerTy())
+ continue;
+
+ if (const Function *F = getAnchorScope()) {
+ if (AAResults *AAR = A.getInfoCache().getAAResultsForFunction(*F)) {
+ bool IsAliasing = AAR->isNoAlias(&getAssociatedValue(), ArgOp);
+ LLVM_DEBUG(dbgs()
+ << "[Attributor][NoAliasCSArg] Check alias between "
+ "callsite arguments "
+ << AAR->isNoAlias(&getAssociatedValue(), ArgOp) << " "
+ << getAssociatedValue() << " " << *ArgOp << " => "
+ << (IsAliasing ? "" : "no-") << "alias \n");
+
+ if (IsAliasing)
+ continue;
+ }
+ }
+ return indicatePessimisticFixpoint();
+ }
+
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(noalias) }
+};
+
+/// NoAlias attribute for function return value.
+struct AANoAliasReturned final : AANoAliasImpl {
+ AANoAliasReturned(const IRPosition &IRP) : AANoAliasImpl(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ virtual ChangeStatus updateImpl(Attributor &A) override {
+
+ auto CheckReturnValue = [&](Value &RV) -> bool {
+ if (Constant *C = dyn_cast<Constant>(&RV))
+ if (C->isNullValue() || isa<UndefValue>(C))
+ return true;
+
+ /// For now, we can only deduce noalias if we have call sites.
+ /// FIXME: add more support.
+ ImmutableCallSite ICS(&RV);
+ if (!ICS)
+ return false;
+
+ const IRPosition &RVPos = IRPosition::value(RV);
+ const auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, RVPos);
+ if (!NoAliasAA.isAssumedNoAlias())
+ return false;
+
+ const auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, RVPos);
+ return NoCaptureAA.isAssumedNoCaptureMaybeReturned();
+ };
+
+ if (!A.checkForAllReturnedValues(CheckReturnValue, *this))
+ return indicatePessimisticFixpoint();
+
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noalias) }
+};
+
+/// NoAlias attribute deduction for a call site return value.
+struct AANoAliasCallSiteReturned final : AANoAliasImpl {
+ AANoAliasCallSiteReturned(const IRPosition &IRP) : AANoAliasImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AANoAliasImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::returned(*F);
+ auto &FnAA = A.getAAFor<AANoAlias>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(), static_cast<const AANoAlias::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); }
+};
+
+/// -------------------AAIsDead Function Attribute-----------------------
+
+struct AAIsDeadImpl : public AAIsDead {
+ AAIsDeadImpl(const IRPosition &IRP) : AAIsDead(IRP) {}
+
+ void initialize(Attributor &A) override {
+ const Function *F = getAssociatedFunction();
+ if (F && !F->isDeclaration())
+ exploreFromEntry(A, F);
+ }
+
+ void exploreFromEntry(Attributor &A, const Function *F) {
+ ToBeExploredPaths.insert(&(F->getEntryBlock().front()));
+
+ for (size_t i = 0; i < ToBeExploredPaths.size(); ++i)
+ if (const Instruction *NextNoReturnI =
+ findNextNoReturn(A, ToBeExploredPaths[i]))
+ NoReturnCalls.insert(NextNoReturnI);
+
+ // Mark the block live after we looked for no-return instructions.
+ assumeLive(A, F->getEntryBlock());
+ }
+
+ /// Find the next assumed noreturn instruction in the block of \p I starting
+ /// from, thus including, \p I.
+ ///
+ /// The caller is responsible to monitor the ToBeExploredPaths set as new
+ /// instructions discovered in other basic block will be placed in there.
+ ///
+ /// \returns The next assumed noreturn instructions in the block of \p I
+ /// starting from, thus including, \p I.
+ const Instruction *findNextNoReturn(Attributor &A, const Instruction *I);
/// See AbstractAttribute::getAsStr().
const std::string getAsStr() const override {
- return getAssumed() ? "nofree" : "may-free";
+ return "Live[#BB " + std::to_string(AssumedLiveBlocks.size()) + "/" +
+ std::to_string(getAssociatedFunction()->size()) + "][#NRI " +
+ std::to_string(NoReturnCalls.size()) + "]";
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ assert(getState().isValidState() &&
+ "Attempted to manifest an invalid state!");
+
+ ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
+ Function &F = *getAssociatedFunction();
+
+ if (AssumedLiveBlocks.empty()) {
+ A.deleteAfterManifest(F);
+ return ChangeStatus::CHANGED;
+ }
+
+ // Flag to determine if we can change an invoke to a call assuming the
+ // callee is nounwind. This is not possible if the personality of the
+ // function allows to catch asynchronous exceptions.
+ bool Invoke2CallAllowed = !mayCatchAsynchronousExceptions(F);
+
+ for (const Instruction *NRC : NoReturnCalls) {
+ Instruction *I = const_cast<Instruction *>(NRC);
+ BasicBlock *BB = I->getParent();
+ Instruction *SplitPos = I->getNextNode();
+ // TODO: mark stuff before unreachable instructions as dead.
+
+ if (auto *II = dyn_cast<InvokeInst>(I)) {
+ // If we keep the invoke the split position is at the beginning of the
+ // normal desitination block (it invokes a noreturn function after all).
+ BasicBlock *NormalDestBB = II->getNormalDest();
+ SplitPos = &NormalDestBB->front();
+
+ /// Invoke is replaced with a call and unreachable is placed after it if
+ /// the callee is nounwind and noreturn. Otherwise, we keep the invoke
+ /// and only place an unreachable in the normal successor.
+ if (Invoke2CallAllowed) {
+ if (II->getCalledFunction()) {
+ const IRPosition &IPos = IRPosition::callsite_function(*II);
+ const auto &AANoUnw = A.getAAFor<AANoUnwind>(*this, IPos);
+ if (AANoUnw.isAssumedNoUnwind()) {
+ LLVM_DEBUG(dbgs()
+ << "[AAIsDead] Replace invoke with call inst\n");
+ // We do not need an invoke (II) but instead want a call followed
+ // by an unreachable. However, we do not remove II as other
+ // abstract attributes might have it cached as part of their
+ // results. Given that we modify the CFG anyway, we simply keep II
+ // around but in a new dead block. To avoid II being live through
+ // a different edge we have to ensure the block we place it in is
+ // only reached from the current block of II and then not reached
+ // at all when we insert the unreachable.
+ SplitBlockPredecessors(NormalDestBB, {BB}, ".i2c");
+ CallInst *CI = createCallMatchingInvoke(II);
+ CI->insertBefore(II);
+ CI->takeName(II);
+ II->replaceAllUsesWith(CI);
+ SplitPos = CI->getNextNode();
+ }
+ }
+ }
+
+ if (SplitPos == &NormalDestBB->front()) {
+ // If this is an invoke of a noreturn function the edge to the normal
+ // destination block is dead but not necessarily the block itself.
+ // TODO: We need to move to an edge based system during deduction and
+ // also manifest.
+ assert(!NormalDestBB->isLandingPad() &&
+ "Expected the normal destination not to be a landingpad!");
+ if (NormalDestBB->getUniquePredecessor() == BB) {
+ assumeLive(A, *NormalDestBB);
+ } else {
+ BasicBlock *SplitBB =
+ SplitBlockPredecessors(NormalDestBB, {BB}, ".dead");
+ // The split block is live even if it contains only an unreachable
+ // instruction at the end.
+ assumeLive(A, *SplitBB);
+ SplitPos = SplitBB->getTerminator();
+ HasChanged = ChangeStatus::CHANGED;
+ }
+ }
+ }
+
+ if (isa_and_nonnull<UnreachableInst>(SplitPos))
+ continue;
+
+ BB = SplitPos->getParent();
+ SplitBlock(BB, SplitPos);
+ changeToUnreachable(BB->getTerminator(), /* UseLLVMTrap */ false);
+ HasChanged = ChangeStatus::CHANGED;
+ }
+
+ for (BasicBlock &BB : F)
+ if (!AssumedLiveBlocks.count(&BB))
+ A.deleteAfterManifest(BB);
+
+ return HasChanged;
}
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override;
- /// See AbstractAttribute::getAttrKind().
- Attribute::AttrKind getAttrKind() const override { return ID; }
+ /// See AAIsDead::isAssumedDead(BasicBlock *).
+ bool isAssumedDead(const BasicBlock *BB) const override {
+ assert(BB->getParent() == getAssociatedFunction() &&
+ "BB must be in the same anchor scope function.");
+
+ if (!getAssumed())
+ return false;
+ return !AssumedLiveBlocks.count(BB);
+ }
+
+ /// See AAIsDead::isKnownDead(BasicBlock *).
+ bool isKnownDead(const BasicBlock *BB) const override {
+ return getKnown() && isAssumedDead(BB);
+ }
+
+ /// See AAIsDead::isAssumed(Instruction *I).
+ bool isAssumedDead(const Instruction *I) const override {
+ assert(I->getParent()->getParent() == getAssociatedFunction() &&
+ "Instruction must be in the same anchor scope function.");
+
+ if (!getAssumed())
+ return false;
+
+ // If it is not in AssumedLiveBlocks then it for sure dead.
+ // Otherwise, it can still be after noreturn call in a live block.
+ if (!AssumedLiveBlocks.count(I->getParent()))
+ return true;
+
+ // If it is not after a noreturn call, than it is live.
+ return isAfterNoReturn(I);
+ }
+
+ /// See AAIsDead::isKnownDead(Instruction *I).
+ bool isKnownDead(const Instruction *I) const override {
+ return getKnown() && isAssumedDead(I);
+ }
+
+ /// Check if instruction is after noreturn call, in other words, assumed dead.
+ bool isAfterNoReturn(const Instruction *I) const;
- /// Return true if "nofree" is assumed.
- bool isAssumedNoFree() const { return getAssumed(); }
+ /// Determine if \p F might catch asynchronous exceptions.
+ static bool mayCatchAsynchronousExceptions(const Function &F) {
+ return F.hasPersonalityFn() && !canSimplifyInvokeNoUnwind(&F);
+ }
+
+ /// Assume \p BB is (partially) live now and indicate to the Attributor \p A
+ /// that internal function called from \p BB should now be looked at.
+ void assumeLive(Attributor &A, const BasicBlock &BB) {
+ if (!AssumedLiveBlocks.insert(&BB).second)
+ return;
+
+ // We assume that all of BB is (probably) live now and if there are calls to
+ // internal functions we will assume that those are now live as well. This
+ // is a performance optimization for blocks with calls to a lot of internal
+ // functions. It can however cause dead functions to be treated as live.
+ for (const Instruction &I : BB)
+ if (ImmutableCallSite ICS = ImmutableCallSite(&I))
+ if (const Function *F = ICS.getCalledFunction())
+ if (F->hasLocalLinkage())
+ A.markLiveInternalFunction(*F);
+ }
- /// Return true if "nofree" is known.
- bool isKnownNoFree() const { return getKnown(); }
+ /// Collection of to be explored paths.
+ SmallSetVector<const Instruction *, 8> ToBeExploredPaths;
- /// The identifier used by the Attributor for this class of attributes.
- static constexpr Attribute::AttrKind ID = Attribute::NoFree;
+ /// Collection of all assumed live BasicBlocks.
+ DenseSet<const BasicBlock *> AssumedLiveBlocks;
+
+ /// Collection of calls with noreturn attribute, assumed or knwon.
+ SmallSetVector<const Instruction *, 4> NoReturnCalls;
};
-ChangeStatus AANoFreeFunction::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
+struct AAIsDeadFunction final : public AAIsDeadImpl {
+ AAIsDeadFunction(const IRPosition &IRP) : AAIsDeadImpl(IRP) {}
- // The map from instruction opcodes to those instructions in the function.
- auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECL(PartiallyDeadBlocks, Function,
+ "Number of basic blocks classified as partially dead");
+ BUILD_STAT_NAME(PartiallyDeadBlocks, Function) += NoReturnCalls.size();
+ }
+};
- for (unsigned Opcode :
- {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
- (unsigned)Instruction::Call}) {
- for (Instruction *I : OpcodeInstMap[Opcode]) {
+bool AAIsDeadImpl::isAfterNoReturn(const Instruction *I) const {
+ const Instruction *PrevI = I->getPrevNode();
+ while (PrevI) {
+ if (NoReturnCalls.count(PrevI))
+ return true;
+ PrevI = PrevI->getPrevNode();
+ }
+ return false;
+}
- auto ICS = ImmutableCallSite(I);
- auto *NoFreeAA = A.getAAFor<AANoFreeFunction>(*this, *I);
+const Instruction *AAIsDeadImpl::findNextNoReturn(Attributor &A,
+ const Instruction *I) {
+ const BasicBlock *BB = I->getParent();
+ const Function &F = *BB->getParent();
- if ((!NoFreeAA || !NoFreeAA->isAssumedNoFree()) &&
- !ICS.hasFnAttr(Attribute::NoFree)) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+ // Flag to determine if we can change an invoke to a call assuming the callee
+ // is nounwind. This is not possible if the personality of the function allows
+ // to catch asynchronous exceptions.
+ bool Invoke2CallAllowed = !mayCatchAsynchronousExceptions(F);
+
+ // TODO: We should have a function that determines if an "edge" is dead.
+ // Edges could be from an instruction to the next or from a terminator
+ // to the successor. For now, we need to special case the unwind block
+ // of InvokeInst below.
+
+ while (I) {
+ ImmutableCallSite ICS(I);
+
+ if (ICS) {
+ const IRPosition &IPos = IRPosition::callsite_function(ICS);
+ // Regarless of the no-return property of an invoke instruction we only
+ // learn that the regular successor is not reachable through this
+ // instruction but the unwind block might still be.
+ if (auto *Invoke = dyn_cast<InvokeInst>(I)) {
+ // Use nounwind to justify the unwind block is dead as well.
+ const auto &AANoUnw = A.getAAFor<AANoUnwind>(*this, IPos);
+ if (!Invoke2CallAllowed || !AANoUnw.isAssumedNoUnwind()) {
+ assumeLive(A, *Invoke->getUnwindDest());
+ ToBeExploredPaths.insert(&Invoke->getUnwindDest()->front());
+ }
}
+
+ const auto &NoReturnAA = A.getAAFor<AANoReturn>(*this, IPos);
+ if (NoReturnAA.isAssumedNoReturn())
+ return I;
}
+
+ I = I->getNextNode();
}
- return ChangeStatus::UNCHANGED;
+
+ // get new paths (reachable blocks).
+ for (const BasicBlock *SuccBB : successors(BB)) {
+ assumeLive(A, *SuccBB);
+ ToBeExploredPaths.insert(&SuccBB->front());
+ }
+
+ // No noreturn instruction found.
+ return nullptr;
}
-/// ------------------------ NonNull Argument Attribute ------------------------
-struct AANonNullImpl : AANonNull, BooleanState {
+ChangeStatus AAIsDeadImpl::updateImpl(Attributor &A) {
+ ChangeStatus Status = ChangeStatus::UNCHANGED;
+
+ // Temporary collection to iterate over existing noreturn instructions. This
+ // will alow easier modification of NoReturnCalls collection
+ SmallVector<const Instruction *, 8> NoReturnChanged;
+
+ for (const Instruction *I : NoReturnCalls)
+ NoReturnChanged.push_back(I);
+
+ for (const Instruction *I : NoReturnChanged) {
+ size_t Size = ToBeExploredPaths.size();
+
+ const Instruction *NextNoReturnI = findNextNoReturn(A, I);
+ if (NextNoReturnI != I) {
+ Status = ChangeStatus::CHANGED;
+ NoReturnCalls.remove(I);
+ if (NextNoReturnI)
+ NoReturnCalls.insert(NextNoReturnI);
+ }
- AANonNullImpl(Value &V, InformationCache &InfoCache)
- : AANonNull(V, InfoCache) {}
+ // Explore new paths.
+ while (Size != ToBeExploredPaths.size()) {
+ Status = ChangeStatus::CHANGED;
+ if (const Instruction *NextNoReturnI =
+ findNextNoReturn(A, ToBeExploredPaths[Size++]))
+ NoReturnCalls.insert(NextNoReturnI);
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << "[AAIsDead] AssumedLiveBlocks: "
+ << AssumedLiveBlocks.size() << " Total number of blocks: "
+ << getAssociatedFunction()->size() << "\n");
- AANonNullImpl(Value *AssociatedVal, Value &AnchoredValue,
- InformationCache &InfoCache)
- : AANonNull(AssociatedVal, AnchoredValue, InfoCache) {}
+ // If we know everything is live there is no need to query for liveness.
+ if (NoReturnCalls.empty() &&
+ getAssociatedFunction()->size() == AssumedLiveBlocks.size()) {
+ // Indicating a pessimistic fixpoint will cause the state to be "invalid"
+ // which will cause the Attributor to not return the AAIsDead on request,
+ // which will prevent us from querying isAssumedDead().
+ indicatePessimisticFixpoint();
+ assert(!isValidState() && "Expected an invalid state!");
+ Status = ChangeStatus::CHANGED;
+ }
+
+ return Status;
+}
+
+/// Liveness information for a call sites.
+struct AAIsDeadCallSite final : AAIsDeadImpl {
+ AAIsDeadCallSite(const IRPosition &IRP) : AAIsDeadImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites instead of
+ // redirecting requests to the callee.
+ llvm_unreachable("Abstract attributes for liveness are not "
+ "supported for call sites yet!");
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ return indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {}
+};
+
+/// -------------------- Dereferenceable Argument Attribute --------------------
+
+template <>
+ChangeStatus clampStateAndIndicateChange<DerefState>(DerefState &S,
+ const DerefState &R) {
+ ChangeStatus CS0 = clampStateAndIndicateChange<IntegerState>(
+ S.DerefBytesState, R.DerefBytesState);
+ ChangeStatus CS1 =
+ clampStateAndIndicateChange<IntegerState>(S.GlobalState, R.GlobalState);
+ return CS0 | CS1;
+}
+
+struct AADereferenceableImpl : AADereferenceable {
+ AADereferenceableImpl(const IRPosition &IRP) : AADereferenceable(IRP) {}
+ using StateType = DerefState;
+
+ void initialize(Attributor &A) override {
+ SmallVector<Attribute, 4> Attrs;
+ getAttrs({Attribute::Dereferenceable, Attribute::DereferenceableOrNull},
+ Attrs);
+ for (const Attribute &Attr : Attrs)
+ takeKnownDerefBytesMaximum(Attr.getValueAsInt());
+
+ NonNullAA = &A.getAAFor<AANonNull>(*this, getIRPosition());
+
+ const IRPosition &IRP = this->getIRPosition();
+ bool IsFnInterface = IRP.isFnInterfaceKind();
+ const Function *FnScope = IRP.getAnchorScope();
+ if (IsFnInterface && (!FnScope || !FnScope->hasExactDefinition()))
+ indicatePessimisticFixpoint();
+ }
/// See AbstractAttribute::getState()
/// {
- AbstractState &getState() override { return *this; }
- const AbstractState &getState() const override { return *this; }
+ StateType &getState() override { return *this; }
+ const StateType &getState() const override { return *this; }
/// }
+ /// See AAFromMustBeExecutedContext
+ bool followUse(Attributor &A, const Use *U, const Instruction *I) {
+ bool IsNonNull = false;
+ bool TrackUse = false;
+ int64_t DerefBytes = getKnownNonNullAndDerefBytesForUse(
+ A, *this, getAssociatedValue(), U, I, IsNonNull, TrackUse);
+ takeKnownDerefBytesMaximum(DerefBytes);
+ return TrackUse;
+ }
+
+ void getDeducedAttributes(LLVMContext &Ctx,
+ SmallVectorImpl<Attribute> &Attrs) const override {
+ // TODO: Add *_globally support
+ if (isAssumedNonNull())
+ Attrs.emplace_back(Attribute::getWithDereferenceableBytes(
+ Ctx, getAssumedDereferenceableBytes()));
+ else
+ Attrs.emplace_back(Attribute::getWithDereferenceableOrNullBytes(
+ Ctx, getAssumedDereferenceableBytes()));
+ }
+
/// See AbstractAttribute::getAsStr().
const std::string getAsStr() const override {
- return getAssumed() ? "nonnull" : "may-null";
+ if (!getAssumedDereferenceableBytes())
+ return "unknown-dereferenceable";
+ return std::string("dereferenceable") +
+ (isAssumedNonNull() ? "" : "_or_null") +
+ (isAssumedGlobal() ? "_globally" : "") + "<" +
+ std::to_string(getKnownDereferenceableBytes()) + "-" +
+ std::to_string(getAssumedDereferenceableBytes()) + ">";
}
+};
+
+/// Dereferenceable attribute for a floating value.
+struct AADereferenceableFloating
+ : AAFromMustBeExecutedContext<AADereferenceable, AADereferenceableImpl> {
+ using Base =
+ AAFromMustBeExecutedContext<AADereferenceable, AADereferenceableImpl>;
+ AADereferenceableFloating(const IRPosition &IRP) : Base(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus Change = Base::updateImpl(A);
+
+ const DataLayout &DL = A.getDataLayout();
+
+ auto VisitValueCB = [&](Value &V, DerefState &T, bool Stripped) -> bool {
+ unsigned IdxWidth =
+ DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace());
+ APInt Offset(IdxWidth, 0);
+ const Value *Base =
+ V.stripAndAccumulateInBoundsConstantOffsets(DL, Offset);
+
+ const auto &AA =
+ A.getAAFor<AADereferenceable>(*this, IRPosition::value(*Base));
+ int64_t DerefBytes = 0;
+ if (!Stripped && this == &AA) {
+ // Use IR information if we did not strip anything.
+ // TODO: track globally.
+ bool CanBeNull;
+ DerefBytes = Base->getPointerDereferenceableBytes(DL, CanBeNull);
+ T.GlobalState.indicatePessimisticFixpoint();
+ } else {
+ const DerefState &DS = static_cast<const DerefState &>(AA.getState());
+ DerefBytes = DS.DerefBytesState.getAssumed();
+ T.GlobalState &= DS.GlobalState;
+ }
+
+ // For now we do not try to "increase" dereferenceability due to negative
+ // indices as we first have to come up with code to deal with loops and
+ // for overflows of the dereferenceable bytes.
+ int64_t OffsetSExt = Offset.getSExtValue();
+ if (OffsetSExt < 0)
+ OffsetSExt = 0;
+
+ T.takeAssumedDerefBytesMinimum(
+ std::max(int64_t(0), DerefBytes - OffsetSExt));
+
+ if (this == &AA) {
+ if (!Stripped) {
+ // If nothing was stripped IR information is all we got.
+ T.takeKnownDerefBytesMaximum(
+ std::max(int64_t(0), DerefBytes - OffsetSExt));
+ T.indicatePessimisticFixpoint();
+ } else if (OffsetSExt > 0) {
+ // If something was stripped but there is circular reasoning we look
+ // for the offset. If it is positive we basically decrease the
+ // dereferenceable bytes in a circluar loop now, which will simply
+ // drive them down to the known value in a very slow way which we
+ // can accelerate.
+ T.indicatePessimisticFixpoint();
+ }
+ }
+
+ return T.isValidState();
+ };
- /// See AANonNull::isAssumedNonNull().
- bool isAssumedNonNull() const override { return getAssumed(); }
+ DerefState T;
+ if (!genericValueTraversal<AADereferenceable, DerefState>(
+ A, getIRPosition(), *this, T, VisitValueCB))
+ return indicatePessimisticFixpoint();
- /// See AANonNull::isKnownNonNull().
- bool isKnownNonNull() const override { return getKnown(); }
+ return Change | clampStateAndIndicateChange(getState(), T);
+ }
- /// Generate a predicate that checks if a given value is assumed nonnull.
- /// The generated function returns true if a value satisfies any of
- /// following conditions.
- /// (i) A value is known nonZero(=nonnull).
- /// (ii) A value is associated with AANonNull and its isAssumedNonNull() is
- /// true.
- std::function<bool(Value &)> generatePredicate(Attributor &);
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(dereferenceable)
+ }
};
-std::function<bool(Value &)> AANonNullImpl::generatePredicate(Attributor &A) {
- // FIXME: The `AAReturnedValues` should provide the predicate with the
- // `ReturnInst` vector as well such that we can use the control flow sensitive
- // version of `isKnownNonZero`. This should fix `test11` in
- // `test/Transforms/FunctionAttrs/nonnull.ll`
+/// Dereferenceable attribute for a return value.
+struct AADereferenceableReturned final
+ : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl,
+ DerefState> {
+ AADereferenceableReturned(const IRPosition &IRP)
+ : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl,
+ DerefState>(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(dereferenceable)
+ }
+};
- std::function<bool(Value &)> Pred = [&](Value &RV) -> bool {
- if (isKnownNonZero(&RV, getAnchorScope().getParent()->getDataLayout()))
- return true;
+/// Dereferenceable attribute for an argument
+struct AADereferenceableArgument final
+ : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<
+ AADereferenceable, AADereferenceableImpl, DerefState> {
+ using Base = AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<
+ AADereferenceable, AADereferenceableImpl, DerefState>;
+ AADereferenceableArgument(const IRPosition &IRP) : Base(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_ARG_ATTR(dereferenceable)
+ }
+};
- auto *NonNullAA = A.getAAFor<AANonNull>(*this, RV);
+/// Dereferenceable attribute for a call site argument.
+struct AADereferenceableCallSiteArgument final : AADereferenceableFloating {
+ AADereferenceableCallSiteArgument(const IRPosition &IRP)
+ : AADereferenceableFloating(IRP) {}
- ImmutableCallSite ICS(&RV);
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSARG_ATTR(dereferenceable)
+ }
+};
- if ((!NonNullAA || !NonNullAA->isAssumedNonNull()) &&
- (!ICS || !ICS.hasRetAttr(Attribute::NonNull)))
- return false;
+/// Dereferenceable attribute deduction for a call site return value.
+struct AADereferenceableCallSiteReturned final
+ : AACallSiteReturnedFromReturnedAndMustBeExecutedContext<
+ AADereferenceable, AADereferenceableImpl> {
+ using Base = AACallSiteReturnedFromReturnedAndMustBeExecutedContext<
+ AADereferenceable, AADereferenceableImpl>;
+ AADereferenceableCallSiteReturned(const IRPosition &IRP) : Base(IRP) {}
- return true;
- };
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ Base::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
+ }
- return Pred;
-}
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+
+ ChangeStatus Change = Base::updateImpl(A);
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::returned(*F);
+ auto &FnAA = A.getAAFor<AADereferenceable>(*this, FnPos);
+ return Change |
+ clampStateAndIndicateChange(
+ getState(), static_cast<const DerefState &>(FnAA.getState()));
+ }
-/// NonNull attribute for function return value.
-struct AANonNullReturned : AANonNullImpl {
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CS_ATTR(dereferenceable);
+ }
+};
- AANonNullReturned(Function &F, InformationCache &InfoCache)
- : AANonNullImpl(F, InfoCache) {}
+// ------------------------ Align Argument Attribute ------------------------
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_RETURNED; }
+struct AAAlignImpl : AAAlign {
+ AAAlignImpl(const IRPosition &IRP) : AAAlign(IRP) {}
- /// See AbstractAttriubute::initialize(...).
+ // Max alignemnt value allowed in IR
+ static const unsigned MAX_ALIGN = 1U << 29;
+
+ /// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- Function &F = getAnchorScope();
+ takeAssumedMinimum(MAX_ALIGN);
- // Already nonnull.
- if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex,
- Attribute::NonNull))
- indicateOptimisticFixpoint();
+ SmallVector<Attribute, 4> Attrs;
+ getAttrs({Attribute::Alignment}, Attrs);
+ for (const Attribute &Attr : Attrs)
+ takeKnownMaximum(Attr.getValueAsInt());
+
+ if (getIRPosition().isFnInterfaceKind() &&
+ (!getAssociatedFunction() ||
+ !getAssociatedFunction()->hasExactDefinition()))
+ indicatePessimisticFixpoint();
}
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
+
+ // Check for users that allow alignment annotations.
+ Value &AnchorVal = getIRPosition().getAnchorValue();
+ for (const Use &U : AnchorVal.uses()) {
+ if (auto *SI = dyn_cast<StoreInst>(U.getUser())) {
+ if (SI->getPointerOperand() == &AnchorVal)
+ if (SI->getAlignment() < getAssumedAlign()) {
+ STATS_DECLTRACK(AAAlign, Store,
+ "Number of times alignemnt added to a store");
+ SI->setAlignment(Align(getAssumedAlign()));
+ Changed = ChangeStatus::CHANGED;
+ }
+ } else if (auto *LI = dyn_cast<LoadInst>(U.getUser())) {
+ if (LI->getPointerOperand() == &AnchorVal)
+ if (LI->getAlignment() < getAssumedAlign()) {
+ LI->setAlignment(Align(getAssumedAlign()));
+ STATS_DECLTRACK(AAAlign, Load,
+ "Number of times alignemnt added to a load");
+ Changed = ChangeStatus::CHANGED;
+ }
+ }
+ }
+
+ return AAAlign::manifest(A) | Changed;
+ }
+
+ // TODO: Provide a helper to determine the implied ABI alignment and check in
+ // the existing manifest method and a new one for AAAlignImpl that value
+ // to avoid making the alignment explicit if it did not improve.
+
+ /// See AbstractAttribute::getDeducedAttributes
+ virtual void
+ getDeducedAttributes(LLVMContext &Ctx,
+ SmallVectorImpl<Attribute> &Attrs) const override {
+ if (getAssumedAlign() > 1)
+ Attrs.emplace_back(
+ Attribute::getWithAlignment(Ctx, Align(getAssumedAlign())));
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr() const override {
+ return getAssumedAlign() ? ("align<" + std::to_string(getKnownAlign()) +
+ "-" + std::to_string(getAssumedAlign()) + ">")
+ : "unknown-align";
+ }
+};
+
+/// Align attribute for a floating value.
+struct AAAlignFloating : AAAlignImpl {
+ AAAlignFloating(const IRPosition &IRP) : AAAlignImpl(IRP) {}
+
/// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override;
+ ChangeStatus updateImpl(Attributor &A) override {
+ const DataLayout &DL = A.getDataLayout();
+
+ auto VisitValueCB = [&](Value &V, AAAlign::StateType &T,
+ bool Stripped) -> bool {
+ const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V));
+ if (!Stripped && this == &AA) {
+ // Use only IR information if we did not strip anything.
+ const MaybeAlign PA = V.getPointerAlignment(DL);
+ T.takeKnownMaximum(PA ? PA->value() : 0);
+ T.indicatePessimisticFixpoint();
+ } else {
+ // Use abstract attribute information.
+ const AAAlign::StateType &DS =
+ static_cast<const AAAlign::StateType &>(AA.getState());
+ T ^= DS;
+ }
+ return T.isValidState();
+ };
+
+ StateType T;
+ if (!genericValueTraversal<AAAlign, StateType>(A, getIRPosition(), *this, T,
+ VisitValueCB))
+ return indicatePessimisticFixpoint();
+
+ // TODO: If we know we visited all incoming values, thus no are assumed
+ // dead, we can take the known information from the state T.
+ return clampStateAndIndicateChange(getState(), T);
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FLOATING_ATTR(align) }
};
-ChangeStatus AANonNullReturned::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
+/// Align attribute for function return value.
+struct AAAlignReturned final
+ : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl> {
+ AAAlignReturned(const IRPosition &IRP)
+ : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>(IRP) {}
- auto *AARetVal = A.getAAFor<AAReturnedValues>(*this, F);
- if (!AARetVal) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(aligned) }
+};
+
+/// Align attribute for function argument.
+struct AAAlignArgument final
+ : AAArgumentFromCallSiteArguments<AAAlign, AAAlignImpl> {
+ AAAlignArgument(const IRPosition &IRP)
+ : AAArgumentFromCallSiteArguments<AAAlign, AAAlignImpl>(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(aligned) }
+};
+
+struct AAAlignCallSiteArgument final : AAAlignFloating {
+ AAAlignCallSiteArgument(const IRPosition &IRP) : AAAlignFloating(IRP) {}
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ return AAAlignImpl::manifest(A);
}
- std::function<bool(Value &)> Pred = this->generatePredicate(A);
- if (!AARetVal->checkForallReturnedValues(Pred)) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(aligned) }
+};
+
+/// Align attribute deduction for a call site return value.
+struct AAAlignCallSiteReturned final : AAAlignImpl {
+ AAAlignCallSiteReturned(const IRPosition &IRP) : AAAlignImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AAAlignImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F)
+ indicatePessimisticFixpoint();
}
- return ChangeStatus::UNCHANGED;
-}
-/// NonNull attribute for function argument.
-struct AANonNullArgument : AANonNullImpl {
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::returned(*F);
+ auto &FnAA = A.getAAFor<AAAlign>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(), static_cast<const AAAlign::StateType &>(FnAA.getState()));
+ }
- AANonNullArgument(Argument &A, InformationCache &InfoCache)
- : AANonNullImpl(A, InfoCache) {}
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(align); }
+};
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; }
+/// ------------------ Function No-Return Attribute ----------------------------
+struct AANoReturnImpl : public AANoReturn {
+ AANoReturnImpl(const IRPosition &IRP) : AANoReturn(IRP) {}
- /// See AbstractAttriubute::initialize(...).
+ /// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- Argument *Arg = cast<Argument>(getAssociatedValue());
- if (Arg->hasNonNullAttr())
- indicateOptimisticFixpoint();
+ AANoReturn::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F || F->hasFnAttribute(Attribute::WillReturn))
+ indicatePessimisticFixpoint();
}
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr() const override {
+ return getAssumed() ? "noreturn" : "may-return";
+ }
+
+ /// See AbstractAttribute::updateImpl(Attributor &A).
+ virtual ChangeStatus updateImpl(Attributor &A) override {
+ const auto &WillReturnAA = A.getAAFor<AAWillReturn>(*this, getIRPosition());
+ if (WillReturnAA.isKnownWillReturn())
+ return indicatePessimisticFixpoint();
+ auto CheckForNoReturn = [](Instruction &) { return false; };
+ if (!A.checkForAllInstructions(CheckForNoReturn, *this,
+ {(unsigned)Instruction::Ret}))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+};
+
+struct AANoReturnFunction final : AANoReturnImpl {
+ AANoReturnFunction(const IRPosition &IRP) : AANoReturnImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(noreturn) }
+};
+
+/// NoReturn attribute deduction for a call sites.
+struct AANoReturnCallSite final : AANoReturnImpl {
+ AANoReturnCallSite(const IRPosition &IRP) : AANoReturnImpl(IRP) {}
+
/// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override;
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AANoReturn::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); }
};
-/// NonNull attribute for a call site argument.
-struct AANonNullCallSiteArgument : AANonNullImpl {
+/// ----------------------- Variable Capturing ---------------------------------
- /// See AANonNullImpl::AANonNullImpl(...).
- AANonNullCallSiteArgument(CallSite CS, unsigned ArgNo,
- InformationCache &InfoCache)
- : AANonNullImpl(CS.getArgOperand(ArgNo), *CS.getInstruction(), InfoCache),
- ArgNo(ArgNo) {}
+/// A class to hold the state of for no-capture attributes.
+struct AANoCaptureImpl : public AANoCapture {
+ AANoCaptureImpl(const IRPosition &IRP) : AANoCapture(IRP) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- CallSite CS(&getAnchoredValue());
- if (isKnownNonZero(getAssociatedValue(),
- getAnchorScope().getParent()->getDataLayout()) ||
- CS.paramHasAttr(ArgNo, getAttrKind()))
+ AANoCapture::initialize(A);
+
+ // You cannot "capture" null in the default address space.
+ if (isa<ConstantPointerNull>(getAssociatedValue()) &&
+ getAssociatedValue().getType()->getPointerAddressSpace() == 0) {
indicateOptimisticFixpoint();
+ return;
+ }
+
+ const IRPosition &IRP = getIRPosition();
+ const Function *F =
+ getArgNo() >= 0 ? IRP.getAssociatedFunction() : IRP.getAnchorScope();
+
+ // Check what state the associated function can actually capture.
+ if (F)
+ determineFunctionCaptureCapabilities(IRP, *F, *this);
+ else
+ indicatePessimisticFixpoint();
}
- /// See AbstractAttribute::updateImpl(Attributor &A).
+ /// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override;
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override {
- return MP_CALL_SITE_ARGUMENT;
- };
+ /// see AbstractAttribute::isAssumedNoCaptureMaybeReturned(...).
+ virtual void
+ getDeducedAttributes(LLVMContext &Ctx,
+ SmallVectorImpl<Attribute> &Attrs) const override {
+ if (!isAssumedNoCaptureMaybeReturned())
+ return;
+
+ if (getArgNo() >= 0) {
+ if (isAssumedNoCapture())
+ Attrs.emplace_back(Attribute::get(Ctx, Attribute::NoCapture));
+ else if (ManifestInternal)
+ Attrs.emplace_back(Attribute::get(Ctx, "no-capture-maybe-returned"));
+ }
+ }
+
+ /// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known
+ /// depending on the ability of the function associated with \p IRP to capture
+ /// state in memory and through "returning/throwing", respectively.
+ static void determineFunctionCaptureCapabilities(const IRPosition &IRP,
+ const Function &F,
+ IntegerState &State) {
+ // TODO: Once we have memory behavior attributes we should use them here.
+
+ // If we know we cannot communicate or write to memory, we do not care about
+ // ptr2int anymore.
+ if (F.onlyReadsMemory() && F.doesNotThrow() &&
+ F.getReturnType()->isVoidTy()) {
+ State.addKnownBits(NO_CAPTURE);
+ return;
+ }
+
+ // A function cannot capture state in memory if it only reads memory, it can
+ // however return/throw state and the state might be influenced by the
+ // pointer value, e.g., loading from a returned pointer might reveal a bit.
+ if (F.onlyReadsMemory())
+ State.addKnownBits(NOT_CAPTURED_IN_MEM);
+
+ // A function cannot communicate state back if it does not through
+ // exceptions and doesn not return values.
+ if (F.doesNotThrow() && F.getReturnType()->isVoidTy())
+ State.addKnownBits(NOT_CAPTURED_IN_RET);
+
+ // Check existing "returned" attributes.
+ int ArgNo = IRP.getArgNo();
+ if (F.doesNotThrow() && ArgNo >= 0) {
+ for (unsigned u = 0, e = F.arg_size(); u< e; ++u)
+ if (F.hasParamAttribute(u, Attribute::Returned)) {
+ if (u == unsigned(ArgNo))
+ State.removeAssumedBits(NOT_CAPTURED_IN_RET);
+ else if (F.onlyReadsMemory())
+ State.addKnownBits(NO_CAPTURE);
+ else
+ State.addKnownBits(NOT_CAPTURED_IN_RET);
+ break;
+ }
+ }
+ }
- // Return argument index of associated value.
- int getArgNo() const { return ArgNo; }
+ /// See AbstractState::getAsStr().
+ const std::string getAsStr() const override {
+ if (isKnownNoCapture())
+ return "known not-captured";
+ if (isAssumedNoCapture())
+ return "assumed not-captured";
+ if (isKnownNoCaptureMaybeReturned())
+ return "known not-captured-maybe-returned";
+ if (isAssumedNoCaptureMaybeReturned())
+ return "assumed not-captured-maybe-returned";
+ return "assumed-captured";
+ }
+};
+
+/// Attributor-aware capture tracker.
+struct AACaptureUseTracker final : public CaptureTracker {
+
+ /// Create a capture tracker that can lookup in-flight abstract attributes
+ /// through the Attributor \p A.
+ ///
+ /// If a use leads to a potential capture, \p CapturedInMemory is set and the
+ /// search is stopped. If a use leads to a return instruction,
+ /// \p CommunicatedBack is set to true and \p CapturedInMemory is not changed.
+ /// If a use leads to a ptr2int which may capture the value,
+ /// \p CapturedInInteger is set. If a use is found that is currently assumed
+ /// "no-capture-maybe-returned", the user is added to the \p PotentialCopies
+ /// set. All values in \p PotentialCopies are later tracked as well. For every
+ /// explored use we decrement \p RemainingUsesToExplore. Once it reaches 0,
+ /// the search is stopped with \p CapturedInMemory and \p CapturedInInteger
+ /// conservatively set to true.
+ AACaptureUseTracker(Attributor &A, AANoCapture &NoCaptureAA,
+ const AAIsDead &IsDeadAA, IntegerState &State,
+ SmallVectorImpl<const Value *> &PotentialCopies,
+ unsigned &RemainingUsesToExplore)
+ : A(A), NoCaptureAA(NoCaptureAA), IsDeadAA(IsDeadAA), State(State),
+ PotentialCopies(PotentialCopies),
+ RemainingUsesToExplore(RemainingUsesToExplore) {}
+
+ /// Determine if \p V maybe captured. *Also updates the state!*
+ bool valueMayBeCaptured(const Value *V) {
+ if (V->getType()->isPointerTy()) {
+ PointerMayBeCaptured(V, this);
+ } else {
+ State.indicatePessimisticFixpoint();
+ }
+ return State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED);
+ }
+
+ /// See CaptureTracker::tooManyUses().
+ void tooManyUses() override {
+ State.removeAssumedBits(AANoCapture::NO_CAPTURE);
+ }
+
+ bool isDereferenceableOrNull(Value *O, const DataLayout &DL) override {
+ if (CaptureTracker::isDereferenceableOrNull(O, DL))
+ return true;
+ const auto &DerefAA =
+ A.getAAFor<AADereferenceable>(NoCaptureAA, IRPosition::value(*O));
+ return DerefAA.getAssumedDereferenceableBytes();
+ }
+
+ /// See CaptureTracker::captured(...).
+ bool captured(const Use *U) override {
+ Instruction *UInst = cast<Instruction>(U->getUser());
+ LLVM_DEBUG(dbgs() << "Check use: " << *U->get() << " in " << *UInst
+ << "\n");
+
+ // Because we may reuse the tracker multiple times we keep track of the
+ // number of explored uses ourselves as well.
+ if (RemainingUsesToExplore-- == 0) {
+ LLVM_DEBUG(dbgs() << " - too many uses to explore!\n");
+ return isCapturedIn(/* Memory */ true, /* Integer */ true,
+ /* Return */ true);
+ }
+
+ // Deal with ptr2int by following uses.
+ if (isa<PtrToIntInst>(UInst)) {
+ LLVM_DEBUG(dbgs() << " - ptr2int assume the worst!\n");
+ return valueMayBeCaptured(UInst);
+ }
+
+ // Explicitly catch return instructions.
+ if (isa<ReturnInst>(UInst))
+ return isCapturedIn(/* Memory */ false, /* Integer */ false,
+ /* Return */ true);
+
+ // For now we only use special logic for call sites. However, the tracker
+ // itself knows about a lot of other non-capturing cases already.
+ CallSite CS(UInst);
+ if (!CS || !CS.isArgOperand(U))
+ return isCapturedIn(/* Memory */ true, /* Integer */ true,
+ /* Return */ true);
+
+ unsigned ArgNo = CS.getArgumentNo(U);
+ const IRPosition &CSArgPos = IRPosition::callsite_argument(CS, ArgNo);
+ // If we have a abstract no-capture attribute for the argument we can use
+ // it to justify a non-capture attribute here. This allows recursion!
+ auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>(NoCaptureAA, CSArgPos);
+ if (ArgNoCaptureAA.isAssumedNoCapture())
+ return isCapturedIn(/* Memory */ false, /* Integer */ false,
+ /* Return */ false);
+ if (ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ addPotentialCopy(CS);
+ return isCapturedIn(/* Memory */ false, /* Integer */ false,
+ /* Return */ false);
+ }
+
+ // Lastly, we could not find a reason no-capture can be assumed so we don't.
+ return isCapturedIn(/* Memory */ true, /* Integer */ true,
+ /* Return */ true);
+ }
+
+ /// Register \p CS as potential copy of the value we are checking.
+ void addPotentialCopy(CallSite CS) {
+ PotentialCopies.push_back(CS.getInstruction());
+ }
+
+ /// See CaptureTracker::shouldExplore(...).
+ bool shouldExplore(const Use *U) override {
+ // Check liveness.
+ return !IsDeadAA.isAssumedDead(cast<Instruction>(U->getUser()));
+ }
+
+ /// Update the state according to \p CapturedInMem, \p CapturedInInt, and
+ /// \p CapturedInRet, then return the appropriate value for use in the
+ /// CaptureTracker::captured() interface.
+ bool isCapturedIn(bool CapturedInMem, bool CapturedInInt,
+ bool CapturedInRet) {
+ LLVM_DEBUG(dbgs() << " - captures [Mem " << CapturedInMem << "|Int "
+ << CapturedInInt << "|Ret " << CapturedInRet << "]\n");
+ if (CapturedInMem)
+ State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_MEM);
+ if (CapturedInInt)
+ State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_INT);
+ if (CapturedInRet)
+ State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_RET);
+ return !State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED);
+ }
private:
- unsigned ArgNo;
+ /// The attributor providing in-flight abstract attributes.
+ Attributor &A;
+
+ /// The abstract attribute currently updated.
+ AANoCapture &NoCaptureAA;
+
+ /// The abstract liveness state.
+ const AAIsDead &IsDeadAA;
+
+ /// The state currently updated.
+ IntegerState &State;
+
+ /// Set of potential copies of the tracked value.
+ SmallVectorImpl<const Value *> &PotentialCopies;
+
+ /// Global counter to limit the number of explored uses.
+ unsigned &RemainingUsesToExplore;
+};
+
+ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) {
+ const IRPosition &IRP = getIRPosition();
+ const Value *V =
+ getArgNo() >= 0 ? IRP.getAssociatedArgument() : &IRP.getAssociatedValue();
+ if (!V)
+ return indicatePessimisticFixpoint();
+
+ const Function *F =
+ getArgNo() >= 0 ? IRP.getAssociatedFunction() : IRP.getAnchorScope();
+ assert(F && "Expected a function!");
+ const IRPosition &FnPos = IRPosition::function(*F);
+ const auto &IsDeadAA = A.getAAFor<AAIsDead>(*this, FnPos);
+
+ AANoCapture::StateType T;
+
+ // Readonly means we cannot capture through memory.
+ const auto &FnMemAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos);
+ if (FnMemAA.isAssumedReadOnly()) {
+ T.addKnownBits(NOT_CAPTURED_IN_MEM);
+ if (FnMemAA.isKnownReadOnly())
+ addKnownBits(NOT_CAPTURED_IN_MEM);
+ }
+
+ // Make sure all returned values are different than the underlying value.
+ // TODO: we could do this in a more sophisticated way inside
+ // AAReturnedValues, e.g., track all values that escape through returns
+ // directly somehow.
+ auto CheckReturnedArgs = [&](const AAReturnedValues &RVAA) {
+ bool SeenConstant = false;
+ for (auto &It : RVAA.returned_values()) {
+ if (isa<Constant>(It.first)) {
+ if (SeenConstant)
+ return false;
+ SeenConstant = true;
+ } else if (!isa<Argument>(It.first) ||
+ It.first == getAssociatedArgument())
+ return false;
+ }
+ return true;
+ };
+
+ const auto &NoUnwindAA = A.getAAFor<AANoUnwind>(*this, FnPos);
+ if (NoUnwindAA.isAssumedNoUnwind()) {
+ bool IsVoidTy = F->getReturnType()->isVoidTy();
+ const AAReturnedValues *RVAA =
+ IsVoidTy ? nullptr : &A.getAAFor<AAReturnedValues>(*this, FnPos);
+ if (IsVoidTy || CheckReturnedArgs(*RVAA)) {
+ T.addKnownBits(NOT_CAPTURED_IN_RET);
+ if (T.isKnown(NOT_CAPTURED_IN_MEM))
+ return ChangeStatus::UNCHANGED;
+ if (NoUnwindAA.isKnownNoUnwind() &&
+ (IsVoidTy || RVAA->getState().isAtFixpoint())) {
+ addKnownBits(NOT_CAPTURED_IN_RET);
+ if (isKnown(NOT_CAPTURED_IN_MEM))
+ return indicateOptimisticFixpoint();
+ }
+ }
+ }
+
+ // Use the CaptureTracker interface and logic with the specialized tracker,
+ // defined in AACaptureUseTracker, that can look at in-flight abstract
+ // attributes and directly updates the assumed state.
+ SmallVector<const Value *, 4> PotentialCopies;
+ unsigned RemainingUsesToExplore = DefaultMaxUsesToExplore;
+ AACaptureUseTracker Tracker(A, *this, IsDeadAA, T, PotentialCopies,
+ RemainingUsesToExplore);
+
+ // Check all potential copies of the associated value until we can assume
+ // none will be captured or we have to assume at least one might be.
+ unsigned Idx = 0;
+ PotentialCopies.push_back(V);
+ while (T.isAssumed(NO_CAPTURE_MAYBE_RETURNED) && Idx < PotentialCopies.size())
+ Tracker.valueMayBeCaptured(PotentialCopies[Idx++]);
+
+ AAAlign::StateType &S = getState();
+ auto Assumed = S.getAssumed();
+ S.intersectAssumedBits(T.getAssumed());
+ return Assumed == S.getAssumed() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+}
+
+/// NoCapture attribute for function arguments.
+struct AANoCaptureArgument final : AANoCaptureImpl {
+ AANoCaptureArgument(const IRPosition &IRP) : AANoCaptureImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nocapture) }
+};
+
+/// NoCapture attribute for call site arguments.
+struct AANoCaptureCallSiteArgument final : AANoCaptureImpl {
+ AANoCaptureCallSiteArgument(const IRPosition &IRP) : AANoCaptureImpl(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Argument *Arg = getAssociatedArgument();
+ if (!Arg)
+ return indicatePessimisticFixpoint();
+ const IRPosition &ArgPos = IRPosition::argument(*Arg);
+ auto &ArgAA = A.getAAFor<AANoCapture>(*this, ArgPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AANoCapture::StateType &>(ArgAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override{STATS_DECLTRACK_CSARG_ATTR(nocapture)};
+};
+
+/// NoCapture attribute for floating values.
+struct AANoCaptureFloating final : AANoCaptureImpl {
+ AANoCaptureFloating(const IRPosition &IRP) : AANoCaptureImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(nocapture)
+ }
+};
+
+/// NoCapture attribute for function return value.
+struct AANoCaptureReturned final : AANoCaptureImpl {
+ AANoCaptureReturned(const IRPosition &IRP) : AANoCaptureImpl(IRP) {
+ llvm_unreachable("NoCapture is not applicable to function returns!");
+ }
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ llvm_unreachable("NoCapture is not applicable to function returns!");
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ llvm_unreachable("NoCapture is not applicable to function returns!");
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {}
+};
+
+/// NoCapture attribute deduction for a call site return value.
+struct AANoCaptureCallSiteReturned final : AANoCaptureImpl {
+ AANoCaptureCallSiteReturned(const IRPosition &IRP) : AANoCaptureImpl(IRP) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSRET_ATTR(nocapture)
+ }
};
-ChangeStatus AANonNullArgument::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
- Argument &Arg = cast<Argument>(getAnchoredValue());
- unsigned ArgNo = Arg.getArgNo();
+/// ------------------ Value Simplify Attribute ----------------------------
+struct AAValueSimplifyImpl : AAValueSimplify {
+ AAValueSimplifyImpl(const IRPosition &IRP) : AAValueSimplify(IRP) {}
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr() const override {
+ return getAssumed() ? (getKnown() ? "simplified" : "maybe-simple")
+ : "not-simple";
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {}
+
+ /// See AAValueSimplify::getAssumedSimplifiedValue()
+ Optional<Value *> getAssumedSimplifiedValue(Attributor &A) const override {
+ if (!getAssumed())
+ return const_cast<Value *>(&getAssociatedValue());
+ return SimplifiedAssociatedValue;
+ }
+ void initialize(Attributor &A) override {}
+
+ /// Helper function for querying AAValueSimplify and updating candicate.
+ /// \param QueryingValue Value trying to unify with SimplifiedValue
+ /// \param AccumulatedSimplifiedValue Current simplification result.
+ static bool checkAndUpdate(Attributor &A, const AbstractAttribute &QueryingAA,
+ Value &QueryingValue,
+ Optional<Value *> &AccumulatedSimplifiedValue) {
+ // FIXME: Add a typecast support.
+
+ auto &ValueSimpifyAA = A.getAAFor<AAValueSimplify>(
+ QueryingAA, IRPosition::value(QueryingValue));
- // Callback function
- std::function<bool(CallSite)> CallSiteCheck = [&](CallSite CS) {
- assert(CS && "Sanity check: Call site was not initialized properly!");
+ Optional<Value *> QueryingValueSimplified =
+ ValueSimpifyAA.getAssumedSimplifiedValue(A);
- auto *NonNullAA = A.getAAFor<AANonNull>(*this, *CS.getInstruction(), ArgNo);
+ if (!QueryingValueSimplified.hasValue())
+ return true;
- // Check that NonNullAA is AANonNullCallSiteArgument.
- if (NonNullAA) {
- ImmutableCallSite ICS(&NonNullAA->getAnchoredValue());
- if (ICS && CS.getInstruction() == ICS.getInstruction())
- return NonNullAA->isAssumedNonNull();
+ if (!QueryingValueSimplified.getValue())
return false;
+
+ Value &QueryingValueSimplifiedUnwrapped =
+ *QueryingValueSimplified.getValue();
+
+ if (isa<UndefValue>(QueryingValueSimplifiedUnwrapped))
+ return true;
+
+ if (AccumulatedSimplifiedValue.hasValue())
+ return AccumulatedSimplifiedValue == QueryingValueSimplified;
+
+ LLVM_DEBUG(dbgs() << "[Attributor][ValueSimplify] " << QueryingValue
+ << " is assumed to be "
+ << QueryingValueSimplifiedUnwrapped << "\n");
+
+ AccumulatedSimplifiedValue = QueryingValueSimplified;
+ return true;
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
+
+ if (!SimplifiedAssociatedValue.hasValue() ||
+ !SimplifiedAssociatedValue.getValue())
+ return Changed;
+
+ if (auto *C = dyn_cast<Constant>(SimplifiedAssociatedValue.getValue())) {
+ // We can replace the AssociatedValue with the constant.
+ Value &V = getAssociatedValue();
+ if (!V.user_empty() && &V != C && V.getType() == C->getType()) {
+ LLVM_DEBUG(dbgs() << "[Attributor][ValueSimplify] " << V << " -> " << *C
+ << "\n");
+ V.replaceAllUsesWith(C);
+ Changed = ChangeStatus::CHANGED;
+ }
+ }
+
+ return Changed | AAValueSimplify::manifest(A);
+ }
+
+protected:
+ // An assumed simplified value. Initially, it is set to Optional::None, which
+ // means that the value is not clear under current assumption. If in the
+ // pessimistic state, getAssumedSimplifiedValue doesn't return this value but
+ // returns orignal associated value.
+ Optional<Value *> SimplifiedAssociatedValue;
+};
+
+struct AAValueSimplifyArgument final : AAValueSimplifyImpl {
+ AAValueSimplifyArgument(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
+
+ auto PredForCallSite = [&](AbstractCallSite ACS) {
+ // Check if we have an associated argument or not (which can happen for
+ // callback calls).
+ if (Value *ArgOp = ACS.getCallArgOperand(getArgNo()))
+ return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue);
+ return false;
+ };
+
+ if (!A.checkForAllCallSites(PredForCallSite, *this, true))
+ return indicatePessimisticFixpoint();
+
+ // If a candicate was found in this update, return CHANGED.
+ return HasValueBefore == SimplifiedAssociatedValue.hasValue()
+ ? ChangeStatus::UNCHANGED
+ : ChangeStatus ::CHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_ARG_ATTR(value_simplify)
+ }
+};
+
+struct AAValueSimplifyReturned : AAValueSimplifyImpl {
+ AAValueSimplifyReturned(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
+
+ auto PredForReturned = [&](Value &V) {
+ return checkAndUpdate(A, *this, V, SimplifiedAssociatedValue);
+ };
+
+ if (!A.checkForAllReturnedValues(PredForReturned, *this))
+ return indicatePessimisticFixpoint();
+
+ // If a candicate was found in this update, return CHANGED.
+ return HasValueBefore == SimplifiedAssociatedValue.hasValue()
+ ? ChangeStatus::UNCHANGED
+ : ChangeStatus ::CHANGED;
+ }
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(value_simplify)
+ }
+};
+
+struct AAValueSimplifyFloating : AAValueSimplifyImpl {
+ AAValueSimplifyFloating(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ Value &V = getAnchorValue();
+
+ // TODO: add other stuffs
+ if (isa<Constant>(V) || isa<UndefValue>(V))
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
+
+ auto VisitValueCB = [&](Value &V, BooleanState, bool Stripped) -> bool {
+ auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V));
+ if (!Stripped && this == &AA) {
+ // TODO: Look the instruction and check recursively.
+ LLVM_DEBUG(
+ dbgs() << "[Attributor][ValueSimplify] Can't be stripped more : "
+ << V << "\n");
+ indicatePessimisticFixpoint();
+ return false;
+ }
+ return checkAndUpdate(A, *this, V, SimplifiedAssociatedValue);
+ };
+
+ if (!genericValueTraversal<AAValueSimplify, BooleanState>(
+ A, getIRPosition(), *this, static_cast<BooleanState &>(*this),
+ VisitValueCB))
+ return indicatePessimisticFixpoint();
+
+ // If a candicate was found in this update, return CHANGED.
+
+ return HasValueBefore == SimplifiedAssociatedValue.hasValue()
+ ? ChangeStatus::UNCHANGED
+ : ChangeStatus ::CHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(value_simplify)
+ }
+};
+
+struct AAValueSimplifyFunction : AAValueSimplifyImpl {
+ AAValueSimplifyFunction(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ SimplifiedAssociatedValue = &getAnchorValue();
+ indicateOptimisticFixpoint();
+ }
+ /// See AbstractAttribute::initialize(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ llvm_unreachable(
+ "AAValueSimplify(Function|CallSite)::updateImpl will not be called");
+ }
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FN_ATTR(value_simplify)
+ }
+};
+
+struct AAValueSimplifyCallSite : AAValueSimplifyFunction {
+ AAValueSimplifyCallSite(const IRPosition &IRP)
+ : AAValueSimplifyFunction(IRP) {}
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CS_ATTR(value_simplify)
+ }
+};
+
+struct AAValueSimplifyCallSiteReturned : AAValueSimplifyReturned {
+ AAValueSimplifyCallSiteReturned(const IRPosition &IRP)
+ : AAValueSimplifyReturned(IRP) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSRET_ATTR(value_simplify)
+ }
+};
+struct AAValueSimplifyCallSiteArgument : AAValueSimplifyFloating {
+ AAValueSimplifyCallSiteArgument(const IRPosition &IRP)
+ : AAValueSimplifyFloating(IRP) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSARG_ATTR(value_simplify)
+ }
+};
+
+/// ----------------------- Heap-To-Stack Conversion ---------------------------
+struct AAHeapToStackImpl : public AAHeapToStack {
+ AAHeapToStackImpl(const IRPosition &IRP) : AAHeapToStack(IRP) {}
+
+ const std::string getAsStr() const override {
+ return "[H2S] Mallocs: " + std::to_string(MallocCalls.size());
+ }
+
+ ChangeStatus manifest(Attributor &A) override {
+ assert(getState().isValidState() &&
+ "Attempted to manifest an invalid state!");
+
+ ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
+ Function *F = getAssociatedFunction();
+ const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F);
+
+ for (Instruction *MallocCall : MallocCalls) {
+ // This malloc cannot be replaced.
+ if (BadMallocCalls.count(MallocCall))
+ continue;
+
+ for (Instruction *FreeCall : FreesForMalloc[MallocCall]) {
+ LLVM_DEBUG(dbgs() << "H2S: Removing free call: " << *FreeCall << "\n");
+ A.deleteAfterManifest(*FreeCall);
+ HasChanged = ChangeStatus::CHANGED;
+ }
+
+ LLVM_DEBUG(dbgs() << "H2S: Removing malloc call: " << *MallocCall
+ << "\n");
+
+ Constant *Size;
+ if (isCallocLikeFn(MallocCall, TLI)) {
+ auto *Num = cast<ConstantInt>(MallocCall->getOperand(0));
+ auto *SizeT = dyn_cast<ConstantInt>(MallocCall->getOperand(1));
+ APInt TotalSize = SizeT->getValue() * Num->getValue();
+ Size =
+ ConstantInt::get(MallocCall->getOperand(0)->getType(), TotalSize);
+ } else {
+ Size = cast<ConstantInt>(MallocCall->getOperand(0));
+ }
+
+ unsigned AS = cast<PointerType>(MallocCall->getType())->getAddressSpace();
+ Instruction *AI = new AllocaInst(Type::getInt8Ty(F->getContext()), AS,
+ Size, "", MallocCall->getNextNode());
+
+ if (AI->getType() != MallocCall->getType())
+ AI = new BitCastInst(AI, MallocCall->getType(), "malloc_bc",
+ AI->getNextNode());
+
+ MallocCall->replaceAllUsesWith(AI);
+
+ if (auto *II = dyn_cast<InvokeInst>(MallocCall)) {
+ auto *NBB = II->getNormalDest();
+ BranchInst::Create(NBB, MallocCall->getParent());
+ A.deleteAfterManifest(*MallocCall);
+ } else {
+ A.deleteAfterManifest(*MallocCall);
+ }
+
+ if (isCallocLikeFn(MallocCall, TLI)) {
+ auto *BI = new BitCastInst(AI, MallocCall->getType(), "calloc_bc",
+ AI->getNextNode());
+ Value *Ops[] = {
+ BI, ConstantInt::get(F->getContext(), APInt(8, 0, false)), Size,
+ ConstantInt::get(Type::getInt1Ty(F->getContext()), false)};
+
+ Type *Tys[] = {BI->getType(), MallocCall->getOperand(0)->getType()};
+ Module *M = F->getParent();
+ Function *Fn = Intrinsic::getDeclaration(M, Intrinsic::memset, Tys);
+ CallInst::Create(Fn, Ops, "", BI->getNextNode());
+ }
+ HasChanged = ChangeStatus::CHANGED;
}
- if (CS.paramHasAttr(ArgNo, Attribute::NonNull))
+ return HasChanged;
+ }
+
+ /// Collection of all malloc calls in a function.
+ SmallSetVector<Instruction *, 4> MallocCalls;
+
+ /// Collection of malloc calls that cannot be converted.
+ DenseSet<const Instruction *> BadMallocCalls;
+
+ /// A map for each malloc call to the set of associated free calls.
+ DenseMap<Instruction *, SmallPtrSet<Instruction *, 4>> FreesForMalloc;
+
+ ChangeStatus updateImpl(Attributor &A) override;
+};
+
+ChangeStatus AAHeapToStackImpl::updateImpl(Attributor &A) {
+ const Function *F = getAssociatedFunction();
+ const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F);
+
+ auto UsesCheck = [&](Instruction &I) {
+ SmallPtrSet<const Use *, 8> Visited;
+ SmallVector<const Use *, 8> Worklist;
+
+ for (Use &U : I.uses())
+ Worklist.push_back(&U);
+
+ while (!Worklist.empty()) {
+ const Use *U = Worklist.pop_back_val();
+ if (!Visited.insert(U).second)
+ continue;
+
+ auto *UserI = U->getUser();
+
+ if (isa<LoadInst>(UserI))
+ continue;
+ if (auto *SI = dyn_cast<StoreInst>(UserI)) {
+ if (SI->getValueOperand() == U->get()) {
+ LLVM_DEBUG(dbgs() << "[H2S] escaping store to memory: " << *UserI << "\n");
+ return false;
+ }
+ // A store into the malloc'ed memory is fine.
+ continue;
+ }
+
+ // NOTE: Right now, if a function that has malloc pointer as an argument
+ // frees memory, we assume that the malloc pointer is freed.
+
+ // TODO: Add nofree callsite argument attribute to indicate that pointer
+ // argument is not freed.
+ if (auto *CB = dyn_cast<CallBase>(UserI)) {
+ if (!CB->isArgOperand(U))
+ continue;
+
+ if (CB->isLifetimeStartOrEnd())
+ continue;
+
+ // Record malloc.
+ if (isFreeCall(UserI, TLI)) {
+ FreesForMalloc[&I].insert(
+ cast<Instruction>(const_cast<User *>(UserI)));
+ continue;
+ }
+
+ // If a function does not free memory we are fine
+ const auto &NoFreeAA =
+ A.getAAFor<AANoFree>(*this, IRPosition::callsite_function(*CB));
+
+ unsigned ArgNo = U - CB->arg_begin();
+ const auto &NoCaptureAA = A.getAAFor<AANoCapture>(
+ *this, IRPosition::callsite_argument(*CB, ArgNo));
+
+ if (!NoCaptureAA.isAssumedNoCapture() || !NoFreeAA.isAssumedNoFree()) {
+ LLVM_DEBUG(dbgs() << "[H2S] Bad user: " << *UserI << "\n");
+ return false;
+ }
+ continue;
+ }
+
+ if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI)) {
+ for (Use &U : UserI->uses())
+ Worklist.push_back(&U);
+ continue;
+ }
+
+ // Unknown user.
+ LLVM_DEBUG(dbgs() << "[H2S] Unknown user: " << *UserI << "\n");
+ return false;
+ }
+ return true;
+ };
+
+ auto MallocCallocCheck = [&](Instruction &I) {
+ if (BadMallocCalls.count(&I))
return true;
- Value *V = CS.getArgOperand(ArgNo);
- if (isKnownNonZero(V, getAnchorScope().getParent()->getDataLayout()))
+ bool IsMalloc = isMallocLikeFn(&I, TLI);
+ bool IsCalloc = !IsMalloc && isCallocLikeFn(&I, TLI);
+ if (!IsMalloc && !IsCalloc) {
+ BadMallocCalls.insert(&I);
return true;
+ }
- return false;
- };
- if (!A.checkForAllCallSites(F, CallSiteCheck, true)) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
- }
- return ChangeStatus::UNCHANGED;
-}
+ if (IsMalloc) {
+ if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(0)))
+ if (Size->getValue().sle(MaxHeapToStackSize))
+ if (UsesCheck(I)) {
+ MallocCalls.insert(&I);
+ return true;
+ }
+ } else if (IsCalloc) {
+ bool Overflow = false;
+ if (auto *Num = dyn_cast<ConstantInt>(I.getOperand(0)))
+ if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(1)))
+ if ((Size->getValue().umul_ov(Num->getValue(), Overflow))
+ .sle(MaxHeapToStackSize))
+ if (!Overflow && UsesCheck(I)) {
+ MallocCalls.insert(&I);
+ return true;
+ }
+ }
-ChangeStatus AANonNullCallSiteArgument::updateImpl(Attributor &A) {
- // NOTE: Never look at the argument of the callee in this method.
- // If we do this, "nonnull" is always deduced because of the assumption.
+ BadMallocCalls.insert(&I);
+ return true;
+ };
- Value &V = *getAssociatedValue();
+ size_t NumBadMallocs = BadMallocCalls.size();
- auto *NonNullAA = A.getAAFor<AANonNull>(*this, V);
+ A.checkForAllCallLikeInstructions(MallocCallocCheck, *this);
- if (!NonNullAA || !NonNullAA->isAssumedNonNull()) {
- indicatePessimisticFixpoint();
+ if (NumBadMallocs != BadMallocCalls.size())
return ChangeStatus::CHANGED;
- }
return ChangeStatus::UNCHANGED;
}
-/// ------------------------ Will-Return Attributes ----------------------------
+struct AAHeapToStackFunction final : public AAHeapToStackImpl {
+ AAHeapToStackFunction(const IRPosition &IRP) : AAHeapToStackImpl(IRP) {}
-struct AAWillReturnImpl : public AAWillReturn, BooleanState {
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECL(MallocCalls, Function,
+ "Number of MallocCalls converted to allocas");
+ BUILD_STAT_NAME(MallocCalls, Function) += MallocCalls.size();
+ }
+};
+
+/// -------------------- Memory Behavior Attributes ----------------------------
+/// Includes read-none, read-only, and write-only.
+/// ----------------------------------------------------------------------------
+struct AAMemoryBehaviorImpl : public AAMemoryBehavior {
+ AAMemoryBehaviorImpl(const IRPosition &IRP) : AAMemoryBehavior(IRP) {}
- /// See AbstractAttribute::AbstractAttribute(...).
- AAWillReturnImpl(Function &F, InformationCache &InfoCache)
- : AAWillReturn(F, InfoCache) {}
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ intersectAssumedBits(BEST_STATE);
+ getKnownStateFromValue(getIRPosition(), getState());
+ IRAttribute::initialize(A);
+ }
- /// See AAWillReturn::isKnownWillReturn().
- bool isKnownWillReturn() const override { return getKnown(); }
+ /// Return the memory behavior information encoded in the IR for \p IRP.
+ static void getKnownStateFromValue(const IRPosition &IRP,
+ IntegerState &State) {
+ SmallVector<Attribute, 2> Attrs;
+ IRP.getAttrs(AttrKinds, Attrs);
+ for (const Attribute &Attr : Attrs) {
+ switch (Attr.getKindAsEnum()) {
+ case Attribute::ReadNone:
+ State.addKnownBits(NO_ACCESSES);
+ break;
+ case Attribute::ReadOnly:
+ State.addKnownBits(NO_WRITES);
+ break;
+ case Attribute::WriteOnly:
+ State.addKnownBits(NO_READS);
+ break;
+ default:
+ llvm_unreachable("Unexpcted attribute!");
+ }
+ }
- /// See AAWillReturn::isAssumedWillReturn().
- bool isAssumedWillReturn() const override { return getAssumed(); }
+ if (auto *I = dyn_cast<Instruction>(&IRP.getAnchorValue())) {
+ if (!I->mayReadFromMemory())
+ State.addKnownBits(NO_READS);
+ if (!I->mayWriteToMemory())
+ State.addKnownBits(NO_WRITES);
+ }
+ }
- /// See AbstractAttribute::getState(...).
- AbstractState &getState() override { return *this; }
+ /// See AbstractAttribute::getDeducedAttributes(...).
+ void getDeducedAttributes(LLVMContext &Ctx,
+ SmallVectorImpl<Attribute> &Attrs) const override {
+ assert(Attrs.size() == 0);
+ if (isAssumedReadNone())
+ Attrs.push_back(Attribute::get(Ctx, Attribute::ReadNone));
+ else if (isAssumedReadOnly())
+ Attrs.push_back(Attribute::get(Ctx, Attribute::ReadOnly));
+ else if (isAssumedWriteOnly())
+ Attrs.push_back(Attribute::get(Ctx, Attribute::WriteOnly));
+ assert(Attrs.size() <= 1);
+ }
- /// See AbstractAttribute::getState(...).
- const AbstractState &getState() const override { return *this; }
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ IRPosition &IRP = getIRPosition();
+
+ // Check if we would improve the existing attributes first.
+ SmallVector<Attribute, 4> DeducedAttrs;
+ getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs);
+ if (llvm::all_of(DeducedAttrs, [&](const Attribute &Attr) {
+ return IRP.hasAttr(Attr.getKindAsEnum(),
+ /* IgnoreSubsumingPositions */ true);
+ }))
+ return ChangeStatus::UNCHANGED;
+
+ // Clear existing attributes.
+ IRP.removeAttrs(AttrKinds);
+
+ // Use the generic manifest method.
+ return IRAttribute::manifest(A);
+ }
- /// See AbstractAttribute::getAsStr()
+ /// See AbstractState::getAsStr().
const std::string getAsStr() const override {
- return getAssumed() ? "willreturn" : "may-noreturn";
+ if (isAssumedReadNone())
+ return "readnone";
+ if (isAssumedReadOnly())
+ return "readonly";
+ if (isAssumedWriteOnly())
+ return "writeonly";
+ return "may-read/write";
}
+
+ /// The set of IR attributes AAMemoryBehavior deals with.
+ static const Attribute::AttrKind AttrKinds[3];
};
-struct AAWillReturnFunction final : AAWillReturnImpl {
+const Attribute::AttrKind AAMemoryBehaviorImpl::AttrKinds[] = {
+ Attribute::ReadNone, Attribute::ReadOnly, Attribute::WriteOnly};
- /// See AbstractAttribute::AbstractAttribute(...).
- AAWillReturnFunction(Function &F, InformationCache &InfoCache)
- : AAWillReturnImpl(F, InfoCache) {}
+/// Memory behavior attribute for a floating value.
+struct AAMemoryBehaviorFloating : AAMemoryBehaviorImpl {
+ AAMemoryBehaviorFloating(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {}
- /// See AbstractAttribute::getManifestPosition().
- ManifestPosition getManifestPosition() const override {
- return MP_FUNCTION;
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AAMemoryBehaviorImpl::initialize(A);
+ // Initialize the use vector with all direct uses of the associated value.
+ for (const Use &U : getAssociatedValue().uses())
+ Uses.insert(&U);
}
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override;
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ if (isAssumedReadNone())
+ STATS_DECLTRACK_FLOATING_ATTR(readnone)
+ else if (isAssumedReadOnly())
+ STATS_DECLTRACK_FLOATING_ATTR(readonly)
+ else if (isAssumedWriteOnly())
+ STATS_DECLTRACK_FLOATING_ATTR(writeonly)
+ }
+
+private:
+ /// Return true if users of \p UserI might access the underlying
+ /// variable/location described by \p U and should therefore be analyzed.
+ bool followUsersOfUseIn(Attributor &A, const Use *U,
+ const Instruction *UserI);
+
+ /// Update the state according to the effect of use \p U in \p UserI.
+ void analyzeUseIn(Attributor &A, const Use *U, const Instruction *UserI);
+
+protected:
+ /// Container for (transitive) uses of the associated argument.
+ SetVector<const Use *> Uses;
+};
+
+/// Memory behavior attribute for function argument.
+struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating {
+ AAMemoryBehaviorArgument(const IRPosition &IRP)
+ : AAMemoryBehaviorFloating(IRP) {}
+
/// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override;
+ void initialize(Attributor &A) override {
+ AAMemoryBehaviorFloating::initialize(A);
+
+ // Initialize the use vector with all direct uses of the associated value.
+ Argument *Arg = getAssociatedArgument();
+ if (!Arg || !Arg->getParent()->hasExactDefinition())
+ indicatePessimisticFixpoint();
+ }
+
+ ChangeStatus manifest(Attributor &A) override {
+ // TODO: From readattrs.ll: "inalloca parameters are always
+ // considered written"
+ if (hasAttr({Attribute::InAlloca})) {
+ removeKnownBits(NO_WRITES);
+ removeAssumedBits(NO_WRITES);
+ }
+ return AAMemoryBehaviorFloating::manifest(A);
+ }
+
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ if (isAssumedReadNone())
+ STATS_DECLTRACK_ARG_ATTR(readnone)
+ else if (isAssumedReadOnly())
+ STATS_DECLTRACK_ARG_ATTR(readonly)
+ else if (isAssumedWriteOnly())
+ STATS_DECLTRACK_ARG_ATTR(writeonly)
+ }
+};
+
+struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument {
+ AAMemoryBehaviorCallSiteArgument(const IRPosition &IRP)
+ : AAMemoryBehaviorArgument(IRP) {}
/// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override;
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Argument *Arg = getAssociatedArgument();
+ const IRPosition &ArgPos = IRPosition::argument(*Arg);
+ auto &ArgAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos);
+ return clampStateAndIndicateChange(
+ getState(),
+ static_cast<const AANoCapture::StateType &>(ArgAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ if (isAssumedReadNone())
+ STATS_DECLTRACK_CSARG_ATTR(readnone)
+ else if (isAssumedReadOnly())
+ STATS_DECLTRACK_CSARG_ATTR(readonly)
+ else if (isAssumedWriteOnly())
+ STATS_DECLTRACK_CSARG_ATTR(writeonly)
+ }
};
-// Helper function that checks whether a function has any cycle.
-// TODO: Replace with more efficent code
-bool containsCycle(Function &F) {
- SmallPtrSet<BasicBlock *, 32> Visited;
+/// Memory behavior attribute for a call site return position.
+struct AAMemoryBehaviorCallSiteReturned final : AAMemoryBehaviorFloating {
+ AAMemoryBehaviorCallSiteReturned(const IRPosition &IRP)
+ : AAMemoryBehaviorFloating(IRP) {}
- // Traverse BB by dfs and check whether successor is already visited.
- for (BasicBlock *BB : depth_first(&F)) {
- Visited.insert(BB);
- for (auto *SuccBB : successors(BB)) {
- if (Visited.count(SuccBB))
- return true;
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ // We do not annotate returned values.
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {}
+};
+
+/// An AA to represent the memory behavior function attributes.
+struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl {
+ AAMemoryBehaviorFunction(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {}
+
+ /// See AbstractAttribute::updateImpl(Attributor &A).
+ virtual ChangeStatus updateImpl(Attributor &A) override;
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ Function &F = cast<Function>(getAnchorValue());
+ if (isAssumedReadNone()) {
+ F.removeFnAttr(Attribute::ArgMemOnly);
+ F.removeFnAttr(Attribute::InaccessibleMemOnly);
+ F.removeFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
}
+ return AAMemoryBehaviorImpl::manifest(A);
}
- return false;
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ if (isAssumedReadNone())
+ STATS_DECLTRACK_FN_ATTR(readnone)
+ else if (isAssumedReadOnly())
+ STATS_DECLTRACK_FN_ATTR(readonly)
+ else if (isAssumedWriteOnly())
+ STATS_DECLTRACK_FN_ATTR(writeonly)
+ }
+};
+
+/// AAMemoryBehavior attribute for call sites.
+struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
+ AAMemoryBehaviorCallSite(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ AAMemoryBehaviorImpl::initialize(A);
+ Function *F = getAssociatedFunction();
+ if (!F || !F->hasExactDefinition())
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ Function *F = getAssociatedFunction();
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto &FnAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos);
+ return clampStateAndIndicateChange(
+ getState(), static_cast<const AAAlign::StateType &>(FnAA.getState()));
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ if (isAssumedReadNone())
+ STATS_DECLTRACK_CS_ATTR(readnone)
+ else if (isAssumedReadOnly())
+ STATS_DECLTRACK_CS_ATTR(readonly)
+ else if (isAssumedWriteOnly())
+ STATS_DECLTRACK_CS_ATTR(writeonly)
+ }
+};
+} // namespace
+
+ChangeStatus AAMemoryBehaviorFunction::updateImpl(Attributor &A) {
+
+ // The current assumed state used to determine a change.
+ auto AssumedState = getAssumed();
+
+ auto CheckRWInst = [&](Instruction &I) {
+ // If the instruction has an own memory behavior state, use it to restrict
+ // the local state. No further analysis is required as the other memory
+ // state is as optimistic as it gets.
+ if (ImmutableCallSite ICS = ImmutableCallSite(&I)) {
+ const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(
+ *this, IRPosition::callsite_function(ICS));
+ intersectAssumedBits(MemBehaviorAA.getAssumed());
+ return !isAtFixpoint();
+ }
+
+ // Remove access kind modifiers if necessary.
+ if (I.mayReadFromMemory())
+ removeAssumedBits(NO_READS);
+ if (I.mayWriteToMemory())
+ removeAssumedBits(NO_WRITES);
+ return !isAtFixpoint();
+ };
+
+ if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this))
+ return indicatePessimisticFixpoint();
+
+ return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED
+ : ChangeStatus::UNCHANGED;
}
-// Helper function that checks the function have a loop which might become an
-// endless loop
-// FIXME: Any cycle is regarded as endless loop for now.
-// We have to allow some patterns.
-bool containsPossiblyEndlessLoop(Function &F) { return containsCycle(F); }
+ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) {
-void AAWillReturnFunction::initialize(Attributor &A) {
- Function &F = getAnchorScope();
+ const IRPosition &IRP = getIRPosition();
+ const IRPosition &FnPos = IRPosition::function_scope(IRP);
+ AAMemoryBehavior::StateType &S = getState();
- if (containsPossiblyEndlessLoop(F))
- indicatePessimisticFixpoint();
+ // First, check the function scope. We take the known information and we avoid
+ // work if the assumed information implies the current assumed information for
+ // this attribute.
+ const auto &FnMemAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos);
+ S.addKnownBits(FnMemAA.getKnown());
+ if ((S.getAssumed() & FnMemAA.getAssumed()) == S.getAssumed())
+ return ChangeStatus::UNCHANGED;
+
+ // Make sure the value is not captured (except through "return"), if
+ // it is, any information derived would be irrelevant anyway as we cannot
+ // check the potential aliases introduced by the capture. However, no need
+ // to fall back to anythign less optimistic than the function state.
+ const auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>(*this, IRP);
+ if (!ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ S.intersectAssumedBits(FnMemAA.getAssumed());
+ return ChangeStatus::CHANGED;
+ }
+
+ // The current assumed state used to determine a change.
+ auto AssumedState = S.getAssumed();
+
+ // Liveness information to exclude dead users.
+ // TODO: Take the FnPos once we have call site specific liveness information.
+ const auto &LivenessAA = A.getAAFor<AAIsDead>(
+ *this, IRPosition::function(*IRP.getAssociatedFunction()));
+
+ // Visit and expand uses until all are analyzed or a fixpoint is reached.
+ for (unsigned i = 0; i < Uses.size() && !isAtFixpoint(); i++) {
+ const Use *U = Uses[i];
+ Instruction *UserI = cast<Instruction>(U->getUser());
+ LLVM_DEBUG(dbgs() << "[AAMemoryBehavior] Use: " << **U << " in " << *UserI
+ << " [Dead: " << (LivenessAA.isAssumedDead(UserI))
+ << "]\n");
+ if (LivenessAA.isAssumedDead(UserI))
+ continue;
+
+ // Check if the users of UserI should also be visited.
+ if (followUsersOfUseIn(A, U, UserI))
+ for (const Use &UserIUse : UserI->uses())
+ Uses.insert(&UserIUse);
+
+ // If UserI might touch memory we analyze the use in detail.
+ if (UserI->mayReadOrWriteMemory())
+ analyzeUseIn(A, U, UserI);
+ }
+
+ return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED
+ : ChangeStatus::UNCHANGED;
}
-ChangeStatus AAWillReturnFunction::updateImpl(Attributor &A) {
- Function &F = getAnchorScope();
+bool AAMemoryBehaviorFloating::followUsersOfUseIn(Attributor &A, const Use *U,
+ const Instruction *UserI) {
+ // The loaded value is unrelated to the pointer argument, no need to
+ // follow the users of the load.
+ if (isa<LoadInst>(UserI))
+ return false;
- // The map from instruction opcodes to those instructions in the function.
- auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
+ // By default we follow all uses assuming UserI might leak information on U,
+ // we have special handling for call sites operands though.
+ ImmutableCallSite ICS(UserI);
+ if (!ICS || !ICS.isArgOperand(U))
+ return true;
- for (unsigned Opcode :
- {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
- (unsigned)Instruction::Call}) {
- for (Instruction *I : OpcodeInstMap[Opcode]) {
- auto ICS = ImmutableCallSite(I);
+ // If the use is a call argument known not to be captured, the users of
+ // the call do not need to be visited because they have to be unrelated to
+ // the input. Note that this check is not trivial even though we disallow
+ // general capturing of the underlying argument. The reason is that the
+ // call might the argument "through return", which we allow and for which we
+ // need to check call users.
+ unsigned ArgNo = ICS.getArgumentNo(U);
+ const auto &ArgNoCaptureAA =
+ A.getAAFor<AANoCapture>(*this, IRPosition::callsite_argument(ICS, ArgNo));
+ return !ArgNoCaptureAA.isAssumedNoCapture();
+}
- if (ICS.hasFnAttr(Attribute::WillReturn))
- continue;
+void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use *U,
+ const Instruction *UserI) {
+ assert(UserI->mayReadOrWriteMemory());
- auto *WillReturnAA = A.getAAFor<AAWillReturn>(*this, *I);
- if (!WillReturnAA || !WillReturnAA->isAssumedWillReturn()) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
- }
+ switch (UserI->getOpcode()) {
+ default:
+ // TODO: Handle all atomics and other side-effect operations we know of.
+ break;
+ case Instruction::Load:
+ // Loads cause the NO_READS property to disappear.
+ removeAssumedBits(NO_READS);
+ return;
- auto *NoRecurseAA = A.getAAFor<AANoRecurse>(*this, *I);
+ case Instruction::Store:
+ // Stores cause the NO_WRITES property to disappear if the use is the
+ // pointer operand. Note that we do assume that capturing was taken care of
+ // somewhere else.
+ if (cast<StoreInst>(UserI)->getPointerOperand() == U->get())
+ removeAssumedBits(NO_WRITES);
+ return;
- // FIXME: (i) Prohibit any recursion for now.
- // (ii) AANoRecurse isn't implemented yet so currently any call is
- // regarded as having recursion.
- // Code below should be
- // if ((!NoRecurseAA || !NoRecurseAA->isAssumedNoRecurse()) &&
- if (!NoRecurseAA && !ICS.hasFnAttr(Attribute::NoRecurse)) {
- indicatePessimisticFixpoint();
- return ChangeStatus::CHANGED;
- }
+ case Instruction::Call:
+ case Instruction::CallBr:
+ case Instruction::Invoke: {
+ // For call sites we look at the argument memory behavior attribute (this
+ // could be recursive!) in order to restrict our own state.
+ ImmutableCallSite ICS(UserI);
+
+ // Give up on operand bundles.
+ if (ICS.isBundleOperand(U)) {
+ indicatePessimisticFixpoint();
+ return;
+ }
+
+ // Calling a function does read the function pointer, maybe write it if the
+ // function is self-modifying.
+ if (ICS.isCallee(U)) {
+ removeAssumedBits(NO_READS);
+ break;
}
+
+ // Adjust the possible access behavior based on the information on the
+ // argument.
+ unsigned ArgNo = ICS.getArgumentNo(U);
+ const IRPosition &ArgPos = IRPosition::callsite_argument(ICS, ArgNo);
+ const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos);
+ // "assumed" has at most the same bits as the MemBehaviorAA assumed
+ // and at least "known".
+ intersectAssumedBits(MemBehaviorAA.getAssumed());
+ return;
}
+ };
- return ChangeStatus::UNCHANGED;
+ // Generally, look at the "may-properties" and adjust the assumed state if we
+ // did not trigger special handling before.
+ if (UserI->mayReadFromMemory())
+ removeAssumedBits(NO_READS);
+ if (UserI->mayWriteToMemory())
+ removeAssumedBits(NO_WRITES);
}
/// ----------------------------------------------------------------------------
/// Attributor
/// ----------------------------------------------------------------------------
-bool Attributor::checkForAllCallSites(Function &F,
- std::function<bool(CallSite)> &Pred,
- bool RequireAllCallSites) {
+bool Attributor::isAssumedDead(const AbstractAttribute &AA,
+ const AAIsDead *LivenessAA) {
+ const Instruction *CtxI = AA.getIRPosition().getCtxI();
+ if (!CtxI)
+ return false;
+
+ if (!LivenessAA)
+ LivenessAA =
+ &getAAFor<AAIsDead>(AA, IRPosition::function(*CtxI->getFunction()),
+ /* TrackDependence */ false);
+
+ // Don't check liveness for AAIsDead.
+ if (&AA == LivenessAA)
+ return false;
+
+ if (!LivenessAA->isAssumedDead(CtxI))
+ return false;
+
+ // We actually used liveness information so we have to record a dependence.
+ recordDependence(*LivenessAA, AA);
+
+ return true;
+}
+
+bool Attributor::checkForAllCallSites(
+ const function_ref<bool(AbstractCallSite)> &Pred,
+ const AbstractAttribute &QueryingAA, bool RequireAllCallSites) {
// We can try to determine information from
// the call sites. However, this is only possible all call sites are known,
// hence the function has internal linkage.
- if (RequireAllCallSites && !F.hasInternalLinkage()) {
+ const IRPosition &IRP = QueryingAA.getIRPosition();
+ const Function *AssociatedFunction = IRP.getAssociatedFunction();
+ if (!AssociatedFunction) {
+ LLVM_DEBUG(dbgs() << "[Attributor] No function associated with " << IRP
+ << "\n");
+ return false;
+ }
+
+ return checkForAllCallSites(Pred, *AssociatedFunction, RequireAllCallSites,
+ &QueryingAA);
+}
+
+bool Attributor::checkForAllCallSites(
+ const function_ref<bool(AbstractCallSite)> &Pred, const Function &Fn,
+ bool RequireAllCallSites, const AbstractAttribute *QueryingAA) {
+ if (RequireAllCallSites && !Fn.hasLocalLinkage()) {
LLVM_DEBUG(
dbgs()
- << "Attributor: Function " << F.getName()
+ << "[Attributor] Function " << Fn.getName()
<< " has no internal linkage, hence not all call sites are known\n");
return false;
}
- for (const Use &U : F.uses()) {
+ for (const Use &U : Fn.uses()) {
+ AbstractCallSite ACS(&U);
+ if (!ACS) {
+ LLVM_DEBUG(dbgs() << "[Attributor] Function "
+ << Fn.getName()
+ << " has non call site use " << *U.get() << " in "
+ << *U.getUser() << "\n");
+ return false;
+ }
+
+ Instruction *I = ACS.getInstruction();
+ Function *Caller = I->getFunction();
+
+ const auto *LivenessAA =
+ lookupAAFor<AAIsDead>(IRPosition::function(*Caller), QueryingAA,
+ /* TrackDependence */ false);
+
+ // Skip dead calls.
+ if (LivenessAA && LivenessAA->isAssumedDead(I)) {
+ // We actually used liveness information so we have to record a
+ // dependence.
+ if (QueryingAA)
+ recordDependence(*LivenessAA, *QueryingAA);
+ continue;
+ }
- CallSite CS(U.getUser());
- if (!CS || !CS.isCallee(&U) || !CS.getCaller()->hasExactDefinition()) {
+ const Use *EffectiveUse =
+ ACS.isCallbackCall() ? &ACS.getCalleeUseForCallback() : &U;
+ if (!ACS.isCallee(EffectiveUse)) {
if (!RequireAllCallSites)
continue;
-
- LLVM_DEBUG(dbgs() << "Attributor: User " << *U.getUser()
- << " is an invalid use of " << F.getName() << "\n");
+ LLVM_DEBUG(dbgs() << "[Attributor] User " << EffectiveUse->getUser()
+ << " is an invalid use of "
+ << Fn.getName() << "\n");
return false;
}
- if (Pred(CS))
+ if (Pred(ACS))
continue;
- LLVM_DEBUG(dbgs() << "Attributor: Call site callback failed for "
- << *CS.getInstruction() << "\n");
+ LLVM_DEBUG(dbgs() << "[Attributor] Call site callback failed for "
+ << *ACS.getInstruction() << "\n");
return false;
}
return true;
}
-ChangeStatus Attributor::run() {
- // Initialize all abstract attributes.
- for (AbstractAttribute *AA : AllAbstractAttributes)
- AA->initialize(*this);
+bool Attributor::checkForAllReturnedValuesAndReturnInsts(
+ const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)>
+ &Pred,
+ const AbstractAttribute &QueryingAA) {
+
+ const IRPosition &IRP = QueryingAA.getIRPosition();
+ // Since we need to provide return instructions we have to have an exact
+ // definition.
+ const Function *AssociatedFunction = IRP.getAssociatedFunction();
+ if (!AssociatedFunction)
+ return false;
+ // If this is a call site query we use the call site specific return values
+ // and liveness information.
+ // TODO: use the function scope once we have call site AAReturnedValues.
+ const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
+ const auto &AARetVal = getAAFor<AAReturnedValues>(QueryingAA, QueryIRP);
+ if (!AARetVal.getState().isValidState())
+ return false;
+
+ return AARetVal.checkForAllReturnedValuesAndReturnInsts(Pred);
+}
+
+bool Attributor::checkForAllReturnedValues(
+ const function_ref<bool(Value &)> &Pred,
+ const AbstractAttribute &QueryingAA) {
+
+ const IRPosition &IRP = QueryingAA.getIRPosition();
+ const Function *AssociatedFunction = IRP.getAssociatedFunction();
+ if (!AssociatedFunction)
+ return false;
+
+ // TODO: use the function scope once we have call site AAReturnedValues.
+ const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
+ const auto &AARetVal = getAAFor<AAReturnedValues>(QueryingAA, QueryIRP);
+ if (!AARetVal.getState().isValidState())
+ return false;
+
+ return AARetVal.checkForAllReturnedValuesAndReturnInsts(
+ [&](Value &RV, const SmallSetVector<ReturnInst *, 4> &) {
+ return Pred(RV);
+ });
+}
+
+static bool
+checkForAllInstructionsImpl(InformationCache::OpcodeInstMapTy &OpcodeInstMap,
+ const function_ref<bool(Instruction &)> &Pred,
+ const AAIsDead *LivenessAA, bool &AnyDead,
+ const ArrayRef<unsigned> &Opcodes) {
+ for (unsigned Opcode : Opcodes) {
+ for (Instruction *I : OpcodeInstMap[Opcode]) {
+ // Skip dead instructions.
+ if (LivenessAA && LivenessAA->isAssumedDead(I)) {
+ AnyDead = true;
+ continue;
+ }
+
+ if (!Pred(*I))
+ return false;
+ }
+ }
+ return true;
+}
+
+bool Attributor::checkForAllInstructions(
+ const llvm::function_ref<bool(Instruction &)> &Pred,
+ const AbstractAttribute &QueryingAA, const ArrayRef<unsigned> &Opcodes) {
+
+ const IRPosition &IRP = QueryingAA.getIRPosition();
+ // Since we need to provide instructions we have to have an exact definition.
+ const Function *AssociatedFunction = IRP.getAssociatedFunction();
+ if (!AssociatedFunction)
+ return false;
+
+ // TODO: use the function scope once we have call site AAReturnedValues.
+ const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
+ const auto &LivenessAA =
+ getAAFor<AAIsDead>(QueryingAA, QueryIRP, /* TrackDependence */ false);
+ bool AnyDead = false;
+
+ auto &OpcodeInstMap =
+ InfoCache.getOpcodeInstMapForFunction(*AssociatedFunction);
+ if (!checkForAllInstructionsImpl(OpcodeInstMap, Pred, &LivenessAA, AnyDead,
+ Opcodes))
+ return false;
+
+ // If we actually used liveness information so we have to record a dependence.
+ if (AnyDead)
+ recordDependence(LivenessAA, QueryingAA);
+
+ return true;
+}
+
+bool Attributor::checkForAllReadWriteInstructions(
+ const llvm::function_ref<bool(Instruction &)> &Pred,
+ AbstractAttribute &QueryingAA) {
+
+ const Function *AssociatedFunction =
+ QueryingAA.getIRPosition().getAssociatedFunction();
+ if (!AssociatedFunction)
+ return false;
+
+ // TODO: use the function scope once we have call site AAReturnedValues.
+ const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
+ const auto &LivenessAA =
+ getAAFor<AAIsDead>(QueryingAA, QueryIRP, /* TrackDependence */ false);
+ bool AnyDead = false;
+
+ for (Instruction *I :
+ InfoCache.getReadOrWriteInstsForFunction(*AssociatedFunction)) {
+ // Skip dead instructions.
+ if (LivenessAA.isAssumedDead(I)) {
+ AnyDead = true;
+ continue;
+ }
+
+ if (!Pred(*I))
+ return false;
+ }
+
+ // If we actually used liveness information so we have to record a dependence.
+ if (AnyDead)
+ recordDependence(LivenessAA, QueryingAA);
+
+ return true;
+}
+
+ChangeStatus Attributor::run(Module &M) {
LLVM_DEBUG(dbgs() << "[Attributor] Identified and initialized "
<< AllAbstractAttributes.size()
<< " abstract attributes.\n");
@@ -1370,10 +4470,25 @@ ChangeStatus Attributor::run() {
SetVector<AbstractAttribute *> Worklist;
Worklist.insert(AllAbstractAttributes.begin(), AllAbstractAttributes.end());
+ bool RecomputeDependences = false;
+
do {
+ // Remember the size to determine new attributes.
+ size_t NumAAs = AllAbstractAttributes.size();
LLVM_DEBUG(dbgs() << "\n\n[Attributor] #Iteration: " << IterationCounter
<< ", Worklist size: " << Worklist.size() << "\n");
+ // If dependences (=QueryMap) are recomputed we have to look at all abstract
+ // attributes again, regardless of what changed in the last iteration.
+ if (RecomputeDependences) {
+ LLVM_DEBUG(
+ dbgs() << "[Attributor] Run all AAs to recompute dependences\n");
+ QueryMap.clear();
+ ChangedAAs.clear();
+ Worklist.insert(AllAbstractAttributes.begin(),
+ AllAbstractAttributes.end());
+ }
+
// Add all abstract attributes that are potentially dependent on one that
// changed to the work list.
for (AbstractAttribute *ChangedAA : ChangedAAs) {
@@ -1381,27 +4496,42 @@ ChangeStatus Attributor::run() {
Worklist.insert(QuerriedAAs.begin(), QuerriedAAs.end());
}
+ LLVM_DEBUG(dbgs() << "[Attributor] #Iteration: " << IterationCounter
+ << ", Worklist+Dependent size: " << Worklist.size()
+ << "\n");
+
// Reset the changed set.
ChangedAAs.clear();
// Update all abstract attribute in the work list and record the ones that
// changed.
for (AbstractAttribute *AA : Worklist)
- if (AA->update(*this) == ChangeStatus::CHANGED)
- ChangedAAs.push_back(AA);
+ if (!isAssumedDead(*AA, nullptr))
+ if (AA->update(*this) == ChangeStatus::CHANGED)
+ ChangedAAs.push_back(AA);
+
+ // Check if we recompute the dependences in the next iteration.
+ RecomputeDependences = (DepRecomputeInterval > 0 &&
+ IterationCounter % DepRecomputeInterval == 0);
+
+ // Add attributes to the changed set if they have been created in the last
+ // iteration.
+ ChangedAAs.append(AllAbstractAttributes.begin() + NumAAs,
+ AllAbstractAttributes.end());
// Reset the work list and repopulate with the changed abstract attributes.
// Note that dependent ones are added above.
Worklist.clear();
Worklist.insert(ChangedAAs.begin(), ChangedAAs.end());
- } while (!Worklist.empty() && ++IterationCounter < MaxFixpointIterations);
+ } while (!Worklist.empty() && (IterationCounter++ < MaxFixpointIterations ||
+ VerifyMaxFixpointIterations));
LLVM_DEBUG(dbgs() << "\n[Attributor] Fixpoint iteration done after: "
<< IterationCounter << "/" << MaxFixpointIterations
<< " iterations\n");
- bool FinishedAtFixpoint = Worklist.empty();
+ size_t NumFinalAAs = AllAbstractAttributes.size();
// Reset abstract arguments not settled in a sound fixpoint by now. This
// happens when we stopped the fixpoint iteration early. Note that only the
@@ -1448,8 +4578,14 @@ ChangeStatus Attributor::run() {
if (!State.isValidState())
continue;
+ // Skip dead code.
+ if (isAssumedDead(*AA, nullptr))
+ continue;
// Manifest the state and record if we changed the IR.
ChangeStatus LocalChange = AA->manifest(*this);
+ if (LocalChange == ChangeStatus::CHANGED && AreStatisticsEnabled())
+ AA->trackStatistics();
+
ManifestChange = ManifestChange | LocalChange;
NumAtFixpoint++;
@@ -1462,69 +4598,92 @@ ChangeStatus Attributor::run() {
<< " arguments while " << NumAtFixpoint
<< " were in a valid fixpoint state\n");
- // If verification is requested, we finished this run at a fixpoint, and the
- // IR was changed, we re-run the whole fixpoint analysis, starting at
- // re-initialization of the arguments. This re-run should not result in an IR
- // change. Though, the (virtual) state of attributes at the end of the re-run
- // might be more optimistic than the known state or the IR state if the better
- // state cannot be manifested.
- if (VerifyAttributor && FinishedAtFixpoint &&
- ManifestChange == ChangeStatus::CHANGED) {
- VerifyAttributor = false;
- ChangeStatus VerifyStatus = run();
- if (VerifyStatus != ChangeStatus::UNCHANGED)
- llvm_unreachable(
- "Attributor verification failed, re-run did result in an IR change "
- "even after a fixpoint was reached in the original run. (False "
- "positives possible!)");
- VerifyAttributor = true;
- }
-
NumAttributesManifested += NumManifested;
NumAttributesValidFixpoint += NumAtFixpoint;
- return ManifestChange;
-}
-
-void Attributor::identifyDefaultAbstractAttributes(
- Function &F, InformationCache &InfoCache,
- DenseSet</* Attribute::AttrKind */ unsigned> *Whitelist) {
+ (void)NumFinalAAs;
+ assert(
+ NumFinalAAs == AllAbstractAttributes.size() &&
+ "Expected the final number of abstract attributes to remain unchanged!");
+
+ // Delete stuff at the end to avoid invalid references and a nice order.
+ {
+ LLVM_DEBUG(dbgs() << "\n[Attributor] Delete at least "
+ << ToBeDeletedFunctions.size() << " functions and "
+ << ToBeDeletedBlocks.size() << " blocks and "
+ << ToBeDeletedInsts.size() << " instructions\n");
+ for (Instruction *I : ToBeDeletedInsts) {
+ if (!I->use_empty())
+ I->replaceAllUsesWith(UndefValue::get(I->getType()));
+ I->eraseFromParent();
+ }
- // Every function can be nounwind.
- registerAA(*new AANoUnwindFunction(F, InfoCache));
+ if (unsigned NumDeadBlocks = ToBeDeletedBlocks.size()) {
+ SmallVector<BasicBlock *, 8> ToBeDeletedBBs;
+ ToBeDeletedBBs.reserve(NumDeadBlocks);
+ ToBeDeletedBBs.append(ToBeDeletedBlocks.begin(), ToBeDeletedBlocks.end());
+ DeleteDeadBlocks(ToBeDeletedBBs);
+ STATS_DECLTRACK(AAIsDead, BasicBlock,
+ "Number of dead basic blocks deleted.");
+ }
- // Every function might be marked "nosync"
- registerAA(*new AANoSyncFunction(F, InfoCache));
+ STATS_DECL(AAIsDead, Function, "Number of dead functions deleted.");
+ for (Function *Fn : ToBeDeletedFunctions) {
+ Fn->replaceAllUsesWith(UndefValue::get(Fn->getType()));
+ Fn->eraseFromParent();
+ STATS_TRACK(AAIsDead, Function);
+ }
- // Every function might be "no-free".
- registerAA(*new AANoFreeFunction(F, InfoCache));
+ // Identify dead internal functions and delete them. This happens outside
+ // the other fixpoint analysis as we might treat potentially dead functions
+ // as live to lower the number of iterations. If they happen to be dead, the
+ // below fixpoint loop will identify and eliminate them.
+ SmallVector<Function *, 8> InternalFns;
+ for (Function &F : M)
+ if (F.hasLocalLinkage())
+ InternalFns.push_back(&F);
+
+ bool FoundDeadFn = true;
+ while (FoundDeadFn) {
+ FoundDeadFn = false;
+ for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) {
+ Function *F = InternalFns[u];
+ if (!F)
+ continue;
- // Return attributes are only appropriate if the return type is non void.
- Type *ReturnType = F.getReturnType();
- if (!ReturnType->isVoidTy()) {
- // Argument attribute "returned" --- Create only one per function even
- // though it is an argument attribute.
- if (!Whitelist || Whitelist->count(AAReturnedValues::ID))
- registerAA(*new AAReturnedValuesImpl(F, InfoCache));
+ const auto *LivenessAA =
+ lookupAAFor<AAIsDead>(IRPosition::function(*F));
+ if (LivenessAA &&
+ !checkForAllCallSites([](AbstractCallSite ACS) { return false; },
+ *LivenessAA, true))
+ continue;
- // Every function with pointer return type might be marked nonnull.
- if (ReturnType->isPointerTy() &&
- (!Whitelist || Whitelist->count(AANonNullReturned::ID)))
- registerAA(*new AANonNullReturned(F, InfoCache));
+ STATS_TRACK(AAIsDead, Function);
+ F->replaceAllUsesWith(UndefValue::get(F->getType()));
+ F->eraseFromParent();
+ InternalFns[u] = nullptr;
+ FoundDeadFn = true;
+ }
+ }
}
- // Every argument with pointer type might be marked nonnull.
- for (Argument &Arg : F.args()) {
- if (Arg.getType()->isPointerTy())
- registerAA(*new AANonNullArgument(Arg, InfoCache));
+ if (VerifyMaxFixpointIterations &&
+ IterationCounter != MaxFixpointIterations) {
+ errs() << "\n[Attributor] Fixpoint iteration done after: "
+ << IterationCounter << "/" << MaxFixpointIterations
+ << " iterations\n";
+ llvm_unreachable("The fixpoint was not reached with exactly the number of "
+ "specified iterations!");
}
- // Every function might be "will-return".
- registerAA(*new AAWillReturnFunction(F, InfoCache));
+ return ManifestChange;
+}
+
+void Attributor::initializeInformationCache(Function &F) {
- // Walk all instructions to find more attribute opportunities and also
- // interesting instructions that might be queried by abstract attributes
- // during their initialization or update.
+ // Walk all instructions to find interesting instructions that might be
+ // queried by abstract attributes during their initialization or update.
+ // This has to happen before we create attributes.
auto &ReadOrWriteInsts = InfoCache.FuncRWInstsMap[&F];
auto &InstOpcodeMap = InfoCache.FuncInstOpcodeMap[&F];
@@ -1540,8 +4699,12 @@ void Attributor::identifyDefaultAbstractAttributes(
default:
assert((!ImmutableCallSite(&I)) && (!isa<CallBase>(&I)) &&
"New call site/base instruction type needs to be known int the "
- "attributor.");
+ "Attributor.");
break;
+ case Instruction::Load:
+ // The alignment of a pointer is interesting for loads.
+ case Instruction::Store:
+ // The alignment of a pointer is interesting for stores.
case Instruction::Call:
case Instruction::CallBr:
case Instruction::Invoke:
@@ -1555,18 +4718,154 @@ void Attributor::identifyDefaultAbstractAttributes(
InstOpcodeMap[I.getOpcode()].push_back(&I);
if (I.mayReadOrWriteMemory())
ReadOrWriteInsts.push_back(&I);
+ }
+}
+
+void Attributor::identifyDefaultAbstractAttributes(Function &F) {
+ if (!VisitedFunctions.insert(&F).second)
+ return;
+
+ IRPosition FPos = IRPosition::function(F);
+
+ // Check for dead BasicBlocks in every function.
+ // We need dead instruction detection because we do not want to deal with
+ // broken IR in which SSA rules do not apply.
+ getOrCreateAAFor<AAIsDead>(FPos);
+
+ // Every function might be "will-return".
+ getOrCreateAAFor<AAWillReturn>(FPos);
+ // Every function can be nounwind.
+ getOrCreateAAFor<AANoUnwind>(FPos);
+
+ // Every function might be marked "nosync"
+ getOrCreateAAFor<AANoSync>(FPos);
+
+ // Every function might be "no-free".
+ getOrCreateAAFor<AANoFree>(FPos);
+
+ // Every function might be "no-return".
+ getOrCreateAAFor<AANoReturn>(FPos);
+
+ // Every function might be "no-recurse".
+ getOrCreateAAFor<AANoRecurse>(FPos);
+
+ // Every function might be "readnone/readonly/writeonly/...".
+ getOrCreateAAFor<AAMemoryBehavior>(FPos);
+
+ // Every function might be applicable for Heap-To-Stack conversion.
+ if (EnableHeapToStack)
+ getOrCreateAAFor<AAHeapToStack>(FPos);
+
+ // Return attributes are only appropriate if the return type is non void.
+ Type *ReturnType = F.getReturnType();
+ if (!ReturnType->isVoidTy()) {
+ // Argument attribute "returned" --- Create only one per function even
+ // though it is an argument attribute.
+ getOrCreateAAFor<AAReturnedValues>(FPos);
+
+ IRPosition RetPos = IRPosition::returned(F);
+
+ // Every function might be simplified.
+ getOrCreateAAFor<AAValueSimplify>(RetPos);
+
+ if (ReturnType->isPointerTy()) {
+
+ // Every function with pointer return type might be marked align.
+ getOrCreateAAFor<AAAlign>(RetPos);
+
+ // Every function with pointer return type might be marked nonnull.
+ getOrCreateAAFor<AANonNull>(RetPos);
+
+ // Every function with pointer return type might be marked noalias.
+ getOrCreateAAFor<AANoAlias>(RetPos);
+
+ // Every function with pointer return type might be marked
+ // dereferenceable.
+ getOrCreateAAFor<AADereferenceable>(RetPos);
+ }
+ }
+
+ for (Argument &Arg : F.args()) {
+ IRPosition ArgPos = IRPosition::argument(Arg);
+
+ // Every argument might be simplified.
+ getOrCreateAAFor<AAValueSimplify>(ArgPos);
+
+ if (Arg.getType()->isPointerTy()) {
+ // Every argument with pointer type might be marked nonnull.
+ getOrCreateAAFor<AANonNull>(ArgPos);
+
+ // Every argument with pointer type might be marked noalias.
+ getOrCreateAAFor<AANoAlias>(ArgPos);
+
+ // Every argument with pointer type might be marked dereferenceable.
+ getOrCreateAAFor<AADereferenceable>(ArgPos);
+
+ // Every argument with pointer type might be marked align.
+ getOrCreateAAFor<AAAlign>(ArgPos);
+
+ // Every argument with pointer type might be marked nocapture.
+ getOrCreateAAFor<AANoCapture>(ArgPos);
+
+ // Every argument with pointer type might be marked
+ // "readnone/readonly/writeonly/..."
+ getOrCreateAAFor<AAMemoryBehavior>(ArgPos);
+ }
+ }
+
+ auto CallSitePred = [&](Instruction &I) -> bool {
CallSite CS(&I);
- if (CS && CS.getCalledFunction()) {
+ if (CS.getCalledFunction()) {
for (int i = 0, e = CS.getCalledFunction()->arg_size(); i < e; i++) {
+
+ IRPosition CSArgPos = IRPosition::callsite_argument(CS, i);
+
+ // Call site argument might be simplified.
+ getOrCreateAAFor<AAValueSimplify>(CSArgPos);
+
if (!CS.getArgument(i)->getType()->isPointerTy())
continue;
// Call site argument attribute "non-null".
- registerAA(*new AANonNullCallSiteArgument(CS, i, InfoCache), i);
+ getOrCreateAAFor<AANonNull>(CSArgPos);
+
+ // Call site argument attribute "no-alias".
+ getOrCreateAAFor<AANoAlias>(CSArgPos);
+
+ // Call site argument attribute "dereferenceable".
+ getOrCreateAAFor<AADereferenceable>(CSArgPos);
+
+ // Call site argument attribute "align".
+ getOrCreateAAFor<AAAlign>(CSArgPos);
}
}
- }
+ return true;
+ };
+
+ auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
+ bool Success, AnyDead = false;
+ Success = checkForAllInstructionsImpl(
+ OpcodeInstMap, CallSitePred, nullptr, AnyDead,
+ {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
+ (unsigned)Instruction::Call});
+ (void)Success;
+ assert(Success && !AnyDead && "Expected the check call to be successful!");
+
+ auto LoadStorePred = [&](Instruction &I) -> bool {
+ if (isa<LoadInst>(I))
+ getOrCreateAAFor<AAAlign>(
+ IRPosition::value(*cast<LoadInst>(I).getPointerOperand()));
+ else
+ getOrCreateAAFor<AAAlign>(
+ IRPosition::value(*cast<StoreInst>(I).getPointerOperand()));
+ return true;
+ };
+ Success = checkForAllInstructionsImpl(
+ OpcodeInstMap, LoadStorePred, nullptr, AnyDead,
+ {(unsigned)Instruction::Load, (unsigned)Instruction::Store});
+ (void)Success;
+ assert(Success && !AnyDead && "Expected the check call to be successful!");
}
/// Helpers to ease debugging through output streams and print calls.
@@ -1576,21 +4875,39 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, ChangeStatus S) {
return OS << (S == ChangeStatus::CHANGED ? "changed" : "unchanged");
}
-raw_ostream &llvm::operator<<(raw_ostream &OS,
- AbstractAttribute::ManifestPosition AP) {
+raw_ostream &llvm::operator<<(raw_ostream &OS, IRPosition::Kind AP) {
switch (AP) {
- case AbstractAttribute::MP_ARGUMENT:
+ case IRPosition::IRP_INVALID:
+ return OS << "inv";
+ case IRPosition::IRP_FLOAT:
+ return OS << "flt";
+ case IRPosition::IRP_RETURNED:
+ return OS << "fn_ret";
+ case IRPosition::IRP_CALL_SITE_RETURNED:
+ return OS << "cs_ret";
+ case IRPosition::IRP_FUNCTION:
+ return OS << "fn";
+ case IRPosition::IRP_CALL_SITE:
+ return OS << "cs";
+ case IRPosition::IRP_ARGUMENT:
return OS << "arg";
- case AbstractAttribute::MP_CALL_SITE_ARGUMENT:
+ case IRPosition::IRP_CALL_SITE_ARGUMENT:
return OS << "cs_arg";
- case AbstractAttribute::MP_FUNCTION:
- return OS << "fn";
- case AbstractAttribute::MP_RETURNED:
- return OS << "fn_ret";
}
llvm_unreachable("Unknown attribute position!");
}
+raw_ostream &llvm::operator<<(raw_ostream &OS, const IRPosition &Pos) {
+ const Value &AV = Pos.getAssociatedValue();
+ return OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " ["
+ << Pos.getAnchorValue().getName() << "@" << Pos.getArgNo() << "]}";
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerState &S) {
+ return OS << "(" << S.getKnown() << "-" << S.getAssumed() << ")"
+ << static_cast<const AbstractState &>(S);
+}
+
raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractState &S) {
return OS << (!S.isValidState() ? "top" : (S.isAtFixpoint() ? "fix" : ""));
}
@@ -1601,8 +4918,8 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractAttribute &AA) {
}
void AbstractAttribute::print(raw_ostream &OS) const {
- OS << "[" << getManifestPosition() << "][" << getAsStr() << "]["
- << AnchoredVal.getName() << "]";
+ OS << "[P: " << getIRPosition() << "][" << getAsStr() << "][S: " << getState()
+ << "]";
}
///}
@@ -1610,7 +4927,7 @@ void AbstractAttribute::print(raw_ostream &OS) const {
/// Pass (Manager) Boilerplate
/// ----------------------------------------------------------------------------
-static bool runAttributorOnModule(Module &M) {
+static bool runAttributorOnModule(Module &M, AnalysisGetter &AG) {
if (DisableAttributor)
return false;
@@ -1619,39 +4936,39 @@ static bool runAttributorOnModule(Module &M) {
// Create an Attributor and initially empty information cache that is filled
// while we identify default attribute opportunities.
- Attributor A;
- InformationCache InfoCache;
+ InformationCache InfoCache(M, AG);
+ Attributor A(InfoCache, DepRecInterval);
+
+ for (Function &F : M)
+ A.initializeInformationCache(F);
for (Function &F : M) {
- // TODO: Not all attributes require an exact definition. Find a way to
- // enable deduction for some but not all attributes in case the
- // definition might be changed at runtime, see also
- // http://lists.llvm.org/pipermail/llvm-dev/2018-February/121275.html.
- // TODO: We could always determine abstract attributes and if sufficient
- // information was found we could duplicate the functions that do not
- // have an exact definition.
- if (!F.hasExactDefinition()) {
+ if (F.hasExactDefinition())
+ NumFnWithExactDefinition++;
+ else
NumFnWithoutExactDefinition++;
- continue;
- }
- // For now we ignore naked and optnone functions.
- if (F.hasFnAttribute(Attribute::Naked) ||
- F.hasFnAttribute(Attribute::OptimizeNone))
- continue;
-
- NumFnWithExactDefinition++;
+ // We look at internal functions only on-demand but if any use is not a
+ // direct call, we have to do it eagerly.
+ if (F.hasLocalLinkage()) {
+ if (llvm::all_of(F.uses(), [](const Use &U) {
+ return ImmutableCallSite(U.getUser()) &&
+ ImmutableCallSite(U.getUser()).isCallee(&U);
+ }))
+ continue;
+ }
// Populate the Attributor with abstract attribute opportunities in the
// function and the information cache with IR information.
- A.identifyDefaultAbstractAttributes(F, InfoCache);
+ A.identifyDefaultAbstractAttributes(F);
}
- return A.run() == ChangeStatus::CHANGED;
+ return A.run(M) == ChangeStatus::CHANGED;
}
PreservedAnalyses AttributorPass::run(Module &M, ModuleAnalysisManager &AM) {
- if (runAttributorOnModule(M)) {
+ AnalysisGetter AG(AM);
+ if (runAttributorOnModule(M, AG)) {
// FIXME: Think about passes we will preserve and add them here.
return PreservedAnalyses::none();
}
@@ -1670,12 +4987,14 @@ struct AttributorLegacyPass : public ModulePass {
bool runOnModule(Module &M) override {
if (skipModule(M))
return false;
- return runAttributorOnModule(M);
+
+ AnalysisGetter AG;
+ return runAttributorOnModule(M, AG);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
// FIXME: Think about passes we will preserve and add them here.
- AU.setPreservesCFG();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
@@ -1684,7 +5003,147 @@ struct AttributorLegacyPass : public ModulePass {
Pass *llvm::createAttributorLegacyPass() { return new AttributorLegacyPass(); }
char AttributorLegacyPass::ID = 0;
+
+const char AAReturnedValues::ID = 0;
+const char AANoUnwind::ID = 0;
+const char AANoSync::ID = 0;
+const char AANoFree::ID = 0;
+const char AANonNull::ID = 0;
+const char AANoRecurse::ID = 0;
+const char AAWillReturn::ID = 0;
+const char AANoAlias::ID = 0;
+const char AANoReturn::ID = 0;
+const char AAIsDead::ID = 0;
+const char AADereferenceable::ID = 0;
+const char AAAlign::ID = 0;
+const char AANoCapture::ID = 0;
+const char AAValueSimplify::ID = 0;
+const char AAHeapToStack::ID = 0;
+const char AAMemoryBehavior::ID = 0;
+
+// Macro magic to create the static generator function for attributes that
+// follow the naming scheme.
+
+#define SWITCH_PK_INV(CLASS, PK, POS_NAME) \
+ case IRPosition::PK: \
+ llvm_unreachable("Cannot create " #CLASS " for a " POS_NAME " position!");
+
+#define SWITCH_PK_CREATE(CLASS, IRP, PK, SUFFIX) \
+ case IRPosition::PK: \
+ AA = new CLASS##SUFFIX(IRP); \
+ break;
+
+#define CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \
+ SWITCH_PK_INV(CLASS, IRP_FLOAT, "floating") \
+ SWITCH_PK_INV(CLASS, IRP_ARGUMENT, "argument") \
+ SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE_RETURNED, "call site returned") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE_ARGUMENT, "call site argument") \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \
+ } \
+ return *AA; \
+ }
+
+#define CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \
+ SWITCH_PK_INV(CLASS, IRP_FUNCTION, "function") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE, "call site") \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_RETURNED, Returned) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \
+ } \
+ return *AA; \
+ }
+
+#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_RETURNED, Returned) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \
+ } \
+ return *AA; \
+ }
+
+#define CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \
+ SWITCH_PK_INV(CLASS, IRP_ARGUMENT, "argument") \
+ SWITCH_PK_INV(CLASS, IRP_FLOAT, "floating") \
+ SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE_RETURNED, "call site returned") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE_ARGUMENT, "call site argument") \
+ SWITCH_PK_INV(CLASS, IRP_CALL_SITE, "call site") \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \
+ } \
+ return *AA; \
+ }
+
+#define CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \
+ SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \
+ SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \
+ } \
+ return *AA; \
+ }
+
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUnwind)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoSync)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoRecurse)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAWillReturn)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues)
+
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADereferenceable)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAlign)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoCapture)
+
+CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify)
+
+CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
+
+CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
+
+#undef CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef SWITCH_PK_CREATE
+#undef SWITCH_PK_INV
+
INITIALIZE_PASS_BEGIN(AttributorLegacyPass, "attributor",
"Deduce and propagate attributes", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(AttributorLegacyPass, "attributor",
"Deduce and propagate attributes", false, false)
diff --git a/lib/Transforms/IPO/BlockExtractor.cpp b/lib/Transforms/IPO/BlockExtractor.cpp
index 6c365f3f3cbe..de80c88c1591 100644
--- a/lib/Transforms/IPO/BlockExtractor.cpp
+++ b/lib/Transforms/IPO/BlockExtractor.cpp
@@ -119,6 +119,8 @@ void BlockExtractor::loadFile() {
/*KeepEmpty=*/false);
if (LineSplit.empty())
continue;
+ if (LineSplit.size()!=2)
+ report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'");
SmallVector<StringRef, 4> BBNames;
LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
@@ -204,7 +206,8 @@ bool BlockExtractor::runOnModule(Module &M) {
++NumExtracted;
Changed = true;
}
- Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion();
+ CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent());
+ Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC);
if (F)
LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName()
<< "' in: " << F->getName() << '\n');
diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp
index ad877ae1786c..3cf839e397f8 100644
--- a/lib/Transforms/IPO/ConstantMerge.cpp
+++ b/lib/Transforms/IPO/ConstantMerge.cpp
@@ -48,7 +48,7 @@ static void FindUsedValues(GlobalVariable *LLVMUsed,
ConstantArray *Inits = cast<ConstantArray>(LLVMUsed->getInitializer());
for (unsigned i = 0, e = Inits->getNumOperands(); i != e; ++i) {
- Value *Operand = Inits->getOperand(i)->stripPointerCastsNoFollowAliases();
+ Value *Operand = Inits->getOperand(i)->stripPointerCasts();
GlobalValue *GV = cast<GlobalValue>(Operand);
UsedValues.insert(GV);
}
@@ -120,7 +120,7 @@ static void replace(Module &M, GlobalVariable *Old, GlobalVariable *New) {
// Bump the alignment if necessary.
if (Old->getAlignment() || New->getAlignment())
- New->setAlignment(std::max(getAlignment(Old), getAlignment(New)));
+ New->setAlignment(Align(std::max(getAlignment(Old), getAlignment(New))));
copyDebugLocMetadata(Old, New);
Old->replaceAllUsesWith(NewConstant);
diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp
index e30b33aa4872..e20159ba0db5 100644
--- a/lib/Transforms/IPO/CrossDSOCFI.cpp
+++ b/lib/Transforms/IPO/CrossDSOCFI.cpp
@@ -84,13 +84,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
for (GlobalObject &GO : M.global_objects()) {
Types.clear();
GO.getMetadata(LLVMContext::MD_type, Types);
- for (MDNode *Type : Types) {
- // Sanity check. GO must not be a function declaration.
- assert(!isa<Function>(&GO) || !cast<Function>(&GO)->isDeclaration());
-
+ for (MDNode *Type : Types)
if (ConstantInt *TypeId = extractNumericTypeId(Type))
TypeIds.insert(TypeId->getZExtValue());
- }
}
NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions");
@@ -108,11 +104,11 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
FunctionCallee C = M.getOrInsertFunction(
"__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),
Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx));
- Function *F = dyn_cast<Function>(C.getCallee());
+ Function *F = cast<Function>(C.getCallee());
// Take over the existing function. The frontend emits a weak stub so that the
// linker knows about the symbol; this pass replaces the function body.
F->deleteBody();
- F->setAlignment(4096);
+ F->setAlignment(Align(4096));
Triple T(M.getTargetTriple());
if (T.isARM() || T.isThumb())
diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp
index 5ccd8bc4b0fb..b174c63a577b 100644
--- a/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -78,11 +78,8 @@ STATISTIC(NumNoRecurse, "Number of functions marked as norecurse");
STATISTIC(NumNoUnwind, "Number of functions marked as nounwind");
STATISTIC(NumNoFree, "Number of functions marked as nofree");
-// FIXME: This is disabled by default to avoid exposing security vulnerabilities
-// in C/C++ code compiled by clang:
-// http://lists.llvm.org/pipermail/cfe-dev/2017-January/052066.html
static cl::opt<bool> EnableNonnullArgPropagation(
- "enable-nonnull-arg-prop", cl::Hidden,
+ "enable-nonnull-arg-prop", cl::init(true), cl::Hidden,
cl::desc("Try to propagate nonnull argument attributes from callsites to "
"caller functions."));
@@ -664,6 +661,25 @@ static bool addArgumentAttrsFromCallsites(Function &F) {
return Changed;
}
+static bool addReadAttr(Argument *A, Attribute::AttrKind R) {
+ assert((R == Attribute::ReadOnly || R == Attribute::ReadNone)
+ && "Must be a Read attribute.");
+ assert(A && "Argument must not be null.");
+
+ // If the argument already has the attribute, nothing needs to be done.
+ if (A->hasAttribute(R))
+ return false;
+
+ // Otherwise, remove potentially conflicting attribute, add the new one,
+ // and update statistics.
+ A->removeAttr(Attribute::WriteOnly);
+ A->removeAttr(Attribute::ReadOnly);
+ A->removeAttr(Attribute::ReadNone);
+ A->addAttr(R);
+ R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;
+ return true;
+}
+
/// Deduce nocapture attributes for the SCC.
static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
bool Changed = false;
@@ -732,11 +748,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
SmallPtrSet<Argument *, 8> Self;
Self.insert(&*A);
Attribute::AttrKind R = determinePointerReadAttrs(&*A, Self);
- if (R != Attribute::None) {
- A->addAttr(R);
- Changed = true;
- R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;
- }
+ if (R != Attribute::None)
+ Changed = addReadAttr(A, R);
}
}
}
@@ -833,12 +846,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) {
if (ReadAttr != Attribute::None) {
for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) {
Argument *A = ArgumentSCC[i]->Definition;
- // Clear out existing readonly/readnone attributes
- A->removeAttr(Attribute::ReadOnly);
- A->removeAttr(Attribute::ReadNone);
- A->addAttr(ReadAttr);
- ReadAttr == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg;
- Changed = true;
+ Changed = addReadAttr(A, ReadAttr);
}
}
}
diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp
index 62c7fbd07223..3f5cc078d75f 100644
--- a/lib/Transforms/IPO/FunctionImport.cpp
+++ b/lib/Transforms/IPO/FunctionImport.cpp
@@ -450,7 +450,7 @@ static void computeImportForFunction(
} else if (PrintImportFailures) {
assert(!FailureInfo &&
"Expected no FailureInfo for newly rejected candidate");
- FailureInfo = llvm::make_unique<FunctionImporter::ImportFailureInfo>(
+ FailureInfo = std::make_unique<FunctionImporter::ImportFailureInfo>(
VI, Edge.second.getHotness(), Reason, 1);
}
LLVM_DEBUG(
@@ -764,7 +764,7 @@ void llvm::computeDeadSymbols(
}
// Make value live and add it to the worklist if it was not live before.
- auto visit = [&](ValueInfo VI) {
+ auto visit = [&](ValueInfo VI, bool IsAliasee) {
// FIXME: If we knew which edges were created for indirect call profiles,
// we could skip them here. Any that are live should be reached via
// other edges, e.g. reference edges. Otherwise, using a profile collected
@@ -800,12 +800,15 @@ void llvm::computeDeadSymbols(
Interposable = true;
}
- if (!KeepAliveLinkage)
- return;
+ if (!IsAliasee) {
+ if (!KeepAliveLinkage)
+ return;
- if (Interposable)
- report_fatal_error(
- "Interposable and available_externally/linkonce_odr/weak_odr symbol");
+ if (Interposable)
+ report_fatal_error(
+ "Interposable and available_externally/linkonce_odr/weak_odr "
+ "symbol");
+ }
}
for (auto &S : VI.getSummaryList())
@@ -821,16 +824,16 @@ void llvm::computeDeadSymbols(
// If this is an alias, visit the aliasee VI to ensure that all copies
// are marked live and it is added to the worklist for further
// processing of its references.
- visit(AS->getAliaseeVI());
+ visit(AS->getAliaseeVI(), true);
continue;
}
Summary->setLive(true);
for (auto Ref : Summary->refs())
- visit(Ref);
+ visit(Ref, false);
if (auto *FS = dyn_cast<FunctionSummary>(Summary.get()))
for (auto Call : FS->calls())
- visit(Call.first);
+ visit(Call.first, false);
}
}
Index.setWithGlobalValueDeadStripping();
@@ -892,7 +895,7 @@ std::error_code llvm::EmitImportsFiles(
StringRef ModulePath, StringRef OutputFilename,
const std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) {
std::error_code EC;
- raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::F_None);
+ raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::OF_None);
if (EC)
return EC;
for (auto &ILI : ModuleToSummariesForIndex)
@@ -948,23 +951,15 @@ void llvm::thinLTOResolvePrevailingInModule(
auto NewLinkage = GS->second->linkage();
if (NewLinkage == GV.getLinkage())
return;
-
- // Switch the linkage to weakany if asked for, e.g. we do this for
- // linker redefined symbols (via --wrap or --defsym).
- // We record that the visibility should be changed here in `addThinLTO`
- // as we need access to the resolution vectors for each input file in
- // order to find which symbols have been redefined.
- // We may consider reorganizing this code and moving the linkage recording
- // somewhere else, e.g. in thinLTOResolvePrevailingInIndex.
- if (NewLinkage == GlobalValue::WeakAnyLinkage) {
- GV.setLinkage(NewLinkage);
- return;
- }
-
if (GlobalValue::isLocalLinkage(GV.getLinkage()) ||
+ // Don't internalize anything here, because the code below
+ // lacks necessary correctness checks. Leave this job to
+ // LLVM 'internalize' pass.
+ GlobalValue::isLocalLinkage(NewLinkage) ||
// In case it was dead and already converted to declaration.
GV.isDeclaration())
return;
+
// Check for a non-prevailing def that has interposable linkage
// (e.g. non-odr weak or linkonce). In that case we can't simply
// convert to available_externally, since it would lose the
diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp
index 86b7f3e49ee6..f010f7b703a6 100644
--- a/lib/Transforms/IPO/GlobalDCE.cpp
+++ b/lib/Transforms/IPO/GlobalDCE.cpp
@@ -17,9 +17,11 @@
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/CtorUtils.h"
@@ -29,10 +31,15 @@ using namespace llvm;
#define DEBUG_TYPE "globaldce"
+static cl::opt<bool>
+ ClEnableVFE("enable-vfe", cl::Hidden, cl::init(true), cl::ZeroOrMore,
+ cl::desc("Enable virtual function elimination"));
+
STATISTIC(NumAliases , "Number of global aliases removed");
STATISTIC(NumFunctions, "Number of functions removed");
STATISTIC(NumIFuncs, "Number of indirect functions removed");
STATISTIC(NumVariables, "Number of global variables removed");
+STATISTIC(NumVFuncs, "Number of virtual functions removed");
namespace {
class GlobalDCELegacyPass : public ModulePass {
@@ -118,6 +125,15 @@ void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) {
ComputeDependencies(User, Deps);
Deps.erase(&GV); // Remove self-reference.
for (GlobalValue *GVU : Deps) {
+ // If this is a dep from a vtable to a virtual function, and we have
+ // complete information about all virtual call sites which could call
+ // though this vtable, then skip it, because the call site information will
+ // be more precise.
+ if (VFESafeVTables.count(GVU) && isa<Function>(&GV)) {
+ LLVM_DEBUG(dbgs() << "Ignoring dep " << GVU->getName() << " -> "
+ << GV.getName() << "\n");
+ continue;
+ }
GVDependencies[GVU].insert(&GV);
}
}
@@ -132,12 +148,133 @@ void GlobalDCEPass::MarkLive(GlobalValue &GV,
if (Updates)
Updates->push_back(&GV);
if (Comdat *C = GV.getComdat()) {
- for (auto &&CM : make_range(ComdatMembers.equal_range(C)))
+ for (auto &&CM : make_range(ComdatMembers.equal_range(C))) {
MarkLive(*CM.second, Updates); // Recursion depth is only two because only
// globals in the same comdat are visited.
+ }
+ }
+}
+
+void GlobalDCEPass::ScanVTables(Module &M) {
+ SmallVector<MDNode *, 2> Types;
+ LLVM_DEBUG(dbgs() << "Building type info -> vtable map\n");
+
+ auto *LTOPostLinkMD =
+ cast_or_null<ConstantAsMetadata>(M.getModuleFlag("LTOPostLink"));
+ bool LTOPostLink =
+ LTOPostLinkMD &&
+ (cast<ConstantInt>(LTOPostLinkMD->getValue())->getZExtValue() != 0);
+
+ for (GlobalVariable &GV : M.globals()) {
+ Types.clear();
+ GV.getMetadata(LLVMContext::MD_type, Types);
+ if (GV.isDeclaration() || Types.empty())
+ continue;
+
+ // Use the typeid metadata on the vtable to build a mapping from typeids to
+ // the list of (GV, offset) pairs which are the possible vtables for that
+ // typeid.
+ for (MDNode *Type : Types) {
+ Metadata *TypeID = Type->getOperand(1).get();
+
+ uint64_t Offset =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
+ ->getZExtValue();
+
+ TypeIdMap[TypeID].insert(std::make_pair(&GV, Offset));
+ }
+
+ // If the type corresponding to the vtable is private to this translation
+ // unit, we know that we can see all virtual functions which might use it,
+ // so VFE is safe.
+ if (auto GO = dyn_cast<GlobalObject>(&GV)) {
+ GlobalObject::VCallVisibility TypeVis = GO->getVCallVisibility();
+ if (TypeVis == GlobalObject::VCallVisibilityTranslationUnit ||
+ (LTOPostLink &&
+ TypeVis == GlobalObject::VCallVisibilityLinkageUnit)) {
+ LLVM_DEBUG(dbgs() << GV.getName() << " is safe for VFE\n");
+ VFESafeVTables.insert(&GV);
+ }
+ }
+ }
+}
+
+void GlobalDCEPass::ScanVTableLoad(Function *Caller, Metadata *TypeId,
+ uint64_t CallOffset) {
+ for (auto &VTableInfo : TypeIdMap[TypeId]) {
+ GlobalVariable *VTable = VTableInfo.first;
+ uint64_t VTableOffset = VTableInfo.second;
+
+ Constant *Ptr =
+ getPointerAtOffset(VTable->getInitializer(), VTableOffset + CallOffset,
+ *Caller->getParent());
+ if (!Ptr) {
+ LLVM_DEBUG(dbgs() << "can't find pointer in vtable!\n");
+ VFESafeVTables.erase(VTable);
+ return;
+ }
+
+ auto Callee = dyn_cast<Function>(Ptr->stripPointerCasts());
+ if (!Callee) {
+ LLVM_DEBUG(dbgs() << "vtable entry is not function pointer!\n");
+ VFESafeVTables.erase(VTable);
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "vfunc dep " << Caller->getName() << " -> "
+ << Callee->getName() << "\n");
+ GVDependencies[Caller].insert(Callee);
}
}
+void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) {
+ LLVM_DEBUG(dbgs() << "Scanning type.checked.load intrinsics\n");
+ Function *TypeCheckedLoadFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+
+ if (!TypeCheckedLoadFunc)
+ return;
+
+ for (auto U : TypeCheckedLoadFunc->users()) {
+ auto CI = dyn_cast<CallInst>(U);
+ if (!CI)
+ continue;
+
+ auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ Value *TypeIdValue = CI->getArgOperand(2);
+ auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
+
+ if (Offset) {
+ ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue());
+ } else {
+ // type.checked.load with a non-constant offset, so assume every entry in
+ // every matching vtable is used.
+ for (auto &VTableInfo : TypeIdMap[TypeId]) {
+ VFESafeVTables.erase(VTableInfo.first);
+ }
+ }
+ }
+}
+
+void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) {
+ if (!ClEnableVFE)
+ return;
+
+ ScanVTables(M);
+
+ if (VFESafeVTables.empty())
+ return;
+
+ ScanTypeCheckedLoadIntrinsics(M);
+
+ LLVM_DEBUG(
+ dbgs() << "VFE safe vtables:\n";
+ for (auto *VTable : VFESafeVTables)
+ dbgs() << " " << VTable->getName() << "\n";
+ );
+}
+
PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
bool Changed = false;
@@ -163,6 +300,10 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
if (Comdat *C = GA.getComdat())
ComdatMembers.insert(std::make_pair(C, &GA));
+ // Add dependencies between virtual call sites and the virtual functions they
+ // might call, if we have that information.
+ AddVirtualFunctionDependencies(M);
+
// Loop over the module, adding globals which are obviously necessary.
for (GlobalObject &GO : M.global_objects()) {
Changed |= RemoveUnusedGlobalValue(GO);
@@ -257,8 +398,17 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
};
NumFunctions += DeadFunctions.size();
- for (Function *F : DeadFunctions)
+ for (Function *F : DeadFunctions) {
+ if (!F->use_empty()) {
+ // Virtual functions might still be referenced by one or more vtables,
+ // but if we've proven them to be unused then it's safe to replace the
+ // virtual function pointers with null, allowing us to remove the
+ // function itself.
+ ++NumVFuncs;
+ F->replaceNonMetadataUsesWith(ConstantPointerNull::get(F->getType()));
+ }
EraseUnusedGlobalValue(F);
+ }
NumVariables += DeadGlobalVars.size();
for (GlobalVariable *GV : DeadGlobalVars)
@@ -277,6 +427,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
ConstantDependenciesCache.clear();
GVDependencies.clear();
ComdatMembers.clear();
+ TypeIdMap.clear();
+ VFESafeVTables.clear();
if (Changed)
return PreservedAnalyses::none();
diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp
index c4fb3ce77f6e..819715b9f8da 100644
--- a/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/lib/Transforms/IPO/GlobalOpt.cpp
@@ -155,7 +155,8 @@ static bool isLeakCheckerRoot(GlobalVariable *GV) {
/// Given a value that is stored to a global but never read, determine whether
/// it's safe to remove the store and the chain of computation that feeds the
/// store.
-static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) {
+static bool IsSafeComputationToRemove(
+ Value *V, function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
do {
if (isa<Constant>(V))
return true;
@@ -164,7 +165,7 @@ static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) {
if (isa<LoadInst>(V) || isa<InvokeInst>(V) || isa<Argument>(V) ||
isa<GlobalValue>(V))
return false;
- if (isAllocationFn(V, TLI))
+ if (isAllocationFn(V, GetTLI))
return true;
Instruction *I = cast<Instruction>(V);
@@ -184,8 +185,9 @@ static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) {
/// This GV is a pointer root. Loop over all users of the global and clean up
/// any that obviously don't assign the global a value that isn't dynamically
/// allocated.
-static bool CleanupPointerRootUsers(GlobalVariable *GV,
- const TargetLibraryInfo *TLI) {
+static bool
+CleanupPointerRootUsers(GlobalVariable *GV,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
// A brief explanation of leak checkers. The goal is to find bugs where
// pointers are forgotten, causing an accumulating growth in memory
// usage over time. The common strategy for leak checkers is to whitelist the
@@ -241,18 +243,18 @@ static bool CleanupPointerRootUsers(GlobalVariable *GV,
C->destroyConstant();
// This could have invalidated UI, start over from scratch.
Dead.clear();
- CleanupPointerRootUsers(GV, TLI);
+ CleanupPointerRootUsers(GV, GetTLI);
return true;
}
}
}
for (int i = 0, e = Dead.size(); i != e; ++i) {
- if (IsSafeComputationToRemove(Dead[i].first, TLI)) {
+ if (IsSafeComputationToRemove(Dead[i].first, GetTLI)) {
Dead[i].second->eraseFromParent();
Instruction *I = Dead[i].first;
do {
- if (isAllocationFn(I, TLI))
+ if (isAllocationFn(I, GetTLI))
break;
Instruction *J = dyn_cast<Instruction>(I->getOperand(0));
if (!J)
@@ -270,9 +272,9 @@ static bool CleanupPointerRootUsers(GlobalVariable *GV,
/// We just marked GV constant. Loop over all users of the global, cleaning up
/// the obvious ones. This is largely just a quick scan over the use list to
/// clean up the easy and obvious cruft. This returns true if it made a change.
-static bool CleanupConstantGlobalUsers(Value *V, Constant *Init,
- const DataLayout &DL,
- TargetLibraryInfo *TLI) {
+static bool CleanupConstantGlobalUsers(
+ Value *V, Constant *Init, const DataLayout &DL,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
bool Changed = false;
// Note that we need to use a weak value handle for the worklist items. When
// we delete a constant array, we may also be holding pointer to one of its
@@ -302,12 +304,12 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init,
Constant *SubInit = nullptr;
if (Init)
SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE);
- Changed |= CleanupConstantGlobalUsers(CE, SubInit, DL, TLI);
+ Changed |= CleanupConstantGlobalUsers(CE, SubInit, DL, GetTLI);
} else if ((CE->getOpcode() == Instruction::BitCast &&
CE->getType()->isPointerTy()) ||
CE->getOpcode() == Instruction::AddrSpaceCast) {
// Pointer cast, delete any stores and memsets to the global.
- Changed |= CleanupConstantGlobalUsers(CE, nullptr, DL, TLI);
+ Changed |= CleanupConstantGlobalUsers(CE, nullptr, DL, GetTLI);
}
if (CE->use_empty()) {
@@ -321,7 +323,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init,
Constant *SubInit = nullptr;
if (!isa<ConstantExpr>(GEP->getOperand(0))) {
ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(
- ConstantFoldInstruction(GEP, DL, TLI));
+ ConstantFoldInstruction(GEP, DL, &GetTLI(*GEP->getFunction())));
if (Init && CE && CE->getOpcode() == Instruction::GetElementPtr)
SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE);
@@ -331,7 +333,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init,
if (Init && isa<ConstantAggregateZero>(Init) && GEP->isInBounds())
SubInit = Constant::getNullValue(GEP->getResultElementType());
}
- Changed |= CleanupConstantGlobalUsers(GEP, SubInit, DL, TLI);
+ Changed |= CleanupConstantGlobalUsers(GEP, SubInit, DL, GetTLI);
if (GEP->use_empty()) {
GEP->eraseFromParent();
@@ -348,7 +350,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init,
// us, and if they are all dead, nuke them without remorse.
if (isSafeToDestroyConstant(C)) {
C->destroyConstant();
- CleanupConstantGlobalUsers(V, Init, DL, TLI);
+ CleanupConstantGlobalUsers(V, Init, DL, GetTLI);
return true;
}
}
@@ -495,8 +497,8 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
// had 256 byte alignment for example, something might depend on that:
// propagate info to each field.
uint64_t FieldOffset = Layout.getElementOffset(i);
- unsigned NewAlign = (unsigned)MinAlign(StartAlignment, FieldOffset);
- if (NewAlign > DL.getABITypeAlignment(STy->getElementType(i)))
+ Align NewAlign(MinAlign(StartAlignment, FieldOffset));
+ if (NewAlign > Align(DL.getABITypeAlignment(STy->getElementType(i))))
NGV->setAlignment(NewAlign);
// Copy over the debug info for the variable.
@@ -511,7 +513,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
NewGlobals.reserve(NumElements);
auto ElTy = STy->getElementType();
uint64_t EltSize = DL.getTypeAllocSize(ElTy);
- unsigned EltAlign = DL.getABITypeAlignment(ElTy);
+ Align EltAlign(DL.getABITypeAlignment(ElTy));
uint64_t FragmentSizeInBits = DL.getTypeAllocSizeInBits(ElTy);
for (unsigned i = 0, e = NumElements; i != e; ++i) {
Constant *In = Init->getAggregateElement(i);
@@ -530,7 +532,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
// Calculate the known alignment of the field. If the original aggregate
// had 256 byte alignment for example, something might depend on that:
// propagate info to each field.
- unsigned NewAlign = (unsigned)MinAlign(StartAlignment, EltSize*i);
+ Align NewAlign(MinAlign(StartAlignment, EltSize * i));
if (NewAlign > EltAlign)
NGV->setAlignment(NewAlign);
transferSRADebugInfo(GV, NGV, FragmentSizeInBits * i, FragmentSizeInBits,
@@ -745,9 +747,9 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) {
/// are uses of the loaded value that would trap if the loaded value is
/// dynamically null, then we know that they cannot be reachable with a null
/// optimize away the load.
-static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV,
- const DataLayout &DL,
- TargetLibraryInfo *TLI) {
+static bool OptimizeAwayTrappingUsesOfLoads(
+ GlobalVariable *GV, Constant *LV, const DataLayout &DL,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
bool Changed = false;
// Keep track of whether we are able to remove all the uses of the global
@@ -793,10 +795,10 @@ static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV,
// nor is the global.
if (AllNonStoreUsesGone) {
if (isLeakCheckerRoot(GV)) {
- Changed |= CleanupPointerRootUsers(GV, TLI);
+ Changed |= CleanupPointerRootUsers(GV, GetTLI);
} else {
Changed = true;
- CleanupConstantGlobalUsers(GV, nullptr, DL, TLI);
+ CleanupConstantGlobalUsers(GV, nullptr, DL, GetTLI);
}
if (GV->use_empty()) {
LLVM_DEBUG(dbgs() << " *** GLOBAL NOW DEAD!\n");
@@ -889,8 +891,8 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy,
while (!GV->use_empty()) {
if (StoreInst *SI = dyn_cast<StoreInst>(GV->user_back())) {
// The global is initialized when the store to it occurs.
- new StoreInst(ConstantInt::getTrue(GV->getContext()), InitBool, false, 0,
- SI->getOrdering(), SI->getSyncScopeID(), SI);
+ new StoreInst(ConstantInt::getTrue(GV->getContext()), InitBool, false,
+ None, SI->getOrdering(), SI->getSyncScopeID(), SI);
SI->eraseFromParent();
continue;
}
@@ -907,7 +909,7 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy,
// Replace the cmp X, 0 with a use of the bool value.
// Sink the load to where the compare was, if atomic rules allow us to.
Value *LV = new LoadInst(InitBool->getValueType(), InitBool,
- InitBool->getName() + ".val", false, 0,
+ InitBool->getName() + ".val", false, None,
LI->getOrdering(), LI->getSyncScopeID(),
LI->isUnordered() ? (Instruction *)ICI : LI);
InitBoolUsed = true;
@@ -1562,10 +1564,10 @@ static bool tryToOptimizeStoreOfMallocToGlobal(GlobalVariable *GV, CallInst *CI,
// Try to optimize globals based on the knowledge that only one value (besides
// its initializer) is ever stored to the global.
-static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal,
- AtomicOrdering Ordering,
- const DataLayout &DL,
- TargetLibraryInfo *TLI) {
+static bool
+optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal,
+ AtomicOrdering Ordering, const DataLayout &DL,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
// Ignore no-op GEPs and bitcasts.
StoredOnceVal = StoredOnceVal->stripPointerCasts();
@@ -1583,9 +1585,10 @@ static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal,
SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType());
// Optimize away any trapping uses of the loaded value.
- if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, TLI))
+ if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, GetTLI))
return true;
- } else if (CallInst *CI = extractMallocCall(StoredOnceVal, TLI)) {
+ } else if (CallInst *CI = extractMallocCall(StoredOnceVal, GetTLI)) {
+ auto *TLI = &GetTLI(*CI->getFunction());
Type *MallocType = getMallocAllocatedType(CI, TLI);
if (MallocType && tryToOptimizeStoreOfMallocToGlobal(GV, CI, MallocType,
Ordering, DL, TLI))
@@ -1643,10 +1646,12 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) {
// instead of a select to synthesize the desired value.
bool IsOneZero = false;
bool EmitOneOrZero = true;
- if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)){
+ auto *CI = dyn_cast<ConstantInt>(OtherVal);
+ if (CI && CI->getValue().getActiveBits() <= 64) {
IsOneZero = InitVal->isNullValue() && CI->isOne();
- if (ConstantInt *CIInit = dyn_cast<ConstantInt>(GV->getInitializer())){
+ auto *CIInit = dyn_cast<ConstantInt>(GV->getInitializer());
+ if (CIInit && CIInit->getValue().getActiveBits() <= 64) {
uint64_t ValInit = CIInit->getZExtValue();
uint64_t ValOther = CI->getZExtValue();
uint64_t ValMinus = ValOther - ValInit;
@@ -1711,7 +1716,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) {
assert(LI->getOperand(0) == GV && "Not a copy!");
// Insert a new load, to preserve the saved value.
StoreVal = new LoadInst(NewGV->getValueType(), NewGV,
- LI->getName() + ".b", false, 0,
+ LI->getName() + ".b", false, None,
LI->getOrdering(), LI->getSyncScopeID(), LI);
} else {
assert((isa<CastInst>(StoredVal) || isa<SelectInst>(StoredVal)) &&
@@ -1721,15 +1726,15 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) {
}
}
StoreInst *NSI =
- new StoreInst(StoreVal, NewGV, false, 0, SI->getOrdering(),
+ new StoreInst(StoreVal, NewGV, false, None, SI->getOrdering(),
SI->getSyncScopeID(), SI);
NSI->setDebugLoc(SI->getDebugLoc());
} else {
// Change the load into a load of bool then a select.
LoadInst *LI = cast<LoadInst>(UI);
- LoadInst *NLI =
- new LoadInst(NewGV->getValueType(), NewGV, LI->getName() + ".b",
- false, 0, LI->getOrdering(), LI->getSyncScopeID(), LI);
+ LoadInst *NLI = new LoadInst(NewGV->getValueType(), NewGV,
+ LI->getName() + ".b", false, None,
+ LI->getOrdering(), LI->getSyncScopeID(), LI);
Instruction *NSI;
if (IsOneZero)
NSI = new ZExtInst(NLI, LI->getType(), "", LI);
@@ -1914,9 +1919,10 @@ static void makeAllConstantUsesInstructions(Constant *C) {
/// Analyze the specified global variable and optimize
/// it if possible. If we make a change, return true.
-static bool processInternalGlobal(
- GlobalVariable *GV, const GlobalStatus &GS, TargetLibraryInfo *TLI,
- function_ref<DominatorTree &(Function &)> LookupDomTree) {
+static bool
+processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI,
+ function_ref<DominatorTree &(Function &)> LookupDomTree) {
auto &DL = GV->getParent()->getDataLayout();
// If this is a first class global and has only one accessing function and
// this function is non-recursive, we replace the global with a local alloca
@@ -1963,11 +1969,12 @@ static bool processInternalGlobal(
bool Changed;
if (isLeakCheckerRoot(GV)) {
// Delete any constant stores to the global.
- Changed = CleanupPointerRootUsers(GV, TLI);
+ Changed = CleanupPointerRootUsers(GV, GetTLI);
} else {
// Delete any stores we can find to the global. We may not be able to
// make it completely dead though.
- Changed = CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI);
+ Changed =
+ CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI);
}
// If the global is dead now, delete it.
@@ -1989,7 +1996,7 @@ static bool processInternalGlobal(
GV->setConstant(true);
// Clean up any obviously simplifiable users now.
- CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI);
+ CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI);
// If the global is dead now, just nuke it.
if (GV->use_empty()) {
@@ -2019,7 +2026,7 @@ static bool processInternalGlobal(
GV->setInitializer(SOVConstant);
// Clean up any obviously simplifiable users now.
- CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI);
+ CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI);
if (GV->use_empty()) {
LLVM_DEBUG(dbgs() << " *** Substituting initializer allowed us to "
@@ -2033,7 +2040,8 @@ static bool processInternalGlobal(
// Try to optimize globals based on the knowledge that only one value
// (besides its initializer) is ever stored to the global.
- if (optimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GS.Ordering, DL, TLI))
+ if (optimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GS.Ordering, DL,
+ GetTLI))
return true;
// Otherwise, if the global was not a boolean, we can shrink it to be a
@@ -2054,7 +2062,8 @@ static bool processInternalGlobal(
/// Analyze the specified global variable and optimize it if possible. If we
/// make a change, return true.
static bool
-processGlobal(GlobalValue &GV, TargetLibraryInfo *TLI,
+processGlobal(GlobalValue &GV,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI,
function_ref<DominatorTree &(Function &)> LookupDomTree) {
if (GV.getName().startswith("llvm."))
return false;
@@ -2086,7 +2095,7 @@ processGlobal(GlobalValue &GV, TargetLibraryInfo *TLI,
if (GVar->isConstant() || !GVar->hasInitializer())
return Changed;
- return processInternalGlobal(GVar, GS, TLI, LookupDomTree) || Changed;
+ return processInternalGlobal(GVar, GS, GetTLI, LookupDomTree) || Changed;
}
/// Walk all of the direct calls of the specified function, changing them to
@@ -2234,7 +2243,8 @@ hasOnlyColdCalls(Function &F,
}
static bool
-OptimizeFunctions(Module &M, TargetLibraryInfo *TLI,
+OptimizeFunctions(Module &M,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI,
function_ref<TargetTransformInfo &(Function &)> GetTTI,
function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
function_ref<DominatorTree &(Function &)> LookupDomTree,
@@ -2275,17 +2285,13 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI,
// So, remove unreachable blocks from the function, because a) there's
// no point in analyzing them and b) GlobalOpt should otherwise grow
// some more complicated logic to break these cycles.
- // Removing unreachable blocks might invalidate the dominator so we
- // recalculate it.
if (!F->isDeclaration()) {
- if (removeUnreachableBlocks(*F)) {
- auto &DT = LookupDomTree(*F);
- DT.recalculate(*F);
- Changed = true;
- }
+ auto &DT = LookupDomTree(*F);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ Changed |= removeUnreachableBlocks(*F, &DTU);
}
- Changed |= processGlobal(*F, TLI, LookupDomTree);
+ Changed |= processGlobal(*F, GetTLI, LookupDomTree);
if (!F->hasLocalLinkage())
continue;
@@ -2342,7 +2348,8 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI,
}
static bool
-OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI,
+OptimizeGlobalVars(Module &M,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI,
function_ref<DominatorTree &(Function &)> LookupDomTree,
SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) {
bool Changed = false;
@@ -2357,7 +2364,10 @@ OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI,
if (GV->hasInitializer())
if (auto *C = dyn_cast<Constant>(GV->getInitializer())) {
auto &DL = M.getDataLayout();
- Constant *New = ConstantFoldConstant(C, DL, TLI);
+ // TLI is not used in the case of a Constant, so use default nullptr
+ // for that optional parameter, since we don't have a Function to
+ // provide GetTLI anyway.
+ Constant *New = ConstantFoldConstant(C, DL, /*TLI*/ nullptr);
if (New && New != C)
GV->setInitializer(New);
}
@@ -2367,7 +2377,7 @@ OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI,
continue;
}
- Changed |= processGlobal(*GV, TLI, LookupDomTree);
+ Changed |= processGlobal(*GV, GetTLI, LookupDomTree);
}
return Changed;
}
@@ -2581,8 +2591,8 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL,
}
static int compareNames(Constant *const *A, Constant *const *B) {
- Value *AStripped = (*A)->stripPointerCastsNoFollowAliases();
- Value *BStripped = (*B)->stripPointerCastsNoFollowAliases();
+ Value *AStripped = (*A)->stripPointerCasts();
+ Value *BStripped = (*B)->stripPointerCasts();
return AStripped->getName().compare(BStripped->getName());
}
@@ -2809,7 +2819,14 @@ OptimizeGlobalAliases(Module &M,
return Changed;
}
-static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) {
+static Function *
+FindCXAAtExit(Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
+ // Hack to get a default TLI before we have actual Function.
+ auto FuncIter = M.begin();
+ if (FuncIter == M.end())
+ return nullptr;
+ auto *TLI = &GetTLI(*FuncIter);
+
LibFunc F = LibFunc_cxa_atexit;
if (!TLI->has(F))
return nullptr;
@@ -2818,6 +2835,9 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) {
if (!Fn)
return nullptr;
+ // Now get the actual TLI for Fn.
+ TLI = &GetTLI(*Fn);
+
// Make sure that the function has the correct prototype.
if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit)
return nullptr;
@@ -2889,7 +2909,8 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) {
}
static bool optimizeGlobalsInModule(
- Module &M, const DataLayout &DL, TargetLibraryInfo *TLI,
+ Module &M, const DataLayout &DL,
+ function_ref<TargetLibraryInfo &(Function &)> GetTLI,
function_ref<TargetTransformInfo &(Function &)> GetTTI,
function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
function_ref<DominatorTree &(Function &)> LookupDomTree) {
@@ -2914,24 +2935,24 @@ static bool optimizeGlobalsInModule(
NotDiscardableComdats.insert(C);
// Delete functions that are trivially dead, ccc -> fastcc
- LocalChange |= OptimizeFunctions(M, TLI, GetTTI, GetBFI, LookupDomTree,
+ LocalChange |= OptimizeFunctions(M, GetTLI, GetTTI, GetBFI, LookupDomTree,
NotDiscardableComdats);
// Optimize global_ctors list.
LocalChange |= optimizeGlobalCtorsList(M, [&](Function *F) {
- return EvaluateStaticConstructor(F, DL, TLI);
+ return EvaluateStaticConstructor(F, DL, &GetTLI(*F));
});
// Optimize non-address-taken globals.
- LocalChange |= OptimizeGlobalVars(M, TLI, LookupDomTree,
- NotDiscardableComdats);
+ LocalChange |=
+ OptimizeGlobalVars(M, GetTLI, LookupDomTree, NotDiscardableComdats);
// Resolve aliases, when possible.
LocalChange |= OptimizeGlobalAliases(M, NotDiscardableComdats);
// Try to remove trivial global destructors if they are not removed
// already.
- Function *CXAAtExitFn = FindCXAAtExit(M, TLI);
+ Function *CXAAtExitFn = FindCXAAtExit(M, GetTLI);
if (CXAAtExitFn)
LocalChange |= OptimizeEmptyGlobalCXXDtors(CXAAtExitFn);
@@ -2946,12 +2967,14 @@ static bool optimizeGlobalsInModule(
PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) {
auto &DL = M.getDataLayout();
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
auto &FAM =
AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
auto LookupDomTree = [&FAM](Function &F) -> DominatorTree &{
return FAM.getResult<DominatorTreeAnalysis>(F);
};
+ auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & {
return FAM.getResult<TargetIRAnalysis>(F);
};
@@ -2960,7 +2983,7 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) {
return FAM.getResult<BlockFrequencyAnalysis>(F);
};
- if (!optimizeGlobalsInModule(M, DL, &TLI, GetTTI, GetBFI, LookupDomTree))
+ if (!optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
@@ -2979,10 +3002,12 @@ struct GlobalOptLegacyPass : public ModulePass {
return false;
auto &DL = M.getDataLayout();
- auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
auto LookupDomTree = [this](Function &F) -> DominatorTree & {
return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
};
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
};
@@ -2991,7 +3016,8 @@ struct GlobalOptLegacyPass : public ModulePass {
return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
};
- return optimizeGlobalsInModule(M, DL, TLI, GetTTI, GetBFI, LookupDomTree);
+ return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI,
+ LookupDomTree);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
diff --git a/lib/Transforms/IPO/HotColdSplitting.cpp b/lib/Transforms/IPO/HotColdSplitting.cpp
index ab1a9a79cad6..cfdcc8db7f50 100644
--- a/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -85,12 +85,6 @@ static cl::opt<int>
"multiple of TCC_Basic)"));
namespace {
-
-/// A sequence of basic blocks.
-///
-/// A 0-sized SmallVector is slightly cheaper to move than a std::vector.
-using BlockSequence = SmallVector<BasicBlock *, 0>;
-
// Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify
// this function unless you modify the MBB version as well.
//
@@ -169,31 +163,6 @@ static bool markFunctionCold(Function &F, bool UpdateEntryCount = false) {
return Changed;
}
-class HotColdSplitting {
-public:
- HotColdSplitting(ProfileSummaryInfo *ProfSI,
- function_ref<BlockFrequencyInfo *(Function &)> GBFI,
- function_ref<TargetTransformInfo &(Function &)> GTTI,
- std::function<OptimizationRemarkEmitter &(Function &)> *GORE,
- function_ref<AssumptionCache *(Function &)> LAC)
- : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {}
- bool run(Module &M);
-
-private:
- bool isFunctionCold(const Function &F) const;
- bool shouldOutlineFrom(const Function &F) const;
- bool outlineColdRegions(Function &F, bool HasProfileSummary);
- Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT,
- BlockFrequencyInfo *BFI, TargetTransformInfo &TTI,
- OptimizationRemarkEmitter &ORE,
- AssumptionCache *AC, unsigned Count);
- ProfileSummaryInfo *PSI;
- function_ref<BlockFrequencyInfo *(Function &)> GetBFI;
- function_ref<TargetTransformInfo &(Function &)> GetTTI;
- std::function<OptimizationRemarkEmitter &(Function &)> *GetORE;
- function_ref<AssumptionCache *(Function &)> LookupAC;
-};
-
class HotColdSplittingLegacyPass : public ModulePass {
public:
static char ID;
@@ -321,13 +290,10 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region,
return Penalty;
}
-Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region,
- DominatorTree &DT,
- BlockFrequencyInfo *BFI,
- TargetTransformInfo &TTI,
- OptimizationRemarkEmitter &ORE,
- AssumptionCache *AC,
- unsigned Count) {
+Function *HotColdSplitting::extractColdRegion(
+ const BlockSequence &Region, const CodeExtractorAnalysisCache &CEAC,
+ DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI,
+ OptimizationRemarkEmitter &ORE, AssumptionCache *AC, unsigned Count) {
assert(!Region.empty());
// TODO: Pass BFI and BPI to update profile information.
@@ -349,7 +315,7 @@ Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region,
return nullptr;
Function *OrigF = Region[0]->getParent();
- if (Function *OutF = CE.extractCodeRegion()) {
+ if (Function *OutF = CE.extractCodeRegion(CEAC)) {
User *U = *OutF->user_begin();
CallInst *CI = cast<CallInst>(U);
CallSite CS(CI);
@@ -607,9 +573,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
});
if (!DT)
- DT = make_unique<DominatorTree>(F);
+ DT = std::make_unique<DominatorTree>(F);
if (!PDT)
- PDT = make_unique<PostDominatorTree>(F);
+ PDT = std::make_unique<PostDominatorTree>(F);
auto Regions = OutliningRegion::create(*BB, *DT, *PDT);
for (OutliningRegion &Region : Regions) {
@@ -637,9 +603,14 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
}
}
+ if (OutliningWorklist.empty())
+ return Changed;
+
// Outline single-entry cold regions, splitting up larger regions as needed.
unsigned OutlinedFunctionID = 1;
- while (!OutliningWorklist.empty()) {
+ // Cache and recycle the CodeExtractor analysis to avoid O(n^2) compile-time.
+ CodeExtractorAnalysisCache CEAC(F);
+ do {
OutliningRegion Region = OutliningWorklist.pop_back_val();
assert(!Region.empty() && "Empty outlining region in worklist");
do {
@@ -650,14 +621,14 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
BB->dump();
});
- Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC,
- OutlinedFunctionID);
+ Function *Outlined = extractColdRegion(SubRegion, CEAC, *DT, BFI, TTI,
+ ORE, AC, OutlinedFunctionID);
if (Outlined) {
++OutlinedFunctionID;
Changed = true;
}
} while (!Region.empty());
- }
+ } while (!OutliningWorklist.empty());
return Changed;
}
diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp
index 34db75dd8b03..bddf75211599 100644
--- a/lib/Transforms/IPO/IPO.cpp
+++ b/lib/Transforms/IPO/IPO.cpp
@@ -114,6 +114,10 @@ void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createIPSCCPPass());
}
+void LLVMAddMergeFunctionsPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createMergeFunctionsPass());
+}
+
void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) {
auto PreserveMain = [=](const GlobalValue &GV) {
return AllButMain && GV.getName() == "main";
@@ -121,6 +125,15 @@ void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) {
unwrap(PM)->add(createInternalizePass(PreserveMain));
}
+void LLVMAddInternalizePassWithMustPreservePredicate(
+ LLVMPassManagerRef PM,
+ void *Context,
+ LLVMBool (*Pred)(LLVMValueRef, void *)) {
+ unwrap(PM)->add(createInternalizePass([=](const GlobalValue &GV) {
+ return Pred(wrap(&GV), Context) == 0 ? false : true;
+ }));
+}
+
void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createStripDeadPrototypesPass());
}
diff --git a/lib/Transforms/IPO/InferFunctionAttrs.cpp b/lib/Transforms/IPO/InferFunctionAttrs.cpp
index 7f5511e008e1..d1a68b28bd33 100644
--- a/lib/Transforms/IPO/InferFunctionAttrs.cpp
+++ b/lib/Transforms/IPO/InferFunctionAttrs.cpp
@@ -18,24 +18,28 @@ using namespace llvm;
#define DEBUG_TYPE "inferattrs"
-static bool inferAllPrototypeAttributes(Module &M,
- const TargetLibraryInfo &TLI) {
+static bool inferAllPrototypeAttributes(
+ Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) {
bool Changed = false;
for (Function &F : M.functions())
// We only infer things using the prototype and the name; we don't need
// definitions.
if (F.isDeclaration() && !F.hasOptNone())
- Changed |= inferLibFuncAttributes(F, TLI);
+ Changed |= inferLibFuncAttributes(F, GetTLI(F));
return Changed;
}
PreservedAnalyses InferFunctionAttrsPass::run(Module &M,
ModuleAnalysisManager &AM) {
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
- if (!inferAllPrototypeAttributes(M, TLI))
+ if (!inferAllPrototypeAttributes(M, GetTLI))
// If we didn't infer anything, preserve all analyses.
return PreservedAnalyses::all();
@@ -60,8 +64,10 @@ struct InferFunctionAttrsLegacyPass : public ModulePass {
if (skipModule(M))
return false;
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
- return inferAllPrototypeAttributes(M, TLI);
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+ return inferAllPrototypeAttributes(M, GetTLI);
}
};
}
diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp
index 945f8affae6e..4b72261131c1 100644
--- a/lib/Transforms/IPO/Inliner.cpp
+++ b/lib/Transforms/IPO/Inliner.cpp
@@ -239,7 +239,7 @@ static void mergeInlinedArrayAllocas(
}
if (Align1 > Align2)
- AvailableAlloca->setAlignment(AI->getAlignment());
+ AvailableAlloca->setAlignment(MaybeAlign(AI->getAlignment()));
}
AI->eraseFromParent();
@@ -527,7 +527,8 @@ static void setInlineRemark(CallSite &CS, StringRef message) {
static bool
inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG,
std::function<AssumptionCache &(Function &)> GetAssumptionCache,
- ProfileSummaryInfo *PSI, TargetLibraryInfo &TLI,
+ ProfileSummaryInfo *PSI,
+ std::function<TargetLibraryInfo &(Function &)> GetTLI,
bool InsertLifetime,
function_ref<InlineCost(CallSite CS)> GetInlineCost,
function_ref<AAResults &(Function &)> AARGetter,
@@ -626,7 +627,8 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG,
Instruction *Instr = CS.getInstruction();
- bool IsTriviallyDead = isInstructionTriviallyDead(Instr, &TLI);
+ bool IsTriviallyDead =
+ isInstructionTriviallyDead(Instr, &GetTLI(*Caller));
int InlineHistoryID;
if (!IsTriviallyDead) {
@@ -757,13 +759,16 @@ bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) {
CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
ACT = &getAnalysis<AssumptionCacheTracker>();
PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto GetTLI = [&](Function &F) -> TargetLibraryInfo & {
+ return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
return ACT->getAssumptionCache(F);
};
- return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime,
- [this](CallSite CS) { return getInlineCost(CS); },
- LegacyAARGetter(*this), ImportedFunctionsStats);
+ return inlineCallsImpl(
+ SCC, CG, GetAssumptionCache, PSI, GetTLI, InsertLifetime,
+ [this](CallSite CS) { return getInlineCost(CS); }, LegacyAARGetter(*this),
+ ImportedFunctionsStats);
}
/// Remove now-dead linkonce functions at the end of
@@ -879,7 +884,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
if (!ImportedFunctionsStats &&
InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) {
ImportedFunctionsStats =
- llvm::make_unique<ImportedFunctionsInliningStatistics>();
+ std::make_unique<ImportedFunctionsInliningStatistics>();
ImportedFunctionsStats->setModuleInfo(M);
}
diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp
index 91c7b5f5f135..add2ae053735 100644
--- a/lib/Transforms/IPO/LoopExtractor.cpp
+++ b/lib/Transforms/IPO/LoopExtractor.cpp
@@ -141,10 +141,12 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) {
if (NumLoops == 0) return Changed;
--NumLoops;
AssumptionCache *AC = nullptr;
+ Function &Func = *L->getHeader()->getParent();
if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
- AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent());
+ AC = ACT->lookupAssumptionCache(Func);
+ CodeExtractorAnalysisCache CEAC(Func);
CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
- if (Extractor.extractCodeRegion() != nullptr) {
+ if (Extractor.extractCodeRegion(CEAC) != nullptr) {
Changed = true;
// After extraction, the loop is replaced by a function call, so
// we shouldn't try to run any more loop passes on it.
diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp
index f7371284f47e..2dec366d70e2 100644
--- a/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -230,6 +230,16 @@ void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
Bytes[AllocByteOffset + B] |= AllocMask;
}
+bool lowertypetests::isJumpTableCanonical(Function *F) {
+ if (F->isDeclarationForLinker())
+ return false;
+ auto *CI = mdconst::extract_or_null<ConstantInt>(
+ F->getParent()->getModuleFlag("CFI Canonical Jump Tables"));
+ if (!CI || CI->getZExtValue() != 0)
+ return true;
+ return F->hasFnAttribute("cfi-canonical-jump-table");
+}
+
namespace {
struct ByteArrayInfo {
@@ -251,9 +261,12 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
GlobalObject *GO;
size_t NTypes;
- // For functions: true if this is a definition (either in the merged module or
- // in one of the thinlto modules).
- bool IsDefinition;
+ // For functions: true if the jump table is canonical. This essentially means
+ // whether the canonical address (i.e. the symbol table entry) of the function
+ // is provided by the local jump table. This is normally the same as whether
+ // the function is defined locally, but if canonical jump tables are disabled
+ // by the user then the jump table never provides a canonical definition.
+ bool IsJumpTableCanonical;
// For functions: true if this function is either defined or used in a thinlto
// module and its jumptable entry needs to be exported to thinlto backends.
@@ -263,13 +276,13 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
public:
static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
- bool IsDefinition, bool IsExported,
+ bool IsJumpTableCanonical, bool IsExported,
ArrayRef<MDNode *> Types) {
auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate(
totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember)));
GTM->GO = GO;
GTM->NTypes = Types.size();
- GTM->IsDefinition = IsDefinition;
+ GTM->IsJumpTableCanonical = IsJumpTableCanonical;
GTM->IsExported = IsExported;
std::uninitialized_copy(Types.begin(), Types.end(),
GTM->getTrailingObjects<MDNode *>());
@@ -280,8 +293,8 @@ public:
return GO;
}
- bool isDefinition() const {
- return IsDefinition;
+ bool isJumpTableCanonical() const {
+ return IsJumpTableCanonical;
}
bool isExported() const {
@@ -320,6 +333,49 @@ private:
size_t NTargets;
};
+struct ScopedSaveAliaseesAndUsed {
+ Module &M;
+ SmallPtrSet<GlobalValue *, 16> Used, CompilerUsed;
+ std::vector<std::pair<GlobalIndirectSymbol *, Function *>> FunctionAliases;
+
+ ScopedSaveAliaseesAndUsed(Module &M) : M(M) {
+ // The users of this class want to replace all function references except
+ // for aliases and llvm.used/llvm.compiler.used with references to a jump
+ // table. We avoid replacing aliases in order to avoid introducing a double
+ // indirection (or an alias pointing to a declaration in ThinLTO mode), and
+ // we avoid replacing llvm.used/llvm.compiler.used because these global
+ // variables describe properties of the global, not the jump table (besides,
+ // offseted references to the jump table in llvm.used are invalid).
+ // Unfortunately, LLVM doesn't have a "RAUW except for these (possibly
+ // indirect) users", so what we do is save the list of globals referenced by
+ // llvm.used/llvm.compiler.used and aliases, erase the used lists, let RAUW
+ // replace the aliasees and then set them back to their original values at
+ // the end.
+ if (GlobalVariable *GV = collectUsedGlobalVariables(M, Used, false))
+ GV->eraseFromParent();
+ if (GlobalVariable *GV = collectUsedGlobalVariables(M, CompilerUsed, true))
+ GV->eraseFromParent();
+
+ for (auto &GIS : concat<GlobalIndirectSymbol>(M.aliases(), M.ifuncs())) {
+ // FIXME: This should look past all aliases not just interposable ones,
+ // see discussion on D65118.
+ if (auto *F =
+ dyn_cast<Function>(GIS.getIndirectSymbol()->stripPointerCasts()))
+ FunctionAliases.push_back({&GIS, F});
+ }
+ }
+
+ ~ScopedSaveAliaseesAndUsed() {
+ appendToUsed(M, std::vector<GlobalValue *>(Used.begin(), Used.end()));
+ appendToCompilerUsed(M, std::vector<GlobalValue *>(CompilerUsed.begin(),
+ CompilerUsed.end()));
+
+ for (auto P : FunctionAliases)
+ P.first->setIndirectSymbol(
+ ConstantExpr::getBitCast(P.second, P.first->getType()));
+ }
+};
+
class LowerTypeTestsModule {
Module &M;
@@ -387,7 +443,8 @@ class LowerTypeTestsModule {
uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL);
TypeIdLowering importTypeId(StringRef TypeId);
void importTypeTest(CallInst *CI);
- void importFunction(Function *F, bool isDefinition);
+ void importFunction(Function *F, bool isJumpTableCanonical,
+ std::vector<GlobalAlias *> &AliasesToErase);
BitSetInfo
buildBitSet(Metadata *TypeId,
@@ -421,7 +478,8 @@ class LowerTypeTestsModule {
ArrayRef<GlobalTypeMember *> Globals,
ArrayRef<ICallBranchFunnel *> ICallBranchFunnels);
- void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT, bool IsDefinition);
+ void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT,
+ bool IsJumpTableCanonical);
void moveInitializerToModuleConstructor(GlobalVariable *GV);
void findGlobalVariableUsersOf(Constant *C,
SmallSetVector<GlobalVariable *, 8> &Out);
@@ -433,7 +491,7 @@ class LowerTypeTestsModule {
/// the block. 'This's use list is expected to have at least one element.
/// Unlike replaceAllUsesWith this function skips blockaddr and direct call
/// uses.
- void replaceCfiUses(Function *Old, Value *New, bool IsDefinition);
+ void replaceCfiUses(Function *Old, Value *New, bool IsJumpTableCanonical);
/// replaceDirectCalls - Go through the uses list for this definition and
/// replace each use, which is a direct function call.
@@ -759,43 +817,50 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(
// Build a new global with the combined contents of the referenced globals.
// This global is a struct whose even-indexed elements contain the original
// contents of the referenced globals and whose odd-indexed elements contain
- // any padding required to align the next element to the next power of 2.
+ // any padding required to align the next element to the next power of 2 plus
+ // any additional padding required to meet its alignment requirements.
std::vector<Constant *> GlobalInits;
const DataLayout &DL = M.getDataLayout();
+ DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
+ Align MaxAlign;
+ uint64_t CurOffset = 0;
+ uint64_t DesiredPadding = 0;
for (GlobalTypeMember *G : Globals) {
- GlobalVariable *GV = cast<GlobalVariable>(G->getGlobal());
+ auto *GV = cast<GlobalVariable>(G->getGlobal());
+ MaybeAlign Alignment(GV->getAlignment());
+ if (!Alignment)
+ Alignment = Align(DL.getABITypeAlignment(GV->getValueType()));
+ MaxAlign = std::max(MaxAlign, *Alignment);
+ uint64_t GVOffset = alignTo(CurOffset + DesiredPadding, *Alignment);
+ GlobalLayout[G] = GVOffset;
+ if (GVOffset != 0) {
+ uint64_t Padding = GVOffset - CurOffset;
+ GlobalInits.push_back(
+ ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
+ }
+
GlobalInits.push_back(GV->getInitializer());
uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType());
+ CurOffset = GVOffset + InitSize;
- // Compute the amount of padding required.
- uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
+ // Compute the amount of padding that we'd like for the next element.
+ DesiredPadding = NextPowerOf2(InitSize - 1) - InitSize;
// Experiments of different caps with Chromium on both x64 and ARM64
// have shown that the 32-byte cap generates the smallest binary on
// both platforms while different caps yield similar performance.
// (see https://lists.llvm.org/pipermail/llvm-dev/2018-July/124694.html)
- if (Padding > 32)
- Padding = alignTo(InitSize, 32) - InitSize;
-
- GlobalInits.push_back(
- ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
+ if (DesiredPadding > 32)
+ DesiredPadding = alignTo(InitSize, 32) - InitSize;
}
- if (!GlobalInits.empty())
- GlobalInits.pop_back();
+
Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits);
auto *CombinedGlobal =
new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true,
GlobalValue::PrivateLinkage, NewInit);
+ CombinedGlobal->setAlignment(MaxAlign);
StructType *NewTy = cast<StructType>(NewInit->getType());
- const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy);
-
- // Compute the offsets of the original globals within the new global.
- DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
- for (unsigned I = 0; I != Globals.size(); ++I)
- // Multiply by 2 to account for padding elements.
- GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
-
lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
// Build aliases pointing to offsets into the combined global for each
@@ -975,14 +1040,16 @@ void LowerTypeTestsModule::importTypeTest(CallInst *CI) {
}
// ThinLTO backend: the function F has a jump table entry; update this module
-// accordingly. isDefinition describes the type of the jump table entry.
-void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) {
+// accordingly. isJumpTableCanonical describes the type of the jump table entry.
+void LowerTypeTestsModule::importFunction(
+ Function *F, bool isJumpTableCanonical,
+ std::vector<GlobalAlias *> &AliasesToErase) {
assert(F->getType()->getAddressSpace() == 0);
GlobalValue::VisibilityTypes Visibility = F->getVisibility();
std::string Name = F->getName();
- if (F->isDeclarationForLinker() && isDefinition) {
+ if (F->isDeclarationForLinker() && isJumpTableCanonical) {
// Non-dso_local functions may be overriden at run time,
// don't short curcuit them
if (F->isDSOLocal()) {
@@ -997,12 +1064,13 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) {
}
Function *FDecl;
- if (F->isDeclarationForLinker() && !isDefinition) {
- // Declaration of an external function.
+ if (!isJumpTableCanonical) {
+ // Either a declaration of an external function or a reference to a locally
+ // defined jump table.
FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
F->getAddressSpace(), Name + ".cfi_jt", &M);
FDecl->setVisibility(GlobalValue::HiddenVisibility);
- } else if (isDefinition) {
+ } else {
F->setName(Name + ".cfi");
F->setLinkage(GlobalValue::ExternalLinkage);
FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
@@ -1011,8 +1079,8 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) {
Visibility = GlobalValue::HiddenVisibility;
// Delete aliases pointing to this function, they'll be re-created in the
- // merged output
- SmallVector<GlobalAlias*, 4> ToErase;
+ // merged output. Don't do it yet though because ScopedSaveAliaseesAndUsed
+ // will want to reset the aliasees first.
for (auto &U : F->uses()) {
if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) {
Function *AliasDecl = Function::Create(
@@ -1020,24 +1088,15 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) {
F->getAddressSpace(), "", &M);
AliasDecl->takeName(A);
A->replaceAllUsesWith(AliasDecl);
- ToErase.push_back(A);
+ AliasesToErase.push_back(A);
}
}
- for (auto *A : ToErase)
- A->eraseFromParent();
- } else {
- // Function definition without type metadata, where some other translation
- // unit contained a declaration with type metadata. This normally happens
- // during mixed CFI + non-CFI compilation. We do nothing with the function
- // so that it is treated the same way as a function defined outside of the
- // LTO unit.
- return;
}
- if (F->isWeakForLinker())
- replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isDefinition);
+ if (F->hasExternalWeakLinkage())
+ replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isJumpTableCanonical);
else
- replaceCfiUses(F, FDecl, isDefinition);
+ replaceCfiUses(F, FDecl, isJumpTableCanonical);
// Set visibility late because it's used in replaceCfiUses() to determine
// whether uses need to to be replaced.
@@ -1225,7 +1284,7 @@ void LowerTypeTestsModule::findGlobalVariableUsersOf(
// Replace all uses of F with (F ? JT : 0).
void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
- Function *F, Constant *JT, bool IsDefinition) {
+ Function *F, Constant *JT, bool IsJumpTableCanonical) {
// The target expression can not appear in a constant initializer on most
// (all?) targets. Switch to a runtime initializer.
SmallSetVector<GlobalVariable *, 8> GlobalVarUsers;
@@ -1239,7 +1298,7 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
Function::Create(cast<FunctionType>(F->getValueType()),
GlobalValue::ExternalWeakLinkage,
F->getAddressSpace(), "", &M);
- replaceCfiUses(F, PlaceholderFn, IsDefinition);
+ replaceCfiUses(F, PlaceholderFn, IsJumpTableCanonical);
Constant *Target = ConstantExpr::getSelect(
ConstantExpr::getICmp(CmpInst::ICMP_NE, F,
@@ -1276,8 +1335,9 @@ selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions,
unsigned ArmCount = 0, ThumbCount = 0;
for (const auto GTM : Functions) {
- if (!GTM->isDefinition()) {
+ if (!GTM->isJumpTableCanonical()) {
// PLT stubs are always ARM.
+ // FIXME: This is the wrong heuristic for non-canonical jump tables.
++ArmCount;
continue;
}
@@ -1303,7 +1363,7 @@ void LowerTypeTestsModule::createJumpTable(
cast<Function>(Functions[I]->getGlobal()));
// Align the whole table by entry size.
- F->setAlignment(getJumpTableEntrySize());
+ F->setAlignment(Align(getJumpTableEntrySize()));
// Skip prologue.
// Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3.
// Luckily, this function does not get any prologue even without the
@@ -1438,47 +1498,53 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
- // Build aliases pointing to offsets into the jump table, and replace
- // references to the original functions with references to the aliases.
- for (unsigned I = 0; I != Functions.size(); ++I) {
- Function *F = cast<Function>(Functions[I]->getGlobal());
- bool IsDefinition = Functions[I]->isDefinition();
-
- Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
- ConstantExpr::getInBoundsGetElementPtr(
- JumpTableType, JumpTable,
- ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
- ConstantInt::get(IntPtrTy, I)}),
- F->getType());
- if (Functions[I]->isExported()) {
- if (IsDefinition) {
- ExportSummary->cfiFunctionDefs().insert(F->getName());
+ {
+ ScopedSaveAliaseesAndUsed S(M);
+
+ // Build aliases pointing to offsets into the jump table, and replace
+ // references to the original functions with references to the aliases.
+ for (unsigned I = 0; I != Functions.size(); ++I) {
+ Function *F = cast<Function>(Functions[I]->getGlobal());
+ bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical();
+
+ Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
+ ConstantExpr::getInBoundsGetElementPtr(
+ JumpTableType, JumpTable,
+ ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
+ ConstantInt::get(IntPtrTy, I)}),
+ F->getType());
+ if (Functions[I]->isExported()) {
+ if (IsJumpTableCanonical) {
+ ExportSummary->cfiFunctionDefs().insert(F->getName());
+ } else {
+ GlobalAlias *JtAlias = GlobalAlias::create(
+ F->getValueType(), 0, GlobalValue::ExternalLinkage,
+ F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M);
+ JtAlias->setVisibility(GlobalValue::HiddenVisibility);
+ ExportSummary->cfiFunctionDecls().insert(F->getName());
+ }
+ }
+ if (!IsJumpTableCanonical) {
+ if (F->hasExternalWeakLinkage())
+ replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr,
+ IsJumpTableCanonical);
+ else
+ replaceCfiUses(F, CombinedGlobalElemPtr, IsJumpTableCanonical);
} else {
- GlobalAlias *JtAlias = GlobalAlias::create(
- F->getValueType(), 0, GlobalValue::ExternalLinkage,
- F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M);
- JtAlias->setVisibility(GlobalValue::HiddenVisibility);
- ExportSummary->cfiFunctionDecls().insert(F->getName());
+ assert(F->getType()->getAddressSpace() == 0);
+
+ GlobalAlias *FAlias =
+ GlobalAlias::create(F->getValueType(), 0, F->getLinkage(), "",
+ CombinedGlobalElemPtr, &M);
+ FAlias->setVisibility(F->getVisibility());
+ FAlias->takeName(F);
+ if (FAlias->hasName())
+ F->setName(FAlias->getName() + ".cfi");
+ replaceCfiUses(F, FAlias, IsJumpTableCanonical);
+ if (!F->hasLocalLinkage())
+ F->setVisibility(GlobalVariable::HiddenVisibility);
}
}
- if (!IsDefinition) {
- if (F->isWeakForLinker())
- replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, IsDefinition);
- else
- replaceCfiUses(F, CombinedGlobalElemPtr, IsDefinition);
- } else {
- assert(F->getType()->getAddressSpace() == 0);
-
- GlobalAlias *FAlias = GlobalAlias::create(
- F->getValueType(), 0, F->getLinkage(), "", CombinedGlobalElemPtr, &M);
- FAlias->setVisibility(F->getVisibility());
- FAlias->takeName(F);
- if (FAlias->hasName())
- F->setName(FAlias->getName() + ".cfi");
- replaceCfiUses(F, FAlias, IsDefinition);
- if (!F->hasLocalLinkage())
- F->setVisibility(GlobalVariable::HiddenVisibility);
- }
}
createJumpTable(JumpTableFn, Functions);
@@ -1623,7 +1689,7 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {
ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary +
": ");
std::error_code EC;
- raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
+ raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text);
ExitOnErr(errorCodeToError(EC));
yaml::Output Out(OS);
@@ -1643,7 +1709,8 @@ static bool isDirectCall(Use& U) {
return false;
}
-void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefinition) {
+void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New,
+ bool IsJumpTableCanonical) {
SmallSetVector<Constant *, 4> Constants;
auto UI = Old->use_begin(), E = Old->use_end();
for (; UI != E;) {
@@ -1655,7 +1722,7 @@ void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefi
continue;
// Skip direct calls to externally defined or non-dso_local functions
- if (isDirectCall(U) && (Old->isDSOLocal() || !IsDefinition))
+ if (isDirectCall(U) && (Old->isDSOLocal() || !IsJumpTableCanonical))
continue;
// Must handle Constants specially, we cannot call replaceUsesOfWith on a
@@ -1678,16 +1745,7 @@ void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefi
}
void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) {
- auto UI = Old->use_begin(), E = Old->use_end();
- for (; UI != E;) {
- Use &U = *UI;
- ++UI;
-
- if (!isDirectCall(U))
- continue;
-
- U.set(New);
- }
+ Old->replaceUsesWithIf(New, [](Use &U) { return isDirectCall(U); });
}
bool LowerTypeTestsModule::lower() {
@@ -1734,10 +1792,16 @@ bool LowerTypeTestsModule::lower() {
Decls.push_back(&F);
}
- for (auto F : Defs)
- importFunction(F, /*isDefinition*/ true);
- for (auto F : Decls)
- importFunction(F, /*isDefinition*/ false);
+ std::vector<GlobalAlias *> AliasesToErase;
+ {
+ ScopedSaveAliaseesAndUsed S(M);
+ for (auto F : Defs)
+ importFunction(F, /*isJumpTableCanonical*/ true, AliasesToErase);
+ for (auto F : Decls)
+ importFunction(F, /*isJumpTableCanonical*/ false, AliasesToErase);
+ }
+ for (GlobalAlias *GA : AliasesToErase)
+ GA->eraseFromParent();
return true;
}
@@ -1823,6 +1887,17 @@ bool LowerTypeTestsModule::lower() {
CfiFunctionLinkage Linkage = P.second.Linkage;
MDNode *FuncMD = P.second.FuncMD;
Function *F = M.getFunction(FunctionName);
+ if (F && F->hasLocalLinkage()) {
+ // Locally defined function that happens to have the same name as a
+ // function defined in a ThinLTO module. Rename it to move it out of
+ // the way of the external reference that we're about to create.
+ // Note that setName will find a unique name for the function, so even
+ // if there is an existing function with the suffix there won't be a
+ // name collision.
+ F->setName(F->getName() + ".1");
+ F = nullptr;
+ }
+
if (!F)
F = Function::Create(
FunctionType::get(Type::getVoidTy(M.getContext()), false),
@@ -1871,24 +1946,26 @@ bool LowerTypeTestsModule::lower() {
Types.clear();
GO.getMetadata(LLVMContext::MD_type, Types);
- bool IsDefinition = !GO.isDeclarationForLinker();
+ bool IsJumpTableCanonical = false;
bool IsExported = false;
if (Function *F = dyn_cast<Function>(&GO)) {
+ IsJumpTableCanonical = isJumpTableCanonical(F);
if (ExportedFunctions.count(F->getName())) {
- IsDefinition |= ExportedFunctions[F->getName()].Linkage == CFL_Definition;
+ IsJumpTableCanonical |=
+ ExportedFunctions[F->getName()].Linkage == CFL_Definition;
IsExported = true;
// TODO: The logic here checks only that the function is address taken,
// not that the address takers are live. This can be updated to check
// their liveness and emit fewer jumptable entries once monolithic LTO
// builds also emit summaries.
} else if (!F->hasAddressTaken()) {
- if (!CrossDsoCfi || !IsDefinition || F->hasLocalLinkage())
+ if (!CrossDsoCfi || !IsJumpTableCanonical || F->hasLocalLinkage())
continue;
}
}
- auto *GTM =
- GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types);
+ auto *GTM = GlobalTypeMember::create(Alloc, &GO, IsJumpTableCanonical,
+ IsExported, Types);
GlobalTypeMembers[&GO] = GTM;
for (MDNode *Type : Types) {
verifyTypeMDNode(&GO, Type);
diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp
index 3a08069dcd4a..8b9abaddc84c 100644
--- a/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/lib/Transforms/IPO/MergeFunctions.cpp
@@ -769,7 +769,7 @@ void MergeFunctions::writeAlias(Function *F, Function *G) {
PtrType->getElementType(), PtrType->getAddressSpace(),
G->getLinkage(), "", BitcastF, G->getParent());
- F->setAlignment(std::max(F->getAlignment(), G->getAlignment()));
+ F->setAlignment(MaybeAlign(std::max(F->getAlignment(), G->getAlignment())));
GA->takeName(G);
GA->setVisibility(G->getVisibility());
GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
@@ -816,7 +816,7 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
removeUsers(F);
F->replaceAllUsesWith(NewF);
- unsigned MaxAlignment = std::max(G->getAlignment(), NewF->getAlignment());
+ MaybeAlign MaxAlignment(std::max(G->getAlignment(), NewF->getAlignment()));
writeThunkOrAlias(F, G);
writeThunkOrAlias(F, NewF);
diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp
index 733782e8764d..e193074884af 100644
--- a/lib/Transforms/IPO/PartialInlining.cpp
+++ b/lib/Transforms/IPO/PartialInlining.cpp
@@ -409,7 +409,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F,
return std::unique_ptr<FunctionOutliningMultiRegionInfo>();
std::unique_ptr<FunctionOutliningMultiRegionInfo> OutliningInfo =
- llvm::make_unique<FunctionOutliningMultiRegionInfo>();
+ std::make_unique<FunctionOutliningMultiRegionInfo>();
auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) {
BasicBlock *Dom = BlockList.front();
@@ -589,7 +589,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) {
};
std::unique_ptr<FunctionOutliningInfo> OutliningInfo =
- llvm::make_unique<FunctionOutliningInfo>();
+ std::make_unique<FunctionOutliningInfo>();
BasicBlock *CurrEntry = EntryBlock;
bool CandidateFound = false;
@@ -966,7 +966,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner(
Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC)
: OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
- ClonedOI = llvm::make_unique<FunctionOutliningInfo>();
+ ClonedOI = std::make_unique<FunctionOutliningInfo>();
// Clone the function, so that we can hack away on it.
ValueToValueMapTy VMap;
@@ -991,7 +991,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner(
OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC)
: OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
- ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>();
+ ClonedOMRI = std::make_unique<FunctionOutliningMultiRegionInfo>();
// Clone the function, so that we can hack away on it.
ValueToValueMapTy VMap;
@@ -1122,6 +1122,9 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
BranchProbabilityInfo BPI(*ClonedFunc, LI);
ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI));
+ // Cache and recycle the CodeExtractor analysis to avoid O(n^2) compile-time.
+ CodeExtractorAnalysisCache CEAC(*ClonedFunc);
+
SetVector<Value *> Inputs, Outputs, Sinks;
for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo :
ClonedOMRI->ORI) {
@@ -1148,7 +1151,7 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
if (Outputs.size() > 0 && !ForceLiveExit)
continue;
- Function *OutlinedFunc = CE.extractCodeRegion();
+ Function *OutlinedFunc = CE.extractCodeRegion(CEAC);
if (OutlinedFunc) {
CallSite OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc);
@@ -1210,11 +1213,12 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
}
// Extract the body of the if.
+ CodeExtractorAnalysisCache CEAC(*ClonedFunc);
Function *OutlinedFunc =
CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
/* AllowVarargs */ true)
- .extractCodeRegion();
+ .extractCodeRegion(CEAC);
if (OutlinedFunc) {
BasicBlock *OutliningCallBB =
@@ -1264,7 +1268,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
if (PSI->isFunctionEntryCold(F))
return {false, nullptr};
- if (empty(F->users()))
+ if (F->users().empty())
return {false, nullptr};
OptimizationRemarkEmitter ORE(F);
@@ -1370,7 +1374,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
return false;
}
- assert(empty(Cloner.OrigFunc->users()) &&
+ assert(Cloner.OrigFunc->users().empty() &&
"F's users should all be replaced!");
std::vector<User *> Users(Cloner.ClonedFunc->user_begin(),
diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp
index 3ea77f08fd3c..5314a8219b1e 100644
--- a/lib/Transforms/IPO/PassManagerBuilder.cpp
+++ b/lib/Transforms/IPO/PassManagerBuilder.cpp
@@ -654,6 +654,7 @@ void PassManagerBuilder::populateModulePassManager(
MPM.add(createGlobalsAAWrapperPass());
MPM.add(createFloat2IntPass());
+ MPM.add(createLowerConstantIntrinsicsPass());
addExtensionsToPM(EP_VectorizerStart, MPM);
diff --git a/lib/Transforms/IPO/SCCP.cpp b/lib/Transforms/IPO/SCCP.cpp
index 7be3608bd2ec..307690729b14 100644
--- a/lib/Transforms/IPO/SCCP.cpp
+++ b/lib/Transforms/IPO/SCCP.cpp
@@ -9,16 +9,18 @@ using namespace llvm;
PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) {
const DataLayout &DL = M.getDataLayout();
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto GetTLI = [&FAM](Function &F) -> const TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
auto getAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn {
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
return {
- make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)),
+ std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)),
&DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F)};
};
- if (!runIPSCCP(M, DL, &TLI, getAnalysis))
+ if (!runIPSCCP(M, DL, GetTLI, getAnalysis))
return PreservedAnalyses::all();
PreservedAnalyses PA;
@@ -47,14 +49,14 @@ public:
if (skipModule(M))
return false;
const DataLayout &DL = M.getDataLayout();
- const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
-
+ auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn {
DominatorTree &DT =
this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
return {
- make_unique<PredicateInfo>(
+ std::make_unique<PredicateInfo>(
F, DT,
this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
F)),
@@ -62,7 +64,7 @@ public:
nullptr}; // manager, so set them to nullptr.
};
- return runIPSCCP(M, DL, TLI, getAnalysis);
+ return runIPSCCP(M, DL, GetTLI, getAnalysis);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp
index 877d20e72ffc..6184681db8a2 100644
--- a/lib/Transforms/IPO/SampleProfile.cpp
+++ b/lib/Transforms/IPO/SampleProfile.cpp
@@ -72,6 +72,7 @@
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/MisExpect.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -79,6 +80,7 @@
#include <limits>
#include <map>
#include <memory>
+#include <queue>
#include <string>
#include <system_error>
#include <utility>
@@ -128,6 +130,12 @@ static cl::opt<bool> ProfileSampleAccurate(
"callsite and function as having 0 samples. Otherwise, treat "
"un-sampled callsites and functions conservatively as unknown. "));
+static cl::opt<bool> ProfileAccurateForSymsInList(
+ "profile-accurate-for-symsinlist", cl::Hidden, cl::ZeroOrMore,
+ cl::init(true),
+ cl::desc("For symbols in profile symbol list, regard their profiles to "
+ "be accurate. It may be overriden by profile-sample-accurate. "));
+
namespace {
using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>;
@@ -137,9 +145,11 @@ using EdgeWeightMap = DenseMap<Edge, uint64_t>;
using BlockEdgeMap =
DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>>;
+class SampleProfileLoader;
+
class SampleCoverageTracker {
public:
- SampleCoverageTracker() = default;
+ SampleCoverageTracker(SampleProfileLoader &SPL) : SPLoader(SPL){};
bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset,
uint32_t Discriminator, uint64_t Samples);
@@ -185,6 +195,76 @@ private:
/// keyed by FunctionSamples pointers, but these stats are cleared after
/// every function, so we just need to keep a single counter.
uint64_t TotalUsedSamples = 0;
+
+ SampleProfileLoader &SPLoader;
+};
+
+class GUIDToFuncNameMapper {
+public:
+ GUIDToFuncNameMapper(Module &M, SampleProfileReader &Reader,
+ DenseMap<uint64_t, StringRef> &GUIDToFuncNameMap)
+ : CurrentReader(Reader), CurrentModule(M),
+ CurrentGUIDToFuncNameMap(GUIDToFuncNameMap) {
+ if (CurrentReader.getFormat() != SPF_Compact_Binary)
+ return;
+
+ for (const auto &F : CurrentModule) {
+ StringRef OrigName = F.getName();
+ CurrentGUIDToFuncNameMap.insert(
+ {Function::getGUID(OrigName), OrigName});
+
+ // Local to global var promotion used by optimization like thinlto
+ // will rename the var and add suffix like ".llvm.xxx" to the
+ // original local name. In sample profile, the suffixes of function
+ // names are all stripped. Since it is possible that the mapper is
+ // built in post-thin-link phase and var promotion has been done,
+ // we need to add the substring of function name without the suffix
+ // into the GUIDToFuncNameMap.
+ StringRef CanonName = FunctionSamples::getCanonicalFnName(F);
+ if (CanonName != OrigName)
+ CurrentGUIDToFuncNameMap.insert(
+ {Function::getGUID(CanonName), CanonName});
+ }
+
+ // Update GUIDToFuncNameMap for each function including inlinees.
+ SetGUIDToFuncNameMapForAll(&CurrentGUIDToFuncNameMap);
+ }
+
+ ~GUIDToFuncNameMapper() {
+ if (CurrentReader.getFormat() != SPF_Compact_Binary)
+ return;
+
+ CurrentGUIDToFuncNameMap.clear();
+
+ // Reset GUIDToFuncNameMap for of each function as they're no
+ // longer valid at this point.
+ SetGUIDToFuncNameMapForAll(nullptr);
+ }
+
+private:
+ void SetGUIDToFuncNameMapForAll(DenseMap<uint64_t, StringRef> *Map) {
+ std::queue<FunctionSamples *> FSToUpdate;
+ for (auto &IFS : CurrentReader.getProfiles()) {
+ FSToUpdate.push(&IFS.second);
+ }
+
+ while (!FSToUpdate.empty()) {
+ FunctionSamples *FS = FSToUpdate.front();
+ FSToUpdate.pop();
+ FS->GUIDToFuncNameMap = Map;
+ for (const auto &ICS : FS->getCallsiteSamples()) {
+ const FunctionSamplesMap &FSMap = ICS.second;
+ for (auto &IFS : FSMap) {
+ FunctionSamples &FS = const_cast<FunctionSamples &>(IFS.second);
+ FSToUpdate.push(&FS);
+ }
+ }
+ }
+ }
+
+ SampleProfileReader &CurrentReader;
+ Module &CurrentModule;
+ DenseMap<uint64_t, StringRef> &CurrentGUIDToFuncNameMap;
};
/// Sample profile pass.
@@ -199,8 +279,9 @@ public:
std::function<AssumptionCache &(Function &)> GetAssumptionCache,
std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo)
: GetAC(std::move(GetAssumptionCache)),
- GetTTI(std::move(GetTargetTransformInfo)), Filename(Name),
- RemappingFilename(RemapName), IsThinLTOPreLink(IsThinLTOPreLink) {}
+ GetTTI(std::move(GetTargetTransformInfo)), CoverageTracker(*this),
+ Filename(Name), RemappingFilename(RemapName),
+ IsThinLTOPreLink(IsThinLTOPreLink) {}
bool doInitialization(Module &M);
bool runOnModule(Module &M, ModuleAnalysisManager *AM,
@@ -209,6 +290,8 @@ public:
void dump() { Reader->dump(); }
protected:
+ friend class SampleCoverageTracker;
+
bool runOnFunction(Function &F, ModuleAnalysisManager *AM);
unsigned getFunctionLoc(Function &F);
bool emitAnnotations(Function &F);
@@ -237,6 +320,8 @@ protected:
bool propagateThroughEdges(Function &F, bool UpdateBlockCount);
void computeDominanceAndLoopInfo(Function &F);
void clearFunctionData();
+ bool callsiteIsHot(const FunctionSamples *CallsiteFS,
+ ProfileSummaryInfo *PSI);
/// Map basic blocks to their computed weights.
///
@@ -310,6 +395,10 @@ protected:
/// Profile Summary Info computed from sample profile.
ProfileSummaryInfo *PSI = nullptr;
+ /// Profle Symbol list tells whether a function name appears in the binary
+ /// used to generate the current profile.
+ std::unique_ptr<ProfileSymbolList> PSL;
+
/// Total number of samples collected in this profile.
///
/// This is the sum of all the samples collected in all the functions executed
@@ -326,6 +415,21 @@ protected:
uint64_t entryCount;
};
DenseMap<Function *, NotInlinedProfileInfo> notInlinedCallInfo;
+
+ // GUIDToFuncNameMap saves the mapping from GUID to the symbol name, for
+ // all the function symbols defined or declared in current module.
+ DenseMap<uint64_t, StringRef> GUIDToFuncNameMap;
+
+ // All the Names used in FunctionSamples including outline function
+ // names, inline instance names and call target names.
+ StringSet<> NamesInProfile;
+
+ // For symbol in profile symbol list, whether to regard their profiles
+ // to be accurate. It is mainly decided by existance of profile symbol
+ // list and -profile-accurate-for-symsinlist flag, but it can be
+ // overriden by -profile-sample-accurate or profile-sample-accurate
+ // attribute.
+ bool ProfAccForSymsInList;
};
class SampleProfileLoaderLegacyPass : public ModulePass {
@@ -381,14 +485,23 @@ private:
/// To decide whether an inlined callsite is hot, we compare the callsite
/// sample count with the hot cutoff computed by ProfileSummaryInfo, it is
/// regarded as hot if the count is above the cutoff value.
-static bool callsiteIsHot(const FunctionSamples *CallsiteFS,
- ProfileSummaryInfo *PSI) {
+///
+/// When ProfileAccurateForSymsInList is enabled and profile symbol list
+/// is present, functions in the profile symbol list but without profile will
+/// be regarded as cold and much less inlining will happen in CGSCC inlining
+/// pass, so we tend to lower the hot criteria here to allow more early
+/// inlining to happen for warm callsites and it is helpful for performance.
+bool SampleProfileLoader::callsiteIsHot(const FunctionSamples *CallsiteFS,
+ ProfileSummaryInfo *PSI) {
if (!CallsiteFS)
return false; // The callsite was not inlined in the original binary.
assert(PSI && "PSI is expected to be non null");
uint64_t CallsiteTotalSamples = CallsiteFS->getTotalSamples();
- return PSI->isHotCount(CallsiteTotalSamples);
+ if (ProfAccForSymsInList)
+ return !PSI->isColdCount(CallsiteTotalSamples);
+ else
+ return PSI->isHotCount(CallsiteTotalSamples);
}
/// Mark as used the sample record for the given function samples at
@@ -425,7 +538,7 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS,
for (const auto &I : FS->getCallsiteSamples())
for (const auto &J : I.second) {
const FunctionSamples *CalleeSamples = &J.second;
- if (callsiteIsHot(CalleeSamples, PSI))
+ if (SPLoader.callsiteIsHot(CalleeSamples, PSI))
Count += countUsedRecords(CalleeSamples, PSI);
}
@@ -444,7 +557,7 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS,
for (const auto &I : FS->getCallsiteSamples())
for (const auto &J : I.second) {
const FunctionSamples *CalleeSamples = &J.second;
- if (callsiteIsHot(CalleeSamples, PSI))
+ if (SPLoader.callsiteIsHot(CalleeSamples, PSI))
Count += countBodyRecords(CalleeSamples, PSI);
}
@@ -465,7 +578,7 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS,
for (const auto &I : FS->getCallsiteSamples())
for (const auto &J : I.second) {
const FunctionSamples *CalleeSamples = &J.second;
- if (callsiteIsHot(CalleeSamples, PSI))
+ if (SPLoader.callsiteIsHot(CalleeSamples, PSI))
Total += countBodySamples(CalleeSamples, PSI);
}
@@ -788,6 +901,14 @@ bool SampleProfileLoader::inlineHotFunctions(
Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) {
DenseSet<Instruction *> PromotedInsns;
+ // ProfAccForSymsInList is used in callsiteIsHot. The assertion makes sure
+ // Profile symbol list is ignored when profile-sample-accurate is on.
+ assert((!ProfAccForSymsInList ||
+ (!ProfileSampleAccurate &&
+ !F.hasFnAttribute("profile-sample-accurate"))) &&
+ "ProfAccForSymsInList should be false when profile-sample-accurate "
+ "is enabled");
+
DenseMap<Instruction *, const FunctionSamples *> localNotInlinedCallSites;
bool Changed = false;
while (true) {
@@ -1219,17 +1340,12 @@ void SampleProfileLoader::buildEdges(Function &F) {
}
/// Returns the sorted CallTargetMap \p M by count in descending order.
-static SmallVector<InstrProfValueData, 2> SortCallTargets(
- const SampleRecord::CallTargetMap &M) {
+static SmallVector<InstrProfValueData, 2> GetSortedValueDataFromCallTargets(
+ const SampleRecord::CallTargetMap & M) {
SmallVector<InstrProfValueData, 2> R;
- for (auto I = M.begin(); I != M.end(); ++I)
- R.push_back({FunctionSamples::getGUID(I->getKey()), I->getValue()});
- llvm::sort(R, [](const InstrProfValueData &L, const InstrProfValueData &R) {
- if (L.Count == R.Count)
- return L.Value > R.Value;
- else
- return L.Count > R.Count;
- });
+ for (const auto &I : SampleRecord::SortCallTargets(M)) {
+ R.emplace_back(InstrProfValueData{FunctionSamples::getGUID(I.first), I.second});
+ }
return R;
}
@@ -1324,7 +1440,7 @@ void SampleProfileLoader::propagateWeights(Function &F) {
if (!T || T.get().empty())
continue;
SmallVector<InstrProfValueData, 2> SortedCallTargets =
- SortCallTargets(T.get());
+ GetSortedValueDataFromCallTargets(T.get());
uint64_t Sum;
findIndirectCallFunctionSamples(I, Sum);
annotateValueSite(*I.getParent()->getParent()->getParent(), I,
@@ -1374,6 +1490,8 @@ void SampleProfileLoader::propagateWeights(Function &F) {
}
}
+ misexpect::verifyMisExpect(TI, Weights, TI->getContext());
+
uint64_t TempWeight;
// Only set weights if there is at least one non-zero weight.
// In any other case, let the analyzer set weights.
@@ -1557,30 +1675,29 @@ INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile",
bool SampleProfileLoader::doInitialization(Module &M) {
auto &Ctx = M.getContext();
- auto ReaderOrErr = SampleProfileReader::create(Filename, Ctx);
+
+ std::unique_ptr<SampleProfileReaderItaniumRemapper> RemapReader;
+ auto ReaderOrErr =
+ SampleProfileReader::create(Filename, Ctx, RemappingFilename);
if (std::error_code EC = ReaderOrErr.getError()) {
std::string Msg = "Could not open profile: " + EC.message();
Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
return false;
}
Reader = std::move(ReaderOrErr.get());
- Reader->collectFuncsToUse(M);
+ Reader->collectFuncsFrom(M);
ProfileIsValid = (Reader->read() == sampleprof_error::success);
-
- if (!RemappingFilename.empty()) {
- // Apply profile remappings to the loaded profile data if requested.
- // For now, we only support remapping symbols encoded using the Itanium
- // C++ ABI's name mangling scheme.
- ReaderOrErr = SampleProfileReaderItaniumRemapper::create(
- RemappingFilename, Ctx, std::move(Reader));
- if (std::error_code EC = ReaderOrErr.getError()) {
- std::string Msg = "Could not open profile remapping file: " + EC.message();
- Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
- return false;
- }
- Reader = std::move(ReaderOrErr.get());
- ProfileIsValid = (Reader->read() == sampleprof_error::success);
+ PSL = Reader->getProfileSymbolList();
+
+ // While profile-sample-accurate is on, ignore symbol list.
+ ProfAccForSymsInList =
+ ProfileAccurateForSymsInList && PSL && !ProfileSampleAccurate;
+ if (ProfAccForSymsInList) {
+ NamesInProfile.clear();
+ if (auto NameTable = Reader->getNameTable())
+ NamesInProfile.insert(NameTable->begin(), NameTable->end());
}
+
return true;
}
@@ -1594,7 +1711,7 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) {
bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
ProfileSummaryInfo *_PSI) {
- FunctionSamples::GUIDToFuncNameMapper Mapper(M);
+ GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap);
if (!ProfileIsValid)
return false;
@@ -1651,19 +1768,48 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) {
}
bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
-
+
DILocation2SampleMap.clear();
// By default the entry count is initialized to -1, which will be treated
// conservatively by getEntryCount as the same as unknown (None). This is
// to avoid newly added code to be treated as cold. If we have samples
// this will be overwritten in emitAnnotations.
- // If ProfileSampleAccurate is true or F has profile-sample-accurate
- // attribute, initialize the entry count to 0 so callsites or functions
- // unsampled will be treated as cold.
- uint64_t initialEntryCount =
- (ProfileSampleAccurate || F.hasFnAttribute("profile-sample-accurate"))
- ? 0
- : -1;
+ uint64_t initialEntryCount = -1;
+
+ ProfAccForSymsInList = ProfileAccurateForSymsInList && PSL;
+ if (ProfileSampleAccurate || F.hasFnAttribute("profile-sample-accurate")) {
+ // initialize all the function entry counts to 0. It means all the
+ // functions without profile will be regarded as cold.
+ initialEntryCount = 0;
+ // profile-sample-accurate is a user assertion which has a higher precedence
+ // than symbol list. When profile-sample-accurate is on, ignore symbol list.
+ ProfAccForSymsInList = false;
+ }
+
+ // PSL -- profile symbol list include all the symbols in sampled binary.
+ // If ProfileAccurateForSymsInList is enabled, PSL is used to treat
+ // old functions without samples being cold, without having to worry
+ // about new and hot functions being mistakenly treated as cold.
+ if (ProfAccForSymsInList) {
+ // Initialize the entry count to 0 for functions in the list.
+ if (PSL->contains(F.getName()))
+ initialEntryCount = 0;
+
+ // Function in the symbol list but without sample will be regarded as
+ // cold. To minimize the potential negative performance impact it could
+ // have, we want to be a little conservative here saying if a function
+ // shows up in the profile, no matter as outline function, inline instance
+ // or call targets, treat the function as not being cold. This will handle
+ // the cases such as most callsites of a function are inlined in sampled
+ // binary but not inlined in current build (because of source code drift,
+ // imprecise debug information, or the callsites are all cold individually
+ // but not cold accumulatively...), so the outline function showing up as
+ // cold in sampled binary will actually not be cold after current build.
+ StringRef CanonName = FunctionSamples::getCanonicalFnName(F);
+ if (NamesInProfile.count(CanonName))
+ initialEntryCount = -1;
+ }
+
F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real));
std::unique_ptr<OptimizationRemarkEmitter> OwnedORE;
if (AM) {
@@ -1672,7 +1818,7 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM)
.getManager();
ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
} else {
- OwnedORE = make_unique<OptimizationRemarkEmitter>(&F);
+ OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F);
ORE = OwnedORE.get();
}
Samples = Reader->getSamplesFor(F);
diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
index 24c476376c14..690b5e8bf49e 100644
--- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
+++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
@@ -24,6 +24,7 @@
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/IPO/FunctionImport.h"
+#include "llvm/Transforms/IPO/LowerTypeTests.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace llvm;
@@ -218,10 +219,18 @@ void splitAndWriteThinLTOBitcode(
promoteTypeIds(M, ModuleId);
- // Returns whether a global has attached type metadata. Such globals may
- // participate in CFI or whole-program devirtualization, so they need to
- // appear in the merged module instead of the thin LTO module.
+ // Returns whether a global or its associated global has attached type
+ // metadata. The former may participate in CFI or whole-program
+ // devirtualization, so they need to appear in the merged module instead of
+ // the thin LTO module. Similarly, globals that are associated with globals
+ // with type metadata need to appear in the merged module because they will
+ // reference the global's section directly.
auto HasTypeMetadata = [](const GlobalObject *GO) {
+ if (MDNode *MD = GO->getMetadata(LLVMContext::MD_associated))
+ if (auto *AssocVM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(0)))
+ if (auto *AssocGO = dyn_cast<GlobalObject>(AssocVM->getValue()))
+ if (AssocGO->hasMetadata(LLVMContext::MD_type))
+ return true;
return GO->hasMetadata(LLVMContext::MD_type);
};
@@ -315,9 +324,9 @@ void splitAndWriteThinLTOBitcode(
SmallVector<Metadata *, 4> Elts;
Elts.push_back(MDString::get(Ctx, F.getName()));
CfiFunctionLinkage Linkage;
- if (!F.isDeclarationForLinker())
+ if (lowertypetests::isJumpTableCanonical(&F))
Linkage = CFL_Definition;
- else if (F.isWeakForLinker())
+ else if (F.hasExternalWeakLinkage())
Linkage = CFL_WeakDeclaration;
else
Linkage = CFL_Declaration;
@@ -457,7 +466,7 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS,
// splitAndWriteThinLTOBitcode). Just always build it once via the
// buildModuleSummaryIndex when Module(s) are ready.
ProfileSummaryInfo PSI(M);
- NewIndex = llvm::make_unique<ModuleSummaryIndex>(
+ NewIndex = std::make_unique<ModuleSummaryIndex>(
buildModuleSummaryIndex(M, nullptr, &PSI));
Index = NewIndex.get();
}
diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 6b6dd6194e17..f0cf5581ba8a 100644
--- a/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -24,12 +24,14 @@
// returns 0, or a single vtable's function returns 1, replace each virtual
// call with a comparison of the vptr against that vtable's address.
//
-// This pass is intended to be used during the regular and thin LTO pipelines.
+// This pass is intended to be used during the regular and thin LTO pipelines:
+//
// During regular LTO, the pass determines the best optimization for each
// virtual call and applies the resolutions directly to virtual calls that are
// eligible for virtual call optimization (i.e. calls that use either of the
-// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During
-// ThinLTO, the pass operates in two phases:
+// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics).
+//
+// During hybrid Regular/ThinLTO, the pass operates in two phases:
// - Export phase: this is run during the thin link over a single merged module
// that contains all vtables with !type metadata that participate in the link.
// The pass computes a resolution for each virtual call and stores it in the
@@ -38,6 +40,14 @@
// modules. The pass applies the resolutions previously computed during the
// import phase to each eligible virtual call.
//
+// During ThinLTO, the pass operates in two phases:
+// - Export phase: this is run during the thin link over the index which
+// contains a summary of all vtables with !type metadata that participate in
+// the link. It computes a resolution for each virtual call and stores it in
+// the type identifier summary. Only single implementation devirtualization
+// is supported.
+// - Import phase: (same as with hybrid case above).
+//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
@@ -117,6 +127,11 @@ static cl::opt<unsigned>
cl::desc("Maximum number of call targets per "
"call site to enable branch funnels"));
+static cl::opt<bool>
+ PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden,
+ cl::init(false), cl::ZeroOrMore,
+ cl::desc("Print index-based devirtualization messages"));
+
// Find the minimum offset that we may store a value of size Size bits at. If
// IsAfter is set, look for an offset before the object, otherwise look for an
// offset after the object.
@@ -265,6 +280,25 @@ template <> struct DenseMapInfo<VTableSlot> {
}
};
+template <> struct DenseMapInfo<VTableSlotSummary> {
+ static VTableSlotSummary getEmptyKey() {
+ return {DenseMapInfo<StringRef>::getEmptyKey(),
+ DenseMapInfo<uint64_t>::getEmptyKey()};
+ }
+ static VTableSlotSummary getTombstoneKey() {
+ return {DenseMapInfo<StringRef>::getTombstoneKey(),
+ DenseMapInfo<uint64_t>::getTombstoneKey()};
+ }
+ static unsigned getHashValue(const VTableSlotSummary &I) {
+ return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^
+ DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
+ }
+ static bool isEqual(const VTableSlotSummary &LHS,
+ const VTableSlotSummary &RHS) {
+ return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
+ }
+};
+
} // end namespace llvm
namespace {
@@ -342,19 +376,21 @@ struct CallSiteInfo {
/// pass the vector is non-empty, we will need to add a use of llvm.type.test
/// to each of the function summaries in the vector.
std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
+ std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers;
bool isExported() const {
return SummaryHasTypeTestAssumeUsers ||
!SummaryTypeCheckedLoadUsers.empty();
}
- void markSummaryHasTypeTestAssumeUsers() {
- SummaryHasTypeTestAssumeUsers = true;
+ void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
+ SummaryTypeCheckedLoadUsers.push_back(FS);
AllCallSitesDevirted = false;
}
- void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
- SummaryTypeCheckedLoadUsers.push_back(FS);
+ void addSummaryTypeTestAssumeUser(FunctionSummary *FS) {
+ SummaryTypeTestAssumeUsers.push_back(FS);
+ SummaryHasTypeTestAssumeUsers = true;
AllCallSitesDevirted = false;
}
@@ -456,7 +492,6 @@ struct DevirtModule {
void buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
- Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
bool
tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
const std::set<TypeMemberInfo> &TypeMemberInfos,
@@ -464,7 +499,8 @@ struct DevirtModule {
void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
bool &IsExported);
- bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+ bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary,
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot,
VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res);
@@ -542,6 +578,38 @@ struct DevirtModule {
function_ref<DominatorTree &(Function &)> LookupDomTree);
};
+struct DevirtIndex {
+ ModuleSummaryIndex &ExportSummary;
+ // The set in which to record GUIDs exported from their module by
+ // devirtualization, used by client to ensure they are not internalized.
+ std::set<GlobalValue::GUID> &ExportedGUIDs;
+ // A map in which to record the information necessary to locate the WPD
+ // resolution for local targets in case they are exported by cross module
+ // importing.
+ std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap;
+
+ MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
+
+ DevirtIndex(
+ ModuleSummaryIndex &ExportSummary,
+ std::set<GlobalValue::GUID> &ExportedGUIDs,
+ std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap)
+ : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
+ LocalWPDTargetsMap(LocalWPDTargetsMap) {}
+
+ bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
+ const TypeIdCompatibleVtableInfo TIdInfo,
+ uint64_t ByteOffset);
+
+ bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
+ VTableSlotSummary &SlotSummary,
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res,
+ std::set<ValueInfo> &DevirtTargets);
+
+ void run();
+};
+
struct WholeProgramDevirt : public ModulePass {
static char ID;
@@ -572,7 +640,7 @@ struct WholeProgramDevirt : public ModulePass {
// an optimization remark emitter on the fly, when we need it.
std::unique_ptr<OptimizationRemarkEmitter> ORE;
auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
- ORE = make_unique<OptimizationRemarkEmitter>(F);
+ ORE = std::make_unique<OptimizationRemarkEmitter>(F);
return *ORE;
};
@@ -632,6 +700,41 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return PreservedAnalyses::none();
}
+namespace llvm {
+void runWholeProgramDevirtOnIndex(
+ ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
+ std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
+ DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
+}
+
+void updateIndexWPDForExports(
+ ModuleSummaryIndex &Summary,
+ function_ref<bool(StringRef, GlobalValue::GUID)> isExported,
+ std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
+ for (auto &T : LocalWPDTargetsMap) {
+ auto &VI = T.first;
+ // This was enforced earlier during trySingleImplDevirt.
+ assert(VI.getSummaryList().size() == 1 &&
+ "Devirt of local target has more than one copy");
+ auto &S = VI.getSummaryList()[0];
+ if (!isExported(S->modulePath(), VI.getGUID()))
+ continue;
+
+ // It's been exported by a cross module import.
+ for (auto &SlotSummary : T.second) {
+ auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID);
+ assert(TIdSum);
+ auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset);
+ assert(WPDRes != TIdSum->WPDRes.end());
+ WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
+ WPDRes->second.SingleImplName,
+ Summary.getModuleHash(S->modulePath()));
+ }
+ }
+}
+
+} // end namespace llvm
+
bool DevirtModule::runForTesting(
Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
@@ -662,7 +765,7 @@ bool DevirtModule::runForTesting(
ExitOnError ExitOnErr(
"-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
std::error_code EC;
- raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
+ raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text);
ExitOnErr(errorCodeToError(EC));
yaml::Output Out(OS);
@@ -706,38 +809,6 @@ void DevirtModule::buildTypeIdentifierMap(
}
}
-Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
- if (I->getType()->isPointerTy()) {
- if (Offset == 0)
- return I;
- return nullptr;
- }
-
- const DataLayout &DL = M.getDataLayout();
-
- if (auto *C = dyn_cast<ConstantStruct>(I)) {
- const StructLayout *SL = DL.getStructLayout(C->getType());
- if (Offset >= SL->getSizeInBytes())
- return nullptr;
-
- unsigned Op = SL->getElementContainingOffset(Offset);
- return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
- Offset - SL->getElementOffset(Op));
- }
- if (auto *C = dyn_cast<ConstantArray>(I)) {
- ArrayType *VTableTy = C->getType();
- uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
-
- unsigned Op = Offset / ElemSize;
- if (Op >= C->getNumOperands())
- return nullptr;
-
- return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
- Offset % ElemSize);
- }
- return nullptr;
-}
-
bool DevirtModule::tryFindVirtualCallTargets(
std::vector<VirtualCallTarget> &TargetsForSlot,
const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
@@ -746,7 +817,7 @@ bool DevirtModule::tryFindVirtualCallTargets(
return false;
Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
- TM.Offset + ByteOffset);
+ TM.Offset + ByteOffset, M);
if (!Ptr)
return false;
@@ -766,6 +837,34 @@ bool DevirtModule::tryFindVirtualCallTargets(
return !TargetsForSlot.empty();
}
+bool DevirtIndex::tryFindVirtualCallTargets(
+ std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo,
+ uint64_t ByteOffset) {
+ for (const TypeIdOffsetVtableInfo P : TIdInfo) {
+ // VTable initializer should have only one summary, or all copies must be
+ // linkonce/weak ODR.
+ assert(P.VTableVI.getSummaryList().size() == 1 ||
+ llvm::all_of(
+ P.VTableVI.getSummaryList(),
+ [&](const std::unique_ptr<GlobalValueSummary> &Summary) {
+ return GlobalValue::isLinkOnceODRLinkage(Summary->linkage()) ||
+ GlobalValue::isWeakODRLinkage(Summary->linkage());
+ }));
+ const auto *VS = cast<GlobalVarSummary>(P.VTableVI.getSummaryList()[0].get());
+ if (!P.VTableVI.getSummaryList()[0]->isLive())
+ continue;
+ for (auto VTP : VS->vTableFuncs()) {
+ if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset)
+ continue;
+
+ TargetsForSlot.push_back(VTP.FuncVI);
+ }
+ }
+
+ // Give up if we couldn't find any targets.
+ return !TargetsForSlot.empty();
+}
+
void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
Constant *TheFn, bool &IsExported) {
auto Apply = [&](CallSiteInfo &CSInfo) {
@@ -788,9 +887,38 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
Apply(P.second);
}
+static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
+ // We can't add calls if we haven't seen a definition
+ if (Callee.getSummaryList().empty())
+ return false;
+
+ // Insert calls into the summary index so that the devirtualized targets
+ // are eligible for import.
+ // FIXME: Annotate type tests with hotness. For now, mark these as hot
+ // to better ensure we have the opportunity to inline them.
+ bool IsExported = false;
+ auto &S = Callee.getSummaryList()[0];
+ CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0);
+ auto AddCalls = [&](CallSiteInfo &CSInfo) {
+ for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
+ FS->addCall({Callee, CI});
+ IsExported |= S->modulePath() != FS->modulePath();
+ }
+ for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) {
+ FS->addCall({Callee, CI});
+ IsExported |= S->modulePath() != FS->modulePath();
+ }
+ };
+ AddCalls(SlotInfo.CSInfo);
+ for (auto &P : SlotInfo.ConstCSInfo)
+ AddCalls(P.second);
+ return IsExported;
+}
+
bool DevirtModule::trySingleImplDevirt(
- MutableArrayRef<VirtualCallTarget> TargetsForSlot,
- VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
+ ModuleSummaryIndex *ExportSummary,
+ MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res) {
// See if the program contains a single implementation of this virtual
// function.
Function *TheFn = TargetsForSlot[0].Fn;
@@ -830,6 +958,10 @@ bool DevirtModule::trySingleImplDevirt(
TheFn->setVisibility(GlobalValue::HiddenVisibility);
TheFn->setName(NewName);
}
+ if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID()))
+ // Any needed promotion of 'TheFn' has already been done during
+ // LTO unit split, so we can ignore return value of AddCalls.
+ AddCalls(SlotInfo, TheFnVI);
Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
Res->SingleImplName = TheFn->getName();
@@ -837,6 +969,63 @@ bool DevirtModule::trySingleImplDevirt(
return true;
}
+bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
+ VTableSlotSummary &SlotSummary,
+ VTableSlotInfo &SlotInfo,
+ WholeProgramDevirtResolution *Res,
+ std::set<ValueInfo> &DevirtTargets) {
+ // See if the program contains a single implementation of this virtual
+ // function.
+ auto TheFn = TargetsForSlot[0];
+ for (auto &&Target : TargetsForSlot)
+ if (TheFn != Target)
+ return false;
+
+ // Don't devirtualize if we don't have target definition.
+ auto Size = TheFn.getSummaryList().size();
+ if (!Size)
+ return false;
+
+ // If the summary list contains multiple summaries where at least one is
+ // a local, give up, as we won't know which (possibly promoted) name to use.
+ for (auto &S : TheFn.getSummaryList())
+ if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
+ return false;
+
+ // Collect functions devirtualized at least for one call site for stats.
+ if (PrintSummaryDevirt)
+ DevirtTargets.insert(TheFn);
+
+ auto &S = TheFn.getSummaryList()[0];
+ bool IsExported = AddCalls(SlotInfo, TheFn);
+ if (IsExported)
+ ExportedGUIDs.insert(TheFn.getGUID());
+
+ // Record in summary for use in devirtualization during the ThinLTO import
+ // step.
+ Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
+ if (GlobalValue::isLocalLinkage(S->linkage())) {
+ if (IsExported)
+ // If target is a local function and we are exporting it by
+ // devirtualizing a call in another module, we need to record the
+ // promoted name.
+ Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
+ TheFn.name(), ExportSummary.getModuleHash(S->modulePath()));
+ else {
+ LocalWPDTargetsMap[TheFn].push_back(SlotSummary);
+ Res->SingleImplName = TheFn.name();
+ }
+ } else
+ Res->SingleImplName = TheFn.name();
+
+ // Name will be empty if this thin link driven off of serialized combined
+ // index (e.g. llvm-lto). However, WPD is not supported/invoked for the
+ // legacy LTO API anyway.
+ assert(!Res->SingleImplName.empty());
+
+ return true;
+}
+
void DevirtModule::tryICallBranchFunnel(
MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res, VTableSlot Slot) {
@@ -1302,10 +1491,13 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
if (B.Before.Bytes.empty() && B.After.Bytes.empty())
return;
- // Align each byte array to pointer width.
- unsigned PointerSize = M.getDataLayout().getPointerSize();
- B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
- B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
+ // Align the before byte array to the global's minimum alignment so that we
+ // don't break any alignment requirements on the global.
+ MaybeAlign Alignment(B.GV->getAlignment());
+ if (!Alignment)
+ Alignment =
+ Align(M.getDataLayout().getABITypeAlignment(B.GV->getValueType()));
+ B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment));
// Before was stored in reverse order; flip it now.
for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
@@ -1322,6 +1514,7 @@ void DevirtModule::rebuildGlobal(VTableBits &B) {
GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
NewGV->setSection(B.GV->getSection());
NewGV->setComdat(B.GV->getComdat());
+ NewGV->setAlignment(MaybeAlign(B.GV->getAlignment()));
// Copy the original vtable's metadata to the anonymous global, adjusting
// offsets as required.
@@ -1483,8 +1676,11 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
}
void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
+ auto *TypeId = dyn_cast<MDString>(Slot.TypeID);
+ if (!TypeId)
+ return;
const TypeIdSummary *TidSummary =
- ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString());
+ ImportSummary->getTypeIdSummary(TypeId->getString());
if (!TidSummary)
return;
auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
@@ -1493,6 +1689,7 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
const WholeProgramDevirtResolution &Res = ResI->second;
if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
+ assert(!Res.SingleImplName.empty());
// The type of the function in the declaration is irrelevant because every
// call site will cast it to the correct type.
Constant *SingleImpl =
@@ -1627,8 +1824,7 @@ bool DevirtModule::run() {
// FIXME: Only add live functions.
for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
for (Metadata *MD : MetadataByGUID[VF.GUID]) {
- CallSlots[{MD, VF.Offset}]
- .CSInfo.markSummaryHasTypeTestAssumeUsers();
+ CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
}
}
for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
@@ -1641,7 +1837,7 @@ bool DevirtModule::run() {
for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
CallSlots[{MD, VC.VFunc.Offset}]
.ConstCSInfo[VC.Args]
- .markSummaryHasTypeTestAssumeUsers();
+ .addSummaryTypeTestAssumeUser(FS);
}
}
for (const FunctionSummary::ConstVCall &VC :
@@ -1673,7 +1869,7 @@ bool DevirtModule::run() {
cast<MDString>(S.first.TypeID)->getString())
.WPDRes[S.first.ByteOffset];
- if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
+ if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) {
DidVirtualConstProp |=
tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
@@ -1710,7 +1906,7 @@ bool DevirtModule::run() {
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
<< "devirtualized "
- << NV("FunctionName", F->getName()));
+ << NV("FunctionName", DT.first));
}
}
@@ -1722,5 +1918,86 @@ bool DevirtModule::run() {
for (VTableBits &B : Bits)
rebuildGlobal(B);
+ // We have lowered or deleted the type checked load intrinsics, so we no
+ // longer have enough information to reason about the liveness of virtual
+ // function pointers in GlobalDCE.
+ for (GlobalVariable &GV : M.globals())
+ GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
+
return true;
}
+
+void DevirtIndex::run() {
+ if (ExportSummary.typeIdCompatibleVtableMap().empty())
+ return;
+
+ DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
+ for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
+ NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first);
+ }
+
+ // Collect information from summary about which calls to try to devirtualize.
+ for (auto &P : ExportSummary) {
+ for (auto &S : P.second.SummaryList) {
+ auto *FS = dyn_cast<FunctionSummary>(S.get());
+ if (!FS)
+ continue;
+ // FIXME: Only add live functions.
+ for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
+ for (StringRef Name : NameByGUID[VF.GUID]) {
+ CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
+ }
+ }
+ for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
+ for (StringRef Name : NameByGUID[VF.GUID]) {
+ CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
+ }
+ }
+ for (const FunctionSummary::ConstVCall &VC :
+ FS->type_test_assume_const_vcalls()) {
+ for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
+ CallSlots[{Name, VC.VFunc.Offset}]
+ .ConstCSInfo[VC.Args]
+ .addSummaryTypeTestAssumeUser(FS);
+ }
+ }
+ for (const FunctionSummary::ConstVCall &VC :
+ FS->type_checked_load_const_vcalls()) {
+ for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
+ CallSlots[{Name, VC.VFunc.Offset}]
+ .ConstCSInfo[VC.Args]
+ .addSummaryTypeCheckedLoadUser(FS);
+ }
+ }
+ }
+ }
+
+ std::set<ValueInfo> DevirtTargets;
+ // For each (type, offset) pair:
+ for (auto &S : CallSlots) {
+ // Search each of the members of the type identifier for the virtual
+ // function implementation at offset S.first.ByteOffset, and add to
+ // TargetsForSlot.
+ std::vector<ValueInfo> TargetsForSlot;
+ auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID);
+ assert(TidSummary);
+ if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary,
+ S.first.ByteOffset)) {
+ WholeProgramDevirtResolution *Res =
+ &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID)
+ .WPDRes[S.first.ByteOffset];
+
+ if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res,
+ DevirtTargets))
+ continue;
+ }
+ }
+
+ // Optionally have the thin link print message for each devirtualized
+ // function.
+ if (PrintSummaryDevirt)
+ for (const auto &DT : DevirtTargets)
+ errs() << "Devirtualized call to " << DT << "\n";
+
+ return;
+}
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index ba15b023f2a3..8bc34825f8a7 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1097,6 +1097,107 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+Instruction *
+InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
+ BinaryOperator &I) {
+ assert((I.getOpcode() == Instruction::Add ||
+ I.getOpcode() == Instruction::Or ||
+ I.getOpcode() == Instruction::Sub) &&
+ "Expecting add/or/sub instruction");
+
+ // We have a subtraction/addition between a (potentially truncated) *logical*
+ // right-shift of X and a "select".
+ Value *X, *Select;
+ Instruction *LowBitsToSkip, *Extract;
+ if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
+ m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
+ m_Instruction(Extract))),
+ m_Value(Select))))
+ return nullptr;
+
+ // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS.
+ if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select)
+ return nullptr;
+
+ Type *XTy = X->getType();
+ bool HadTrunc = I.getType() != XTy;
+
+ // If there was a truncation of extracted value, then we'll need to produce
+ // one extra instruction, so we need to ensure one instruction will go away.
+ if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
+ // Extraction should extract high NBits bits, with shift amount calculated as:
+ // low bits to skip = shift bitwidth - high bits to extract
+ // The shift amount itself may be extended, and we need to look past zero-ext
+ // when matching NBits, that will matter for matching later.
+ Constant *C;
+ Value *NBits;
+ if (!match(
+ LowBitsToSkip,
+ m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) ||
+ !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(C->getType()->getScalarSizeInBits(),
+ X->getType()->getScalarSizeInBits()))))
+ return nullptr;
+
+ // Sign-extending value can be zero-extended if we `sub`tract it,
+ // or sign-extended otherwise.
+ auto SkipExtInMagic = [&I](Value *&V) {
+ if (I.getOpcode() == Instruction::Sub)
+ match(V, m_ZExtOrSelf(m_Value(V)));
+ else
+ match(V, m_SExtOrSelf(m_Value(V)));
+ };
+
+ // Now, finally validate the sign-extending magic.
+ // `select` itself may be appropriately extended, look past that.
+ SkipExtInMagic(Select);
+
+ ICmpInst::Predicate Pred;
+ const APInt *Thr;
+ Value *SignExtendingValue, *Zero;
+ bool ShouldSignext;
+ // It must be a select between two values we will later establish to be a
+ // sign-extending value and a zero constant. The condition guarding the
+ // sign-extension must be based on a sign bit of the same X we had in `lshr`.
+ if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)),
+ m_Value(SignExtendingValue), m_Value(Zero))) ||
+ !isSignBitCheck(Pred, *Thr, ShouldSignext))
+ return nullptr;
+
+ // icmp-select pair is commutative.
+ if (!ShouldSignext)
+ std::swap(SignExtendingValue, Zero);
+
+ // If we should not perform sign-extension then we must add/or/subtract zero.
+ if (!match(Zero, m_Zero()))
+ return nullptr;
+ // Otherwise, it should be some constant, left-shifted by the same NBits we
+ // had in `lshr`. Said left-shift can also be appropriately extended.
+ // Again, we must look past zero-ext when looking for NBits.
+ SkipExtInMagic(SignExtendingValue);
+ Constant *SignExtendingValueBaseConstant;
+ if (!match(SignExtendingValue,
+ m_Shl(m_Constant(SignExtendingValueBaseConstant),
+ m_ZExtOrSelf(m_Specific(NBits)))))
+ return nullptr;
+ // If we `sub`, then the constant should be one, else it should be all-ones.
+ if (I.getOpcode() == Instruction::Sub
+ ? !match(SignExtendingValueBaseConstant, m_One())
+ : !match(SignExtendingValueBaseConstant, m_AllOnes()))
+ return nullptr;
+
+ auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip,
+ Extract->getName() + ".sext");
+ NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness.
+ if (!HadTrunc)
+ return NewAShr;
+
+ Builder.Insert(NewAShr);
+ return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType());
+}
+
Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -1302,12 +1403,32 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Instruction *V = canonicalizeLowbitMask(I, Builder))
return V;
+ if (Instruction *V =
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
+ return V;
+
if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I))
return SatAdd;
return Changed ? &I : nullptr;
}
+/// Eliminate an op from a linear interpolation (lerp) pattern.
+static Instruction *factorizeLerp(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X, *Y, *Z;
+ if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y),
+ m_OneUse(m_FSub(m_FPOne(),
+ m_Value(Z))))),
+ m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z))))))
+ return nullptr;
+
+ // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants]
+ Value *XY = Builder.CreateFSubFMF(X, Y, &I);
+ Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I);
+ return BinaryOperator::CreateFAddFMF(Y, MulZ, &I);
+}
+
/// Factor a common operand out of fadd/fsub of fmul/fdiv.
static Instruction *factorizeFAddFSub(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
@@ -1315,6 +1436,10 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I,
I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub");
assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
"FP factorization requires FMF");
+
+ if (Instruction *Lerp = factorizeLerp(I, Builder))
+ return Lerp;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *X, *Y, *Z;
bool IsFMul;
@@ -1362,17 +1487,32 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I))
return FoldedFAdd;
- Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
- Value *X;
// (-X) + Y --> Y - X
- if (match(LHS, m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFSubFMF(RHS, X, &I);
- // Y + (-X) --> Y - X
- if (match(RHS, m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFSubFMF(LHS, X, &I);
+ Value *X, *Y;
+ if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y))))
+ return BinaryOperator::CreateFSubFMF(Y, X, &I);
+
+ // Similar to above, but look through fmul/fdiv for the negated term.
+ // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants]
+ Value *Z;
+ if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))),
+ m_Value(Z)))) {
+ Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+ return BinaryOperator::CreateFSubFMF(Z, XY, &I);
+ }
+ // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants]
+ // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants]
+ if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))),
+ m_Value(Z))) ||
+ match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))),
+ m_Value(Z)))) {
+ Value *XY = Builder.CreateFDivFMF(X, Y, &I);
+ return BinaryOperator::CreateFSubFMF(Z, XY, &I);
+ }
// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
Value *LHSIntVal = LHSConv->getOperand(0);
Type *FPType = LHSConv->getType();
@@ -1631,37 +1771,50 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
const APInt *Op0C;
if (match(Op0, m_APInt(Op0C))) {
- unsigned BitWidth = I.getType()->getScalarSizeInBits();
- // -(X >>u 31) -> (X >>s 31)
- // -(X >>s 31) -> (X >>u 31)
if (Op0C->isNullValue()) {
+ Value *Op1Wide;
+ match(Op1, m_TruncOrSelf(m_Value(Op1Wide)));
+ bool HadTrunc = Op1Wide != Op1;
+ bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse();
+ unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits();
+
Value *X;
const APInt *ShAmt;
- if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
+ // -(X >>u 31) -> (X >>s 31)
+ if (NoTruncOrTruncIsOneUse &&
+ match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
*ShAmt == BitWidth - 1) {
- Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
- return BinaryOperator::CreateAShr(X, ShAmtOp);
+ Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
+ Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp);
+ NewShift->copyIRFlags(Op1Wide);
+ if (!HadTrunc)
+ return NewShift;
+ Builder.Insert(NewShift);
+ return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
}
- if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
+ // -(X >>s 31) -> (X >>u 31)
+ if (NoTruncOrTruncIsOneUse &&
+ match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
*ShAmt == BitWidth - 1) {
- Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
- return BinaryOperator::CreateLShr(X, ShAmtOp);
+ Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
+ Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp);
+ NewShift->copyIRFlags(Op1Wide);
+ if (!HadTrunc)
+ return NewShift;
+ Builder.Insert(NewShift);
+ return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
}
- if (Op1->hasOneUse()) {
+ if (!HadTrunc && Op1->hasOneUse()) {
Value *LHS, *RHS;
SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
if (SPF == SPF_ABS || SPF == SPF_NABS) {
// This is a negate of an ABS/NABS pattern. Just swap the operands
// of the select.
- SelectInst *SI = cast<SelectInst>(Op1);
- Value *TrueVal = SI->getTrueValue();
- Value *FalseVal = SI->getFalseValue();
- SI->setTrueValue(FalseVal);
- SI->setFalseValue(TrueVal);
+ cast<SelectInst>(Op1)->swapValues();
// Don't swap prof metadata, we didn't change the branch behavior.
- return replaceInstUsesWith(I, SI);
+ return replaceInstUsesWith(I, Op1);
}
}
}
@@ -1686,6 +1839,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return BinaryOperator::CreateNeg(Y);
}
+ // (sub (or A, B) (and A, B)) --> (xor A, B)
+ {
+ Value *A, *B;
+ if (match(Op1, m_And(m_Value(A), m_Value(B))) &&
+ match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
+ return BinaryOperator::CreateXor(A, B);
+ }
+
+ // (sub (and A, B) (or A, B)) --> neg (xor A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse()))
+ return BinaryOperator::CreateNeg(Builder.CreateXor(A, B));
+ }
+
// (sub (or A, B), (xor A, B)) --> (and A, B)
{
Value *A, *B;
@@ -1694,6 +1864,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return BinaryOperator::CreateAnd(A, B);
}
+ // (sub (xor A, B) (or A, B)) --> neg (and A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse()))
+ return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B));
+ }
+
{
Value *Y;
// ((X | Y) - X) --> (~X & Y)
@@ -1778,7 +1957,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
std::swap(LHS, RHS);
// LHS is now O above and expected to have at least 2 uses (the min/max)
// NotA is epected to have 2 uses from the min/max and 1 from the sub.
- if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
+ if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
!NotA->hasNUsesOrMore(4)) {
// Note: We don't generate the inverse max/min, just create the not of
// it and let other folds do the rest.
@@ -1826,6 +2005,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return SelectInst::Create(Cmp, Neg, A);
}
+ if (Instruction *V =
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
+ return V;
+
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1865,6 +2048,22 @@ static Instruction *foldFNegIntoConstant(Instruction &I) {
return nullptr;
}
+static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *FNeg;
+ if (!match(&I, m_FNeg(m_Value(FNeg))))
+ return nullptr;
+
+ Value *X, *Y;
+ if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))
+ return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+
+ if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))))
+ return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
Value *Op = I.getOperand(0);
@@ -1882,6 +2081,9 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y)))))
return BinaryOperator::CreateFSubFMF(Y, X, &I);
+ if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
+ return R;
+
return nullptr;
}
@@ -1903,6 +2105,9 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I))
return X;
+ if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
+ return R;
+
Value *X, *Y;
Constant *C;
@@ -1944,6 +2149,21 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y))))))
return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I);
+ // Similar to above, but look through fmul/fdiv of the negated value:
+ // Op0 - (-X * Y) --> Op0 + (X * Y)
+ // Op0 - (Y * -X) --> Op0 + (X * Y)
+ if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) {
+ Value *FMul = Builder.CreateFMulFMF(X, Y, &I);
+ return BinaryOperator::CreateFAddFMF(Op0, FMul, &I);
+ }
+ // Op0 - (-X / Y) --> Op0 + (X / Y)
+ // Op0 - (X / -Y) --> Op0 + (X / Y)
+ if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) ||
+ match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) {
+ Value *FDiv = Builder.CreateFDivFMF(X, Y, &I);
+ return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I);
+ }
+
// Handle special cases for FSub with selects feeding the operation
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 2b9859b602f4..4a30b60ca931 100644
--- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -160,16 +160,14 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,
}
/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
-/// (V < Lo || V >= Hi). This method expects that Lo <= Hi. IsSigned indicates
+/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
/// whether to treat V, Lo, and Hi as signed or not.
Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside) {
- assert((isSigned ? Lo.sle(Hi) : Lo.ule(Hi)) &&
- "Lo is not <= Hi in range emission code!");
+ assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) &&
+ "Lo is not < Hi in range emission code!");
Type *Ty = V->getType();
- if (Lo == Hi)
- return Inside ? ConstantInt::getFalse(Ty) : ConstantInt::getTrue(Ty);
// V >= Min && V < Hi --> V < Hi
// V < Min || V >= Hi --> V >= Hi
@@ -1051,9 +1049,103 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd,
return nullptr;
}
+/// Commuted variants are assumed to be handled by calling this function again
+/// with the parameters swapped.
+static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
+ ICmpInst *UnsignedICmp, bool IsAnd,
+ const SimplifyQuery &Q,
+ InstCombiner::BuilderTy &Builder) {
+ Value *ZeroCmpOp;
+ ICmpInst::Predicate EqPred;
+ if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) ||
+ !ICmpInst::isEquality(EqPred))
+ return nullptr;
+
+ auto IsKnownNonZero = [&](Value *V) {
+ return isKnownNonZero(V, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ };
+
+ ICmpInst::Predicate UnsignedPred;
+
+ Value *A, *B;
+ if (match(UnsignedICmp,
+ m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) &&
+ match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) &&
+ (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) {
+ if (UnsignedICmp->getOperand(0) != ZeroCmpOp)
+ UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+
+ auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) {
+ if (!IsKnownNonZero(NonZero))
+ std::swap(NonZero, Other);
+ return IsKnownNonZero(NonZero);
+ };
+
+ // Given ZeroCmpOp = (A + B)
+ // ZeroCmpOp <= A && ZeroCmpOp != 0 --> (0-B) < A
+ // ZeroCmpOp > A || ZeroCmpOp == 0 --> (0-B) >= A
+ //
+ // ZeroCmpOp < A && ZeroCmpOp != 0 --> (0-X) < Y iff
+ // ZeroCmpOp >= A || ZeroCmpOp == 0 --> (0-X) >= Y iff
+ // with X being the value (A/B) that is known to be non-zero,
+ // and Y being remaining value.
+ if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE &&
+ IsAnd)
+ return Builder.CreateICmpULT(Builder.CreateNeg(B), A);
+ if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE &&
+ IsAnd && GetKnownNonZeroAndOther(B, A))
+ return Builder.CreateICmpULT(Builder.CreateNeg(B), A);
+ if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ &&
+ !IsAnd)
+ return Builder.CreateICmpUGE(Builder.CreateNeg(B), A);
+ if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ &&
+ !IsAnd && GetKnownNonZeroAndOther(B, A))
+ return Builder.CreateICmpUGE(Builder.CreateNeg(B), A);
+ }
+
+ Value *Base, *Offset;
+ if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset))))
+ return nullptr;
+
+ if (!match(UnsignedICmp,
+ m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) ||
+ !ICmpInst::isUnsigned(UnsignedPred))
+ return nullptr;
+ if (UnsignedICmp->getOperand(0) != Base)
+ UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+
+ // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset
+ // (no overflow and not null)
+ if ((UnsignedPred == ICmpInst::ICMP_UGE ||
+ UnsignedPred == ICmpInst::ICMP_UGT) &&
+ EqPred == ICmpInst::ICMP_NE && IsAnd)
+ return Builder.CreateICmpUGT(Base, Offset);
+
+ // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset
+ // (overflow or null)
+ if ((UnsignedPred == ICmpInst::ICMP_ULE ||
+ UnsignedPred == ICmpInst::ICMP_ULT) &&
+ EqPred == ICmpInst::ICMP_EQ && !IsAnd)
+ return Builder.CreateICmpULE(Base, Offset);
+
+ // Base <= Offset && (Base - Offset) != 0 --> Base < Offset
+ if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE &&
+ IsAnd)
+ return Builder.CreateICmpULT(Base, Offset);
+
+ // Base > Offset || (Base - Offset) == 0 --> Base >= Offset
+ if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ &&
+ !IsAnd)
+ return Builder.CreateICmpUGE(Base, Offset);
+
+ return nullptr;
+}
+
/// Fold (icmp)&(icmp) if possible.
Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Instruction &CxtI) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&CxtI);
+
// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
// if K1 and K2 are a one-bit mask.
if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI))
@@ -1096,6 +1188,13 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder))
return V;
+ if (Value *X =
+ foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/true, Q, Builder))
+ return X;
+ if (Value *X =
+ foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder))
+ return X;
+
// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
@@ -1196,16 +1295,22 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
default:
llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_ULT:
- if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13
+ // (X != 13 & X u< 14) -> X < 13
+ if (LHSC->getValue() == (RHSC->getValue() - 1))
return Builder.CreateICmpULT(LHS0, LHSC);
- if (LHSC->isZero()) // (X != 0 & X u< 14) -> X-1 u< 13
+ if (LHSC->isZero()) // (X != 0 & X u< C) -> X-1 u< C-1
return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
break; // (X != 13 & X u< 15) -> no change
case ICmpInst::ICMP_SLT:
- if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13
+ // (X != 13 & X s< 14) -> X < 13
+ if (LHSC->getValue() == (RHSC->getValue() - 1))
return Builder.CreateICmpSLT(LHS0, LHSC);
- break; // (X != 13 & X s< 15) -> no change
+ // (X != INT_MIN & X s< C) -> X-(INT_MIN+1) u< (C-(INT_MIN+1))
+ if (LHSC->isMinValue(true))
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
+ true, true);
+ break; // (X != 13 & X s< 15) -> no change
case ICmpInst::ICMP_NE:
// Potential folds for this case should already be handled.
break;
@@ -1216,10 +1321,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
default:
llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_NE:
- if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14
+ // (X u> 13 & X != 14) -> X u> 14
+ if (RHSC->getValue() == (LHSC->getValue() + 1))
return Builder.CreateICmp(PredL, LHS0, RHSC);
+ // X u> C & X != UINT_MAX -> (X-(C+1)) u< UINT_MAX-(C+1)
+ if (RHSC->isMaxValue(false))
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
+ false, true);
break; // (X u> 13 & X != 15) -> no change
- case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1
+ case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) u< 1
return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
}
@@ -1229,10 +1339,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
default:
llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_NE:
- if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14
+ // (X s> 13 & X != 14) -> X s> 14
+ if (RHSC->getValue() == (LHSC->getValue() + 1))
return Builder.CreateICmp(PredL, LHS0, RHSC);
+ // X s> C & X != INT_MAX -> (X-(C+1)) u< INT_MAX-(C+1)
+ if (RHSC->isMaxValue(true))
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
+ true, true);
break; // (X s> 13 & X != 15) -> no change
- case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1
+ case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) u< 1
return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true,
true);
}
@@ -1352,8 +1467,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
Value *A, *B;
if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) &&
match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) &&
- !IsFreeToInvert(A, A->hasOneUse()) &&
- !IsFreeToInvert(B, B->hasOneUse())) {
+ !isFreeToInvert(A, A->hasOneUse()) &&
+ !isFreeToInvert(B, B->hasOneUse())) {
Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan");
return BinaryOperator::CreateNot(AndOr);
}
@@ -1770,13 +1885,13 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
// (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C
if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))
- if (Op1->hasOneUse() || IsFreeToInvert(C, C->hasOneUse()))
+ if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse()))
return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C));
// ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C
if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))))
if (match(Op1, m_Xor(m_Specific(B), m_Specific(A))))
- if (Op0->hasOneUse() || IsFreeToInvert(C, C->hasOneUse()))
+ if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse()))
return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C));
// (A | B) & ((~A) ^ B) -> (A & B)
@@ -1844,6 +1959,20 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
A->getType()->isIntOrIntVectorTy(1))
return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType()));
+ // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0.
+ {
+ Value *X, *Y;
+ const APInt *ShAmt;
+ Type *Ty = I.getType();
+ if (match(&I, m_c_And(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)),
+ m_APInt(ShAmt))),
+ m_Deferred(X))) &&
+ *ShAmt == Ty->getScalarSizeInBits() - 1) {
+ Value *NewICmpInst = Builder.CreateICmpSGT(X, Y);
+ return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty));
+ }
+ }
+
return nullptr;
}
@@ -2057,6 +2186,8 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B,
/// Fold (icmp)|(icmp) if possible.
Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Instruction &CxtI) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&CxtI);
+
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// if K1 and K2 are a one-bit mask.
if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI))
@@ -2182,6 +2313,13 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldIsPowerOf2(LHS, RHS, false /* JoinedByAnd */, Builder))
return V;
+ if (Value *X =
+ foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/false, Q, Builder))
+ return X;
+ if (Value *X =
+ foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder))
+ return X;
+
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
if (!LHSC || !RHSC)
return nullptr;
@@ -2251,8 +2389,19 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
case ICmpInst::ICMP_EQ:
// Potential folds for this case should already be handled.
break;
- case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change
- case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change
+ case ICmpInst::ICMP_UGT:
+ // (X == 0 || X u> C) -> (X-1) u>= C
+ if (LHSC->isMinValue(false))
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1,
+ false, false);
+ // (X == 13 | X u> 14) -> no change
+ break;
+ case ICmpInst::ICMP_SGT:
+ // (X == INT_MIN || X s> C) -> (X-(INT_MIN+1)) u>= C-INT_MIN
+ if (LHSC->isMinValue(true))
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1,
+ true, false);
+ // (X == 13 | X s> 14) -> no change
break;
}
break;
@@ -2261,6 +2410,10 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
default:
llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change
+ // (X u< C || X == UINT_MAX) => (X-C) u>= UINT_MAX-C
+ if (RHSC->isMaxValue(false))
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(),
+ false, false);
break;
case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2
assert(!RHSC->isMaxValue(false) && "Missed icmp simplification");
@@ -2272,9 +2425,14 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
switch (PredR) {
default:
llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change
+ case ICmpInst::ICMP_EQ:
+ // (X s< C || X == INT_MAX) => (X-C) u>= INT_MAX-C
+ if (RHSC->isMaxValue(true))
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(),
+ true, false);
+ // (X s< 13 | X == 14) -> no change
break;
- case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2
+ case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) u> 2
assert(!RHSC->isMaxValue(true) && "Missed icmp simplification");
return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true,
false);
@@ -2552,6 +2710,25 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
}
}
+ // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? -1 : X.
+ {
+ Value *X, *Y;
+ const APInt *ShAmt;
+ Type *Ty = I.getType();
+ if (match(&I, m_c_Or(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)),
+ m_APInt(ShAmt))),
+ m_Deferred(X))) &&
+ *ShAmt == Ty->getScalarSizeInBits() - 1) {
+ Value *NewICmpInst = Builder.CreateICmpSGT(X, Y);
+ return SelectInst::Create(NewICmpInst, ConstantInt::getAllOnesValue(Ty),
+ X);
+ }
+ }
+
+ if (Instruction *V =
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
+ return V;
+
return nullptr;
}
@@ -2617,7 +2794,11 @@ static Instruction *foldXorToXor(BinaryOperator &I,
return nullptr;
}
-Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
+Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+ BinaryOperator &I) {
+ assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS &&
+ I.getOperand(1) == RHS && "Should be 'xor' with these operands");
+
if (predicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) {
if (LHS->getOperand(0) == RHS->getOperand(1) &&
LHS->getOperand(1) == RHS->getOperand(0))
@@ -2672,14 +2853,35 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
// TODO: If OrICmp is false, the whole thing is false (InstSimplify?).
if (Value *AndICmp = SimplifyBinOp(Instruction::And, LHS, RHS, SQ)) {
// TODO: Independently handle cases where the 'and' side is a constant.
- if (OrICmp == LHS && AndICmp == RHS && RHS->hasOneUse()) {
- // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS
- RHS->setPredicate(RHS->getInversePredicate());
- return Builder.CreateAnd(LHS, RHS);
+ ICmpInst *X = nullptr, *Y = nullptr;
+ if (OrICmp == LHS && AndICmp == RHS) {
+ // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS --> X & !Y
+ X = LHS;
+ Y = RHS;
}
- if (OrICmp == RHS && AndICmp == LHS && LHS->hasOneUse()) {
- // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS
- LHS->setPredicate(LHS->getInversePredicate());
+ if (OrICmp == RHS && AndICmp == LHS) {
+ // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS --> !Y & X
+ X = RHS;
+ Y = LHS;
+ }
+ if (X && Y && (Y->hasOneUse() || canFreelyInvertAllUsersOf(Y, &I))) {
+ // Invert the predicate of 'Y', thus inverting its output.
+ Y->setPredicate(Y->getInversePredicate());
+ // So, are there other uses of Y?
+ if (!Y->hasOneUse()) {
+ // We need to adapt other uses of Y though. Get a value that matches
+ // the original value of Y before inversion. While this increases
+ // immediate instruction count, we have just ensured that all the
+ // users are freely-invertible, so that 'not' *will* get folded away.
+ BuilderTy::InsertPointGuard Guard(Builder);
+ // Set insertion point to right after the Y.
+ Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator()));
+ Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not");
+ // Replace all uses of Y (excluding the one in NotY!) with NotY.
+ Y->replaceUsesWithIf(NotY,
+ [NotY](Use &U) { return U.getUser() != NotY; });
+ }
+ // All done.
return Builder.CreateAnd(LHS, RHS);
}
}
@@ -2747,9 +2949,9 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I,
return nullptr;
// We only want to do the transform if it is free to do.
- if (IsFreeToInvert(X, X->hasOneUse())) {
+ if (isFreeToInvert(X, X->hasOneUse())) {
// Ok, good.
- } else if (IsFreeToInvert(Y, Y->hasOneUse())) {
+ } else if (isFreeToInvert(Y, Y->hasOneUse())) {
std::swap(X, Y);
} else
return nullptr;
@@ -2827,9 +3029,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// Apply DeMorgan's Law when inverts are free:
// ~(X & Y) --> (~X | ~Y)
// ~(X | Y) --> (~X & ~Y)
- if (IsFreeToInvert(NotVal->getOperand(0),
+ if (isFreeToInvert(NotVal->getOperand(0),
NotVal->getOperand(0)->hasOneUse()) &&
- IsFreeToInvert(NotVal->getOperand(1),
+ isFreeToInvert(NotVal->getOperand(1),
NotVal->getOperand(1)->hasOneUse())) {
Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs");
Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs");
@@ -3004,7 +3206,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
- if (Value *V = foldXorOfICmps(LHS, RHS))
+ if (Value *V = foldXorOfICmps(LHS, RHS, I))
return replaceInstUsesWith(I, V);
if (Instruction *CastedXor = foldCastedBitwiseLogic(I))
@@ -3052,7 +3254,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (SelectPatternResult::isMinOrMax(SPF)) {
// It's possible we get here before the not has been simplified, so make
// sure the input to the not isn't freely invertible.
- if (match(LHS, m_Not(m_Value(X))) && !IsFreeToInvert(X, X->hasOneUse())) {
+ if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) {
Value *NotY = Builder.CreateNot(RHS);
return SelectInst::Create(
Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY);
@@ -3060,7 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// It's possible we get here before the not has been simplified, so make
// sure the input to the not isn't freely invertible.
- if (match(RHS, m_Not(m_Value(Y))) && !IsFreeToInvert(Y, Y->hasOneUse())) {
+ if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) {
Value *NotX = Builder.CreateNot(LHS);
return SelectInst::Create(
Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y);
@@ -3068,8 +3270,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// If both sides are freely invertible, then we can get rid of the xor
// completely.
- if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
- IsFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) {
+ if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
+ isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) {
Value *NotLHS = Builder.CreateNot(LHS);
Value *NotRHS = Builder.CreateNot(RHS);
return SelectInst::Create(
diff --git a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
index 5f37a00f56cf..825f4b468b0a 100644
--- a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
@@ -124,7 +124,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
auto *SI = new StoreInst(RMWI.getValOperand(),
RMWI.getPointerOperand(), &RMWI);
SI->setAtomic(Ordering, RMWI.getSyncScopeID());
- SI->setAlignment(DL.getABITypeAlignment(RMWI.getType()));
+ SI->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType())));
return eraseInstFromFunction(RMWI);
}
@@ -154,6 +154,6 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand());
Load->setAtomic(Ordering, RMWI.getSyncScopeID());
- Load->setAlignment(DL.getABITypeAlignment(RMWI.getType()));
+ Load->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType())));
return Load;
}
diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 4b3333affa72..c650d242cd50 100644
--- a/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -185,7 +185,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy);
LoadInst *L = Builder.CreateLoad(IntType, Src);
// Alignment from the mem intrinsic will be better, so use it.
- L->setAlignment(CopySrcAlign);
+ L->setAlignment(
+ MaybeAlign(CopySrcAlign)); // FIXME: Check if we can use Align instead.
if (CopyMD)
L->setMetadata(LLVMContext::MD_tbaa, CopyMD);
MDNode *LoopMemParallelMD =
@@ -198,7 +199,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
StoreInst *S = Builder.CreateStore(L, Dest);
// Alignment from the mem intrinsic will be better, so use it.
- S->setAlignment(CopyDstAlign);
+ S->setAlignment(
+ MaybeAlign(CopyDstAlign)); // FIXME: Check if we can use Align instead.
if (CopyMD)
S->setMetadata(LLVMContext::MD_tbaa, CopyMD);
if (LoopMemParallelMD)
@@ -223,9 +225,10 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
}
Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) {
- unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT);
- if (MI->getDestAlignment() < Alignment) {
- MI->setDestAlignment(Alignment);
+ const unsigned KnownAlignment =
+ getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT);
+ if (MI->getDestAlignment() < KnownAlignment) {
+ MI->setDestAlignment(KnownAlignment);
return MI;
}
@@ -243,13 +246,9 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) {
ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue());
if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8))
return nullptr;
- uint64_t Len = LenC->getLimitedValue();
- Alignment = MI->getDestAlignment();
+ const uint64_t Len = LenC->getLimitedValue();
assert(Len && "0-sized memory setting should be removed already.");
-
- // Alignment 0 is identity for alignment 1 for memset, but not store.
- if (Alignment == 0)
- Alignment = 1;
+ const Align Alignment = assumeAligned(MI->getDestAlignment());
// If it is an atomic and alignment is less than the size then we will
// introduce the unaligned memory access which will be later transformed
@@ -1060,9 +1059,9 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) {
// If we can unconditionally load from this address, replace with a
// load/select idiom. TODO: use DT for context sensitive query
- if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment,
- II.getModule()->getDataLayout(),
- &II, nullptr)) {
+ if (isDereferenceableAndAlignedPointer(
+ LoadPtr, II.getType(), MaybeAlign(Alignment),
+ II.getModule()->getDataLayout(), &II, nullptr)) {
Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment,
"unmaskedload");
return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3));
@@ -1086,7 +1085,8 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
// If the mask is all ones, this is a plain vector store of the 1st argument.
if (ConstMask->isAllOnesValue()) {
Value *StorePtr = II.getArgOperand(1);
- unsigned Alignment = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ MaybeAlign Alignment(
+ cast<ConstantInt>(II.getArgOperand(2))->getZExtValue());
return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
}
@@ -2234,6 +2234,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(*II, Add);
}
+ // Try to simplify the underlying FMul.
+ if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
+ II->getFastMathFlags(),
+ SQ.getWithInstruction(II))) {
+ auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
+ FAdd->copyFastMathFlags(II);
+ return FAdd;
+ }
+
LLVM_FALLTHROUGH;
}
case Intrinsic::fma: {
@@ -2258,9 +2267,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return II;
}
- // fma x, 1, z -> fadd x, z
- if (match(Src1, m_FPOne())) {
- auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2));
+ // Try to simplify the underlying FMul. We can only apply simplifications
+ // that do not require rounding.
+ if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1),
+ II->getFastMathFlags(),
+ SQ.getWithInstruction(II))) {
+ auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
FAdd->copyFastMathFlags(II);
return FAdd;
}
@@ -2331,7 +2343,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Turn PPC VSX loads into normal loads.
Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
PointerType::getUnqual(II->getType()));
- return new LoadInst(II->getType(), Ptr, Twine(""), false, 1);
+ return new LoadInst(II->getType(), Ptr, Twine(""), false, Align::None());
}
case Intrinsic::ppc_altivec_stvx:
case Intrinsic::ppc_altivec_stvxl:
@@ -2349,7 +2361,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Turn PPC VSX stores into normal stores.
Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType());
Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
- return new StoreInst(II->getArgOperand(0), Ptr, false, 1);
+ return new StoreInst(II->getArgOperand(0), Ptr, false, Align::None());
}
case Intrinsic::ppc_qpx_qvlfs:
// Turn PPC QPX qvlfs -> load if the pointer is known aligned.
@@ -3885,6 +3897,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Asan needs to poison memory to detect invalid access which is possible
// even for empty lifetime range.
if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) ||
+ II->getFunction()->hasFnAttribute(Attribute::SanitizeMemory) ||
II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress))
break;
@@ -3950,10 +3963,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
case Intrinsic::experimental_gc_relocate: {
+ auto &GCR = *cast<GCRelocateInst>(II);
+
+ // If we have two copies of the same pointer in the statepoint argument
+ // list, canonicalize to one. This may let us common gc.relocates.
+ if (GCR.getBasePtr() == GCR.getDerivedPtr() &&
+ GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) {
+ auto *OpIntTy = GCR.getOperand(2)->getType();
+ II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex()));
+ return II;
+ }
+
// Translate facts known about a pointer before relocating into
// facts about the relocate value, while being careful to
// preserve relocation semantics.
- Value *DerivedPtr = cast<GCRelocateInst>(II)->getDerivedPtr();
+ Value *DerivedPtr = GCR.getDerivedPtr();
// Remove the relocation if unused, note that this check is required
// to prevent the cases below from looping forever.
@@ -4177,10 +4201,58 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
return nullptr;
}
+static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) {
+ unsigned NumArgs = Call.getNumArgOperands();
+ ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0));
+ ConstantInt *Op1C =
+ (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1));
+ // Bail out if the allocation size is zero.
+ if ((Op0C && Op0C->isNullValue()) || (Op1C && Op1C->isNullValue()))
+ return;
+
+ if (isMallocLikeFn(&Call, TLI) && Op0C) {
+ if (isOpNewLikeFn(&Call, TLI))
+ Call.addAttribute(AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableBytes(
+ Call.getContext(), Op0C->getZExtValue()));
+ else
+ Call.addAttribute(AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableOrNullBytes(
+ Call.getContext(), Op0C->getZExtValue()));
+ } else if (isReallocLikeFn(&Call, TLI) && Op1C) {
+ Call.addAttribute(AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableOrNullBytes(
+ Call.getContext(), Op1C->getZExtValue()));
+ } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) {
+ bool Overflow;
+ const APInt &N = Op0C->getValue();
+ APInt Size = N.umul_ov(Op1C->getValue(), Overflow);
+ if (!Overflow)
+ Call.addAttribute(AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableOrNullBytes(
+ Call.getContext(), Size.getZExtValue()));
+ } else if (isStrdupLikeFn(&Call, TLI)) {
+ uint64_t Len = GetStringLength(Call.getOperand(0));
+ if (Len) {
+ // strdup
+ if (NumArgs == 1)
+ Call.addAttribute(AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableOrNullBytes(
+ Call.getContext(), Len));
+ // strndup
+ else if (NumArgs == 2 && Op1C)
+ Call.addAttribute(
+ AttributeList::ReturnIndex,
+ Attribute::getWithDereferenceableOrNullBytes(
+ Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1)));
+ }
+ }
+}
+
/// Improvements for call, callbr and invoke instructions.
Instruction *InstCombiner::visitCallBase(CallBase &Call) {
- if (isAllocLikeFn(&Call, &TLI))
- return visitAllocSite(Call);
+ if (isAllocationFn(&Call, &TLI))
+ annotateAnyAllocSite(Call, &TLI);
bool Changed = false;
@@ -4312,6 +4384,9 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) {
if (I) return eraseInstFromFunction(*I);
}
+ if (isAllocLikeFn(&Call, &TLI))
+ return visitAllocSite(Call);
+
return Changed ? &Call : nullptr;
}
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 2c9ba203fbf3..65aaef28d87a 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -140,7 +140,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
}
AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt);
- New->setAlignment(AI.getAlignment());
+ New->setAlignment(MaybeAlign(AI.getAlignment()));
New->takeName(&AI);
New->setUsedWithInAlloca(AI.isUsedWithInAlloca());
@@ -1531,16 +1531,16 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
// what we can and cannot do safely varies from operation to operation, and
// is explained below in the various case statements.
Type *Ty = FPT.getType();
- BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0));
- if (OpI && OpI->hasOneUse()) {
- Type *LHSMinType = getMinimumFPType(OpI->getOperand(0));
- Type *RHSMinType = getMinimumFPType(OpI->getOperand(1));
- unsigned OpWidth = OpI->getType()->getFPMantissaWidth();
+ auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
+ if (BO && BO->hasOneUse()) {
+ Type *LHSMinType = getMinimumFPType(BO->getOperand(0));
+ Type *RHSMinType = getMinimumFPType(BO->getOperand(1));
+ unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
unsigned SrcWidth = std::max(LHSWidth, RHSWidth);
unsigned DstWidth = Ty->getFPMantissaWidth();
- switch (OpI->getOpcode()) {
+ switch (BO->getOpcode()) {
default: break;
case Instruction::FAdd:
case Instruction::FSub:
@@ -1563,10 +1563,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
// could be tightened for those cases, but they are rare (the main
// case of interest here is (float)((double)float + float)).
if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS);
- RI->copyFastMathFlags(OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS);
+ RI->copyFastMathFlags(BO);
return RI;
}
break;
@@ -1577,9 +1577,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
// rounding can possibly occur; we can safely perform the operation
// in the destination format if it can represent both sources.
if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ return BinaryOperator::CreateFMulFMF(LHS, RHS, BO);
}
break;
case Instruction::FDiv:
@@ -1590,9 +1590,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
// condition used here is a good conservative first pass.
// TODO: Tighten bound via rigorous analysis of the unbalanced case.
if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ return BinaryOperator::CreateFDivFMF(LHS, RHS, BO);
}
break;
case Instruction::FRem: {
@@ -1604,14 +1604,14 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
break;
Value *LHS, *RHS;
if (LHSWidth == SrcWidth) {
- LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType);
- RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType);
+ LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType);
+ RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType);
} else {
- LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType);
- RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType);
+ LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType);
+ RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType);
}
- Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI);
+ Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO);
return CastInst::CreateFPCast(ExactResult, Ty);
}
}
@@ -2338,8 +2338,23 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
// If we found a path from the src to dest, create the getelementptr now.
if (SrcElTy == DstElTy) {
SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0));
- return GetElementPtrInst::CreateInBounds(SrcPTy->getElementType(), Src,
- Idxs);
+ GetElementPtrInst *GEP =
+ GetElementPtrInst::Create(SrcPTy->getElementType(), Src, Idxs);
+
+ // If the source pointer is dereferenceable, then assume it points to an
+ // allocated object and apply "inbounds" to the GEP.
+ bool CanBeNull;
+ if (Src->getPointerDereferenceableBytes(DL, CanBeNull)) {
+ // In a non-default address space (not 0), a null pointer can not be
+ // assumed inbounds, so ignore that case (dereferenceable_or_null).
+ // The reason is that 'null' is not treated differently in these address
+ // spaces, and we consequently ignore the 'gep inbounds' special case
+ // for 'null' which allows 'inbounds' on 'null' if the indices are
+ // zeros.
+ if (SrcPTy->getAddressSpace() == 0 || !CanBeNull)
+ GEP->setIsInBounds();
+ }
+ return GEP;
}
}
@@ -2391,28 +2406,47 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
}
}
- if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) {
+ if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) {
// Okay, we have (bitcast (shuffle ..)). Check to see if this is
// a bitcast to a vector with the same # elts.
- if (SVI->hasOneUse() && DestTy->isVectorTy() &&
- DestTy->getVectorNumElements() == SVI->getType()->getNumElements() &&
- SVI->getType()->getNumElements() ==
- SVI->getOperand(0)->getType()->getVectorNumElements()) {
+ Value *ShufOp0 = Shuf->getOperand(0);
+ Value *ShufOp1 = Shuf->getOperand(1);
+ unsigned NumShufElts = Shuf->getType()->getVectorNumElements();
+ unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements();
+ if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
+ DestTy->getVectorNumElements() == NumShufElts &&
+ NumShufElts == NumSrcVecElts) {
BitCastInst *Tmp;
// If either of the operands is a cast from CI.getType(), then
// evaluating the shuffle in the casted destination's type will allow
// us to eliminate at least one cast.
- if (((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(0))) &&
+ if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) &&
Tmp->getOperand(0)->getType() == DestTy) ||
- ((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(1))) &&
+ ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) &&
Tmp->getOperand(0)->getType() == DestTy)) {
- Value *LHS = Builder.CreateBitCast(SVI->getOperand(0), DestTy);
- Value *RHS = Builder.CreateBitCast(SVI->getOperand(1), DestTy);
+ Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy);
+ Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy);
// Return a new shuffle vector. Use the same element ID's, as we
// know the vector types match #elts.
- return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2));
+ return new ShuffleVectorInst(LHS, RHS, Shuf->getOperand(2));
}
}
+
+ // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as
+ // a byte-swap:
+ // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X)
+ // TODO: We should match the related pattern for bitreverse.
+ if (DestTy->isIntegerTy() &&
+ DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
+ SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 &&
+ Shuf->hasOneUse() && Shuf->isReverse()) {
+ assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
+ assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op");
+ Function *Bswap =
+ Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy);
+ Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy);
+ return IntrinsicInst::Create(Bswap, { ScalarX });
+ }
}
// Handle the A->B->A cast, and there is an intervening PHI node.
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3a4283ae5406..a9f64feb600c 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -69,34 +69,6 @@ static bool hasBranchUse(ICmpInst &I) {
return false;
}
-/// Given an exploded icmp instruction, return true if the comparison only
-/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the
-/// result of the comparison is true when the input value is signed.
-static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
- bool &TrueIfSigned) {
- switch (Pred) {
- case ICmpInst::ICMP_SLT: // True if LHS s< 0
- TrueIfSigned = true;
- return RHS.isNullValue();
- case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1
- TrueIfSigned = true;
- return RHS.isAllOnesValue();
- case ICmpInst::ICMP_SGT: // True if LHS s> -1
- TrueIfSigned = false;
- return RHS.isAllOnesValue();
- case ICmpInst::ICMP_UGT:
- // True if LHS u> RHS and RHS == high-bit-mask - 1
- TrueIfSigned = true;
- return RHS.isMaxSignedValue();
- case ICmpInst::ICMP_UGE:
- // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc)
- TrueIfSigned = true;
- return RHS.isSignMask();
- default:
- return false;
- }
-}
-
/// Returns true if the exploded icmp can be expressed as a signed comparison
/// to zero and updates the predicate accordingly.
/// The signedness of the comparison is preserved.
@@ -832,6 +804,10 @@ getAsConstantIndexedAddress(Value *V, const DataLayout &DL) {
static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
ICmpInst::Predicate Cond,
const DataLayout &DL) {
+ // FIXME: Support vector of pointers.
+ if (GEPLHS->getType()->isVectorTy())
+ return nullptr;
+
if (!GEPLHS->hasAllConstantIndices())
return nullptr;
@@ -882,7 +858,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
RHS = RHS->stripPointerCasts();
Value *PtrBase = GEPLHS->getOperand(0);
- if (PtrBase == RHS && GEPLHS->isInBounds()) {
+ // FIXME: Support vector pointer GEPs.
+ if (PtrBase == RHS && GEPLHS->isInBounds() &&
+ !GEPLHS->getType()->isVectorTy()) {
// ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0).
// This transformation (ignoring the base and scales) is valid because we
// know pointers can't overflow since the gep is inbounds. See if we can
@@ -894,6 +872,37 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
Offset = EmitGEPOffset(GEPLHS);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset,
Constant::getNullValue(Offset->getType()));
+ }
+
+ if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) &&
+ isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() &&
+ !NullPointerIsDefined(I.getFunction(),
+ RHS->getType()->getPointerAddressSpace())) {
+ // For most address spaces, an allocation can't be placed at null, but null
+ // itself is treated as a 0 size allocation in the in bounds rules. Thus,
+ // the only valid inbounds address derived from null, is null itself.
+ // Thus, we have four cases to consider:
+ // 1) Base == nullptr, Offset == 0 -> inbounds, null
+ // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds
+ // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations)
+ // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison)
+ //
+ // (Note if we're indexing a type of size 0, that simply collapses into one
+ // of the buckets above.)
+ //
+ // In general, we're allowed to make values less poison (i.e. remove
+ // sources of full UB), so in this case, we just select between the two
+ // non-poison cases (1 and 4 above).
+ //
+ // For vectors, we apply the same reasoning on a per-lane basis.
+ auto *Base = GEPLHS->getPointerOperand();
+ if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
+ int NumElts = GEPLHS->getType()->getVectorNumElements();
+ Base = Builder.CreateVectorSplat(NumElts, Base);
+ }
+ return new ICmpInst(Cond, Base,
+ ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(RHS), Base->getType()));
} else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) {
// If the base pointers are different, but the indices are the same, just
// compare the base pointer.
@@ -916,11 +925,13 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// If we're comparing GEPs with two base pointers that only differ in type
// and both GEPs have only constant indices or just one use, then fold
// the compare with the adjusted indices.
+ // FIXME: Support vector of pointers.
if (GEPLHS->isInBounds() && GEPRHS->isInBounds() &&
(GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) &&
(GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) &&
PtrBase->stripPointerCasts() ==
- GEPRHS->getOperand(0)->stripPointerCasts()) {
+ GEPRHS->getOperand(0)->stripPointerCasts() &&
+ !GEPLHS->getType()->isVectorTy()) {
Value *LOffset = EmitGEPOffset(GEPLHS);
Value *ROffset = EmitGEPOffset(GEPRHS);
@@ -949,12 +960,14 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
}
// If one of the GEPs has all zero indices, recurse.
- if (GEPLHS->hasAllZeroIndices())
+ // FIXME: Handle vector of pointers.
+ if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices())
return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
ICmpInst::getSwappedPredicate(Cond), I);
// If the other GEP has all zero indices, recurse.
- if (GEPRHS->hasAllZeroIndices())
+ // FIXME: Handle vector of pointers.
+ if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices())
return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
@@ -964,15 +977,20 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
unsigned DiffOperand = 0; // The operand that differs.
for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i)
if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) {
- if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() !=
- GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) {
+ Type *LHSType = GEPLHS->getOperand(i)->getType();
+ Type *RHSType = GEPRHS->getOperand(i)->getType();
+ // FIXME: Better support for vector of pointers.
+ if (LHSType->getPrimitiveSizeInBits() !=
+ RHSType->getPrimitiveSizeInBits() ||
+ (GEPLHS->getType()->isVectorTy() &&
+ (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) {
// Irreconcilable differences.
NumDifferences = 2;
break;
- } else {
- if (NumDifferences++) break;
- DiffOperand = i;
}
+
+ if (NumDifferences++) break;
+ DiffOperand = i;
}
if (NumDifferences == 0) // SAME GEP?
@@ -1317,6 +1335,59 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
}
+/// If we have:
+/// icmp eq/ne (urem/srem %x, %y), 0
+/// iff %y is a power-of-two, we can replace this with a bit test:
+/// icmp eq/ne (and %x, (add %y, -1)), 0
+Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
+ // This fold is only valid for equality predicates.
+ if (!I.isEquality())
+ return nullptr;
+ ICmpInst::Predicate Pred;
+ Value *X, *Y, *Zero;
+ if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))),
+ m_CombineAnd(m_Zero(), m_Value(Zero)))))
+ return nullptr;
+ if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I))
+ return nullptr;
+ // This may increase instruction count, we don't enforce that Y is a constant.
+ Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType()));
+ Value *Masked = Builder.CreateAnd(X, Mask);
+ return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero);
+}
+
+/// Fold equality-comparison between zero and any (maybe truncated) right-shift
+/// by one-less-than-bitwidth into a sign test on the original value.
+Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) {
+ Instruction *Val;
+ ICmpInst::Predicate Pred;
+ if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
+ return nullptr;
+
+ Value *X;
+ Type *XTy;
+
+ Constant *C;
+ if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) {
+ XTy = X->getType();
+ unsigned XBitWidth = XTy->getScalarSizeInBits();
+ if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(XBitWidth, XBitWidth - 1))))
+ return nullptr;
+ } else if (isa<BinaryOperator>(Val) &&
+ (X = reassociateShiftAmtsOfTwoSameDirectionShifts(
+ cast<BinaryOperator>(Val), SQ.getWithInstruction(Val),
+ /*AnalyzeForSignBitExtraction=*/true))) {
+ XTy = X->getType();
+ } else
+ return nullptr;
+
+ return ICmpInst::Create(Instruction::ICmp,
+ Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE
+ : ICmpInst::ICMP_SLT,
+ X, ConstantInt::getNullValue(XTy));
+}
+
// Handle icmp pred X, 0
Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {
CmpInst::Predicate Pred = Cmp.getPredicate();
@@ -1335,6 +1406,9 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {
}
}
+ if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp))
+ return New;
+
// Given:
// icmp eq/ne (urem %x, %y), 0
// Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem':
@@ -2179,6 +2253,44 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
return nullptr;
}
+Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp,
+ BinaryOperator *SRem,
+ const APInt &C) {
+ // Match an 'is positive' or 'is negative' comparison of remainder by a
+ // constant power-of-2 value:
+ // (X % pow2C) sgt/slt 0
+ const ICmpInst::Predicate Pred = Cmp.getPredicate();
+ if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT)
+ return nullptr;
+
+ // TODO: The one-use check is standard because we do not typically want to
+ // create longer instruction sequences, but this might be a special-case
+ // because srem is not good for analysis or codegen.
+ if (!SRem->hasOneUse())
+ return nullptr;
+
+ const APInt *DivisorC;
+ if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC)))
+ return nullptr;
+
+ // Mask off the sign bit and the modulo bits (low-bits).
+ Type *Ty = SRem->getType();
+ APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits());
+ Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1));
+ Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC);
+
+ // For 'is positive?' check that the sign-bit is clear and at least 1 masked
+ // bit is set. Example:
+ // (i8 X % 32) s> 0 --> (X & 159) s> 0
+ if (Pred == ICmpInst::ICMP_SGT)
+ return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty));
+
+ // For 'is negative?' check that the sign-bit is set and at least 1 masked
+ // bit is set. Example:
+ // (i16 X % 4) s< 0 --> (X & 32771) u> 32768
+ return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
+}
+
/// Fold icmp (udiv X, Y), C.
Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
@@ -2387,6 +2499,11 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
const APInt *C2;
APInt SubResult;
+ // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0
+ if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality())
+ return new ICmpInst(Cmp.getPredicate(), Y,
+ ConstantInt::get(Y->getType(), 0));
+
// (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C)
if (match(X, m_APInt(C2)) &&
((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) ||
@@ -2509,20 +2626,49 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
// TODO: Generalize this to work with other comparison idioms or ensure
// they get canonicalized into this form.
- // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32
- // Greater), where Equal, Less and Greater are placeholders for any three
- // constants.
- ICmpInst::Predicate PredA, PredB;
- if (match(SI->getTrueValue(), m_ConstantInt(Equal)) &&
- match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) &&
- PredA == ICmpInst::ICMP_EQ &&
- match(SI->getFalseValue(),
- m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)),
- m_ConstantInt(Less), m_ConstantInt(Greater))) &&
- PredB == ICmpInst::ICMP_SLT) {
- return true;
+ // select i1 (a == b),
+ // i32 Equal,
+ // i32 (select i1 (a < b), i32 Less, i32 Greater)
+ // where Equal, Less and Greater are placeholders for any three constants.
+ ICmpInst::Predicate PredA;
+ if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) ||
+ !ICmpInst::isEquality(PredA))
+ return false;
+ Value *EqualVal = SI->getTrueValue();
+ Value *UnequalVal = SI->getFalseValue();
+ // We still can get non-canonical predicate here, so canonicalize.
+ if (PredA == ICmpInst::ICMP_NE)
+ std::swap(EqualVal, UnequalVal);
+ if (!match(EqualVal, m_ConstantInt(Equal)))
+ return false;
+ ICmpInst::Predicate PredB;
+ Value *LHS2, *RHS2;
+ if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)),
+ m_ConstantInt(Less), m_ConstantInt(Greater))))
+ return false;
+ // We can get predicate mismatch here, so canonicalize if possible:
+ // First, ensure that 'LHS' match.
+ if (LHS2 != LHS) {
+ // x sgt y <--> y slt x
+ std::swap(LHS2, RHS2);
+ PredB = ICmpInst::getSwappedPredicate(PredB);
+ }
+ if (LHS2 != LHS)
+ return false;
+ // We also need to canonicalize 'RHS'.
+ if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
+ // x sgt C-1 <--> x sge C <--> not(x slt C)
+ auto FlippedStrictness =
+ getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
+ if (!FlippedStrictness)
+ return false;
+ assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check");
+ RHS2 = FlippedStrictness->second;
+ // And kind-of perform the result swap.
+ std::swap(Less, Greater);
+ PredB = ICmpInst::ICMP_SLT;
}
- return false;
+ return PredB == ICmpInst::ICMP_SLT && RHS == RHS2;
}
Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,
@@ -2702,6 +2848,10 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C))
return I;
break;
+ case Instruction::SRem:
+ if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C))
+ return I;
+ break;
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C))
return I;
@@ -2926,6 +3076,28 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
}
break;
}
+
+ case Intrinsic::uadd_sat: {
+ // uadd.sat(a, b) == 0 -> (a | b) == 0
+ if (C.isNullValue()) {
+ Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1));
+ return replaceInstUsesWith(Cmp, Builder.CreateICmp(
+ Cmp.getPredicate(), Or, Constant::getNullValue(Ty)));
+
+ }
+ break;
+ }
+
+ case Intrinsic::usub_sat: {
+ // usub.sat(a, b) == 0 -> a <= b
+ if (C.isNullValue()) {
+ ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ
+ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
+ return ICmpInst::Create(Instruction::ICmp, NewPred,
+ II->getArgOperand(0), II->getArgOperand(1));
+ }
+ break;
+ }
default:
break;
}
@@ -3275,6 +3447,7 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I,
// we should move shifts to the same hand of 'and', i.e. rewrite as
// icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x)
// We are only interested in opposite logical shifts here.
+// One of the shifts can be truncated.
// If we can, we want to end up creating 'lshr' shift.
static Value *
foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
@@ -3284,55 +3457,215 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
return nullptr;
auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value());
- auto m_AnyLShr = m_LShr(m_Value(), m_Value());
-
- // Look for an 'and' of two (opposite) logical shifts.
- // Pick the single-use shift as XShift.
- Value *XShift, *YShift;
- if (!match(I.getOperand(0),
- m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))),
- m_CombineAnd(m_AnyLogicalShift, m_Value(YShift)))))
+
+ // Look for an 'and' of two logical shifts, one of which may be truncated.
+ // We use m_TruncOrSelf() on the RHS to correctly handle commutative case.
+ Instruction *XShift, *MaybeTruncation, *YShift;
+ if (!match(
+ I.getOperand(0),
+ m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)),
+ m_CombineAnd(m_TruncOrSelf(m_CombineAnd(
+ m_AnyLogicalShift, m_Instruction(YShift))),
+ m_Instruction(MaybeTruncation)))))
return nullptr;
- // If YShift is a single-use 'lshr', swap the shifts around.
- if (match(YShift, m_OneUse(m_AnyLShr)))
+ // We potentially looked past 'trunc', but only when matching YShift,
+ // therefore YShift must have the widest type.
+ Instruction *WidestShift = YShift;
+ // Therefore XShift must have the shallowest type.
+ // Or they both have identical types if there was no truncation.
+ Instruction *NarrowestShift = XShift;
+
+ Type *WidestTy = WidestShift->getType();
+ assert(NarrowestShift->getType() == I.getOperand(0)->getType() &&
+ "We did not look past any shifts while matching XShift though.");
+ bool HadTrunc = WidestTy != I.getOperand(0)->getType();
+
+ // If YShift is a 'lshr', swap the shifts around.
+ if (match(YShift, m_LShr(m_Value(), m_Value())))
std::swap(XShift, YShift);
// The shifts must be in opposite directions.
- Instruction::BinaryOps XShiftOpcode =
- cast<BinaryOperator>(XShift)->getOpcode();
- if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode())
+ auto XShiftOpcode = XShift->getOpcode();
+ if (XShiftOpcode == YShift->getOpcode())
return nullptr; // Do not care about same-direction shifts here.
Value *X, *XShAmt, *Y, *YShAmt;
- match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt)));
- match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt)));
+ match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt))));
+ match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt))));
+
+ // If one of the values being shifted is a constant, then we will end with
+ // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not,
+ // however, we will need to ensure that we won't increase instruction count.
+ if (!isa<Constant>(X) && !isa<Constant>(Y)) {
+ // At least one of the hands of the 'and' should be one-use shift.
+ if (!match(I.getOperand(0),
+ m_c_And(m_OneUse(m_AnyLogicalShift), m_Value())))
+ return nullptr;
+ if (HadTrunc) {
+ // Due to the 'trunc', we will need to widen X. For that either the old
+ // 'trunc' or the shift amt in the non-truncated shift should be one-use.
+ if (!MaybeTruncation->hasOneUse() &&
+ !NarrowestShift->getOperand(1)->hasOneUse())
+ return nullptr;
+ }
+ }
+
+ // We have two shift amounts from two different shifts. The types of those
+ // shift amounts may not match. If that's the case let's bailout now.
+ if (XShAmt->getType() != YShAmt->getType())
+ return nullptr;
// Can we fold (XShAmt+YShAmt) ?
- Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt,
- SQ.getWithInstruction(&I));
+ auto *NewShAmt = dyn_cast_or_null<Constant>(
+ SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false,
+ /*isNUW=*/false, SQ.getWithInstruction(&I)));
if (!NewShAmt)
return nullptr;
+ NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy);
+ unsigned WidestBitWidth = WidestTy->getScalarSizeInBits();
+
// Is the new shift amount smaller than the bit width?
// FIXME: could also rely on ConstantRange.
- unsigned BitWidth = X->getType()->getScalarSizeInBits();
- if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
- APInt(BitWidth, BitWidth))))
+ if (!match(NewShAmt,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
+ APInt(WidestBitWidth, WidestBitWidth))))
return nullptr;
- // All good, we can do this fold. The shift is the same that was for X.
+
+ // An extra legality check is needed if we had trunc-of-lshr.
+ if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) {
+ auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ,
+ WidestShift]() {
+ // It isn't obvious whether it's worth it to analyze non-constants here.
+ // Also, let's basically give up on non-splat cases, pessimizing vectors.
+ // If *any* of these preconditions matches we can perform the fold.
+ Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy()
+ ? NewShAmt->getSplatValue()
+ : NewShAmt;
+ // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold.
+ if (NewShAmtSplat &&
+ (NewShAmtSplat->isNullValue() ||
+ NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1))
+ return true;
+ // We consider *min* leading zeros so a single outlier
+ // blocks the transform as opposed to allowing it.
+ if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) {
+ KnownBits Known = computeKnownBits(C, SQ.DL);
+ unsigned MinLeadZero = Known.countMinLeadingZeros();
+ // If the value being shifted has at most lowest bit set we can fold.
+ unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero;
+ if (MaxActiveBits <= 1)
+ return true;
+ // Precondition: NewShAmt u<= countLeadingZeros(C)
+ if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero))
+ return true;
+ }
+ if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) {
+ KnownBits Known = computeKnownBits(C, SQ.DL);
+ unsigned MinLeadZero = Known.countMinLeadingZeros();
+ // If the value being shifted has at most lowest bit set we can fold.
+ unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero;
+ if (MaxActiveBits <= 1)
+ return true;
+ // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C)
+ if (NewShAmtSplat) {
+ APInt AdjNewShAmt =
+ (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger();
+ if (AdjNewShAmt.ule(MinLeadZero))
+ return true;
+ }
+ }
+ return false; // Can't tell if it's ok.
+ };
+ if (!CanFold())
+ return nullptr;
+ }
+
+ // All good, we can do this fold.
+ X = Builder.CreateZExt(X, WidestTy);
+ Y = Builder.CreateZExt(Y, WidestTy);
+ // The shift is the same that was for X.
Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr
? Builder.CreateLShr(X, NewShAmt)
: Builder.CreateShl(X, NewShAmt);
Value *T1 = Builder.CreateAnd(T0, Y);
return Builder.CreateICmp(I.getPredicate(), T1,
- Constant::getNullValue(X->getType()));
+ Constant::getNullValue(WidestTy));
+}
+
+/// Fold
+/// (-1 u/ x) u< y
+/// ((x * y) u/ x) != y
+/// to
+/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit
+/// Note that the comparison is commutative, while inverted (u>=, ==) predicate
+/// will mean that we are looking for the opposite answer.
+Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
+ ICmpInst::Predicate Pred;
+ Value *X, *Y;
+ Instruction *Mul;
+ bool NeedNegation;
+ // Look for: (-1 u/ x) u</u>= y
+ if (!I.isEquality() &&
+ match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
+ m_Value(Y)))) {
+ Mul = nullptr;
+ // Canonicalize as-if y was on RHS.
+ if (I.getOperand(1) != Y)
+ Pred = I.getSwappedPredicate();
+
+ // Are we checking that overflow does not happen, or does happen?
+ switch (Pred) {
+ case ICmpInst::Predicate::ICMP_ULT:
+ NeedNegation = false;
+ break; // OK
+ case ICmpInst::Predicate::ICMP_UGE:
+ NeedNegation = true;
+ break; // OK
+ default:
+ return nullptr; // Wrong predicate.
+ }
+ } else // Look for: ((x * y) u/ x) !=/== y
+ if (I.isEquality() &&
+ match(&I, m_c_ICmp(Pred, m_Value(Y),
+ m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
+ m_Value(X)),
+ m_Instruction(Mul)),
+ m_Deferred(X)))))) {
+ NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ;
+ } else
+ return nullptr;
+
+ BuilderTy::InsertPointGuard Guard(Builder);
+ // If the pattern included (x * y), we'll want to insert new instructions
+ // right before that original multiplication so that we can replace it.
+ bool MulHadOtherUses = Mul && !Mul->hasOneUse();
+ if (MulHadOtherUses)
+ Builder.SetInsertPoint(Mul);
+
+ Function *F = Intrinsic::getDeclaration(
+ I.getModule(), Intrinsic::umul_with_overflow, X->getType());
+ CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul");
+
+ // If the multiplication was used elsewhere, to ensure that we don't leave
+ // "duplicate" instructions, replace uses of that original multiplication
+ // with the multiplication result from the with.overflow intrinsic.
+ if (MulHadOtherUses)
+ replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val"));
+
+ Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov");
+ if (NeedNegation) // This technically increases instruction count.
+ Res = Builder.CreateNot(Res, "umul.not.ov");
+
+ return Res;
}
/// Try to fold icmp (binop), X or icmp X, (binop).
/// TODO: A large part of this logic is duplicated in InstSimplify's
/// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
/// duplication.
-Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
+Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
// Special logic for binary operators.
@@ -3345,13 +3678,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
Value *X;
// Convert add-with-unsigned-overflow comparisons into a 'not' with compare.
- // (Op1 + X) <u Op1 --> ~Op1 <u X
- // Op0 >u (Op0 + X) --> X >u ~Op0
+ // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X
if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) &&
- Pred == ICmpInst::ICMP_ULT)
+ (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE))
return new ICmpInst(Pred, Builder.CreateNot(Op1), X);
+ // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0
if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) &&
- Pred == ICmpInst::ICMP_UGT)
+ (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE))
return new ICmpInst(Pred, X, Builder.CreateNot(Op0));
bool NoOp0WrapProblem = false, NoOp1WrapProblem = false;
@@ -3378,21 +3711,21 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
D = BO1->getOperand(1);
}
- // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
+ // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow.
+ // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow.
if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
return new ICmpInst(Pred, A == Op1 ? B : A,
Constant::getNullValue(Op1->getType()));
- // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow.
+ // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow.
+ // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow.
if ((C == Op0 || D == Op0) && NoOp1WrapProblem)
return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()),
C == Op0 ? D : C);
- // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow.
+ // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow.
if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem &&
- NoOp1WrapProblem &&
- // Try not to increase register pressure.
- BO0->hasOneUse() && BO1->hasOneUse()) {
+ NoOp1WrapProblem) {
// Determine Y and Z in the form icmp (X+Y), (X+Z).
Value *Y, *Z;
if (A == C) {
@@ -3416,39 +3749,39 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
return new ICmpInst(Pred, Y, Z);
}
- // icmp slt (X + -1), Y -> icmp sle X, Y
+ // icmp slt (A + -1), Op1 -> icmp sle A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT &&
match(B, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SLE, A, Op1);
- // icmp sge (X + -1), Y -> icmp sgt X, Y
+ // icmp sge (A + -1), Op1 -> icmp sgt A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE &&
match(B, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SGT, A, Op1);
- // icmp sle (X + 1), Y -> icmp slt X, Y
+ // icmp sle (A + 1), Op1 -> icmp slt A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_SLT, A, Op1);
- // icmp sgt (X + 1), Y -> icmp sge X, Y
+ // icmp sgt (A + 1), Op1 -> icmp sge A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_SGE, A, Op1);
- // icmp sgt X, (Y + -1) -> icmp sge X, Y
+ // icmp sgt Op0, (C + -1) -> icmp sge Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT &&
match(D, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SGE, Op0, C);
- // icmp sle X, (Y + -1) -> icmp slt X, Y
+ // icmp sle Op0, (C + -1) -> icmp slt Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE &&
match(D, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SLT, Op0, C);
- // icmp sge X, (Y + 1) -> icmp sgt X, Y
+ // icmp sge Op0, (C + 1) -> icmp sgt Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SGT, Op0, C);
- // icmp slt X, (Y + 1) -> icmp sle X, Y
+ // icmp slt Op0, (C + 1) -> icmp sle Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
@@ -3456,33 +3789,33 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
// canonicalization from (X -nuw 1) to (X + -1) means that the combinations
// wouldn't happen even if they were implemented.
//
- // icmp ult (X - 1), Y -> icmp ule X, Y
- // icmp uge (X - 1), Y -> icmp ugt X, Y
- // icmp ugt X, (Y - 1) -> icmp uge X, Y
- // icmp ule X, (Y - 1) -> icmp ult X, Y
+ // icmp ult (A - 1), Op1 -> icmp ule A, Op1
+ // icmp uge (A - 1), Op1 -> icmp ugt A, Op1
+ // icmp ugt Op0, (C - 1) -> icmp uge Op0, C
+ // icmp ule Op0, (C - 1) -> icmp ult Op0, C
- // icmp ule (X + 1), Y -> icmp ult X, Y
+ // icmp ule (A + 1), Op0 -> icmp ult A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_ULT, A, Op1);
- // icmp ugt (X + 1), Y -> icmp uge X, Y
+ // icmp ugt (A + 1), Op0 -> icmp uge A, Op1
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_UGE, A, Op1);
- // icmp uge X, (Y + 1) -> icmp ugt X, Y
+ // icmp uge Op0, (C + 1) -> icmp ugt Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_UGT, Op0, C);
- // icmp ult X, (Y + 1) -> icmp ule X, Y
+ // icmp ult Op0, (C + 1) -> icmp ule Op0, C
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_ULE, Op0, C);
// if C1 has greater magnitude than C2:
- // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y
+ // icmp (A + C1), (C + C2) -> icmp (A + C3), C
// s.t. C3 = C1 - C2
//
// if C2 has greater magnitude than C1:
- // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3)
+ // icmp (A + C1), (C + C2) -> icmp A, (C + C3)
// s.t. C3 = C2 - C1
if (A && C && NoOp0WrapProblem && NoOp1WrapProblem &&
(BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned())
@@ -3520,29 +3853,35 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
D = BO1->getOperand(1);
}
- // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow.
+ // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow.
if (A == Op1 && NoOp0WrapProblem)
return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B);
- // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow.
+ // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow.
if (C == Op0 && NoOp1WrapProblem)
return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType()));
- // (A - B) >u A --> A <u B
- if (A == Op1 && Pred == ICmpInst::ICMP_UGT)
- return new ICmpInst(ICmpInst::ICMP_ULT, A, B);
- // C <u (C - D) --> C <u D
- if (C == Op0 && Pred == ICmpInst::ICMP_ULT)
- return new ICmpInst(ICmpInst::ICMP_ULT, C, D);
-
- // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow.
- if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem &&
- // Try not to increase register pressure.
- BO0->hasOneUse() && BO1->hasOneUse())
+ // Convert sub-with-unsigned-overflow comparisons into a comparison of args.
+ // (A - B) u>/u<= A --> B u>/u<= A
+ if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE))
+ return new ICmpInst(Pred, B, A);
+ // C u</u>= (C - D) --> C u</u>= D
+ if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE))
+ return new ICmpInst(Pred, C, D);
+ // (A - B) u>=/u< A --> B u>/u<= A iff B != 0
+ if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) &&
+ isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A);
+ // C u<=/u> (C - D) --> C u</u>= D iff B != 0
+ if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) &&
+ isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D);
+
+ // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow.
+ if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, A, C);
- // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow.
- if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem &&
- // Try not to increase register pressure.
- BO0->hasOneUse() && BO1->hasOneUse())
+
+ // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow.
+ if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, D, B);
// icmp (0-X) < cst --> x > -cst
@@ -3677,6 +4016,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
}
}
+ if (Value *V = foldUnsignedMultiplicationOverflowCheck(I))
+ return replaceInstUsesWith(I, V);
+
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))
return replaceInstUsesWith(I, V);
@@ -3953,125 +4295,140 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
return nullptr;
}
-/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so
-/// far.
-Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) {
- const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0));
- Value *LHSCIOp = LHSCI->getOperand(0);
- Type *SrcTy = LHSCIOp->getType();
- Type *DestTy = LHSCI->getType();
-
- // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
- // integer type is the same size as the pointer type.
- const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool {
- if (isa<VectorType>(SrcTy)) {
- SrcTy = cast<VectorType>(SrcTy)->getElementType();
- DestTy = cast<VectorType>(DestTy)->getElementType();
- }
- return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
- };
- if (LHSCI->getOpcode() == Instruction::PtrToInt &&
- CompatibleSizes(SrcTy, DestTy)) {
- Value *RHSOp = nullptr;
- if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
- Value *RHSCIOp = RHSC->getOperand(0);
- if (RHSCIOp->getType()->getPointerAddressSpace() ==
- LHSCIOp->getType()->getPointerAddressSpace()) {
- RHSOp = RHSC->getOperand(0);
- // If the pointer types don't match, insert a bitcast.
- if (LHSCIOp->getType() != RHSOp->getType())
- RHSOp = Builder.CreateBitCast(RHSOp, LHSCIOp->getType());
- }
- } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
- RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy);
- }
-
- if (RHSOp)
- return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp);
- }
-
- // The code below only handles extension cast instructions, so far.
- // Enforce this.
- if (LHSCI->getOpcode() != Instruction::ZExt &&
- LHSCI->getOpcode() != Instruction::SExt)
+static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
+ InstCombiner::BuilderTy &Builder) {
+ assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0");
+ auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0));
+ Value *X;
+ if (!match(CastOp0, m_ZExtOrSExt(m_Value(X))))
return nullptr;
- bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt;
- bool isSignedCmp = ICmp.isSigned();
-
- if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) {
- // Not an extension from the same type?
- Value *RHSCIOp = CI->getOperand(0);
- if (RHSCIOp->getType() != LHSCIOp->getType())
- return nullptr;
-
+ bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt;
+ bool IsSignedCmp = ICmp.isSigned();
+ if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) {
// If the signedness of the two casts doesn't agree (i.e. one is a sext
// and the other is a zext), then we can't handle this.
- if (CI->getOpcode() != LHSCI->getOpcode())
+ // TODO: This is too strict. We can handle some predicates (equality?).
+ if (CastOp0->getOpcode() != CastOp1->getOpcode())
return nullptr;
- // Deal with equality cases early.
+ // Not an extension from the same type?
+ Value *Y = CastOp1->getOperand(0);
+ Type *XTy = X->getType(), *YTy = Y->getType();
+ if (XTy != YTy) {
+ // One of the casts must have one use because we are creating a new cast.
+ if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse())
+ return nullptr;
+ // Extend the narrower operand to the type of the wider operand.
+ if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits())
+ X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy);
+ else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits())
+ Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy);
+ else
+ return nullptr;
+ }
+
+ // (zext X) == (zext Y) --> X == Y
+ // (sext X) == (sext Y) --> X == Y
if (ICmp.isEquality())
- return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp);
+ return new ICmpInst(ICmp.getPredicate(), X, Y);
// A signed comparison of sign extended values simplifies into a
// signed comparison.
- if (isSignedCmp && isSignedExt)
- return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp);
+ if (IsSignedCmp && IsSignedExt)
+ return new ICmpInst(ICmp.getPredicate(), X, Y);
// The other three cases all fold into an unsigned comparison.
- return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp);
+ return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y);
}
- // If we aren't dealing with a constant on the RHS, exit early.
+ // Below here, we are only folding a compare with constant.
auto *C = dyn_cast<Constant>(ICmp.getOperand(1));
if (!C)
return nullptr;
// Compute the constant that would happen if we truncated to SrcTy then
// re-extended to DestTy.
+ Type *SrcTy = CastOp0->getSrcTy();
+ Type *DestTy = CastOp0->getDestTy();
Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
- Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy);
+ Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
// If the re-extended constant didn't change...
if (Res2 == C) {
- // Deal with equality cases early.
if (ICmp.isEquality())
- return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1);
+ return new ICmpInst(ICmp.getPredicate(), X, Res1);
// A signed comparison of sign extended values simplifies into a
// signed comparison.
- if (isSignedExt && isSignedCmp)
- return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1);
+ if (IsSignedExt && IsSignedCmp)
+ return new ICmpInst(ICmp.getPredicate(), X, Res1);
// The other three cases all fold into an unsigned comparison.
- return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1);
+ return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
}
// The re-extended constant changed, partly changed (in the case of a vector),
// or could not be determined to be equal (in the case of a constant
// expression), so the constant cannot be represented in the shorter type.
- // Consequently, we cannot emit a simple comparison.
// All the cases that fold to true or false will have already been handled
// by SimplifyICmpInst, so only deal with the tricky case.
+ if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C))
+ return nullptr;
+
+ // Is source op positive?
+ // icmp ult (sext X), C --> icmp sgt X, -1
+ if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
+ return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy));
+
+ // Is source op negative?
+ // icmp ugt (sext X), C --> icmp slt X, 0
+ assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
+ return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy));
+}
- if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C))
+/// Handle icmp (cast x), (cast or constant).
+Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
+ auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
+ if (!CastOp0)
+ return nullptr;
+ if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1)))
return nullptr;
- // Evaluate the comparison for LT (we invert for GT below). LE and GE cases
- // should have been folded away previously and not enter in here.
+ Value *Op0Src = CastOp0->getOperand(0);
+ Type *SrcTy = CastOp0->getSrcTy();
+ Type *DestTy = CastOp0->getDestTy();
- // We're performing an unsigned comp with a sign extended value.
- // This is true if the input is >= 0. [aka >s -1]
- Constant *NegOne = Constant::getAllOnesValue(SrcTy);
- Value *Result = Builder.CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName());
+ // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
+ // integer type is the same size as the pointer type.
+ auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) {
+ if (isa<VectorType>(SrcTy)) {
+ SrcTy = cast<VectorType>(SrcTy)->getElementType();
+ DestTy = cast<VectorType>(DestTy)->getElementType();
+ }
+ return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
+ };
+ if (CastOp0->getOpcode() == Instruction::PtrToInt &&
+ CompatibleSizes(SrcTy, DestTy)) {
+ Value *NewOp1 = nullptr;
+ if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
+ Value *PtrSrc = PtrToIntOp1->getOperand(0);
+ if (PtrSrc->getType()->getPointerAddressSpace() ==
+ Op0Src->getType()->getPointerAddressSpace()) {
+ NewOp1 = PtrToIntOp1->getOperand(0);
+ // If the pointer types don't match, insert a bitcast.
+ if (Op0Src->getType() != NewOp1->getType())
+ NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType());
+ }
+ } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
+ NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy);
+ }
- // Finally, return the value computed.
- if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
- return replaceInstUsesWith(ICmp, Result);
+ if (NewOp1)
+ return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
+ }
- assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
- return BinaryOperator::CreateNot(Result);
+ return foldICmpWithZextOrSext(ICmp, Builder);
}
static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
@@ -4791,41 +5148,35 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
return nullptr;
}
-/// If we have an icmp le or icmp ge instruction with a constant operand, turn
-/// it into the appropriate icmp lt or icmp gt instruction. This transform
-/// allows them to be folded in visitICmpInst.
-static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
- ICmpInst::Predicate Pred = I.getPredicate();
- if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE &&
- Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE)
- return nullptr;
+llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
+llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
+ Constant *C) {
+ assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
+ "Only for relational integer predicates.");
- Value *Op0 = I.getOperand(0);
- Value *Op1 = I.getOperand(1);
- auto *Op1C = dyn_cast<Constant>(Op1);
- if (!Op1C)
- return nullptr;
+ Type *Type = C->getType();
+ bool IsSigned = ICmpInst::isSigned(Pred);
+
+ CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
+ bool WillIncrement =
+ UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
- // Check if the constant operand can be safely incremented/decremented without
- // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled
- // the edge cases for us, so we just assert on them. For vectors, we must
- // handle the edge cases.
- Type *Op1Type = Op1->getType();
- bool IsSigned = I.isSigned();
- bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE);
- auto *CI = dyn_cast<ConstantInt>(Op1C);
- if (CI) {
- // A <= MAX -> TRUE ; A >= MIN -> TRUE
- assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned));
- } else if (Op1Type->isVectorTy()) {
- // TODO? If the edge cases for vectors were guaranteed to be handled as they
- // are for scalar, we could remove the min/max checks. However, to do that,
- // we would have to use insertelement/shufflevector to replace edge values.
- unsigned NumElts = Op1Type->getVectorNumElements();
+ // Check if the constant operand can be safely incremented/decremented
+ // without overflowing/underflowing.
+ auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
+ return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
+ };
+
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
+ // Bail out if the constant can't be safely incremented/decremented.
+ if (!ConstantIsOk(CI))
+ return llvm::None;
+ } else if (Type->isVectorTy()) {
+ unsigned NumElts = Type->getVectorNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = Op1C->getAggregateElement(i);
+ Constant *Elt = C->getAggregateElement(i);
if (!Elt)
- return nullptr;
+ return llvm::None;
if (isa<UndefValue>(Elt))
continue;
@@ -4833,20 +5184,43 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
// Bail out if we can't determine if this constant is min/max or if we
// know that this constant is min/max.
auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned)))
- return nullptr;
+ if (!CI || !ConstantIsOk(CI))
+ return llvm::None;
}
} else {
// ConstantExpr?
- return nullptr;
+ return llvm::None;
}
- // Increment or decrement the constant and set the new comparison predicate:
- // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT
- Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true);
- CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT;
- NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred;
- return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne));
+ CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
+
+ // Increment or decrement the constant.
+ Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
+ Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
+
+ return std::make_pair(NewPred, NewC);
+}
+
+/// If we have an icmp le or icmp ge instruction with a constant operand, turn
+/// it into the appropriate icmp lt or icmp gt instruction. This transform
+/// allows them to be folded in visitICmpInst.
+static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
+ ICmpInst::Predicate Pred = I.getPredicate();
+ if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) ||
+ isCanonicalPredicate(Pred))
+ return nullptr;
+
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ auto *Op1C = dyn_cast<Constant>(Op1);
+ if (!Op1C)
+ return nullptr;
+
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+ if (!FlippedStrictness)
+ return nullptr;
+
+ return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second);
}
/// Integer compare with boolean values can always be turned into bitwise ops.
@@ -5002,6 +5376,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
bool Changed = false;
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
unsigned Op0Cplxity = getComplexity(Op0);
unsigned Op1Cplxity = getComplexity(Op1);
@@ -5016,8 +5391,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
Changed = true;
}
- if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1,
- SQ.getWithInstruction(&I)))
+ if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, Q))
return replaceInstUsesWith(I, V);
// Comparing -val or val with non-zero is the same as just comparing val
@@ -5050,6 +5424,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpWithDominatingICmp(I))
return Res;
+ if (Instruction *Res = foldICmpBinOp(I, Q))
+ return Res;
+
if (Instruction *Res = foldICmpUsingKnownBits(I))
return Res;
@@ -5098,6 +5475,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpInstWithConstant(I))
return Res;
+ // Try to match comparison as a sign bit test. Intentionally do this after
+ // foldICmpInstWithConstant() to potentially let other folds to happen first.
+ if (Instruction *New = foldSignBitTest(I))
+ return New;
+
if (Instruction *Res = foldICmpInstWithConstantNotInt(I))
return Res;
@@ -5124,20 +5506,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpBitCast(I, Builder))
return Res;
- if (isa<CastInst>(Op0)) {
- // Handle the special case of: icmp (cast bool to X), <cst>
- // This comes up when you have code like
- // int X = A < B;
- // if (X) ...
- // For generality, we handle any zero-extension of any operand comparison
- // with a constant or another cast from the same type.
- if (isa<Constant>(Op1) || isa<CastInst>(Op1))
- if (Instruction *R = foldICmpWithCastAndCast(I))
- return R;
- }
-
- if (Instruction *Res = foldICmpBinOp(I))
- return Res;
+ if (Instruction *R = foldICmpWithCastOp(I))
+ return R;
if (Instruction *Res = foldICmpWithMinMax(I))
return Res;
diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index 434b0d591215..1dbc06d92e7a 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -113,6 +113,48 @@ static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) {
}
}
+/// Given an exploded icmp instruction, return true if the comparison only
+/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the
+/// result of the comparison is true when the input value is signed.
+inline bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
+ bool &TrueIfSigned) {
+ switch (Pred) {
+ case ICmpInst::ICMP_SLT: // True if LHS s< 0
+ TrueIfSigned = true;
+ return RHS.isNullValue();
+ case ICmpInst::ICMP_SLE: // True if LHS s<= -1
+ TrueIfSigned = true;
+ return RHS.isAllOnesValue();
+ case ICmpInst::ICMP_SGT: // True if LHS s> -1
+ TrueIfSigned = false;
+ return RHS.isAllOnesValue();
+ case ICmpInst::ICMP_SGE: // True if LHS s>= 0
+ TrueIfSigned = false;
+ return RHS.isNullValue();
+ case ICmpInst::ICMP_UGT:
+ // True if LHS u> RHS and RHS == sign-bit-mask - 1
+ TrueIfSigned = true;
+ return RHS.isMaxSignedValue();
+ case ICmpInst::ICMP_UGE:
+ // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
+ TrueIfSigned = true;
+ return RHS.isMinSignedValue();
+ case ICmpInst::ICMP_ULT:
+ // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
+ TrueIfSigned = false;
+ return RHS.isMinSignedValue();
+ case ICmpInst::ICMP_ULE:
+ // True if LHS u<= RHS and RHS == sign-bit-mask - 1
+ TrueIfSigned = false;
+ return RHS.isMaxSignedValue();
+ default:
+ return false;
+ }
+}
+
+llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
+getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C);
+
/// Return the source operand of a potentially bitcasted value while optionally
/// checking if it has one use. If there is no bitcast or the one use check is
/// not met, return the input value itself.
@@ -139,32 +181,17 @@ static inline Constant *SubOne(Constant *C) {
/// This happens in cases where the ~ can be eliminated. If WillInvertAllUses
/// is true, work under the assumption that the caller intends to remove all
/// uses of V and only keep uses of ~V.
-static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) {
+///
+/// See also: canFreelyInvertAllUsersOf()
+static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) {
// ~(~(X)) -> X.
if (match(V, m_Not(m_Value())))
return true;
// Constants can be considered to be not'ed values.
- if (isa<ConstantInt>(V))
+ if (match(V, m_AnyIntegralConstant()))
return true;
- // A vector of constant integers can be inverted easily.
- if (V->getType()->isVectorTy() && isa<Constant>(V)) {
- unsigned NumElts = V->getType()->getVectorNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
- if (!Elt)
- return false;
-
- if (isa<UndefValue>(Elt))
- continue;
-
- if (!isa<ConstantInt>(Elt))
- return false;
- }
- return true;
- }
-
// Compares can be inverted if all of their uses are being modified to use the
// ~V.
if (isa<CmpInst>(V))
@@ -185,6 +212,32 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) {
return false;
}
+/// Given i1 V, can every user of V be freely adapted if V is changed to !V ?
+///
+/// See also: isFreeToInvert()
+static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) {
+ // Look at every user of V.
+ for (User *U : V->users()) {
+ if (U == IgnoredUser)
+ continue; // Don't consider this user.
+
+ auto *I = cast<Instruction>(U);
+ switch (I->getOpcode()) {
+ case Instruction::Select:
+ case Instruction::Br:
+ break; // Free to invert by swapping true/false values/destinations.
+ case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it.
+ if (!match(I, m_Not(m_Value())))
+ return false; // Not a 'not'.
+ break;
+ default:
+ return false; // Don't know, likely not freely invertible.
+ }
+ // So far all users were free to invert...
+ }
+ return true; // Can freely invert all users!
+}
+
/// Some binary operators require special handling to avoid poison and undefined
/// behavior. If a constant vector has undef elements, replace those undefs with
/// identity constants if possible because those are always safe to execute.
@@ -337,6 +390,13 @@ public:
Instruction *visitOr(BinaryOperator &I);
Instruction *visitXor(BinaryOperator &I);
Instruction *visitShl(BinaryOperator &I);
+ Value *reassociateShiftAmtsOfTwoSameDirectionShifts(
+ BinaryOperator *Sh0, const SimplifyQuery &SQ,
+ bool AnalyzeForSignBitExtraction = false);
+ Instruction *canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
+ BinaryOperator &I);
+ Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
+ BinaryOperator &OldAShr);
Instruction *visitAShr(BinaryOperator &I);
Instruction *visitLShr(BinaryOperator &I);
Instruction *commonShiftTransforms(BinaryOperator &I);
@@ -541,6 +601,7 @@ private:
Instruction *narrowMathIfNoOverflow(BinaryOperator &I);
Instruction *narrowRotate(TruncInst &Trunc);
Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
+ Instruction *matchSAddSubSat(SelectInst &MinMax1);
/// Determine if a pair of casts can be replaced by a single cast.
///
@@ -557,7 +618,7 @@ private:
Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
- Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS);
+ Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &I);
/// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp).
/// NOTE: Unlike most of instcombine, this returns a Value which should
@@ -725,7 +786,7 @@ public:
Value *LHS, Value *RHS, Instruction *CxtI) const;
/// Maximum size of array considered when transforming.
- uint64_t MaxArraySizeForCombine;
+ uint64_t MaxArraySizeForCombine = 0;
private:
/// Performs a few simplifications for operators which are associative
@@ -798,7 +859,8 @@ private:
int DmaskIdx = -1);
Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
- APInt &UndefElts, unsigned Depth = 0);
+ APInt &UndefElts, unsigned Depth = 0,
+ bool AllowMultipleUsers = false);
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
@@ -847,17 +909,21 @@ private:
Constant *RHSC);
Instruction *foldICmpAddOpConst(Value *X, const APInt &C,
ICmpInst::Predicate Pred);
- Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
+ Instruction *foldICmpWithCastOp(ICmpInst &ICI);
Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
Instruction *foldICmpWithConstant(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
- Instruction *foldICmpBinOp(ICmpInst &Cmp);
+ Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
Instruction *foldICmpEquality(ICmpInst &Cmp);
+ Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
+ Instruction *foldSignBitTest(ICmpInst &I);
Instruction *foldICmpWithZero(ICmpInst &Cmp);
+ Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp);
+
Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select,
ConstantInt *C);
Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc,
@@ -874,6 +940,8 @@ private:
const APInt &C);
Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
const APInt &C);
+ Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+ const APInt &C);
Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
const APInt &C);
Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,
diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 054fb7da09a2..3a0e05832fcb 100644
--- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -175,7 +175,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI,
uint64_t AllocaSize = DL.getTypeStoreSize(AI->getAllocatedType());
if (!AllocaSize)
return false;
- return isDereferenceableAndAlignedPointer(V, AI->getAlignment(),
+ return isDereferenceableAndAlignedPointer(V, Align(AI->getAlignment()),
APInt(64, AllocaSize), DL);
}
@@ -197,7 +197,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) {
if (C->getValue().getActiveBits() <= 64) {
Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue());
AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName());
- New->setAlignment(AI.getAlignment());
+ New->setAlignment(MaybeAlign(AI.getAlignment()));
// Scan to the end of the allocation instructions, to skip over a block of
// allocas if possible...also skip interleaved debug info
@@ -345,7 +345,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
if (AI.getAllocatedType()->isSized()) {
// If the alignment is 0 (unspecified), assign it the preferred alignment.
if (AI.getAlignment() == 0)
- AI.setAlignment(DL.getPrefTypeAlignment(AI.getAllocatedType()));
+ AI.setAlignment(
+ MaybeAlign(DL.getPrefTypeAlignment(AI.getAllocatedType())));
// Move all alloca's of zero byte objects to the entry block and merge them
// together. Note that we only do this for alloca's, because malloc should
@@ -377,12 +378,12 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
// assign it the preferred alignment.
if (EntryAI->getAlignment() == 0)
EntryAI->setAlignment(
- DL.getPrefTypeAlignment(EntryAI->getAllocatedType()));
+ MaybeAlign(DL.getPrefTypeAlignment(EntryAI->getAllocatedType())));
// Replace this zero-sized alloca with the one at the start of the entry
// block after ensuring that the address will be aligned enough for both
// types.
- unsigned MaxAlign = std::max(EntryAI->getAlignment(),
- AI.getAlignment());
+ const MaybeAlign MaxAlign(
+ std::max(EntryAI->getAlignment(), AI.getAlignment()));
EntryAI->setAlignment(MaxAlign);
if (AI.getType() != EntryAI->getType())
return new BitCastInst(EntryAI, AI.getType());
@@ -455,9 +456,6 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT
Value *Ptr = LI.getPointerOperand();
unsigned AS = LI.getPointerAddressSpace();
- SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
- LI.getAllMetadata(MD);
-
Value *NewPtr = nullptr;
if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) &&
NewPtr->getType()->getPointerElementType() == NewTy &&
@@ -467,48 +465,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT
LoadInst *NewLoad = IC.Builder.CreateAlignedLoad(
NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix);
NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
- MDBuilder MDB(NewLoad->getContext());
- for (const auto &MDPair : MD) {
- unsigned ID = MDPair.first;
- MDNode *N = MDPair.second;
- // Note, essentially every kind of metadata should be preserved here! This
- // routine is supposed to clone a load instruction changing *only its type*.
- // The only metadata it makes sense to drop is metadata which is invalidated
- // when the pointer type changes. This should essentially never be the case
- // in LLVM, but we explicitly switch over only known metadata to be
- // conservatively correct. If you are adding metadata to LLVM which pertains
- // to loads, you almost certainly want to add it here.
- switch (ID) {
- case LLVMContext::MD_dbg:
- case LLVMContext::MD_tbaa:
- case LLVMContext::MD_prof:
- case LLVMContext::MD_fpmath:
- case LLVMContext::MD_tbaa_struct:
- case LLVMContext::MD_invariant_load:
- case LLVMContext::MD_alias_scope:
- case LLVMContext::MD_noalias:
- case LLVMContext::MD_nontemporal:
- case LLVMContext::MD_mem_parallel_loop_access:
- case LLVMContext::MD_access_group:
- // All of these directly apply.
- NewLoad->setMetadata(ID, N);
- break;
-
- case LLVMContext::MD_nonnull:
- copyNonnullMetadata(LI, N, *NewLoad);
- break;
- case LLVMContext::MD_align:
- case LLVMContext::MD_dereferenceable:
- case LLVMContext::MD_dereferenceable_or_null:
- // These only directly apply if the new type is also a pointer.
- if (NewTy->isPointerTy())
- NewLoad->setMetadata(ID, N);
- break;
- case LLVMContext::MD_range:
- copyRangeMetadata(IC.getDataLayout(), LI, N, *NewLoad);
- break;
- }
- }
+ copyMetadataForLoad(*NewLoad, LI);
return NewLoad;
}
@@ -1004,9 +961,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType());
if (KnownAlign > EffectiveLoadAlign)
- LI.setAlignment(KnownAlign);
+ LI.setAlignment(MaybeAlign(KnownAlign));
else if (LoadAlign == 0)
- LI.setAlignment(EffectiveLoadAlign);
+ LI.setAlignment(MaybeAlign(EffectiveLoadAlign));
// Replace GEP indices if possible.
if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) {
@@ -1063,11 +1020,11 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
//
if (SelectInst *SI = dyn_cast<SelectInst>(Op)) {
// load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2).
- unsigned Align = LI.getAlignment();
- if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), Align,
- DL, SI) &&
- isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), Align,
- DL, SI)) {
+ const MaybeAlign Alignment(LI.getAlignment());
+ if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(),
+ Alignment, DL, SI) &&
+ isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(),
+ Alignment, DL, SI)) {
LoadInst *V1 =
Builder.CreateLoad(LI.getType(), SI->getOperand(1),
SI->getOperand(1)->getName() + ".val");
@@ -1075,9 +1032,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
Builder.CreateLoad(LI.getType(), SI->getOperand(2),
SI->getOperand(2)->getName() + ".val");
assert(LI.isUnordered() && "implied by above");
- V1->setAlignment(Align);
+ V1->setAlignment(Alignment);
V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
- V2->setAlignment(Align);
+ V2->setAlignment(Alignment);
V2->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
return SelectInst::Create(SI->getCondition(), V1, V2);
}
@@ -1399,15 +1356,15 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
return eraseInstFromFunction(SI);
// Attempt to improve the alignment.
- unsigned KnownAlign = getOrEnforceKnownAlignment(
- Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT);
- unsigned StoreAlign = SI.getAlignment();
- unsigned EffectiveStoreAlign =
- StoreAlign != 0 ? StoreAlign : DL.getABITypeAlignment(Val->getType());
+ const Align KnownAlign = Align(getOrEnforceKnownAlignment(
+ Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT));
+ const MaybeAlign StoreAlign = MaybeAlign(SI.getAlignment());
+ const Align EffectiveStoreAlign =
+ StoreAlign ? *StoreAlign : Align(DL.getABITypeAlignment(Val->getType()));
if (KnownAlign > EffectiveStoreAlign)
SI.setAlignment(KnownAlign);
- else if (StoreAlign == 0)
+ else if (!StoreAlign)
SI.setAlignment(EffectiveStoreAlign);
// Try to canonicalize the stored type.
@@ -1622,8 +1579,8 @@ bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) {
// Advance to a place where it is safe to insert the new store and insert it.
BBI = DestBB->getFirstInsertionPt();
- StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1),
- SI.isVolatile(), SI.getAlignment(),
+ StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(),
+ MaybeAlign(SI.getAlignment()),
SI.getOrdering(), SI.getSyncScopeID());
InsertNewInstBefore(NewSI, *BBI);
NewSI->setDebugLoc(MergedLoc);
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index cc753ce05313..0b9128a9f5a1 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -124,6 +124,50 @@ static Constant *getLogBase2(Type *Ty, Constant *C) {
return ConstantVector::get(Elts);
}
+// TODO: This is a specific form of a much more general pattern.
+// We could detect a select with any binop identity constant, or we
+// could use SimplifyBinOp to see if either arm of the select reduces.
+// But that needs to be done carefully and/or while removing potential
+// reverse canonicalizations as in InstCombiner::foldSelectIntoOp().
+static Value *foldMulSelectToNegate(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Cond, *OtherOp;
+
+ // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp
+ // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp
+ if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())),
+ m_Value(OtherOp))))
+ return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp));
+
+ // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp
+ // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp
+ if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())),
+ m_Value(OtherOp))))
+ return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp);
+
+ // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp
+ // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp
+ if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0),
+ m_SpecificFP(-1.0))),
+ m_Value(OtherOp)))) {
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(I.getFastMathFlags());
+ return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp));
+ }
+
+ // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp
+ // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp
+ if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0),
+ m_SpecificFP(1.0))),
+ m_Value(OtherOp)))) {
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(I.getFastMathFlags());
+ return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
@@ -213,6 +257,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I))
return FoldedMul;
+ if (Value *FoldedMul = foldMulSelectToNegate(I, Builder))
+ return replaceInstUsesWith(I, FoldedMul);
+
// Simplify mul instructions with a constant RHS.
if (isa<Constant>(Op1)) {
// Canonicalize (X+C1)*CI -> X*CI+C1*CI.
@@ -358,6 +405,9 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I))
return FoldedMul;
+ if (Value *FoldedMul = foldMulSelectToNegate(I, Builder))
+ return replaceInstUsesWith(I, FoldedMul);
+
// X * -1.0 --> -X
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (match(Op1, m_SpecificFP(-1.0)))
@@ -373,16 +423,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C)))
return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
- // Sink negation: -X * Y --> -(X * Y)
- // But don't transform constant expressions because there's an inverse fold.
- if (match(Op0, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op0))
- return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I);
-
- // Sink negation: Y * -X --> -(X * Y)
- // But don't transform constant expressions because there's an inverse fold.
- if (match(Op1, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op1))
- return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I);
-
// fabs(X) * fabs(X) -> X * X
if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
return BinaryOperator::CreateFMulFMF(X, X, &I);
@@ -1211,8 +1251,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
!IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) &&
match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X)));
- if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan,
- LibFunc_tanf, LibFunc_tanl)) {
+ if ((IsTan || IsCot) &&
+ hasFloatFn(&TLI, I.getType(), LibFunc_tan, LibFunc_tanf, LibFunc_tanl)) {
IRBuilder<> B(&I);
IRBuilder<>::FastMathFlagGuard FMFGuard(B);
B.setFastMathFlags(I.getFastMathFlags());
@@ -1244,6 +1284,17 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
return &I;
}
+ // X / fabs(X) -> copysign(1.0, X)
+ // fabs(X) / X -> copysign(1.0, X)
+ if (I.hasNoNaNs() && I.hasNoInfs() &&
+ (match(&I,
+ m_FDiv(m_Value(X), m_Intrinsic<Intrinsic::fabs>(m_Deferred(X)))) ||
+ match(&I, m_FDiv(m_Intrinsic<Intrinsic::fabs>(m_Value(X)),
+ m_Deferred(X))))) {
+ Value *V = Builder.CreateBinaryIntrinsic(
+ Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I);
+ return replaceInstUsesWith(I, V);
+ }
return nullptr;
}
@@ -1309,6 +1360,8 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {
+ // This may increase instruction count, we don't enforce that Y is a
+ // constant.
Constant *N1 = Constant::getAllOnesValue(Ty);
Value *Add = Builder.CreateAdd(Op1, N1);
return BinaryOperator::CreateAnd(Op0, Add);
diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 5820ab726637..e0376b7582f3 100644
--- a/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -542,7 +542,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) {
// visitLoadInst will propagate an alignment onto the load when TD is around,
// and if TD isn't around, we can't handle the mixed case.
bool isVolatile = FirstLI->isVolatile();
- unsigned LoadAlignment = FirstLI->getAlignment();
+ MaybeAlign LoadAlignment(FirstLI->getAlignment());
unsigned LoadAddrSpace = FirstLI->getPointerAddressSpace();
// We can't sink the load if the loaded value could be modified between the
@@ -574,10 +574,10 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) {
// If some of the loads have an alignment specified but not all of them,
// we can't do the transformation.
- if ((LoadAlignment != 0) != (LI->getAlignment() != 0))
+ if ((LoadAlignment.hasValue()) != (LI->getAlignment() != 0))
return nullptr;
- LoadAlignment = std::min(LoadAlignment, LI->getAlignment());
+ LoadAlignment = std::min(LoadAlignment, MaybeAlign(LI->getAlignment()));
// If the PHI is of volatile loads and the load block has multiple
// successors, sinking it would remove a load of the volatile value from
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aefaf5af1750..9fc871e49b30 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -785,6 +785,41 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+/// Fold the following code sequence:
+/// \code
+/// int a = ctlz(x & -x);
+// x ? 31 - a : a;
+/// \code
+///
+/// into:
+/// cttz(x)
+static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal,
+ Value *FalseVal,
+ InstCombiner::BuilderTy &Builder) {
+ unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits();
+ if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero()))
+ return nullptr;
+
+ if (ICI->getPredicate() == ICmpInst::ICMP_NE)
+ std::swap(TrueVal, FalseVal);
+
+ if (!match(FalseVal,
+ m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1))))
+ return nullptr;
+
+ if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>()))
+ return nullptr;
+
+ Value *X = ICI->getOperand(0);
+ auto *II = cast<IntrinsicInst>(TrueVal);
+ if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X)))))
+ return nullptr;
+
+ Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz,
+ II->getType());
+ return CallInst::Create(F, {X, II->getArgOperand(1)});
+}
+
/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single
/// call to cttz/ctlz with flag 'is_zero_undef' cleared.
///
@@ -973,8 +1008,7 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
// If we are swapping the select operands, swap the metadata too.
assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS &&
"Unexpected results from matchSelectPattern");
- Sel.setTrueValue(LHS);
- Sel.setFalseValue(RHS);
+ Sel.swapValues();
Sel.swapProfMetadata();
return &Sel;
}
@@ -1056,17 +1090,293 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
}
// We are swapping the select operands, so swap the metadata too.
- Sel.setTrueValue(FVal);
- Sel.setFalseValue(TVal);
+ Sel.swapValues();
Sel.swapProfMetadata();
return &Sel;
}
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
+ const SimplifyQuery &Q) {
+ // If this is a binary operator, try to simplify it with the replaced op
+ // because we know Op and ReplaceOp are equivalant.
+ // For example: V = X + 1, Op = X, ReplaceOp = 42
+ // Simplifies as: add(42, 1) --> 43
+ if (auto *BO = dyn_cast<BinaryOperator>(V)) {
+ if (BO->getOperand(0) == Op)
+ return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
+ if (BO->getOperand(1) == Op)
+ return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
+ }
+
+ return nullptr;
+}
+
+/// If we have a select with an equality comparison, then we know the value in
+/// one of the arms of the select. See if substituting this value into an arm
+/// and simplifying the result yields the same value as the other arm.
+///
+/// To make this transform safe, we must drop poison-generating flags
+/// (nsw, etc) if we simplified to a binop because the select may be guarding
+/// that poison from propagating. If the existing binop already had no
+/// poison-generating flags, then this transform can be done by instsimplify.
+///
+/// Consider:
+/// %cmp = icmp eq i32 %x, 2147483647
+/// %add = add nsw i32 %x, 1
+/// %sel = select i1 %cmp, i32 -2147483648, i32 %add
+///
+/// We can't replace %sel with %add unless we strip away the flags.
+/// TODO: Wrapping flags could be preserved in some cases with better analysis.
+static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
+ const SimplifyQuery &Q) {
+ if (!Cmp.isEquality())
+ return nullptr;
+
+ // Canonicalize the pattern to ICMP_EQ by swapping the select operands.
+ Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
+ if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
+ std::swap(TrueVal, FalseVal);
+
+ // Try each equivalence substitution possibility.
+ // We have an 'EQ' comparison, so the select's false value will propagate.
+ // Example:
+ // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
+ // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
+ Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
+ if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
+ simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
+ simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
+ simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
+ if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
+ FalseInst->dropPoisonGeneratingFlags();
+ return FalseVal;
+ }
+ return nullptr;
+}
+
+// See if this is a pattern like:
+// %old_cmp1 = icmp slt i32 %x, C2
+// %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high
+// %old_x_offseted = add i32 %x, C1
+// %old_cmp0 = icmp ult i32 %old_x_offseted, C0
+// %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement
+// This can be rewritten as more canonical pattern:
+// %new_cmp1 = icmp slt i32 %x, -C1
+// %new_cmp2 = icmp sge i32 %x, C0-C1
+// %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x
+// %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low
+// Iff -C1 s<= C2 s<= C0-C1
+// Also ULT predicate can also be UGT iff C0 != -1 (+invert result)
+// SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.)
+static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X = Sel0.getTrueValue();
+ Value *Sel1 = Sel0.getFalseValue();
+
+ // First match the condition of the outermost select.
+ // Said condition must be one-use.
+ if (!Cmp0.hasOneUse())
+ return nullptr;
+ Value *Cmp00 = Cmp0.getOperand(0);
+ Constant *C0;
+ if (!match(Cmp0.getOperand(1),
+ m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0))))
+ return nullptr;
+ // Canonicalize Cmp0 into the form we expect.
+ // FIXME: we shouldn't care about lanes that are 'undef' in the end?
+ switch (Cmp0.getPredicate()) {
+ case ICmpInst::Predicate::ICMP_ULT:
+ break; // Great!
+ case ICmpInst::Predicate::ICMP_ULE:
+ // We'd have to increment C0 by one, and for that it must not have all-ones
+ // element, but then it would have been canonicalized to 'ult' before
+ // we get here. So we can't do anything useful with 'ule'.
+ return nullptr;
+ case ICmpInst::Predicate::ICMP_UGT:
+ // We want to canonicalize it to 'ult', so we'll need to increment C0,
+ // which again means it must not have any all-ones elements.
+ if (!match(C0,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE,
+ APInt::getAllOnesValue(
+ C0->getType()->getScalarSizeInBits()))))
+ return nullptr; // Can't do, have all-ones element[s].
+ C0 = AddOne(C0);
+ std::swap(X, Sel1);
+ break;
+ case ICmpInst::Predicate::ICMP_UGE:
+ // The only way we'd get this predicate if this `icmp` has extra uses,
+ // but then we won't be able to do this fold.
+ return nullptr;
+ default:
+ return nullptr; // Unknown predicate.
+ }
+
+ // Now that we've canonicalized the ICmp, we know the X we expect;
+ // the select in other hand should be one-use.
+ if (!Sel1->hasOneUse())
+ return nullptr;
+
+ // We now can finish matching the condition of the outermost select:
+ // it should either be the X itself, or an addition of some constant to X.
+ Constant *C1;
+ if (Cmp00 == X)
+ C1 = ConstantInt::getNullValue(Sel0.getType());
+ else if (!match(Cmp00,
+ m_Add(m_Specific(X),
+ m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1)))))
+ return nullptr;
+
+ Value *Cmp1;
+ ICmpInst::Predicate Pred1;
+ Constant *C2;
+ Value *ReplacementLow, *ReplacementHigh;
+ if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow),
+ m_Value(ReplacementHigh))) ||
+ !match(Cmp1,
+ m_ICmp(Pred1, m_Specific(X),
+ m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2)))))
+ return nullptr;
+
+ if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse()))
+ return nullptr; // Not enough one-use instructions for the fold.
+ // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of
+ // two comparisons we'll need to build.
+
+ // Canonicalize Cmp1 into the form we expect.
+ // FIXME: we shouldn't care about lanes that are 'undef' in the end?
+ switch (Pred1) {
+ case ICmpInst::Predicate::ICMP_SLT:
+ break;
+ case ICmpInst::Predicate::ICMP_SLE:
+ // We'd have to increment C2 by one, and for that it must not have signed
+ // max element, but then it would have been canonicalized to 'slt' before
+ // we get here. So we can't do anything useful with 'sle'.
+ return nullptr;
+ case ICmpInst::Predicate::ICMP_SGT:
+ // We want to canonicalize it to 'slt', so we'll need to increment C2,
+ // which again means it must not have any signed max elements.
+ if (!match(C2,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE,
+ APInt::getSignedMaxValue(
+ C2->getType()->getScalarSizeInBits()))))
+ return nullptr; // Can't do, have signed max element[s].
+ C2 = AddOne(C2);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::Predicate::ICMP_SGE:
+ // Also non-canonical, but here we don't need to change C2,
+ // so we don't have any restrictions on C2, so we can just handle it.
+ std::swap(ReplacementLow, ReplacementHigh);
+ break;
+ default:
+ return nullptr; // Unknown predicate.
+ }
+
+ // The thresholds of this clamp-like pattern.
+ auto *ThresholdLowIncl = ConstantExpr::getNeg(C1);
+ auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1);
+
+ // The fold has a precondition 1: C2 s>= ThresholdLow
+ auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
+ ThresholdLowIncl);
+ if (!match(Precond1, m_One()))
+ return nullptr;
+ // The fold has a precondition 2: C2 s<= ThresholdHigh
+ auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
+ ThresholdHighExcl);
+ if (!match(Precond2, m_One()))
+ return nullptr;
+
+ // All good, finally emit the new pattern.
+ Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl);
+ Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl);
+ Value *MaybeReplacedLow =
+ Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X);
+ Instruction *MaybeReplacedHigh =
+ SelectInst::Create(ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow);
+
+ return MaybeReplacedHigh;
+}
+
+// If we have
+// %cmp = icmp [canonical predicate] i32 %x, C0
+// %r = select i1 %cmp, i32 %y, i32 C1
+// Where C0 != C1 and %x may be different from %y, see if the constant that we
+// will have if we flip the strictness of the predicate (i.e. without changing
+// the result) is identical to the C1 in select. If it matches we can change
+// original comparison to one with swapped predicate, reuse the constant,
+// and swap the hands of select.
+static Instruction *
+tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
+ InstCombiner::BuilderTy &Builder) {
+ ICmpInst::Predicate Pred;
+ Value *X;
+ Constant *C0;
+ if (!match(&Cmp, m_OneUse(m_ICmp(
+ Pred, m_Value(X),
+ m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0))))))
+ return nullptr;
+
+ // If comparison predicate is non-relational, we won't be able to do anything.
+ if (ICmpInst::isEquality(Pred))
+ return nullptr;
+
+ // If comparison predicate is non-canonical, then we certainly won't be able
+ // to make it canonical; canonicalizeCmpWithConstant() already tried.
+ if (!isCanonicalPredicate(Pred))
+ return nullptr;
+
+ // If the [input] type of comparison and select type are different, lets abort
+ // for now. We could try to compare constants with trunc/[zs]ext though.
+ if (C0->getType() != Sel.getType())
+ return nullptr;
+
+ // FIXME: are there any magic icmp predicate+constant pairs we must not touch?
+
+ Value *SelVal0, *SelVal1; // We do not care which one is from where.
+ match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
+ // At least one of these values we are selecting between must be a constant
+ // else we'll never succeed.
+ if (!match(SelVal0, m_AnyIntegralConstant()) &&
+ !match(SelVal1, m_AnyIntegralConstant()))
+ return nullptr;
+
+ // Does this constant C match any of the `select` values?
+ auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) {
+ return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1);
+ };
+
+ // If C0 *already* matches true/false value of select, we are done.
+ if (MatchesSelectValue(C0))
+ return nullptr;
+
+ // Check the constant we'd have with flipped-strictness predicate.
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
+ if (!FlippedStrictness)
+ return nullptr;
+
+ // If said constant doesn't match either, then there is no hope,
+ if (!MatchesSelectValue(FlippedStrictness->second))
+ return nullptr;
+
+ // It matched! Lets insert the new comparison just before select.
+ InstCombiner::BuilderTy::InsertPointGuard Guard(Builder);
+ Builder.SetInsertPoint(&Sel);
+
+ Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped.
+ Value *NewCmp = Builder.CreateICmp(Pred, X, FlippedStrictness->second,
+ Cmp.getName() + ".inv");
+ Sel.setCondition(NewCmp);
+ Sel.swapValues();
+ Sel.swapProfMetadata();
+
+ return &Sel;
+}
+
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
- Value *TrueVal = SI.getTrueValue();
- Value *FalseVal = SI.getFalseValue();
+ if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ))
+ return replaceInstUsesWith(SI, V);
if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder))
return NewSel;
@@ -1074,12 +1384,21 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder))
return NewAbs;
+ if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder))
+ return NewAbs;
+
+ if (Instruction *NewSel =
+ tryToReuseConstantFromSelectInComparison(SI, *ICI, Builder))
+ return NewSel;
+
bool Changed = adjustMinMax(SI, *ICI);
if (Value *V = foldSelectICmpAnd(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
@@ -1149,6 +1468,9 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder))
return V;
+ if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder))
+ return V;
+
if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
@@ -1253,6 +1575,16 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
}
}
+ // max(max(A, B), min(A, B)) --> max(A, B)
+ // min(min(A, B), max(A, B)) --> min(A, B)
+ // TODO: This could be done in instsimplify.
+ if (SPF1 == SPF2 &&
+ ((SPF1 == SPF_UMIN && match(C, m_c_UMax(m_Specific(A), m_Specific(B)))) ||
+ (SPF1 == SPF_SMIN && match(C, m_c_SMax(m_Specific(A), m_Specific(B)))) ||
+ (SPF1 == SPF_UMAX && match(C, m_c_UMin(m_Specific(A), m_Specific(B)))) ||
+ (SPF1 == SPF_SMAX && match(C, m_c_SMin(m_Specific(A), m_Specific(B))))))
+ return replaceInstUsesWith(Outer, Inner);
+
// ABS(ABS(X)) -> ABS(X)
// NABS(NABS(X)) -> NABS(X)
// TODO: This could be done in instsimplify.
@@ -1280,7 +1612,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
return true;
}
- if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) {
+ if (isFreeToInvert(V, !V->hasNUsesOrMore(3))) {
NotV = nullptr;
return true;
}
@@ -1492,6 +1824,30 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
ConstantVector::get(Mask));
}
+/// If we have a select of vectors with a scalar condition, try to convert that
+/// to a vector select by splatting the condition. A splat may get folded with
+/// other operations in IR and having all operands of a select be vector types
+/// is likely better for vector codegen.
+static Instruction *canonicalizeScalarSelectOfVecs(
+ SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
+ Type *Ty = Sel.getType();
+ if (!Ty->isVectorTy())
+ return nullptr;
+
+ // We can replace a single-use extract with constant index.
+ Value *Cond = Sel.getCondition();
+ if (!match(Cond, m_OneUse(m_ExtractElement(m_Value(), m_ConstantInt()))))
+ return nullptr;
+
+ // select (extelt V, Index), T, F --> select (splat V, Index), T, F
+ // Splatting the extracted condition reduces code (we could directly create a
+ // splat shuffle of the source vector to eliminate the intermediate step).
+ unsigned NumElts = Ty->getVectorNumElements();
+ Value *SplatCond = Builder.CreateVectorSplat(NumElts, Cond);
+ Sel.setCondition(SplatCond);
+ return &Sel;
+}
+
/// Reuse bitcasted operands between a compare and select:
/// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) -->
/// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D))
@@ -1648,6 +2004,71 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X,
return nullptr;
}
+/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
+Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) {
+ Type *Ty = MinMax1.getType();
+
+ // We are looking for a tree of:
+ // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
+ // Where the min and max could be reversed
+ Instruction *MinMax2;
+ BinaryOperator *AddSub;
+ const APInt *MinValue, *MaxValue;
+ if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
+ if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
+ return nullptr;
+ } else if (match(&MinMax1,
+ m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
+ if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
+ return nullptr;
+ } else
+ return nullptr;
+
+ // Check that the constants clamp a saturate, and that the new type would be
+ // sensible to convert to.
+ if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
+ return nullptr;
+ // In what bitwidth can this be treated as saturating arithmetics?
+ unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
+ // FIXME: This isn't quite right for vectors, but using the scalar type is a
+ // good first approximation for what should be done there.
+ if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
+ return nullptr;
+
+ // Also make sure that the number of uses is as expected. The "3"s are for the
+ // the two items of min/max (the compare and the select).
+ if (MinMax2->hasNUsesOrMore(3) || AddSub->hasNUsesOrMore(3))
+ return nullptr;
+
+ // Create the new type (which can be a vector type)
+ Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);
+ // Match the two extends from the add/sub
+ Value *A, *B;
+ if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B)))))
+ return nullptr;
+ // And check the incoming values are of a type smaller than or equal to the
+ // size of the saturation. Otherwise the higher bits can cause different
+ // results.
+ if (A->getType()->getScalarSizeInBits() > NewBitWidth ||
+ B->getType()->getScalarSizeInBits() > NewBitWidth)
+ return nullptr;
+
+ Intrinsic::ID IntrinsicID;
+ if (AddSub->getOpcode() == Instruction::Add)
+ IntrinsicID = Intrinsic::sadd_sat;
+ else if (AddSub->getOpcode() == Instruction::Sub)
+ IntrinsicID = Intrinsic::ssub_sat;
+ else
+ return nullptr;
+
+ // Finally create and return the sat intrinsic, truncated to the new type
+ Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy);
+ Value *AT = Builder.CreateSExt(A, NewTy);
+ Value *BT = Builder.CreateSExt(B, NewTy);
+ Value *Sat = Builder.CreateCall(F, {AT, BT});
+ return CastInst::Create(Instruction::SExt, Sat, Ty);
+}
+
/// Reduce a sequence of min/max with a common operand.
static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
Value *RHS,
@@ -1788,6 +2209,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *I = canonicalizeSelectToShuffle(SI))
return I;
+ if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, Builder))
+ return I;
+
// Canonicalize a one-use integer compare with a non-canonical predicate by
// inverting the predicate and swapping the select operands. This matches a
// compare canonicalization for conditional branches.
@@ -2013,16 +2437,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
(LHS->getType()->isFPOrFPVectorTy() &&
((CmpLHS != LHS && CmpLHS != RHS) ||
(CmpRHS != LHS && CmpRHS != RHS)))) {
- CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered);
+ CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
Value *Cmp;
- if (CmpInst::isIntPredicate(Pred)) {
- Cmp = Builder.CreateICmp(Pred, LHS, RHS);
+ if (CmpInst::isIntPredicate(MinMaxPred)) {
+ Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS);
} else {
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
- auto FMF = cast<FPMathOperator>(SI.getCondition())->getFastMathFlags();
+ auto FMF =
+ cast<FPMathOperator>(SI.getCondition())->getFastMathFlags();
Builder.setFastMathFlags(FMF);
- Cmp = Builder.CreateFCmp(Pred, LHS, RHS);
+ Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS);
}
Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI);
@@ -2040,9 +2465,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * {
Value *A;
if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) &&
- !IsFreeToInvert(A, A->hasOneUse()) &&
+ !isFreeToInvert(A, A->hasOneUse()) &&
// Passing false to only consider m_Not and constants.
- IsFreeToInvert(Y, false)) {
+ isFreeToInvert(Y, false)) {
Value *B = Builder.CreateNot(Y);
Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF),
A, B);
@@ -2070,6 +2495,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder))
return I;
+ if (Instruction *I = matchSAddSubSat(SI))
+ return I;
}
}
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c821292400cd..64294838644f 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -25,50 +25,275 @@ using namespace PatternMatch;
// we should rewrite it as
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
// This is valid for any shift, but they must be identical.
-static Instruction *
-reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
- const SimplifyQuery &SQ) {
- // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1
- Value *X, *ShAmt1, *ShAmt0;
+//
+// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
+// pattern has any 2 right-shifts that sum to 1 less than original bit width.
+Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
+ BinaryOperator *Sh0, const SimplifyQuery &SQ,
+ bool AnalyzeForSignBitExtraction) {
+ // Look for a shift of some instruction, ignore zext of shift amount if any.
+ Instruction *Sh0Op0;
+ Value *ShAmt0;
+ if (!match(Sh0,
+ m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0)))))
+ return nullptr;
+
+ // If there is a truncation between the two shifts, we must make note of it
+ // and look through it. The truncation imposes additional constraints on the
+ // transform.
Instruction *Sh1;
- if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)),
- m_Instruction(Sh1)),
- m_Value(ShAmt0))))
+ Value *Trunc = nullptr;
+ match(Sh0Op0,
+ m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)),
+ m_Instruction(Sh1)));
+
+ // Inner shift: (x shiftopcode ShAmt1)
+ // Like with other shift, ignore zext of shift amount if any.
+ Value *X, *ShAmt1;
+ if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
+ return nullptr;
+
+ // We have two shift amounts from two different shifts. The types of those
+ // shift amounts may not match. If that's the case let's bailout now..
+ if (ShAmt0->getType() != ShAmt1->getType())
+ return nullptr;
+
+ // We are only looking for signbit extraction if we have two right shifts.
+ bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
+ match(Sh1, m_Shr(m_Value(), m_Value()));
+ // ... and if it's not two right-shifts, we know the answer already.
+ if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
return nullptr;
- // The shift opcodes must be identical.
+ // The shift opcodes must be identical, unless we are just checking whether
+ // this pattern can be interpreted as a sign-bit-extraction.
Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
- if (ShiftOpcode != Sh1->getOpcode())
+ bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
+ if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
return nullptr;
+
+ // If we saw truncation, we'll need to produce extra instruction,
+ // and for that one of the operands of the shift must be one-use,
+ // unless of course we don't actually plan to produce any instructions here.
+ if (Trunc && !AnalyzeForSignBitExtraction &&
+ !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
// Can we fold (ShAmt0+ShAmt1) ?
- Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1,
- SQ.getWithInstruction(Sh0));
+ auto *NewShAmt = dyn_cast_or_null<Constant>(
+ SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false,
+ SQ.getWithInstruction(Sh0)));
if (!NewShAmt)
return nullptr; // Did not simplify.
- // Is the new shift amount smaller than the bit width?
- // FIXME: could also rely on ConstantRange.
- unsigned BitWidth = X->getType()->getScalarSizeInBits();
+ unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits();
+ unsigned XBitWidth = X->getType()->getScalarSizeInBits();
+ // Is the new shift amount smaller than the bit width of inner/new shift?
if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
- APInt(BitWidth, BitWidth))))
- return nullptr;
+ APInt(NewShAmtBitWidth, XBitWidth))))
+ return nullptr; // FIXME: could perform constant-folding.
+
+ // If there was a truncation, and we have a right-shift, we can only fold if
+ // we are left with the original sign bit. Likewise, if we were just checking
+ // that this is a sighbit extraction, this is the place to check it.
+ // FIXME: zero shift amount is also legal here, but we can't *easily* check
+ // more than one predicate so it's not really worth it.
+ if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
+ // If it's not a sign bit extraction, then we're done.
+ if (!match(NewShAmt,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(NewShAmtBitWidth, XBitWidth - 1))))
+ return nullptr;
+ // If it is, and that was the question, return the base value.
+ if (AnalyzeForSignBitExtraction)
+ return X;
+ }
+
+ assert(IdenticalShOpcodes && "Should not get here with different shifts.");
+
// All good, we can do this fold.
+ NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
+
BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
- // If both of the original shifts had the same flag set, preserve the flag.
- if (ShiftOpcode == Instruction::BinaryOps::Shl) {
- NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
- Sh1->hasNoUnsignedWrap());
- NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
- Sh1->hasNoSignedWrap());
- } else {
- NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
+
+ // The flags can only be propagated if there wasn't a trunc.
+ if (!Trunc) {
+ // If the pattern did not involve trunc, and both of the original shifts
+ // had the same flag set, preserve the flag.
+ if (ShiftOpcode == Instruction::BinaryOps::Shl) {
+ NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
+ Sh1->hasNoUnsignedWrap());
+ NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
+ Sh1->hasNoSignedWrap());
+ } else {
+ NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
+ }
+ }
+
+ Instruction *Ret = NewShift;
+ if (Trunc) {
+ Builder.Insert(NewShift);
+ Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType());
+ }
+
+ return Ret;
+}
+
+// Try to replace `undef` constants in C with Replacement.
+static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) {
+ if (C && match(C, m_Undef()))
+ return Replacement;
+
+ if (auto *CV = dyn_cast<ConstantVector>(C)) {
+ llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands());
+ for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) {
+ Constant *EltC = CV->getOperand(i);
+ NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC;
+ }
+ return ConstantVector::get(NewOps);
+ }
+
+ // Don't know how to deal with this constant.
+ return C;
+}
+
+// If we have some pattern that leaves only some low bits set, and then performs
+// left-shift of those bits, if none of the bits that are left after the final
+// shift are modified by the mask, we can omit the mask.
+//
+// There are many variants to this pattern:
+// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
+// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
+// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt
+// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt
+// e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt
+// f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt
+// All these patterns can be simplified to just:
+// x << ShiftShAmt
+// iff:
+// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
+// c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
+static Instruction *
+dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
+ const SimplifyQuery &Q,
+ InstCombiner::BuilderTy &Builder) {
+ assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
+ "The input must be 'shl'!");
+
+ Value *Masked, *ShiftShAmt;
+ match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt)));
+
+ Type *NarrowestTy = OuterShift->getType();
+ Type *WidestTy = Masked->getType();
+ // The mask must be computed in a type twice as wide to ensure
+ // that no bits are lost if the sum-of-shifts is wider than the base type.
+ Type *ExtendedTy = WidestTy->getExtendedType();
+
+ Value *MaskShAmt;
+
+ // ((1 << MaskShAmt) - 1)
+ auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
+ // (~(-1 << maskNbits))
+ auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
+ // (-1 >> MaskShAmt)
+ auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
+ // ((-1 << MaskShAmt) >> MaskShAmt)
+ auto MaskD =
+ m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt));
+
+ Value *X;
+ Constant *NewMask;
+
+ if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
+ // Can we simplify (MaskShAmt+ShiftShAmt) ?
+ auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst(
+ MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
+ if (!SumOfShAmts)
+ return nullptr; // Did not simplify.
+ // In this pattern SumOfShAmts correlates with the number of low bits
+ // that shall remain in the root value (OuterShift).
+
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with final
+ // shift bitwidth to ensure that the value remains undef when creating the
+ // subsequent shift op.
+ SumOfShAmts = replaceUndefsWith(
+ SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
+ ExtendedTy->getScalarSizeInBits()));
+ auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+ // And compute the mask as usual: ~(-1 << (SumOfShAmts))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ auto *ExtendedInvertedMask =
+ ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
+ NewMask = ConstantExpr::getNot(ExtendedInvertedMask);
+ } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
+ match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
+ m_Deferred(MaskShAmt)))) {
+ // Can we simplify (ShiftShAmt-MaskShAmt) ?
+ auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst(
+ ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
+ if (!ShAmtsDiff)
+ return nullptr; // Did not simplify.
+ // In this pattern ShAmtsDiff correlates with the number of high bits that
+ // shall be unset in the root value (OuterShift).
+
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with negated
+ // bitwidth of innermost shift to ensure that the value remains undef when
+ // creating the subsequent shift op.
+ unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
+ ShAmtsDiff = replaceUndefsWith(
+ ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
+ -WidestTyBitWidth));
+ auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+ ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
+ WidestTyBitWidth,
+ /*isSigned=*/false),
+ ShAmtsDiff),
+ ExtendedTy);
+ // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ NewMask =
+ ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+ } else
+ return nullptr; // Don't know anything about this pattern.
+
+ NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy);
+
+ // Does this mask has any unset bits? If not then we can just not apply it.
+ bool NeedMask = !match(NewMask, m_AllOnes());
+
+ // If we need to apply a mask, there are several more restrictions we have.
+ if (NeedMask) {
+ // The old masking instruction must go away.
+ if (!Masked->hasOneUse())
+ return nullptr;
+ // The original "masking" instruction must not have been`ashr`.
+ if (match(Masked, m_AShr(m_Value(), m_Value())))
+ return nullptr;
}
- return NewShift;
+
+ // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
+ auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
+ OuterShift->getOperand(1));
+
+ if (!NeedMask)
+ return NewShift;
+
+ Builder.Insert(NewShift);
+ return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
}
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
+ // If the shift amount is a one-use `sext`, we can demote it to `zext`.
+ Value *Y;
+ if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
+ Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName());
+ return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
+ }
+
// See if we can fold away this shift.
if (SimplifyDemandedInstructionBits(I))
return &I;
@@ -83,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res;
- if (Instruction *NewShift =
- reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))
+ if (auto *NewShift = cast_or_null<Instruction>(
+ reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
return NewShift;
// (C1 shift (A add C2)) -> (C1 shift C2) shift A)
@@ -618,9 +843,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
Instruction *InstCombiner::visitShl(BinaryOperator &I) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+
if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
- I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
- SQ.getWithInstruction(&I)))
+ I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q))
return replaceInstUsesWith(I, V);
if (Instruction *X = foldVectorBinop(I))
@@ -629,6 +855,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Instruction *V = commonShiftTransforms(I))
return V;
+ if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
+ return V;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
@@ -636,12 +865,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
const APInt *ShAmtAPInt;
if (match(Op1, m_APInt(ShAmtAPInt))) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
- unsigned BitWidth = Ty->getScalarSizeInBits();
// shl (zext X), ShAmt --> zext (shl X, ShAmt)
// This is only valid if X would have zeros shifted out.
Value *X;
- if (match(Op0, m_ZExt(m_Value(X)))) {
+ if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmt < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
@@ -719,6 +947,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
// (X * C2) << C1 --> X * (C2 << C1)
if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
+
+ // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
+ if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+ auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1);
+ return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
+ }
}
// (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
@@ -859,6 +1093,75 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
return nullptr;
}
+Instruction *
+InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract(
+ BinaryOperator &OldAShr) {
+ assert(OldAShr.getOpcode() == Instruction::AShr &&
+ "Must be called with arithmetic right-shift instruction only.");
+
+ // Check that constant C is a splat of the element-wise bitwidth of V.
+ auto BitWidthSplat = [](Constant *C, Value *V) {
+ return match(
+ C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(C->getType()->getScalarSizeInBits(),
+ V->getType()->getScalarSizeInBits())));
+ };
+
+ // It should look like variable-length sign-extension on the outside:
+ // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
+ Value *NBits;
+ Instruction *MaybeTrunc;
+ Constant *C1, *C2;
+ if (!match(&OldAShr,
+ m_AShr(m_Shl(m_Instruction(MaybeTrunc),
+ m_ZExtOrSelf(m_Sub(m_Constant(C1),
+ m_ZExtOrSelf(m_Value(NBits))))),
+ m_ZExtOrSelf(m_Sub(m_Constant(C2),
+ m_ZExtOrSelf(m_Deferred(NBits)))))) ||
+ !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
+ return nullptr;
+
+ // There may or may not be a truncation after outer two shifts.
+ Instruction *HighBitExtract;
+ match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
+ bool HadTrunc = MaybeTrunc != HighBitExtract;
+
+ // And finally, the innermost part of the pattern must be a right-shift.
+ Value *X, *NumLowBitsToSkip;
+ if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
+ return nullptr;
+
+ // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
+ Constant *C0;
+ if (!match(NumLowBitsToSkip,
+ m_ZExtOrSelf(
+ m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
+ !BitWidthSplat(C0, HighBitExtract))
+ return nullptr;
+
+ // Since the NBits is identical for all shifts, if the outermost and
+ // innermost shifts are identical, then outermost shifts are redundant.
+ // If we had truncation, do keep it though.
+ if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
+ return replaceInstUsesWith(OldAShr, MaybeTrunc);
+
+ // Else, if there was a truncation, then we need to ensure that one
+ // instruction will go away.
+ if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
+ // Finally, bypass two innermost shifts, and perform the outermost shift on
+ // the operands of the innermost shift.
+ Instruction *NewAShr =
+ BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
+ NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
+ if (!HadTrunc)
+ return NewAShr;
+
+ Builder.Insert(NewAShr);
+ return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
+}
+
Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
@@ -933,6 +1236,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
}
}
+ if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
+ return R;
+
// See if we can turn a signed shr into an unsigned shr.
if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
return BinaryOperator::CreateLShr(Op0, Op1);
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index e0d85c4b49ae..d30ab8001897 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -971,6 +971,13 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1,
Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
APInt DemandedElts,
int DMaskIdx) {
+
+ // FIXME: Allow v3i16/v3f16 in buffer intrinsics when the types are fully supported.
+ if (DMaskIdx < 0 &&
+ II->getType()->getScalarSizeInBits() != 32 &&
+ DemandedElts.getActiveBits() == 3)
+ return nullptr;
+
unsigned VWidth = II->getType()->getVectorNumElements();
if (VWidth == 1)
return nullptr;
@@ -1067,16 +1074,22 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
}
/// The specified value produces a vector with any number of elements.
+/// This method analyzes which elements of the operand are undef and returns
+/// that information in UndefElts.
+///
/// DemandedElts contains the set of elements that are actually used by the
-/// caller. This method analyzes which elements of the operand are undef and
-/// returns that information in UndefElts.
+/// caller, and by default (AllowMultipleUsers equals false) the value is
+/// simplified only if it has a single caller. If AllowMultipleUsers is set
+/// to true, DemandedElts refers to the union of sets of elements that are
+/// used by all callers.
///
/// If the information about demanded elements can be used to simplify the
/// operation, the operation is simplified, then the resultant value is
/// returned. This returns null if no change was made.
Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
APInt &UndefElts,
- unsigned Depth) {
+ unsigned Depth,
+ bool AllowMultipleUsers) {
unsigned VWidth = V->getType()->getVectorNumElements();
APInt EltMask(APInt::getAllOnesValue(VWidth));
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
@@ -1130,19 +1143,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
if (Depth == 10)
return nullptr;
- // If multiple users are using the root value, proceed with
- // simplification conservatively assuming that all elements
- // are needed.
- if (!V->hasOneUse()) {
- // Quit if we find multiple users of a non-root value though.
- // They'll be handled when it's their turn to be visited by
- // the main instcombine process.
- if (Depth != 0)
- // TODO: Just compute the UndefElts information recursively.
- return nullptr;
+ if (!AllowMultipleUsers) {
+ // If multiple users are using the root value, proceed with
+ // simplification conservatively assuming that all elements
+ // are needed.
+ if (!V->hasOneUse()) {
+ // Quit if we find multiple users of a non-root value though.
+ // They'll be handled when it's their turn to be visited by
+ // the main instcombine process.
+ if (Depth != 0)
+ // TODO: Just compute the UndefElts information recursively.
+ return nullptr;
- // Conservatively assume that all elements are needed.
- DemandedElts = EltMask;
+ // Conservatively assume that all elements are needed.
+ DemandedElts = EltMask;
+ }
}
Instruction *I = dyn_cast<Instruction>(V);
@@ -1674,8 +1689,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
case Intrinsic::amdgcn_buffer_load_format:
case Intrinsic::amdgcn_raw_buffer_load:
case Intrinsic::amdgcn_raw_buffer_load_format:
+ case Intrinsic::amdgcn_raw_tbuffer_load:
case Intrinsic::amdgcn_struct_buffer_load:
case Intrinsic::amdgcn_struct_buffer_load_format:
+ case Intrinsic::amdgcn_struct_tbuffer_load:
+ case Intrinsic::amdgcn_tbuffer_load:
return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts);
default: {
if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID()))
diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index dc9abdd7f47a..9c890748e5ab 100644
--- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -253,6 +253,69 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
return nullptr;
}
+/// Find elements of V demanded by UserInstr.
+static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
+ unsigned VWidth = V->getType()->getVectorNumElements();
+
+ // Conservatively assume that all elements are needed.
+ APInt UsedElts(APInt::getAllOnesValue(VWidth));
+
+ switch (UserInstr->getOpcode()) {
+ case Instruction::ExtractElement: {
+ ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr);
+ assert(EEI->getVectorOperand() == V);
+ ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
+ if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
+ UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
+ }
+ break;
+ }
+ case Instruction::ShuffleVector: {
+ ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
+ unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements();
+
+ UsedElts = APInt(VWidth, 0);
+ for (unsigned i = 0; i < MaskNumElts; i++) {
+ unsigned MaskVal = Shuffle->getMaskValue(i);
+ if (MaskVal == -1u || MaskVal >= 2 * VWidth)
+ continue;
+ if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
+ UsedElts.setBit(MaskVal);
+ if (Shuffle->getOperand(1) == V &&
+ ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
+ UsedElts.setBit(MaskVal - VWidth);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return UsedElts;
+}
+
+/// Find union of elements of V demanded by all its users.
+/// If it is known by querying findDemandedEltsBySingleUser that
+/// no user demands an element of V, then the corresponding bit
+/// remains unset in the returned value.
+static APInt findDemandedEltsByAllUsers(Value *V) {
+ unsigned VWidth = V->getType()->getVectorNumElements();
+
+ APInt UnionUsedElts(VWidth, 0);
+ for (const Use &U : V->uses()) {
+ if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
+ UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
+ } else {
+ UnionUsedElts = APInt::getAllOnesValue(VWidth);
+ break;
+ }
+
+ if (UnionUsedElts.isAllOnesValue())
+ break;
+ }
+
+ return UnionUsedElts;
+}
+
Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
Value *SrcVec = EI.getVectorOperand();
Value *Index = EI.getIndexOperand();
@@ -271,19 +334,35 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
return nullptr;
// This instruction only demands the single element from the input vector.
- // If the input vector has a single use, simplify it based on this use
- // property.
- if (SrcVec->hasOneUse() && NumElts != 1) {
- APInt UndefElts(NumElts, 0);
- APInt DemandedElts(NumElts, 0);
- DemandedElts.setBit(IndexC->getZExtValue());
- if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts,
- UndefElts)) {
- EI.setOperand(0, V);
- return &EI;
+ if (NumElts != 1) {
+ // If the input vector has a single use, simplify it based on this use
+ // property.
+ if (SrcVec->hasOneUse()) {
+ APInt UndefElts(NumElts, 0);
+ APInt DemandedElts(NumElts, 0);
+ DemandedElts.setBit(IndexC->getZExtValue());
+ if (Value *V =
+ SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) {
+ EI.setOperand(0, V);
+ return &EI;
+ }
+ } else {
+ // If the input vector has multiple uses, simplify it based on a union
+ // of all elements used.
+ APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec);
+ if (!DemandedElts.isAllOnesValue()) {
+ APInt UndefElts(NumElts, 0);
+ if (Value *V = SimplifyDemandedVectorElts(
+ SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
+ true /* AllowMultipleUsers */)) {
+ if (V != SrcVec) {
+ SrcVec->replaceAllUsesWith(V);
+ return &EI;
+ }
+ }
+ }
}
}
-
if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
return I;
@@ -766,6 +845,55 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) {
return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask);
}
+/// Try to fold an extract+insert element into an existing identity shuffle by
+/// changing the shuffle's mask to include the index of this insert element.
+static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) {
+ // Check if the vector operand of this insert is an identity shuffle.
+ auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0));
+ if (!Shuf || !isa<UndefValue>(Shuf->getOperand(1)) ||
+ !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding()))
+ return nullptr;
+
+ // Check for a constant insertion index.
+ uint64_t IdxC;
+ if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC)))
+ return nullptr;
+
+ // Check if this insert's scalar op is extracted from the identity shuffle's
+ // input vector.
+ Value *Scalar = InsElt.getOperand(1);
+ Value *X = Shuf->getOperand(0);
+ if (!match(Scalar, m_ExtractElement(m_Specific(X), m_SpecificInt(IdxC))))
+ return nullptr;
+
+ // Replace the shuffle mask element at the index of this extract+insert with
+ // that same index value.
+ // For example:
+ // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask'
+ unsigned NumMaskElts = Shuf->getType()->getVectorNumElements();
+ SmallVector<Constant *, 16> NewMaskVec(NumMaskElts);
+ Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext());
+ Constant *NewMaskEltC = ConstantInt::get(I32Ty, IdxC);
+ Constant *OldMask = Shuf->getMask();
+ for (unsigned i = 0; i != NumMaskElts; ++i) {
+ if (i != IdxC) {
+ // All mask elements besides the inserted element remain the same.
+ NewMaskVec[i] = OldMask->getAggregateElement(i);
+ } else if (OldMask->getAggregateElement(i) == NewMaskEltC) {
+ // If the mask element was already set, there's nothing to do
+ // (demanded elements analysis may unset it later).
+ return nullptr;
+ } else {
+ assert(isa<UndefValue>(OldMask->getAggregateElement(i)) &&
+ "Unexpected shuffle mask element for identity shuffle");
+ NewMaskVec[i] = NewMaskEltC;
+ }
+ }
+
+ Constant *NewMask = ConstantVector::get(NewMaskVec);
+ return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask);
+}
+
/// If we have an insertelement instruction feeding into another insertelement
/// and the 2nd is inserting a constant into the vector, canonicalize that
/// constant insertion before the insertion of a variable:
@@ -987,6 +1115,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) {
if (Instruction *Splat = foldInsEltIntoSplat(IE))
return Splat;
+ if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE))
+ return IdentityShuf;
+
return nullptr;
}
@@ -1009,17 +1140,23 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
if (Depth == 0) return false;
switch (I->getOpcode()) {
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ // Propagating an undefined shuffle mask element to integer div/rem is not
+ // allowed because those opcodes can create immediate undefined behavior
+ // from an undefined element in an operand.
+ if (llvm::any_of(Mask, [](int M){ return M == -1; }))
+ return false;
+ LLVM_FALLTHROUGH;
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
case Instruction::FSub:
case Instruction::Mul:
case Instruction::FMul:
- case Instruction::UDiv:
- case Instruction::SDiv:
case Instruction::FDiv:
- case Instruction::URem:
- case Instruction::SRem:
case Instruction::FRem:
case Instruction::Shl:
case Instruction::LShr:
@@ -1040,9 +1177,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
case Instruction::FPExt:
case Instruction::GetElementPtr: {
// Bail out if we would create longer vector ops. We could allow creating
- // longer vector ops, but that may result in more expensive codegen. We
- // would also need to limit the transform to avoid undefined behavior for
- // integer div/rem.
+ // longer vector ops, but that may result in more expensive codegen.
Type *ITy = I->getType();
if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements())
return false;
diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp
index 385f4926b845..ecb486c544e0 100644
--- a/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -200,8 +200,8 @@ bool InstCombiner::shouldChangeType(Type *From, Type *To) const {
// where both B and C should be ConstantInts, results in a constant that does
// not overflow. This function only handles the Add and Sub opcodes. For
// all other opcodes, the function conservatively returns false.
-static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) {
- OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I);
+static bool maintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) {
+ auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I);
if (!OBO || !OBO->hasNoSignedWrap())
return false;
@@ -224,10 +224,15 @@ static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) {
}
static bool hasNoUnsignedWrap(BinaryOperator &I) {
- OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I);
+ auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I);
return OBO && OBO->hasNoUnsignedWrap();
}
+static bool hasNoSignedWrap(BinaryOperator &I) {
+ auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I);
+ return OBO && OBO->hasNoSignedWrap();
+}
+
/// Conservatively clears subclassOptionalData after a reassociation or
/// commutation. We preserve fast-math flags when applicable as they can be
/// preserved.
@@ -332,22 +337,21 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
// It simplifies to V. Form "A op V".
I.setOperand(0, A);
I.setOperand(1, V);
- // Conservatively clear the optional flags, since they may not be
- // preserved by the reassociation.
bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0);
- bool IsNSW = MaintainNoSignedWrap(I, B, C);
+ bool IsNSW = maintainNoSignedWrap(I, B, C) && hasNoSignedWrap(*Op0);
+ // Conservatively clear all optional flags since they may not be
+ // preserved by the reassociation. Reset nsw/nuw based on the above
+ // analysis.
ClearSubclassDataAfterReassociation(I);
+ // Note: this is only valid because SimplifyBinOp doesn't look at
+ // the operands to Op0.
if (IsNUW)
I.setHasNoUnsignedWrap(true);
- if (IsNSW &&
- (!Op0 || (isa<BinaryOperator>(Op0) && Op0->hasNoSignedWrap()))) {
- // Note: this is only valid because SimplifyBinOp doesn't look at
- // the operands to Op0.
+ if (IsNSW)
I.setHasNoSignedWrap(true);
- }
Changed = true;
++NumReassoc;
@@ -610,7 +614,6 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I,
HasNUW &= ROBO->hasNoUnsignedWrap();
}
- const APInt *CInt;
if (TopLevelOpcode == Instruction::Add &&
InnerOpcode == Instruction::Mul) {
// We can propagate 'nsw' if we know that
@@ -620,6 +623,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I,
// %Z = mul nsw i16 %X, C+1
//
// iff C+1 isn't INT_MIN
+ const APInt *CInt;
if (match(V, m_APInt(CInt))) {
if (!CInt->isMinSignedValue())
BO->setHasNoSignedWrap(HasNSW);
@@ -763,12 +767,16 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) {
bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse();
+
+ FastMathFlags FMF;
BuilderTy::FastMathFlagGuard Guard(Builder);
- if (isa<FPMathOperator>(&I))
- Builder.setFastMathFlags(I.getFastMathFlags());
+ if (isa<FPMathOperator>(&I)) {
+ FMF = I.getFastMathFlags();
+ Builder.setFastMathFlags(FMF);
+ }
- Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I));
- Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I));
+ Value *V1 = SimplifyBinOp(Opcode, C, E, FMF, SQ.getWithInstruction(&I));
+ Value *V2 = SimplifyBinOp(Opcode, B, D, FMF, SQ.getWithInstruction(&I));
if (V1 && V2)
SI = Builder.CreateSelect(A, V2, V1);
else if (V2 && SelectsHaveOneUse)
@@ -1659,7 +1667,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// to an index of zero, so replace it with zero if it is not zero already.
Type *EltTy = GTI.getIndexedType();
if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0)
- if (!isa<Constant>(*I) || !cast<Constant>(*I)->isNullValue()) {
+ if (!isa<Constant>(*I) || !match(I->get(), m_Zero())) {
*I = Constant::getNullValue(NewIndexType);
MadeChange = true;
}
@@ -2549,9 +2557,7 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) {
Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
// Change br (not X), label True, label False to: br X, label False, True
Value *X = nullptr;
- BasicBlock *TrueDest;
- BasicBlock *FalseDest;
- if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) &&
+ if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) &&
!isa<Constant>(X)) {
// Swap Destinations and condition...
BI.setCondition(X);
@@ -2569,8 +2575,8 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
// Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq.
CmpInst::Predicate Pred;
- if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest,
- FalseDest)) &&
+ if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())),
+ m_BasicBlock(), m_BasicBlock())) &&
!isCanonicalPredicate(Pred)) {
// Swap destinations and condition.
CmpInst *Cond = cast<CmpInst>(BI.getCondition());
@@ -3156,6 +3162,21 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) {
findDbgUsers(DbgUsers, I);
for (auto *DII : reverse(DbgUsers)) {
if (DII->getParent() == SrcBlock) {
+ if (isa<DbgDeclareInst>(DII)) {
+ // A dbg.declare instruction should not be cloned, since there can only be
+ // one per variable fragment. It should be left in the original place since
+ // sunk instruction is not an alloca(otherwise we could not be here).
+ // But we need to update arguments of dbg.declare instruction, so that it
+ // would not point into sunk instruction.
+ if (!isa<CastInst>(I))
+ continue; // dbg.declare points at something it shouldn't
+
+ DII->setOperand(
+ 0, MetadataAsValue::get(I->getContext(),
+ ValueAsMetadata::get(I->getOperand(0))));
+ continue;
+ }
+
// dbg.value is in the same basic block as the sunk inst, see if we can
// salvage it. Clone a new copy of the instruction: on success we need
// both salvaged and unsalvaged copies.
@@ -3580,7 +3601,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
// Required analyses.
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index 6821e214e921..d92ee11c2e1a 100644
--- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -129,6 +129,8 @@ static const uintptr_t kRetiredStackFrameMagic = 0x45E0360E;
static const char *const kAsanModuleCtorName = "asan.module_ctor";
static const char *const kAsanModuleDtorName = "asan.module_dtor";
static const uint64_t kAsanCtorAndDtorPriority = 1;
+// On Emscripten, the system needs more than one priorities for constructors.
+static const uint64_t kAsanEmscriptenCtorAndDtorPriority = 50;
static const char *const kAsanReportErrorTemplate = "__asan_report_";
static const char *const kAsanRegisterGlobalsName = "__asan_register_globals";
static const char *const kAsanUnregisterGlobalsName =
@@ -191,6 +193,11 @@ static cl::opt<bool> ClRecover(
cl::desc("Enable recovery mode (continue-after-error)."),
cl::Hidden, cl::init(false));
+static cl::opt<bool> ClInsertVersionCheck(
+ "asan-guard-against-version-mismatch",
+ cl::desc("Guard against compiler/runtime version mismatch."),
+ cl::Hidden, cl::init(true));
+
// This flag may need to be replaced with -f[no-]asan-reads.
static cl::opt<bool> ClInstrumentReads("asan-instrument-reads",
cl::desc("instrument read instructions"),
@@ -530,6 +537,14 @@ static size_t RedzoneSizeForScale(int MappingScale) {
return std::max(32U, 1U << MappingScale);
}
+static uint64_t GetCtorAndDtorPriority(Triple &TargetTriple) {
+ if (TargetTriple.isOSEmscripten()) {
+ return kAsanEmscriptenCtorAndDtorPriority;
+ } else {
+ return kAsanCtorAndDtorPriority;
+ }
+}
+
namespace {
/// Module analysis for getting various metadata about the module.
@@ -565,10 +580,10 @@ char ASanGlobalsMetadataWrapperPass::ID = 0;
/// AddressSanitizer: instrument the code in module to find memory bugs.
struct AddressSanitizer {
- AddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD,
+ AddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD,
bool CompileKernel = false, bool Recover = false,
bool UseAfterScope = false)
- : UseAfterScope(UseAfterScope || ClUseAfterScope), GlobalsMD(GlobalsMD) {
+ : UseAfterScope(UseAfterScope || ClUseAfterScope), GlobalsMD(*GlobalsMD) {
this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover;
this->CompileKernel =
ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel;
@@ -677,7 +692,7 @@ private:
FunctionCallee AsanMemmove, AsanMemcpy, AsanMemset;
InlineAsm *EmptyAsm;
Value *LocalDynamicShadow = nullptr;
- GlobalsMetadata GlobalsMD;
+ const GlobalsMetadata &GlobalsMD;
DenseMap<const AllocaInst *, bool> ProcessedAllocas;
};
@@ -706,8 +721,8 @@ public:
GlobalsMetadata &GlobalsMD =
getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD();
const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
- AddressSanitizer ASan(*F.getParent(), GlobalsMD, CompileKernel, Recover,
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ AddressSanitizer ASan(*F.getParent(), &GlobalsMD, CompileKernel, Recover,
UseAfterScope);
return ASan.instrumentFunction(F, TLI);
}
@@ -720,10 +735,10 @@ private:
class ModuleAddressSanitizer {
public:
- ModuleAddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD,
+ ModuleAddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD,
bool CompileKernel = false, bool Recover = false,
bool UseGlobalsGC = true, bool UseOdrIndicator = false)
- : GlobalsMD(GlobalsMD), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC),
+ : GlobalsMD(*GlobalsMD), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC),
// Enable aliases as they should have no downside with ODR indicators.
UsePrivateAlias(UseOdrIndicator || ClUsePrivateAlias),
UseOdrIndicator(UseOdrIndicator || ClUseOdrIndicator),
@@ -783,7 +798,7 @@ private:
}
int GetAsanVersion(const Module &M) const;
- GlobalsMetadata GlobalsMD;
+ const GlobalsMetadata &GlobalsMD;
bool CompileKernel;
bool Recover;
bool UseGlobalsGC;
@@ -830,7 +845,7 @@ public:
bool runOnModule(Module &M) override {
GlobalsMetadata &GlobalsMD =
getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD();
- ModuleAddressSanitizer ASanModule(M, GlobalsMD, CompileKernel, Recover,
+ ModuleAddressSanitizer ASanModule(M, &GlobalsMD, CompileKernel, Recover,
UseGlobalGC, UseOdrIndicator);
return ASanModule.instrumentModule(M);
}
@@ -1033,7 +1048,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> {
if (!II.isLifetimeStartOrEnd())
return;
// Found lifetime intrinsic, add ASan instrumentation if necessary.
- ConstantInt *Size = dyn_cast<ConstantInt>(II.getArgOperand(0));
+ auto *Size = cast<ConstantInt>(II.getArgOperand(0));
// If size argument is undefined, don't do anything.
if (Size->isMinusOne()) return;
// Check that size doesn't saturate uint64_t and can
@@ -1156,7 +1171,7 @@ PreservedAnalyses AddressSanitizerPass::run(Function &F,
Module &M = *F.getParent();
if (auto *R = MAM.getCachedResult<ASanGlobalsMetadataAnalysis>(M)) {
const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
- AddressSanitizer Sanitizer(M, *R, CompileKernel, Recover, UseAfterScope);
+ AddressSanitizer Sanitizer(M, R, CompileKernel, Recover, UseAfterScope);
if (Sanitizer.instrumentFunction(F, TLI))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
@@ -1178,7 +1193,7 @@ ModuleAddressSanitizerPass::ModuleAddressSanitizerPass(bool CompileKernel,
PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M,
AnalysisManager<Module> &AM) {
GlobalsMetadata &GlobalsMD = AM.getResult<ASanGlobalsMetadataAnalysis>(M);
- ModuleAddressSanitizer Sanitizer(M, GlobalsMD, CompileKernel, Recover,
+ ModuleAddressSanitizer Sanitizer(M, &GlobalsMD, CompileKernel, Recover,
UseGlobalGC, UseOdrIndicator);
if (Sanitizer.instrumentModule(M))
return PreservedAnalyses::none();
@@ -1331,7 +1346,7 @@ Value *AddressSanitizer::isInterestingMemoryAccess(Instruction *I,
unsigned *Alignment,
Value **MaybeMask) {
// Skip memory accesses inserted by another instrumentation.
- if (I->getMetadata("nosanitize")) return nullptr;
+ if (I->hasMetadata("nosanitize")) return nullptr;
// Do not instrument the load fetching the dynamic shadow address.
if (LocalDynamicShadow == I)
@@ -1775,9 +1790,10 @@ void ModuleAddressSanitizer::createInitializerPoisonCalls(
// Must have a function or null ptr.
if (Function *F = dyn_cast<Function>(CS->getOperand(1))) {
if (F->getName() == kAsanModuleCtorName) continue;
- ConstantInt *Priority = dyn_cast<ConstantInt>(CS->getOperand(0));
+ auto *Priority = cast<ConstantInt>(CS->getOperand(0));
// Don't instrument CTORs that will run before asan.module_ctor.
- if (Priority->getLimitedValue() <= kAsanCtorAndDtorPriority) continue;
+ if (Priority->getLimitedValue() <= GetCtorAndDtorPriority(TargetTriple))
+ continue;
poisonOneInitializer(*F, ModuleName);
}
}
@@ -1919,7 +1935,12 @@ StringRef ModuleAddressSanitizer::getGlobalMetadataSection() const {
case Triple::COFF: return ".ASAN$GL";
case Triple::ELF: return "asan_globals";
case Triple::MachO: return "__DATA,__asan_globals,regular";
- default: break;
+ case Triple::Wasm:
+ case Triple::XCOFF:
+ report_fatal_error(
+ "ModuleAddressSanitizer not implemented for object file format.");
+ case Triple::UnknownObjectFormat:
+ break;
}
llvm_unreachable("unsupported object format");
}
@@ -2033,7 +2054,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsCOFF(
unsigned SizeOfGlobalStruct = DL.getTypeAllocSize(Initializer->getType());
assert(isPowerOf2_32(SizeOfGlobalStruct) &&
"global metadata will not be padded appropriately");
- Metadata->setAlignment(SizeOfGlobalStruct);
+ Metadata->setAlignment(assumeAligned(SizeOfGlobalStruct));
SetComdatForGlobalMetadata(G, Metadata, "");
}
@@ -2170,7 +2191,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray(
M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage,
ConstantArray::get(ArrayOfGlobalStructTy, MetadataInitializers), "");
if (Mapping.Scale > 3)
- AllGlobals->setAlignment(1ULL << Mapping.Scale);
+ AllGlobals->setAlignment(Align(1ULL << Mapping.Scale));
IRB.CreateCall(AsanRegisterGlobals,
{IRB.CreatePointerCast(AllGlobals, IntptrTy),
@@ -2270,7 +2291,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
"", G, G->getThreadLocalMode());
NewGlobal->copyAttributesFrom(G);
NewGlobal->setComdat(G->getComdat());
- NewGlobal->setAlignment(MinRZ);
+ NewGlobal->setAlignment(MaybeAlign(MinRZ));
// Don't fold globals with redzones. ODR violation detector and redzone
// poisoning implicitly creates a dependence on the global's address, so it
// is no longer valid for it to be marked unnamed_addr.
@@ -2338,7 +2359,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
// Set meaningful attributes for indicator symbol.
ODRIndicatorSym->setVisibility(NewGlobal->getVisibility());
ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass());
- ODRIndicatorSym->setAlignment(1);
+ ODRIndicatorSym->setAlignment(Align::None());
ODRIndicator = ODRIndicatorSym;
}
@@ -2410,39 +2431,39 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) {
// Create a module constructor. A destructor is created lazily because not all
// platforms, and not all modules need it.
+ std::string AsanVersion = std::to_string(GetAsanVersion(M));
std::string VersionCheckName =
- kAsanVersionCheckNamePrefix + std::to_string(GetAsanVersion(M));
+ ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : "";
std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions(
M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{},
/*InitArgs=*/{}, VersionCheckName);
bool CtorComdat = true;
- bool Changed = false;
// TODO(glider): temporarily disabled globals instrumentation for KASan.
if (ClGlobals) {
IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator());
- Changed |= InstrumentGlobals(IRB, M, &CtorComdat);
+ InstrumentGlobals(IRB, M, &CtorComdat);
}
+ const uint64_t Priority = GetCtorAndDtorPriority(TargetTriple);
+
// Put the constructor and destructor in comdat if both
// (1) global instrumentation is not TU-specific
// (2) target is ELF.
if (UseCtorComdat && TargetTriple.isOSBinFormatELF() && CtorComdat) {
AsanCtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleCtorName));
- appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority,
- AsanCtorFunction);
+ appendToGlobalCtors(M, AsanCtorFunction, Priority, AsanCtorFunction);
if (AsanDtorFunction) {
AsanDtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleDtorName));
- appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority,
- AsanDtorFunction);
+ appendToGlobalDtors(M, AsanDtorFunction, Priority, AsanDtorFunction);
}
} else {
- appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority);
+ appendToGlobalCtors(M, AsanCtorFunction, Priority);
if (AsanDtorFunction)
- appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority);
+ appendToGlobalDtors(M, AsanDtorFunction, Priority);
}
- return Changed;
+ return true;
}
void AddressSanitizer::initializeCallbacks(Module &M) {
@@ -2664,7 +2685,7 @@ bool AddressSanitizer::instrumentFunction(Function &F,
if (CS) {
// A call inside BB.
TempsToInstrument.clear();
- if (CS.doesNotReturn() && !CS->getMetadata("nosanitize"))
+ if (CS.doesNotReturn() && !CS->hasMetadata("nosanitize"))
NoReturnCalls.push_back(CS.getInstruction());
}
if (CallInst *CI = dyn_cast<CallInst>(&Inst))
@@ -2877,18 +2898,19 @@ void FunctionStackPoisoner::copyArgsPassedByValToAllocas() {
for (Argument &Arg : F.args()) {
if (Arg.hasByValAttr()) {
Type *Ty = Arg.getType()->getPointerElementType();
- unsigned Align = Arg.getParamAlignment();
- if (Align == 0) Align = DL.getABITypeAlignment(Ty);
+ unsigned Alignment = Arg.getParamAlignment();
+ if (Alignment == 0)
+ Alignment = DL.getABITypeAlignment(Ty);
AllocaInst *AI = IRB.CreateAlloca(
Ty, nullptr,
(Arg.hasName() ? Arg.getName() : "Arg" + Twine(Arg.getArgNo())) +
".byval");
- AI->setAlignment(Align);
+ AI->setAlignment(Align(Alignment));
Arg.replaceAllUsesWith(AI);
uint64_t AllocSize = DL.getTypeAllocSize(Ty);
- IRB.CreateMemCpy(AI, Align, &Arg, Align, AllocSize);
+ IRB.CreateMemCpy(AI, Alignment, &Arg, Alignment, AllocSize);
}
}
}
@@ -2919,7 +2941,7 @@ Value *FunctionStackPoisoner::createAllocaForLayout(
}
assert((ClRealignStack & (ClRealignStack - 1)) == 0);
size_t FrameAlignment = std::max(L.FrameAlignment, (size_t)ClRealignStack);
- Alloca->setAlignment(FrameAlignment);
+ Alloca->setAlignment(MaybeAlign(FrameAlignment));
return IRB.CreatePointerCast(Alloca, IntptrTy);
}
@@ -2928,7 +2950,7 @@ void FunctionStackPoisoner::createDynamicAllocasInitStorage() {
IRBuilder<> IRB(dyn_cast<Instruction>(FirstBB.begin()));
DynamicAllocaLayout = IRB.CreateAlloca(IntptrTy, nullptr);
IRB.CreateStore(Constant::getNullValue(IntptrTy), DynamicAllocaLayout);
- DynamicAllocaLayout->setAlignment(32);
+ DynamicAllocaLayout->setAlignment(Align(32));
}
void FunctionStackPoisoner::processDynamicAllocas() {
@@ -3275,7 +3297,7 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) {
// Insert new alloca with new NewSize and Align params.
AllocaInst *NewAlloca = IRB.CreateAlloca(IRB.getInt8Ty(), NewSize);
- NewAlloca->setAlignment(Align);
+ NewAlloca->setAlignment(MaybeAlign(Align));
// NewAddress = Address + Align
Value *NewAddress = IRB.CreateAdd(IRB.CreatePtrToInt(NewAlloca, IntptrTy),
diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp
index 4dc9b611c156..ae34be986537 100644
--- a/lib/Transforms/Instrumentation/BoundsChecking.cpp
+++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp
@@ -224,7 +224,7 @@ struct BoundsCheckingLegacyPass : public FunctionPass {
}
bool runOnFunction(Function &F) override {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
return addBoundsChecking(F, TLI, SE);
}
diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h
index 971e00041762..8bb6f47c4846 100644
--- a/lib/Transforms/Instrumentation/CFGMST.h
+++ b/lib/Transforms/Instrumentation/CFGMST.h
@@ -257,13 +257,13 @@ public:
std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
if (Inserted) {
// Newly inserted, update the real info.
- Iter->second = std::move(llvm::make_unique<BBInfo>(Index));
+ Iter->second = std::move(std::make_unique<BBInfo>(Index));
Index++;
}
std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
if (Inserted)
// Newly inserted, update the real info.
- Iter->second = std::move(llvm::make_unique<BBInfo>(Index));
+ Iter->second = std::move(std::make_unique<BBInfo>(Index));
AllEdges.emplace_back(new Edge(Src, Dest, W));
return *AllEdges.back();
}
diff --git a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 3f4f9bc7145d..55c64fa4b727 100644
--- a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -512,30 +512,38 @@ static bool isHoistable(Instruction *I, DominatorTree &DT) {
// first-region entry block) or the (hoistable or unhoistable) base values that
// are defined outside (including the first-region entry block) of the
// scope. The returned set doesn't include constants.
-static std::set<Value *> getBaseValues(Value *V,
- DominatorTree &DT) {
+static std::set<Value *> getBaseValues(
+ Value *V, DominatorTree &DT,
+ DenseMap<Value *, std::set<Value *>> &Visited) {
+ if (Visited.count(V)) {
+ return Visited[V];
+ }
std::set<Value *> Result;
if (auto *I = dyn_cast<Instruction>(V)) {
// We don't stop at a block that's not in the Scope because we would miss some
// instructions that are based on the same base values if we stop there.
if (!isHoistable(I, DT)) {
Result.insert(I);
+ Visited.insert(std::make_pair(V, Result));
return Result;
}
// I is hoistable above the Scope.
for (Value *Op : I->operands()) {
- std::set<Value *> OpResult = getBaseValues(Op, DT);
+ std::set<Value *> OpResult = getBaseValues(Op, DT, Visited);
Result.insert(OpResult.begin(), OpResult.end());
}
+ Visited.insert(std::make_pair(V, Result));
return Result;
}
if (isa<Argument>(V)) {
Result.insert(V);
+ Visited.insert(std::make_pair(V, Result));
return Result;
}
// We don't include others like constants because those won't lead to any
// chance of folding of conditions (eg two bit checks merged into one check)
// after CHR.
+ Visited.insert(std::make_pair(V, Result));
return Result; // empty
}
@@ -1078,12 +1086,13 @@ static bool shouldSplit(Instruction *InsertPoint,
if (!PrevConditionValues.empty() && !ConditionValues.empty()) {
// Use std::set as DenseSet doesn't work with set_intersection.
std::set<Value *> PrevBases, Bases;
+ DenseMap<Value *, std::set<Value *>> Visited;
for (Value *V : PrevConditionValues) {
- std::set<Value *> BaseValues = getBaseValues(V, DT);
+ std::set<Value *> BaseValues = getBaseValues(V, DT, Visited);
PrevBases.insert(BaseValues.begin(), BaseValues.end());
}
for (Value *V : ConditionValues) {
- std::set<Value *> BaseValues = getBaseValues(V, DT);
+ std::set<Value *> BaseValues = getBaseValues(V, DT, Visited);
Bases.insert(BaseValues.begin(), BaseValues.end());
}
CHR_DEBUG(
@@ -1538,10 +1547,7 @@ static bool negateICmpIfUsedByBranchOrSelectOnly(ICmpInst *ICmp,
}
if (auto *SI = dyn_cast<SelectInst>(U)) {
// Swap operands
- Value *TrueValue = SI->getTrueValue();
- Value *FalseValue = SI->getFalseValue();
- SI->setTrueValue(FalseValue);
- SI->setFalseValue(TrueValue);
+ SI->swapValues();
SI->swapProfMetadata();
if (Scope->TrueBiasedSelects.count(SI)) {
assert(Scope->FalseBiasedSelects.count(SI) == 0 &&
@@ -2073,7 +2079,7 @@ bool ControlHeightReductionLegacyPass::runOnFunction(Function &F) {
getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
RegionInfo &RI = getAnalysis<RegionInfoPass>().getRegionInfo();
std::unique_ptr<OptimizationRemarkEmitter> OwnedORE =
- llvm::make_unique<OptimizationRemarkEmitter>(&F);
+ std::make_unique<OptimizationRemarkEmitter>(&F);
return CHR(F, BFI, DT, PSI, RI, *OwnedORE.get()).run();
}
diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index 2279c1bcb6a8..c0353cba0b2f 100644
--- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -1212,7 +1212,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align,
return DFS.ZeroShadow;
case 1: {
LoadInst *LI = new LoadInst(DFS.ShadowTy, ShadowAddr, "", Pos);
- LI->setAlignment(ShadowAlign);
+ LI->setAlignment(MaybeAlign(ShadowAlign));
return LI;
}
case 2: {
diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp
index 59950ffc4e9a..ac6082441eae 100644
--- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp
+++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp
@@ -86,7 +86,9 @@ public:
ReversedVersion[3] = Options.Version[0];
ReversedVersion[4] = '\0';
}
- bool runOnModule(Module &M, const TargetLibraryInfo &TLI);
+ bool
+ runOnModule(Module &M,
+ std::function<const TargetLibraryInfo &(Function &F)> GetTLI);
private:
// Create the .gcno files for the Module based on DebugInfo.
@@ -102,9 +104,9 @@ private:
std::vector<Regex> &Regexes);
// Get pointers to the functions in the runtime library.
- FunctionCallee getStartFileFunc();
- FunctionCallee getEmitFunctionFunc();
- FunctionCallee getEmitArcsFunc();
+ FunctionCallee getStartFileFunc(const TargetLibraryInfo *TLI);
+ FunctionCallee getEmitFunctionFunc(const TargetLibraryInfo *TLI);
+ FunctionCallee getEmitArcsFunc(const TargetLibraryInfo *TLI);
FunctionCallee getSummaryInfoFunc();
FunctionCallee getEndFileFunc();
@@ -127,7 +129,7 @@ private:
SmallVector<uint32_t, 4> FileChecksums;
Module *M;
- const TargetLibraryInfo *TLI;
+ std::function<const TargetLibraryInfo &(Function &F)> GetTLI;
LLVMContext *Ctx;
SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs;
std::vector<Regex> FilterRe;
@@ -147,8 +149,9 @@ public:
StringRef getPassName() const override { return "GCOV Profiler"; }
bool runOnModule(Module &M) override {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
- return Profiler.runOnModule(M, TLI);
+ return Profiler.runOnModule(M, [this](Function &F) -> TargetLibraryInfo & {
+ return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ });
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -555,9 +558,10 @@ std::string GCOVProfiler::mangleName(const DICompileUnit *CU,
return CurPath.str();
}
-bool GCOVProfiler::runOnModule(Module &M, const TargetLibraryInfo &TLI) {
+bool GCOVProfiler::runOnModule(
+ Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI) {
this->M = &M;
- this->TLI = &TLI;
+ this->GetTLI = std::move(GetTLI);
Ctx = &M.getContext();
AddFlushBeforeForkAndExec();
@@ -574,9 +578,12 @@ PreservedAnalyses GCOVProfilerPass::run(Module &M,
ModuleAnalysisManager &AM) {
GCOVProfiler Profiler(GCOVOpts);
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
- if (!Profiler.runOnModule(M, TLI))
+ if (!Profiler.runOnModule(M, [&](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ }))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
@@ -624,6 +631,7 @@ static bool shouldKeepInEntry(BasicBlock::iterator It) {
void GCOVProfiler::AddFlushBeforeForkAndExec() {
SmallVector<Instruction *, 2> ForkAndExecs;
for (auto &F : M->functions()) {
+ auto *TLI = &GetTLI(F);
for (auto &I : instructions(F)) {
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
if (Function *Callee = CI->getCalledFunction()) {
@@ -669,7 +677,8 @@ void GCOVProfiler::emitProfileNotes() {
continue;
std::error_code EC;
- raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC, sys::fs::F_None);
+ raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC,
+ sys::fs::OF_None);
if (EC) {
Ctx->emitError(Twine("failed to open coverage notes file for writing: ") +
EC.message());
@@ -695,7 +704,7 @@ void GCOVProfiler::emitProfileNotes() {
++It;
EntryBlock.splitBasicBlock(It);
- Funcs.push_back(make_unique<GCOVFunction>(SP, &F, &out, FunctionIdent++,
+ Funcs.push_back(std::make_unique<GCOVFunction>(SP, &F, &out, FunctionIdent++,
Options.UseCfgChecksum,
Options.ExitBlockBeforeBody));
GCOVFunction &Func = *Funcs.back();
@@ -873,7 +882,7 @@ bool GCOVProfiler::emitProfileArcs() {
return Result;
}
-FunctionCallee GCOVProfiler::getStartFileFunc() {
+FunctionCallee GCOVProfiler::getStartFileFunc(const TargetLibraryInfo *TLI) {
Type *Args[] = {
Type::getInt8PtrTy(*Ctx), // const char *orig_filename
Type::getInt8PtrTy(*Ctx), // const char version[4]
@@ -887,7 +896,7 @@ FunctionCallee GCOVProfiler::getStartFileFunc() {
return Res;
}
-FunctionCallee GCOVProfiler::getEmitFunctionFunc() {
+FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) {
Type *Args[] = {
Type::getInt32Ty(*Ctx), // uint32_t ident
Type::getInt8PtrTy(*Ctx), // const char *function_name
@@ -906,7 +915,7 @@ FunctionCallee GCOVProfiler::getEmitFunctionFunc() {
return M->getOrInsertFunction("llvm_gcda_emit_function", FTy);
}
-FunctionCallee GCOVProfiler::getEmitArcsFunc() {
+FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) {
Type *Args[] = {
Type::getInt32Ty(*Ctx), // uint32_t num_counters
Type::getInt64PtrTy(*Ctx), // uint64_t *counters
@@ -943,9 +952,11 @@ Function *GCOVProfiler::insertCounterWriteout(
BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", WriteoutF);
IRBuilder<> Builder(BB);
- FunctionCallee StartFile = getStartFileFunc();
- FunctionCallee EmitFunction = getEmitFunctionFunc();
- FunctionCallee EmitArcs = getEmitArcsFunc();
+ auto *TLI = &GetTLI(*WriteoutF);
+
+ FunctionCallee StartFile = getStartFileFunc(TLI);
+ FunctionCallee EmitFunction = getEmitFunctionFunc(TLI);
+ FunctionCallee EmitArcs = getEmitArcsFunc(TLI);
FunctionCallee SummaryInfo = getSummaryInfoFunc();
FunctionCallee EndFile = getEndFileFunc();
diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
index 90a9f4955a4b..f87132ee4758 100644
--- a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
@@ -12,10 +12,12 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
+#include "llvm/BinaryFormat/ELF.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -52,7 +54,10 @@ using namespace llvm;
#define DEBUG_TYPE "hwasan"
static const char *const kHwasanModuleCtorName = "hwasan.module_ctor";
+static const char *const kHwasanNoteName = "hwasan.note";
static const char *const kHwasanInitName = "__hwasan_init";
+static const char *const kHwasanPersonalityThunkName =
+ "__hwasan_personality_thunk";
static const char *const kHwasanShadowMemoryDynamicAddress =
"__hwasan_shadow_memory_dynamic_address";
@@ -112,6 +117,9 @@ static cl::opt<bool> ClGenerateTagsWithCalls(
cl::desc("generate new tags with runtime library calls"), cl::Hidden,
cl::init(false));
+static cl::opt<bool> ClGlobals("hwasan-globals", cl::desc("Instrument globals"),
+ cl::Hidden, cl::init(false));
+
static cl::opt<int> ClMatchAllTag(
"hwasan-match-all-tag",
cl::desc("don't report bad accesses via pointers with this tag"),
@@ -155,8 +163,18 @@ static cl::opt<bool>
static cl::opt<bool>
ClInstrumentLandingPads("hwasan-instrument-landing-pads",
- cl::desc("instrument landing pads"), cl::Hidden,
- cl::init(true));
+ cl::desc("instrument landing pads"), cl::Hidden,
+ cl::init(false), cl::ZeroOrMore);
+
+static cl::opt<bool> ClUseShortGranules(
+ "hwasan-use-short-granules",
+ cl::desc("use short granules in allocas and outlined checks"), cl::Hidden,
+ cl::init(false), cl::ZeroOrMore);
+
+static cl::opt<bool> ClInstrumentPersonalityFunctions(
+ "hwasan-instrument-personality-functions",
+ cl::desc("instrument personality functions"), cl::Hidden, cl::init(false),
+ cl::ZeroOrMore);
static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks",
cl::desc("inline all checks"),
@@ -169,16 +187,16 @@ namespace {
class HWAddressSanitizer {
public:
explicit HWAddressSanitizer(Module &M, bool CompileKernel = false,
- bool Recover = false) {
+ bool Recover = false) : M(M) {
this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover;
this->CompileKernel = ClEnableKhwasan.getNumOccurrences() > 0 ?
ClEnableKhwasan : CompileKernel;
- initializeModule(M);
+ initializeModule();
}
bool sanitizeFunction(Function &F);
- void initializeModule(Module &M);
+ void initializeModule();
void initializeCallbacks(Module &M);
@@ -216,9 +234,14 @@ public:
Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty);
void emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord);
+ void instrumentGlobal(GlobalVariable *GV, uint8_t Tag);
+ void instrumentGlobals();
+
+ void instrumentPersonalityFunctions();
+
private:
LLVMContext *C;
- std::string CurModuleUniqueId;
+ Module &M;
Triple TargetTriple;
FunctionCallee HWAsanMemmove, HWAsanMemcpy, HWAsanMemset;
FunctionCallee HWAsanHandleVfork;
@@ -238,17 +261,21 @@ private:
bool InTls;
void init(Triple &TargetTriple);
- unsigned getAllocaAlignment() const { return 1U << Scale; }
+ unsigned getObjectAlignment() const { return 1U << Scale; }
};
ShadowMapping Mapping;
+ Type *VoidTy = Type::getVoidTy(M.getContext());
Type *IntptrTy;
Type *Int8PtrTy;
Type *Int8Ty;
Type *Int32Ty;
+ Type *Int64Ty = Type::getInt64Ty(M.getContext());
bool CompileKernel;
bool Recover;
+ bool UseShortGranules;
+ bool InstrumentLandingPads;
Function *HwasanCtorFunction;
@@ -278,7 +305,7 @@ public:
StringRef getPassName() const override { return "HWAddressSanitizer"; }
bool doInitialization(Module &M) override {
- HWASan = llvm::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover);
+ HWASan = std::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover);
return true;
}
@@ -333,7 +360,7 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M,
/// Module-level initialization.
///
/// inserts a call to __hwasan_init to the module's constructor list.
-void HWAddressSanitizer::initializeModule(Module &M) {
+void HWAddressSanitizer::initializeModule() {
LLVM_DEBUG(dbgs() << "Init " << M.getName() << "\n");
auto &DL = M.getDataLayout();
@@ -342,7 +369,6 @@ void HWAddressSanitizer::initializeModule(Module &M) {
Mapping.init(TargetTriple);
C = &(M.getContext());
- CurModuleUniqueId = getUniqueModuleId(&M);
IRBuilder<> IRB(*C);
IntptrTy = IRB.getIntPtrTy(DL);
Int8PtrTy = IRB.getInt8PtrTy();
@@ -350,6 +376,21 @@ void HWAddressSanitizer::initializeModule(Module &M) {
Int32Ty = IRB.getInt32Ty();
HwasanCtorFunction = nullptr;
+
+ // Older versions of Android do not have the required runtime support for
+ // short granules, global or personality function instrumentation. On other
+ // platforms we currently require using the latest version of the runtime.
+ bool NewRuntime =
+ !TargetTriple.isAndroid() || !TargetTriple.isAndroidVersionLT(30);
+
+ UseShortGranules =
+ ClUseShortGranules.getNumOccurrences() ? ClUseShortGranules : NewRuntime;
+
+ // If we don't have personality function support, fall back to landing pads.
+ InstrumentLandingPads = ClInstrumentLandingPads.getNumOccurrences()
+ ? ClInstrumentLandingPads
+ : !NewRuntime;
+
if (!CompileKernel) {
std::tie(HwasanCtorFunction, std::ignore) =
getOrCreateSanitizerCtorAndInitFunctions(
@@ -363,6 +404,18 @@ void HWAddressSanitizer::initializeModule(Module &M) {
Ctor->setComdat(CtorComdat);
appendToGlobalCtors(M, Ctor, 0, Ctor);
});
+
+ bool InstrumentGlobals =
+ ClGlobals.getNumOccurrences() ? ClGlobals : NewRuntime;
+ if (InstrumentGlobals)
+ instrumentGlobals();
+
+ bool InstrumentPersonalityFunctions =
+ ClInstrumentPersonalityFunctions.getNumOccurrences()
+ ? ClInstrumentPersonalityFunctions
+ : NewRuntime;
+ if (InstrumentPersonalityFunctions)
+ instrumentPersonalityFunctions();
}
if (!TargetTriple.isAndroid()) {
@@ -456,7 +509,7 @@ Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I,
unsigned *Alignment,
Value **MaybeMask) {
// Skip memory accesses inserted by another instrumentation.
- if (I->getMetadata("nosanitize")) return nullptr;
+ if (I->hasMetadata("nosanitize")) return nullptr;
// Do not instrument the load fetching the dynamic shadow address.
if (LocalDynamicShadow == I)
@@ -564,9 +617,11 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
TargetTriple.isOSBinFormatELF() && !Recover) {
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy);
- IRB.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::hwasan_check_memaccess),
- {shadowBase(), Ptr, ConstantInt::get(Int32Ty, AccessInfo)});
+ IRB.CreateCall(Intrinsic::getDeclaration(
+ M, UseShortGranules
+ ? Intrinsic::hwasan_check_memaccess_shortgranules
+ : Intrinsic::hwasan_check_memaccess),
+ {shadowBase(), Ptr, ConstantInt::get(Int32Ty, AccessInfo)});
return;
}
@@ -718,7 +773,9 @@ static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) {
bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI,
Value *Tag, size_t Size) {
- size_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment());
+ size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment());
+ if (!UseShortGranules)
+ Size = AlignedSize;
Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty());
if (ClInstrumentWithCalls) {
@@ -738,7 +795,7 @@ bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI,
IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, /*Align=*/1);
if (Size != AlignedSize) {
IRB.CreateStore(
- ConstantInt::get(Int8Ty, Size % Mapping.getAllocaAlignment()),
+ ConstantInt::get(Int8Ty, Size % Mapping.getObjectAlignment()),
IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize));
IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32(
Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy),
@@ -778,8 +835,9 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) {
// FIXME: use addressofreturnaddress (but implement it in aarch64 backend
// first).
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
- auto GetStackPointerFn =
- Intrinsic::getDeclaration(M, Intrinsic::frameaddress);
+ auto GetStackPointerFn = Intrinsic::getDeclaration(
+ M, Intrinsic::frameaddress,
+ IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
Value *StackPointer = IRB.CreateCall(
GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())});
@@ -912,8 +970,10 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) {
PC = readRegister(IRB, "pc");
else
PC = IRB.CreatePtrToInt(F, IntptrTy);
- auto GetStackPointerFn =
- Intrinsic::getDeclaration(F->getParent(), Intrinsic::frameaddress);
+ Module *M = F->getParent();
+ auto GetStackPointerFn = Intrinsic::getDeclaration(
+ M, Intrinsic::frameaddress,
+ IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
Value *SP = IRB.CreatePtrToInt(
IRB.CreateCall(GetStackPointerFn,
{Constant::getNullValue(IRB.getInt32Ty())}),
@@ -999,11 +1059,8 @@ bool HWAddressSanitizer::instrumentStack(
AI->hasName() ? AI->getName().str() : "alloca." + itostr(N);
Replacement->setName(Name + ".hwasan");
- for (auto UI = AI->use_begin(), UE = AI->use_end(); UI != UE;) {
- Use &U = *UI++;
- if (U.getUser() != AILong)
- U.set(Replacement);
- }
+ AI->replaceUsesWithIf(Replacement,
+ [AILong](Use &U) { return U.getUser() != AILong; });
for (auto *DDI : AllocaDeclareMap.lookup(AI)) {
DIExpression *OldExpr = DDI->getExpression();
@@ -1020,7 +1077,7 @@ bool HWAddressSanitizer::instrumentStack(
// Re-tag alloca memory with the special UAR tag.
Value *Tag = getUARTag(IRB, StackTag);
- tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getAllocaAlignment()));
+ tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getObjectAlignment()));
}
}
@@ -1074,7 +1131,7 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) {
if (auto *Alloca = dyn_cast_or_null<AllocaInst>(DDI->getAddress()))
AllocaDeclareMap[Alloca].push_back(DDI);
- if (ClInstrumentLandingPads && isa<LandingPadInst>(Inst))
+ if (InstrumentLandingPads && isa<LandingPadInst>(Inst))
LandingPadVec.push_back(&Inst);
Value *MaybeMask = nullptr;
@@ -1093,6 +1150,13 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) {
if (!LandingPadVec.empty())
instrumentLandingPads(LandingPadVec);
+ if (AllocasToInstrument.empty() && F.hasPersonalityFn() &&
+ F.getPersonalityFn()->getName() == kHwasanPersonalityThunkName) {
+ // __hwasan_personality_thunk is a no-op for functions without an
+ // instrumented stack, so we can drop it.
+ F.setPersonalityFn(nullptr);
+ }
+
if (AllocasToInstrument.empty() && ToInstrument.empty())
return false;
@@ -1118,8 +1182,9 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) {
DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap;
for (AllocaInst *AI : AllocasToInstrument) {
uint64_t Size = getAllocaSizeInBytes(*AI);
- uint64_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment());
- AI->setAlignment(std::max(AI->getAlignment(), 16u));
+ uint64_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment());
+ AI->setAlignment(
+ MaybeAlign(std::max(AI->getAlignment(), Mapping.getObjectAlignment())));
if (Size != AlignedSize) {
Type *AllocatedType = AI->getAllocatedType();
if (AI->isArrayAllocation()) {
@@ -1132,7 +1197,7 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) {
auto *NewAI = new AllocaInst(
TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI);
NewAI->takeName(AI);
- NewAI->setAlignment(AI->getAlignment());
+ NewAI->setAlignment(MaybeAlign(AI->getAlignment()));
NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca());
NewAI->setSwiftError(AI->isSwiftError());
NewAI->copyMetadata(*AI);
@@ -1179,6 +1244,257 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) {
return Changed;
}
+void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) {
+ Constant *Initializer = GV->getInitializer();
+ uint64_t SizeInBytes =
+ M.getDataLayout().getTypeAllocSize(Initializer->getType());
+ uint64_t NewSize = alignTo(SizeInBytes, Mapping.getObjectAlignment());
+ if (SizeInBytes != NewSize) {
+ // Pad the initializer out to the next multiple of 16 bytes and add the
+ // required short granule tag.
+ std::vector<uint8_t> Init(NewSize - SizeInBytes, 0);
+ Init.back() = Tag;
+ Constant *Padding = ConstantDataArray::get(*C, Init);
+ Initializer = ConstantStruct::getAnon({Initializer, Padding});
+ }
+
+ auto *NewGV = new GlobalVariable(M, Initializer->getType(), GV->isConstant(),
+ GlobalValue::ExternalLinkage, Initializer,
+ GV->getName() + ".hwasan");
+ NewGV->copyAttributesFrom(GV);
+ NewGV->setLinkage(GlobalValue::PrivateLinkage);
+ NewGV->copyMetadata(GV, 0);
+ NewGV->setAlignment(
+ MaybeAlign(std::max(GV->getAlignment(), Mapping.getObjectAlignment())));
+
+ // It is invalid to ICF two globals that have different tags. In the case
+ // where the size of the global is a multiple of the tag granularity the
+ // contents of the globals may be the same but the tags (i.e. symbol values)
+ // may be different, and the symbols are not considered during ICF. In the
+ // case where the size is not a multiple of the granularity, the short granule
+ // tags would discriminate two globals with different tags, but there would
+ // otherwise be nothing stopping such a global from being incorrectly ICF'd
+ // with an uninstrumented (i.e. tag 0) global that happened to have the short
+ // granule tag in the last byte.
+ NewGV->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
+
+ // Descriptor format (assuming little-endian):
+ // bytes 0-3: relative address of global
+ // bytes 4-6: size of global (16MB ought to be enough for anyone, but in case
+ // it isn't, we create multiple descriptors)
+ // byte 7: tag
+ auto *DescriptorTy = StructType::get(Int32Ty, Int32Ty);
+ const uint64_t MaxDescriptorSize = 0xfffff0;
+ for (uint64_t DescriptorPos = 0; DescriptorPos < SizeInBytes;
+ DescriptorPos += MaxDescriptorSize) {
+ auto *Descriptor =
+ new GlobalVariable(M, DescriptorTy, true, GlobalValue::PrivateLinkage,
+ nullptr, GV->getName() + ".hwasan.descriptor");
+ auto *GVRelPtr = ConstantExpr::getTrunc(
+ ConstantExpr::getAdd(
+ ConstantExpr::getSub(
+ ConstantExpr::getPtrToInt(NewGV, Int64Ty),
+ ConstantExpr::getPtrToInt(Descriptor, Int64Ty)),
+ ConstantInt::get(Int64Ty, DescriptorPos)),
+ Int32Ty);
+ uint32_t Size = std::min(SizeInBytes - DescriptorPos, MaxDescriptorSize);
+ auto *SizeAndTag = ConstantInt::get(Int32Ty, Size | (uint32_t(Tag) << 24));
+ Descriptor->setComdat(NewGV->getComdat());
+ Descriptor->setInitializer(ConstantStruct::getAnon({GVRelPtr, SizeAndTag}));
+ Descriptor->setSection("hwasan_globals");
+ Descriptor->setMetadata(LLVMContext::MD_associated,
+ MDNode::get(*C, ValueAsMetadata::get(NewGV)));
+ appendToCompilerUsed(M, Descriptor);
+ }
+
+ Constant *Aliasee = ConstantExpr::getIntToPtr(
+ ConstantExpr::getAdd(
+ ConstantExpr::getPtrToInt(NewGV, Int64Ty),
+ ConstantInt::get(Int64Ty, uint64_t(Tag) << kPointerTagShift)),
+ GV->getType());
+ auto *Alias = GlobalAlias::create(GV->getValueType(), GV->getAddressSpace(),
+ GV->getLinkage(), "", Aliasee, &M);
+ Alias->setVisibility(GV->getVisibility());
+ Alias->takeName(GV);
+ GV->replaceAllUsesWith(Alias);
+ GV->eraseFromParent();
+}
+
+void HWAddressSanitizer::instrumentGlobals() {
+ // Start by creating a note that contains pointers to the list of global
+ // descriptors. Adding a note to the output file will cause the linker to
+ // create a PT_NOTE program header pointing to the note that we can use to
+ // find the descriptor list starting from the program headers. A function
+ // provided by the runtime initializes the shadow memory for the globals by
+ // accessing the descriptor list via the note. The dynamic loader needs to
+ // call this function whenever a library is loaded.
+ //
+ // The reason why we use a note for this instead of a more conventional
+ // approach of having a global constructor pass a descriptor list pointer to
+ // the runtime is because of an order of initialization problem. With
+ // constructors we can encounter the following problematic scenario:
+ //
+ // 1) library A depends on library B and also interposes one of B's symbols
+ // 2) B's constructors are called before A's (as required for correctness)
+ // 3) during construction, B accesses one of its "own" globals (actually
+ // interposed by A) and triggers a HWASAN failure due to the initialization
+ // for A not having happened yet
+ //
+ // Even without interposition it is possible to run into similar situations in
+ // cases where two libraries mutually depend on each other.
+ //
+ // We only need one note per binary, so put everything for the note in a
+ // comdat.
+ Comdat *NoteComdat = M.getOrInsertComdat(kHwasanNoteName);
+
+ Type *Int8Arr0Ty = ArrayType::get(Int8Ty, 0);
+ auto Start =
+ new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage,
+ nullptr, "__start_hwasan_globals");
+ Start->setVisibility(GlobalValue::HiddenVisibility);
+ Start->setDSOLocal(true);
+ auto Stop =
+ new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage,
+ nullptr, "__stop_hwasan_globals");
+ Stop->setVisibility(GlobalValue::HiddenVisibility);
+ Stop->setDSOLocal(true);
+
+ // Null-terminated so actually 8 bytes, which are required in order to align
+ // the note properly.
+ auto *Name = ConstantDataArray::get(*C, "LLVM\0\0\0");
+
+ auto *NoteTy = StructType::get(Int32Ty, Int32Ty, Int32Ty, Name->getType(),
+ Int32Ty, Int32Ty);
+ auto *Note =
+ new GlobalVariable(M, NoteTy, /*isConstantGlobal=*/true,
+ GlobalValue::PrivateLinkage, nullptr, kHwasanNoteName);
+ Note->setSection(".note.hwasan.globals");
+ Note->setComdat(NoteComdat);
+ Note->setAlignment(Align(4));
+ Note->setDSOLocal(true);
+
+ // The pointers in the note need to be relative so that the note ends up being
+ // placed in rodata, which is the standard location for notes.
+ auto CreateRelPtr = [&](Constant *Ptr) {
+ return ConstantExpr::getTrunc(
+ ConstantExpr::getSub(ConstantExpr::getPtrToInt(Ptr, Int64Ty),
+ ConstantExpr::getPtrToInt(Note, Int64Ty)),
+ Int32Ty);
+ };
+ Note->setInitializer(ConstantStruct::getAnon(
+ {ConstantInt::get(Int32Ty, 8), // n_namesz
+ ConstantInt::get(Int32Ty, 8), // n_descsz
+ ConstantInt::get(Int32Ty, ELF::NT_LLVM_HWASAN_GLOBALS), // n_type
+ Name, CreateRelPtr(Start), CreateRelPtr(Stop)}));
+ appendToCompilerUsed(M, Note);
+
+ // Create a zero-length global in hwasan_globals so that the linker will
+ // always create start and stop symbols.
+ auto Dummy = new GlobalVariable(
+ M, Int8Arr0Ty, /*isConstantGlobal*/ true, GlobalVariable::PrivateLinkage,
+ Constant::getNullValue(Int8Arr0Ty), "hwasan.dummy.global");
+ Dummy->setSection("hwasan_globals");
+ Dummy->setComdat(NoteComdat);
+ Dummy->setMetadata(LLVMContext::MD_associated,
+ MDNode::get(*C, ValueAsMetadata::get(Note)));
+ appendToCompilerUsed(M, Dummy);
+
+ std::vector<GlobalVariable *> Globals;
+ for (GlobalVariable &GV : M.globals()) {
+ if (GV.isDeclarationForLinker() || GV.getName().startswith("llvm.") ||
+ GV.isThreadLocal())
+ continue;
+
+ // Common symbols can't have aliases point to them, so they can't be tagged.
+ if (GV.hasCommonLinkage())
+ continue;
+
+ // Globals with custom sections may be used in __start_/__stop_ enumeration,
+ // which would be broken both by adding tags and potentially by the extra
+ // padding/alignment that we insert.
+ if (GV.hasSection())
+ continue;
+
+ Globals.push_back(&GV);
+ }
+
+ MD5 Hasher;
+ Hasher.update(M.getSourceFileName());
+ MD5::MD5Result Hash;
+ Hasher.final(Hash);
+ uint8_t Tag = Hash[0];
+
+ for (GlobalVariable *GV : Globals) {
+ // Skip tag 0 in order to avoid collisions with untagged memory.
+ if (Tag == 0)
+ Tag = 1;
+ instrumentGlobal(GV, Tag++);
+ }
+}
+
+void HWAddressSanitizer::instrumentPersonalityFunctions() {
+ // We need to untag stack frames as we unwind past them. That is the job of
+ // the personality function wrapper, which either wraps an existing
+ // personality function or acts as a personality function on its own. Each
+ // function that has a personality function or that can be unwound past has
+ // its personality function changed to a thunk that calls the personality
+ // function wrapper in the runtime.
+ MapVector<Constant *, std::vector<Function *>> PersonalityFns;
+ for (Function &F : M) {
+ if (F.isDeclaration() || !F.hasFnAttribute(Attribute::SanitizeHWAddress))
+ continue;
+
+ if (F.hasPersonalityFn()) {
+ PersonalityFns[F.getPersonalityFn()->stripPointerCasts()].push_back(&F);
+ } else if (!F.hasFnAttribute(Attribute::NoUnwind)) {
+ PersonalityFns[nullptr].push_back(&F);
+ }
+ }
+
+ if (PersonalityFns.empty())
+ return;
+
+ FunctionCallee HwasanPersonalityWrapper = M.getOrInsertFunction(
+ "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty,
+ Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy);
+ FunctionCallee UnwindGetGR = M.getOrInsertFunction("_Unwind_GetGR", VoidTy);
+ FunctionCallee UnwindGetCFA = M.getOrInsertFunction("_Unwind_GetCFA", VoidTy);
+
+ for (auto &P : PersonalityFns) {
+ std::string ThunkName = kHwasanPersonalityThunkName;
+ if (P.first)
+ ThunkName += ("." + P.first->getName()).str();
+ FunctionType *ThunkFnTy = FunctionType::get(
+ Int32Ty, {Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int8PtrTy}, false);
+ bool IsLocal = P.first && (!isa<GlobalValue>(P.first) ||
+ cast<GlobalValue>(P.first)->hasLocalLinkage());
+ auto *ThunkFn = Function::Create(ThunkFnTy,
+ IsLocal ? GlobalValue::InternalLinkage
+ : GlobalValue::LinkOnceODRLinkage,
+ ThunkName, &M);
+ if (!IsLocal) {
+ ThunkFn->setVisibility(GlobalValue::HiddenVisibility);
+ ThunkFn->setComdat(M.getOrInsertComdat(ThunkName));
+ }
+
+ auto *BB = BasicBlock::Create(*C, "entry", ThunkFn);
+ IRBuilder<> IRB(BB);
+ CallInst *WrapperCall = IRB.CreateCall(
+ HwasanPersonalityWrapper,
+ {ThunkFn->getArg(0), ThunkFn->getArg(1), ThunkFn->getArg(2),
+ ThunkFn->getArg(3), ThunkFn->getArg(4),
+ P.first ? IRB.CreateBitCast(P.first, Int8PtrTy)
+ : Constant::getNullValue(Int8PtrTy),
+ IRB.CreateBitCast(UnwindGetGR.getCallee(), Int8PtrTy),
+ IRB.CreateBitCast(UnwindGetCFA.getCallee(), Int8PtrTy)});
+ WrapperCall->setTailCall();
+ IRB.CreateRet(WrapperCall);
+
+ for (Function *F : P.second)
+ F->setPersonalityFn(ThunkFn);
+ }
+}
+
void HWAddressSanitizer::ShadowMapping::init(Triple &TargetTriple) {
Scale = kDefaultShadowScale;
if (ClMappingOffset.getNumOccurrences() > 0) {
diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index c7371f567ff3..74d6e76eceb6 100644
--- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -403,7 +403,7 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI,
AM->getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
} else {
- OwnedORE = llvm::make_unique<OptimizationRemarkEmitter>(&F);
+ OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F);
ORE = OwnedORE.get();
}
diff --git a/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/lib/Transforms/Instrumentation/InstrOrderFile.cpp
index a2c1ddfd279e..93d3a8a14d5c 100644
--- a/lib/Transforms/Instrumentation/InstrOrderFile.cpp
+++ b/lib/Transforms/Instrumentation/InstrOrderFile.cpp
@@ -100,7 +100,8 @@ public:
if (!ClOrderFileWriteMapping.empty()) {
std::lock_guard<std::mutex> LogLock(MappingMutex);
std::error_code EC;
- llvm::raw_fd_ostream OS(ClOrderFileWriteMapping, EC, llvm::sys::fs::F_Append);
+ llvm::raw_fd_ostream OS(ClOrderFileWriteMapping, EC,
+ llvm::sys::fs::OF_Append);
if (EC) {
report_fatal_error(Twine("Failed to open ") + ClOrderFileWriteMapping +
" to save mapping file for order file instrumentation\n");
diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp
index 63c2b8078967..1f092a5f3103 100644
--- a/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -157,7 +157,10 @@ public:
}
bool runOnModule(Module &M) override {
- return InstrProf.run(M, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI());
+ auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
+ return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ };
+ return InstrProf.run(M, GetTLI);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -370,8 +373,12 @@ private:
} // end anonymous namespace
PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) {
- auto &TLI = AM.getResult<TargetLibraryAnalysis>(M);
- if (!run(M, TLI))
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
+ if (!run(M, GetTLI))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
@@ -441,7 +448,7 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) {
std::unique_ptr<BlockFrequencyInfo> BFI;
if (Options.UseBFIInPromotion) {
std::unique_ptr<BranchProbabilityInfo> BPI;
- BPI.reset(new BranchProbabilityInfo(*F, LI, TLI));
+ BPI.reset(new BranchProbabilityInfo(*F, LI, &GetTLI(*F)));
BFI.reset(new BlockFrequencyInfo(*F, *BPI, LI));
}
@@ -482,9 +489,10 @@ static bool containsProfilingIntrinsics(Module &M) {
return false;
}
-bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) {
+bool InstrProfiling::run(
+ Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI) {
this->M = &M;
- this->TLI = &TLI;
+ this->GetTLI = std::move(GetTLI);
NamesVar = nullptr;
NamesSize = 0;
ProfileDataMap.clear();
@@ -601,6 +609,7 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {
bool IsRange = (Ind->getValueKind()->getZExtValue() ==
llvm::InstrProfValueKind::IPVK_MemOPSize);
CallInst *Call = nullptr;
+ auto *TLI = &GetTLI(*Ind->getFunction());
if (!IsRange) {
Value *Args[3] = {Ind->getTargetValue(),
Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
@@ -731,9 +740,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
PD = It->second;
}
- // Match the linkage and visibility of the name global, except on COFF, where
- // the linkage must be local and consequentially the visibility must be
- // default.
+ // Match the linkage and visibility of the name global. COFF supports using
+ // comdats with internal symbols, so do that if we can.
Function *Fn = Inc->getParent()->getParent();
GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage();
GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility();
@@ -749,19 +757,21 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
// new comdat group for the counters and profiling data. If we use the comdat
// of the parent function, that will result in relocations against discarded
// sections.
- Comdat *Cmdt = nullptr;
- GlobalValue::LinkageTypes CounterLinkage = Linkage;
- if (needsComdatForCounter(*Fn, *M)) {
- StringRef CmdtPrefix = getInstrProfComdatPrefix();
+ bool NeedComdat = needsComdatForCounter(*Fn, *M);
+ if (NeedComdat) {
if (TT.isOSBinFormatCOFF()) {
- // For COFF, the comdat group name must be the name of a symbol in the
- // group. Use the counter variable name, and upgrade its linkage to
- // something externally visible, like linkonce_odr.
- CmdtPrefix = getInstrProfCountersVarPrefix();
- CounterLinkage = GlobalValue::LinkOnceODRLinkage;
+ // For COFF, put the counters, data, and values each into their own
+ // comdats. We can't use a group because the Visual C++ linker will
+ // report duplicate symbol errors if there are multiple external symbols
+ // with the same name marked IMAGE_COMDAT_SELECT_ASSOCIATIVE.
+ Linkage = GlobalValue::LinkOnceODRLinkage;
+ Visibility = GlobalValue::HiddenVisibility;
}
- Cmdt = M->getOrInsertComdat(getVarName(Inc, CmdtPrefix));
}
+ auto MaybeSetComdat = [=](GlobalVariable *GV) {
+ if (NeedComdat)
+ GV->setComdat(M->getOrInsertComdat(GV->getName()));
+ };
uint64_t NumCounters = Inc->getNumCounters()->getZExtValue();
LLVMContext &Ctx = M->getContext();
@@ -775,9 +785,9 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
CounterPtr->setVisibility(Visibility);
CounterPtr->setSection(
getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat()));
- CounterPtr->setAlignment(8);
- CounterPtr->setComdat(Cmdt);
- CounterPtr->setLinkage(CounterLinkage);
+ CounterPtr->setAlignment(Align(8));
+ MaybeSetComdat(CounterPtr);
+ CounterPtr->setLinkage(Linkage);
auto *Int8PtrTy = Type::getInt8PtrTy(Ctx);
// Allocate statically the array of pointers to value profile nodes for
@@ -797,8 +807,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
ValuesVar->setVisibility(Visibility);
ValuesVar->setSection(
getInstrProfSectionName(IPSK_vals, TT.getObjectFormat()));
- ValuesVar->setAlignment(8);
- ValuesVar->setComdat(Cmdt);
+ ValuesVar->setAlignment(Align(8));
+ MaybeSetComdat(ValuesVar);
ValuesPtrExpr =
ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx));
}
@@ -830,8 +840,9 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) {
getVarName(Inc, getInstrProfDataVarPrefix()));
Data->setVisibility(Visibility);
Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat()));
- Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT);
- Data->setComdat(Cmdt);
+ Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT));
+ MaybeSetComdat(Data);
+ Data->setLinkage(Linkage);
PD.RegionCounters = CounterPtr;
PD.DataVar = Data;
@@ -920,7 +931,7 @@ void InstrProfiling::emitNameData() {
// On COFF, it's important to reduce the alignment down to 1 to prevent the
// linker from inserting padding before the start of the names section or
// between names entries.
- NamesVar->setAlignment(1);
+ NamesVar->setAlignment(Align::None());
UsedVars.push_back(NamesVar);
for (auto *NamePtr : ReferencedNames)
diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp
index f56a1bd91b89..a6c2c9b464b6 100644
--- a/lib/Transforms/Instrumentation/Instrumentation.cpp
+++ b/lib/Transforms/Instrumentation/Instrumentation.cpp
@@ -68,7 +68,8 @@ GlobalVariable *llvm::createPrivateGlobalForString(Module &M, StringRef Str,
GlobalValue::PrivateLinkage, StrConst, NamePrefix);
if (AllowMerging)
GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
- GV->setAlignment(1); // Strings may not be merged w/o setting align 1.
+ GV->setAlignment(Align::None()); // Strings may not be merged w/o setting
+ // alignment explicitly.
return GV;
}
@@ -116,7 +117,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) {
initializeMemorySanitizerLegacyPassPass(Registry);
initializeHWAddressSanitizerLegacyPassPass(Registry);
initializeThreadSanitizerLegacyPassPass(Registry);
- initializeSanitizerCoverageModulePass(Registry);
+ initializeModuleSanitizerCoverageLegacyPassPass(Registry);
initializeDataFlowSanitizerPass(Registry);
}
diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index b25cbed1bb02..69c9020e060b 100644
--- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -462,16 +462,9 @@ namespace {
/// the module.
class MemorySanitizer {
public:
- MemorySanitizer(Module &M, MemorySanitizerOptions Options) {
- this->CompileKernel =
- ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : Options.Kernel;
- if (ClTrackOrigins.getNumOccurrences() > 0)
- this->TrackOrigins = ClTrackOrigins;
- else
- this->TrackOrigins = this->CompileKernel ? 2 : Options.TrackOrigins;
- this->Recover = ClKeepGoing.getNumOccurrences() > 0
- ? ClKeepGoing
- : (this->CompileKernel | Options.Recover);
+ MemorySanitizer(Module &M, MemorySanitizerOptions Options)
+ : CompileKernel(Options.Kernel), TrackOrigins(Options.TrackOrigins),
+ Recover(Options.Recover) {
initializeModule(M);
}
@@ -594,10 +587,26 @@ private:
/// An empty volatile inline asm that prevents callback merge.
InlineAsm *EmptyAsm;
-
- Function *MsanCtorFunction;
};
+void insertModuleCtor(Module &M) {
+ getOrCreateSanitizerCtorAndInitFunctions(
+ M, kMsanModuleCtorName, kMsanInitName,
+ /*InitArgTypes=*/{},
+ /*InitArgs=*/{},
+ // This callback is invoked when the functions are created the first
+ // time. Hook them into the global ctors list in that case:
+ [&](Function *Ctor, FunctionCallee) {
+ if (!ClWithComdat) {
+ appendToGlobalCtors(M, Ctor, 0);
+ return;
+ }
+ Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName);
+ Ctor->setComdat(MsanCtorComdat);
+ appendToGlobalCtors(M, Ctor, 0, Ctor);
+ });
+}
+
/// A legacy function pass for msan instrumentation.
///
/// Instruments functions to detect unitialized reads.
@@ -615,7 +624,7 @@ struct MemorySanitizerLegacyPass : public FunctionPass {
bool runOnFunction(Function &F) override {
return MSan->sanitizeFunction(
- F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI());
+ F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F));
}
bool doInitialization(Module &M) override;
@@ -623,8 +632,17 @@ struct MemorySanitizerLegacyPass : public FunctionPass {
MemorySanitizerOptions Options;
};
+template <class T> T getOptOrDefault(const cl::opt<T> &Opt, T Default) {
+ return (Opt.getNumOccurrences() > 0) ? Opt : Default;
+}
+
} // end anonymous namespace
+MemorySanitizerOptions::MemorySanitizerOptions(int TO, bool R, bool K)
+ : Kernel(getOptOrDefault(ClEnableKmsan, K)),
+ TrackOrigins(getOptOrDefault(ClTrackOrigins, Kernel ? 2 : TO)),
+ Recover(getOptOrDefault(ClKeepGoing, Kernel || R)) {}
+
PreservedAnalyses MemorySanitizerPass::run(Function &F,
FunctionAnalysisManager &FAM) {
MemorySanitizer Msan(*F.getParent(), Options);
@@ -633,6 +651,14 @@ PreservedAnalyses MemorySanitizerPass::run(Function &F,
return PreservedAnalyses::all();
}
+PreservedAnalyses MemorySanitizerPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ if (Options.Kernel)
+ return PreservedAnalyses::all();
+ insertModuleCtor(M);
+ return PreservedAnalyses::none();
+}
+
char MemorySanitizerLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan",
@@ -918,23 +944,6 @@ void MemorySanitizer::initializeModule(Module &M) {
OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000);
if (!CompileKernel) {
- std::tie(MsanCtorFunction, std::ignore) =
- getOrCreateSanitizerCtorAndInitFunctions(
- M, kMsanModuleCtorName, kMsanInitName,
- /*InitArgTypes=*/{},
- /*InitArgs=*/{},
- // This callback is invoked when the functions are created the first
- // time. Hook them into the global ctors list in that case:
- [&](Function *Ctor, FunctionCallee) {
- if (!ClWithComdat) {
- appendToGlobalCtors(M, Ctor, 0);
- return;
- }
- Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName);
- Ctor->setComdat(MsanCtorComdat);
- appendToGlobalCtors(M, Ctor, 0, Ctor);
- });
-
if (TrackOrigins)
M.getOrInsertGlobal("__msan_track_origins", IRB.getInt32Ty(), [&] {
return new GlobalVariable(
@@ -952,6 +961,8 @@ void MemorySanitizer::initializeModule(Module &M) {
}
bool MemorySanitizerLegacyPass::doInitialization(Module &M) {
+ if (!Options.Kernel)
+ insertModuleCtor(M);
MSan.emplace(M, Options);
return true;
}
@@ -2562,6 +2573,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
return false;
}
+ void handleInvariantGroup(IntrinsicInst &I) {
+ setShadow(&I, getShadow(&I, 0));
+ setOrigin(&I, getOrigin(&I, 0));
+ }
+
void handleLifetimeStart(IntrinsicInst &I) {
if (!PoisonStack)
return;
@@ -2993,6 +3009,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
case Intrinsic::lifetime_start:
handleLifetimeStart(I);
break;
+ case Intrinsic::launder_invariant_group:
+ case Intrinsic::strip_invariant_group:
+ handleInvariantGroup(I);
+ break;
case Intrinsic::bswap:
handleBswap(I);
break;
@@ -3627,10 +3647,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
int getNumOutputArgs(InlineAsm *IA, CallBase *CB) {
int NumRetOutputs = 0;
int NumOutputs = 0;
- Type *RetTy = dyn_cast<Value>(CB)->getType();
+ Type *RetTy = cast<Value>(CB)->getType();
if (!RetTy->isVoidTy()) {
// Register outputs are returned via the CallInst return value.
- StructType *ST = dyn_cast_or_null<StructType>(RetTy);
+ auto *ST = dyn_cast<StructType>(RetTy);
if (ST)
NumRetOutputs = ST->getNumElements();
else
@@ -3667,7 +3687,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// corresponding CallInst has nO+nI+1 operands (the last operand is the
// function to be called).
const DataLayout &DL = F.getParent()->getDataLayout();
- CallBase *CB = dyn_cast<CallBase>(&I);
+ CallBase *CB = cast<CallBase>(&I);
IRBuilder<> IRB(&I);
InlineAsm *IA = cast<InlineAsm>(CB->getCalledValue());
int OutputArgs = getNumOutputArgs(IA, CB);
@@ -4567,8 +4587,9 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan,
}
bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) {
- if (!CompileKernel && (&F == MsanCtorFunction))
+ if (!CompileKernel && F.getName() == kMsanModuleCtorName)
return false;
+
MemorySanitizerVisitor Visitor(F, *this, TLI);
// Clear out readonly/readnone attributes.
diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 6fec3c9c79ee..ca1bb62389e9 100644
--- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -48,6 +48,7 @@
//===----------------------------------------------------------------------===//
#include "CFGMST.h"
+#include "ValueProfileCollector.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -61,7 +62,6 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/CFG.h"
-#include "llvm/Analysis/IndirectCallVisitor.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -96,6 +96,7 @@
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/CRC.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/DOTGraphTraits.h"
@@ -103,11 +104,11 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/GraphWriter.h"
-#include "llvm/Support/JamCRC.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/MisExpect.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -120,6 +121,7 @@
using namespace llvm;
using ProfileCount = Function::ProfileCount;
+using VPCandidateInfo = ValueProfileCollector::CandidateInfo;
#define DEBUG_TYPE "pgo-instrumentation"
@@ -286,6 +288,11 @@ static std::string getBranchCondString(Instruction *TI) {
return result;
}
+static const char *ValueProfKindDescr[] = {
+#define VALUE_PROF_KIND(Enumerator, Value, Descr) Descr,
+#include "llvm/ProfileData/InstrProfData.inc"
+};
+
namespace {
/// The select instruction visitor plays three roles specified
@@ -348,50 +355,6 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
unsigned getNumOfSelectInsts() const { return NSIs; }
};
-/// Instruction Visitor class to visit memory intrinsic calls.
-struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> {
- Function &F;
- unsigned NMemIs = 0; // Number of memIntrinsics instrumented.
- VisitMode Mode = VM_counting; // Visiting mode.
- unsigned CurCtrId = 0; // Current counter index.
- unsigned TotalNumCtrs = 0; // Total number of counters
- GlobalVariable *FuncNameVar = nullptr;
- uint64_t FuncHash = 0;
- PGOUseFunc *UseFunc = nullptr;
- std::vector<Instruction *> Candidates;
-
- MemIntrinsicVisitor(Function &Func) : F(Func) {}
-
- void countMemIntrinsics(Function &Func) {
- NMemIs = 0;
- Mode = VM_counting;
- visit(Func);
- }
-
- void instrumentMemIntrinsics(Function &Func, unsigned TotalNC,
- GlobalVariable *FNV, uint64_t FHash) {
- Mode = VM_instrument;
- TotalNumCtrs = TotalNC;
- FuncHash = FHash;
- FuncNameVar = FNV;
- visit(Func);
- }
-
- std::vector<Instruction *> findMemIntrinsics(Function &Func) {
- Candidates.clear();
- Mode = VM_annotate;
- visit(Func);
- return Candidates;
- }
-
- // Visit the IR stream and annotate all mem intrinsic call instructions.
- void instrumentOneMemIntrinsic(MemIntrinsic &MI);
-
- // Visit \p MI instruction and perform tasks according to visit mode.
- void visitMemIntrinsic(MemIntrinsic &SI);
-
- unsigned getNumOfMemIntrinsics() const { return NMemIs; }
-};
class PGOInstrumentationGenLegacyPass : public ModulePass {
public:
@@ -563,13 +526,14 @@ private:
// A map that stores the Comdat group in function F.
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
+ ValueProfileCollector VPC;
+
void computeCFGHash();
void renameComdatFunction();
public:
- std::vector<std::vector<Instruction *>> ValueSites;
+ std::vector<std::vector<VPCandidateInfo>> ValueSites;
SelectInstVisitor SIVisitor;
- MemIntrinsicVisitor MIVisitor;
std::string FuncName;
GlobalVariable *FuncNameVar;
@@ -604,23 +568,21 @@ public:
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
BlockFrequencyInfo *BFI = nullptr, bool IsCS = false)
- : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers),
- ValueSites(IPVK_Last + 1), SIVisitor(Func), MIVisitor(Func),
- MST(F, BPI, BFI) {
+ : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func),
+ ValueSites(IPVK_Last + 1), SIVisitor(Func), MST(F, BPI, BFI) {
// This should be done before CFG hash computation.
SIVisitor.countSelects(Func);
- MIVisitor.countMemIntrinsics(Func);
+ ValueSites[IPVK_MemOPSize] = VPC.get(IPVK_MemOPSize);
if (!IsCS) {
NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
- NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
+ NumOfPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size();
NumOfPGOBB += MST.BBInfos.size();
- ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func);
+ ValueSites[IPVK_IndirectCallTarget] = VPC.get(IPVK_IndirectCallTarget);
} else {
NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
- NumOfCSPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
+ NumOfCSPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size();
NumOfCSPGOBB += MST.BBInfos.size();
}
- ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func);
FuncName = getPGOFuncName(F);
computeCFGHash();
@@ -647,7 +609,7 @@ public:
// value of each BB in the CFG. The higher 32 bits record the number of edges.
template <class Edge, class BBInfo>
void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
- std::vector<char> Indexes;
+ std::vector<uint8_t> Indexes;
JamCRC JC;
for (auto &BB : F) {
const Instruction *TI = BB.getTerminator();
@@ -658,7 +620,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
continue;
uint32_t Index = BI->Index;
for (int J = 0; J < 4; J++)
- Indexes.push_back((char)(Index >> (J * 8)));
+ Indexes.push_back((uint8_t)(Index >> (J * 8)));
}
}
JC.update(Indexes);
@@ -874,28 +836,36 @@ static void instrumentOneFunc(
if (DisableValueProfiling)
return;
- unsigned NumIndirectCalls = 0;
- for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) {
- CallSite CS(I);
- Value *Callee = CS.getCalledValue();
- LLVM_DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = "
- << NumIndirectCalls << "\n");
- IRBuilder<> Builder(I);
- assert(Builder.GetInsertPoint() != I->getParent()->end() &&
- "Cannot get the Instrumentation point");
- Builder.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
- {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
- Builder.getInt64(FuncInfo.FunctionHash),
- Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()),
- Builder.getInt32(IPVK_IndirectCallTarget),
- Builder.getInt32(NumIndirectCalls++)});
- }
- NumOfPGOICall += NumIndirectCalls;
+ NumOfPGOICall += FuncInfo.ValueSites[IPVK_IndirectCallTarget].size();
- // Now instrument memop intrinsic calls.
- FuncInfo.MIVisitor.instrumentMemIntrinsics(
- F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash);
+ // For each VP Kind, walk the VP candidates and instrument each one.
+ for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) {
+ unsigned SiteIndex = 0;
+ if (Kind == IPVK_MemOPSize && !PGOInstrMemOP)
+ continue;
+
+ for (VPCandidateInfo Cand : FuncInfo.ValueSites[Kind]) {
+ LLVM_DEBUG(dbgs() << "Instrument one VP " << ValueProfKindDescr[Kind]
+ << " site: CallSite Index = " << SiteIndex << "\n");
+
+ IRBuilder<> Builder(Cand.InsertPt);
+ assert(Builder.GetInsertPoint() != Cand.InsertPt->getParent()->end() &&
+ "Cannot get the Instrumentation point");
+
+ Value *ToProfile = nullptr;
+ if (Cand.V->getType()->isIntegerTy())
+ ToProfile = Builder.CreateZExtOrTrunc(Cand.V, Builder.getInt64Ty());
+ else if (Cand.V->getType()->isPointerTy())
+ ToProfile = Builder.CreatePtrToInt(Cand.V, Builder.getInt64Ty());
+ assert(ToProfile && "value profiling Value is of unexpected type");
+
+ Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
+ {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
+ Builder.getInt64(FuncInfo.FunctionHash), ToProfile,
+ Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)});
+ }
+ } // IPVK_First <= Kind <= IPVK_Last
}
namespace {
@@ -984,9 +954,9 @@ class PGOUseFunc {
public:
PGOUseFunc(Function &Func, Module *Modu,
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
- BranchProbabilityInfo *BPI = nullptr,
- BlockFrequencyInfo *BFIin = nullptr, bool IsCS = false)
- : F(Func), M(Modu), BFI(BFIin),
+ BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin,
+ ProfileSummaryInfo *PSI, bool IsCS)
+ : F(Func), M(Modu), BFI(BFIin), PSI(PSI),
FuncInfo(Func, ComdatMembers, false, BPI, BFIin, IsCS),
FreqAttr(FFA_Normal), IsCS(IsCS) {}
@@ -1041,6 +1011,7 @@ private:
Function &F;
Module *M;
BlockFrequencyInfo *BFI;
+ ProfileSummaryInfo *PSI;
// This member stores the shared information with class PGOGenFunc.
FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
@@ -1078,15 +1049,9 @@ private:
// FIXME: This function should be removed once the functionality in
// the inliner is implemented.
void markFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) {
- if (ProgramMaxCount == 0)
- return;
- // Threshold of the hot functions.
- const BranchProbability HotFunctionThreshold(1, 100);
- // Threshold of the cold functions.
- const BranchProbability ColdFunctionThreshold(2, 10000);
- if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount))
+ if (PSI->isHotCount(EntryCount))
FreqAttr = FFA_Hot;
- else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount))
+ else if (PSI->isColdCount(MaxCount))
FreqAttr = FFA_Cold;
}
};
@@ -1433,43 +1398,6 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
llvm_unreachable("Unknown visiting mode");
}
-void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) {
- Module *M = F.getParent();
- IRBuilder<> Builder(&MI);
- Type *Int64Ty = Builder.getInt64Ty();
- Type *I8PtrTy = Builder.getInt8PtrTy();
- Value *Length = MI.getLength();
- assert(!isa<ConstantInt>(Length));
- Builder.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
- {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
- Builder.getInt64(FuncHash), Builder.CreateZExtOrTrunc(Length, Int64Ty),
- Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)});
- ++CurCtrId;
-}
-
-void MemIntrinsicVisitor::visitMemIntrinsic(MemIntrinsic &MI) {
- if (!PGOInstrMemOP)
- return;
- Value *Length = MI.getLength();
- // Not instrument constant length calls.
- if (dyn_cast<ConstantInt>(Length))
- return;
-
- switch (Mode) {
- case VM_counting:
- NMemIs++;
- return;
- case VM_instrument:
- instrumentOneMemIntrinsic(MI);
- return;
- case VM_annotate:
- Candidates.push_back(&MI);
- return;
- }
- llvm_unreachable("Unknown visiting mode");
-}
-
// Traverse all valuesites and annotate the instructions for all value kind.
void PGOUseFunc::annotateValueSites() {
if (DisableValueProfiling)
@@ -1482,11 +1410,6 @@ void PGOUseFunc::annotateValueSites() {
annotateValueSites(Kind);
}
-static const char *ValueProfKindDescr[] = {
-#define VALUE_PROF_KIND(Enumerator, Value, Descr) Descr,
-#include "llvm/ProfileData/InstrProfData.inc"
-};
-
// Annotate the instructions for a specific value kind.
void PGOUseFunc::annotateValueSites(uint32_t Kind) {
assert(Kind <= IPVK_Last);
@@ -1505,11 +1428,11 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) {
return;
}
- for (auto &I : ValueSites) {
+ for (VPCandidateInfo &I : ValueSites) {
LLVM_DEBUG(dbgs() << "Read one value site profile (kind = " << Kind
<< "): Index = " << ValueSiteIndex << " out of "
<< NumValueSites << "\n");
- annotateValueSite(*M, *I, ProfileRecord,
+ annotateValueSite(*M, *I.AnnotatedInst, ProfileRecord,
static_cast<InstrProfValueKind>(Kind), ValueSiteIndex,
Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations
: MaxNumAnnotations);
@@ -1595,7 +1518,8 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
static bool annotateAllFunctions(
Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName,
function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
- function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
+ function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+ ProfileSummaryInfo *PSI, bool IsCS) {
LLVM_DEBUG(dbgs() << "Read in profile counters: ");
auto &Ctx = M.getContext();
// Read the counter array from file.
@@ -1626,6 +1550,13 @@ static bool annotateAllFunctions(
return false;
}
+ // Add the profile summary (read from the header of the indexed summary) here
+ // so that we can use it below when reading counters (which checks if the
+ // function should be marked with a cold or inlinehint attribute).
+ M.setProfileSummary(PGOReader->getSummary(IsCS).getMD(M.getContext()),
+ IsCS ? ProfileSummary::PSK_CSInstr
+ : ProfileSummary::PSK_Instr);
+
std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers;
collectComdatMembers(M, ComdatMembers);
std::vector<Function *> HotFunctions;
@@ -1638,7 +1569,7 @@ static bool annotateAllFunctions(
// Split indirectbr critical edges here before computing the MST rather than
// later in getInstrBB() to avoid invalidating it.
SplitIndirectBrCriticalEdges(F, BPI, BFI);
- PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, IsCS);
+ PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, PSI, IsCS);
bool AllZeros = false;
if (!Func.readCounters(PGOReader.get(), AllZeros))
continue;
@@ -1662,9 +1593,9 @@ static bool annotateAllFunctions(
F.getName().equals(ViewBlockFreqFuncName))) {
LoopInfo LI{DominatorTree(F)};
std::unique_ptr<BranchProbabilityInfo> NewBPI =
- llvm::make_unique<BranchProbabilityInfo>(F, LI);
+ std::make_unique<BranchProbabilityInfo>(F, LI);
std::unique_ptr<BlockFrequencyInfo> NewBFI =
- llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI);
+ std::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI);
if (PGOViewCounts == PGOVCT_Graph)
NewBFI->view();
else if (PGOViewCounts == PGOVCT_Text) {
@@ -1686,9 +1617,6 @@ static bool annotateAllFunctions(
}
}
}
- M.setProfileSummary(PGOReader->getSummary(IsCS).getMD(M.getContext()),
- IsCS ? ProfileSummary::PSK_CSInstr
- : ProfileSummary::PSK_Instr);
// Set function hotness attribute from the profile.
// We have to apply these attributes at the end because their presence
@@ -1730,8 +1658,10 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M,
return &FAM.getResult<BlockFrequencyAnalysis>(F);
};
+ auto *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
+
if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName,
- LookupBPI, LookupBFI, IsCS))
+ LookupBPI, LookupBFI, PSI, IsCS))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
@@ -1748,7 +1678,8 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) {
return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
};
- return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI,
+ auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
+ return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI, PSI,
IsCS);
}
@@ -1776,6 +1707,9 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
: Weights) {
dbgs() << W << " ";
} dbgs() << "\n";);
+
+ misexpect::verifyMisExpect(TI, Weights, TI->getContext());
+
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
diff --git a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
index 188f95b4676b..9f81bb16d0a7 100644
--- a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
+++ b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
@@ -138,7 +138,7 @@ public:
OptimizationRemarkEmitter &ORE, DominatorTree *DT)
: Func(Func), BFI(BFI), ORE(ORE), DT(DT), Changed(false) {
ValueDataArray =
- llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2);
+ std::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2);
// Get the MemOPSize range information from option MemOPSizeRange,
getMemOPSizeRangeFromOption(MemOPSizeRange, PreciseRangeStart,
PreciseRangeLast);
@@ -374,8 +374,8 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) {
Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB);
Instruction *NewInst = MI->clone();
// Fix the argument.
- MemIntrinsic * MemI = dyn_cast<MemIntrinsic>(NewInst);
- IntegerType *SizeType = dyn_cast<IntegerType>(MemI->getLength()->getType());
+ auto *MemI = cast<MemIntrinsic>(NewInst);
+ auto *SizeType = dyn_cast<IntegerType>(MemI->getLength()->getType());
assert(SizeType && "Expected integer type size argument.");
ConstantInt *CaseSizeId = ConstantInt::get(SizeType, SizeId);
MemI->setLength(CaseSizeId);
diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
index ca0cb4bdbe84..f8fa9cad03b8 100644
--- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
+++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/Instrumentation/SanitizerCoverage.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/EHPersonalities.h"
@@ -176,24 +177,21 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) {
return Options;
}
-class SanitizerCoverageModule : public ModulePass {
+using DomTreeCallback = function_ref<const DominatorTree *(Function &F)>;
+using PostDomTreeCallback =
+ function_ref<const PostDominatorTree *(Function &F)>;
+
+class ModuleSanitizerCoverage {
public:
- SanitizerCoverageModule(
+ ModuleSanitizerCoverage(
const SanitizerCoverageOptions &Options = SanitizerCoverageOptions())
- : ModulePass(ID), Options(OverrideFromCL(Options)) {
- initializeSanitizerCoverageModulePass(*PassRegistry::getPassRegistry());
- }
- bool runOnModule(Module &M) override;
- bool runOnFunction(Function &F);
- static char ID; // Pass identification, replacement for typeid
- StringRef getPassName() const override { return "SanitizerCoverageModule"; }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<PostDominatorTreeWrapperPass>();
- }
+ : Options(OverrideFromCL(Options)) {}
+ bool instrumentModule(Module &M, DomTreeCallback DTCallback,
+ PostDomTreeCallback PDTCallback);
private:
+ void instrumentFunction(Function &F, DomTreeCallback DTCallback,
+ PostDomTreeCallback PDTCallback);
void InjectCoverageForIndirectCalls(Function &F,
ArrayRef<Instruction *> IndirCalls);
void InjectTraceForCmp(Function &F, ArrayRef<Instruction *> CmpTraceTargets);
@@ -252,10 +250,57 @@ private:
SanitizerCoverageOptions Options;
};
+class ModuleSanitizerCoverageLegacyPass : public ModulePass {
+public:
+ ModuleSanitizerCoverageLegacyPass(
+ const SanitizerCoverageOptions &Options = SanitizerCoverageOptions())
+ : ModulePass(ID), Options(Options) {
+ initializeModuleSanitizerCoverageLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+ }
+ bool runOnModule(Module &M) override {
+ ModuleSanitizerCoverage ModuleSancov(Options);
+ auto DTCallback = [this](Function &F) -> const DominatorTree * {
+ return &this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
+ };
+ auto PDTCallback = [this](Function &F) -> const PostDominatorTree * {
+ return &this->getAnalysis<PostDominatorTreeWrapperPass>(F)
+ .getPostDomTree();
+ };
+ return ModuleSancov.instrumentModule(M, DTCallback, PDTCallback);
+ }
+
+ static char ID; // Pass identification, replacement for typeid
+ StringRef getPassName() const override { return "ModuleSanitizerCoverage"; }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<PostDominatorTreeWrapperPass>();
+ }
+
+private:
+ SanitizerCoverageOptions Options;
+};
+
} // namespace
+PreservedAnalyses ModuleSanitizerCoveragePass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ ModuleSanitizerCoverage ModuleSancov(Options);
+ auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto DTCallback = [&FAM](Function &F) -> const DominatorTree * {
+ return &FAM.getResult<DominatorTreeAnalysis>(F);
+ };
+ auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree * {
+ return &FAM.getResult<PostDominatorTreeAnalysis>(F);
+ };
+ if (ModuleSancov.instrumentModule(M, DTCallback, PDTCallback))
+ return PreservedAnalyses::none();
+ return PreservedAnalyses::all();
+}
+
std::pair<Value *, Value *>
-SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section,
+ModuleSanitizerCoverage::CreateSecStartEnd(Module &M, const char *Section,
Type *Ty) {
GlobalVariable *SecStart =
new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, nullptr,
@@ -278,7 +323,7 @@ SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section,
return std::make_pair(IRB.CreatePointerCast(GEP, Ty), SecEndPtr);
}
-Function *SanitizerCoverageModule::CreateInitCallsForSections(
+Function *ModuleSanitizerCoverage::CreateInitCallsForSections(
Module &M, const char *CtorName, const char *InitFunctionName, Type *Ty,
const char *Section) {
auto SecStartEnd = CreateSecStartEnd(M, Section, Ty);
@@ -310,7 +355,8 @@ Function *SanitizerCoverageModule::CreateInitCallsForSections(
return CtorFunc;
}
-bool SanitizerCoverageModule::runOnModule(Module &M) {
+bool ModuleSanitizerCoverage::instrumentModule(
+ Module &M, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) {
if (Options.CoverageType == SanitizerCoverageOptions::SCK_None)
return false;
C = &(M.getContext());
@@ -403,7 +449,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) {
M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy);
for (auto &F : M)
- runOnFunction(F);
+ instrumentFunction(F, DTCallback, PDTCallback);
Function *Ctor = nullptr;
@@ -518,29 +564,30 @@ static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree *DT,
return true;
}
-bool SanitizerCoverageModule::runOnFunction(Function &F) {
+void ModuleSanitizerCoverage::instrumentFunction(
+ Function &F, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) {
if (F.empty())
- return false;
+ return;
if (F.getName().find(".module_ctor") != std::string::npos)
- return false; // Should not instrument sanitizer init functions.
+ return; // Should not instrument sanitizer init functions.
if (F.getName().startswith("__sanitizer_"))
- return false; // Don't instrument __sanitizer_* callbacks.
+ return; // Don't instrument __sanitizer_* callbacks.
// Don't touch available_externally functions, their actual body is elewhere.
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
- return false;
+ return;
// Don't instrument MSVC CRT configuration helpers. They may run before normal
// initialization.
if (F.getName() == "__local_stdio_printf_options" ||
F.getName() == "__local_stdio_scanf_options")
- return false;
+ return;
if (isa<UnreachableInst>(F.getEntryBlock().getTerminator()))
- return false;
+ return;
// Don't instrument functions using SEH for now. Splitting basic blocks like
// we do for coverage breaks WinEHPrepare.
// FIXME: Remove this when SEH no longer uses landingpad pattern matching.
if (F.hasPersonalityFn() &&
isAsynchronousEHPersonality(classifyEHPersonality(F.getPersonalityFn())))
- return false;
+ return;
if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge)
SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests());
SmallVector<Instruction *, 8> IndirCalls;
@@ -550,10 +597,8 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) {
SmallVector<BinaryOperator *, 8> DivTraceTargets;
SmallVector<GetElementPtrInst *, 8> GepTraceTargets;
- const DominatorTree *DT =
- &getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
- const PostDominatorTree *PDT =
- &getAnalysis<PostDominatorTreeWrapperPass>(F).getPostDomTree();
+ const DominatorTree *DT = DTCallback(F);
+ const PostDominatorTree *PDT = PDTCallback(F);
bool IsLeafFunc = true;
for (auto &BB : F) {
@@ -593,10 +638,9 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) {
InjectTraceForSwitch(F, SwitchTraceTargets);
InjectTraceForDiv(F, DivTraceTargets);
InjectTraceForGep(F, GepTraceTargets);
- return true;
}
-GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection(
+GlobalVariable *ModuleSanitizerCoverage::CreateFunctionLocalArrayInSection(
size_t NumElements, Function &F, Type *Ty, const char *Section) {
ArrayType *ArrayTy = ArrayType::get(Ty, NumElements);
auto Array = new GlobalVariable(
@@ -608,8 +652,9 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection(
GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId))
Array->setComdat(Comdat);
Array->setSection(getSectionName(Section));
- Array->setAlignment(Ty->isPointerTy() ? DL->getPointerSize()
- : Ty->getPrimitiveSizeInBits() / 8);
+ Array->setAlignment(Align(Ty->isPointerTy()
+ ? DL->getPointerSize()
+ : Ty->getPrimitiveSizeInBits() / 8));
GlobalsToAppendToUsed.push_back(Array);
GlobalsToAppendToCompilerUsed.push_back(Array);
MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F));
@@ -619,7 +664,7 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection(
}
GlobalVariable *
-SanitizerCoverageModule::CreatePCArray(Function &F,
+ModuleSanitizerCoverage::CreatePCArray(Function &F,
ArrayRef<BasicBlock *> AllBlocks) {
size_t N = AllBlocks.size();
assert(N);
@@ -646,7 +691,7 @@ SanitizerCoverageModule::CreatePCArray(Function &F,
return PCArray;
}
-void SanitizerCoverageModule::CreateFunctionLocalArrays(
+void ModuleSanitizerCoverage::CreateFunctionLocalArrays(
Function &F, ArrayRef<BasicBlock *> AllBlocks) {
if (Options.TracePCGuard)
FunctionGuardArray = CreateFunctionLocalArrayInSection(
@@ -660,7 +705,7 @@ void SanitizerCoverageModule::CreateFunctionLocalArrays(
FunctionPCsArray = CreatePCArray(F, AllBlocks);
}
-bool SanitizerCoverageModule::InjectCoverage(Function &F,
+bool ModuleSanitizerCoverage::InjectCoverage(Function &F,
ArrayRef<BasicBlock *> AllBlocks,
bool IsLeafFunc) {
if (AllBlocks.empty()) return false;
@@ -677,7 +722,7 @@ bool SanitizerCoverageModule::InjectCoverage(Function &F,
// The cache is used to speed up recording the caller-callee pairs.
// The address of the caller is passed implicitly via caller PC.
// CacheSize is encoded in the name of the run-time function.
-void SanitizerCoverageModule::InjectCoverageForIndirectCalls(
+void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls(
Function &F, ArrayRef<Instruction *> IndirCalls) {
if (IndirCalls.empty())
return;
@@ -696,7 +741,7 @@ void SanitizerCoverageModule::InjectCoverageForIndirectCalls(
// __sanitizer_cov_trace_switch(CondValue,
// {NumCases, ValueSizeInBits, Case0Value, Case1Value, Case2Value, ... })
-void SanitizerCoverageModule::InjectTraceForSwitch(
+void ModuleSanitizerCoverage::InjectTraceForSwitch(
Function &, ArrayRef<Instruction *> SwitchTraceTargets) {
for (auto I : SwitchTraceTargets) {
if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) {
@@ -735,7 +780,7 @@ void SanitizerCoverageModule::InjectTraceForSwitch(
}
}
-void SanitizerCoverageModule::InjectTraceForDiv(
+void ModuleSanitizerCoverage::InjectTraceForDiv(
Function &, ArrayRef<BinaryOperator *> DivTraceTargets) {
for (auto BO : DivTraceTargets) {
IRBuilder<> IRB(BO);
@@ -753,7 +798,7 @@ void SanitizerCoverageModule::InjectTraceForDiv(
}
}
-void SanitizerCoverageModule::InjectTraceForGep(
+void ModuleSanitizerCoverage::InjectTraceForGep(
Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) {
for (auto GEP : GepTraceTargets) {
IRBuilder<> IRB(GEP);
@@ -764,7 +809,7 @@ void SanitizerCoverageModule::InjectTraceForGep(
}
}
-void SanitizerCoverageModule::InjectTraceForCmp(
+void ModuleSanitizerCoverage::InjectTraceForCmp(
Function &, ArrayRef<Instruction *> CmpTraceTargets) {
for (auto I : CmpTraceTargets) {
if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) {
@@ -799,7 +844,7 @@ void SanitizerCoverageModule::InjectTraceForCmp(
}
}
-void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
+void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
size_t Idx,
bool IsLeafFunc) {
BasicBlock::iterator IP = BB.getFirstInsertionPt();
@@ -842,8 +887,10 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
}
if (Options.StackDepth && IsEntryBB && !IsLeafFunc) {
// Check stack depth. If it's the deepest so far, record it.
- Function *GetFrameAddr =
- Intrinsic::getDeclaration(F.getParent(), Intrinsic::frameaddress);
+ Module *M = F.getParent();
+ Function *GetFrameAddr = Intrinsic::getDeclaration(
+ M, Intrinsic::frameaddress,
+ IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
auto FrameAddrPtr =
IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)});
auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy);
@@ -858,7 +905,7 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
}
std::string
-SanitizerCoverageModule::getSectionName(const std::string &Section) const {
+ModuleSanitizerCoverage::getSectionName(const std::string &Section) const {
if (TargetTriple.isOSBinFormatCOFF()) {
if (Section == SanCovCountersSectionName)
return ".SCOV$CM";
@@ -872,32 +919,29 @@ SanitizerCoverageModule::getSectionName(const std::string &Section) const {
}
std::string
-SanitizerCoverageModule::getSectionStart(const std::string &Section) const {
+ModuleSanitizerCoverage::getSectionStart(const std::string &Section) const {
if (TargetTriple.isOSBinFormatMachO())
return "\1section$start$__DATA$__" + Section;
return "__start___" + Section;
}
std::string
-SanitizerCoverageModule::getSectionEnd(const std::string &Section) const {
+ModuleSanitizerCoverage::getSectionEnd(const std::string &Section) const {
if (TargetTriple.isOSBinFormatMachO())
return "\1section$end$__DATA$__" + Section;
return "__stop___" + Section;
}
-
-char SanitizerCoverageModule::ID = 0;
-INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov",
- "SanitizerCoverage: TODO."
- "ModulePass",
- false, false)
+char ModuleSanitizerCoverageLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(ModuleSanitizerCoverageLegacyPass, "sancov",
+ "Pass for instrumenting coverage on functions", false,
+ false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_END(SanitizerCoverageModule, "sancov",
- "SanitizerCoverage: TODO."
- "ModulePass",
- false, false)
-ModulePass *llvm::createSanitizerCoverageModulePass(
+INITIALIZE_PASS_END(ModuleSanitizerCoverageLegacyPass, "sancov",
+ "Pass for instrumenting coverage on functions", false,
+ false)
+ModulePass *llvm::createModuleSanitizerCoverageLegacyPassPass(
const SanitizerCoverageOptions &Options) {
- return new SanitizerCoverageModule(Options);
+ return new ModuleSanitizerCoverageLegacyPass(Options);
}
diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
index 5be13fa745cb..ac274a155a80 100644
--- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
+++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
@@ -92,11 +92,10 @@ namespace {
/// ensures the __tsan_init function is in the list of global constructors for
/// the module.
struct ThreadSanitizer {
- ThreadSanitizer(Module &M);
bool sanitizeFunction(Function &F, const TargetLibraryInfo &TLI);
private:
- void initializeCallbacks(Module &M);
+ void initialize(Module &M);
bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
bool instrumentAtomic(Instruction *I, const DataLayout &DL);
bool instrumentMemIntrinsic(Instruction *I);
@@ -108,8 +107,6 @@ private:
void InsertRuntimeIgnores(Function &F);
Type *IntptrTy;
- IntegerType *OrdTy;
- // Callbacks to run-time library are computed in doInitialization.
FunctionCallee TsanFuncEntry;
FunctionCallee TsanFuncExit;
FunctionCallee TsanIgnoreBegin;
@@ -130,7 +127,6 @@ private:
FunctionCallee TsanVptrUpdate;
FunctionCallee TsanVptrLoad;
FunctionCallee MemmoveFn, MemcpyFn, MemsetFn;
- Function *TsanCtorFunction;
};
struct ThreadSanitizerLegacyPass : FunctionPass {
@@ -143,16 +139,32 @@ struct ThreadSanitizerLegacyPass : FunctionPass {
private:
Optional<ThreadSanitizer> TSan;
};
+
+void insertModuleCtor(Module &M) {
+ getOrCreateSanitizerCtorAndInitFunctions(
+ M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{},
+ /*InitArgs=*/{},
+ // This callback is invoked when the functions are created the first
+ // time. Hook them into the global ctors list in that case:
+ [&](Function *Ctor, FunctionCallee) { appendToGlobalCtors(M, Ctor, 0); });
+}
+
} // namespace
PreservedAnalyses ThreadSanitizerPass::run(Function &F,
FunctionAnalysisManager &FAM) {
- ThreadSanitizer TSan(*F.getParent());
+ ThreadSanitizer TSan;
if (TSan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F)))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
+PreservedAnalyses ThreadSanitizerPass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ insertModuleCtor(M);
+ return PreservedAnalyses::none();
+}
+
char ThreadSanitizerLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(ThreadSanitizerLegacyPass, "tsan",
"ThreadSanitizer: detects data races.", false, false)
@@ -169,12 +181,13 @@ void ThreadSanitizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
}
bool ThreadSanitizerLegacyPass::doInitialization(Module &M) {
- TSan.emplace(M);
+ insertModuleCtor(M);
+ TSan.emplace();
return true;
}
bool ThreadSanitizerLegacyPass::runOnFunction(Function &F) {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
TSan->sanitizeFunction(F, TLI);
return true;
}
@@ -183,7 +196,10 @@ FunctionPass *llvm::createThreadSanitizerLegacyPassPass() {
return new ThreadSanitizerLegacyPass();
}
-void ThreadSanitizer::initializeCallbacks(Module &M) {
+void ThreadSanitizer::initialize(Module &M) {
+ const DataLayout &DL = M.getDataLayout();
+ IntptrTy = DL.getIntPtrType(M.getContext());
+
IRBuilder<> IRB(M.getContext());
AttributeList Attr;
Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex,
@@ -197,7 +213,7 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {
IRB.getVoidTy());
TsanIgnoreEnd =
M.getOrInsertFunction("__tsan_ignore_thread_end", Attr, IRB.getVoidTy());
- OrdTy = IRB.getInt32Ty();
+ IntegerType *OrdTy = IRB.getInt32Ty();
for (size_t i = 0; i < kNumberOfAccessSizes; ++i) {
const unsigned ByteSize = 1U << i;
const unsigned BitSize = ByteSize * 8;
@@ -280,20 +296,6 @@ void ThreadSanitizer::initializeCallbacks(Module &M) {
IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy);
}
-ThreadSanitizer::ThreadSanitizer(Module &M) {
- const DataLayout &DL = M.getDataLayout();
- IntptrTy = DL.getIntPtrType(M.getContext());
- std::tie(TsanCtorFunction, std::ignore) =
- getOrCreateSanitizerCtorAndInitFunctions(
- M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{},
- /*InitArgs=*/{},
- // This callback is invoked when the functions are created the first
- // time. Hook them into the global ctors list in that case:
- [&](Function *Ctor, FunctionCallee) {
- appendToGlobalCtors(M, Ctor, 0);
- });
-}
-
static bool isVtableAccess(Instruction *I) {
if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
return Tag->isTBAAVtableAccess();
@@ -436,9 +438,9 @@ bool ThreadSanitizer::sanitizeFunction(Function &F,
const TargetLibraryInfo &TLI) {
// This is required to prevent instrumenting call to __tsan_init from within
// the module constructor.
- if (&F == TsanCtorFunction)
+ if (F.getName() == kTsanModuleCtorName)
return false;
- initializeCallbacks(*F.getParent());
+ initialize(*F.getParent());
SmallVector<Instruction*, 8> AllLoadsAndStores;
SmallVector<Instruction*, 8> LocalLoadsAndStores;
SmallVector<Instruction*, 8> AtomicAccesses;
diff --git a/lib/Transforms/Instrumentation/ValueProfileCollector.cpp b/lib/Transforms/Instrumentation/ValueProfileCollector.cpp
new file mode 100644
index 000000000000..604726d4f40f
--- /dev/null
+++ b/lib/Transforms/Instrumentation/ValueProfileCollector.cpp
@@ -0,0 +1,78 @@
+//===- ValueProfileCollector.cpp - determine what to value profile --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The implementation of the ValueProfileCollector via ValueProfileCollectorImpl
+//
+//===----------------------------------------------------------------------===//
+
+#include "ValueProfilePlugins.inc"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/InitializePasses.h"
+
+#include <cassert>
+
+using namespace llvm;
+
+namespace {
+
+/// A plugin-based class that takes an arbitrary number of Plugin types.
+/// Each plugin type must satisfy the following API:
+/// 1) the constructor must take a `Function &f`. Typically, the plugin would
+/// scan the function looking for candidates.
+/// 2) contain a member function with the following signature and name:
+/// void run(std::vector<CandidateInfo> &Candidates);
+/// such that the plugin would append its result into the vector parameter.
+///
+/// Plugins are defined in ValueProfilePlugins.inc
+template <class... Ts> class PluginChain;
+
+/// The type PluginChainFinal is the final chain of plugins that will be used by
+/// ValueProfileCollectorImpl.
+using PluginChainFinal = PluginChain<VP_PLUGIN_LIST>;
+
+template <> class PluginChain<> {
+public:
+ PluginChain(Function &F) {}
+ void get(InstrProfValueKind K, std::vector<CandidateInfo> &Candidates) {}
+};
+
+template <class PluginT, class... Ts>
+class PluginChain<PluginT, Ts...> : public PluginChain<Ts...> {
+ PluginT Plugin;
+ using Base = PluginChain<Ts...>;
+
+public:
+ PluginChain(Function &F) : PluginChain<Ts...>(F), Plugin(F) {}
+
+ void get(InstrProfValueKind K, std::vector<CandidateInfo> &Candidates) {
+ if (K == PluginT::Kind)
+ Plugin.run(Candidates);
+ Base::get(K, Candidates);
+ }
+};
+
+} // end anonymous namespace
+
+/// ValueProfileCollectorImpl inherits the API of PluginChainFinal.
+class ValueProfileCollector::ValueProfileCollectorImpl : public PluginChainFinal {
+public:
+ using PluginChainFinal::PluginChainFinal;
+};
+
+ValueProfileCollector::ValueProfileCollector(Function &F)
+ : PImpl(new ValueProfileCollectorImpl(F)) {}
+
+ValueProfileCollector::~ValueProfileCollector() = default;
+
+std::vector<CandidateInfo>
+ValueProfileCollector::get(InstrProfValueKind Kind) const {
+ std::vector<CandidateInfo> Result;
+ PImpl->get(Kind, Result);
+ return Result;
+}
diff --git a/lib/Transforms/Instrumentation/ValueProfileCollector.h b/lib/Transforms/Instrumentation/ValueProfileCollector.h
new file mode 100644
index 000000000000..ff883c8d0c77
--- /dev/null
+++ b/lib/Transforms/Instrumentation/ValueProfileCollector.h
@@ -0,0 +1,79 @@
+//===- ValueProfileCollector.h - determine what to value profile ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a utility class, ValueProfileCollector, that is used to
+// determine what kind of llvm::Value's are worth value-profiling, at which
+// point in the program, and which instruction holds the Value Profile metadata.
+// Currently, the only users of this utility is the PGOInstrumentation[Gen|Use]
+// passes.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H
+#define LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H
+
+#include "llvm/IR/Function.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/ProfileData/InstrProf.h"
+
+namespace llvm {
+
+/// Utility analysis that determines what values are worth profiling.
+/// The actual logic is inside the ValueProfileCollectorImpl, whose job is to
+/// populate the Candidates vector.
+///
+/// Value profiling an expression means to track the values that this expression
+/// takes at runtime and the frequency of each value.
+/// It is important to distinguish between two sets of value profiles for a
+/// particular expression:
+/// 1) The set of values at the point of evaluation.
+/// 2) The set of values at the point of use.
+/// In some cases, the two sets are identical, but it's not unusual for the two
+/// to differ.
+///
+/// To elaborate more, consider this C code, and focus on the expression `nn`:
+/// void foo(int nn, bool b) {
+/// if (b) memcpy(x, y, nn);
+/// }
+/// The point of evaluation can be as early as the start of the function, and
+/// let's say the value profile for `nn` is:
+/// total=100; (value,freq) set = {(8,10), (32,50)}
+/// The point of use is right before we call memcpy, and since we execute the
+/// memcpy conditionally, the value profile of `nn` can be:
+/// total=15; (value,freq) set = {(8,10), (4,5)}
+///
+/// For this reason, a plugin is responsible for computing the insertion point
+/// for each value to be profiled. The `CandidateInfo` structure encapsulates
+/// all the information needed for each value profile site.
+class ValueProfileCollector {
+public:
+ struct CandidateInfo {
+ Value *V; // The value to profile.
+ Instruction *InsertPt; // Insert the VP lib call before this instr.
+ Instruction *AnnotatedInst; // Where metadata is attached.
+ };
+
+ ValueProfileCollector(Function &Fn);
+ ValueProfileCollector(ValueProfileCollector &&) = delete;
+ ValueProfileCollector &operator=(ValueProfileCollector &&) = delete;
+
+ ValueProfileCollector(const ValueProfileCollector &) = delete;
+ ValueProfileCollector &operator=(const ValueProfileCollector &) = delete;
+ ~ValueProfileCollector();
+
+ /// returns a list of value profiling candidates of the given kind
+ std::vector<CandidateInfo> get(InstrProfValueKind Kind) const;
+
+private:
+ class ValueProfileCollectorImpl;
+ std::unique_ptr<ValueProfileCollectorImpl> PImpl;
+};
+
+} // namespace llvm
+
+#endif
diff --git a/lib/Transforms/Instrumentation/ValueProfilePlugins.inc b/lib/Transforms/Instrumentation/ValueProfilePlugins.inc
new file mode 100644
index 000000000000..4cc4c6c848c3
--- /dev/null
+++ b/lib/Transforms/Instrumentation/ValueProfilePlugins.inc
@@ -0,0 +1,75 @@
+//=== ValueProfilePlugins.inc - set of plugins used by ValueProfileCollector =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of plugin classes used in ValueProfileCollectorImpl.
+// Each plugin is responsible for collecting Value Profiling candidates for a
+// particular optimization.
+// Each plugin must satisfy the interface described in ValueProfileCollector.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "ValueProfileCollector.h"
+#include "llvm/Analysis/IndirectCallVisitor.h"
+#include "llvm/IR/InstVisitor.h"
+
+using namespace llvm;
+using CandidateInfo = ValueProfileCollector::CandidateInfo;
+
+///--------------------------- MemIntrinsicPlugin ------------------------------
+class MemIntrinsicPlugin : public InstVisitor<MemIntrinsicPlugin> {
+ Function &F;
+ std::vector<CandidateInfo> *Candidates;
+
+public:
+ static constexpr InstrProfValueKind Kind = IPVK_MemOPSize;
+
+ MemIntrinsicPlugin(Function &Fn) : F(Fn), Candidates(nullptr) {}
+
+ void run(std::vector<CandidateInfo> &Cs) {
+ Candidates = &Cs;
+ visit(F);
+ Candidates = nullptr;
+ }
+ void visitMemIntrinsic(MemIntrinsic &MI) {
+ Value *Length = MI.getLength();
+ // Not instrument constant length calls.
+ if (dyn_cast<ConstantInt>(Length))
+ return;
+
+ Instruction *InsertPt = &MI;
+ Instruction *AnnotatedInst = &MI;
+ Candidates->emplace_back(CandidateInfo{Length, InsertPt, AnnotatedInst});
+ }
+};
+
+///------------------------ IndirectCallPromotionPlugin ------------------------
+class IndirectCallPromotionPlugin {
+ Function &F;
+
+public:
+ static constexpr InstrProfValueKind Kind = IPVK_IndirectCallTarget;
+
+ IndirectCallPromotionPlugin(Function &Fn) : F(Fn) {}
+
+ void run(std::vector<CandidateInfo> &Candidates) {
+ std::vector<Instruction *> Result = findIndirectCalls(F);
+ for (Instruction *I : Result) {
+ Value *Callee = CallSite(I).getCalledValue();
+ Instruction *InsertPt = I;
+ Instruction *AnnotatedInst = I;
+ Candidates.emplace_back(CandidateInfo{Callee, InsertPt, AnnotatedInst});
+ }
+ }
+};
+
+///----------------------- Registration of the plugins -------------------------
+/// For now, registering a plugin with the ValueProfileCollector is done by
+/// adding the plugin type to the VP_PLUGIN_LIST macro.
+#define VP_PLUGIN_LIST \
+ MemIntrinsicPlugin, \
+ IndirectCallPromotionPlugin
diff --git a/lib/Transforms/ObjCARC/PtrState.cpp b/lib/Transforms/ObjCARC/PtrState.cpp
index 3243481dee0d..26dd416d6184 100644
--- a/lib/Transforms/ObjCARC/PtrState.cpp
+++ b/lib/Transforms/ObjCARC/PtrState.cpp
@@ -275,6 +275,10 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst,
} else {
InsertAfter = std::next(Inst->getIterator());
}
+
+ if (InsertAfter != BB->end())
+ InsertAfter = skipDebugIntrinsics(InsertAfter);
+
InsertReverseInsertPt(&*InsertAfter);
};
diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index de9a62e88c27..0e9f03a06061 100644
--- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -93,9 +93,7 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
const SCEV *AlignSCEV,
ScalarEvolution *SE) {
// DiffUnits = Diff % int64_t(Alignment)
- const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV);
- const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV);
- const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV);
+ const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
<< *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
@@ -323,7 +321,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
LI->getPointerOperand(), SE);
if (NewAlignment > LI->getAlignment()) {
- LI->setAlignment(NewAlignment);
+ LI->setAlignment(MaybeAlign(NewAlignment));
++NumLoadAlignChanged;
}
} else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
@@ -331,7 +329,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
SI->getPointerOperand(), SE);
if (NewAlignment > SI->getAlignment()) {
- SI->setAlignment(NewAlignment);
+ SI->setAlignment(MaybeAlign(NewAlignment));
++NumStoreAlignChanged;
}
} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp
index 3519b000a33f..c3fba923104f 100644
--- a/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -562,7 +562,7 @@ struct CallSiteSplittingLegacyPass : public FunctionPass {
if (skipFunction(F))
return false;
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
return doCallSiteSplitting(F, TLI, TTI, DT);
diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp
index 98243a23f1ef..9f340afbf7c2 100644
--- a/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -204,7 +204,7 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst,
/// set found in \p BBs.
static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
BasicBlock *Entry,
- SmallPtrSet<BasicBlock *, 8> &BBs) {
+ SetVector<BasicBlock *> &BBs) {
assert(!BBs.count(Entry) && "Assume Entry is not in BBs");
// Nodes on the current path to the root.
SmallPtrSet<BasicBlock *, 8> Path;
@@ -257,7 +257,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
// Visit Orders in bottom-up order.
using InsertPtsCostPair =
- std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency>;
+ std::pair<SetVector<BasicBlock *>, BlockFrequency>;
// InsertPtsMap is a map from a BB to the best insertion points for the
// subtree of BB (subtree not including the BB itself).
@@ -266,7 +266,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) {
BasicBlock *Node = *RIt;
bool NodeInBBs = BBs.count(Node);
- SmallPtrSet<BasicBlock *, 16> &InsertPts = InsertPtsMap[Node].first;
+ auto &InsertPts = InsertPtsMap[Node].first;
BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second;
// Return the optimal insert points in BBs.
@@ -283,7 +283,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock();
// Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child
// will update its parent's ParentInsertPts and ParentPtsFreq.
- SmallPtrSet<BasicBlock *, 16> &ParentInsertPts = InsertPtsMap[Parent].first;
+ auto &ParentInsertPts = InsertPtsMap[Parent].first;
BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second;
// Choose to insert in Node or in subtree of Node.
// Don't hoist to EHPad because we may not find a proper place to insert
@@ -305,12 +305,12 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
}
/// Find an insertion point that dominates all uses.
-SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint(
+SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint(
const ConstantInfo &ConstInfo) const {
assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry.");
// Collect all basic blocks.
- SmallPtrSet<BasicBlock *, 8> BBs;
- SmallPtrSet<Instruction *, 8> InsertPts;
+ SetVector<BasicBlock *> BBs;
+ SetVector<Instruction *> InsertPts;
for (auto const &RCI : ConstInfo.RebasedConstants)
for (auto const &U : RCI.Uses)
BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent());
@@ -333,15 +333,13 @@ SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint(
while (BBs.size() >= 2) {
BasicBlock *BB, *BB1, *BB2;
- BB1 = *BBs.begin();
- BB2 = *std::next(BBs.begin());
+ BB1 = BBs.pop_back_val();
+ BB2 = BBs.pop_back_val();
BB = DT->findNearestCommonDominator(BB1, BB2);
if (BB == Entry) {
InsertPts.insert(&Entry->front());
return InsertPts;
}
- BBs.erase(BB1);
- BBs.erase(BB2);
BBs.insert(BB);
}
assert((BBs.size() == 1) && "Expected only one element.");
@@ -403,7 +401,7 @@ void ConstantHoistingPass::collectConstantCandidates(
return;
// Get offset from the base GV.
- PointerType *GVPtrTy = dyn_cast<PointerType>(BaseGV->getType());
+ PointerType *GVPtrTy = cast<PointerType>(BaseGV->getType());
IntegerType *PtrIntTy = DL->getIntPtrType(*Ctx, GVPtrTy->getAddressSpace());
APInt Offset(DL->getTypeSizeInBits(PtrIntTy), /*val*/0, /*isSigned*/true);
auto *GEPO = cast<GEPOperator>(ConstExpr);
@@ -830,7 +828,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec =
BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec;
for (auto const &ConstInfo : ConstInfoVec) {
- SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo);
+ SetVector<Instruction *> IPSet = findConstantInsertionPoint(ConstInfo);
// We can have an empty set if the function contains unreachable blocks.
if (IPSet.empty())
continue;
diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp
index 770321c740a0..e9e6afe3fdd4 100644
--- a/lib/Transforms/Scalar/ConstantProp.cpp
+++ b/lib/Transforms/Scalar/ConstantProp.cpp
@@ -82,7 +82,7 @@ bool ConstantPropagation::runOnFunction(Function &F) {
bool Changed = false;
const DataLayout &DL = F.getParent()->getDataLayout();
TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
while (!WorkList.empty()) {
SmallVector<Instruction*, 16> NewWorkListVec;
diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 89497177524f..2ef85268df48 100644
--- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -62,6 +62,23 @@ STATISTIC(NumSDivs, "Number of sdiv converted to udiv");
STATISTIC(NumUDivs, "Number of udivs whose width was decreased");
STATISTIC(NumAShrs, "Number of ashr converted to lshr");
STATISTIC(NumSRems, "Number of srem converted to urem");
+STATISTIC(NumSExt, "Number of sext converted to zext");
+STATISTIC(NumAnd, "Number of ands removed");
+STATISTIC(NumNW, "Number of no-wrap deductions");
+STATISTIC(NumNSW, "Number of no-signed-wrap deductions");
+STATISTIC(NumNUW, "Number of no-unsigned-wrap deductions");
+STATISTIC(NumAddNW, "Number of no-wrap deductions for add");
+STATISTIC(NumAddNSW, "Number of no-signed-wrap deductions for add");
+STATISTIC(NumAddNUW, "Number of no-unsigned-wrap deductions for add");
+STATISTIC(NumSubNW, "Number of no-wrap deductions for sub");
+STATISTIC(NumSubNSW, "Number of no-signed-wrap deductions for sub");
+STATISTIC(NumSubNUW, "Number of no-unsigned-wrap deductions for sub");
+STATISTIC(NumMulNW, "Number of no-wrap deductions for mul");
+STATISTIC(NumMulNSW, "Number of no-signed-wrap deductions for mul");
+STATISTIC(NumMulNUW, "Number of no-unsigned-wrap deductions for mul");
+STATISTIC(NumShlNW, "Number of no-wrap deductions for shl");
+STATISTIC(NumShlNSW, "Number of no-signed-wrap deductions for shl");
+STATISTIC(NumShlNUW, "Number of no-unsigned-wrap deductions for shl");
STATISTIC(NumOverflows, "Number of overflow checks removed");
STATISTIC(NumSaturating,
"Number of saturating arithmetics converted to normal arithmetics");
@@ -85,6 +102,7 @@ namespace {
AU.addRequired<LazyValueInfoWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addPreserved<LazyValueInfoWrapperPass>();
}
};
@@ -416,37 +434,96 @@ static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) {
return NWRegion.contains(LRange);
}
-static void processOverflowIntrinsic(WithOverflowInst *WO) {
- IRBuilder<> B(WO);
- Value *NewOp = B.CreateBinOp(
- WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), WO->getName());
- // Constant-folding could have happened.
- if (auto *Inst = dyn_cast<Instruction>(NewOp)) {
- if (WO->isSigned())
+static void setDeducedOverflowingFlags(Value *V, Instruction::BinaryOps Opcode,
+ bool NewNSW, bool NewNUW) {
+ Statistic *OpcNW, *OpcNSW, *OpcNUW;
+ switch (Opcode) {
+ case Instruction::Add:
+ OpcNW = &NumAddNW;
+ OpcNSW = &NumAddNSW;
+ OpcNUW = &NumAddNUW;
+ break;
+ case Instruction::Sub:
+ OpcNW = &NumSubNW;
+ OpcNSW = &NumSubNSW;
+ OpcNUW = &NumSubNUW;
+ break;
+ case Instruction::Mul:
+ OpcNW = &NumMulNW;
+ OpcNSW = &NumMulNSW;
+ OpcNUW = &NumMulNUW;
+ break;
+ case Instruction::Shl:
+ OpcNW = &NumShlNW;
+ OpcNSW = &NumShlNSW;
+ OpcNUW = &NumShlNUW;
+ break;
+ default:
+ llvm_unreachable("Will not be called with other binops");
+ }
+
+ auto *Inst = dyn_cast<Instruction>(V);
+ if (NewNSW) {
+ ++NumNW;
+ ++*OpcNW;
+ ++NumNSW;
+ ++*OpcNSW;
+ if (Inst)
Inst->setHasNoSignedWrap();
- else
+ }
+ if (NewNUW) {
+ ++NumNW;
+ ++*OpcNW;
+ ++NumNUW;
+ ++*OpcNUW;
+ if (Inst)
Inst->setHasNoUnsignedWrap();
}
+}
- Value *NewI = B.CreateInsertValue(UndefValue::get(WO->getType()), NewOp, 0);
- NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(WO->getContext()), 1);
+static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
+
+// Rewrite this with.overflow intrinsic as non-overflowing.
+static void processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) {
+ IRBuilder<> B(WO);
+ Instruction::BinaryOps Opcode = WO->getBinaryOp();
+ bool NSW = WO->isSigned();
+ bool NUW = !WO->isSigned();
+
+ Value *NewOp =
+ B.CreateBinOp(Opcode, WO->getLHS(), WO->getRHS(), WO->getName());
+ setDeducedOverflowingFlags(NewOp, Opcode, NSW, NUW);
+
+ StructType *ST = cast<StructType>(WO->getType());
+ Constant *Struct = ConstantStruct::get(ST,
+ { UndefValue::get(ST->getElementType(0)),
+ ConstantInt::getFalse(ST->getElementType(1)) });
+ Value *NewI = B.CreateInsertValue(Struct, NewOp, 0);
WO->replaceAllUsesWith(NewI);
WO->eraseFromParent();
++NumOverflows;
+
+ // See if we can infer the other no-wrap too.
+ if (auto *BO = dyn_cast<BinaryOperator>(NewOp))
+ processBinOp(BO, LVI);
}
-static void processSaturatingInst(SaturatingInst *SI) {
+static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {
+ Instruction::BinaryOps Opcode = SI->getBinaryOp();
+ bool NSW = SI->isSigned();
+ bool NUW = !SI->isSigned();
BinaryOperator *BinOp = BinaryOperator::Create(
- SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI);
+ Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI);
BinOp->setDebugLoc(SI->getDebugLoc());
- if (SI->isSigned())
- BinOp->setHasNoSignedWrap();
- else
- BinOp->setHasNoUnsignedWrap();
+ setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW);
SI->replaceAllUsesWith(BinOp);
SI->eraseFromParent();
++NumSaturating;
+
+ // See if we can infer the other no-wrap too.
+ if (auto *BO = dyn_cast<BinaryOperator>(BinOp))
+ processBinOp(BO, LVI);
}
/// Infer nonnull attributes for the arguments at the specified callsite.
@@ -456,14 +533,14 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) {
if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
- processOverflowIntrinsic(WO);
+ processOverflowIntrinsic(WO, LVI);
return true;
}
}
if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) {
if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
- processSaturatingInst(SI);
+ processSaturatingInst(SI, LVI);
return true;
}
}
@@ -632,6 +709,27 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
return true;
}
+static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
+ if (SDI->getType()->isVectorTy())
+ return false;
+
+ Value *Base = SDI->getOperand(0);
+
+ Constant *Zero = ConstantInt::get(Base->getType(), 0);
+ if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, Base, Zero, SDI) !=
+ LazyValueInfo::True)
+ return false;
+
+ ++NumSExt;
+ auto *ZExt =
+ CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI);
+ ZExt->setDebugLoc(SDI->getDebugLoc());
+ SDI->replaceAllUsesWith(ZExt);
+ SDI->eraseFromParent();
+
+ return true;
+}
+
static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
using OBO = OverflowingBinaryOperator;
@@ -648,6 +746,7 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
BasicBlock *BB = BinOp->getParent();
+ Instruction::BinaryOps Opcode = BinOp->getOpcode();
Value *LHS = BinOp->getOperand(0);
Value *RHS = BinOp->getOperand(1);
@@ -655,24 +754,48 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
ConstantRange RRange = LVI->getConstantRange(RHS, BB, BinOp);
bool Changed = false;
+ bool NewNUW = false, NewNSW = false;
if (!NUW) {
ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
- BinOp->getOpcode(), RRange, OBO::NoUnsignedWrap);
- bool NewNUW = NUWRange.contains(LRange);
- BinOp->setHasNoUnsignedWrap(NewNUW);
+ Opcode, RRange, OBO::NoUnsignedWrap);
+ NewNUW = NUWRange.contains(LRange);
Changed |= NewNUW;
}
if (!NSW) {
ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
- BinOp->getOpcode(), RRange, OBO::NoSignedWrap);
- bool NewNSW = NSWRange.contains(LRange);
- BinOp->setHasNoSignedWrap(NewNSW);
+ Opcode, RRange, OBO::NoSignedWrap);
+ NewNSW = NSWRange.contains(LRange);
Changed |= NewNSW;
}
+ setDeducedOverflowingFlags(BinOp, Opcode, NewNSW, NewNUW);
+
return Changed;
}
+static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) {
+ if (BinOp->getType()->isVectorTy())
+ return false;
+
+ // Pattern match (and lhs, C) where C includes a superset of bits which might
+ // be set in lhs. This is a common truncation idiom created by instcombine.
+ BasicBlock *BB = BinOp->getParent();
+ Value *LHS = BinOp->getOperand(0);
+ ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1));
+ if (!RHS || !RHS->getValue().isMask())
+ return false;
+
+ ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp);
+ if (!LRange.getUnsignedMax().ule(RHS->getValue()))
+ return false;
+
+ BinOp->replaceAllUsesWith(LHS);
+ BinOp->eraseFromParent();
+ NumAnd++;
+ return true;
+}
+
+
static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) {
if (Constant *C = LVI->getConstant(V, At->getParent(), At))
return C;
@@ -740,10 +863,18 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
case Instruction::AShr:
BBChanged |= processAShr(cast<BinaryOperator>(II), LVI);
break;
+ case Instruction::SExt:
+ BBChanged |= processSExt(cast<SExtInst>(II), LVI);
+ break;
case Instruction::Add:
case Instruction::Sub:
+ case Instruction::Mul:
+ case Instruction::Shl:
BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI);
break;
+ case Instruction::And:
+ BBChanged |= processAnd(cast<BinaryOperator>(II), LVI);
+ break;
}
}
@@ -796,5 +927,6 @@ CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {
PreservedAnalyses PA;
PA.preserve<GlobalsAA>();
PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<LazyValueAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp
index 479e0ed74074..a79d775aa7f3 100644
--- a/lib/Transforms/Scalar/DCE.cpp
+++ b/lib/Transforms/Scalar/DCE.cpp
@@ -38,17 +38,19 @@ namespace {
//===--------------------------------------------------------------------===//
// DeadInstElimination pass implementation
//
- struct DeadInstElimination : public BasicBlockPass {
- static char ID; // Pass identification, replacement for typeid
- DeadInstElimination() : BasicBlockPass(ID) {
- initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry());
- }
- bool runOnBasicBlock(BasicBlock &BB) override {
- if (skipBasicBlock(BB))
- return false;
- auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr;
- bool Changed = false;
+struct DeadInstElimination : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ DeadInstElimination() : FunctionPass(ID) {
+ initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry());
+ }
+ bool runOnFunction(Function &F) override {
+ if (skipFunction(F))
+ return false;
+ auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
+ TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
+
+ bool Changed = false;
+ for (auto &BB : F) {
for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) {
Instruction *Inst = &*DI++;
if (isInstructionTriviallyDead(Inst, TLI)) {
@@ -60,13 +62,14 @@ namespace {
++DIEEliminated;
}
}
- return Changed;
}
+ return Changed;
+ }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
}
- };
+};
}
char DeadInstElimination::ID = 0;
@@ -154,7 +157,7 @@ struct DCELegacyPass : public FunctionPass {
return false;
auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr;
+ TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
return eliminateDeadCode(F, TLI);
}
diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp
index a81645745b48..685de82810ed 100644
--- a/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1254,8 +1254,9 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
auto *SI = new StoreInst(
ConstantInt::get(Earlier->getValueOperand()->getType(), Merged),
- Earlier->getPointerOperand(), false, Earlier->getAlignment(),
- Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite);
+ Earlier->getPointerOperand(), false,
+ MaybeAlign(Earlier->getAlignment()), Earlier->getOrdering(),
+ Earlier->getSyncScopeID(), DepWrite);
unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,
LLVMContext::MD_alias_scope,
@@ -1361,7 +1362,7 @@ public:
MemoryDependenceResults *MD =
&getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
return eliminateDeadStores(F, AA, MD, DT, TLI);
}
diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp
index 876681b4f9de..934853507478 100644
--- a/lib/Transforms/Scalar/DivRemPairs.cpp
+++ b/lib/Transforms/Scalar/DivRemPairs.cpp
@@ -1,4 +1,4 @@
-//===- DivRemPairs.cpp - Hoist/decompose division and remainder -*- C++ -*-===//
+//===- DivRemPairs.cpp - Hoist/[dr]ecompose division and remainder --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This pass hoists and/or decomposes integer division and remainder
+// This pass hoists and/or decomposes/recomposes integer division and remainder
// instructions to enable CFG improvements and better codegen.
//
//===----------------------------------------------------------------------===//
@@ -19,37 +19,105 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Support/DebugCounter.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BypassSlowDivision.h"
+
using namespace llvm;
+using namespace llvm::PatternMatch;
#define DEBUG_TYPE "div-rem-pairs"
STATISTIC(NumPairs, "Number of div/rem pairs");
+STATISTIC(NumRecomposed, "Number of instructions recomposed");
STATISTIC(NumHoisted, "Number of instructions hoisted");
STATISTIC(NumDecomposed, "Number of instructions decomposed");
DEBUG_COUNTER(DRPCounter, "div-rem-pairs-transform",
"Controls transformations in div-rem-pairs pass");
-/// Find matching pairs of integer div/rem ops (they have the same numerator,
-/// denominator, and signedness). If they exist in different basic blocks, bring
-/// them together by hoisting or replace the common division operation that is
-/// implicit in the remainder:
-/// X % Y <--> X - ((X / Y) * Y).
-///
-/// We can largely ignore the normal safety and cost constraints on speculation
-/// of these ops when we find a matching pair. This is because we are already
-/// guaranteed that any exceptions and most cost are already incurred by the
-/// first member of the pair.
-///
-/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or
-/// SimplifyCFG, but it's split off on its own because it's different enough
-/// that it doesn't quite match the stated objectives of those passes.
-static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
- const DominatorTree &DT) {
- bool Changed = false;
+namespace {
+struct ExpandedMatch {
+ DivRemMapKey Key;
+ Instruction *Value;
+};
+} // namespace
+
+/// See if we can match: (which is the form we expand into)
+/// X - ((X ?/ Y) * Y)
+/// which is equivalent to:
+/// X ?% Y
+static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) {
+ Value *Dividend, *XroundedDownToMultipleOfY;
+ if (!match(&I, m_Sub(m_Value(Dividend), m_Value(XroundedDownToMultipleOfY))))
+ return llvm::None;
+
+ Value *Divisor;
+ Instruction *Div;
+ // Look for ((X / Y) * Y)
+ if (!match(
+ XroundedDownToMultipleOfY,
+ m_c_Mul(m_CombineAnd(m_IDiv(m_Specific(Dividend), m_Value(Divisor)),
+ m_Instruction(Div)),
+ m_Deferred(Divisor))))
+ return llvm::None;
+
+ ExpandedMatch M;
+ M.Key.SignedOp = Div->getOpcode() == Instruction::SDiv;
+ M.Key.Dividend = Dividend;
+ M.Key.Divisor = Divisor;
+ M.Value = &I;
+ return M;
+}
+
+/// A thin wrapper to store two values that we matched as div-rem pair.
+/// We want this extra indirection to avoid dealing with RAUW'ing the map keys.
+struct DivRemPairWorklistEntry {
+ /// The actual udiv/sdiv instruction. Source of truth.
+ AssertingVH<Instruction> DivInst;
+
+ /// The instruction that we have matched as a remainder instruction.
+ /// Should only be used as Value, don't introspect it.
+ AssertingVH<Instruction> RemInst;
+
+ DivRemPairWorklistEntry(Instruction *DivInst_, Instruction *RemInst_)
+ : DivInst(DivInst_), RemInst(RemInst_) {
+ assert((DivInst->getOpcode() == Instruction::UDiv ||
+ DivInst->getOpcode() == Instruction::SDiv) &&
+ "Not a division.");
+ assert(DivInst->getType() == RemInst->getType() && "Types should match.");
+ // We can't check anything else about remainder instruction,
+ // it's not strictly required to be a urem/srem.
+ }
+ /// The type for this pair, identical for both the div and rem.
+ Type *getType() const { return DivInst->getType(); }
+
+ /// Is this pair signed or unsigned?
+ bool isSigned() const { return DivInst->getOpcode() == Instruction::SDiv; }
+
+ /// In this pair, what are the divident and divisor?
+ Value *getDividend() const { return DivInst->getOperand(0); }
+ Value *getDivisor() const { return DivInst->getOperand(1); }
+
+ bool isRemExpanded() const {
+ switch (RemInst->getOpcode()) {
+ case Instruction::SRem:
+ case Instruction::URem:
+ return false; // single 'rem' instruction - unexpanded form.
+ default:
+ return true; // anything else means we have remainder in expanded form.
+ }
+ }
+};
+using DivRemWorklistTy = SmallVector<DivRemPairWorklistEntry, 4>;
+
+/// Find matching pairs of integer div/rem ops (they have the same numerator,
+/// denominator, and signedness). Place those pairs into a worklist for further
+/// processing. This indirection is needed because we have to use TrackingVH<>
+/// because we will be doing RAUW, and if one of the rem instructions we change
+/// happens to be an input to another div/rem in the maps, we'd have problems.
+static DivRemWorklistTy getWorklist(Function &F) {
// Insert all divide and remainder instructions into maps keyed by their
// operands and opcode (signed or unsigned).
DenseMap<DivRemMapKey, Instruction *> DivMap;
@@ -66,9 +134,14 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
RemMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I;
else if (I.getOpcode() == Instruction::URem)
RemMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I;
+ else if (auto Match = matchExpandedRem(I))
+ RemMap[Match->Key] = Match->Value;
}
}
+ // We'll accumulate the matching pairs of div-rem instructions here.
+ DivRemWorklistTy Worklist;
+
// We can iterate over either map because we are only looking for matched
// pairs. Choose remainders for efficiency because they are usually even more
// rare than division.
@@ -78,12 +151,77 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
if (!DivInst)
continue;
- // We have a matching pair of div/rem instructions. If one dominates the
- // other, hoist and/or replace one.
+ // We have a matching pair of div/rem instructions.
NumPairs++;
Instruction *RemInst = RemPair.second;
- bool IsSigned = DivInst->getOpcode() == Instruction::SDiv;
- bool HasDivRemOp = TTI.hasDivRemOp(DivInst->getType(), IsSigned);
+
+ // Place it in the worklist.
+ Worklist.emplace_back(DivInst, RemInst);
+ }
+
+ return Worklist;
+}
+
+/// Find matching pairs of integer div/rem ops (they have the same numerator,
+/// denominator, and signedness). If they exist in different basic blocks, bring
+/// them together by hoisting or replace the common division operation that is
+/// implicit in the remainder:
+/// X % Y <--> X - ((X / Y) * Y).
+///
+/// We can largely ignore the normal safety and cost constraints on speculation
+/// of these ops when we find a matching pair. This is because we are already
+/// guaranteed that any exceptions and most cost are already incurred by the
+/// first member of the pair.
+///
+/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or
+/// SimplifyCFG, but it's split off on its own because it's different enough
+/// that it doesn't quite match the stated objectives of those passes.
+static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
+ const DominatorTree &DT) {
+ bool Changed = false;
+
+ // Get the matching pairs of div-rem instructions. We want this extra
+ // indirection to avoid dealing with having to RAUW the keys of the maps.
+ DivRemWorklistTy Worklist = getWorklist(F);
+
+ // Process each entry in the worklist.
+ for (DivRemPairWorklistEntry &E : Worklist) {
+ if (!DebugCounter::shouldExecute(DRPCounter))
+ continue;
+
+ bool HasDivRemOp = TTI.hasDivRemOp(E.getType(), E.isSigned());
+
+ auto &DivInst = E.DivInst;
+ auto &RemInst = E.RemInst;
+
+ const bool RemOriginallyWasInExpandedForm = E.isRemExpanded();
+ (void)RemOriginallyWasInExpandedForm; // suppress unused variable warning
+
+ if (HasDivRemOp && E.isRemExpanded()) {
+ // The target supports div+rem but the rem is expanded.
+ // We should recompose it first.
+ Value *X = E.getDividend();
+ Value *Y = E.getDivisor();
+ Instruction *RealRem = E.isSigned() ? BinaryOperator::CreateSRem(X, Y)
+ : BinaryOperator::CreateURem(X, Y);
+ // Note that we place it right next to the original expanded instruction,
+ // and letting further handling to move it if needed.
+ RealRem->setName(RemInst->getName() + ".recomposed");
+ RealRem->insertAfter(RemInst);
+ Instruction *OrigRemInst = RemInst;
+ // Update AssertingVH<> with new instruction so it doesn't assert.
+ RemInst = RealRem;
+ // And replace the original instruction with the new one.
+ OrigRemInst->replaceAllUsesWith(RealRem);
+ OrigRemInst->eraseFromParent();
+ NumRecomposed++;
+ // Note that we have left ((X / Y) * Y) around.
+ // If it had other uses we could rewrite it as X - X % Y
+ }
+
+ assert((!E.isRemExpanded() || !HasDivRemOp) &&
+ "*If* the target supports div-rem, then by now the RemInst *is* "
+ "Instruction::[US]Rem.");
// If the target supports div+rem and the instructions are in the same block
// already, there's nothing to do. The backend should handle this. If the
@@ -92,10 +230,16 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
continue;
bool DivDominates = DT.dominates(DivInst, RemInst);
- if (!DivDominates && !DT.dominates(RemInst, DivInst))
+ if (!DivDominates && !DT.dominates(RemInst, DivInst)) {
+ // We have matching div-rem pair, but they are in two different blocks,
+ // neither of which dominates one another.
+ // FIXME: We could hoist both ops to the common predecessor block?
continue;
+ }
- if (!DebugCounter::shouldExecute(DRPCounter))
+ // The target does not have a single div/rem operation,
+ // and the rem is already in expanded form. Nothing to do.
+ if (!HasDivRemOp && E.isRemExpanded())
continue;
if (HasDivRemOp) {
@@ -107,11 +251,17 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
DivInst->moveAfter(RemInst);
NumHoisted++;
} else {
- // The target does not have a single div/rem operation. Decompose the
- // remainder calculation as:
+ // The target does not have a single div/rem operation,
+ // and the rem is *not* in a already-expanded form.
+ // Decompose the remainder calculation as:
// X % Y --> X - ((X / Y) * Y).
- Value *X = RemInst->getOperand(0);
- Value *Y = RemInst->getOperand(1);
+
+ assert(!RemOriginallyWasInExpandedForm &&
+ "We should not be expanding if the rem was in expanded form to "
+ "begin with.");
+
+ Value *X = E.getDividend();
+ Value *Y = E.getDivisor();
Instruction *Mul = BinaryOperator::CreateMul(DivInst, Y);
Instruction *Sub = BinaryOperator::CreateSub(X, Mul);
@@ -152,8 +302,13 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
// Now kill the explicit remainder. We have replaced it with:
// (sub X, (mul (div X, Y), Y)
- RemInst->replaceAllUsesWith(Sub);
- RemInst->eraseFromParent();
+ Sub->setName(RemInst->getName() + ".decomposed");
+ Instruction *OrigRemInst = RemInst;
+ // Update AssertingVH<> with new instruction so it doesn't assert.
+ RemInst = Sub;
+ // And replace the original instruction with the new one.
+ OrigRemInst->replaceAllUsesWith(Sub);
+ OrigRemInst->eraseFromParent();
NumDecomposed++;
}
Changed = true;
@@ -188,7 +343,7 @@ struct DivRemPairsLegacyPass : public FunctionPass {
return optimizeDivRem(F, TTI, DT);
}
};
-}
+} // namespace
char DivRemPairsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs",
diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp
index f1f075257020..ce540683dae2 100644
--- a/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -108,11 +108,12 @@ struct SimpleValue {
// This can only handle non-void readnone functions.
if (CallInst *CI = dyn_cast<CallInst>(Inst))
return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy();
- return isa<CastInst>(Inst) || isa<BinaryOperator>(Inst) ||
- isa<GetElementPtrInst>(Inst) || isa<CmpInst>(Inst) ||
- isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) ||
- isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) ||
- isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst);
+ return isa<CastInst>(Inst) || isa<UnaryOperator>(Inst) ||
+ isa<BinaryOperator>(Inst) || isa<GetElementPtrInst>(Inst) ||
+ isa<CmpInst>(Inst) || isa<SelectInst>(Inst) ||
+ isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
+ isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) ||
+ isa<InsertValueInst>(Inst);
}
};
@@ -240,7 +241,7 @@ static unsigned getHashValueImpl(SimpleValue Val) {
assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) ||
isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
- isa<ShuffleVectorInst>(Inst)) &&
+ isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst)) &&
"Invalid/unknown instruction");
// Mix in the opcode.
@@ -526,7 +527,7 @@ public:
const TargetTransformInfo &TTI, DominatorTree &DT,
AssumptionCache &AC, MemorySSA *MSSA)
: TLI(TLI), TTI(TTI), DT(DT), AC(AC), SQ(DL, &TLI, &DT, &AC), MSSA(MSSA),
- MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {}
+ MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) {}
bool run();
@@ -651,7 +652,7 @@ private:
bool isInvariantLoad() const {
if (auto *LI = dyn_cast<LoadInst>(Inst))
- return LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr;
+ return LI->hasMetadata(LLVMContext::MD_invariant_load);
return false;
}
@@ -790,7 +791,7 @@ bool EarlyCSE::isOperatingOnInvariantMemAt(Instruction *I, unsigned GenAt) {
// A location loaded from with an invariant_load is assumed to *never* change
// within the visible scope of the compilation.
if (auto *LI = dyn_cast<LoadInst>(I))
- if (LI->getMetadata(LLVMContext::MD_invariant_load))
+ if (LI->hasMetadata(LLVMContext::MD_invariant_load))
return true;
auto MemLocOpt = MemoryLocation::getOrNone(I);
@@ -1359,7 +1360,7 @@ public:
if (skipFunction(F))
return false;
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
@@ -1381,6 +1382,7 @@ public:
AU.addPreserved<MemorySSAWrapperPass>();
}
AU.addPreserved<GlobalsAAWrapperPass>();
+ AU.addPreserved<AAResultsWrapperPass>();
AU.setPreservesCFG();
}
};
diff --git a/lib/Transforms/Scalar/FlattenCFGPass.cpp b/lib/Transforms/Scalar/FlattenCFGPass.cpp
index 31670b1464e4..e6abf1ceb026 100644
--- a/lib/Transforms/Scalar/FlattenCFGPass.cpp
+++ b/lib/Transforms/Scalar/FlattenCFGPass.cpp
@@ -11,10 +11,12 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/IR/CFG.h"
+#include "llvm/IR/ValueHandle.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
+
using namespace llvm;
#define DEBUG_TYPE "flattencfg"
@@ -52,15 +54,23 @@ FunctionPass *llvm::createFlattenCFGPass() { return new FlattenCFGPass(); }
static bool iterativelyFlattenCFG(Function &F, AliasAnalysis *AA) {
bool Changed = false;
bool LocalChange = true;
+
+ // Use block handles instead of iterating over function blocks directly
+ // to avoid using iterators invalidated by erasing blocks.
+ std::vector<WeakVH> Blocks;
+ Blocks.reserve(F.size());
+ for (auto &BB : F)
+ Blocks.push_back(&BB);
+
while (LocalChange) {
LocalChange = false;
- // Loop over all of the basic blocks and remove them if they are unneeded...
- //
- for (Function::iterator BBIt = F.begin(); BBIt != F.end();) {
- if (FlattenCFG(&*BBIt++, AA)) {
- LocalChange = true;
- }
+ // Loop over all of the basic blocks and try to flatten them.
+ for (WeakVH &BlockHandle : Blocks) {
+ // Skip blocks erased by FlattenCFG.
+ if (auto *BB = cast_or_null<BasicBlock>(BlockHandle))
+ if (FlattenCFG(BB, AA))
+ LocalChange = true;
}
Changed |= LocalChange;
}
diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp
index 4f83e869b303..4d2eac0451df 100644
--- a/lib/Transforms/Scalar/Float2Int.cpp
+++ b/lib/Transforms/Scalar/Float2Int.cpp
@@ -60,11 +60,13 @@ namespace {
if (skipFunction(F))
return false;
- return Impl.runImpl(F);
+ const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ return Impl.runImpl(F, DT);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
+ AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
@@ -116,21 +118,29 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
// Find the roots - instructions that convert from the FP domain to
// integer domain.
-void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) {
- for (auto &I : instructions(F)) {
- if (isa<VectorType>(I.getType()))
+void Float2IntPass::findRoots(Function &F, const DominatorTree &DT,
+ SmallPtrSet<Instruction*,8> &Roots) {
+ for (BasicBlock &BB : F) {
+ // Unreachable code can take on strange forms that we are not prepared to
+ // handle. For example, an instruction may have itself as an operand.
+ if (!DT.isReachableFromEntry(&BB))
continue;
- switch (I.getOpcode()) {
- default: break;
- case Instruction::FPToUI:
- case Instruction::FPToSI:
- Roots.insert(&I);
- break;
- case Instruction::FCmp:
- if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
- CmpInst::BAD_ICMP_PREDICATE)
+
+ for (Instruction &I : BB) {
+ if (isa<VectorType>(I.getType()))
+ continue;
+ switch (I.getOpcode()) {
+ default: break;
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
Roots.insert(&I);
- break;
+ break;
+ case Instruction::FCmp:
+ if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
+ CmpInst::BAD_ICMP_PREDICATE)
+ Roots.insert(&I);
+ break;
+ }
}
}
}
@@ -503,7 +513,7 @@ void Float2IntPass::cleanup() {
I.first->eraseFromParent();
}
-bool Float2IntPass::runImpl(Function &F) {
+bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
// Clear out all state.
ECs = EquivalenceClasses<Instruction*>();
@@ -513,7 +523,7 @@ bool Float2IntPass::runImpl(Function &F) {
Ctx = &F.getParent()->getContext();
- findRoots(F, Roots);
+ findRoots(F, DT, Roots);
walkBackwards(Roots);
walkForwards();
@@ -527,8 +537,9 @@ bool Float2IntPass::runImpl(Function &F) {
namespace llvm {
FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
-PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) {
- if (!runImpl(F))
+PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
+ const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ if (!runImpl(F, DT))
return PreservedAnalyses::all();
PreservedAnalyses PA;
diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp
index 1a02e9d33f49..743353eaea22 100644
--- a/lib/Transforms/Scalar/GVN.cpp
+++ b/lib/Transforms/Scalar/GVN.cpp
@@ -70,6 +70,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
@@ -626,6 +627,8 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<GlobalsAA>();
PA.preserve<TargetLibraryAnalysis>();
+ if (LI)
+ PA.preserve<LoopAnalysis>();
return PA;
}
@@ -1161,15 +1164,30 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock,
// Do PHI translation to get its value in the predecessor if necessary. The
// returned pointer (if non-null) is guaranteed to dominate UnavailablePred.
+ // We do the translation for each edge we skipped by going from LI's block
+ // to LoadBB, otherwise we might miss pieces needing translation.
// If all preds have a single successor, then we know it is safe to insert
// the load on the pred (?!?), so we can insert code to materialize the
// pointer if it is not available.
- PHITransAddr Address(LI->getPointerOperand(), DL, AC);
- Value *LoadPtr = nullptr;
- LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred,
- *DT, NewInsts);
+ Value *LoadPtr = LI->getPointerOperand();
+ BasicBlock *Cur = LI->getParent();
+ while (Cur != LoadBB) {
+ PHITransAddr Address(LoadPtr, DL, AC);
+ LoadPtr = Address.PHITranslateWithInsertion(
+ Cur, Cur->getSinglePredecessor(), *DT, NewInsts);
+ if (!LoadPtr) {
+ CanDoPRE = false;
+ break;
+ }
+ Cur = Cur->getSinglePredecessor();
+ }
+ if (LoadPtr) {
+ PHITransAddr Address(LoadPtr, DL, AC);
+ LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT,
+ NewInsts);
+ }
// If we couldn't find or insert a computation of this phi translated value,
// we fail PRE.
if (!LoadPtr) {
@@ -1184,8 +1202,12 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock,
if (!CanDoPRE) {
while (!NewInsts.empty()) {
- Instruction *I = NewInsts.pop_back_val();
- markInstructionForDeletion(I);
+ // Erase instructions generated by the failed PHI translation before
+ // trying to number them. PHI translation might insert instructions
+ // in basic blocks other than the current one, and we delete them
+ // directly, as markInstructionForDeletion only allows removing from the
+ // current basic block.
+ NewInsts.pop_back_val()->eraseFromParent();
}
// HINT: Don't revert the edge-splitting as following transformation may
// also need to split these critical edges.
@@ -1219,10 +1241,10 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock,
BasicBlock *UnavailablePred = PredLoad.first;
Value *LoadPtr = PredLoad.second;
- auto *NewLoad =
- new LoadInst(LI->getType(), LoadPtr, LI->getName() + ".pre",
- LI->isVolatile(), LI->getAlignment(), LI->getOrdering(),
- LI->getSyncScopeID(), UnavailablePred->getTerminator());
+ auto *NewLoad = new LoadInst(
+ LI->getType(), LoadPtr, LI->getName() + ".pre", LI->isVolatile(),
+ MaybeAlign(LI->getAlignment()), LI->getOrdering(), LI->getSyncScopeID(),
+ UnavailablePred->getTerminator());
NewLoad->setDebugLoc(LI->getDebugLoc());
// Transfer the old load's AA tags to the new load.
@@ -1365,6 +1387,14 @@ bool GVN::processNonLocalLoad(LoadInst *LI) {
return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks);
}
+static bool hasUsersIn(Value *V, BasicBlock *BB) {
+ for (User *U : V->users())
+ if (isa<Instruction>(U) &&
+ cast<Instruction>(U)->getParent() == BB)
+ return true;
+ return false;
+}
+
bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
assert(IntrinsicI->getIntrinsicID() == Intrinsic::assume &&
"This function can only be called with llvm.assume intrinsic");
@@ -1403,12 +1433,23 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
// We can replace assume value with true, which covers cases like this:
// call void @llvm.assume(i1 %cmp)
// br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true
- ReplaceWithConstMap[V] = True;
-
- // If one of *cmp *eq operand is const, adding it to map will cover this:
+ ReplaceOperandsWithMap[V] = True;
+
+ // If we find an equality fact, canonicalize all dominated uses in this block
+ // to one of the two values. We heuristically choice the "oldest" of the
+ // two where age is determined by value number. (Note that propagateEquality
+ // above handles the cross block case.)
+ //
+ // Key case to cover are:
+ // 1)
// %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen
// call void @llvm.assume(i1 %cmp)
// ret float %0 ; will change it to ret float 3.000000e+00
+ // 2)
+ // %load = load float, float* %addr
+ // %cmp = fcmp oeq float %load, %0
+ // call void @llvm.assume(i1 %cmp)
+ // ret float %load ; will change it to ret float %0
if (auto *CmpI = dyn_cast<CmpInst>(V)) {
if (CmpI->getPredicate() == CmpInst::Predicate::ICMP_EQ ||
CmpI->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
@@ -1416,13 +1457,50 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {
CmpI->getFastMathFlags().noNaNs())) {
Value *CmpLHS = CmpI->getOperand(0);
Value *CmpRHS = CmpI->getOperand(1);
- if (isa<Constant>(CmpLHS))
+ // Heuristically pick the better replacement -- the choice of heuristic
+ // isn't terribly important here, but the fact we canonicalize on some
+ // replacement is for exposing other simplifications.
+ // TODO: pull this out as a helper function and reuse w/existing
+ // (slightly different) logic.
+ if (isa<Constant>(CmpLHS) && !isa<Constant>(CmpRHS))
std::swap(CmpLHS, CmpRHS);
- auto *RHSConst = dyn_cast<Constant>(CmpRHS);
+ if (!isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS))
+ std::swap(CmpLHS, CmpRHS);
+ if ((isa<Argument>(CmpLHS) && isa<Argument>(CmpRHS)) ||
+ (isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS))) {
+ // Move the 'oldest' value to the right-hand side, using the value
+ // number as a proxy for age.
+ uint32_t LVN = VN.lookupOrAdd(CmpLHS);
+ uint32_t RVN = VN.lookupOrAdd(CmpRHS);
+ if (LVN < RVN)
+ std::swap(CmpLHS, CmpRHS);
+ }
- // If only one operand is constant.
- if (RHSConst != nullptr && !isa<Constant>(CmpLHS))
- ReplaceWithConstMap[CmpLHS] = RHSConst;
+ // Handle degenerate case where we either haven't pruned a dead path or a
+ // removed a trivial assume yet.
+ if (isa<Constant>(CmpLHS) && isa<Constant>(CmpRHS))
+ return Changed;
+
+ // +0.0 and -0.0 compare equal, but do not imply equivalence. Unless we
+ // can prove equivalence, bail.
+ if (CmpRHS->getType()->isFloatTy() &&
+ (!isa<ConstantFP>(CmpRHS) || cast<ConstantFP>(CmpRHS)->isZero()))
+ return Changed;
+
+ LLVM_DEBUG(dbgs() << "Replacing dominated uses of "
+ << *CmpLHS << " with "
+ << *CmpRHS << " in block "
+ << IntrinsicI->getParent()->getName() << "\n");
+
+
+ // Setup the replacement map - this handles uses within the same block
+ if (hasUsersIn(CmpLHS, IntrinsicI->getParent()))
+ ReplaceOperandsWithMap[CmpLHS] = CmpRHS;
+
+ // NOTE: The non-block local cases are handled by the call to
+ // propagateEquality above; this block is just about handling the block
+ // local cases. TODO: There's a bunch of logic in propagateEqualiy which
+ // isn't duplicated for the block local case, can we share it somehow?
}
}
return Changed;
@@ -1522,6 +1600,41 @@ uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred,
return NewNum;
}
+// Return true if the value number \p Num and NewNum have equal value.
+// Return false if the result is unknown.
+bool GVN::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum,
+ const BasicBlock *Pred,
+ const BasicBlock *PhiBlock, GVN &Gvn) {
+ CallInst *Call = nullptr;
+ LeaderTableEntry *Vals = &Gvn.LeaderTable[Num];
+ while (Vals) {
+ Call = dyn_cast<CallInst>(Vals->Val);
+ if (Call && Call->getParent() == PhiBlock)
+ break;
+ Vals = Vals->Next;
+ }
+
+ if (AA->doesNotAccessMemory(Call))
+ return true;
+
+ if (!MD || !AA->onlyReadsMemory(Call))
+ return false;
+
+ MemDepResult local_dep = MD->getDependency(Call);
+ if (!local_dep.isNonLocal())
+ return false;
+
+ const MemoryDependenceResults::NonLocalDepInfo &deps =
+ MD->getNonLocalCallDependency(Call);
+
+ // Check to see if the Call has no function local clobber.
+ for (unsigned i = 0; i < deps.size(); i++) {
+ if (deps[i].getResult().isNonFuncLocal())
+ return true;
+ }
+ return false;
+}
+
/// Translate value number \p Num using phis, so that it has the values of
/// the phis in BB.
uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred,
@@ -1568,8 +1681,11 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred,
}
}
- if (uint32_t NewNum = expressionNumbering[Exp])
+ if (uint32_t NewNum = expressionNumbering[Exp]) {
+ if (Exp.opcode == Instruction::Call && NewNum != Num)
+ return areCallValsEqual(Num, NewNum, Pred, PhiBlock, Gvn) ? NewNum : Num;
return NewNum;
+ }
return Num;
}
@@ -1637,16 +1753,12 @@ void GVN::assignBlockRPONumber(Function &F) {
InvalidBlockRPONumbers = false;
}
-// Tries to replace instruction with const, using information from
-// ReplaceWithConstMap.
-bool GVN::replaceOperandsWithConsts(Instruction *Instr) const {
+bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const {
bool Changed = false;
for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) {
- Value *Operand = Instr->getOperand(OpNum);
- auto it = ReplaceWithConstMap.find(Operand);
- if (it != ReplaceWithConstMap.end()) {
- assert(!isa<Constant>(Operand) &&
- "Replacing constants with constants is invalid");
+ Value *Operand = Instr->getOperand(OpNum);
+ auto it = ReplaceOperandsWithMap.find(Operand);
+ if (it != ReplaceOperandsWithMap.end()) {
LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with "
<< *it->second << " in instruction " << *Instr << '\n');
Instr->setOperand(OpNum, it->second);
@@ -1976,6 +2088,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
MD = RunMD;
ImplicitControlFlowTracking ImplicitCFT(DT);
ICF = &ImplicitCFT;
+ this->LI = LI;
VN.setMemDep(MD);
ORE = RunORE;
InvalidBlockRPONumbers = true;
@@ -2037,13 +2150,13 @@ bool GVN::processBlock(BasicBlock *BB) {
return false;
// Clearing map before every BB because it can be used only for single BB.
- ReplaceWithConstMap.clear();
+ ReplaceOperandsWithMap.clear();
bool ChangedFunction = false;
for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
BI != BE;) {
- if (!ReplaceWithConstMap.empty())
- ChangedFunction |= replaceOperandsWithConsts(&*BI);
+ if (!ReplaceOperandsWithMap.empty())
+ ChangedFunction |= replaceOperandsForInBlockEquality(&*BI);
ChangedFunction |= processInstruction(&*BI);
if (InstrsToErase.empty()) {
@@ -2335,7 +2448,7 @@ bool GVN::performPRE(Function &F) {
/// the block inserted to the critical edge.
BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) {
BasicBlock *BB =
- SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT));
+ SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT, LI));
if (MD)
MD->invalidateCachedPredecessors();
InvalidBlockRPONumbers = true;
@@ -2350,7 +2463,7 @@ bool GVN::splitCriticalEdges() {
do {
std::pair<Instruction *, unsigned> Edge = toSplit.pop_back_val();
SplitCriticalEdge(Edge.first, Edge.second,
- CriticalEdgeSplittingOptions(DT));
+ CriticalEdgeSplittingOptions(DT, LI));
} while (!toSplit.empty());
if (MD) MD->invalidateCachedPredecessors();
InvalidBlockRPONumbers = true;
@@ -2456,18 +2569,26 @@ void GVN::addDeadBlock(BasicBlock *BB) {
if (DeadBlocks.count(B))
continue;
+ // First, split the critical edges. This might also create additional blocks
+ // to preserve LoopSimplify form and adjust edges accordingly.
SmallVector<BasicBlock *, 4> Preds(pred_begin(B), pred_end(B));
for (BasicBlock *P : Preds) {
if (!DeadBlocks.count(P))
continue;
- if (isCriticalEdge(P->getTerminator(), GetSuccessorNumber(P, B))) {
+ if (llvm::any_of(successors(P),
+ [B](BasicBlock *Succ) { return Succ == B; }) &&
+ isCriticalEdge(P->getTerminator(), B)) {
if (BasicBlock *S = splitCriticalEdges(P, B))
DeadBlocks.insert(P = S);
}
+ }
- for (BasicBlock::iterator II = B->begin(); isa<PHINode>(II); ++II) {
- PHINode &Phi = cast<PHINode>(*II);
+ // Now undef the incoming values from the dead predecessors.
+ for (BasicBlock *P : predecessors(B)) {
+ if (!DeadBlocks.count(P))
+ continue;
+ for (PHINode &Phi : B->phis()) {
Phi.setIncomingValueForBlock(P, UndefValue::get(Phi.getType()));
if (MD)
MD->invalidateCachedPointerInfo(&Phi);
@@ -2544,10 +2665,11 @@ public:
return Impl.runImpl(
F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
getAnalysis<AAResultsWrapperPass>().getAAResults(),
- NoMemDepAnalysis ? nullptr
- : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(),
+ NoMemDepAnalysis
+ ? nullptr
+ : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(),
LIWP ? &LIWP->getLoopInfo() : nullptr,
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE());
}
@@ -2556,6 +2678,7 @@ public:
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
if (!NoMemDepAnalysis)
AU.addRequired<MemoryDependenceWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
@@ -2563,6 +2686,8 @@ public:
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
+ AU.addPreservedID(LoopSimplifyID);
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp
index 7614599653c4..c87e41484b13 100644
--- a/lib/Transforms/Scalar/GVNHoist.cpp
+++ b/lib/Transforms/Scalar/GVNHoist.cpp
@@ -257,7 +257,7 @@ public:
GVNHoist(DominatorTree *DT, PostDominatorTree *PDT, AliasAnalysis *AA,
MemoryDependenceResults *MD, MemorySSA *MSSA)
: DT(DT), PDT(PDT), AA(AA), MD(MD), MSSA(MSSA),
- MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {}
+ MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) {}
bool run(Function &F) {
NumFuncArgs = F.arg_size();
@@ -539,7 +539,7 @@ private:
// Check for unsafe hoistings due to side effects.
if (K == InsKind::Store) {
- if (hasEHOrLoadsOnPath(NewPt, dyn_cast<MemoryDef>(U), NBBsOnAllPaths))
+ if (hasEHOrLoadsOnPath(NewPt, cast<MemoryDef>(U), NBBsOnAllPaths))
return false;
} else if (hasEHOnPath(NewBB, OldBB, NBBsOnAllPaths))
return false;
@@ -889,19 +889,18 @@ private:
void updateAlignment(Instruction *I, Instruction *Repl) {
if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) {
- ReplacementLoad->setAlignment(
- std::min(ReplacementLoad->getAlignment(),
- cast<LoadInst>(I)->getAlignment()));
+ ReplacementLoad->setAlignment(MaybeAlign(std::min(
+ ReplacementLoad->getAlignment(), cast<LoadInst>(I)->getAlignment())));
++NumLoadsRemoved;
} else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) {
ReplacementStore->setAlignment(
- std::min(ReplacementStore->getAlignment(),
- cast<StoreInst>(I)->getAlignment()));
+ MaybeAlign(std::min(ReplacementStore->getAlignment(),
+ cast<StoreInst>(I)->getAlignment())));
++NumStoresRemoved;
} else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) {
ReplacementAlloca->setAlignment(
- std::max(ReplacementAlloca->getAlignment(),
- cast<AllocaInst>(I)->getAlignment()));
+ MaybeAlign(std::max(ReplacementAlloca->getAlignment(),
+ cast<AllocaInst>(I)->getAlignment())));
} else if (isa<CallInst>(Repl)) {
++NumCallsRemoved;
}
diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp
index e14f44bb7069..2697d7809568 100644
--- a/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/lib/Transforms/Scalar/GuardWidening.cpp
@@ -591,7 +591,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
else
Result = RC.getCheckInst();
}
-
+ assert(Result && "Failed to find result value");
Result->setName("wide.chk");
}
return true;
diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp
index f9fc698a4a9b..5519a00c12c9 100644
--- a/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -124,6 +124,11 @@ static cl::opt<bool>
DisableLFTR("disable-lftr", cl::Hidden, cl::init(false),
cl::desc("Disable Linear Function Test Replace optimization"));
+static cl::opt<bool>
+LoopPredication("indvars-predicate-loops", cl::Hidden, cl::init(false),
+ cl::desc("Predicate conditions in read only loops"));
+
+
namespace {
struct RewritePhi;
@@ -144,7 +149,11 @@ class IndVarSimplify {
bool rewriteNonIntegerIVs(Loop *L);
bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI);
- bool optimizeLoopExits(Loop *L);
+ /// Try to eliminate loop exits based on analyzeable exit counts
+ bool optimizeLoopExits(Loop *L, SCEVExpander &Rewriter);
+ /// Try to form loop invariant tests for loop exits by changing how many
+ /// iterations of the loop run when that is unobservable.
+ bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet);
bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter);
@@ -628,12 +637,30 @@ bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) {
// Okay, this instruction has a user outside of the current loop
// and varies predictably *inside* the loop. Evaluate the value it
- // contains when the loop exits, if possible.
+ // contains when the loop exits, if possible. We prefer to start with
+ // expressions which are true for all exits (so as to maximize
+ // expression reuse by the SCEVExpander), but resort to per-exit
+ // evaluation if that fails.
const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop());
- if (!SE->isLoopInvariant(ExitValue, L) ||
- !isSafeToExpand(ExitValue, *SE))
- continue;
-
+ if (isa<SCEVCouldNotCompute>(ExitValue) ||
+ !SE->isLoopInvariant(ExitValue, L) ||
+ !isSafeToExpand(ExitValue, *SE)) {
+ // TODO: This should probably be sunk into SCEV in some way; maybe a
+ // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
+ // most SCEV expressions and other recurrence types (e.g. shift
+ // recurrences). Is there existing code we can reuse?
+ const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i));
+ if (isa<SCEVCouldNotCompute>(ExitCount))
+ continue;
+ if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst)))
+ if (AddRec->getLoop() == L)
+ ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE);
+ if (isa<SCEVCouldNotCompute>(ExitValue) ||
+ !SE->isLoopInvariant(ExitValue, L) ||
+ !isSafeToExpand(ExitValue, *SE))
+ continue;
+ }
+
// Computing the value outside of the loop brings no benefit if it is
// definitely used inside the loop in a way which can not be optimized
// away. Avoid doing so unless we know we have a value which computes
@@ -804,7 +831,7 @@ bool IndVarSimplify::canLoopBeDeleted(
L->getExitingBlocks(ExitingBlocks);
SmallVector<BasicBlock *, 8> ExitBlocks;
L->getUniqueExitBlocks(ExitBlocks);
- if (ExitBlocks.size() > 1 || ExitingBlocks.size() > 1)
+ if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
return false;
BasicBlock *ExitBlock = ExitBlocks[0];
@@ -1654,6 +1681,10 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) {
return nullptr;
}
+ // if we reached this point then we are going to replace
+ // DU.NarrowUse with WideUse. Reattach DbgValue then.
+ replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT);
+
ExtendKindMap[DU.NarrowUse] = WideAddRec.second;
// Returning WideUse pushes it on the worklist.
return WideUse;
@@ -1779,14 +1810,9 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) {
DeadInsts.emplace_back(DU.NarrowDef);
}
- // Attach any debug information to the new PHI. Since OrigPhi and WidePHI
- // evaluate the same recurrence, we can just copy the debug info over.
- SmallVector<DbgValueInst *, 1> DbgValues;
- llvm::findDbgValues(DbgValues, OrigPhi);
- auto *MDPhi = MetadataAsValue::get(WidePhi->getContext(),
- ValueAsMetadata::get(WidePhi));
- for (auto &DbgValue : DbgValues)
- DbgValue->setOperand(0, MDPhi);
+ // Attach any debug information to the new PHI.
+ replaceAllDbgUsesWith(*OrigPhi, *WidePhi, *WidePhi, *DT);
+
return WidePhi;
}
@@ -1817,8 +1843,8 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,
auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS));
auto CmpConstrainedLHSRange =
ConstantRange::makeAllowedICmpRegion(P, CmpRHSRange);
- auto NarrowDefRange =
- CmpConstrainedLHSRange.addWithNoSignedWrap(*NarrowDefRHS);
+ auto NarrowDefRange = CmpConstrainedLHSRange.addWithNoWrap(
+ *NarrowDefRHS, OverflowingBinaryOperator::NoSignedWrap);
updatePostIncRangeInfo(NarrowDef, NarrowUser, NarrowDefRange);
};
@@ -2242,8 +2268,8 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy())
continue;
- const auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
-
+ const auto *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
+
// AR may be a pointer type, while BECount is an integer type.
// AR may be wider than BECount. With eq/ne tests overflow is immaterial.
// AR may not be a narrower type, or we may never exit.
@@ -2624,74 +2650,125 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) {
return MadeAnyChanges;
}
-bool IndVarSimplify::optimizeLoopExits(Loop *L) {
+/// Return a symbolic upper bound for the backedge taken count of the loop.
+/// This is more general than getConstantMaxBackedgeTakenCount as it returns
+/// an arbitrary expression as opposed to only constants.
+/// TODO: Move into the ScalarEvolution class.
+static const SCEV* getMaxBackedgeTakenCount(ScalarEvolution &SE,
+ DominatorTree &DT, Loop *L) {
SmallVector<BasicBlock*, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
// Form an expression for the maximum exit count possible for this loop. We
// merge the max and exact information to approximate a version of
- // getMaxBackedgeTakenInfo which isn't restricted to just constants.
- // TODO: factor this out as a version of getMaxBackedgeTakenCount which
- // isn't guaranteed to return a constant.
+ // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
SmallVector<const SCEV*, 4> ExitCounts;
- const SCEV *MaxConstEC = SE->getMaxBackedgeTakenCount(L);
+ const SCEV *MaxConstEC = SE.getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxConstEC))
ExitCounts.push_back(MaxConstEC);
for (BasicBlock *ExitingBB : ExitingBlocks) {
- const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
+ const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
if (!isa<SCEVCouldNotCompute>(ExitCount)) {
- assert(DT->dominates(ExitingBB, L->getLoopLatch()) &&
+ assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
"We should only have known counts for exiting blocks that "
"dominate latch!");
ExitCounts.push_back(ExitCount);
}
}
if (ExitCounts.empty())
- return false;
- const SCEV *MaxExitCount = SE->getUMinFromMismatchedTypes(ExitCounts);
+ return SE.getCouldNotCompute();
+ return SE.getUMinFromMismatchedTypes(ExitCounts);
+}
- bool Changed = false;
- for (BasicBlock *ExitingBB : ExitingBlocks) {
+bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
+ SmallVector<BasicBlock*, 16> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ // Remove all exits which aren't both rewriteable and analyzeable.
+ auto NewEnd = llvm::remove_if(ExitingBlocks,
+ [&](BasicBlock *ExitingBB) {
// If our exitting block exits multiple loops, we can only rewrite the
// innermost one. Otherwise, we're changing how many times the innermost
// loop runs before it exits.
if (LI->getLoopFor(ExitingBB) != L)
- continue;
+ return true;
// Can't rewrite non-branch yet.
BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
if (!BI)
- continue;
+ return true;
// If already constant, nothing to do.
if (isa<Constant>(BI->getCondition()))
- continue;
+ return true;
const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
if (isa<SCEVCouldNotCompute>(ExitCount))
- continue;
+ return true;
+ return false;
+ });
+ ExitingBlocks.erase(NewEnd, ExitingBlocks.end());
+
+ if (ExitingBlocks.empty())
+ return false;
+
+ // Get a symbolic upper bound on the loop backedge taken count.
+ const SCEV *MaxExitCount = getMaxBackedgeTakenCount(*SE, *DT, L);
+ if (isa<SCEVCouldNotCompute>(MaxExitCount))
+ return false;
+
+ // Visit our exit blocks in order of dominance. We know from the fact that
+ // all exits (left) are analyzeable that the must be a total dominance order
+ // between them as each must dominate the latch. The visit order only
+ // matters for the provably equal case.
+ llvm::sort(ExitingBlocks,
+ [&](BasicBlock *A, BasicBlock *B) {
+ // std::sort sorts in ascending order, so we want the inverse of
+ // the normal dominance relation.
+ if (DT->properlyDominates(A, B)) return true;
+ if (DT->properlyDominates(B, A)) return false;
+ llvm_unreachable("expected total dominance order!");
+ });
+#ifdef ASSERT
+ for (unsigned i = 1; i < ExitingBlocks.size(); i++) {
+ assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]));
+ }
+#endif
+
+ auto FoldExit = [&](BasicBlock *ExitingBB, bool IsTaken) {
+ BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
+ bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
+ auto *OldCond = BI->getCondition();
+ auto *NewCond = ConstantInt::get(OldCond->getType(),
+ IsTaken ? ExitIfTrue : !ExitIfTrue);
+ BI->setCondition(NewCond);
+ if (OldCond->use_empty())
+ DeadInsts.push_back(OldCond);
+ };
+ bool Changed = false;
+ SmallSet<const SCEV*, 8> DominatingExitCounts;
+ for (BasicBlock *ExitingBB : ExitingBlocks) {
+ const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
+ assert(!isa<SCEVCouldNotCompute>(ExitCount) && "checked above");
+
// If we know we'd exit on the first iteration, rewrite the exit to
// reflect this. This does not imply the loop must exit through this
// exit; there may be an earlier one taken on the first iteration.
// TODO: Given we know the backedge can't be taken, we should go ahead
// and break it. Or at least, kill all the header phis and simplify.
if (ExitCount->isZero()) {
- bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
- auto *OldCond = BI->getCondition();
- auto *NewCond = ExitIfTrue ? ConstantInt::getTrue(OldCond->getType()) :
- ConstantInt::getFalse(OldCond->getType());
- BI->setCondition(NewCond);
- if (OldCond->use_empty())
- DeadInsts.push_back(OldCond);
+ FoldExit(ExitingBB, true);
Changed = true;
continue;
}
- // If we end up with a pointer exit count, bail.
+ // If we end up with a pointer exit count, bail. Note that we can end up
+ // with a pointer exit count for one exiting block, and not for another in
+ // the same loop.
if (!ExitCount->getType()->isIntegerTy() ||
!MaxExitCount->getType()->isIntegerTy())
- return false;
+ continue;
Type *WiderType =
SE->getWiderType(MaxExitCount->getType(), ExitCount->getType());
@@ -2700,35 +2777,198 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L) {
assert(MaxExitCount->getType() == ExitCount->getType());
// Can we prove that some other exit must be taken strictly before this
- // one? TODO: handle cases where ule is known, and equality is covered
- // by a dominating exit
+ // one?
if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT,
MaxExitCount, ExitCount)) {
- bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
- auto *OldCond = BI->getCondition();
- auto *NewCond = ExitIfTrue ? ConstantInt::getFalse(OldCond->getType()) :
- ConstantInt::getTrue(OldCond->getType());
- BI->setCondition(NewCond);
- if (OldCond->use_empty())
- DeadInsts.push_back(OldCond);
+ FoldExit(ExitingBB, false);
Changed = true;
continue;
}
- // TODO: If we can prove that the exiting iteration is equal to the exit
- // count for this exit and that no previous exit oppurtunities exist within
- // the loop, then we can discharge all other exits. (May fall out of
- // previous TODO.)
-
- // TODO: If we can't prove any relation between our exit count and the
- // loops exit count, but taking this exit doesn't require actually running
- // the loop (i.e. no side effects, no computed values used in exit), then
- // we can replace the exit test with a loop invariant test which exits on
- // the first iteration.
+ // As we run, keep track of which exit counts we've encountered. If we
+ // find a duplicate, we've found an exit which would have exited on the
+ // exiting iteration, but (from the visit order) strictly follows another
+ // which does the same and is thus dead.
+ if (!DominatingExitCounts.insert(ExitCount).second) {
+ FoldExit(ExitingBB, false);
+ Changed = true;
+ continue;
+ }
+
+ // TODO: There might be another oppurtunity to leverage SCEV's reasoning
+ // here. If we kept track of the min of dominanting exits so far, we could
+ // discharge exits with EC >= MDEC. This is less powerful than the existing
+ // transform (since later exits aren't considered), but potentially more
+ // powerful for any case where SCEV can prove a >=u b, but neither a == b
+ // or a >u b. Such a case is not currently known.
}
return Changed;
}
+bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
+ SmallVector<BasicBlock*, 16> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ bool Changed = false;
+
+ // Finally, see if we can rewrite our exit conditions into a loop invariant
+ // form. If we have a read-only loop, and we can tell that we must exit down
+ // a path which does not need any of the values computed within the loop, we
+ // can rewrite the loop to exit on the first iteration. Note that this
+ // doesn't either a) tell us the loop exits on the first iteration (unless
+ // *all* exits are predicateable) or b) tell us *which* exit might be taken.
+ // This transformation looks a lot like a restricted form of dead loop
+ // elimination, but restricted to read-only loops and without neccesssarily
+ // needing to kill the loop entirely.
+ if (!LoopPredication)
+ return Changed;
+
+ if (!SE->hasLoopInvariantBackedgeTakenCount(L))
+ return Changed;
+
+ // Note: ExactBTC is the exact backedge taken count *iff* the loop exits
+ // through *explicit* control flow. We have to eliminate the possibility of
+ // implicit exits (see below) before we know it's truly exact.
+ const SCEV *ExactBTC = SE->getBackedgeTakenCount(L);
+ if (isa<SCEVCouldNotCompute>(ExactBTC) ||
+ !SE->isLoopInvariant(ExactBTC, L) ||
+ !isSafeToExpand(ExactBTC, *SE))
+ return Changed;
+
+ auto BadExit = [&](BasicBlock *ExitingBB) {
+ // If our exiting block exits multiple loops, we can only rewrite the
+ // innermost one. Otherwise, we're changing how many times the innermost
+ // loop runs before it exits.
+ if (LI->getLoopFor(ExitingBB) != L)
+ return true;
+
+ // Can't rewrite non-branch yet.
+ BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
+ if (!BI)
+ return true;
+
+ // If already constant, nothing to do.
+ if (isa<Constant>(BI->getCondition()))
+ return true;
+
+ // If the exit block has phis, we need to be able to compute the values
+ // within the loop which contains them. This assumes trivially lcssa phis
+ // have already been removed; TODO: generalize
+ BasicBlock *ExitBlock =
+ BI->getSuccessor(L->contains(BI->getSuccessor(0)) ? 1 : 0);
+ if (!ExitBlock->phis().empty())
+ return true;
+
+ const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
+ assert(!isa<SCEVCouldNotCompute>(ExactBTC) && "implied by having exact trip count");
+ if (!SE->isLoopInvariant(ExitCount, L) ||
+ !isSafeToExpand(ExitCount, *SE))
+ return true;
+
+ return false;
+ };
+
+ // If we have any exits which can't be predicated themselves, than we can't
+ // predicate any exit which isn't guaranteed to execute before it. Consider
+ // two exits (a) and (b) which would both exit on the same iteration. If we
+ // can predicate (b), but not (a), and (a) preceeds (b) along some path, then
+ // we could convert a loop from exiting through (a) to one exiting through
+ // (b). Note that this problem exists only for exits with the same exit
+ // count, and we could be more aggressive when exit counts are known inequal.
+ llvm::sort(ExitingBlocks,
+ [&](BasicBlock *A, BasicBlock *B) {
+ // std::sort sorts in ascending order, so we want the inverse of
+ // the normal dominance relation, plus a tie breaker for blocks
+ // unordered by dominance.
+ if (DT->properlyDominates(A, B)) return true;
+ if (DT->properlyDominates(B, A)) return false;
+ return A->getName() < B->getName();
+ });
+ // Check to see if our exit blocks are a total order (i.e. a linear chain of
+ // exits before the backedge). If they aren't, reasoning about reachability
+ // is complicated and we choose not to for now.
+ for (unsigned i = 1; i < ExitingBlocks.size(); i++)
+ if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]))
+ return Changed;
+
+ // Given our sorted total order, we know that exit[j] must be evaluated
+ // after all exit[i] such j > i.
+ for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++)
+ if (BadExit(ExitingBlocks[i])) {
+ ExitingBlocks.resize(i);
+ break;
+ }
+
+ if (ExitingBlocks.empty())
+ return Changed;
+
+ // We rely on not being able to reach an exiting block on a later iteration
+ // then it's statically compute exit count. The implementaton of
+ // getExitCount currently has this invariant, but assert it here so that
+ // breakage is obvious if this ever changes..
+ assert(llvm::all_of(ExitingBlocks, [&](BasicBlock *ExitingBB) {
+ return DT->dominates(ExitingBB, L->getLoopLatch());
+ }));
+
+ // At this point, ExitingBlocks consists of only those blocks which are
+ // predicatable. Given that, we know we have at least one exit we can
+ // predicate if the loop is doesn't have side effects and doesn't have any
+ // implicit exits (because then our exact BTC isn't actually exact).
+ // @Reviewers - As structured, this is O(I^2) for loop nests. Any
+ // suggestions on how to improve this? I can obviously bail out for outer
+ // loops, but that seems less than ideal. MemorySSA can find memory writes,
+ // is that enough for *all* side effects?
+ for (BasicBlock *BB : L->blocks())
+ for (auto &I : *BB)
+ // TODO:isGuaranteedToTransfer
+ if (I.mayHaveSideEffects() || I.mayThrow())
+ return Changed;
+
+ // Finally, do the actual predication for all predicatable blocks. A couple
+ // of notes here:
+ // 1) We don't bother to constant fold dominated exits with identical exit
+ // counts; that's simply a form of CSE/equality propagation and we leave
+ // it for dedicated passes.
+ // 2) We insert the comparison at the branch. Hoisting introduces additional
+ // legality constraints and we leave that to dedicated logic. We want to
+ // predicate even if we can't insert a loop invariant expression as
+ // peeling or unrolling will likely reduce the cost of the otherwise loop
+ // varying check.
+ Rewriter.setInsertPoint(L->getLoopPreheader()->getTerminator());
+ IRBuilder<> B(L->getLoopPreheader()->getTerminator());
+ Value *ExactBTCV = nullptr; //lazy generated if needed
+ for (BasicBlock *ExitingBB : ExitingBlocks) {
+ const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
+
+ auto *BI = cast<BranchInst>(ExitingBB->getTerminator());
+ Value *NewCond;
+ if (ExitCount == ExactBTC) {
+ NewCond = L->contains(BI->getSuccessor(0)) ?
+ B.getFalse() : B.getTrue();
+ } else {
+ Value *ECV = Rewriter.expandCodeFor(ExitCount);
+ if (!ExactBTCV)
+ ExactBTCV = Rewriter.expandCodeFor(ExactBTC);
+ Value *RHS = ExactBTCV;
+ if (ECV->getType() != RHS->getType()) {
+ Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
+ ECV = B.CreateZExt(ECV, WiderTy);
+ RHS = B.CreateZExt(RHS, WiderTy);
+ }
+ auto Pred = L->contains(BI->getSuccessor(0)) ?
+ ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
+ NewCond = B.CreateICmp(Pred, ECV, RHS);
+ }
+ Value *OldCond = BI->getCondition();
+ BI->setCondition(NewCond);
+ if (OldCond->use_empty())
+ DeadInsts.push_back(OldCond);
+ Changed = true;
+ }
+
+ return Changed;
+}
+
//===----------------------------------------------------------------------===//
// IndVarSimplify driver. Manage several subpasses of IV simplification.
//===----------------------------------------------------------------------===//
@@ -2755,7 +2995,10 @@ bool IndVarSimplify::run(Loop *L) {
// transform them to use integer recurrences.
Changed |= rewriteNonIntegerIVs(L);
+#ifndef NDEBUG
+ // Used below for a consistency check only
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
+#endif
// Create a rewriter object which we'll use to transform the code with.
SCEVExpander Rewriter(*SE, DL, "indvars");
@@ -2772,20 +3015,22 @@ bool IndVarSimplify::run(Loop *L) {
Rewriter.disableCanonicalMode();
Changed |= simplifyAndExtend(L, Rewriter, LI);
- // Check to see if this loop has a computable loop-invariant execution count.
- // If so, this means that we can compute the final value of any expressions
+ // Check to see if we can compute the final value of any expressions
// that are recurrent in the loop, and substitute the exit values from the
- // loop into any instructions outside of the loop that use the final values of
- // the current expressions.
- //
- if (ReplaceExitValue != NeverRepl &&
- !isa<SCEVCouldNotCompute>(BackedgeTakenCount))
+ // loop into any instructions outside of the loop that use the final values
+ // of the current expressions.
+ if (ReplaceExitValue != NeverRepl)
Changed |= rewriteLoopExitValues(L, Rewriter);
// Eliminate redundant IV cycles.
NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts);
- Changed |= optimizeLoopExits(L);
+ // Try to eliminate loop exits based on analyzeable exit counts
+ Changed |= optimizeLoopExits(L, Rewriter);
+
+ // Try to form loop invariant tests for loop exits by changing how many
+ // iterations of the loop run when that is unobservable.
+ Changed |= predicateLoopExits(L, Rewriter);
// If we have a trip count expression, rewrite the loop's exit condition
// using it.
@@ -2825,7 +3070,7 @@ bool IndVarSimplify::run(Loop *L) {
// that our definition of "high cost" is not exactly principled.
if (Rewriter.isHighCostExpansion(ExitCount, L))
continue;
-
+
// Check preconditions for proper SCEVExpander operation. SCEV does not
// express SCEVExpander's dependencies, such as LoopSimplify. Instead
// any pass that uses the SCEVExpander must do it. This does not work
@@ -2924,7 +3169,7 @@ struct IndVarSimplifyLegacyPass : public LoopPass {
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI() : nullptr;
+ auto *TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr;
auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr;
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 5f0e2001c73d..e7e73a132fbe 100644
--- a/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -141,6 +141,8 @@ using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>;
/// InferAddressSpaces
class InferAddressSpaces : public FunctionPass {
+ const TargetTransformInfo *TTI;
+
/// Target specific address space which uses of should be replaced if
/// possible.
unsigned FlatAddrSpace;
@@ -264,17 +266,6 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
Module *M = II->getParent()->getParent()->getParent();
switch (II->getIntrinsicID()) {
- case Intrinsic::amdgcn_atomic_inc:
- case Intrinsic::amdgcn_atomic_dec:
- case Intrinsic::amdgcn_ds_fadd:
- case Intrinsic::amdgcn_ds_fmin:
- case Intrinsic::amdgcn_ds_fmax: {
- const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4));
- if (!IsVolatile || !IsVolatile->isZero())
- return false;
-
- LLVM_FALLTHROUGH;
- }
case Intrinsic::objectsize: {
Type *DestTy = II->getType();
Type *SrcTy = NewV->getType();
@@ -285,25 +276,27 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
return true;
}
default:
- return false;
+ return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
}
}
-// TODO: Move logic to TTI?
void InferAddressSpaces::collectRewritableIntrinsicOperands(
IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack,
DenseSet<Value *> &Visited) const {
- switch (II->getIntrinsicID()) {
+ auto IID = II->getIntrinsicID();
+ switch (IID) {
case Intrinsic::objectsize:
- case Intrinsic::amdgcn_atomic_inc:
- case Intrinsic::amdgcn_atomic_dec:
- case Intrinsic::amdgcn_ds_fadd:
- case Intrinsic::amdgcn_ds_fmin:
- case Intrinsic::amdgcn_ds_fmax:
appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
PostorderStack, Visited);
break;
default:
+ SmallVector<int, 2> OpIndexes;
+ if (TTI->collectFlatAddressOperands(OpIndexes, IID)) {
+ for (int Idx : OpIndexes) {
+ appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(Idx),
+ PostorderStack, Visited);
+ }
+ }
break;
}
}
@@ -631,11 +624,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
- const TargetTransformInfo &TTI =
- getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
if (FlatAddrSpace == UninitializedAddressSpace) {
- FlatAddrSpace = TTI.getFlatAddressSpace();
+ FlatAddrSpace = TTI->getFlatAddressSpace();
if (FlatAddrSpace == UninitializedAddressSpace)
return false;
}
@@ -650,7 +642,7 @@ bool InferAddressSpaces::runOnFunction(Function &F) {
// Changes the address spaces of the flat address expressions who are inferred
// to point to a specific address space.
- return rewriteWithNewAddressSpaces(TTI, Postorder, InferredAddrSpace, &F);
+ return rewriteWithNewAddressSpaces(*TTI, Postorder, InferredAddrSpace, &F);
}
// Constants need to be tracked through RAUW to handle cases with nested
diff --git a/lib/Transforms/Scalar/InstSimplifyPass.cpp b/lib/Transforms/Scalar/InstSimplifyPass.cpp
index 6616364ab203..ec28f790f252 100644
--- a/lib/Transforms/Scalar/InstSimplifyPass.cpp
+++ b/lib/Transforms/Scalar/InstSimplifyPass.cpp
@@ -33,37 +33,39 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ,
bool Changed = false;
do {
- for (BasicBlock *BB : depth_first(&F.getEntryBlock())) {
- // Here be subtlety: the iterator must be incremented before the loop
- // body (not sure why), so a range-for loop won't work here.
- for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
- Instruction *I = &*BI++;
- // The first time through the loop ToSimplify is empty and we try to
- // simplify all instructions. On later iterations ToSimplify is not
+ for (BasicBlock &BB : F) {
+ // Unreachable code can take on strange forms that we are not prepared to
+ // handle. For example, an instruction may have itself as an operand.
+ if (!SQ.DT->isReachableFromEntry(&BB))
+ continue;
+
+ SmallVector<Instruction *, 8> DeadInstsInBB;
+ for (Instruction &I : BB) {
+ // The first time through the loop, ToSimplify is empty and we try to
+ // simplify all instructions. On later iterations, ToSimplify is not
// empty and we only bother simplifying instructions that are in it.
- if (!ToSimplify->empty() && !ToSimplify->count(I))
+ if (!ToSimplify->empty() && !ToSimplify->count(&I))
continue;
- // Don't waste time simplifying unused instructions.
- if (!I->use_empty()) {
- if (Value *V = SimplifyInstruction(I, SQ, ORE)) {
+ // Don't waste time simplifying dead/unused instructions.
+ if (isInstructionTriviallyDead(&I)) {
+ DeadInstsInBB.push_back(&I);
+ Changed = true;
+ } else if (!I.use_empty()) {
+ if (Value *V = SimplifyInstruction(&I, SQ, ORE)) {
// Mark all uses for resimplification next time round the loop.
- for (User *U : I->users())
+ for (User *U : I.users())
Next->insert(cast<Instruction>(U));
- I->replaceAllUsesWith(V);
+ I.replaceAllUsesWith(V);
++NumSimplified;
Changed = true;
+ // A call can get simplified, but it may not be trivially dead.
+ if (isInstructionTriviallyDead(&I))
+ DeadInstsInBB.push_back(&I);
}
}
- if (RecursivelyDeleteTriviallyDeadInstructions(I, SQ.TLI)) {
- // RecursivelyDeleteTriviallyDeadInstruction can remove more than one
- // instruction, so simply incrementing the iterator does not work.
- // When instructions get deleted re-iterate instead.
- BI = BB->begin();
- BE = BB->end();
- Changed = true;
- }
}
+ RecursivelyDeleteTriviallyDeadInstructions(DeadInstsInBB, SQ.TLI);
}
// Place the list of instructions to simplify on the next loop iteration
@@ -90,7 +92,7 @@ struct InstSimplifyLegacyPass : public FunctionPass {
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
- /// runOnFunction - Remove instructions that simplify.
+ /// Remove instructions that simplify.
bool runOnFunction(Function &F) override {
if (skipFunction(F))
return false;
@@ -98,7 +100,7 @@ struct InstSimplifyLegacyPass : public FunctionPass {
const DominatorTree *DT =
&getAnalysis<DominatorTreeWrapperPass>().getDomTree();
const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
AssumptionCache *AC =
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
OptimizationRemarkEmitter *ORE =
diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp
index b86bf2fefbe5..0cf00baaa24a 100644
--- a/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/lib/Transforms/Scalar/JumpThreading.cpp
@@ -224,13 +224,21 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
BasicBlock *PhiBB) -> std::pair<BasicBlock *, BasicBlock *> {
auto *PredBB = IncomingBB;
auto *SuccBB = PhiBB;
+ SmallPtrSet<BasicBlock *, 16> Visited;
while (true) {
BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator());
if (PredBr && PredBr->isConditional())
return {PredBB, SuccBB};
+ Visited.insert(PredBB);
auto *SinglePredBB = PredBB->getSinglePredecessor();
if (!SinglePredBB)
return {nullptr, nullptr};
+
+ // Stop searching when SinglePredBB has been visited. It means we see
+ // an unreachable loop.
+ if (Visited.count(SinglePredBB))
+ return {nullptr, nullptr};
+
SuccBB = PredBB;
PredBB = SinglePredBB;
}
@@ -253,7 +261,9 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
return;
BasicBlock *PredBB = PredOutEdge.first;
- BranchInst *PredBr = cast<BranchInst>(PredBB->getTerminator());
+ BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator());
+ if (!PredBr)
+ return;
uint64_t PredTrueWeight, PredFalseWeight;
// FIXME: We currently only set the profile data when it is missing.
@@ -286,7 +296,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
bool JumpThreading::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
- auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
// Get DT analysis before LVI. When LVI is initialized it conditionally adds
// DT if it's available.
auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
@@ -1461,7 +1471,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) {
"Can't handle critical edge here!");
LoadInst *NewVal = new LoadInst(
LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred),
- LoadI->getName() + ".pr", false, LoadI->getAlignment(),
+ LoadI->getName() + ".pr", false, MaybeAlign(LoadI->getAlignment()),
LoadI->getOrdering(), LoadI->getSyncScopeID(),
UnavailablePred->getTerminator());
NewVal->setDebugLoc(LoadI->getDebugLoc());
@@ -2423,7 +2433,7 @@ void JumpThreadingPass::UnfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
// |-----
// v
// BB
- BranchInst *PredTerm = dyn_cast<BranchInst>(Pred->getTerminator());
+ BranchInst *PredTerm = cast<BranchInst>(Pred->getTerminator());
BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold",
BB->getParent(), BB);
// Move the unconditional branch to NewBB.
diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp
index d9dda4cef2d2..6ce4831a7359 100644
--- a/lib/Transforms/Scalar/LICM.cpp
+++ b/lib/Transforms/Scalar/LICM.cpp
@@ -220,7 +220,8 @@ struct LegacyLICMPass : public LoopPass {
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
&getAnalysis<LoopInfoWrapperPass>().getLoopInfo(),
&getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent()),
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
*L->getHeader()->getParent()),
SE ? &SE->getSE() : nullptr, MSSA, &ORE, false);
@@ -294,7 +295,7 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<LoopAnalysis>();
- if (EnableMSSALoopDependency)
+ if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
@@ -330,6 +331,12 @@ bool LoopInvariantCodeMotion::runOnLoop(
assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form.");
+ // If this loop has metadata indicating that LICM is not to be performed then
+ // just exit.
+ if (hasDisableLICMTransformsHint(L)) {
+ return false;
+ }
+
std::unique_ptr<AliasSetTracker> CurAST;
std::unique_ptr<MemorySSAUpdater> MSSAU;
bool NoOfMemAccTooLarge = false;
@@ -340,7 +347,7 @@ bool LoopInvariantCodeMotion::runOnLoop(
CurAST = collectAliasInfoForLoop(L, LI, AA);
} else {
LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA.\n");
- MSSAU = make_unique<MemorySSAUpdater>(MSSA);
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
unsigned AccessCapCount = 0;
for (auto *BB : L->getBlocks()) {
@@ -956,7 +963,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,
// Now that we've finished hoisting make sure that LI and DT are still
// valid.
-#ifndef NDEBUG
+#ifdef EXPENSIVE_CHECKS
if (Changed) {
assert(DT->verify(DominatorTree::VerificationLevel::Fast) &&
"Dominator tree verification failed");
@@ -1026,7 +1033,8 @@ namespace {
bool isHoistableAndSinkableInst(Instruction &I) {
// Only these instructions are hoistable/sinkable.
return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) ||
- isa<FenceInst>(I) || isa<BinaryOperator>(I) || isa<CastInst>(I) ||
+ isa<FenceInst>(I) || isa<CastInst>(I) ||
+ isa<UnaryOperator>(I) || isa<BinaryOperator>(I) ||
isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) ||
isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) ||
isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) ||
@@ -1092,7 +1100,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
// in the same alias set as something that ends up being modified.
if (AA->pointsToConstantMemory(LI->getOperand(0)))
return true;
- if (LI->getMetadata(LLVMContext::MD_invariant_load))
+ if (LI->hasMetadata(LLVMContext::MD_invariant_load))
return true;
if (LI->isAtomic() && !TargetExecutesOncePerLoop)
@@ -1240,12 +1248,22 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
// FIXME: More precise: no Uses that alias SI.
if (!Flags->IsSink && !MSSA->dominates(SIMD, MU))
return false;
- } else if (const auto *MD = dyn_cast<MemoryDef>(&MA))
+ } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) {
if (auto *LI = dyn_cast<LoadInst>(MD->getMemoryInst())) {
(void)LI; // Silence warning.
assert(!LI->isUnordered() && "Expected unordered load");
return false;
}
+ // Any call, while it may not be clobbering SI, it may be a use.
+ if (auto *CI = dyn_cast<CallInst>(MD->getMemoryInst())) {
+ // Check if the call may read from the memory locattion written
+ // to by SI. Check CI's attributes and arguments; the number of
+ // such checks performed is limited above by NoOfMemAccTooLarge.
+ ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI));
+ if (isModOrRefSet(MRI))
+ return false;
+ }
+ }
}
auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI);
@@ -1375,8 +1393,7 @@ static Instruction *CloneInstructionInExitBlock(
if (!I.getName().empty())
New->setName(I.getName() + ".le");
- MemoryAccess *OldMemAcc;
- if (MSSAU && (OldMemAcc = MSSAU->getMemorySSA()->getMemoryAccess(&I))) {
+ if (MSSAU && MSSAU->getMemorySSA()->getMemoryAccess(&I)) {
// Create a new MemoryAccess and let MemorySSA set its defining access.
MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
New, nullptr, New->getParent(), MemorySSA::Beginning);
@@ -1385,7 +1402,7 @@ static Instruction *CloneInstructionInExitBlock(
MSSAU->insertDef(MemDef, /*RenameUses=*/true);
else {
auto *MemUse = cast<MemoryUse>(NewMemAcc);
- MSSAU->insertUse(MemUse);
+ MSSAU->insertUse(MemUse, /*RenameUses=*/true);
}
}
}
@@ -1783,7 +1800,7 @@ public:
StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos);
if (UnorderedAtomic)
NewSI->setOrdering(AtomicOrdering::Unordered);
- NewSI->setAlignment(Alignment);
+ NewSI->setAlignment(MaybeAlign(Alignment));
NewSI->setDebugLoc(DL);
if (AATags)
NewSI->setAAMetadata(AATags);
@@ -2016,7 +2033,8 @@ bool llvm::promoteLoopAccessesToScalars(
if (!DereferenceableInPH) {
DereferenceableInPH = isDereferenceableAndAlignedPointer(
Store->getPointerOperand(), Store->getValueOperand()->getType(),
- Store->getAlignment(), MDL, Preheader->getTerminator(), DT);
+ MaybeAlign(Store->getAlignment()), MDL,
+ Preheader->getTerminator(), DT);
}
} else
return false; // Not a load or store.
@@ -2101,20 +2119,21 @@ bool llvm::promoteLoopAccessesToScalars(
SomePtr->getName() + ".promoted", Preheader->getTerminator());
if (SawUnorderedAtomic)
PreheaderLoad->setOrdering(AtomicOrdering::Unordered);
- PreheaderLoad->setAlignment(Alignment);
+ PreheaderLoad->setAlignment(MaybeAlign(Alignment));
PreheaderLoad->setDebugLoc(DL);
if (AATags)
PreheaderLoad->setAAMetadata(AATags);
SSA.AddAvailableValue(Preheader, PreheaderLoad);
- MemoryAccess *PreheaderLoadMemoryAccess;
if (MSSAU) {
- PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB(
+ MemoryAccess *PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB(
PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End);
MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess);
- MSSAU->insertUse(NewMemUse);
+ MSSAU->insertUse(NewMemUse, /*RenameUses=*/true);
}
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
// Rewrite all the loads in the loop and remember all the definitions from
// stores in the loop.
Promoter.run(LoopUses);
@@ -2161,7 +2180,7 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI,
LoopToAliasSetMap.erase(MapI);
}
if (!CurAST)
- CurAST = make_unique<AliasSetTracker>(*AA);
+ CurAST = std::make_unique<AliasSetTracker>(*AA);
// Add everything from the sub loops that are no longer directly available.
for (Loop *InnerL : RecomputeLoops)
@@ -2180,7 +2199,7 @@ std::unique_ptr<AliasSetTracker>
LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA(
Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) {
auto *MSSA = MSSAU->getMemorySSA();
- auto CurAST = make_unique<AliasSetTracker>(*AA, MSSA, L);
+ auto CurAST = std::make_unique<AliasSetTracker>(*AA, MSSA, L);
CurAST->addAllInstructionsInLoopUsingMSSA();
return CurAST;
}
diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp
index 1fcf1315a177..a972d6fa2fcd 100644
--- a/lib/Transforms/Scalar/LoopDataPrefetch.cpp
+++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp
@@ -312,8 +312,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
IRBuilder<> Builder(MemI);
Module *M = BB->getParent()->getParent();
Type *I32 = Type::getInt32Ty(BB->getContext());
- Function *PrefetchFunc =
- Intrinsic::getDeclaration(M, Intrinsic::prefetch);
+ Function *PrefetchFunc = Intrinsic::getDeclaration(
+ M, Intrinsic::prefetch, PrefPtrValue->getType());
Builder.CreateCall(
PrefetchFunc,
{PrefPtrValue,
diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp
index 8371367e24e7..cee197cf8354 100644
--- a/lib/Transforms/Scalar/LoopDeletion.cpp
+++ b/lib/Transforms/Scalar/LoopDeletion.cpp
@@ -191,7 +191,7 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,
// Don't remove loops for which we can't solve the trip count.
// They could be infinite, in which case we'd be changing program behavior.
- const SCEV *S = SE.getMaxBackedgeTakenCount(L);
+ const SCEV *S = SE.getConstantMaxBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(S)) {
LLVM_DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n");
return Changed ? LoopDeletionResult::Modified
diff --git a/lib/Transforms/Scalar/LoopFuse.cpp b/lib/Transforms/Scalar/LoopFuse.cpp
index 0bc2bcff2ae1..9f93c68e6128 100644
--- a/lib/Transforms/Scalar/LoopFuse.cpp
+++ b/lib/Transforms/Scalar/LoopFuse.cpp
@@ -66,7 +66,7 @@ using namespace llvm;
#define DEBUG_TYPE "loop-fusion"
-STATISTIC(FuseCounter, "Count number of loop fusions performed");
+STATISTIC(FuseCounter, "Loops fused");
STATISTIC(NumFusionCandidates, "Number of candidates for loop fusion");
STATISTIC(InvalidPreheader, "Loop has invalid preheader");
STATISTIC(InvalidHeader, "Loop has invalid header");
@@ -79,12 +79,15 @@ STATISTIC(MayThrowException, "Loop may throw an exception");
STATISTIC(ContainsVolatileAccess, "Loop contains a volatile access");
STATISTIC(NotSimplifiedForm, "Loop is not in simplified form");
STATISTIC(InvalidDependencies, "Dependencies prevent fusion");
-STATISTIC(InvalidTripCount,
- "Loop does not have invariant backedge taken count");
+STATISTIC(UnknownTripCount, "Loop has unknown trip count");
STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop");
-STATISTIC(NonEqualTripCount, "Candidate trip counts are not the same");
-STATISTIC(NonAdjacent, "Candidates are not adjacent");
-STATISTIC(NonEmptyPreheader, "Candidate has a non-empty preheader");
+STATISTIC(NonEqualTripCount, "Loop trip counts are not the same");
+STATISTIC(NonAdjacent, "Loops are not adjacent");
+STATISTIC(NonEmptyPreheader, "Loop has a non-empty preheader");
+STATISTIC(FusionNotBeneficial, "Fusion is not beneficial");
+STATISTIC(NonIdenticalGuards, "Candidates have different guards");
+STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block");
+STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block");
enum FusionDependenceAnalysisChoice {
FUSION_DEPENDENCE_ANALYSIS_SCEV,
@@ -110,6 +113,7 @@ static cl::opt<bool>
cl::Hidden, cl::init(false), cl::ZeroOrMore);
#endif
+namespace {
/// This class is used to represent a candidate for loop fusion. When it is
/// constructed, it checks the conditions for loop fusion to ensure that it
/// represents a valid candidate. It caches several parts of a loop that are
@@ -143,6 +147,8 @@ struct FusionCandidate {
SmallVector<Instruction *, 16> MemWrites;
/// Are all of the members of this fusion candidate still valid
bool Valid;
+ /// Guard branch of the loop, if it exists
+ BranchInst *GuardBranch;
/// Dominator and PostDominator trees are needed for the
/// FusionCandidateCompare function, required by FusionCandidateSet to
@@ -151,11 +157,20 @@ struct FusionCandidate {
const DominatorTree *DT;
const PostDominatorTree *PDT;
+ OptimizationRemarkEmitter &ORE;
+
FusionCandidate(Loop *L, const DominatorTree *DT,
- const PostDominatorTree *PDT)
+ const PostDominatorTree *PDT, OptimizationRemarkEmitter &ORE)
: Preheader(L->getLoopPreheader()), Header(L->getHeader()),
ExitingBlock(L->getExitingBlock()), ExitBlock(L->getExitBlock()),
- Latch(L->getLoopLatch()), L(L), Valid(true), DT(DT), PDT(PDT) {
+ Latch(L->getLoopLatch()), L(L), Valid(true), GuardBranch(nullptr),
+ DT(DT), PDT(PDT), ORE(ORE) {
+
+ // TODO: This is temporary while we fuse both rotated and non-rotated
+ // loops. Once we switch to only fusing rotated loops, the initialization of
+ // GuardBranch can be moved into the initialization list above.
+ if (isRotated())
+ GuardBranch = L->getLoopGuardBranch();
// Walk over all blocks in the loop and check for conditions that may
// prevent fusion. For each block, walk over all instructions and collect
@@ -163,28 +178,28 @@ struct FusionCandidate {
// found, invalidate this object and return.
for (BasicBlock *BB : L->blocks()) {
if (BB->hasAddressTaken()) {
- AddressTakenBB++;
invalidate();
+ reportInvalidCandidate(AddressTakenBB);
return;
}
for (Instruction &I : *BB) {
if (I.mayThrow()) {
- MayThrowException++;
invalidate();
+ reportInvalidCandidate(MayThrowException);
return;
}
if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
if (SI->isVolatile()) {
- ContainsVolatileAccess++;
invalidate();
+ reportInvalidCandidate(ContainsVolatileAccess);
return;
}
}
if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
if (LI->isVolatile()) {
- ContainsVolatileAccess++;
invalidate();
+ reportInvalidCandidate(ContainsVolatileAccess);
return;
}
}
@@ -214,19 +229,96 @@ struct FusionCandidate {
assert(Latch == L->getLoopLatch() && "Latch is out of sync");
}
+ /// Get the entry block for this fusion candidate.
+ ///
+ /// If this fusion candidate represents a guarded loop, the entry block is the
+ /// loop guard block. If it represents an unguarded loop, the entry block is
+ /// the preheader of the loop.
+ BasicBlock *getEntryBlock() const {
+ if (GuardBranch)
+ return GuardBranch->getParent();
+ else
+ return Preheader;
+ }
+
+ /// Given a guarded loop, get the successor of the guard that is not in the
+ /// loop.
+ ///
+ /// This method returns the successor of the loop guard that is not located
+ /// within the loop (i.e., the successor of the guard that is not the
+ /// preheader).
+ /// This method is only valid for guarded loops.
+ BasicBlock *getNonLoopBlock() const {
+ assert(GuardBranch && "Only valid on guarded loops.");
+ assert(GuardBranch->isConditional() &&
+ "Expecting guard to be a conditional branch.");
+ return (GuardBranch->getSuccessor(0) == Preheader)
+ ? GuardBranch->getSuccessor(1)
+ : GuardBranch->getSuccessor(0);
+ }
+
+ bool isRotated() const {
+ assert(L && "Expecting loop to be valid.");
+ assert(Latch && "Expecting latch to be valid.");
+ return L->isLoopExiting(Latch);
+ }
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void dump() const {
- dbgs() << "\tPreheader: " << (Preheader ? Preheader->getName() : "nullptr")
+ dbgs() << "\tGuardBranch: "
+ << (GuardBranch ? GuardBranch->getName() : "nullptr") << "\n"
+ << "\tPreheader: " << (Preheader ? Preheader->getName() : "nullptr")
<< "\n"
<< "\tHeader: " << (Header ? Header->getName() : "nullptr") << "\n"
<< "\tExitingBB: "
<< (ExitingBlock ? ExitingBlock->getName() : "nullptr") << "\n"
<< "\tExitBB: " << (ExitBlock ? ExitBlock->getName() : "nullptr")
<< "\n"
- << "\tLatch: " << (Latch ? Latch->getName() : "nullptr") << "\n";
+ << "\tLatch: " << (Latch ? Latch->getName() : "nullptr") << "\n"
+ << "\tEntryBlock: "
+ << (getEntryBlock() ? getEntryBlock()->getName() : "nullptr")
+ << "\n";
}
#endif
+ /// Determine if a fusion candidate (representing a loop) is eligible for
+ /// fusion. Note that this only checks whether a single loop can be fused - it
+ /// does not check whether it is *legal* to fuse two loops together.
+ bool isEligibleForFusion(ScalarEvolution &SE) const {
+ if (!isValid()) {
+ LLVM_DEBUG(dbgs() << "FC has invalid CFG requirements!\n");
+ if (!Preheader)
+ ++InvalidPreheader;
+ if (!Header)
+ ++InvalidHeader;
+ if (!ExitingBlock)
+ ++InvalidExitingBlock;
+ if (!ExitBlock)
+ ++InvalidExitBlock;
+ if (!Latch)
+ ++InvalidLatch;
+ if (L->isInvalid())
+ ++InvalidLoop;
+
+ return false;
+ }
+
+ // Require ScalarEvolution to be able to determine a trip count.
+ if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+ LLVM_DEBUG(dbgs() << "Loop " << L->getName()
+ << " trip count not computable!\n");
+ return reportInvalidCandidate(UnknownTripCount);
+ }
+
+ if (!L->isLoopSimplifyForm()) {
+ LLVM_DEBUG(dbgs() << "Loop " << L->getName()
+ << " is not in simplified form!\n");
+ return reportInvalidCandidate(NotSimplifiedForm);
+ }
+
+ return true;
+ }
+
private:
// This is only used internally for now, to clear the MemWrites and MemReads
// list and setting Valid to false. I can't envision other uses of this right
@@ -239,17 +331,18 @@ private:
MemReads.clear();
Valid = false;
}
-};
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
- const FusionCandidate &FC) {
- if (FC.isValid())
- OS << FC.Preheader->getName();
- else
- OS << "<Invalid>";
-
- return OS;
-}
+ bool reportInvalidCandidate(llvm::Statistic &Stat) const {
+ using namespace ore;
+ assert(L && Preheader && "Fusion candidate not initialized properly!");
+ ++Stat;
+ ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, Stat.getName(),
+ L->getStartLoc(), Preheader)
+ << "[" << Preheader->getParent()->getName() << "]: "
+ << "Loop is not a candidate for fusion: " << Stat.getDesc());
+ return false;
+ }
+};
struct FusionCandidateCompare {
/// Comparison functor to sort two Control Flow Equivalent fusion candidates
@@ -260,21 +353,24 @@ struct FusionCandidateCompare {
const FusionCandidate &RHS) const {
const DominatorTree *DT = LHS.DT;
+ BasicBlock *LHSEntryBlock = LHS.getEntryBlock();
+ BasicBlock *RHSEntryBlock = RHS.getEntryBlock();
+
// Do not save PDT to local variable as it is only used in asserts and thus
// will trigger an unused variable warning if building without asserts.
assert(DT && LHS.PDT && "Expecting valid dominator tree");
// Do this compare first so if LHS == RHS, function returns false.
- if (DT->dominates(RHS.Preheader, LHS.Preheader)) {
+ if (DT->dominates(RHSEntryBlock, LHSEntryBlock)) {
// RHS dominates LHS
// Verify LHS post-dominates RHS
- assert(LHS.PDT->dominates(LHS.Preheader, RHS.Preheader));
+ assert(LHS.PDT->dominates(LHSEntryBlock, RHSEntryBlock));
return false;
}
- if (DT->dominates(LHS.Preheader, RHS.Preheader)) {
+ if (DT->dominates(LHSEntryBlock, RHSEntryBlock)) {
// Verify RHS Postdominates LHS
- assert(LHS.PDT->dominates(RHS.Preheader, LHS.Preheader));
+ assert(LHS.PDT->dominates(RHSEntryBlock, LHSEntryBlock));
return true;
}
@@ -286,7 +382,6 @@ struct FusionCandidateCompare {
}
};
-namespace {
using LoopVector = SmallVector<Loop *, 4>;
// Set of Control Flow Equivalent (CFE) Fusion Candidates, sorted in dominance
@@ -301,17 +396,26 @@ using LoopVector = SmallVector<Loop *, 4>;
// keeps the FusionCandidateSet sorted will also simplify the implementation.
using FusionCandidateSet = std::set<FusionCandidate, FusionCandidateCompare>;
using FusionCandidateCollection = SmallVector<FusionCandidateSet, 4>;
-} // namespace
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
+#if !defined(NDEBUG)
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
+ const FusionCandidate &FC) {
+ if (FC.isValid())
+ OS << FC.Preheader->getName();
+ else
+ OS << "<Invalid>";
+
+ return OS;
+}
+
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const FusionCandidateSet &CandSet) {
- for (auto IT : CandSet)
- OS << IT << "\n";
+ for (const FusionCandidate &FC : CandSet)
+ OS << FC << '\n';
return OS;
}
-#if !defined(NDEBUG)
static void
printFusionCandidates(const FusionCandidateCollection &FusionCandidates) {
dbgs() << "Fusion Candidates: \n";
@@ -391,16 +495,6 @@ static void printLoopVector(const LoopVector &LV) {
}
#endif
-static void reportLoopFusion(const FusionCandidate &FC0,
- const FusionCandidate &FC1,
- OptimizationRemarkEmitter &ORE) {
- using namespace ore;
- ORE.emit(
- OptimizationRemark(DEBUG_TYPE, "LoopFusion", FC0.Preheader->getParent())
- << "Fused " << NV("Cand1", StringRef(FC0.Preheader->getName()))
- << " with " << NV("Cand2", StringRef(FC1.Preheader->getName())));
-}
-
struct LoopFuser {
private:
// Sets of control flow equivalent fusion candidates for a given nest level.
@@ -497,53 +591,16 @@ private:
const FusionCandidate &FC1) const {
assert(FC0.Preheader && FC1.Preheader && "Expecting valid preheaders");
- if (DT.dominates(FC0.Preheader, FC1.Preheader))
- return PDT.dominates(FC1.Preheader, FC0.Preheader);
+ BasicBlock *FC0EntryBlock = FC0.getEntryBlock();
+ BasicBlock *FC1EntryBlock = FC1.getEntryBlock();
- if (DT.dominates(FC1.Preheader, FC0.Preheader))
- return PDT.dominates(FC0.Preheader, FC1.Preheader);
+ if (DT.dominates(FC0EntryBlock, FC1EntryBlock))
+ return PDT.dominates(FC1EntryBlock, FC0EntryBlock);
- return false;
- }
-
- /// Determine if a fusion candidate (representing a loop) is eligible for
- /// fusion. Note that this only checks whether a single loop can be fused - it
- /// does not check whether it is *legal* to fuse two loops together.
- bool eligibleForFusion(const FusionCandidate &FC) const {
- if (!FC.isValid()) {
- LLVM_DEBUG(dbgs() << "FC " << FC << " has invalid CFG requirements!\n");
- if (!FC.Preheader)
- InvalidPreheader++;
- if (!FC.Header)
- InvalidHeader++;
- if (!FC.ExitingBlock)
- InvalidExitingBlock++;
- if (!FC.ExitBlock)
- InvalidExitBlock++;
- if (!FC.Latch)
- InvalidLatch++;
- if (FC.L->isInvalid())
- InvalidLoop++;
+ if (DT.dominates(FC1EntryBlock, FC0EntryBlock))
+ return PDT.dominates(FC0EntryBlock, FC1EntryBlock);
- return false;
- }
-
- // Require ScalarEvolution to be able to determine a trip count.
- if (!SE.hasLoopInvariantBackedgeTakenCount(FC.L)) {
- LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName()
- << " trip count not computable!\n");
- InvalidTripCount++;
- return false;
- }
-
- if (!FC.L->isLoopSimplifyForm()) {
- LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName()
- << " is not in simplified form!\n");
- NotSimplifiedForm++;
- return false;
- }
-
- return true;
+ return false;
}
/// Iterate over all loops in the given loop set and identify the loops that
@@ -551,8 +608,8 @@ private:
/// Flow Equivalent sets, sorted by dominance.
void collectFusionCandidates(const LoopVector &LV) {
for (Loop *L : LV) {
- FusionCandidate CurrCand(L, &DT, &PDT);
- if (!eligibleForFusion(CurrCand))
+ FusionCandidate CurrCand(L, &DT, &PDT, ORE);
+ if (!CurrCand.isEligibleForFusion(SE))
continue;
// Go through each list in FusionCandidates and determine if L is control
@@ -664,31 +721,64 @@ private:
if (!identicalTripCounts(*FC0, *FC1)) {
LLVM_DEBUG(dbgs() << "Fusion candidates do not have identical trip "
"counts. Not fusing.\n");
- NonEqualTripCount++;
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEqualTripCount);
continue;
}
if (!isAdjacent(*FC0, *FC1)) {
LLVM_DEBUG(dbgs()
<< "Fusion candidates are not adjacent. Not fusing.\n");
- NonAdjacent++;
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, NonAdjacent);
continue;
}
- // For now we skip fusing if the second candidate has any instructions
- // in the preheader. This is done because we currently do not have the
- // safety checks to determine if it is save to move the preheader of
- // the second candidate past the body of the first candidate. Once
- // these checks are added, this condition can be removed.
+ // Ensure that FC0 and FC1 have identical guards.
+ // If one (or both) are not guarded, this check is not necessary.
+ if (FC0->GuardBranch && FC1->GuardBranch &&
+ !haveIdenticalGuards(*FC0, *FC1)) {
+ LLVM_DEBUG(dbgs() << "Fusion candidates do not have identical "
+ "guards. Not Fusing.\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonIdenticalGuards);
+ continue;
+ }
+
+ // The following three checks look for empty blocks in FC0 and FC1. If
+ // any of these blocks are non-empty, we do not fuse. This is done
+ // because we currently do not have the safety checks to determine if
+ // it is safe to move the blocks past other blocks in the loop. Once
+ // these checks are added, these conditions can be relaxed.
if (!isEmptyPreheader(*FC1)) {
LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty "
"preheader. Not fusing.\n");
- NonEmptyPreheader++;
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEmptyPreheader);
+ continue;
+ }
+
+ if (FC0->GuardBranch && !isEmptyExitBlock(*FC0)) {
+ LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty exit "
+ "block. Not fusing.\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEmptyExitBlock);
+ continue;
+ }
+
+ if (FC1->GuardBranch && !isEmptyGuardBlock(*FC1)) {
+ LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty guard "
+ "block. Not fusing.\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ NonEmptyGuardBlock);
continue;
}
+ // Check the dependencies across the loops and do not fuse if it would
+ // violate them.
if (!dependencesAllowFusion(*FC0, *FC1)) {
LLVM_DEBUG(dbgs() << "Memory dependencies do not allow fusion!\n");
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ InvalidDependencies);
continue;
}
@@ -696,9 +786,11 @@ private:
LLVM_DEBUG(dbgs()
<< "\tFusion appears to be "
<< (BeneficialToFuse ? "" : "un") << "profitable!\n");
- if (!BeneficialToFuse)
+ if (!BeneficialToFuse) {
+ reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,
+ FusionNotBeneficial);
continue;
-
+ }
// All analysis has completed and has determined that fusion is legal
// and profitable. At this point, start transforming the code and
// perform fusion.
@@ -710,15 +802,14 @@ private:
// Note this needs to be done *before* performFusion because
// performFusion will change the original loops, making it not
// possible to identify them after fusion is complete.
- reportLoopFusion(*FC0, *FC1, ORE);
+ reportLoopFusion<OptimizationRemark>(*FC0, *FC1, FuseCounter);
- FusionCandidate FusedCand(performFusion(*FC0, *FC1), &DT, &PDT);
+ FusionCandidate FusedCand(performFusion(*FC0, *FC1), &DT, &PDT, ORE);
FusedCand.verify();
- assert(eligibleForFusion(FusedCand) &&
+ assert(FusedCand.isEligibleForFusion(SE) &&
"Fused candidate should be eligible for fusion!");
// Notify the loop-depth-tree that these loops are not valid objects
- // anymore.
LDT.removeLoop(FC1->L);
CandidateSet.erase(FC0);
@@ -889,7 +980,7 @@ private:
LLVM_DEBUG(dbgs() << "Check if " << FC0 << " can be fused with " << FC1
<< "\n");
assert(FC0.L->getLoopDepth() == FC1.L->getLoopDepth());
- assert(DT.dominates(FC0.Preheader, FC1.Preheader));
+ assert(DT.dominates(FC0.getEntryBlock(), FC1.getEntryBlock()));
for (Instruction *WriteL0 : FC0.MemWrites) {
for (Instruction *WriteL1 : FC1.MemWrites)
@@ -939,18 +1030,89 @@ private:
return true;
}
- /// Determine if the exit block of \p FC0 is the preheader of \p FC1. In this
- /// case, there is no code in between the two fusion candidates, thus making
- /// them adjacent.
+ /// Determine if two fusion candidates are adjacent in the CFG.
+ ///
+ /// This method will determine if there are additional basic blocks in the CFG
+ /// between the exit of \p FC0 and the entry of \p FC1.
+ /// If the two candidates are guarded loops, then it checks whether the
+ /// non-loop successor of the \p FC0 guard branch is the entry block of \p
+ /// FC1. If not, then the loops are not adjacent. If the two candidates are
+ /// not guarded loops, then it checks whether the exit block of \p FC0 is the
+ /// preheader of \p FC1.
bool isAdjacent(const FusionCandidate &FC0,
const FusionCandidate &FC1) const {
- return FC0.ExitBlock == FC1.Preheader;
+ // If the successor of the guard branch is FC1, then the loops are adjacent
+ if (FC0.GuardBranch)
+ return FC0.getNonLoopBlock() == FC1.getEntryBlock();
+ else
+ return FC0.ExitBlock == FC1.getEntryBlock();
+ }
+
+ /// Determine if two fusion candidates have identical guards
+ ///
+ /// This method will determine if two fusion candidates have the same guards.
+ /// The guards are considered the same if:
+ /// 1. The instructions to compute the condition used in the compare are
+ /// identical.
+ /// 2. The successors of the guard have the same flow into/around the loop.
+ /// If the compare instructions are identical, then the first successor of the
+ /// guard must go to the same place (either the preheader of the loop or the
+ /// NonLoopBlock). In other words, the the first successor of both loops must
+ /// both go into the loop (i.e., the preheader) or go around the loop (i.e.,
+ /// the NonLoopBlock). The same must be true for the second successor.
+ bool haveIdenticalGuards(const FusionCandidate &FC0,
+ const FusionCandidate &FC1) const {
+ assert(FC0.GuardBranch && FC1.GuardBranch &&
+ "Expecting FC0 and FC1 to be guarded loops.");
+
+ if (auto FC0CmpInst =
+ dyn_cast<Instruction>(FC0.GuardBranch->getCondition()))
+ if (auto FC1CmpInst =
+ dyn_cast<Instruction>(FC1.GuardBranch->getCondition()))
+ if (!FC0CmpInst->isIdenticalTo(FC1CmpInst))
+ return false;
+
+ // The compare instructions are identical.
+ // Now make sure the successor of the guards have the same flow into/around
+ // the loop
+ if (FC0.GuardBranch->getSuccessor(0) == FC0.Preheader)
+ return (FC1.GuardBranch->getSuccessor(0) == FC1.Preheader);
+ else
+ return (FC1.GuardBranch->getSuccessor(1) == FC1.Preheader);
+ }
+
+ /// Check that the guard for \p FC *only* contains the cmp/branch for the
+ /// guard.
+ /// Once we are able to handle intervening code, any code in the guard block
+ /// for FC1 will need to be treated as intervening code and checked whether
+ /// it can safely move around the loops.
+ bool isEmptyGuardBlock(const FusionCandidate &FC) const {
+ assert(FC.GuardBranch && "Expecting a fusion candidate with guard branch.");
+ if (auto *CmpInst = dyn_cast<Instruction>(FC.GuardBranch->getCondition())) {
+ auto *GuardBlock = FC.GuardBranch->getParent();
+ // If the generation of the cmp value is in GuardBlock, then the size of
+ // the guard block should be 2 (cmp + branch). If the generation of the
+ // cmp value is in a different block, then the size of the guard block
+ // should only be 1.
+ if (CmpInst->getParent() == GuardBlock)
+ return GuardBlock->size() == 2;
+ else
+ return GuardBlock->size() == 1;
+ }
+
+ return false;
}
bool isEmptyPreheader(const FusionCandidate &FC) const {
+ assert(FC.Preheader && "Expecting a valid preheader");
return FC.Preheader->size() == 1;
}
+ bool isEmptyExitBlock(const FusionCandidate &FC) const {
+ assert(FC.ExitBlock && "Expecting a valid exit block");
+ return FC.ExitBlock->size() == 1;
+ }
+
/// Fuse two fusion candidates, creating a new fused loop.
///
/// This method contains the mechanics of fusing two loops, represented by \p
@@ -987,6 +1149,12 @@ private:
LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump();
dbgs() << "Fusion Candidate 1: \n"; FC1.dump(););
+ // Fusing guarded loops is handled slightly differently than non-guarded
+ // loops and has been broken out into a separate method instead of trying to
+ // intersperse the logic within a single method.
+ if (FC0.GuardBranch)
+ return fuseGuardedLoops(FC0, FC1);
+
assert(FC1.Preheader == FC0.ExitBlock);
assert(FC1.Preheader->size() == 1 &&
FC1.Preheader->getSingleSuccessor() == FC1.Header);
@@ -1131,7 +1299,258 @@ private:
SE.verify();
#endif
- FuseCounter++;
+ LLVM_DEBUG(dbgs() << "Fusion done:\n");
+
+ return FC0.L;
+ }
+
+ /// Report details on loop fusion opportunities.
+ ///
+ /// This template function can be used to report both successful and missed
+ /// loop fusion opportunities, based on the RemarkKind. The RemarkKind should
+ /// be one of:
+ /// - OptimizationRemarkMissed to report when loop fusion is unsuccessful
+ /// given two valid fusion candidates.
+ /// - OptimizationRemark to report successful fusion of two fusion
+ /// candidates.
+ /// The remarks will be printed using the form:
+ /// <path/filename>:<line number>:<column number>: [<function name>]:
+ /// <Cand1 Preheader> and <Cand2 Preheader>: <Stat Description>
+ template <typename RemarkKind>
+ void reportLoopFusion(const FusionCandidate &FC0, const FusionCandidate &FC1,
+ llvm::Statistic &Stat) {
+ assert(FC0.Preheader && FC1.Preheader &&
+ "Expecting valid fusion candidates");
+ using namespace ore;
+ ++Stat;
+ ORE.emit(RemarkKind(DEBUG_TYPE, Stat.getName(), FC0.L->getStartLoc(),
+ FC0.Preheader)
+ << "[" << FC0.Preheader->getParent()->getName()
+ << "]: " << NV("Cand1", StringRef(FC0.Preheader->getName()))
+ << " and " << NV("Cand2", StringRef(FC1.Preheader->getName()))
+ << ": " << Stat.getDesc());
+ }
+
+ /// Fuse two guarded fusion candidates, creating a new fused loop.
+ ///
+ /// Fusing guarded loops is handled much the same way as fusing non-guarded
+ /// loops. The rewiring of the CFG is slightly different though, because of
+ /// the presence of the guards around the loops and the exit blocks after the
+ /// loop body. As such, the new loop is rewired as follows:
+ /// 1. Keep the guard branch from FC0 and use the non-loop block target
+ /// from the FC1 guard branch.
+ /// 2. Remove the exit block from FC0 (this exit block should be empty
+ /// right now).
+ /// 3. Remove the guard branch for FC1
+ /// 4. Remove the preheader for FC1.
+ /// The exit block successor for the latch of FC0 is updated to be the header
+ /// of FC1 and the non-exit block successor of the latch of FC1 is updated to
+ /// be the header of FC0, thus creating the fused loop.
+ Loop *fuseGuardedLoops(const FusionCandidate &FC0,
+ const FusionCandidate &FC1) {
+ assert(FC0.GuardBranch && FC1.GuardBranch && "Expecting guarded loops");
+
+ BasicBlock *FC0GuardBlock = FC0.GuardBranch->getParent();
+ BasicBlock *FC1GuardBlock = FC1.GuardBranch->getParent();
+ BasicBlock *FC0NonLoopBlock = FC0.getNonLoopBlock();
+ BasicBlock *FC1NonLoopBlock = FC1.getNonLoopBlock();
+
+ assert(FC0NonLoopBlock == FC1GuardBlock && "Loops are not adjacent");
+
+ SmallVector<DominatorTree::UpdateType, 8> TreeUpdates;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Update the Loop Guard
+ ////////////////////////////////////////////////////////////////////////////
+ // The guard for FC0 is updated to guard both FC0 and FC1. This is done by
+ // changing the NonLoopGuardBlock for FC0 to the NonLoopGuardBlock for FC1.
+ // Thus, one path from the guard goes to the preheader for FC0 (and thus
+ // executes the new fused loop) and the other path goes to the NonLoopBlock
+ // for FC1 (where FC1 guard would have gone if FC1 was not executed).
+ FC0.GuardBranch->replaceUsesOfWith(FC0NonLoopBlock, FC1NonLoopBlock);
+ FC0.ExitBlock->getTerminator()->replaceUsesOfWith(FC1GuardBlock,
+ FC1.Header);
+
+ // The guard of FC1 is not necessary anymore.
+ FC1.GuardBranch->eraseFromParent();
+ new UnreachableInst(FC1GuardBlock->getContext(), FC1GuardBlock);
+
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Delete, FC1GuardBlock, FC1.Preheader));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Delete, FC1GuardBlock, FC1NonLoopBlock));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Delete, FC0GuardBlock, FC1GuardBlock));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Insert, FC0GuardBlock, FC1NonLoopBlock));
+
+ assert(pred_begin(FC1GuardBlock) == pred_end(FC1GuardBlock) &&
+ "Expecting guard block to have no predecessors");
+ assert(succ_begin(FC1GuardBlock) == succ_end(FC1GuardBlock) &&
+ "Expecting guard block to have no successors");
+
+ // Remember the phi nodes originally in the header of FC0 in order to rewire
+ // them later. However, this is only necessary if the new loop carried
+ // values might not dominate the exiting branch. While we do not generally
+ // test if this is the case but simply insert intermediate phi nodes, we
+ // need to make sure these intermediate phi nodes have different
+ // predecessors. To this end, we filter the special case where the exiting
+ // block is the latch block of the first loop. Nothing needs to be done
+ // anyway as all loop carried values dominate the latch and thereby also the
+ // exiting branch.
+ // KB: This is no longer necessary because FC0.ExitingBlock == FC0.Latch
+ // (because the loops are rotated. Thus, nothing will ever be added to
+ // OriginalFC0PHIs.
+ SmallVector<PHINode *, 8> OriginalFC0PHIs;
+ if (FC0.ExitingBlock != FC0.Latch)
+ for (PHINode &PHI : FC0.Header->phis())
+ OriginalFC0PHIs.push_back(&PHI);
+
+ assert(OriginalFC0PHIs.empty() && "Expecting OriginalFC0PHIs to be empty!");
+
+ // Replace incoming blocks for header PHIs first.
+ FC1.Preheader->replaceSuccessorsPhiUsesWith(FC0.Preheader);
+ FC0.Latch->replaceSuccessorsPhiUsesWith(FC1.Latch);
+
+ // The old exiting block of the first loop (FC0) has to jump to the header
+ // of the second as we need to execute the code in the second header block
+ // regardless of the trip count. That is, if the trip count is 0, so the
+ // back edge is never taken, we still have to execute both loop headers,
+ // especially (but not only!) if the second is a do-while style loop.
+ // However, doing so might invalidate the phi nodes of the first loop as
+ // the new values do only need to dominate their latch and not the exiting
+ // predicate. To remedy this potential problem we always introduce phi
+ // nodes in the header of the second loop later that select the loop carried
+ // value, if the second header was reached through an old latch of the
+ // first, or undef otherwise. This is sound as exiting the first implies the
+ // second will exit too, __without__ taking the back-edge (their
+ // trip-counts are equal after all).
+ FC0.ExitingBlock->getTerminator()->replaceUsesOfWith(FC0.ExitBlock,
+ FC1.Header);
+
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Delete, FC0.ExitingBlock, FC0.ExitBlock));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Insert, FC0.ExitingBlock, FC1.Header));
+
+ // Remove FC0 Exit Block
+ // The exit block for FC0 is no longer needed since control will flow
+ // directly to the header of FC1. Since it is an empty block, it can be
+ // removed at this point.
+ // TODO: In the future, we can handle non-empty exit blocks my merging any
+ // instructions from FC0 exit block into FC1 exit block prior to removing
+ // the block.
+ assert(pred_begin(FC0.ExitBlock) == pred_end(FC0.ExitBlock) &&
+ "Expecting exit block to be empty");
+ FC0.ExitBlock->getTerminator()->eraseFromParent();
+ new UnreachableInst(FC0.ExitBlock->getContext(), FC0.ExitBlock);
+
+ // Remove FC1 Preheader
+ // The pre-header of L1 is not necessary anymore.
+ assert(pred_begin(FC1.Preheader) == pred_end(FC1.Preheader));
+ FC1.Preheader->getTerminator()->eraseFromParent();
+ new UnreachableInst(FC1.Preheader->getContext(), FC1.Preheader);
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Delete, FC1.Preheader, FC1.Header));
+
+ // Moves the phi nodes from the second to the first loops header block.
+ while (PHINode *PHI = dyn_cast<PHINode>(&FC1.Header->front())) {
+ if (SE.isSCEVable(PHI->getType()))
+ SE.forgetValue(PHI);
+ if (PHI->hasNUsesOrMore(1))
+ PHI->moveBefore(&*FC0.Header->getFirstInsertionPt());
+ else
+ PHI->eraseFromParent();
+ }
+
+ // Introduce new phi nodes in the second loop header to ensure
+ // exiting the first and jumping to the header of the second does not break
+ // the SSA property of the phis originally in the first loop. See also the
+ // comment above.
+ Instruction *L1HeaderIP = &FC1.Header->front();
+ for (PHINode *LCPHI : OriginalFC0PHIs) {
+ int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch);
+ assert(L1LatchBBIdx >= 0 &&
+ "Expected loop carried value to be rewired at this point!");
+
+ Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx);
+
+ PHINode *L1HeaderPHI = PHINode::Create(
+ LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP);
+ L1HeaderPHI->addIncoming(LCV, FC0.Latch);
+ L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()),
+ FC0.ExitingBlock);
+
+ LCPHI->setIncomingValue(L1LatchBBIdx, L1HeaderPHI);
+ }
+
+ // Update the latches
+
+ // Replace latch terminator destinations.
+ FC0.Latch->getTerminator()->replaceUsesOfWith(FC0.Header, FC1.Header);
+ FC1.Latch->getTerminator()->replaceUsesOfWith(FC1.Header, FC0.Header);
+
+ // If FC0.Latch and FC0.ExitingBlock are the same then we have already
+ // performed the updates above.
+ if (FC0.Latch != FC0.ExitingBlock)
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(
+ DominatorTree::Insert, FC0.Latch, FC1.Header));
+
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete,
+ FC0.Latch, FC0.Header));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Insert,
+ FC1.Latch, FC0.Header));
+ TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete,
+ FC1.Latch, FC1.Header));
+
+ // All done
+ // Apply the updates to the Dominator Tree and cleanup.
+
+ assert(succ_begin(FC1GuardBlock) == succ_end(FC1GuardBlock) &&
+ "FC1GuardBlock has successors!!");
+ assert(pred_begin(FC1GuardBlock) == pred_end(FC1GuardBlock) &&
+ "FC1GuardBlock has predecessors!!");
+
+ // Update DT/PDT
+ DTU.applyUpdates(TreeUpdates);
+
+ LI.removeBlock(FC1.Preheader);
+ DTU.deleteBB(FC1.Preheader);
+ DTU.deleteBB(FC0.ExitBlock);
+ DTU.flush();
+
+ // Is there a way to keep SE up-to-date so we don't need to forget the loops
+ // and rebuild the information in subsequent passes of fusion?
+ SE.forgetLoop(FC1.L);
+ SE.forgetLoop(FC0.L);
+
+ // Merge the loops.
+ SmallVector<BasicBlock *, 8> Blocks(FC1.L->block_begin(),
+ FC1.L->block_end());
+ for (BasicBlock *BB : Blocks) {
+ FC0.L->addBlockEntry(BB);
+ FC1.L->removeBlockFromLoop(BB);
+ if (LI.getLoopFor(BB) != FC1.L)
+ continue;
+ LI.changeLoopFor(BB, FC0.L);
+ }
+ while (!FC1.L->empty()) {
+ const auto &ChildLoopIt = FC1.L->begin();
+ Loop *ChildLoop = *ChildLoopIt;
+ FC1.L->removeChildLoop(ChildLoopIt);
+ FC0.L->addChildLoop(ChildLoop);
+ }
+
+ // Delete the now empty loop L1.
+ LI.erase(FC1.L);
+
+#ifndef NDEBUG
+ assert(!verifyFunction(*FC0.Header->getParent(), &errs()));
+ assert(DT.verify(DominatorTree::VerificationLevel::Fast));
+ assert(PDT.verify());
+ LI.verify(DT);
+ SE.verify();
+#endif
LLVM_DEBUG(dbgs() << "Fusion done:\n");
@@ -1177,6 +1596,7 @@ struct LoopFuseLegacy : public FunctionPass {
return LF.fuseLoops(F);
}
};
+} // namespace
PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) {
auto &LI = AM.getResult<LoopAnalysis>(F);
diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index e561494f19cf..dd477e800693 100644
--- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -41,6 +41,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -77,16 +78,20 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
+#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/LoopPassManager.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -102,6 +107,7 @@ using namespace llvm;
STATISTIC(NumMemSet, "Number of memset's formed from loop stores");
STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores");
+STATISTIC(NumBCmp, "Number of memcmp's formed from loop 2xload+eq-compare");
static cl::opt<bool> UseLIRCodeSizeHeurs(
"use-lir-code-size-heurs",
@@ -111,6 +117,26 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
namespace {
+// FIXME: reinventing the wheel much? Is there a cleaner solution?
+struct PMAbstraction {
+ virtual void markLoopAsDeleted(Loop *L) = 0;
+ virtual ~PMAbstraction() = default;
+};
+struct LegacyPMAbstraction : PMAbstraction {
+ LPPassManager &LPM;
+ LegacyPMAbstraction(LPPassManager &LPM) : LPM(LPM) {}
+ virtual ~LegacyPMAbstraction() = default;
+ void markLoopAsDeleted(Loop *L) override { LPM.markLoopAsDeleted(*L); }
+};
+struct NewPMAbstraction : PMAbstraction {
+ LPMUpdater &Updater;
+ NewPMAbstraction(LPMUpdater &Updater) : Updater(Updater) {}
+ virtual ~NewPMAbstraction() = default;
+ void markLoopAsDeleted(Loop *L) override {
+ Updater.markLoopAsDeleted(*L, L->getName());
+ }
+};
+
class LoopIdiomRecognize {
Loop *CurLoop = nullptr;
AliasAnalysis *AA;
@@ -120,6 +146,7 @@ class LoopIdiomRecognize {
TargetLibraryInfo *TLI;
const TargetTransformInfo *TTI;
const DataLayout *DL;
+ PMAbstraction &LoopDeleter;
OptimizationRemarkEmitter &ORE;
bool ApplyCodeSizeHeuristics;
@@ -128,9 +155,10 @@ public:
LoopInfo *LI, ScalarEvolution *SE,
TargetLibraryInfo *TLI,
const TargetTransformInfo *TTI,
- const DataLayout *DL,
+ const DataLayout *DL, PMAbstraction &LoopDeleter,
OptimizationRemarkEmitter &ORE)
- : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {}
+ : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL),
+ LoopDeleter(LoopDeleter), ORE(ORE) {}
bool runOnLoop(Loop *L);
@@ -144,6 +172,8 @@ private:
bool HasMemset;
bool HasMemsetPattern;
bool HasMemcpy;
+ bool HasMemCmp;
+ bool HasBCmp;
/// Return code for isLegalStore()
enum LegalStoreKind {
@@ -186,6 +216,32 @@ private:
bool runOnNoncountableLoop();
+ struct CmpLoopStructure {
+ Value *BCmpValue, *LatchCmpValue;
+ BasicBlock *HeaderBrEqualBB, *HeaderBrUnequalBB;
+ BasicBlock *LatchBrFinishBB, *LatchBrContinueBB;
+ };
+ bool matchBCmpLoopStructure(CmpLoopStructure &CmpLoop) const;
+ struct CmpOfLoads {
+ ICmpInst::Predicate BCmpPred;
+ Value *LoadSrcA, *LoadSrcB;
+ Value *LoadA, *LoadB;
+ };
+ bool matchBCmpOfLoads(Value *BCmpValue, CmpOfLoads &CmpOfLoads) const;
+ bool recognizeBCmpLoopControlFlow(const CmpOfLoads &CmpOfLoads,
+ CmpLoopStructure &CmpLoop) const;
+ bool recognizeBCmpLoopSCEV(uint64_t BCmpTyBytes, CmpOfLoads &CmpOfLoads,
+ const SCEV *&SrcA, const SCEV *&SrcB,
+ const SCEV *&Iterations) const;
+ bool detectBCmpIdiom(ICmpInst *&BCmpInst, CmpInst *&LatchCmpInst,
+ LoadInst *&LoadA, LoadInst *&LoadB, const SCEV *&SrcA,
+ const SCEV *&SrcB, const SCEV *&NBytes) const;
+ BasicBlock *transformBCmpControlFlow(ICmpInst *ComparedEqual);
+ void transformLoopToBCmp(ICmpInst *BCmpInst, CmpInst *LatchCmpInst,
+ LoadInst *LoadA, LoadInst *LoadB, const SCEV *SrcA,
+ const SCEV *SrcB, const SCEV *NBytes);
+ bool recognizeBCmp();
+
bool recognizePopcount();
void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst,
PHINode *CntPhi, Value *Var);
@@ -217,18 +273,20 @@ public:
LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent());
const TargetTransformInfo *TTI =
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
*L->getHeader()->getParent());
const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout();
+ LegacyPMAbstraction LoopDeleter(LPM);
// For the old PM, we can't use OptimizationRemarkEmitter as an analysis
// pass. Function analyses need to be preserved across loop transformations
// but ORE cannot be preserved (see comment before the pass definition).
OptimizationRemarkEmitter ORE(L->getHeader()->getParent());
- LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE);
+ LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, LoopDeleter, ORE);
return LIR.runOnLoop(L);
}
@@ -247,7 +305,7 @@ char LoopIdiomRecognizeLegacyPass::ID = 0;
PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
- LPMUpdater &) {
+ LPMUpdater &Updater) {
const auto *DL = &L.getHeader()->getModule()->getDataLayout();
const auto &FAM =
@@ -261,8 +319,9 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
"LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached "
"at a higher level");
+ NewPMAbstraction LoopDeleter(Updater);
LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL,
- *ORE);
+ LoopDeleter, *ORE);
if (!LIR.runOnLoop(&L))
return PreservedAnalyses::all();
@@ -299,7 +358,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
// Disable loop idiom recognition if the function's name is a common idiom.
StringRef Name = L->getHeader()->getParent()->getName();
- if (Name == "memset" || Name == "memcpy")
+ if (Name == "memset" || Name == "memcpy" || Name == "memcmp" ||
+ Name == "bcmp")
return false;
// Determine if code size heuristics need to be applied.
@@ -309,8 +369,10 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
HasMemset = TLI->has(LibFunc_memset);
HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
HasMemcpy = TLI->has(LibFunc_memcpy);
+ HasMemCmp = TLI->has(LibFunc_memcmp);
+ HasBCmp = TLI->has(LibFunc_bcmp);
- if (HasMemset || HasMemsetPattern || HasMemcpy)
+ if (HasMemset || HasMemsetPattern || HasMemcpy || HasMemCmp || HasBCmp)
if (SE->hasLoopInvariantBackedgeTakenCount(L))
return runOnCountableLoop();
@@ -961,7 +1023,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
GlobalValue::PrivateLinkage,
PatternValue, ".memset_pattern");
GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
- GV->setAlignment(16);
+ GV->setAlignment(Align(16));
Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy);
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
}
@@ -1149,7 +1211,7 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
<< "] Noncountable Loop %"
<< CurLoop->getHeader()->getName() << "\n");
- return recognizePopcount() || recognizeAndInsertFFS();
+ return recognizeBCmp() || recognizePopcount() || recognizeAndInsertFFS();
}
/// Check if the given conditional branch is based on the comparison between
@@ -1823,3 +1885,811 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB,
// loop. The loop would otherwise not be deleted even if it becomes empty.
SE->forgetLoop(CurLoop);
}
+
+bool LoopIdiomRecognize::matchBCmpLoopStructure(
+ CmpLoopStructure &CmpLoop) const {
+ ICmpInst::Predicate BCmpPred;
+
+ // We are looking for the following basic layout:
+ // PreheaderBB: <preheader> ; preds = ???
+ // <...>
+ // br label %LoopHeaderBB
+ // LoopHeaderBB: <header,exiting> ; preds = %PreheaderBB,%LoopLatchBB
+ // <...>
+ // %BCmpValue = icmp <...>
+ // br i1 %BCmpValue, label %LoopLatchBB, label %Successor0
+ // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB
+ // <...>
+ // %LatchCmpValue = <are we done, or do next iteration?>
+ // br i1 %LatchCmpValue, label %Successor1, label %LoopHeaderBB
+ // Successor0: <exit> ; preds = %LoopHeaderBB
+ // <...>
+ // Successor1: <exit> ; preds = %LoopLatchBB
+ // <...>
+ //
+ // Successor0 and Successor1 may or may not be the same basic block.
+
+ // Match basic frame-work of this supposedly-comparison loop.
+ using namespace PatternMatch;
+ if (!match(CurLoop->getHeader()->getTerminator(),
+ m_Br(m_CombineAnd(m_ICmp(BCmpPred, m_Value(), m_Value()),
+ m_Value(CmpLoop.BCmpValue)),
+ CmpLoop.HeaderBrEqualBB, CmpLoop.HeaderBrUnequalBB)) ||
+ !match(CurLoop->getLoopLatch()->getTerminator(),
+ m_Br(m_CombineAnd(m_Cmp(), m_Value(CmpLoop.LatchCmpValue)),
+ CmpLoop.LatchBrFinishBB, CmpLoop.LatchBrContinueBB))) {
+ LLVM_DEBUG(dbgs() << "Basic control-flow layout unrecognized.\n");
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "Recognized basic control-flow layout.\n");
+ return true;
+}
+
+bool LoopIdiomRecognize::matchBCmpOfLoads(Value *BCmpValue,
+ CmpOfLoads &CmpOfLoads) const {
+ using namespace PatternMatch;
+ LLVM_DEBUG(dbgs() << "Analyzing header icmp " << *BCmpValue
+ << " as bcmp pattern.\n");
+
+ // Match bcmp-style loop header cmp. It must be an eq-icmp of loads. Example:
+ // %v0 = load <...>, <...>* %LoadSrcA
+ // %v1 = load <...>, <...>* %LoadSrcB
+ // %CmpLoop.BCmpValue = icmp eq <...> %v0, %v1
+ // There won't be any no-op bitcasts between load and icmp,
+ // they would have been transformed into a load of bitcast.
+ // FIXME: {b,mem}cmp() calls have the same semantics as icmp. Match them too.
+ if (!match(BCmpValue,
+ m_ICmp(CmpOfLoads.BCmpPred,
+ m_CombineAnd(m_Load(m_Value(CmpOfLoads.LoadSrcA)),
+ m_Value(CmpOfLoads.LoadA)),
+ m_CombineAnd(m_Load(m_Value(CmpOfLoads.LoadSrcB)),
+ m_Value(CmpOfLoads.LoadB)))) ||
+ !ICmpInst::isEquality(CmpOfLoads.BCmpPred)) {
+ LLVM_DEBUG(dbgs() << "Loop header icmp did not match bcmp pattern.\n");
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "Recognized header icmp as bcmp pattern with loads:\n\t"
+ << *CmpOfLoads.LoadA << "\n\t" << *CmpOfLoads.LoadB
+ << "\n");
+ // FIXME: handle memcmp pattern?
+ return true;
+}
+
+bool LoopIdiomRecognize::recognizeBCmpLoopControlFlow(
+ const CmpOfLoads &CmpOfLoads, CmpLoopStructure &CmpLoop) const {
+ BasicBlock *LoopHeaderBB = CurLoop->getHeader();
+ BasicBlock *LoopLatchBB = CurLoop->getLoopLatch();
+
+ // Be wary, comparisons can be inverted, canonicalize order.
+ // If this 'element' comparison passed, we expect to proceed to the next elt.
+ if (CmpOfLoads.BCmpPred != ICmpInst::Predicate::ICMP_EQ)
+ std::swap(CmpLoop.HeaderBrEqualBB, CmpLoop.HeaderBrUnequalBB);
+ // The predicate on loop latch does not matter, just canonicalize some order.
+ if (CmpLoop.LatchBrContinueBB != LoopHeaderBB)
+ std::swap(CmpLoop.LatchBrFinishBB, CmpLoop.LatchBrContinueBB);
+
+ SmallVector<BasicBlock *, 2> ExitBlocks;
+
+ CurLoop->getUniqueExitBlocks(ExitBlocks);
+ assert(ExitBlocks.size() <= 2U && "Can't have more than two exit blocks.");
+
+ // Check that control-flow between blocks is as expected.
+ if (CmpLoop.HeaderBrEqualBB != LoopLatchBB ||
+ CmpLoop.LatchBrContinueBB != LoopHeaderBB ||
+ !is_contained(ExitBlocks, CmpLoop.HeaderBrUnequalBB) ||
+ !is_contained(ExitBlocks, CmpLoop.LatchBrFinishBB)) {
+ LLVM_DEBUG(dbgs() << "Loop control-flow not recognized.\n");
+ return false;
+ }
+
+ assert(!is_contained(ExitBlocks, CmpLoop.HeaderBrEqualBB) &&
+ !is_contained(ExitBlocks, CmpLoop.LatchBrContinueBB) &&
+ "Unexpected exit edges.");
+
+ LLVM_DEBUG(dbgs() << "Recognized loop control-flow.\n");
+
+ LLVM_DEBUG(dbgs() << "Performing side-effect analysis on the loop.\n");
+ assert(CurLoop->isLCSSAForm(*DT) && "Should only get LCSSA-form loops here.");
+ // No loop instructions must be used outside of the loop. Since we are in
+ // LCSSA form, we only need to check successor block's PHI nodes's incoming
+ // values for incoming blocks that are the loop basic blocks.
+ for (const BasicBlock *ExitBB : ExitBlocks) {
+ for (const PHINode &PHI : ExitBB->phis()) {
+ for (const BasicBlock *LoopBB :
+ make_filter_range(PHI.blocks(), [this](BasicBlock *PredecessorBB) {
+ return CurLoop->contains(PredecessorBB);
+ })) {
+ const auto *I =
+ dyn_cast<Instruction>(PHI.getIncomingValueForBlock(LoopBB));
+ if (I && CurLoop->contains(I)) {
+ LLVM_DEBUG(dbgs()
+ << "Loop contains instruction " << *I
+ << " which is used outside of the loop in basic block "
+ << ExitBB->getName() << " in phi node " << PHI << "\n");
+ return false;
+ }
+ }
+ }
+ }
+ // Similarly, the loop should not have any other observable side-effects
+ // other than the final comparison result.
+ for (BasicBlock *LoopBB : CurLoop->blocks()) {
+ for (Instruction &I : *LoopBB) {
+ if (isa<DbgInfoIntrinsic>(I)) // Ignore dbginfo.
+ continue; // FIXME: anything else? lifetime info?
+ if ((I.mayHaveSideEffects() || I.isAtomic() || I.isFenceLike()) &&
+ &I != CmpOfLoads.LoadA && &I != CmpOfLoads.LoadB) {
+ LLVM_DEBUG(
+ dbgs() << "Loop contains instruction with potential side-effects: "
+ << I << "\n");
+ return false;
+ }
+ }
+ }
+ LLVM_DEBUG(dbgs() << "No loop instructions deemed to have side-effects.\n");
+ return true;
+}
+
+bool LoopIdiomRecognize::recognizeBCmpLoopSCEV(uint64_t BCmpTyBytes,
+ CmpOfLoads &CmpOfLoads,
+ const SCEV *&SrcA,
+ const SCEV *&SrcB,
+ const SCEV *&Iterations) const {
+ // Try to compute SCEV of the loads, for this loop's scope.
+ const auto *ScevForSrcA = dyn_cast<SCEVAddRecExpr>(
+ SE->getSCEVAtScope(CmpOfLoads.LoadSrcA, CurLoop));
+ const auto *ScevForSrcB = dyn_cast<SCEVAddRecExpr>(
+ SE->getSCEVAtScope(CmpOfLoads.LoadSrcB, CurLoop));
+ if (!ScevForSrcA || !ScevForSrcB) {
+ LLVM_DEBUG(dbgs() << "Failed to get SCEV expressions for load sources.\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << "Got SCEV expressions (at loop scope) for loads:\n\t"
+ << *ScevForSrcA << "\n\t" << *ScevForSrcB << "\n");
+
+ // Loads must have folloving SCEV exprs: {%ptr,+,BCmpTyBytes}<%LoopHeaderBB>
+ const SCEV *RecStepForA = ScevForSrcA->getStepRecurrence(*SE);
+ const SCEV *RecStepForB = ScevForSrcB->getStepRecurrence(*SE);
+ if (!ScevForSrcA->isAffine() || !ScevForSrcB->isAffine() ||
+ ScevForSrcA->getLoop() != CurLoop || ScevForSrcB->getLoop() != CurLoop ||
+ RecStepForA != RecStepForB || !isa<SCEVConstant>(RecStepForA) ||
+ cast<SCEVConstant>(RecStepForA)->getAPInt() != BCmpTyBytes) {
+ LLVM_DEBUG(dbgs() << "Unsupported SCEV expressions for loads. Only support "
+ "affine SCEV expressions originating in the loop we "
+ "are analysing with identical constant positive step, "
+ "equal to the count of bytes compared. Got:\n\t"
+ << *RecStepForA << "\n\t" << *RecStepForB << "\n");
+ return false;
+ // FIXME: can support BCmpTyBytes > Step.
+ // But will need to account for the extra bytes compared at the end.
+ }
+
+ SrcA = ScevForSrcA->getStart();
+ SrcB = ScevForSrcB->getStart();
+ LLVM_DEBUG(dbgs() << "Got SCEV expressions for load sources:\n\t" << *SrcA
+ << "\n\t" << *SrcB << "\n");
+
+ // The load sources must be loop-invants that dominate the loop header.
+ if (SrcA == SE->getCouldNotCompute() || SrcB == SE->getCouldNotCompute() ||
+ !SE->isAvailableAtLoopEntry(SrcA, CurLoop) ||
+ !SE->isAvailableAtLoopEntry(SrcB, CurLoop)) {
+ LLVM_DEBUG(dbgs() << "Unsupported SCEV expressions for loads, unavaliable "
+ "prior to loop header.\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << "SCEV expressions for loads are acceptable.\n");
+
+ // bcmp / memcmp take length argument as size_t, so let's conservatively
+ // assume that the iteration count should be not wider than that.
+ Type *CmpFuncSizeTy = DL->getIntPtrType(SE->getContext());
+
+ // For how many iterations is loop guaranteed not to exit via LoopLatch?
+ // This is one less than the maximal number of comparisons,and is: n + -1
+ const SCEV *LoopExitCount =
+ SE->getExitCount(CurLoop, CurLoop->getLoopLatch());
+ LLVM_DEBUG(dbgs() << "Got SCEV expression for loop latch exit count: "
+ << *LoopExitCount << "\n");
+ // Exit count, similarly, must be loop-invant that dominates the loop header.
+ if (LoopExitCount == SE->getCouldNotCompute() ||
+ !LoopExitCount->getType()->isIntOrPtrTy() ||
+ LoopExitCount->getType()->getScalarSizeInBits() >
+ CmpFuncSizeTy->getScalarSizeInBits() ||
+ !SE->isAvailableAtLoopEntry(LoopExitCount, CurLoop)) {
+ LLVM_DEBUG(dbgs() << "Unsupported SCEV expression for loop latch exit.\n");
+ return false;
+ }
+
+ // LoopExitCount is always one less than the actual count of iterations.
+ // Do this before cast, else we will be stuck with 1 + zext(-1 + n)
+ Iterations = SE->getAddExpr(
+ LoopExitCount, SE->getOne(LoopExitCount->getType()), SCEV::FlagNUW);
+ assert(Iterations != SE->getCouldNotCompute() &&
+ "Shouldn't fail to increment by one.");
+
+ LLVM_DEBUG(dbgs() << "Computed iteration count: " << *Iterations << "\n");
+ return true;
+}
+
+/// Return true iff the bcmp idiom is detected in the loop.
+///
+/// Additionally:
+/// 1) \p BCmpInst is set to the root byte-comparison instruction.
+/// 2) \p LatchCmpInst is set to the comparison that controls the latch.
+/// 3) \p LoadA is set to the first LoadInst.
+/// 4) \p LoadB is set to the second LoadInst.
+/// 5) \p SrcA is set to the first source location that is being compared.
+/// 6) \p SrcB is set to the second source location that is being compared.
+/// 7) \p NBytes is set to the number of bytes to compare.
+bool LoopIdiomRecognize::detectBCmpIdiom(ICmpInst *&BCmpInst,
+ CmpInst *&LatchCmpInst,
+ LoadInst *&LoadA, LoadInst *&LoadB,
+ const SCEV *&SrcA, const SCEV *&SrcB,
+ const SCEV *&NBytes) const {
+ LLVM_DEBUG(dbgs() << "Recognizing bcmp idiom\n");
+
+ // Give up if the loop is not in normal form, or has more than 2 blocks.
+ if (!CurLoop->isLoopSimplifyForm() || CurLoop->getNumBlocks() > 2) {
+ LLVM_DEBUG(dbgs() << "Basic loop structure unrecognized.\n");
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "Recognized basic loop structure.\n");
+
+ CmpLoopStructure CmpLoop;
+ if (!matchBCmpLoopStructure(CmpLoop))
+ return false;
+
+ CmpOfLoads CmpOfLoads;
+ if (!matchBCmpOfLoads(CmpLoop.BCmpValue, CmpOfLoads))
+ return false;
+
+ if (!recognizeBCmpLoopControlFlow(CmpOfLoads, CmpLoop))
+ return false;
+
+ BCmpInst = cast<ICmpInst>(CmpLoop.BCmpValue); // FIXME: is there no
+ LatchCmpInst = cast<CmpInst>(CmpLoop.LatchCmpValue); // way to combine
+ LoadA = cast<LoadInst>(CmpOfLoads.LoadA); // these cast with
+ LoadB = cast<LoadInst>(CmpOfLoads.LoadB); // m_Value() matcher?
+
+ Type *BCmpValTy = BCmpInst->getOperand(0)->getType();
+ LLVMContext &Context = BCmpValTy->getContext();
+ uint64_t BCmpTyBits = DL->getTypeSizeInBits(BCmpValTy);
+ static constexpr uint64_t ByteTyBits = 8;
+
+ LLVM_DEBUG(dbgs() << "Got comparison between values of type " << *BCmpValTy
+ << " of size " << BCmpTyBits
+ << " bits (while byte = " << ByteTyBits << " bits).\n");
+ // bcmp()/memcmp() minimal unit of work is a byte. Therefore we must check
+ // that we are dealing with a multiple of a byte here.
+ if (BCmpTyBits % ByteTyBits != 0) {
+ LLVM_DEBUG(dbgs() << "Value size is not a multiple of byte.\n");
+ return false;
+ // FIXME: could still be done under a run-time check that the total bit
+ // count is a multiple of a byte i guess? Or handle remainder separately?
+ }
+
+ // Each comparison is done on this many bytes.
+ uint64_t BCmpTyBytes = BCmpTyBits / ByteTyBits;
+ LLVM_DEBUG(dbgs() << "Size is exactly " << BCmpTyBytes
+ << " bytes, eligible for bcmp conversion.\n");
+
+ const SCEV *Iterations;
+ if (!recognizeBCmpLoopSCEV(BCmpTyBytes, CmpOfLoads, SrcA, SrcB, Iterations))
+ return false;
+
+ // bcmp / memcmp take length argument as size_t, do promotion now.
+ Type *CmpFuncSizeTy = DL->getIntPtrType(Context);
+ Iterations = SE->getNoopOrZeroExtend(Iterations, CmpFuncSizeTy);
+ assert(Iterations != SE->getCouldNotCompute() && "Promotion failed.");
+ // Note that it didn't do ptrtoint cast, we will need to do it manually.
+
+ // We will be comparing *bytes*, not BCmpTy, we need to recalculate size.
+ // It's a multiplication, and it *could* overflow. But for it to overflow
+ // we'd want to compare more bytes than could be represented by size_t, But
+ // allocation functions also take size_t. So how'd you produce such buffer?
+ // FIXME: we likely need to actually check that we know this won't overflow,
+ // via llvm::computeOverflowForUnsignedMul().
+ NBytes = SE->getMulExpr(
+ Iterations, SE->getConstant(CmpFuncSizeTy, BCmpTyBytes), SCEV::FlagNUW);
+ assert(NBytes != SE->getCouldNotCompute() &&
+ "Shouldn't fail to increment by one.");
+
+ LLVM_DEBUG(dbgs() << "Computed total byte count: " << *NBytes << "\n");
+
+ if (LoadA->getPointerAddressSpace() != LoadB->getPointerAddressSpace() ||
+ LoadA->getPointerAddressSpace() != 0 || !LoadA->isSimple() ||
+ !LoadB->isSimple()) {
+ StringLiteral L("Unsupported loads in idiom - only support identical, "
+ "simple loads from address space 0.\n");
+ LLVM_DEBUG(dbgs() << L);
+ ORE.emit([&]() {
+ return OptimizationRemarkMissed(DEBUG_TYPE, "BCmpIdiomUnsupportedLoads",
+ BCmpInst->getDebugLoc(),
+ CurLoop->getHeader())
+ << L;
+ });
+ return false; // FIXME: support non-simple loads.
+ }
+
+ LLVM_DEBUG(dbgs() << "Recognized bcmp idiom\n");
+ ORE.emit([&]() {
+ return OptimizationRemarkAnalysis(DEBUG_TYPE, "RecognizedBCmpIdiom",
+ CurLoop->getStartLoc(),
+ CurLoop->getHeader())
+ << "Loop recognized as a bcmp idiom";
+ });
+
+ return true;
+}
+
+BasicBlock *
+LoopIdiomRecognize::transformBCmpControlFlow(ICmpInst *ComparedEqual) {
+ LLVM_DEBUG(dbgs() << "Transforming control-flow.\n");
+ SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
+
+ BasicBlock *PreheaderBB = CurLoop->getLoopPreheader();
+ BasicBlock *HeaderBB = CurLoop->getHeader();
+ BasicBlock *LoopLatchBB = CurLoop->getLoopLatch();
+ SmallString<32> LoopName = CurLoop->getName();
+ Function *Func = PreheaderBB->getParent();
+ LLVMContext &Context = Func->getContext();
+
+ // Before doing anything, drop SCEV info.
+ SE->forgetLoop(CurLoop);
+
+ // Here we start with: (0/6)
+ // PreheaderBB: <preheader> ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // br label %LoopHeaderBB
+ // LoopHeaderBB: <header,exiting> ; preds = %PreheaderBB,%LoopLatchBB
+ // <...>
+ // br i1 %<...>, label %LoopLatchBB, label %Successor0BB
+ // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB
+ // <...>
+ // br i1 %<...>, label %Successor1BB, label %LoopHeaderBB
+ // Successor0BB: <exit> ; preds = %LoopHeaderBB
+ // %S0PHI = phi <...> [ <...>, %LoopHeaderBB ]
+ // <...>
+ // Successor1BB: <exit> ; preds = %LoopLatchBB
+ // %S1PHI = phi <...> [ <...>, %LoopLatchBB ]
+ // <...>
+ //
+ // Successor0 and Successor1 may or may not be the same basic block.
+
+ // Decouple the edge between loop preheader basic block and loop header basic
+ // block. Thus the loop has become unreachable.
+ assert(cast<BranchInst>(PreheaderBB->getTerminator())->isUnconditional() &&
+ PreheaderBB->getTerminator()->getSuccessor(0) == HeaderBB &&
+ "Preheader bb must end with an unconditional branch to header bb.");
+ PreheaderBB->getTerminator()->eraseFromParent();
+ DTUpdates.push_back({DominatorTree::Delete, PreheaderBB, HeaderBB});
+
+ // Create a new preheader basic block before loop header basic block.
+ auto *PhonyPreheaderBB = BasicBlock::Create(
+ Context, LoopName + ".phonypreheaderbb", Func, HeaderBB);
+ // And insert an unconditional branch from phony preheader basic block to
+ // loop header basic block.
+ IRBuilder<>(PhonyPreheaderBB).CreateBr(HeaderBB);
+ DTUpdates.push_back({DominatorTree::Insert, PhonyPreheaderBB, HeaderBB});
+
+ // Create a *single* new empty block that we will substitute as a
+ // successor basic block for the loop's exits. This one is temporary.
+ // Much like phony preheader basic block, it is not connected.
+ auto *PhonySuccessorBB =
+ BasicBlock::Create(Context, LoopName + ".phonysuccessorbb", Func,
+ LoopLatchBB->getNextNode());
+ // That block must have *some* non-PHI instruction, or else deleteDeadLoop()
+ // will mess up cleanup of dbginfo, and verifier will complain.
+ IRBuilder<>(PhonySuccessorBB).CreateUnreachable();
+
+ // Create two new empty blocks that we will use to preserve the original
+ // loop exit control-flow, and preserve the incoming values in the PHI nodes
+ // in loop's successor exit blocks. These will live one.
+ auto *ComparedUnequalBB =
+ BasicBlock::Create(Context, ComparedEqual->getName() + ".unequalbb", Func,
+ PhonySuccessorBB->getNextNode());
+ auto *ComparedEqualBB =
+ BasicBlock::Create(Context, ComparedEqual->getName() + ".equalbb", Func,
+ PhonySuccessorBB->getNextNode());
+
+ // By now we have: (1/6)
+ // PreheaderBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // [no terminator instruction!]
+ // PhonyPreheaderBB: <preheader> ; No preds, UNREACHABLE!
+ // br label %LoopHeaderBB
+ // LoopHeaderBB: <header,exiting> ; preds = %PhonyPreheaderBB, %LoopLatchBB
+ // <...>
+ // br i1 %<...>, label %LoopLatchBB, label %Successor0BB
+ // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB
+ // <...>
+ // br i1 %<...>, label %Successor1BB, label %LoopHeaderBB
+ // PhonySuccessorBB: ; No preds, UNREACHABLE!
+ // unreachable
+ // EqualBB: ; No preds, UNREACHABLE!
+ // [no terminator instruction!]
+ // UnequalBB: ; No preds, UNREACHABLE!
+ // [no terminator instruction!]
+ // Successor0BB: <exit> ; preds = %LoopHeaderBB
+ // %S0PHI = phi <...> [ <...>, %LoopHeaderBB ]
+ // <...>
+ // Successor1BB: <exit> ; preds = %LoopLatchBB
+ // %S1PHI = phi <...> [ <...>, %LoopLatchBB ]
+ // <...>
+
+ // What is the mapping/replacement basic block for exiting out of the loop
+ // from either of old's loop basic blocks?
+ auto GetReplacementBB = [this, ComparedEqualBB,
+ ComparedUnequalBB](const BasicBlock *OldBB) {
+ assert(CurLoop->contains(OldBB) && "Only for loop's basic blocks.");
+ if (OldBB == CurLoop->getLoopLatch()) // "all elements compared equal".
+ return ComparedEqualBB;
+ if (OldBB == CurLoop->getHeader()) // "element compared unequal".
+ return ComparedUnequalBB;
+ llvm_unreachable("Only had two basic blocks in loop.");
+ };
+
+ // What are the exits out of this loop?
+ SmallVector<Loop::Edge, 2> LoopExitEdges;
+ CurLoop->getExitEdges(LoopExitEdges);
+ assert(LoopExitEdges.size() == 2 && "Should have only to two exit edges.");
+
+ // Populate new basic blocks, update the exiting control-flow, PHI nodes.
+ for (const Loop::Edge &Edge : LoopExitEdges) {
+ auto *OldLoopBB = const_cast<BasicBlock *>(Edge.first);
+ auto *SuccessorBB = const_cast<BasicBlock *>(Edge.second);
+ assert(CurLoop->contains(OldLoopBB) && !CurLoop->contains(SuccessorBB) &&
+ "Unexpected edge.");
+
+ // If we would exit the loop from this loop's basic block,
+ // what semantically would that mean? Did comparison succeed or fail?
+ BasicBlock *NewBB = GetReplacementBB(OldLoopBB);
+ assert(NewBB->empty() && "Should not get same new basic block here twice.");
+ IRBuilder<> Builder(NewBB);
+ Builder.SetCurrentDebugLocation(OldLoopBB->getTerminator()->getDebugLoc());
+ Builder.CreateBr(SuccessorBB);
+ DTUpdates.push_back({DominatorTree::Insert, NewBB, SuccessorBB});
+ // Also, be *REALLY* careful with PHI nodes in successor basic block,
+ // update them to recieve the same input value, but not from current loop's
+ // basic block, but from new basic block instead.
+ SuccessorBB->replacePhiUsesWith(OldLoopBB, NewBB);
+ // Also, change loop control-flow. This loop's basic block shall no longer
+ // exit from the loop to it's original successor basic block, but to our new
+ // phony successor basic block. Note that new successor will be unique exit.
+ OldLoopBB->getTerminator()->replaceSuccessorWith(SuccessorBB,
+ PhonySuccessorBB);
+ DTUpdates.push_back({DominatorTree::Delete, OldLoopBB, SuccessorBB});
+ DTUpdates.push_back({DominatorTree::Insert, OldLoopBB, PhonySuccessorBB});
+ }
+
+ // Inform DomTree about edge changes. Note that LoopInfo is still out-of-date.
+ assert(DTUpdates.size() == 8 && "Update count prediction failed.");
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ DTU.applyUpdates(DTUpdates);
+ DTUpdates.clear();
+
+ // By now we have: (2/6)
+ // PreheaderBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // [no terminator instruction!]
+ // PhonyPreheaderBB: <preheader> ; No preds, UNREACHABLE!
+ // br label %LoopHeaderBB
+ // LoopHeaderBB: <header,exiting> ; preds = %PhonyPreheaderBB, %LoopLatchBB
+ // <...>
+ // br i1 %<...>, label %LoopLatchBB, label %PhonySuccessorBB
+ // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB
+ // <...>
+ // br i1 %<...>, label %PhonySuccessorBB, label %LoopHeaderBB
+ // PhonySuccessorBB: <uniq. exit> ; preds = %LoopHeaderBB, %LoopLatchBB
+ // unreachable
+ // EqualBB: ; No preds, UNREACHABLE!
+ // br label %Successor1BB
+ // UnequalBB: ; No preds, UNREACHABLE!
+ // br label %Successor0BB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ // *Finally*, zap the original loop. Record it's parent loop though.
+ Loop *ParentLoop = CurLoop->getParentLoop();
+ LLVM_DEBUG(dbgs() << "Deleting old loop.\n");
+ LoopDeleter.markLoopAsDeleted(CurLoop); // Mark as deleted *BEFORE* deleting!
+ deleteDeadLoop(CurLoop, DT, SE, LI); // And actually delete the loop.
+ CurLoop = nullptr;
+
+ // By now we have: (3/6)
+ // PreheaderBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // [no terminator instruction!]
+ // PhonyPreheaderBB: ; No preds, UNREACHABLE!
+ // br label %PhonySuccessorBB
+ // PhonySuccessorBB: ; preds = %PhonyPreheaderBB
+ // unreachable
+ // EqualBB: ; No preds, UNREACHABLE!
+ // br label %Successor1BB
+ // UnequalBB: ; No preds, UNREACHABLE!
+ // br label %Successor0BB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ // Now, actually restore the CFG.
+
+ // Insert an unconditional branch from an actual preheader basic block to
+ // phony preheader basic block.
+ IRBuilder<>(PreheaderBB).CreateBr(PhonyPreheaderBB);
+ DTUpdates.push_back({DominatorTree::Insert, PhonyPreheaderBB, HeaderBB});
+ // Insert proper conditional branch from phony successor basic block to the
+ // "dispatch" basic blocks, which were used to preserve incoming values in
+ // original loop's successor basic blocks.
+ assert(isa<UnreachableInst>(PhonySuccessorBB->getTerminator()) &&
+ "Yep, that's the one we created to keep deleteDeadLoop() happy.");
+ PhonySuccessorBB->getTerminator()->eraseFromParent();
+ {
+ IRBuilder<> Builder(PhonySuccessorBB);
+ Builder.SetCurrentDebugLocation(ComparedEqual->getDebugLoc());
+ Builder.CreateCondBr(ComparedEqual, ComparedEqualBB, ComparedUnequalBB);
+ }
+ DTUpdates.push_back(
+ {DominatorTree::Insert, PhonySuccessorBB, ComparedEqualBB});
+ DTUpdates.push_back(
+ {DominatorTree::Insert, PhonySuccessorBB, ComparedUnequalBB});
+
+ BasicBlock *DispatchBB = PhonySuccessorBB;
+ DispatchBB->setName(LoopName + ".bcmpdispatchbb");
+
+ assert(DTUpdates.size() == 3 && "Update count prediction failed.");
+ DTU.applyUpdates(DTUpdates);
+ DTUpdates.clear();
+
+ // By now we have: (4/6)
+ // PreheaderBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // br label %PhonyPreheaderBB
+ // PhonyPreheaderBB: ; preds = %PreheaderBB
+ // br label %DispatchBB
+ // DispatchBB: ; preds = %PhonyPreheaderBB
+ // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB
+ // EqualBB: ; preds = %DispatchBB
+ // br label %Successor1BB
+ // UnequalBB: ; preds = %DispatchBB
+ // br label %Successor0BB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ // The basic CFG has been restored! Now let's merge redundant basic blocks.
+
+ // Merge phony successor basic block into it's only predecessor,
+ // phony preheader basic block. It is fully pointlessly redundant.
+ MergeBasicBlockIntoOnlyPred(DispatchBB, &DTU);
+
+ // By now we have: (5/6)
+ // PreheaderBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // br label %DispatchBB
+ // DispatchBB: ; preds = %PreheaderBB
+ // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB
+ // EqualBB: ; preds = %DispatchBB
+ // br label %Successor1BB
+ // UnequalBB: ; preds = %DispatchBB
+ // br label %Successor0BB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ // Was this loop nested?
+ if (!ParentLoop) {
+ // If the loop was *NOT* nested, then let's also merge phony successor
+ // basic block into it's only predecessor, preheader basic block.
+ // Also, here we need to update LoopInfo.
+ LI->removeBlock(PreheaderBB);
+ MergeBasicBlockIntoOnlyPred(DispatchBB, &DTU);
+
+ // By now we have: (6/6)
+ // DispatchBB: ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB
+ // EqualBB: ; preds = %DispatchBB
+ // br label %Successor1BB
+ // UnequalBB: ; preds = %DispatchBB
+ // br label %Successor0BB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ return DispatchBB;
+ }
+
+ // Otherwise, we need to "preserve" the LoopSimplify form of the deleted loop.
+ // To achieve that, we shall keep the preheader basic block (mainly so that
+ // the loop header block will be guaranteed to have a predecessor outside of
+ // the loop), and create a phony loop with all these new three basic blocks.
+ Loop *PhonyLoop = LI->AllocateLoop();
+ ParentLoop->addChildLoop(PhonyLoop);
+ PhonyLoop->addBasicBlockToLoop(DispatchBB, *LI);
+ PhonyLoop->addBasicBlockToLoop(ComparedEqualBB, *LI);
+ PhonyLoop->addBasicBlockToLoop(ComparedUnequalBB, *LI);
+
+ // But we only have a preheader basic block, a header basic block block and
+ // two exiting basic blocks. For a proper loop we also need a backedge from
+ // non-header basic block to header bb.
+ // Let's just add a never-taken branch from both of the exiting basic blocks.
+ for (BasicBlock *BB : {ComparedEqualBB, ComparedUnequalBB}) {
+ BranchInst *OldTerminator = cast<BranchInst>(BB->getTerminator());
+ assert(OldTerminator->isUnconditional() && "That's the one we created.");
+ BasicBlock *SuccessorBB = OldTerminator->getSuccessor(0);
+
+ IRBuilder<> Builder(OldTerminator);
+ Builder.SetCurrentDebugLocation(OldTerminator->getDebugLoc());
+ Builder.CreateCondBr(ConstantInt::getTrue(Context), SuccessorBB,
+ DispatchBB);
+ OldTerminator->eraseFromParent();
+ // Yes, the backedge will never be taken. The control-flow is redundant.
+ // If it can be simplified further, other passes will take care.
+ DTUpdates.push_back({DominatorTree::Delete, BB, SuccessorBB});
+ DTUpdates.push_back({DominatorTree::Insert, BB, SuccessorBB});
+ DTUpdates.push_back({DominatorTree::Insert, BB, DispatchBB});
+ }
+ assert(DTUpdates.size() == 6 && "Update count prediction failed.");
+ DTU.applyUpdates(DTUpdates);
+ DTUpdates.clear();
+
+ // By now we have: (6/6)
+ // PreheaderBB: <preheader> ; preds = ???
+ // <...>
+ // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes)
+ // %ComparedEqual = icmp eq <...> %memcmp, 0
+ // br label %BCmpDispatchBB
+ // BCmpDispatchBB: <header> ; preds = %PreheaderBB
+ // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB
+ // EqualBB: <latch,exiting> ; preds = %BCmpDispatchBB
+ // br i1 %true, label %Successor1BB, label %BCmpDispatchBB
+ // UnequalBB: <latch,exiting> ; preds = %BCmpDispatchBB
+ // br i1 %true, label %Successor0BB, label %BCmpDispatchBB
+ // Successor0BB: ; preds = %UnequalBB
+ // %S0PHI = phi <...> [ <...>, %UnequalBB ]
+ // <...>
+ // Successor1BB: ; preds = %EqualBB
+ // %S0PHI = phi <...> [ <...>, %EqualBB ]
+ // <...>
+
+ // Finally fully DONE!
+ return DispatchBB;
+}
+
+void LoopIdiomRecognize::transformLoopToBCmp(ICmpInst *BCmpInst,
+ CmpInst *LatchCmpInst,
+ LoadInst *LoadA, LoadInst *LoadB,
+ const SCEV *SrcA, const SCEV *SrcB,
+ const SCEV *NBytes) {
+ // We will be inserting before the terminator instruction of preheader block.
+ IRBuilder<> Builder(CurLoop->getLoopPreheader()->getTerminator());
+
+ LLVM_DEBUG(dbgs() << "Transforming bcmp loop idiom into a call.\n");
+ LLVM_DEBUG(dbgs() << "Emitting new instructions.\n");
+
+ // Expand the SCEV expressions for both sources to compare, and produce value
+ // for the byte len (beware of Iterations potentially being a pointer, and
+ // account for element size being BCmpTyBytes bytes, which may be not 1 byte)
+ Value *PtrA, *PtrB, *Len;
+ {
+ SCEVExpander SExp(*SE, *DL, "LoopToBCmp");
+ SExp.setInsertPoint(&*Builder.GetInsertPoint());
+
+ auto HandlePtr = [&SExp](LoadInst *Load, const SCEV *Src) {
+ SExp.SetCurrentDebugLocation(DebugLoc());
+ // If the pointer operand of original load had dbgloc - use it.
+ if (const auto *I = dyn_cast<Instruction>(Load->getPointerOperand()))
+ SExp.SetCurrentDebugLocation(I->getDebugLoc());
+ return SExp.expandCodeFor(Src);
+ };
+ PtrA = HandlePtr(LoadA, SrcA);
+ PtrB = HandlePtr(LoadB, SrcB);
+
+ // For len calculation let's use dbgloc for the loop's latch condition.
+ Builder.SetCurrentDebugLocation(LatchCmpInst->getDebugLoc());
+ SExp.SetCurrentDebugLocation(LatchCmpInst->getDebugLoc());
+ Len = SExp.expandCodeFor(NBytes);
+
+ Type *CmpFuncSizeTy = DL->getIntPtrType(Builder.getContext());
+ assert(SE->getTypeSizeInBits(Len->getType()) ==
+ DL->getTypeSizeInBits(CmpFuncSizeTy) &&
+ "Len should already have the correct size.");
+
+ // Make sure that iteration count is a number, insert ptrtoint cast if not.
+ if (Len->getType()->isPointerTy())
+ Len = Builder.CreatePtrToInt(Len, CmpFuncSizeTy);
+ assert(Len->getType() == CmpFuncSizeTy && "Should have correct type now.");
+
+ Len->setName(Len->getName() + ".bytecount");
+
+ // There is no legality check needed. We want to compare that the memory
+ // regions [PtrA, PtrA+Len) and [PtrB, PtrB+Len) are fully identical, equal.
+ // For them to be fully equal, they must match bit-by-bit. And likewise,
+ // for them to *NOT* be fully equal, they have to differ just by one bit.
+ // The step of comparison (bits compared at once) simply does not matter.
+ }
+
+ // For the rest of new instructions, dbgloc should point at the value cmp.
+ Builder.SetCurrentDebugLocation(BCmpInst->getDebugLoc());
+
+ // Emit the comparison itself.
+ auto *CmpCall =
+ cast<CallInst>(HasBCmp ? emitBCmp(PtrA, PtrB, Len, Builder, *DL, TLI)
+ : emitMemCmp(PtrA, PtrB, Len, Builder, *DL, TLI));
+ // FIXME: add {B,Mem}CmpInst with MemoryCompareInst
+ // (based on MemIntrinsicBase) as base?
+ // FIXME: propagate metadata from loads? (alignments, AS, TBAA, ...)
+
+ // {b,mem}cmp returned 0 if they were equal, or non-zero if not equal.
+ auto *ComparedEqual = cast<ICmpInst>(Builder.CreateICmpEQ(
+ CmpCall, ConstantInt::get(CmpCall->getType(), 0),
+ PtrA->getName() + ".vs." + PtrB->getName() + ".eqcmp"));
+
+ BasicBlock *BB = transformBCmpControlFlow(ComparedEqual);
+ Builder.ClearInsertionPoint();
+
+ // We're done.
+ LLVM_DEBUG(dbgs() << "Transformed loop bcmp idiom into a call.\n");
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "TransformedBCmpIdiomToCall",
+ CmpCall->getDebugLoc(), BB)
+ << "Transformed bcmp idiom into a call to "
+ << ore::NV("NewFunction", CmpCall->getCalledFunction())
+ << "() function";
+ });
+ ++NumBCmp;
+}
+
+/// Recognizes a bcmp idiom in a non-countable loop.
+///
+/// If detected, transforms the relevant code to issue the bcmp (or memcmp)
+/// intrinsic function call, and returns true; otherwise, returns false.
+bool LoopIdiomRecognize::recognizeBCmp() {
+ if (!HasMemCmp && !HasBCmp)
+ return false;
+
+ ICmpInst *BCmpInst;
+ CmpInst *LatchCmpInst;
+ LoadInst *LoadA, *LoadB;
+ const SCEV *SrcA, *SrcB, *NBytes;
+ if (!detectBCmpIdiom(BCmpInst, LatchCmpInst, LoadA, LoadB, SrcA, SrcB,
+ NBytes)) {
+ LLVM_DEBUG(dbgs() << "bcmp idiom recognition failed.\n");
+ return false;
+ }
+
+ transformLoopToBCmp(BCmpInst, LatchCmpInst, LoadA, LoadB, SrcA, SrcB, NBytes);
+ return true;
+}
diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp
index 31191b52895c..368b9d4e8df1 100644
--- a/lib/Transforms/Scalar/LoopInstSimplify.cpp
+++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp
@@ -192,7 +192,8 @@ public:
getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
*L->getHeader()->getParent());
const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent());
MemorySSA *MSSA = nullptr;
Optional<MemorySSAUpdater> MSSAU;
if (EnableMSSALoopDependency) {
@@ -233,7 +234,7 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
auto PA = getLoopPassPreservedAnalyses();
PA.preserveSet<CFGAnalyses>();
- if (EnableMSSALoopDependency)
+ if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp
index 9a42365adc1b..1af4b21b432e 100644
--- a/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -410,8 +410,6 @@ public:
void removeChildLoop(Loop *OuterLoop, Loop *InnerLoop);
private:
- void splitInnerLoopLatch(Instruction *);
- void splitInnerLoopHeader();
bool adjustLoopLinks();
void adjustLoopPreheaders();
bool adjustLoopBranches();
@@ -1226,7 +1224,7 @@ bool LoopInterchangeTransform::transform() {
if (InnerLoop->getSubLoops().empty()) {
BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader();
- LLVM_DEBUG(dbgs() << "Calling Split Inner Loop\n");
+ LLVM_DEBUG(dbgs() << "Splitting the inner loop latch\n");
PHINode *InductionPHI = getInductionVariable(InnerLoop, SE);
if (!InductionPHI) {
LLVM_DEBUG(dbgs() << "Failed to find the point to split loop latch \n");
@@ -1242,11 +1240,55 @@ bool LoopInterchangeTransform::transform() {
if (&InductionPHI->getParent()->front() != InductionPHI)
InductionPHI->moveBefore(&InductionPHI->getParent()->front());
- // Split at the place were the induction variable is
- // incremented/decremented.
- // TODO: This splitting logic may not work always. Fix this.
- splitInnerLoopLatch(InnerIndexVar);
- LLVM_DEBUG(dbgs() << "splitInnerLoopLatch done\n");
+ // Create a new latch block for the inner loop. We split at the
+ // current latch's terminator and then move the condition and all
+ // operands that are not either loop-invariant or the induction PHI into the
+ // new latch block.
+ BasicBlock *NewLatch =
+ SplitBlock(InnerLoop->getLoopLatch(),
+ InnerLoop->getLoopLatch()->getTerminator(), DT, LI);
+
+ SmallSetVector<Instruction *, 4> WorkList;
+ unsigned i = 0;
+ auto MoveInstructions = [&i, &WorkList, this, InductionPHI, NewLatch]() {
+ for (; i < WorkList.size(); i++) {
+ // Duplicate instruction and move it the new latch. Update uses that
+ // have been moved.
+ Instruction *NewI = WorkList[i]->clone();
+ NewI->insertBefore(NewLatch->getFirstNonPHI());
+ assert(!NewI->mayHaveSideEffects() &&
+ "Moving instructions with side-effects may change behavior of "
+ "the loop nest!");
+ for (auto UI = WorkList[i]->use_begin(), UE = WorkList[i]->use_end();
+ UI != UE;) {
+ Use &U = *UI++;
+ Instruction *UserI = cast<Instruction>(U.getUser());
+ if (!InnerLoop->contains(UserI->getParent()) ||
+ UserI->getParent() == NewLatch || UserI == InductionPHI)
+ U.set(NewI);
+ }
+ // Add operands of moved instruction to the worklist, except if they are
+ // outside the inner loop or are the induction PHI.
+ for (Value *Op : WorkList[i]->operands()) {
+ Instruction *OpI = dyn_cast<Instruction>(Op);
+ if (!OpI ||
+ this->LI->getLoopFor(OpI->getParent()) != this->InnerLoop ||
+ OpI == InductionPHI)
+ continue;
+ WorkList.insert(OpI);
+ }
+ }
+ };
+
+ // FIXME: Should we interchange when we have a constant condition?
+ Instruction *CondI = dyn_cast<Instruction>(
+ cast<BranchInst>(InnerLoop->getLoopLatch()->getTerminator())
+ ->getCondition());
+ if (CondI)
+ WorkList.insert(CondI);
+ MoveInstructions();
+ WorkList.insert(cast<Instruction>(InnerIndexVar));
+ MoveInstructions();
// Splits the inner loops phi nodes out into a separate basic block.
BasicBlock *InnerLoopHeader = InnerLoop->getHeader();
@@ -1263,10 +1305,6 @@ bool LoopInterchangeTransform::transform() {
return true;
}
-void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) {
- SplitBlock(InnerLoop->getLoopLatch(), Inc, DT, LI);
-}
-
/// \brief Move all instructions except the terminator from FromBB right before
/// InsertBefore
static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) {
diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 2b3d5e0ce9b7..e8dc879a184b 100644
--- a/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -435,7 +435,8 @@ public:
PH->getTerminator());
Value *Initial = new LoadInst(
Cand.Load->getType(), InitialPtr, "load_initial",
- /* isVolatile */ false, Cand.Load->getAlignment(), PH->getTerminator());
+ /* isVolatile */ false, MaybeAlign(Cand.Load->getAlignment()),
+ PH->getTerminator());
PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded",
&L->getHeader()->front());
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp
index 507a1e251ca6..885c0e8f4b8b 100644
--- a/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/lib/Transforms/Scalar/LoopPredication.cpp
@@ -543,7 +543,7 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
if (AA->pointsToConstantMemory(LI->getOperand(0)) ||
- LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr)
+ LI->hasMetadata(LLVMContext::MD_invariant_load))
return true;
return false;
}
diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp
index 166b57f20b43..96e2c2a3ac6b 100644
--- a/lib/Transforms/Scalar/LoopRerollPass.cpp
+++ b/lib/Transforms/Scalar/LoopRerollPass.cpp
@@ -1644,7 +1644,8 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) {
AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent());
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
PreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp
index e009947690af..94517996df39 100644
--- a/lib/Transforms/Scalar/LoopRotation.cpp
+++ b/lib/Transforms/Scalar/LoopRotation.cpp
@@ -55,7 +55,7 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM,
AR.MSSA->verifyMemorySSA();
auto PA = getLoopPassPreservedAnalyses();
- if (EnableMSSALoopDependency)
+ if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}
@@ -94,17 +94,15 @@ public:
auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
- auto *SE = SEWP ? &SEWP->getSE() : nullptr;
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
const SimplifyQuery SQ = getBestSimplifyQuery(*this, F);
Optional<MemorySSAUpdater> MSSAU;
if (EnableMSSALoopDependency) {
MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
MSSAU = MemorySSAUpdater(MSSA);
}
- return LoopRotation(L, LI, TTI, AC, DT, SE,
+ return LoopRotation(L, LI, TTI, AC, &DT, &SE,
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ,
false, MaxHeaderSize, false);
}
diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
index 046f4c8af492..299f3fc5fb19 100644
--- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
+++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
@@ -690,7 +690,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &LPMU) {
Optional<MemorySSAUpdater> MSSAU;
- if (EnableMSSALoopDependency && AR.MSSA)
+ if (AR.MSSA)
MSSAU = MemorySSAUpdater(AR.MSSA);
bool DeleteCurrentLoop = false;
if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE,
@@ -702,7 +702,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM,
LPMU.markLoopAsDeleted(L, "loop-simplifycfg");
auto PA = getLoopPassPreservedAnalyses();
- if (EnableMSSALoopDependency)
+ if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp
index 975452e13f09..65e0dee0225a 100644
--- a/lib/Transforms/Scalar/LoopSink.cpp
+++ b/lib/Transforms/Scalar/LoopSink.cpp
@@ -230,12 +230,9 @@ static bool sinkInstruction(Loop &L, Instruction &I,
IC->setName(I.getName());
IC->insertBefore(&*N->getFirstInsertionPt());
// Replaces uses of I with IC in N
- for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); UI != UE;) {
- Use &U = *UI++;
- auto *I = cast<Instruction>(U.getUser());
- if (I->getParent() == N)
- U.set(IC);
- }
+ I.replaceUsesWithIf(IC, [N](Use &U) {
+ return cast<Instruction>(U.getUser())->getParent() == N;
+ });
// Replaces uses of I with IC in blocks dominated by N
replaceDominatedUsesWith(&I, IC, DT, N);
LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName()
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 59a387a186b8..7f119175c4a8 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1386,7 +1386,9 @@ void Cost::RateFormula(const Formula &F,
// Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as
// additional instruction (at least fill).
- unsigned TTIRegNum = TTI->getNumberOfRegisters(false) - 1;
+ // TODO: Need distinguish register class?
+ unsigned TTIRegNum = TTI->getNumberOfRegisters(
+ TTI->getRegisterClassForType(false, F.getType())) - 1;
if (C.NumRegs > TTIRegNum) {
// Cost already exceeded TTIRegNum, then only newly added register can add
// new instructions.
@@ -3165,6 +3167,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,
LLVM_DEBUG(dbgs() << "Concealed chain head: " << *Head.UserInst << "\n");
return;
}
+ assert(IVSrc && "Failed to find IV chain source");
LLVM_DEBUG(dbgs() << "Generate chain at: " << *IVSrc << "\n");
Type *IVTy = IVSrc->getType();
@@ -3265,12 +3268,12 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
// requirements for both N and i at the same time. Limiting this code to
// equality icmps is not a problem because all interesting loops use
// equality icmps, thanks to IndVarSimplify.
- if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst))
+ if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) {
+ // If CI can be saved in some target, like replaced inside hardware loop
+ // in PowerPC, no need to generate initial formulae for it.
+ if (SaveCmp && CI == dyn_cast<ICmpInst>(ExitBranch->getCondition()))
+ continue;
if (CI->isEquality()) {
- // If CI can be saved in some target, like replaced inside hardware loop
- // in PowerPC, no need to generate initial formulae for it.
- if (SaveCmp && CI == dyn_cast<ICmpInst>(ExitBranch->getCondition()))
- continue;
// Swap the operands if needed to put the OperandValToReplace on the
// left, for consistency.
Value *NV = CI->getOperand(1);
@@ -3298,6 +3301,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
Factors.insert(-(uint64_t)Factors[i]);
Factors.insert(-1);
}
+ }
// Get or create an LSRUse.
std::pair<size_t, int64_t> P = getUse(S, Kind, AccessTy);
@@ -4834,6 +4838,7 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() {
}
}
}
+ assert(Best && "Failed to find best LSRUse candidate");
LLVM_DEBUG(dbgs() << "Narrowing the search space by assuming " << *Best
<< " will yield profitable reuse.\n");
@@ -5740,7 +5745,8 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
*L->getHeader()->getParent());
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
*L->getHeader()->getParent());
- auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent());
return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo);
}
diff --git a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
index 86891eb451bb..8d88be420314 100644
--- a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
+++ b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
@@ -166,7 +166,7 @@ static bool computeUnrollAndJamCount(
bool UseUpperBound = false;
bool ExplicitUnroll = computeUnrollCount(
L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount,
- OuterTripMultiple, OuterLoopSize, UP, UseUpperBound);
+ /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, UseUpperBound);
if (ExplicitUnroll || UseUpperBound) {
// If the user explicitly set the loop as unrolled, dont UnJ it. Leave it
// for the unroller instead.
@@ -293,9 +293,9 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
if (Latch != Exit || SubLoopLatch != SubLoopExit)
return LoopUnrollResult::Unmodified;
- TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
- L, SE, TTI, nullptr, nullptr, OptLevel,
- None, None, None, None, None, None);
+ TargetTransformInfo::UnrollingPreferences UP =
+ gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None,
+ None, None, None, None, None, None, None);
if (AllowUnrollAndJam.getNumOccurrences() > 0)
UP.UnrollAndJam = AllowUnrollAndJam;
if (UnrollAndJamThreshold.getNumOccurrences() > 0)
diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp
index 2fa7436213dd..a6d4164c3645 100644
--- a/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -178,7 +178,9 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel,
Optional<unsigned> UserThreshold, Optional<unsigned> UserCount,
Optional<bool> UserAllowPartial, Optional<bool> UserRuntime,
- Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling) {
+ Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling,
+ Optional<bool> UserAllowProfileBasedPeeling,
+ Optional<unsigned> UserFullUnrollMaxCount) {
TargetTransformInfo::UnrollingPreferences UP;
// Set up the defaults
@@ -202,6 +204,7 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.UpperBound = false;
UP.AllowPeeling = true;
UP.UnrollAndJam = false;
+ UP.PeelProfiledIterations = true;
UP.UnrollAndJamInnerLoopThreshold = 60;
// Override with any target specific settings
@@ -257,6 +260,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(
UP.UpperBound = *UserUpperBound;
if (UserAllowPeeling.hasValue())
UP.AllowPeeling = *UserAllowPeeling;
+ if (UserAllowProfileBasedPeeling.hasValue())
+ UP.PeelProfiledIterations = *UserAllowProfileBasedPeeling;
+ if (UserFullUnrollMaxCount.hasValue())
+ UP.FullUnrollMaxCount = *UserFullUnrollMaxCount;
return UP;
}
@@ -730,7 +737,7 @@ bool llvm::computeUnrollCount(
Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI,
ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues,
OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount,
- unsigned &TripMultiple, unsigned LoopSize,
+ bool MaxOrZero, unsigned &TripMultiple, unsigned LoopSize,
TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) {
// Check for explicit Count.
@@ -781,18 +788,34 @@ bool llvm::computeUnrollCount(
// Also we need to check if we exceed FullUnrollMaxCount.
// If using the upper bound to unroll, TripMultiple should be set to 1 because
// we do not know when loop may exit.
- // MaxTripCount and ExactTripCount cannot both be non zero since we only
+
+ // We can unroll by the upper bound amount if it's generally allowed or if
+ // we know that the loop is executed either the upper bound or zero times.
+ // (MaxOrZero unrolling keeps only the first loop test, so the number of
+ // loop tests remains the same compared to the non-unrolled version, whereas
+ // the generic upper bound unrolling keeps all but the last loop test so the
+ // number of loop tests goes up which may end up being worse on targets with
+ // constrained branch predictor resources so is controlled by an option.)
+ // In addition we only unroll small upper bounds.
+ unsigned FullUnrollMaxTripCount = MaxTripCount;
+ if (!(UP.UpperBound || MaxOrZero) ||
+ FullUnrollMaxTripCount > UnrollMaxUpperBound)
+ FullUnrollMaxTripCount = 0;
+
+ // UnrollByMaxCount and ExactTripCount cannot both be non zero since we only
// compute the former when the latter is zero.
unsigned ExactTripCount = TripCount;
- assert((ExactTripCount == 0 || MaxTripCount == 0) &&
- "ExtractTripCount and MaxTripCount cannot both be non zero.");
- unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : MaxTripCount;
+ assert((ExactTripCount == 0 || FullUnrollMaxTripCount == 0) &&
+ "ExtractTripCount and UnrollByMaxCount cannot both be non zero.");
+
+ unsigned FullUnrollTripCount =
+ ExactTripCount ? ExactTripCount : FullUnrollMaxTripCount;
UP.Count = FullUnrollTripCount;
if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) {
// When computing the unrolled size, note that BEInsns are not replicated
// like the rest of the loop body.
if (getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) {
- UseUpperBound = (MaxTripCount == FullUnrollTripCount);
+ UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount);
TripCount = FullUnrollTripCount;
TripMultiple = UP.UpperBound ? 1 : TripMultiple;
return ExplicitUnroll;
@@ -806,7 +829,7 @@ bool llvm::computeUnrollCount(
unsigned Boost =
getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost);
if (Cost->UnrolledCost < UP.Threshold * Boost / 100) {
- UseUpperBound = (MaxTripCount == FullUnrollTripCount);
+ UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount);
TripCount = FullUnrollTripCount;
TripMultiple = UP.UpperBound ? 1 : TripMultiple;
return ExplicitUnroll;
@@ -882,6 +905,8 @@ bool llvm::computeUnrollCount(
"because "
"unrolled size is too large.";
});
+ LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count
+ << "\n");
return ExplicitUnroll;
}
assert(TripCount == 0 &&
@@ -903,6 +928,12 @@ bool llvm::computeUnrollCount(
return false;
}
+ // Don't unroll a small upper bound loop unless user or TTI asked to do so.
+ if (MaxTripCount && !UP.Force && MaxTripCount < UnrollMaxUpperBound) {
+ UP.Count = 0;
+ return false;
+ }
+
// Check if the runtime trip count is too small when profile is available.
if (L->getHeader()->getParent()->hasProfileData()) {
if (auto ProfileTripCount = getLoopEstimatedTripCount(L)) {
@@ -966,7 +997,11 @@ bool llvm::computeUnrollCount(
if (UP.Count > UP.MaxCount)
UP.Count = UP.MaxCount;
- LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count
+
+ if (MaxTripCount && UP.Count > MaxTripCount)
+ UP.Count = MaxTripCount;
+
+ LLVM_DEBUG(dbgs() << " runtime unrolling with count: " << UP.Count
<< "\n");
if (UP.Count < 2)
UP.Count = 0;
@@ -976,13 +1011,14 @@ bool llvm::computeUnrollCount(
static LoopUnrollResult tryToUnrollLoop(
Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
const TargetTransformInfo &TTI, AssumptionCache &AC,
- OptimizationRemarkEmitter &ORE,
- BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI,
- bool PreserveLCSSA, int OptLevel,
+ OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
+ ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel,
bool OnlyWhenForced, bool ForgetAllSCEV, Optional<unsigned> ProvidedCount,
Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial,
Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound,
- Optional<bool> ProvidedAllowPeeling) {
+ Optional<bool> ProvidedAllowPeeling,
+ Optional<bool> ProvidedAllowProfileBasedPeeling,
+ Optional<unsigned> ProvidedFullUnrollMaxCount) {
LLVM_DEBUG(dbgs() << "Loop Unroll: F["
<< L->getHeader()->getParent()->getName() << "] Loop %"
<< L->getHeader()->getName() << "\n");
@@ -1007,7 +1043,8 @@ static LoopUnrollResult tryToUnrollLoop(
TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount,
ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound,
- ProvidedAllowPeeling);
+ ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling,
+ ProvidedFullUnrollMaxCount);
// Exit early if unrolling is disabled. For OptForSize, we pick the loop size
// as threshold later on.
@@ -1028,10 +1065,10 @@ static LoopUnrollResult tryToUnrollLoop(
return LoopUnrollResult::Unmodified;
}
- // When optimizing for size, use LoopSize as threshold, to (fully) unroll
- // loops, if it does not increase code size.
+ // When optimizing for size, use LoopSize + 1 as threshold (we use < Threshold
+ // later), to (fully) unroll loops, if it does not increase code size.
if (OptForSize)
- UP.Threshold = std::max(UP.Threshold, LoopSize);
+ UP.Threshold = std::max(UP.Threshold, LoopSize + 1);
if (NumInlineCandidates != 0) {
LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
@@ -1040,7 +1077,6 @@ static LoopUnrollResult tryToUnrollLoop(
// Find trip count and trip multiple if count is not available
unsigned TripCount = 0;
- unsigned MaxTripCount = 0;
unsigned TripMultiple = 1;
// If there are multiple exiting blocks but one of them is the latch, use the
// latch for the trip count estimation. Otherwise insist on a single exiting
@@ -1070,28 +1106,18 @@ static LoopUnrollResult tryToUnrollLoop(
// Try to find the trip count upper bound if we cannot find the exact trip
// count.
+ unsigned MaxTripCount = 0;
bool MaxOrZero = false;
if (!TripCount) {
MaxTripCount = SE.getSmallConstantMaxTripCount(L);
MaxOrZero = SE.isBackedgeTakenCountMaxOrZero(L);
- // We can unroll by the upper bound amount if it's generally allowed or if
- // we know that the loop is executed either the upper bound or zero times.
- // (MaxOrZero unrolling keeps only the first loop test, so the number of
- // loop tests remains the same compared to the non-unrolled version, whereas
- // the generic upper bound unrolling keeps all but the last loop test so the
- // number of loop tests goes up which may end up being worse on targets with
- // constrained branch predictor resources so is controlled by an option.)
- // In addition we only unroll small upper bounds.
- if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) {
- MaxTripCount = 0;
- }
}
// computeUnrollCount() decides whether it is beneficial to use upper bound to
// fully unroll the loop.
bool UseUpperBound = false;
bool IsCountSetExplicitly = computeUnrollCount(
- L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount,
+ L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero,
TripMultiple, LoopSize, UP, UseUpperBound);
if (!UP.Count)
return LoopUnrollResult::Unmodified;
@@ -1139,7 +1165,7 @@ static LoopUnrollResult tryToUnrollLoop(
// If the loop was peeled, we already "used up" the profile information
// we had, so we don't want to unroll or peel again.
if (UnrollResult != LoopUnrollResult::FullyUnrolled &&
- (IsCountSetExplicitly || UP.PeelCount))
+ (IsCountSetExplicitly || (UP.PeelProfiledIterations && UP.PeelCount)))
L->setLoopAlreadyUnrolled();
return UnrollResult;
@@ -1169,18 +1195,24 @@ public:
Optional<bool> ProvidedRuntime;
Optional<bool> ProvidedUpperBound;
Optional<bool> ProvidedAllowPeeling;
+ Optional<bool> ProvidedAllowProfileBasedPeeling;
+ Optional<unsigned> ProvidedFullUnrollMaxCount;
LoopUnroll(int OptLevel = 2, bool OnlyWhenForced = false,
bool ForgetAllSCEV = false, Optional<unsigned> Threshold = None,
Optional<unsigned> Count = None,
Optional<bool> AllowPartial = None, Optional<bool> Runtime = None,
Optional<bool> UpperBound = None,
- Optional<bool> AllowPeeling = None)
+ Optional<bool> AllowPeeling = None,
+ Optional<bool> AllowProfileBasedPeeling = None,
+ Optional<unsigned> ProvidedFullUnrollMaxCount = None)
: LoopPass(ID), OptLevel(OptLevel), OnlyWhenForced(OnlyWhenForced),
ForgetAllSCEV(ForgetAllSCEV), ProvidedCount(std::move(Count)),
ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial),
ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound),
- ProvidedAllowPeeling(AllowPeeling) {
+ ProvidedAllowPeeling(AllowPeeling),
+ ProvidedAllowProfileBasedPeeling(AllowProfileBasedPeeling),
+ ProvidedFullUnrollMaxCount(ProvidedFullUnrollMaxCount) {
initializeLoopUnrollPass(*PassRegistry::getPassRegistry());
}
@@ -1203,10 +1235,11 @@ public:
bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
LoopUnrollResult Result = tryToUnrollLoop(
- L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr,
- PreserveLCSSA, OptLevel, OnlyWhenForced,
- ForgetAllSCEV, ProvidedCount, ProvidedThreshold, ProvidedAllowPartial,
- ProvidedRuntime, ProvidedUpperBound, ProvidedAllowPeeling);
+ L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, PreserveLCSSA, OptLevel,
+ OnlyWhenForced, ForgetAllSCEV, ProvidedCount, ProvidedThreshold,
+ ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound,
+ ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling,
+ ProvidedFullUnrollMaxCount);
if (Result == LoopUnrollResult::FullyUnrolled)
LPM.markLoopAsDeleted(*L);
@@ -1283,14 +1316,16 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
std::string LoopName = L.getName();
- bool Changed =
- tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE,
- /*BFI*/ nullptr, /*PSI*/ nullptr,
- /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced,
- ForgetSCEV, /*Count*/ None,
- /*Threshold*/ None, /*AllowPartial*/ false,
- /*Runtime*/ false, /*UpperBound*/ false,
- /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified;
+ bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE,
+ /*BFI*/ nullptr, /*PSI*/ nullptr,
+ /*PreserveLCSSA*/ true, OptLevel,
+ OnlyWhenForced, ForgetSCEV, /*Count*/ None,
+ /*Threshold*/ None, /*AllowPartial*/ false,
+ /*Runtime*/ false, /*UpperBound*/ false,
+ /*AllowPeeling*/ false,
+ /*AllowProfileBasedPeeling*/ false,
+ /*FullUnrollMaxCount*/ None) !=
+ LoopUnrollResult::Unmodified;
if (!Changed)
return PreservedAnalyses::all();
@@ -1430,7 +1465,8 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
/*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced,
UnrollOpts.ForgetSCEV, /*Count*/ None,
/*Threshold*/ None, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime,
- UnrollOpts.AllowUpperBound, LocalAllowPeeling);
+ UnrollOpts.AllowUpperBound, LocalAllowPeeling,
+ UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount);
Changed |= Result != LoopUnrollResult::Unmodified;
// The parent must not be damaged by unrolling!
diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp
index b5b8e720069c..b410df0c5f68 100644
--- a/lib/Transforms/Scalar/LoopUnswitch.cpp
+++ b/lib/Transforms/Scalar/LoopUnswitch.cpp
@@ -420,7 +420,8 @@ enum OperatorChain {
/// cost of creating an entirely new loop.
static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
OperatorChain &ParentChain,
- DenseMap<Value *, Value *> &Cache) {
+ DenseMap<Value *, Value *> &Cache,
+ MemorySSAUpdater *MSSAU) {
auto CacheIt = Cache.find(Cond);
if (CacheIt != Cache.end())
return CacheIt->second;
@@ -438,7 +439,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
// TODO: Handle: br (VARIANT|INVARIANT).
// Hoist simple values out.
- if (L->makeLoopInvariant(Cond, Changed)) {
+ if (L->makeLoopInvariant(Cond, Changed, nullptr, MSSAU)) {
Cache[Cond] = Cond;
return Cond;
}
@@ -478,7 +479,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
// which will cause the branch to go away in one loop and the condition to
// simplify in the other one.
if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
- ParentChain, Cache)) {
+ ParentChain, Cache, MSSAU)) {
Cache[Cond] = LHS;
return LHS;
}
@@ -486,7 +487,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
// operand(1).
ParentChain = NewChain;
if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
- ParentChain, Cache)) {
+ ParentChain, Cache, MSSAU)) {
Cache[Cond] = RHS;
return RHS;
}
@@ -500,12 +501,12 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
/// an invariant piece, return the invariant along with the operator chain type.
/// Otherwise, return null.
-static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond,
- Loop *L,
- bool &Changed) {
+static std::pair<Value *, OperatorChain>
+FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
+ MemorySSAUpdater *MSSAU) {
DenseMap<Value *, Value *> Cache;
OperatorChain OpChain = OC_OpChainNone;
- Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache);
+ Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU);
// In case we do find a LIV, it can not be obtained by walking up a mixed
// operator chain.
@@ -525,7 +526,7 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
if (EnableMSSALoopDependency) {
MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
- MSSAU = make_unique<MemorySSAUpdater>(MSSA);
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
assert(DT && "Cannot update MemorySSA without a valid DomTree.");
}
currentLoop = L;
@@ -694,8 +695,9 @@ bool LoopUnswitch::processCurrentLoop() {
}
for (IntrinsicInst *Guard : Guards) {
- Value *LoopCond =
- FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;
+ Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop,
+ Changed, MSSAU.get())
+ .first;
if (LoopCond &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
// NB! Unswitching (if successful) could have erased some of the
@@ -735,8 +737,9 @@ bool LoopUnswitch::processCurrentLoop() {
if (BI->isConditional()) {
// See if this, or some part of it, is loop invariant. If so, we can
// unswitch on it if we desire.
- Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
- currentLoop, Changed).first;
+ Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop,
+ Changed, MSSAU.get())
+ .first;
if (LoopCond && !EqualityPropUnSafe(*LoopCond) &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
++NumBranches;
@@ -748,7 +751,7 @@ bool LoopUnswitch::processCurrentLoop() {
Value *LoopCond;
OperatorChain OpChain;
std::tie(LoopCond, OpChain) =
- FindLIVLoopCondition(SC, currentLoop, Changed);
+ FindLIVLoopCondition(SC, currentLoop, Changed, MSSAU.get());
unsigned NumCases = SI->getNumCases();
if (LoopCond && NumCases) {
@@ -808,8 +811,9 @@ bool LoopUnswitch::processCurrentLoop() {
for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end();
BBI != E; ++BBI)
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
- Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
- currentLoop, Changed).first;
+ Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop,
+ Changed, MSSAU.get())
+ .first;
if (LoopCond && UnswitchIfProfitable(LoopCond,
ConstantInt::getTrue(Context))) {
++NumSelects;
@@ -1123,8 +1127,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
if (!BI->isConditional())
return false;
- Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
- currentLoop, Changed).first;
+ Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop,
+ Changed, MSSAU.get())
+ .first;
// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
@@ -1157,8 +1162,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
return true;
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
// If this isn't switching on an invariant condition, we can't unswitch it.
- Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
- currentLoop, Changed).first;
+ Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop,
+ Changed, MSSAU.get())
+ .first;
// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
@@ -1240,6 +1246,9 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
LoopBlocks.clear();
NewBlocks.clear();
+ if (MSSAU && VerifyMemorySSA)
+ MSSA->verifyMemorySSA();
+
// First step, split the preheader and exit blocks, and add these blocks to
// the LoopBlocks list.
BasicBlock *NewPreheader =
@@ -1607,36 +1616,30 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
// If BI's parent is the only pred of the successor, fold the two blocks
// together.
BasicBlock *Pred = BI->getParent();
+ (void)Pred;
BasicBlock *Succ = BI->getSuccessor(0);
BasicBlock *SinglePred = Succ->getSinglePredecessor();
if (!SinglePred) continue; // Nothing to do.
assert(SinglePred == Pred && "CFG broken");
- LLVM_DEBUG(dbgs() << "Merging blocks: " << Pred->getName() << " <- "
- << Succ->getName() << "\n");
-
- // Resolve any single entry PHI nodes in Succ.
- while (PHINode *PN = dyn_cast<PHINode>(Succ->begin()))
- ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM,
- MSSAU.get());
-
- // If Succ has any successors with PHI nodes, update them to have
- // entries coming from Pred instead of Succ.
- Succ->replaceAllUsesWith(Pred);
-
- // Move all of the successor contents from Succ to Pred.
- Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(),
- Succ->begin(), Succ->end());
- if (MSSAU)
- MSSAU->moveAllAfterMergeBlocks(Succ, Pred, BI);
+ // Make the LPM and Worklist updates specific to LoopUnswitch.
LPM->deleteSimpleAnalysisValue(BI, L);
RemoveFromWorklist(BI, Worklist);
- BI->eraseFromParent();
-
- // Remove Succ from the loop tree.
- LI->removeBlock(Succ);
LPM->deleteSimpleAnalysisValue(Succ, L);
- Succ->eraseFromParent();
+ auto SuccIt = Succ->begin();
+ while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) {
+ for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It)
+ if (Instruction *Use = dyn_cast<Instruction>(PN->getOperand(It)))
+ Worklist.push_back(Use);
+ for (User *U : PN->users())
+ Worklist.push_back(cast<Instruction>(U));
+ LPM->deleteSimpleAnalysisValue(PN, L);
+ RemoveFromWorklist(PN, Worklist);
+ ++NumSimplify;
+ }
+ // Merge the block and make the remaining analyses updates.
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ MergeBlockIntoPredecessor(Succ, &DTU, LI, MSSAU.get());
++NumSimplify;
continue;
}
diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp
index 896dd8bcb922..2ccb7cae3079 100644
--- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp
+++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp
@@ -112,37 +112,6 @@ static cl::opt<unsigned> LVLoopDepthThreshold(
"LoopVersioningLICM's threshold for maximum allowed loop nest/depth"),
cl::init(2), cl::Hidden);
-/// Create MDNode for input string.
-static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
- LLVMContext &Context = TheLoop->getHeader()->getContext();
- Metadata *MDs[] = {
- MDString::get(Context, Name),
- ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))};
- return MDNode::get(Context, MDs);
-}
-
-/// Set input string into loop metadata by keeping other values intact.
-void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
- unsigned V) {
- SmallVector<Metadata *, 4> MDs(1);
- // If the loop already has metadata, retain it.
- MDNode *LoopID = TheLoop->getLoopID();
- if (LoopID) {
- for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
- MDNode *Node = cast<MDNode>(LoopID->getOperand(i));
- MDs.push_back(Node);
- }
- }
- // Add new metadata.
- MDs.push_back(createStringMetadata(TheLoop, MDString, V));
- // Replace current metadata node with new one.
- LLVMContext &Context = TheLoop->getHeader()->getContext();
- MDNode *NewLoopID = MDNode::get(Context, MDs);
- // Set operand 0 to refer to the loop id itself.
- NewLoopID->replaceOperandWith(0, NewLoopID);
- TheLoop->setLoopID(NewLoopID);
-}
-
namespace {
struct LoopVersioningLICM : public LoopPass {
diff --git a/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
new file mode 100644
index 000000000000..d0fcf38b5a7b
--- /dev/null
+++ b/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
@@ -0,0 +1,170 @@
+//===- LowerConstantIntrinsics.cpp - Lower constant intrinsic calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers all remaining 'objectsize' 'is.constant' intrinsic calls
+// and provides constant propagation and basic CFG cleanup on the result.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "lower-is-constant-intrinsic"
+
+STATISTIC(IsConstantIntrinsicsHandled,
+ "Number of 'is.constant' intrinsic calls handled");
+STATISTIC(ObjectSizeIntrinsicsHandled,
+ "Number of 'objectsize' intrinsic calls handled");
+
+static Value *lowerIsConstantIntrinsic(IntrinsicInst *II) {
+ Value *Op = II->getOperand(0);
+
+ return isa<Constant>(Op) ? ConstantInt::getTrue(II->getType())
+ : ConstantInt::getFalse(II->getType());
+}
+
+static bool replaceConditionalBranchesOnConstant(Instruction *II,
+ Value *NewValue) {
+ bool HasDeadBlocks = false;
+ SmallSetVector<Instruction *, 8> Worklist;
+ replaceAndRecursivelySimplify(II, NewValue, nullptr, nullptr, nullptr,
+ &Worklist);
+ for (auto I : Worklist) {
+ BranchInst *BI = dyn_cast<BranchInst>(I);
+ if (!BI)
+ continue;
+ if (BI->isUnconditional())
+ continue;
+
+ BasicBlock *Target, *Other;
+ if (match(BI->getOperand(0), m_Zero())) {
+ Target = BI->getSuccessor(1);
+ Other = BI->getSuccessor(0);
+ } else if (match(BI->getOperand(0), m_One())) {
+ Target = BI->getSuccessor(0);
+ Other = BI->getSuccessor(1);
+ } else {
+ Target = nullptr;
+ Other = nullptr;
+ }
+ if (Target && Target != Other) {
+ BasicBlock *Source = BI->getParent();
+ Other->removePredecessor(Source);
+ BI->eraseFromParent();
+ BranchInst::Create(Target, Source);
+ if (pred_begin(Other) == pred_end(Other))
+ HasDeadBlocks = true;
+ }
+ }
+ return HasDeadBlocks;
+}
+
+static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI) {
+ bool HasDeadBlocks = false;
+ const auto &DL = F.getParent()->getDataLayout();
+ SmallVector<WeakTrackingVH, 8> Worklist;
+
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ for (BasicBlock *BB : RPOT) {
+ for (Instruction &I: *BB) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
+ if (!II)
+ continue;
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::is_constant:
+ case Intrinsic::objectsize:
+ Worklist.push_back(WeakTrackingVH(&I));
+ break;
+ }
+ }
+ }
+ for (WeakTrackingVH &VH: Worklist) {
+ // Items on the worklist can be mutated by earlier recursive replaces.
+ // This can remove the intrinsic as dead (VH == null), but also replace
+ // the intrinsic in place.
+ if (!VH)
+ continue;
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*VH);
+ if (!II)
+ continue;
+ Value *NewValue;
+ switch (II->getIntrinsicID()) {
+ default:
+ continue;
+ case Intrinsic::is_constant:
+ NewValue = lowerIsConstantIntrinsic(II);
+ IsConstantIntrinsicsHandled++;
+ break;
+ case Intrinsic::objectsize:
+ NewValue = lowerObjectSizeCall(II, DL, TLI, true);
+ ObjectSizeIntrinsicsHandled++;
+ break;
+ }
+ HasDeadBlocks |= replaceConditionalBranchesOnConstant(II, NewValue);
+ }
+ if (HasDeadBlocks)
+ removeUnreachableBlocks(F);
+ return !Worklist.empty();
+}
+
+PreservedAnalyses
+LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) {
+ if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F)))
+ return PreservedAnalyses::none();
+
+ return PreservedAnalyses::all();
+}
+
+namespace {
+/// Legacy pass for lowering is.constant intrinsics out of the IR.
+///
+/// When this pass is run over a function it converts is.constant intrinsics
+/// into 'true' or 'false'. This is completements the normal constand folding
+/// to 'true' as part of Instruction Simplify passes.
+class LowerConstantIntrinsics : public FunctionPass {
+public:
+ static char ID;
+ LowerConstantIntrinsics() : FunctionPass(ID) {
+ initializeLowerConstantIntrinsicsPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
+ const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
+ return lowerConstantIntrinsics(F, TLI);
+ }
+};
+} // namespace
+
+char LowerConstantIntrinsics::ID = 0;
+INITIALIZE_PASS(LowerConstantIntrinsics, "lower-constant-intrinsics",
+ "Lower constant intrinsics", false, false)
+
+FunctionPass *llvm::createLowerConstantIntrinsicsPass() {
+ return new LowerConstantIntrinsics();
+}
diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 0d67c0d740ec..d85f20b3f80c 100644
--- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -26,6 +26,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/MisExpect.h"
using namespace llvm;
@@ -71,15 +72,20 @@ static bool handleSwitchExpect(SwitchInst &SI) {
unsigned n = SI.getNumCases(); // +1 for default case.
SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight);
- if (Case == *SI.case_default())
- Weights[0] = LikelyBranchWeight;
- else
- Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight;
+ uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
+ Weights[Index] = LikelyBranchWeight;
+
+ SI.setMetadata(
+ LLVMContext::MD_misexpect,
+ MDBuilder(CI->getContext())
+ .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight));
+
+ SI.setCondition(ArgValue);
+ misexpect::checkFrontendInstrumentation(SI);
SI.setMetadata(LLVMContext::MD_prof,
MDBuilder(CI->getContext()).createBranchWeights(Weights));
- SI.setCondition(ArgValue);
return true;
}
@@ -155,7 +161,7 @@ static void handlePhiDef(CallInst *Expect) {
return Result;
};
- auto *PhiDef = dyn_cast<PHINode>(V);
+ auto *PhiDef = cast<PHINode>(V);
// Get the first dominating conditional branch of the operand
// i's incoming block.
@@ -280,19 +286,28 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
MDBuilder MDB(CI->getContext());
MDNode *Node;
+ MDNode *ExpNode;
if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
- (Predicate == CmpInst::ICMP_EQ))
+ (Predicate == CmpInst::ICMP_EQ)) {
Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight);
- else
+ ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight);
+ } else {
Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight);
+ ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight);
+ }
- BSI.setMetadata(LLVMContext::MD_prof, Node);
+ BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode);
if (CmpI)
CmpI->setOperand(0, ArgValue);
else
BSI.setCondition(ArgValue);
+
+ misexpect::checkFrontendInstrumentation(BSI);
+
+ BSI.setMetadata(LLVMContext::MD_prof, Node);
+
return true;
}
diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 5a055139be4f..2364748efb05 100644
--- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -69,90 +69,6 @@ STATISTIC(NumMemSetInfer, "Number of memsets inferred");
STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy");
STATISTIC(NumCpyToSet, "Number of memcpys converted to memset");
-static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx,
- bool &VariableIdxFound,
- const DataLayout &DL) {
- // Skip over the first indices.
- gep_type_iterator GTI = gep_type_begin(GEP);
- for (unsigned i = 1; i != Idx; ++i, ++GTI)
- /*skip along*/;
-
- // Compute the offset implied by the rest of the indices.
- int64_t Offset = 0;
- for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
- ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i));
- if (!OpC)
- return VariableIdxFound = true;
- if (OpC->isZero()) continue; // No offset.
-
- // Handle struct indices, which add their field offset to the pointer.
- if (StructType *STy = GTI.getStructTypeOrNull()) {
- Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
- continue;
- }
-
- // Otherwise, we have a sequential type like an array or vector. Multiply
- // the index by the ElementSize.
- uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
- Offset += Size*OpC->getSExtValue();
- }
-
- return Offset;
-}
-
-/// Return true if Ptr1 is provably equal to Ptr2 plus a constant offset, and
-/// return that constant offset. For example, Ptr1 might be &A[42], and Ptr2
-/// might be &A[40]. In this case offset would be -8.
-static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
- const DataLayout &DL) {
- Ptr1 = Ptr1->stripPointerCasts();
- Ptr2 = Ptr2->stripPointerCasts();
-
- // Handle the trivial case first.
- if (Ptr1 == Ptr2) {
- Offset = 0;
- return true;
- }
-
- GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1);
- GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2);
-
- bool VariableIdxFound = false;
-
- // If one pointer is a GEP and the other isn't, then see if the GEP is a
- // constant offset from the base, as in "P" and "gep P, 1".
- if (GEP1 && !GEP2 && GEP1->getOperand(0)->stripPointerCasts() == Ptr2) {
- Offset = -GetOffsetFromIndex(GEP1, 1, VariableIdxFound, DL);
- return !VariableIdxFound;
- }
-
- if (GEP2 && !GEP1 && GEP2->getOperand(0)->stripPointerCasts() == Ptr1) {
- Offset = GetOffsetFromIndex(GEP2, 1, VariableIdxFound, DL);
- return !VariableIdxFound;
- }
-
- // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical
- // base. After that base, they may have some number of common (and
- // potentially variable) indices. After that they handle some constant
- // offset, which determines their offset from each other. At this point, we
- // handle no other case.
- if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0))
- return false;
-
- // Skip any common indices and track the GEP types.
- unsigned Idx = 1;
- for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx)
- if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx))
- break;
-
- int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, DL);
- int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, DL);
- if (VariableIdxFound) return false;
-
- Offset = Offset2-Offset1;
- return true;
-}
-
namespace {
/// Represents a range of memset'd bytes with the ByteVal value.
@@ -419,12 +335,12 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
break;
// Check to see if this store is to a constant offset from the start ptr.
- int64_t Offset;
- if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(), Offset,
- DL))
+ Optional<int64_t> Offset =
+ isPointerOffset(StartPtr, NextStore->getPointerOperand(), DL);
+ if (!Offset)
break;
- Ranges.addStore(Offset, NextStore);
+ Ranges.addStore(*Offset, NextStore);
} else {
MemSetInst *MSI = cast<MemSetInst>(BI);
@@ -433,11 +349,11 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
break;
// Check to see if this store is to a constant offset from the start ptr.
- int64_t Offset;
- if (!IsPointerOffset(StartPtr, MSI->getDest(), Offset, DL))
+ Optional<int64_t> Offset = isPointerOffset(StartPtr, MSI->getDest(), DL);
+ if (!Offset)
break;
- Ranges.addMemSet(Offset, MSI);
+ Ranges.addMemSet(*Offset, MSI);
}
}
@@ -597,9 +513,13 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P,
ToLift.push_back(C);
for (unsigned k = 0, e = C->getNumOperands(); k != e; ++k)
- if (auto *A = dyn_cast<Instruction>(C->getOperand(k)))
- if (A->getParent() == SI->getParent())
+ if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) {
+ if (A->getParent() == SI->getParent()) {
+ // Cannot hoist user of P above P
+ if(A == P) return false;
Args.insert(A);
+ }
+ }
}
// We made it, we need to lift
@@ -979,7 +899,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,
// If the destination wasn't sufficiently aligned then increase its alignment.
if (!isDestSufficientlyAligned) {
assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!");
- cast<AllocaInst>(cpyDest)->setAlignment(srcAlign);
+ cast<AllocaInst>(cpyDest)->setAlignment(MaybeAlign(srcAlign));
}
// Drop any cached information about the call, because we may have changed
@@ -1516,7 +1436,7 @@ bool MemCpyOptLegacyPass::runOnFunction(Function &F) {
return false;
auto *MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
- auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto LookupAliasAnalysis = [this]() -> AliasAnalysis & {
return getAnalysis<AAResultsWrapperPass>().getAAResults();
diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp
index 3d047a193267..98a45b391319 100644
--- a/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/lib/Transforms/Scalar/MergeICmps.cpp
@@ -897,7 +897,7 @@ public:
bool runOnFunction(Function &F) override {
if (skipFunction(F)) return false;
- const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
// MergeICmps does not need the DominatorTree, but we update it if it's
// already available.
diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
index 30645f4400e3..9799ea7960ec 100644
--- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
+++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
@@ -14,9 +14,11 @@
// diamond (hammock) and merges them into a single load in the header. Similar
// it sinks and merges two stores to the tail block (footer). The algorithm
// iterates over the instructions of one side of the diamond and attempts to
-// find a matching load/store on the other side. It hoists / sinks when it
-// thinks it safe to do so. This optimization helps with eg. hiding load
-// latencies, triggering if-conversion, and reducing static code size.
+// find a matching load/store on the other side. New tail/footer block may be
+// insterted if the tail/footer block has more predecessors (not only the two
+// predecessors that are forming the diamond). It hoists / sinks when it thinks
+// it safe to do so. This optimization helps with eg. hiding load latencies,
+// triggering if-conversion, and reducing static code size.
//
// NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist.
//
@@ -103,7 +105,9 @@ class MergedLoadStoreMotion {
// Control is enforced by the check Size0 * Size1 < MagicCompileTimeControl.
const int MagicCompileTimeControl = 250;
+ const bool SplitFooterBB;
public:
+ MergedLoadStoreMotion(bool SplitFooterBB) : SplitFooterBB(SplitFooterBB) {}
bool run(Function &F, AliasAnalysis &AA);
private:
@@ -114,7 +118,9 @@ private:
PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1);
bool isStoreSinkBarrierInRange(const Instruction &Start,
const Instruction &End, MemoryLocation Loc);
- bool sinkStore(BasicBlock *BB, StoreInst *SinkCand, StoreInst *ElseInst);
+ bool canSinkStoresAndGEPs(StoreInst *S0, StoreInst *S1) const;
+ void sinkStoresAndGEPs(BasicBlock *BB, StoreInst *SinkCand,
+ StoreInst *ElseInst);
bool mergeStores(BasicBlock *BB);
};
} // end anonymous namespace
@@ -217,74 +223,82 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0,
}
///
+/// Check if 2 stores can be sunk together with corresponding GEPs
+///
+bool MergedLoadStoreMotion::canSinkStoresAndGEPs(StoreInst *S0,
+ StoreInst *S1) const {
+ auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand());
+ auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand());
+ return A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() &&
+ (A0->getParent() == S0->getParent()) && A1->hasOneUse() &&
+ (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0);
+}
+
+///
/// Merge two stores to same address and sink into \p BB
///
/// Also sinks GEP instruction computing the store address
///
-bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0,
- StoreInst *S1) {
+void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0,
+ StoreInst *S1) {
// Only one definition?
auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand());
auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand());
- if (A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() &&
- (A0->getParent() == S0->getParent()) && A1->hasOneUse() &&
- (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0)) {
- LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump();
- dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n";
- dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n");
- // Hoist the instruction.
- BasicBlock::iterator InsertPt = BB->getFirstInsertionPt();
- // Intersect optional metadata.
- S0->andIRFlags(S1);
- S0->dropUnknownNonDebugMetadata();
-
- // Create the new store to be inserted at the join point.
- StoreInst *SNew = cast<StoreInst>(S0->clone());
- Instruction *ANew = A0->clone();
- SNew->insertBefore(&*InsertPt);
- ANew->insertBefore(SNew);
-
- assert(S0->getParent() == A0->getParent());
- assert(S1->getParent() == A1->getParent());
-
- // New PHI operand? Use it.
- if (PHINode *NewPN = getPHIOperand(BB, S0, S1))
- SNew->setOperand(0, NewPN);
- S0->eraseFromParent();
- S1->eraseFromParent();
- A0->replaceAllUsesWith(ANew);
- A0->eraseFromParent();
- A1->replaceAllUsesWith(ANew);
- A1->eraseFromParent();
- return true;
- }
- return false;
+ LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump();
+ dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n";
+ dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n");
+ // Hoist the instruction.
+ BasicBlock::iterator InsertPt = BB->getFirstInsertionPt();
+ // Intersect optional metadata.
+ S0->andIRFlags(S1);
+ S0->dropUnknownNonDebugMetadata();
+
+ // Create the new store to be inserted at the join point.
+ StoreInst *SNew = cast<StoreInst>(S0->clone());
+ Instruction *ANew = A0->clone();
+ SNew->insertBefore(&*InsertPt);
+ ANew->insertBefore(SNew);
+
+ assert(S0->getParent() == A0->getParent());
+ assert(S1->getParent() == A1->getParent());
+
+ // New PHI operand? Use it.
+ if (PHINode *NewPN = getPHIOperand(BB, S0, S1))
+ SNew->setOperand(0, NewPN);
+ S0->eraseFromParent();
+ S1->eraseFromParent();
+ A0->replaceAllUsesWith(ANew);
+ A0->eraseFromParent();
+ A1->replaceAllUsesWith(ANew);
+ A1->eraseFromParent();
}
///
/// True when two stores are equivalent and can sink into the footer
///
-/// Starting from a diamond tail block, iterate over the instructions in one
-/// predecessor block and try to match a store in the second predecessor.
+/// Starting from a diamond head block, iterate over the instructions in one
+/// successor block and try to match a store in the second successor.
///
-bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) {
+bool MergedLoadStoreMotion::mergeStores(BasicBlock *HeadBB) {
bool MergedStores = false;
- assert(T && "Footer of a diamond cannot be empty");
-
- pred_iterator PI = pred_begin(T), E = pred_end(T);
- assert(PI != E);
- BasicBlock *Pred0 = *PI;
- ++PI;
- BasicBlock *Pred1 = *PI;
- ++PI;
+ BasicBlock *TailBB = getDiamondTail(HeadBB);
+ BasicBlock *SinkBB = TailBB;
+ assert(SinkBB && "Footer of a diamond cannot be empty");
+
+ succ_iterator SI = succ_begin(HeadBB);
+ assert(SI != succ_end(HeadBB) && "Diamond head cannot have zero successors");
+ BasicBlock *Pred0 = *SI;
+ ++SI;
+ assert(SI != succ_end(HeadBB) && "Diamond head cannot have single successor");
+ BasicBlock *Pred1 = *SI;
// tail block of a diamond/hammock?
if (Pred0 == Pred1)
return false; // No.
- if (PI != E)
- return false; // No. More than 2 predecessors.
-
- // #Instructions in Succ1 for Compile Time Control
+ // bail out early if we can not merge into the footer BB
+ if (!SplitFooterBB && TailBB->hasNPredecessorsOrMore(3))
+ return false;
+ // #Instructions in Pred1 for Compile Time Control
auto InstsNoDbg = Pred1->instructionsWithoutDebug();
int Size1 = std::distance(InstsNoDbg.begin(), InstsNoDbg.end());
int NStores = 0;
@@ -304,14 +318,23 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) {
if (NStores * Size1 >= MagicCompileTimeControl)
break;
if (StoreInst *S1 = canSinkFromBlock(Pred1, S0)) {
- bool Res = sinkStore(T, S0, S1);
- MergedStores |= Res;
- // Don't attempt to sink below stores that had to stick around
- // But after removal of a store and some of its feeding
- // instruction search again from the beginning since the iterator
- // is likely stale at this point.
- if (!Res)
+ if (!canSinkStoresAndGEPs(S0, S1))
+ // Don't attempt to sink below stores that had to stick around
+ // But after removal of a store and some of its feeding
+ // instruction search again from the beginning since the iterator
+ // is likely stale at this point.
break;
+
+ if (SinkBB == TailBB && TailBB->hasNPredecessorsOrMore(3)) {
+ // We have more than 2 predecessors. Insert a new block
+ // postdominating 2 predecessors we're going to sink from.
+ SinkBB = SplitBlockPredecessors(TailBB, {Pred0, Pred1}, ".sink.split");
+ if (!SinkBB)
+ break;
+ }
+
+ MergedStores = true;
+ sinkStoresAndGEPs(SinkBB, S0, S1);
RBI = Pred0->rbegin();
RBE = Pred0->rend();
LLVM_DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump());
@@ -328,13 +351,15 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) {
// Merge unconditional branches, allowing PRE to catch more
// optimization opportunities.
+ // This loop doesn't care about newly inserted/split blocks
+ // since they never will be diamond heads.
for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) {
BasicBlock *BB = &*FI++;
// Hoist equivalent loads and sink stores
// outside diamonds when possible
if (isDiamondHead(BB)) {
- Changed |= mergeStores(getDiamondTail(BB));
+ Changed |= mergeStores(BB);
}
}
return Changed;
@@ -342,9 +367,11 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) {
namespace {
class MergedLoadStoreMotionLegacyPass : public FunctionPass {
+ const bool SplitFooterBB;
public:
static char ID; // Pass identification, replacement for typeid
- MergedLoadStoreMotionLegacyPass() : FunctionPass(ID) {
+ MergedLoadStoreMotionLegacyPass(bool SplitFooterBB = false)
+ : FunctionPass(ID), SplitFooterBB(SplitFooterBB) {
initializeMergedLoadStoreMotionLegacyPassPass(
*PassRegistry::getPassRegistry());
}
@@ -355,13 +382,14 @@ public:
bool runOnFunction(Function &F) override {
if (skipFunction(F))
return false;
- MergedLoadStoreMotion Impl;
+ MergedLoadStoreMotion Impl(SplitFooterBB);
return Impl.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults());
}
private:
void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
+ if (!SplitFooterBB)
+ AU.setPreservesCFG();
AU.addRequired<AAResultsWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
@@ -373,8 +401,8 @@ char MergedLoadStoreMotionLegacyPass::ID = 0;
///
/// createMergedLoadStoreMotionPass - The public interface to this file.
///
-FunctionPass *llvm::createMergedLoadStoreMotionPass() {
- return new MergedLoadStoreMotionLegacyPass();
+FunctionPass *llvm::createMergedLoadStoreMotionPass(bool SplitFooterBB) {
+ return new MergedLoadStoreMotionLegacyPass(SplitFooterBB);
}
INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion",
@@ -385,13 +413,14 @@ INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion",
PreservedAnalyses
MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) {
- MergedLoadStoreMotion Impl;
+ MergedLoadStoreMotion Impl(Options.SplitFooterBB);
auto &AA = AM.getResult<AAManager>(F);
if (!Impl.run(F, AA))
return PreservedAnalyses::all();
PreservedAnalyses PA;
- PA.preserveSet<CFGAnalyses>();
+ if (!Options.SplitFooterBB)
+ PA.preserveSet<CFGAnalyses>();
PA.preserve<GlobalsAA>();
return PA;
}
diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp
index 94436b55752a..1260bd39cdee 100644
--- a/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -170,7 +170,7 @@ bool NaryReassociateLegacyPass::runOnFunction(Function &F) {
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return Impl.runImpl(F, AC, DT, SE, TLI, TTI);
diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp
index 08ac2b666fce..b213264de557 100644
--- a/lib/Transforms/Scalar/NewGVN.cpp
+++ b/lib/Transforms/Scalar/NewGVN.cpp
@@ -89,6 +89,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
@@ -122,6 +123,7 @@
using namespace llvm;
using namespace llvm::GVNExpression;
using namespace llvm::VNCoercion;
+using namespace llvm::PatternMatch;
#define DEBUG_TYPE "newgvn"
@@ -656,7 +658,7 @@ public:
TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA,
const DataLayout &DL)
: F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL),
- PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)),
+ PredInfo(std::make_unique<PredicateInfo>(F, *DT, *AC)),
SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {}
bool runGVN();
@@ -1332,7 +1334,7 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,
E->setOpcode(0);
E->op_push_back(PointerOp);
if (LI)
- E->setAlignment(LI->getAlignment());
+ E->setAlignment(MaybeAlign(LI->getAlignment()));
// TODO: Value number heap versions. We may be able to discover
// things alias analysis can't on it's own (IE that a store and a
@@ -1637,8 +1639,11 @@ const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const {
if (AA->doesNotAccessMemory(CI)) {
return createCallExpression(CI, TOPClass->getMemoryLeader());
} else if (AA->onlyReadsMemory(CI)) {
- MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI);
- return createCallExpression(CI, DefiningAccess);
+ if (auto *MA = MSSA->getMemoryAccess(CI)) {
+ auto *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(MA);
+ return createCallExpression(CI, DefiningAccess);
+ } else // MSSA determined that CI does not access memory.
+ return createCallExpression(CI, TOPClass->getMemoryLeader());
}
return nullptr;
}
@@ -1754,7 +1759,7 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps,
return true;
});
// If we are left with no operands, it's dead.
- if (empty(Filtered)) {
+ if (Filtered.empty()) {
// If it has undef at this point, it means there are no-non-undef arguments,
// and thus, the value of the phi node must be undef.
if (HasUndef) {
@@ -2464,9 +2469,9 @@ Value *NewGVN::findConditionEquivalence(Value *Cond) const {
// Process the outgoing edges of a block for reachability.
void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) {
// Evaluate reachability of terminator instruction.
- BranchInst *BR;
- if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) {
- Value *Cond = BR->getCondition();
+ Value *Cond;
+ BasicBlock *TrueSucc, *FalseSucc;
+ if (match(TI, m_Br(m_Value(Cond), TrueSucc, FalseSucc))) {
Value *CondEvaluated = findConditionEquivalence(Cond);
if (!CondEvaluated) {
if (auto *I = dyn_cast<Instruction>(Cond)) {
@@ -2479,8 +2484,6 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) {
}
}
ConstantInt *CI;
- BasicBlock *TrueSucc = BR->getSuccessor(0);
- BasicBlock *FalseSucc = BR->getSuccessor(1);
if (CondEvaluated && (CI = dyn_cast<ConstantInt>(CondEvaluated))) {
if (CI->isOne()) {
LLVM_DEBUG(dbgs() << "Condition for Terminator " << *TI
@@ -4196,7 +4199,7 @@ bool NewGVNLegacyPass::runOnFunction(Function &F) {
return false;
return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
&getAnalysis<MemorySSAWrapperPass>().getMSSA(),
F.getParent()->getDataLayout())
diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
index 039123218544..68a0f5151ad5 100644
--- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
+++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
@@ -161,7 +161,7 @@ public:
return false;
TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
const TargetTransformInfo *TTI =
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return runPartiallyInlineLibCalls(F, TLI, TTI);
diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp
index b544f0a39ea8..beb299272ed8 100644
--- a/lib/Transforms/Scalar/PlaceSafepoints.cpp
+++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp
@@ -131,7 +131,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass {
SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
for (Loop *I : *LI) {
runOnLoopAndSubLoops(I);
}
@@ -240,7 +240,7 @@ static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header,
static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE,
BasicBlock *Pred) {
// A conservative bound on the loop as a whole.
- const SCEV *MaxTrips = SE->getMaxBackedgeTakenCount(L);
+ const SCEV *MaxTrips = SE->getConstantMaxBackedgeTakenCount(L);
if (MaxTrips != SE->getCouldNotCompute() &&
SE->getUnsignedRange(MaxTrips).getUnsignedMax().isIntN(
CountedLoopTripWidth))
@@ -478,7 +478,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) {
return false;
const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
bool Modified = false;
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index fa8c9e2a5fe4..124f625ef7b6 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -861,7 +861,7 @@ static Value *NegateValue(Value *V, Instruction *BI,
// this use. We do this by moving it to the entry block (if it is a
// non-instruction value) or right after the definition. These negates will
// be zapped by reassociate later, so we don't need much finesse here.
- BinaryOperator *TheNeg = cast<BinaryOperator>(U);
+ Instruction *TheNeg = cast<Instruction>(U);
// Verify that the negate is in this function, V might be a constant expr.
if (TheNeg->getParent()->getParent() != BI->getParent()->getParent())
@@ -1938,88 +1938,132 @@ void ReassociatePass::EraseInst(Instruction *I) {
MadeChange = true;
}
-// Canonicalize expressions of the following form:
-// x + (-Constant * y) -> x - (Constant * y)
-// x - (-Constant * y) -> x + (Constant * y)
-Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) {
- if (!I->hasOneUse() || I->getType()->isVectorTy())
- return nullptr;
-
- // Must be a fmul or fdiv instruction.
- unsigned Opcode = I->getOpcode();
- if (Opcode != Instruction::FMul && Opcode != Instruction::FDiv)
- return nullptr;
-
- auto *C0 = dyn_cast<ConstantFP>(I->getOperand(0));
- auto *C1 = dyn_cast<ConstantFP>(I->getOperand(1));
-
- // Both operands are constant, let it get constant folded away.
- if (C0 && C1)
- return nullptr;
-
- ConstantFP *CF = C0 ? C0 : C1;
-
- // Must have one constant operand.
- if (!CF)
- return nullptr;
+/// Recursively analyze an expression to build a list of instructions that have
+/// negative floating-point constant operands. The caller can then transform
+/// the list to create positive constants for better reassociation and CSE.
+static void getNegatibleInsts(Value *V,
+ SmallVectorImpl<Instruction *> &Candidates) {
+ // Handle only one-use instructions. Combining negations does not justify
+ // replicating instructions.
+ Instruction *I;
+ if (!match(V, m_OneUse(m_Instruction(I))))
+ return;
- // Must be a negative ConstantFP.
- if (!CF->isNegative())
- return nullptr;
+ // Handle expressions of multiplications and divisions.
+ // TODO: This could look through floating-point casts.
+ const APFloat *C;
+ switch (I->getOpcode()) {
+ case Instruction::FMul:
+ // Not expecting non-canonical code here. Bail out and wait.
+ if (match(I->getOperand(0), m_Constant()))
+ break;
- // User must be a binary operator with one or more uses.
- Instruction *User = I->user_back();
- if (!isa<BinaryOperator>(User) || User->use_empty())
- return nullptr;
+ if (match(I->getOperand(1), m_APFloat(C)) && C->isNegative()) {
+ Candidates.push_back(I);
+ LLVM_DEBUG(dbgs() << "FMul with negative constant: " << *I << '\n');
+ }
+ getNegatibleInsts(I->getOperand(0), Candidates);
+ getNegatibleInsts(I->getOperand(1), Candidates);
+ break;
+ case Instruction::FDiv:
+ // Not expecting non-canonical code here. Bail out and wait.
+ if (match(I->getOperand(0), m_Constant()) &&
+ match(I->getOperand(1), m_Constant()))
+ break;
- unsigned UserOpcode = User->getOpcode();
- if (UserOpcode != Instruction::FAdd && UserOpcode != Instruction::FSub)
- return nullptr;
+ if ((match(I->getOperand(0), m_APFloat(C)) && C->isNegative()) ||
+ (match(I->getOperand(1), m_APFloat(C)) && C->isNegative())) {
+ Candidates.push_back(I);
+ LLVM_DEBUG(dbgs() << "FDiv with negative constant: " << *I << '\n');
+ }
+ getNegatibleInsts(I->getOperand(0), Candidates);
+ getNegatibleInsts(I->getOperand(1), Candidates);
+ break;
+ default:
+ break;
+ }
+}
- // Subtraction is not commutative. Explicitly, the following transform is
- // not valid: (-Constant * y) - x -> x + (Constant * y)
- if (!User->isCommutative() && User->getOperand(1) != I)
+/// Given an fadd/fsub with an operand that is a one-use instruction
+/// (the fadd/fsub), try to change negative floating-point constants into
+/// positive constants to increase potential for reassociation and CSE.
+Instruction *ReassociatePass::canonicalizeNegFPConstantsForOp(Instruction *I,
+ Instruction *Op,
+ Value *OtherOp) {
+ assert((I->getOpcode() == Instruction::FAdd ||
+ I->getOpcode() == Instruction::FSub) && "Expected fadd/fsub");
+
+ // Collect instructions with negative FP constants from the subtree that ends
+ // in Op.
+ SmallVector<Instruction *, 4> Candidates;
+ getNegatibleInsts(Op, Candidates);
+ if (Candidates.empty())
return nullptr;
// Don't canonicalize x + (-Constant * y) -> x - (Constant * y), if the
// resulting subtract will be broken up later. This can get us into an
// infinite loop during reassociation.
- if (UserOpcode == Instruction::FAdd && ShouldBreakUpSubtract(User))
+ bool IsFSub = I->getOpcode() == Instruction::FSub;
+ bool NeedsSubtract = !IsFSub && Candidates.size() % 2 == 1;
+ if (NeedsSubtract && ShouldBreakUpSubtract(I))
return nullptr;
- // Change the sign of the constant.
- APFloat Val = CF->getValueAPF();
- Val.changeSign();
- I->setOperand(C0 ? 0 : 1, ConstantFP::get(CF->getContext(), Val));
-
- // Canonicalize I to RHS to simplify the next bit of logic. E.g.,
- // ((-Const*y) + x) -> (x + (-Const*y)).
- if (User->getOperand(0) == I && User->isCommutative())
- cast<BinaryOperator>(User)->swapOperands();
-
- Value *Op0 = User->getOperand(0);
- Value *Op1 = User->getOperand(1);
- BinaryOperator *NI;
- switch (UserOpcode) {
- default:
- llvm_unreachable("Unexpected Opcode!");
- case Instruction::FAdd:
- NI = BinaryOperator::CreateFSub(Op0, Op1);
- NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags());
- break;
- case Instruction::FSub:
- NI = BinaryOperator::CreateFAdd(Op0, Op1);
- NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags());
- break;
+ for (Instruction *Negatible : Candidates) {
+ const APFloat *C;
+ if (match(Negatible->getOperand(0), m_APFloat(C))) {
+ assert(!match(Negatible->getOperand(1), m_Constant()) &&
+ "Expecting only 1 constant operand");
+ assert(C->isNegative() && "Expected negative FP constant");
+ Negatible->setOperand(0, ConstantFP::get(Negatible->getType(), abs(*C)));
+ MadeChange = true;
+ }
+ if (match(Negatible->getOperand(1), m_APFloat(C))) {
+ assert(!match(Negatible->getOperand(0), m_Constant()) &&
+ "Expecting only 1 constant operand");
+ assert(C->isNegative() && "Expected negative FP constant");
+ Negatible->setOperand(1, ConstantFP::get(Negatible->getType(), abs(*C)));
+ MadeChange = true;
+ }
}
+ assert(MadeChange == true && "Negative constant candidate was not changed");
- NI->insertBefore(User);
- NI->setName(User->getName());
- User->replaceAllUsesWith(NI);
- NI->setDebugLoc(I->getDebugLoc());
+ // Negations cancelled out.
+ if (Candidates.size() % 2 == 0)
+ return I;
+
+ // Negate the final operand in the expression by flipping the opcode of this
+ // fadd/fsub.
+ assert(Candidates.size() % 2 == 1 && "Expected odd number");
+ IRBuilder<> Builder(I);
+ Value *NewInst = IsFSub ? Builder.CreateFAddFMF(OtherOp, Op, I)
+ : Builder.CreateFSubFMF(OtherOp, Op, I);
+ I->replaceAllUsesWith(NewInst);
RedoInsts.insert(I);
- MadeChange = true;
- return NI;
+ return dyn_cast<Instruction>(NewInst);
+}
+
+/// Canonicalize expressions that contain a negative floating-point constant
+/// of the following form:
+/// OtherOp + (subtree) -> OtherOp {+/-} (canonical subtree)
+/// (subtree) + OtherOp -> OtherOp {+/-} (canonical subtree)
+/// OtherOp - (subtree) -> OtherOp {+/-} (canonical subtree)
+///
+/// The fadd/fsub opcode may be switched to allow folding a negation into the
+/// input instruction.
+Instruction *ReassociatePass::canonicalizeNegFPConstants(Instruction *I) {
+ LLVM_DEBUG(dbgs() << "Combine negations for: " << *I << '\n');
+ Value *X;
+ Instruction *Op;
+ if (match(I, m_FAdd(m_Value(X), m_OneUse(m_Instruction(Op)))))
+ if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X))
+ I = R;
+ if (match(I, m_FAdd(m_OneUse(m_Instruction(Op)), m_Value(X))))
+ if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X))
+ I = R;
+ if (match(I, m_FSub(m_Value(X), m_OneUse(m_Instruction(Op)))))
+ if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X))
+ I = R;
+ return I;
}
/// Inspect and optimize the given instruction. Note that erasing
@@ -2042,16 +2086,16 @@ void ReassociatePass::OptimizeInst(Instruction *I) {
I = NI;
}
- // Canonicalize negative constants out of expressions.
- if (Instruction *Res = canonicalizeNegConstExpr(I))
- I = Res;
-
// Commute binary operators, to canonicalize the order of their operands.
// This can potentially expose more CSE opportunities, and makes writing other
// transformations simpler.
if (I->isCommutative())
canonicalizeOperands(I);
+ // Canonicalize negative constants out of expressions.
+ if (Instruction *Res = canonicalizeNegFPConstants(I))
+ I = Res;
+
// Don't optimize floating-point instructions unless they are 'fast'.
if (I->getType()->isFPOrFPVectorTy() && !I->isFast())
return;
diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index c358258d24cf..48bbdd8d1b33 100644
--- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -172,8 +172,6 @@ public:
bool runOnModule(Module &M) override {
bool Changed = false;
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
for (Function &F : M) {
// Nothing to do for declarations.
if (F.isDeclaration() || F.empty())
@@ -186,6 +184,8 @@ public:
TargetTransformInfo &TTI =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ const TargetLibraryInfo &TLI =
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
Changed |= Impl.runOnFunction(F, DT, TTI, TLI);
@@ -2530,7 +2530,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,
// statepoints surviving this pass. This makes testing easier and the
// resulting IR less confusing to human readers.
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
- bool MadeChange = removeUnreachableBlocks(F, nullptr, &DTU);
+ bool MadeChange = removeUnreachableBlocks(F, &DTU);
// Flush the Dominator Tree.
DTU.getDomTree();
diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp
index 4093e50ce899..10fbdc8aacd2 100644
--- a/lib/Transforms/Scalar/SCCP.cpp
+++ b/lib/Transforms/Scalar/SCCP.cpp
@@ -191,7 +191,7 @@ public:
///
class SCCPSolver : public InstVisitor<SCCPSolver> {
const DataLayout &DL;
- const TargetLibraryInfo *TLI;
+ std::function<const TargetLibraryInfo &(Function &)> GetTLI;
SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable.
DenseMap<Value *, LatticeVal> ValueState; // The state each value is in.
// The state each parameter is in.
@@ -268,8 +268,9 @@ public:
return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy};
}
- SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli)
- : DL(DL), TLI(tli) {}
+ SCCPSolver(const DataLayout &DL,
+ std::function<const TargetLibraryInfo &(Function &)> GetTLI)
+ : DL(DL), GetTLI(std::move(GetTLI)) {}
/// MarkBlockExecutable - This method can be used by clients to mark all of
/// the blocks that are known to be intrinsically live in the processed unit.
@@ -1290,7 +1291,7 @@ CallOverdefined:
// If we can constant fold this, mark the result of the call as a
// constant.
if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F,
- Operands, TLI)) {
+ Operands, &GetTLI(*F))) {
// call -> undef.
if (isa<UndefValue>(C))
return;
@@ -1465,7 +1466,24 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
}
LatticeVal &LV = getValueState(&I);
- if (!LV.isUnknown()) continue;
+ if (!LV.isUnknown())
+ continue;
+
+ // There are two reasons a call can have an undef result
+ // 1. It could be tracked.
+ // 2. It could be constant-foldable.
+ // Because of the way we solve return values, tracked calls must
+ // never be marked overdefined in ResolvedUndefsIn.
+ if (CallSite CS = CallSite(&I)) {
+ if (Function *F = CS.getCalledFunction())
+ if (TrackedRetVals.count(F))
+ continue;
+
+ // If the call is constant-foldable, we mark it overdefined because
+ // we do not know what return values are valid.
+ markOverdefined(&I);
+ return true;
+ }
// extractvalue is safe; check here because the argument is a struct.
if (isa<ExtractValueInst>(I))
@@ -1638,19 +1656,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {
case Instruction::Call:
case Instruction::Invoke:
case Instruction::CallBr:
- // There are two reasons a call can have an undef result
- // 1. It could be tracked.
- // 2. It could be constant-foldable.
- // Because of the way we solve return values, tracked calls must
- // never be marked overdefined in ResolvedUndefsIn.
- if (Function *F = CallSite(&I).getCalledFunction())
- if (TrackedRetVals.count(F))
- break;
-
- // If the call is constant-foldable, we mark it overdefined because
- // we do not know what return values are valid.
- markOverdefined(&I);
- return true;
+ llvm_unreachable("Call-like instructions should have be handled early");
default:
// If we don't know what should happen here, conservatively mark it
// overdefined.
@@ -1751,7 +1757,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
[](const LatticeVal &LV) { return LV.isOverdefined(); }))
return false;
std::vector<Constant *> ConstVals;
- auto *ST = dyn_cast<StructType>(V->getType());
+ auto *ST = cast<StructType>(V->getType());
for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
LatticeVal V = IVs[i];
ConstVals.push_back(V.isConstant()
@@ -1796,7 +1802,8 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {
static bool runSCCP(Function &F, const DataLayout &DL,
const TargetLibraryInfo *TLI) {
LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n");
- SCCPSolver Solver(DL, TLI);
+ SCCPSolver Solver(
+ DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; });
// Mark the first block of the function as being executable.
Solver.MarkBlockExecutable(&F.front());
@@ -1891,7 +1898,7 @@ public:
return false;
const DataLayout &DL = F.getParent()->getDataLayout();
const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
return runSCCP(F, DL, TLI);
}
};
@@ -1924,6 +1931,27 @@ static void findReturnsToZap(Function &F,
return;
}
+ assert(
+ all_of(F.users(),
+ [&Solver](User *U) {
+ if (isa<Instruction>(U) &&
+ !Solver.isBlockExecutable(cast<Instruction>(U)->getParent()))
+ return true;
+ // Non-callsite uses are not impacted by zapping. Also, constant
+ // uses (like blockaddresses) could stuck around, without being
+ // used in the underlying IR, meaning we do not have lattice
+ // values for them.
+ if (!CallSite(U))
+ return true;
+ if (U->getType()->isStructTy()) {
+ return all_of(
+ Solver.getStructLatticeValueFor(U),
+ [](const LatticeVal &LV) { return !LV.isOverdefined(); });
+ }
+ return !Solver.getLatticeValueFor(U).isOverdefined();
+ }) &&
+ "We can only zap functions where all live users have a concrete value");
+
for (BasicBlock &BB : F) {
if (CallInst *CI = BB.getTerminatingMustTailCall()) {
LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present "
@@ -1974,9 +2002,10 @@ static void forceIndeterminateEdge(Instruction* I, SCCPSolver &Solver) {
}
bool llvm::runIPSCCP(
- Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI,
+ Module &M, const DataLayout &DL,
+ std::function<const TargetLibraryInfo &(Function &)> GetTLI,
function_ref<AnalysisResultsForFn(Function &)> getAnalysis) {
- SCCPSolver Solver(DL, TLI);
+ SCCPSolver Solver(DL, GetTLI);
// Loop over all functions, marking arguments to those with their addresses
// taken or that are external as overdefined.
diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp
index 33f90d0b01e4..74b8ff913050 100644
--- a/lib/Transforms/Scalar/SROA.cpp
+++ b/lib/Transforms/Scalar/SROA.cpp
@@ -959,14 +959,16 @@ private:
std::tie(UsedI, I) = Uses.pop_back_val();
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
- Size = std::max(Size, DL.getTypeStoreSize(LI->getType()));
+ Size = std::max(Size,
+ DL.getTypeStoreSize(LI->getType()).getFixedSize());
continue;
}
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
Value *Op = SI->getOperand(0);
if (Op == UsedI)
return SI;
- Size = std::max(Size, DL.getTypeStoreSize(Op->getType()));
+ Size = std::max(Size,
+ DL.getTypeStoreSize(Op->getType()).getFixedSize());
continue;
}
@@ -1197,7 +1199,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) {
// TODO: Allow recursive phi users.
// TODO: Allow stores.
BasicBlock *BB = PN.getParent();
- unsigned MaxAlign = 0;
+ MaybeAlign MaxAlign;
uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType());
APInt MaxSize(APWidth, 0);
bool HaveLoad = false;
@@ -1218,8 +1220,8 @@ static bool isSafePHIToSpeculate(PHINode &PN) {
if (BBI->mayWriteToMemory())
return false;
- uint64_t Size = DL.getTypeStoreSizeInBits(LI->getType());
- MaxAlign = std::max(MaxAlign, LI->getAlignment());
+ uint64_t Size = DL.getTypeStoreSize(LI->getType());
+ MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment()));
MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize;
HaveLoad = true;
}
@@ -1266,11 +1268,11 @@ static void speculatePHINodeLoads(PHINode &PN) {
PHINode *NewPN = PHIBuilder.CreatePHI(LoadTy, PN.getNumIncomingValues(),
PN.getName() + ".sroa.speculated");
- // Get the AA tags and alignment to use from one of the loads. It doesn't
+ // Get the AA tags and alignment to use from one of the loads. It does not
// matter which one we get and if any differ.
AAMDNodes AATags;
SomeLoad->getAAMetadata(AATags);
- unsigned Align = SomeLoad->getAlignment();
+ const MaybeAlign Align = MaybeAlign(SomeLoad->getAlignment());
// Rewrite all loads of the PN to use the new PHI.
while (!PN.use_empty()) {
@@ -1338,11 +1340,11 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) {
// Both operands to the select need to be dereferenceable, either
// absolutely (e.g. allocas) or at this point because we can see other
// accesses to it.
- if (!isSafeToLoadUnconditionally(TValue, LI->getType(), LI->getAlignment(),
- DL, LI))
+ if (!isSafeToLoadUnconditionally(TValue, LI->getType(),
+ MaybeAlign(LI->getAlignment()), DL, LI))
return false;
- if (!isSafeToLoadUnconditionally(FValue, LI->getType(), LI->getAlignment(),
- DL, LI))
+ if (!isSafeToLoadUnconditionally(FValue, LI->getType(),
+ MaybeAlign(LI->getAlignment()), DL, LI))
return false;
}
@@ -1368,8 +1370,8 @@ static void speculateSelectInstLoads(SelectInst &SI) {
NumLoadsSpeculated += 2;
// Transfer alignment and AA info if present.
- TL->setAlignment(LI->getAlignment());
- FL->setAlignment(LI->getAlignment());
+ TL->setAlignment(MaybeAlign(LI->getAlignment()));
+ FL->setAlignment(MaybeAlign(LI->getAlignment()));
AAMDNodes Tags;
LI->getAAMetadata(Tags);
@@ -1888,6 +1890,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
bool HaveCommonEltTy = true;
auto CheckCandidateType = [&](Type *Ty) {
if (auto *VTy = dyn_cast<VectorType>(Ty)) {
+ // Return if bitcast to vectors is different for total size in bits.
+ if (!CandidateTys.empty()) {
+ VectorType *V = CandidateTys[0];
+ if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) {
+ CandidateTys.clear();
+ return;
+ }
+ }
CandidateTys.push_back(VTy);
if (!CommonEltTy)
CommonEltTy = VTy->getElementType();
@@ -3110,7 +3120,7 @@ private:
unsigned LoadAlign = LI->getAlignment();
if (!LoadAlign)
LoadAlign = DL.getABITypeAlignment(LI->getType());
- LI->setAlignment(std::min(LoadAlign, getSliceAlign()));
+ LI->setAlignment(MaybeAlign(std::min(LoadAlign, getSliceAlign())));
continue;
}
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
@@ -3119,7 +3129,7 @@ private:
Value *Op = SI->getOperand(0);
StoreAlign = DL.getABITypeAlignment(Op->getType());
}
- SI->setAlignment(std::min(StoreAlign, getSliceAlign()));
+ SI->setAlignment(MaybeAlign(std::min(StoreAlign, getSliceAlign())));
continue;
}
diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp
index 869cf00e0a89..1d2e40bf62be 100644
--- a/lib/Transforms/Scalar/Scalar.cpp
+++ b/lib/Transforms/Scalar/Scalar.cpp
@@ -79,6 +79,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeLoopVersioningLICMPass(Registry);
initializeLoopIdiomRecognizeLegacyPassPass(Registry);
initializeLowerAtomicLegacyPassPass(Registry);
+ initializeLowerConstantIntrinsicsPass(Registry);
initializeLowerExpectIntrinsicPass(Registry);
initializeLowerGuardIntrinsicLegacyPassPass(Registry);
initializeLowerWidenableConditionLegacyPassPass(Registry);
@@ -123,6 +124,10 @@ void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createAggressiveDCEPass());
}
+void LLVMAddDCEPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createDeadCodeEliminationPass());
+}
+
void LLVMAddBitTrackingDCEPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createBitTrackingDCEPass());
}
@@ -280,6 +285,10 @@ void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createBasicAAWrapperPass());
}
+void LLVMAddLowerConstantIntrinsicsPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createLowerConstantIntrinsicsPass());
+}
+
void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createLowerExpectIntrinsicPass());
}
diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index f6a12fb13142..41554fccdf08 100644
--- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -1121,7 +1121,7 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) {
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
bool Changed = false;
for (BasicBlock &B : F) {
for (BasicBlock::iterator I = B.begin(), IE = B.end(); I != IE;)
diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index aeac6f548b32..ac832b9b4567 100644
--- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -1909,7 +1909,7 @@ static void unswitchNontrivialInvariants(
// We can only unswitch switches, conditional branches with an invariant
// condition, or combining invariant conditions with an instruction.
- assert((SI || BI->isConditional()) &&
+ assert((SI || (BI && BI->isConditional())) &&
"Can only unswitch switches and conditional branch!");
bool FullUnswitch = SI || BI->getCondition() == Invariants[0];
if (FullUnswitch)
@@ -2141,17 +2141,21 @@ static void unswitchNontrivialInvariants(
buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction,
*ClonedPH, *LoopPH);
DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH});
+
+ if (MSSAU) {
+ DT.applyUpdates(DTUpdates);
+ DTUpdates.clear();
+
+ // Perform MSSA cloning updates.
+ for (auto &VMap : VMaps)
+ MSSAU->updateForClonedLoop(LBRPO, ExitBlocks, *VMap,
+ /*IgnoreIncomingWithNoClones=*/true);
+ MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMaps, DT);
+ }
}
// Apply the updates accumulated above to get an up-to-date dominator tree.
DT.applyUpdates(DTUpdates);
- if (!FullUnswitch && MSSAU) {
- // Update MSSA for partial unswitch, after DT update.
- SmallVector<CFGUpdate, 1> Updates;
- Updates.push_back(
- {cfg::UpdateKind::Insert, SplitBB, ClonedPHs.begin()->second});
- MSSAU->applyInsertUpdates(Updates, DT);
- }
// Now that we have an accurate dominator tree, first delete the dead cloned
// blocks so that we can accurately build any cloned loops. It is important to
@@ -2720,7 +2724,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
return Cost * (SuccessorsCount - 1);
};
Instruction *BestUnswitchTI = nullptr;
- int BestUnswitchCost;
+ int BestUnswitchCost = 0;
ArrayRef<Value *> BestUnswitchInvariants;
for (auto &TerminatorAndInvariants : UnswitchCandidates) {
Instruction &TI = *TerminatorAndInvariants.first;
@@ -2752,6 +2756,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
BestUnswitchInvariants = Invariants;
}
}
+ assert(BestUnswitchTI && "Failed to find loop unswitch candidate");
if (BestUnswitchCost >= UnswitchThreshold) {
LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: "
@@ -2880,7 +2885,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
auto PA = getLoopPassPreservedAnalyses();
- if (EnableMSSALoopDependency)
+ if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}
diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
index c13fb3e04516..e6db11f47ead 100644
--- a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
+++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp
@@ -777,8 +777,10 @@ static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs,
// speculation if the predecessor is an invoke. This doesn't seem
// fundamental and we should probably be splitting critical edges
// differently.
- if (isa<IndirectBrInst>(PredBB->getTerminator()) ||
- isa<InvokeInst>(PredBB->getTerminator())) {
+ const auto *TermInst = PredBB->getTerminator();
+ if (isa<IndirectBrInst>(TermInst) ||
+ isa<InvokeInst>(TermInst) ||
+ isa<CallBrInst>(TermInst)) {
LLVM_DEBUG(dbgs() << " Invalid: predecessor terminator: "
<< PredBB->getName() << "\n");
return false;
diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp
index e5400676c7e8..9791cf41f621 100644
--- a/lib/Transforms/Scalar/StructurizeCFG.cpp
+++ b/lib/Transforms/Scalar/StructurizeCFG.cpp
@@ -65,7 +65,7 @@ static cl::opt<bool> ForceSkipUniformRegions(
static cl::opt<bool>
RelaxedUniformRegions("structurizecfg-relaxed-uniform-regions", cl::Hidden,
cl::desc("Allow relaxed uniform region checks"),
- cl::init(false));
+ cl::init(true));
// Definition of the complex types used in this pass.
diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp
index f0b79079d817..b27a36b67d62 100644
--- a/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -341,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
const DataLayout &DL = L->getModule()->getDataLayout();
if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) ||
!isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(),
- L->getAlignment(), DL, L))
+ MaybeAlign(L->getAlignment()), DL, L))
return false;
}
}
diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp
index 5fa371377c85..d85cc40c372a 100644
--- a/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -170,7 +170,8 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) {
bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
LoopInfo *LI, MemorySSAUpdater *MSSAU,
- MemoryDependenceResults *MemDep) {
+ MemoryDependenceResults *MemDep,
+ bool PredecessorWithTwoSuccessors) {
if (BB->hasAddressTaken())
return false;
@@ -185,9 +186,24 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
return false;
// Can't merge if there are multiple distinct successors.
- if (PredBB->getUniqueSuccessor() != BB)
+ if (!PredecessorWithTwoSuccessors && PredBB->getUniqueSuccessor() != BB)
return false;
+ // Currently only allow PredBB to have two predecessors, one being BB.
+ // Update BI to branch to BB's only successor instead of BB.
+ BranchInst *PredBB_BI;
+ BasicBlock *NewSucc = nullptr;
+ unsigned FallThruPath;
+ if (PredecessorWithTwoSuccessors) {
+ if (!(PredBB_BI = dyn_cast<BranchInst>(PredBB->getTerminator())))
+ return false;
+ BranchInst *BB_JmpI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BB_JmpI || !BB_JmpI->isUnconditional())
+ return false;
+ NewSucc = BB_JmpI->getSuccessor(0);
+ FallThruPath = PredBB_BI->getSuccessor(0) == BB ? 0 : 1;
+ }
+
// Can't merge if there is PHI loop.
for (PHINode &PN : BB->phis())
for (Value *IncValue : PN.incoming_values())
@@ -227,18 +243,39 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
Updates.push_back({DominatorTree::Delete, PredBB, BB});
}
- if (MSSAU)
- MSSAU->moveAllAfterMergeBlocks(BB, PredBB, &*(BB->begin()));
+ Instruction *PTI = PredBB->getTerminator();
+ Instruction *STI = BB->getTerminator();
+ Instruction *Start = &*BB->begin();
+ // If there's nothing to move, mark the starting instruction as the last
+ // instruction in the block.
+ if (Start == STI)
+ Start = PTI;
+
+ // Move all definitions in the successor to the predecessor...
+ PredBB->getInstList().splice(PTI->getIterator(), BB->getInstList(),
+ BB->begin(), STI->getIterator());
- // Delete the unconditional branch from the predecessor...
- PredBB->getInstList().pop_back();
+ if (MSSAU)
+ MSSAU->moveAllAfterMergeBlocks(BB, PredBB, Start);
// Make all PHI nodes that referred to BB now refer to Pred as their
// source...
BB->replaceAllUsesWith(PredBB);
- // Move all definitions in the successor to the predecessor...
- PredBB->getInstList().splice(PredBB->end(), BB->getInstList());
+ if (PredecessorWithTwoSuccessors) {
+ // Delete the unconditional branch from BB.
+ BB->getInstList().pop_back();
+
+ // Update branch in the predecessor.
+ PredBB_BI->setSuccessor(FallThruPath, NewSucc);
+ } else {
+ // Delete the unconditional branch from the predecessor.
+ PredBB->getInstList().pop_back();
+
+ // Move terminator instruction.
+ PredBB->getInstList().splice(PredBB->end(), BB->getInstList());
+ }
+ // Add unreachable to now empty BB.
new UnreachableInst(BB->getContext(), BB);
// Eliminate duplicate dbg.values describing the entry PHI node post-splice.
@@ -274,11 +311,10 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
"applying corresponding DTU updates.");
DTU->applyUpdatesPermissive(Updates);
DTU->deleteBB(BB);
- }
-
- else {
+ } else {
BB->eraseFromParent(); // Nuke BB if DTU is nullptr.
}
+
return true;
}
@@ -365,11 +401,13 @@ llvm::SplitAllCriticalEdges(Function &F,
BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt,
DominatorTree *DT, LoopInfo *LI,
- MemorySSAUpdater *MSSAU) {
+ MemorySSAUpdater *MSSAU, const Twine &BBName) {
BasicBlock::iterator SplitIt = SplitPt->getIterator();
while (isa<PHINode>(SplitIt) || SplitIt->isEHPad())
++SplitIt;
- BasicBlock *New = Old->splitBasicBlock(SplitIt, Old->getName()+".split");
+ std::string Name = BBName.str();
+ BasicBlock *New = Old->splitBasicBlock(
+ SplitIt, Name.empty() ? Old->getName() + ".split" : Name);
// The new block lives in whichever loop the old one did. This preserves
// LCSSA as well, because we force the split point to be after any PHI nodes.
diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp
index 27f110e24f9c..71316ce8f758 100644
--- a/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -88,6 +88,14 @@ static bool setDoesNotCapture(Function &F, unsigned ArgNo) {
return true;
}
+static bool setDoesNotAlias(Function &F, unsigned ArgNo) {
+ if (F.hasParamAttribute(ArgNo, Attribute::NoAlias))
+ return false;
+ F.addParamAttr(ArgNo, Attribute::NoAlias);
+ ++NumNoAlias;
+ return true;
+}
+
static bool setOnlyReadsMemory(Function &F, unsigned ArgNo) {
if (F.hasParamAttribute(ArgNo, Attribute::ReadOnly))
return false;
@@ -175,6 +183,9 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
return Changed;
case LibFunc_strcpy:
case LibFunc_strncpy:
+ Changed |= setDoesNotAlias(F, 0);
+ Changed |= setDoesNotAlias(F, 1);
+ LLVM_FALLTHROUGH;
case LibFunc_strcat:
case LibFunc_strncat:
Changed |= setReturnedArg(F, 0);
@@ -249,12 +260,14 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
case LibFunc_sprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 0);
+ Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
return Changed;
case LibFunc_snprintf:
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 0);
+ Changed |= setDoesNotAlias(F, 0);
Changed |= setDoesNotCapture(F, 2);
Changed |= setOnlyReadsMemory(F, 2);
return Changed;
@@ -291,11 +304,23 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
Changed |= setDoesNotCapture(F, 1);
return Changed;
case LibFunc_memcpy:
+ Changed |= setDoesNotAlias(F, 0);
+ Changed |= setDoesNotAlias(F, 1);
+ Changed |= setReturnedArg(F, 0);
+ Changed |= setDoesNotThrow(F);
+ Changed |= setDoesNotCapture(F, 1);
+ Changed |= setOnlyReadsMemory(F, 1);
+ return Changed;
case LibFunc_memmove:
Changed |= setReturnedArg(F, 0);
- LLVM_FALLTHROUGH;
+ Changed |= setDoesNotThrow(F);
+ Changed |= setDoesNotCapture(F, 1);
+ Changed |= setOnlyReadsMemory(F, 1);
+ return Changed;
case LibFunc_mempcpy:
case LibFunc_memccpy:
+ Changed |= setDoesNotAlias(F, 0);
+ Changed |= setDoesNotAlias(F, 1);
Changed |= setDoesNotThrow(F);
Changed |= setDoesNotCapture(F, 1);
Changed |= setOnlyReadsMemory(F, 1);
@@ -760,9 +785,8 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) {
}
}
-bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty,
- LibFunc DoubleFn, LibFunc FloatFn,
- LibFunc LongDoubleFn) {
+bool llvm::hasFloatFn(const TargetLibraryInfo *TLI, Type *Ty,
+ LibFunc DoubleFn, LibFunc FloatFn, LibFunc LongDoubleFn) {
switch (Ty->getTypeID()) {
case Type::HalfTyID:
return false;
@@ -775,10 +799,10 @@ bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty,
}
}
-StringRef llvm::getUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty,
- LibFunc DoubleFn, LibFunc FloatFn,
- LibFunc LongDoubleFn) {
- assert(hasUnaryFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) &&
+StringRef llvm::getFloatFnName(const TargetLibraryInfo *TLI, Type *Ty,
+ LibFunc DoubleFn, LibFunc FloatFn,
+ LibFunc LongDoubleFn) {
+ assert(hasFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) &&
"Cannot get name for unavailable function!");
switch (Ty->getTypeID()) {
@@ -827,6 +851,12 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL,
B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI);
}
+Value *llvm::emitStrDup(Value *Ptr, IRBuilder<> &B,
+ const TargetLibraryInfo *TLI) {
+ return emitLibCall(LibFunc_strdup, B.getInt8PtrTy(), B.getInt8PtrTy(),
+ castToCStr(Ptr, B), B, TLI);
+}
+
Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
Type *I8Ptr = B.getInt8PtrTy();
@@ -1045,24 +1075,28 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI,
LibFunc LongDoubleFn, IRBuilder<> &B,
const AttributeList &Attrs) {
// Get the name of the function according to TLI.
- StringRef Name = getUnaryFloatFn(TLI, Op->getType(),
- DoubleFn, FloatFn, LongDoubleFn);
+ StringRef Name = getFloatFnName(TLI, Op->getType(),
+ DoubleFn, FloatFn, LongDoubleFn);
return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs);
}
-Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,
- IRBuilder<> &B, const AttributeList &Attrs) {
+static Value *emitBinaryFloatFnCallHelper(Value *Op1, Value *Op2,
+ StringRef Name, IRBuilder<> &B,
+ const AttributeList &Attrs) {
assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall");
- SmallString<20> NameBuffer;
- appendTypeSuffix(Op1, Name, NameBuffer);
-
Module *M = B.GetInsertBlock()->getModule();
- FunctionCallee Callee = M->getOrInsertFunction(
- Name, Op1->getType(), Op1->getType(), Op2->getType());
- CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name);
- CI->setAttributes(Attrs);
+ FunctionCallee Callee = M->getOrInsertFunction(Name, Op1->getType(),
+ Op1->getType(), Op2->getType());
+ CallInst *CI = B.CreateCall(Callee, { Op1, Op2 }, Name);
+
+ // The incoming attribute set may have come from a speculatable intrinsic, but
+ // is being replaced with a library call which is not allowed to be
+ // speculatable.
+ CI->setAttributes(Attrs.removeAttribute(B.getContext(),
+ AttributeList::FunctionIndex,
+ Attribute::Speculatable));
if (const Function *F =
dyn_cast<Function>(Callee.getCallee()->stripPointerCasts()))
CI->setCallingConv(F->getCallingConv());
@@ -1070,6 +1104,28 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,
return CI;
}
+Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name,
+ IRBuilder<> &B, const AttributeList &Attrs) {
+ assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall");
+
+ SmallString<20> NameBuffer;
+ appendTypeSuffix(Op1, Name, NameBuffer);
+
+ return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs);
+}
+
+Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2,
+ const TargetLibraryInfo *TLI,
+ LibFunc DoubleFn, LibFunc FloatFn,
+ LibFunc LongDoubleFn, IRBuilder<> &B,
+ const AttributeList &Attrs) {
+ // Get the name of the function according to TLI.
+ StringRef Name = getFloatFnName(TLI, Op1->getType(),
+ DoubleFn, FloatFn, LongDoubleFn);
+
+ return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs);
+}
+
Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B,
const TargetLibraryInfo *TLI) {
if (!TLI->has(LibFunc_putchar))
diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp
index df299f673f65..9a6761040bd8 100644
--- a/lib/Transforms/Utils/BypassSlowDivision.cpp
+++ b/lib/Transforms/Utils/BypassSlowDivision.cpp
@@ -448,13 +448,17 @@ bool llvm::bypassSlowDivision(BasicBlock *BB,
DivCacheTy PerBBDivCache;
bool MadeChange = false;
- Instruction* Next = &*BB->begin();
+ Instruction *Next = &*BB->begin();
while (Next != nullptr) {
// We may add instructions immediately after I, but we want to skip over
// them.
- Instruction* I = Next;
+ Instruction *I = Next;
Next = Next->getNextNode();
+ // Ignore dead code to save time and avoid bugs.
+ if (I->hasNUses(0))
+ continue;
+
FastDivInsertionTask Task(I, BypassWidths);
if (Value *Replacement = Task.getReplacement(PerBBDivCache)) {
I->replaceAllUsesWith(Replacement);
diff --git a/lib/Transforms/Utils/CanonicalizeAliases.cpp b/lib/Transforms/Utils/CanonicalizeAliases.cpp
index 455fcbb1cf98..3c7c8d872595 100644
--- a/lib/Transforms/Utils/CanonicalizeAliases.cpp
+++ b/lib/Transforms/Utils/CanonicalizeAliases.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/Operator.h"
#include "llvm/IR/ValueHandle.h"
+#include "llvm/Pass.h"
using namespace llvm;
diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp
index 1026c9d37038..75e8963303c2 100644
--- a/lib/Transforms/Utils/CloneFunction.cpp
+++ b/lib/Transforms/Utils/CloneFunction.cpp
@@ -210,6 +210,21 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,
RemapInstruction(&II, VMap,
ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges,
TypeMapper, Materializer);
+
+ // Register all DICompileUnits of the old parent module in the new parent module
+ auto* OldModule = OldFunc->getParent();
+ auto* NewModule = NewFunc->getParent();
+ if (OldModule && NewModule && OldModule != NewModule && DIFinder.compile_unit_count()) {
+ auto* NMD = NewModule->getOrInsertNamedMetadata("llvm.dbg.cu");
+ // Avoid multiple insertions of the same DICompileUnit to NMD.
+ SmallPtrSet<const void*, 8> Visited;
+ for (auto* Operand : NMD->operands())
+ Visited.insert(Operand);
+ for (auto* Unit : DIFinder.compile_units())
+ // VMap.MD()[Unit] == Unit
+ if (Visited.insert(Unit).second)
+ NMD->addOperand(Unit);
+ }
}
/// Return a copy of the specified function and add it to that function's
diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp
index 7ddf59becba9..2c8c3abb2922 100644
--- a/lib/Transforms/Utils/CloneModule.cpp
+++ b/lib/Transforms/Utils/CloneModule.cpp
@@ -48,7 +48,7 @@ std::unique_ptr<Module> llvm::CloneModule(
function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) {
// First off, we need to create the new module.
std::unique_ptr<Module> New =
- llvm::make_unique<Module>(M.getModuleIdentifier(), M.getContext());
+ std::make_unique<Module>(M.getModuleIdentifier(), M.getContext());
New->setSourceFileName(M.getSourceFileName());
New->setDataLayout(M.getDataLayout());
New->setTargetTriple(M.getTargetTriple());
@@ -181,13 +181,25 @@ std::unique_ptr<Module> llvm::CloneModule(
}
// And named metadata....
+ const auto* LLVM_DBG_CU = M.getNamedMetadata("llvm.dbg.cu");
for (Module::const_named_metadata_iterator I = M.named_metadata_begin(),
E = M.named_metadata_end();
I != E; ++I) {
const NamedMDNode &NMD = *I;
NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName());
- for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i)
- NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap));
+ if (&NMD == LLVM_DBG_CU) {
+ // Do not insert duplicate operands.
+ SmallPtrSet<const void*, 8> Visited;
+ for (const auto* Operand : NewNMD->operands())
+ Visited.insert(Operand);
+ for (const auto* Operand : NMD.operands()) {
+ auto* MappedOperand = MapMetadata(Operand, VMap);
+ if (Visited.insert(MappedOperand).second)
+ NewNMD->addOperand(MappedOperand);
+ }
+ } else
+ for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i)
+ NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap));
}
return New;
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index fa6d3f8ae873..0298ff9a395f 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -293,10 +293,8 @@ static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
CommonExitBlock = Succ;
continue;
}
- if (CommonExitBlock == Succ)
- continue;
-
- return true;
+ if (CommonExitBlock != Succ)
+ return true;
}
return false;
};
@@ -307,52 +305,79 @@ static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
return CommonExitBlock;
}
-bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
- Instruction *Addr) const {
- AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
- Function *Func = (*Blocks.begin())->getParent();
- for (BasicBlock &BB : *Func) {
- if (Blocks.count(&BB))
- continue;
- for (Instruction &II : BB) {
- if (isa<DbgInfoIntrinsic>(II))
- continue;
+CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) {
+ for (BasicBlock &BB : F) {
+ for (Instruction &II : BB.instructionsWithoutDebug())
+ if (auto *AI = dyn_cast<AllocaInst>(&II))
+ Allocas.push_back(AI);
- unsigned Opcode = II.getOpcode();
- Value *MemAddr = nullptr;
- switch (Opcode) {
- case Instruction::Store:
- case Instruction::Load: {
- if (Opcode == Instruction::Store) {
- StoreInst *SI = cast<StoreInst>(&II);
- MemAddr = SI->getPointerOperand();
- } else {
- LoadInst *LI = cast<LoadInst>(&II);
- MemAddr = LI->getPointerOperand();
- }
- // Global variable can not be aliased with locals.
- if (dyn_cast<Constant>(MemAddr))
- break;
- Value *Base = MemAddr->stripInBoundsConstantOffsets();
- if (!isa<AllocaInst>(Base) || Base == AI)
- return false;
+ findSideEffectInfoForBlock(BB);
+ }
+}
+
+void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) {
+ for (Instruction &II : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = II.getOpcode();
+ Value *MemAddr = nullptr;
+ switch (Opcode) {
+ case Instruction::Store:
+ case Instruction::Load: {
+ if (Opcode == Instruction::Store) {
+ StoreInst *SI = cast<StoreInst>(&II);
+ MemAddr = SI->getPointerOperand();
+ } else {
+ LoadInst *LI = cast<LoadInst>(&II);
+ MemAddr = LI->getPointerOperand();
+ }
+ // Global variable can not be aliased with locals.
+ if (dyn_cast<Constant>(MemAddr))
break;
+ Value *Base = MemAddr->stripInBoundsConstantOffsets();
+ if (!isa<AllocaInst>(Base)) {
+ SideEffectingBlocks.insert(&BB);
+ return;
}
- default: {
- IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
- if (IntrInst) {
- if (IntrInst->isLifetimeStartOrEnd())
- break;
- return false;
- }
- // Treat all the other cases conservatively if it has side effects.
- if (II.mayHaveSideEffects())
- return false;
+ BaseMemAddrs[&BB].insert(Base);
+ break;
+ }
+ default: {
+ IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
+ if (IntrInst) {
+ if (IntrInst->isLifetimeStartOrEnd())
+ break;
+ SideEffectingBlocks.insert(&BB);
+ return;
}
+ // Treat all the other cases conservatively if it has side effects.
+ if (II.mayHaveSideEffects()) {
+ SideEffectingBlocks.insert(&BB);
+ return;
}
}
+ }
}
+}
+bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
+ BasicBlock &BB, AllocaInst *Addr) const {
+ if (SideEffectingBlocks.count(&BB))
+ return true;
+ auto It = BaseMemAddrs.find(&BB);
+ if (It != BaseMemAddrs.end())
+ return It->second.count(Addr);
+ return false;
+}
+
+bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
+ const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const {
+ AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
+ Function *Func = (*Blocks.begin())->getParent();
+ for (BasicBlock &BB : *Func) {
+ if (Blocks.count(&BB))
+ continue;
+ if (CEAC.doesBlockContainClobberOfAddr(BB, AI))
+ return false;
+ }
return true;
}
@@ -415,7 +440,8 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
// outline region. If there are not other untracked uses of the address, return
// the pair of markers if found; otherwise return a pair of nullptr.
CodeExtractor::LifetimeMarkerInfo
-CodeExtractor::getLifetimeMarkers(Instruction *Addr,
+CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
+ Instruction *Addr,
BasicBlock *ExitBlock) const {
LifetimeMarkerInfo Info;
@@ -447,7 +473,7 @@ CodeExtractor::getLifetimeMarkers(Instruction *Addr,
Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
// Do legality check.
if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
- !isLegalToShrinkwrapLifetimeMarkers(Addr))
+ !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr))
return {};
// Check to see if we have a place to do hoisting, if not, bail.
@@ -457,7 +483,8 @@ CodeExtractor::getLifetimeMarkers(Instruction *Addr,
return Info;
}
-void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
+void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
+ ValueSet &SinkCands, ValueSet &HoistCands,
BasicBlock *&ExitBlock) const {
Function *Func = (*Blocks.begin())->getParent();
ExitBlock = getCommonExitBlock(Blocks);
@@ -478,74 +505,104 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
return true;
};
- for (BasicBlock &BB : *Func) {
- if (Blocks.count(&BB))
+ // Look up allocas in the original function in CodeExtractorAnalysisCache, as
+ // this is much faster than walking all the instructions.
+ for (AllocaInst *AI : CEAC.getAllocas()) {
+ BasicBlock *BB = AI->getParent();
+ if (Blocks.count(BB))
continue;
- for (Instruction &II : BB) {
- auto *AI = dyn_cast<AllocaInst>(&II);
- if (!AI)
- continue;
- LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock);
- bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
- if (Moved) {
- LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
- SinkCands.insert(AI);
- continue;
- }
+ // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
+ // check whether it is actually still in the original function.
+ Function *AIFunc = BB->getParent();
+ if (AIFunc != Func)
+ continue;
- // Follow any bitcasts.
- SmallVector<Instruction *, 2> Bitcasts;
- SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
- for (User *U : AI->users()) {
- if (U->stripInBoundsConstantOffsets() == AI) {
- Instruction *Bitcast = cast<Instruction>(U);
- LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock);
- if (LMI.LifeStart) {
- Bitcasts.push_back(Bitcast);
- BitcastLifetimeInfo.push_back(LMI);
- continue;
- }
- }
+ LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, AI, ExitBlock);
+ bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
+ if (Moved) {
+ LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
+ SinkCands.insert(AI);
+ continue;
+ }
- // Found unknown use of AI.
- if (!definedInRegion(Blocks, U)) {
- Bitcasts.clear();
- break;
+ // Follow any bitcasts.
+ SmallVector<Instruction *, 2> Bitcasts;
+ SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
+ for (User *U : AI->users()) {
+ if (U->stripInBoundsConstantOffsets() == AI) {
+ Instruction *Bitcast = cast<Instruction>(U);
+ LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Bitcast, ExitBlock);
+ if (LMI.LifeStart) {
+ Bitcasts.push_back(Bitcast);
+ BitcastLifetimeInfo.push_back(LMI);
+ continue;
}
}
- // Either no bitcasts reference the alloca or there are unknown uses.
- if (Bitcasts.empty())
- continue;
+ // Found unknown use of AI.
+ if (!definedInRegion(Blocks, U)) {
+ Bitcasts.clear();
+ break;
+ }
+ }
- LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
- SinkCands.insert(AI);
- for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
- Instruction *BitcastAddr = Bitcasts[I];
- const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
- assert(LMI.LifeStart &&
- "Unsafe to sink bitcast without lifetime markers");
- moveOrIgnoreLifetimeMarkers(LMI);
- if (!definedInRegion(Blocks, BitcastAddr)) {
- LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
- << "\n");
- SinkCands.insert(BitcastAddr);
- }
+ // Either no bitcasts reference the alloca or there are unknown uses.
+ if (Bitcasts.empty())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
+ SinkCands.insert(AI);
+ for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
+ Instruction *BitcastAddr = Bitcasts[I];
+ const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
+ assert(LMI.LifeStart &&
+ "Unsafe to sink bitcast without lifetime markers");
+ moveOrIgnoreLifetimeMarkers(LMI);
+ if (!definedInRegion(Blocks, BitcastAddr)) {
+ LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
+ << "\n");
+ SinkCands.insert(BitcastAddr);
}
}
}
}
+bool CodeExtractor::isEligible() const {
+ if (Blocks.empty())
+ return false;
+ BasicBlock *Header = *Blocks.begin();
+ Function *F = Header->getParent();
+
+ // For functions with varargs, check that varargs handling is only done in the
+ // outlined function, i.e vastart and vaend are only used in outlined blocks.
+ if (AllowVarArgs && F->getFunctionType()->isVarArg()) {
+ auto containsVarArgIntrinsic = [](const Instruction &I) {
+ if (const CallInst *CI = dyn_cast<CallInst>(&I))
+ if (const Function *Callee = CI->getCalledFunction())
+ return Callee->getIntrinsicID() == Intrinsic::vastart ||
+ Callee->getIntrinsicID() == Intrinsic::vaend;
+ return false;
+ };
+
+ for (auto &BB : *F) {
+ if (Blocks.count(&BB))
+ continue;
+ if (llvm::any_of(BB, containsVarArgIntrinsic))
+ return false;
+ }
+ }
+ return true;
+}
+
void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
const ValueSet &SinkCands) const {
for (BasicBlock *BB : Blocks) {
// If a used value is defined outside the region, it's an input. If an
// instruction is used outside the region, it's an output.
for (Instruction &II : *BB) {
- for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
- ++OI) {
- Value *V = *OI;
+ for (auto &OI : II.operands()) {
+ Value *V = OI;
if (!SinkCands.count(V) && definedInCaller(Blocks, V))
Inputs.insert(V);
}
@@ -904,12 +961,12 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
// within the new function. This must be done before we lose track of which
// blocks were originally in the code region.
std::vector<User *> Users(header->user_begin(), header->user_end());
- for (unsigned i = 0, e = Users.size(); i != e; ++i)
+ for (auto &U : Users)
// The BasicBlock which contains the branch is not in the region
// modify the branch target to a new block
- if (Instruction *I = dyn_cast<Instruction>(Users[i]))
- if (I->isTerminator() && !Blocks.count(I->getParent()) &&
- I->getParent()->getParent() == oldFunction)
+ if (Instruction *I = dyn_cast<Instruction>(U))
+ if (I->isTerminator() && I->getFunction() == oldFunction &&
+ !Blocks.count(I->getParent()))
I->replaceUsesOfWith(header, newHeader);
return newFunction;
@@ -1277,13 +1334,6 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
// Insert this basic block into the new function
newBlocks.push_back(Block);
-
- // Remove @llvm.assume calls that were moved to the new function from the
- // old function's assumption cache.
- if (AC)
- for (auto &I : *Block)
- if (match(&I, m_Intrinsic<Intrinsic::assume>()))
- AC->unregisterAssumption(cast<CallInst>(&I));
}
}
@@ -1332,7 +1382,8 @@ void CodeExtractor::calculateNewCallTerminatorWeights(
MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
}
-Function *CodeExtractor::extractCodeRegion() {
+Function *
+CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
if (!isEligible())
return nullptr;
@@ -1341,27 +1392,6 @@ Function *CodeExtractor::extractCodeRegion() {
BasicBlock *header = *Blocks.begin();
Function *oldFunction = header->getParent();
- // For functions with varargs, check that varargs handling is only done in the
- // outlined function, i.e vastart and vaend are only used in outlined blocks.
- if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
- auto containsVarArgIntrinsic = [](Instruction &I) {
- if (const CallInst *CI = dyn_cast<CallInst>(&I))
- if (const Function *F = CI->getCalledFunction())
- return F->getIntrinsicID() == Intrinsic::vastart ||
- F->getIntrinsicID() == Intrinsic::vaend;
- return false;
- };
-
- for (auto &BB : *oldFunction) {
- if (Blocks.count(&BB))
- continue;
- if (llvm::any_of(BB, containsVarArgIntrinsic))
- return nullptr;
- }
- }
- ValueSet inputs, outputs, SinkingCands, HoistingCands;
- BasicBlock *CommonExit = nullptr;
-
// Calculate the entry frequency of the new function before we change the root
// block.
BlockFrequency EntryFreq;
@@ -1375,6 +1405,15 @@ Function *CodeExtractor::extractCodeRegion() {
}
}
+ if (AC) {
+ // Remove @llvm.assume calls that were moved to the new function from the
+ // old function's assumption cache.
+ for (BasicBlock *Block : Blocks)
+ for (auto &I : *Block)
+ if (match(&I, m_Intrinsic<Intrinsic::assume>()))
+ AC->unregisterAssumption(cast<CallInst>(&I));
+ }
+
// If we have any return instructions in the region, split those blocks so
// that the return is not in the region.
splitReturnBlocks();
@@ -1428,7 +1467,9 @@ Function *CodeExtractor::extractCodeRegion() {
}
newFuncRoot->getInstList().push_back(BranchI);
- findAllocas(SinkingCands, HoistingCands, CommonExit);
+ ValueSet inputs, outputs, SinkingCands, HoistingCands;
+ BasicBlock *CommonExit = nullptr;
+ findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
assert(HoistingCands.empty() || CommonExit);
// Find inputs to, outputs from the code region.
@@ -1563,5 +1604,17 @@ Function *CodeExtractor::extractCodeRegion() {
});
LLVM_DEBUG(if (verifyFunction(*oldFunction))
report_fatal_error("verification of oldFunction failed!"));
+ LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, AC))
+ report_fatal_error("Stale Asumption cache for old Function!"));
return newFunction;
}
+
+bool CodeExtractor::verifyAssumptionCache(const Function& F,
+ AssumptionCache *AC) {
+ for (auto AssumeVH : AC->assumptions()) {
+ CallInst *I = cast<CallInst>(AssumeVH);
+ if (I->getFunction() != &F)
+ return true;
+ }
+ return false;
+}
diff --git a/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/lib/Transforms/Utils/EntryExitInstrumenter.cpp
index 4aa40eeadda4..57e2ff0251a9 100644
--- a/lib/Transforms/Utils/EntryExitInstrumenter.cpp
+++ b/lib/Transforms/Utils/EntryExitInstrumenter.cpp
@@ -24,7 +24,7 @@ static void insertCall(Function &CurFn, StringRef Func,
if (Func == "mcount" ||
Func == ".mcount" ||
- Func == "\01__gnu_mcount_nc" ||
+ Func == "llvm.arm.gnu.eabi.mcount" ||
Func == "\01_mcount" ||
Func == "\01mcount" ||
Func == "__mcount" ||
diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp
index 0e203f4e075d..ad36790b8c6a 100644
--- a/lib/Transforms/Utils/Evaluator.cpp
+++ b/lib/Transforms/Utils/Evaluator.cpp
@@ -469,7 +469,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst,
return false; // Cannot handle array allocs.
}
Type *Ty = AI->getAllocatedType();
- AllocaTmps.push_back(llvm::make_unique<GlobalVariable>(
+ AllocaTmps.push_back(std::make_unique<GlobalVariable>(
Ty, false, GlobalValue::InternalLinkage, UndefValue::get(Ty),
AI->getName(), /*TLMode=*/GlobalValue::NotThreadLocal,
AI->getType()->getPointerAddressSpace()));
diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp
index 0c52e6f3703b..893f23eb6048 100644
--- a/lib/Transforms/Utils/FlattenCFG.cpp
+++ b/lib/Transforms/Utils/FlattenCFG.cpp
@@ -67,7 +67,7 @@ public:
/// Before:
/// ......
/// %cmp10 = fcmp une float %tmp1, %tmp2
-/// br i1 %cmp1, label %if.then, label %lor.rhs
+/// br i1 %cmp10, label %if.then, label %lor.rhs
///
/// lor.rhs:
/// ......
@@ -251,8 +251,8 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) {
bool EverChanged = false;
for (; CurrBlock != FirstCondBlock;
CurrBlock = CurrBlock->getSinglePredecessor()) {
- BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator());
- CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
+ auto *BI = cast<BranchInst>(CurrBlock->getTerminator());
+ auto *CI = dyn_cast<CmpInst>(BI->getCondition());
if (!CI)
continue;
@@ -278,7 +278,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) {
// Do the transformation.
BasicBlock *CB;
- BranchInst *PBI = dyn_cast<BranchInst>(FirstCondBlock->getTerminator());
+ BranchInst *PBI = cast<BranchInst>(FirstCondBlock->getTerminator());
bool Iteration = true;
IRBuilder<>::InsertPointGuard Guard(Builder);
Value *PC = PBI->getCondition();
@@ -444,7 +444,7 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) {
FirstEntryBlock->getInstList().pop_back();
FirstEntryBlock->getInstList()
.splice(FirstEntryBlock->end(), SecondEntryBlock->getInstList());
- BranchInst *PBI = dyn_cast<BranchInst>(FirstEntryBlock->getTerminator());
+ BranchInst *PBI = cast<BranchInst>(FirstEntryBlock->getTerminator());
Value *CC = PBI->getCondition();
BasicBlock *SaveInsertBB = Builder.GetInsertBlock();
BasicBlock::iterator SaveInsertPt = Builder.GetInsertPoint();
@@ -453,6 +453,16 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) {
PBI->replaceUsesOfWith(CC, NC);
Builder.SetInsertPoint(SaveInsertBB, SaveInsertPt);
+ // Handle PHI node to replace its predecessors to FirstEntryBlock.
+ for (BasicBlock *Succ : successors(PBI)) {
+ for (PHINode &Phi : Succ->phis()) {
+ for (unsigned i = 0, e = Phi.getNumIncomingValues(); i != e; ++i) {
+ if (Phi.getIncomingBlock(i) == SecondEntryBlock)
+ Phi.setIncomingBlock(i, FirstEntryBlock);
+ }
+ }
+ }
+
// Remove IfTrue1
if (IfTrue1 != FirstEntryBlock) {
IfTrue1->dropAllReferences();
diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp
index c9cc0990f237..76b4635ad501 100644
--- a/lib/Transforms/Utils/FunctionImportUtils.cpp
+++ b/lib/Transforms/Utils/FunctionImportUtils.cpp
@@ -210,7 +210,7 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) {
if (Function *F = dyn_cast<Function>(&GV)) {
if (!F->isDeclaration()) {
for (auto &S : VI.getSummaryList()) {
- FunctionSummary *FS = dyn_cast<FunctionSummary>(S->getBaseObject());
+ auto *FS = cast<FunctionSummary>(S->getBaseObject());
if (FS->modulePath() == M.getModuleIdentifier()) {
F->setEntryCount(Function::ProfileCount(FS->entryCount(),
Function::PCT_Synthetic));
diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
index 8041e66e6c4c..ea93f99d69e3 100644
--- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
+++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp
@@ -25,8 +25,8 @@ ImportedFunctionsInliningStatistics::createInlineGraphNode(const Function &F) {
auto &ValueLookup = NodesMap[F.getName()];
if (!ValueLookup) {
- ValueLookup = llvm::make_unique<InlineGraphNode>();
- ValueLookup->Imported = F.getMetadata("thinlto_src_module") != nullptr;
+ ValueLookup = std::make_unique<InlineGraphNode>();
+ ValueLookup->Imported = F.hasMetadata("thinlto_src_module");
}
return *ValueLookup;
}
@@ -64,7 +64,7 @@ void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) {
if (F.isDeclaration())
continue;
AllFunctions++;
- ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr);
+ ImportedFunctions += int(F.hasMetadata("thinlto_src_module"));
}
}
static std::string getStatString(const char *Msg, int32_t Fraction, int32_t All,
diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
index 8c67d1dc6eb3..ed28fffc22b5 100644
--- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
+++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
@@ -533,7 +533,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
}
bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
return runImpl(F, TLI, DT);
diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp
index 39b6b889f91c..5bcd05757ec1 100644
--- a/lib/Transforms/Utils/Local.cpp
+++ b/lib/Transforms/Utils/Local.cpp
@@ -324,8 +324,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
Value *Address = IBI->getAddress();
IBI->eraseFromParent();
if (DeleteDeadConditions)
+ // Delete pointer cast instructions.
RecursivelyDeleteTriviallyDeadInstructions(Address, TLI);
+ // Also zap the blockaddress constant if there are no users remaining,
+ // otherwise the destination is still marked as having its address taken.
+ if (BA->use_empty())
+ BA->destroyConstant();
+
// If we didn't find our destination in the IBI successor list, then we
// have undefined behavior. Replace the unconditional branch with an
// 'unreachable' instruction.
@@ -633,17 +639,6 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB,
// Control Flow Graph Restructuring.
//
-/// RemovePredecessorAndSimplify - Like BasicBlock::removePredecessor, this
-/// method is called when we're about to delete Pred as a predecessor of BB. If
-/// BB contains any PHI nodes, this drops the entries in the PHI nodes for Pred.
-///
-/// Unlike the removePredecessor method, this attempts to simplify uses of PHI
-/// nodes that collapse into identity values. For example, if we have:
-/// x = phi(1, 0, 0, 0)
-/// y = and x, z
-///
-/// .. and delete the predecessor corresponding to the '1', this will attempt to
-/// recursively fold the and to 0.
void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred,
DomTreeUpdater *DTU) {
// This only adjusts blocks with PHI nodes.
@@ -672,10 +667,6 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred,
DTU->applyUpdatesPermissive({{DominatorTree::Delete, Pred, BB}});
}
-/// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its
-/// predecessor is known to have one successor (DestBB!). Eliminate the edge
-/// between them, moving the instructions in the predecessor into DestBB and
-/// deleting the predecessor block.
void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB,
DomTreeUpdater *DTU) {
@@ -755,15 +746,14 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB,
}
}
-/// CanMergeValues - Return true if we can choose one of these values to use
-/// in place of the other. Note that we will always choose the non-undef
-/// value to keep.
+/// Return true if we can choose one of these values to use in place of the
+/// other. Note that we will always choose the non-undef value to keep.
static bool CanMergeValues(Value *First, Value *Second) {
return First == Second || isa<UndefValue>(First) || isa<UndefValue>(Second);
}
-/// CanPropagatePredecessorsForPHIs - Return true if we can fold BB, an
-/// almost-empty BB ending in an unconditional branch to Succ, into Succ.
+/// Return true if we can fold BB, an almost-empty BB ending in an unconditional
+/// branch to Succ, into Succ.
///
/// Assumption: Succ is the single successor for BB.
static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) {
@@ -956,11 +946,6 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB,
replaceUndefValuesInPhi(PN, IncomingValues);
}
-/// TryToSimplifyUncondBranchFromEmptyBlock - BB is known to contain an
-/// unconditional branch, and contains no instructions other than PHI nodes,
-/// potential side-effect free intrinsics and the branch. If possible,
-/// eliminate BB by rewriting all the predecessors to branch to the successor
-/// block and return true. If we can't transform, return false.
bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
DomTreeUpdater *DTU) {
assert(BB != &BB->getParent()->getEntryBlock() &&
@@ -1088,10 +1073,6 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
return true;
}
-/// EliminateDuplicatePHINodes - Check for and eliminate duplicate PHI
-/// nodes in this block. This doesn't try to be clever about PHI nodes
-/// which differ only in the order of the incoming values, but instcombine
-/// orders them so it usually won't matter.
bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) {
// This implementation doesn't currently consider undef operands
// specially. Theoretically, two phis which are identical except for
@@ -1151,10 +1132,10 @@ bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) {
/// often possible though. If alignment is important, a more reliable approach
/// is to simply align all global variables and allocation instructions to
/// their preferred alignment from the beginning.
-static unsigned enforceKnownAlignment(Value *V, unsigned Align,
+static unsigned enforceKnownAlignment(Value *V, unsigned Alignment,
unsigned PrefAlign,
const DataLayout &DL) {
- assert(PrefAlign > Align);
+ assert(PrefAlign > Alignment);
V = V->stripPointerCasts();
@@ -1165,36 +1146,36 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align,
// stripPointerCasts recurses through infinite layers of bitcasts,
// while computeKnownBits is not allowed to traverse more than 6
// levels.
- Align = std::max(AI->getAlignment(), Align);
- if (PrefAlign <= Align)
- return Align;
+ Alignment = std::max(AI->getAlignment(), Alignment);
+ if (PrefAlign <= Alignment)
+ return Alignment;
// If the preferred alignment is greater than the natural stack alignment
// then don't round up. This avoids dynamic stack realignment.
- if (DL.exceedsNaturalStackAlignment(PrefAlign))
- return Align;
- AI->setAlignment(PrefAlign);
+ if (DL.exceedsNaturalStackAlignment(Align(PrefAlign)))
+ return Alignment;
+ AI->setAlignment(MaybeAlign(PrefAlign));
return PrefAlign;
}
if (auto *GO = dyn_cast<GlobalObject>(V)) {
// TODO: as above, this shouldn't be necessary.
- Align = std::max(GO->getAlignment(), Align);
- if (PrefAlign <= Align)
- return Align;
+ Alignment = std::max(GO->getAlignment(), Alignment);
+ if (PrefAlign <= Alignment)
+ return Alignment;
// If there is a large requested alignment and we can, bump up the alignment
// of the global. If the memory we set aside for the global may not be the
// memory used by the final program then it is impossible for us to reliably
// enforce the preferred alignment.
if (!GO->canIncreaseAlignment())
- return Align;
+ return Alignment;
- GO->setAlignment(PrefAlign);
+ GO->setAlignment(MaybeAlign(PrefAlign));
return PrefAlign;
}
- return Align;
+ return Alignment;
}
unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign,
@@ -1397,7 +1378,12 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
/// Determine whether this alloca is either a VLA or an array.
static bool isArray(AllocaInst *AI) {
return AI->isArrayAllocation() ||
- AI->getType()->getElementType()->isArrayTy();
+ (AI->getAllocatedType() && AI->getAllocatedType()->isArrayTy());
+}
+
+/// Determine whether this alloca is a structure.
+static bool isStructure(AllocaInst *AI) {
+ return AI->getAllocatedType() && AI->getAllocatedType()->isStructTy();
}
/// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set
@@ -1422,7 +1408,7 @@ bool llvm::LowerDbgDeclare(Function &F) {
// stored on the stack, while the dbg.declare can only describe
// the stack slot (and at a lexical-scope granularity). Later
// passes will attempt to elide the stack slot.
- if (!AI || isArray(AI))
+ if (!AI || isArray(AI) || isStructure(AI))
continue;
// A volatile load/store means that the alloca can't be elided anyway.
@@ -1591,15 +1577,10 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress,
DIExpr->getElement(0) != dwarf::DW_OP_deref)
return;
- // Insert the offset immediately after the first deref.
+ // Insert the offset before the first deref.
// We could just change the offset argument of dbg.value, but it's unsigned...
- if (Offset) {
- SmallVector<uint64_t, 4> Ops;
- Ops.push_back(dwarf::DW_OP_deref);
- DIExpression::appendOffset(Ops, Offset);
- Ops.append(DIExpr->elements_begin() + 1, DIExpr->elements_end());
- DIExpr = Builder.createExpression(Ops);
- }
+ if (Offset)
+ DIExpr = DIExpression::prepend(DIExpr, 0, Offset);
Builder.insertDbgValueIntrinsic(NewAddress, DIVar, DIExpr, Loc, DVI);
DVI->eraseFromParent();
@@ -1957,18 +1938,24 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap,
return NumInstrsRemoved;
}
-/// changeToCall - Convert the specified invoke into a normal call.
-static void changeToCall(InvokeInst *II, DomTreeUpdater *DTU = nullptr) {
- SmallVector<Value*, 8> Args(II->arg_begin(), II->arg_end());
+CallInst *llvm::createCallMatchingInvoke(InvokeInst *II) {
+ SmallVector<Value *, 8> Args(II->arg_begin(), II->arg_end());
SmallVector<OperandBundleDef, 1> OpBundles;
II->getOperandBundlesAsDefs(OpBundles);
- CallInst *NewCall = CallInst::Create(
- II->getFunctionType(), II->getCalledValue(), Args, OpBundles, "", II);
- NewCall->takeName(II);
+ CallInst *NewCall = CallInst::Create(II->getFunctionType(),
+ II->getCalledValue(), Args, OpBundles);
NewCall->setCallingConv(II->getCallingConv());
NewCall->setAttributes(II->getAttributes());
NewCall->setDebugLoc(II->getDebugLoc());
NewCall->copyMetadata(*II);
+ return NewCall;
+}
+
+/// changeToCall - Convert the specified invoke into a normal call.
+void llvm::changeToCall(InvokeInst *II, DomTreeUpdater *DTU) {
+ CallInst *NewCall = createCallMatchingInvoke(II);
+ NewCall->takeName(II);
+ NewCall->insertBefore(II);
II->replaceAllUsesWith(NewCall);
// Follow the call by a branch to the normal destination.
@@ -2223,12 +2210,10 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) {
/// removeUnreachableBlocks - Remove blocks that are not reachable, even
/// if they are in a dead cycle. Return true if a change was made, false
-/// otherwise. If `LVI` is passed, this function preserves LazyValueInfo
-/// after modifying the CFG.
-bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI,
- DomTreeUpdater *DTU,
+/// otherwise.
+bool llvm::removeUnreachableBlocks(Function &F, DomTreeUpdater *DTU,
MemorySSAUpdater *MSSAU) {
- SmallPtrSet<BasicBlock*, 16> Reachable;
+ SmallPtrSet<BasicBlock *, 16> Reachable;
bool Changed = markAliveBlocks(F, Reachable, DTU);
// If there are unreachable blocks in the CFG...
@@ -2236,21 +2221,21 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI,
return Changed;
assert(Reachable.size() < F.size());
- NumRemoved += F.size()-Reachable.size();
+ NumRemoved += F.size() - Reachable.size();
SmallSetVector<BasicBlock *, 8> DeadBlockSet;
- for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ++I) {
- auto *BB = &*I;
- if (Reachable.count(BB))
+ for (BasicBlock &BB : F) {
+ // Skip reachable basic blocks
+ if (Reachable.find(&BB) != Reachable.end())
continue;
- DeadBlockSet.insert(BB);
+ DeadBlockSet.insert(&BB);
}
if (MSSAU)
MSSAU->removeBlocks(DeadBlockSet);
// Loop over all of the basic blocks that are not reachable, dropping all of
- // their internal references. Update DTU and LVI if available.
+ // their internal references. Update DTU if available.
std::vector<DominatorTree::UpdateType> Updates;
for (auto *BB : DeadBlockSet) {
for (BasicBlock *Successor : successors(BB)) {
@@ -2259,26 +2244,18 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI,
if (DTU)
Updates.push_back({DominatorTree::Delete, BB, Successor});
}
- if (LVI)
- LVI->eraseBlock(BB);
BB->dropAllReferences();
- }
- for (Function::iterator I = ++F.begin(); I != F.end();) {
- auto *BB = &*I;
- if (Reachable.count(BB)) {
- ++I;
- continue;
- }
if (DTU) {
- // Remove the terminator of BB to clear the successor list of BB.
- if (BB->getTerminator())
- BB->getInstList().pop_back();
+ Instruction *TI = BB->getTerminator();
+ assert(TI && "Basic block should have a terminator");
+ // Terminators like invoke can have users. We have to replace their users,
+ // before removing them.
+ if (!TI->use_empty())
+ TI->replaceAllUsesWith(UndefValue::get(TI->getType()));
+ TI->eraseFromParent();
new UnreachableInst(BB->getContext(), BB);
assert(succ_empty(BB) && "The successor list of BB isn't empty before "
"applying corresponding DTU updates.");
- ++I;
- } else {
- I = F.getBasicBlockList().erase(I);
}
}
@@ -2294,7 +2271,11 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI,
}
if (!Deleted)
return false;
+ } else {
+ for (auto *BB : DeadBlockSet)
+ BB->eraseFromParent();
}
+
return true;
}
@@ -2363,6 +2344,9 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
K->setMetadata(Kind,
MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
break;
+ case LLVMContext::MD_preserve_access_index:
+ // Preserve !preserve.access.index in K.
+ break;
}
}
// Set !invariant.group from J if J has it. If both instructions have it
@@ -2385,10 +2369,61 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J,
LLVMContext::MD_invariant_group, LLVMContext::MD_align,
LLVMContext::MD_dereferenceable,
LLVMContext::MD_dereferenceable_or_null,
- LLVMContext::MD_access_group};
+ LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index};
combineMetadata(K, J, KnownIDs, KDominatesJ);
}
+void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) {
+ SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
+ Source.getAllMetadata(MD);
+ MDBuilder MDB(Dest.getContext());
+ Type *NewType = Dest.getType();
+ const DataLayout &DL = Source.getModule()->getDataLayout();
+ for (const auto &MDPair : MD) {
+ unsigned ID = MDPair.first;
+ MDNode *N = MDPair.second;
+ // Note, essentially every kind of metadata should be preserved here! This
+ // routine is supposed to clone a load instruction changing *only its type*.
+ // The only metadata it makes sense to drop is metadata which is invalidated
+ // when the pointer type changes. This should essentially never be the case
+ // in LLVM, but we explicitly switch over only known metadata to be
+ // conservatively correct. If you are adding metadata to LLVM which pertains
+ // to loads, you almost certainly want to add it here.
+ switch (ID) {
+ case LLVMContext::MD_dbg:
+ case LLVMContext::MD_tbaa:
+ case LLVMContext::MD_prof:
+ case LLVMContext::MD_fpmath:
+ case LLVMContext::MD_tbaa_struct:
+ case LLVMContext::MD_invariant_load:
+ case LLVMContext::MD_alias_scope:
+ case LLVMContext::MD_noalias:
+ case LLVMContext::MD_nontemporal:
+ case LLVMContext::MD_mem_parallel_loop_access:
+ case LLVMContext::MD_access_group:
+ // All of these directly apply.
+ Dest.setMetadata(ID, N);
+ break;
+
+ case LLVMContext::MD_nonnull:
+ copyNonnullMetadata(Source, N, Dest);
+ break;
+
+ case LLVMContext::MD_align:
+ case LLVMContext::MD_dereferenceable:
+ case LLVMContext::MD_dereferenceable_or_null:
+ // These only directly apply if the new type is also a pointer.
+ if (NewType->isPointerTy())
+ Dest.setMetadata(ID, N);
+ break;
+
+ case LLVMContext::MD_range:
+ copyRangeMetadata(DL, Source, N, Dest);
+ break;
+ }
+ }
+}
+
void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
auto *ReplInst = dyn_cast<Instruction>(Repl);
if (!ReplInst)
@@ -2417,7 +2452,7 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
LLVMContext::MD_noalias, LLVMContext::MD_range,
LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load,
LLVMContext::MD_invariant_group, LLVMContext::MD_nonnull,
- LLVMContext::MD_access_group};
+ LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index};
combineMetadata(ReplInst, I, KnownIDs, false);
}
diff --git a/lib/Transforms/Utils/LoopRotationUtils.cpp b/lib/Transforms/Utils/LoopRotationUtils.cpp
index 37389a695b45..889ea5ca9970 100644
--- a/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -615,30 +615,9 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) {
LLVM_DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into "
<< LastExit->getName() << "\n");
- // Hoist the instructions from Latch into LastExit.
- Instruction *FirstLatchInst = &*(Latch->begin());
- LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(),
- Latch->begin(), Jmp->getIterator());
-
- // Update MemorySSA
- if (MSSAU)
- MSSAU->moveAllAfterMergeBlocks(Latch, LastExit, FirstLatchInst);
-
- unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1;
- BasicBlock *Header = Jmp->getSuccessor(0);
- assert(Header == L->getHeader() && "expected a backward branch");
-
- // Remove Latch from the CFG so that LastExit becomes the new Latch.
- BI->setSuccessor(FallThruPath, Header);
- Latch->replaceSuccessorsPhiUsesWith(LastExit);
- Jmp->eraseFromParent();
-
- // Nuke the Latch block.
- assert(Latch->empty() && "unable to evacuate Latch");
- LI->removeBlock(Latch);
- if (DT)
- DT->eraseNode(Latch);
- Latch->eraseFromParent();
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ MergeBlockIntoPredecessor(Latch, &DTU, LI, MSSAU, nullptr,
+ /*PredecessorWithTwoSuccessors=*/true);
if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp
index 7e6da02d5707..d0f89dc54bfb 100644
--- a/lib/Transforms/Utils/LoopSimplify.cpp
+++ b/lib/Transforms/Utils/LoopSimplify.cpp
@@ -808,7 +808,7 @@ bool LoopSimplify::runOnFunction(Function &F) {
auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
if (MSSAAnalysis) {
MSSA = &MSSAAnalysis->getMSSA();
- MSSAU = make_unique<MemorySSAUpdater>(MSSA);
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
}
}
@@ -835,12 +835,19 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F,
DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F);
AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
+ auto *MSSAAnalysis = AM.getCachedResult<MemorySSAAnalysis>(F);
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
+ if (MSSAAnalysis) {
+ auto *MSSA = &MSSAAnalysis->getMSSA();
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+ }
+
// Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA
- // after simplifying the loops. MemorySSA is not preserved either.
+ // after simplifying the loops. MemorySSA is preserved if it exists.
for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
Changed |=
- simplifyLoop(*I, DT, LI, SE, AC, nullptr, /*PreserveLCSSA*/ false);
+ simplifyLoop(*I, DT, LI, SE, AC, MSSAU.get(), /*PreserveLCSSA*/ false);
if (!Changed)
return PreservedAnalyses::all();
@@ -853,6 +860,8 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F,
PA.preserve<SCEVAA>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<DependenceAnalysis>();
+ if (MSSAAnalysis)
+ PA.preserve<MemorySSAAnalysis>();
// BPI maps conditional terminators to probabilities, LoopSimplify can insert
// blocks, but it does so only by splitting existing blocks and edges. This
// results in the interesting property that all new terminators inserted are
diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp
index e39ade523714..a7590fc32545 100644
--- a/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/lib/Transforms/Utils/LoopUnroll.cpp
@@ -711,7 +711,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
auto setDest = [LoopExit, ContinueOnTrue](BasicBlock *Src, BasicBlock *Dest,
ArrayRef<BasicBlock *> NextBlocks,
- BasicBlock *CurrentHeader,
+ BasicBlock *BlockInLoop,
bool NeedConditional) {
auto *Term = cast<BranchInst>(Src->getTerminator());
if (NeedConditional) {
@@ -723,7 +723,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
if (Dest != LoopExit) {
BasicBlock *BB = Src;
for (BasicBlock *Succ : successors(BB)) {
- if (Succ == CurrentHeader)
+ // Preserve the incoming value from BB if we are jumping to the block
+ // in the current loop.
+ if (Succ == BlockInLoop)
continue;
for (PHINode &Phi : Succ->phis())
Phi.removeIncomingValue(BB, false);
@@ -794,7 +796,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
// unconditional branch for some iterations.
NeedConditional = false;
- setDest(Headers[i], Dest, Headers, Headers[i], NeedConditional);
+ setDest(Headers[i], Dest, Headers, HeaderSucc[i], NeedConditional);
}
// Set up latches to branch to the new header in the unrolled iterations or
@@ -868,7 +870,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
assert(!DT || !UnrollVerifyDomtree ||
DT->verify(DominatorTree::VerificationLevel::Fast));
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
// Merge adjacent basic blocks, if possible.
for (BasicBlock *Latch : Latches) {
BranchInst *Term = dyn_cast<BranchInst>(Latch->getTerminator());
@@ -888,6 +890,8 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
}
}
}
+ // Apply updates to the DomTree.
+ DT = &DTU.getDomTree();
// At this point, the code is well formed. We now simplify the unrolled loop,
// doing constant propagation and dead code elimination as we go.
diff --git a/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/lib/Transforms/Utils/LoopUnrollAndJam.cpp
index ff49d83f25c5..bf2e87b0d49f 100644
--- a/lib/Transforms/Utils/LoopUnrollAndJam.cpp
+++ b/lib/Transforms/Utils/LoopUnrollAndJam.cpp
@@ -517,6 +517,7 @@ LoopUnrollResult llvm::UnrollAndJamLoop(
movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
}
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
// Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
// new ones required.
if (Count != 1) {
@@ -530,7 +531,7 @@ LoopUnrollResult llvm::UnrollAndJamLoop(
ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
SubLoopBlocksLast.back(), AftBlocksFirst[0]);
- DT->applyUpdates(DTUpdates);
+ DTU.applyUpdatesPermissive(DTUpdates);
}
// Merge adjacent basic blocks, if possible.
@@ -538,7 +539,6 @@ LoopUnrollResult llvm::UnrollAndJamLoop(
MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
while (!MergeBlocks.empty()) {
BasicBlock *BB = *MergeBlocks.begin();
BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
@@ -555,6 +555,8 @@ LoopUnrollResult llvm::UnrollAndJamLoop(
} else
MergeBlocks.erase(BB);
}
+ // Apply updates to the DomTree.
+ DT = &DTU.getDomTree();
// At this point, the code is well formed. We now do a quick sweep over the
// inserted code, doing constant propagation and dead code elimination as we
diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp
index 005306cf1898..58e42074f963 100644
--- a/lib/Transforms/Utils/LoopUnrollPeel.cpp
+++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp
@@ -62,9 +62,11 @@ static cl::opt<unsigned> UnrollForcePeelCount(
cl::desc("Force a peel count regardless of profiling information."));
static cl::opt<bool> UnrollPeelMultiDeoptExit(
- "unroll-peel-multi-deopt-exit", cl::init(false), cl::Hidden,
+ "unroll-peel-multi-deopt-exit", cl::init(true), cl::Hidden,
cl::desc("Allow peeling of loops with multiple deopt exits."));
+static const char *PeeledCountMetaData = "llvm.loop.peeled.count";
+
// Designates that a Phi is estimated to become invariant after an "infinite"
// number of loop iterations (i.e. only may become an invariant if the loop is
// fully unrolled).
@@ -275,6 +277,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
LLVM_DEBUG(dbgs() << "Force-peeling first " << UnrollForcePeelCount
<< " iterations.\n");
UP.PeelCount = UnrollForcePeelCount;
+ UP.PeelProfiledIterations = true;
return;
}
@@ -282,6 +285,13 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
if (!UP.AllowPeeling)
return;
+ unsigned AlreadyPeeled = 0;
+ if (auto Peeled = getOptionalIntLoopAttribute(L, PeeledCountMetaData))
+ AlreadyPeeled = *Peeled;
+ // Stop if we already peeled off the maximum number of iterations.
+ if (AlreadyPeeled >= UnrollPeelMaxCount)
+ return;
+
// Here we try to get rid of Phis which become invariants after 1, 2, ..., N
// iterations of the loop. For this we compute the number for iterations after
// which every Phi is guaranteed to become an invariant, and try to peel the
@@ -317,11 +327,14 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount);
// Consider max peel count limitation.
assert(DesiredPeelCount > 0 && "Wrong loop size estimation?");
- LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount
- << " iteration(s) to turn"
- << " some Phis into invariants.\n");
- UP.PeelCount = DesiredPeelCount;
- return;
+ if (DesiredPeelCount + AlreadyPeeled <= UnrollPeelMaxCount) {
+ LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount
+ << " iteration(s) to turn"
+ << " some Phis into invariants.\n");
+ UP.PeelCount = DesiredPeelCount;
+ UP.PeelProfiledIterations = false;
+ return;
+ }
}
}
@@ -330,6 +343,9 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
if (TripCount)
return;
+ // Do not apply profile base peeling if it is disabled.
+ if (!UP.PeelProfiledIterations)
+ return;
// If we don't know the trip count, but have reason to believe the average
// trip count is low, peeling should be beneficial, since we will usually
// hit the peeled section.
@@ -344,7 +360,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
<< "\n");
if (*PeelCount) {
- if ((*PeelCount <= UnrollPeelMaxCount) &&
+ if ((*PeelCount + AlreadyPeeled <= UnrollPeelMaxCount) &&
(LoopSize * (*PeelCount + 1) <= UP.Threshold)) {
LLVM_DEBUG(dbgs() << "Peeling first " << *PeelCount
<< " iterations.\n");
@@ -352,6 +368,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
return;
}
LLVM_DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n");
+ LLVM_DEBUG(dbgs() << "Already peel count: " << AlreadyPeeled << "\n");
LLVM_DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n");
LLVM_DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1)
<< "\n");
@@ -364,88 +381,77 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
/// iteration.
/// This sets the branch weights for the latch of the recently peeled off loop
/// iteration correctly.
-/// Our goal is to make sure that:
-/// a) The total weight of all the copies of the loop body is preserved.
-/// b) The total weight of the loop exit is preserved.
-/// c) The body weight is reasonably distributed between the peeled iterations.
+/// Let F is a weight of the edge from latch to header.
+/// Let E is a weight of the edge from latch to exit.
+/// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to
+/// go to exit.
+/// Then, Estimated TripCount = F / E.
+/// For I-th (counting from 0) peeled off iteration we set the the weights for
+/// the peeled latch as (TC - I, 1). It gives us reasonable distribution,
+/// The probability to go to exit 1/(TC-I) increases. At the same time
+/// the estimated trip count of remaining loop reduces by I.
+/// To avoid dealing with division rounding we can just multiple both part
+/// of weights to E and use weight as (F - I * E, E).
///
/// \param Header The copy of the header block that belongs to next iteration.
/// \param LatchBR The copy of the latch branch that belongs to this iteration.
-/// \param IterNumber The serial number of the iteration that was just
-/// peeled off.
-/// \param AvgIters The average number of iterations we expect the loop to have.
-/// \param[in,out] PeeledHeaderWeight The total number of dynamic loop
-/// iterations that are unaccounted for. As an input, it represents the number
-/// of times we expect to enter the header of the iteration currently being
-/// peeled off. The output is the number of times we expect to enter the
-/// header of the next iteration.
+/// \param[in,out] FallThroughWeight The weight of the edge from latch to
+/// header before peeling (in) and after peeled off one iteration (out).
static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- unsigned IterNumber, unsigned AvgIters,
- uint64_t &PeeledHeaderWeight) {
- if (!PeeledHeaderWeight)
+ uint64_t ExitWeight,
+ uint64_t &FallThroughWeight) {
+ // FallThroughWeight is 0 means that there is no branch weights on original
+ // latch block or estimated trip count is zero.
+ if (!FallThroughWeight)
return;
- // FIXME: Pick a more realistic distribution.
- // Currently the proportion of weight we assign to the fall-through
- // side of the branch drops linearly with the iteration number, and we use
- // a 0.9 fudge factor to make the drop-off less sharp...
- uint64_t FallThruWeight =
- PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9);
- uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight;
- PeeledHeaderWeight -= ExitWeight;
unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1);
MDBuilder MDB(LatchBR->getContext());
MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight)
- : MDB.createBranchWeights(FallThruWeight, ExitWeight);
+ HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
+ : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
+ FallThroughWeight =
+ FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1;
}
/// Initialize the weights.
///
/// \param Header The header block.
/// \param LatchBR The latch branch.
-/// \param AvgIters The average number of iterations we expect the loop to have.
-/// \param[out] ExitWeight The # of times the edge from Latch to Exit is taken.
-/// \param[out] CurHeaderWeight The # of times the header is executed.
+/// \param[out] ExitWeight The weight of the edge from Latch to Exit.
+/// \param[out] FallThroughWeight The weight of the edge from Latch to Header.
static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- unsigned AvgIters, uint64_t &ExitWeight,
- uint64_t &CurHeaderWeight) {
+ uint64_t &ExitWeight,
+ uint64_t &FallThroughWeight) {
uint64_t TrueWeight, FalseWeight;
if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
return;
unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;
- // The # of times the loop body executes is the sum of the exit block
- // is taken and the # of times the backedges are taken.
- CurHeaderWeight = TrueWeight + FalseWeight;
+ FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight;
}
/// Update the weights of original Latch block after peeling off all iterations.
///
/// \param Header The header block.
/// \param LatchBR The latch branch.
-/// \param ExitWeight The weight of the edge from Latch to Exit block.
-/// \param CurHeaderWeight The # of time the header is executed.
+/// \param ExitWeight The weight of the edge from Latch to Exit.
+/// \param FallThroughWeight The weight of the edge from Latch to Header.
static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- uint64_t ExitWeight, uint64_t CurHeaderWeight) {
- // Adjust the branch weights on the loop exit.
- if (!ExitWeight)
+ uint64_t ExitWeight,
+ uint64_t FallThroughWeight) {
+ // FallThroughWeight is 0 means that there is no branch weights on original
+ // latch block or estimated trip count is zero.
+ if (!FallThroughWeight)
return;
- // The backedge count is the difference of current header weight and
- // current loop exit weight. If the current header weight is smaller than
- // the current loop exit weight, we mark the loop backedge weight as 1.
- uint64_t BackEdgeWeight = 0;
- if (ExitWeight < CurHeaderWeight)
- BackEdgeWeight = CurHeaderWeight - ExitWeight;
- else
- BackEdgeWeight = 1;
+ // Sets the branch weights on the loop exit.
MDBuilder MDB(LatchBR->getContext());
unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight)
- : MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
+ HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
+ : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
}
@@ -586,11 +592,30 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
DenseMap<BasicBlock *, BasicBlock *> ExitIDom;
if (DT) {
+ // We'd like to determine the idom of exit block after peeling one
+ // iteration.
+ // Let Exit is exit block.
+ // Let ExitingSet - is a set of predecessors of Exit block. They are exiting
+ // blocks.
+ // Let Latch' and ExitingSet' are copies after a peeling.
+ // We'd like to find an idom'(Exit) - idom of Exit after peeling.
+ // It is an evident that idom'(Exit) will be the nearest common dominator
+ // of ExitingSet and ExitingSet'.
+ // idom(Exit) is a nearest common dominator of ExitingSet.
+ // idom(Exit)' is a nearest common dominator of ExitingSet'.
+ // Taking into account that we have a single Latch, Latch' will dominate
+ // Header and idom(Exit).
+ // So the idom'(Exit) is nearest common dominator of idom(Exit)' and Latch'.
+ // All these basic blocks are in the same loop, so what we find is
+ // (nearest common dominator of idom(Exit) and Latch)'.
+ // In the loop below we remember nearest common dominator of idom(Exit) and
+ // Latch to update idom of Exit later.
assert(L->hasDedicatedExits() && "No dedicated exits?");
for (auto Edge : ExitEdges) {
if (ExitIDom.count(Edge.second))
continue;
- BasicBlock *BB = DT->getNode(Edge.second)->getIDom()->getBlock();
+ BasicBlock *BB = DT->findNearestCommonDominator(
+ DT->getNode(Edge.second)->getIDom()->getBlock(), Latch);
assert(L->contains(BB) && "IDom is not in a loop");
ExitIDom[Edge.second] = BB;
}
@@ -659,23 +684,14 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
// newly created branches.
BranchInst *LatchBR =
cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator());
- uint64_t ExitWeight = 0, CurHeaderWeight = 0;
- initBranchWeights(Header, LatchBR, PeelCount, ExitWeight, CurHeaderWeight);
+ uint64_t ExitWeight = 0, FallThroughWeight = 0;
+ initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
// For each peeled-off iteration, make a copy of the loop.
for (unsigned Iter = 0; Iter < PeelCount; ++Iter) {
SmallVector<BasicBlock *, 8> NewBlocks;
ValueToValueMapTy VMap;
- // Subtract the exit weight from the current header weight -- the exit
- // weight is exactly the weight of the previous iteration's header.
- // FIXME: due to the way the distribution is constructed, we need a
- // guard here to make sure we don't end up with non-positive weights.
- if (ExitWeight < CurHeaderWeight)
- CurHeaderWeight -= ExitWeight;
- else
- CurHeaderWeight = 1;
-
cloneLoopBlocks(L, Iter, InsertTop, InsertBot, ExitEdges, NewBlocks,
LoopBlocks, VMap, LVMap, DT, LI);
@@ -697,8 +713,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
}
auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]);
- updateBranchWeights(InsertBot, LatchBRCopy, Iter,
- PeelCount, ExitWeight);
+ updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight);
// Remove Loop metadata from the latch branch instruction
// because it is not the Loop's latch branch anymore.
LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr);
@@ -724,7 +739,13 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
- fixupBranchWeights(Header, LatchBR, ExitWeight, CurHeaderWeight);
+ fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
+
+ // Update Metadata for count of peeled off iterations.
+ unsigned AlreadyPeeled = 0;
+ if (auto Peeled = getOptionalIntLoopAttribute(L, PeeledCountMetaData))
+ AlreadyPeeled = *Peeled;
+ addStringMetadataToLoop(L, PeeledCountMetaData, AlreadyPeeled + PeelCount);
if (Loop *ParentLoop = L->getParentLoop())
L = ParentLoop;
diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp
index ec226e65f650..b4d7f35d2d9a 100644
--- a/lib/Transforms/Utils/LoopUtils.cpp
+++ b/lib/Transforms/Utils/LoopUtils.cpp
@@ -19,6 +19,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/MustExecute.h"
#include "llvm/Analysis/ScalarEvolution.h"
@@ -45,6 +46,7 @@ using namespace llvm::PatternMatch;
#define DEBUG_TYPE "loop-utils"
static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced";
+static const char *LLVMLoopDisableLICM = "llvm.licm.disable";
bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
MemorySSAUpdater *MSSAU,
@@ -169,6 +171,8 @@ void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
AU.addPreserved<SCEVAAWrapperPass>();
AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
+ // FIXME: When all loop passes preserve MemorySSA, it can be required and
+ // preserved here instead of the individual handling in each pass.
}
/// Manually defined generic "LoopPass" dependency initialization. This is used
@@ -189,6 +193,54 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) {
INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
+}
+
+/// Create MDNode for input string.
+static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
+ LLVMContext &Context = TheLoop->getHeader()->getContext();
+ Metadata *MDs[] = {
+ MDString::get(Context, Name),
+ ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))};
+ return MDNode::get(Context, MDs);
+}
+
+/// Set input string into loop metadata by keeping other values intact.
+/// If the string is already in loop metadata update value if it is
+/// different.
+void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
+ unsigned V) {
+ SmallVector<Metadata *, 4> MDs(1);
+ // If the loop already has metadata, retain it.
+ MDNode *LoopID = TheLoop->getLoopID();
+ if (LoopID) {
+ for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
+ MDNode *Node = cast<MDNode>(LoopID->getOperand(i));
+ // If it is of form key = value, try to parse it.
+ if (Node->getNumOperands() == 2) {
+ MDString *S = dyn_cast<MDString>(Node->getOperand(0));
+ if (S && S->getString().equals(StringMD)) {
+ ConstantInt *IntMD =
+ mdconst::extract_or_null<ConstantInt>(Node->getOperand(1));
+ if (IntMD && IntMD->getSExtValue() == V)
+ // It is already in place. Do nothing.
+ return;
+ // We need to update the value, so just skip it here and it will
+ // be added after copying other existed nodes.
+ continue;
+ }
+ }
+ MDs.push_back(Node);
+ }
+ }
+ // Add new metadata.
+ MDs.push_back(createStringMetadata(TheLoop, StringMD, V));
+ // Replace current metadata node with new one.
+ LLVMContext &Context = TheLoop->getHeader()->getContext();
+ MDNode *NewLoopID = MDNode::get(Context, MDs);
+ // Set operand 0 to refer to the loop id itself.
+ NewLoopID->replaceOperandWith(0, NewLoopID);
+ TheLoop->setLoopID(NewLoopID);
}
/// Find string metadata for loop
@@ -332,6 +384,10 @@ bool llvm::hasDisableAllTransformsHint(const Loop *L) {
return getBooleanLoopAttribute(L, LLVMLoopDisableNonforced);
}
+bool llvm::hasDisableLICMTransformsHint(const Loop *L) {
+ return getBooleanLoopAttribute(L, LLVMLoopDisableLICM);
+}
+
TransformationMode llvm::hasUnrollTransformation(Loop *L) {
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable"))
return TM_SuppressedByUser;
diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp
index a9a480a4b7f9..5d7759056c7d 100644
--- a/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/lib/Transforms/Utils/LoopVersioning.cpp
@@ -92,8 +92,8 @@ void LoopVersioning::versionLoop(
// Create empty preheader for the loop (and after cloning for the
// non-versioned loop).
BasicBlock *PH =
- SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
- PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
+ SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
+ nullptr, VersionedLoop->getHeader()->getName() + ".ph");
// Clone the loop including the preheader.
//
diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp
index c0b7edc547fd..60bb2775a194 100644
--- a/lib/Transforms/Utils/MetaRenamer.cpp
+++ b/lib/Transforms/Utils/MetaRenamer.cpp
@@ -121,15 +121,14 @@ namespace {
}
// Rename all functions
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
for (auto &F : M) {
StringRef Name = F.getName();
LibFunc Tmp;
// Leave library functions alone because their presence or absence could
// affect the behavior of other passes.
if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
- TLI.getLibFunc(F, Tmp))
+ getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F).getLibFunc(
+ F, Tmp))
continue;
// Leave @main alone. The output of -metarenamer might be passed to
diff --git a/lib/Transforms/Utils/MisExpect.cpp b/lib/Transforms/Utils/MisExpect.cpp
new file mode 100644
index 000000000000..26d3402bd279
--- /dev/null
+++ b/lib/Transforms/Utils/MisExpect.cpp
@@ -0,0 +1,177 @@
+//===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This contains code to emit warnings for potentially incorrect usage of the
+// llvm.expect intrinsic. This utility extracts the threshold values from
+// metadata associated with the instrumented Branch or Switch instruction. The
+// threshold values are then used to determine if a warning should be emmited.
+//
+// MisExpect metadata is generated when llvm.expect intrinsics are lowered see
+// LowerExpectIntrinsic.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/MisExpect.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <functional>
+#include <numeric>
+
+#define DEBUG_TYPE "misexpect"
+
+using namespace llvm;
+using namespace misexpect;
+
+namespace llvm {
+
+// Command line option to enable/disable the warning when profile data suggests
+// a mismatch with the use of the llvm.expect intrinsic
+static cl::opt<bool> PGOWarnMisExpect(
+ "pgo-warn-misexpect", cl::init(false), cl::Hidden,
+ cl::desc("Use this option to turn on/off "
+ "warnings about incorrect usage of llvm.expect intrinsics."));
+
+} // namespace llvm
+
+namespace {
+
+Instruction *getOprndOrInst(Instruction *I) {
+ assert(I != nullptr && "MisExpect target Instruction cannot be nullptr");
+ Instruction *Ret = nullptr;
+ if (auto *B = dyn_cast<BranchInst>(I)) {
+ Ret = dyn_cast<Instruction>(B->getCondition());
+ }
+ // TODO: Find a way to resolve condition location for switches
+ // Using the condition of the switch seems to often resolve to an earlier
+ // point in the program, i.e. the calculation of the switch condition, rather
+ // than the switches location in the source code. Thus, we should use the
+ // instruction to get source code locations rather than the condition to
+ // improve diagnostic output, such as the caret. If the same problem exists
+ // for branch instructions, then we should remove this function and directly
+ // use the instruction
+ //
+ // else if (auto S = dyn_cast<SwitchInst>(I)) {
+ // Ret = I;
+ //}
+ return Ret ? Ret : I;
+}
+
+void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
+ uint64_t ProfCount, uint64_t TotalCount) {
+ double PercentageCorrect = (double)ProfCount / TotalCount;
+ auto PerString =
+ formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount);
+ auto RemStr = formatv(
+ "Potential performance regression from use of the llvm.expect intrinsic: "
+ "Annotation was correct on {0} of profiled executions.",
+ PerString);
+ Twine Msg(PerString);
+ Instruction *Cond = getOprndOrInst(I);
+ if (PGOWarnMisExpect)
+ Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg));
+ OptimizationRemarkEmitter ORE(I->getParent()->getParent());
+ ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str());
+}
+
+} // namespace
+
+namespace llvm {
+namespace misexpect {
+
+void verifyMisExpect(Instruction *I, const SmallVector<uint32_t, 4> &Weights,
+ LLVMContext &Ctx) {
+ if (auto *MisExpectData = I->getMetadata(LLVMContext::MD_misexpect)) {
+ auto *MisExpectDataName = dyn_cast<MDString>(MisExpectData->getOperand(0));
+ if (MisExpectDataName &&
+ MisExpectDataName->getString().equals("misexpect")) {
+ LLVM_DEBUG(llvm::dbgs() << "------------------\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "Function: " << I->getFunction()->getName() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "Instruction: " << *I << ":\n");
+ LLVM_DEBUG(for (int Idx = 0, Size = Weights.size(); Idx < Size; ++Idx) {
+ llvm::dbgs() << "Weights[" << Idx << "] = " << Weights[Idx] << "\n";
+ });
+
+ // extract values from misexpect metadata
+ const auto *IndexCint =
+ mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(1));
+ const auto *LikelyCInt =
+ mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(2));
+ const auto *UnlikelyCInt =
+ mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(3));
+
+ if (!IndexCint || !LikelyCInt || !UnlikelyCInt)
+ return;
+
+ const uint64_t Index = IndexCint->getZExtValue();
+ const uint64_t LikelyBranchWeight = LikelyCInt->getZExtValue();
+ const uint64_t UnlikelyBranchWeight = UnlikelyCInt->getZExtValue();
+ const uint64_t ProfileCount = Weights[Index];
+ const uint64_t CaseTotal = std::accumulate(
+ Weights.begin(), Weights.end(), (uint64_t)0, std::plus<uint64_t>());
+ const uint64_t NumUnlikelyTargets = Weights.size() - 1;
+
+ const uint64_t TotalBranchWeight =
+ LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets);
+
+ const llvm::BranchProbability LikelyThreshold(LikelyBranchWeight,
+ TotalBranchWeight);
+ uint64_t ScaledThreshold = LikelyThreshold.scale(CaseTotal);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unlikely Targets: " << NumUnlikelyTargets << ":\n");
+ LLVM_DEBUG(llvm::dbgs() << "Profile Count: " << ProfileCount << ":\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "Scaled Threshold: " << ScaledThreshold << ":\n");
+ LLVM_DEBUG(llvm::dbgs() << "------------------\n");
+ if (ProfileCount < ScaledThreshold)
+ emitMisexpectDiagnostic(I, Ctx, ProfileCount, CaseTotal);
+ }
+ }
+}
+
+void checkFrontendInstrumentation(Instruction &I) {
+ if (auto *MD = I.getMetadata(LLVMContext::MD_prof)) {
+ unsigned NOps = MD->getNumOperands();
+
+ // Only emit misexpect diagnostics if at least 2 branch weights are present.
+ // Less than 2 branch weights means that the profiling metadata is:
+ // 1) incorrect/corrupted
+ // 2) not branch weight metadata
+ // 3) completely deterministic
+ // In these cases we should not emit any diagnostic related to misexpect.
+ if (NOps < 3)
+ return;
+
+ // Operand 0 is a string tag "branch_weights"
+ if (MDString *Tag = cast<MDString>(MD->getOperand(0))) {
+ if (Tag->getString().equals("branch_weights")) {
+ SmallVector<uint32_t, 4> RealWeights(NOps - 1);
+ for (unsigned i = 1; i < NOps; i++) {
+ ConstantInt *Value =
+ mdconst::dyn_extract<ConstantInt>(MD->getOperand(i));
+ RealWeights[i - 1] = Value->getZExtValue();
+ }
+ verifyMisExpect(&I, RealWeights, I.getContext());
+ }
+ }
+ }
+}
+
+} // namespace misexpect
+} // namespace llvm
+#undef DEBUG_TYPE
diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp
index c84beceee191..1ef3757017a8 100644
--- a/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/lib/Transforms/Utils/ModuleUtils.cpp
@@ -73,7 +73,7 @@ static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *>
SmallPtrSet<Constant *, 16> InitAsSet;
SmallVector<Constant *, 16> Init;
if (GV) {
- ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer());
+ auto *CA = cast<ConstantArray>(GV->getInitializer());
for (auto &Op : CA->operands()) {
Constant *C = cast_or_null<Constant>(Op);
if (InitAsSet.insert(C).second)
diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp
index bdf24d80bd17..44859eafb9c1 100644
--- a/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/lib/Transforms/Utils/PredicateInfo.cpp
@@ -125,8 +125,10 @@ static bool valueComesBefore(OrderedInstructions &OI, const Value *A,
// necessary to compare uses/defs in the same block. Doing so allows us to walk
// the minimum number of instructions necessary to compute our def/use ordering.
struct ValueDFS_Compare {
+ DominatorTree &DT;
OrderedInstructions &OI;
- ValueDFS_Compare(OrderedInstructions &OI) : OI(OI) {}
+ ValueDFS_Compare(DominatorTree &DT, OrderedInstructions &OI)
+ : DT(DT), OI(OI) {}
bool operator()(const ValueDFS &A, const ValueDFS &B) const {
if (&A == &B)
@@ -136,7 +138,9 @@ struct ValueDFS_Compare {
// comesbefore to see what the real ordering is, because they are in the
// same basic block.
- bool SameBlock = std::tie(A.DFSIn, A.DFSOut) == std::tie(B.DFSIn, B.DFSOut);
+ assert((A.DFSIn != B.DFSIn || A.DFSOut == B.DFSOut) &&
+ "Equal DFS-in numbers imply equal out numbers");
+ bool SameBlock = A.DFSIn == B.DFSIn;
// We want to put the def that will get used for a given set of phi uses,
// before those phi uses.
@@ -145,9 +149,11 @@ struct ValueDFS_Compare {
if (SameBlock && A.LocalNum == LN_Last && B.LocalNum == LN_Last)
return comparePHIRelated(A, B);
+ bool isADef = A.Def;
+ bool isBDef = B.Def;
if (!SameBlock || A.LocalNum != LN_Middle || B.LocalNum != LN_Middle)
- return std::tie(A.DFSIn, A.DFSOut, A.LocalNum, A.Def, A.U) <
- std::tie(B.DFSIn, B.DFSOut, B.LocalNum, B.Def, B.U);
+ return std::tie(A.DFSIn, A.LocalNum, isADef) <
+ std::tie(B.DFSIn, B.LocalNum, isBDef);
return localComesBefore(A, B);
}
@@ -164,10 +170,35 @@ struct ValueDFS_Compare {
// For two phi related values, return the ordering.
bool comparePHIRelated(const ValueDFS &A, const ValueDFS &B) const {
- auto &ABlockEdge = getBlockEdge(A);
- auto &BBlockEdge = getBlockEdge(B);
- // Now sort by block edge and then defs before uses.
- return std::tie(ABlockEdge, A.Def, A.U) < std::tie(BBlockEdge, B.Def, B.U);
+ BasicBlock *ASrc, *ADest, *BSrc, *BDest;
+ std::tie(ASrc, ADest) = getBlockEdge(A);
+ std::tie(BSrc, BDest) = getBlockEdge(B);
+
+#ifndef NDEBUG
+ // This function should only be used for values in the same BB, check that.
+ DomTreeNode *DomASrc = DT.getNode(ASrc);
+ DomTreeNode *DomBSrc = DT.getNode(BSrc);
+ assert(DomASrc->getDFSNumIn() == (unsigned)A.DFSIn &&
+ "DFS numbers for A should match the ones of the source block");
+ assert(DomBSrc->getDFSNumIn() == (unsigned)B.DFSIn &&
+ "DFS numbers for B should match the ones of the source block");
+ assert(A.DFSIn == B.DFSIn && "Values must be in the same block");
+#endif
+ (void)ASrc;
+ (void)BSrc;
+
+ // Use DFS numbers to compare destination blocks, to guarantee a
+ // deterministic order.
+ DomTreeNode *DomADest = DT.getNode(ADest);
+ DomTreeNode *DomBDest = DT.getNode(BDest);
+ unsigned AIn = DomADest->getDFSNumIn();
+ unsigned BIn = DomBDest->getDFSNumIn();
+ bool isADef = A.Def;
+ bool isBDef = B.Def;
+ assert((!A.Def || !A.U) && (!B.Def || !B.U) &&
+ "Def and U cannot be set at the same time");
+ // Now sort by edge destination and then defs before uses.
+ return std::tie(AIn, isADef) < std::tie(BIn, isBDef);
}
// Get the definition of an instruction that occurs in the middle of a block.
@@ -306,10 +337,11 @@ void collectCmpOps(CmpInst *Comparison, SmallVectorImpl<Value *> &CmpOperands) {
}
// Add Op, PB to the list of value infos for Op, and mark Op to be renamed.
-void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op,
+void PredicateInfo::addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op,
PredicateBase *PB) {
- OpsToRename.insert(Op);
auto &OperandInfo = getOrCreateValueInfo(Op);
+ if (OperandInfo.Infos.empty())
+ OpsToRename.push_back(Op);
AllInfos.push_back(PB);
OperandInfo.Infos.push_back(PB);
}
@@ -317,7 +349,7 @@ void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op,
// Process an assume instruction and place relevant operations we want to rename
// into OpsToRename.
void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB,
- SmallPtrSetImpl<Value *> &OpsToRename) {
+ SmallVectorImpl<Value *> &OpsToRename) {
// See if we have a comparison we support
SmallVector<Value *, 8> CmpOperands;
SmallVector<Value *, 2> ConditionsToProcess;
@@ -357,7 +389,7 @@ void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB,
// Process a block terminating branch, and place relevant operations to be
// renamed into OpsToRename.
void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB,
- SmallPtrSetImpl<Value *> &OpsToRename) {
+ SmallVectorImpl<Value *> &OpsToRename) {
BasicBlock *FirstBB = BI->getSuccessor(0);
BasicBlock *SecondBB = BI->getSuccessor(1);
SmallVector<BasicBlock *, 2> SuccsToProcess;
@@ -427,7 +459,7 @@ void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB,
// Process a block terminating switch, and place relevant operations to be
// renamed into OpsToRename.
void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB,
- SmallPtrSetImpl<Value *> &OpsToRename) {
+ SmallVectorImpl<Value *> &OpsToRename) {
Value *Op = SI->getCondition();
if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse())
return;
@@ -457,7 +489,7 @@ void PredicateInfo::buildPredicateInfo() {
DT.updateDFSNumbers();
// Collect operands to rename from all conditional branch terminators, as well
// as assume statements.
- SmallPtrSet<Value *, 8> OpsToRename;
+ SmallVector<Value *, 8> OpsToRename;
for (auto DTN : depth_first(DT.getRootNode())) {
BasicBlock *BranchBB = DTN->getBlock();
if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) {
@@ -524,7 +556,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter,
if (isa<PredicateWithEdge>(ValInfo)) {
IRBuilder<> B(getBranchTerminator(ValInfo));
Function *IF = getCopyDeclaration(F.getParent(), Op->getType());
- if (empty(IF->users()))
+ if (IF->users().empty())
CreatedDeclarations.insert(IF);
CallInst *PIC =
B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++));
@@ -536,7 +568,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter,
"Should not have gotten here without it being an assume");
IRBuilder<> B(PAssume->AssumeInst);
Function *IF = getCopyDeclaration(F.getParent(), Op->getType());
- if (empty(IF->users()))
+ if (IF->users().empty())
CreatedDeclarations.insert(IF);
CallInst *PIC = B.CreateCall(IF, Op);
PredicateMap.insert({PIC, ValInfo});
@@ -565,14 +597,8 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter,
//
// TODO: Use this algorithm to perform fast single-variable renaming in
// promotememtoreg and memoryssa.
-void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) {
- // Sort OpsToRename since we are going to iterate it.
- SmallVector<Value *, 8> OpsToRename(OpSet.begin(), OpSet.end());
- auto Comparator = [&](const Value *A, const Value *B) {
- return valueComesBefore(OI, A, B);
- };
- llvm::sort(OpsToRename, Comparator);
- ValueDFS_Compare Compare(OI);
+void PredicateInfo::renameUses(SmallVectorImpl<Value *> &OpsToRename) {
+ ValueDFS_Compare Compare(DT, OI);
// Compute liveness, and rename in O(uses) per Op.
for (auto *Op : OpsToRename) {
LLVM_DEBUG(dbgs() << "Visiting " << *Op << "\n");
@@ -772,7 +798,7 @@ static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) {
bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) {
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto PredInfo = make_unique<PredicateInfo>(F, DT, AC);
+ auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC);
PredInfo->print(dbgs());
if (VerifyPredicateInfo)
PredInfo->verifyPredicateInfo();
@@ -786,7 +812,7 @@ PreservedAnalyses PredicateInfoPrinterPass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
OS << "PredicateInfo for function: " << F.getName() << "\n";
- auto PredInfo = make_unique<PredicateInfo>(F, DT, AC);
+ auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC);
PredInfo->print(OS);
replaceCreatedSSACopys(*PredInfo, F);
@@ -845,7 +871,7 @@ PreservedAnalyses PredicateInfoVerifierPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
- make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo();
+ std::make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo();
return PreservedAnalyses::all();
}
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 11651d040dc0..3a5e3293ed4f 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -94,6 +94,12 @@ static cl::opt<unsigned> PHINodeFoldingThreshold(
cl::desc(
"Control the amount of phi node folding to perform (default = 2)"));
+static cl::opt<unsigned> TwoEntryPHINodeFoldingThreshold(
+ "two-entry-phi-node-folding-threshold", cl::Hidden, cl::init(4),
+ cl::desc("Control the maximal total instruction cost that we are willing "
+ "to speculatively execute to fold a 2-entry PHI node into a "
+ "select (default = 4)"));
+
static cl::opt<bool> DupRet(
"simplifycfg-dup-ret", cl::Hidden, cl::init(false),
cl::desc("Duplicate return instructions into unconditional branches"));
@@ -332,7 +338,7 @@ static unsigned ComputeSpeculationCost(const User *I,
/// CostRemaining, false is returned and CostRemaining is undefined.
static bool DominatesMergePoint(Value *V, BasicBlock *BB,
SmallPtrSetImpl<Instruction *> &AggressiveInsts,
- unsigned &CostRemaining,
+ int &BudgetRemaining,
const TargetTransformInfo &TTI,
unsigned Depth = 0) {
// It is possible to hit a zero-cost cycle (phi/gep instructions for example),
@@ -375,7 +381,7 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB,
if (!isSafeToSpeculativelyExecute(I))
return false;
- unsigned Cost = ComputeSpeculationCost(I, TTI);
+ BudgetRemaining -= ComputeSpeculationCost(I, TTI);
// Allow exactly one instruction to be speculated regardless of its cost
// (as long as it is safe to do so).
@@ -383,17 +389,14 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB,
// or other expensive operation. The speculation of an expensive instruction
// is expected to be undone in CodeGenPrepare if the speculation has not
// enabled further IR optimizations.
- if (Cost > CostRemaining &&
+ if (BudgetRemaining < 0 &&
(!SpeculateOneExpensiveInst || !AggressiveInsts.empty() || Depth > 0))
return false;
- // Avoid unsigned wrap.
- CostRemaining = (Cost > CostRemaining) ? 0 : CostRemaining - Cost;
-
// Okay, we can only really hoist these out if their operands do
// not take us over the cost threshold.
for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i)
- if (!DominatesMergePoint(*i, BB, AggressiveInsts, CostRemaining, TTI,
+ if (!DominatesMergePoint(*i, BB, AggressiveInsts, BudgetRemaining, TTI,
Depth + 1))
return false;
// Okay, it's safe to do this! Remember this instruction.
@@ -629,8 +632,7 @@ private:
/// vector.
/// One "Extra" case is allowed to differ from the other.
void gather(Value *V) {
- Instruction *I = dyn_cast<Instruction>(V);
- bool isEQ = (I->getOpcode() == Instruction::Or);
+ bool isEQ = (cast<Instruction>(V)->getOpcode() == Instruction::Or);
// Keep a stack (SmallVector for efficiency) for depth-first traversal
SmallVector<Value *, 8> DFT;
@@ -1313,7 +1315,8 @@ static bool HoistThenElseCodeToIf(BranchInst *BI,
LLVMContext::MD_dereferenceable,
LLVMContext::MD_dereferenceable_or_null,
LLVMContext::MD_mem_parallel_loop_access,
- LLVMContext::MD_access_group};
+ LLVMContext::MD_access_group,
+ LLVMContext::MD_preserve_access_index};
combineMetadata(I1, I2, KnownIDs, true);
// I1 and I2 are being combined into a single instruction. Its debug
@@ -1420,6 +1423,20 @@ HoistTerminator:
return true;
}
+// Check lifetime markers.
+static bool isLifeTimeMarker(const Instruction *I) {
+ if (auto II = dyn_cast<IntrinsicInst>(I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::lifetime_start:
+ case Intrinsic::lifetime_end:
+ return true;
+ }
+ }
+ return false;
+}
+
// All instructions in Insts belong to different blocks that all unconditionally
// branch to a common successor. Analyze each instruction and return true if it
// would be possible to sink them into their successor, creating one common
@@ -1474,20 +1491,25 @@ static bool canSinkInstructions(
return false;
}
- // Because SROA can't handle speculating stores of selects, try not
- // to sink loads or stores of allocas when we'd have to create a PHI for
- // the address operand. Also, because it is likely that loads or stores
- // of allocas will disappear when Mem2Reg/SROA is run, don't sink them.
+ // Because SROA can't handle speculating stores of selects, try not to sink
+ // loads, stores or lifetime markers of allocas when we'd have to create a
+ // PHI for the address operand. Also, because it is likely that loads or
+ // stores of allocas will disappear when Mem2Reg/SROA is run, don't sink
+ // them.
// This can cause code churn which can have unintended consequences down
// the line - see https://llvm.org/bugs/show_bug.cgi?id=30244.
// FIXME: This is a workaround for a deficiency in SROA - see
// https://llvm.org/bugs/show_bug.cgi?id=30188
if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) {
- return isa<AllocaInst>(I->getOperand(1));
+ return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts());
}))
return false;
if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) {
- return isa<AllocaInst>(I->getOperand(0));
+ return isa<AllocaInst>(I->getOperand(0)->stripPointerCasts());
+ }))
+ return false;
+ if (isLifeTimeMarker(I0) && any_of(Insts, [](const Instruction *I) {
+ return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts());
}))
return false;
@@ -1959,7 +1981,7 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
SmallVector<Instruction *, 4> SpeculatedDbgIntrinsics;
- unsigned SpeculationCost = 0;
+ unsigned SpeculatedInstructions = 0;
Value *SpeculatedStoreValue = nullptr;
StoreInst *SpeculatedStore = nullptr;
for (BasicBlock::iterator BBI = ThenBB->begin(),
@@ -1974,8 +1996,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
// Only speculatively execute a single instruction (not counting the
// terminator) for now.
- ++SpeculationCost;
- if (SpeculationCost > 1)
+ ++SpeculatedInstructions;
+ if (SpeculatedInstructions > 1)
return false;
// Don't hoist the instruction if it's unsafe or expensive.
@@ -2012,8 +2034,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
E = SinkCandidateUseCounts.end();
I != E; ++I)
if (I->first->hasNUses(I->second)) {
- ++SpeculationCost;
- if (SpeculationCost > 1)
+ ++SpeculatedInstructions;
+ if (SpeculatedInstructions > 1)
return false;
}
@@ -2053,8 +2075,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
// getting expanded into Instructions.
// FIXME: This doesn't account for how many operations are combined in the
// constant expression.
- ++SpeculationCost;
- if (SpeculationCost > 1)
+ ++SpeculatedInstructions;
+ if (SpeculatedInstructions > 1)
return false;
}
@@ -2302,10 +2324,8 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
// instructions. While we are at it, keep track of the instructions
// that need to be moved to the dominating block.
SmallPtrSet<Instruction *, 4> AggressiveInsts;
- unsigned MaxCostVal0 = PHINodeFoldingThreshold,
- MaxCostVal1 = PHINodeFoldingThreshold;
- MaxCostVal0 *= TargetTransformInfo::TCC_Basic;
- MaxCostVal1 *= TargetTransformInfo::TCC_Basic;
+ int BudgetRemaining =
+ TwoEntryPHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic;
for (BasicBlock::iterator II = BB->begin(); isa<PHINode>(II);) {
PHINode *PN = cast<PHINode>(II++);
@@ -2316,9 +2336,9 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
}
if (!DominatesMergePoint(PN->getIncomingValue(0), BB, AggressiveInsts,
- MaxCostVal0, TTI) ||
+ BudgetRemaining, TTI) ||
!DominatesMergePoint(PN->getIncomingValue(1), BB, AggressiveInsts,
- MaxCostVal1, TTI))
+ BudgetRemaining, TTI))
return false;
}
@@ -2328,12 +2348,24 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
if (!PN)
return true;
- // Don't fold i1 branches on PHIs which contain binary operators. These can
- // often be turned into switches and other things.
+ // Return true if at least one of these is a 'not', and another is either
+ // a 'not' too, or a constant.
+ auto CanHoistNotFromBothValues = [](Value *V0, Value *V1) {
+ if (!match(V0, m_Not(m_Value())))
+ std::swap(V0, V1);
+ auto Invertible = m_CombineOr(m_Not(m_Value()), m_AnyIntegralConstant());
+ return match(V0, m_Not(m_Value())) && match(V1, Invertible);
+ };
+
+ // Don't fold i1 branches on PHIs which contain binary operators, unless one
+ // of the incoming values is an 'not' and another one is freely invertible.
+ // These can often be turned into switches and other things.
if (PN->getType()->isIntegerTy(1) &&
(isa<BinaryOperator>(PN->getIncomingValue(0)) ||
isa<BinaryOperator>(PN->getIncomingValue(1)) ||
- isa<BinaryOperator>(IfCond)))
+ isa<BinaryOperator>(IfCond)) &&
+ !CanHoistNotFromBothValues(PN->getIncomingValue(0),
+ PN->getIncomingValue(1)))
return false;
// If all PHI nodes are promotable, check to make sure that all instructions
@@ -2368,6 +2400,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
return false;
}
}
+ assert(DomBlock && "Failed to find root DomBlock");
LLVM_DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond
<< " T: " << IfTrue->getName()
@@ -2913,42 +2946,8 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB,
BasicBlock *QTB, BasicBlock *QFB,
BasicBlock *PostBB, Value *Address,
bool InvertPCond, bool InvertQCond,
- const DataLayout &DL) {
- auto IsaBitcastOfPointerType = [](const Instruction &I) {
- return Operator::getOpcode(&I) == Instruction::BitCast &&
- I.getType()->isPointerTy();
- };
-
- // If we're not in aggressive mode, we only optimize if we have some
- // confidence that by optimizing we'll allow P and/or Q to be if-converted.
- auto IsWorthwhile = [&](BasicBlock *BB) {
- if (!BB)
- return true;
- // Heuristic: if the block can be if-converted/phi-folded and the
- // instructions inside are all cheap (arithmetic/GEPs), it's worthwhile to
- // thread this store.
- unsigned N = 0;
- for (auto &I : BB->instructionsWithoutDebug()) {
- // Cheap instructions viable for folding.
- if (isa<BinaryOperator>(I) || isa<GetElementPtrInst>(I) ||
- isa<StoreInst>(I))
- ++N;
- // Free instructions.
- else if (I.isTerminator() || IsaBitcastOfPointerType(I))
- continue;
- else
- return false;
- }
- // The store we want to merge is counted in N, so add 1 to make sure
- // we're counting the instructions that would be left.
- return N <= (PHINodeFoldingThreshold + 1);
- };
-
- if (!MergeCondStoresAggressively &&
- (!IsWorthwhile(PTB) || !IsWorthwhile(PFB) || !IsWorthwhile(QTB) ||
- !IsWorthwhile(QFB)))
- return false;
-
+ const DataLayout &DL,
+ const TargetTransformInfo &TTI) {
// For every pointer, there must be exactly two stores, one coming from
// PTB or PFB, and the other from QTB or QFB. We don't support more than one
// store (to any address) in PTB,PFB or QTB,QFB.
@@ -2989,6 +2988,46 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB,
if (&*I != PStore && I->mayReadOrWriteMemory())
return false;
+ // If we're not in aggressive mode, we only optimize if we have some
+ // confidence that by optimizing we'll allow P and/or Q to be if-converted.
+ auto IsWorthwhile = [&](BasicBlock *BB, ArrayRef<StoreInst *> FreeStores) {
+ if (!BB)
+ return true;
+ // Heuristic: if the block can be if-converted/phi-folded and the
+ // instructions inside are all cheap (arithmetic/GEPs), it's worthwhile to
+ // thread this store.
+ int BudgetRemaining =
+ PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic;
+ for (auto &I : BB->instructionsWithoutDebug()) {
+ // Consider terminator instruction to be free.
+ if (I.isTerminator())
+ continue;
+ // If this is one the stores that we want to speculate out of this BB,
+ // then don't count it's cost, consider it to be free.
+ if (auto *S = dyn_cast<StoreInst>(&I))
+ if (llvm::find(FreeStores, S))
+ continue;
+ // Else, we have a white-list of instructions that we are ak speculating.
+ if (!isa<BinaryOperator>(I) && !isa<GetElementPtrInst>(I))
+ return false; // Not in white-list - not worthwhile folding.
+ // And finally, if this is a non-free instruction that we are okay
+ // speculating, ensure that we consider the speculation budget.
+ BudgetRemaining -= TTI.getUserCost(&I);
+ if (BudgetRemaining < 0)
+ return false; // Eagerly refuse to fold as soon as we're out of budget.
+ }
+ assert(BudgetRemaining >= 0 &&
+ "When we run out of budget we will eagerly return from within the "
+ "per-instruction loop.");
+ return true;
+ };
+
+ const SmallVector<StoreInst *, 2> FreeStores = {PStore, QStore};
+ if (!MergeCondStoresAggressively &&
+ (!IsWorthwhile(PTB, FreeStores) || !IsWorthwhile(PFB, FreeStores) ||
+ !IsWorthwhile(QTB, FreeStores) || !IsWorthwhile(QFB, FreeStores)))
+ return false;
+
// If PostBB has more than two predecessors, we need to split it so we can
// sink the store.
if (std::next(pred_begin(PostBB), 2) != pred_end(PostBB)) {
@@ -3048,15 +3087,15 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB,
// store that doesn't execute.
if (MinAlignment != 0) {
// Choose the minimum of all non-zero alignments.
- SI->setAlignment(MinAlignment);
+ SI->setAlignment(Align(MinAlignment));
} else if (MaxAlignment != 0) {
// Choose the minimal alignment between the non-zero alignment and the ABI
// default alignment for the type of the stored value.
- SI->setAlignment(std::min(MaxAlignment, TypeAlignment));
+ SI->setAlignment(Align(std::min(MaxAlignment, TypeAlignment)));
} else {
// If both alignments are zero, use ABI default alignment for the type of
// the stored value.
- SI->setAlignment(TypeAlignment);
+ SI->setAlignment(Align(TypeAlignment));
}
QStore->eraseFromParent();
@@ -3066,7 +3105,8 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB,
}
static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI,
- const DataLayout &DL) {
+ const DataLayout &DL,
+ const TargetTransformInfo &TTI) {
// The intention here is to find diamonds or triangles (see below) where each
// conditional block contains a store to the same address. Both of these
// stores are conditional, so they can't be unconditionally sunk. But it may
@@ -3168,7 +3208,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI,
bool Changed = false;
for (auto *Address : CommonAddresses)
Changed |= mergeConditionalStoreToAddress(
- PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond, DL);
+ PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond, DL, TTI);
return Changed;
}
@@ -3177,7 +3217,8 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI,
/// that PBI and BI are both conditional branches, and BI is in one of the
/// successor blocks of PBI - PBI branches to BI.
static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
- const DataLayout &DL) {
+ const DataLayout &DL,
+ const TargetTransformInfo &TTI) {
assert(PBI->isConditional() && BI->isConditional());
BasicBlock *BB = BI->getParent();
@@ -3233,7 +3274,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
// If both branches are conditional and both contain stores to the same
// address, remove the stores from the conditionals and create a conditional
// merged store at the end.
- if (MergeCondStores && mergeConditionalStores(PBI, BI, DL))
+ if (MergeCondStores && mergeConditionalStores(PBI, BI, DL, TTI))
return true;
// If this is a conditional branch in an empty block, and if any
@@ -3697,12 +3738,17 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder,
BasicBlock *BB = BI->getParent();
+ // MSAN does not like undefs as branch condition which can be introduced
+ // with "explicit branch".
+ if (ExtraCase && BB->getParent()->hasFnAttribute(Attribute::SanitizeMemory))
+ return false;
+
LLVM_DEBUG(dbgs() << "Converting 'icmp' chain with " << Values.size()
<< " cases into SWITCH. BB is:\n"
<< *BB);
// If there are any extra values that couldn't be folded into the switch
- // then we evaluate them with an explicit branch first. Split the block
+ // then we evaluate them with an explicit branch first. Split the block
// right before the condbr to handle it.
if (ExtraCase) {
BasicBlock *NewBB =
@@ -3851,7 +3897,7 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) {
// Simplify resume that is only used by a single (non-phi) landing pad.
bool SimplifyCFGOpt::SimplifySingleResume(ResumeInst *RI) {
BasicBlock *BB = RI->getParent();
- LandingPadInst *LPInst = dyn_cast<LandingPadInst>(BB->getFirstNonPHI());
+ auto *LPInst = cast<LandingPadInst>(BB->getFirstNonPHI());
assert(RI->getValue() == LPInst &&
"Resume must unwind the exception that caused control to here");
@@ -4178,23 +4224,22 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) {
IRBuilder<> Builder(TI);
if (auto *BI = dyn_cast<BranchInst>(TI)) {
if (BI->isUnconditional()) {
- if (BI->getSuccessor(0) == BB) {
- new UnreachableInst(TI->getContext(), TI);
- TI->eraseFromParent();
- Changed = true;
- }
+ assert(BI->getSuccessor(0) == BB && "Incorrect CFG");
+ new UnreachableInst(TI->getContext(), TI);
+ TI->eraseFromParent();
+ Changed = true;
} else {
Value* Cond = BI->getCondition();
if (BI->getSuccessor(0) == BB) {
Builder.CreateAssumption(Builder.CreateNot(Cond));
Builder.CreateBr(BI->getSuccessor(1));
- EraseTerminatorAndDCECond(BI);
- } else if (BI->getSuccessor(1) == BB) {
+ } else {
+ assert(BI->getSuccessor(1) == BB && "Incorrect CFG");
Builder.CreateAssumption(Cond);
Builder.CreateBr(BI->getSuccessor(0));
- EraseTerminatorAndDCECond(BI);
- Changed = true;
}
+ EraseTerminatorAndDCECond(BI);
+ Changed = true;
}
} else if (auto *SI = dyn_cast<SwitchInst>(TI)) {
SwitchInstProfUpdateWrapper SU(*SI);
@@ -4276,6 +4321,17 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
return true;
}
+static void createUnreachableSwitchDefault(SwitchInst *Switch) {
+ LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n");
+ BasicBlock *NewDefaultBlock =
+ SplitBlockPredecessors(Switch->getDefaultDest(), Switch->getParent(), "");
+ Switch->setDefaultDest(&*NewDefaultBlock);
+ SplitBlock(&*NewDefaultBlock, &NewDefaultBlock->front());
+ auto *NewTerminator = NewDefaultBlock->getTerminator();
+ new UnreachableInst(Switch->getContext(), NewTerminator);
+ EraseTerminatorAndDCECond(NewTerminator);
+}
+
/// Turn a switch with two reachable destinations into an integer range
/// comparison and branch.
static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) {
@@ -4384,6 +4440,11 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) {
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
+ // Clean up the default block - it may have phis or other instructions before
+ // the unreachable terminator.
+ if (!HasDefault)
+ createUnreachableSwitchDefault(SI);
+
// Drop the switch.
SI->eraseFromParent();
@@ -4428,14 +4489,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC,
if (HasDefault && DeadCases.empty() &&
NumUnknownBits < 64 /* avoid overflow */ &&
SI->getNumCases() == (1ULL << NumUnknownBits)) {
- LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n");
- BasicBlock *NewDefault =
- SplitBlockPredecessors(SI->getDefaultDest(), SI->getParent(), "");
- SI->setDefaultDest(&*NewDefault);
- SplitBlock(&*NewDefault, &NewDefault->front());
- auto *OldTI = NewDefault->getTerminator();
- new UnreachableInst(SI->getContext(), OldTI);
- EraseTerminatorAndDCECond(OldTI);
+ createUnreachableSwitchDefault(SI);
return true;
}
@@ -5031,7 +5085,7 @@ SwitchLookupTable::SwitchLookupTable(
Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
// Set the alignment to that of an array items. We will be only loading one
// value out of it.
- Array->setAlignment(DL.getPrefTypeAlignment(ValueType));
+ Array->setAlignment(Align(DL.getPrefTypeAlignment(ValueType)));
Kind = ArrayKind;
}
@@ -5260,7 +5314,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// Figure out the corresponding result for each case value and phi node in the
// common destination, as well as the min and max case values.
- assert(!empty(SI->cases()));
+ assert(!SI->cases().empty());
SwitchInst::CaseIt CI = SI->case_begin();
ConstantInt *MinCaseVal = CI->getCaseValue();
ConstantInt *MaxCaseVal = CI->getCaseValue();
@@ -5892,7 +5946,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator()))
if (PBI != BI && PBI->isConditional())
- if (SimplifyCondBranchToCondBranch(PBI, BI, DL))
+ if (SimplifyCondBranchToCondBranch(PBI, BI, DL, TTI))
return requestResimplify();
// Look for diamond patterns.
@@ -5900,7 +5954,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
if (BasicBlock *PrevBB = allPredecessorsComeFromSameSource(BB))
if (BranchInst *PBI = dyn_cast<BranchInst>(PrevBB->getTerminator()))
if (PBI != BI && PBI->isConditional())
- if (mergeConditionalStores(PBI, BI, DL))
+ if (mergeConditionalStores(PBI, BI, DL, TTI))
return requestResimplify();
return false;
diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp
index e0def81d5eee..0324993a8203 100644
--- a/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -35,6 +35,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
@@ -47,7 +48,6 @@ static cl::opt<bool>
cl::desc("Enable unsafe double to float "
"shrinking for math lib calls"));
-
//===----------------------------------------------------------------------===//
// Helper Functions
//===----------------------------------------------------------------------===//
@@ -177,7 +177,8 @@ static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len,
if (!isOnlyUsedInComparisonWithZero(CI))
return false;
- if (!isDereferenceableAndAlignedPointer(Str, 1, APInt(64, Len), DL))
+ if (!isDereferenceableAndAlignedPointer(Str, Align::None(), APInt(64, Len),
+ DL))
return false;
if (CI->getFunction()->hasFnAttribute(Attribute::SanitizeMemory))
@@ -186,6 +187,67 @@ static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len,
return true;
}
+static void annotateDereferenceableBytes(CallInst *CI,
+ ArrayRef<unsigned> ArgNos,
+ uint64_t DereferenceableBytes) {
+ const Function *F = CI->getCaller();
+ if (!F)
+ return;
+ for (unsigned ArgNo : ArgNos) {
+ uint64_t DerefBytes = DereferenceableBytes;
+ unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace();
+ if (!llvm::NullPointerIsDefined(F, AS) ||
+ CI->paramHasAttr(ArgNo, Attribute::NonNull))
+ DerefBytes = std::max(CI->getDereferenceableOrNullBytes(
+ ArgNo + AttributeList::FirstArgIndex),
+ DereferenceableBytes);
+
+ if (CI->getDereferenceableBytes(ArgNo + AttributeList::FirstArgIndex) <
+ DerefBytes) {
+ CI->removeParamAttr(ArgNo, Attribute::Dereferenceable);
+ if (!llvm::NullPointerIsDefined(F, AS) ||
+ CI->paramHasAttr(ArgNo, Attribute::NonNull))
+ CI->removeParamAttr(ArgNo, Attribute::DereferenceableOrNull);
+ CI->addParamAttr(ArgNo, Attribute::getWithDereferenceableBytes(
+ CI->getContext(), DerefBytes));
+ }
+ }
+}
+
+static void annotateNonNullBasedOnAccess(CallInst *CI,
+ ArrayRef<unsigned> ArgNos) {
+ Function *F = CI->getCaller();
+ if (!F)
+ return;
+
+ for (unsigned ArgNo : ArgNos) {
+ if (CI->paramHasAttr(ArgNo, Attribute::NonNull))
+ continue;
+ unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace();
+ if (llvm::NullPointerIsDefined(F, AS))
+ continue;
+
+ CI->addParamAttr(ArgNo, Attribute::NonNull);
+ annotateDereferenceableBytes(CI, ArgNo, 1);
+ }
+}
+
+static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> ArgNos,
+ Value *Size, const DataLayout &DL) {
+ if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) {
+ annotateNonNullBasedOnAccess(CI, ArgNos);
+ annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
+ } else if (isKnownNonZero(Size, DL)) {
+ annotateNonNullBasedOnAccess(CI, ArgNos);
+ const APInt *X, *Y;
+ uint64_t DerefMin = 1;
+ if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
+ DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
+ annotateDereferenceableBytes(CI, ArgNos, DerefMin);
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// String and Memory Library Call Optimizations
//===----------------------------------------------------------------------===//
@@ -194,10 +256,13 @@ Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) {
// Extract some information from the instruction
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
+ annotateNonNullBasedOnAccess(CI, {0, 1});
// See if we can get the length of the input string.
uint64_t Len = GetStringLength(Src);
- if (Len == 0)
+ if (Len)
+ annotateDereferenceableBytes(CI, 1, Len);
+ else
return nullptr;
--Len; // Unbias length.
@@ -232,24 +297,34 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) {
// Extract some information from the instruction.
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
+ Value *Size = CI->getArgOperand(2);
uint64_t Len;
+ annotateNonNullBasedOnAccess(CI, 0);
+ if (isKnownNonZero(Size, DL))
+ annotateNonNullBasedOnAccess(CI, 1);
// We don't do anything if length is not constant.
- if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
+ ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size);
+ if (LengthArg) {
Len = LengthArg->getZExtValue();
- else
+ // strncat(x, c, 0) -> x
+ if (!Len)
+ return Dst;
+ } else {
return nullptr;
+ }
// See if we can get the length of the input string.
uint64_t SrcLen = GetStringLength(Src);
- if (SrcLen == 0)
+ if (SrcLen) {
+ annotateDereferenceableBytes(CI, 1, SrcLen);
+ --SrcLen; // Unbias length.
+ } else {
return nullptr;
- --SrcLen; // Unbias length.
+ }
- // Handle the simple, do-nothing cases:
// strncat(x, "", c) -> x
- // strncat(x, c, 0) -> x
- if (SrcLen == 0 || Len == 0)
+ if (SrcLen == 0)
return Dst;
// We don't optimize this case.
@@ -265,13 +340,18 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
FunctionType *FT = Callee->getFunctionType();
Value *SrcStr = CI->getArgOperand(0);
+ annotateNonNullBasedOnAccess(CI, 0);
// If the second operand is non-constant, see if we can compute the length
// of the input string and turn this into memchr.
ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
if (!CharC) {
uint64_t Len = GetStringLength(SrcStr);
- if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32.
+ if (Len)
+ annotateDereferenceableBytes(CI, 0, Len);
+ else
+ return nullptr;
+ if (!FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32.
return nullptr;
return emitMemChr(SrcStr, CI->getArgOperand(1), // include nul.
@@ -304,6 +384,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) {
Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) {
Value *SrcStr = CI->getArgOperand(0);
ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ annotateNonNullBasedOnAccess(CI, 0);
// Cannot fold anything if we're not looking for a constant.
if (!CharC)
@@ -351,7 +432,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) {
// strcmp(P, "x") -> memcmp(P, "x", 2)
uint64_t Len1 = GetStringLength(Str1P);
+ if (Len1)
+ annotateDereferenceableBytes(CI, 0, Len1);
uint64_t Len2 = GetStringLength(Str2P);
+ if (Len2)
+ annotateDereferenceableBytes(CI, 1, Len2);
+
if (Len1 && Len2) {
return emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()),
@@ -374,17 +460,22 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) {
TLI);
}
+ annotateNonNullBasedOnAccess(CI, {0, 1});
return nullptr;
}
Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) {
- Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1);
+ Value *Str1P = CI->getArgOperand(0);
+ Value *Str2P = CI->getArgOperand(1);
+ Value *Size = CI->getArgOperand(2);
if (Str1P == Str2P) // strncmp(x,x,n) -> 0
return ConstantInt::get(CI->getType(), 0);
+ if (isKnownNonZero(Size, DL))
+ annotateNonNullBasedOnAccess(CI, {0, 1});
// Get the length argument if it is constant.
uint64_t Length;
- if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
+ if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size))
Length = LengthArg->getZExtValue();
else
return nullptr;
@@ -393,7 +484,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) {
return ConstantInt::get(CI->getType(), 0);
if (Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1)
- return emitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI);
+ return emitMemCmp(Str1P, Str2P, Size, B, DL, TLI);
StringRef Str1, Str2;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
@@ -415,7 +506,11 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) {
CI->getType());
uint64_t Len1 = GetStringLength(Str1P);
+ if (Len1)
+ annotateDereferenceableBytes(CI, 0, Len1);
uint64_t Len2 = GetStringLength(Str2P);
+ if (Len2)
+ annotateDereferenceableBytes(CI, 1, Len2);
// strncmp to memcmp
if (!HasStr1 && HasStr2) {
@@ -437,20 +532,38 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) {
return nullptr;
}
+Value *LibCallSimplifier::optimizeStrNDup(CallInst *CI, IRBuilder<> &B) {
+ Value *Src = CI->getArgOperand(0);
+ ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ uint64_t SrcLen = GetStringLength(Src);
+ if (SrcLen && Size) {
+ annotateDereferenceableBytes(CI, 0, SrcLen);
+ if (SrcLen <= Size->getZExtValue() + 1)
+ return emitStrDup(Src, B, TLI);
+ }
+
+ return nullptr;
+}
+
Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) {
Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
if (Dst == Src) // strcpy(x,x) -> x
return Src;
-
+
+ annotateNonNullBasedOnAccess(CI, {0, 1});
// See if we can get the length of the input string.
uint64_t Len = GetStringLength(Src);
- if (Len == 0)
+ if (Len)
+ annotateDereferenceableBytes(CI, 1, Len);
+ else
return nullptr;
// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
- B.CreateMemCpy(Dst, 1, Src, 1,
- ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
+ CallInst *NewCI =
+ B.CreateMemCpy(Dst, 1, Src, 1,
+ ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
+ NewCI->setAttributes(CI->getAttributes());
return Dst;
}
@@ -464,7 +577,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) {
// See if we can get the length of the input string.
uint64_t Len = GetStringLength(Src);
- if (Len == 0)
+ if (Len)
+ annotateDereferenceableBytes(CI, 1, Len);
+ else
return nullptr;
Type *PT = Callee->getFunctionType()->getParamType(0);
@@ -474,7 +589,8 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) {
// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
- B.CreateMemCpy(Dst, 1, Src, 1, LenV);
+ CallInst *NewCI = B.CreateMemCpy(Dst, 1, Src, 1, LenV);
+ NewCI->setAttributes(CI->getAttributes());
return DstEnd;
}
@@ -482,37 +598,47 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
- Value *LenOp = CI->getArgOperand(2);
+ Value *Size = CI->getArgOperand(2);
+ annotateNonNullBasedOnAccess(CI, 0);
+ if (isKnownNonZero(Size, DL))
+ annotateNonNullBasedOnAccess(CI, 1);
+
+ uint64_t Len;
+ if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size))
+ Len = LengthArg->getZExtValue();
+ else
+ return nullptr;
+
+ // strncpy(x, y, 0) -> x
+ if (Len == 0)
+ return Dst;
// See if we can get the length of the input string.
uint64_t SrcLen = GetStringLength(Src);
- if (SrcLen == 0)
+ if (SrcLen) {
+ annotateDereferenceableBytes(CI, 1, SrcLen);
+ --SrcLen; // Unbias length.
+ } else {
return nullptr;
- --SrcLen;
+ }
if (SrcLen == 0) {
// strncpy(x, "", y) -> memset(align 1 x, '\0', y)
- B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1);
+ CallInst *NewCI = B.CreateMemSet(Dst, B.getInt8('\0'), Size, 1);
+ AttrBuilder ArgAttrs(CI->getAttributes().getParamAttributes(0));
+ NewCI->setAttributes(NewCI->getAttributes().addParamAttributes(
+ CI->getContext(), 0, ArgAttrs));
return Dst;
}
- uint64_t Len;
- if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp))
- Len = LengthArg->getZExtValue();
- else
- return nullptr;
-
- if (Len == 0)
- return Dst; // strncpy(x, y, 0) -> x
-
// Let strncpy handle the zero padding
if (Len > SrcLen + 1)
return nullptr;
Type *PT = Callee->getFunctionType()->getParamType(0);
// strncpy(x, s, c) -> memcpy(align 1 x, align 1 s, c) [s and c are constant]
- B.CreateMemCpy(Dst, 1, Src, 1, ConstantInt::get(DL.getIntPtrType(PT), Len));
-
+ CallInst *NewCI = B.CreateMemCpy(Dst, 1, Src, 1, ConstantInt::get(DL.getIntPtrType(PT), Len));
+ NewCI->setAttributes(CI->getAttributes());
return Dst;
}
@@ -608,7 +734,10 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B,
}
Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) {
- return optimizeStringLength(CI, B, 8);
+ if (Value *V = optimizeStringLength(CI, B, 8))
+ return V;
+ annotateNonNullBasedOnAccess(CI, 0);
+ return nullptr;
}
Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) {
@@ -756,21 +885,35 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) {
Value *StrChr = emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI);
return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr;
}
+
+ annotateNonNullBasedOnAccess(CI, {0, 1});
+ return nullptr;
+}
+
+Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilder<> &B) {
+ if (isKnownNonZero(CI->getOperand(2), DL))
+ annotateNonNullBasedOnAccess(CI, 0);
return nullptr;
}
Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) {
Value *SrcStr = CI->getArgOperand(0);
+ Value *Size = CI->getArgOperand(2);
+ annotateNonNullAndDereferenceable(CI, 0, Size, DL);
ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
- ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2));
+ ConstantInt *LenC = dyn_cast<ConstantInt>(Size);
// memchr(x, y, 0) -> null
- if (LenC && LenC->isZero())
- return Constant::getNullValue(CI->getType());
+ if (LenC) {
+ if (LenC->isZero())
+ return Constant::getNullValue(CI->getType());
+ } else {
+ // From now on we need at least constant length and string.
+ return nullptr;
+ }
- // From now on we need at least constant length and string.
StringRef Str;
- if (!LenC || !getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false))
+ if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false))
return nullptr;
// Truncate the string to LenC. If Str is smaller than LenC we will still only
@@ -913,6 +1056,7 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS,
Ret = 1;
return ConstantInt::get(CI->getType(), Ret);
}
+
return nullptr;
}
@@ -925,12 +1069,19 @@ Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI,
if (LHS == RHS) // memcmp(s,s,x) -> 0
return Constant::getNullValue(CI->getType());
+ annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
// Handle constant lengths.
- if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size))
- if (Value *Res = optimizeMemCmpConstantSize(CI, LHS, RHS,
- LenC->getZExtValue(), B, DL))
- return Res;
+ ConstantInt *LenC = dyn_cast<ConstantInt>(Size);
+ if (!LenC)
+ return nullptr;
+ // memcmp(d,s,0) -> 0
+ if (LenC->getZExtValue() == 0)
+ return Constant::getNullValue(CI->getType());
+
+ if (Value *Res =
+ optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL))
+ return Res;
return nullptr;
}
@@ -939,9 +1090,9 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) {
return V;
// memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0
- // `bcmp` can be more efficient than memcmp because it only has to know that
- // there is a difference, not where it is.
- if (isOnlyUsedInZeroEqualityComparison(CI) && TLI->has(LibFunc_bcmp)) {
+ // bcmp can be more efficient than memcmp because it only has to know that
+ // there is a difference, not how different one is to the other.
+ if (TLI->has(LibFunc_bcmp) && isOnlyUsedInZeroEqualityComparison(CI)) {
Value *LHS = CI->getArgOperand(0);
Value *RHS = CI->getArgOperand(1);
Value *Size = CI->getArgOperand(2);
@@ -956,16 +1107,37 @@ Value *LibCallSimplifier::optimizeBCmp(CallInst *CI, IRBuilder<> &B) {
}
Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) {
+ Value *Size = CI->getArgOperand(2);
+ annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
+ if (isa<IntrinsicInst>(CI))
+ return nullptr;
+
// memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n)
- B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
- CI->getArgOperand(2));
+ CallInst *NewCI =
+ B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size);
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
+Value *LibCallSimplifier::optimizeMemPCpy(CallInst *CI, IRBuilder<> &B) {
+ Value *Dst = CI->getArgOperand(0);
+ Value *N = CI->getArgOperand(2);
+ // mempcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n), x + n
+ CallInst *NewCI = B.CreateMemCpy(Dst, 1, CI->getArgOperand(1), 1, N);
+ NewCI->setAttributes(CI->getAttributes());
+ return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, N);
+}
+
Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) {
+ Value *Size = CI->getArgOperand(2);
+ annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
+ if (isa<IntrinsicInst>(CI))
+ return nullptr;
+
// memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n)
- B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
- CI->getArgOperand(2));
+ CallInst *NewCI =
+ B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size);
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
@@ -1003,25 +1175,29 @@ Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilder<> &B) {
B.SetInsertPoint(Malloc->getParent(), ++Malloc->getIterator());
const DataLayout &DL = Malloc->getModule()->getDataLayout();
IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext());
- Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1),
- Malloc->getArgOperand(0), Malloc->getAttributes(),
- B, *TLI);
- if (!Calloc)
- return nullptr;
-
- Malloc->replaceAllUsesWith(Calloc);
- eraseFromParent(Malloc);
+ if (Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1),
+ Malloc->getArgOperand(0),
+ Malloc->getAttributes(), B, *TLI)) {
+ substituteInParent(Malloc, Calloc);
+ return Calloc;
+ }
- return Calloc;
+ return nullptr;
}
Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) {
+ Value *Size = CI->getArgOperand(2);
+ annotateNonNullAndDereferenceable(CI, 0, Size, DL);
+ if (isa<IntrinsicInst>(CI))
+ return nullptr;
+
if (auto *Calloc = foldMallocMemset(CI, B))
return Calloc;
// memset(p, v, n) -> llvm.memset(align 1 p, v, n)
Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false);
- B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1);
+ CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, 1);
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
@@ -1096,21 +1272,18 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B,
if (!V[0] || (isBinary && !V[1]))
return nullptr;
- StringRef CalleeNm = CalleeFn->getName();
- AttributeList CalleeAt = CalleeFn->getAttributes();
- bool CalleeIn = CalleeFn->isIntrinsic();
-
// If call isn't an intrinsic, check that it isn't within a function with the
// same name as the float version of this call, otherwise the result is an
// infinite loop. For example, from MinGW-w64:
//
// float expf(float val) { return (float) exp((double) val); }
- if (!CalleeIn) {
- const Function *Fn = CI->getFunction();
- StringRef FnName = Fn->getName();
- if (FnName.back() == 'f' &&
- FnName.size() == (CalleeNm.size() + 1) &&
- FnName.startswith(CalleeNm))
+ StringRef CalleeName = CalleeFn->getName();
+ bool IsIntrinsic = CalleeFn->isIntrinsic();
+ if (!IsIntrinsic) {
+ StringRef CallerName = CI->getFunction()->getName();
+ if (!CallerName.empty() && CallerName.back() == 'f' &&
+ CallerName.size() == (CalleeName.size() + 1) &&
+ CallerName.startswith(CalleeName))
return nullptr;
}
@@ -1120,16 +1293,16 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B,
// g((double) float) -> (double) gf(float)
Value *R;
- if (CalleeIn) {
+ if (IsIntrinsic) {
Module *M = CI->getModule();
Intrinsic::ID IID = CalleeFn->getIntrinsicID();
Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy());
R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]);
+ } else {
+ AttributeList CalleeAttrs = CalleeFn->getAttributes();
+ R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeName, B, CalleeAttrs)
+ : emitUnaryFloatFnCall(V[0], CalleeName, B, CalleeAttrs);
}
- else
- R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeNm, B, CalleeAt)
- : emitUnaryFloatFnCall(V[0], CalleeNm, B, CalleeAt);
-
return B.CreateFPExt(R, B.getDoubleTy());
}
@@ -1234,9 +1407,25 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) {
return InnerChain[Exp];
}
+// Return a properly extended 32-bit integer if the operation is an itofp.
+static Value *getIntToFPVal(Value *I2F, IRBuilder<> &B) {
+ if (isa<SIToFPInst>(I2F) || isa<UIToFPInst>(I2F)) {
+ Value *Op = cast<Instruction>(I2F)->getOperand(0);
+ // Make sure that the exponent fits inside an int32_t,
+ // thus avoiding any range issues that FP has not.
+ unsigned BitWidth = Op->getType()->getPrimitiveSizeInBits();
+ if (BitWidth < 32 ||
+ (BitWidth == 32 && isa<SIToFPInst>(I2F)))
+ return isa<SIToFPInst>(I2F) ? B.CreateSExt(Op, B.getInt32Ty())
+ : B.CreateZExt(Op, B.getInt32Ty());
+ }
+
+ return nullptr;
+}
+
/// Use exp{,2}(x * y) for pow(exp{,2}(x), y);
-/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x);
-/// exp2(log2(n) * x) for pow(n, x).
+/// ldexp(1.0, x) for pow(2.0, itofp(x)); exp2(n * x) for pow(2.0 ** n, x);
+/// exp10(x) for pow(10.0, x); exp2(log2(n) * x) for pow(n, x).
Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
AttributeList Attrs = Pow->getCalledFunction()->getAttributes();
@@ -1269,9 +1458,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
StringRef ExpName;
Intrinsic::ID ID;
Value *ExpFn;
- LibFunc LibFnFloat;
- LibFunc LibFnDouble;
- LibFunc LibFnLongDouble;
+ LibFunc LibFnFloat, LibFnDouble, LibFnLongDouble;
switch (LibFn) {
default:
@@ -1305,9 +1492,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
// elimination cannot be trusted to remove it, since it may have side
// effects (e.g., errno). When the only consumer for the original
// exp{,2}() is pow(), then it has to be explicitly erased.
- BaseFn->replaceAllUsesWith(ExpFn);
- eraseFromParent(BaseFn);
-
+ substituteInParent(BaseFn, ExpFn);
return ExpFn;
}
}
@@ -1318,8 +1503,18 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
if (!match(Pow->getArgOperand(0), m_APFloat(BaseF)))
return nullptr;
+ // pow(2.0, itofp(x)) -> ldexp(1.0, x)
+ if (match(Base, m_SpecificFP(2.0)) &&
+ (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) &&
+ hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) {
+ if (Value *ExpoI = getIntToFPVal(Expo, B))
+ return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, TLI,
+ LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl,
+ B, Attrs);
+ }
+
// pow(2.0 ** n, x) -> exp2(n * x)
- if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) {
+ if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) {
APFloat BaseR = APFloat(1.0);
BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored);
BaseR = BaseR / *BaseF;
@@ -1344,7 +1539,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
// pow(10.0, x) -> exp10(x)
// TODO: There is no exp10() intrinsic yet, but some day there shall be one.
if (match(Base, m_SpecificFP(10.0)) &&
- hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l))
+ hasFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l))
return emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f,
LibFunc_exp10l, B, Attrs);
@@ -1359,17 +1554,15 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) {
if (Log) {
Value *FMul = B.CreateFMul(Log, Expo, "mul");
- if (Pow->doesNotAccessMemory()) {
+ if (Pow->doesNotAccessMemory())
return B.CreateCall(Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty),
FMul, "exp2");
- } else {
- if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f,
- LibFunc_exp2l))
- return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f,
- LibFunc_exp2l, B, Attrs);
- }
+ else if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l))
+ return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f,
+ LibFunc_exp2l, B, Attrs);
}
}
+
return nullptr;
}
@@ -1384,8 +1577,7 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno,
}
// Otherwise, use the libcall for sqrt().
- if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf,
- LibFunc_sqrtl))
+ if (hasFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl))
// TODO: We also should check that the target can in fact lower the sqrt()
// libcall. We currently have no way to ask this question, so we ask if
// the target has a sqrt() libcall, which is not exactly the same.
@@ -1452,7 +1644,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
bool Ignored;
// Bail out if simplifying libcalls to pow() is disabled.
- if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl))
+ if (!hasFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl))
return nullptr;
// Propagate the math semantics from the call to any created instructions.
@@ -1480,8 +1672,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
if (match(Expo, m_SpecificFP(-1.0)))
return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal");
- // pow(x, 0.0) -> 1.0
- if (match(Expo, m_SpecificFP(0.0)))
+ // pow(x, +/-0.0) -> 1.0
+ if (match(Expo, m_AnyZeroFP()))
return ConstantFP::get(Ty, 1.0);
// pow(x, 1.0) -> x
@@ -1558,16 +1750,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
// powf(x, itofp(y)) -> powi(x, y)
if (AllowApprox && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo))) {
- Value *IntExpo = cast<Instruction>(Expo)->getOperand(0);
- Value *NewExpo = nullptr;
- unsigned BitWidth = IntExpo->getType()->getPrimitiveSizeInBits();
- if (isa<SIToFPInst>(Expo) && BitWidth == 32)
- NewExpo = IntExpo;
- else if (BitWidth < 32)
- NewExpo = isa<SIToFPInst>(Expo) ? B.CreateSExt(IntExpo, B.getInt32Ty())
- : B.CreateZExt(IntExpo, B.getInt32Ty());
- if (NewExpo)
- return createPowWithIntegerExponent(Base, NewExpo, M, B);
+ if (Value *ExpoI = getIntToFPVal(Expo, B))
+ return createPowWithIntegerExponent(Base, ExpoI, M, B);
}
return Shrunk;
@@ -1575,45 +1759,25 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
- Value *Ret = nullptr;
StringRef Name = Callee->getName();
- if (UnsafeFPShrink && Name == "exp2" && hasFloatVersion(Name))
+ Value *Ret = nullptr;
+ if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) &&
+ hasFloatVersion(Name))
Ret = optimizeUnaryDoubleFP(CI, B, true);
+ Type *Ty = CI->getType();
Value *Op = CI->getArgOperand(0);
+
// Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32
// Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32
- LibFunc LdExp = LibFunc_ldexpl;
- if (Op->getType()->isFloatTy())
- LdExp = LibFunc_ldexpf;
- else if (Op->getType()->isDoubleTy())
- LdExp = LibFunc_ldexp;
-
- if (TLI->has(LdExp)) {
- Value *LdExpArg = nullptr;
- if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) {
- if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32)
- LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty());
- } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) {
- if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32)
- LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty());
- }
-
- if (LdExpArg) {
- Constant *One = ConstantFP::get(CI->getContext(), APFloat(1.0f));
- if (!Op->getType()->isFloatTy())
- One = ConstantExpr::getFPExtend(One, Op->getType());
-
- Module *M = CI->getModule();
- FunctionCallee NewCallee = M->getOrInsertFunction(
- TLI->getName(LdExp), Op->getType(), Op->getType(), B.getInt32Ty());
- CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg});
- if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts()))
- CI->setCallingConv(F->getCallingConv());
-
- return CI;
- }
+ if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) &&
+ hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) {
+ if (Value *Exp = getIntToFPVal(Op, B))
+ return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI,
+ LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl,
+ B, CI->getCalledFunction()->getAttributes());
}
+
return Ret;
}
@@ -1644,48 +1808,155 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
return B.CreateCall(F, { CI->getArgOperand(0), CI->getArgOperand(1) });
}
-Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
- Function *Callee = CI->getCalledFunction();
+Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilder<> &B) {
+ Function *LogFn = Log->getCalledFunction();
+ AttributeList Attrs = LogFn->getAttributes();
+ StringRef LogNm = LogFn->getName();
+ Intrinsic::ID LogID = LogFn->getIntrinsicID();
+ Module *Mod = Log->getModule();
+ Type *Ty = Log->getType();
Value *Ret = nullptr;
- StringRef Name = Callee->getName();
- if (UnsafeFPShrink && hasFloatVersion(Name))
- Ret = optimizeUnaryDoubleFP(CI, B, true);
- if (!CI->isFast())
- return Ret;
- Value *Op1 = CI->getArgOperand(0);
- auto *OpC = dyn_cast<CallInst>(Op1);
+ if (UnsafeFPShrink && hasFloatVersion(LogNm))
+ Ret = optimizeUnaryDoubleFP(Log, B, true);
// The earlier call must also be 'fast' in order to do these transforms.
- if (!OpC || !OpC->isFast())
+ CallInst *Arg = dyn_cast<CallInst>(Log->getArgOperand(0));
+ if (!Log->isFast() || !Arg || !Arg->isFast() || !Arg->hasOneUse())
return Ret;
- // log(pow(x,y)) -> y*log(x)
- // This is only applicable to log, log2, log10.
- if (Name != "log" && Name != "log2" && Name != "log10")
+ LibFunc LogLb, ExpLb, Exp2Lb, Exp10Lb, PowLb;
+
+ // This is only applicable to log(), log2(), log10().
+ if (TLI->getLibFunc(LogNm, LogLb))
+ switch (LogLb) {
+ case LibFunc_logf:
+ LogID = Intrinsic::log;
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ PowLb = LibFunc_powf;
+ break;
+ case LibFunc_log:
+ LogID = Intrinsic::log;
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ PowLb = LibFunc_pow;
+ break;
+ case LibFunc_logl:
+ LogID = Intrinsic::log;
+ ExpLb = LibFunc_expl;
+ Exp2Lb = LibFunc_exp2l;
+ Exp10Lb = LibFunc_exp10l;
+ PowLb = LibFunc_powl;
+ break;
+ case LibFunc_log2f:
+ LogID = Intrinsic::log2;
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ PowLb = LibFunc_powf;
+ break;
+ case LibFunc_log2:
+ LogID = Intrinsic::log2;
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ PowLb = LibFunc_pow;
+ break;
+ case LibFunc_log2l:
+ LogID = Intrinsic::log2;
+ ExpLb = LibFunc_expl;
+ Exp2Lb = LibFunc_exp2l;
+ Exp10Lb = LibFunc_exp10l;
+ PowLb = LibFunc_powl;
+ break;
+ case LibFunc_log10f:
+ LogID = Intrinsic::log10;
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ PowLb = LibFunc_powf;
+ break;
+ case LibFunc_log10:
+ LogID = Intrinsic::log10;
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ PowLb = LibFunc_pow;
+ break;
+ case LibFunc_log10l:
+ LogID = Intrinsic::log10;
+ ExpLb = LibFunc_expl;
+ Exp2Lb = LibFunc_exp2l;
+ Exp10Lb = LibFunc_exp10l;
+ PowLb = LibFunc_powl;
+ break;
+ default:
+ return Ret;
+ }
+ else if (LogID == Intrinsic::log || LogID == Intrinsic::log2 ||
+ LogID == Intrinsic::log10) {
+ if (Ty->getScalarType()->isFloatTy()) {
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ PowLb = LibFunc_powf;
+ } else if (Ty->getScalarType()->isDoubleTy()) {
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ PowLb = LibFunc_pow;
+ } else
+ return Ret;
+ } else
return Ret;
IRBuilder<>::FastMathFlagGuard Guard(B);
- FastMathFlags FMF;
- FMF.setFast();
- B.setFastMathFlags(FMF);
+ B.setFastMathFlags(FastMathFlags::getFast());
+
+ Intrinsic::ID ArgID = Arg->getIntrinsicID();
+ LibFunc ArgLb = NotLibFunc;
+ TLI->getLibFunc(Arg, ArgLb);
+
+ // log(pow(x,y)) -> y*log(x)
+ if (ArgLb == PowLb || ArgID == Intrinsic::pow) {
+ Value *LogX =
+ Log->doesNotAccessMemory()
+ ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty),
+ Arg->getOperand(0), "log")
+ : emitUnaryFloatFnCall(Arg->getOperand(0), LogNm, B, Attrs);
+ Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul");
+ // Since pow() may have side effects, e.g. errno,
+ // dead code elimination may not be trusted to remove it.
+ substituteInParent(Arg, MulY);
+ return MulY;
+ }
+
+ // log(exp{,2,10}(y)) -> y*log({e,2,10})
+ // TODO: There is no exp10() intrinsic yet.
+ if (ArgLb == ExpLb || ArgLb == Exp2Lb || ArgLb == Exp10Lb ||
+ ArgID == Intrinsic::exp || ArgID == Intrinsic::exp2) {
+ Constant *Eul;
+ if (ArgLb == ExpLb || ArgID == Intrinsic::exp)
+ // FIXME: Add more precise value of e for long double.
+ Eul = ConstantFP::get(Log->getType(), numbers::e);
+ else if (ArgLb == Exp2Lb || ArgID == Intrinsic::exp2)
+ Eul = ConstantFP::get(Log->getType(), 2.0);
+ else
+ Eul = ConstantFP::get(Log->getType(), 10.0);
+ Value *LogE = Log->doesNotAccessMemory()
+ ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty),
+ Eul, "log")
+ : emitUnaryFloatFnCall(Eul, LogNm, B, Attrs);
+ Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul");
+ // Since exp() may have side effects, e.g. errno,
+ // dead code elimination may not be trusted to remove it.
+ substituteInParent(Arg, MulY);
+ return MulY;
+ }
- LibFunc Func;
- Function *F = OpC->getCalledFunction();
- if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) &&
- Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow))
- return B.CreateFMul(OpC->getArgOperand(1),
- emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
- Callee->getAttributes()), "mul");
-
- // log(exp2(y)) -> y*log(2)
- if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) &&
- TLI->has(Func) && Func == LibFunc_exp2)
- return B.CreateFMul(
- OpC->getArgOperand(0),
- emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0),
- Callee->getName(), B, Callee->getAttributes()),
- "logmul");
return Ret;
}
@@ -2137,6 +2408,7 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) {
return New;
}
+ annotateNonNullBasedOnAccess(CI, 0);
return nullptr;
}
@@ -2231,21 +2503,21 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) {
return New;
}
+ annotateNonNullBasedOnAccess(CI, {0, 1});
return nullptr;
}
Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) {
- // Check for a fixed format string.
- StringRef FormatStr;
- if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr))
- return nullptr;
-
// Check for size
ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
if (!Size)
return nullptr;
uint64_t N = Size->getZExtValue();
+ // Check for a fixed format string.
+ StringRef FormatStr;
+ if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr))
+ return nullptr;
// If we just have a format string (nothing else crazy) transform it.
if (CI->getNumArgOperands() == 3) {
@@ -2318,6 +2590,8 @@ Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilder<> &B) {
return V;
}
+ if (isKnownNonZero(CI->getOperand(1), DL))
+ annotateNonNullBasedOnAccess(CI, 0);
return nullptr;
}
@@ -2503,6 +2777,7 @@ Value *LibCallSimplifier::optimizeFRead(CallInst *CI, IRBuilder<> &B) {
}
Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) {
+ annotateNonNullBasedOnAccess(CI, 0);
if (!CI->use_empty())
return nullptr;
@@ -2515,6 +2790,12 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) {
return nullptr;
}
+Value *LibCallSimplifier::optimizeBCopy(CallInst *CI, IRBuilder<> &B) {
+ // bcopy(src, dst, n) -> llvm.memmove(dst, src, n)
+ return B.CreateMemMove(CI->getArgOperand(1), 1, CI->getArgOperand(0), 1,
+ CI->getArgOperand(2));
+}
+
bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) {
LibFunc Func;
SmallString<20> FloatFuncName = FuncName;
@@ -2557,6 +2838,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
return optimizeStrLen(CI, Builder);
case LibFunc_strpbrk:
return optimizeStrPBrk(CI, Builder);
+ case LibFunc_strndup:
+ return optimizeStrNDup(CI, Builder);
case LibFunc_strtol:
case LibFunc_strtod:
case LibFunc_strtof:
@@ -2573,12 +2856,16 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
return optimizeStrStr(CI, Builder);
case LibFunc_memchr:
return optimizeMemChr(CI, Builder);
+ case LibFunc_memrchr:
+ return optimizeMemRChr(CI, Builder);
case LibFunc_bcmp:
return optimizeBCmp(CI, Builder);
case LibFunc_memcmp:
return optimizeMemCmp(CI, Builder);
case LibFunc_memcpy:
return optimizeMemCpy(CI, Builder);
+ case LibFunc_mempcpy:
+ return optimizeMemPCpy(CI, Builder);
case LibFunc_memmove:
return optimizeMemMove(CI, Builder);
case LibFunc_memset:
@@ -2587,6 +2874,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
return optimizeRealloc(CI, Builder);
case LibFunc_wcslen:
return optimizeWcslen(CI, Builder);
+ case LibFunc_bcopy:
+ return optimizeBCopy(CI, Builder);
default:
break;
}
@@ -2626,11 +2915,21 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
case LibFunc_sqrt:
case LibFunc_sqrtl:
return optimizeSqrt(CI, Builder);
+ case LibFunc_logf:
case LibFunc_log:
+ case LibFunc_logl:
+ case LibFunc_log10f:
case LibFunc_log10:
+ case LibFunc_log10l:
+ case LibFunc_log1pf:
case LibFunc_log1p:
+ case LibFunc_log1pl:
+ case LibFunc_log2f:
case LibFunc_log2:
+ case LibFunc_log2l:
+ case LibFunc_logbf:
case LibFunc_logb:
+ case LibFunc_logbl:
return optimizeLog(CI, Builder);
case LibFunc_tan:
case LibFunc_tanf:
@@ -2721,10 +3020,18 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case Intrinsic::exp2:
return optimizeExp2(CI, Builder);
case Intrinsic::log:
+ case Intrinsic::log2:
+ case Intrinsic::log10:
return optimizeLog(CI, Builder);
case Intrinsic::sqrt:
return optimizeSqrt(CI, Builder);
// TODO: Use foldMallocMemset() with memset intrinsic.
+ case Intrinsic::memset:
+ return optimizeMemSet(CI, Builder);
+ case Intrinsic::memcpy:
+ return optimizeMemCpy(CI, Builder);
+ case Intrinsic::memmove:
+ return optimizeMemMove(CI, Builder);
default:
return nullptr;
}
@@ -2740,8 +3047,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
IRBuilder<> TmpBuilder(SimplifiedCI);
if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, TmpBuilder)) {
// If we were able to further simplify, remove the now redundant call.
- SimplifiedCI->replaceAllUsesWith(V);
- eraseFromParent(SimplifiedCI);
+ substituteInParent(SimplifiedCI, V);
return V;
}
}
@@ -2898,7 +3204,9 @@ FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI,
uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp));
// If the length is 0 we don't know how long it is and so we can't
// remove the check.
- if (Len == 0)
+ if (Len)
+ annotateDereferenceableBytes(CI, *StrOp, Len);
+ else
return false;
return ObjSizeCI->getZExtValue() >= Len;
}
@@ -2915,8 +3223,9 @@ FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI,
Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI,
IRBuilder<> &B) {
if (isFortifiedCallFoldable(CI, 3, 2)) {
- B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
- CI->getArgOperand(2));
+ CallInst *NewCI = B.CreateMemCpy(
+ CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2));
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
return nullptr;
@@ -2925,8 +3234,9 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI,
Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI,
IRBuilder<> &B) {
if (isFortifiedCallFoldable(CI, 3, 2)) {
- B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
- CI->getArgOperand(2));
+ CallInst *NewCI = B.CreateMemMove(
+ CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2));
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
return nullptr;
@@ -2938,7 +3248,9 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI,
if (isFortifiedCallFoldable(CI, 3, 2)) {
Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false);
- B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1);
+ CallInst *NewCI =
+ B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1);
+ NewCI->setAttributes(CI->getAttributes());
return CI->getArgOperand(0);
}
return nullptr;
@@ -2974,7 +3286,9 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
// Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk.
uint64_t Len = GetStringLength(Src);
- if (Len == 0)
+ if (Len)
+ annotateDereferenceableBytes(CI, 1, Len);
+ else
return nullptr;
Type *SizeTTy = DL.getIntPtrType(CI->getContext());
diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp
index 456724779b43..5d380dcf231c 100644
--- a/lib/Transforms/Utils/SymbolRewriter.cpp
+++ b/lib/Transforms/Utils/SymbolRewriter.cpp
@@ -380,11 +380,11 @@ parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
// TODO see if there is a more elegant solution to selecting the rewrite
// descriptor type
if (!Target.empty())
- DL->push_back(llvm::make_unique<ExplicitRewriteFunctionDescriptor>(
+ DL->push_back(std::make_unique<ExplicitRewriteFunctionDescriptor>(
Source, Target, Naked));
else
DL->push_back(
- llvm::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform));
+ std::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform));
return true;
}
@@ -442,11 +442,11 @@ parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
}
if (!Target.empty())
- DL->push_back(llvm::make_unique<ExplicitRewriteGlobalVariableDescriptor>(
+ DL->push_back(std::make_unique<ExplicitRewriteGlobalVariableDescriptor>(
Source, Target,
/*Naked*/ false));
else
- DL->push_back(llvm::make_unique<PatternRewriteGlobalVariableDescriptor>(
+ DL->push_back(std::make_unique<PatternRewriteGlobalVariableDescriptor>(
Source, Transform));
return true;
@@ -505,11 +505,11 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
}
if (!Target.empty())
- DL->push_back(llvm::make_unique<ExplicitRewriteNamedAliasDescriptor>(
+ DL->push_back(std::make_unique<ExplicitRewriteNamedAliasDescriptor>(
Source, Target,
/*Naked*/ false));
else
- DL->push_back(llvm::make_unique<PatternRewriteNamedAliasDescriptor>(
+ DL->push_back(std::make_unique<PatternRewriteNamedAliasDescriptor>(
Source, Transform));
return true;
diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp
index a77bf50fe10b..591e1fd2dbee 100644
--- a/lib/Transforms/Utils/VNCoercion.cpp
+++ b/lib/Transforms/Utils/VNCoercion.cpp
@@ -431,7 +431,7 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy,
PtrVal = Builder.CreateBitCast(PtrVal, DestPTy);
LoadInst *NewLoad = Builder.CreateLoad(DestTy, PtrVal);
NewLoad->takeName(SrcVal);
- NewLoad->setAlignment(SrcVal->getAlignment());
+ NewLoad->setAlignment(MaybeAlign(SrcVal->getAlignment()));
LLVM_DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n");
LLVM_DEBUG(dbgs() << "TO: " << *NewLoad << "\n");
diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp
index fbc3407c301f..da68d3713b40 100644
--- a/lib/Transforms/Utils/ValueMapper.cpp
+++ b/lib/Transforms/Utils/ValueMapper.cpp
@@ -27,8 +27,8 @@
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
-#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalObject.h"
+#include "llvm/IR/GlobalIndirectSymbol.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instruction.h"
@@ -66,7 +66,7 @@ struct WorklistEntry {
enum EntryKind {
MapGlobalInit,
MapAppendingVar,
- MapGlobalAliasee,
+ MapGlobalIndirectSymbol,
RemapFunction
};
struct GVInitTy {
@@ -77,9 +77,9 @@ struct WorklistEntry {
GlobalVariable *GV;
Constant *InitPrefix;
};
- struct GlobalAliaseeTy {
- GlobalAlias *GA;
- Constant *Aliasee;
+ struct GlobalIndirectSymbolTy {
+ GlobalIndirectSymbol *GIS;
+ Constant *Target;
};
unsigned Kind : 2;
@@ -89,7 +89,7 @@ struct WorklistEntry {
union {
GVInitTy GVInit;
AppendingGVTy AppendingGV;
- GlobalAliaseeTy GlobalAliasee;
+ GlobalIndirectSymbolTy GlobalIndirectSymbol;
Function *RemapF;
} Data;
};
@@ -161,8 +161,8 @@ public:
bool IsOldCtorDtor,
ArrayRef<Constant *> NewMembers,
unsigned MCID);
- void scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee,
- unsigned MCID);
+ void scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, Constant &Target,
+ unsigned MCID);
void scheduleRemapFunction(Function &F, unsigned MCID);
void flush();
@@ -172,7 +172,7 @@ private:
void mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix,
bool IsOldCtorDtor,
ArrayRef<Constant *> NewMembers);
- void mapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee);
+ void mapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, Constant &Target);
void remapFunction(Function &F, ValueToValueMapTy &VM);
ValueToValueMapTy &getVM() { return *MCs[CurrentMCID].VM; }
@@ -774,20 +774,6 @@ Metadata *MDNodeMapper::mapTopLevelUniquedNode(const MDNode &FirstN) {
return *getMappedOp(&FirstN);
}
-namespace {
-
-struct MapMetadataDisabler {
- ValueToValueMapTy &VM;
-
- MapMetadataDisabler(ValueToValueMapTy &VM) : VM(VM) {
- VM.disableMapMetadata();
- }
-
- ~MapMetadataDisabler() { VM.enableMapMetadata(); }
-};
-
-} // end anonymous namespace
-
Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) {
// If the value already exists in the map, use it.
if (Optional<Metadata *> NewMD = getVM().getMappedMD(MD))
@@ -802,9 +788,6 @@ Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) {
return const_cast<Metadata *>(MD);
if (auto *CMD = dyn_cast<ConstantAsMetadata>(MD)) {
- // Disallow recursion into metadata mapping through mapValue.
- MapMetadataDisabler MMD(getVM());
-
// Don't memoize ConstantAsMetadata. Instead of lasting until the
// LLVMContext is destroyed, they can be deleted when the GlobalValue they
// reference is destructed. These aren't super common, so the extra
@@ -846,9 +829,9 @@ void Mapper::flush() {
AppendingInits.resize(PrefixSize);
break;
}
- case WorklistEntry::MapGlobalAliasee:
- E.Data.GlobalAliasee.GA->setAliasee(
- mapConstant(E.Data.GlobalAliasee.Aliasee));
+ case WorklistEntry::MapGlobalIndirectSymbol:
+ E.Data.GlobalIndirectSymbol.GIS->setIndirectSymbol(
+ mapConstant(E.Data.GlobalIndirectSymbol.Target));
break;
case WorklistEntry::RemapFunction:
remapFunction(*E.Data.RemapF);
@@ -1041,16 +1024,16 @@ void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV,
AppendingInits.append(NewMembers.begin(), NewMembers.end());
}
-void Mapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee,
- unsigned MCID) {
- assert(AlreadyScheduled.insert(&GA).second && "Should not reschedule");
+void Mapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS,
+ Constant &Target, unsigned MCID) {
+ assert(AlreadyScheduled.insert(&GIS).second && "Should not reschedule");
assert(MCID < MCs.size() && "Invalid mapping context");
WorklistEntry WE;
- WE.Kind = WorklistEntry::MapGlobalAliasee;
+ WE.Kind = WorklistEntry::MapGlobalIndirectSymbol;
WE.MCID = MCID;
- WE.Data.GlobalAliasee.GA = &GA;
- WE.Data.GlobalAliasee.Aliasee = &Aliasee;
+ WE.Data.GlobalIndirectSymbol.GIS = &GIS;
+ WE.Data.GlobalIndirectSymbol.Target = &Target;
Worklist.push_back(WE);
}
@@ -1147,9 +1130,10 @@ void ValueMapper::scheduleMapAppendingVariable(GlobalVariable &GV,
GV, InitPrefix, IsOldCtorDtor, NewMembers, MCID);
}
-void ValueMapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee,
- unsigned MCID) {
- getAsMapper(pImpl)->scheduleMapGlobalAliasee(GA, Aliasee, MCID);
+void ValueMapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS,
+ Constant &Target,
+ unsigned MCID) {
+ getAsMapper(pImpl)->scheduleMapGlobalIndirectSymbol(GIS, Target, MCID);
}
void ValueMapper::scheduleRemapFunction(Function &F, unsigned MCID) {
diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 4273080ddd91..f44976c723ec 100644
--- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -147,7 +147,7 @@ private:
static const unsigned MaxDepth = 3;
bool isConsecutiveAccess(Value *A, Value *B);
- bool areConsecutivePointers(Value *PtrA, Value *PtrB, const APInt &PtrDelta,
+ bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta,
unsigned Depth = 0) const;
bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta,
unsigned Depth) const;
@@ -336,14 +336,29 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
}
bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB,
- const APInt &PtrDelta,
- unsigned Depth) const {
+ APInt PtrDelta, unsigned Depth) const {
unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType());
APInt OffsetA(PtrBitWidth, 0);
APInt OffsetB(PtrBitWidth, 0);
PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
+ unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
+
+ if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
+ return false;
+
+ // In case if we have to shrink the pointer
+ // stripAndAccumulateInBoundsConstantOffsets should properly handle a
+ // possible overflow and the value should fit into a smallest data type
+ // used in the cast/gep chain.
+ assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth &&
+ OffsetB.getMinSignedBits() <= NewPtrBitWidth);
+
+ OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
+ OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
+ PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth);
+
APInt OffsetDelta = OffsetB - OffsetA;
// Check if they are based on the same pointer. That makes the offsets
@@ -650,7 +665,7 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
// We can ignore the alias if the we have a load store pair and the load
// is known to be invariant. The load cannot be clobbered by the store.
auto IsInvariantLoad = [](const LoadInst *LI) -> bool {
- return LI->getMetadata(LLVMContext::MD_invariant_load);
+ return LI->hasMetadata(LLVMContext::MD_invariant_load);
};
// We can ignore the alias as long as the load comes before the store,
@@ -1077,7 +1092,7 @@ bool Vectorizer::vectorizeLoadChain(
LoadInst *L0 = cast<LoadInst>(Chain[0]);
// If the vector has an int element, default to int for the whole load.
- Type *LoadTy;
+ Type *LoadTy = nullptr;
for (const auto &V : Chain) {
LoadTy = cast<LoadInst>(V)->getType();
if (LoadTy->isIntOrIntVectorTy())
@@ -1089,6 +1104,7 @@ bool Vectorizer::vectorizeLoadChain(
break;
}
}
+ assert(LoadTy && "Can't determine LoadInst type from chain");
unsigned Sz = DL.getTypeSizeInBits(LoadTy);
unsigned AS = L0->getPointerAddressSpace();
diff --git a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 6ef8dc2d3cd7..f43842be5357 100644
--- a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -13,7 +13,10 @@
// pass. It should be easy to create an analysis pass around it if there
// is a need (but D45420 needs to happen first).
//
+#include "llvm/Transforms/Vectorize/LoopVectorize.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -47,38 +50,6 @@ static const unsigned MaxInterleaveFactor = 16;
namespace llvm {
-#ifndef NDEBUG
-static void debugVectorizationFailure(const StringRef DebugMsg,
- Instruction *I) {
- dbgs() << "LV: Not vectorizing: " << DebugMsg;
- if (I != nullptr)
- dbgs() << " " << *I;
- else
- dbgs() << '.';
- dbgs() << '\n';
-}
-#endif
-
-OptimizationRemarkAnalysis createLVMissedAnalysis(const char *PassName,
- StringRef RemarkName,
- Loop *TheLoop,
- Instruction *I) {
- Value *CodeRegion = TheLoop->getHeader();
- DebugLoc DL = TheLoop->getStartLoc();
-
- if (I) {
- CodeRegion = I->getParent();
- // If there is no debug location attached to the instruction, revert back to
- // using the loop's.
- if (I->getDebugLoc())
- DL = I->getDebugLoc();
- }
-
- OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion);
- R << "loop not vectorized: ";
- return R;
-}
-
bool LoopVectorizeHints::Hint::validate(unsigned Val) {
switch (Kind) {
case HK_WIDTH:
@@ -88,6 +59,7 @@ bool LoopVectorizeHints::Hint::validate(unsigned Val) {
case HK_FORCE:
return (Val <= 1);
case HK_ISVECTORIZED:
+ case HK_PREDICATE:
return (Val == 0 || Val == 1);
}
return false;
@@ -99,7 +71,9 @@ LoopVectorizeHints::LoopVectorizeHints(const Loop *L,
: Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH),
Interleave("interleave.count", InterleaveOnlyWhenForced, HK_UNROLL),
Force("vectorize.enable", FK_Undefined, HK_FORCE),
- IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) {
+ IsVectorized("isvectorized", 0, HK_ISVECTORIZED),
+ Predicate("vectorize.predicate.enable", 0, HK_PREDICATE), TheLoop(L),
+ ORE(ORE) {
// Populate values with existing loop metadata.
getHintsFromMetadata();
@@ -250,7 +224,7 @@ void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) {
return;
unsigned Val = C->getZExtValue();
- Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized};
+ Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized, &Predicate};
for (auto H : Hints) {
if (Name == H->Name) {
if (H->validate(Val))
@@ -435,7 +409,8 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
const ValueToValueMap &Strides =
getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap();
- int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, true, false);
+ bool CanAddPredicate = !TheLoop->getHeader()->getParent()->hasOptSize();
+ int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, CanAddPredicate, false);
if (Stride == 1 || Stride == -1)
return Stride;
return 0;
@@ -445,14 +420,6 @@ bool LoopVectorizationLegality::isUniform(Value *V) {
return LAI->isUniform(V);
}
-void LoopVectorizationLegality::reportVectorizationFailure(
- const StringRef DebugMsg, const StringRef OREMsg,
- const StringRef ORETag, Instruction *I) const {
- LLVM_DEBUG(debugVectorizationFailure(DebugMsg, I));
- ORE->emit(createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(),
- ORETag, TheLoop, I) << OREMsg);
-}
-
bool LoopVectorizationLegality::canVectorizeOuterLoop() {
assert(!TheLoop->empty() && "We are not vectorizing an outer loop.");
// Store the result and return it at the end instead of exiting early, in case
@@ -467,7 +434,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() {
if (!Br) {
reportVectorizationFailure("Unsupported basic block terminator",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -486,7 +453,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() {
!LI->isLoopHeader(Br->getSuccessor(1))) {
reportVectorizationFailure("Unsupported conditional branch",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -500,7 +467,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() {
TheLoop /*context outer loop*/)) {
reportVectorizationFailure("Outer loop contains divergent loops",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -511,7 +478,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() {
if (!setupOuterLoopInductions()) {
reportVectorizationFailure("Unsupported outer loop Phi(s)",
"Unsupported outer loop Phi(s)",
- "UnsupportedPhi");
+ "UnsupportedPhi", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -618,7 +585,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
!PhiTy->isPointerTy()) {
reportVectorizationFailure("Found a non-int non-pointer PHI",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
return false;
}
@@ -631,6 +598,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// Unsafe cyclic dependencies with header phis are identified during
// legalization for reduction, induction and first order
// recurrences.
+ AllowedExit.insert(&I);
continue;
}
@@ -638,7 +606,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
if (Phi->getNumIncomingValues() != 2) {
reportVectorizationFailure("Found an invalid PHI",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood", Phi);
+ "CFGNotUnderstood", ORE, TheLoop, Phi);
return false;
}
@@ -690,7 +658,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
reportVectorizationFailure("Found an unidentified PHI",
"value that could not be identified as "
"reduction is used outside the loop",
- "NonReductionValueUsedOutsideLoop", Phi);
+ "NonReductionValueUsedOutsideLoop", ORE, TheLoop, Phi);
return false;
} // end of PHI handling
@@ -721,11 +689,11 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
"library call cannot be vectorized. "
"Try compiling with -fno-math-errno, -ffast-math, "
"or similar flags",
- "CantVectorizeLibcall", CI);
+ "CantVectorizeLibcall", ORE, TheLoop, CI);
} else {
reportVectorizationFailure("Found a non-intrinsic callsite",
"call instruction cannot be vectorized",
- "CantVectorizeLibcall", CI);
+ "CantVectorizeLibcall", ORE, TheLoop, CI);
}
return false;
}
@@ -740,7 +708,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(i)), TheLoop)) {
reportVectorizationFailure("Found unvectorizable intrinsic",
"intrinsic instruction cannot be vectorized",
- "CantVectorizeIntrinsic", CI);
+ "CantVectorizeIntrinsic", ORE, TheLoop, CI);
return false;
}
}
@@ -753,7 +721,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
isa<ExtractElementInst>(I)) {
reportVectorizationFailure("Found unvectorizable type",
"instruction return type cannot be vectorized",
- "CantVectorizeInstructionReturnType", &I);
+ "CantVectorizeInstructionReturnType", ORE, TheLoop, &I);
return false;
}
@@ -763,7 +731,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
if (!VectorType::isValidElementType(T)) {
reportVectorizationFailure("Store instruction cannot be vectorized",
"store instruction cannot be vectorized",
- "CantVectorizeStore", ST);
+ "CantVectorizeStore", ORE, TheLoop, ST);
return false;
}
@@ -773,12 +741,13 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// Arbitrarily try a vector of 2 elements.
Type *VecTy = VectorType::get(T, /*NumElements=*/2);
assert(VecTy && "did not find vectorized version of stored type");
- unsigned Alignment = getLoadStoreAlignment(ST);
- if (!TTI->isLegalNTStore(VecTy, Alignment)) {
+ const MaybeAlign Alignment = getLoadStoreAlignment(ST);
+ assert(Alignment && "Alignment should be set");
+ if (!TTI->isLegalNTStore(VecTy, *Alignment)) {
reportVectorizationFailure(
"nontemporal store instruction cannot be vectorized",
"nontemporal store instruction cannot be vectorized",
- "CantVectorizeNontemporalStore", ST);
+ "CantVectorizeNontemporalStore", ORE, TheLoop, ST);
return false;
}
}
@@ -789,12 +758,13 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// supported on the target (arbitrarily try a vector of 2 elements).
Type *VecTy = VectorType::get(I.getType(), /*NumElements=*/2);
assert(VecTy && "did not find vectorized version of load type");
- unsigned Alignment = getLoadStoreAlignment(LD);
- if (!TTI->isLegalNTLoad(VecTy, Alignment)) {
+ const MaybeAlign Alignment = getLoadStoreAlignment(LD);
+ assert(Alignment && "Alignment should be set");
+ if (!TTI->isLegalNTLoad(VecTy, *Alignment)) {
reportVectorizationFailure(
"nontemporal load instruction cannot be vectorized",
"nontemporal load instruction cannot be vectorized",
- "CantVectorizeNontemporalLoad", LD);
+ "CantVectorizeNontemporalLoad", ORE, TheLoop, LD);
return false;
}
}
@@ -823,7 +793,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
}
reportVectorizationFailure("Value cannot be used outside the loop",
"value cannot be used outside the loop",
- "ValueUsedOutsideLoop", &I);
+ "ValueUsedOutsideLoop", ORE, TheLoop, &I);
return false;
}
} // next instr.
@@ -833,12 +803,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
if (Inductions.empty()) {
reportVectorizationFailure("Did not find one integer induction var",
"loop induction variable could not be identified",
- "NoInductionVariable");
+ "NoInductionVariable", ORE, TheLoop);
return false;
} else if (!WidestIndTy) {
reportVectorizationFailure("Did not find one integer induction var",
"integer loop induction variable could not be identified",
- "NoIntegerInductionVariable");
+ "NoIntegerInductionVariable", ORE, TheLoop);
return false;
} else {
LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n");
@@ -869,7 +839,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
if (LAI->hasDependenceInvolvingLoopInvariantAddress()) {
reportVectorizationFailure("Stores to a uniform address",
"write to a loop invariant address could not be vectorized",
- "CantVectorizeStoreToLoopInvariantAddress");
+ "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop);
return false;
}
Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
@@ -905,7 +875,7 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) {
}
bool LoopVectorizationLegality::blockCanBePredicated(
- BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs) {
+ BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs, bool PreserveGuards) {
const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel();
for (Instruction &I : *BB) {
@@ -924,7 +894,7 @@ bool LoopVectorizationLegality::blockCanBePredicated(
// !llvm.mem.parallel_loop_access implies if-conversion safety.
// Otherwise, record that the load needs (real or emulated) masking
// and let the cost model decide.
- if (!IsAnnotatedParallel)
+ if (!IsAnnotatedParallel || PreserveGuards)
MaskedOp.insert(LI);
continue;
}
@@ -953,23 +923,41 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
if (!EnableIfConversion) {
reportVectorizationFailure("If-conversion is disabled",
"if-conversion is disabled",
- "IfConversionDisabled");
+ "IfConversionDisabled",
+ ORE, TheLoop);
return false;
}
assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable");
- // A list of pointers that we can safely read and write to.
+ // A list of pointers which are known to be dereferenceable within scope of
+ // the loop body for each iteration of the loop which executes. That is,
+ // the memory pointed to can be dereferenced (with the access size implied by
+ // the value's type) unconditionally within the loop header without
+ // introducing a new fault.
SmallPtrSet<Value *, 8> SafePointes;
// Collect safe addresses.
for (BasicBlock *BB : TheLoop->blocks()) {
- if (blockNeedsPredication(BB))
+ if (!blockNeedsPredication(BB)) {
+ for (Instruction &I : *BB)
+ if (auto *Ptr = getLoadStorePointerOperand(&I))
+ SafePointes.insert(Ptr);
continue;
+ }
- for (Instruction &I : *BB)
- if (auto *Ptr = getLoadStorePointerOperand(&I))
- SafePointes.insert(Ptr);
+ // For a block which requires predication, a address may be safe to access
+ // in the loop w/o predication if we can prove dereferenceability facts
+ // sufficient to ensure it'll never fault within the loop. For the moment,
+ // we restrict this to loads; stores are more complicated due to
+ // concurrency restrictions.
+ ScalarEvolution &SE = *PSE.getSE();
+ for (Instruction &I : *BB) {
+ LoadInst *LI = dyn_cast<LoadInst>(&I);
+ if (LI && !mustSuppressSpeculation(*LI) &&
+ isDereferenceableAndAlignedInLoop(LI, TheLoop, SE, *DT))
+ SafePointes.insert(LI->getPointerOperand());
+ }
}
// Collect the blocks that need predication.
@@ -979,7 +967,8 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
if (!isa<BranchInst>(BB->getTerminator())) {
reportVectorizationFailure("Loop contains a switch statement",
"loop contains a switch statement",
- "LoopContainsSwitch", BB->getTerminator());
+ "LoopContainsSwitch", ORE, TheLoop,
+ BB->getTerminator());
return false;
}
@@ -989,14 +978,16 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
reportVectorizationFailure(
"Control flow cannot be substituted for a select",
"control flow cannot be substituted for a select",
- "NoCFGForSelect", BB->getTerminator());
+ "NoCFGForSelect", ORE, TheLoop,
+ BB->getTerminator());
return false;
}
} else if (BB != Header && !canIfConvertPHINodes(BB)) {
reportVectorizationFailure(
"Control flow cannot be substituted for a select",
"control flow cannot be substituted for a select",
- "NoCFGForSelect", BB->getTerminator());
+ "NoCFGForSelect", ORE, TheLoop,
+ BB->getTerminator());
return false;
}
}
@@ -1026,7 +1017,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
if (!Lp->getLoopPreheader()) {
reportVectorizationFailure("Loop doesn't have a legal pre-header",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -1037,7 +1028,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
if (Lp->getNumBackEdges() != 1) {
reportVectorizationFailure("The loop must have a single backedge",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -1048,7 +1039,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
if (!Lp->getExitingBlock()) {
reportVectorizationFailure("The loop must have an exiting block",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -1061,7 +1052,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
if (Lp->getExitingBlock() != Lp->getLoopLatch()) {
reportVectorizationFailure("The exiting block is not the loop latch",
"loop control flow is not understood by vectorizer",
- "CFGNotUnderstood");
+ "CFGNotUnderstood", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -1124,7 +1115,8 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
if (!canVectorizeOuterLoop()) {
reportVectorizationFailure("Unsupported outer loop",
"unsupported outer loop",
- "UnsupportedOuterLoop");
+ "UnsupportedOuterLoop",
+ ORE, TheLoop);
// TODO: Implement DoExtraAnalysis when subsequent legal checks support
// outer loops.
return false;
@@ -1176,7 +1168,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) {
reportVectorizationFailure("Too many SCEV checks needed",
"Too many SCEV assumptions need to be made and checked at runtime",
- "TooManySCEVRunTimeChecks");
+ "TooManySCEVRunTimeChecks", ORE, TheLoop);
if (DoExtraAnalysis)
Result = false;
else
@@ -1190,7 +1182,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
return Result;
}
-bool LoopVectorizationLegality::canFoldTailByMasking() {
+bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n");
@@ -1199,22 +1191,21 @@ bool LoopVectorizationLegality::canFoldTailByMasking() {
"No primary induction, cannot fold tail by masking",
"Missing a primary induction variable in the loop, which is "
"needed in order to fold tail by masking as required.",
- "NoPrimaryInduction");
+ "NoPrimaryInduction", ORE, TheLoop);
return false;
}
- // TODO: handle reductions when tail is folded by masking.
- if (!Reductions.empty()) {
- reportVectorizationFailure(
- "Loop has reductions, cannot fold tail by masking",
- "Cannot fold tail by masking in the presence of reductions.",
- "ReductionFoldingTailByMasking");
- return false;
- }
+ SmallPtrSet<const Value *, 8> ReductionLiveOuts;
- // TODO: handle outside users when tail is folded by masking.
+ for (auto &Reduction : *getReductionVars())
+ ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
+
+ // TODO: handle non-reduction outside users when tail is folded by masking.
for (auto *AE : AllowedExit) {
- // Check that all users of allowed exit values are inside the loop.
+ // Check that all users of allowed exit values are inside the loop or
+ // are the live-out of a reduction.
+ if (ReductionLiveOuts.count(AE))
+ continue;
for (User *U : AE->users()) {
Instruction *UI = cast<Instruction>(U);
if (TheLoop->contains(UI))
@@ -1222,7 +1213,7 @@ bool LoopVectorizationLegality::canFoldTailByMasking() {
reportVectorizationFailure(
"Cannot fold tail by masking, loop has an outside user for",
"Cannot fold tail by masking in the presence of live outs.",
- "LiveOutFoldingTailByMasking", UI);
+ "LiveOutFoldingTailByMasking", ORE, TheLoop, UI);
return false;
}
}
@@ -1233,11 +1224,12 @@ bool LoopVectorizationLegality::canFoldTailByMasking() {
// Check and mark all blocks for predication, including those that ordinarily
// do not need predication such as the header block.
for (BasicBlock *BB : TheLoop->blocks()) {
- if (!blockCanBePredicated(BB, SafePointers)) {
+ if (!blockCanBePredicated(BB, SafePointers, /* MaskAllLoads= */ true)) {
reportVectorizationFailure(
"Cannot fold tail by masking as required",
"control flow cannot be substituted for a select",
- "NoCFGForSelect", BB->getTerminator());
+ "NoCFGForSelect", ORE, TheLoop,
+ BB->getTerminator());
return false;
}
}
diff --git a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 97077cce83e3..a5e85f27fabf 100644
--- a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -228,11 +228,11 @@ public:
/// Plan how to best vectorize, return the best VF and its cost, or None if
/// vectorization and interleaving should be avoided up front.
- Optional<VectorizationFactor> plan(bool OptForSize, unsigned UserVF);
+ Optional<VectorizationFactor> plan(unsigned UserVF);
/// Use the VPlan-native path to plan how to best vectorize, return the best
/// VF and its cost.
- VectorizationFactor planInVPlanNativePath(bool OptForSize, unsigned UserVF);
+ VectorizationFactor planInVPlanNativePath(unsigned UserVF);
/// Finalize the best decision and dispose of all other VPlans.
void setBestPlan(unsigned VF, unsigned UF);
diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp
index 46265e3f3e13..8f0bf70f873c 100644
--- a/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -177,6 +177,14 @@ static cl::opt<unsigned> TinyTripCountVectorThreshold(
"value are vectorized only if no scalar iteration overheads "
"are incurred."));
+// Indicates that an epilogue is undesired, predication is preferred.
+// This means that the vectorizer will try to fold the loop-tail (epilogue)
+// into the loop and predicate the loop body accordingly.
+static cl::opt<bool> PreferPredicateOverEpilog(
+ "prefer-predicate-over-epilog", cl::init(false), cl::Hidden,
+ cl::desc("Indicate that an epilogue is undesired, predication should be "
+ "used instead."));
+
static cl::opt<bool> MaximizeBandwidth(
"vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden,
cl::desc("Maximize bandwidth when selecting vectorization factor which "
@@ -347,6 +355,29 @@ static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
: ConstantFP::get(Ty, C);
}
+/// Returns "best known" trip count for the specified loop \p L as defined by
+/// the following procedure:
+/// 1) Returns exact trip count if it is known.
+/// 2) Returns expected trip count according to profile data if any.
+/// 3) Returns upper bound estimate if it is known.
+/// 4) Returns None if all of the above failed.
+static Optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, Loop *L) {
+ // Check if exact trip count is known.
+ if (unsigned ExpectedTC = SE.getSmallConstantTripCount(L))
+ return ExpectedTC;
+
+ // Check if there is an expected trip count available from profile data.
+ if (LoopVectorizeWithBlockFrequency)
+ if (auto EstimatedTC = getLoopEstimatedTripCount(L))
+ return EstimatedTC;
+
+ // Check if upper bound estimate is known.
+ if (unsigned ExpectedTC = SE.getSmallConstantMaxTripCount(L))
+ return ExpectedTC;
+
+ return None;
+}
+
namespace llvm {
/// InnerLoopVectorizer vectorizes loops which contain only one basic
@@ -795,6 +826,59 @@ void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr)
B.SetCurrentDebugLocation(DebugLoc());
}
+/// Write a record \p DebugMsg about vectorization failure to the debug
+/// output stream. If \p I is passed, it is an instruction that prevents
+/// vectorization.
+#ifndef NDEBUG
+static void debugVectorizationFailure(const StringRef DebugMsg,
+ Instruction *I) {
+ dbgs() << "LV: Not vectorizing: " << DebugMsg;
+ if (I != nullptr)
+ dbgs() << " " << *I;
+ else
+ dbgs() << '.';
+ dbgs() << '\n';
+}
+#endif
+
+/// Create an analysis remark that explains why vectorization failed
+///
+/// \p PassName is the name of the pass (e.g. can be AlwaysPrint). \p
+/// RemarkName is the identifier for the remark. If \p I is passed it is an
+/// instruction that prevents vectorization. Otherwise \p TheLoop is used for
+/// the location of the remark. \return the remark object that can be
+/// streamed to.
+static OptimizationRemarkAnalysis createLVAnalysis(const char *PassName,
+ StringRef RemarkName, Loop *TheLoop, Instruction *I) {
+ Value *CodeRegion = TheLoop->getHeader();
+ DebugLoc DL = TheLoop->getStartLoc();
+
+ if (I) {
+ CodeRegion = I->getParent();
+ // If there is no debug location attached to the instruction, revert back to
+ // using the loop's.
+ if (I->getDebugLoc())
+ DL = I->getDebugLoc();
+ }
+
+ OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion);
+ R << "loop not vectorized: ";
+ return R;
+}
+
+namespace llvm {
+
+void reportVectorizationFailure(const StringRef DebugMsg,
+ const StringRef OREMsg, const StringRef ORETag,
+ OptimizationRemarkEmitter *ORE, Loop *TheLoop, Instruction *I) {
+ LLVM_DEBUG(debugVectorizationFailure(DebugMsg, I));
+ LoopVectorizeHints Hints(TheLoop, true /* doesn't matter */, *ORE);
+ ORE->emit(createLVAnalysis(Hints.vectorizeAnalysisPassName(),
+ ORETag, TheLoop, I) << OREMsg);
+}
+
+} // end namespace llvm
+
#ifndef NDEBUG
/// \return string containing a file name and a line # for the given loop.
static std::string getDebugLocString(const Loop *L) {
@@ -836,6 +920,26 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To,
namespace llvm {
+// Loop vectorization cost-model hints how the scalar epilogue loop should be
+// lowered.
+enum ScalarEpilogueLowering {
+
+ // The default: allowing scalar epilogues.
+ CM_ScalarEpilogueAllowed,
+
+ // Vectorization with OptForSize: don't allow epilogues.
+ CM_ScalarEpilogueNotAllowedOptSize,
+
+ // A special case of vectorisation with OptForSize: loops with a very small
+ // trip count are considered for vectorization under OptForSize, thereby
+ // making sure the cost of their loop body is dominant, free of runtime
+ // guards and scalar iteration overheads.
+ CM_ScalarEpilogueNotAllowedLowTripLoop,
+
+ // Loop hint predicate indicating an epilogue is undesired.
+ CM_ScalarEpilogueNotNeededUsePredicate
+};
+
/// LoopVectorizationCostModel - estimates the expected speedups due to
/// vectorization.
/// In many cases vectorization is not profitable. This can happen because of
@@ -845,20 +949,26 @@ namespace llvm {
/// different operations.
class LoopVectorizationCostModel {
public:
- LoopVectorizationCostModel(Loop *L, PredicatedScalarEvolution &PSE,
- LoopInfo *LI, LoopVectorizationLegality *Legal,
+ LoopVectorizationCostModel(ScalarEpilogueLowering SEL, Loop *L,
+ PredicatedScalarEvolution &PSE, LoopInfo *LI,
+ LoopVectorizationLegality *Legal,
const TargetTransformInfo &TTI,
const TargetLibraryInfo *TLI, DemandedBits *DB,
AssumptionCache *AC,
OptimizationRemarkEmitter *ORE, const Function *F,
const LoopVectorizeHints *Hints,
InterleavedAccessInfo &IAI)
- : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
- AC(AC), ORE(ORE), TheFunction(F), Hints(Hints), InterleaveInfo(IAI) {}
+ : ScalarEpilogueStatus(SEL), TheLoop(L), PSE(PSE), LI(LI), Legal(Legal),
+ TTI(TTI), TLI(TLI), DB(DB), AC(AC), ORE(ORE), TheFunction(F),
+ Hints(Hints), InterleaveInfo(IAI) {}
/// \return An upper bound for the vectorization factor, or None if
/// vectorization and interleaving should be avoided up front.
- Optional<unsigned> computeMaxVF(bool OptForSize);
+ Optional<unsigned> computeMaxVF();
+
+ /// \return True if runtime checks are required for vectorization, and false
+ /// otherwise.
+ bool runtimeChecksRequired();
/// \return The most profitable vectorization factor and the cost of that VF.
/// This method checks every power of two up to MaxVF. If UserVF is not ZERO
@@ -881,8 +991,7 @@ public:
/// If interleave count has been specified by metadata it will be returned.
/// Otherwise, the interleave count is computed and returned. VF and LoopCost
/// are the selected vectorization factor and the cost of the selected VF.
- unsigned selectInterleaveCount(bool OptForSize, unsigned VF,
- unsigned LoopCost);
+ unsigned selectInterleaveCount(unsigned VF, unsigned LoopCost);
/// Memory access instruction may be vectorized in more than one way.
/// Form of instruction after vectorization depends on cost.
@@ -897,10 +1006,11 @@ public:
/// of a loop.
struct RegisterUsage {
/// Holds the number of loop invariant values that are used in the loop.
- unsigned LoopInvariantRegs;
-
+ /// The key is ClassID of target-provided register class.
+ SmallMapVector<unsigned, unsigned, 4> LoopInvariantRegs;
/// Holds the maximum number of concurrent live intervals in the loop.
- unsigned MaxLocalUsers;
+ /// The key is ClassID of target-provided register class.
+ SmallMapVector<unsigned, unsigned, 4> MaxLocalUsers;
};
/// \return Returns information about the register usages of the loop for the
@@ -1080,14 +1190,16 @@ public:
/// Returns true if the target machine supports masked store operation
/// for the given \p DataType and kind of access to \p Ptr.
- bool isLegalMaskedStore(Type *DataType, Value *Ptr) {
- return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedStore(DataType);
+ bool isLegalMaskedStore(Type *DataType, Value *Ptr, MaybeAlign Alignment) {
+ return Legal->isConsecutivePtr(Ptr) &&
+ TTI.isLegalMaskedStore(DataType, Alignment);
}
/// Returns true if the target machine supports masked load operation
/// for the given \p DataType and kind of access to \p Ptr.
- bool isLegalMaskedLoad(Type *DataType, Value *Ptr) {
- return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedLoad(DataType);
+ bool isLegalMaskedLoad(Type *DataType, Value *Ptr, MaybeAlign Alignment) {
+ return Legal->isConsecutivePtr(Ptr) &&
+ TTI.isLegalMaskedLoad(DataType, Alignment);
}
/// Returns true if the target machine supports masked scatter operation
@@ -1157,11 +1269,14 @@ public:
/// to handle accesses with gaps, and there is nothing preventing us from
/// creating a scalar epilogue.
bool requiresScalarEpilogue() const {
- return IsScalarEpilogueAllowed && InterleaveInfo.requiresScalarEpilogue();
+ return isScalarEpilogueAllowed() && InterleaveInfo.requiresScalarEpilogue();
}
- /// Returns true if a scalar epilogue is not allowed due to optsize.
- bool isScalarEpilogueAllowed() const { return IsScalarEpilogueAllowed; }
+ /// Returns true if a scalar epilogue is not allowed due to optsize or a
+ /// loop hint annotation.
+ bool isScalarEpilogueAllowed() const {
+ return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed;
+ }
/// Returns true if all loop blocks should be masked to fold tail loop.
bool foldTailByMasking() const { return FoldTailByMasking; }
@@ -1187,7 +1302,7 @@ private:
/// \return An upper bound for the vectorization factor, larger than zero.
/// One is returned if vectorization should best be avoided due to cost.
- unsigned computeFeasibleMaxVF(bool OptForSize, unsigned ConstTripCount);
+ unsigned computeFeasibleMaxVF(unsigned ConstTripCount);
/// The vectorization cost is a combination of the cost itself and a boolean
/// indicating whether any of the contributing operations will actually
@@ -1246,15 +1361,6 @@ private:
/// should be used.
bool useEmulatedMaskMemRefHack(Instruction *I);
- /// Create an analysis remark that explains why vectorization failed
- ///
- /// \p RemarkName is the identifier for the remark. \return the remark object
- /// that can be streamed to.
- OptimizationRemarkAnalysis createMissedAnalysis(StringRef RemarkName) {
- return createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(),
- RemarkName, TheLoop);
- }
-
/// Map of scalar integer values to the smallest bitwidth they can be legally
/// represented as. The vector equivalents of these values should be truncated
/// to this type.
@@ -1270,13 +1376,13 @@ private:
SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization;
/// Records whether it is allowed to have the original scalar loop execute at
- /// least once. This may be needed as a fallback loop in case runtime
+ /// least once. This may be needed as a fallback loop in case runtime
/// aliasing/dependence checks fail, or to handle the tail/remainder
/// iterations when the trip count is unknown or doesn't divide by the VF,
/// or as a peel-loop to handle gaps in interleave-groups.
/// Under optsize and when the trip count is very small we don't allow any
/// iterations to execute in the scalar loop.
- bool IsScalarEpilogueAllowed = true;
+ ScalarEpilogueLowering ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
/// All blocks of loop are to be masked to fold tail of scalar iterations.
bool FoldTailByMasking = false;
@@ -1496,7 +1602,7 @@ struct LoopVectorize : public FunctionPass {
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI();
auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI() : nullptr;
+ auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
@@ -2253,12 +2359,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr,
Type *ScalarDataTy = getMemInstValueType(Instr);
Type *DataTy = VectorType::get(ScalarDataTy, VF);
Value *Ptr = getLoadStorePointerOperand(Instr);
- unsigned Alignment = getLoadStoreAlignment(Instr);
// An alignment of 0 means target abi alignment. We need to use the scalar's
// target abi alignment in such a case.
const DataLayout &DL = Instr->getModule()->getDataLayout();
- if (!Alignment)
- Alignment = DL.getABITypeAlignment(ScalarDataTy);
+ const Align Alignment =
+ DL.getValueOrABITypeAlignment(getLoadStoreAlignment(Instr), ScalarDataTy);
unsigned AddressSpace = getLoadStoreAddressSpace(Instr);
// Determine if the pointer operand of the access is either consecutive or
@@ -2322,8 +2427,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr,
if (CreateGatherScatter) {
Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr;
Value *VectorGep = getOrCreateVectorValue(Ptr, Part);
- NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment,
- MaskPart);
+ NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep,
+ Alignment.value(), MaskPart);
} else {
if (Reverse) {
// If we store to reverse consecutive memory locations, then we need
@@ -2334,10 +2439,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr,
}
auto *VecPtr = CreateVecPtr(Part, Ptr);
if (isMaskRequired)
- NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment,
- Mask[Part]);
+ NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr,
+ Alignment.value(), Mask[Part]);
else
- NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment);
+ NewSI =
+ Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment.value());
}
addMetadata(NewSI, SI);
}
@@ -2352,18 +2458,18 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr,
if (CreateGatherScatter) {
Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr;
Value *VectorGep = getOrCreateVectorValue(Ptr, Part);
- NewLI = Builder.CreateMaskedGather(VectorGep, Alignment, MaskPart,
+ NewLI = Builder.CreateMaskedGather(VectorGep, Alignment.value(), MaskPart,
nullptr, "wide.masked.gather");
addMetadata(NewLI, LI);
} else {
auto *VecPtr = CreateVecPtr(Part, Ptr);
if (isMaskRequired)
- NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part],
+ NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment.value(), Mask[Part],
UndefValue::get(DataTy),
"wide.masked.load");
else
- NewLI =
- Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load");
+ NewLI = Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment.value(),
+ "wide.load");
// Add metadata to the load, but setVectorValue to the reverse shuffle.
addMetadata(NewLI, LI);
@@ -2615,8 +2721,9 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
if (C->isZero())
return;
- assert(!Cost->foldTailByMasking() &&
- "Cannot SCEV check stride or overflow when folding tail");
+ assert(!BB->getParent()->hasOptSize() &&
+ "Cannot SCEV check stride or overflow when optimizing for size");
+
// Create a new block containing the stride check.
BB->setName("vector.scevcheck");
auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
@@ -2649,7 +2756,20 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) {
if (!MemRuntimeCheck)
return;
- assert(!Cost->foldTailByMasking() && "Cannot check memory when folding tail");
+ if (BB->getParent()->hasOptSize()) {
+ assert(Cost->Hints->getForce() == LoopVectorizeHints::FK_Enabled &&
+ "Cannot emit memory checks when optimizing for size, unless forced "
+ "to vectorize.");
+ ORE->emit([&]() {
+ return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationCodeSize",
+ L->getStartLoc(), L->getHeader())
+ << "Code-size may be reduced by not forcing "
+ "vectorization, or by source-code modifications "
+ "eliminating the need for runtime checks "
+ "(e.g., adding 'restrict').";
+ });
+ }
+
// Create a new block containing the memory check.
BB->setName("vector.memcheck");
auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
@@ -2666,7 +2786,7 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) {
// We currently don't use LoopVersioning for the actual loop cloning but we
// still use it to add the noalias metadata.
- LVer = llvm::make_unique<LoopVersioning>(*Legal->getLAI(), OrigLoop, LI, DT,
+ LVer = std::make_unique<LoopVersioning>(*Legal->getLAI(), OrigLoop, LI, DT,
PSE.getSE());
LVer->prepareNoAliasMetadata();
}
@@ -3598,6 +3718,26 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
setDebugLocFromInst(Builder, LoopExitInst);
+ // If tail is folded by masking, the vector value to leave the loop should be
+ // a Select choosing between the vectorized LoopExitInst and vectorized Phi,
+ // instead of the former.
+ if (Cost->foldTailByMasking()) {
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ Value *VecLoopExitInst =
+ VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
+ Value *Sel = nullptr;
+ for (User *U : VecLoopExitInst->users()) {
+ if (isa<SelectInst>(U)) {
+ assert(!Sel && "Reduction exit feeding two selects");
+ Sel = U;
+ } else
+ assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select");
+ }
+ assert(Sel && "Reduction exit feeds no select");
+ VectorLoopValueMap.resetVectorValue(LoopExitInst, Part, Sel);
+ }
+ }
+
// If the vector reduction can be performed in a smaller type, we truncate
// then extend the loop exit value to enable InstCombine to evaluate the
// entire expression in the smaller type.
@@ -4064,7 +4204,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) {
case Instruction::FCmp: {
// Widen compares. Generate vector compares.
bool FCmp = (I.getOpcode() == Instruction::FCmp);
- auto *Cmp = dyn_cast<CmpInst>(&I);
+ auto *Cmp = cast<CmpInst>(&I);
setDebugLocFromInst(Builder, Cmp);
for (unsigned Part = 0; Part < UF; ++Part) {
Value *A = getOrCreateVectorValue(Cmp->getOperand(0), Part);
@@ -4097,7 +4237,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) {
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
- auto *CI = dyn_cast<CastInst>(&I);
+ auto *CI = cast<CastInst>(&I);
setDebugLocFromInst(Builder, CI);
/// Vectorize casts.
@@ -4421,9 +4561,10 @@ bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I, unsigne
"Widening decision should be ready at this moment");
return WideningDecision == CM_Scalarize;
}
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
return isa<LoadInst>(I) ?
- !(isLegalMaskedLoad(Ty, Ptr) || isLegalMaskedGather(Ty))
- : !(isLegalMaskedStore(Ty, Ptr) || isLegalMaskedScatter(Ty));
+ !(isLegalMaskedLoad(Ty, Ptr, Alignment) || isLegalMaskedGather(Ty))
+ : !(isLegalMaskedStore(Ty, Ptr, Alignment) || isLegalMaskedScatter(Ty));
}
case Instruction::UDiv:
case Instruction::SDiv:
@@ -4452,10 +4593,10 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I,
// Check if masking is required.
// A Group may need masking for one of two reasons: it resides in a block that
// needs predication, or it was decided to use masking to deal with gaps.
- bool PredicatedAccessRequiresMasking =
+ bool PredicatedAccessRequiresMasking =
Legal->blockNeedsPredication(I->getParent()) && Legal->isMaskRequired(I);
- bool AccessWithGapsRequiresMasking =
- Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed;
+ bool AccessWithGapsRequiresMasking =
+ Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed();
if (!PredicatedAccessRequiresMasking && !AccessWithGapsRequiresMasking)
return true;
@@ -4466,8 +4607,9 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I,
"Masked interleave-groups for predicated accesses are not enabled.");
auto *Ty = getMemInstValueType(I);
- return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty)
- : TTI.isLegalMaskedStore(Ty);
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
+ return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty, Alignment)
+ : TTI.isLegalMaskedStore(Ty, Alignment);
}
bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(Instruction *I,
@@ -4675,82 +4817,96 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) {
Uniforms[VF].insert(Worklist.begin(), Worklist.end());
}
-Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) {
- if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) {
- // TODO: It may by useful to do since it's still likely to be dynamically
- // uniform if the target can skip.
- LLVM_DEBUG(
- dbgs() << "LV: Not inserting runtime ptr check for divergent target");
-
- ORE->emit(
- createMissedAnalysis("CantVersionLoopWithDivergentTarget")
- << "runtime pointer checks needed. Not enabled for divergent target");
-
- return None;
- }
-
- unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
- if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize.
- return computeFeasibleMaxVF(OptForSize, TC);
+bool LoopVectorizationCostModel::runtimeChecksRequired() {
+ LLVM_DEBUG(dbgs() << "LV: Performing code size checks.\n");
if (Legal->getRuntimePointerChecking()->Need) {
- ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize")
- << "runtime pointer checks needed. Enable vectorization of this "
- "loop with '#pragma clang loop vectorize(enable)' when "
- "compiling with -Os/-Oz");
- LLVM_DEBUG(
- dbgs()
- << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n");
- return None;
+ reportVectorizationFailure("Runtime ptr check is required with -Os/-Oz",
+ "runtime pointer checks needed. Enable vectorization of this "
+ "loop with '#pragma clang loop vectorize(enable)' when "
+ "compiling with -Os/-Oz",
+ "CantVersionLoopWithOptForSize", ORE, TheLoop);
+ return true;
}
if (!PSE.getUnionPredicate().getPredicates().empty()) {
- ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize")
- << "runtime SCEV checks needed. Enable vectorization of this "
- "loop with '#pragma clang loop vectorize(enable)' when "
- "compiling with -Os/-Oz");
- LLVM_DEBUG(
- dbgs()
- << "LV: Aborting. Runtime SCEV check is required with -Os/-Oz.\n");
- return None;
+ reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz",
+ "runtime SCEV checks needed. Enable vectorization of this "
+ "loop with '#pragma clang loop vectorize(enable)' when "
+ "compiling with -Os/-Oz",
+ "CantVersionLoopWithOptForSize", ORE, TheLoop);
+ return true;
}
// FIXME: Avoid specializing for stride==1 instead of bailing out.
if (!Legal->getLAI()->getSymbolicStrides().empty()) {
- ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize")
- << "runtime stride == 1 checks needed. Enable vectorization of "
- "this loop with '#pragma clang loop vectorize(enable)' when "
- "compiling with -Os/-Oz");
- LLVM_DEBUG(
- dbgs()
- << "LV: Aborting. Runtime stride check is required with -Os/-Oz.\n");
+ reportVectorizationFailure("Runtime stride check is required with -Os/-Oz",
+ "runtime stride == 1 checks needed. Enable vectorization of "
+ "this loop with '#pragma clang loop vectorize(enable)' when "
+ "compiling with -Os/-Oz",
+ "CantVersionLoopWithOptForSize", ORE, TheLoop);
+ return true;
+ }
+
+ return false;
+}
+
+Optional<unsigned> LoopVectorizationCostModel::computeMaxVF() {
+ if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) {
+ // TODO: It may by useful to do since it's still likely to be dynamically
+ // uniform if the target can skip.
+ reportVectorizationFailure(
+ "Not inserting runtime ptr check for divergent target",
+ "runtime pointer checks needed. Not enabled for divergent target",
+ "CantVersionLoopWithDivergentTarget", ORE, TheLoop);
return None;
}
- // If we optimize the program for size, avoid creating the tail loop.
+ unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
-
if (TC == 1) {
- ORE->emit(createMissedAnalysis("SingleIterationLoop")
- << "loop trip count is one, irrelevant for vectorization");
- LLVM_DEBUG(dbgs() << "LV: Aborting, single iteration (non) loop.\n");
+ reportVectorizationFailure("Single iteration (non) loop",
+ "loop trip count is one, irrelevant for vectorization",
+ "SingleIterationLoop", ORE, TheLoop);
return None;
}
- // Record that scalar epilogue is not allowed.
- LLVM_DEBUG(dbgs() << "LV: Not allowing scalar epilogue due to -Os/-Oz.\n");
+ switch (ScalarEpilogueStatus) {
+ case CM_ScalarEpilogueAllowed:
+ return computeFeasibleMaxVF(TC);
+ case CM_ScalarEpilogueNotNeededUsePredicate:
+ LLVM_DEBUG(
+ dbgs() << "LV: vector predicate hint/switch found.\n"
+ << "LV: Not allowing scalar epilogue, creating predicated "
+ << "vector loop.\n");
+ break;
+ case CM_ScalarEpilogueNotAllowedLowTripLoop:
+ // fallthrough as a special case of OptForSize
+ case CM_ScalarEpilogueNotAllowedOptSize:
+ if (ScalarEpilogueStatus == CM_ScalarEpilogueNotAllowedOptSize)
+ LLVM_DEBUG(
+ dbgs() << "LV: Not allowing scalar epilogue due to -Os/-Oz.\n");
+ else
+ LLVM_DEBUG(dbgs() << "LV: Not allowing scalar epilogue due to low trip "
+ << "count.\n");
+
+ // Bail if runtime checks are required, which are not good when optimising
+ // for size.
+ if (runtimeChecksRequired())
+ return None;
+ break;
+ }
- IsScalarEpilogueAllowed = !OptForSize;
+ // Now try the tail folding
- // We don't create an epilogue when optimizing for size.
// Invalidate interleave groups that require an epilogue if we can't mask
// the interleave-group.
- if (!useMaskedInterleavedAccesses(TTI))
+ if (!useMaskedInterleavedAccesses(TTI))
InterleaveInfo.invalidateGroupsRequiringScalarEpilogue();
- unsigned MaxVF = computeFeasibleMaxVF(OptForSize, TC);
-
+ unsigned MaxVF = computeFeasibleMaxVF(TC);
if (TC > 0 && TC % MaxVF == 0) {
+ // Accept MaxVF if we do not have a tail.
LLVM_DEBUG(dbgs() << "LV: No tail will remain for any chosen VF.\n");
return MaxVF;
}
@@ -4759,28 +4915,30 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) {
// found modulo the vectorization factor is not zero, try to fold the tail
// by masking.
// FIXME: look for a smaller MaxVF that does divide TC rather than masking.
- if (Legal->canFoldTailByMasking()) {
+ if (Legal->prepareToFoldTailByMasking()) {
FoldTailByMasking = true;
return MaxVF;
}
if (TC == 0) {
- ORE->emit(
- createMissedAnalysis("UnknownLoopCountComplexCFG")
- << "unable to calculate the loop count due to complex control flow");
+ reportVectorizationFailure(
+ "Unable to calculate the loop count due to complex control flow",
+ "unable to calculate the loop count due to complex control flow",
+ "UnknownLoopCountComplexCFG", ORE, TheLoop);
return None;
}
- ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize")
- << "cannot optimize for size and vectorize at the same time. "
- "Enable vectorization of this loop with '#pragma clang loop "
- "vectorize(enable)' when compiling with -Os/-Oz");
+ reportVectorizationFailure(
+ "Cannot optimize for size and vectorize at the same time.",
+ "cannot optimize for size and vectorize at the same time. "
+ "Enable vectorization of this loop with '#pragma clang loop "
+ "vectorize(enable)' when compiling with -Os/-Oz",
+ "NoTailLoopWithOptForSize", ORE, TheLoop);
return None;
}
unsigned
-LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize,
- unsigned ConstTripCount) {
+LoopVectorizationCostModel::computeFeasibleMaxVF(unsigned ConstTripCount) {
MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
unsigned SmallestType, WidestType;
std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
@@ -4818,8 +4976,8 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize,
}
unsigned MaxVF = MaxVectorSize;
- if (TTI.shouldMaximizeVectorBandwidth(OptForSize) ||
- (MaximizeBandwidth && !OptForSize)) {
+ if (TTI.shouldMaximizeVectorBandwidth(!isScalarEpilogueAllowed()) ||
+ (MaximizeBandwidth && isScalarEpilogueAllowed())) {
// Collect all viable vectorization factors larger than the default MaxVF
// (i.e. MaxVectorSize).
SmallVector<unsigned, 8> VFs;
@@ -4832,9 +4990,14 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize,
// Select the largest VF which doesn't require more registers than existing
// ones.
- unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true);
for (int i = RUs.size() - 1; i >= 0; --i) {
- if (RUs[i].MaxLocalUsers <= TargetNumRegisters) {
+ bool Selected = true;
+ for (auto& pair : RUs[i].MaxLocalUsers) {
+ unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first);
+ if (pair.second > TargetNumRegisters)
+ Selected = false;
+ }
+ if (Selected) {
MaxVF = VFs[i];
break;
}
@@ -4886,10 +5049,9 @@ LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) {
}
if (!EnableCondStoresVectorization && NumPredStores) {
- ORE->emit(createMissedAnalysis("ConditionalStore")
- << "store that is conditionally executed prevents vectorization");
- LLVM_DEBUG(
- dbgs() << "LV: No vectorization. There are conditional stores.\n");
+ reportVectorizationFailure("There are conditional stores.",
+ "store that is conditionally executed prevents vectorization",
+ "ConditionalStore", ORE, TheLoop);
Width = 1;
Cost = ScalarCost;
}
@@ -4958,8 +5120,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {
return {MinWidth, MaxWidth};
}
-unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
- unsigned VF,
+unsigned LoopVectorizationCostModel::selectInterleaveCount(unsigned VF,
unsigned LoopCost) {
// -- The interleave heuristics --
// We interleave the loop in order to expose ILP and reduce the loop overhead.
@@ -4975,8 +5136,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
// 3. We don't interleave if we think that we will spill registers to memory
// due to the increased register pressure.
- // When we optimize for size, we don't interleave.
- if (OptForSize)
+ if (!isScalarEpilogueAllowed())
return 1;
// We used the distance for the interleave count.
@@ -4988,22 +5148,12 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
if (TC > 1 && TC < TinyTripCountInterleaveThreshold)
return 1;
- unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF > 1);
- LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters
- << " registers\n");
-
- if (VF == 1) {
- if (ForceTargetNumScalarRegs.getNumOccurrences() > 0)
- TargetNumRegisters = ForceTargetNumScalarRegs;
- } else {
- if (ForceTargetNumVectorRegs.getNumOccurrences() > 0)
- TargetNumRegisters = ForceTargetNumVectorRegs;
- }
-
RegisterUsage R = calculateRegisterUsage({VF})[0];
// We divide by these constants so assume that we have at least one
// instruction that uses at least one register.
- R.MaxLocalUsers = std::max(R.MaxLocalUsers, 1U);
+ for (auto& pair : R.MaxLocalUsers) {
+ pair.second = std::max(pair.second, 1U);
+ }
// We calculate the interleave count using the following formula.
// Subtract the number of loop invariants from the number of available
@@ -5016,13 +5166,35 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
// We also want power of two interleave counts to ensure that the induction
// variable of the vector loop wraps to zero, when tail is folded by masking;
// this currently happens when OptForSize, in which case IC is set to 1 above.
- unsigned IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs) /
- R.MaxLocalUsers);
+ unsigned IC = UINT_MAX;
- // Don't count the induction variable as interleaved.
- if (EnableIndVarRegisterHeur)
- IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs - 1) /
- std::max(1U, (R.MaxLocalUsers - 1)));
+ for (auto& pair : R.MaxLocalUsers) {
+ unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first);
+ LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters
+ << " registers of "
+ << TTI.getRegisterClassName(pair.first) << " register class\n");
+ if (VF == 1) {
+ if (ForceTargetNumScalarRegs.getNumOccurrences() > 0)
+ TargetNumRegisters = ForceTargetNumScalarRegs;
+ } else {
+ if (ForceTargetNumVectorRegs.getNumOccurrences() > 0)
+ TargetNumRegisters = ForceTargetNumVectorRegs;
+ }
+ unsigned MaxLocalUsers = pair.second;
+ unsigned LoopInvariantRegs = 0;
+ if (R.LoopInvariantRegs.find(pair.first) != R.LoopInvariantRegs.end())
+ LoopInvariantRegs = R.LoopInvariantRegs[pair.first];
+
+ unsigned TmpIC = PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs) / MaxLocalUsers);
+ // Don't count the induction variable as interleaved.
+ if (EnableIndVarRegisterHeur) {
+ TmpIC =
+ PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs - 1) /
+ std::max(1U, (MaxLocalUsers - 1)));
+ }
+
+ IC = std::min(IC, TmpIC);
+ }
// Clamp the interleave ranges to reasonable counts.
unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF);
@@ -5036,6 +5208,14 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
MaxInterleaveCount = ForceTargetMaxVectorInterleaveFactor;
}
+ // If the trip count is constant, limit the interleave count to be less than
+ // the trip count divided by VF.
+ if (TC > 0) {
+ assert(TC >= VF && "VF exceeds trip count?");
+ if ((TC / VF) < MaxInterleaveCount)
+ MaxInterleaveCount = (TC / VF);
+ }
+
// If we did not calculate the cost for VF (because the user selected the VF)
// then we calculate the cost of VF here.
if (LoopCost == 0)
@@ -5044,7 +5224,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize,
assert(LoopCost && "Non-zero loop cost expected");
// Clamp the calculated IC to be between the 1 and the max interleave count
- // that the target allows.
+ // that the target and trip count allows.
if (IC > MaxInterleaveCount)
IC = MaxInterleaveCount;
else if (IC < 1)
@@ -5196,7 +5376,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) {
const DataLayout &DL = TheFunction->getParent()->getDataLayout();
SmallVector<RegisterUsage, 8> RUs(VFs.size());
- SmallVector<unsigned, 8> MaxUsages(VFs.size(), 0);
+ SmallVector<SmallMapVector<unsigned, unsigned, 4>, 8> MaxUsages(VFs.size());
LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
@@ -5226,21 +5406,45 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) {
// For each VF find the maximum usage of registers.
for (unsigned j = 0, e = VFs.size(); j < e; ++j) {
+ // Count the number of live intervals.
+ SmallMapVector<unsigned, unsigned, 4> RegUsage;
+
if (VFs[j] == 1) {
- MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size());
- continue;
+ for (auto Inst : OpenIntervals) {
+ unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType());
+ if (RegUsage.find(ClassID) == RegUsage.end())
+ RegUsage[ClassID] = 1;
+ else
+ RegUsage[ClassID] += 1;
+ }
+ } else {
+ collectUniformsAndScalars(VFs[j]);
+ for (auto Inst : OpenIntervals) {
+ // Skip ignored values for VF > 1.
+ if (VecValuesToIgnore.find(Inst) != VecValuesToIgnore.end())
+ continue;
+ if (isScalarAfterVectorization(Inst, VFs[j])) {
+ unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType());
+ if (RegUsage.find(ClassID) == RegUsage.end())
+ RegUsage[ClassID] = 1;
+ else
+ RegUsage[ClassID] += 1;
+ } else {
+ unsigned ClassID = TTI.getRegisterClassForType(true, Inst->getType());
+ if (RegUsage.find(ClassID) == RegUsage.end())
+ RegUsage[ClassID] = GetRegUsage(Inst->getType(), VFs[j]);
+ else
+ RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[j]);
+ }
+ }
}
- collectUniformsAndScalars(VFs[j]);
- // Count the number of live intervals.
- unsigned RegUsage = 0;
- for (auto Inst : OpenIntervals) {
- // Skip ignored values for VF > 1.
- if (VecValuesToIgnore.find(Inst) != VecValuesToIgnore.end() ||
- isScalarAfterVectorization(Inst, VFs[j]))
- continue;
- RegUsage += GetRegUsage(Inst->getType(), VFs[j]);
+
+ for (auto& pair : RegUsage) {
+ if (MaxUsages[j].find(pair.first) != MaxUsages[j].end())
+ MaxUsages[j][pair.first] = std::max(MaxUsages[j][pair.first], pair.second);
+ else
+ MaxUsages[j][pair.first] = pair.second;
}
- MaxUsages[j] = std::max(MaxUsages[j], RegUsage);
}
LLVM_DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # "
@@ -5251,18 +5455,34 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) {
}
for (unsigned i = 0, e = VFs.size(); i < e; ++i) {
- unsigned Invariant = 0;
- if (VFs[i] == 1)
- Invariant = LoopInvariants.size();
- else {
- for (auto Inst : LoopInvariants)
- Invariant += GetRegUsage(Inst->getType(), VFs[i]);
+ SmallMapVector<unsigned, unsigned, 4> Invariant;
+
+ for (auto Inst : LoopInvariants) {
+ unsigned Usage = VFs[i] == 1 ? 1 : GetRegUsage(Inst->getType(), VFs[i]);
+ unsigned ClassID = TTI.getRegisterClassForType(VFs[i] > 1, Inst->getType());
+ if (Invariant.find(ClassID) == Invariant.end())
+ Invariant[ClassID] = Usage;
+ else
+ Invariant[ClassID] += Usage;
}
- LLVM_DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n');
- LLVM_DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n');
- LLVM_DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant
- << '\n');
+ LLVM_DEBUG({
+ dbgs() << "LV(REG): VF = " << VFs[i] << '\n';
+ dbgs() << "LV(REG): Found max usage: " << MaxUsages[i].size()
+ << " item\n";
+ for (const auto &pair : MaxUsages[i]) {
+ dbgs() << "LV(REG): RegisterClass: "
+ << TTI.getRegisterClassName(pair.first) << ", " << pair.second
+ << " registers\n";
+ }
+ dbgs() << "LV(REG): Found invariant usage: " << Invariant.size()
+ << " item\n";
+ for (const auto &pair : Invariant) {
+ dbgs() << "LV(REG): RegisterClass: "
+ << TTI.getRegisterClassName(pair.first) << ", " << pair.second
+ << " registers\n";
+ }
+ });
RU.LoopInvariantRegs = Invariant;
RU.MaxLocalUsers = MaxUsages[i];
@@ -5511,7 +5731,6 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I,
Type *ValTy = getMemInstValueType(I);
auto SE = PSE.getSE();
- unsigned Alignment = getLoadStoreAlignment(I);
unsigned AS = getLoadStoreAddressSpace(I);
Value *Ptr = getLoadStorePointerOperand(I);
Type *PtrTy = ToVectorTy(Ptr->getType(), VF);
@@ -5525,9 +5744,9 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I,
// Don't pass *I here, since it is scalar but will actually be part of a
// vectorized loop where the user of it is a vectorized instruction.
- Cost += VF *
- TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment,
- AS);
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
+ Cost += VF * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(),
+ Alignment ? Alignment->value() : 0, AS);
// Get the overhead of the extractelement and insertelement instructions
// we might create due to scalarization.
@@ -5552,18 +5771,20 @@ unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I,
unsigned VF) {
Type *ValTy = getMemInstValueType(I);
Type *VectorTy = ToVectorTy(ValTy, VF);
- unsigned Alignment = getLoadStoreAlignment(I);
Value *Ptr = getLoadStorePointerOperand(I);
unsigned AS = getLoadStoreAddressSpace(I);
int ConsecutiveStride = Legal->isConsecutivePtr(Ptr);
assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) &&
"Stride should be 1 or -1 for consecutive memory access");
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
unsigned Cost = 0;
if (Legal->isMaskRequired(I))
- Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS);
+ Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy,
+ Alignment ? Alignment->value() : 0, AS);
else
- Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, I);
+ Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy,
+ Alignment ? Alignment->value() : 0, AS, I);
bool Reverse = ConsecutiveStride < 0;
if (Reverse)
@@ -5575,33 +5796,37 @@ unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I,
unsigned VF) {
Type *ValTy = getMemInstValueType(I);
Type *VectorTy = ToVectorTy(ValTy, VF);
- unsigned Alignment = getLoadStoreAlignment(I);
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
unsigned AS = getLoadStoreAddressSpace(I);
if (isa<LoadInst>(I)) {
return TTI.getAddressComputationCost(ValTy) +
- TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) +
+ TTI.getMemoryOpCost(Instruction::Load, ValTy,
+ Alignment ? Alignment->value() : 0, AS) +
TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy);
}
StoreInst *SI = cast<StoreInst>(I);
bool isLoopInvariantStoreValue = Legal->isUniform(SI->getValueOperand());
return TTI.getAddressComputationCost(ValTy) +
- TTI.getMemoryOpCost(Instruction::Store, ValTy, Alignment, AS) +
- (isLoopInvariantStoreValue ? 0 : TTI.getVectorInstrCost(
- Instruction::ExtractElement,
- VectorTy, VF - 1));
+ TTI.getMemoryOpCost(Instruction::Store, ValTy,
+ Alignment ? Alignment->value() : 0, AS) +
+ (isLoopInvariantStoreValue
+ ? 0
+ : TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy,
+ VF - 1));
}
unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I,
unsigned VF) {
Type *ValTy = getMemInstValueType(I);
Type *VectorTy = ToVectorTy(ValTy, VF);
- unsigned Alignment = getLoadStoreAlignment(I);
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
Value *Ptr = getLoadStorePointerOperand(I);
return TTI.getAddressComputationCost(VectorTy) +
TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr,
- Legal->isMaskRequired(I), Alignment);
+ Legal->isMaskRequired(I),
+ Alignment ? Alignment->value() : 0);
}
unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
@@ -5626,8 +5851,8 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
}
// Calculate the cost of the whole interleaved group.
- bool UseMaskForGaps =
- Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed;
+ bool UseMaskForGaps =
+ Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed();
unsigned Cost = TTI.getInterleavedMemoryOpCost(
I->getOpcode(), WideVecTy, Group->getFactor(), Indices,
Group->getAlignment(), AS, Legal->isMaskRequired(I), UseMaskForGaps);
@@ -5648,11 +5873,12 @@ unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I,
// moment.
if (VF == 1) {
Type *ValTy = getMemInstValueType(I);
- unsigned Alignment = getLoadStoreAlignment(I);
+ const MaybeAlign Alignment = getLoadStoreAlignment(I);
unsigned AS = getLoadStoreAddressSpace(I);
return TTI.getAddressComputationCost(ValTy) +
- TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I);
+ TTI.getMemoryOpCost(I->getOpcode(), ValTy,
+ Alignment ? Alignment->value() : 0, AS, I);
}
return getWideningCost(I, VF);
}
@@ -6167,8 +6393,7 @@ static unsigned determineVPlanVF(const unsigned WidestVectorRegBits,
}
VectorizationFactor
-LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize,
- unsigned UserVF) {
+LoopVectorizationPlanner::planInVPlanNativePath(unsigned UserVF) {
unsigned VF = UserVF;
// Outer loop handling: They may require CFG and instruction level
// transformations before even evaluating whether vectorization is profitable.
@@ -6207,10 +6432,9 @@ LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize,
return VectorizationFactor::Disabled();
}
-Optional<VectorizationFactor> LoopVectorizationPlanner::plan(bool OptForSize,
- unsigned UserVF) {
+Optional<VectorizationFactor> LoopVectorizationPlanner::plan(unsigned UserVF) {
assert(OrigLoop->empty() && "Inner loop expected.");
- Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize);
+ Optional<unsigned> MaybeMaxVF = CM.computeMaxVF();
if (!MaybeMaxVF) // Cases that should not to be vectorized nor interleaved.
return None;
@@ -6840,8 +7064,15 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF,
// If the tail is to be folded by masking, the primary induction variable
// needs to be represented in VPlan for it to model early-exit masking.
- if (CM.foldTailByMasking())
+ // Also, both the Phi and the live-out instruction of each reduction are
+ // required in order to introduce a select between them in VPlan.
+ if (CM.foldTailByMasking()) {
NeedDef.insert(Legal->getPrimaryInduction());
+ for (auto &Reduction : *Legal->getReductionVars()) {
+ NeedDef.insert(Reduction.first);
+ NeedDef.insert(Reduction.second.getLoopExitInstr());
+ }
+ }
// Collect instructions from the original loop that will become trivially dead
// in the vectorized loop. We don't need to vectorize these instructions. For
@@ -6873,7 +7104,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// Create a dummy pre-entry VPBasicBlock to start building the VPlan.
VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry");
- auto Plan = llvm::make_unique<VPlan>(VPBB);
+ auto Plan = std::make_unique<VPlan>(VPBB);
VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, Legal, CM, Builder);
// Represent values that will have defs inside VPlan.
@@ -6968,6 +7199,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
VPBlockUtils::disconnectBlocks(PreEntry, Entry);
delete PreEntry;
+ // Finally, if tail is folded by masking, introduce selects between the phi
+ // and the live-out instruction of each reduction, at the end of the latch.
+ if (CM.foldTailByMasking()) {
+ Builder.setInsertPoint(VPBB);
+ auto *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan);
+ for (auto &Reduction : *Legal->getReductionVars()) {
+ VPValue *Phi = Plan->getVPValue(Reduction.first);
+ VPValue *Red = Plan->getVPValue(Reduction.second.getLoopExitInstr());
+ Builder.createNaryOp(Instruction::Select, {Cond, Red, Phi});
+ }
+ }
+
std::string PlanName;
raw_string_ostream RSO(PlanName);
unsigned VF = Range.Start;
@@ -6993,7 +7236,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
// Create new empty VPlan
- auto Plan = llvm::make_unique<VPlan>();
+ auto Plan = std::make_unique<VPlan>();
// Build hierarchical CFG
VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan);
@@ -7199,6 +7442,20 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
State.ILV->vectorizeMemoryInstruction(&Instr, &MaskValues);
}
+static ScalarEpilogueLowering
+getScalarEpilogueLowering(Function *F, Loop *L, LoopVectorizeHints &Hints,
+ ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) {
+ ScalarEpilogueLowering SEL = CM_ScalarEpilogueAllowed;
+ if (Hints.getForce() != LoopVectorizeHints::FK_Enabled &&
+ (F->hasOptSize() ||
+ llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI)))
+ SEL = CM_ScalarEpilogueNotAllowedOptSize;
+ else if (PreferPredicateOverEpilog || Hints.getPredicate())
+ SEL = CM_ScalarEpilogueNotNeededUsePredicate;
+
+ return SEL;
+}
+
// Process the loop in the VPlan-native vectorization path. This path builds
// VPlan upfront in the vectorization pipeline, which allows to apply
// VPlan-to-VPlan transformations from the very beginning without modifying the
@@ -7213,7 +7470,9 @@ static bool processLoopInVPlanNativePath(
assert(EnableVPlanNativePath && "VPlan-native path is disabled.");
Function *F = L->getHeader()->getParent();
InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI());
- LoopVectorizationCostModel CM(L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F,
+ ScalarEpilogueLowering SEL = getScalarEpilogueLowering(F, L, Hints, PSI, BFI);
+
+ LoopVectorizationCostModel CM(SEL, L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F,
&Hints, IAI);
// Use the planner for outer loop vectorization.
// TODO: CM is not used at this point inside the planner. Turn CM into an
@@ -7223,15 +7482,8 @@ static bool processLoopInVPlanNativePath(
// Get user vectorization factor.
const unsigned UserVF = Hints.getWidth();
- // Check the function attributes and profiles to find out if this function
- // should be optimized for size.
- bool OptForSize =
- Hints.getForce() != LoopVectorizeHints::FK_Enabled &&
- (F->hasOptSize() ||
- llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI));
-
// Plan how to best vectorize, return the best VF and its cost.
- const VectorizationFactor VF = LVP.planInVPlanNativePath(OptForSize, UserVF);
+ const VectorizationFactor VF = LVP.planInVPlanNativePath(UserVF);
// If we are stress testing VPlan builds, do not attempt to generate vector
// code. Masked vector code generation support will follow soon.
@@ -7310,10 +7562,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Check the function attributes and profiles to find out if this function
// should be optimized for size.
- bool OptForSize =
- Hints.getForce() != LoopVectorizeHints::FK_Enabled &&
- (F->hasOptSize() ||
- llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI));
+ ScalarEpilogueLowering SEL = getScalarEpilogueLowering(F, L, Hints, PSI, BFI);
// Entrance to the VPlan-native vectorization path. Outer loops are processed
// here. They may require CFG and instruction level transformations before
@@ -7325,36 +7574,11 @@ bool LoopVectorizePass::processLoop(Loop *L) {
ORE, BFI, PSI, Hints);
assert(L->empty() && "Inner loop expected.");
+
// Check the loop for a trip count threshold: vectorize loops with a tiny trip
// count by optimizing for size, to minimize overheads.
- // Prefer constant trip counts over profile data, over upper bound estimate.
- unsigned ExpectedTC = 0;
- bool HasExpectedTC = false;
- if (const SCEVConstant *ConstExits =
- dyn_cast<SCEVConstant>(SE->getBackedgeTakenCount(L))) {
- const APInt &ExitsCount = ConstExits->getAPInt();
- // We are interested in small values for ExpectedTC. Skip over those that
- // can't fit an unsigned.
- if (ExitsCount.ult(std::numeric_limits<unsigned>::max())) {
- ExpectedTC = static_cast<unsigned>(ExitsCount.getZExtValue()) + 1;
- HasExpectedTC = true;
- }
- }
- // ExpectedTC may be large because it's bound by a variable. Check
- // profiling information to validate we should vectorize.
- if (!HasExpectedTC && LoopVectorizeWithBlockFrequency) {
- auto EstimatedTC = getLoopEstimatedTripCount(L);
- if (EstimatedTC) {
- ExpectedTC = *EstimatedTC;
- HasExpectedTC = true;
- }
- }
- if (!HasExpectedTC) {
- ExpectedTC = SE->getSmallConstantMaxTripCount(L);
- HasExpectedTC = (ExpectedTC > 0);
- }
-
- if (HasExpectedTC && ExpectedTC < TinyTripCountVectorThreshold) {
+ auto ExpectedTC = getSmallBestKnownTC(*SE, L);
+ if (ExpectedTC && *ExpectedTC < TinyTripCountVectorThreshold) {
LLVM_DEBUG(dbgs() << "LV: Found a loop with a very small trip count. "
<< "This loop is worth vectorizing only if no scalar "
<< "iteration overheads are incurred.");
@@ -7362,10 +7586,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n");
else {
LLVM_DEBUG(dbgs() << "\n");
- // Loops with a very small trip count are considered for vectorization
- // under OptForSize, thereby making sure the cost of their loop body is
- // dominant, free of runtime guards and scalar iteration overheads.
- OptForSize = true;
+ SEL = CM_ScalarEpilogueNotAllowedLowTripLoop;
}
}
@@ -7374,11 +7595,10 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// an integer loop and the vector instructions selected are purely integer
// vector instructions?
if (F->hasFnAttribute(Attribute::NoImplicitFloat)) {
- LLVM_DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat"
- "attribute is used.\n");
- ORE->emit(createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(),
- "NoImplicitFloat", L)
- << "loop not vectorized due to NoImplicitFloat attribute");
+ reportVectorizationFailure(
+ "Can't vectorize when the NoImplicitFloat attribute is used",
+ "loop not vectorized due to NoImplicitFloat attribute",
+ "NoImplicitFloat", ORE, L);
Hints.emitRemarkWithHints();
return false;
}
@@ -7389,11 +7609,10 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// additional fp-math flags can help.
if (Hints.isPotentiallyUnsafe() &&
TTI->isFPVectorizationPotentiallyUnsafe()) {
- LLVM_DEBUG(
- dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n");
- ORE->emit(
- createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L)
- << "loop not vectorized due to unsafe FP support.");
+ reportVectorizationFailure(
+ "Potentially unsafe FP op prevents vectorization",
+ "loop not vectorized due to unsafe FP support.",
+ "UnsafeFP", ORE, L);
Hints.emitRemarkWithHints();
return false;
}
@@ -7411,8 +7630,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
}
// Use the cost model.
- LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F,
- &Hints, IAI);
+ LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE,
+ F, &Hints, IAI);
CM.collectValuesToIgnore();
// Use the planner for vectorization.
@@ -7422,7 +7641,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
unsigned UserVF = Hints.getWidth();
// Plan how to best vectorize, return the best VF and its cost.
- Optional<VectorizationFactor> MaybeVF = LVP.plan(OptForSize, UserVF);
+ Optional<VectorizationFactor> MaybeVF = LVP.plan(UserVF);
VectorizationFactor VF = VectorizationFactor::Disabled();
unsigned IC = 1;
@@ -7431,7 +7650,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
if (MaybeVF) {
VF = *MaybeVF;
// Select the interleave count.
- IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost);
+ IC = CM.selectInterleaveCount(VF.Width, VF.Cost);
}
// Identify the diagnostic messages that should be produced.
@@ -7609,7 +7828,8 @@ bool LoopVectorizePass::runImpl(
// The second condition is necessary because, even if the target has no
// vector registers, loop vectorization may still enable scalar
// interleaving.
- if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2)
+ if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true)) &&
+ TTI->getMaxInterleaveFactor(1) < 2)
return false;
bool Changed = false;
diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 27a86c0bca91..974eff9974d9 100644
--- a/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -194,10 +194,13 @@ static bool allSameBlock(ArrayRef<Value *> VL) {
return true;
}
-/// \returns True if all of the values in \p VL are constants.
+/// \returns True if all of the values in \p VL are constants (but not
+/// globals/constant expressions).
static bool allConstant(ArrayRef<Value *> VL) {
+ // Constant expressions and globals can't be vectorized like normal integer/FP
+ // constants.
for (Value *i : VL)
- if (!isa<Constant>(i))
+ if (!isa<Constant>(i) || isa<ConstantExpr>(i) || isa<GlobalValue>(i))
return false;
return true;
}
@@ -486,6 +489,7 @@ namespace slpvectorizer {
/// Bottom Up SLP Vectorizer.
class BoUpSLP {
struct TreeEntry;
+ struct ScheduleData;
public:
using ValueList = SmallVector<Value *, 8>;
@@ -614,6 +618,15 @@ public:
/// vectorizable. We do not vectorize such trees.
bool isTreeTinyAndNotFullyVectorizable() const;
+ /// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values
+ /// can be load combined in the backend. Load combining may not be allowed in
+ /// the IR optimizer, so we do not want to alter the pattern. For example,
+ /// partially transforming a scalar bswap() pattern into vector code is
+ /// effectively impossible for the backend to undo.
+ /// TODO: If load combining is allowed in the IR optimizer, this analysis
+ /// may not be necessary.
+ bool isLoadCombineReductionCandidate(unsigned ReductionOpcode) const;
+
OptimizationRemarkEmitter *getORE() { return ORE; }
/// This structure holds any data we need about the edges being traversed
@@ -1117,6 +1130,14 @@ public:
#endif
};
+ /// Checks if the instruction is marked for deletion.
+ bool isDeleted(Instruction *I) const { return DeletedInstructions.count(I); }
+
+ /// Marks values operands for later deletion by replacing them with Undefs.
+ void eraseInstructions(ArrayRef<Value *> AV);
+
+ ~BoUpSLP();
+
private:
/// Checks if all users of \p I are the part of the vectorization tree.
bool areAllUsersVectorized(Instruction *I) const;
@@ -1153,8 +1174,7 @@ private:
/// Set the Builder insert point to one after the last instruction in
/// the bundle
- void setInsertPointAfterBundle(ArrayRef<Value *> VL,
- const InstructionsState &S);
+ void setInsertPointAfterBundle(TreeEntry *E);
/// \returns a vector from a collection of scalars in \p VL.
Value *Gather(ArrayRef<Value *> VL, VectorType *Ty);
@@ -1220,27 +1240,37 @@ private:
/// reordering of operands during buildTree_rec() and vectorizeTree().
SmallVector<ValueList, 2> Operands;
+ /// The main/alternate instruction.
+ Instruction *MainOp = nullptr;
+ Instruction *AltOp = nullptr;
+
public:
/// Set this bundle's \p OpIdx'th operand to \p OpVL.
- void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL,
- ArrayRef<unsigned> ReuseShuffleIndices) {
+ void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL) {
if (Operands.size() < OpIdx + 1)
Operands.resize(OpIdx + 1);
assert(Operands[OpIdx].size() == 0 && "Already resized?");
Operands[OpIdx].resize(Scalars.size());
for (unsigned Lane = 0, E = Scalars.size(); Lane != E; ++Lane)
- Operands[OpIdx][Lane] = (!ReuseShuffleIndices.empty())
- ? OpVL[ReuseShuffleIndices[Lane]]
- : OpVL[Lane];
- }
-
- /// If there is a user TreeEntry, then set its operand.
- void trySetUserTEOperand(const EdgeInfo &UserTreeIdx,
- ArrayRef<Value *> OpVL,
- ArrayRef<unsigned> ReuseShuffleIndices) {
- if (UserTreeIdx.UserTE)
- UserTreeIdx.UserTE->setOperand(UserTreeIdx.EdgeIdx, OpVL,
- ReuseShuffleIndices);
+ Operands[OpIdx][Lane] = OpVL[Lane];
+ }
+
+ /// Set the operands of this bundle in their original order.
+ void setOperandsInOrder() {
+ assert(Operands.empty() && "Already initialized?");
+ auto *I0 = cast<Instruction>(Scalars[0]);
+ Operands.resize(I0->getNumOperands());
+ unsigned NumLanes = Scalars.size();
+ for (unsigned OpIdx = 0, NumOperands = I0->getNumOperands();
+ OpIdx != NumOperands; ++OpIdx) {
+ Operands[OpIdx].resize(NumLanes);
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ auto *I = cast<Instruction>(Scalars[Lane]);
+ assert(I->getNumOperands() == NumOperands &&
+ "Expected same number of operands");
+ Operands[OpIdx][Lane] = I->getOperand(OpIdx);
+ }
+ }
}
/// \returns the \p OpIdx operand of this TreeEntry.
@@ -1249,6 +1279,9 @@ private:
return Operands[OpIdx];
}
+ /// \returns the number of operands.
+ unsigned getNumOperands() const { return Operands.size(); }
+
/// \return the single \p OpIdx operand.
Value *getSingleOperand(unsigned OpIdx) const {
assert(OpIdx < Operands.size() && "Off bounds");
@@ -1256,6 +1289,58 @@ private:
return Operands[OpIdx][0];
}
+ /// Some of the instructions in the list have alternate opcodes.
+ bool isAltShuffle() const {
+ return getOpcode() != getAltOpcode();
+ }
+
+ bool isOpcodeOrAlt(Instruction *I) const {
+ unsigned CheckedOpcode = I->getOpcode();
+ return (getOpcode() == CheckedOpcode ||
+ getAltOpcode() == CheckedOpcode);
+ }
+
+ /// Chooses the correct key for scheduling data. If \p Op has the same (or
+ /// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is
+ /// \p OpValue.
+ Value *isOneOf(Value *Op) const {
+ auto *I = dyn_cast<Instruction>(Op);
+ if (I && isOpcodeOrAlt(I))
+ return Op;
+ return MainOp;
+ }
+
+ void setOperations(const InstructionsState &S) {
+ MainOp = S.MainOp;
+ AltOp = S.AltOp;
+ }
+
+ Instruction *getMainOp() const {
+ return MainOp;
+ }
+
+ Instruction *getAltOp() const {
+ return AltOp;
+ }
+
+ /// The main/alternate opcodes for the list of instructions.
+ unsigned getOpcode() const {
+ return MainOp ? MainOp->getOpcode() : 0;
+ }
+
+ unsigned getAltOpcode() const {
+ return AltOp ? AltOp->getOpcode() : 0;
+ }
+
+ /// Update operations state of this entry if reorder occurred.
+ bool updateStateIfReorder() {
+ if (ReorderIndices.empty())
+ return false;
+ InstructionsState S = getSameOpcode(Scalars, ReorderIndices.front());
+ setOperations(S);
+ return true;
+ }
+
#ifndef NDEBUG
/// Debug printer.
LLVM_DUMP_METHOD void dump() const {
@@ -1269,6 +1354,8 @@ private:
for (Value *V : Scalars)
dbgs().indent(2) << *V << "\n";
dbgs() << "NeedToGather: " << NeedToGather << "\n";
+ dbgs() << "MainOp: " << *MainOp << "\n";
+ dbgs() << "AltOp: " << *AltOp << "\n";
dbgs() << "VectorizedValue: ";
if (VectorizedValue)
dbgs() << *VectorizedValue;
@@ -1279,12 +1366,12 @@ private:
if (ReuseShuffleIndices.empty())
dbgs() << "Emtpy";
else
- for (unsigned Idx : ReuseShuffleIndices)
- dbgs() << Idx << ", ";
+ for (unsigned ReuseIdx : ReuseShuffleIndices)
+ dbgs() << ReuseIdx << ", ";
dbgs() << "\n";
dbgs() << "ReorderIndices: ";
- for (unsigned Idx : ReorderIndices)
- dbgs() << Idx << ", ";
+ for (unsigned ReorderIdx : ReorderIndices)
+ dbgs() << ReorderIdx << ", ";
dbgs() << "\n";
dbgs() << "UserTreeIndices: ";
for (const auto &EInfo : UserTreeIndices)
@@ -1295,11 +1382,13 @@ private:
};
/// Create a new VectorizableTree entry.
- TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized,
+ TreeEntry *newTreeEntry(ArrayRef<Value *> VL, Optional<ScheduleData *> Bundle,
+ const InstructionsState &S,
const EdgeInfo &UserTreeIdx,
ArrayRef<unsigned> ReuseShuffleIndices = None,
ArrayRef<unsigned> ReorderIndices = None) {
- VectorizableTree.push_back(llvm::make_unique<TreeEntry>(VectorizableTree));
+ bool Vectorized = (bool)Bundle;
+ VectorizableTree.push_back(std::make_unique<TreeEntry>(VectorizableTree));
TreeEntry *Last = VectorizableTree.back().get();
Last->Idx = VectorizableTree.size() - 1;
Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end());
@@ -1307,11 +1396,22 @@ private:
Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(),
ReuseShuffleIndices.end());
Last->ReorderIndices = ReorderIndices;
+ Last->setOperations(S);
if (Vectorized) {
for (int i = 0, e = VL.size(); i != e; ++i) {
assert(!getTreeEntry(VL[i]) && "Scalar already in tree!");
- ScalarToTreeEntry[VL[i]] = Last->Idx;
- }
+ ScalarToTreeEntry[VL[i]] = Last;
+ }
+ // Update the scheduler bundle to point to this TreeEntry.
+ unsigned Lane = 0;
+ for (ScheduleData *BundleMember = Bundle.getValue(); BundleMember;
+ BundleMember = BundleMember->NextInBundle) {
+ BundleMember->TE = Last;
+ BundleMember->Lane = Lane;
+ ++Lane;
+ }
+ assert((!Bundle.getValue() || Lane == VL.size()) &&
+ "Bundle and VL out of sync");
} else {
MustGather.insert(VL.begin(), VL.end());
}
@@ -1319,7 +1419,6 @@ private:
if (UserTreeIdx.UserTE)
Last->UserTreeIndices.push_back(UserTreeIdx);
- Last->trySetUserTEOperand(UserTreeIdx, VL, ReuseShuffleIndices);
return Last;
}
@@ -1340,19 +1439,19 @@ private:
TreeEntry *getTreeEntry(Value *V) {
auto I = ScalarToTreeEntry.find(V);
if (I != ScalarToTreeEntry.end())
- return VectorizableTree[I->second].get();
+ return I->second;
return nullptr;
}
const TreeEntry *getTreeEntry(Value *V) const {
auto I = ScalarToTreeEntry.find(V);
if (I != ScalarToTreeEntry.end())
- return VectorizableTree[I->second].get();
+ return I->second;
return nullptr;
}
/// Maps a specific scalar to its tree entry.
- SmallDenseMap<Value*, int> ScalarToTreeEntry;
+ SmallDenseMap<Value*, TreeEntry *> ScalarToTreeEntry;
/// A list of scalars that we found that we need to keep as scalars.
ValueSet MustGather;
@@ -1408,15 +1507,14 @@ private:
/// This is required to ensure that there are no incorrect collisions in the
/// AliasCache, which can happen if a new instruction is allocated at the
/// same address as a previously deleted instruction.
- void eraseInstruction(Instruction *I) {
- I->removeFromParent();
- I->dropAllReferences();
- DeletedInstructions.emplace_back(I);
+ void eraseInstruction(Instruction *I, bool ReplaceOpsWithUndef = false) {
+ auto It = DeletedInstructions.try_emplace(I, ReplaceOpsWithUndef).first;
+ It->getSecond() = It->getSecond() && ReplaceOpsWithUndef;
}
/// Temporary store for deleted instructions. Instructions will be deleted
/// eventually when the BoUpSLP is destructed.
- SmallVector<unique_value, 8> DeletedInstructions;
+ DenseMap<Instruction *, bool> DeletedInstructions;
/// A list of values that need to extracted out of the tree.
/// This list holds pairs of (Internal Scalar : External User). External User
@@ -1453,6 +1551,8 @@ private:
UnscheduledDepsInBundle = UnscheduledDeps;
clearDependencies();
OpValue = OpVal;
+ TE = nullptr;
+ Lane = -1;
}
/// Returns true if the dependency information has been calculated.
@@ -1559,6 +1659,12 @@ private:
/// Opcode of the current instruction in the schedule data.
Value *OpValue = nullptr;
+
+ /// The TreeEntry that this instruction corresponds to.
+ TreeEntry *TE = nullptr;
+
+ /// The lane of this node in the TreeEntry.
+ int Lane = -1;
};
#ifndef NDEBUG
@@ -1633,10 +1739,9 @@ private:
continue;
}
// Handle the def-use chain dependencies.
- for (Use &U : BundleMember->Inst->operands()) {
- auto *I = dyn_cast<Instruction>(U.get());
- if (!I)
- continue;
+
+ // Decrement the unscheduled counter and insert to ready list if ready.
+ auto &&DecrUnsched = [this, &ReadyList](Instruction *I) {
doForAllOpcodes(I, [&ReadyList](ScheduleData *OpDef) {
if (OpDef && OpDef->hasValidDependencies() &&
OpDef->incrementUnscheduledDeps(-1) == 0) {
@@ -1651,6 +1756,24 @@ private:
<< "SLP: gets ready (def): " << *DepBundle << "\n");
}
});
+ };
+
+ // If BundleMember is a vector bundle, its operands may have been
+ // reordered duiring buildTree(). We therefore need to get its operands
+ // through the TreeEntry.
+ if (TreeEntry *TE = BundleMember->TE) {
+ int Lane = BundleMember->Lane;
+ assert(Lane >= 0 && "Lane not set");
+ for (unsigned OpIdx = 0, NumOperands = TE->getNumOperands();
+ OpIdx != NumOperands; ++OpIdx)
+ if (auto *I = dyn_cast<Instruction>(TE->getOperand(OpIdx)[Lane]))
+ DecrUnsched(I);
+ } else {
+ // If BundleMember is a stand-alone instruction, no operand reordering
+ // has taken place, so we directly access its operands.
+ for (Use &U : BundleMember->Inst->operands())
+ if (auto *I = dyn_cast<Instruction>(U.get()))
+ DecrUnsched(I);
}
// Handle the memory dependencies.
for (ScheduleData *MemoryDepSD : BundleMember->MemoryDependencies) {
@@ -1697,8 +1820,11 @@ private:
/// Checks if a bundle of instructions can be scheduled, i.e. has no
/// cyclic dependencies. This is only a dry-run, no instructions are
/// actually moved at this stage.
- bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP,
- const InstructionsState &S);
+ /// \returns the scheduling bundle. The returned Optional value is non-None
+ /// if \p VL is allowed to be scheduled.
+ Optional<ScheduleData *>
+ tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP,
+ const InstructionsState &S);
/// Un-bundles a group of instructions.
void cancelScheduling(ArrayRef<Value *> VL, Value *OpValue);
@@ -1945,6 +2071,30 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
} // end namespace llvm
+BoUpSLP::~BoUpSLP() {
+ for (const auto &Pair : DeletedInstructions) {
+ // Replace operands of ignored instructions with Undefs in case if they were
+ // marked for deletion.
+ if (Pair.getSecond()) {
+ Value *Undef = UndefValue::get(Pair.getFirst()->getType());
+ Pair.getFirst()->replaceAllUsesWith(Undef);
+ }
+ Pair.getFirst()->dropAllReferences();
+ }
+ for (const auto &Pair : DeletedInstructions) {
+ assert(Pair.getFirst()->use_empty() &&
+ "trying to erase instruction with users.");
+ Pair.getFirst()->eraseFromParent();
+ }
+}
+
+void BoUpSLP::eraseInstructions(ArrayRef<Value *> AV) {
+ for (auto *V : AV) {
+ if (auto *I = dyn_cast<Instruction>(V))
+ eraseInstruction(I, /*ReplaceWithUndef=*/true);
+ };
+}
+
void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
ArrayRef<Value *> UserIgnoreLst) {
ExtraValueToDebugLocsMap ExternallyUsedValues;
@@ -2026,28 +2176,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
InstructionsState S = getSameOpcode(VL);
if (Depth == RecursionMaxDepth) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
// Don't handle vectors.
if (S.OpValue->getType()->isVectorTy()) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue))
if (SI->getValueOperand()->getType()->isVectorTy()) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
// If all of the operands are identical or constant we have a simple solution.
if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode()) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
@@ -2055,11 +2205,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// the same block.
// Don't vectorize ephemeral values.
- for (unsigned i = 0, e = VL.size(); i != e; ++i) {
- if (EphValues.count(VL[i])) {
- LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i]
+ for (Value *V : VL) {
+ if (EphValues.count(V)) {
+ LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
<< ") is ephemeral.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
}
@@ -2069,7 +2219,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n");
if (!E->isSame(VL)) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
// Record the reuse of the tree node. FIXME, currently this is only used to
@@ -2077,19 +2227,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
E->UserTreeIndices.push_back(UserTreeIdx);
LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue
<< ".\n");
- E->trySetUserTEOperand(UserTreeIdx, VL, None);
return;
}
// Check that none of the instructions in the bundle are already in the tree.
- for (unsigned i = 0, e = VL.size(); i != e; ++i) {
- auto *I = dyn_cast<Instruction>(VL[i]);
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
if (getTreeEntry(I)) {
- LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i]
+ LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
<< ") is already in tree.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
}
@@ -2097,10 +2246,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// If any of the scalars is marked as a value that needs to stay scalar, then
// we need to gather the scalars.
// The reduction nodes (stored in UserIgnoreList) also should stay scalar.
- for (unsigned i = 0, e = VL.size(); i != e; ++i) {
- if (MustGather.count(VL[i]) || is_contained(UserIgnoreList, VL[i])) {
+ for (Value *V : VL) {
+ if (MustGather.count(V) || is_contained(UserIgnoreList, V)) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
}
@@ -2114,7 +2263,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// Don't go into unreachable blocks. They may contain instructions with
// dependency cycles which confuse the final scheduling.
LLVM_DEBUG(dbgs() << "SLP: bundle in unreachable block.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
@@ -2128,13 +2277,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (Res.second)
UniqueValues.emplace_back(V);
}
- if (UniqueValues.size() == VL.size()) {
+ size_t NumUniqueScalarValues = UniqueValues.size();
+ if (NumUniqueScalarValues == VL.size()) {
ReuseShuffleIndicies.clear();
} else {
LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n");
- if (UniqueValues.size() <= 1 || !llvm::isPowerOf2_32(UniqueValues.size())) {
+ if (NumUniqueScalarValues <= 1 ||
+ !llvm::isPowerOf2_32(NumUniqueScalarValues)) {
LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n");
- newTreeEntry(VL, false, UserTreeIdx);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
return;
}
VL = UniqueValues;
@@ -2142,16 +2293,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
auto &BSRef = BlocksSchedules[BB];
if (!BSRef)
- BSRef = llvm::make_unique<BlockScheduling>(BB);
+ BSRef = std::make_unique<BlockScheduling>(BB);
BlockScheduling &BS = *BSRef.get();
- if (!BS.tryScheduleBundle(VL, this, S)) {
+ Optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S);
+ if (!Bundle) {
LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n");
assert((!BS.getScheduleData(VL0) ||
!BS.getScheduleData(VL0)->isPartOfBundle()) &&
"tryScheduleBundle should cancelScheduling on failure");
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n");
@@ -2160,7 +2313,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
(unsigned) Instruction::ShuffleVector : S.getOpcode();
switch (ShuffleOrOp) {
case Instruction::PHI: {
- PHINode *PH = dyn_cast<PHINode>(VL0);
+ auto *PH = cast<PHINode>(VL0);
// Check for terminator values (e.g. invoke).
for (unsigned j = 0; j < VL.size(); ++j)
@@ -2172,23 +2325,29 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
LLVM_DEBUG(dbgs()
<< "SLP: Need to swizzle PHINodes (terminator use).\n");
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE =
+ newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n");
+ // Keeps the reordered operands to avoid code duplication.
+ SmallVector<ValueList, 2> OperandsVec;
for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
ValueList Operands;
// Prepare the operand vector.
for (Value *j : VL)
Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock(
PH->getIncomingBlock(i)));
-
- buildTree_rec(Operands, Depth + 1, {TE, i});
+ TE->setOperand(i, Operands);
+ OperandsVec.push_back(Operands);
}
+ for (unsigned OpIdx = 0, OpE = OperandsVec.size(); OpIdx != OpE; ++OpIdx)
+ buildTree_rec(OperandsVec[OpIdx], Depth + 1, {TE, OpIdx});
return;
}
case Instruction::ExtractValue:
@@ -2198,13 +2357,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (Reuse) {
LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n");
++NumOpsWantToKeepOriginalOrder;
- newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx,
+ newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
// This is a special case, as it does not gather, but at the same time
// we are not extending buildTree_rec() towards the operands.
ValueList Op0;
Op0.assign(VL.size(), VL0->getOperand(0));
- VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies);
+ VectorizableTree.back()->setOperand(0, Op0);
return;
}
if (!CurrentOrder.empty()) {
@@ -2220,17 +2379,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
auto StoredCurrentOrderAndNum =
NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first;
++StoredCurrentOrderAndNum->getSecond();
- newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, ReuseShuffleIndicies,
+ newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies,
StoredCurrentOrderAndNum->getFirst());
// This is a special case, as it does not gather, but at the same time
// we are not extending buildTree_rec() towards the operands.
ValueList Op0;
Op0.assign(VL.size(), VL0->getOperand(0));
- VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies);
+ VectorizableTree.back()->setOperand(0, Op0);
return;
}
LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n");
- newTreeEntry(VL, /*Vectorized=*/false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
BS.cancelScheduling(VL, VL0);
return;
}
@@ -2246,7 +2407,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (DL->getTypeSizeInBits(ScalarTy) !=
DL->getTypeAllocSizeInBits(ScalarTy)) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n");
return;
}
@@ -2259,7 +2421,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
auto *L = cast<LoadInst>(V);
if (!L->isSimple()) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n");
return;
}
@@ -2289,15 +2452,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (CurrentOrder.empty()) {
// Original loads are consecutive and does not require reordering.
++NumOpsWantToKeepOriginalOrder;
- newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx,
- ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S,
+ UserTreeIdx, ReuseShuffleIndicies);
+ TE->setOperandsInOrder();
LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n");
} else {
// Need to reorder.
auto I = NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first;
++I->getSecond();
- newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx,
- ReuseShuffleIndicies, I->getFirst());
+ TreeEntry *TE =
+ newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies, I->getFirst());
+ TE->setOperandsInOrder();
LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n");
}
return;
@@ -2306,7 +2472,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n");
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
case Instruction::ZExt:
@@ -2322,24 +2489,27 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
case Instruction::FPTrunc:
case Instruction::BitCast: {
Type *SrcTy = VL0->getOperand(0)->getType();
- for (unsigned i = 0; i < VL.size(); ++i) {
- Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType();
+ for (Value *V : VL) {
+ Type *Ty = cast<Instruction>(V)->getOperand(0)->getType();
if (Ty != SrcTy || !isValidElementType(Ty)) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs()
<< "SLP: Gathering casts with different src types.\n");
return;
}
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n");
+ TE->setOperandsInOrder();
for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
ValueList Operands;
// Prepare the operand vector.
- for (Value *j : VL)
- Operands.push_back(cast<Instruction>(j)->getOperand(i));
+ for (Value *V : VL)
+ Operands.push_back(cast<Instruction>(V)->getOperand(i));
buildTree_rec(Operands, Depth + 1, {TE, i});
}
@@ -2351,19 +2521,21 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0);
Type *ComparedTy = VL0->getOperand(0)->getType();
- for (unsigned i = 1, e = VL.size(); i < e; ++i) {
- CmpInst *Cmp = cast<CmpInst>(VL[i]);
+ for (Value *V : VL) {
+ CmpInst *Cmp = cast<CmpInst>(V);
if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) ||
Cmp->getOperand(0)->getType() != ComparedTy) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs()
<< "SLP: Gathering cmp with different predicate.\n");
return;
}
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n");
ValueList Left, Right;
@@ -2384,7 +2556,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
Right.push_back(RHS);
}
}
-
+ TE->setOperand(0, Left);
+ TE->setOperand(1, Right);
buildTree_rec(Left, Depth + 1, {TE, 0});
buildTree_rec(Right, Depth + 1, {TE, 1});
return;
@@ -2409,7 +2582,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of un/bin op.\n");
// Sort operands of the instructions so that each side is more likely to
@@ -2417,11 +2591,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) {
ValueList Left, Right;
reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE);
+ TE->setOperand(0, Left);
+ TE->setOperand(1, Right);
buildTree_rec(Left, Depth + 1, {TE, 0});
buildTree_rec(Right, Depth + 1, {TE, 1});
return;
}
+ TE->setOperandsInOrder();
for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
ValueList Operands;
// Prepare the operand vector.
@@ -2434,11 +2611,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
case Instruction::GetElementPtr: {
// We don't combine GEPs with complicated (nested) indexing.
- for (unsigned j = 0; j < VL.size(); ++j) {
- if (cast<Instruction>(VL[j])->getNumOperands() != 2) {
+ for (Value *V : VL) {
+ if (cast<Instruction>(V)->getNumOperands() != 2) {
LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n");
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
}
@@ -2446,58 +2624,64 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// We can't combine several GEPs into one vector if they operate on
// different types.
Type *Ty0 = VL0->getOperand(0)->getType();
- for (unsigned j = 0; j < VL.size(); ++j) {
- Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType();
+ for (Value *V : VL) {
+ Type *CurTy = cast<Instruction>(V)->getOperand(0)->getType();
if (Ty0 != CurTy) {
LLVM_DEBUG(dbgs()
<< "SLP: not-vectorizable GEP (different types).\n");
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
}
// We don't combine GEPs with non-constant indexes.
- for (unsigned j = 0; j < VL.size(); ++j) {
- auto Op = cast<Instruction>(VL[j])->getOperand(1);
+ for (Value *V : VL) {
+ auto Op = cast<Instruction>(V)->getOperand(1);
if (!isa<ConstantInt>(Op)) {
LLVM_DEBUG(dbgs()
<< "SLP: not-vectorizable GEP (non-constant indexes).\n");
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
return;
}
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n");
+ TE->setOperandsInOrder();
for (unsigned i = 0, e = 2; i < e; ++i) {
ValueList Operands;
// Prepare the operand vector.
- for (Value *j : VL)
- Operands.push_back(cast<Instruction>(j)->getOperand(i));
+ for (Value *V : VL)
+ Operands.push_back(cast<Instruction>(V)->getOperand(i));
buildTree_rec(Operands, Depth + 1, {TE, i});
}
return;
}
case Instruction::Store: {
- // Check if the stores are consecutive or of we need to swizzle them.
+ // Check if the stores are consecutive or if we need to swizzle them.
for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n");
return;
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n");
ValueList Operands;
- for (Value *j : VL)
- Operands.push_back(cast<Instruction>(j)->getOperand(0));
-
+ for (Value *V : VL)
+ Operands.push_back(cast<Instruction>(V)->getOperand(0));
+ TE->setOperandsInOrder();
buildTree_rec(Operands, Depth + 1, {TE, 0});
return;
}
@@ -2509,7 +2693,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
if (!isTriviallyVectorizable(ID)) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
return;
}
@@ -2519,14 +2704,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
for (unsigned j = 0; j != NumArgs; ++j)
if (hasVectorInstrinsicScalarOpd(ID, j))
ScalarArgs[j] = CI->getArgOperand(j);
- for (unsigned i = 1, e = VL.size(); i != e; ++i) {
- CallInst *CI2 = dyn_cast<CallInst>(VL[i]);
+ for (Value *V : VL) {
+ CallInst *CI2 = dyn_cast<CallInst>(V);
if (!CI2 || CI2->getCalledFunction() != Int ||
getVectorIntrinsicIDForCall(CI2, TLI) != ID ||
!CI->hasIdenticalOperandBundleSchema(*CI2)) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i]
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
+ LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V
<< "\n");
return;
}
@@ -2537,7 +2723,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
Value *A1J = CI2->getArgOperand(j);
if (ScalarArgs[j] != A1J) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI
<< " argument " << ScalarArgs[j] << "!=" << A1J
<< "\n");
@@ -2551,19 +2738,22 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
CI->op_begin() + CI->getBundleOperandsEndIndex(),
CI2->op_begin() + CI2->getBundleOperandsStartIndex())) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:"
- << *CI << "!=" << *VL[i] << '\n');
+ << *CI << "!=" << *V << '\n');
return;
}
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
+ TE->setOperandsInOrder();
for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) {
ValueList Operands;
// Prepare the operand vector.
- for (Value *j : VL) {
- CallInst *CI2 = dyn_cast<CallInst>(j);
+ for (Value *V : VL) {
+ auto *CI2 = cast<CallInst>(V);
Operands.push_back(CI2->getArgOperand(i));
}
buildTree_rec(Operands, Depth + 1, {TE, i});
@@ -2575,27 +2765,32 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// then do not vectorize this instruction.
if (!S.isAltShuffle()) {
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");
return;
}
- auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");
// Reorder operands if reordering would enable vectorization.
if (isa<BinaryOperator>(VL0)) {
ValueList Left, Right;
reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE);
+ TE->setOperand(0, Left);
+ TE->setOperand(1, Right);
buildTree_rec(Left, Depth + 1, {TE, 0});
buildTree_rec(Right, Depth + 1, {TE, 1});
return;
}
+ TE->setOperandsInOrder();
for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
ValueList Operands;
// Prepare the operand vector.
- for (Value *j : VL)
- Operands.push_back(cast<Instruction>(j)->getOperand(i));
+ for (Value *V : VL)
+ Operands.push_back(cast<Instruction>(V)->getOperand(i));
buildTree_rec(Operands, Depth + 1, {TE, i});
}
@@ -2603,7 +2798,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
default:
BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies);
+ newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");
return;
}
@@ -2738,7 +2934,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
return ReuseShuffleCost +
TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
}
- if (getSameOpcode(VL).getOpcode() == Instruction::ExtractElement &&
+ if (E->getOpcode() == Instruction::ExtractElement &&
allSameType(VL) && allSameBlock(VL)) {
Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isShuffle(VL);
if (ShuffleKind.hasValue()) {
@@ -2761,11 +2957,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
}
return ReuseShuffleCost + getGatherCost(VL);
}
- InstructionsState S = getSameOpcode(VL);
- assert(S.getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL");
- Instruction *VL0 = cast<Instruction>(S.OpValue);
- unsigned ShuffleOrOp = S.isAltShuffle() ?
- (unsigned) Instruction::ShuffleVector : S.getOpcode();
+ assert(E->getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL");
+ Instruction *VL0 = E->getMainOp();
+ unsigned ShuffleOrOp =
+ E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
switch (ShuffleOrOp) {
case Instruction::PHI:
return 0;
@@ -2851,7 +3046,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
case Instruction::BitCast: {
Type *SrcTy = VL0->getOperand(0)->getType();
int ScalarEltCost =
- TTI->getCastInstrCost(S.getOpcode(), ScalarTy, SrcTy, VL0);
+ TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, VL0);
if (NeedToShuffleReuses) {
ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost;
}
@@ -2864,7 +3059,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
// Check if the values are candidates to demote.
if (!MinBWs.count(VL0) || VecTy != SrcVecTy) {
VecCost = ReuseShuffleCost +
- TTI->getCastInstrCost(S.getOpcode(), VecTy, SrcVecTy, VL0);
+ TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, VL0);
}
return VecCost - ScalarCost;
}
@@ -2872,14 +3067,14 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
case Instruction::ICmp:
case Instruction::Select: {
// Calculate the cost of this instruction.
- int ScalarEltCost = TTI->getCmpSelInstrCost(S.getOpcode(), ScalarTy,
+ int ScalarEltCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy,
Builder.getInt1Ty(), VL0);
if (NeedToShuffleReuses) {
ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost;
}
VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
int ScalarCost = VecTy->getNumElements() * ScalarEltCost;
- int VecCost = TTI->getCmpSelInstrCost(S.getOpcode(), VecTy, MaskTy, VL0);
+ int VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VL0);
return ReuseShuffleCost + VecCost - ScalarCost;
}
case Instruction::FNeg:
@@ -2940,12 +3135,12 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
SmallVector<const Value *, 4> Operands(VL0->operand_values());
int ScalarEltCost = TTI->getArithmeticInstrCost(
- S.getOpcode(), ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands);
+ E->getOpcode(), ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands);
if (NeedToShuffleReuses) {
ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost;
}
int ScalarCost = VecTy->getNumElements() * ScalarEltCost;
- int VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy, Op1VK,
+ int VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, Op1VK,
Op2VK, Op1VP, Op2VP, Operands);
return ReuseShuffleCost + VecCost - ScalarCost;
}
@@ -3027,11 +3222,11 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
return ReuseShuffleCost + VecCallCost - ScalarCallCost;
}
case Instruction::ShuffleVector: {
- assert(S.isAltShuffle() &&
- ((Instruction::isBinaryOp(S.getOpcode()) &&
- Instruction::isBinaryOp(S.getAltOpcode())) ||
- (Instruction::isCast(S.getOpcode()) &&
- Instruction::isCast(S.getAltOpcode()))) &&
+ assert(E->isAltShuffle() &&
+ ((Instruction::isBinaryOp(E->getOpcode()) &&
+ Instruction::isBinaryOp(E->getAltOpcode())) ||
+ (Instruction::isCast(E->getOpcode()) &&
+ Instruction::isCast(E->getAltOpcode()))) &&
"Invalid Shuffle Vector Operand");
int ScalarCost = 0;
if (NeedToShuffleReuses) {
@@ -3046,25 +3241,25 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
I, TargetTransformInfo::TCK_RecipThroughput);
}
}
- for (Value *i : VL) {
- Instruction *I = cast<Instruction>(i);
- assert(S.isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
+ for (Value *V : VL) {
+ Instruction *I = cast<Instruction>(V);
+ assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
ScalarCost += TTI->getInstructionCost(
I, TargetTransformInfo::TCK_RecipThroughput);
}
// VecCost is equal to sum of the cost of creating 2 vectors
// and the cost of creating shuffle.
int VecCost = 0;
- if (Instruction::isBinaryOp(S.getOpcode())) {
- VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy);
- VecCost += TTI->getArithmeticInstrCost(S.getAltOpcode(), VecTy);
+ if (Instruction::isBinaryOp(E->getOpcode())) {
+ VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy);
+ VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy);
} else {
- Type *Src0SclTy = S.MainOp->getOperand(0)->getType();
- Type *Src1SclTy = S.AltOp->getOperand(0)->getType();
+ Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
+ Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
VectorType *Src0Ty = VectorType::get(Src0SclTy, VL.size());
VectorType *Src1Ty = VectorType::get(Src1SclTy, VL.size());
- VecCost = TTI->getCastInstrCost(S.getOpcode(), VecTy, Src0Ty);
- VecCost += TTI->getCastInstrCost(S.getAltOpcode(), VecTy, Src1Ty);
+ VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty);
+ VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty);
}
VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0);
return ReuseShuffleCost + VecCost - ScalarCost;
@@ -3098,6 +3293,43 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const {
return true;
}
+bool BoUpSLP::isLoadCombineReductionCandidate(unsigned RdxOpcode) const {
+ if (RdxOpcode != Instruction::Or)
+ return false;
+
+ unsigned NumElts = VectorizableTree[0]->Scalars.size();
+ Value *FirstReduced = VectorizableTree[0]->Scalars[0];
+
+ // Look past the reduction to find a source value. Arbitrarily follow the
+ // path through operand 0 of any 'or'. Also, peek through optional
+ // shift-left-by-constant.
+ Value *ZextLoad = FirstReduced;
+ while (match(ZextLoad, m_Or(m_Value(), m_Value())) ||
+ match(ZextLoad, m_Shl(m_Value(), m_Constant())))
+ ZextLoad = cast<BinaryOperator>(ZextLoad)->getOperand(0);
+
+ // Check if the input to the reduction is an extended load.
+ Value *LoadPtr;
+ if (!match(ZextLoad, m_ZExt(m_Load(m_Value(LoadPtr)))))
+ return false;
+
+ // Require that the total load bit width is a legal integer type.
+ // For example, <8 x i8> --> i64 is a legal integer on a 64-bit target.
+ // But <16 x i8> --> i128 is not, so the backend probably can't reduce it.
+ Type *SrcTy = LoadPtr->getType()->getPointerElementType();
+ unsigned LoadBitWidth = SrcTy->getIntegerBitWidth() * NumElts;
+ LLVMContext &Context = FirstReduced->getContext();
+ if (!TTI->isTypeLegal(IntegerType::get(Context, LoadBitWidth)))
+ return false;
+
+ // Everything matched - assume that we can fold the whole sequence using
+ // load combining.
+ LLVM_DEBUG(dbgs() << "SLP: Assume load combining for scalar reduction of "
+ << *(cast<Instruction>(FirstReduced)) << "\n");
+
+ return true;
+}
+
bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
// We can vectorize the tree if its size is greater than or equal to the
// minimum size specified by the MinTreeSize command line option.
@@ -3319,16 +3551,16 @@ void BoUpSLP::reorderInputsAccordingToOpcode(
Right = Ops.getVL(1);
}
-void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL,
- const InstructionsState &S) {
+void BoUpSLP::setInsertPointAfterBundle(TreeEntry *E) {
// Get the basic block this bundle is in. All instructions in the bundle
// should be in this block.
- auto *Front = cast<Instruction>(S.OpValue);
+ auto *Front = E->getMainOp();
auto *BB = Front->getParent();
- assert(llvm::all_of(make_range(VL.begin(), VL.end()), [=](Value *V) -> bool {
- auto *I = cast<Instruction>(V);
- return !S.isOpcodeOrAlt(I) || I->getParent() == BB;
- }));
+ assert(llvm::all_of(make_range(E->Scalars.begin(), E->Scalars.end()),
+ [=](Value *V) -> bool {
+ auto *I = cast<Instruction>(V);
+ return !E->isOpcodeOrAlt(I) || I->getParent() == BB;
+ }));
// The last instruction in the bundle in program order.
Instruction *LastInst = nullptr;
@@ -3339,7 +3571,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL,
// bundle. The end of the bundle is marked by null ScheduleData.
if (BlocksSchedules.count(BB)) {
auto *Bundle =
- BlocksSchedules[BB]->getScheduleData(isOneOf(S, VL.back()));
+ BlocksSchedules[BB]->getScheduleData(E->isOneOf(E->Scalars.back()));
if (Bundle && Bundle->isPartOfBundle())
for (; Bundle; Bundle = Bundle->NextInBundle)
if (Bundle->OpValue == Bundle->Inst)
@@ -3365,14 +3597,15 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL,
// we both exit early from buildTree_rec and that the bundle be out-of-order
// (causing us to iterate all the way to the end of the block).
if (!LastInst) {
- SmallPtrSet<Value *, 16> Bundle(VL.begin(), VL.end());
+ SmallPtrSet<Value *, 16> Bundle(E->Scalars.begin(), E->Scalars.end());
for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) {
- if (Bundle.erase(&I) && S.isOpcodeOrAlt(&I))
+ if (Bundle.erase(&I) && E->isOpcodeOrAlt(&I))
LastInst = &I;
if (Bundle.empty())
break;
}
}
+ assert(LastInst && "Failed to find last instruction in bundle");
// Set the insertion point after the last instruction in the bundle. Set the
// debug location to Front.
@@ -3385,7 +3618,7 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
// Generate the 'InsertElement' instruction.
for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
- if (Instruction *Insrt = dyn_cast<Instruction>(Vec)) {
+ if (auto *Insrt = dyn_cast<InsertElementInst>(Vec)) {
GatherSeq.insert(Insrt);
CSEBlocks.insert(Insrt->getParent());
@@ -3494,8 +3727,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return E->VectorizedValue;
}
- InstructionsState S = getSameOpcode(E->Scalars);
- Instruction *VL0 = cast<Instruction>(S.OpValue);
+ Instruction *VL0 = E->getMainOp();
Type *ScalarTy = VL0->getType();
if (StoreInst *SI = dyn_cast<StoreInst>(VL0))
ScalarTy = SI->getValueOperand()->getType();
@@ -3504,7 +3736,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
if (E->NeedToGather) {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
auto *V = Gather(E->Scalars, VecTy);
if (NeedToShuffleReuses) {
V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy),
@@ -3518,11 +3750,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
- unsigned ShuffleOrOp = S.isAltShuffle() ?
- (unsigned) Instruction::ShuffleVector : S.getOpcode();
+ unsigned ShuffleOrOp =
+ E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
switch (ShuffleOrOp) {
case Instruction::PHI: {
- PHINode *PH = dyn_cast<PHINode>(VL0);
+ auto *PH = cast<PHINode>(VL0);
Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI());
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
@@ -3577,7 +3809,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
E->VectorizedValue = V;
return V;
}
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
auto *V = Gather(E->Scalars, VecTy);
if (NeedToShuffleReuses) {
V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy),
@@ -3612,7 +3844,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
E->VectorizedValue = NewV;
return NewV;
}
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
auto *V = Gather(E->Scalars, VecTy);
if (NeedToShuffleReuses) {
V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy),
@@ -3637,7 +3869,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *InVec = vectorizeTree(E->getOperand(0));
@@ -3646,7 +3878,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return E->VectorizedValue;
}
- CastInst *CI = dyn_cast<CastInst>(VL0);
+ auto *CI = cast<CastInst>(VL0);
Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
if (NeedToShuffleReuses) {
V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy),
@@ -3658,7 +3890,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
case Instruction::FCmp:
case Instruction::ICmp: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *L = vectorizeTree(E->getOperand(0));
Value *R = vectorizeTree(E->getOperand(1));
@@ -3670,7 +3902,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
Value *V;
- if (S.getOpcode() == Instruction::FCmp)
+ if (E->getOpcode() == Instruction::FCmp)
V = Builder.CreateFCmp(P0, L, R);
else
V = Builder.CreateICmp(P0, L, R);
@@ -3685,7 +3917,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
case Instruction::Select: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *Cond = vectorizeTree(E->getOperand(0));
Value *True = vectorizeTree(E->getOperand(1));
@@ -3706,7 +3938,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
case Instruction::FNeg: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *Op = vectorizeTree(E->getOperand(0));
@@ -3716,7 +3948,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
Value *V = Builder.CreateUnOp(
- static_cast<Instruction::UnaryOps>(S.getOpcode()), Op);
+ static_cast<Instruction::UnaryOps>(E->getOpcode()), Op);
propagateIRFlags(V, E->Scalars, VL0);
if (auto *I = dyn_cast<Instruction>(V))
V = propagateMetadata(I, E->Scalars);
@@ -3748,7 +3980,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *LHS = vectorizeTree(E->getOperand(0));
Value *RHS = vectorizeTree(E->getOperand(1));
@@ -3759,7 +3991,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
Value *V = Builder.CreateBinOp(
- static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS);
+ static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
+ RHS);
propagateIRFlags(V, E->Scalars, VL0);
if (auto *I = dyn_cast<Instruction>(V))
V = propagateMetadata(I, E->Scalars);
@@ -3776,12 +4009,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::Load: {
// Loads are inserted at the head of the tree because we don't want to
// sink them all the way down past store instructions.
- bool IsReorder = !E->ReorderIndices.empty();
- if (IsReorder) {
- S = getSameOpcode(E->Scalars, E->ReorderIndices.front());
- VL0 = cast<Instruction>(S.OpValue);
- }
- setInsertPointAfterBundle(E->Scalars, S);
+ bool IsReorder = E->updateStateIfReorder();
+ if (IsReorder)
+ VL0 = E->getMainOp();
+ setInsertPointAfterBundle(E);
LoadInst *LI = cast<LoadInst>(VL0);
Type *ScalarLoadTy = LI->getType();
@@ -3797,11 +4028,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
if (getTreeEntry(PO))
ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0));
- unsigned Alignment = LI->getAlignment();
+ MaybeAlign Alignment = MaybeAlign(LI->getAlignment());
LI = Builder.CreateLoad(VecTy, VecPtr);
- if (!Alignment) {
- Alignment = DL->getABITypeAlignment(ScalarLoadTy);
- }
+ if (!Alignment)
+ Alignment = MaybeAlign(DL->getABITypeAlignment(ScalarLoadTy));
LI->setAlignment(Alignment);
Value *V = propagateMetadata(LI, E->Scalars);
if (IsReorder) {
@@ -3824,7 +4054,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
unsigned Alignment = SI->getAlignment();
unsigned AS = SI->getPointerAddressSpace();
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *VecValue = vectorizeTree(E->getOperand(0));
Value *ScalarPtr = SI->getPointerOperand();
@@ -3840,7 +4070,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
if (!Alignment)
Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType());
- ST->setAlignment(Alignment);
+ ST->setAlignment(Align(Alignment));
Value *V = propagateMetadata(ST, E->Scalars);
if (NeedToShuffleReuses) {
V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy),
@@ -3851,7 +4081,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
case Instruction::GetElementPtr: {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
Value *Op0 = vectorizeTree(E->getOperand(0));
@@ -3878,13 +4108,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
case Instruction::Call: {
CallInst *CI = cast<CallInst>(VL0);
- setInsertPointAfterBundle(E->Scalars, S);
- Function *FI;
+ setInsertPointAfterBundle(E);
+
Intrinsic::ID IID = Intrinsic::not_intrinsic;
- Value *ScalarArg = nullptr;
- if (CI && (FI = CI->getCalledFunction())) {
+ if (Function *FI = CI->getCalledFunction())
IID = FI->getIntrinsicID();
- }
+
+ Value *ScalarArg = nullptr;
std::vector<Value *> OpVecs;
for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) {
ValueList OpVL;
@@ -3926,20 +4156,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
case Instruction::ShuffleVector: {
- assert(S.isAltShuffle() &&
- ((Instruction::isBinaryOp(S.getOpcode()) &&
- Instruction::isBinaryOp(S.getAltOpcode())) ||
- (Instruction::isCast(S.getOpcode()) &&
- Instruction::isCast(S.getAltOpcode()))) &&
+ assert(E->isAltShuffle() &&
+ ((Instruction::isBinaryOp(E->getOpcode()) &&
+ Instruction::isBinaryOp(E->getAltOpcode())) ||
+ (Instruction::isCast(E->getOpcode()) &&
+ Instruction::isCast(E->getAltOpcode()))) &&
"Invalid Shuffle Vector Operand");
- Value *LHS, *RHS;
- if (Instruction::isBinaryOp(S.getOpcode())) {
- setInsertPointAfterBundle(E->Scalars, S);
+ Value *LHS = nullptr, *RHS = nullptr;
+ if (Instruction::isBinaryOp(E->getOpcode())) {
+ setInsertPointAfterBundle(E);
LHS = vectorizeTree(E->getOperand(0));
RHS = vectorizeTree(E->getOperand(1));
} else {
- setInsertPointAfterBundle(E->Scalars, S);
+ setInsertPointAfterBundle(E);
LHS = vectorizeTree(E->getOperand(0));
}
@@ -3949,16 +4179,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
Value *V0, *V1;
- if (Instruction::isBinaryOp(S.getOpcode())) {
+ if (Instruction::isBinaryOp(E->getOpcode())) {
V0 = Builder.CreateBinOp(
- static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS);
+ static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS);
V1 = Builder.CreateBinOp(
- static_cast<Instruction::BinaryOps>(S.getAltOpcode()), LHS, RHS);
+ static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS);
} else {
V0 = Builder.CreateCast(
- static_cast<Instruction::CastOps>(S.getOpcode()), LHS, VecTy);
+ static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
V1 = Builder.CreateCast(
- static_cast<Instruction::CastOps>(S.getAltOpcode()), LHS, VecTy);
+ static_cast<Instruction::CastOps>(E->getAltOpcode()), LHS, VecTy);
}
// Create shuffle to take alternate operations from the vector.
@@ -3969,8 +4199,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
SmallVector<Constant *, 8> Mask(e);
for (unsigned i = 0; i < e; ++i) {
auto *OpInst = cast<Instruction>(E->Scalars[i]);
- assert(S.isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode");
- if (OpInst->getOpcode() == S.getAltOpcode()) {
+ assert(E->isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode");
+ if (OpInst->getOpcode() == E->getAltOpcode()) {
Mask[i] = Builder.getInt32(e + i);
AltScalars.push_back(E->Scalars[i]);
} else {
@@ -4136,20 +4366,18 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) {
for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) {
Value *Scalar = Entry->Scalars[Lane];
+#ifndef NDEBUG
Type *Ty = Scalar->getType();
if (!Ty->isVoidTy()) {
-#ifndef NDEBUG
for (User *U : Scalar->users()) {
LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n");
- // It is legal to replace users in the ignorelist by undef.
+ // It is legal to delete users in the ignorelist.
assert((getTreeEntry(U) || is_contained(UserIgnoreList, U)) &&
- "Replacing out-of-tree value with undef");
+ "Deleting out-of-tree value");
}
-#endif
- Value *Undef = UndefValue::get(Ty);
- Scalar->replaceAllUsesWith(Undef);
}
+#endif
LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n");
eraseInstruction(cast<Instruction>(Scalar));
}
@@ -4165,7 +4393,7 @@ void BoUpSLP::optimizeGatherSequence() {
<< " gather sequences instructions.\n");
// LICM InsertElementInst sequences.
for (Instruction *I : GatherSeq) {
- if (!isa<InsertElementInst>(I) && !isa<ShuffleVectorInst>(I))
+ if (isDeleted(I))
continue;
// Check if this block is inside a loop.
@@ -4219,6 +4447,8 @@ void BoUpSLP::optimizeGatherSequence() {
// For all instructions in blocks containing gather sequences:
for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) {
Instruction *In = &*it++;
+ if (isDeleted(In))
+ continue;
if (!isa<InsertElementInst>(In) && !isa<ExtractElementInst>(In))
continue;
@@ -4245,11 +4475,11 @@ void BoUpSLP::optimizeGatherSequence() {
// Groups the instructions to a bundle (which is then a single scheduling entity)
// and schedules instructions until the bundle gets ready.
-bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL,
- BoUpSLP *SLP,
- const InstructionsState &S) {
+Optional<BoUpSLP::ScheduleData *>
+BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP,
+ const InstructionsState &S) {
if (isa<PHINode>(S.OpValue))
- return true;
+ return nullptr;
// Initialize the instruction bundle.
Instruction *OldScheduleEnd = ScheduleEnd;
@@ -4262,7 +4492,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL,
// instructions of the bundle.
for (Value *V : VL) {
if (!extendSchedulingRegion(V, S))
- return false;
+ return None;
}
for (Value *V : VL) {
@@ -4308,6 +4538,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL,
resetSchedule();
initialFillReadyList(ReadyInsts);
}
+ assert(Bundle && "Failed to find schedule bundle");
LLVM_DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block "
<< BB->getName() << "\n");
@@ -4329,9 +4560,9 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL,
}
if (!Bundle->isReady()) {
cancelScheduling(VL, S.OpValue);
- return false;
+ return None;
}
- return true;
+ return Bundle;
}
void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL,
@@ -4364,7 +4595,7 @@ void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL,
BoUpSLP::ScheduleData *BoUpSLP::BlockScheduling::allocateScheduleDataChunks() {
// Allocate a new ScheduleData for the instruction.
if (ChunkPos >= ChunkSize) {
- ScheduleDataChunks.push_back(llvm::make_unique<ScheduleData[]>(ChunkSize));
+ ScheduleDataChunks.push_back(std::make_unique<ScheduleData[]>(ChunkSize));
ChunkPos = 0;
}
return &(ScheduleDataChunks.back()[ChunkPos++]);
@@ -4977,7 +5208,7 @@ struct SLPVectorizer : public FunctionPass {
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI() : nullptr;
+ auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
@@ -5052,7 +5283,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_,
// If the target claims to have no vector registers don't attempt
// vectorization.
- if (!TTI->getNumberOfRegisters(true))
+ if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true)))
return false;
// Don't vectorize when the attribute NoImplicitFloat is used.
@@ -5100,19 +5331,6 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_,
return Changed;
}
-/// Check that the Values in the slice in VL array are still existent in
-/// the WeakTrackingVH array.
-/// Vectorization of part of the VL array may cause later values in the VL array
-/// to become invalid. We track when this has happened in the WeakTrackingVH
-/// array.
-static bool hasValueBeenRAUWed(ArrayRef<Value *> VL,
- ArrayRef<WeakTrackingVH> VH, unsigned SliceBegin,
- unsigned SliceSize) {
- VL = VL.slice(SliceBegin, SliceSize);
- VH = VH.slice(SliceBegin, SliceSize);
- return !std::equal(VL.begin(), VL.end(), VH.begin());
-}
-
bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
unsigned VecRegSize) {
const unsigned ChainLen = Chain.size();
@@ -5124,20 +5342,20 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
if (!isPowerOf2_32(Sz) || VF < 2)
return false;
- // Keep track of values that were deleted by vectorizing in the loop below.
- const SmallVector<WeakTrackingVH, 8> TrackValues(Chain.begin(), Chain.end());
-
bool Changed = false;
// Look for profitable vectorizable trees at all offsets, starting at zero.
for (unsigned i = 0, e = ChainLen; i + VF <= e; ++i) {
+ ArrayRef<Value *> Operands = Chain.slice(i, VF);
// Check that a previous iteration of this loop did not delete the Value.
- if (hasValueBeenRAUWed(Chain, TrackValues, i, VF))
+ if (llvm::any_of(Operands, [&R](Value *V) {
+ auto *I = dyn_cast<Instruction>(V);
+ return I && R.isDeleted(I);
+ }))
continue;
LLVM_DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
<< "\n");
- ArrayRef<Value *> Operands = Chain.slice(i, VF);
R.buildTree(Operands);
if (R.isTreeTinyAndNotFullyVectorizable())
@@ -5329,12 +5547,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
bool CandidateFound = false;
int MinCost = SLPCostThreshold;
- // Keep track of values that were deleted by vectorizing in the loop below.
- SmallVector<WeakTrackingVH, 8> TrackValues(VL.begin(), VL.end());
-
unsigned NextInst = 0, MaxInst = VL.size();
- for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF;
- VF /= 2) {
+ for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; VF /= 2) {
// No actual vectorization should happen, if number of parts is the same as
// provided vectorization factor (i.e. the scalar type is used for vector
// code during codegen).
@@ -5352,13 +5566,16 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
if (!isPowerOf2_32(OpsWidth) || OpsWidth < 2)
break;
+ ArrayRef<Value *> Ops = VL.slice(I, OpsWidth);
// Check that a previous iteration of this loop did not delete the Value.
- if (hasValueBeenRAUWed(VL, TrackValues, I, OpsWidth))
+ if (llvm::any_of(Ops, [&R](Value *V) {
+ auto *I = dyn_cast<Instruction>(V);
+ return I && R.isDeleted(I);
+ }))
continue;
LLVM_DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations "
<< "\n");
- ArrayRef<Value *> Ops = VL.slice(I, OpsWidth);
R.buildTree(Ops);
Optional<ArrayRef<unsigned>> Order = R.bestOrder();
@@ -5571,7 +5788,7 @@ class HorizontalReduction {
Value *createOp(IRBuilder<> &Builder, const Twine &Name) const {
assert(isVectorizable() &&
"Expected add|fadd or min/max reduction operation.");
- Value *Cmp;
+ Value *Cmp = nullptr;
switch (Kind) {
case RK_Arithmetic:
return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
@@ -5579,23 +5796,23 @@ class HorizontalReduction {
case RK_Min:
Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS)
: Builder.CreateFCmpOLT(LHS, RHS);
- break;
+ return Builder.CreateSelect(Cmp, LHS, RHS, Name);
case RK_Max:
Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS)
: Builder.CreateFCmpOGT(LHS, RHS);
- break;
+ return Builder.CreateSelect(Cmp, LHS, RHS, Name);
case RK_UMin:
assert(Opcode == Instruction::ICmp && "Expected integer types.");
Cmp = Builder.CreateICmpULT(LHS, RHS);
- break;
+ return Builder.CreateSelect(Cmp, LHS, RHS, Name);
case RK_UMax:
assert(Opcode == Instruction::ICmp && "Expected integer types.");
Cmp = Builder.CreateICmpUGT(LHS, RHS);
- break;
+ return Builder.CreateSelect(Cmp, LHS, RHS, Name);
case RK_None:
- llvm_unreachable("Unknown reduction operation.");
+ break;
}
- return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+ llvm_unreachable("Unknown reduction operation.");
}
public:
@@ -6203,6 +6420,8 @@ public:
}
if (V.isTreeTinyAndNotFullyVectorizable())
break;
+ if (V.isLoadCombineReductionCandidate(ReductionData.getOpcode()))
+ break;
V.computeMinimumValueSizes();
@@ -6275,6 +6494,9 @@ public:
}
// Update users.
ReductionRoot->replaceAllUsesWith(VectorizedTree);
+ // Mark all scalar reduction ops for deletion, they are replaced by the
+ // vector reductions.
+ V.eraseInstructions(IgnoreList);
}
return VectorizedTree != nullptr;
}
@@ -6323,7 +6545,7 @@ private:
IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost;
int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost;
- int ScalarReduxCost;
+ int ScalarReduxCost = 0;
switch (ReductionData.getKind()) {
case RK_Arithmetic:
ScalarReduxCost =
@@ -6429,10 +6651,9 @@ static bool findBuildVector(InsertElementInst *LastInsertElem,
/// \return true if it matches.
static bool findBuildAggregate(InsertValueInst *IV,
SmallVectorImpl<Value *> &BuildVectorOpds) {
- Value *V;
do {
BuildVectorOpds.push_back(IV->getInsertedValueOperand());
- V = IV->getAggregateOperand();
+ Value *V = IV->getAggregateOperand();
if (isa<UndefValue>(V))
break;
IV = dyn_cast<InsertValueInst>(V);
@@ -6530,18 +6751,13 @@ static bool tryToVectorizeHorReductionOrInstOperands(
// horizontal reduction.
// Interrupt the process if the Root instruction itself was vectorized or all
// sub-trees not higher that RecursionMaxDepth were analyzed/vectorized.
- SmallVector<std::pair<WeakTrackingVH, unsigned>, 8> Stack(1, {Root, 0});
+ SmallVector<std::pair<Instruction *, unsigned>, 8> Stack(1, {Root, 0});
SmallPtrSet<Value *, 8> VisitedInstrs;
bool Res = false;
while (!Stack.empty()) {
- Value *V;
+ Instruction *Inst;
unsigned Level;
- std::tie(V, Level) = Stack.pop_back_val();
- if (!V)
- continue;
- auto *Inst = dyn_cast<Instruction>(V);
- if (!Inst)
- continue;
+ std::tie(Inst, Level) = Stack.pop_back_val();
auto *BI = dyn_cast<BinaryOperator>(Inst);
auto *SI = dyn_cast<SelectInst>(Inst);
if (BI || SI) {
@@ -6582,8 +6798,8 @@ static bool tryToVectorizeHorReductionOrInstOperands(
for (auto *Op : Inst->operand_values())
if (VisitedInstrs.insert(Op).second)
if (auto *I = dyn_cast<Instruction>(Op))
- if (!isa<PHINode>(I) && I->getParent() == BB)
- Stack.emplace_back(Op, Level);
+ if (!isa<PHINode>(I) && !R.isDeleted(I) && I->getParent() == BB)
+ Stack.emplace_back(I, Level);
}
return Res;
}
@@ -6652,11 +6868,10 @@ bool SLPVectorizerPass::vectorizeCmpInst(CmpInst *CI, BasicBlock *BB,
}
bool SLPVectorizerPass::vectorizeSimpleInstructions(
- SmallVectorImpl<WeakVH> &Instructions, BasicBlock *BB, BoUpSLP &R) {
+ SmallVectorImpl<Instruction *> &Instructions, BasicBlock *BB, BoUpSLP &R) {
bool OpsChanged = false;
- for (auto &VH : reverse(Instructions)) {
- auto *I = dyn_cast_or_null<Instruction>(VH);
- if (!I)
+ for (auto *I : reverse(Instructions)) {
+ if (R.isDeleted(I))
continue;
if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I))
OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R);
@@ -6685,7 +6900,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
if (!P)
break;
- if (!VisitedInstrs.count(P))
+ if (!VisitedInstrs.count(P) && !R.isDeleted(P))
Incoming.push_back(P);
}
@@ -6729,9 +6944,12 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
VisitedInstrs.clear();
- SmallVector<WeakVH, 8> PostProcessInstructions;
+ SmallVector<Instruction *, 8> PostProcessInstructions;
SmallDenseSet<Instruction *, 4> KeyNodes;
for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
+ // Skip instructions marked for the deletion.
+ if (R.isDeleted(&*it))
+ continue;
// We may go through BB multiple times so skip the one we have checked.
if (!VisitedInstrs.insert(&*it).second) {
if (it->use_empty() && KeyNodes.count(&*it) > 0 &&
@@ -6811,10 +7029,16 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) {
LLVM_DEBUG(dbgs() << "SLP: Analyzing a getelementptr list of length "
<< Entry.second.size() << ".\n");
- // We process the getelementptr list in chunks of 16 (like we do for
- // stores) to minimize compile-time.
- for (unsigned BI = 0, BE = Entry.second.size(); BI < BE; BI += 16) {
- auto Len = std::min<unsigned>(BE - BI, 16);
+ // Process the GEP list in chunks suitable for the target's supported
+ // vector size. If a vector register can't hold 1 element, we are done.
+ unsigned MaxVecRegSize = R.getMaxVecRegSize();
+ unsigned EltSize = R.getVectorElementSize(Entry.second[0]);
+ if (MaxVecRegSize < EltSize)
+ continue;
+
+ unsigned MaxElts = MaxVecRegSize / EltSize;
+ for (unsigned BI = 0, BE = Entry.second.size(); BI < BE; BI += MaxElts) {
+ auto Len = std::min<unsigned>(BE - BI, MaxElts);
auto GEPList = makeArrayRef(&Entry.second[BI], Len);
// Initialize a set a candidate getelementptrs. Note that we use a
@@ -6824,10 +7048,10 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) {
SetVector<Value *> Candidates(GEPList.begin(), GEPList.end());
// Some of the candidates may have already been vectorized after we
- // initially collected them. If so, the WeakTrackingVHs will have
- // nullified the
- // values, so remove them from the set of candidates.
- Candidates.remove(nullptr);
+ // initially collected them. If so, they are marked as deleted, so remove
+ // them from the set of candidates.
+ Candidates.remove_if(
+ [&R](Value *I) { return R.isDeleted(cast<Instruction>(I)); });
// Remove from the set of candidates all pairs of getelementptrs with
// constant differences. Such getelementptrs are likely not good
@@ -6835,18 +7059,18 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) {
// computed from the other. We also ensure all candidate getelementptr
// indices are unique.
for (int I = 0, E = GEPList.size(); I < E && Candidates.size() > 1; ++I) {
- auto *GEPI = cast<GetElementPtrInst>(GEPList[I]);
+ auto *GEPI = GEPList[I];
if (!Candidates.count(GEPI))
continue;
auto *SCEVI = SE->getSCEV(GEPList[I]);
for (int J = I + 1; J < E && Candidates.size() > 1; ++J) {
- auto *GEPJ = cast<GetElementPtrInst>(GEPList[J]);
+ auto *GEPJ = GEPList[J];
auto *SCEVJ = SE->getSCEV(GEPList[J]);
if (isa<SCEVConstant>(SE->getMinusSCEV(SCEVI, SCEVJ))) {
- Candidates.remove(GEPList[I]);
- Candidates.remove(GEPList[J]);
+ Candidates.remove(GEPI);
+ Candidates.remove(GEPJ);
} else if (GEPI->idx_begin()->get() == GEPJ->idx_begin()->get()) {
- Candidates.remove(GEPList[J]);
+ Candidates.remove(GEPJ);
}
}
}
diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp
index 517d759d7bfc..4b80d1fb20aa 100644
--- a/lib/Transforms/Vectorize/VPlan.cpp
+++ b/lib/Transforms/Vectorize/VPlan.cpp
@@ -283,6 +283,12 @@ iplist<VPRecipeBase>::iterator VPRecipeBase::eraseFromParent() {
return getParent()->getRecipeList().erase(getIterator());
}
+void VPRecipeBase::moveAfter(VPRecipeBase *InsertPos) {
+ InsertPos->getParent()->getRecipeList().splice(
+ std::next(InsertPos->getIterator()), getParent()->getRecipeList(),
+ getIterator());
+}
+
void VPInstruction::generateInstruction(VPTransformState &State,
unsigned Part) {
IRBuilder<> &Builder = State.Builder;
@@ -309,6 +315,14 @@ void VPInstruction::generateInstruction(VPTransformState &State,
State.set(this, V, Part);
break;
}
+ case Instruction::Select: {
+ Value *Cond = State.get(getOperand(0), Part);
+ Value *Op1 = State.get(getOperand(1), Part);
+ Value *Op2 = State.get(getOperand(2), Part);
+ Value *V = Builder.CreateSelect(Cond, Op1, Op2);
+ State.set(this, V, Part);
+ break;
+ }
default:
llvm_unreachable("Unsupported opcode for instruction");
}
@@ -728,7 +742,7 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New,
auto NewIGIter = Old2New.find(IG);
if (NewIGIter == Old2New.end())
Old2New[IG] = new InterleaveGroup<VPInstruction>(
- IG->getFactor(), IG->isReverse(), IG->getAlignment());
+ IG->getFactor(), IG->isReverse(), Align(IG->getAlignment()));
if (Inst == IG->getInsertPos())
Old2New[IG]->setInsertPos(VPInst);
@@ -736,7 +750,8 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New,
InterleaveGroupMap[VPInst] = Old2New[IG];
InterleaveGroupMap[VPInst]->insertMember(
VPInst, IG->getIndex(Inst),
- IG->isReverse() ? (-1) * int(IG->getFactor()) : IG->getFactor());
+ Align(IG->isReverse() ? (-1) * int(IG->getFactor())
+ : IG->getFactor()));
}
} else if (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block))
visitRegion(Region, Old2New, IAI);
diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h
index 8a06412ad590..44d8a198f27e 100644
--- a/lib/Transforms/Vectorize/VPlan.h
+++ b/lib/Transforms/Vectorize/VPlan.h
@@ -615,6 +615,10 @@ public:
/// the specified recipe.
void insertBefore(VPRecipeBase *InsertPos);
+ /// Unlink this recipe from its current VPBasicBlock and insert it into
+ /// the VPBasicBlock that MovePos lives in, right after MovePos.
+ void moveAfter(VPRecipeBase *MovePos);
+
/// This method unlinks 'this' from the containing basic block and deletes it.
///
/// \returns an iterator pointing to the element after the erased one
diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp
index 7ed7d21b6caa..b22d3190d654 100644
--- a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp
+++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp
@@ -21,7 +21,7 @@ void VPlanHCFGTransforms::VPInstructionsToVPRecipes(
LoopVectorizationLegality::InductionList *Inductions,
SmallPtrSetImpl<Instruction *> &DeadInstructions) {
- VPRegionBlock *TopRegion = dyn_cast<VPRegionBlock>(Plan->getEntry());
+ auto *TopRegion = cast<VPRegionBlock>(Plan->getEntry());
ReversePostOrderTraversal<VPBlockBase *> RPOT(TopRegion->getEntry());
// Condition bit VPValues get deleted during transformation to VPRecipes.
diff --git a/lib/Transforms/Vectorize/VPlanSLP.cpp b/lib/Transforms/Vectorize/VPlanSLP.cpp
index e5ab24e52df6..9019ed15ec5f 100644
--- a/lib/Transforms/Vectorize/VPlanSLP.cpp
+++ b/lib/Transforms/Vectorize/VPlanSLP.cpp
@@ -346,11 +346,14 @@ SmallVector<VPlanSlp::MultiNodeOpTy, 4> VPlanSlp::reorderMultiNodeOps() {
void VPlanSlp::dumpBundle(ArrayRef<VPValue *> Values) {
dbgs() << " Ops: ";
- for (auto Op : Values)
- if (auto *Instr = cast_or_null<VPInstruction>(Op)->getUnderlyingInstr())
- dbgs() << *Instr << " | ";
- else
- dbgs() << " nullptr | ";
+ for (auto Op : Values) {
+ if (auto *VPInstr = cast_or_null<VPInstruction>(Op))
+ if (auto *Instr = VPInstr->getUnderlyingInstr()) {
+ dbgs() << *Instr << " | ";
+ continue;
+ }
+ dbgs() << " nullptr | ";
+ }
dbgs() << "\n";
}