aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-07-26 19:03:47 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-07-26 19:04:23 +0000
commit7fa27ce4a07f19b07799a767fc29416f3b625afb (patch)
tree27825c83636c4de341eb09a74f49f5d38a15d165 /llvm/lib/Transforms
parente3b557809604d036af6e00c60f012c2025b59a5e (diff)
downloadsrc-7fa27ce4a07f19b07799a767fc29416f3b625afb.tar.gz
src-7fa27ce4a07f19b07799a767fc29416f3b625afb.zip
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp362
-rw-r--r--llvm/lib/Transforms/CFGuard/CFGuard.cpp2
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroCleanup.cpp10
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp4
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroElide.cpp44
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroFrame.cpp608
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroInternal.h15
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroSplit.cpp188
-rw-r--r--llvm/lib/Transforms/Coroutines/Coroutines.cpp26
-rw-r--r--llvm/lib/Transforms/IPO/AlwaysInliner.cpp136
-rw-r--r--llvm/lib/Transforms/IPO/Annotation2Metadata.cpp34
-rw-r--r--llvm/lib/Transforms/IPO/ArgumentPromotion.cpp72
-rw-r--r--llvm/lib/Transforms/IPO/Attributor.cpp1093
-rw-r--r--llvm/lib/Transforms/IPO/AttributorAttributes.cpp3316
-rw-r--r--llvm/lib/Transforms/IPO/BlockExtractor.cpp2
-rw-r--r--llvm/lib/Transforms/IPO/CalledValuePropagation.cpp32
-rw-r--r--llvm/lib/Transforms/IPO/ConstantMerge.cpp31
-rw-r--r--llvm/lib/Transforms/IPO/CrossDSOCFI.cpp20
-rw-r--r--llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp40
-rw-r--r--llvm/lib/Transforms/IPO/ElimAvailExtern.cpp110
-rw-r--r--llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp52
-rw-r--r--llvm/lib/Transforms/IPO/ExtractGV.cpp26
-rw-r--r--llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp31
-rw-r--r--llvm/lib/Transforms/IPO/FunctionAttrs.cpp308
-rw-r--r--llvm/lib/Transforms/IPO/FunctionImport.cpp362
-rw-r--r--llvm/lib/Transforms/IPO/FunctionSpecialization.cpp707
-rw-r--r--llvm/lib/Transforms/IPO/GlobalDCE.cpp110
-rw-r--r--llvm/lib/Transforms/IPO/GlobalOpt.cpp368
-rw-r--r--llvm/lib/Transforms/IPO/GlobalSplit.cpp35
-rw-r--r--llvm/lib/Transforms/IPO/HotColdSplitting.cpp57
-rw-r--r--llvm/lib/Transforms/IPO/IPO.cpp97
-rw-r--r--llvm/lib/Transforms/IPO/IROutliner.cpp75
-rw-r--r--llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp36
-rw-r--r--llvm/lib/Transforms/IPO/InlineSimple.cpp118
-rw-r--r--llvm/lib/Transforms/IPO/Inliner.cpp558
-rw-r--r--llvm/lib/Transforms/IPO/Internalize.cpp63
-rw-r--r--llvm/lib/Transforms/IPO/LoopExtractor.cpp4
-rw-r--r--llvm/lib/Transforms/IPO/LowerTypeTests.cpp197
-rw-r--r--llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp3277
-rw-r--r--llvm/lib/Transforms/IPO/MergeFunctions.cpp28
-rw-r--r--llvm/lib/Transforms/IPO/ModuleInliner.cpp9
-rw-r--r--llvm/lib/Transforms/IPO/OpenMPOpt.cpp699
-rw-r--r--llvm/lib/Transforms/IPO/PartialInlining.cpp76
-rw-r--r--llvm/lib/Transforms/IPO/PassManagerBuilder.cpp517
-rw-r--r--llvm/lib/Transforms/IPO/SCCP.cpp132
-rw-r--r--llvm/lib/Transforms/IPO/SampleProfile.cpp569
-rw-r--r--llvm/lib/Transforms/IPO/SampleProfileProbe.cpp110
-rw-r--r--llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp29
-rw-r--r--llvm/lib/Transforms/IPO/StripSymbols.cpp209
-rw-r--r--llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp55
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp162
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp102
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp487
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp18
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp596
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp314
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp1160
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h49
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp139
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp244
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp12
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp448
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp67
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp192
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp258
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp1311
-rw-r--r--llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp323
-rw-r--r--llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp368
-rw-r--r--llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp10
-rw-r--r--llvm/lib/Transforms/Instrumentation/CFGMST.h303
-rw-r--r--llvm/lib/Transforms/Instrumentation/CGProfile.cpp23
-rw-r--r--llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp22
-rw-r--r--llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp31
-rw-r--r--llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp20
-rw-r--r--llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp319
-rw-r--r--llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp65
-rw-r--r--llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp3
-rw-r--r--llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp106
-rw-r--r--llvm/lib/Transforms/Instrumentation/Instrumentation.cpp5
-rw-r--r--llvm/lib/Transforms/Instrumentation/KCFI.cpp20
-rw-r--r--llvm/lib/Transforms/Instrumentation/MemProfiler.cpp315
-rw-r--r--llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp335
-rw-r--r--llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp813
-rw-r--r--llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp2
-rw-r--r--llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp187
-rw-r--r--llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp36
-rw-r--r--llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp6
-rw-r--r--llvm/lib/Transforms/ObjCARC/ObjCARC.h2
-rw-r--r--llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp2
-rw-r--r--llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp10
-rw-r--r--llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp71
-rw-r--r--llvm/lib/Transforms/Scalar/ADCE.cpp94
-rw-r--r--llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp56
-rw-r--r--llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp8
-rw-r--r--llvm/lib/Transforms/Scalar/BDCE.cpp36
-rw-r--r--llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp39
-rw-r--r--llvm/lib/Transforms/Scalar/ConstantHoisting.cpp130
-rw-r--r--llvm/lib/Transforms/Scalar/ConstraintElimination.cpp702
-rw-r--r--llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp105
-rw-r--r--llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp50
-rw-r--r--llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp169
-rw-r--r--llvm/lib/Transforms/Scalar/DivRemPairs.cpp45
-rw-r--r--llvm/lib/Transforms/Scalar/EarlyCSE.cpp74
-rw-r--r--llvm/lib/Transforms/Scalar/Float2Int.cpp47
-rw-r--r--llvm/lib/Transforms/Scalar/GVN.cpp296
-rw-r--r--llvm/lib/Transforms/Scalar/GVNHoist.cpp57
-rw-r--r--llvm/lib/Transforms/Scalar/GVNSink.cpp41
-rw-r--r--llvm/lib/Transforms/Scalar/GuardWidening.cpp222
-rw-r--r--llvm/lib/Transforms/Scalar/IndVarSimplify.cpp261
-rw-r--r--llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp503
-rw-r--r--llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp157
-rw-r--r--llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp15
-rw-r--r--llvm/lib/Transforms/Scalar/JumpThreading.cpp404
-rw-r--r--llvm/lib/Transforms/Scalar/LICM.cpp517
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDeletion.cpp71
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDistribute.cpp57
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFlatten.cpp97
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFuse.cpp64
-rw-r--r--llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp159
-rw-r--r--llvm/lib/Transforms/Scalar/LoopInterchange.cpp60
-rw-r--r--llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp94
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPassManager.cpp4
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp39
-rw-r--r--llvm/lib/Transforms/Scalar/LoopRerollPass.cpp45
-rw-r--r--llvm/lib/Transforms/Scalar/LoopRotation.cpp15
-rw-r--r--llvm/lib/Transforms/Scalar/LoopSink.cpp40
-rw-r--r--llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp390
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp74
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp41
-rw-r--r--llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp65
-rw-r--r--llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp3
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp295
-rw-r--r--llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp471
-rw-r--r--llvm/lib/Transforms/Scalar/MergeICmps.cpp13
-rw-r--r--llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp23
-rw-r--r--llvm/lib/Transforms/Scalar/NaryReassociate.cpp14
-rw-r--r--llvm/lib/Transforms/Scalar/NewGVN.cpp94
-rw-r--r--llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp476
-rw-r--r--llvm/lib/Transforms/Scalar/Reassociate.cpp160
-rw-r--r--llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp245
-rw-r--r--llvm/lib/Transforms/Scalar/SCCP.cpp52
-rw-r--r--llvm/lib/Transforms/Scalar/SROA.cpp627
-rw-r--r--llvm/lib/Transforms/Scalar/Scalar.cpp232
-rw-r--r--llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp28
-rw-r--r--llvm/lib/Transforms/Scalar/Scalarizer.cpp774
-rw-r--r--llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp129
-rw-r--r--llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp695
-rw-r--r--llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp28
-rw-r--r--llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp50
-rw-r--r--llvm/lib/Transforms/Scalar/StructurizeCFG.cpp41
-rw-r--r--llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp45
-rw-r--r--llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp322
-rw-r--r--llvm/lib/Transforms/Utils/AddDiscriminators.cpp31
-rw-r--r--llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp64
-rw-r--r--llvm/lib/Transforms/Utils/BasicBlockUtils.cpp282
-rw-r--r--llvm/lib/Transforms/Utils/BuildLibCalls.cpp88
-rw-r--r--llvm/lib/Transforms/Utils/BypassSlowDivision.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/CallGraphUpdater.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/CallPromotionUtils.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/CloneFunction.cpp17
-rw-r--r--llvm/lib/Transforms/Utils/CodeExtractor.cpp43
-rw-r--r--llvm/lib/Transforms/Utils/CodeLayout.cpp770
-rw-r--r--llvm/lib/Transforms/Utils/CountVisits.cpp25
-rw-r--r--llvm/lib/Transforms/Utils/CtorUtils.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/Debugify.cpp83
-rw-r--r--llvm/lib/Transforms/Utils/DemoteRegToStack.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp16
-rw-r--r--llvm/lib/Transforms/Utils/EscapeEnumerator.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/Evaluator.cpp40
-rw-r--r--llvm/lib/Transforms/Utils/FlattenCFG.cpp13
-rw-r--r--llvm/lib/Transforms/Utils/FunctionComparator.cpp63
-rw-r--r--llvm/lib/Transforms/Utils/InjectTLIMappings.cpp65
-rw-r--r--llvm/lib/Transforms/Utils/InlineFunction.cpp123
-rw-r--r--llvm/lib/Transforms/Utils/InstructionNamer.cpp33
-rw-r--r--llvm/lib/Transforms/Utils/LCSSA.cpp46
-rw-r--r--llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp107
-rw-r--r--llvm/lib/Transforms/Utils/Local.cpp293
-rw-r--r--llvm/lib/Transforms/Utils/LoopPeel.cpp50
-rw-r--r--llvm/lib/Transforms/Utils/LoopRotationUtils.cpp7
-rw-r--r--llvm/lib/Transforms/Utils/LoopSimplify.cpp20
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnroll.cpp42
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp78
-rw-r--r--llvm/lib/Transforms/Utils/LoopVersioning.cpp56
-rw-r--r--llvm/lib/Transforms/Utils/LowerAtomic.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp75
-rw-r--r--llvm/lib/Transforms/Utils/Mem2Reg.cpp12
-rw-r--r--llvm/lib/Transforms/Utils/MemoryOpRemark.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/MetaRenamer.cpp85
-rw-r--r--llvm/lib/Transforms/Utils/ModuleUtils.cpp20
-rw-r--r--llvm/lib/Transforms/Utils/MoveAutoInit.cpp231
-rw-r--r--llvm/lib/Transforms/Utils/NameAnonGlobals.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp81
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp511
-rw-r--r--llvm/lib/Transforms/Utils/SSAUpdater.cpp28
-rw-r--r--llvm/lib/Transforms/Utils/SampleProfileInference.cpp36
-rw-r--r--llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp442
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp261
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyIndVar.cpp26
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp218
-rw-r--r--llvm/lib/Transforms/Utils/SizeOpts.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp39
-rw-r--r--llvm/lib/Transforms/Utils/SymbolRewriter.cpp43
-rw-r--r--llvm/lib/Transforms/Utils/UnifyLoopExits.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/Utils.cpp27
-rw-r--r--llvm/lib/Transforms/Utils/VNCoercion.cpp192
-rw-r--r--llvm/lib/Transforms/Utils/ValueMapper.cpp10
-rw-r--r--llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp2049
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp188
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h64
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp2346
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp4239
-rw-r--r--llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h24
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.cpp116
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h680
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanCFG.h1
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp25
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h5
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp385
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp327
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h29
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanValue.h18
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp45
-rw-r--r--llvm/lib/Transforms/Vectorize/VectorCombine.cpp85
-rw-r--r--llvm/lib/Transforms/Vectorize/Vectorize.cpp19
228 files changed, 29185 insertions, 20475 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 473b41241b8a..34c8a380448e 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -18,6 +18,8 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -27,6 +29,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -64,7 +67,6 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
// shift amount.
auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
Value *&ShAmt) {
- Value *SubAmt;
unsigned Width = V->getType()->getScalarSizeInBits();
// fshl(ShVal0, ShVal1, ShAmt)
@@ -72,8 +74,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
if (match(V, m_OneUse(m_c_Or(
m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
m_LShr(m_Value(ShVal1),
- m_Sub(m_SpecificInt(Width), m_Value(SubAmt))))))) {
- if (ShAmt == SubAmt) // TODO: Use m_Specific
+ m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
return Intrinsic::fshl;
}
@@ -81,9 +82,8 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
// == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
if (match(V,
m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
- m_Value(SubAmt))),
- m_LShr(m_Value(ShVal1), m_Value(ShAmt)))))) {
- if (ShAmt == SubAmt) // TODO: Use m_Specific
+ m_Value(ShAmt))),
+ m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
return Intrinsic::fshr;
}
@@ -305,7 +305,7 @@ static bool tryToRecognizePopCount(Instruction &I) {
Value *MulOp0;
// Matching "(i * 0x01010101...) >> 24".
if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
- match(Op1, m_SpecificInt(MaskShift))) {
+ match(Op1, m_SpecificInt(MaskShift))) {
Value *ShiftOp0;
// Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
@@ -398,51 +398,6 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
return true;
}
-/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
-/// pessimistic codegen that has to account for setting errno and can enable
-/// vectorization.
-static bool
-foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) {
- // Match a call to sqrt mathlib function.
- auto *Call = dyn_cast<CallInst>(&I);
- if (!Call)
- return false;
-
- Module *M = Call->getModule();
- LibFunc Func;
- if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
- return false;
-
- if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
- return false;
-
- // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
- // (because NNAN or the operand arg must not be less than -0.0) and (2) we
- // would not end up lowering to a libcall anyway (which could change the value
- // of errno), then:
- // (1) errno won't be set.
- // (2) it is safe to convert this to an intrinsic call.
- Type *Ty = Call->getType();
- Value *Arg = Call->getArgOperand(0);
- if (TTI.haveFastSqrt(Ty) &&
- (Call->hasNoNaNs() || CannotBeOrderedLessThanZero(Arg, &TLI))) {
- IRBuilder<> Builder(&I);
- IRBuilderBase::FastMathFlagGuard Guard(Builder);
- Builder.setFastMathFlags(Call->getFastMathFlags());
-
- Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
- Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
- I.replaceAllUsesWith(NewSqrt);
-
- // Explicitly erase the old call because a call with side effects is not
- // trivially dead.
- I.eraseFromParent();
- return true;
- }
-
- return false;
-}
-
// Check if this array of constants represents a cttz table.
// Iterate over the elements from \p Table by trying to find/match all
// the numbers from 0 to \p InputBits that should represent cttz results.
@@ -613,7 +568,7 @@ struct LoadOps {
LoadInst *RootInsert = nullptr;
bool FoundRoot = false;
uint64_t LoadSize = 0;
- Value *Shift = nullptr;
+ const APInt *Shift = nullptr;
Type *ZextType;
AAMDNodes AATags;
};
@@ -623,7 +578,7 @@ struct LoadOps {
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
AliasAnalysis &AA) {
- Value *ShAmt2 = nullptr;
+ const APInt *ShAmt2 = nullptr;
Value *X;
Instruction *L1, *L2;
@@ -631,7 +586,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
if (match(V, m_OneUse(m_c_Or(
m_Value(X),
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
- m_Value(ShAmt2)))))) ||
+ m_APInt(ShAmt2)))))) ||
match(V, m_OneUse(m_Or(m_Value(X),
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
@@ -642,11 +597,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
// Check if the pattern has loads
LoadInst *LI1 = LOps.Root;
- Value *ShAmt1 = LOps.Shift;
+ const APInt *ShAmt1 = LOps.Shift;
if (LOps.FoundRoot == false &&
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
- m_Value(ShAmt1)))))) {
+ m_APInt(ShAmt1)))))) {
LI1 = dyn_cast<LoadInst>(L1);
}
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -721,12 +676,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
std::swap(ShAmt1, ShAmt2);
// Find Shifts values.
- const APInt *Temp;
uint64_t Shift1 = 0, Shift2 = 0;
- if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
- Shift1 = Temp->getZExtValue();
- if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
- Shift2 = Temp->getZExtValue();
+ if (ShAmt1)
+ Shift1 = ShAmt1->getZExtValue();
+ if (ShAmt2)
+ Shift2 = ShAmt2->getZExtValue();
// First load is always LI1. This is where we put the new load.
// Use the merged load size available from LI1 for forward loads.
@@ -768,7 +722,8 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
// pattern which suggests that the loads can be combined. The one and only use
// of the loads is to form a wider load.
static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
- TargetTransformInfo &TTI, AliasAnalysis &AA) {
+ TargetTransformInfo &TTI, AliasAnalysis &AA,
+ const DominatorTree &DT) {
// Only consider load chains of scalar values.
if (isa<VectorType>(I.getType()))
return false;
@@ -793,17 +748,18 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
if (!Allowed || !Fast)
return false;
- // Make sure the Load pointer of type GEP/non-GEP is above insert point
- Instruction *Inst = dyn_cast<Instruction>(LI1->getPointerOperand());
- if (Inst && Inst->getParent() == LI1->getParent() &&
- !Inst->comesBefore(LOps.RootInsert))
- Inst->moveBefore(LOps.RootInsert);
-
- // New load can be generated
+ // Get the Index and Ptr for the new GEP.
Value *Load1Ptr = LI1->getPointerOperand();
Builder.SetInsertPoint(LOps.RootInsert);
- Value *NewPtr = Builder.CreateBitCast(Load1Ptr, WiderType->getPointerTo(AS));
- NewLoad = Builder.CreateAlignedLoad(WiderType, NewPtr, LI1->getAlign(),
+ if (!DT.dominates(Load1Ptr, LOps.RootInsert)) {
+ APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
+ Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
+ DL, Offset1, /* AllowNonInbounds */ true);
+ Load1Ptr = Builder.CreateGEP(Builder.getInt8Ty(), Load1Ptr,
+ Builder.getInt32(Offset1.getZExtValue()));
+ }
+ // Generate wider load.
+ NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(),
LI1->isVolatile(), "");
NewLoad->takeName(LI1);
// Set the New Load AATags Metadata.
@@ -818,18 +774,254 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
// Check if shift needed. We need to shift with the amount of load1
// shift if not zero.
if (LOps.Shift)
- NewOp = Builder.CreateShl(NewOp, LOps.Shift);
+ NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
I.replaceAllUsesWith(NewOp);
return true;
}
+// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
+// ModOffset
+static std::pair<APInt, APInt>
+getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
+ unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
+ std::optional<APInt> Stride;
+ APInt ModOffset(BW, 0);
+ // Return a minimum gep stride, greatest common divisor of consective gep
+ // index scales(c.f. Bézout's identity).
+ while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) {
+ MapVector<Value *, APInt> VarOffsets;
+ if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset))
+ break;
+
+ for (auto [V, Scale] : VarOffsets) {
+ // Only keep a power of two factor for non-inbounds
+ if (!GEP->isInBounds())
+ Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero());
+
+ if (!Stride)
+ Stride = Scale;
+ else
+ Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale);
+ }
+
+ PtrOp = GEP->getPointerOperand();
+ }
+
+ // Check whether pointer arrives back at Global Variable via at least one GEP.
+ // Even if it doesn't, we can check by alignment.
+ if (!isa<GlobalVariable>(PtrOp) || !Stride)
+ return {APInt(BW, 1), APInt(BW, 0)};
+
+ // In consideration of signed GEP indices, non-negligible offset become
+ // remainder of division by minimum GEP stride.
+ ModOffset = ModOffset.srem(*Stride);
+ if (ModOffset.isNegative())
+ ModOffset += *Stride;
+
+ return {*Stride, ModOffset};
+}
+
+/// If C is a constant patterned array and all valid loaded results for given
+/// alignment are same to a constant, return that constant.
+static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
+ auto *LI = dyn_cast<LoadInst>(&I);
+ if (!LI || LI->isVolatile())
+ return false;
+
+ // We can only fold the load if it is from a constant global with definitive
+ // initializer. Skip expensive logic if this is not the case.
+ auto *PtrOp = LI->getPointerOperand();
+ auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp));
+ if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
+ return false;
+
+ // Bail for large initializers in excess of 4K to avoid too many scans.
+ Constant *C = GV->getInitializer();
+ uint64_t GVSize = DL.getTypeAllocSize(C->getType());
+ if (!GVSize || 4096 < GVSize)
+ return false;
+
+ Type *LoadTy = LI->getType();
+ unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
+ auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
+
+ // Any possible offset could be multiple of GEP stride. And any valid
+ // offset is multiple of load alignment, so checking only multiples of bigger
+ // one is sufficient to say results' equality.
+ if (auto LA = LI->getAlign();
+ LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
+ ConstOffset = APInt(BW, 0);
+ Stride = APInt(BW, LA.value());
+ }
+
+ Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL);
+ if (!Ca)
+ return false;
+
+ unsigned E = GVSize - DL.getTypeStoreSize(LoadTy);
+ for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
+ if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL))
+ return false;
+
+ I.replaceAllUsesWith(Ca);
+
+ return true;
+}
+
+/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
+/// pessimistic codegen that has to account for setting errno and can enable
+/// vectorization.
+static bool foldSqrt(CallInst *Call, TargetTransformInfo &TTI,
+ TargetLibraryInfo &TLI, AssumptionCache &AC,
+ DominatorTree &DT) {
+ Module *M = Call->getModule();
+
+ // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
+ // (because NNAN or the operand arg must not be less than -0.0) and (2) we
+ // would not end up lowering to a libcall anyway (which could change the value
+ // of errno), then:
+ // (1) errno won't be set.
+ // (2) it is safe to convert this to an intrinsic call.
+ Type *Ty = Call->getType();
+ Value *Arg = Call->getArgOperand(0);
+ if (TTI.haveFastSqrt(Ty) &&
+ (Call->hasNoNaNs() ||
+ cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, Call,
+ &DT))) {
+ IRBuilder<> Builder(Call);
+ IRBuilderBase::FastMathFlagGuard Guard(Builder);
+ Builder.setFastMathFlags(Call->getFastMathFlags());
+
+ Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
+ Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
+ Call->replaceAllUsesWith(NewSqrt);
+
+ // Explicitly erase the old call because a call with side effects is not
+ // trivially dead.
+ Call->eraseFromParent();
+ return true;
+ }
+
+ return false;
+}
+
+/// Try to expand strcmp(P, "x") calls.
+static bool expandStrcmp(CallInst *CI, DominatorTree &DT, bool &MadeCFGChange) {
+ Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1);
+
+ // Trivial cases are optimized during inst combine
+ if (Str1P == Str2P)
+ return false;
+
+ StringRef Str1, Str2;
+ bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+ bool HasStr2 = getConstantStringInfo(Str2P, Str2);
+
+ Value *NonConstantP = nullptr;
+ StringRef ConstantStr;
+
+ if (!HasStr1 && HasStr2 && Str2.size() == 1) {
+ NonConstantP = Str1P;
+ ConstantStr = Str2;
+ } else if (!HasStr2 && HasStr1 && Str1.size() == 1) {
+ NonConstantP = Str2P;
+ ConstantStr = Str1;
+ } else {
+ return false;
+ }
+
+ // Check if strcmp result is only used in a comparison with zero
+ if (!isOnlyUsedInZeroComparison(CI))
+ return false;
+
+ // For strcmp(P, "x") do the following transformation:
+ //
+ // (before)
+ // dst = strcmp(P, "x")
+ //
+ // (after)
+ // v0 = P[0] - 'x'
+ // [if v0 == 0]
+ // v1 = P[1]
+ // dst = phi(v0, v1)
+ //
+
+ IRBuilder<> B(CI->getParent());
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+
+ Type *RetType = CI->getType();
+
+ B.SetInsertPoint(CI);
+ BasicBlock *InitialBB = B.GetInsertBlock();
+ Value *Str1FirstCharacterValue =
+ B.CreateZExt(B.CreateLoad(B.getInt8Ty(), NonConstantP), RetType);
+ Value *Str2FirstCharacterValue =
+ ConstantInt::get(RetType, static_cast<unsigned char>(ConstantStr[0]));
+ Value *FirstCharacterSub =
+ B.CreateNSWSub(Str1FirstCharacterValue, Str2FirstCharacterValue);
+ Value *IsFirstCharacterSubZero =
+ B.CreateICmpEQ(FirstCharacterSub, ConstantInt::get(RetType, 0));
+ Instruction *IsFirstCharacterSubZeroBBTerminator = SplitBlockAndInsertIfThen(
+ IsFirstCharacterSubZero, CI, /*Unreachable*/ false,
+ /*BranchWeights*/ nullptr, &DTU);
+
+ B.SetInsertPoint(IsFirstCharacterSubZeroBBTerminator);
+ B.GetInsertBlock()->setName("strcmp_expand_sub_is_zero");
+ BasicBlock *IsFirstCharacterSubZeroBB = B.GetInsertBlock();
+ Value *Str1SecondCharacterValue = B.CreateZExt(
+ B.CreateLoad(B.getInt8Ty(), B.CreateConstInBoundsGEP1_64(
+ B.getInt8Ty(), NonConstantP, 1)),
+ RetType);
+
+ B.SetInsertPoint(CI);
+ B.GetInsertBlock()->setName("strcmp_expand_sub_join");
+
+ PHINode *Result = B.CreatePHI(RetType, 2);
+ Result->addIncoming(FirstCharacterSub, InitialBB);
+ Result->addIncoming(Str1SecondCharacterValue, IsFirstCharacterSubZeroBB);
+
+ CI->replaceAllUsesWith(Result);
+ CI->eraseFromParent();
+
+ MadeCFGChange = true;
+
+ return true;
+}
+
+static bool foldLibraryCalls(Instruction &I, TargetTransformInfo &TTI,
+ TargetLibraryInfo &TLI, DominatorTree &DT,
+ AssumptionCache &AC, bool &MadeCFGChange) {
+ CallInst *CI = dyn_cast<CallInst>(&I);
+ if (!CI)
+ return false;
+
+ LibFunc Func;
+ Module *M = I.getModule();
+ if (!TLI.getLibFunc(*CI, Func) || !isLibFuncEmittable(M, &TLI, Func))
+ return false;
+
+ switch (Func) {
+ case LibFunc_sqrt:
+ case LibFunc_sqrtf:
+ case LibFunc_sqrtl:
+ return foldSqrt(CI, TTI, TLI, AC, DT);
+ case LibFunc_strcmp:
+ return expandStrcmp(CI, DT, MadeCFGChange);
+ default:
+ break;
+ }
+
+ return false;
+}
+
/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
TargetTransformInfo &TTI,
- TargetLibraryInfo &TLI, AliasAnalysis &AA) {
+ TargetLibraryInfo &TLI, AliasAnalysis &AA,
+ AssumptionCache &AC, bool &MadeCFGChange) {
bool MadeChange = false;
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
@@ -849,11 +1041,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= tryToRecognizePopCount(I);
MadeChange |= tryToFPToSat(I, TTI);
MadeChange |= tryToRecognizeTableBasedCttz(I);
- MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA);
+ MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
+ MadeChange |= foldPatternedLoads(I, DL);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
- MadeChange |= foldSqrt(I, TTI, TLI);
+ MadeChange |= foldLibraryCalls(I, TTI, TLI, DT, AC, MadeCFGChange);
}
}
@@ -869,12 +1062,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
/// handled in the callers of this function.
static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
TargetLibraryInfo &TLI, DominatorTree &DT,
- AliasAnalysis &AA) {
+ AliasAnalysis &AA, bool &ChangedCFG) {
bool MadeChange = false;
const DataLayout &DL = F.getParent()->getDataLayout();
TruncInstCombine TIC(AC, TLI, DL, DT);
MadeChange |= TIC.run(F);
- MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA);
+ MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, ChangedCFG);
return MadeChange;
}
@@ -885,12 +1078,21 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
- if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
+
+ bool MadeCFGChange = false;
+
+ if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
}
+
// Mark all the analyses that instcombine updates as preserved.
PreservedAnalyses PA;
- PA.preserveSet<CFGAnalyses>();
+
+ if (MadeCFGChange)
+ PA.preserve<DominatorTreeAnalysis>();
+ else
+ PA.preserveSet<CFGAnalyses>();
+
return PA;
}
diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
index bebaa6cb5969..bf823ac55497 100644
--- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp
+++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
@@ -15,12 +15,12 @@
#include "llvm/Transforms/CFGuard.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/TargetParser/Triple.h"
using namespace llvm;
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index 81b43a2ab2c2..29978bef661c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -127,10 +127,16 @@ PreservedAnalyses CoroCleanupPass::run(Module &M,
FunctionPassManager FPM;
FPM.addPass(SimplifyCFGPass());
+ PreservedAnalyses FuncPA;
+ FuncPA.preserveSet<CFGAnalyses>();
+
Lowerer L(M);
- for (auto &F : M)
- if (L.lower(F))
+ for (auto &F : M) {
+ if (L.lower(F)) {
+ FAM.invalidate(F, FuncPA);
FPM.run(F, FAM);
+ }
+ }
return PreservedAnalyses::none();
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp
index 974123fe36a1..3e71e58bb1de 100644
--- a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp
@@ -26,7 +26,7 @@ PreservedAnalyses CoroConditionalWrapper::run(Module &M,
void CoroConditionalWrapper::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
OS << "coro-cond";
- OS << "(";
+ OS << '(';
PM.printPipeline(OS, MapClassName2PassName);
- OS << ")";
+ OS << ')';
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index f032c568449b..d78ab1c1ea28 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Support/ErrorHandling.h"
@@ -46,7 +47,8 @@ struct Lowerer : coro::LowererBase {
AAResults &AA);
bool shouldElide(Function *F, DominatorTree &DT) const;
void collectPostSplitCoroIds(Function *F);
- bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT);
+ bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT,
+ OptimizationRemarkEmitter &ORE);
bool hasEscapePath(const CoroBeginInst *,
const SmallPtrSetImpl<BasicBlock *> &) const;
};
@@ -299,7 +301,7 @@ void Lowerer::collectPostSplitCoroIds(Function *F) {
}
bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
- DominatorTree &DT) {
+ DominatorTree &DT, OptimizationRemarkEmitter &ORE) {
CoroBegins.clear();
CoroAllocs.clear();
ResumeAddr.clear();
@@ -343,6 +345,24 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
replaceWithConstant(ResumeAddrConstant, ResumeAddr);
bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
+ if (!ShouldElide)
+ ORE.emit([&]() {
+ if (auto FrameSizeAndAlign =
+ getFrameLayout(cast<Function>(ResumeAddrConstant)))
+ return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
+ << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
+ << "' not elided in '"
+ << ore::NV("caller", CoroId->getFunction()->getName())
+ << "' (frame_size="
+ << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
+ << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
+ else
+ return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
+ << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
+ << "' not elided in '"
+ << ore::NV("caller", CoroId->getFunction()->getName())
+ << "' (frame_size=unknown, align=unknown)";
+ });
auto *DestroyAddrConstant = Resumers->getAggregateElement(
ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
@@ -363,6 +383,23 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
<< "Elide " << CoroId->getCoroutine()->getName() << " in "
<< CoroId->getFunction()->getName() << "\n";
#endif
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId)
+ << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
+ << "' elided in '"
+ << ore::NV("caller", CoroId->getFunction()->getName())
+ << "' (frame_size="
+ << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
+ << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
+ });
+ } else {
+ ORE.emit([&]() {
+ return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
+ << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
+ << "' not elided in '"
+ << ore::NV("caller", CoroId->getFunction()->getName())
+ << "' (frame_size=unknown, align=unknown)";
+ });
}
}
@@ -387,10 +424,11 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
AAResults &AA = AM.getResult<AAManager>(F);
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
bool Changed = false;
for (auto *CII : L.CoroIds)
- Changed |= L.processCoroId(CII, AA, DT);
+ Changed |= L.processCoroId(CII, AA, DT, ORE);
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index e98c601648e0..1f373270f951 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -16,6 +16,7 @@
#include "CoroInternal.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Analysis/PtrUseVisitor.h"
@@ -37,6 +38,7 @@
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include <algorithm>
+#include <deque>
#include <optional>
using namespace llvm;
@@ -87,7 +89,7 @@ public:
// crosses a suspend point.
//
namespace {
-struct SuspendCrossingInfo {
+class SuspendCrossingInfo {
BlockToIndexMapping Mapping;
struct BlockData {
@@ -96,20 +98,30 @@ struct SuspendCrossingInfo {
bool Suspend = false;
bool End = false;
bool KillLoop = false;
+ bool Changed = false;
};
SmallVector<BlockData, SmallVectorThreshold> Block;
- iterator_range<succ_iterator> successors(BlockData const &BD) const {
+ iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
- return llvm::successors(BB);
+ return llvm::predecessors(BB);
}
BlockData &getBlockData(BasicBlock *BB) {
return Block[Mapping.blockToIndex(BB)];
}
+ /// Compute the BlockData for the current function in one iteration.
+ /// Returns whether the BlockData changes in this iteration.
+ /// Initialize - Whether this is the first iteration, we can optimize
+ /// the initial case a little bit by manual loop switch.
+ template <bool Initialize = false> bool computeBlockData();
+
+public:
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void dump() const;
void dump(StringRef Label, BitVector const &BV) const;
+#endif
SuspendCrossingInfo(Function &F, coro::Shape &Shape);
@@ -211,6 +223,72 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
}
#endif
+template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
+ const size_t N = Mapping.size();
+ bool Changed = false;
+
+ for (size_t I = 0; I < N; ++I) {
+ auto &B = Block[I];
+
+ // We don't need to count the predecessors when initialization.
+ if constexpr (!Initialize)
+ // If all the predecessors of the current Block don't change,
+ // the BlockData for the current block must not change too.
+ if (all_of(predecessors(B), [this](BasicBlock *BB) {
+ return !Block[Mapping.blockToIndex(BB)].Changed;
+ })) {
+ B.Changed = false;
+ continue;
+ }
+
+ // Saved Consumes and Kills bitsets so that it is easy to see
+ // if anything changed after propagation.
+ auto SavedConsumes = B.Consumes;
+ auto SavedKills = B.Kills;
+
+ for (BasicBlock *PI : predecessors(B)) {
+ auto PrevNo = Mapping.blockToIndex(PI);
+ auto &P = Block[PrevNo];
+
+ // Propagate Kills and Consumes from predecessors into B.
+ B.Consumes |= P.Consumes;
+ B.Kills |= P.Kills;
+
+ // If block P is a suspend block, it should propagate kills into block
+ // B for every block P consumes.
+ if (P.Suspend)
+ B.Kills |= P.Consumes;
+ }
+
+ if (B.Suspend) {
+ // If block S is a suspend block, it should kill all of the blocks it
+ // consumes.
+ B.Kills |= B.Consumes;
+ } else if (B.End) {
+ // If block B is an end block, it should not propagate kills as the
+ // blocks following coro.end() are reached during initial invocation
+ // of the coroutine while all the data are still available on the
+ // stack or in the registers.
+ B.Kills.reset();
+ } else {
+ // This is reached when B block it not Suspend nor coro.end and it
+ // need to make sure that it is not in the kill set.
+ B.KillLoop |= B.Kills[I];
+ B.Kills.reset(I);
+ }
+
+ if constexpr (!Initialize) {
+ B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
+ Changed |= B.Changed;
+ }
+ }
+
+ if constexpr (Initialize)
+ return true;
+
+ return Changed;
+}
+
SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
: Mapping(F) {
const size_t N = Mapping.size();
@@ -222,6 +300,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
B.Consumes.resize(N);
B.Kills.resize(N);
B.Consumes.set(I);
+ B.Changed = true;
}
// Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
@@ -246,73 +325,123 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
markSuspendBlock(Save);
}
- // Iterate propagating consumes and kills until they stop changing.
- int Iteration = 0;
- (void)Iteration;
+ computeBlockData</*Initialize=*/true>();
- bool Changed;
- do {
- LLVM_DEBUG(dbgs() << "iteration " << ++Iteration);
- LLVM_DEBUG(dbgs() << "==============\n");
-
- Changed = false;
- for (size_t I = 0; I < N; ++I) {
- auto &B = Block[I];
- for (BasicBlock *SI : successors(B)) {
-
- auto SuccNo = Mapping.blockToIndex(SI);
-
- // Saved Consumes and Kills bitsets so that it is easy to see
- // if anything changed after propagation.
- auto &S = Block[SuccNo];
- auto SavedConsumes = S.Consumes;
- auto SavedKills = S.Kills;
-
- // Propagate Kills and Consumes from block B into its successor S.
- S.Consumes |= B.Consumes;
- S.Kills |= B.Kills;
-
- // If block B is a suspend block, it should propagate kills into the
- // its successor for every block B consumes.
- if (B.Suspend) {
- S.Kills |= B.Consumes;
- }
- if (S.Suspend) {
- // If block S is a suspend block, it should kill all of the blocks it
- // consumes.
- S.Kills |= S.Consumes;
- } else if (S.End) {
- // If block S is an end block, it should not propagate kills as the
- // blocks following coro.end() are reached during initial invocation
- // of the coroutine while all the data are still available on the
- // stack or in the registers.
- S.Kills.reset();
- } else {
- // This is reached when S block it not Suspend nor coro.end and it
- // need to make sure that it is not in the kill set.
- S.KillLoop |= S.Kills[SuccNo];
- S.Kills.reset(SuccNo);
- }
+ while (computeBlockData())
+ ;
+
+ LLVM_DEBUG(dump());
+}
- // See if anything changed.
- Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes);
+namespace {
- if (S.Kills != SavedKills) {
- LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName()
- << "\n");
- LLVM_DEBUG(dump("S.Kills", S.Kills));
- LLVM_DEBUG(dump("SavedKills", SavedKills));
- }
- if (S.Consumes != SavedConsumes) {
- LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n");
- LLVM_DEBUG(dump("S.Consume", S.Consumes));
- LLVM_DEBUG(dump("SavedCons", SavedConsumes));
+// RematGraph is used to construct a DAG for rematerializable instructions
+// When the constructor is invoked with a candidate instruction (which is
+// materializable) it builds a DAG of materializable instructions from that
+// point.
+// Typically, for each instruction identified as re-materializable across a
+// suspend point, a RematGraph will be created.
+struct RematGraph {
+ // Each RematNode in the graph contains the edges to instructions providing
+ // operands in the current node.
+ struct RematNode {
+ Instruction *Node;
+ SmallVector<RematNode *> Operands;
+ RematNode() = default;
+ RematNode(Instruction *V) : Node(V) {}
+ };
+
+ RematNode *EntryNode;
+ using RematNodeMap =
+ SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
+ RematNodeMap Remats;
+ const std::function<bool(Instruction &)> &MaterializableCallback;
+ SuspendCrossingInfo &Checker;
+
+ RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
+ Instruction *I, SuspendCrossingInfo &Checker)
+ : MaterializableCallback(MaterializableCallback), Checker(Checker) {
+ std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
+ EntryNode = FirstNode.get();
+ std::deque<std::unique_ptr<RematNode>> WorkList;
+ addNode(std::move(FirstNode), WorkList, cast<User>(I));
+ while (WorkList.size()) {
+ std::unique_ptr<RematNode> N = std::move(WorkList.front());
+ WorkList.pop_front();
+ addNode(std::move(N), WorkList, cast<User>(I));
+ }
+ }
+
+ void addNode(std::unique_ptr<RematNode> NUPtr,
+ std::deque<std::unique_ptr<RematNode>> &WorkList,
+ User *FirstUse) {
+ RematNode *N = NUPtr.get();
+ if (Remats.count(N->Node))
+ return;
+
+ // We haven't see this node yet - add to the list
+ Remats[N->Node] = std::move(NUPtr);
+ for (auto &Def : N->Node->operands()) {
+ Instruction *D = dyn_cast<Instruction>(Def.get());
+ if (!D || !MaterializableCallback(*D) ||
+ !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
+ continue;
+
+ if (Remats.count(D)) {
+ // Already have this in the graph
+ N->Operands.push_back(Remats[D].get());
+ continue;
+ }
+
+ bool NoMatch = true;
+ for (auto &I : WorkList) {
+ if (I->Node == D) {
+ NoMatch = false;
+ N->Operands.push_back(I.get());
+ break;
}
}
+ if (NoMatch) {
+ // Create a new node
+ std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
+ N->Operands.push_back(ChildNode.get());
+ WorkList.push_back(std::move(ChildNode));
+ }
}
- } while (Changed);
- LLVM_DEBUG(dump());
-}
+ }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ void dump() const {
+ dbgs() << "Entry (";
+ if (EntryNode->Node->getParent()->hasName())
+ dbgs() << EntryNode->Node->getParent()->getName();
+ else
+ EntryNode->Node->getParent()->printAsOperand(dbgs(), false);
+ dbgs() << ") : " << *EntryNode->Node << "\n";
+ for (auto &E : Remats) {
+ dbgs() << *(E.first) << "\n";
+ for (RematNode *U : E.second->Operands)
+ dbgs() << " " << *U->Node << "\n";
+ }
+ }
+#endif
+};
+} // end anonymous namespace
+
+namespace llvm {
+
+template <> struct GraphTraits<RematGraph *> {
+ using NodeRef = RematGraph::RematNode *;
+ using ChildIteratorType = RematGraph::RematNode **;
+
+ static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
+ static ChildIteratorType child_begin(NodeRef N) {
+ return N->Operands.begin();
+ }
+ static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
+};
+
+} // end namespace llvm
#undef DEBUG_TYPE // "coro-suspend-crossing"
#define DEBUG_TYPE "coro-frame"
@@ -425,6 +554,15 @@ static void dumpSpills(StringRef Title, const SpillInfo &Spills) {
I->dump();
}
}
+static void dumpRemats(
+ StringRef Title,
+ const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
+ dbgs() << "------------- " << Title << "--------------\n";
+ for (const auto &E : RM) {
+ E.second->dump();
+ dbgs() << "--\n";
+ }
+}
static void dumpAllocas(const SmallVectorImpl<AllocaInfo> &Allocas) {
dbgs() << "------------- Allocas --------------\n";
@@ -637,10 +775,10 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F,
return;
}
- // Because there are pathes from the lifetime.start to coro.end
+ // Because there are paths from the lifetime.start to coro.end
// for each alloca, the liferanges for every alloca is overlaped
// in the blocks who contain coro.end and the successor blocks.
- // So we choose to skip there blocks when we calculates the liferange
+ // So we choose to skip there blocks when we calculate the liferange
// for each alloca. It should be reasonable since there shouldn't be uses
// in these blocks and the coroutine frame shouldn't be used outside the
// coroutine body.
@@ -820,7 +958,7 @@ void FrameTypeBuilder::finish(StructType *Ty) {
static void cacheDIVar(FrameDataInfo &FrameData,
DenseMap<Value *, DILocalVariable *> &DIVarCache) {
for (auto *V : FrameData.getAllDefs()) {
- if (DIVarCache.find(V) != DIVarCache.end())
+ if (DIVarCache.contains(V))
continue;
auto DDIs = FindDbgDeclareUses(V);
@@ -852,18 +990,8 @@ static StringRef solveTypeName(Type *Ty) {
return "__floating_type_";
}
- if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
- if (PtrTy->isOpaque())
- return "PointerType";
- Type *PointeeTy = PtrTy->getNonOpaquePointerElementType();
- auto Name = solveTypeName(PointeeTy);
- if (Name == "UnknownType")
- return "PointerType";
- SmallString<16> Buffer;
- Twine(Name + "_Ptr").toStringRef(Buffer);
- auto *MDName = MDString::get(Ty->getContext(), Buffer.str());
- return MDName->getString();
- }
+ if (Ty->isPointerTy())
+ return "PointerType";
if (Ty->isStructTy()) {
if (!cast<StructType>(Ty)->hasName())
@@ -1043,7 +1171,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape,
dwarf::DW_ATE_unsigned_char)});
for (auto *V : FrameData.getAllDefs()) {
- if (DIVarCache.find(V) == DIVarCache.end())
+ if (!DIVarCache.contains(V))
continue;
auto Index = FrameData.getFieldIndex(V);
@@ -1075,7 +1203,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape,
// fields confilicts with each other.
unsigned UnknownTypeNum = 0;
for (unsigned Index = 0; Index < FrameTy->getNumElements(); Index++) {
- if (OffsetCache.find(Index) == OffsetCache.end())
+ if (!OffsetCache.contains(Index))
continue;
std::string Name;
@@ -1090,7 +1218,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape,
AlignInBits = OffsetCache[Index].first * 8;
OffsetInBits = OffsetCache[Index].second * 8;
- if (NameCache.find(Index) != NameCache.end()) {
+ if (NameCache.contains(Index)) {
Name = NameCache[Index].str();
DITy = TyCache[Index];
} else {
@@ -1282,7 +1410,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
// function call or any of the memory intrinsics, we check whether this
// instruction is prior to CoroBegin. To answer question 3, we track the offsets
// of all aliases created for the alloca prior to CoroBegin but used after
-// CoroBegin. llvm::Optional is used to be able to represent the case when the
+// CoroBegin. std::optional is used to be able to represent the case when the
// offset is unknown (e.g. when you have a PHINode that takes in different
// offset values). We cannot handle unknown offsets and will assert. This is the
// potential issue left out. An ideal solution would likely require a
@@ -1586,11 +1714,12 @@ static void createFramePtr(coro::Shape &Shape) {
static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
auto *CB = Shape.CoroBegin;
LLVMContext &C = CB->getContext();
+ Function *F = CB->getFunction();
IRBuilder<> Builder(C);
StructType *FrameTy = Shape.FrameTy;
Value *FramePtr = Shape.FramePtr;
- DominatorTree DT(*CB->getFunction());
- SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache;
+ DominatorTree DT(*F);
+ SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
// Create a GEP with the given index into the coroutine frame for the original
// value Orig. Appends an extra 0 index for array-allocas, preserving the
@@ -1723,6 +1852,21 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
SpillAlignment, E.first->getName() + Twine(".reload"));
TinyPtrVector<DbgDeclareInst *> DIs = FindDbgDeclareUses(Def);
+ // Try best to find dbg.declare. If the spill is a temp, there may not
+ // be a direct dbg.declare. Walk up the load chain to find one from an
+ // alias.
+ if (F->getSubprogram()) {
+ auto *CurDef = Def;
+ while (DIs.empty() && isa<LoadInst>(CurDef)) {
+ auto *LdInst = cast<LoadInst>(CurDef);
+ // Only consider ptr to ptr same type load.
+ if (LdInst->getPointerOperandType() != LdInst->getType())
+ break;
+ CurDef = LdInst->getPointerOperand();
+ DIs = FindDbgDeclareUses(CurDef);
+ }
+ }
+
for (DbgDeclareInst *DDI : DIs) {
bool AllowUnresolved = false;
// This dbg.declare is preserved for all coro-split function
@@ -1734,16 +1878,10 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
&*Builder.GetInsertPoint());
// This dbg.declare is for the main function entry point. It
// will be deleted in all coro-split functions.
- coro::salvageDebugInfo(DbgPtrAllocaCache, DDI, Shape.OptimizeFrame);
+ coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame);
}
}
- // Salvage debug info on any dbg.addr that we see. We do not insert them
- // into each block where we have a use though.
- if (auto *DI = dyn_cast<DbgAddrIntrinsic>(U)) {
- coro::salvageDebugInfo(DbgPtrAllocaCache, DI, Shape.OptimizeFrame);
- }
-
// If we have a single edge PHINode, remove it and replace it with a
// reload from the coroutine frame. (We already took care of multi edge
// PHINodes by rewriting them in the rewritePHIs function).
@@ -1813,11 +1951,13 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
DVI->replaceUsesOfWith(Alloca, G);
for (Instruction *I : UsersToUpdate) {
- // It is meaningless to remain the lifetime intrinsics refer for the
+ // It is meaningless to retain the lifetime intrinsics refer for the
// member of coroutine frames and the meaningless lifetime intrinsics
// are possible to block further optimizations.
- if (I->isLifetimeStartOrEnd())
+ if (I->isLifetimeStartOrEnd()) {
+ I->eraseFromParent();
continue;
+ }
I->replaceUsesOfWith(Alloca, G);
}
@@ -2089,11 +2229,12 @@ static void rewritePHIs(Function &F) {
rewritePHIs(*BB);
}
+/// Default materializable callback
// Check for instructions that we can recreate on resume as opposed to spill
// the result into a coroutine frame.
-static bool materializable(Instruction &V) {
- return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
- isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
+bool coro::defaultMaterializable(Instruction &V) {
+ return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
+ isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
}
// Check for structural coroutine intrinsics that should not be spilled into
@@ -2103,41 +2244,82 @@ static bool isCoroutineStructureIntrinsic(Instruction &I) {
isa<CoroSuspendInst>(&I);
}
-// For every use of the value that is across suspend point, recreate that value
-// after a suspend point.
-static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
- const SpillInfo &Spills) {
- for (const auto &E : Spills) {
- Value *Def = E.first;
- BasicBlock *CurrentBlock = nullptr;
+// For each instruction identified as materializable across the suspend point,
+// and its associated DAG of other rematerializable instructions,
+// recreate the DAG of instructions after the suspend point.
+static void rewriteMaterializableInstructions(
+ const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
+ &AllRemats) {
+ // This has to be done in 2 phases
+ // Do the remats and record the required defs to be replaced in the
+ // original use instructions
+ // Once all the remats are complete, replace the uses in the final
+ // instructions with the new defs
+ typedef struct {
+ Instruction *Use;
+ Instruction *Def;
+ Instruction *Remat;
+ } ProcessNode;
+
+ SmallVector<ProcessNode> FinalInstructionsToProcess;
+
+ for (const auto &E : AllRemats) {
+ Instruction *Use = E.first;
Instruction *CurrentMaterialization = nullptr;
- for (Instruction *U : E.second) {
- // If we have not seen this block, materialize the value.
- if (CurrentBlock != U->getParent()) {
+ RematGraph *RG = E.second.get();
+ ReversePostOrderTraversal<RematGraph *> RPOT(RG);
+ SmallVector<Instruction *> InstructionsToProcess;
+
+ // If the target use is actually a suspend instruction then we have to
+ // insert the remats into the end of the predecessor (there should only be
+ // one). This is so that suspend blocks always have the suspend instruction
+ // as the first instruction.
+ auto InsertPoint = &*Use->getParent()->getFirstInsertionPt();
+ if (isa<AnyCoroSuspendInst>(Use)) {
+ BasicBlock *SuspendPredecessorBlock =
+ Use->getParent()->getSinglePredecessor();
+ assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
+ InsertPoint = SuspendPredecessorBlock->getTerminator();
+ }
- bool IsInCoroSuspendBlock = isa<AnyCoroSuspendInst>(U);
- CurrentBlock = U->getParent();
- auto *InsertBlock = IsInCoroSuspendBlock
- ? CurrentBlock->getSinglePredecessor()
- : CurrentBlock;
- CurrentMaterialization = cast<Instruction>(Def)->clone();
- CurrentMaterialization->setName(Def->getName());
- CurrentMaterialization->insertBefore(
- IsInCoroSuspendBlock ? InsertBlock->getTerminator()
- : &*InsertBlock->getFirstInsertionPt());
- }
- if (auto *PN = dyn_cast<PHINode>(U)) {
- assert(PN->getNumIncomingValues() == 1 &&
- "unexpected number of incoming "
- "values in the PHINode");
- PN->replaceAllUsesWith(CurrentMaterialization);
- PN->eraseFromParent();
- continue;
- }
- // Replace all uses of Def in the current instruction with the
- // CurrentMaterialization for the block.
- U->replaceUsesOfWith(Def, CurrentMaterialization);
+ // Note: skip the first instruction as this is the actual use that we're
+ // rematerializing everything for.
+ auto I = RPOT.begin();
+ ++I;
+ for (; I != RPOT.end(); ++I) {
+ Instruction *D = (*I)->Node;
+ CurrentMaterialization = D->clone();
+ CurrentMaterialization->setName(D->getName());
+ CurrentMaterialization->insertBefore(InsertPoint);
+ InsertPoint = CurrentMaterialization;
+
+ // Replace all uses of Def in the instructions being added as part of this
+ // rematerialization group
+ for (auto &I : InstructionsToProcess)
+ I->replaceUsesOfWith(D, CurrentMaterialization);
+
+ // Don't replace the final use at this point as this can cause problems
+ // for other materializations. Instead, for any final use that uses a
+ // define that's being rematerialized, record the replace values
+ for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
+ if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
+ FinalInstructionsToProcess.push_back(
+ {Use, D, CurrentMaterialization});
+
+ InstructionsToProcess.push_back(CurrentMaterialization);
+ }
+ }
+
+ // Finally, replace the uses with the defines that we've just rematerialized
+ for (auto &R : FinalInstructionsToProcess) {
+ if (auto *PN = dyn_cast<PHINode>(R.Use)) {
+ assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
+ "values in the PHINode");
+ PN->replaceAllUsesWith(R.Remat);
+ PN->eraseFromParent();
+ continue;
}
+ R.Use->replaceUsesOfWith(R.Def, R.Remat);
}
}
@@ -2407,10 +2589,7 @@ static void eliminateSwiftErrorArgument(Function &F, Argument &Arg,
IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
auto ArgTy = cast<PointerType>(Arg.getType());
- // swifterror arguments are required to have pointer-to-pointer type,
- // so create a pointer-typed alloca with opaque pointers.
- auto ValueTy = ArgTy->isOpaque() ? PointerType::getUnqual(F.getContext())
- : ArgTy->getNonOpaquePointerElementType();
+ auto ValueTy = PointerType::getUnqual(F.getContext());
// Reduce to the alloca case:
@@ -2523,6 +2702,9 @@ static void sinkSpillUsesAfterCoroBegin(Function &F,
/// hence minimizing the amount of data we end up putting on the frame.
static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape,
SuspendCrossingInfo &Checker) {
+ if (F.hasOptNone())
+ return;
+
DominatorTree DT(F);
// Collect all possible basic blocks which may dominate all uses of allocas.
@@ -2635,7 +2817,7 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape,
}
void coro::salvageDebugInfo(
- SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> &DbgPtrAllocaCache,
+ SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap,
DbgVariableIntrinsic *DVI, bool OptimizeFrame) {
Function *F = DVI->getFunction();
IRBuilder<> Builder(F->getContext());
@@ -2652,7 +2834,7 @@ void coro::salvageDebugInfo(
while (auto *Inst = dyn_cast_or_null<Instruction>(Storage)) {
if (auto *LdInst = dyn_cast<LoadInst>(Inst)) {
- Storage = LdInst->getOperand(0);
+ Storage = LdInst->getPointerOperand();
// FIXME: This is a heuristic that works around the fact that
// LLVM IR debug intrinsics cannot yet distinguish between
// memory and value locations: Because a dbg.declare(alloca) is
@@ -2662,7 +2844,7 @@ void coro::salvageDebugInfo(
if (!SkipOutermostLoad)
Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore);
} else if (auto *StInst = dyn_cast<StoreInst>(Inst)) {
- Storage = StInst->getOperand(0);
+ Storage = StInst->getValueOperand();
} else {
SmallVector<uint64_t, 16> Ops;
SmallVector<Value *, 0> AdditionalValues;
@@ -2682,38 +2864,44 @@ void coro::salvageDebugInfo(
if (!Storage)
return;
- // Store a pointer to the coroutine frame object in an alloca so it
- // is available throughout the function when producing unoptimized
- // code. Extending the lifetime this way is correct because the
- // variable has been declared by a dbg.declare intrinsic.
- //
- // Avoid to create the alloca would be eliminated by optimization
- // passes and the corresponding dbg.declares would be invalid.
- if (!OptimizeFrame)
- if (auto *Arg = dyn_cast<llvm::Argument>(Storage)) {
- auto &Cached = DbgPtrAllocaCache[Storage];
- if (!Cached) {
- Cached = Builder.CreateAlloca(Storage->getType(), 0, nullptr,
- Arg->getName() + ".debug");
- Builder.CreateStore(Storage, Cached);
- }
- Storage = Cached;
- // FIXME: LLVM lacks nuanced semantics to differentiate between
- // memory and direct locations at the IR level. The backend will
- // turn a dbg.declare(alloca, ..., DIExpression()) into a memory
- // location. Thus, if there are deref and offset operations in the
- // expression, we need to add a DW_OP_deref at the *start* of the
- // expression to first load the contents of the alloca before
- // adjusting it with the expression.
- Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore);
+ auto *StorageAsArg = dyn_cast<Argument>(Storage);
+ const bool IsSwiftAsyncArg =
+ StorageAsArg && StorageAsArg->hasAttribute(Attribute::SwiftAsync);
+
+ // Swift async arguments are described by an entry value of the ABI-defined
+ // register containing the coroutine context.
+ if (IsSwiftAsyncArg && !Expr->isEntryValue())
+ Expr = DIExpression::prepend(Expr, DIExpression::EntryValue);
+
+ // If the coroutine frame is an Argument, store it in an alloca to improve
+ // its availability (e.g. registers may be clobbered).
+ // Avoid this if optimizations are enabled (they would remove the alloca) or
+ // if the value is guaranteed to be available through other means (e.g. swift
+ // ABI guarantees).
+ if (StorageAsArg && !OptimizeFrame && !IsSwiftAsyncArg) {
+ auto &Cached = ArgToAllocaMap[StorageAsArg];
+ if (!Cached) {
+ Cached = Builder.CreateAlloca(Storage->getType(), 0, nullptr,
+ Storage->getName() + ".debug");
+ Builder.CreateStore(Storage, Cached);
}
+ Storage = Cached;
+ // FIXME: LLVM lacks nuanced semantics to differentiate between
+ // memory and direct locations at the IR level. The backend will
+ // turn a dbg.declare(alloca, ..., DIExpression()) into a memory
+ // location. Thus, if there are deref and offset operations in the
+ // expression, we need to add a DW_OP_deref at the *start* of the
+ // expression to first load the contents of the alloca before
+ // adjusting it with the expression.
+ Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore);
+ }
DVI->replaceVariableLocationOp(OriginalStorage, Storage);
DVI->setExpression(Expr);
// We only hoist dbg.declare today since it doesn't make sense to hoist
- // dbg.value or dbg.addr since they do not have the same function wide
- // guarantees that dbg.declare does.
- if (!isa<DbgValueInst>(DVI) && !isa<DbgAddrIntrinsic>(DVI)) {
+ // dbg.value since it does not have the same function wide guarantees that
+ // dbg.declare does.
+ if (isa<DbgDeclareInst>(DVI)) {
Instruction *InsertPt = nullptr;
if (auto *I = dyn_cast<Instruction>(Storage))
InsertPt = I->getInsertionPointAfterDef();
@@ -2724,7 +2912,71 @@ void coro::salvageDebugInfo(
}
}
-void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
+static void doRematerializations(
+ Function &F, SuspendCrossingInfo &Checker,
+ const std::function<bool(Instruction &)> &MaterializableCallback) {
+ if (F.hasOptNone())
+ return;
+
+ SpillInfo Spills;
+
+ // See if there are materializable instructions across suspend points
+ // We record these as the starting point to also identify materializable
+ // defs of uses in these operations
+ for (Instruction &I : instructions(F)) {
+ if (!MaterializableCallback(I))
+ continue;
+ for (User *U : I.users())
+ if (Checker.isDefinitionAcrossSuspend(I, U))
+ Spills[&I].push_back(cast<Instruction>(U));
+ }
+
+ // Process each of the identified rematerializable instructions
+ // and add predecessor instructions that can also be rematerialized.
+ // This is actually a graph of instructions since we could potentially
+ // have multiple uses of a def in the set of predecessor instructions.
+ // The approach here is to maintain a graph of instructions for each bottom
+ // level instruction - where we have a unique set of instructions (nodes)
+ // and edges between them. We then walk the graph in reverse post-dominator
+ // order to insert them past the suspend point, but ensure that ordering is
+ // correct. We also rely on CSE removing duplicate defs for remats of
+ // different instructions with a def in common (rather than maintaining more
+ // complex graphs for each suspend point)
+
+ // We can do this by adding new nodes to the list for each suspend
+ // point. Then using standard GraphTraits to give a reverse post-order
+ // traversal when we insert the nodes after the suspend
+ SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
+ for (auto &E : Spills) {
+ for (Instruction *U : E.second) {
+ // Don't process a user twice (this can happen if the instruction uses
+ // more than one rematerializable def)
+ if (AllRemats.count(U))
+ continue;
+
+ // Constructor creates the whole RematGraph for the given Use
+ auto RematUPtr =
+ std::make_unique<RematGraph>(MaterializableCallback, U, Checker);
+
+ LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
+ ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
+ for (auto I = RPOT.begin(); I != RPOT.end();
+ ++I) { (*I)->Node->dump(); } dbgs()
+ << "\n";);
+
+ AllRemats[U] = std::move(RematUPtr);
+ }
+ }
+
+ // Rewrite materializable instructions to be materialized at the use
+ // point.
+ LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
+ rewriteMaterializableInstructions(AllRemats);
+}
+
+void coro::buildCoroutineFrame(
+ Function &F, Shape &Shape,
+ const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
eliminateSwiftError(F, Shape);
@@ -2775,35 +3027,11 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
// Build suspend crossing info.
SuspendCrossingInfo Checker(F, Shape);
- IRBuilder<> Builder(F.getContext());
+ doRematerializations(F, Checker, MaterializableCallback);
+
FrameDataInfo FrameData;
SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
SmallVector<Instruction*, 4> DeadInstructions;
-
- {
- SpillInfo Spills;
- for (int Repeat = 0; Repeat < 4; ++Repeat) {
- // See if there are materializable instructions across suspend points.
- // FIXME: We can use a worklist to track the possible materialize
- // instructions instead of iterating the whole function again and again.
- for (Instruction &I : instructions(F))
- if (materializable(I)) {
- for (User *U : I.users())
- if (Checker.isDefinitionAcrossSuspend(I, U))
- Spills[&I].push_back(cast<Instruction>(U));
- }
-
- if (Spills.empty())
- break;
-
- // Rewrite materializable instructions to be materialized at the use
- // point.
- LLVM_DEBUG(dumpSpills("Materializations", Spills));
- rewriteMaterializableInstructions(Builder, Spills);
- Spills.clear();
- }
- }
-
if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon &&
Shape.ABI != coro::ABI::RetconOnce)
sinkLifetimeStartMarkers(F, Shape, Checker);
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index 032361c22045..067fb6bba47e 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -25,10 +25,13 @@ bool declaresIntrinsics(const Module &M,
const std::initializer_list<StringRef>);
void replaceCoroFree(CoroIdInst *CoroId, bool Elide);
-/// Recover a dbg.declare prepared by the frontend and emit an alloca
-/// holding a pointer to the coroutine frame.
+/// Attempts to rewrite the location operand of debug intrinsics in terms of
+/// the coroutine frame pointer, folding pointer offsets into the DIExpression
+/// of the intrinsic.
+/// If the frame pointer is an Argument, store it into an alloca if
+/// OptimizeFrame is false.
void salvageDebugInfo(
- SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> &DbgPtrAllocaCache,
+ SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap,
DbgVariableIntrinsic *DVI, bool OptimizeFrame);
// Keeps data and helper functions for lowering coroutine intrinsics.
@@ -124,7 +127,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
};
struct AsyncLoweringStorage {
- FunctionType *AsyncFuncTy;
Value *Context;
CallingConv::ID AsyncCC;
unsigned ContextArgNo;
@@ -261,7 +263,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
void buildFrom(Function &F);
};
-void buildCoroutineFrame(Function &F, Shape &Shape);
+bool defaultMaterializable(Instruction &V);
+void buildCoroutineFrame(
+ Function &F, Shape &Shape,
+ const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 1171878f749a..39e909bf3316 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -31,6 +31,7 @@
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/LazyCallGraph.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/IR/Argument.h"
@@ -299,6 +300,26 @@ static void markCoroutineAsDone(IRBuilder<> &Builder, const coro::Shape &Shape,
auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
Shape.FrameTy->getTypeAtIndex(coro::Shape::SwitchFieldIndex::Resume)));
Builder.CreateStore(NullPtr, GepIndex);
+
+ // If the coroutine don't have unwind coro end, we could omit the store to
+ // the final suspend point since we could infer the coroutine is suspended
+ // at the final suspend point by the nullness of ResumeFnAddr.
+ // However, we can't skip it if the coroutine have unwind coro end. Since
+ // the coroutine reaches unwind coro end is considered suspended at the
+ // final suspend point (the ResumeFnAddr is null) but in fact the coroutine
+ // didn't complete yet. We need the IndexVal for the final suspend point
+ // to make the states clear.
+ if (Shape.SwitchLowering.HasUnwindCoroEnd &&
+ Shape.SwitchLowering.HasFinalSuspend) {
+ assert(cast<CoroSuspendInst>(Shape.CoroSuspends.back())->isFinal() &&
+ "The final suspend should only live in the last position of "
+ "CoroSuspends.");
+ ConstantInt *IndexVal = Shape.getIndex(Shape.CoroSuspends.size() - 1);
+ auto *FinalIndex = Builder.CreateStructGEP(
+ Shape.FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+
+ Builder.CreateStore(IndexVal, FinalIndex);
+ }
}
/// Replace an unwind call to llvm.coro.end.
@@ -396,17 +417,7 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
// The coroutine should be marked done if it reaches the final suspend
// point.
markCoroutineAsDone(Builder, Shape, FramePtr);
- }
-
- // If the coroutine don't have unwind coro end, we could omit the store to
- // the final suspend point since we could infer the coroutine is suspended
- // at the final suspend point by the nullness of ResumeFnAddr.
- // However, we can't skip it if the coroutine have unwind coro end. Since
- // the coroutine reaches unwind coro end is considered suspended at the
- // final suspend point (the ResumeFnAddr is null) but in fact the coroutine
- // didn't complete yet. We need the IndexVal for the final suspend point
- // to make the states clear.
- if (!S->isFinal() || Shape.SwitchLowering.HasUnwindCoroEnd) {
+ } else {
auto *GepIndex = Builder.CreateStructGEP(
FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
Builder.CreateStore(IndexVal, GepIndex);
@@ -565,7 +576,7 @@ void CoroCloner::replaceRetconOrAsyncSuspendUses() {
if (NewS->use_empty()) return;
// Otherwise, we need to create an aggregate.
- Value *Agg = UndefValue::get(NewS->getType());
+ Value *Agg = PoisonValue::get(NewS->getType());
for (size_t I = 0, E = Args.size(); I != E; ++I)
Agg = Builder.CreateInsertValue(Agg, Args[I], I);
@@ -623,20 +634,13 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
return;
Value *CachedSlot = nullptr;
auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
- if (CachedSlot) {
- assert(cast<PointerType>(CachedSlot->getType())
- ->isOpaqueOrPointeeTypeMatches(ValueTy) &&
- "multiple swifterror slots in function with different types");
+ if (CachedSlot)
return CachedSlot;
- }
// Check if the function has a swifterror argument.
for (auto &Arg : F.args()) {
if (Arg.isSwiftError()) {
CachedSlot = &Arg;
- assert(cast<PointerType>(Arg.getType())
- ->isOpaqueOrPointeeTypeMatches(ValueTy) &&
- "swifterror argument does not have expected type");
return &Arg;
}
}
@@ -679,19 +683,26 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
}
}
+/// Returns all DbgVariableIntrinsic in F.
+static SmallVector<DbgVariableIntrinsic *, 8>
+collectDbgVariableIntrinsics(Function &F) {
+ SmallVector<DbgVariableIntrinsic *, 8> Intrinsics;
+ for (auto &I : instructions(F))
+ if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I))
+ Intrinsics.push_back(DVI);
+ return Intrinsics;
+}
+
void CoroCloner::replaceSwiftErrorOps() {
::replaceSwiftErrorOps(*NewF, Shape, &VMap);
}
void CoroCloner::salvageDebugInfo() {
- SmallVector<DbgVariableIntrinsic *, 8> Worklist;
- SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache;
- for (auto &BB : *NewF)
- for (auto &I : BB)
- if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I))
- Worklist.push_back(DVI);
+ SmallVector<DbgVariableIntrinsic *, 8> Worklist =
+ collectDbgVariableIntrinsics(*NewF);
+ SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
for (DbgVariableIntrinsic *DVI : Worklist)
- coro::salvageDebugInfo(DbgPtrAllocaCache, DVI, Shape.OptimizeFrame);
+ coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame);
// Remove all salvaged dbg.declare intrinsics that became
// either unreachable or stale due to the CoroSplit transformation.
@@ -886,7 +897,7 @@ void CoroCloner::create() {
// frame.
SmallVector<Instruction *> DummyArgs;
for (Argument &A : OrigF.args()) {
- DummyArgs.push_back(new FreezeInst(UndefValue::get(A.getType())));
+ DummyArgs.push_back(new FreezeInst(PoisonValue::get(A.getType())));
VMap[&A] = DummyArgs.back();
}
@@ -1044,7 +1055,7 @@ void CoroCloner::create() {
// All uses of the arguments should have been resolved by this point,
// so we can safely remove the dummy values.
for (Instruction *DummyArg : DummyArgs) {
- DummyArg->replaceAllUsesWith(UndefValue::get(DummyArg->getType()));
+ DummyArg->replaceAllUsesWith(PoisonValue::get(DummyArg->getType()));
DummyArg->deleteValue();
}
@@ -1231,8 +1242,11 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
// instruction. Suspend instruction represented by a switch, track the PHI
// values and select the correct case successor when possible.
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
+ // There is nothing to simplify.
+ if (isa<ReturnInst>(InitialInst))
+ return false;
+
DenseMap<Value *, Value *> ResolvedValues;
- BasicBlock *UnconditionalSucc = nullptr;
assert(InitialInst->getModule());
const DataLayout &DL = InitialInst->getModule()->getDataLayout();
@@ -1262,39 +1276,35 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
Instruction *I = InitialInst;
while (I->isTerminator() || isa<CmpInst>(I)) {
if (isa<ReturnInst>(I)) {
- if (I != InitialInst) {
- // If InitialInst is an unconditional branch,
- // remove PHI values that come from basic block of InitialInst
- if (UnconditionalSucc)
- UnconditionalSucc->removePredecessor(InitialInst->getParent(), true);
- ReplaceInstWithInst(InitialInst, I->clone());
- }
+ ReplaceInstWithInst(InitialInst, I->clone());
return true;
}
+
if (auto *BR = dyn_cast<BranchInst>(I)) {
- if (BR->isUnconditional()) {
- BasicBlock *Succ = BR->getSuccessor(0);
- if (I == InitialInst)
- UnconditionalSucc = Succ;
- scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
- I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime());
- continue;
+ unsigned SuccIndex = 0;
+ if (BR->isConditional()) {
+ // Handle the case the condition of the conditional branch is constant.
+ // e.g.,
+ //
+ // br i1 false, label %cleanup, label %CoroEnd
+ //
+ // It is possible during the transformation. We could continue the
+ // simplifying in this case.
+ ConstantInt *Cond = TryResolveConstant(BR->getCondition());
+ if (!Cond)
+ return false;
+
+ SuccIndex = Cond->isOne() ? 0 : 1;
}
- BasicBlock *BB = BR->getParent();
- // Handle the case the condition of the conditional branch is constant.
- // e.g.,
- //
- // br i1 false, label %cleanup, label %CoroEnd
- //
- // It is possible during the transformation. We could continue the
- // simplifying in this case.
- if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
- // Handle this branch in next iteration.
- I = BB->getTerminator();
- continue;
- }
- } else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
+ BasicBlock *Succ = BR->getSuccessor(SuccIndex);
+ scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
+ I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime());
+
+ continue;
+ }
+
+ if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
// If the case number of suspended switch instruction is reduced to
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
auto *BR = dyn_cast<BranchInst>(
@@ -1318,13 +1328,14 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
if (!ConstResult)
return false;
- CondCmp->replaceAllUsesWith(ConstResult);
- CondCmp->eraseFromParent();
+ ResolvedValues[BR->getCondition()] = ConstResult;
// Handle this branch in next iteration.
I = BR;
continue;
- } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(I)) {
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
if (!Cond)
return false;
@@ -1337,6 +1348,7 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
return false;
}
+
return false;
}
@@ -1889,7 +1901,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
if (ReturnPHIs.size() == 1) {
RetV = CastedContinuation;
} else {
- RetV = UndefValue::get(RetTy);
+ RetV = PoisonValue::get(RetTy);
RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0);
for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I)
RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I);
@@ -1929,10 +1941,10 @@ namespace {
};
}
-static coro::Shape splitCoroutine(Function &F,
- SmallVectorImpl<Function *> &Clones,
- TargetTransformInfo &TTI,
- bool OptimizeFrame) {
+static coro::Shape
+splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI, bool OptimizeFrame,
+ std::function<bool(Instruction &)> MaterializableCallback) {
PrettyStackTraceFunction prettyStackTrace(F);
// The suspend-crossing algorithm in buildCoroutineFrame get tripped
@@ -1944,7 +1956,7 @@ static coro::Shape splitCoroutine(Function &F,
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape);
+ buildCoroutineFrame(F, Shape, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
@@ -1970,25 +1982,12 @@ static coro::Shape splitCoroutine(Function &F,
// This invalidates SwiftErrorOps in the Shape.
replaceSwiftErrorOps(F, Shape, nullptr);
- // Finally, salvage the llvm.dbg.{declare,addr} in our original function that
- // point into the coroutine frame. We only do this for the current function
- // since the Cloner salvaged debug info for us in the new coroutine funclets.
- SmallVector<DbgVariableIntrinsic *, 8> Worklist;
- SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache;
- for (auto &BB : F) {
- for (auto &I : BB) {
- if (auto *DDI = dyn_cast<DbgDeclareInst>(&I)) {
- Worklist.push_back(DDI);
- continue;
- }
- if (auto *DDI = dyn_cast<DbgAddrIntrinsic>(&I)) {
- Worklist.push_back(DDI);
- continue;
- }
- }
- }
- for (auto *DDI : Worklist)
- coro::salvageDebugInfo(DbgPtrAllocaCache, DDI, Shape.OptimizeFrame);
+ // Salvage debug intrinsics that point into the coroutine frame in the
+ // original function. The Cloner has already salvaged debug info in the new
+ // coroutine funclets.
+ SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
+ for (auto *DDI : collectDbgVariableIntrinsics(F))
+ coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame);
return Shape;
}
@@ -2104,6 +2103,10 @@ static void addPrepareFunction(const Module &M,
Fns.push_back(PrepareFn);
}
+CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
+ : MaterializableCallback(coro::defaultMaterializable),
+ OptimizeFrame(OptimizeFrame) {}
+
PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR) {
@@ -2142,10 +2145,19 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
F.setSplittedCoroutine();
SmallVector<Function *, 4> Clones;
- const coro::Shape Shape = splitCoroutine(
- F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame);
+ auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ const coro::Shape Shape =
+ splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
+ OptimizeFrame, MaterializableCallback);
updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "CoroSplit", &F)
+ << "Split '" << ore::NV("function", F.getName())
+ << "' (frame_size=" << ore::NV("frame_size", Shape.FrameSize)
+ << ", align=" << ore::NV("align", Shape.FrameAlign.value()) << ")";
+ });
+
if (!Shape.CoroSuspends.empty()) {
// Run the CGSCC pipeline on the original and newly split functions.
UR.CWorklist.insert(&C);
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index ce4262e593b6..cde74c5e693b 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -596,20 +596,6 @@ static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
if (!AsyncFuncPtrAddr)
fail(I, "llvm.coro.id.async async function pointer not a global", V);
-
- if (AsyncFuncPtrAddr->getType()->isOpaquePointerTy())
- return;
-
- auto *StructTy = cast<StructType>(
- AsyncFuncPtrAddr->getType()->getNonOpaquePointerElementType());
- if (StructTy->isOpaque() || !StructTy->isPacked() ||
- StructTy->getNumElements() != 2 ||
- !StructTy->getElementType(0)->isIntegerTy(32) ||
- !StructTy->getElementType(1)->isIntegerTy(32))
- fail(I,
- "llvm.coro.id.async async function pointer argument's type is not "
- "<{i32, i32}>",
- V);
}
void CoroIdAsyncInst::checkWellFormed() const {
@@ -625,19 +611,15 @@ void CoroIdAsyncInst::checkWellFormed() const {
static void checkAsyncContextProjectFunction(const Instruction *I,
Function *F) {
auto *FunTy = cast<FunctionType>(F->getValueType());
- Type *Int8Ty = Type::getInt8Ty(F->getContext());
- auto *RetPtrTy = dyn_cast<PointerType>(FunTy->getReturnType());
- if (!RetPtrTy || !RetPtrTy->isOpaqueOrPointeeTypeMatches(Int8Ty))
+ if (!FunTy->getReturnType()->isPointerTy())
fail(I,
"llvm.coro.suspend.async resume function projection function must "
- "return an i8* type",
+ "return a ptr type",
F);
- if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() ||
- !cast<PointerType>(FunTy->getParamType(0))
- ->isOpaqueOrPointeeTypeMatches(Int8Ty))
+ if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
fail(I,
"llvm.coro.suspend.async resume function projection function must "
- "take one i8* type as parameter",
+ "take one ptr type as parameter",
F);
}
diff --git a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp
index 09286482edff..cc375f9badcd 100644
--- a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp
+++ b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp
@@ -28,16 +28,13 @@ using namespace llvm;
#define DEBUG_TYPE "inline"
-PreservedAnalyses AlwaysInlinerPass::run(Module &M,
- ModuleAnalysisManager &MAM) {
- // Add inline assumptions during code generation.
- FunctionAnalysisManager &FAM =
- MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
- return FAM.getResult<AssumptionAnalysis>(F);
- };
- auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
+namespace {
+bool AlwaysInlineImpl(
+ Module &M, bool InsertLifetime, ProfileSummaryInfo &PSI,
+ function_ref<AssumptionCache &(Function &)> GetAssumptionCache,
+ function_ref<AAResults &(Function &)> GetAAR,
+ function_ref<BlockFrequencyInfo &(Function &)> GetBFI) {
SmallSetVector<CallBase *, 16> Calls;
bool Changed = false;
SmallVector<Function *, 16> InlinedFunctions;
@@ -65,14 +62,12 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M,
DebugLoc DLoc = CB->getDebugLoc();
BasicBlock *Block = CB->getParent();
- InlineFunctionInfo IFI(
- /*cg=*/nullptr, GetAssumptionCache, &PSI,
- &FAM.getResult<BlockFrequencyAnalysis>(*Caller),
- &FAM.getResult<BlockFrequencyAnalysis>(F));
+ InlineFunctionInfo IFI(GetAssumptionCache, &PSI,
+ GetBFI ? &GetBFI(*Caller) : nullptr,
+ GetBFI ? &GetBFI(F) : nullptr);
- InlineResult Res =
- InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
- &FAM.getResult<AAManager>(F), InsertLifetime);
+ InlineResult Res = InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
+ &GetAAR(F), InsertLifetime);
if (!Res.isSuccess()) {
ORE.emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc,
@@ -127,48 +122,52 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M,
}
}
- return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+ return Changed;
}
-namespace {
-
-/// Inliner pass which only handles "always inline" functions.
-///
-/// Unlike the \c AlwaysInlinerPass, this uses the more heavyweight \c Inliner
-/// base class to provide several facilities such as array alloca merging.
-class AlwaysInlinerLegacyPass : public LegacyInlinerBase {
+struct AlwaysInlinerLegacyPass : public ModulePass {
+ bool InsertLifetime;
-public:
- AlwaysInlinerLegacyPass() : LegacyInlinerBase(ID, /*InsertLifetime*/ true) {
- initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry());
- }
+ AlwaysInlinerLegacyPass()
+ : AlwaysInlinerLegacyPass(/*InsertLifetime*/ true) {}
AlwaysInlinerLegacyPass(bool InsertLifetime)
- : LegacyInlinerBase(ID, InsertLifetime) {
+ : ModulePass(ID), InsertLifetime(InsertLifetime) {
initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry());
}
/// Main run interface method. We override here to avoid calling skipSCC().
- bool runOnSCC(CallGraphSCC &SCC) override { return inlineCalls(SCC); }
+ bool runOnModule(Module &M) override {
+
+ auto &PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
+ auto GetAAR = [&](Function &F) -> AAResults & {
+ return getAnalysis<AAResultsWrapperPass>(F).getAAResults();
+ };
+ auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
+ return getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ };
+
+ return AlwaysInlineImpl(M, InsertLifetime, PSI, GetAssumptionCache, GetAAR,
+ /*GetBFI*/ nullptr);
+ }
static char ID; // Pass identification, replacement for typeid
- InlineCost getInlineCost(CallBase &CB) override;
-
- using llvm::Pass::doFinalization;
- bool doFinalization(CallGraph &CG) override {
- return removeDeadFunctions(CG, /*AlwaysInlineOnly=*/true);
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<ProfileSummaryInfoWrapperPass>();
}
};
-}
+
+} // namespace
char AlwaysInlinerLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(AlwaysInlinerLegacyPass, "always-inline",
"Inliner for always_inline functions", false, false)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(AlwaysInlinerLegacyPass, "always-inline",
"Inliner for always_inline functions", false, false)
@@ -176,46 +175,23 @@ Pass *llvm::createAlwaysInlinerLegacyPass(bool InsertLifetime) {
return new AlwaysInlinerLegacyPass(InsertLifetime);
}
-/// Get the inline cost for the always-inliner.
-///
-/// The always inliner *only* handles functions which are marked with the
-/// attribute to force inlining. As such, it is dramatically simpler and avoids
-/// using the powerful (but expensive) inline cost analysis. Instead it uses
-/// a very simple and boring direct walk of the instructions looking for
-/// impossible-to-inline constructs.
-///
-/// Note, it would be possible to go to some lengths to cache the information
-/// computed here, but as we only expect to do this for relatively few and
-/// small functions which have the explicit attribute to force inlining, it is
-/// likely not worth it in practice.
-InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallBase &CB) {
- Function *Callee = CB.getCalledFunction();
-
- // Only inline direct calls to functions with always-inline attributes
- // that are viable for inlining.
- if (!Callee)
- return InlineCost::getNever("indirect call");
-
- // When callee coroutine function is inlined into caller coroutine function
- // before coro-split pass,
- // coro-early pass can not handle this quiet well.
- // So we won't inline the coroutine function if it have not been unsplited
- if (Callee->isPresplitCoroutine())
- return InlineCost::getNever("unsplited coroutine call");
-
- // FIXME: We shouldn't even get here for declarations.
- if (Callee->isDeclaration())
- return InlineCost::getNever("no definition");
-
- if (!CB.hasFnAttr(Attribute::AlwaysInline))
- return InlineCost::getNever("no alwaysinline attribute");
-
- if (Callee->hasFnAttribute(Attribute::AlwaysInline) && CB.isNoInline())
- return InlineCost::getNever("noinline call site attribute");
-
- auto IsViable = isInlineViable(*Callee);
- if (!IsViable.isSuccess())
- return InlineCost::getNever(IsViable.getFailureReason());
-
- return InlineCost::getAlways("always inliner");
+PreservedAnalyses AlwaysInlinerPass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ FunctionAnalysisManager &FAM =
+ MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
+ return FAM.getResult<AssumptionAnalysis>(F);
+ };
+ auto GetBFI = [&](Function &F) -> BlockFrequencyInfo & {
+ return FAM.getResult<BlockFrequencyAnalysis>(F);
+ };
+ auto GetAAR = [&](Function &F) -> AAResults & {
+ return FAM.getResult<AAManager>(F);
+ };
+ auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
+
+ bool Changed = AlwaysInlineImpl(M, InsertLifetime, PSI, GetAssumptionCache,
+ GetAAR, GetBFI);
+
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp
index 6cc04544cabc..40cc00d2c78c 100644
--- a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp
+++ b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp
@@ -17,8 +17,6 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/IPO.h"
using namespace llvm;
@@ -64,36 +62,8 @@ static bool convertAnnotation2Metadata(Module &M) {
return true;
}
-namespace {
-struct Annotation2MetadataLegacy : public ModulePass {
- static char ID;
-
- Annotation2MetadataLegacy() : ModulePass(ID) {
- initializeAnnotation2MetadataLegacyPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override { return convertAnnotation2Metadata(M); }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
-};
-
-} // end anonymous namespace
-
-char Annotation2MetadataLegacy::ID = 0;
-
-INITIALIZE_PASS_BEGIN(Annotation2MetadataLegacy, DEBUG_TYPE,
- "Annotation2Metadata", false, false)
-INITIALIZE_PASS_END(Annotation2MetadataLegacy, DEBUG_TYPE,
- "Annotation2Metadata", false, false)
-
-ModulePass *llvm::createAnnotation2MetadataLegacyPass() {
- return new Annotation2MetadataLegacy();
-}
-
PreservedAnalyses Annotation2MetadataPass::run(Module &M,
ModuleAnalysisManager &AM) {
- convertAnnotation2Metadata(M);
- return PreservedAnalyses::all();
+ return convertAnnotation2Metadata(M) ? PreservedAnalyses::none()
+ : PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index dd1a3b78a378..824da6395f2e 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -67,6 +67,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include <algorithm>
#include <cassert>
@@ -97,49 +98,11 @@ using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
Value *Ptr, Type *ResElemTy, int64_t Offset) {
- // For non-opaque pointers, try to create a "nice" GEP if possible, otherwise
- // fall back to an i8 GEP to a specific offset.
- unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
- APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
- if (!Ptr->getType()->isOpaquePointerTy()) {
- Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType();
- if (OrigOffset == 0 && OrigElemTy == ResElemTy)
- return Ptr;
-
- if (OrigElemTy->isSized()) {
- APInt TmpOffset = OrigOffset;
- Type *TmpTy = OrigElemTy;
- SmallVector<APInt> IntIndices =
- DL.getGEPIndicesForOffset(TmpTy, TmpOffset);
- if (TmpOffset == 0) {
- // Try to add trailing zero indices to reach the right type.
- while (TmpTy != ResElemTy) {
- Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0);
- if (!NextTy)
- break;
-
- IntIndices.push_back(APInt::getZero(
- isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth()));
- TmpTy = NextTy;
- }
-
- SmallVector<Value *> Indices;
- for (const APInt &Index : IntIndices)
- Indices.push_back(IRB.getInt(Index));
-
- if (OrigOffset != 0 || TmpTy == ResElemTy) {
- Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices);
- return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
- }
- }
- }
+ if (Offset != 0) {
+ APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
+ Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(APOffset));
}
-
- if (OrigOffset != 0) {
- Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace));
- Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset));
- }
- return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
+ return Ptr;
}
/// DoPromotion - This method actually performs the promotion of the specified
@@ -220,6 +183,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
// pass in the loaded pointers.
SmallVector<Value *, 16> Args;
const DataLayout &DL = F->getParent()->getDataLayout();
+ SmallVector<WeakTrackingVH, 16> DeadArgs;
+
while (!F->use_empty()) {
CallBase &CB = cast<CallBase>(*F->user_back());
assert(CB.getCalledFunction() == F);
@@ -246,15 +211,25 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
if (Pair.second.MustExecInstr) {
LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata());
LI->copyMetadata(*Pair.second.MustExecInstr,
- {LLVMContext::MD_range, LLVMContext::MD_nonnull,
- LLVMContext::MD_dereferenceable,
+ {LLVMContext::MD_dereferenceable,
LLVMContext::MD_dereferenceable_or_null,
- LLVMContext::MD_align, LLVMContext::MD_noundef,
+ LLVMContext::MD_noundef,
LLVMContext::MD_nontemporal});
+ // Only transfer poison-generating metadata if we also have
+ // !noundef.
+ // TODO: Without !noundef, we could merge this metadata across
+ // all promoted loads.
+ if (LI->hasMetadata(LLVMContext::MD_noundef))
+ LI->copyMetadata(*Pair.second.MustExecInstr,
+ {LLVMContext::MD_range, LLVMContext::MD_nonnull,
+ LLVMContext::MD_align});
}
Args.push_back(LI);
ArgAttrVec.push_back(AttributeSet());
}
+ } else {
+ assert(ArgsToPromote.count(&*I) && I->use_empty());
+ DeadArgs.emplace_back(AI->get());
}
}
@@ -297,6 +272,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
CB.eraseFromParent();
}
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadArgs);
+
// Since we have now created the new function, splice the body of the old
// function right into the new function, leaving the old rotting hulk of the
// function empty.
@@ -766,6 +743,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
// Check to see which arguments are promotable. If an argument is promotable,
// add it to ArgsToPromote.
DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
+ unsigned NumArgsAfterPromote = F->getFunctionType()->getNumParams();
for (Argument *PtrArg : PointerArgs) {
// Replace sret attribute with noalias. This reduces register pressure by
// avoiding a register copy.
@@ -789,6 +767,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
Types.push_back(Pair.second.Ty);
if (areTypesABICompatible(Types, *F, TTI)) {
+ NumArgsAfterPromote += ArgParts.size() - 1;
ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
}
}
@@ -798,6 +777,9 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
if (ArgsToPromote.empty())
return nullptr;
+ if (NumArgsAfterPromote > TTI.getMaxNumArgs())
+ return nullptr;
+
return doPromotion(F, FAM, ArgsToPromote);
}
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index b9134ce26e80..847d07a49dee 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -15,16 +15,17 @@
#include "llvm/Transforms/IPO/Attributor.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MustExecute.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantFold.h"
@@ -35,14 +36,15 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugCounter.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/ModRef.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -98,11 +100,6 @@ static cl::opt<unsigned, true> MaxInitializationChainLengthX(
cl::location(MaxInitializationChainLength), cl::init(1024));
unsigned llvm::MaxInitializationChainLength;
-static cl::opt<bool> VerifyMaxFixpointIterations(
- "attributor-max-iterations-verify", cl::Hidden,
- cl::desc("Verify that max-iterations is a tight bound for a fixpoint"),
- cl::init(false));
-
static cl::opt<bool> AnnotateDeclarationCallSites(
"attributor-annotate-decl-cs", cl::Hidden,
cl::desc("Annotate call sites of function declarations."), cl::init(false));
@@ -188,6 +185,11 @@ ChangeStatus &llvm::operator&=(ChangeStatus &L, ChangeStatus R) {
}
///}
+bool AA::isGPU(const Module &M) {
+ Triple T(M.getTargetTriple());
+ return T.isAMDGPU() || T.isNVPTX();
+}
+
bool AA::isNoSyncInst(Attributor &A, const Instruction &I,
const AbstractAttribute &QueryingAA) {
// We are looking for volatile instructions or non-relaxed atomics.
@@ -202,9 +204,10 @@ bool AA::isNoSyncInst(Attributor &A, const Instruction &I,
if (AANoSync::isNoSyncIntrinsic(&I))
return true;
- const auto &NoSyncAA = A.getAAFor<AANoSync>(
- QueryingAA, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
- return NoSyncAA.isAssumedNoSync();
+ bool IsKnownNoSync;
+ return AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, &QueryingAA, IRPosition::callsite_function(*CB),
+ DepClassTy::OPTIONAL, IsKnownNoSync);
}
if (!I.mayReadOrWriteMemory())
@@ -218,12 +221,12 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA,
// TODO: See the AAInstanceInfo class comment.
if (!ForAnalysisOnly)
return false;
- auto &InstanceInfoAA = A.getAAFor<AAInstanceInfo>(
+ auto *InstanceInfoAA = A.getAAFor<AAInstanceInfo>(
QueryingAA, IRPosition::value(V), DepClassTy::OPTIONAL);
- return InstanceInfoAA.isAssumedUniqueForAnalysis();
+ return InstanceInfoAA && InstanceInfoAA->isAssumedUniqueForAnalysis();
}
-Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty,
+Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty,
const TargetLibraryInfo *TLI,
const DataLayout &DL,
AA::RangeTy *RangePtr) {
@@ -234,17 +237,31 @@ Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty,
auto *GV = dyn_cast<GlobalVariable>(&Obj);
if (!GV)
return nullptr;
- if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer()))
- return nullptr;
- if (!GV->hasInitializer())
- return UndefValue::get(&Ty);
+
+ bool UsedAssumedInformation = false;
+ Constant *Initializer = nullptr;
+ if (A.hasGlobalVariableSimplificationCallback(*GV)) {
+ auto AssumedGV = A.getAssumedInitializerFromCallBack(
+ *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation);
+ Initializer = *AssumedGV;
+ if (!Initializer)
+ return nullptr;
+ } else {
+ if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer()))
+ return nullptr;
+ if (!GV->hasInitializer())
+ return UndefValue::get(&Ty);
+
+ if (!Initializer)
+ Initializer = GV->getInitializer();
+ }
if (RangePtr && !RangePtr->offsetOrSizeAreUnknown()) {
APInt Offset = APInt(64, RangePtr->Offset);
- return ConstantFoldLoadFromConst(GV->getInitializer(), &Ty, Offset, DL);
+ return ConstantFoldLoadFromConst(Initializer, &Ty, Offset, DL);
}
- return ConstantFoldLoadFromUniformValue(GV->getInitializer(), &Ty);
+ return ConstantFoldLoadFromUniformValue(Initializer, &Ty);
}
bool AA::isValidInScope(const Value &V, const Function *Scope) {
@@ -396,6 +413,18 @@ static bool getPotentialCopiesOfMemoryValue(
NullOnly = false;
};
+ auto AdjustWrittenValueType = [&](const AAPointerInfo::Access &Acc,
+ Value &V) {
+ Value *AdjV = AA::getWithType(V, *I.getType());
+ if (!AdjV) {
+ LLVM_DEBUG(dbgs() << "Underlying object written but stored value "
+ "cannot be converted to read type: "
+ << *Acc.getRemoteInst() << " : " << *I.getType()
+ << "\n";);
+ }
+ return AdjV;
+ };
+
auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) {
if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead()))
return true;
@@ -417,7 +446,10 @@ static bool getPotentialCopiesOfMemoryValue(
if (IsLoad) {
assert(isa<LoadInst>(I) && "Expected load or store instruction only!");
if (!Acc.isWrittenValueUnknown()) {
- NewCopies.push_back(Acc.getWrittenValue());
+ Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue());
+ if (!V)
+ return false;
+ NewCopies.push_back(V);
NewCopyOrigins.push_back(Acc.getRemoteInst());
return true;
}
@@ -428,7 +460,10 @@ static bool getPotentialCopiesOfMemoryValue(
<< *Acc.getRemoteInst() << "\n";);
return false;
}
- NewCopies.push_back(SI->getValueOperand());
+ Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand());
+ if (!V)
+ return false;
+ NewCopies.push_back(V);
NewCopyOrigins.push_back(SI);
} else {
assert(isa<StoreInst>(I) && "Expected load or store instruction only!");
@@ -449,10 +484,13 @@ static bool getPotentialCopiesOfMemoryValue(
bool HasBeenWrittenTo = false;
AA::RangeTy Range;
- auto &PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj),
+ auto *PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj),
DepClassTy::NONE);
- if (!PI.forallInterferingAccesses(A, QueryingAA, I, CheckAccess,
- HasBeenWrittenTo, Range)) {
+ if (!PI ||
+ !PI->forallInterferingAccesses(A, QueryingAA, I,
+ /* FindInterferingWrites */ IsLoad,
+ /* FindInterferingReads */ !IsLoad,
+ CheckAccess, HasBeenWrittenTo, Range)) {
LLVM_DEBUG(
dbgs()
<< "Failed to verify all interfering accesses for underlying object: "
@@ -463,7 +501,7 @@ static bool getPotentialCopiesOfMemoryValue(
if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) {
const DataLayout &DL = A.getDataLayout();
Value *InitialValue =
- AA::getInitialValueForObj(Obj, *I.getType(), TLI, DL, &Range);
+ AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range);
if (!InitialValue) {
LLVM_DEBUG(dbgs() << "Could not determine required initial value of "
"underlying object, abort!\n");
@@ -480,14 +518,14 @@ static bool getPotentialCopiesOfMemoryValue(
NewCopyOrigins.push_back(nullptr);
}
- PIs.push_back(&PI);
+ PIs.push_back(PI);
return true;
};
- const auto &AAUO = A.getAAFor<AAUnderlyingObjects>(
+ const auto *AAUO = A.getAAFor<AAUnderlyingObjects>(
QueryingAA, IRPosition::value(Ptr), DepClassTy::OPTIONAL);
- if (!AAUO.forallUnderlyingObjects(Pred)) {
+ if (!AAUO || !AAUO->forallUnderlyingObjects(Pred)) {
LLVM_DEBUG(
dbgs() << "Underlying objects stored into could not be determined\n";);
return false;
@@ -530,27 +568,37 @@ bool AA::getPotentialCopiesOfStoredValue(
static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP,
const AbstractAttribute &QueryingAA,
bool RequireReadNone, bool &IsKnown) {
+ if (RequireReadNone) {
+ if (AA::hasAssumedIRAttr<Attribute::ReadNone>(
+ A, &QueryingAA, IRP, DepClassTy::OPTIONAL, IsKnown,
+ /* IgnoreSubsumingPositions */ true))
+ return true;
+ } else if (AA::hasAssumedIRAttr<Attribute::ReadOnly>(
+ A, &QueryingAA, IRP, DepClassTy::OPTIONAL, IsKnown,
+ /* IgnoreSubsumingPositions */ true))
+ return true;
IRPosition::Kind Kind = IRP.getPositionKind();
if (Kind == IRPosition::IRP_FUNCTION || Kind == IRPosition::IRP_CALL_SITE) {
- const auto &MemLocAA =
+ const auto *MemLocAA =
A.getAAFor<AAMemoryLocation>(QueryingAA, IRP, DepClassTy::NONE);
- if (MemLocAA.isAssumedReadNone()) {
- IsKnown = MemLocAA.isKnownReadNone();
+ if (MemLocAA && MemLocAA->isAssumedReadNone()) {
+ IsKnown = MemLocAA->isKnownReadNone();
if (!IsKnown)
- A.recordDependence(MemLocAA, QueryingAA, DepClassTy::OPTIONAL);
+ A.recordDependence(*MemLocAA, QueryingAA, DepClassTy::OPTIONAL);
return true;
}
}
- const auto &MemBehaviorAA =
+ const auto *MemBehaviorAA =
A.getAAFor<AAMemoryBehavior>(QueryingAA, IRP, DepClassTy::NONE);
- if (MemBehaviorAA.isAssumedReadNone() ||
- (!RequireReadNone && MemBehaviorAA.isAssumedReadOnly())) {
- IsKnown = RequireReadNone ? MemBehaviorAA.isKnownReadNone()
- : MemBehaviorAA.isKnownReadOnly();
+ if (MemBehaviorAA &&
+ (MemBehaviorAA->isAssumedReadNone() ||
+ (!RequireReadNone && MemBehaviorAA->isAssumedReadOnly()))) {
+ IsKnown = RequireReadNone ? MemBehaviorAA->isKnownReadNone()
+ : MemBehaviorAA->isKnownReadOnly();
if (!IsKnown)
- A.recordDependence(MemBehaviorAA, QueryingAA, DepClassTy::OPTIONAL);
+ A.recordDependence(*MemBehaviorAA, QueryingAA, DepClassTy::OPTIONAL);
return true;
}
@@ -574,7 +622,7 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
const AbstractAttribute &QueryingAA,
const AA::InstExclusionSetTy *ExclusionSet,
std::function<bool(const Function &F)> GoBackwardsCB) {
- LLVM_DEBUG({
+ DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, {
dbgs() << "[AA] isPotentiallyReachable @" << ToFn.getName() << " from "
<< FromI << " [GBCB: " << bool(GoBackwardsCB) << "][#ExS: "
<< (ExclusionSet ? std::to_string(ExclusionSet->size()) : "none")
@@ -584,6 +632,19 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
dbgs() << *ES << "\n";
});
+ // We know kernels (generally) cannot be called from within the module. Thus,
+ // for reachability we would need to step back from a kernel which would allow
+ // us to reach anything anyway. Even if a kernel is invoked from another
+ // kernel, values like allocas and shared memory are not accessible. We
+ // implicitly check for this situation to avoid costly lookups.
+ if (GoBackwardsCB && &ToFn != FromI.getFunction() &&
+ !GoBackwardsCB(*FromI.getFunction()) && ToFn.hasFnAttribute("kernel") &&
+ FromI.getFunction()->hasFnAttribute("kernel")) {
+ LLVM_DEBUG(dbgs() << "[AA] assume kernel cannot be reached from within the "
+ "module; success\n";);
+ return false;
+ }
+
// If we can go arbitrarily backwards we will eventually reach an entry point
// that can reach ToI. Only if a set of blocks through which we cannot go is
// provided, or once we track internal functions not accessible from the
@@ -611,10 +672,10 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
return true;
LLVM_DEBUG(dbgs() << "[AA] check " << *ToI << " from " << *CurFromI
<< " intraprocedurally\n");
- const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>(
+ const auto *ReachabilityAA = A.getAAFor<AAIntraFnReachability>(
QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL);
- bool Result =
- ReachabilityAA.isAssumedReachable(A, *CurFromI, *ToI, ExclusionSet);
+ bool Result = !ReachabilityAA || ReachabilityAA->isAssumedReachable(
+ A, *CurFromI, *ToI, ExclusionSet);
LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " "
<< (Result ? "can potentially " : "cannot ") << "reach "
<< *ToI << " [Intra]\n");
@@ -624,11 +685,11 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
bool Result = true;
if (!ToFn.isDeclaration() && ToI) {
- const auto &ToReachabilityAA = A.getAAFor<AAIntraFnReachability>(
+ const auto *ToReachabilityAA = A.getAAFor<AAIntraFnReachability>(
QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL);
const Instruction &EntryI = ToFn.getEntryBlock().front();
- Result =
- ToReachabilityAA.isAssumedReachable(A, EntryI, *ToI, ExclusionSet);
+ Result = !ToReachabilityAA || ToReachabilityAA->isAssumedReachable(
+ A, EntryI, *ToI, ExclusionSet);
LLVM_DEBUG(dbgs() << "[AA] Entry " << EntryI << " of @" << ToFn.getName()
<< " " << (Result ? "can potentially " : "cannot ")
<< "reach @" << *ToI << " [ToFn]\n");
@@ -637,10 +698,10 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
if (Result) {
// The entry of the ToFn can reach the instruction ToI. If the current
// instruction is already known to reach the ToFn.
- const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>(
+ const auto *FnReachabilityAA = A.getAAFor<AAInterFnReachability>(
QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL);
- Result = FnReachabilityAA.instructionCanReach(A, *CurFromI, ToFn,
- ExclusionSet);
+ Result = !FnReachabilityAA || FnReachabilityAA->instructionCanReach(
+ A, *CurFromI, ToFn, ExclusionSet);
LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " in @" << FromFn->getName()
<< " " << (Result ? "can potentially " : "cannot ")
<< "reach @" << ToFn.getName() << " [FromFn]\n");
@@ -649,11 +710,11 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
}
// TODO: Check assumed nounwind.
- const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>(
+ const auto *ReachabilityAA = A.getAAFor<AAIntraFnReachability>(
QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL);
auto ReturnInstCB = [&](Instruction &Ret) {
- bool Result =
- ReachabilityAA.isAssumedReachable(A, *CurFromI, Ret, ExclusionSet);
+ bool Result = !ReachabilityAA || ReachabilityAA->isAssumedReachable(
+ A, *CurFromI, Ret, ExclusionSet);
LLVM_DEBUG(dbgs() << "[AA][Ret] " << *CurFromI << " "
<< (Result ? "can potentially " : "cannot ") << "reach "
<< Ret << " [Intra]\n");
@@ -743,14 +804,15 @@ bool AA::isAssumedThreadLocalObject(Attributor &A, Value &Obj,
<< "' is thread local; stack objects are thread local.\n");
return true;
}
- const auto &NoCaptureAA = A.getAAFor<AANoCapture>(
- QueryingAA, IRPosition::value(Obj), DepClassTy::OPTIONAL);
+ bool IsKnownNoCapture;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, &QueryingAA, IRPosition::value(Obj), DepClassTy::OPTIONAL,
+ IsKnownNoCapture);
LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj << "' is "
- << (NoCaptureAA.isAssumedNoCapture() ? "" : "not")
- << " thread local; "
- << (NoCaptureAA.isAssumedNoCapture() ? "non-" : "")
+ << (IsAssumedNoCapture ? "" : "not") << " thread local; "
+ << (IsAssumedNoCapture ? "non-" : "")
<< "captured stack object.\n");
- return NoCaptureAA.isAssumedNoCapture();
+ return IsAssumedNoCapture;
}
if (auto *GV = dyn_cast<GlobalVariable>(&Obj)) {
if (GV->isConstant()) {
@@ -831,9 +893,9 @@ bool AA::isPotentiallyAffectedByBarrier(Attributor &A,
return false;
};
- const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
+ const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
QueryingAA, IRPosition::value(*Ptr), DepClassTy::OPTIONAL);
- if (!UnderlyingObjsAA.forallUnderlyingObjects(Pred))
+ if (!UnderlyingObjsAA || !UnderlyingObjsAA->forallUnderlyingObjects(Pred))
return true;
}
return false;
@@ -848,38 +910,42 @@ static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) {
}
/// Return true if the information provided by \p Attr was added to the
-/// attribute list \p Attrs. This is only the case if it was not already present
-/// in \p Attrs at the position describe by \p PK and \p AttrIdx.
+/// attribute set \p AttrSet. This is only the case if it was not already
+/// present in \p AttrSet.
static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr,
- AttributeList &Attrs, int AttrIdx,
- bool ForceReplace = false) {
+ AttributeSet AttrSet, bool ForceReplace,
+ AttrBuilder &AB) {
if (Attr.isEnumAttribute()) {
Attribute::AttrKind Kind = Attr.getKindAsEnum();
- if (Attrs.hasAttributeAtIndex(AttrIdx, Kind))
- if (!ForceReplace &&
- isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind)))
- return false;
- Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr);
+ if (AttrSet.hasAttribute(Kind))
+ return false;
+ AB.addAttribute(Kind);
return true;
}
if (Attr.isStringAttribute()) {
StringRef Kind = Attr.getKindAsString();
- if (Attrs.hasAttributeAtIndex(AttrIdx, Kind))
- if (!ForceReplace &&
- isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind)))
+ if (AttrSet.hasAttribute(Kind)) {
+ if (!ForceReplace)
return false;
- Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr);
+ }
+ AB.addAttribute(Kind, Attr.getValueAsString());
return true;
}
if (Attr.isIntAttribute()) {
Attribute::AttrKind Kind = Attr.getKindAsEnum();
- if (Attrs.hasAttributeAtIndex(AttrIdx, Kind))
- if (!ForceReplace &&
- isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind)))
+ if (!ForceReplace && Kind == Attribute::Memory) {
+ MemoryEffects ME = Attr.getMemoryEffects() & AttrSet.getMemoryEffects();
+ if (ME == AttrSet.getMemoryEffects())
return false;
- Attrs = Attrs.removeAttributeAtIndex(Ctx, AttrIdx, Kind);
- Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr);
+ AB.addMemoryAttr(ME);
+ return true;
+ }
+ if (AttrSet.hasAttribute(Kind)) {
+ if (!ForceReplace && isEqualOrWorse(Attr, AttrSet.getAttribute(Kind)))
+ return false;
+ }
+ AB.addAttribute(Attr);
return true;
}
@@ -933,7 +999,7 @@ Argument *IRPosition::getAssociatedArgument() const {
// If no callbacks were found, or none used the underlying call site operand
// exclusively, use the direct callee argument if available.
- const Function *Callee = CB.getCalledFunction();
+ auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
if (Callee && Callee->arg_size() > unsigned(ArgNo))
return Callee->getArg(ArgNo);
@@ -955,63 +1021,168 @@ ChangeStatus AbstractAttribute::update(Attributor &A) {
return HasChanged;
}
+bool Attributor::getAttrsFromAssumes(const IRPosition &IRP,
+ Attribute::AttrKind AK,
+ SmallVectorImpl<Attribute> &Attrs) {
+ assert(IRP.getPositionKind() != IRPosition::IRP_INVALID &&
+ "Did expect a valid position!");
+ MustBeExecutedContextExplorer *Explorer =
+ getInfoCache().getMustBeExecutedContextExplorer();
+ if (!Explorer)
+ return false;
+
+ Value &AssociatedValue = IRP.getAssociatedValue();
+
+ const Assume2KnowledgeMap &A2K =
+ getInfoCache().getKnowledgeMap().lookup({&AssociatedValue, AK});
+
+ // Check if we found any potential assume use, if not we don't need to create
+ // explorer iterators.
+ if (A2K.empty())
+ return false;
+
+ LLVMContext &Ctx = AssociatedValue.getContext();
+ unsigned AttrsSize = Attrs.size();
+ auto EIt = Explorer->begin(IRP.getCtxI()),
+ EEnd = Explorer->end(IRP.getCtxI());
+ for (const auto &It : A2K)
+ if (Explorer->findInContextOf(It.first, EIt, EEnd))
+ Attrs.push_back(Attribute::get(Ctx, AK, It.second.Max));
+ return AttrsSize != Attrs.size();
+}
+
+template <typename DescTy>
ChangeStatus
-IRAttributeManifest::manifestAttrs(Attributor &A, const IRPosition &IRP,
- const ArrayRef<Attribute> &DeducedAttrs,
- bool ForceReplace) {
- Function *ScopeFn = IRP.getAnchorScope();
- IRPosition::Kind PK = IRP.getPositionKind();
-
- // In the following some generic code that will manifest attributes in
- // DeducedAttrs if they improve the current IR. Due to the different
- // annotation positions we use the underlying AttributeList interface.
-
- AttributeList Attrs;
- switch (PK) {
- case IRPosition::IRP_INVALID:
+Attributor::updateAttrMap(const IRPosition &IRP,
+ const ArrayRef<DescTy> &AttrDescs,
+ function_ref<bool(const DescTy &, AttributeSet,
+ AttributeMask &, AttrBuilder &)>
+ CB) {
+ if (AttrDescs.empty())
+ return ChangeStatus::UNCHANGED;
+ switch (IRP.getPositionKind()) {
case IRPosition::IRP_FLOAT:
+ case IRPosition::IRP_INVALID:
return ChangeStatus::UNCHANGED;
- case IRPosition::IRP_ARGUMENT:
- case IRPosition::IRP_FUNCTION:
- case IRPosition::IRP_RETURNED:
- Attrs = ScopeFn->getAttributes();
- break;
- case IRPosition::IRP_CALL_SITE:
- case IRPosition::IRP_CALL_SITE_RETURNED:
- case IRPosition::IRP_CALL_SITE_ARGUMENT:
- Attrs = cast<CallBase>(IRP.getAnchorValue()).getAttributes();
+ default:
break;
- }
+ };
+
+ AttributeList AL;
+ Value *AttrListAnchor = IRP.getAttrListAnchor();
+ auto It = AttrsMap.find(AttrListAnchor);
+ if (It == AttrsMap.end())
+ AL = IRP.getAttrList();
+ else
+ AL = It->getSecond();
- ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
LLVMContext &Ctx = IRP.getAnchorValue().getContext();
- for (const Attribute &Attr : DeducedAttrs) {
- if (!addIfNotExistent(Ctx, Attr, Attrs, IRP.getAttrIdx(), ForceReplace))
- continue;
+ auto AttrIdx = IRP.getAttrIdx();
+ AttributeSet AS = AL.getAttributes(AttrIdx);
+ AttributeMask AM;
+ AttrBuilder AB(Ctx);
- HasChanged = ChangeStatus::CHANGED;
- }
+ ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
+ for (const DescTy &AttrDesc : AttrDescs)
+ if (CB(AttrDesc, AS, AM, AB))
+ HasChanged = ChangeStatus::CHANGED;
if (HasChanged == ChangeStatus::UNCHANGED)
- return HasChanged;
+ return ChangeStatus::UNCHANGED;
- switch (PK) {
- case IRPosition::IRP_ARGUMENT:
- case IRPosition::IRP_FUNCTION:
- case IRPosition::IRP_RETURNED:
- ScopeFn->setAttributes(Attrs);
- break;
- case IRPosition::IRP_CALL_SITE:
- case IRPosition::IRP_CALL_SITE_RETURNED:
- case IRPosition::IRP_CALL_SITE_ARGUMENT:
- cast<CallBase>(IRP.getAnchorValue()).setAttributes(Attrs);
- break;
- case IRPosition::IRP_INVALID:
- case IRPosition::IRP_FLOAT:
- break;
+ AL = AL.removeAttributesAtIndex(Ctx, AttrIdx, AM);
+ AL = AL.addAttributesAtIndex(Ctx, AttrIdx, AB);
+ AttrsMap[AttrListAnchor] = AL;
+ return ChangeStatus::CHANGED;
+}
+
+bool Attributor::hasAttr(const IRPosition &IRP,
+ ArrayRef<Attribute::AttrKind> AttrKinds,
+ bool IgnoreSubsumingPositions,
+ Attribute::AttrKind ImpliedAttributeKind) {
+ bool Implied = false;
+ bool HasAttr = false;
+ auto HasAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet,
+ AttributeMask &, AttrBuilder &) {
+ if (AttrSet.hasAttribute(Kind)) {
+ Implied |= Kind != ImpliedAttributeKind;
+ HasAttr = true;
+ }
+ return false;
+ };
+ for (const IRPosition &EquivIRP : SubsumingPositionIterator(IRP)) {
+ updateAttrMap<Attribute::AttrKind>(EquivIRP, AttrKinds, HasAttrCB);
+ if (HasAttr)
+ break;
+ // The first position returned by the SubsumingPositionIterator is
+ // always the position itself. If we ignore subsuming positions we
+ // are done after the first iteration.
+ if (IgnoreSubsumingPositions)
+ break;
+ Implied = true;
+ }
+ if (!HasAttr) {
+ Implied = true;
+ SmallVector<Attribute> Attrs;
+ for (Attribute::AttrKind AK : AttrKinds)
+ if (getAttrsFromAssumes(IRP, AK, Attrs)) {
+ HasAttr = true;
+ break;
+ }
}
- return HasChanged;
+ // Check if we should manifest the implied attribute kind at the IRP.
+ if (ImpliedAttributeKind != Attribute::None && HasAttr && Implied)
+ manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(),
+ ImpliedAttributeKind)});
+ return HasAttr;
+}
+
+void Attributor::getAttrs(const IRPosition &IRP,
+ ArrayRef<Attribute::AttrKind> AttrKinds,
+ SmallVectorImpl<Attribute> &Attrs,
+ bool IgnoreSubsumingPositions) {
+ auto CollectAttrCB = [&](const Attribute::AttrKind &Kind,
+ AttributeSet AttrSet, AttributeMask &,
+ AttrBuilder &) {
+ if (AttrSet.hasAttribute(Kind))
+ Attrs.push_back(AttrSet.getAttribute(Kind));
+ return false;
+ };
+ for (const IRPosition &EquivIRP : SubsumingPositionIterator(IRP)) {
+ updateAttrMap<Attribute::AttrKind>(EquivIRP, AttrKinds, CollectAttrCB);
+ // The first position returned by the SubsumingPositionIterator is
+ // always the position itself. If we ignore subsuming positions we
+ // are done after the first iteration.
+ if (IgnoreSubsumingPositions)
+ break;
+ }
+ for (Attribute::AttrKind AK : AttrKinds)
+ getAttrsFromAssumes(IRP, AK, Attrs);
+}
+
+ChangeStatus
+Attributor::removeAttrs(const IRPosition &IRP,
+ const ArrayRef<Attribute::AttrKind> &AttrKinds) {
+ auto RemoveAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet,
+ AttributeMask &AM, AttrBuilder &) {
+ if (!AttrSet.hasAttribute(Kind))
+ return false;
+ AM.addAttribute(Kind);
+ return true;
+ };
+ return updateAttrMap<Attribute::AttrKind>(IRP, AttrKinds, RemoveAttrCB);
+}
+
+ChangeStatus Attributor::manifestAttrs(const IRPosition &IRP,
+ const ArrayRef<Attribute> &Attrs,
+ bool ForceReplace) {
+ LLVMContext &Ctx = IRP.getAnchorValue().getContext();
+ auto AddAttrCB = [&](const Attribute &Attr, AttributeSet AttrSet,
+ AttributeMask &, AttrBuilder &AB) {
+ return addIfNotExistent(Ctx, Attr, AttrSet, ForceReplace, AB);
+ };
+ return updateAttrMap<Attribute>(IRP, Attrs, AddAttrCB);
}
const IRPosition IRPosition::EmptyKey(DenseMapInfo<void *>::getEmptyKey());
@@ -1021,7 +1192,7 @@ const IRPosition
SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
IRPositions.emplace_back(IRP);
- // Helper to determine if operand bundles on a call site are benin or
+ // Helper to determine if operand bundles on a call site are benign or
// potentially problematic. We handle only llvm.assume for now.
auto CanIgnoreOperandBundles = [](const CallBase &CB) {
return (isa<IntrinsicInst>(CB) &&
@@ -1043,7 +1214,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
// TODO: We need to look at the operand bundles similar to the redirection
// in CallBase.
if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB))
- if (const Function *Callee = CB->getCalledFunction())
+ if (auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand()))
IRPositions.emplace_back(IRPosition::function(*Callee));
return;
case IRPosition::IRP_CALL_SITE_RETURNED:
@@ -1051,7 +1222,8 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
// TODO: We need to look at the operand bundles similar to the redirection
// in CallBase.
if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) {
- if (const Function *Callee = CB->getCalledFunction()) {
+ if (auto *Callee =
+ dyn_cast_if_present<Function>(CB->getCalledOperand())) {
IRPositions.emplace_back(IRPosition::returned(*Callee));
IRPositions.emplace_back(IRPosition::function(*Callee));
for (const Argument &Arg : Callee->args())
@@ -1071,7 +1243,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
// TODO: We need to look at the operand bundles similar to the redirection
// in CallBase.
if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) {
- const Function *Callee = CB->getCalledFunction();
+ auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand());
if (Callee) {
if (Argument *Arg = IRP.getAssociatedArgument())
IRPositions.emplace_back(IRPosition::argument(*Arg));
@@ -1084,85 +1256,6 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) {
}
}
-bool IRPosition::hasAttr(ArrayRef<Attribute::AttrKind> AKs,
- bool IgnoreSubsumingPositions, Attributor *A) const {
- SmallVector<Attribute, 4> Attrs;
- for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) {
- for (Attribute::AttrKind AK : AKs)
- if (EquivIRP.getAttrsFromIRAttr(AK, Attrs))
- return true;
- // The first position returned by the SubsumingPositionIterator is
- // always the position itself. If we ignore subsuming positions we
- // are done after the first iteration.
- if (IgnoreSubsumingPositions)
- break;
- }
- if (A)
- for (Attribute::AttrKind AK : AKs)
- if (getAttrsFromAssumes(AK, Attrs, *A))
- return true;
- return false;
-}
-
-void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs,
- SmallVectorImpl<Attribute> &Attrs,
- bool IgnoreSubsumingPositions, Attributor *A) const {
- for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) {
- for (Attribute::AttrKind AK : AKs)
- EquivIRP.getAttrsFromIRAttr(AK, Attrs);
- // The first position returned by the SubsumingPositionIterator is
- // always the position itself. If we ignore subsuming positions we
- // are done after the first iteration.
- if (IgnoreSubsumingPositions)
- break;
- }
- if (A)
- for (Attribute::AttrKind AK : AKs)
- getAttrsFromAssumes(AK, Attrs, *A);
-}
-
-bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK,
- SmallVectorImpl<Attribute> &Attrs) const {
- if (getPositionKind() == IRP_INVALID || getPositionKind() == IRP_FLOAT)
- return false;
-
- AttributeList AttrList;
- if (const auto *CB = dyn_cast<CallBase>(&getAnchorValue()))
- AttrList = CB->getAttributes();
- else
- AttrList = getAssociatedFunction()->getAttributes();
-
- bool HasAttr = AttrList.hasAttributeAtIndex(getAttrIdx(), AK);
- if (HasAttr)
- Attrs.push_back(AttrList.getAttributeAtIndex(getAttrIdx(), AK));
- return HasAttr;
-}
-
-bool IRPosition::getAttrsFromAssumes(Attribute::AttrKind AK,
- SmallVectorImpl<Attribute> &Attrs,
- Attributor &A) const {
- assert(getPositionKind() != IRP_INVALID && "Did expect a valid position!");
- Value &AssociatedValue = getAssociatedValue();
-
- const Assume2KnowledgeMap &A2K =
- A.getInfoCache().getKnowledgeMap().lookup({&AssociatedValue, AK});
-
- // Check if we found any potential assume use, if not we don't need to create
- // explorer iterators.
- if (A2K.empty())
- return false;
-
- LLVMContext &Ctx = AssociatedValue.getContext();
- unsigned AttrsSize = Attrs.size();
- MustBeExecutedContextExplorer &Explorer =
- A.getInfoCache().getMustBeExecutedContextExplorer();
- auto EIt = Explorer.begin(getCtxI()), EEnd = Explorer.end(getCtxI());
- for (const auto &It : A2K)
- if (Explorer.findInContextOf(It.first, EIt, EEnd))
- Attrs.push_back(Attribute::get(Ctx, AK, It.second.Max));
- return AttrsSize != Attrs.size();
-}
-
void IRPosition::verify() {
#ifdef EXPENSIVE_CHECKS
switch (getPositionKind()) {
@@ -1285,35 +1378,67 @@ std::optional<Value *> Attributor::getAssumedSimplified(
}
bool Attributor::getAssumedSimplifiedValues(
- const IRPosition &IRP, const AbstractAttribute *AA,
+ const IRPosition &InitialIRP, const AbstractAttribute *AA,
SmallVectorImpl<AA::ValueAndContext> &Values, AA::ValueScope S,
- bool &UsedAssumedInformation) {
- // First check all callbacks provided by outside AAs. If any of them returns
- // a non-null value that is different from the associated value, or
- // std::nullopt, we assume it's simplified.
- const auto &SimplificationCBs = SimplificationCallbacks.lookup(IRP);
- for (const auto &CB : SimplificationCBs) {
- std::optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation);
- if (!CBResult.has_value())
- continue;
- Value *V = *CBResult;
- if (!V)
- return false;
- if ((S & AA::ValueScope::Interprocedural) ||
- AA::isValidInScope(*V, IRP.getAnchorScope()))
- Values.push_back(AA::ValueAndContext{*V, nullptr});
- else
- return false;
- }
- if (!SimplificationCBs.empty())
- return true;
+ bool &UsedAssumedInformation, bool RecurseForSelectAndPHI) {
+ SmallPtrSet<Value *, 8> Seen;
+ SmallVector<IRPosition, 8> Worklist;
+ Worklist.push_back(InitialIRP);
+ while (!Worklist.empty()) {
+ const IRPosition &IRP = Worklist.pop_back_val();
+
+ // First check all callbacks provided by outside AAs. If any of them returns
+ // a non-null value that is different from the associated value, or
+ // std::nullopt, we assume it's simplified.
+ int NV = Values.size();
+ const auto &SimplificationCBs = SimplificationCallbacks.lookup(IRP);
+ for (const auto &CB : SimplificationCBs) {
+ std::optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation);
+ if (!CBResult.has_value())
+ continue;
+ Value *V = *CBResult;
+ if (!V)
+ return false;
+ if ((S & AA::ValueScope::Interprocedural) ||
+ AA::isValidInScope(*V, IRP.getAnchorScope()))
+ Values.push_back(AA::ValueAndContext{*V, nullptr});
+ else
+ return false;
+ }
+ if (SimplificationCBs.empty()) {
+ // If no high-level/outside simplification occurred, use
+ // AAPotentialValues.
+ const auto *PotentialValuesAA =
+ getOrCreateAAFor<AAPotentialValues>(IRP, AA, DepClassTy::OPTIONAL);
+ if (PotentialValuesAA && PotentialValuesAA->getAssumedSimplifiedValues(*this, Values, S)) {
+ UsedAssumedInformation |= !PotentialValuesAA->isAtFixpoint();
+ } else if (IRP.getPositionKind() != IRPosition::IRP_RETURNED) {
+ Values.push_back({IRP.getAssociatedValue(), IRP.getCtxI()});
+ } else {
+ // TODO: We could visit all returns and add the operands.
+ return false;
+ }
+ }
- // If no high-level/outside simplification occurred, use AAPotentialValues.
- const auto &PotentialValuesAA =
- getOrCreateAAFor<AAPotentialValues>(IRP, AA, DepClassTy::OPTIONAL);
- if (!PotentialValuesAA.getAssumedSimplifiedValues(*this, Values, S))
- return false;
- UsedAssumedInformation |= !PotentialValuesAA.isAtFixpoint();
+ if (!RecurseForSelectAndPHI)
+ break;
+
+ for (int I = NV, E = Values.size(); I < E; ++I) {
+ Value *V = Values[I].getValue();
+ if (!isa<PHINode>(V) && !isa<SelectInst>(V))
+ continue;
+ if (!Seen.insert(V).second)
+ continue;
+ // Move the last element to this slot.
+ Values[I] = Values[E - 1];
+ // Eliminate the last slot, adjust the indices.
+ Values.pop_back();
+ --E;
+ --I;
+ // Add a new value (select or phi) to the worklist.
+ Worklist.push_back(IRPosition::value(*V));
+ }
+ }
return true;
}
@@ -1325,7 +1450,8 @@ std::optional<Value *> Attributor::translateArgumentToCallSiteContent(
if (*V == nullptr || isa<Constant>(*V))
return V;
if (auto *Arg = dyn_cast<Argument>(*V))
- if (CB.getCalledFunction() == Arg->getParent())
+ if (CB.getCalledOperand() == Arg->getParent() &&
+ CB.arg_size() > Arg->getArgNo())
if (!Arg->hasPointeeInMemoryValueAttr())
return getAssumedSimplified(
IRPosition::callsite_argument(CB, Arg->getArgNo()), AA,
@@ -1346,6 +1472,8 @@ bool Attributor::isAssumedDead(const AbstractAttribute &AA,
const AAIsDead *FnLivenessAA,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly, DepClassTy DepClass) {
+ if (!Configuration.UseLiveness)
+ return false;
const IRPosition &IRP = AA.getIRPosition();
if (!Functions.count(IRP.getAnchorScope()))
return false;
@@ -1358,6 +1486,8 @@ bool Attributor::isAssumedDead(const Use &U,
const AAIsDead *FnLivenessAA,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly, DepClassTy DepClass) {
+ if (!Configuration.UseLiveness)
+ return false;
Instruction *UserI = dyn_cast<Instruction>(U.getUser());
if (!UserI)
return isAssumedDead(IRPosition::value(*U.get()), QueryingAA, FnLivenessAA,
@@ -1384,12 +1514,12 @@ bool Attributor::isAssumedDead(const Use &U,
} else if (StoreInst *SI = dyn_cast<StoreInst>(UserI)) {
if (!CheckBBLivenessOnly && SI->getPointerOperand() != U.get()) {
const IRPosition IRP = IRPosition::inst(*SI);
- const AAIsDead &IsDeadAA =
+ const AAIsDead *IsDeadAA =
getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE);
- if (IsDeadAA.isRemovableStore()) {
+ if (IsDeadAA && IsDeadAA->isRemovableStore()) {
if (QueryingAA)
- recordDependence(IsDeadAA, *QueryingAA, DepClass);
- if (!IsDeadAA.isKnown(AAIsDead::IS_REMOVABLE))
+ recordDependence(*IsDeadAA, *QueryingAA, DepClass);
+ if (!IsDeadAA->isKnown(AAIsDead::IS_REMOVABLE))
UsedAssumedInformation = true;
return true;
}
@@ -1406,6 +1536,8 @@ bool Attributor::isAssumedDead(const Instruction &I,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly, DepClassTy DepClass,
bool CheckForDeadStore) {
+ if (!Configuration.UseLiveness)
+ return false;
const IRPosition::CallBaseContext *CBCtx =
QueryingAA ? QueryingAA->getCallBaseContext() : nullptr;
@@ -1414,11 +1546,11 @@ bool Attributor::isAssumedDead(const Instruction &I,
const Function &F = *I.getFunction();
if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F)
- FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F, CBCtx),
- QueryingAA, DepClassTy::NONE);
+ FnLivenessAA = getOrCreateAAFor<AAIsDead>(IRPosition::function(F, CBCtx),
+ QueryingAA, DepClassTy::NONE);
// Don't use recursive reasoning.
- if (QueryingAA == FnLivenessAA)
+ if (!FnLivenessAA || QueryingAA == FnLivenessAA)
return false;
// If we have a context instruction and a liveness AA we use it.
@@ -1435,25 +1567,25 @@ bool Attributor::isAssumedDead(const Instruction &I,
return false;
const IRPosition IRP = IRPosition::inst(I, CBCtx);
- const AAIsDead &IsDeadAA =
+ const AAIsDead *IsDeadAA =
getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE);
// Don't use recursive reasoning.
- if (QueryingAA == &IsDeadAA)
+ if (!IsDeadAA || QueryingAA == IsDeadAA)
return false;
- if (IsDeadAA.isAssumedDead()) {
+ if (IsDeadAA->isAssumedDead()) {
if (QueryingAA)
- recordDependence(IsDeadAA, *QueryingAA, DepClass);
- if (!IsDeadAA.isKnownDead())
+ recordDependence(*IsDeadAA, *QueryingAA, DepClass);
+ if (!IsDeadAA->isKnownDead())
UsedAssumedInformation = true;
return true;
}
- if (CheckForDeadStore && isa<StoreInst>(I) && IsDeadAA.isRemovableStore()) {
+ if (CheckForDeadStore && isa<StoreInst>(I) && IsDeadAA->isRemovableStore()) {
if (QueryingAA)
- recordDependence(IsDeadAA, *QueryingAA, DepClass);
- if (!IsDeadAA.isKnownDead())
+ recordDependence(*IsDeadAA, *QueryingAA, DepClass);
+ if (!IsDeadAA->isKnownDead())
UsedAssumedInformation = true;
return true;
}
@@ -1466,6 +1598,8 @@ bool Attributor::isAssumedDead(const IRPosition &IRP,
const AAIsDead *FnLivenessAA,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly, DepClassTy DepClass) {
+ if (!Configuration.UseLiveness)
+ return false;
// Don't check liveness for constants, e.g. functions, used as (floating)
// values since the context instruction and such is here meaningless.
if (IRP.getPositionKind() == IRPosition::IRP_FLOAT &&
@@ -1486,14 +1620,14 @@ bool Attributor::isAssumedDead(const IRPosition &IRP,
// If we haven't succeeded we query the specific liveness info for the IRP.
const AAIsDead *IsDeadAA;
if (IRP.getPositionKind() == IRPosition::IRP_CALL_SITE)
- IsDeadAA = &getOrCreateAAFor<AAIsDead>(
+ IsDeadAA = getOrCreateAAFor<AAIsDead>(
IRPosition::callsite_returned(cast<CallBase>(IRP.getAssociatedValue())),
QueryingAA, DepClassTy::NONE);
else
- IsDeadAA = &getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE);
+ IsDeadAA = getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE);
// Don't use recursive reasoning.
- if (QueryingAA == IsDeadAA)
+ if (!IsDeadAA || QueryingAA == IsDeadAA)
return false;
if (IsDeadAA->isAssumedDead()) {
@@ -1511,13 +1645,15 @@ bool Attributor::isAssumedDead(const BasicBlock &BB,
const AbstractAttribute *QueryingAA,
const AAIsDead *FnLivenessAA,
DepClassTy DepClass) {
+ if (!Configuration.UseLiveness)
+ return false;
const Function &F = *BB.getParent();
if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F)
- FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F),
- QueryingAA, DepClassTy::NONE);
+ FnLivenessAA = getOrCreateAAFor<AAIsDead>(IRPosition::function(F),
+ QueryingAA, DepClassTy::NONE);
// Don't use recursive reasoning.
- if (QueryingAA == FnLivenessAA)
+ if (!FnLivenessAA || QueryingAA == FnLivenessAA)
return false;
if (FnLivenessAA->isAssumedDead(&BB)) {
@@ -1570,8 +1706,8 @@ bool Attributor::checkForAllUses(
const Function *ScopeFn = IRP.getAnchorScope();
const auto *LivenessAA =
- ScopeFn ? &getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*ScopeFn),
- DepClassTy::NONE)
+ ScopeFn ? getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*ScopeFn),
+ DepClassTy::NONE)
: nullptr;
while (!Worklist.empty()) {
@@ -1777,49 +1913,26 @@ bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP) {
return EnableCallSiteSpecific;
}
-bool Attributor::checkForAllReturnedValuesAndReturnInsts(
- function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred,
- const AbstractAttribute &QueryingAA) {
+bool Attributor::checkForAllReturnedValues(function_ref<bool(Value &)> Pred,
+ const AbstractAttribute &QueryingAA,
+ AA::ValueScope S,
+ bool RecurseForSelectAndPHI) {
const IRPosition &IRP = QueryingAA.getIRPosition();
- // Since we need to provide return instructions we have to have an exact
- // definition.
const Function *AssociatedFunction = IRP.getAssociatedFunction();
if (!AssociatedFunction)
return false;
- // If this is a call site query we use the call site specific return values
- // and liveness information.
- // TODO: use the function scope once we have call site AAReturnedValues.
- const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
- const auto &AARetVal =
- getAAFor<AAReturnedValues>(QueryingAA, QueryIRP, DepClassTy::REQUIRED);
- if (!AARetVal.getState().isValidState())
- return false;
-
- return AARetVal.checkForAllReturnedValuesAndReturnInsts(Pred);
-}
-
-bool Attributor::checkForAllReturnedValues(
- function_ref<bool(Value &)> Pred, const AbstractAttribute &QueryingAA) {
-
- const IRPosition &IRP = QueryingAA.getIRPosition();
- const Function *AssociatedFunction = IRP.getAssociatedFunction();
- if (!AssociatedFunction)
- return false;
-
- // TODO: use the function scope once we have call site AAReturnedValues.
- const IRPosition &QueryIRP = IRPosition::function(
- *AssociatedFunction, QueryingAA.getCallBaseContext());
- const auto &AARetVal =
- getAAFor<AAReturnedValues>(QueryingAA, QueryIRP, DepClassTy::REQUIRED);
- if (!AARetVal.getState().isValidState())
+ bool UsedAssumedInformation = false;
+ SmallVector<AA::ValueAndContext> Values;
+ if (!getAssumedSimplifiedValues(
+ IRPosition::returned(*AssociatedFunction), &QueryingAA, Values, S,
+ UsedAssumedInformation, RecurseForSelectAndPHI))
return false;
- return AARetVal.checkForAllReturnedValuesAndReturnInsts(
- [&](Value &RV, const SmallSetVector<ReturnInst *, 4> &) {
- return Pred(RV);
- });
+ return llvm::all_of(Values, [&](const AA::ValueAndContext &VAC) {
+ return Pred(*VAC.getValue());
+ });
}
static bool checkForAllInstructionsImpl(
@@ -1863,12 +1976,11 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
if (!Fn || Fn->isDeclaration())
return false;
- // TODO: use the function scope once we have call site AAReturnedValues.
const IRPosition &QueryIRP = IRPosition::function(*Fn);
const auto *LivenessAA =
- (CheckBBLivenessOnly || CheckPotentiallyDead)
+ CheckPotentiallyDead
? nullptr
- : &(getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE));
+ : (getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE));
auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn);
if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA,
@@ -1895,21 +2007,21 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
bool Attributor::checkForAllReadWriteInstructions(
function_ref<bool(Instruction &)> Pred, AbstractAttribute &QueryingAA,
bool &UsedAssumedInformation) {
+ TimeTraceScope TS("checkForAllReadWriteInstructions");
const Function *AssociatedFunction =
QueryingAA.getIRPosition().getAssociatedFunction();
if (!AssociatedFunction)
return false;
- // TODO: use the function scope once we have call site AAReturnedValues.
const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction);
- const auto &LivenessAA =
+ const auto *LivenessAA =
getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE);
for (Instruction *I :
InfoCache.getReadOrWriteInstsForFunction(*AssociatedFunction)) {
// Skip dead instructions.
- if (isAssumedDead(IRPosition::inst(*I), &QueryingAA, &LivenessAA,
+ if (isAssumedDead(IRPosition::inst(*I), &QueryingAA, LivenessAA,
UsedAssumedInformation))
continue;
@@ -1954,11 +2066,9 @@ void Attributor::runTillFixpoint() {
dbgs() << "[Attributor] InvalidAA: " << *InvalidAA
<< " has " << InvalidAA->Deps.size()
<< " required & optional dependences\n");
- while (!InvalidAA->Deps.empty()) {
- const auto &Dep = InvalidAA->Deps.back();
- InvalidAA->Deps.pop_back();
- AbstractAttribute *DepAA = cast<AbstractAttribute>(Dep.getPointer());
- if (Dep.getInt() == unsigned(DepClassTy::OPTIONAL)) {
+ for (auto &DepIt : InvalidAA->Deps) {
+ AbstractAttribute *DepAA = cast<AbstractAttribute>(DepIt.getPointer());
+ if (DepIt.getInt() == unsigned(DepClassTy::OPTIONAL)) {
DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE,
dbgs() << " - recompute: " << *DepAA);
Worklist.insert(DepAA);
@@ -1973,16 +2083,16 @@ void Attributor::runTillFixpoint() {
else
ChangedAAs.push_back(DepAA);
}
+ InvalidAA->Deps.clear();
}
// Add all abstract attributes that are potentially dependent on one that
// changed to the work list.
- for (AbstractAttribute *ChangedAA : ChangedAAs)
- while (!ChangedAA->Deps.empty()) {
- Worklist.insert(
- cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer()));
- ChangedAA->Deps.pop_back();
- }
+ for (AbstractAttribute *ChangedAA : ChangedAAs) {
+ for (auto &DepIt : ChangedAA->Deps)
+ Worklist.insert(cast<AbstractAttribute>(DepIt.getPointer()));
+ ChangedAA->Deps.clear();
+ }
LLVM_DEBUG(dbgs() << "[Attributor] #Iteration: " << IterationCounter
<< ", Worklist+Dependent size: " << Worklist.size()
@@ -2019,8 +2129,7 @@ void Attributor::runTillFixpoint() {
QueryAAsAwaitingUpdate.end());
QueryAAsAwaitingUpdate.clear();
- } while (!Worklist.empty() &&
- (IterationCounter++ < MaxIterations || VerifyMaxFixpointIterations));
+ } while (!Worklist.empty() && (IterationCounter++ < MaxIterations));
if (IterationCounter > MaxIterations && !Functions.empty()) {
auto Remark = [&](OptimizationRemarkMissed ORM) {
@@ -2053,11 +2162,9 @@ void Attributor::runTillFixpoint() {
NumAttributesTimedOut++;
}
- while (!ChangedAA->Deps.empty()) {
- ChangedAAs.push_back(
- cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer()));
- ChangedAA->Deps.pop_back();
- }
+ for (auto &DepIt : ChangedAA->Deps)
+ ChangedAAs.push_back(cast<AbstractAttribute>(DepIt.getPointer()));
+ ChangedAA->Deps.clear();
}
LLVM_DEBUG({
@@ -2065,13 +2172,6 @@ void Attributor::runTillFixpoint() {
dbgs() << "\n[Attributor] Finalized " << Visited.size()
<< " abstract attributes.\n";
});
-
- if (VerifyMaxFixpointIterations && IterationCounter != MaxIterations) {
- errs() << "\n[Attributor] Fixpoint iteration done after: "
- << IterationCounter << "/" << MaxIterations << " iterations\n";
- llvm_unreachable("The fixpoint was not reached with exactly the number of "
- "specified iterations!");
- }
}
void Attributor::registerForUpdate(AbstractAttribute &AA) {
@@ -2141,17 +2241,31 @@ ChangeStatus Attributor::manifestAttributes() {
(void)NumFinalAAs;
if (NumFinalAAs != DG.SyntheticRoot.Deps.size()) {
- for (unsigned u = NumFinalAAs; u < DG.SyntheticRoot.Deps.size(); ++u)
+ auto DepIt = DG.SyntheticRoot.Deps.begin();
+ for (unsigned u = 0; u < NumFinalAAs; ++u)
+ ++DepIt;
+ for (unsigned u = NumFinalAAs; u < DG.SyntheticRoot.Deps.size();
+ ++u, ++DepIt) {
errs() << "Unexpected abstract attribute: "
- << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer())
- << " :: "
- << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer())
+ << cast<AbstractAttribute>(DepIt->getPointer()) << " :: "
+ << cast<AbstractAttribute>(DepIt->getPointer())
->getIRPosition()
.getAssociatedValue()
<< "\n";
+ }
llvm_unreachable("Expected the final number of abstract attributes to "
"remain unchanged!");
}
+
+ for (auto &It : AttrsMap) {
+ AttributeList &AL = It.getSecond();
+ const IRPosition &IRP =
+ isa<Function>(It.getFirst())
+ ? IRPosition::function(*cast<Function>(It.getFirst()))
+ : IRPosition::callsite_function(*cast<CallBase>(It.getFirst()));
+ IRP.setAttrList(AL);
+ }
+
return ManifestChange;
}
@@ -2271,9 +2385,9 @@ ChangeStatus Attributor::cleanupIR() {
if (CB->isArgOperand(U)) {
unsigned Idx = CB->getArgOperandNo(U);
CB->removeParamAttr(Idx, Attribute::NoUndef);
- Function *Fn = CB->getCalledFunction();
- if (Fn && Fn->arg_size() > Idx)
- Fn->removeParamAttr(Idx, Attribute::NoUndef);
+ auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand());
+ if (Callee && Callee->arg_size() > Idx)
+ Callee->removeParamAttr(Idx, Attribute::NoUndef);
}
}
if (isa<Constant>(NewV) && isa<BranchInst>(U->getUser())) {
@@ -2484,9 +2598,9 @@ ChangeStatus Attributor::run() {
}
ChangeStatus Attributor::updateAA(AbstractAttribute &AA) {
- TimeTraceScope TimeScope(
- AA.getName() + std::to_string(AA.getIRPosition().getPositionKind()) +
- "::updateAA");
+ TimeTraceScope TimeScope("updateAA", [&]() {
+ return AA.getName() + std::to_string(AA.getIRPosition().getPositionKind());
+ });
assert(Phase == AttributorPhase::UPDATE &&
"We can update AA only in the update stage!");
@@ -2672,7 +2786,10 @@ bool Attributor::isValidFunctionSignatureRewrite(
ACS.getInstruction()->getType() !=
ACS.getCalledFunction()->getReturnType())
return false;
- if (ACS.getCalledOperand()->getType() != Fn->getType())
+ if (cast<CallBase>(ACS.getInstruction())->getCalledOperand()->getType() !=
+ Fn->getType())
+ return false;
+ if (ACS.getNumArgOperands() != Fn->arg_size())
return false;
// Forbid must-tail calls for now.
return !ACS.isCallbackCall() && !ACS.getInstruction()->isMustTailCall();
@@ -2698,7 +2815,8 @@ bool Attributor::isValidFunctionSignatureRewrite(
// Avoid callbacks for now.
bool UsedAssumedInformation = false;
if (!checkForAllCallSites(CallSiteCanBeChanged, *Fn, true, nullptr,
- UsedAssumedInformation)) {
+ UsedAssumedInformation,
+ /* CheckPotentiallyDead */ true)) {
LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite all call sites\n");
return false;
}
@@ -3041,7 +3159,8 @@ void InformationCache::initializeInformationCache(const Function &CF,
AddToAssumeUsesMap(*Assume->getArgOperand(0));
} else if (cast<CallInst>(I).isMustTailCall()) {
FI.ContainsMustTailCall = true;
- if (const Function *Callee = cast<CallInst>(I).getCalledFunction())
+ if (auto *Callee = dyn_cast_if_present<Function>(
+ cast<CallInst>(I).getCalledOperand()))
getFunctionInfo(*Callee).CalledViaMustTail = true;
}
[[fallthrough]];
@@ -3077,10 +3196,6 @@ void InformationCache::initializeInformationCache(const Function &CF,
InlineableFunctions.insert(&F);
}
-AAResults *InformationCache::getAAResultsForFunction(const Function &F) {
- return AG.getAnalysis<AAManager>(F);
-}
-
InformationCache::FunctionInfo::~FunctionInfo() {
// The instruction vectors are allocated using a BumpPtrAllocator, we need to
// manually destroy them.
@@ -3111,11 +3226,21 @@ void Attributor::rememberDependences() {
DI.DepClass == DepClassTy::OPTIONAL) &&
"Expected required or optional dependence (1 bit)!");
auto &DepAAs = const_cast<AbstractAttribute &>(*DI.FromAA).Deps;
- DepAAs.push_back(AbstractAttribute::DepTy(
+ DepAAs.insert(AbstractAttribute::DepTy(
const_cast<AbstractAttribute *>(DI.ToAA), unsigned(DI.DepClass)));
}
}
+template <Attribute::AttrKind AK, typename AAType>
+void Attributor::checkAndQueryIRAttr(const IRPosition &IRP,
+ AttributeSet Attrs) {
+ bool IsKnown;
+ if (!Attrs.hasAttribute(AK))
+ if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE,
+ IsKnown))
+ getOrCreateAAFor<AAType>(IRP);
+}
+
void Attributor::identifyDefaultAbstractAttributes(Function &F) {
if (!VisitedFunctions.insert(&F).second)
return;
@@ -3134,89 +3259,114 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
}
IRPosition FPos = IRPosition::function(F);
+ bool IsIPOAmendable = isFunctionIPOAmendable(F);
+ auto Attrs = F.getAttributes();
+ auto FnAttrs = Attrs.getFnAttrs();
// Check for dead BasicBlocks in every function.
// We need dead instruction detection because we do not want to deal with
// broken IR in which SSA rules do not apply.
getOrCreateAAFor<AAIsDead>(FPos);
- // Every function might be "will-return".
- getOrCreateAAFor<AAWillReturn>(FPos);
-
- // Every function might contain instructions that cause "undefined behavior".
+ // Every function might contain instructions that cause "undefined
+ // behavior".
getOrCreateAAFor<AAUndefinedBehavior>(FPos);
- // Every function can be nounwind.
- getOrCreateAAFor<AANoUnwind>(FPos);
+ // Every function might be applicable for Heap-To-Stack conversion.
+ if (EnableHeapToStack)
+ getOrCreateAAFor<AAHeapToStack>(FPos);
- // Every function might be marked "nosync"
- getOrCreateAAFor<AANoSync>(FPos);
+ // Every function might be "must-progress".
+ checkAndQueryIRAttr<Attribute::MustProgress, AAMustProgress>(FPos, FnAttrs);
// Every function might be "no-free".
- getOrCreateAAFor<AANoFree>(FPos);
+ checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(FPos, FnAttrs);
- // Every function might be "no-return".
- getOrCreateAAFor<AANoReturn>(FPos);
+ // Every function might be "will-return".
+ checkAndQueryIRAttr<Attribute::WillReturn, AAWillReturn>(FPos, FnAttrs);
- // Every function might be "no-recurse".
- getOrCreateAAFor<AANoRecurse>(FPos);
+ // Everything that is visible from the outside (=function, argument, return
+ // positions), cannot be changed if the function is not IPO amendable. We can
+ // however analyse the code inside.
+ if (IsIPOAmendable) {
- // Every function might be "readnone/readonly/writeonly/...".
- getOrCreateAAFor<AAMemoryBehavior>(FPos);
+ // Every function can be nounwind.
+ checkAndQueryIRAttr<Attribute::NoUnwind, AANoUnwind>(FPos, FnAttrs);
- // Every function can be "readnone/argmemonly/inaccessiblememonly/...".
- getOrCreateAAFor<AAMemoryLocation>(FPos);
+ // Every function might be marked "nosync"
+ checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs);
- // Every function can track active assumptions.
- getOrCreateAAFor<AAAssumptionInfo>(FPos);
+ // Every function might be "no-return".
+ checkAndQueryIRAttr<Attribute::NoReturn, AANoReturn>(FPos, FnAttrs);
- // Every function might be applicable for Heap-To-Stack conversion.
- if (EnableHeapToStack)
- getOrCreateAAFor<AAHeapToStack>(FPos);
+ // Every function might be "no-recurse".
+ checkAndQueryIRAttr<Attribute::NoRecurse, AANoRecurse>(FPos, FnAttrs);
- // Return attributes are only appropriate if the return type is non void.
- Type *ReturnType = F.getReturnType();
- if (!ReturnType->isVoidTy()) {
- // Argument attribute "returned" --- Create only one per function even
- // though it is an argument attribute.
- getOrCreateAAFor<AAReturnedValues>(FPos);
+ // Every function can be "non-convergent".
+ if (Attrs.hasFnAttr(Attribute::Convergent))
+ getOrCreateAAFor<AANonConvergent>(FPos);
- IRPosition RetPos = IRPosition::returned(F);
+ // Every function might be "readnone/readonly/writeonly/...".
+ getOrCreateAAFor<AAMemoryBehavior>(FPos);
- // Every returned value might be dead.
- getOrCreateAAFor<AAIsDead>(RetPos);
+ // Every function can be "readnone/argmemonly/inaccessiblememonly/...".
+ getOrCreateAAFor<AAMemoryLocation>(FPos);
- // Every function might be simplified.
- bool UsedAssumedInformation = false;
- getAssumedSimplified(RetPos, nullptr, UsedAssumedInformation,
- AA::Intraprocedural);
+ // Every function can track active assumptions.
+ getOrCreateAAFor<AAAssumptionInfo>(FPos);
- // Every returned value might be marked noundef.
- getOrCreateAAFor<AANoUndef>(RetPos);
+ // Return attributes are only appropriate if the return type is non void.
+ Type *ReturnType = F.getReturnType();
+ if (!ReturnType->isVoidTy()) {
+ IRPosition RetPos = IRPosition::returned(F);
+ AttributeSet RetAttrs = Attrs.getRetAttrs();
- if (ReturnType->isPointerTy()) {
+ // Every returned value might be dead.
+ getOrCreateAAFor<AAIsDead>(RetPos);
- // Every function with pointer return type might be marked align.
- getOrCreateAAFor<AAAlign>(RetPos);
+ // Every function might be simplified.
+ bool UsedAssumedInformation = false;
+ getAssumedSimplified(RetPos, nullptr, UsedAssumedInformation,
+ AA::Intraprocedural);
+
+ // Every returned value might be marked noundef.
+ checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(RetPos, RetAttrs);
+
+ if (ReturnType->isPointerTy()) {
- // Every function with pointer return type might be marked nonnull.
- getOrCreateAAFor<AANonNull>(RetPos);
+ // Every function with pointer return type might be marked align.
+ getOrCreateAAFor<AAAlign>(RetPos);
- // Every function with pointer return type might be marked noalias.
- getOrCreateAAFor<AANoAlias>(RetPos);
+ // Every function with pointer return type might be marked nonnull.
+ checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(RetPos, RetAttrs);
- // Every function with pointer return type might be marked
- // dereferenceable.
- getOrCreateAAFor<AADereferenceable>(RetPos);
+ // Every function with pointer return type might be marked noalias.
+ checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(RetPos, RetAttrs);
+
+ // Every function with pointer return type might be marked
+ // dereferenceable.
+ getOrCreateAAFor<AADereferenceable>(RetPos);
+ } else if (AttributeFuncs::isNoFPClassCompatibleType(ReturnType)) {
+ getOrCreateAAFor<AANoFPClass>(RetPos);
+ }
}
}
for (Argument &Arg : F.args()) {
IRPosition ArgPos = IRPosition::argument(Arg);
+ auto ArgNo = Arg.getArgNo();
+ AttributeSet ArgAttrs = Attrs.getParamAttrs(ArgNo);
+
+ if (!IsIPOAmendable) {
+ if (Arg.getType()->isPointerTy())
+ // Every argument with pointer type might be marked nofree.
+ checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(ArgPos, ArgAttrs);
+ continue;
+ }
- // Every argument might be simplified. We have to go through the Attributor
- // interface though as outside AAs can register custom simplification
- // callbacks.
+ // Every argument might be simplified. We have to go through the
+ // Attributor interface though as outside AAs can register custom
+ // simplification callbacks.
bool UsedAssumedInformation = false;
getAssumedSimplified(ArgPos, /* AA */ nullptr, UsedAssumedInformation,
AA::Intraprocedural);
@@ -3225,14 +3375,14 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
getOrCreateAAFor<AAIsDead>(ArgPos);
// Every argument might be marked noundef.
- getOrCreateAAFor<AANoUndef>(ArgPos);
+ checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(ArgPos, ArgAttrs);
if (Arg.getType()->isPointerTy()) {
// Every argument with pointer type might be marked nonnull.
- getOrCreateAAFor<AANonNull>(ArgPos);
+ checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(ArgPos, ArgAttrs);
// Every argument with pointer type might be marked noalias.
- getOrCreateAAFor<AANoAlias>(ArgPos);
+ checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(ArgPos, ArgAttrs);
// Every argument with pointer type might be marked dereferenceable.
getOrCreateAAFor<AADereferenceable>(ArgPos);
@@ -3241,17 +3391,20 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
getOrCreateAAFor<AAAlign>(ArgPos);
// Every argument with pointer type might be marked nocapture.
- getOrCreateAAFor<AANoCapture>(ArgPos);
+ checkAndQueryIRAttr<Attribute::NoCapture, AANoCapture>(ArgPos, ArgAttrs);
// Every argument with pointer type might be marked
// "readnone/readonly/writeonly/..."
getOrCreateAAFor<AAMemoryBehavior>(ArgPos);
// Every argument with pointer type might be marked nofree.
- getOrCreateAAFor<AANoFree>(ArgPos);
+ checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(ArgPos, ArgAttrs);
- // Every argument with pointer type might be privatizable (or promotable)
+ // Every argument with pointer type might be privatizable (or
+ // promotable)
getOrCreateAAFor<AAPrivatizablePtr>(ArgPos);
+ } else if (AttributeFuncs::isNoFPClassCompatibleType(Arg.getType())) {
+ getOrCreateAAFor<AANoFPClass>(ArgPos);
}
}
@@ -3264,7 +3417,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
// users. The return value might be dead if there are no live users.
getOrCreateAAFor<AAIsDead>(CBInstPos);
- Function *Callee = CB.getCalledFunction();
+ Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
// TODO: Even if the callee is not known now we might be able to simplify
// the call/callee.
if (!Callee)
@@ -3280,16 +3433,20 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
return true;
if (!Callee->getReturnType()->isVoidTy() && !CB.use_empty()) {
-
IRPosition CBRetPos = IRPosition::callsite_returned(CB);
bool UsedAssumedInformation = false;
getAssumedSimplified(CBRetPos, nullptr, UsedAssumedInformation,
AA::Intraprocedural);
+
+ if (AttributeFuncs::isNoFPClassCompatibleType(Callee->getReturnType()))
+ getOrCreateAAFor<AANoFPClass>(CBInstPos);
}
+ const AttributeList &CBAttrs = CBFnPos.getAttrList();
for (int I = 0, E = CB.arg_size(); I < E; ++I) {
IRPosition CBArgPos = IRPosition::callsite_argument(CB, I);
+ AttributeSet CBArgAttrs = CBAttrs.getParamAttrs(I);
// Every call site argument might be dead.
getOrCreateAAFor<AAIsDead>(CBArgPos);
@@ -3302,19 +3459,26 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
AA::Intraprocedural);
// Every call site argument might be marked "noundef".
- getOrCreateAAFor<AANoUndef>(CBArgPos);
+ checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(CBArgPos, CBArgAttrs);
+
+ Type *ArgTy = CB.getArgOperand(I)->getType();
+
+ if (!ArgTy->isPointerTy()) {
+ if (AttributeFuncs::isNoFPClassCompatibleType(ArgTy))
+ getOrCreateAAFor<AANoFPClass>(CBArgPos);
- if (!CB.getArgOperand(I)->getType()->isPointerTy())
continue;
+ }
// Call site argument attribute "non-null".
- getOrCreateAAFor<AANonNull>(CBArgPos);
+ checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(CBArgPos, CBArgAttrs);
// Call site argument attribute "nocapture".
- getOrCreateAAFor<AANoCapture>(CBArgPos);
+ checkAndQueryIRAttr<Attribute::NoCapture, AANoCapture>(CBArgPos,
+ CBArgAttrs);
// Call site argument attribute "no-alias".
- getOrCreateAAFor<AANoAlias>(CBArgPos);
+ checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(CBArgPos, CBArgAttrs);
// Call site argument attribute "dereferenceable".
getOrCreateAAFor<AADereferenceable>(CBArgPos);
@@ -3324,10 +3488,11 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
// Call site argument attribute
// "readnone/readonly/writeonly/..."
- getOrCreateAAFor<AAMemoryBehavior>(CBArgPos);
+ if (!CBAttrs.hasParamAttr(I, Attribute::ReadNone))
+ getOrCreateAAFor<AAMemoryBehavior>(CBArgPos);
// Call site argument attribute "nofree".
- getOrCreateAAFor<AANoFree>(CBArgPos);
+ checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(CBArgPos, CBArgAttrs);
}
return true;
};
@@ -3344,18 +3509,21 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
assert(Success && "Expected the check call to be successful!");
auto LoadStorePred = [&](Instruction &I) -> bool {
- if (isa<LoadInst>(I)) {
- getOrCreateAAFor<AAAlign>(
- IRPosition::value(*cast<LoadInst>(I).getPointerOperand()));
+ if (auto *LI = dyn_cast<LoadInst>(&I)) {
+ getOrCreateAAFor<AAAlign>(IRPosition::value(*LI->getPointerOperand()));
if (SimplifyAllLoads)
getAssumedSimplified(IRPosition::value(I), nullptr,
UsedAssumedInformation, AA::Intraprocedural);
+ getOrCreateAAFor<AAAddressSpace>(
+ IRPosition::value(*LI->getPointerOperand()));
} else {
auto &SI = cast<StoreInst>(I);
getOrCreateAAFor<AAIsDead>(IRPosition::inst(I));
getAssumedSimplified(IRPosition::value(*SI.getValueOperand()), nullptr,
UsedAssumedInformation, AA::Intraprocedural);
getOrCreateAAFor<AAAlign>(IRPosition::value(*SI.getPointerOperand()));
+ getOrCreateAAFor<AAAddressSpace>(
+ IRPosition::value(*SI.getPointerOperand()));
}
return true;
};
@@ -3461,7 +3629,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS,
return OS;
}
-void AbstractAttribute::print(raw_ostream &OS) const {
+void AbstractAttribute::print(Attributor *A, raw_ostream &OS) const {
OS << "[";
OS << getName();
OS << "] for CtxI ";
@@ -3473,7 +3641,7 @@ void AbstractAttribute::print(raw_ostream &OS) const {
} else
OS << "<<null inst>>";
- OS << " at position " << getIRPosition() << " with state " << getAsStr()
+ OS << " at position " << getIRPosition() << " with state " << getAsStr(A)
<< '\n';
}
@@ -3679,11 +3847,11 @@ template <> struct GraphTraits<AADepGraphNode *> {
using EdgeRef = PointerIntPair<AADepGraphNode *, 1>;
static NodeRef getEntryNode(AADepGraphNode *DGN) { return DGN; }
- static NodeRef DepGetVal(DepTy &DT) { return DT.getPointer(); }
+ static NodeRef DepGetVal(const DepTy &DT) { return DT.getPointer(); }
using ChildIteratorType =
- mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>;
- using ChildEdgeIteratorType = TinyPtrVector<DepTy>::iterator;
+ mapped_iterator<AADepGraphNode::DepSetTy::iterator, decltype(&DepGetVal)>;
+ using ChildEdgeIteratorType = AADepGraphNode::DepSetTy::iterator;
static ChildIteratorType child_begin(NodeRef N) { return N->child_begin(); }
@@ -3695,7 +3863,7 @@ struct GraphTraits<AADepGraph *> : public GraphTraits<AADepGraphNode *> {
static NodeRef getEntryNode(AADepGraph *DG) { return DG->GetEntryNode(); }
using nodes_iterator =
- mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>;
+ mapped_iterator<AADepGraphNode::DepSetTy::iterator, decltype(&DepGetVal)>;
static nodes_iterator nodes_begin(AADepGraph *DG) { return DG->begin(); }
@@ -3715,98 +3883,3 @@ template <> struct DOTGraphTraits<AADepGraph *> : public DefaultDOTGraphTraits {
};
} // end namespace llvm
-
-namespace {
-
-struct AttributorLegacyPass : public ModulePass {
- static char ID;
-
- AttributorLegacyPass() : ModulePass(ID) {
- initializeAttributorLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- AnalysisGetter AG;
- SetVector<Function *> Functions;
- for (Function &F : M)
- Functions.insert(&F);
-
- CallGraphUpdater CGUpdater;
- BumpPtrAllocator Allocator;
- InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr);
- return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater,
- /* DeleteFns*/ true,
- /* IsModulePass */ true);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- // FIXME: Think about passes we will preserve and add them here.
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- }
-};
-
-struct AttributorCGSCCLegacyPass : public CallGraphSCCPass {
- static char ID;
-
- AttributorCGSCCLegacyPass() : CallGraphSCCPass(ID) {
- initializeAttributorCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnSCC(CallGraphSCC &SCC) override {
- if (skipSCC(SCC))
- return false;
-
- SetVector<Function *> Functions;
- for (CallGraphNode *CGN : SCC)
- if (Function *Fn = CGN->getFunction())
- if (!Fn->isDeclaration())
- Functions.insert(Fn);
-
- if (Functions.empty())
- return false;
-
- AnalysisGetter AG;
- CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph());
- CallGraphUpdater CGUpdater;
- CGUpdater.initialize(CG, SCC);
- Module &M = *Functions.back()->getParent();
- BumpPtrAllocator Allocator;
- InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions);
- return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater,
- /* DeleteFns */ false,
- /* IsModulePass */ false);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- // FIXME: Think about passes we will preserve and add them here.
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- CallGraphSCCPass::getAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-Pass *llvm::createAttributorLegacyPass() { return new AttributorLegacyPass(); }
-Pass *llvm::createAttributorCGSCCLegacyPass() {
- return new AttributorCGSCCLegacyPass();
-}
-
-char AttributorLegacyPass::ID = 0;
-char AttributorCGSCCLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(AttributorLegacyPass, "attributor",
- "Deduce and propagate attributes", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(AttributorLegacyPass, "attributor",
- "Deduce and propagate attributes", false, false)
-INITIALIZE_PASS_BEGIN(AttributorCGSCCLegacyPass, "attributor-cgscc",
- "Deduce and propagate attributes (CGSCC pass)", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
-INITIALIZE_PASS_END(AttributorCGSCCLegacyPass, "attributor-cgscc",
- "Deduce and propagate attributes (CGSCC pass)", false,
- false)
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 001ef55ba472..3a9a89d61355 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -24,6 +24,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -38,6 +39,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Assumptions.h"
+#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -52,6 +54,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
@@ -156,10 +159,11 @@ PIPE_OPERATOR(AAIsDead)
PIPE_OPERATOR(AANoUnwind)
PIPE_OPERATOR(AANoSync)
PIPE_OPERATOR(AANoRecurse)
+PIPE_OPERATOR(AANonConvergent)
PIPE_OPERATOR(AAWillReturn)
PIPE_OPERATOR(AANoReturn)
-PIPE_OPERATOR(AAReturnedValues)
PIPE_OPERATOR(AANonNull)
+PIPE_OPERATOR(AAMustProgress)
PIPE_OPERATOR(AANoAlias)
PIPE_OPERATOR(AADereferenceable)
PIPE_OPERATOR(AAAlign)
@@ -177,11 +181,13 @@ PIPE_OPERATOR(AAUndefinedBehavior)
PIPE_OPERATOR(AAPotentialConstantValues)
PIPE_OPERATOR(AAPotentialValues)
PIPE_OPERATOR(AANoUndef)
+PIPE_OPERATOR(AANoFPClass)
PIPE_OPERATOR(AACallEdges)
PIPE_OPERATOR(AAInterFnReachability)
PIPE_OPERATOR(AAPointerInfo)
PIPE_OPERATOR(AAAssumptionInfo)
PIPE_OPERATOR(AAUnderlyingObjects)
+PIPE_OPERATOR(AAAddressSpace)
#undef PIPE_OPERATOR
@@ -196,6 +202,19 @@ ChangeStatus clampStateAndIndicateChange<DerefState>(DerefState &S,
} // namespace llvm
+static bool mayBeInCycle(const CycleInfo *CI, const Instruction *I,
+ bool HeaderOnly, Cycle **CPtr = nullptr) {
+ if (!CI)
+ return true;
+ auto *BB = I->getParent();
+ auto *C = CI->getCycle(BB);
+ if (!C)
+ return false;
+ if (CPtr)
+ *CPtr = C;
+ return !HeaderOnly || BB == C->getHeader();
+}
+
/// Checks if a type could have padding bytes.
static bool isDenselyPacked(Type *Ty, const DataLayout &DL) {
// There is no size information, so be conservative.
@@ -317,12 +336,14 @@ stripAndAccumulateOffsets(Attributor &A, const AbstractAttribute &QueryingAA,
auto AttributorAnalysis = [&](Value &V, APInt &ROffset) -> bool {
const IRPosition &Pos = IRPosition::value(V);
// Only track dependence if we are going to use the assumed info.
- const AAValueConstantRange &ValueConstantRangeAA =
+ const AAValueConstantRange *ValueConstantRangeAA =
A.getAAFor<AAValueConstantRange>(QueryingAA, Pos,
UseAssumed ? DepClassTy::OPTIONAL
: DepClassTy::NONE);
- ConstantRange Range = UseAssumed ? ValueConstantRangeAA.getAssumed()
- : ValueConstantRangeAA.getKnown();
+ if (!ValueConstantRangeAA)
+ return false;
+ ConstantRange Range = UseAssumed ? ValueConstantRangeAA->getAssumed()
+ : ValueConstantRangeAA->getKnown();
if (Range.isFullSet())
return false;
@@ -355,7 +376,9 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA,
/// Clamp the information known for all returned values of a function
/// (identified by \p QueryingAA) into \p S.
-template <typename AAType, typename StateType = typename AAType::StateType>
+template <typename AAType, typename StateType = typename AAType::StateType,
+ Attribute::AttrKind IRAttributeKind = Attribute::None,
+ bool RecurseForSelectAndPHI = true>
static void clampReturnedValueStates(
Attributor &A, const AAType &QueryingAA, StateType &S,
const IRPosition::CallBaseContext *CBContext = nullptr) {
@@ -376,11 +399,20 @@ static void clampReturnedValueStates(
// Callback for each possibly returned value.
auto CheckReturnValue = [&](Value &RV) -> bool {
const IRPosition &RVPos = IRPosition::value(RV, CBContext);
- const AAType &AA =
+ // If possible, use the hasAssumedIRAttr interface.
+ if (IRAttributeKind != Attribute::None) {
+ bool IsKnown;
+ return AA::hasAssumedIRAttr<IRAttributeKind>(
+ A, &QueryingAA, RVPos, DepClassTy::REQUIRED, IsKnown);
+ }
+
+ const AAType *AA =
A.getAAFor<AAType>(QueryingAA, RVPos, DepClassTy::REQUIRED);
- LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr()
- << " @ " << RVPos << "\n");
- const StateType &AAS = AA.getState();
+ if (!AA)
+ return false;
+ LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV
+ << " AA: " << AA->getAsStr(&A) << " @ " << RVPos << "\n");
+ const StateType &AAS = AA->getState();
if (!T)
T = StateType::getBestState(AAS);
*T &= AAS;
@@ -389,7 +421,9 @@ static void clampReturnedValueStates(
return T->isValidState();
};
- if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA))
+ if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA,
+ AA::ValueScope::Intraprocedural,
+ RecurseForSelectAndPHI))
S.indicatePessimisticFixpoint();
else if (T)
S ^= *T;
@@ -399,7 +433,9 @@ namespace {
/// Helper class for generic deduction: return value -> returned position.
template <typename AAType, typename BaseType,
typename StateType = typename BaseType::StateType,
- bool PropagateCallBaseContext = false>
+ bool PropagateCallBaseContext = false,
+ Attribute::AttrKind IRAttributeKind = Attribute::None,
+ bool RecurseForSelectAndPHI = true>
struct AAReturnedFromReturnedValues : public BaseType {
AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A)
: BaseType(IRP, A) {}
@@ -407,7 +443,7 @@ struct AAReturnedFromReturnedValues : public BaseType {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
StateType S(StateType::getBestState(this->getState()));
- clampReturnedValueStates<AAType, StateType>(
+ clampReturnedValueStates<AAType, StateType, IRAttributeKind, RecurseForSelectAndPHI>(
A, *this, S,
PropagateCallBaseContext ? this->getCallBaseContext() : nullptr);
// TODO: If we know we visited all returned values, thus no are assumed
@@ -418,7 +454,8 @@ struct AAReturnedFromReturnedValues : public BaseType {
/// Clamp the information known at all call sites for a given argument
/// (identified by \p QueryingAA) into \p S.
-template <typename AAType, typename StateType = typename AAType::StateType>
+template <typename AAType, typename StateType = typename AAType::StateType,
+ Attribute::AttrKind IRAttributeKind = Attribute::None>
static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
StateType &S) {
LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for "
@@ -442,11 +479,21 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID)
return false;
- const AAType &AA =
+ // If possible, use the hasAssumedIRAttr interface.
+ if (IRAttributeKind != Attribute::None) {
+ bool IsKnown;
+ return AA::hasAssumedIRAttr<IRAttributeKind>(
+ A, &QueryingAA, ACSArgPos, DepClassTy::REQUIRED, IsKnown);
+ }
+
+ const AAType *AA =
A.getAAFor<AAType>(QueryingAA, ACSArgPos, DepClassTy::REQUIRED);
+ if (!AA)
+ return false;
LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction()
- << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n");
- const StateType &AAS = AA.getState();
+ << " AA: " << AA->getAsStr(&A) << " @" << ACSArgPos
+ << "\n");
+ const StateType &AAS = AA->getState();
if (!T)
T = StateType::getBestState(AAS);
*T &= AAS;
@@ -466,7 +513,8 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
/// This function is the bridge between argument position and the call base
/// context.
template <typename AAType, typename BaseType,
- typename StateType = typename AAType::StateType>
+ typename StateType = typename AAType::StateType,
+ Attribute::AttrKind IRAttributeKind = Attribute::None>
bool getArgumentStateFromCallBaseContext(Attributor &A,
BaseType &QueryingAttribute,
IRPosition &Pos, StateType &State) {
@@ -478,12 +526,21 @@ bool getArgumentStateFromCallBaseContext(Attributor &A,
int ArgNo = Pos.getCallSiteArgNo();
assert(ArgNo >= 0 && "Invalid Arg No!");
+ const IRPosition CBArgPos = IRPosition::callsite_argument(*CBContext, ArgNo);
+
+ // If possible, use the hasAssumedIRAttr interface.
+ if (IRAttributeKind != Attribute::None) {
+ bool IsKnown;
+ return AA::hasAssumedIRAttr<IRAttributeKind>(
+ A, &QueryingAttribute, CBArgPos, DepClassTy::REQUIRED, IsKnown);
+ }
- const auto &AA = A.getAAFor<AAType>(
- QueryingAttribute, IRPosition::callsite_argument(*CBContext, ArgNo),
- DepClassTy::REQUIRED);
+ const auto *AA =
+ A.getAAFor<AAType>(QueryingAttribute, CBArgPos, DepClassTy::REQUIRED);
+ if (!AA)
+ return false;
const StateType &CBArgumentState =
- static_cast<const StateType &>(AA.getState());
+ static_cast<const StateType &>(AA->getState());
LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument"
<< "Position:" << Pos << "CB Arg state:" << CBArgumentState
@@ -497,7 +554,8 @@ bool getArgumentStateFromCallBaseContext(Attributor &A,
/// Helper class for generic deduction: call site argument -> argument position.
template <typename AAType, typename BaseType,
typename StateType = typename AAType::StateType,
- bool BridgeCallBaseContext = false>
+ bool BridgeCallBaseContext = false,
+ Attribute::AttrKind IRAttributeKind = Attribute::None>
struct AAArgumentFromCallSiteArguments : public BaseType {
AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A)
: BaseType(IRP, A) {}
@@ -508,12 +566,14 @@ struct AAArgumentFromCallSiteArguments : public BaseType {
if (BridgeCallBaseContext) {
bool Success =
- getArgumentStateFromCallBaseContext<AAType, BaseType, StateType>(
+ getArgumentStateFromCallBaseContext<AAType, BaseType, StateType,
+ IRAttributeKind>(
A, *this, this->getIRPosition(), S);
if (Success)
return clampStateAndIndicateChange<StateType>(this->getState(), S);
}
- clampCallSiteArgumentStates<AAType, StateType>(A, *this, S);
+ clampCallSiteArgumentStates<AAType, StateType, IRAttributeKind>(A, *this,
+ S);
// TODO: If we know we visited all incoming values, thus no are assumed
// dead, we can take the known information from the state T.
@@ -524,7 +584,8 @@ struct AAArgumentFromCallSiteArguments : public BaseType {
/// Helper class for generic replication: function returned -> cs returned.
template <typename AAType, typename BaseType,
typename StateType = typename BaseType::StateType,
- bool IntroduceCallBaseContext = false>
+ bool IntroduceCallBaseContext = false,
+ Attribute::AttrKind IRAttributeKind = Attribute::None>
struct AACallSiteReturnedFromReturned : public BaseType {
AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A)
: BaseType(IRP, A) {}
@@ -549,8 +610,20 @@ struct AACallSiteReturnedFromReturned : public BaseType {
IRPosition FnPos = IRPosition::returned(
*AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr);
- const AAType &AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(S, AA.getState());
+
+ // If possible, use the hasAssumedIRAttr interface.
+ if (IRAttributeKind != Attribute::None) {
+ bool IsKnown;
+ if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos,
+ DepClassTy::REQUIRED, IsKnown))
+ return S.indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!AA)
+ return S.indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(S, AA->getState());
}
};
@@ -585,16 +658,17 @@ static void followUsesInContext(AAType &AA, Attributor &A,
template <class AAType, typename StateType = typename AAType::StateType>
static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S,
Instruction &CtxI) {
+ MustBeExecutedContextExplorer *Explorer =
+ A.getInfoCache().getMustBeExecutedContextExplorer();
+ if (!Explorer)
+ return;
// Container for (transitive) uses of the associated value.
SetVector<const Use *> Uses;
for (const Use &U : AA.getIRPosition().getAssociatedValue().uses())
Uses.insert(&U);
- MustBeExecutedContextExplorer &Explorer =
- A.getInfoCache().getMustBeExecutedContextExplorer();
-
- followUsesInContext<AAType>(AA, A, Explorer, &CtxI, Uses, S);
+ followUsesInContext<AAType>(AA, A, *Explorer, &CtxI, Uses, S);
if (S.isAtFixpoint())
return;
@@ -639,7 +713,7 @@ static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S,
// }
// }
- Explorer.checkForAllContext(&CtxI, Pred);
+ Explorer->checkForAllContext(&CtxI, Pred);
for (const BranchInst *Br : BrInsts) {
StateType ParentState;
@@ -651,7 +725,7 @@ static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S,
StateType ChildState;
size_t BeforeSize = Uses.size();
- followUsesInContext(AA, A, Explorer, &BB->front(), Uses, ChildState);
+ followUsesInContext(AA, A, *Explorer, &BB->front(), Uses, ChildState);
// Erase uses which only appear in the child.
for (auto It = Uses.begin() + BeforeSize; It != Uses.end();)
@@ -855,7 +929,7 @@ protected:
for (unsigned Index : LocalList->getSecond()) {
for (auto &R : AccessList[Index]) {
Range &= R;
- if (Range.offsetOrSizeAreUnknown())
+ if (Range.offsetAndSizeAreUnknown())
break;
}
}
@@ -887,10 +961,8 @@ ChangeStatus AA::PointerInfo::State::addAccess(
}
auto AddToBins = [&](const AAPointerInfo::RangeList &ToAdd) {
- LLVM_DEBUG(
- if (ToAdd.size())
- dbgs() << "[AAPointerInfo] Inserting access in new offset bins\n";
- );
+ LLVM_DEBUG(if (ToAdd.size()) dbgs()
+ << "[AAPointerInfo] Inserting access in new offset bins\n";);
for (auto Key : ToAdd) {
LLVM_DEBUG(dbgs() << " key " << Key << "\n");
@@ -923,10 +995,8 @@ ChangeStatus AA::PointerInfo::State::addAccess(
// from the offset bins.
AAPointerInfo::RangeList ToRemove;
AAPointerInfo::RangeList::set_difference(ExistingRanges, NewRanges, ToRemove);
- LLVM_DEBUG(
- if (ToRemove.size())
- dbgs() << "[AAPointerInfo] Removing access from old offset bins\n";
- );
+ LLVM_DEBUG(if (ToRemove.size()) dbgs()
+ << "[AAPointerInfo] Removing access from old offset bins\n";);
for (auto Key : ToRemove) {
LLVM_DEBUG(dbgs() << " key " << Key << "\n");
@@ -1011,7 +1081,7 @@ struct AAPointerInfoImpl
AAPointerInfoImpl(const IRPosition &IRP, Attributor &A) : BaseTy(IRP) {}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return std::string("PointerInfo ") +
(isValidState() ? (std::string("#") +
std::to_string(OffsetBins.size()) + " bins")
@@ -1032,6 +1102,7 @@ struct AAPointerInfoImpl
bool forallInterferingAccesses(
Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I,
+ bool FindInterferingWrites, bool FindInterferingReads,
function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo,
AA::RangeTy &Range) const override {
HasBeenWrittenTo = false;
@@ -1040,15 +1111,27 @@ struct AAPointerInfoImpl
SmallVector<std::pair<const Access *, bool>, 8> InterferingAccesses;
Function &Scope = *I.getFunction();
- const auto &NoSyncAA = A.getAAFor<AANoSync>(
- QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL);
+ bool IsKnownNoSync;
+ bool IsAssumedNoSync = AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, &QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL,
+ IsKnownNoSync);
const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>(
- IRPosition::function(Scope), &QueryingAA, DepClassTy::OPTIONAL);
- bool AllInSameNoSyncFn = NoSyncAA.isAssumedNoSync();
+ IRPosition::function(Scope), &QueryingAA, DepClassTy::NONE);
+ bool AllInSameNoSyncFn = IsAssumedNoSync;
bool InstIsExecutedByInitialThreadOnly =
ExecDomainAA && ExecDomainAA->isExecutedByInitialThreadOnly(I);
+
+ // If the function is not ending in aligned barriers, we need the stores to
+ // be in aligned barriers. The load being in one is not sufficient since the
+ // store might be executed by a thread that disappears after, causing the
+ // aligned barrier guarding the load to unblock and the load to read a value
+ // that has no CFG path to the load.
bool InstIsExecutedInAlignedRegion =
- ExecDomainAA && ExecDomainAA->isExecutedInAlignedRegion(A, I);
+ FindInterferingReads && ExecDomainAA &&
+ ExecDomainAA->isExecutedInAlignedRegion(A, I);
+
+ if (InstIsExecutedInAlignedRegion || InstIsExecutedByInitialThreadOnly)
+ A.recordDependence(*ExecDomainAA, QueryingAA, DepClassTy::OPTIONAL);
InformationCache &InfoCache = A.getInfoCache();
bool IsThreadLocalObj =
@@ -1063,14 +1146,25 @@ struct AAPointerInfoImpl
auto CanIgnoreThreadingForInst = [&](const Instruction &I) -> bool {
if (IsThreadLocalObj || AllInSameNoSyncFn)
return true;
- if (!ExecDomainAA)
+ const auto *FnExecDomainAA =
+ I.getFunction() == &Scope
+ ? ExecDomainAA
+ : A.lookupAAFor<AAExecutionDomain>(
+ IRPosition::function(*I.getFunction()), &QueryingAA,
+ DepClassTy::NONE);
+ if (!FnExecDomainAA)
return false;
if (InstIsExecutedInAlignedRegion ||
- ExecDomainAA->isExecutedInAlignedRegion(A, I))
+ (FindInterferingWrites &&
+ FnExecDomainAA->isExecutedInAlignedRegion(A, I))) {
+ A.recordDependence(*FnExecDomainAA, QueryingAA, DepClassTy::OPTIONAL);
return true;
+ }
if (InstIsExecutedByInitialThreadOnly &&
- ExecDomainAA->isExecutedByInitialThreadOnly(I))
+ FnExecDomainAA->isExecutedByInitialThreadOnly(I)) {
+ A.recordDependence(*FnExecDomainAA, QueryingAA, DepClassTy::OPTIONAL);
return true;
+ }
return false;
};
@@ -1084,13 +1178,13 @@ struct AAPointerInfoImpl
};
// TODO: Use inter-procedural reachability and dominance.
- const auto &NoRecurseAA = A.getAAFor<AANoRecurse>(
- QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL);
+ bool IsKnownNoRecurse;
+ AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this, IRPosition::function(Scope), DepClassTy::OPTIONAL,
+ IsKnownNoRecurse);
- const bool FindInterferingWrites = I.mayReadFromMemory();
- const bool FindInterferingReads = I.mayWriteToMemory();
const bool UseDominanceReasoning =
- FindInterferingWrites && NoRecurseAA.isKnownNoRecurse();
+ FindInterferingWrites && IsKnownNoRecurse;
const DominatorTree *DT =
InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(Scope);
@@ -1098,8 +1192,7 @@ struct AAPointerInfoImpl
// outlive a GPU kernel. This is true for shared, constant, and local
// globals on AMD and NVIDIA GPUs.
auto HasKernelLifetime = [&](Value *V, Module &M) {
- Triple T(M.getTargetTriple());
- if (!(T.isAMDGPU() || T.isNVPTX()))
+ if (!AA::isGPU(M))
return false;
switch (AA::GPUAddressSpace(V->getType()->getPointerAddressSpace())) {
case AA::GPUAddressSpace::Shared:
@@ -1122,9 +1215,10 @@ struct AAPointerInfoImpl
// If the alloca containing function is not recursive the alloca
// must be dead in the callee.
const Function *AIFn = AI->getFunction();
- const auto &NoRecurseAA = A.getAAFor<AANoRecurse>(
- *this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL);
- if (NoRecurseAA.isAssumedNoRecurse()) {
+ bool IsKnownNoRecurse;
+ if (AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL,
+ IsKnownNoRecurse)) {
IsLiveInCalleeCB = [AIFn](const Function &Fn) { return AIFn != &Fn; };
}
} else if (auto *GV = dyn_cast<GlobalValue>(&getAssociatedValue())) {
@@ -1220,7 +1314,7 @@ struct AAPointerInfoImpl
if (!WriteChecked && HasBeenWrittenTo &&
Acc.getRemoteInst()->getFunction() != &Scope) {
- const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>(
+ const auto *FnReachabilityAA = A.getAAFor<AAInterFnReachability>(
QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL);
// Without going backwards in the call tree, can we reach the access
@@ -1228,7 +1322,8 @@ struct AAPointerInfoImpl
// itself either.
bool Inserted = ExclusionSet.insert(&I).second;
- if (!FnReachabilityAA.instructionCanReach(
+ if (!FnReachabilityAA ||
+ !FnReachabilityAA->instructionCanReach(
A, *LeastDominatingWriteInst,
*Acc.getRemoteInst()->getFunction(), &ExclusionSet))
WriteChecked = true;
@@ -1337,7 +1432,10 @@ struct AAPointerInfoImpl
O << " --> " << *Acc.getRemoteInst()
<< "\n";
if (!Acc.isWrittenValueYetUndetermined()) {
- if (Acc.getWrittenValue())
+ if (isa_and_nonnull<Function>(Acc.getWrittenValue()))
+ O << " - c: func " << Acc.getWrittenValue()->getName()
+ << "\n";
+ else if (Acc.getWrittenValue())
O << " - c: " << *Acc.getWrittenValue() << "\n";
else
O << " - c: <unknown>\n";
@@ -1450,22 +1548,22 @@ bool AAPointerInfoFloating::collectConstantsForGEP(Attributor &A,
// combination of elements, picked one each from these sets, is separately
// added to the original set of offsets, thus resulting in more offsets.
for (const auto &VI : VariableOffsets) {
- auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>(
+ auto *PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>(
*this, IRPosition::value(*VI.first), DepClassTy::OPTIONAL);
- if (!PotentialConstantsAA.isValidState()) {
+ if (!PotentialConstantsAA || !PotentialConstantsAA->isValidState()) {
UsrOI.setUnknown();
return true;
}
// UndefValue is treated as a zero, which leaves Union as is.
- if (PotentialConstantsAA.undefIsContained())
+ if (PotentialConstantsAA->undefIsContained())
continue;
// We need at least one constant in every set to compute an actual offset.
// Otherwise, we end up pessimizing AAPointerInfo by respecting offsets that
// don't actually exist. In other words, the absence of constant values
// implies that the operation can be assumed dead for now.
- auto &AssumedSet = PotentialConstantsAA.getAssumedSet();
+ auto &AssumedSet = PotentialConstantsAA->getAssumedSet();
if (AssumedSet.empty())
return false;
@@ -1602,16 +1700,6 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
return true;
}
- auto mayBeInCycleHeader = [](const CycleInfo *CI, const Instruction *I) {
- if (!CI)
- return true;
- auto *BB = I->getParent();
- auto *C = CI->getCycle(BB);
- if (!C)
- return false;
- return BB == C->getHeader();
- };
-
// Check if the PHI operand is not dependent on the PHI itself. Every
// recurrence is a cyclic net of PHIs in the data flow, and has an
// equivalent Cycle in the control flow. One of those PHIs must be in the
@@ -1619,7 +1707,7 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
// Cycles reported by CycleInfo. It is sufficient to check the PHIs in
// every Cycle header; if such a node is marked unknown, this will
// eventually propagate through the whole net of PHIs in the recurrence.
- if (mayBeInCycleHeader(CI, cast<Instruction>(Usr))) {
+ if (mayBeInCycle(CI, cast<Instruction>(Usr), /* HeaderOnly */ true)) {
auto BaseOI = It->getSecond();
BaseOI.addToAll(Offset.getZExtValue());
if (IsFirstPHIUser || BaseOI == UsrOI) {
@@ -1681,6 +1769,8 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
return false;
} else {
auto PredIt = pred_begin(IntrBB);
+ if (PredIt == pred_end(IntrBB))
+ return false;
if ((*PredIt) != BB)
return false;
if (++PredIt != pred_end(IntrBB))
@@ -1780,11 +1870,14 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
return true;
if (CB->isArgOperand(&U)) {
unsigned ArgNo = CB->getArgOperandNo(&U);
- const auto &CSArgPI = A.getAAFor<AAPointerInfo>(
+ const auto *CSArgPI = A.getAAFor<AAPointerInfo>(
*this, IRPosition::callsite_argument(*CB, ArgNo),
DepClassTy::REQUIRED);
- Changed = translateAndAddState(A, CSArgPI, OffsetInfoMap[CurPtr], *CB) |
- Changed;
+ if (!CSArgPI)
+ return false;
+ Changed =
+ translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB) |
+ Changed;
return isValidState();
}
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
@@ -1845,13 +1938,6 @@ struct AAPointerInfoArgument final : AAPointerInfoFloating {
AAPointerInfoArgument(const IRPosition &IRP, Attributor &A)
: AAPointerInfoFloating(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AAPointerInfoFloating::initialize(A);
- if (getAnchorScope()->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
AAPointerInfoImpl::trackPointerInfoStatistics(getIRPosition());
@@ -1900,19 +1986,18 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating {
Argument *Arg = getAssociatedArgument();
if (Arg) {
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA =
+ auto *ArgAA =
A.getAAFor<AAPointerInfo>(*this, ArgPos, DepClassTy::REQUIRED);
- if (ArgAA.getState().isValidState())
- return translateAndAddStateFromCallee(A, ArgAA,
+ if (ArgAA && ArgAA->getState().isValidState())
+ return translateAndAddStateFromCallee(A, *ArgAA,
*cast<CallBase>(getCtxI()));
if (!Arg->getParent()->isDeclaration())
return indicatePessimisticFixpoint();
}
- const auto &NoCaptureAA =
- A.getAAFor<AANoCapture>(*this, getIRPosition(), DepClassTy::OPTIONAL);
-
- if (!NoCaptureAA.isAssumedNoCapture())
+ bool IsKnownNoCapture;
+ if (!AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNoCapture))
return indicatePessimisticFixpoint();
bool IsKnown = false;
@@ -1948,7 +2033,15 @@ namespace {
struct AANoUnwindImpl : AANoUnwind {
AANoUnwindImpl(const IRPosition &IRP, Attributor &A) : AANoUnwind(IRP, A) {}
- const std::string getAsStr() const override {
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
+ }
+
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "nounwind" : "may-unwind";
}
@@ -1960,13 +2053,14 @@ struct AANoUnwindImpl : AANoUnwind {
(unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume};
auto CheckForNoUnwind = [&](Instruction &I) {
- if (!I.mayThrow())
+ if (!I.mayThrow(/* IncludePhaseOneUnwind */ true))
return true;
if (const auto *CB = dyn_cast<CallBase>(&I)) {
- const auto &NoUnwindAA = A.getAAFor<AANoUnwind>(
- *this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED);
- return NoUnwindAA.isAssumedNoUnwind();
+ bool IsKnownNoUnwind;
+ return AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED,
+ IsKnownNoUnwind);
}
return false;
};
@@ -1993,14 +2087,6 @@ struct AANoUnwindCallSite final : AANoUnwindImpl {
AANoUnwindCallSite(const IRPosition &IRP, Attributor &A)
: AANoUnwindImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoUnwindImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -2009,263 +2095,15 @@ struct AANoUnwindCallSite final : AANoUnwindImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoUnwind>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
- }
-
- /// See AbstractAttribute::trackStatistics()
- void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); }
-};
-} // namespace
-
-/// --------------------- Function Return Values -------------------------------
-
-namespace {
-/// "Attribute" that collects all potential returned values and the return
-/// instructions that they arise from.
-///
-/// If there is a unique returned value R, the manifest method will:
-/// - mark R with the "returned" attribute, if R is an argument.
-class AAReturnedValuesImpl : public AAReturnedValues, public AbstractState {
-
- /// Mapping of values potentially returned by the associated function to the
- /// return instructions that might return them.
- MapVector<Value *, SmallSetVector<ReturnInst *, 4>> ReturnedValues;
-
- /// State flags
- ///
- ///{
- bool IsFixed = false;
- bool IsValidState = true;
- ///}
-
-public:
- AAReturnedValuesImpl(const IRPosition &IRP, Attributor &A)
- : AAReturnedValues(IRP, A) {}
-
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- // Reset the state.
- IsFixed = false;
- IsValidState = true;
- ReturnedValues.clear();
-
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration()) {
- indicatePessimisticFixpoint();
- return;
- }
- assert(!F->getReturnType()->isVoidTy() &&
- "Did not expect a void return type!");
-
- // The map from instruction opcodes to those instructions in the function.
- auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F);
-
- // Look through all arguments, if one is marked as returned we are done.
- for (Argument &Arg : F->args()) {
- if (Arg.hasReturnedAttr()) {
- auto &ReturnInstSet = ReturnedValues[&Arg];
- if (auto *Insts = OpcodeInstMap.lookup(Instruction::Ret))
- for (Instruction *RI : *Insts)
- ReturnInstSet.insert(cast<ReturnInst>(RI));
-
- indicateOptimisticFixpoint();
- return;
- }
- }
-
- if (!A.isFunctionIPOAmendable(*F))
- indicatePessimisticFixpoint();
- }
-
- /// See AbstractAttribute::manifest(...).
- ChangeStatus manifest(Attributor &A) override;
-
- /// See AbstractAttribute::getState(...).
- AbstractState &getState() override { return *this; }
-
- /// See AbstractAttribute::getState(...).
- const AbstractState &getState() const override { return *this; }
-
- /// See AbstractAttribute::updateImpl(Attributor &A).
- ChangeStatus updateImpl(Attributor &A) override;
-
- llvm::iterator_range<iterator> returned_values() override {
- return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
- }
-
- llvm::iterator_range<const_iterator> returned_values() const override {
- return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
- }
-
- /// Return the number of potential return values, -1 if unknown.
- size_t getNumReturnValues() const override {
- return isValidState() ? ReturnedValues.size() : -1;
- }
-
- /// Return an assumed unique return value if a single candidate is found. If
- /// there cannot be one, return a nullptr. If it is not clear yet, return
- /// std::nullopt.
- std::optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const;
-
- /// See AbstractState::checkForAllReturnedValues(...).
- bool checkForAllReturnedValuesAndReturnInsts(
- function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred)
- const override;
-
- /// Pretty print the attribute similar to the IR representation.
- const std::string getAsStr() const override;
-
- /// See AbstractState::isAtFixpoint().
- bool isAtFixpoint() const override { return IsFixed; }
-
- /// See AbstractState::isValidState().
- bool isValidState() const override { return IsValidState; }
-
- /// See AbstractState::indicateOptimisticFixpoint(...).
- ChangeStatus indicateOptimisticFixpoint() override {
- IsFixed = true;
- return ChangeStatus::UNCHANGED;
- }
-
- ChangeStatus indicatePessimisticFixpoint() override {
- IsFixed = true;
- IsValidState = false;
- return ChangeStatus::CHANGED;
- }
-};
-
-ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) {
- ChangeStatus Changed = ChangeStatus::UNCHANGED;
-
- // Bookkeeping.
- assert(isValidState());
- STATS_DECLTRACK(KnownReturnValues, FunctionReturn,
- "Number of function with known return values");
-
- // Check if we have an assumed unique return value that we could manifest.
- std::optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A);
-
- if (!UniqueRV || !*UniqueRV)
- return Changed;
-
- // Bookkeeping.
- STATS_DECLTRACK(UniqueReturnValue, FunctionReturn,
- "Number of function with unique return");
- // If the assumed unique return value is an argument, annotate it.
- if (auto *UniqueRVArg = dyn_cast<Argument>(*UniqueRV)) {
- if (UniqueRVArg->getType()->canLosslesslyBitCastTo(
- getAssociatedFunction()->getReturnType())) {
- getIRPosition() = IRPosition::argument(*UniqueRVArg);
- Changed = IRAttribute::manifest(A);
- }
- }
- return Changed;
-}
-
-const std::string AAReturnedValuesImpl::getAsStr() const {
- return (isAtFixpoint() ? "returns(#" : "may-return(#") +
- (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")";
-}
-
-std::optional<Value *>
-AAReturnedValuesImpl::getAssumedUniqueReturnValue(Attributor &A) const {
- // If checkForAllReturnedValues provides a unique value, ignoring potential
- // undef values that can also be present, it is assumed to be the actual
- // return value and forwarded to the caller of this method. If there are
- // multiple, a nullptr is returned indicating there cannot be a unique
- // returned value.
- std::optional<Value *> UniqueRV;
- Type *Ty = getAssociatedFunction()->getReturnType();
-
- auto Pred = [&](Value &RV) -> bool {
- UniqueRV = AA::combineOptionalValuesInAAValueLatice(UniqueRV, &RV, Ty);
- return UniqueRV != std::optional<Value *>(nullptr);
- };
-
- if (!A.checkForAllReturnedValues(Pred, *this))
- UniqueRV = nullptr;
-
- return UniqueRV;
-}
-
-bool AAReturnedValuesImpl::checkForAllReturnedValuesAndReturnInsts(
- function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred)
- const {
- if (!isValidState())
- return false;
-
- // Check all returned values but ignore call sites as long as we have not
- // encountered an overdefined one during an update.
- for (const auto &It : ReturnedValues) {
- Value *RV = It.first;
- if (!Pred(*RV, It.second))
- return false;
- }
-
- return true;
-}
-
-ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
- ChangeStatus Changed = ChangeStatus::UNCHANGED;
-
- SmallVector<AA::ValueAndContext> Values;
- bool UsedAssumedInformation = false;
- auto ReturnInstCB = [&](Instruction &I) {
- ReturnInst &Ret = cast<ReturnInst>(I);
- Values.clear();
- if (!A.getAssumedSimplifiedValues(IRPosition::value(*Ret.getReturnValue()),
- *this, Values, AA::Intraprocedural,
- UsedAssumedInformation))
- Values.push_back({*Ret.getReturnValue(), Ret});
-
- for (auto &VAC : Values) {
- assert(AA::isValidInScope(*VAC.getValue(), Ret.getFunction()) &&
- "Assumed returned value should be valid in function scope!");
- if (ReturnedValues[VAC.getValue()].insert(&Ret))
- Changed = ChangeStatus::CHANGED;
- }
- return true;
- };
-
- // Discover returned values from all live returned instructions in the
- // associated function.
- if (!A.checkForAllInstructions(ReturnInstCB, *this, {Instruction::Ret},
- UsedAssumedInformation))
- return indicatePessimisticFixpoint();
- return Changed;
-}
-
-struct AAReturnedValuesFunction final : public AAReturnedValuesImpl {
- AAReturnedValuesFunction(const IRPosition &IRP, Attributor &A)
- : AAReturnedValuesImpl(IRP, A) {}
-
- /// See AbstractAttribute::trackStatistics()
- void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(returned) }
-};
-
-/// Returned values information for a call sites.
-struct AAReturnedValuesCallSite final : AAReturnedValuesImpl {
- AAReturnedValuesCallSite(const IRPosition &IRP, Attributor &A)
- : AAReturnedValuesImpl(IRP, A) {}
-
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites instead of
- // redirecting requests to the callee.
- llvm_unreachable("Abstract attributes for returned values are not "
- "supported for call sites yet!");
- }
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
+ bool IsKnownNoUnwind;
+ if (AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoUnwind))
+ return ChangeStatus::UNCHANGED;
return indicatePessimisticFixpoint();
}
/// See AbstractAttribute::trackStatistics()
- void trackStatistics() const override {}
+ void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); }
};
} // namespace
@@ -2334,7 +2172,15 @@ namespace {
struct AANoSyncImpl : AANoSync {
AANoSyncImpl(const IRPosition &IRP, Attributor &A) : AANoSync(IRP, A) {}
- const std::string getAsStr() const override {
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoSync>(A, nullptr, getIRPosition(),
+ DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
+ }
+
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "nosync" : "may-sync";
}
@@ -2381,14 +2227,6 @@ struct AANoSyncCallSite final : AANoSyncImpl {
AANoSyncCallSite(const IRPosition &IRP, Attributor &A)
: AANoSyncImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoSyncImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -2397,8 +2235,11 @@ struct AANoSyncCallSite final : AANoSyncImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoSync>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnownNoSycn;
+ if (AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoSycn))
+ return ChangeStatus::UNCHANGED;
+ return indicatePessimisticFixpoint();
}
/// See AbstractAttribute::trackStatistics()
@@ -2412,16 +2253,21 @@ namespace {
struct AANoFreeImpl : public AANoFree {
AANoFreeImpl(const IRPosition &IRP, Attributor &A) : AANoFree(IRP, A) {}
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoFree>(A, nullptr, getIRPosition(),
+ DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
+ }
+
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
auto CheckForNoFree = [&](Instruction &I) {
- const auto &CB = cast<CallBase>(I);
- if (CB.hasFnAttr(Attribute::NoFree))
- return true;
-
- const auto &NoFreeAA = A.getAAFor<AANoFree>(
- *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
- return NoFreeAA.isAssumedNoFree();
+ bool IsKnown;
+ return AA::hasAssumedIRAttr<Attribute::NoFree>(
+ A, this, IRPosition::callsite_function(cast<CallBase>(I)),
+ DepClassTy::REQUIRED, IsKnown);
};
bool UsedAssumedInformation = false;
@@ -2432,7 +2278,7 @@ struct AANoFreeImpl : public AANoFree {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "nofree" : "may-free";
}
};
@@ -2450,14 +2296,6 @@ struct AANoFreeCallSite final : AANoFreeImpl {
AANoFreeCallSite(const IRPosition &IRP, Attributor &A)
: AANoFreeImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoFreeImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -2466,8 +2304,11 @@ struct AANoFreeCallSite final : AANoFreeImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoFree>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, FnPos,
+ DepClassTy::REQUIRED, IsKnown))
+ return ChangeStatus::UNCHANGED;
+ return indicatePessimisticFixpoint();
}
/// See AbstractAttribute::trackStatistics()
@@ -2486,9 +2327,10 @@ struct AANoFreeFloating : AANoFreeImpl {
ChangeStatus updateImpl(Attributor &A) override {
const IRPosition &IRP = getIRPosition();
- const auto &NoFreeAA = A.getAAFor<AANoFree>(
- *this, IRPosition::function_scope(IRP), DepClassTy::OPTIONAL);
- if (NoFreeAA.isAssumedNoFree())
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this,
+ IRPosition::function_scope(IRP),
+ DepClassTy::OPTIONAL, IsKnown))
return ChangeStatus::UNCHANGED;
Value &AssociatedValue = getIRPosition().getAssociatedValue();
@@ -2501,10 +2343,10 @@ struct AANoFreeFloating : AANoFreeImpl {
return true;
unsigned ArgNo = CB->getArgOperandNo(&U);
- const auto &NoFreeArg = A.getAAFor<AANoFree>(
- *this, IRPosition::callsite_argument(*CB, ArgNo),
- DepClassTy::REQUIRED);
- return NoFreeArg.isAssumedNoFree();
+ bool IsKnown;
+ return AA::hasAssumedIRAttr<Attribute::NoFree>(
+ A, this, IRPosition::callsite_argument(*CB, ArgNo),
+ DepClassTy::REQUIRED, IsKnown);
}
if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI) ||
@@ -2550,8 +2392,11 @@ struct AANoFreeCallSiteArgument final : AANoFreeFloating {
if (!Arg)
return indicatePessimisticFixpoint();
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA = A.getAAFor<AANoFree>(*this, ArgPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), ArgAA.getState());
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, ArgPos,
+ DepClassTy::REQUIRED, IsKnown))
+ return ChangeStatus::UNCHANGED;
+ return indicatePessimisticFixpoint();
}
/// See AbstractAttribute::trackStatistics()
@@ -2593,6 +2438,39 @@ struct AANoFreeCallSiteReturned final : AANoFreeFloating {
} // namespace
/// ------------------------ NonNull Argument Attribute ------------------------
+
+bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP,
+ Attribute::AttrKind ImpliedAttributeKind,
+ bool IgnoreSubsumingPositions) {
+ SmallVector<Attribute::AttrKind, 2> AttrKinds;
+ AttrKinds.push_back(Attribute::NonNull);
+ if (!NullPointerIsDefined(IRP.getAnchorScope(),
+ IRP.getAssociatedType()->getPointerAddressSpace()))
+ AttrKinds.push_back(Attribute::Dereferenceable);
+ if (A.hasAttr(IRP, AttrKinds, IgnoreSubsumingPositions, Attribute::NonNull))
+ return true;
+
+ if (IRP.getPositionKind() == IRP_RETURNED)
+ return false;
+
+ DominatorTree *DT = nullptr;
+ AssumptionCache *AC = nullptr;
+ InformationCache &InfoCache = A.getInfoCache();
+ if (const Function *Fn = IRP.getAnchorScope()) {
+ if (!Fn->isDeclaration()) {
+ DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*Fn);
+ AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn);
+ }
+ }
+
+ if (!isKnownNonZero(&IRP.getAssociatedValue(), A.getDataLayout(), 0, AC,
+ IRP.getCtxI(), DT))
+ return false;
+ A.manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(),
+ Attribute::NonNull)});
+ return true;
+}
+
namespace {
static int64_t getKnownNonNullAndDerefBytesForUse(
Attributor &A, const AbstractAttribute &QueryingAA, Value &AssociatedValue,
@@ -2641,10 +2519,13 @@ static int64_t getKnownNonNullAndDerefBytesForUse(
IRPosition IRP = IRPosition::callsite_argument(*CB, ArgNo);
// As long as we only use known information there is no need to track
// dependences here.
- auto &DerefAA =
+ bool IsKnownNonNull;
+ AA::hasAssumedIRAttr<Attribute::NonNull>(A, &QueryingAA, IRP,
+ DepClassTy::NONE, IsKnownNonNull);
+ IsNonNull |= IsKnownNonNull;
+ auto *DerefAA =
A.getAAFor<AADereferenceable>(QueryingAA, IRP, DepClassTy::NONE);
- IsNonNull |= DerefAA.isKnownNonNull();
- return DerefAA.getKnownDereferenceableBytes();
+ return DerefAA ? DerefAA->getKnownDereferenceableBytes() : 0;
}
std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I);
@@ -2673,43 +2554,16 @@ static int64_t getKnownNonNullAndDerefBytesForUse(
}
struct AANonNullImpl : AANonNull {
- AANonNullImpl(const IRPosition &IRP, Attributor &A)
- : AANonNull(IRP, A),
- NullIsDefined(NullPointerIsDefined(
- getAnchorScope(),
- getAssociatedValue().getType()->getPointerAddressSpace())) {}
+ AANonNullImpl(const IRPosition &IRP, Attributor &A) : AANonNull(IRP, A) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
Value &V = *getAssociatedValue().stripPointerCasts();
- if (!NullIsDefined &&
- hasAttr({Attribute::NonNull, Attribute::Dereferenceable},
- /* IgnoreSubsumingPositions */ false, &A)) {
- indicateOptimisticFixpoint();
- return;
- }
-
if (isa<ConstantPointerNull>(V)) {
indicatePessimisticFixpoint();
return;
}
- AANonNull::initialize(A);
-
- bool CanBeNull, CanBeFreed;
- if (V.getPointerDereferenceableBytes(A.getDataLayout(), CanBeNull,
- CanBeFreed)) {
- if (!CanBeNull) {
- indicateOptimisticFixpoint();
- return;
- }
- }
-
- if (isa<GlobalValue>(V)) {
- indicatePessimisticFixpoint();
- return;
- }
-
if (Instruction *CtxI = getCtxI())
followUsesInMBEC(*this, A, getState(), *CtxI);
}
@@ -2726,13 +2580,9 @@ struct AANonNullImpl : AANonNull {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "nonnull" : "may-null";
}
-
- /// Flag to determine if the underlying value can be null and still allow
- /// valid accesses.
- const bool NullIsDefined;
};
/// NonNull attribute for a floating value.
@@ -2742,48 +2592,39 @@ struct AANonNullFloating : public AANonNullImpl {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
- const DataLayout &DL = A.getDataLayout();
+ auto CheckIRP = [&](const IRPosition &IRP) {
+ bool IsKnownNonNull;
+ return AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, *this, IRP, DepClassTy::OPTIONAL, IsKnownNonNull);
+ };
bool Stripped;
bool UsedAssumedInformation = false;
+ Value *AssociatedValue = &getAssociatedValue();
SmallVector<AA::ValueAndContext> Values;
if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values,
- AA::AnyScope, UsedAssumedInformation)) {
- Values.push_back({getAssociatedValue(), getCtxI()});
+ AA::AnyScope, UsedAssumedInformation))
Stripped = false;
- } else {
- Stripped = Values.size() != 1 ||
- Values.front().getValue() != &getAssociatedValue();
- }
-
- DominatorTree *DT = nullptr;
- AssumptionCache *AC = nullptr;
- InformationCache &InfoCache = A.getInfoCache();
- if (const Function *Fn = getAnchorScope()) {
- DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*Fn);
- AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn);
+ else
+ Stripped =
+ Values.size() != 1 || Values.front().getValue() != AssociatedValue;
+
+ if (!Stripped) {
+ // If we haven't stripped anything we might still be able to use a
+ // different AA, but only if the IRP changes. Effectively when we
+ // interpret this not as a call site value but as a floating/argument
+ // value.
+ const IRPosition AVIRP = IRPosition::value(*AssociatedValue);
+ if (AVIRP == getIRPosition() || !CheckIRP(AVIRP))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
}
- AANonNull::StateType T;
- auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool {
- const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V),
- DepClassTy::REQUIRED);
- if (!Stripped && this == &AA) {
- if (!isKnownNonZero(&V, DL, 0, AC, CtxI, DT))
- T.indicatePessimisticFixpoint();
- } else {
- // Use abstract attribute information.
- const AANonNull::StateType &NS = AA.getState();
- T ^= NS;
- }
- return T.isValidState();
- };
-
for (const auto &VAC : Values)
- if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI()))
+ if (!CheckIRP(IRPosition::value(*VAC.getValue())))
return indicatePessimisticFixpoint();
- return clampStateAndIndicateChange(getState(), T);
+ return ChangeStatus::UNCHANGED;
}
/// See AbstractAttribute::trackStatistics()
@@ -2792,12 +2633,14 @@ struct AANonNullFloating : public AANonNullImpl {
/// NonNull attribute for function return value.
struct AANonNullReturned final
- : AAReturnedFromReturnedValues<AANonNull, AANonNull> {
+ : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType,
+ false, AANonNull::IRAttributeKind> {
AANonNullReturned(const IRPosition &IRP, Attributor &A)
- : AAReturnedFromReturnedValues<AANonNull, AANonNull>(IRP, A) {}
+ : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType,
+ false, Attribute::NonNull>(IRP, A) {}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "nonnull" : "may-null";
}
@@ -2807,9 +2650,13 @@ struct AANonNullReturned final
/// NonNull attribute for function argument.
struct AANonNullArgument final
- : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl> {
+ : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl,
+ AANonNull::StateType, false,
+ AANonNull::IRAttributeKind> {
AANonNullArgument(const IRPosition &IRP, Attributor &A)
- : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl>(IRP, A) {}
+ : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl,
+ AANonNull::StateType, false,
+ AANonNull::IRAttributeKind>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) }
@@ -2825,23 +2672,118 @@ struct AANonNullCallSiteArgument final : AANonNullFloating {
/// NonNull attribute for a call site return position.
struct AANonNullCallSiteReturned final
- : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl> {
+ : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl,
+ AANonNull::StateType, false,
+ AANonNull::IRAttributeKind> {
AANonNullCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl>(IRP, A) {}
+ : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl,
+ AANonNull::StateType, false,
+ AANonNull::IRAttributeKind>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) }
};
} // namespace
+/// ------------------------ Must-Progress Attributes --------------------------
+namespace {
+struct AAMustProgressImpl : public AAMustProgress {
+ AAMustProgressImpl(const IRPosition &IRP, Attributor &A)
+ : AAMustProgress(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::MustProgress>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
+ }
+
+ /// See AbstractAttribute::getAsStr()
+ const std::string getAsStr(Attributor *A) const override {
+ return getAssumed() ? "mustprogress" : "may-not-progress";
+ }
+};
+
+struct AAMustProgressFunction final : AAMustProgressImpl {
+ AAMustProgressFunction(const IRPosition &IRP, Attributor &A)
+ : AAMustProgressImpl(IRP, A) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::WillReturn>(
+ A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnown)) {
+ if (IsKnown)
+ return indicateOptimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ auto CheckForMustProgress = [&](AbstractCallSite ACS) {
+ IRPosition IPos = IRPosition::callsite_function(*ACS.getInstruction());
+ bool IsKnownMustProgress;
+ return AA::hasAssumedIRAttr<Attribute::MustProgress>(
+ A, this, IPos, DepClassTy::REQUIRED, IsKnownMustProgress,
+ /* IgnoreSubsumingPositions */ true);
+ };
+
+ bool AllCallSitesKnown = true;
+ if (!A.checkForAllCallSites(CheckForMustProgress, *this,
+ /* RequireAllCallSites */ true,
+ AllCallSitesKnown))
+ return indicatePessimisticFixpoint();
+
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FN_ATTR(mustprogress)
+ }
+};
+
+/// MustProgress attribute deduction for a call sites.
+struct AAMustProgressCallSite final : AAMustProgressImpl {
+ AAMustProgressCallSite(const IRPosition &IRP, Attributor &A)
+ : AAMustProgressImpl(IRP, A) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // TODO: Once we have call site specific value information we can provide
+ // call site specific liveness information and then it makes
+ // sense to specialize attributes for call sites arguments instead of
+ // redirecting requests to the callee argument.
+ const IRPosition &FnPos = IRPosition::function(*getAnchorScope());
+ bool IsKnownMustProgress;
+ if (!AA::hasAssumedIRAttr<Attribute::MustProgress>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownMustProgress))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CS_ATTR(mustprogress);
+ }
+};
+} // namespace
+
/// ------------------------ No-Recurse Attributes ----------------------------
namespace {
struct AANoRecurseImpl : public AANoRecurse {
AANoRecurseImpl(const IRPosition &IRP, Attributor &A) : AANoRecurse(IRP, A) {}
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
+ }
+
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "norecurse" : "may-recurse";
}
};
@@ -2855,10 +2797,13 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
// If all live call sites are known to be no-recurse, we are as well.
auto CallSitePred = [&](AbstractCallSite ACS) {
- const auto &NoRecurseAA = A.getAAFor<AANoRecurse>(
- *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
- DepClassTy::NONE);
- return NoRecurseAA.isKnownNoRecurse();
+ bool IsKnownNoRecurse;
+ if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this,
+ IRPosition::function(*ACS.getInstruction()->getFunction()),
+ DepClassTy::NONE, IsKnownNoRecurse))
+ return false;
+ return IsKnownNoRecurse;
};
bool UsedAssumedInformation = false;
if (A.checkForAllCallSites(CallSitePred, *this, true,
@@ -2873,10 +2818,10 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
return ChangeStatus::UNCHANGED;
}
- const AAInterFnReachability &EdgeReachability =
+ const AAInterFnReachability *EdgeReachability =
A.getAAFor<AAInterFnReachability>(*this, getIRPosition(),
DepClassTy::REQUIRED);
- if (EdgeReachability.canReach(A, *getAnchorScope()))
+ if (EdgeReachability && EdgeReachability->canReach(A, *getAnchorScope()))
return indicatePessimisticFixpoint();
return ChangeStatus::UNCHANGED;
}
@@ -2889,14 +2834,6 @@ struct AANoRecurseCallSite final : AANoRecurseImpl {
AANoRecurseCallSite(const IRPosition &IRP, Attributor &A)
: AANoRecurseImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoRecurseImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -2905,8 +2842,11 @@ struct AANoRecurseCallSite final : AANoRecurseImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoRecurse>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnownNoRecurse;
+ if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoRecurse))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
}
/// See AbstractAttribute::trackStatistics()
@@ -2914,6 +2854,62 @@ struct AANoRecurseCallSite final : AANoRecurseImpl {
};
} // namespace
+/// ------------------------ No-Convergent Attribute --------------------------
+
+namespace {
+struct AANonConvergentImpl : public AANonConvergent {
+ AANonConvergentImpl(const IRPosition &IRP, Attributor &A)
+ : AANonConvergent(IRP, A) {}
+
+ /// See AbstractAttribute::getAsStr()
+ const std::string getAsStr(Attributor *A) const override {
+ return getAssumed() ? "non-convergent" : "may-be-convergent";
+ }
+};
+
+struct AANonConvergentFunction final : AANonConvergentImpl {
+ AANonConvergentFunction(const IRPosition &IRP, Attributor &A)
+ : AANonConvergentImpl(IRP, A) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ // If all function calls are known to not be convergent, we are not
+ // convergent.
+ auto CalleeIsNotConvergent = [&](Instruction &Inst) {
+ CallBase &CB = cast<CallBase>(Inst);
+ auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
+ if (!Callee || Callee->isIntrinsic()) {
+ return false;
+ }
+ if (Callee->isDeclaration()) {
+ return !Callee->hasFnAttribute(Attribute::Convergent);
+ }
+ const auto *ConvergentAA = A.getAAFor<AANonConvergent>(
+ *this, IRPosition::function(*Callee), DepClassTy::REQUIRED);
+ return ConvergentAA && ConvergentAA->isAssumedNotConvergent();
+ };
+
+ bool UsedAssumedInformation = false;
+ if (!A.checkForAllCallLikeInstructions(CalleeIsNotConvergent, *this,
+ UsedAssumedInformation)) {
+ return indicatePessimisticFixpoint();
+ }
+ return ChangeStatus::UNCHANGED;
+ }
+
+ ChangeStatus manifest(Attributor &A) override {
+ if (isKnownNotConvergent() &&
+ A.hasAttr(getIRPosition(), Attribute::Convergent)) {
+ A.removeAttrs(getIRPosition(), {Attribute::Convergent});
+ return ChangeStatus::CHANGED;
+ }
+ return ChangeStatus::UNCHANGED;
+ }
+
+ void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(convergent) }
+};
+} // namespace
+
/// -------------------- Undefined-Behavior Attributes ------------------------
namespace {
@@ -3009,7 +3005,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
// Check nonnull and noundef argument attribute violation for each
// callsite.
CallBase &CB = cast<CallBase>(I);
- Function *Callee = CB.getCalledFunction();
+ auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
if (!Callee)
return true;
for (unsigned idx = 0; idx < CB.arg_size(); idx++) {
@@ -3030,9 +3026,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
// (3) Simplified to null pointer where known to be nonnull.
// The argument is a poison value and violate noundef attribute.
IRPosition CalleeArgumentIRP = IRPosition::callsite_argument(CB, idx);
- auto &NoUndefAA =
- A.getAAFor<AANoUndef>(*this, CalleeArgumentIRP, DepClassTy::NONE);
- if (!NoUndefAA.isKnownNoUndef())
+ bool IsKnownNoUndef;
+ AA::hasAssumedIRAttr<Attribute::NoUndef>(
+ A, this, CalleeArgumentIRP, DepClassTy::NONE, IsKnownNoUndef);
+ if (!IsKnownNoUndef)
continue;
bool UsedAssumedInformation = false;
std::optional<Value *> SimplifiedVal =
@@ -3049,9 +3046,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
if (!ArgVal->getType()->isPointerTy() ||
!isa<ConstantPointerNull>(**SimplifiedVal))
continue;
- auto &NonNullAA =
- A.getAAFor<AANonNull>(*this, CalleeArgumentIRP, DepClassTy::NONE);
- if (NonNullAA.isKnownNonNull())
+ bool IsKnownNonNull;
+ AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, CalleeArgumentIRP, DepClassTy::NONE, IsKnownNonNull);
+ if (IsKnownNonNull)
KnownUBInsts.insert(&I);
}
return true;
@@ -3081,9 +3079,11 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
// position has nonnull attribute (because the returned value is
// poison).
if (isa<ConstantPointerNull>(*SimplifiedRetValue)) {
- auto &NonNullAA = A.getAAFor<AANonNull>(
- *this, IRPosition::returned(*getAnchorScope()), DepClassTy::NONE);
- if (NonNullAA.isKnownNonNull())
+ bool IsKnownNonNull;
+ AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::returned(*getAnchorScope()), DepClassTy::NONE,
+ IsKnownNonNull);
+ if (IsKnownNonNull)
KnownUBInsts.insert(&I);
}
@@ -3108,9 +3108,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
if (!getAnchorScope()->getReturnType()->isVoidTy()) {
const IRPosition &ReturnIRP = IRPosition::returned(*getAnchorScope());
if (!A.isAssumedDead(ReturnIRP, this, nullptr, UsedAssumedInformation)) {
- auto &RetPosNoUndefAA =
- A.getAAFor<AANoUndef>(*this, ReturnIRP, DepClassTy::NONE);
- if (RetPosNoUndefAA.isKnownNoUndef())
+ bool IsKnownNoUndef;
+ AA::hasAssumedIRAttr<Attribute::NoUndef>(
+ A, this, ReturnIRP, DepClassTy::NONE, IsKnownNoUndef);
+ if (IsKnownNoUndef)
A.checkForAllInstructions(InspectReturnInstForUB, *this,
{Instruction::Ret}, UsedAssumedInformation,
/* CheckBBLivenessOnly */ true);
@@ -3161,7 +3162,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior {
}
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "undefined-behavior" : "no-ub";
}
@@ -3284,20 +3285,15 @@ struct AAWillReturnImpl : public AAWillReturn {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- AAWillReturn::initialize(A);
-
- if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ true)) {
- indicateOptimisticFixpoint();
- return;
- }
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::WillReturn>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
}
/// Check for `mustprogress` and `readonly` as they imply `willreturn`.
bool isImpliedByMustprogressAndReadonly(Attributor &A, bool KnownOnly) {
- // Check for `mustprogress` in the scope and the associated function which
- // might be different if this is a call site.
- if ((!getAnchorScope() || !getAnchorScope()->mustProgress()) &&
- (!getAssociatedFunction() || !getAssociatedFunction()->mustProgress()))
+ if (!A.hasAttr(getIRPosition(), {Attribute::MustProgress}))
return false;
bool IsKnown;
@@ -3313,15 +3309,17 @@ struct AAWillReturnImpl : public AAWillReturn {
auto CheckForWillReturn = [&](Instruction &I) {
IRPosition IPos = IRPosition::callsite_function(cast<CallBase>(I));
- const auto &WillReturnAA =
- A.getAAFor<AAWillReturn>(*this, IPos, DepClassTy::REQUIRED);
- if (WillReturnAA.isKnownWillReturn())
- return true;
- if (!WillReturnAA.isAssumedWillReturn())
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::WillReturn>(
+ A, this, IPos, DepClassTy::REQUIRED, IsKnown)) {
+ if (IsKnown)
+ return true;
+ } else {
return false;
- const auto &NoRecurseAA =
- A.getAAFor<AANoRecurse>(*this, IPos, DepClassTy::REQUIRED);
- return NoRecurseAA.isAssumedNoRecurse();
+ }
+ bool IsKnownNoRecurse;
+ return AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this, IPos, DepClassTy::REQUIRED, IsKnownNoRecurse);
};
bool UsedAssumedInformation = false;
@@ -3333,7 +3331,7 @@ struct AAWillReturnImpl : public AAWillReturn {
}
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "willreturn" : "may-noreturn";
}
};
@@ -3347,7 +3345,8 @@ struct AAWillReturnFunction final : AAWillReturnImpl {
AAWillReturnImpl::initialize(A);
Function *F = getAnchorScope();
- if (!F || F->isDeclaration() || mayContainUnboundedCycle(*F, A))
+ assert(F && "Did expect an anchor function");
+ if (F->isDeclaration() || mayContainUnboundedCycle(*F, A))
indicatePessimisticFixpoint();
}
@@ -3360,14 +3359,6 @@ struct AAWillReturnCallSite final : AAWillReturnImpl {
AAWillReturnCallSite(const IRPosition &IRP, Attributor &A)
: AAWillReturnImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AAWillReturnImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || !A.isFunctionIPOAmendable(*F))
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false))
@@ -3379,8 +3370,11 @@ struct AAWillReturnCallSite final : AAWillReturnImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AAWillReturn>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnown;
+ if (AA::hasAssumedIRAttr<Attribute::WillReturn>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnown))
+ return ChangeStatus::UNCHANGED;
+ return indicatePessimisticFixpoint();
}
/// See AbstractAttribute::trackStatistics()
@@ -3414,22 +3408,18 @@ template <typename ToTy> struct ReachabilityQueryInfo {
/// Constructor replacement to ensure unique and stable sets are used for the
/// cache.
ReachabilityQueryInfo(Attributor &A, const Instruction &From, const ToTy &To,
- const AA::InstExclusionSetTy *ES)
+ const AA::InstExclusionSetTy *ES, bool MakeUnique)
: From(&From), To(&To), ExclusionSet(ES) {
- if (ExclusionSet && !ExclusionSet->empty()) {
- ExclusionSet =
- A.getInfoCache().getOrCreateUniqueBlockExecutionSet(ExclusionSet);
- } else {
+ if (!ES || ES->empty()) {
ExclusionSet = nullptr;
+ } else if (MakeUnique) {
+ ExclusionSet = A.getInfoCache().getOrCreateUniqueBlockExecutionSet(ES);
}
}
ReachabilityQueryInfo(const ReachabilityQueryInfo &RQI)
- : From(RQI.From), To(RQI.To), ExclusionSet(RQI.ExclusionSet) {
- assert(RQI.Result == Reachable::No &&
- "Didn't expect to copy an explored RQI!");
- }
+ : From(RQI.From), To(RQI.To), ExclusionSet(RQI.ExclusionSet) {}
};
namespace llvm {
@@ -3482,8 +3472,7 @@ template <typename BaseTy, typename ToTy>
struct CachedReachabilityAA : public BaseTy {
using RQITy = ReachabilityQueryInfo<ToTy>;
- CachedReachabilityAA<BaseTy, ToTy>(const IRPosition &IRP, Attributor &A)
- : BaseTy(IRP, A) {}
+ CachedReachabilityAA(const IRPosition &IRP, Attributor &A) : BaseTy(IRP, A) {}
/// See AbstractAttribute::isQueryAA.
bool isQueryAA() const override { return true; }
@@ -3492,7 +3481,8 @@ struct CachedReachabilityAA : public BaseTy {
ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
InUpdate = true;
- for (RQITy *RQI : QueryVector) {
+ for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) {
+ RQITy *RQI = QueryVector[u];
if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI))
Changed = ChangeStatus::CHANGED;
}
@@ -3503,39 +3493,78 @@ struct CachedReachabilityAA : public BaseTy {
virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0;
bool rememberResult(Attributor &A, typename RQITy::Reachable Result,
- RQITy &RQI) {
- if (Result == RQITy::Reachable::No) {
- if (!InUpdate)
- A.registerForUpdate(*this);
- return false;
- }
- assert(RQI.Result == RQITy::Reachable::No && "Already reachable?");
+ RQITy &RQI, bool UsedExclusionSet) {
RQI.Result = Result;
- return true;
+
+ // Remove the temporary RQI from the cache.
+ if (!InUpdate)
+ QueryCache.erase(&RQI);
+
+ // Insert a plain RQI (w/o exclusion set) if that makes sense. Two options:
+ // 1) If it is reachable, it doesn't matter if we have an exclusion set for
+ // this query. 2) We did not use the exclusion set, potentially because
+ // there is none.
+ if (Result == RQITy::Reachable::Yes || !UsedExclusionSet) {
+ RQITy PlainRQI(RQI.From, RQI.To);
+ if (!QueryCache.count(&PlainRQI)) {
+ RQITy *RQIPtr = new (A.Allocator) RQITy(RQI.From, RQI.To);
+ RQIPtr->Result = Result;
+ QueryVector.push_back(RQIPtr);
+ QueryCache.insert(RQIPtr);
+ }
+ }
+
+ // Check if we need to insert a new permanent RQI with the exclusion set.
+ if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
+ assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) &&
+ "Did not expect empty set!");
+ RQITy *RQIPtr = new (A.Allocator)
+ RQITy(A, *RQI.From, *RQI.To, RQI.ExclusionSet, true);
+ assert(RQIPtr->Result == RQITy::Reachable::No && "Already reachable?");
+ RQIPtr->Result = Result;
+ assert(!QueryCache.count(RQIPtr));
+ QueryVector.push_back(RQIPtr);
+ QueryCache.insert(RQIPtr);
+ }
+
+ if (Result == RQITy::Reachable::No && !InUpdate)
+ A.registerForUpdate(*this);
+ return Result == RQITy::Reachable::Yes;
}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
// TODO: Return the number of reachable queries.
return "#queries(" + std::to_string(QueryVector.size()) + ")";
}
- RQITy *checkQueryCache(Attributor &A, RQITy &StackRQI,
- typename RQITy::Reachable &Result) {
+ bool checkQueryCache(Attributor &A, RQITy &StackRQI,
+ typename RQITy::Reachable &Result) {
if (!this->getState().isValidState()) {
Result = RQITy::Reachable::Yes;
- return nullptr;
+ return true;
+ }
+
+ // If we have an exclusion set we might be able to find our answer by
+ // ignoring it first.
+ if (StackRQI.ExclusionSet) {
+ RQITy PlainRQI(StackRQI.From, StackRQI.To);
+ auto It = QueryCache.find(&PlainRQI);
+ if (It != QueryCache.end() && (*It)->Result == RQITy::Reachable::No) {
+ Result = RQITy::Reachable::No;
+ return true;
+ }
}
auto It = QueryCache.find(&StackRQI);
if (It != QueryCache.end()) {
Result = (*It)->Result;
- return nullptr;
+ return true;
}
- RQITy *RQIPtr = new (A.Allocator) RQITy(StackRQI);
- QueryVector.push_back(RQIPtr);
- QueryCache.insert(RQIPtr);
- return RQIPtr;
+ // Insert a temporary for recursive queries. We will replace it with a
+ // permanent entry later.
+ QueryCache.insert(&StackRQI);
+ return false;
}
private:
@@ -3546,8 +3575,9 @@ private:
struct AAIntraFnReachabilityFunction final
: public CachedReachabilityAA<AAIntraFnReachability, Instruction> {
+ using Base = CachedReachabilityAA<AAIntraFnReachability, Instruction>;
AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A)
- : CachedReachabilityAA<AAIntraFnReachability, Instruction>(IRP, A) {}
+ : Base(IRP, A) {}
bool isAssumedReachable(
Attributor &A, const Instruction &From, const Instruction &To,
@@ -3556,23 +3586,39 @@ struct AAIntraFnReachabilityFunction final
if (&From == &To)
return true;
- RQITy StackRQI(A, From, To, ExclusionSet);
+ RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
- if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result)) {
- return NonConstThis->isReachableImpl(A, *RQIPtr);
- }
+ if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
+ return NonConstThis->isReachableImpl(A, StackRQI);
return Result == RQITy::Reachable::Yes;
}
+ ChangeStatus updateImpl(Attributor &A) override {
+ // We only depend on liveness. DeadEdges is all we care about, check if any
+ // of them changed.
+ auto *LivenessAA =
+ A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (LivenessAA && llvm::all_of(DeadEdges, [&](const auto &DeadEdge) {
+ return LivenessAA->isEdgeDead(DeadEdge.first, DeadEdge.second);
+ })) {
+ return ChangeStatus::UNCHANGED;
+ }
+ DeadEdges.clear();
+ return Base::updateImpl(A);
+ }
+
bool isReachableImpl(Attributor &A, RQITy &RQI) override {
const Instruction *Origin = RQI.From;
+ bool UsedExclusionSet = false;
- auto WillReachInBlock = [=](const Instruction &From, const Instruction &To,
+ auto WillReachInBlock = [&](const Instruction &From, const Instruction &To,
const AA::InstExclusionSetTy *ExclusionSet) {
const Instruction *IP = &From;
while (IP && IP != &To) {
- if (ExclusionSet && IP != Origin && ExclusionSet->count(IP))
+ if (ExclusionSet && IP != Origin && ExclusionSet->count(IP)) {
+ UsedExclusionSet = true;
break;
+ }
IP = IP->getNextNode();
}
return IP == &To;
@@ -3587,7 +3633,12 @@ struct AAIntraFnReachabilityFunction final
// possible.
if (FromBB == ToBB &&
WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
+
+ // Check if reaching the ToBB block is sufficient or if even that would not
+ // ensure reaching the target. In the latter case we are done.
+ if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet))
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks;
if (RQI.ExclusionSet)
@@ -3598,40 +3649,80 @@ struct AAIntraFnReachabilityFunction final
if (ExclusionBlocks.count(FromBB) &&
!WillReachInBlock(*RQI.From, *FromBB->getTerminator(),
RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::No, RQI);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
SmallPtrSet<const BasicBlock *, 16> Visited;
SmallVector<const BasicBlock *, 16> Worklist;
Worklist.push_back(FromBB);
- auto &LivenessAA =
+ DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> LocalDeadEdges;
+ auto *LivenessAA =
A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
while (!Worklist.empty()) {
const BasicBlock *BB = Worklist.pop_back_val();
if (!Visited.insert(BB).second)
continue;
for (const BasicBlock *SuccBB : successors(BB)) {
- if (LivenessAA.isEdgeDead(BB, SuccBB))
+ if (LivenessAA && LivenessAA->isEdgeDead(BB, SuccBB)) {
+ LocalDeadEdges.insert({BB, SuccBB});
continue;
- if (SuccBB == ToBB &&
- WillReachInBlock(SuccBB->front(), *RQI.To, RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
- if (ExclusionBlocks.count(SuccBB))
+ }
+ // We checked before if we just need to reach the ToBB block.
+ if (SuccBB == ToBB)
+ return rememberResult(A, RQITy::Reachable::Yes, RQI,
+ UsedExclusionSet);
+ if (ExclusionBlocks.count(SuccBB)) {
+ UsedExclusionSet = true;
continue;
+ }
Worklist.push_back(SuccBB);
}
}
- return rememberResult(A, RQITy::Reachable::No, RQI);
+ DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end());
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {}
+
+private:
+ // Set of assumed dead edges we used in the last query. If any changes we
+ // update the state.
+ DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> DeadEdges;
};
} // namespace
/// ------------------------ NoAlias Argument Attribute ------------------------
+bool AANoAlias::isImpliedByIR(Attributor &A, const IRPosition &IRP,
+ Attribute::AttrKind ImpliedAttributeKind,
+ bool IgnoreSubsumingPositions) {
+ assert(ImpliedAttributeKind == Attribute::NoAlias &&
+ "Unexpected attribute kind");
+ Value *Val = &IRP.getAssociatedValue();
+ if (IRP.getPositionKind() != IRP_CALL_SITE_ARGUMENT) {
+ if (isa<AllocaInst>(Val))
+ return true;
+ } else {
+ IgnoreSubsumingPositions = true;
+ }
+
+ if (isa<UndefValue>(Val))
+ return true;
+
+ if (isa<ConstantPointerNull>(Val) &&
+ !NullPointerIsDefined(IRP.getAnchorScope(),
+ Val->getType()->getPointerAddressSpace()))
+ return true;
+
+ if (A.hasAttr(IRP, {Attribute::ByVal, Attribute::NoAlias},
+ IgnoreSubsumingPositions, Attribute::NoAlias))
+ return true;
+
+ return false;
+}
+
namespace {
struct AANoAliasImpl : AANoAlias {
AANoAliasImpl(const IRPosition &IRP, Attributor &A) : AANoAlias(IRP, A) {
@@ -3639,7 +3730,7 @@ struct AANoAliasImpl : AANoAlias {
"Noalias is a pointer attribute");
}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "noalias" : "may-alias";
}
};
@@ -3649,39 +3740,6 @@ struct AANoAliasFloating final : AANoAliasImpl {
AANoAliasFloating(const IRPosition &IRP, Attributor &A)
: AANoAliasImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoAliasImpl::initialize(A);
- Value *Val = &getAssociatedValue();
- do {
- CastInst *CI = dyn_cast<CastInst>(Val);
- if (!CI)
- break;
- Value *Base = CI->getOperand(0);
- if (!Base->hasOneUse())
- break;
- Val = Base;
- } while (true);
-
- if (!Val->getType()->isPointerTy()) {
- indicatePessimisticFixpoint();
- return;
- }
-
- if (isa<AllocaInst>(Val))
- indicateOptimisticFixpoint();
- else if (isa<ConstantPointerNull>(Val) &&
- !NullPointerIsDefined(getAnchorScope(),
- Val->getType()->getPointerAddressSpace()))
- indicateOptimisticFixpoint();
- else if (Val != &getAssociatedValue()) {
- const auto &ValNoAliasAA = A.getAAFor<AANoAlias>(
- *this, IRPosition::value(*Val), DepClassTy::OPTIONAL);
- if (ValNoAliasAA.isKnownNoAlias())
- indicateOptimisticFixpoint();
- }
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Implement this.
@@ -3696,18 +3754,14 @@ struct AANoAliasFloating final : AANoAliasImpl {
/// NoAlias attribute for an argument.
struct AANoAliasArgument final
- : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> {
- using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>;
+ : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl,
+ AANoAlias::StateType, false,
+ Attribute::NoAlias> {
+ using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl,
+ AANoAlias::StateType, false,
+ Attribute::NoAlias>;
AANoAliasArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- Base::initialize(A);
- // See callsite argument attribute and callee argument attribute.
- if (hasAttr({Attribute::ByVal}))
- indicateOptimisticFixpoint();
- }
-
/// See AbstractAttribute::update(...).
ChangeStatus updateImpl(Attributor &A) override {
// We have to make sure no-alias on the argument does not break
@@ -3716,10 +3770,10 @@ struct AANoAliasArgument final
// function, otherwise we give up for now.
// If the function is no-sync, no-alias cannot break synchronization.
- const auto &NoSyncAA =
- A.getAAFor<AANoSync>(*this, IRPosition::function_scope(getIRPosition()),
- DepClassTy::OPTIONAL);
- if (NoSyncAA.isAssumedNoSync())
+ bool IsKnownNoSycn;
+ if (AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, this, IRPosition::function_scope(getIRPosition()),
+ DepClassTy::OPTIONAL, IsKnownNoSycn))
return Base::updateImpl(A);
// If the argument is read-only, no-alias cannot break synchronization.
@@ -3752,19 +3806,6 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
AANoAliasCallSiteArgument(const IRPosition &IRP, Attributor &A)
: AANoAliasImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- // See callsite argument attribute and callee argument attribute.
- const auto &CB = cast<CallBase>(getAnchorValue());
- if (CB.paramHasAttr(getCallSiteArgNo(), Attribute::NoAlias))
- indicateOptimisticFixpoint();
- Value &Val = getAssociatedValue();
- if (isa<ConstantPointerNull>(Val) &&
- !NullPointerIsDefined(getAnchorScope(),
- Val.getType()->getPointerAddressSpace()))
- indicateOptimisticFixpoint();
- }
-
/// Determine if the underlying value may alias with the call site argument
/// \p OtherArgNo of \p ICS (= the underlying call site).
bool mayAliasWithArgument(Attributor &A, AAResults *&AAR,
@@ -3779,27 +3820,29 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
if (!ArgOp->getType()->isPtrOrPtrVectorTy())
return false;
- auto &CBArgMemBehaviorAA = A.getAAFor<AAMemoryBehavior>(
+ auto *CBArgMemBehaviorAA = A.getAAFor<AAMemoryBehavior>(
*this, IRPosition::callsite_argument(CB, OtherArgNo), DepClassTy::NONE);
// If the argument is readnone, there is no read-write aliasing.
- if (CBArgMemBehaviorAA.isAssumedReadNone()) {
- A.recordDependence(CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL);
+ if (CBArgMemBehaviorAA && CBArgMemBehaviorAA->isAssumedReadNone()) {
+ A.recordDependence(*CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL);
return false;
}
// If the argument is readonly and the underlying value is readonly, there
// is no read-write aliasing.
bool IsReadOnly = MemBehaviorAA.isAssumedReadOnly();
- if (CBArgMemBehaviorAA.isAssumedReadOnly() && IsReadOnly) {
+ if (CBArgMemBehaviorAA && CBArgMemBehaviorAA->isAssumedReadOnly() &&
+ IsReadOnly) {
A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL);
- A.recordDependence(CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL);
+ A.recordDependence(*CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL);
return false;
}
// We have to utilize actual alias analysis queries so we need the object.
if (!AAR)
- AAR = A.getInfoCache().getAAResultsForFunction(*getAnchorScope());
+ AAR = A.getInfoCache().getAnalysisResultForFunction<AAManager>(
+ *getAnchorScope());
// Try to rule it out at the call site.
bool IsAliasing = !AAR || !AAR->isNoAlias(&getAssociatedValue(), ArgOp);
@@ -3811,10 +3854,8 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
return IsAliasing;
}
- bool
- isKnownNoAliasDueToNoAliasPreservation(Attributor &A, AAResults *&AAR,
- const AAMemoryBehavior &MemBehaviorAA,
- const AANoAlias &NoAliasAA) {
+ bool isKnownNoAliasDueToNoAliasPreservation(
+ Attributor &A, AAResults *&AAR, const AAMemoryBehavior &MemBehaviorAA) {
// We can deduce "noalias" if the following conditions hold.
// (i) Associated value is assumed to be noalias in the definition.
// (ii) Associated value is assumed to be no-capture in all the uses
@@ -3822,24 +3863,14 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
// (iii) There is no other pointer argument which could alias with the
// value.
- bool AssociatedValueIsNoAliasAtDef = NoAliasAA.isAssumedNoAlias();
- if (!AssociatedValueIsNoAliasAtDef) {
- LLVM_DEBUG(dbgs() << "[AANoAlias] " << getAssociatedValue()
- << " is not no-alias at the definition\n");
- return false;
- }
-
auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) {
- const auto &DerefAA = A.getAAFor<AADereferenceable>(
+ const auto *DerefAA = A.getAAFor<AADereferenceable>(
*this, IRPosition::value(*O), DepClassTy::OPTIONAL);
- return DerefAA.getAssumedDereferenceableBytes();
+ return DerefAA ? DerefAA->getAssumedDereferenceableBytes() : 0;
};
- A.recordDependence(NoAliasAA, *this, DepClassTy::OPTIONAL);
-
const IRPosition &VIRP = IRPosition::value(getAssociatedValue());
const Function *ScopeFn = VIRP.getAnchorScope();
- auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, VIRP, DepClassTy::NONE);
// Check whether the value is captured in the scope using AANoCapture.
// Look at CFG and check only uses possibly executed before this
// callsite.
@@ -3859,11 +3890,10 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
unsigned ArgNo = CB->getArgOperandNo(&U);
- const auto &NoCaptureAA = A.getAAFor<AANoCapture>(
- *this, IRPosition::callsite_argument(*CB, ArgNo),
- DepClassTy::OPTIONAL);
-
- if (NoCaptureAA.isAssumedNoCapture())
+ bool IsKnownNoCapture;
+ if (AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, IRPosition::callsite_argument(*CB, ArgNo),
+ DepClassTy::OPTIONAL, IsKnownNoCapture))
return true;
}
}
@@ -3891,7 +3921,12 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
llvm_unreachable("unknown UseCaptureKind");
};
- if (!NoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ bool IsKnownNoCapture;
+ const AANoCapture *NoCaptureAA = nullptr;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, VIRP, DepClassTy::NONE, IsKnownNoCapture, false, &NoCaptureAA);
+ if (!IsAssumedNoCapture &&
+ (!NoCaptureAA || !NoCaptureAA->isAssumedNoCaptureMaybeReturned())) {
if (!A.checkForAllUses(UsePred, *this, getAssociatedValue())) {
LLVM_DEBUG(
dbgs() << "[AANoAliasCSArg] " << getAssociatedValue()
@@ -3899,7 +3934,8 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
return false;
}
}
- A.recordDependence(NoCaptureAA, *this, DepClassTy::OPTIONAL);
+ if (NoCaptureAA)
+ A.recordDependence(*NoCaptureAA, *this, DepClassTy::OPTIONAL);
// Check there is no other pointer argument which could alias with the
// value passed at this call site.
@@ -3916,20 +3952,25 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
ChangeStatus updateImpl(Attributor &A) override {
// If the argument is readnone we are done as there are no accesses via the
// argument.
- auto &MemBehaviorAA =
+ auto *MemBehaviorAA =
A.getAAFor<AAMemoryBehavior>(*this, getIRPosition(), DepClassTy::NONE);
- if (MemBehaviorAA.isAssumedReadNone()) {
- A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL);
+ if (MemBehaviorAA && MemBehaviorAA->isAssumedReadNone()) {
+ A.recordDependence(*MemBehaviorAA, *this, DepClassTy::OPTIONAL);
return ChangeStatus::UNCHANGED;
}
+ bool IsKnownNoAlias;
const IRPosition &VIRP = IRPosition::value(getAssociatedValue());
- const auto &NoAliasAA =
- A.getAAFor<AANoAlias>(*this, VIRP, DepClassTy::NONE);
+ if (!AA::hasAssumedIRAttr<Attribute::NoAlias>(
+ A, this, VIRP, DepClassTy::REQUIRED, IsKnownNoAlias)) {
+ LLVM_DEBUG(dbgs() << "[AANoAlias] " << getAssociatedValue()
+ << " is not no-alias at the definition\n");
+ return indicatePessimisticFixpoint();
+ }
AAResults *AAR = nullptr;
- if (isKnownNoAliasDueToNoAliasPreservation(A, AAR, MemBehaviorAA,
- NoAliasAA)) {
+ if (MemBehaviorAA &&
+ isKnownNoAliasDueToNoAliasPreservation(A, AAR, *MemBehaviorAA)) {
LLVM_DEBUG(
dbgs() << "[AANoAlias] No-Alias deduced via no-alias preservation\n");
return ChangeStatus::UNCHANGED;
@@ -3947,14 +3988,6 @@ struct AANoAliasReturned final : AANoAliasImpl {
AANoAliasReturned(const IRPosition &IRP, Attributor &A)
: AANoAliasImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoAliasImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
@@ -3969,14 +4002,18 @@ struct AANoAliasReturned final : AANoAliasImpl {
return false;
const IRPosition &RVPos = IRPosition::value(RV);
- const auto &NoAliasAA =
- A.getAAFor<AANoAlias>(*this, RVPos, DepClassTy::REQUIRED);
- if (!NoAliasAA.isAssumedNoAlias())
+ bool IsKnownNoAlias;
+ if (!AA::hasAssumedIRAttr<Attribute::NoAlias>(
+ A, this, RVPos, DepClassTy::REQUIRED, IsKnownNoAlias))
return false;
- const auto &NoCaptureAA =
- A.getAAFor<AANoCapture>(*this, RVPos, DepClassTy::REQUIRED);
- return NoCaptureAA.isAssumedNoCaptureMaybeReturned();
+ bool IsKnownNoCapture;
+ const AANoCapture *NoCaptureAA = nullptr;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, RVPos, DepClassTy::REQUIRED, IsKnownNoCapture, false,
+ &NoCaptureAA);
+ return IsAssumedNoCapture ||
+ (NoCaptureAA && NoCaptureAA->isAssumedNoCaptureMaybeReturned());
};
if (!A.checkForAllReturnedValues(CheckReturnValue, *this))
@@ -3994,14 +4031,6 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl {
AANoAliasCallSiteReturned(const IRPosition &IRP, Attributor &A)
: AANoAliasImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoAliasImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -4010,8 +4039,11 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::returned(*F);
- auto &FnAA = A.getAAFor<AANoAlias>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnownNoAlias;
+ if (!AA::hasAssumedIRAttr<Attribute::NoAlias>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoAlias))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
}
/// See AbstractAttribute::trackStatistics()
@@ -4025,13 +4057,6 @@ namespace {
struct AAIsDeadValueImpl : public AAIsDead {
AAIsDeadValueImpl(const IRPosition &IRP, Attributor &A) : AAIsDead(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- if (auto *Scope = getAnchorScope())
- if (!A.isRunOn(*Scope))
- indicatePessimisticFixpoint();
- }
-
/// See AAIsDead::isAssumedDead().
bool isAssumedDead() const override { return isAssumed(IS_DEAD); }
@@ -4055,7 +4080,7 @@ struct AAIsDeadValueImpl : public AAIsDead {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return isAssumedDead() ? "assumed-dead" : "assumed-live";
}
@@ -4097,12 +4122,11 @@ struct AAIsDeadValueImpl : public AAIsDead {
return false;
const IRPosition &CallIRP = IRPosition::callsite_function(*CB);
- const auto &NoUnwindAA =
- A.getAndUpdateAAFor<AANoUnwind>(*this, CallIRP, DepClassTy::NONE);
- if (!NoUnwindAA.isAssumedNoUnwind())
+
+ bool IsKnownNoUnwind;
+ if (!AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, this, CallIRP, DepClassTy::OPTIONAL, IsKnownNoUnwind))
return false;
- if (!NoUnwindAA.isKnownNoUnwind())
- A.recordDependence(NoUnwindAA, *this, DepClassTy::OPTIONAL);
bool IsKnown;
return AA::isAssumedReadOnly(A, CallIRP, *this, IsKnown);
@@ -4124,13 +4148,22 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl {
Instruction *I = dyn_cast<Instruction>(&getAssociatedValue());
if (!isAssumedSideEffectFree(A, I)) {
- if (!isa_and_nonnull<StoreInst>(I))
+ if (!isa_and_nonnull<StoreInst>(I) && !isa_and_nonnull<FenceInst>(I))
indicatePessimisticFixpoint();
else
removeAssumedBits(HAS_NO_EFFECT);
}
}
+ bool isDeadFence(Attributor &A, FenceInst &FI) {
+ const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>(
+ IRPosition::function(*FI.getFunction()), *this, DepClassTy::NONE);
+ if (!ExecDomainAA || !ExecDomainAA->isNoOpFence(FI))
+ return false;
+ A.recordDependence(*ExecDomainAA, *this, DepClassTy::OPTIONAL);
+ return true;
+ }
+
bool isDeadStore(Attributor &A, StoreInst &SI,
SmallSetVector<Instruction *, 8> *AssumeOnlyInst = nullptr) {
// Lang ref now states volatile store is not UB/dead, let's skip them.
@@ -4161,12 +4194,14 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl {
return true;
if (auto *LI = dyn_cast<LoadInst>(V)) {
if (llvm::all_of(LI->uses(), [&](const Use &U) {
- return InfoCache.isOnlyUsedByAssume(
- cast<Instruction>(*U.getUser())) ||
- A.isAssumedDead(U, this, nullptr, UsedAssumedInformation);
+ auto &UserI = cast<Instruction>(*U.getUser());
+ if (InfoCache.isOnlyUsedByAssume(UserI)) {
+ if (AssumeOnlyInst)
+ AssumeOnlyInst->insert(&UserI);
+ return true;
+ }
+ return A.isAssumedDead(U, this, nullptr, UsedAssumedInformation);
})) {
- if (AssumeOnlyInst)
- AssumeOnlyInst->insert(LI);
return true;
}
}
@@ -4177,12 +4212,15 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
Instruction *I = dyn_cast<Instruction>(&getAssociatedValue());
if (isa_and_nonnull<StoreInst>(I))
if (isValidState())
return "assumed-dead-store";
- return AAIsDeadValueImpl::getAsStr();
+ if (isa_and_nonnull<FenceInst>(I))
+ if (isValidState())
+ return "assumed-dead-fence";
+ return AAIsDeadValueImpl::getAsStr(A);
}
/// See AbstractAttribute::updateImpl(...).
@@ -4191,6 +4229,9 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl {
if (auto *SI = dyn_cast_or_null<StoreInst>(I)) {
if (!isDeadStore(A, *SI))
return indicatePessimisticFixpoint();
+ } else if (auto *FI = dyn_cast_or_null<FenceInst>(I)) {
+ if (!isDeadFence(A, *FI))
+ return indicatePessimisticFixpoint();
} else {
if (!isAssumedSideEffectFree(A, I))
return indicatePessimisticFixpoint();
@@ -4226,6 +4267,11 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl {
}
return ChangeStatus::CHANGED;
}
+ if (auto *FI = dyn_cast<FenceInst>(I)) {
+ assert(isDeadFence(A, *FI));
+ A.deleteAfterManifest(*FI);
+ return ChangeStatus::CHANGED;
+ }
if (isAssumedSideEffectFree(A, I) && !isa<InvokeInst>(I)) {
A.deleteAfterManifest(*I);
return ChangeStatus::CHANGED;
@@ -4248,13 +4294,6 @@ struct AAIsDeadArgument : public AAIsDeadFloating {
AAIsDeadArgument(const IRPosition &IRP, Attributor &A)
: AAIsDeadFloating(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AAIsDeadFloating::initialize(A);
- if (!A.isFunctionIPOAmendable(*getAnchorScope()))
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
Argument &Arg = *getAssociatedArgument();
@@ -4293,8 +4332,10 @@ struct AAIsDeadCallSiteArgument : public AAIsDeadValueImpl {
if (!Arg)
return indicatePessimisticFixpoint();
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA = A.getAAFor<AAIsDead>(*this, ArgPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), ArgAA.getState());
+ auto *ArgAA = A.getAAFor<AAIsDead>(*this, ArgPos, DepClassTy::REQUIRED);
+ if (!ArgAA)
+ return indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(getState(), ArgAA->getState());
}
/// See AbstractAttribute::manifest(...).
@@ -4355,7 +4396,7 @@ struct AAIsDeadCallSiteReturned : public AAIsDeadFloating {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return isAssumedDead()
? "assumed-dead"
: (getAssumed() ? "assumed-dead-users" : "assumed-live");
@@ -4416,10 +4457,7 @@ struct AAIsDeadFunction : public AAIsDead {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
Function *F = getAnchorScope();
- if (!F || F->isDeclaration() || !A.isRunOn(*F)) {
- indicatePessimisticFixpoint();
- return;
- }
+ assert(F && "Did expect an anchor function");
if (!isAssumedDeadInternalFunction(A)) {
ToBeExploredFrom.insert(&F->getEntryBlock().front());
assumeLive(A, F->getEntryBlock());
@@ -4435,7 +4473,7 @@ struct AAIsDeadFunction : public AAIsDead {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return "Live[#BB " + std::to_string(AssumedLiveBlocks.size()) + "/" +
std::to_string(getAnchorScope()->size()) + "][#TBEP " +
std::to_string(ToBeExploredFrom.size()) + "][#KDE " +
@@ -4465,9 +4503,10 @@ struct AAIsDeadFunction : public AAIsDead {
auto *CB = dyn_cast<CallBase>(DeadEndI);
if (!CB)
continue;
- const auto &NoReturnAA = A.getAndUpdateAAFor<AANoReturn>(
- *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
- bool MayReturn = !NoReturnAA.isAssumedNoReturn();
+ bool IsKnownNoReturn;
+ bool MayReturn = !AA::hasAssumedIRAttr<Attribute::NoReturn>(
+ A, this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL,
+ IsKnownNoReturn);
if (MayReturn && (!Invoke2CallAllowed || !isa<InvokeInst>(CB)))
continue;
@@ -4564,7 +4603,7 @@ struct AAIsDeadFunction : public AAIsDead {
// functions. It can however cause dead functions to be treated as live.
for (const Instruction &I : BB)
if (const auto *CB = dyn_cast<CallBase>(&I))
- if (const Function *F = CB->getCalledFunction())
+ if (auto *F = dyn_cast_if_present<Function>(CB->getCalledOperand()))
if (F->hasLocalLinkage())
A.markLiveInternalFunction(*F);
return true;
@@ -4590,10 +4629,10 @@ identifyAliveSuccessors(Attributor &A, const CallBase &CB,
SmallVectorImpl<const Instruction *> &AliveSuccessors) {
const IRPosition &IPos = IRPosition::callsite_function(CB);
- const auto &NoReturnAA =
- A.getAndUpdateAAFor<AANoReturn>(AA, IPos, DepClassTy::OPTIONAL);
- if (NoReturnAA.isAssumedNoReturn())
- return !NoReturnAA.isKnownNoReturn();
+ bool IsKnownNoReturn;
+ if (AA::hasAssumedIRAttr<Attribute::NoReturn>(
+ A, &AA, IPos, DepClassTy::OPTIONAL, IsKnownNoReturn))
+ return !IsKnownNoReturn;
if (CB.isTerminator())
AliveSuccessors.push_back(&CB.getSuccessor(0)->front());
else
@@ -4615,10 +4654,11 @@ identifyAliveSuccessors(Attributor &A, const InvokeInst &II,
AliveSuccessors.push_back(&II.getUnwindDest()->front());
} else {
const IRPosition &IPos = IRPosition::callsite_function(II);
- const auto &AANoUnw =
- A.getAndUpdateAAFor<AANoUnwind>(AA, IPos, DepClassTy::OPTIONAL);
- if (AANoUnw.isAssumedNoUnwind()) {
- UsedAssumedInformation |= !AANoUnw.isKnownNoUnwind();
+
+ bool IsKnownNoUnwind;
+ if (AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, &AA, IPos, DepClassTy::OPTIONAL, IsKnownNoUnwind)) {
+ UsedAssumedInformation |= !IsKnownNoUnwind;
} else {
AliveSuccessors.push_back(&II.getUnwindDest()->front());
}
@@ -4829,25 +4869,21 @@ struct AADereferenceableImpl : AADereferenceable {
void initialize(Attributor &A) override {
Value &V = *getAssociatedValue().stripPointerCasts();
SmallVector<Attribute, 4> Attrs;
- getAttrs({Attribute::Dereferenceable, Attribute::DereferenceableOrNull},
- Attrs, /* IgnoreSubsumingPositions */ false, &A);
+ A.getAttrs(getIRPosition(),
+ {Attribute::Dereferenceable, Attribute::DereferenceableOrNull},
+ Attrs, /* IgnoreSubsumingPositions */ false);
for (const Attribute &Attr : Attrs)
takeKnownDerefBytesMaximum(Attr.getValueAsInt());
- const IRPosition &IRP = this->getIRPosition();
- NonNullAA = &A.getAAFor<AANonNull>(*this, IRP, DepClassTy::NONE);
+ // Ensure we initialize the non-null AA (if necessary).
+ bool IsKnownNonNull;
+ AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNonNull);
bool CanBeNull, CanBeFreed;
takeKnownDerefBytesMaximum(V.getPointerDereferenceableBytes(
A.getDataLayout(), CanBeNull, CanBeFreed));
- bool IsFnInterface = IRP.isFnInterfaceKind();
- Function *FnScope = IRP.getAnchorScope();
- if (IsFnInterface && (!FnScope || !A.isFunctionIPOAmendable(*FnScope))) {
- indicatePessimisticFixpoint();
- return;
- }
-
if (Instruction *CtxI = getCtxI())
followUsesInMBEC(*this, A, getState(), *CtxI);
}
@@ -4894,17 +4930,24 @@ struct AADereferenceableImpl : AADereferenceable {
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
ChangeStatus Change = AADereferenceable::manifest(A);
- if (isAssumedNonNull() && hasAttr(Attribute::DereferenceableOrNull)) {
- removeAttrs({Attribute::DereferenceableOrNull});
+ bool IsKnownNonNull;
+ bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull);
+ if (IsAssumedNonNull &&
+ A.hasAttr(getIRPosition(), Attribute::DereferenceableOrNull)) {
+ A.removeAttrs(getIRPosition(), {Attribute::DereferenceableOrNull});
return ChangeStatus::CHANGED;
}
return Change;
}
- void getDeducedAttributes(LLVMContext &Ctx,
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
SmallVectorImpl<Attribute> &Attrs) const override {
// TODO: Add *_globally support
- if (isAssumedNonNull())
+ bool IsKnownNonNull;
+ bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull);
+ if (IsAssumedNonNull)
Attrs.emplace_back(Attribute::getWithDereferenceableBytes(
Ctx, getAssumedDereferenceableBytes()));
else
@@ -4913,14 +4956,20 @@ struct AADereferenceableImpl : AADereferenceable {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
if (!getAssumedDereferenceableBytes())
return "unknown-dereferenceable";
+ bool IsKnownNonNull;
+ bool IsAssumedNonNull = false;
+ if (A)
+ IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
+ *A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull);
return std::string("dereferenceable") +
- (isAssumedNonNull() ? "" : "_or_null") +
+ (IsAssumedNonNull ? "" : "_or_null") +
(isAssumedGlobal() ? "_globally" : "") + "<" +
std::to_string(getKnownDereferenceableBytes()) + "-" +
- std::to_string(getAssumedDereferenceableBytes()) + ">";
+ std::to_string(getAssumedDereferenceableBytes()) + ">" +
+ (!A ? " [non-null is unknown]" : "");
}
};
@@ -4931,7 +4980,6 @@ struct AADereferenceableFloating : AADereferenceableImpl {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
-
bool Stripped;
bool UsedAssumedInformation = false;
SmallVector<AA::ValueAndContext> Values;
@@ -4955,10 +5003,10 @@ struct AADereferenceableFloating : AADereferenceableImpl {
A, *this, &V, DL, Offset, /* GetMinOffset */ false,
/* AllowNonInbounds */ true);
- const auto &AA = A.getAAFor<AADereferenceable>(
+ const auto *AA = A.getAAFor<AADereferenceable>(
*this, IRPosition::value(*Base), DepClassTy::REQUIRED);
int64_t DerefBytes = 0;
- if (!Stripped && this == &AA) {
+ if (!AA || (!Stripped && this == AA)) {
// Use IR information if we did not strip anything.
// TODO: track globally.
bool CanBeNull, CanBeFreed;
@@ -4966,7 +5014,7 @@ struct AADereferenceableFloating : AADereferenceableImpl {
Base->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
T.GlobalState.indicatePessimisticFixpoint();
} else {
- const DerefState &DS = AA.getState();
+ const DerefState &DS = AA->getState();
DerefBytes = DS.DerefBytesState.getAssumed();
T.GlobalState &= DS.GlobalState;
}
@@ -4981,7 +5029,7 @@ struct AADereferenceableFloating : AADereferenceableImpl {
T.takeAssumedDerefBytesMinimum(
std::max(int64_t(0), DerefBytes - OffsetSExt));
- if (this == &AA) {
+ if (this == AA) {
if (!Stripped) {
// If nothing was stripped IR information is all we got.
T.takeKnownDerefBytesMaximum(
@@ -5016,9 +5064,10 @@ struct AADereferenceableFloating : AADereferenceableImpl {
/// Dereferenceable attribute for a return value.
struct AADereferenceableReturned final
: AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl> {
+ using Base =
+ AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl>;
AADereferenceableReturned(const IRPosition &IRP, Attributor &A)
- : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl>(
- IRP, A) {}
+ : Base(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
@@ -5095,8 +5144,9 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA,
IRPosition IRP = IRPosition::callsite_argument(*CB, ArgNo);
// As long as we only use known information there is no need to track
// dependences here.
- auto &AlignAA = A.getAAFor<AAAlign>(QueryingAA, IRP, DepClassTy::NONE);
- MA = MaybeAlign(AlignAA.getKnownAlign());
+ auto *AlignAA = A.getAAFor<AAAlign>(QueryingAA, IRP, DepClassTy::NONE);
+ if (AlignAA)
+ MA = MaybeAlign(AlignAA->getKnownAlign());
}
const DataLayout &DL = A.getDataLayout();
@@ -5122,7 +5172,7 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA,
// gcd(Offset, Alignment) is an alignment.
uint32_t gcd = std::gcd(uint32_t(abs((int32_t)Offset)), Alignment);
- Alignment = llvm::PowerOf2Floor(gcd);
+ Alignment = llvm::bit_floor(gcd);
}
}
@@ -5135,20 +5185,13 @@ struct AAAlignImpl : AAAlign {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
SmallVector<Attribute, 4> Attrs;
- getAttrs({Attribute::Alignment}, Attrs);
+ A.getAttrs(getIRPosition(), {Attribute::Alignment}, Attrs);
for (const Attribute &Attr : Attrs)
takeKnownMaximum(Attr.getValueAsInt());
Value &V = *getAssociatedValue().stripPointerCasts();
takeKnownMaximum(V.getPointerAlignment(A.getDataLayout()).value());
- if (getIRPosition().isFnInterfaceKind() &&
- (!getAnchorScope() ||
- !A.isFunctionIPOAmendable(*getAssociatedFunction()))) {
- indicatePessimisticFixpoint();
- return;
- }
-
if (Instruction *CtxI = getCtxI())
followUsesInMBEC(*this, A, getState(), *CtxI);
}
@@ -5193,7 +5236,7 @@ struct AAAlignImpl : AAAlign {
// to avoid making the alignment explicit if it did not improve.
/// See AbstractAttribute::getDeducedAttributes
- void getDeducedAttributes(LLVMContext &Ctx,
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
SmallVectorImpl<Attribute> &Attrs) const override {
if (getAssumedAlign() > 1)
Attrs.emplace_back(
@@ -5213,7 +5256,7 @@ struct AAAlignImpl : AAAlign {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return "align<" + std::to_string(getKnownAlign().value()) + "-" +
std::to_string(getAssumedAlign().value()) + ">";
}
@@ -5243,9 +5286,9 @@ struct AAAlignFloating : AAAlignImpl {
auto VisitValueCB = [&](Value &V) -> bool {
if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V))
return true;
- const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V),
+ const auto *AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V),
DepClassTy::REQUIRED);
- if (!Stripped && this == &AA) {
+ if (!AA || (!Stripped && this == AA)) {
int64_t Offset;
unsigned Alignment = 1;
if (const Value *Base =
@@ -5258,7 +5301,7 @@ struct AAAlignFloating : AAAlignImpl {
uint32_t gcd =
std::gcd(uint32_t(abs((int32_t)Offset)), uint32_t(PA.value()));
- Alignment = llvm::PowerOf2Floor(gcd);
+ Alignment = llvm::bit_floor(gcd);
} else {
Alignment = V.getPointerAlignment(DL).value();
}
@@ -5267,7 +5310,7 @@ struct AAAlignFloating : AAAlignImpl {
T.indicatePessimisticFixpoint();
} else {
// Use abstract attribute information.
- const AAAlign::StateType &DS = AA.getState();
+ const AAAlign::StateType &DS = AA->getState();
T ^= DS;
}
return T.isValidState();
@@ -5293,14 +5336,6 @@ struct AAAlignReturned final
using Base = AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>;
AAAlignReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- Base::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(aligned) }
};
@@ -5351,9 +5386,10 @@ struct AAAlignCallSiteArgument final : AAAlignFloating {
if (Argument *Arg = getAssociatedArgument()) {
// We only take known information from the argument
// so we do not need to track a dependence.
- const auto &ArgAlignAA = A.getAAFor<AAAlign>(
+ const auto *ArgAlignAA = A.getAAFor<AAAlign>(
*this, IRPosition::argument(*Arg), DepClassTy::NONE);
- takeKnownMaximum(ArgAlignAA.getKnownAlign().value());
+ if (ArgAlignAA)
+ takeKnownMaximum(ArgAlignAA->getKnownAlign().value());
}
return Changed;
}
@@ -5369,14 +5405,6 @@ struct AAAlignCallSiteReturned final
AAAlignCallSiteReturned(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- Base::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(align); }
};
@@ -5389,14 +5417,14 @@ struct AANoReturnImpl : public AANoReturn {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- AANoReturn::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoReturn>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "noreturn" : "may-return";
}
@@ -5425,17 +5453,6 @@ struct AANoReturnCallSite final : AANoReturnImpl {
AANoReturnCallSite(const IRPosition &IRP, Attributor &A)
: AANoReturnImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AANoReturnImpl::initialize(A);
- if (Function *F = getAssociatedFunction()) {
- const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos, DepClassTy::REQUIRED);
- if (!FnAA.isAssumedNoReturn())
- indicatePessimisticFixpoint();
- }
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -5444,8 +5461,11 @@ struct AANoReturnCallSite final : AANoReturnImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ bool IsKnownNoReturn;
+ if (!AA::hasAssumedIRAttr<Attribute::NoReturn>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoReturn))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
}
/// See AbstractAttribute::trackStatistics()
@@ -5477,6 +5497,15 @@ struct AAInstanceInfoImpl : public AAInstanceInfo {
indicateOptimisticFixpoint();
return;
}
+ if (auto *I = dyn_cast<Instruction>(&V)) {
+ const auto *CI =
+ A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>(
+ *I->getFunction());
+ if (mayBeInCycle(CI, I, /* HeaderOnly */ false)) {
+ indicatePessimisticFixpoint();
+ return;
+ }
+ }
}
/// See AbstractAttribute::updateImpl(...).
@@ -5495,9 +5524,10 @@ struct AAInstanceInfoImpl : public AAInstanceInfo {
if (!Scope)
return indicateOptimisticFixpoint();
- auto &NoRecurseAA = A.getAAFor<AANoRecurse>(
- *this, IRPosition::function(*Scope), DepClassTy::OPTIONAL);
- if (NoRecurseAA.isAssumedNoRecurse())
+ bool IsKnownNoRecurse;
+ if (AA::hasAssumedIRAttr<Attribute::NoRecurse>(
+ A, this, IRPosition::function(*Scope), DepClassTy::OPTIONAL,
+ IsKnownNoRecurse))
return Changed;
auto UsePred = [&](const Use &U, bool &Follow) {
@@ -5514,15 +5544,16 @@ struct AAInstanceInfoImpl : public AAInstanceInfo {
if (auto *CB = dyn_cast<CallBase>(UserI)) {
// This check is not guaranteeing uniqueness but for now that we cannot
// end up with two versions of \p U thinking it was one.
- if (!CB->getCalledFunction() ||
- !CB->getCalledFunction()->hasLocalLinkage())
+ auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand());
+ if (!Callee || !Callee->hasLocalLinkage())
return true;
if (!CB->isArgOperand(&U))
return false;
- const auto &ArgInstanceInfoAA = A.getAAFor<AAInstanceInfo>(
+ const auto *ArgInstanceInfoAA = A.getAAFor<AAInstanceInfo>(
*this, IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U)),
DepClassTy::OPTIONAL);
- if (!ArgInstanceInfoAA.isAssumedUniqueForAnalysis())
+ if (!ArgInstanceInfoAA ||
+ !ArgInstanceInfoAA->isAssumedUniqueForAnalysis())
return false;
// If this call base might reach the scope again we might forward the
// argument back here. This is very conservative.
@@ -5554,7 +5585,7 @@ struct AAInstanceInfoImpl : public AAInstanceInfo {
}
/// See AbstractState::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return isAssumedUniqueForAnalysis() ? "<unique [fAa]>" : "<unknown>";
}
@@ -5589,9 +5620,11 @@ struct AAInstanceInfoCallSiteArgument final : AAInstanceInfoImpl {
if (!Arg)
return indicatePessimisticFixpoint();
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA =
+ auto *ArgAA =
A.getAAFor<AAInstanceInfo>(*this, ArgPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), ArgAA.getState());
+ if (!ArgAA)
+ return indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(getState(), ArgAA->getState());
}
};
@@ -5621,6 +5654,95 @@ struct AAInstanceInfoCallSiteReturned final : AAInstanceInfoFloating {
} // namespace
/// ----------------------- Variable Capturing ---------------------------------
+bool AANoCapture::isImpliedByIR(Attributor &A, const IRPosition &IRP,
+ Attribute::AttrKind ImpliedAttributeKind,
+ bool IgnoreSubsumingPositions) {
+ assert(ImpliedAttributeKind == Attribute::NoCapture &&
+ "Unexpected attribute kind");
+ Value &V = IRP.getAssociatedValue();
+ if (!IRP.isArgumentPosition())
+ return V.use_empty();
+
+ // You cannot "capture" null in the default address space.
+ if (isa<UndefValue>(V) || (isa<ConstantPointerNull>(V) &&
+ V.getType()->getPointerAddressSpace() == 0)) {
+ return true;
+ }
+
+ if (A.hasAttr(IRP, {Attribute::NoCapture},
+ /* IgnoreSubsumingPositions */ true, Attribute::NoCapture))
+ return true;
+
+ if (IRP.getPositionKind() == IRP_CALL_SITE_ARGUMENT)
+ if (Argument *Arg = IRP.getAssociatedArgument())
+ if (A.hasAttr(IRPosition::argument(*Arg),
+ {Attribute::NoCapture, Attribute::ByVal},
+ /* IgnoreSubsumingPositions */ true)) {
+ A.manifestAttrs(IRP,
+ Attribute::get(V.getContext(), Attribute::NoCapture));
+ return true;
+ }
+
+ if (const Function *F = IRP.getAssociatedFunction()) {
+ // Check what state the associated function can actually capture.
+ AANoCapture::StateType State;
+ determineFunctionCaptureCapabilities(IRP, *F, State);
+ if (State.isKnown(NO_CAPTURE)) {
+ A.manifestAttrs(IRP,
+ Attribute::get(V.getContext(), Attribute::NoCapture));
+ return true;
+ }
+ }
+
+ return false;
+}
+
+/// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known
+/// depending on the ability of the function associated with \p IRP to capture
+/// state in memory and through "returning/throwing", respectively.
+void AANoCapture::determineFunctionCaptureCapabilities(const IRPosition &IRP,
+ const Function &F,
+ BitIntegerState &State) {
+ // TODO: Once we have memory behavior attributes we should use them here.
+
+ // If we know we cannot communicate or write to memory, we do not care about
+ // ptr2int anymore.
+ bool ReadOnly = F.onlyReadsMemory();
+ bool NoThrow = F.doesNotThrow();
+ bool IsVoidReturn = F.getReturnType()->isVoidTy();
+ if (ReadOnly && NoThrow && IsVoidReturn) {
+ State.addKnownBits(NO_CAPTURE);
+ return;
+ }
+
+ // A function cannot capture state in memory if it only reads memory, it can
+ // however return/throw state and the state might be influenced by the
+ // pointer value, e.g., loading from a returned pointer might reveal a bit.
+ if (ReadOnly)
+ State.addKnownBits(NOT_CAPTURED_IN_MEM);
+
+ // A function cannot communicate state back if it does not through
+ // exceptions and doesn not return values.
+ if (NoThrow && IsVoidReturn)
+ State.addKnownBits(NOT_CAPTURED_IN_RET);
+
+ // Check existing "returned" attributes.
+ int ArgNo = IRP.getCalleeArgNo();
+ if (!NoThrow || ArgNo < 0 ||
+ !F.getAttributes().hasAttrSomewhere(Attribute::Returned))
+ return;
+
+ for (unsigned U = 0, E = F.arg_size(); U < E; ++U)
+ if (F.hasParamAttribute(U, Attribute::Returned)) {
+ if (U == unsigned(ArgNo))
+ State.removeAssumedBits(NOT_CAPTURED_IN_RET);
+ else if (ReadOnly)
+ State.addKnownBits(NO_CAPTURE);
+ else
+ State.addKnownBits(NOT_CAPTURED_IN_RET);
+ break;
+ }
+}
namespace {
/// A class to hold the state of for no-capture attributes.
@@ -5629,39 +5751,17 @@ struct AANoCaptureImpl : public AANoCapture {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- if (hasAttr(getAttrKind(), /* IgnoreSubsumingPositions */ true)) {
- indicateOptimisticFixpoint();
- return;
- }
- Function *AnchorScope = getAnchorScope();
- if (isFnInterfaceKind() &&
- (!AnchorScope || !A.isFunctionIPOAmendable(*AnchorScope))) {
- indicatePessimisticFixpoint();
- return;
- }
-
- // You cannot "capture" null in the default address space.
- if (isa<ConstantPointerNull>(getAssociatedValue()) &&
- getAssociatedValue().getType()->getPointerAddressSpace() == 0) {
- indicateOptimisticFixpoint();
- return;
- }
-
- const Function *F =
- isArgumentPosition() ? getAssociatedFunction() : AnchorScope;
-
- // Check what state the associated function can actually capture.
- if (F)
- determineFunctionCaptureCapabilities(getIRPosition(), *F, *this);
- else
- indicatePessimisticFixpoint();
+ bool IsKnown;
+ assert(!AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown));
+ (void)IsKnown;
}
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override;
/// see AbstractAttribute::isAssumedNoCaptureMaybeReturned(...).
- void getDeducedAttributes(LLVMContext &Ctx,
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
SmallVectorImpl<Attribute> &Attrs) const override {
if (!isAssumedNoCaptureMaybeReturned())
return;
@@ -5674,51 +5774,8 @@ struct AANoCaptureImpl : public AANoCapture {
}
}
- /// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known
- /// depending on the ability of the function associated with \p IRP to capture
- /// state in memory and through "returning/throwing", respectively.
- static void determineFunctionCaptureCapabilities(const IRPosition &IRP,
- const Function &F,
- BitIntegerState &State) {
- // TODO: Once we have memory behavior attributes we should use them here.
-
- // If we know we cannot communicate or write to memory, we do not care about
- // ptr2int anymore.
- if (F.onlyReadsMemory() && F.doesNotThrow() &&
- F.getReturnType()->isVoidTy()) {
- State.addKnownBits(NO_CAPTURE);
- return;
- }
-
- // A function cannot capture state in memory if it only reads memory, it can
- // however return/throw state and the state might be influenced by the
- // pointer value, e.g., loading from a returned pointer might reveal a bit.
- if (F.onlyReadsMemory())
- State.addKnownBits(NOT_CAPTURED_IN_MEM);
-
- // A function cannot communicate state back if it does not through
- // exceptions and doesn not return values.
- if (F.doesNotThrow() && F.getReturnType()->isVoidTy())
- State.addKnownBits(NOT_CAPTURED_IN_RET);
-
- // Check existing "returned" attributes.
- int ArgNo = IRP.getCalleeArgNo();
- if (F.doesNotThrow() && ArgNo >= 0) {
- for (unsigned u = 0, e = F.arg_size(); u < e; ++u)
- if (F.hasParamAttribute(u, Attribute::Returned)) {
- if (u == unsigned(ArgNo))
- State.removeAssumedBits(NOT_CAPTURED_IN_RET);
- else if (F.onlyReadsMemory())
- State.addKnownBits(NO_CAPTURE);
- else
- State.addKnownBits(NOT_CAPTURED_IN_RET);
- break;
- }
- }
- }
-
/// See AbstractState::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
if (isKnownNoCapture())
return "known not-captured";
if (isAssumedNoCapture())
@@ -5771,12 +5828,15 @@ struct AANoCaptureImpl : public AANoCapture {
const IRPosition &CSArgPos = IRPosition::callsite_argument(*CB, ArgNo);
// If we have a abstract no-capture attribute for the argument we can use
// it to justify a non-capture attribute here. This allows recursion!
- auto &ArgNoCaptureAA =
- A.getAAFor<AANoCapture>(*this, CSArgPos, DepClassTy::REQUIRED);
- if (ArgNoCaptureAA.isAssumedNoCapture())
+ bool IsKnownNoCapture;
+ const AANoCapture *ArgNoCaptureAA = nullptr;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, CSArgPos, DepClassTy::REQUIRED, IsKnownNoCapture, false,
+ &ArgNoCaptureAA);
+ if (IsAssumedNoCapture)
return isCapturedIn(State, /* Memory */ false, /* Integer */ false,
/* Return */ false);
- if (ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ if (ArgNoCaptureAA && ArgNoCaptureAA->isAssumedNoCaptureMaybeReturned()) {
Follow = true;
return isCapturedIn(State, /* Memory */ false, /* Integer */ false,
/* Return */ false);
@@ -5830,37 +5890,35 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) {
// TODO: we could do this in a more sophisticated way inside
// AAReturnedValues, e.g., track all values that escape through returns
// directly somehow.
- auto CheckReturnedArgs = [&](const AAReturnedValues &RVAA) {
- if (!RVAA.getState().isValidState())
+ auto CheckReturnedArgs = [&](bool &UsedAssumedInformation) {
+ SmallVector<AA::ValueAndContext> Values;
+ if (!A.getAssumedSimplifiedValues(IRPosition::returned(*F), this, Values,
+ AA::ValueScope::Intraprocedural,
+ UsedAssumedInformation))
return false;
bool SeenConstant = false;
- for (const auto &It : RVAA.returned_values()) {
- if (isa<Constant>(It.first)) {
+ for (const AA::ValueAndContext &VAC : Values) {
+ if (isa<Constant>(VAC.getValue())) {
if (SeenConstant)
return false;
SeenConstant = true;
- } else if (!isa<Argument>(It.first) ||
- It.first == getAssociatedArgument())
+ } else if (!isa<Argument>(VAC.getValue()) ||
+ VAC.getValue() == getAssociatedArgument())
return false;
}
return true;
};
- const auto &NoUnwindAA =
- A.getAAFor<AANoUnwind>(*this, FnPos, DepClassTy::OPTIONAL);
- if (NoUnwindAA.isAssumedNoUnwind()) {
+ bool IsKnownNoUnwind;
+ if (AA::hasAssumedIRAttr<Attribute::NoUnwind>(
+ A, this, FnPos, DepClassTy::OPTIONAL, IsKnownNoUnwind)) {
bool IsVoidTy = F->getReturnType()->isVoidTy();
- const AAReturnedValues *RVAA =
- IsVoidTy ? nullptr
- : &A.getAAFor<AAReturnedValues>(*this, FnPos,
-
- DepClassTy::OPTIONAL);
- if (IsVoidTy || CheckReturnedArgs(*RVAA)) {
+ bool UsedAssumedInformation = false;
+ if (IsVoidTy || CheckReturnedArgs(UsedAssumedInformation)) {
T.addKnownBits(NOT_CAPTURED_IN_RET);
if (T.isKnown(NOT_CAPTURED_IN_MEM))
return ChangeStatus::UNCHANGED;
- if (NoUnwindAA.isKnownNoUnwind() &&
- (IsVoidTy || RVAA->getState().isAtFixpoint())) {
+ if (IsKnownNoUnwind && (IsVoidTy || !UsedAssumedInformation)) {
addKnownBits(NOT_CAPTURED_IN_RET);
if (isKnown(NOT_CAPTURED_IN_MEM))
return indicateOptimisticFixpoint();
@@ -5869,9 +5927,9 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) {
}
auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) {
- const auto &DerefAA = A.getAAFor<AADereferenceable>(
+ const auto *DerefAA = A.getAAFor<AADereferenceable>(
*this, IRPosition::value(*O), DepClassTy::OPTIONAL);
- return DerefAA.getAssumedDereferenceableBytes();
+ return DerefAA && DerefAA->getAssumedDereferenceableBytes();
};
auto UseCheck = [&](const Use &U, bool &Follow) -> bool {
@@ -5913,14 +5971,6 @@ struct AANoCaptureCallSiteArgument final : AANoCaptureImpl {
AANoCaptureCallSiteArgument(const IRPosition &IRP, Attributor &A)
: AANoCaptureImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- if (Argument *Arg = getAssociatedArgument())
- if (Arg->hasByValAttr())
- indicateOptimisticFixpoint();
- AANoCaptureImpl::initialize(A);
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -5931,8 +5981,15 @@ struct AANoCaptureCallSiteArgument final : AANoCaptureImpl {
if (!Arg)
return indicatePessimisticFixpoint();
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA = A.getAAFor<AANoCapture>(*this, ArgPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), ArgAA.getState());
+ bool IsKnownNoCapture;
+ const AANoCapture *ArgAA = nullptr;
+ if (AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, ArgPos, DepClassTy::REQUIRED, IsKnownNoCapture, false,
+ &ArgAA))
+ return ChangeStatus::UNCHANGED;
+ if (!ArgAA || !ArgAA->isAssumedNoCaptureMaybeReturned())
+ return indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(getState(), ArgAA->getState());
}
/// See AbstractAttribute::trackStatistics()
@@ -6023,7 +6080,7 @@ struct AAValueSimplifyImpl : AAValueSimplify {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
LLVM_DEBUG({
dbgs() << "SAV: " << (bool)SimplifiedAssociatedValue << " ";
if (SimplifiedAssociatedValue && *SimplifiedAssociatedValue)
@@ -6156,19 +6213,21 @@ struct AAValueSimplifyImpl : AAValueSimplify {
return false;
// This will also pass the call base context.
- const auto &AA =
+ const auto *AA =
A.getAAFor<AAType>(*this, getIRPosition(), DepClassTy::NONE);
+ if (!AA)
+ return false;
- std::optional<Constant *> COpt = AA.getAssumedConstant(A);
+ std::optional<Constant *> COpt = AA->getAssumedConstant(A);
if (!COpt) {
SimplifiedAssociatedValue = std::nullopt;
- A.recordDependence(AA, *this, DepClassTy::OPTIONAL);
+ A.recordDependence(*AA, *this, DepClassTy::OPTIONAL);
return true;
}
if (auto *C = *COpt) {
SimplifiedAssociatedValue = C;
- A.recordDependence(AA, *this, DepClassTy::OPTIONAL);
+ A.recordDependence(*AA, *this, DepClassTy::OPTIONAL);
return true;
}
return false;
@@ -6215,11 +6274,10 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl {
void initialize(Attributor &A) override {
AAValueSimplifyImpl::initialize(A);
- if (!getAnchorScope() || getAnchorScope()->isDeclaration())
- indicatePessimisticFixpoint();
- if (hasAttr({Attribute::InAlloca, Attribute::Preallocated,
- Attribute::StructRet, Attribute::Nest, Attribute::ByVal},
- /* IgnoreSubsumingPositions */ true))
+ if (A.hasAttr(getIRPosition(),
+ {Attribute::InAlloca, Attribute::Preallocated,
+ Attribute::StructRet, Attribute::Nest, Attribute::ByVal},
+ /* IgnoreSubsumingPositions */ true))
indicatePessimisticFixpoint();
}
@@ -6266,7 +6324,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl {
bool Success;
bool UsedAssumedInformation = false;
if (hasCallBaseContext() &&
- getCallBaseContext()->getCalledFunction() == Arg->getParent())
+ getCallBaseContext()->getCalledOperand() == Arg->getParent())
Success = PredForCallSite(
AbstractCallSite(&getCallBaseContext()->getCalledOperandUse()));
else
@@ -6401,10 +6459,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl {
void initialize(Attributor &A) override {
AAValueSimplifyImpl::initialize(A);
Function *Fn = getAssociatedFunction();
- if (!Fn) {
- indicatePessimisticFixpoint();
- return;
- }
+ assert(Fn && "Did expect an associted function");
for (Argument &Arg : Fn->args()) {
if (Arg.hasReturnedAttr()) {
auto IRP = IRPosition::callsite_argument(*cast<CallBase>(getCtxI()),
@@ -6421,26 +6476,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
- auto Before = SimplifiedAssociatedValue;
- auto &RetAA = A.getAAFor<AAReturnedValues>(
- *this, IRPosition::function(*getAssociatedFunction()),
- DepClassTy::REQUIRED);
- auto PredForReturned =
- [&](Value &RetVal, const SmallSetVector<ReturnInst *, 4> &RetInsts) {
- bool UsedAssumedInformation = false;
- std::optional<Value *> CSRetVal =
- A.translateArgumentToCallSiteContent(
- &RetVal, *cast<CallBase>(getCtxI()), *this,
- UsedAssumedInformation);
- SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice(
- SimplifiedAssociatedValue, CSRetVal, getAssociatedType());
- return SimplifiedAssociatedValue != std::optional<Value *>(nullptr);
- };
- if (!RetAA.checkForAllReturnedValuesAndReturnInsts(PredForReturned))
- if (!askSimplifiedValueForOtherAAs(A))
return indicatePessimisticFixpoint();
- return Before == SimplifiedAssociatedValue ? ChangeStatus::UNCHANGED
- : ChangeStatus ::CHANGED;
}
void trackStatistics() const override {
@@ -6581,7 +6617,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack {
SCB);
}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
unsigned NumH2SMallocs = 0, NumInvalidMallocs = 0;
for (const auto &It : AllocationInfos) {
if (It.second->Status == AllocationInfo::INVALID)
@@ -6773,10 +6809,10 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
const Function *F = getAnchorScope();
const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F);
- const auto &LivenessAA =
+ const auto *LivenessAA =
A.getAAFor<AAIsDead>(*this, IRPosition::function(*F), DepClassTy::NONE);
- MustBeExecutedContextExplorer &Explorer =
+ MustBeExecutedContextExplorer *Explorer =
A.getInfoCache().getMustBeExecutedContextExplorer();
bool StackIsAccessibleByOtherThreads =
@@ -6813,7 +6849,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
// No need to analyze dead calls, ignore them instead.
bool UsedAssumedInformation = false;
- if (A.isAssumedDead(*DI.CB, this, &LivenessAA, UsedAssumedInformation,
+ if (A.isAssumedDead(*DI.CB, this, LivenessAA, UsedAssumedInformation,
/* CheckBBLivenessOnly */ true))
continue;
@@ -6855,9 +6891,9 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
// doesn't apply as the pointer could be shared and needs to be places in
// "shareable" memory.
if (!StackIsAccessibleByOtherThreads) {
- auto &NoSyncAA =
- A.getAAFor<AANoSync>(*this, getIRPosition(), DepClassTy::OPTIONAL);
- if (!NoSyncAA.isAssumedNoSync()) {
+ bool IsKnownNoSycn;
+ if (!AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNoSycn)) {
LLVM_DEBUG(
dbgs() << "[H2S] found an escaping use, stack is not accessible by "
"other threads and function is not nosync:\n");
@@ -6902,7 +6938,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
return false;
}
Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode();
- if (!Explorer.findInContextOf(UniqueFree, CtxI)) {
+ if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) {
LLVM_DEBUG(
dbgs()
<< "[H2S] unique free call might not be executed with the allocation "
@@ -6938,22 +6974,21 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
}
unsigned ArgNo = CB->getArgOperandNo(&U);
+ auto CBIRP = IRPosition::callsite_argument(*CB, ArgNo);
- const auto &NoCaptureAA = A.getAAFor<AANoCapture>(
- *this, IRPosition::callsite_argument(*CB, ArgNo),
- DepClassTy::OPTIONAL);
+ bool IsKnownNoCapture;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, CBIRP, DepClassTy::OPTIONAL, IsKnownNoCapture);
// If a call site argument use is nofree, we are fine.
- const auto &ArgNoFreeAA = A.getAAFor<AANoFree>(
- *this, IRPosition::callsite_argument(*CB, ArgNo),
- DepClassTy::OPTIONAL);
+ bool IsKnownNoFree;
+ bool IsAssumedNoFree = AA::hasAssumedIRAttr<Attribute::NoFree>(
+ A, this, CBIRP, DepClassTy::OPTIONAL, IsKnownNoFree);
- bool MaybeCaptured = !NoCaptureAA.isAssumedNoCapture();
- bool MaybeFreed = !ArgNoFreeAA.isAssumedNoFree();
- if (MaybeCaptured ||
+ if (!IsAssumedNoCapture ||
(AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared &&
- MaybeFreed)) {
- AI.HasPotentiallyFreeingUnknownUses |= MaybeFreed;
+ !IsAssumedNoFree)) {
+ AI.HasPotentiallyFreeingUnknownUses |= !IsAssumedNoFree;
// Emit a missed remark if this is missed OpenMP globalization.
auto Remark = [&](OptimizationRemarkMissed ORM) {
@@ -6984,7 +7019,14 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
ValidUsesOnly = false;
return true;
};
- if (!A.checkForAllUses(Pred, *this, *AI.CB))
+ if (!A.checkForAllUses(Pred, *this, *AI.CB, /* CheckBBLivenessOnly */ false,
+ DepClassTy::OPTIONAL, /* IgnoreDroppableUses */ true,
+ [&](const Use &OldU, const Use &NewU) {
+ auto *SI = dyn_cast<StoreInst>(OldU.getUser());
+ return !SI || StackIsAccessibleByOtherThreads ||
+ AA::isAssumedThreadLocalObject(
+ A, *SI->getPointerOperand(), *this);
+ }))
return false;
return ValidUsesOnly;
};
@@ -7018,7 +7060,8 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
}
std::optional<APInt> Size = getSize(A, *this, AI);
- if (MaxHeapToStackSize != -1) {
+ if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared &&
+ MaxHeapToStackSize != -1) {
if (!Size || Size->ugt(MaxHeapToStackSize)) {
LLVM_DEBUG({
if (!Size)
@@ -7078,7 +7121,8 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr {
}
/// Identify the type we can chose for a private copy of the underlying
- /// argument. None means it is not clear yet, nullptr means there is none.
+ /// argument. std::nullopt means it is not clear yet, nullptr means there is
+ /// none.
virtual std::optional<Type *> identifyPrivatizableType(Attributor &A) = 0;
/// Return a privatizable type that encloses both T0 and T1.
@@ -7098,7 +7142,7 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr {
return PrivatizableType;
}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return isAssumedPrivatizablePtr() ? "[priv]" : "[no-priv]";
}
@@ -7118,7 +7162,8 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
// rewrite them), there is no need to check them explicitly.
bool UsedAssumedInformation = false;
SmallVector<Attribute, 1> Attrs;
- getAttrs({Attribute::ByVal}, Attrs, /* IgnoreSubsumingPositions */ true);
+ A.getAttrs(getIRPosition(), {Attribute::ByVal}, Attrs,
+ /* IgnoreSubsumingPositions */ true);
if (!Attrs.empty() &&
A.checkForAllCallSites([](AbstractCallSite ACS) { return true; }, *this,
true, UsedAssumedInformation))
@@ -7141,9 +7186,11 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
return false;
// Check that all call sites agree on a type.
- auto &PrivCSArgAA =
+ auto *PrivCSArgAA =
A.getAAFor<AAPrivatizablePtr>(*this, ACSArgPos, DepClassTy::REQUIRED);
- std::optional<Type *> CSTy = PrivCSArgAA.getPrivatizableType();
+ if (!PrivCSArgAA)
+ return false;
+ std::optional<Type *> CSTy = PrivCSArgAA->getPrivatizableType();
LLVM_DEBUG({
dbgs() << "[AAPrivatizablePtr] ACSPos: " << ACSArgPos << ", CSTy: ";
@@ -7191,7 +7238,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
DepClassTy::OPTIONAL);
// Avoid arguments with padding for now.
- if (!getIRPosition().hasAttr(Attribute::ByVal) &&
+ if (!A.hasAttr(getIRPosition(), Attribute::ByVal) &&
!isDenselyPacked(*PrivatizableType, A.getInfoCache().getDL())) {
LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Padding detected\n");
return indicatePessimisticFixpoint();
@@ -7216,7 +7263,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
auto CallSiteCheck = [&](AbstractCallSite ACS) {
CallBase *CB = ACS.getInstruction();
return TTI->areTypesABICompatible(
- CB->getCaller(), CB->getCalledFunction(), ReplacementTypes);
+ CB->getCaller(),
+ dyn_cast_if_present<Function>(CB->getCalledOperand()),
+ ReplacementTypes);
};
bool UsedAssumedInformation = false;
if (!A.checkForAllCallSites(CallSiteCheck, *this, true,
@@ -7264,10 +7313,10 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
if (CBArgNo != int(ArgNo))
continue;
- const auto &CBArgPrivAA = A.getAAFor<AAPrivatizablePtr>(
+ const auto *CBArgPrivAA = A.getAAFor<AAPrivatizablePtr>(
*this, IRPosition::argument(CBArg), DepClassTy::REQUIRED);
- if (CBArgPrivAA.isValidState()) {
- auto CBArgPrivTy = CBArgPrivAA.getPrivatizableType();
+ if (CBArgPrivAA && CBArgPrivAA->isValidState()) {
+ auto CBArgPrivTy = CBArgPrivAA->getPrivatizableType();
if (!CBArgPrivTy)
continue;
if (*CBArgPrivTy == PrivatizableType)
@@ -7298,23 +7347,23 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
assert(DCArgNo >= 0 && unsigned(DCArgNo) < DC->arg_size() &&
"Expected a direct call operand for callback call operand");
+ Function *DCCallee =
+ dyn_cast_if_present<Function>(DC->getCalledOperand());
LLVM_DEBUG({
dbgs() << "[AAPrivatizablePtr] Argument " << *Arg
<< " check if be privatized in the context of its parent ("
<< Arg->getParent()->getName()
<< ")\n[AAPrivatizablePtr] because it is an argument in a "
"direct call of ("
- << DCArgNo << "@" << DC->getCalledFunction()->getName()
- << ").\n";
+ << DCArgNo << "@" << DCCallee->getName() << ").\n";
});
- Function *DCCallee = DC->getCalledFunction();
if (unsigned(DCArgNo) < DCCallee->arg_size()) {
- const auto &DCArgPrivAA = A.getAAFor<AAPrivatizablePtr>(
+ const auto *DCArgPrivAA = A.getAAFor<AAPrivatizablePtr>(
*this, IRPosition::argument(*DCCallee->getArg(DCArgNo)),
DepClassTy::REQUIRED);
- if (DCArgPrivAA.isValidState()) {
- auto DCArgPrivTy = DCArgPrivAA.getPrivatizableType();
+ if (DCArgPrivAA && DCArgPrivAA->isValidState()) {
+ auto DCArgPrivTy = DCArgPrivAA->getPrivatizableType();
if (!DCArgPrivTy)
return true;
if (*DCArgPrivTy == PrivatizableType)
@@ -7328,7 +7377,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
<< Arg->getParent()->getName()
<< ")\n[AAPrivatizablePtr] because it is an argument in a "
"direct call of ("
- << ACS.getInstruction()->getCalledFunction()->getName()
+ << ACS.getInstruction()->getCalledOperand()->getName()
<< ").\n[AAPrivatizablePtr] for which the argument "
"privatization is not compatible.\n";
});
@@ -7479,7 +7528,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
Argument *Arg = getAssociatedArgument();
// Query AAAlign attribute for alignment of associated argument to
// determine the best alignment of loads.
- const auto &AlignAA =
+ const auto *AlignAA =
A.getAAFor<AAAlign>(*this, IRPosition::value(*Arg), DepClassTy::NONE);
// Callback to repair the associated function. A new alloca is placed at the
@@ -7510,13 +7559,13 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl {
// of the privatizable type are loaded prior to the call and passed to the
// new function version.
Attributor::ArgumentReplacementInfo::ACSRepairCBTy ACSRepairCB =
- [=, &AlignAA](const Attributor::ArgumentReplacementInfo &ARI,
- AbstractCallSite ACS,
- SmallVectorImpl<Value *> &NewArgOperands) {
+ [=](const Attributor::ArgumentReplacementInfo &ARI,
+ AbstractCallSite ACS, SmallVectorImpl<Value *> &NewArgOperands) {
// When no alignment is specified for the load instruction,
// natural alignment is assumed.
createReplacementValues(
- AlignAA.getAssumedAlign(), *PrivatizableType, ACS,
+ AlignAA ? AlignAA->getAssumedAlign() : Align(0),
+ *PrivatizableType, ACS,
ACS.getCallArgOperand(ARI.getReplacedArg().getArgNo()),
NewArgOperands);
};
@@ -7568,10 +7617,10 @@ struct AAPrivatizablePtrFloating : public AAPrivatizablePtrImpl {
if (CI->isOne())
return AI->getAllocatedType();
if (auto *Arg = dyn_cast<Argument>(Obj)) {
- auto &PrivArgAA = A.getAAFor<AAPrivatizablePtr>(
+ auto *PrivArgAA = A.getAAFor<AAPrivatizablePtr>(
*this, IRPosition::argument(*Arg), DepClassTy::REQUIRED);
- if (PrivArgAA.isAssumedPrivatizablePtr())
- return PrivArgAA.getPrivatizableType();
+ if (PrivArgAA && PrivArgAA->isAssumedPrivatizablePtr())
+ return PrivArgAA->getPrivatizableType();
}
LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Underlying object neither valid "
@@ -7593,7 +7642,7 @@ struct AAPrivatizablePtrCallSiteArgument final
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- if (getIRPosition().hasAttr(Attribute::ByVal))
+ if (A.hasAttr(getIRPosition(), Attribute::ByVal))
indicateOptimisticFixpoint();
}
@@ -7606,15 +7655,17 @@ struct AAPrivatizablePtrCallSiteArgument final
return indicatePessimisticFixpoint();
const IRPosition &IRP = getIRPosition();
- auto &NoCaptureAA =
- A.getAAFor<AANoCapture>(*this, IRP, DepClassTy::REQUIRED);
- if (!NoCaptureAA.isAssumedNoCapture()) {
+ bool IsKnownNoCapture;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, IRP, DepClassTy::REQUIRED, IsKnownNoCapture);
+ if (!IsAssumedNoCapture) {
LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] pointer might be captured!\n");
return indicatePessimisticFixpoint();
}
- auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, IRP, DepClassTy::REQUIRED);
- if (!NoAliasAA.isAssumedNoAlias()) {
+ bool IsKnownNoAlias;
+ if (!AA::hasAssumedIRAttr<Attribute::NoAlias>(
+ A, this, IRP, DepClassTy::REQUIRED, IsKnownNoAlias)) {
LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] pointer might alias!\n");
return indicatePessimisticFixpoint();
}
@@ -7679,16 +7730,16 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
intersectAssumedBits(BEST_STATE);
- getKnownStateFromValue(getIRPosition(), getState());
+ getKnownStateFromValue(A, getIRPosition(), getState());
AAMemoryBehavior::initialize(A);
}
/// Return the memory behavior information encoded in the IR for \p IRP.
- static void getKnownStateFromValue(const IRPosition &IRP,
+ static void getKnownStateFromValue(Attributor &A, const IRPosition &IRP,
BitIntegerState &State,
bool IgnoreSubsumingPositions = false) {
SmallVector<Attribute, 2> Attrs;
- IRP.getAttrs(AttrKinds, Attrs, IgnoreSubsumingPositions);
+ A.getAttrs(IRP, AttrKinds, Attrs, IgnoreSubsumingPositions);
for (const Attribute &Attr : Attrs) {
switch (Attr.getKindAsEnum()) {
case Attribute::ReadNone:
@@ -7714,7 +7765,7 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior {
}
/// See AbstractAttribute::getDeducedAttributes(...).
- void getDeducedAttributes(LLVMContext &Ctx,
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
SmallVectorImpl<Attribute> &Attrs) const override {
assert(Attrs.size() == 0);
if (isAssumedReadNone())
@@ -7728,29 +7779,30 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior {
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
- if (hasAttr(Attribute::ReadNone, /* IgnoreSubsumingPositions */ true))
- return ChangeStatus::UNCHANGED;
-
const IRPosition &IRP = getIRPosition();
+ if (A.hasAttr(IRP, Attribute::ReadNone,
+ /* IgnoreSubsumingPositions */ true))
+ return ChangeStatus::UNCHANGED;
+
// Check if we would improve the existing attributes first.
SmallVector<Attribute, 4> DeducedAttrs;
- getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs);
+ getDeducedAttributes(A, IRP.getAnchorValue().getContext(), DeducedAttrs);
if (llvm::all_of(DeducedAttrs, [&](const Attribute &Attr) {
- return IRP.hasAttr(Attr.getKindAsEnum(),
- /* IgnoreSubsumingPositions */ true);
+ return A.hasAttr(IRP, Attr.getKindAsEnum(),
+ /* IgnoreSubsumingPositions */ true);
}))
return ChangeStatus::UNCHANGED;
// Clear existing attributes.
- IRP.removeAttrs(AttrKinds);
+ A.removeAttrs(IRP, AttrKinds);
// Use the generic manifest method.
return IRAttribute::manifest(A);
}
/// See AbstractState::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
if (isAssumedReadNone())
return "readnone";
if (isAssumedReadOnly())
@@ -7807,15 +7859,10 @@ struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating {
// TODO: Make IgnoreSubsumingPositions a property of an IRAttribute so we
// can query it when we use has/getAttr. That would allow us to reuse the
// initialize of the base class here.
- bool HasByVal =
- IRP.hasAttr({Attribute::ByVal}, /* IgnoreSubsumingPositions */ true);
- getKnownStateFromValue(IRP, getState(),
+ bool HasByVal = A.hasAttr(IRP, {Attribute::ByVal},
+ /* IgnoreSubsumingPositions */ true);
+ getKnownStateFromValue(A, IRP, getState(),
/* IgnoreSubsumingPositions */ HasByVal);
-
- // Initialize the use vector with all direct uses of the associated value.
- Argument *Arg = getAssociatedArgument();
- if (!Arg || !A.isFunctionIPOAmendable(*(Arg->getParent())))
- indicatePessimisticFixpoint();
}
ChangeStatus manifest(Attributor &A) override {
@@ -7825,10 +7872,12 @@ struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating {
// TODO: From readattrs.ll: "inalloca parameters are always
// considered written"
- if (hasAttr({Attribute::InAlloca, Attribute::Preallocated})) {
+ if (A.hasAttr(getIRPosition(),
+ {Attribute::InAlloca, Attribute::Preallocated})) {
removeKnownBits(NO_WRITES);
removeAssumedBits(NO_WRITES);
}
+ A.removeAttrs(getIRPosition(), AttrKinds);
return AAMemoryBehaviorFloating::manifest(A);
}
@@ -7874,9 +7923,11 @@ struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument {
// redirecting requests to the callee argument.
Argument *Arg = getAssociatedArgument();
const IRPosition &ArgPos = IRPosition::argument(*Arg);
- auto &ArgAA =
+ auto *ArgAA =
A.getAAFor<AAMemoryBehavior>(*this, ArgPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), ArgAA.getState());
+ if (!ArgAA)
+ return indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(getState(), ArgAA->getState());
}
/// See AbstractAttribute::trackStatistics()
@@ -7898,11 +7949,7 @@ struct AAMemoryBehaviorCallSiteReturned final : AAMemoryBehaviorFloating {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
AAMemoryBehaviorImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
}
-
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
// We do not annotate returned values.
@@ -7935,16 +7982,9 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl {
else if (isAssumedWriteOnly())
ME = MemoryEffects::writeOnly();
- // Intersect with existing memory attribute, as we currently deduce the
- // location and modref portion separately.
- MemoryEffects ExistingME = F.getMemoryEffects();
- ME &= ExistingME;
- if (ME == ExistingME)
- return ChangeStatus::UNCHANGED;
-
- return IRAttributeManifest::manifestAttrs(
- A, getIRPosition(), Attribute::getWithMemoryEffects(F.getContext(), ME),
- /*ForceReplace*/ true);
+ A.removeAttrs(getIRPosition(), AttrKinds);
+ return A.manifestAttrs(getIRPosition(),
+ Attribute::getWithMemoryEffects(F.getContext(), ME));
}
/// See AbstractAttribute::trackStatistics()
@@ -7963,14 +8003,6 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
AAMemoryBehaviorCallSite(const IRPosition &IRP, Attributor &A)
: AAMemoryBehaviorImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AAMemoryBehaviorImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -7979,9 +8011,11 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA =
+ auto *FnAA =
A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::REQUIRED);
- return clampStateAndIndicateChange(getState(), FnAA.getState());
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
+ return clampStateAndIndicateChange(getState(), FnAA->getState());
}
/// See AbstractAttribute::manifest(...).
@@ -7996,17 +8030,9 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
else if (isAssumedWriteOnly())
ME = MemoryEffects::writeOnly();
- // Intersect with existing memory attribute, as we currently deduce the
- // location and modref portion separately.
- MemoryEffects ExistingME = CB.getMemoryEffects();
- ME &= ExistingME;
- if (ME == ExistingME)
- return ChangeStatus::UNCHANGED;
-
- return IRAttributeManifest::manifestAttrs(
- A, getIRPosition(),
- Attribute::getWithMemoryEffects(CB.getContext(), ME),
- /*ForceReplace*/ true);
+ A.removeAttrs(getIRPosition(), AttrKinds);
+ return A.manifestAttrs(
+ getIRPosition(), Attribute::getWithMemoryEffects(CB.getContext(), ME));
}
/// See AbstractAttribute::trackStatistics()
@@ -8030,10 +8056,12 @@ ChangeStatus AAMemoryBehaviorFunction::updateImpl(Attributor &A) {
// the local state. No further analysis is required as the other memory
// state is as optimistic as it gets.
if (const auto *CB = dyn_cast<CallBase>(&I)) {
- const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(
+ const auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(
*this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED);
- intersectAssumedBits(MemBehaviorAA.getAssumed());
- return !isAtFixpoint();
+ if (MemBehaviorAA) {
+ intersectAssumedBits(MemBehaviorAA->getAssumed());
+ return !isAtFixpoint();
+ }
}
// Remove access kind modifiers if necessary.
@@ -8066,12 +8094,14 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) {
AAMemoryBehavior::base_t FnMemAssumedState =
AAMemoryBehavior::StateType::getWorstState();
if (!Arg || !Arg->hasByValAttr()) {
- const auto &FnMemAA =
+ const auto *FnMemAA =
A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::OPTIONAL);
- FnMemAssumedState = FnMemAA.getAssumed();
- S.addKnownBits(FnMemAA.getKnown());
- if ((S.getAssumed() & FnMemAA.getAssumed()) == S.getAssumed())
- return ChangeStatus::UNCHANGED;
+ if (FnMemAA) {
+ FnMemAssumedState = FnMemAA->getAssumed();
+ S.addKnownBits(FnMemAA->getKnown());
+ if ((S.getAssumed() & FnMemAA->getAssumed()) == S.getAssumed())
+ return ChangeStatus::UNCHANGED;
+ }
}
// The current assumed state used to determine a change.
@@ -8081,9 +8111,14 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) {
// it is, any information derived would be irrelevant anyway as we cannot
// check the potential aliases introduced by the capture. However, no need
// to fall back to anythign less optimistic than the function state.
- const auto &ArgNoCaptureAA =
- A.getAAFor<AANoCapture>(*this, IRP, DepClassTy::OPTIONAL);
- if (!ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) {
+ bool IsKnownNoCapture;
+ const AANoCapture *ArgNoCaptureAA = nullptr;
+ bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, IRP, DepClassTy::OPTIONAL, IsKnownNoCapture, false,
+ &ArgNoCaptureAA);
+
+ if (!IsAssumedNoCapture &&
+ (!ArgNoCaptureAA || !ArgNoCaptureAA->isAssumedNoCaptureMaybeReturned())) {
S.intersectAssumedBits(FnMemAssumedState);
return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED
: ChangeStatus::UNCHANGED;
@@ -8137,9 +8172,10 @@ bool AAMemoryBehaviorFloating::followUsersOfUseIn(Attributor &A, const Use &U,
// need to check call users.
if (U.get()->getType()->isPointerTy()) {
unsigned ArgNo = CB->getArgOperandNo(&U);
- const auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>(
- *this, IRPosition::callsite_argument(*CB, ArgNo), DepClassTy::OPTIONAL);
- return !ArgNoCaptureAA.isAssumedNoCapture();
+ bool IsKnownNoCapture;
+ return !AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, IRPosition::callsite_argument(*CB, ArgNo),
+ DepClassTy::OPTIONAL, IsKnownNoCapture);
}
return true;
@@ -8195,11 +8231,13 @@ void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use &U,
Pos = IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U));
else
Pos = IRPosition::callsite_function(*CB);
- const auto &MemBehaviorAA =
+ const auto *MemBehaviorAA =
A.getAAFor<AAMemoryBehavior>(*this, Pos, DepClassTy::OPTIONAL);
+ if (!MemBehaviorAA)
+ break;
// "assumed" has at most the same bits as the MemBehaviorAA assumed
// and at least "known".
- intersectAssumedBits(MemBehaviorAA.getAssumed());
+ intersectAssumedBits(MemBehaviorAA->getAssumed());
return;
}
};
@@ -8286,7 +8324,7 @@ struct AAMemoryLocationImpl : public AAMemoryLocation {
UseArgMemOnly = !AnchorFn->hasLocalLinkage();
SmallVector<Attribute, 2> Attrs;
- IRP.getAttrs({Attribute::Memory}, Attrs, IgnoreSubsumingPositions);
+ A.getAttrs(IRP, {Attribute::Memory}, Attrs, IgnoreSubsumingPositions);
for (const Attribute &Attr : Attrs) {
// TODO: We can map MemoryEffects to Attributor locations more precisely.
MemoryEffects ME = Attr.getMemoryEffects();
@@ -8304,11 +8342,10 @@ struct AAMemoryLocationImpl : public AAMemoryLocation {
else {
// Remove location information, only keep read/write info.
ME = MemoryEffects(ME.getModRef());
- IRAttributeManifest::manifestAttrs(
- A, IRP,
- Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(),
- ME),
- /*ForceReplace*/ true);
+ A.manifestAttrs(IRP,
+ Attribute::getWithMemoryEffects(
+ IRP.getAnchorValue().getContext(), ME),
+ /*ForceReplace*/ true);
}
continue;
}
@@ -8319,11 +8356,10 @@ struct AAMemoryLocationImpl : public AAMemoryLocation {
else {
// Remove location information, only keep read/write info.
ME = MemoryEffects(ME.getModRef());
- IRAttributeManifest::manifestAttrs(
- A, IRP,
- Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(),
- ME),
- /*ForceReplace*/ true);
+ A.manifestAttrs(IRP,
+ Attribute::getWithMemoryEffects(
+ IRP.getAnchorValue().getContext(), ME),
+ /*ForceReplace*/ true);
}
continue;
}
@@ -8331,7 +8367,7 @@ struct AAMemoryLocationImpl : public AAMemoryLocation {
}
/// See AbstractAttribute::getDeducedAttributes(...).
- void getDeducedAttributes(LLVMContext &Ctx,
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
SmallVectorImpl<Attribute> &Attrs) const override {
// TODO: We can map Attributor locations to MemoryEffects more precisely.
assert(Attrs.size() == 0);
@@ -8359,27 +8395,13 @@ struct AAMemoryLocationImpl : public AAMemoryLocation {
const IRPosition &IRP = getIRPosition();
SmallVector<Attribute, 1> DeducedAttrs;
- getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs);
+ getDeducedAttributes(A, IRP.getAnchorValue().getContext(), DeducedAttrs);
if (DeducedAttrs.size() != 1)
return ChangeStatus::UNCHANGED;
MemoryEffects ME = DeducedAttrs[0].getMemoryEffects();
- // Intersect with existing memory attribute, as we currently deduce the
- // location and modref portion separately.
- SmallVector<Attribute, 1> ExistingAttrs;
- IRP.getAttrs({Attribute::Memory}, ExistingAttrs,
- /* IgnoreSubsumingPositions */ true);
- if (ExistingAttrs.size() == 1) {
- MemoryEffects ExistingME = ExistingAttrs[0].getMemoryEffects();
- ME &= ExistingME;
- if (ME == ExistingME)
- return ChangeStatus::UNCHANGED;
- }
-
- return IRAttributeManifest::manifestAttrs(
- A, IRP,
- Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), ME),
- /*ForceReplace*/ true);
+ return A.manifestAttrs(IRP, Attribute::getWithMemoryEffects(
+ IRP.getAnchorValue().getContext(), ME));
}
/// See AAMemoryLocation::checkForAllAccessesToMemoryKind(...).
@@ -8492,13 +8514,16 @@ protected:
if (!Accesses)
Accesses = new (Allocator) AccessSet();
Changed |= Accesses->insert(AccessInfo{I, Ptr, AK}).second;
+ if (MLK == NO_UNKOWN_MEM)
+ MLK = NO_LOCATIONS;
State.removeAssumedBits(MLK);
}
/// Determine the underlying locations kinds for \p Ptr, e.g., globals or
/// arguments, and update the state and access map accordingly.
void categorizePtrValue(Attributor &A, const Instruction &I, const Value &Ptr,
- AAMemoryLocation::StateType &State, bool &Changed);
+ AAMemoryLocation::StateType &State, bool &Changed,
+ unsigned AccessAS = 0);
/// Used to allocate access sets.
BumpPtrAllocator &Allocator;
@@ -8506,14 +8531,24 @@ protected:
void AAMemoryLocationImpl::categorizePtrValue(
Attributor &A, const Instruction &I, const Value &Ptr,
- AAMemoryLocation::StateType &State, bool &Changed) {
+ AAMemoryLocation::StateType &State, bool &Changed, unsigned AccessAS) {
LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Categorize pointer locations for "
<< Ptr << " ["
<< getMemoryLocationsAsStr(State.getAssumed()) << "]\n");
auto Pred = [&](Value &Obj) {
+ unsigned ObjectAS = Obj.getType()->getPointerAddressSpace();
// TODO: recognize the TBAA used for constant accesses.
MemoryLocationsKind MLK = NO_LOCATIONS;
+
+ // Filter accesses to constant (GPU) memory if we have an AS at the access
+ // site or the object is known to actually have the associated AS.
+ if ((AccessAS == (unsigned)AA::GPUAddressSpace::Constant ||
+ (ObjectAS == (unsigned)AA::GPUAddressSpace::Constant &&
+ isIdentifiedObject(&Obj))) &&
+ AA::isGPU(*I.getModule()))
+ return true;
+
if (isa<UndefValue>(&Obj))
return true;
if (isa<Argument>(&Obj)) {
@@ -8537,15 +8572,16 @@ void AAMemoryLocationImpl::categorizePtrValue(
else
MLK = NO_GLOBAL_EXTERNAL_MEM;
} else if (isa<ConstantPointerNull>(&Obj) &&
- !NullPointerIsDefined(getAssociatedFunction(),
- Ptr.getType()->getPointerAddressSpace())) {
+ (!NullPointerIsDefined(getAssociatedFunction(), AccessAS) ||
+ !NullPointerIsDefined(getAssociatedFunction(), ObjectAS))) {
return true;
} else if (isa<AllocaInst>(&Obj)) {
MLK = NO_LOCAL_MEM;
} else if (const auto *CB = dyn_cast<CallBase>(&Obj)) {
- const auto &NoAliasAA = A.getAAFor<AANoAlias>(
- *this, IRPosition::callsite_returned(*CB), DepClassTy::OPTIONAL);
- if (NoAliasAA.isAssumedNoAlias())
+ bool IsKnownNoAlias;
+ if (AA::hasAssumedIRAttr<Attribute::NoAlias>(
+ A, this, IRPosition::callsite_returned(*CB), DepClassTy::OPTIONAL,
+ IsKnownNoAlias))
MLK = NO_MALLOCED_MEM;
else
MLK = NO_UNKOWN_MEM;
@@ -8556,15 +8592,15 @@ void AAMemoryLocationImpl::categorizePtrValue(
assert(MLK != NO_LOCATIONS && "No location specified!");
LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Ptr value can be categorized: "
<< Obj << " -> " << getMemoryLocationsAsStr(MLK) << "\n");
- updateStateAndAccessesMap(getState(), MLK, &I, &Obj, Changed,
+ updateStateAndAccessesMap(State, MLK, &I, &Obj, Changed,
getAccessKindFromInst(&I));
return true;
};
- const auto &AA = A.getAAFor<AAUnderlyingObjects>(
+ const auto *AA = A.getAAFor<AAUnderlyingObjects>(
*this, IRPosition::value(Ptr), DepClassTy::OPTIONAL);
- if (!AA.forallUnderlyingObjects(Pred, AA::Intraprocedural)) {
+ if (!AA || !AA->forallUnderlyingObjects(Pred, AA::Intraprocedural)) {
LLVM_DEBUG(
dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n");
updateStateAndAccessesMap(State, NO_UNKOWN_MEM, &I, nullptr, Changed,
@@ -8589,10 +8625,10 @@ void AAMemoryLocationImpl::categorizeArgumentPointerLocations(
// Skip readnone arguments.
const IRPosition &ArgOpIRP = IRPosition::callsite_argument(CB, ArgNo);
- const auto &ArgOpMemLocationAA =
+ const auto *ArgOpMemLocationAA =
A.getAAFor<AAMemoryBehavior>(*this, ArgOpIRP, DepClassTy::OPTIONAL);
- if (ArgOpMemLocationAA.isAssumedReadNone())
+ if (ArgOpMemLocationAA && ArgOpMemLocationAA->isAssumedReadNone())
continue;
// Categorize potentially accessed pointer arguments as if there was an
@@ -8613,22 +8649,27 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I,
if (auto *CB = dyn_cast<CallBase>(&I)) {
// First check if we assume any memory is access is visible.
- const auto &CBMemLocationAA = A.getAAFor<AAMemoryLocation>(
+ const auto *CBMemLocationAA = A.getAAFor<AAMemoryLocation>(
*this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Categorize call site: " << I
<< " [" << CBMemLocationAA << "]\n");
+ if (!CBMemLocationAA) {
+ updateStateAndAccessesMap(AccessedLocs, NO_UNKOWN_MEM, &I, nullptr,
+ Changed, getAccessKindFromInst(&I));
+ return NO_UNKOWN_MEM;
+ }
- if (CBMemLocationAA.isAssumedReadNone())
+ if (CBMemLocationAA->isAssumedReadNone())
return NO_LOCATIONS;
- if (CBMemLocationAA.isAssumedInaccessibleMemOnly()) {
+ if (CBMemLocationAA->isAssumedInaccessibleMemOnly()) {
updateStateAndAccessesMap(AccessedLocs, NO_INACCESSIBLE_MEM, &I, nullptr,
Changed, getAccessKindFromInst(&I));
return AccessedLocs.getAssumed();
}
uint32_t CBAssumedNotAccessedLocs =
- CBMemLocationAA.getAssumedNotAccessedLocation();
+ CBMemLocationAA->getAssumedNotAccessedLocation();
// Set the argmemonly and global bit as we handle them separately below.
uint32_t CBAssumedNotAccessedLocsNoArgMem =
@@ -8651,7 +8692,7 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I,
getAccessKindFromInst(&I));
return true;
};
- if (!CBMemLocationAA.checkForAllAccessesToMemoryKind(
+ if (!CBMemLocationAA->checkForAllAccessesToMemoryKind(
AccessPred, inverseLocation(NO_GLOBAL_MEM, false, false)))
return AccessedLocs.getWorstState();
}
@@ -8676,7 +8717,8 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I,
LLVM_DEBUG(
dbgs() << "[AAMemoryLocation] Categorize memory access with pointer: "
<< I << " [" << *Ptr << "]\n");
- categorizePtrValue(A, I, *Ptr, AccessedLocs, Changed);
+ categorizePtrValue(A, I, *Ptr, AccessedLocs, Changed,
+ Ptr->getType()->getPointerAddressSpace());
return AccessedLocs.getAssumed();
}
@@ -8695,14 +8737,14 @@ struct AAMemoryLocationFunction final : public AAMemoryLocationImpl {
/// See AbstractAttribute::updateImpl(Attributor &A).
ChangeStatus updateImpl(Attributor &A) override {
- const auto &MemBehaviorAA =
+ const auto *MemBehaviorAA =
A.getAAFor<AAMemoryBehavior>(*this, getIRPosition(), DepClassTy::NONE);
- if (MemBehaviorAA.isAssumedReadNone()) {
- if (MemBehaviorAA.isKnownReadNone())
+ if (MemBehaviorAA && MemBehaviorAA->isAssumedReadNone()) {
+ if (MemBehaviorAA->isKnownReadNone())
return indicateOptimisticFixpoint();
assert(isAssumedReadNone() &&
"AAMemoryLocation was not read-none but AAMemoryBehavior was!");
- A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL);
+ A.recordDependence(*MemBehaviorAA, *this, DepClassTy::OPTIONAL);
return ChangeStatus::UNCHANGED;
}
@@ -8747,14 +8789,6 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl {
AAMemoryLocationCallSite(const IRPosition &IRP, Attributor &A)
: AAMemoryLocationImpl(IRP, A) {}
- /// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {
- AAMemoryLocationImpl::initialize(A);
- Function *F = getAssociatedFunction();
- if (!F || F->isDeclaration())
- indicatePessimisticFixpoint();
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
@@ -8763,8 +8797,10 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl {
// redirecting requests to the callee argument.
Function *F = getAssociatedFunction();
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA =
+ auto *FnAA =
A.getAAFor<AAMemoryLocation>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
bool Changed = false;
auto AccessPred = [&](const Instruction *I, const Value *Ptr,
AccessKind Kind, MemoryLocationsKind MLK) {
@@ -8772,7 +8808,7 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl {
getAccessKindFromInst(I));
return true;
};
- if (!FnAA.checkForAllAccessesToMemoryKind(AccessPred, ALL_LOCATIONS))
+ if (!FnAA->checkForAllAccessesToMemoryKind(AccessPred, ALL_LOCATIONS))
return indicatePessimisticFixpoint();
return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
}
@@ -8808,7 +8844,7 @@ struct AAValueConstantRangeImpl : AAValueConstantRange {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
std::string Str;
llvm::raw_string_ostream OS(Str);
OS << "range(" << getBitWidth() << ")<";
@@ -9023,15 +9059,6 @@ struct AAValueConstantRangeArgument final
AAValueConstantRangeArgument(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
- /// See AbstractAttribute::initialize(..).
- void initialize(Attributor &A) override {
- if (!getAnchorScope() || getAnchorScope()->isDeclaration()) {
- indicatePessimisticFixpoint();
- } else {
- Base::initialize(A);
- }
- }
-
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
STATS_DECLTRACK_ARG_ATTR(value_range)
@@ -9052,7 +9079,10 @@ struct AAValueConstantRangeReturned
: Base(IRP, A) {}
/// See AbstractAttribute::initialize(...).
- void initialize(Attributor &A) override {}
+ void initialize(Attributor &A) override {
+ if (!A.isFunctionIPOAmendable(*getAssociatedFunction()))
+ indicatePessimisticFixpoint();
+ }
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
@@ -9141,17 +9171,21 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy())
return false;
- auto &LHSAA = A.getAAFor<AAValueConstantRange>(
+ auto *LHSAA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*LHS, getCallBaseContext()),
DepClassTy::REQUIRED);
- QuerriedAAs.push_back(&LHSAA);
- auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI);
+ if (!LHSAA)
+ return false;
+ QuerriedAAs.push_back(LHSAA);
+ auto LHSAARange = LHSAA->getAssumedConstantRange(A, CtxI);
- auto &RHSAA = A.getAAFor<AAValueConstantRange>(
+ auto *RHSAA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*RHS, getCallBaseContext()),
DepClassTy::REQUIRED);
- QuerriedAAs.push_back(&RHSAA);
- auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI);
+ if (!RHSAA)
+ return false;
+ QuerriedAAs.push_back(RHSAA);
+ auto RHSAARange = RHSAA->getAssumedConstantRange(A, CtxI);
auto AssumedRange = LHSAARange.binaryOp(BinOp->getOpcode(), RHSAARange);
@@ -9184,12 +9218,14 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
if (!OpV->getType()->isIntegerTy())
return false;
- auto &OpAA = A.getAAFor<AAValueConstantRange>(
+ auto *OpAA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*OpV, getCallBaseContext()),
DepClassTy::REQUIRED);
- QuerriedAAs.push_back(&OpAA);
- T.unionAssumed(
- OpAA.getAssumed().castOp(CastI->getOpcode(), getState().getBitWidth()));
+ if (!OpAA)
+ return false;
+ QuerriedAAs.push_back(OpAA);
+ T.unionAssumed(OpAA->getAssumed().castOp(CastI->getOpcode(),
+ getState().getBitWidth()));
return T.isValidState();
}
@@ -9224,16 +9260,20 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy())
return false;
- auto &LHSAA = A.getAAFor<AAValueConstantRange>(
+ auto *LHSAA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*LHS, getCallBaseContext()),
DepClassTy::REQUIRED);
- QuerriedAAs.push_back(&LHSAA);
- auto &RHSAA = A.getAAFor<AAValueConstantRange>(
+ if (!LHSAA)
+ return false;
+ QuerriedAAs.push_back(LHSAA);
+ auto *RHSAA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*RHS, getCallBaseContext()),
DepClassTy::REQUIRED);
- QuerriedAAs.push_back(&RHSAA);
- auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI);
- auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI);
+ if (!RHSAA)
+ return false;
+ QuerriedAAs.push_back(RHSAA);
+ auto LHSAARange = LHSAA->getAssumedConstantRange(A, CtxI);
+ auto RHSAARange = RHSAA->getAssumedConstantRange(A, CtxI);
// If one of them is empty set, we can't decide.
if (LHSAARange.isEmptySet() || RHSAARange.isEmptySet())
@@ -9260,8 +9300,10 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
else
T.unionAssumed(ConstantRange(/* BitWidth */ 1, /* isFullSet */ true));
- LLVM_DEBUG(dbgs() << "[AAValueConstantRange] " << *CmpI << " " << LHSAA
- << " " << RHSAA << "\n");
+ LLVM_DEBUG(dbgs() << "[AAValueConstantRange] " << *CmpI << " after "
+ << (MustTrue ? "true" : (MustFalse ? "false" : "unknown"))
+ << ": " << T << "\n\t" << *LHSAA << "\t<op>\n\t"
+ << *RHSAA);
// TODO: Track a known state too.
return T.isValidState();
@@ -9287,12 +9329,15 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
Value *VPtr = *SimplifiedOpV;
// If the value is not instruction, we query AA to Attributor.
- const auto &AA = A.getAAFor<AAValueConstantRange>(
+ const auto *AA = A.getAAFor<AAValueConstantRange>(
*this, IRPosition::value(*VPtr, getCallBaseContext()),
DepClassTy::REQUIRED);
// Clamp operator is not used to utilize a program point CtxI.
- T.unionAssumed(AA.getAssumedConstantRange(A, CtxI));
+ if (AA)
+ T.unionAssumed(AA->getAssumedConstantRange(A, CtxI));
+ else
+ return false;
return T.isValidState();
}
@@ -9454,12 +9499,12 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues {
return false;
if (!IRP.getAssociatedType()->isIntegerTy())
return false;
- auto &PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>(
+ auto *PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>(
*this, IRP, DepClassTy::REQUIRED);
- if (!PotentialValuesAA.getState().isValidState())
+ if (!PotentialValuesAA || !PotentialValuesAA->getState().isValidState())
return false;
- ContainsUndef = PotentialValuesAA.getState().undefIsContained();
- S = PotentialValuesAA.getState().getAssumedSet();
+ ContainsUndef = PotentialValuesAA->getState().undefIsContained();
+ S = PotentialValuesAA->getState().getAssumedSet();
return true;
}
@@ -9483,7 +9528,7 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
std::string Str;
llvm::raw_string_ostream OS(Str);
OS << getState();
@@ -9506,15 +9551,6 @@ struct AAPotentialConstantValuesArgument final
AAPotentialConstantValuesArgument(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
- /// See AbstractAttribute::initialize(..).
- void initialize(Attributor &A) override {
- if (!getAnchorScope() || getAnchorScope()->isDeclaration()) {
- indicatePessimisticFixpoint();
- } else {
- Base::initialize(A);
- }
- }
-
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
STATS_DECLTRACK_ARG_ATTR(potential_values)
@@ -9529,6 +9565,12 @@ struct AAPotentialConstantValuesReturned
AAPotentialConstantValuesReturned(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
+ void initialize(Attributor &A) override {
+ if (!A.isFunctionIPOAmendable(*getAssociatedFunction()))
+ indicatePessimisticFixpoint();
+ Base::initialize(A);
+ }
+
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
STATS_DECLTRACK_FNRET_ATTR(potential_values)
@@ -9958,9 +10000,11 @@ struct AAPotentialConstantValuesCallSiteArgument
ChangeStatus updateImpl(Attributor &A) override {
Value &V = getAssociatedValue();
auto AssumedBefore = getAssumed();
- auto &AA = A.getAAFor<AAPotentialConstantValues>(
+ auto *AA = A.getAAFor<AAPotentialConstantValues>(
*this, IRPosition::value(V), DepClassTy::REQUIRED);
- const auto &S = AA.getAssumed();
+ if (!AA)
+ return indicatePessimisticFixpoint();
+ const auto &S = AA->getAssumed();
unionAssumed(S);
return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
@@ -9971,27 +10015,39 @@ struct AAPotentialConstantValuesCallSiteArgument
STATS_DECLTRACK_CSARG_ATTR(potential_values)
}
};
+} // namespace
/// ------------------------ NoUndef Attribute ---------------------------------
+bool AANoUndef::isImpliedByIR(Attributor &A, const IRPosition &IRP,
+ Attribute::AttrKind ImpliedAttributeKind,
+ bool IgnoreSubsumingPositions) {
+ assert(ImpliedAttributeKind == Attribute::NoUndef &&
+ "Unexpected attribute kind");
+ if (A.hasAttr(IRP, {Attribute::NoUndef}, IgnoreSubsumingPositions,
+ Attribute::NoUndef))
+ return true;
+
+ Value &Val = IRP.getAssociatedValue();
+ if (IRP.getPositionKind() != IRPosition::IRP_RETURNED &&
+ isGuaranteedNotToBeUndefOrPoison(&Val)) {
+ LLVMContext &Ctx = Val.getContext();
+ A.manifestAttrs(IRP, Attribute::get(Ctx, Attribute::NoUndef));
+ return true;
+ }
+
+ return false;
+}
+
+namespace {
struct AANoUndefImpl : AANoUndef {
AANoUndefImpl(const IRPosition &IRP, Attributor &A) : AANoUndef(IRP, A) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- if (getIRPosition().hasAttr({Attribute::NoUndef})) {
- indicateOptimisticFixpoint();
- return;
- }
Value &V = getAssociatedValue();
if (isa<UndefValue>(V))
indicatePessimisticFixpoint();
- else if (isa<FreezeInst>(V))
- indicateOptimisticFixpoint();
- else if (getPositionKind() != IRPosition::IRP_RETURNED &&
- isGuaranteedNotToBeUndefOrPoison(&V))
- indicateOptimisticFixpoint();
- else
- AANoUndef::initialize(A);
+ assert(!isImpliedByIR(A, getIRPosition(), Attribute::NoUndef));
}
/// See followUsesInMBEC
@@ -10015,7 +10071,7 @@ struct AANoUndefImpl : AANoUndef {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return getAssumed() ? "noundef" : "may-undef-or-poison";
}
@@ -10052,33 +10108,39 @@ struct AANoUndefFloating : public AANoUndefImpl {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
+ auto VisitValueCB = [&](const IRPosition &IRP) -> bool {
+ bool IsKnownNoUndef;
+ return AA::hasAssumedIRAttr<Attribute::NoUndef>(
+ A, this, IRP, DepClassTy::REQUIRED, IsKnownNoUndef);
+ };
- SmallVector<AA::ValueAndContext> Values;
+ bool Stripped;
bool UsedAssumedInformation = false;
+ Value *AssociatedValue = &getAssociatedValue();
+ SmallVector<AA::ValueAndContext> Values;
if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values,
- AA::AnyScope, UsedAssumedInformation)) {
- Values.push_back({getAssociatedValue(), getCtxI()});
+ AA::AnyScope, UsedAssumedInformation))
+ Stripped = false;
+ else
+ Stripped =
+ Values.size() != 1 || Values.front().getValue() != AssociatedValue;
+
+ if (!Stripped) {
+ // If we haven't stripped anything we might still be able to use a
+ // different AA, but only if the IRP changes. Effectively when we
+ // interpret this not as a call site value but as a floating/argument
+ // value.
+ const IRPosition AVIRP = IRPosition::value(*AssociatedValue);
+ if (AVIRP == getIRPosition() || !VisitValueCB(AVIRP))
+ return indicatePessimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
}
- StateType T;
- auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool {
- const auto &AA = A.getAAFor<AANoUndef>(*this, IRPosition::value(V),
- DepClassTy::REQUIRED);
- if (this == &AA) {
- T.indicatePessimisticFixpoint();
- } else {
- const AANoUndef::StateType &S =
- static_cast<const AANoUndef::StateType &>(AA.getState());
- T ^= S;
- }
- return T.isValidState();
- };
-
for (const auto &VAC : Values)
- if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI()))
+ if (!VisitValueCB(IRPosition::value(*VAC.getValue())))
return indicatePessimisticFixpoint();
- return clampStateAndIndicateChange(getState(), T);
+ return ChangeStatus::UNCHANGED;
}
/// See AbstractAttribute::trackStatistics()
@@ -10086,18 +10148,26 @@ struct AANoUndefFloating : public AANoUndefImpl {
};
struct AANoUndefReturned final
- : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> {
+ : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef> {
AANoUndefReturned(const IRPosition &IRP, Attributor &A)
- : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {}
+ : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) }
};
struct AANoUndefArgument final
- : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> {
+ : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef> {
AANoUndefArgument(const IRPosition &IRP, Attributor &A)
- : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {}
+ : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) }
@@ -10112,14 +10182,173 @@ struct AANoUndefCallSiteArgument final : AANoUndefFloating {
};
struct AANoUndefCallSiteReturned final
- : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl> {
+ : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef> {
AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl>(IRP, A) {}
+ : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl,
+ AANoUndef::StateType, false,
+ Attribute::NoUndef>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) }
};
+/// ------------------------ NoFPClass Attribute -------------------------------
+
+struct AANoFPClassImpl : AANoFPClass {
+ AANoFPClassImpl(const IRPosition &IRP, Attributor &A) : AANoFPClass(IRP, A) {}
+
+ void initialize(Attributor &A) override {
+ const IRPosition &IRP = getIRPosition();
+
+ Value &V = IRP.getAssociatedValue();
+ if (isa<UndefValue>(V)) {
+ indicateOptimisticFixpoint();
+ return;
+ }
+
+ SmallVector<Attribute> Attrs;
+ A.getAttrs(getIRPosition(), {Attribute::NoFPClass}, Attrs, false);
+ for (const auto &Attr : Attrs) {
+ addKnownBits(Attr.getNoFPClass());
+ return;
+ }
+
+ const DataLayout &DL = A.getDataLayout();
+ if (getPositionKind() != IRPosition::IRP_RETURNED) {
+ KnownFPClass KnownFPClass = computeKnownFPClass(&V, DL);
+ addKnownBits(~KnownFPClass.KnownFPClasses);
+ }
+
+ if (Instruction *CtxI = getCtxI())
+ followUsesInMBEC(*this, A, getState(), *CtxI);
+ }
+
+ /// See followUsesInMBEC
+ bool followUseInMBEC(Attributor &A, const Use *U, const Instruction *I,
+ AANoFPClass::StateType &State) {
+ const Value *UseV = U->get();
+ const DominatorTree *DT = nullptr;
+ AssumptionCache *AC = nullptr;
+ const TargetLibraryInfo *TLI = nullptr;
+ InformationCache &InfoCache = A.getInfoCache();
+
+ if (Function *F = getAnchorScope()) {
+ DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F);
+ AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F);
+ TLI = InfoCache.getTargetLibraryInfoForFunction(*F);
+ }
+
+ const DataLayout &DL = A.getDataLayout();
+
+ KnownFPClass KnownFPClass =
+ computeKnownFPClass(UseV, DL,
+ /*InterestedClasses=*/fcAllFlags,
+ /*Depth=*/0, TLI, AC, I, DT);
+ State.addKnownBits(~KnownFPClass.KnownFPClasses);
+
+ bool TrackUse = false;
+ return TrackUse;
+ }
+
+ const std::string getAsStr(Attributor *A) const override {
+ std::string Result = "nofpclass";
+ raw_string_ostream OS(Result);
+ OS << getAssumedNoFPClass();
+ return Result;
+ }
+
+ void getDeducedAttributes(Attributor &A, LLVMContext &Ctx,
+ SmallVectorImpl<Attribute> &Attrs) const override {
+ Attrs.emplace_back(Attribute::getWithNoFPClass(Ctx, getAssumedNoFPClass()));
+ }
+};
+
+struct AANoFPClassFloating : public AANoFPClassImpl {
+ AANoFPClassFloating(const IRPosition &IRP, Attributor &A)
+ : AANoFPClassImpl(IRP, A) {}
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ SmallVector<AA::ValueAndContext> Values;
+ bool UsedAssumedInformation = false;
+ if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values,
+ AA::AnyScope, UsedAssumedInformation)) {
+ Values.push_back({getAssociatedValue(), getCtxI()});
+ }
+
+ StateType T;
+ auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool {
+ const auto *AA = A.getAAFor<AANoFPClass>(*this, IRPosition::value(V),
+ DepClassTy::REQUIRED);
+ if (!AA || this == AA) {
+ T.indicatePessimisticFixpoint();
+ } else {
+ const AANoFPClass::StateType &S =
+ static_cast<const AANoFPClass::StateType &>(AA->getState());
+ T ^= S;
+ }
+ return T.isValidState();
+ };
+
+ for (const auto &VAC : Values)
+ if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI()))
+ return indicatePessimisticFixpoint();
+
+ return clampStateAndIndicateChange(getState(), T);
+ }
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(nofpclass)
+ }
+};
+
+struct AANoFPClassReturned final
+ : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl,
+ AANoFPClassImpl::StateType, false, Attribute::None, false> {
+ AANoFPClassReturned(const IRPosition &IRP, Attributor &A)
+ : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl,
+ AANoFPClassImpl::StateType, false, Attribute::None, false>(
+ IRP, A) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(nofpclass)
+ }
+};
+
+struct AANoFPClassArgument final
+ : AAArgumentFromCallSiteArguments<AANoFPClass, AANoFPClassImpl> {
+ AANoFPClassArgument(const IRPosition &IRP, Attributor &A)
+ : AAArgumentFromCallSiteArguments<AANoFPClass, AANoFPClassImpl>(IRP, A) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nofpclass) }
+};
+
+struct AANoFPClassCallSiteArgument final : AANoFPClassFloating {
+ AANoFPClassCallSiteArgument(const IRPosition &IRP, Attributor &A)
+ : AANoFPClassFloating(IRP, A) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSARG_ATTR(nofpclass)
+ }
+};
+
+struct AANoFPClassCallSiteReturned final
+ : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl> {
+ AANoFPClassCallSiteReturned(const IRPosition &IRP, Attributor &A)
+ : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl>(IRP, A) {}
+
+ /// See AbstractAttribute::trackStatistics()
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSRET_ATTR(nofpclass)
+ }
+};
+
struct AACallEdgesImpl : public AACallEdges {
AACallEdgesImpl(const IRPosition &IRP, Attributor &A) : AACallEdges(IRP, A) {}
@@ -10133,7 +10362,7 @@ struct AACallEdgesImpl : public AACallEdges {
return HasUnknownCalleeNonAsm;
}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return "CallEdges[" + std::to_string(HasUnknownCallee) + "," +
std::to_string(CalledFunctions.size()) + "]";
}
@@ -10191,6 +10420,11 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
SmallVector<AA::ValueAndContext> Values;
// Process any value that we might call.
auto ProcessCalledOperand = [&](Value *V, Instruction *CtxI) {
+ if (isa<Constant>(V)) {
+ VisitValue(*V, CtxI);
+ return;
+ }
+
bool UsedAssumedInformation = false;
Values.clear();
if (!A.getAssumedSimplifiedValues(IRPosition::value(*V), *this, Values,
@@ -10246,14 +10480,16 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
auto ProcessCallInst = [&](Instruction &Inst) {
CallBase &CB = cast<CallBase>(Inst);
- auto &CBEdges = A.getAAFor<AACallEdges>(
+ auto *CBEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
- if (CBEdges.hasNonAsmUnknownCallee())
+ if (!CBEdges)
+ return false;
+ if (CBEdges->hasNonAsmUnknownCallee())
setHasUnknownCallee(true, Change);
- if (CBEdges.hasUnknownCallee())
+ if (CBEdges->hasUnknownCallee())
setHasUnknownCallee(false, Change);
- for (Function *F : CBEdges.getOptimisticEdges())
+ for (Function *F : CBEdges->getOptimisticEdges())
addCalledFunction(F, Change);
return true;
@@ -10277,8 +10513,9 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
struct AAInterFnReachabilityFunction
: public CachedReachabilityAA<AAInterFnReachability, Function> {
+ using Base = CachedReachabilityAA<AAInterFnReachability, Function>;
AAInterFnReachabilityFunction(const IRPosition &IRP, Attributor &A)
- : CachedReachabilityAA<AAInterFnReachability, Function>(IRP, A) {}
+ : Base(IRP, A) {}
bool instructionCanReach(
Attributor &A, const Instruction &From, const Function &To,
@@ -10287,10 +10524,10 @@ struct AAInterFnReachabilityFunction
assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!");
auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this);
- RQITy StackRQI(A, From, To, ExclusionSet);
+ RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
- if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result))
- return NonConstThis->isReachableImpl(A, *RQIPtr);
+ if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
+ return NonConstThis->isReachableImpl(A, StackRQI);
return Result == RQITy::Reachable::Yes;
}
@@ -10305,59 +10542,61 @@ struct AAInterFnReachabilityFunction
if (!Visited)
Visited = &LocalVisited;
- const auto &IntraFnReachability = A.getAAFor<AAIntraFnReachability>(
- *this, IRPosition::function(*RQI.From->getFunction()),
- DepClassTy::OPTIONAL);
-
- // Determine call like instructions that we can reach from the inst.
- SmallVector<CallBase *> ReachableCallBases;
- auto CheckCallBase = [&](Instruction &CBInst) {
- if (IntraFnReachability.isAssumedReachable(A, *RQI.From, CBInst,
- RQI.ExclusionSet))
- ReachableCallBases.push_back(cast<CallBase>(&CBInst));
- return true;
- };
-
- bool UsedAssumedInformation = false;
- if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this,
- UsedAssumedInformation,
- /* CheckBBLivenessOnly */ true))
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
-
- for (CallBase *CB : ReachableCallBases) {
- auto &CBEdges = A.getAAFor<AACallEdges>(
+ auto CheckReachableCallBase = [&](CallBase *CB) {
+ auto *CBEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
- if (!CBEdges.getState().isValidState())
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ if (!CBEdges || !CBEdges->getState().isValidState())
+ return false;
// TODO Check To backwards in this case.
- if (CBEdges.hasUnknownCallee())
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ if (CBEdges->hasUnknownCallee())
+ return false;
- for (Function *Fn : CBEdges.getOptimisticEdges()) {
+ for (Function *Fn : CBEdges->getOptimisticEdges()) {
if (Fn == RQI.To)
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ return false;
if (!Visited->insert(Fn).second)
continue;
if (Fn->isDeclaration()) {
if (Fn->hasFnAttribute(Attribute::NoCallback))
continue;
// TODO Check To backwards in this case.
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ return false;
}
const AAInterFnReachability *InterFnReachability = this;
if (Fn != getAnchorScope())
- InterFnReachability = &A.getAAFor<AAInterFnReachability>(
+ InterFnReachability = A.getAAFor<AAInterFnReachability>(
*this, IRPosition::function(*Fn), DepClassTy::OPTIONAL);
const Instruction &FnFirstInst = Fn->getEntryBlock().front();
- if (InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To,
+ if (!InterFnReachability ||
+ InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To,
RQI.ExclusionSet, Visited))
- return rememberResult(A, RQITy::Reachable::Yes, RQI);
+ return false;
}
- }
+ return true;
+ };
+
+ const auto *IntraFnReachability = A.getAAFor<AAIntraFnReachability>(
+ *this, IRPosition::function(*RQI.From->getFunction()),
+ DepClassTy::OPTIONAL);
+
+ // Determine call like instructions that we can reach from the inst.
+ auto CheckCallBase = [&](Instruction &CBInst) {
+ if (!IntraFnReachability || !IntraFnReachability->isAssumedReachable(
+ A, *RQI.From, CBInst, RQI.ExclusionSet))
+ return true;
+ return CheckReachableCallBase(cast<CallBase>(&CBInst));
+ };
+
+ bool UsedExclusionSet = /* conservative */ true;
+ bool UsedAssumedInformation = false;
+ if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this,
+ UsedAssumedInformation,
+ /* CheckBBLivenessOnly */ true))
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
- return rememberResult(A, RQITy::Reachable::No, RQI);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
}
void trackStatistics() const override {}
@@ -10376,16 +10615,18 @@ askForAssumedConstant(Attributor &A, const AbstractAttribute &QueryingAA,
return nullptr;
// This will also pass the call base context.
- const auto &AA = A.getAAFor<AAType>(QueryingAA, IRP, DepClassTy::NONE);
+ const auto *AA = A.getAAFor<AAType>(QueryingAA, IRP, DepClassTy::NONE);
+ if (!AA)
+ return nullptr;
- std::optional<Constant *> COpt = AA.getAssumedConstant(A);
+ std::optional<Constant *> COpt = AA->getAssumedConstant(A);
if (!COpt.has_value()) {
- A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL);
+ A.recordDependence(*AA, QueryingAA, DepClassTy::OPTIONAL);
return std::nullopt;
}
if (auto *C = *COpt) {
- A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL);
+ A.recordDependence(*AA, QueryingAA, DepClassTy::OPTIONAL);
return C;
}
return nullptr;
@@ -10432,7 +10673,7 @@ struct AAPotentialValuesImpl : AAPotentialValues {
}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
std::string Str;
llvm::raw_string_ostream OS(Str);
OS << getState();
@@ -10454,9 +10695,9 @@ struct AAPotentialValuesImpl : AAPotentialValues {
return nullptr;
}
- void addValue(Attributor &A, StateType &State, Value &V,
- const Instruction *CtxI, AA::ValueScope S,
- Function *AnchorScope) const {
+ virtual void addValue(Attributor &A, StateType &State, Value &V,
+ const Instruction *CtxI, AA::ValueScope S,
+ Function *AnchorScope) const {
IRPosition ValIRP = IRPosition::value(V);
if (auto *CB = dyn_cast_or_null<CallBase>(CtxI)) {
@@ -10474,12 +10715,12 @@ struct AAPotentialValuesImpl : AAPotentialValues {
std::optional<Value *> SimpleV =
askOtherAA<AAValueConstantRange>(A, *this, ValIRP, Ty);
if (SimpleV.has_value() && !*SimpleV) {
- auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>(
+ auto *PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>(
*this, ValIRP, DepClassTy::OPTIONAL);
- if (PotentialConstantsAA.isValidState()) {
- for (const auto &It : PotentialConstantsAA.getAssumedSet())
+ if (PotentialConstantsAA && PotentialConstantsAA->isValidState()) {
+ for (const auto &It : PotentialConstantsAA->getAssumedSet())
State.unionAssumed({{*ConstantInt::get(&Ty, It), nullptr}, S});
- if (PotentialConstantsAA.undefIsContained())
+ if (PotentialConstantsAA->undefIsContained())
State.unionAssumed({{*UndefValue::get(&Ty), nullptr}, S});
return;
}
@@ -10586,14 +10827,23 @@ struct AAPotentialValuesImpl : AAPotentialValues {
return ChangeStatus::UNCHANGED;
}
- bool getAssumedSimplifiedValues(Attributor &A,
- SmallVectorImpl<AA::ValueAndContext> &Values,
- AA::ValueScope S) const override {
+ bool getAssumedSimplifiedValues(
+ Attributor &A, SmallVectorImpl<AA::ValueAndContext> &Values,
+ AA::ValueScope S, bool RecurseForSelectAndPHI = false) const override {
if (!isValidState())
return false;
+ bool UsedAssumedInformation = false;
for (const auto &It : getAssumedSet())
- if (It.second & S)
+ if (It.second & S) {
+ if (RecurseForSelectAndPHI && (isa<PHINode>(It.first.getValue()) ||
+ isa<SelectInst>(It.first.getValue()))) {
+ if (A.getAssumedSimplifiedValues(
+ IRPosition::inst(*cast<Instruction>(It.first.getValue())),
+ this, Values, S, UsedAssumedInformation))
+ continue;
+ }
Values.push_back(It.first);
+ }
assert(!undefIsContained() && "Undef should be an explicit value!");
return true;
}
@@ -10607,7 +10857,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
ChangeStatus updateImpl(Attributor &A) override {
auto AssumedBefore = getAssumed();
- genericValueTraversal(A);
+ genericValueTraversal(A, &getAssociatedValue());
return (AssumedBefore == getAssumed()) ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
@@ -10677,9 +10927,11 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
// The index is the operand that we assume is not null.
unsigned PtrIdx = LHSIsNull;
- auto &PtrNonNullAA = A.getAAFor<AANonNull>(
- *this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED);
- if (!PtrNonNullAA.isAssumedNonNull())
+ bool IsKnownNonNull;
+ bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED,
+ IsKnownNonNull);
+ if (!IsAssumedNonNull)
return false;
// The new value depends on the predicate, true for != and false for ==.
@@ -10743,7 +10995,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
InformationCache &InfoCache = A.getInfoCache();
if (InfoCache.isOnlyUsedByAssume(LI)) {
if (!llvm::all_of(PotentialValueOrigins, [&](Instruction *I) {
- if (!I)
+ if (!I || isa<AssumeInst>(I))
return true;
if (auto *SI = dyn_cast<StoreInst>(I))
return A.isAssumedDead(SI->getOperandUse(0), this,
@@ -10797,21 +11049,37 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
auto GetLivenessInfo = [&](const Function &F) -> LivenessInfo & {
LivenessInfo &LI = LivenessAAs[&F];
if (!LI.LivenessAA)
- LI.LivenessAA = &A.getAAFor<AAIsDead>(*this, IRPosition::function(F),
- DepClassTy::NONE);
+ LI.LivenessAA = A.getAAFor<AAIsDead>(*this, IRPosition::function(F),
+ DepClassTy::NONE);
return LI;
};
if (&PHI == &getAssociatedValue()) {
LivenessInfo &LI = GetLivenessInfo(*PHI.getFunction());
+ const auto *CI =
+ A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>(
+ *PHI.getFunction());
+
+ Cycle *C = nullptr;
+ bool CyclePHI = mayBeInCycle(CI, &PHI, /* HeaderOnly */ true, &C);
for (unsigned u = 0, e = PHI.getNumIncomingValues(); u < e; u++) {
BasicBlock *IncomingBB = PHI.getIncomingBlock(u);
- if (LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) {
+ if (LI.LivenessAA &&
+ LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) {
LI.AnyDead = true;
continue;
}
- Worklist.push_back(
- {{*PHI.getIncomingValue(u), IncomingBB->getTerminator()}, II.S});
+ Value *V = PHI.getIncomingValue(u);
+ if (V == &PHI)
+ continue;
+
+ // If the incoming value is not the PHI but an instruction in the same
+ // cycle we might have multiple versions of it flying around.
+ if (CyclePHI && isa<Instruction>(V) &&
+ (!C || C->contains(cast<Instruction>(V)->getParent())))
+ return false;
+
+ Worklist.push_back({{*V, IncomingBB->getTerminator()}, II.S});
}
return true;
}
@@ -10866,11 +11134,10 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F);
const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F);
auto *AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F);
- OptimizationRemarkEmitter *ORE = nullptr;
const DataLayout &DL = I.getModule()->getDataLayout();
SimplifyQuery Q(DL, TLI, DT, AC, &I);
- Value *NewV = simplifyInstructionWithOperands(&I, NewOps, Q, ORE);
+ Value *NewV = simplifyInstructionWithOperands(&I, NewOps, Q);
if (!NewV || NewV == &I)
return false;
@@ -10902,10 +11169,9 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
return false;
}
- void genericValueTraversal(Attributor &A) {
+ void genericValueTraversal(Attributor &A, Value *InitialV) {
SmallMapVector<const Function *, LivenessInfo, 4> LivenessAAs;
- Value *InitialV = &getAssociatedValue();
SmallSet<ItemInfo, 16> Visited;
SmallVector<ItemInfo, 16> Worklist;
Worklist.push_back({{*InitialV, getCtxI()}, AA::AnyScope});
@@ -10937,14 +11203,15 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
if (V->getType()->isPointerTy()) {
NewV = AA::getWithType(*V->stripPointerCasts(), *V->getType());
} else {
- auto *CB = dyn_cast<CallBase>(V);
- if (CB && CB->getCalledFunction()) {
- for (Argument &Arg : CB->getCalledFunction()->args())
- if (Arg.hasReturnedAttr()) {
- NewV = CB->getArgOperand(Arg.getArgNo());
- break;
- }
- }
+ if (auto *CB = dyn_cast<CallBase>(V))
+ if (auto *Callee =
+ dyn_cast_if_present<Function>(CB->getCalledOperand())) {
+ for (Argument &Arg : Callee->args())
+ if (Arg.hasReturnedAttr()) {
+ NewV = CB->getArgOperand(Arg.getArgNo());
+ break;
+ }
+ }
}
if (NewV && NewV != V) {
Worklist.push_back({{*NewV, CtxI}, S});
@@ -11062,25 +11329,127 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl {
}
};
-struct AAPotentialValuesReturned
- : AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl> {
- using Base =
- AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl>;
+struct AAPotentialValuesReturned : public AAPotentialValuesFloating {
+ using Base = AAPotentialValuesFloating;
AAPotentialValuesReturned(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
/// See AbstractAttribute::initialize(..).
void initialize(Attributor &A) override {
- if (A.hasSimplificationCallback(getIRPosition()))
+ Function *F = getAssociatedFunction();
+ if (!F || F->isDeclaration() || F->getReturnType()->isVoidTy()) {
indicatePessimisticFixpoint();
- else
- AAPotentialValues::initialize(A);
+ return;
+ }
+
+ for (Argument &Arg : F->args())
+ if (Arg.hasReturnedAttr()) {
+ addValue(A, getState(), Arg, nullptr, AA::AnyScope, F);
+ ReturnedArg = &Arg;
+ break;
+ }
+ if (!A.isFunctionIPOAmendable(*F) ||
+ A.hasSimplificationCallback(getIRPosition())) {
+ if (!ReturnedArg)
+ indicatePessimisticFixpoint();
+ else
+ indicateOptimisticFixpoint();
+ }
+ }
+
+ /// See AbstractAttribute::updateImpl(...).
+ ChangeStatus updateImpl(Attributor &A) override {
+ auto AssumedBefore = getAssumed();
+ bool UsedAssumedInformation = false;
+
+ SmallVector<AA::ValueAndContext> Values;
+ Function *AnchorScope = getAnchorScope();
+ auto HandleReturnedValue = [&](Value &V, Instruction *CtxI,
+ bool AddValues) {
+ for (AA::ValueScope S : {AA::Interprocedural, AA::Intraprocedural}) {
+ Values.clear();
+ if (!A.getAssumedSimplifiedValues(IRPosition::value(V), this, Values, S,
+ UsedAssumedInformation,
+ /* RecurseForSelectAndPHI */ true))
+ return false;
+ if (!AddValues)
+ continue;
+ for (const AA::ValueAndContext &VAC : Values)
+ addValue(A, getState(), *VAC.getValue(),
+ VAC.getCtxI() ? VAC.getCtxI() : CtxI, S, AnchorScope);
+ }
+ return true;
+ };
+
+ if (ReturnedArg) {
+ HandleReturnedValue(*ReturnedArg, nullptr, true);
+ } else {
+ auto RetInstPred = [&](Instruction &RetI) {
+ bool AddValues = true;
+ if (isa<PHINode>(RetI.getOperand(0)) ||
+ isa<SelectInst>(RetI.getOperand(0))) {
+ addValue(A, getState(), *RetI.getOperand(0), &RetI, AA::AnyScope,
+ AnchorScope);
+ AddValues = false;
+ }
+ return HandleReturnedValue(*RetI.getOperand(0), &RetI, AddValues);
+ };
+
+ if (!A.checkForAllInstructions(RetInstPred, *this, {Instruction::Ret},
+ UsedAssumedInformation,
+ /* CheckBBLivenessOnly */ true))
+ return indicatePessimisticFixpoint();
+ }
+
+ return (AssumedBefore == getAssumed()) ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
+
+ void addValue(Attributor &A, StateType &State, Value &V,
+ const Instruction *CtxI, AA::ValueScope S,
+ Function *AnchorScope) const override {
+ Function *F = getAssociatedFunction();
+ if (auto *CB = dyn_cast<CallBase>(&V))
+ if (CB->getCalledOperand() == F)
+ return;
+ Base::addValue(A, State, V, CtxI, S, AnchorScope);
}
ChangeStatus manifest(Attributor &A) override {
- // We queried AAValueSimplify for the returned values so they will be
- // replaced if a simplified form was found. Nothing to do here.
- return ChangeStatus::UNCHANGED;
+ if (ReturnedArg)
+ return ChangeStatus::UNCHANGED;
+ SmallVector<AA::ValueAndContext> Values;
+ if (!getAssumedSimplifiedValues(A, Values, AA::ValueScope::Intraprocedural,
+ /* RecurseForSelectAndPHI */ true))
+ return ChangeStatus::UNCHANGED;
+ Value *NewVal = getSingleValue(A, *this, getIRPosition(), Values);
+ if (!NewVal)
+ return ChangeStatus::UNCHANGED;
+
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
+ if (auto *Arg = dyn_cast<Argument>(NewVal)) {
+ STATS_DECLTRACK(UniqueReturnValue, FunctionReturn,
+ "Number of function with unique return");
+ Changed |= A.manifestAttrs(
+ IRPosition::argument(*Arg),
+ {Attribute::get(Arg->getContext(), Attribute::Returned)});
+ STATS_DECLTRACK_ARG_ATTR(returned);
+ }
+
+ auto RetInstPred = [&](Instruction &RetI) {
+ Value *RetOp = RetI.getOperand(0);
+ if (isa<UndefValue>(RetOp) || RetOp == NewVal)
+ return true;
+ if (AA::isValidAtPosition({*NewVal, RetI}, A.getInfoCache()))
+ if (A.changeUseAfterManifest(RetI.getOperandUse(0), *NewVal))
+ Changed = ChangeStatus::CHANGED;
+ return true;
+ };
+ bool UsedAssumedInformation = false;
+ (void)A.checkForAllInstructions(RetInstPred, *this, {Instruction::Ret},
+ UsedAssumedInformation,
+ /* CheckBBLivenessOnly */ true);
+ return Changed;
}
ChangeStatus indicatePessimisticFixpoint() override {
@@ -11088,9 +11457,11 @@ struct AAPotentialValuesReturned
}
/// See AbstractAttribute::trackStatistics()
- void trackStatistics() const override {
- STATS_DECLTRACK_FNRET_ATTR(potential_values)
- }
+ void trackStatistics() const override{
+ STATS_DECLTRACK_FNRET_ATTR(potential_values)}
+
+ /// The argumented with an existing `returned` attribute.
+ Argument *ReturnedArg = nullptr;
};
struct AAPotentialValuesFunction : AAPotentialValuesImpl {
@@ -11162,7 +11533,7 @@ struct AAPotentialValuesCallSiteReturned : AAPotentialValuesImpl {
SmallVector<AA::ValueAndContext> ArgValues;
IRPosition IRP = IRPosition::value(*V);
if (auto *Arg = dyn_cast<Argument>(V))
- if (Arg->getParent() == CB->getCalledFunction())
+ if (Arg->getParent() == CB->getCalledOperand())
IRP = IRPosition::callsite_argument(*CB, Arg->getArgNo());
if (recurseForValue(A, IRP, AA::AnyScope))
continue;
@@ -11228,12 +11599,26 @@ struct AAAssumptionInfoImpl : public AAAssumptionInfo {
const DenseSet<StringRef> &Known)
: AAAssumptionInfo(IRP, A, Known) {}
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ // Don't manifest a universal set if it somehow made it here.
+ if (getKnown().isUniversal())
+ return ChangeStatus::UNCHANGED;
+
+ const IRPosition &IRP = getIRPosition();
+ return A.manifestAttrs(
+ IRP,
+ Attribute::get(IRP.getAnchorValue().getContext(), AssumptionAttrKey,
+ llvm::join(getAssumed().getSet(), ",")),
+ /* ForceReplace */ true);
+ }
+
bool hasAssumption(const StringRef Assumption) const override {
return isValidState() && setContains(Assumption);
}
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
const SetContents &Known = getKnown();
const SetContents &Assumed = getAssumed();
@@ -11264,31 +11649,18 @@ struct AAAssumptionInfoFunction final : AAAssumptionInfoImpl {
: AAAssumptionInfoImpl(IRP, A,
getAssumptions(*IRP.getAssociatedFunction())) {}
- /// See AbstractAttribute::manifest(...).
- ChangeStatus manifest(Attributor &A) override {
- const auto &Assumptions = getKnown();
-
- // Don't manifest a universal set if it somehow made it here.
- if (Assumptions.isUniversal())
- return ChangeStatus::UNCHANGED;
-
- Function *AssociatedFunction = getAssociatedFunction();
-
- bool Changed = addAssumptions(*AssociatedFunction, Assumptions.getSet());
-
- return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
bool Changed = false;
auto CallSitePred = [&](AbstractCallSite ACS) {
- const auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
+ const auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
*this, IRPosition::callsite_function(*ACS.getInstruction()),
DepClassTy::REQUIRED);
+ if (!AssumptionAA)
+ return false;
// Get the set of assumptions shared by all of this function's callers.
- Changed |= getIntersection(AssumptionAA.getAssumed());
+ Changed |= getIntersection(AssumptionAA->getAssumed());
return !getAssumed().empty() || !getKnown().empty();
};
@@ -11319,24 +11691,14 @@ struct AAAssumptionInfoCallSite final : AAAssumptionInfoImpl {
A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED);
}
- /// See AbstractAttribute::manifest(...).
- ChangeStatus manifest(Attributor &A) override {
- // Don't manifest a universal set if it somehow made it here.
- if (getKnown().isUniversal())
- return ChangeStatus::UNCHANGED;
-
- CallBase &AssociatedCall = cast<CallBase>(getAssociatedValue());
- bool Changed = addAssumptions(AssociatedCall, getAssumed().getSet());
-
- return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
- }
-
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
const IRPosition &FnPos = IRPosition::function(*getAnchorScope());
- auto &AssumptionAA =
+ auto *AssumptionAA =
A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED);
- bool Changed = getIntersection(AssumptionAA.getAssumed());
+ if (!AssumptionAA)
+ return indicatePessimisticFixpoint();
+ bool Changed = getIntersection(AssumptionAA->getAssumed());
return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
}
@@ -11360,7 +11722,7 @@ private:
AACallGraphNode *AACallEdgeIterator::operator*() const {
return static_cast<AACallGraphNode *>(const_cast<AACallEdges *>(
- &A.getOrCreateAAFor<AACallEdges>(IRPosition::function(**I))));
+ A.getOrCreateAAFor<AACallEdges>(IRPosition::function(**I))));
}
void AttributorCallGraph::print() { llvm::WriteGraph(outs(), this); }
@@ -11374,7 +11736,7 @@ struct AAUnderlyingObjectsImpl
AAUnderlyingObjectsImpl(const IRPosition &IRP, Attributor &A) : BaseTy(IRP) {}
/// See AbstractAttribute::getAsStr().
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *A) const override {
return std::string("UnderlyingObjects ") +
(isValidState()
? (std::string("inter #") +
@@ -11409,24 +11771,33 @@ struct AAUnderlyingObjectsImpl
auto *Obj = VAC.getValue();
Value *UO = getUnderlyingObject(Obj);
if (UO && UO != VAC.getValue() && SeenObjects.insert(UO).second) {
- const auto &OtherAA = A.getAAFor<AAUnderlyingObjects>(
+ const auto *OtherAA = A.getAAFor<AAUnderlyingObjects>(
*this, IRPosition::value(*UO), DepClassTy::OPTIONAL);
auto Pred = [&Values](Value &V) {
Values.emplace_back(V, nullptr);
return true;
};
- if (!OtherAA.forallUnderlyingObjects(Pred, Scope))
+ if (!OtherAA || !OtherAA->forallUnderlyingObjects(Pred, Scope))
llvm_unreachable(
"The forall call should not return false at this position");
continue;
}
- if (isa<SelectInst>(Obj) || isa<PHINode>(Obj)) {
+ if (isa<SelectInst>(Obj)) {
Changed |= handleIndirect(A, *Obj, UnderlyingObjects, Scope);
continue;
}
+ if (auto *PHI = dyn_cast<PHINode>(Obj)) {
+ // Explicitly look through PHIs as we do not care about dynamically
+ // uniqueness.
+ for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
+ Changed |= handleIndirect(A, *PHI->getIncomingValue(u),
+ UnderlyingObjects, Scope);
+ }
+ continue;
+ }
Changed |= UnderlyingObjects.insert(Obj);
}
@@ -11464,13 +11835,13 @@ private:
SmallSetVector<Value *, 8> &UnderlyingObjects,
AA::ValueScope Scope) {
bool Changed = false;
- const auto &AA = A.getAAFor<AAUnderlyingObjects>(
+ const auto *AA = A.getAAFor<AAUnderlyingObjects>(
*this, IRPosition::value(V), DepClassTy::OPTIONAL);
auto Pred = [&](Value &V) {
Changed |= UnderlyingObjects.insert(&V);
return true;
};
- if (!AA.forallUnderlyingObjects(Pred, Scope))
+ if (!AA || !AA->forallUnderlyingObjects(Pred, Scope))
llvm_unreachable(
"The forall call should not return false at this position");
return Changed;
@@ -11516,14 +11887,190 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl {
AAUnderlyingObjectsFunction(const IRPosition &IRP, Attributor &A)
: AAUnderlyingObjectsImpl(IRP, A) {}
};
-}
+} // namespace
+
+/// ------------------------ Address Space ------------------------------------
+namespace {
+struct AAAddressSpaceImpl : public AAAddressSpace {
+ AAAddressSpaceImpl(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpace(IRP, A) {}
+
+ int32_t getAddressSpace() const override {
+ assert(isValidState() && "the AA is invalid");
+ return AssumedAddressSpace;
+ }
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ assert(getAssociatedType()->isPtrOrPtrVectorTy() &&
+ "Associated value is not a pointer");
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ int32_t OldAddressSpace = AssumedAddressSpace;
+ auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this,
+ DepClassTy::REQUIRED);
+ auto Pred = [&](Value &Obj) {
+ if (isa<UndefValue>(&Obj))
+ return true;
+ return takeAddressSpace(Obj.getType()->getPointerAddressSpace());
+ };
+
+ if (!AUO->forallUnderlyingObjects(Pred))
+ return indicatePessimisticFixpoint();
+
+ return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ Value *AssociatedValue = &getAssociatedValue();
+ Value *OriginalValue = peelAddrspacecast(AssociatedValue);
+ if (getAddressSpace() == NoAddressSpace ||
+ static_cast<uint32_t>(getAddressSpace()) ==
+ getAssociatedType()->getPointerAddressSpace())
+ return ChangeStatus::UNCHANGED;
+
+ Type *NewPtrTy = PointerType::get(getAssociatedType()->getContext(),
+ static_cast<uint32_t>(getAddressSpace()));
+ bool UseOriginalValue =
+ OriginalValue->getType()->getPointerAddressSpace() ==
+ static_cast<uint32_t>(getAddressSpace());
+
+ bool Changed = false;
+
+ auto MakeChange = [&](Instruction *I, Use &U) {
+ Changed = true;
+ if (UseOriginalValue) {
+ A.changeUseAfterManifest(U, *OriginalValue);
+ return;
+ }
+ Instruction *CastInst = new AddrSpaceCastInst(OriginalValue, NewPtrTy);
+ CastInst->insertBefore(cast<Instruction>(I));
+ A.changeUseAfterManifest(U, *CastInst);
+ };
+
+ auto Pred = [&](const Use &U, bool &) {
+ if (U.get() != AssociatedValue)
+ return true;
+ auto *Inst = dyn_cast<Instruction>(U.getUser());
+ if (!Inst)
+ return true;
+ // This is a WA to make sure we only change uses from the corresponding
+ // CGSCC if the AA is run on CGSCC instead of the entire module.
+ if (!A.isRunOn(Inst->getFunction()))
+ return true;
+ if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst))
+ MakeChange(Inst, const_cast<Use &>(U));
+ return true;
+ };
+
+ // It doesn't matter if we can't check all uses as we can simply
+ // conservatively ignore those that can not be visited.
+ (void)A.checkForAllUses(Pred, *this, getAssociatedValue(),
+ /* CheckBBLivenessOnly */ true);
+
+ return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr(Attributor *A) const override {
+ if (!isValidState())
+ return "addrspace(<invalid>)";
+ return "addrspace(" +
+ (AssumedAddressSpace == NoAddressSpace
+ ? "none"
+ : std::to_string(AssumedAddressSpace)) +
+ ")";
+ }
+
+private:
+ int32_t AssumedAddressSpace = NoAddressSpace;
+
+ bool takeAddressSpace(int32_t AS) {
+ if (AssumedAddressSpace == NoAddressSpace) {
+ AssumedAddressSpace = AS;
+ return true;
+ }
+ return AssumedAddressSpace == AS;
+ }
+
+ static Value *peelAddrspacecast(Value *V) {
+ if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
+ return peelAddrspacecast(I->getPointerOperand());
+ if (auto *C = dyn_cast<ConstantExpr>(V))
+ if (C->getOpcode() == Instruction::AddrSpaceCast)
+ return peelAddrspacecast(C->getOperand(0));
+ return V;
+ }
+};
+
+struct AAAddressSpaceFloating final : AAAddressSpaceImpl {
+ AAAddressSpaceFloating(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpaceImpl(IRP, A) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(addrspace);
+ }
+};
+
+struct AAAddressSpaceReturned final : AAAddressSpaceImpl {
+ AAAddressSpaceReturned(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpaceImpl(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // TODO: we don't rewrite function argument for now because it will need to
+ // rewrite the function signature and all call sites.
+ (void)indicatePessimisticFixpoint();
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(addrspace);
+ }
+};
+
+struct AAAddressSpaceCallSiteReturned final : AAAddressSpaceImpl {
+ AAAddressSpaceCallSiteReturned(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpaceImpl(IRP, A) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSRET_ATTR(addrspace);
+ }
+};
+
+struct AAAddressSpaceArgument final : AAAddressSpaceImpl {
+ AAAddressSpaceArgument(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpaceImpl(IRP, A) {}
+
+ void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(addrspace); }
+};
+
+struct AAAddressSpaceCallSiteArgument final : AAAddressSpaceImpl {
+ AAAddressSpaceCallSiteArgument(const IRPosition &IRP, Attributor &A)
+ : AAAddressSpaceImpl(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // TODO: we don't rewrite call site argument for now because it will need to
+ // rewrite the function signature of the callee.
+ (void)indicatePessimisticFixpoint();
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSARG_ATTR(addrspace);
+ }
+};
+} // namespace
-const char AAReturnedValues::ID = 0;
const char AANoUnwind::ID = 0;
const char AANoSync::ID = 0;
const char AANoFree::ID = 0;
const char AANonNull::ID = 0;
+const char AAMustProgress::ID = 0;
const char AANoRecurse::ID = 0;
+const char AANonConvergent::ID = 0;
const char AAWillReturn::ID = 0;
const char AAUndefinedBehavior::ID = 0;
const char AANoAlias::ID = 0;
@@ -11543,11 +12090,13 @@ const char AAValueConstantRange::ID = 0;
const char AAPotentialConstantValues::ID = 0;
const char AAPotentialValues::ID = 0;
const char AANoUndef::ID = 0;
+const char AANoFPClass::ID = 0;
const char AACallEdges::ID = 0;
const char AAInterFnReachability::ID = 0;
const char AAPointerInfo::ID = 0;
const char AAAssumptionInfo::ID = 0;
const char AAUnderlyingObjects::ID = 0;
+const char AAAddressSpace::ID = 0;
// Macro magic to create the static generator function for attributes that
// follow the naming scheme.
@@ -11647,10 +12196,10 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoSync)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoRecurse)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAWillReturn)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn)
-CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges)
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo)
+CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias)
@@ -11663,7 +12212,9 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueConstantRange)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialConstantValues)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialValues)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFPClass)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAddressSpace)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
@@ -11672,6 +12223,7 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
+CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability)
diff --git a/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/llvm/lib/Transforms/IPO/BlockExtractor.cpp
index a68cf7db7c85..0c406aa9822e 100644
--- a/llvm/lib/Transforms/IPO/BlockExtractor.cpp
+++ b/llvm/lib/Transforms/IPO/BlockExtractor.cpp
@@ -17,8 +17,6 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MemoryBuffer.h"
diff --git a/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp b/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp
index 64bfcb2a9a9f..2c8756c07f87 100644
--- a/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp
+++ b/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp
@@ -21,8 +21,6 @@
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/MDBuilder.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
@@ -405,33 +403,3 @@ PreservedAnalyses CalledValuePropagationPass::run(Module &M,
runCVP(M);
return PreservedAnalyses::all();
}
-
-namespace {
-class CalledValuePropagationLegacyPass : public ModulePass {
-public:
- static char ID;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
-
- CalledValuePropagationLegacyPass() : ModulePass(ID) {
- initializeCalledValuePropagationLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
- return runCVP(M);
- }
-};
-} // namespace
-
-char CalledValuePropagationLegacyPass::ID = 0;
-INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation",
- "Called Value Propagation", false, false)
-
-ModulePass *llvm::createCalledValuePropagationPass() {
- return new CalledValuePropagationLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/llvm/lib/Transforms/IPO/ConstantMerge.cpp
index 77bc377f4514..29052c8d997e 100644
--- a/llvm/lib/Transforms/IPO/ConstantMerge.cpp
+++ b/llvm/lib/Transforms/IPO/ConstantMerge.cpp
@@ -28,8 +28,6 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/IPO.h"
#include <algorithm>
@@ -251,32 +249,3 @@ PreservedAnalyses ConstantMergePass::run(Module &M, ModuleAnalysisManager &) {
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
-
-namespace {
-
-struct ConstantMergeLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
-
- ConstantMergeLegacyPass() : ModulePass(ID) {
- initializeConstantMergeLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- // For this pass, process all of the globals in the module, eliminating
- // duplicate constants.
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
- return mergeConstants(M);
- }
-};
-
-} // end anonymous namespace
-
-char ConstantMergeLegacyPass::ID = 0;
-
-INITIALIZE_PASS(ConstantMergeLegacyPass, "constmerge",
- "Merge Duplicate Global Constants", false, false)
-
-ModulePass *llvm::createConstantMergePass() {
- return new ConstantMergeLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
index 4fe7bb6c757c..93d15f59a036 100644
--- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
+++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
@@ -14,7 +14,6 @@
#include "llvm/Transforms/IPO/CrossDSOCFI.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalObject.h"
@@ -23,8 +22,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO.h"
using namespace llvm;
@@ -35,28 +33,16 @@ STATISTIC(NumTypeIds, "Number of unique type identifiers");
namespace {
-struct CrossDSOCFI : public ModulePass {
- static char ID;
- CrossDSOCFI() : ModulePass(ID) {
- initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
- }
-
+struct CrossDSOCFI {
MDNode *VeryLikelyWeights;
ConstantInt *extractNumericTypeId(MDNode *MD);
void buildCFICheck(Module &M);
- bool runOnModule(Module &M) override;
+ bool runOnModule(Module &M);
};
} // anonymous namespace
-INITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false,
- false)
-INITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false)
-char CrossDSOCFI::ID = 0;
-
-ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
-
/// Extracts a numeric type identifier from an MDNode containing type metadata.
ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) {
// This check excludes vtables for classes inside anonymous namespaces.
diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
index bf2c65a2402c..01834015f3fd 100644
--- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
+++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
@@ -16,9 +16,11 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Argument.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
@@ -43,7 +45,6 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
-#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <cassert>
#include <utility>
@@ -85,6 +86,11 @@ public:
virtual bool shouldHackArguments() const { return false; }
};
+bool isMustTailCalleeAnalyzable(const CallBase &CB) {
+ assert(CB.isMustTailCall());
+ return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration();
+}
+
} // end anonymous namespace
char DAE::ID = 0;
@@ -520,8 +526,16 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) {
for (const BasicBlock &BB : F) {
// If we have any returns of `musttail` results - the signature can't
// change
- if (BB.getTerminatingMustTailCall() != nullptr)
+ if (const auto *TC = BB.getTerminatingMustTailCall()) {
HasMustTailCalls = true;
+ // In addition, if the called function is not locally defined (or unknown,
+ // if this is an indirect call), we can't change the callsite and thus
+ // can't change this function's signature either.
+ if (!isMustTailCalleeAnalyzable(*TC)) {
+ markLive(F);
+ return;
+ }
+ }
}
if (HasMustTailCalls) {
@@ -1081,6 +1095,26 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
return true;
}
+void DeadArgumentEliminationPass::propagateVirtMustcallLiveness(
+ const Module &M) {
+ // If a function was marked "live", and it has musttail callers, they in turn
+ // can't change either.
+ LiveFuncSet NewLiveFuncs(LiveFunctions);
+ while (!NewLiveFuncs.empty()) {
+ LiveFuncSet Temp;
+ for (const auto *F : NewLiveFuncs)
+ for (const auto *U : F->users())
+ if (const auto *CB = dyn_cast<CallBase>(U))
+ if (CB->isMustTailCall())
+ if (!LiveFunctions.count(CB->getParent()->getParent()))
+ Temp.insert(CB->getParent()->getParent());
+ NewLiveFuncs.clear();
+ NewLiveFuncs.insert(Temp.begin(), Temp.end());
+ for (const auto *F : Temp)
+ markLive(*F);
+ }
+}
+
PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
ModuleAnalysisManager &) {
bool Changed = false;
@@ -1101,6 +1135,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
for (auto &F : M)
surveyFunction(F);
+ propagateVirtMustcallLiveness(M);
+
// Now, remove all dead arguments and return values from each function in
// turn. We use make_early_inc_range here because functions will probably get
// removed (i.e. replaced by new ones).
diff --git a/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp b/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp
index 7f138d206fac..2b34d3b5a56e 100644
--- a/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp
+++ b/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp
@@ -12,24 +12,82 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/ElimAvailExtern.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Constant.h"
+#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/GlobalStatus.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace llvm;
#define DEBUG_TYPE "elim-avail-extern"
-STATISTIC(NumFunctions, "Number of functions removed");
+cl::opt<bool> ConvertToLocal(
+ "avail-extern-to-local", cl::Hidden,
+ cl::desc("Convert available_externally into locals, renaming them "
+ "to avoid link-time clashes."));
+
+STATISTIC(NumRemovals, "Number of functions removed");
+STATISTIC(NumConversions, "Number of functions converted");
STATISTIC(NumVariables, "Number of global variables removed");
+void deleteFunction(Function &F) {
+ // This will set the linkage to external
+ F.deleteBody();
+ ++NumRemovals;
+}
+
+/// Create a copy of the thinlto import, mark it local, and redirect direct
+/// calls to the copy. Only direct calls are replaced, so that e.g. indirect
+/// call function pointer tests would use the global identity of the function.
+///
+/// Currently, Value Profiling ("VP") MD_prof data isn't updated to refer to the
+/// clone's GUID (which will be different, because the name and linkage is
+/// different), under the assumption that the last consumer of this data is
+/// upstream the pipeline (e.g. ICP).
+static void convertToLocalCopy(Module &M, Function &F) {
+ assert(F.hasAvailableExternallyLinkage());
+ assert(!F.isDeclaration());
+ // If we can't find a single use that's a call, just delete the function.
+ if (F.uses().end() == llvm::find_if(F.uses(), [&](Use &U) {
+ return isa<CallBase>(U.getUser());
+ }))
+ return deleteFunction(F);
+
+ auto OrigName = F.getName().str();
+ // Build a new name. We still need the old name (see below).
+ // We could just rely on internal linking allowing 2 modules have internal
+ // functions with the same name, but that just creates more trouble than
+ // necessary e.g. distinguishing profiles or debugging. Instead, we append the
+ // module identifier.
+ auto NewName = OrigName + ".__uniq" + getUniqueModuleId(&M);
+ F.setName(NewName);
+ if (auto *SP = F.getSubprogram())
+ SP->replaceLinkageName(MDString::get(F.getParent()->getContext(), NewName));
+
+ F.setLinkage(GlobalValue::InternalLinkage);
+ // Now make a declaration for the old name. We'll use it if there are non-call
+ // uses. For those, it would be incorrect to replace them with the local copy:
+ // for example, one such use could be taking the address of the function and
+ // passing it to an external function, which, in turn, might compare the
+ // function pointer to the original (non-local) function pointer, e.g. as part
+ // of indirect call promotion.
+ auto *Decl =
+ Function::Create(F.getFunctionType(), GlobalValue::ExternalLinkage,
+ F.getAddressSpace(), OrigName, F.getParent());
+ F.replaceUsesWithIf(Decl,
+ [&](Use &U) { return !isa<CallBase>(U.getUser()); });
+ ++NumConversions;
+}
+
static bool eliminateAvailableExternally(Module &M) {
bool Changed = false;
@@ -45,19 +103,21 @@ static bool eliminateAvailableExternally(Module &M) {
}
GV.removeDeadConstantUsers();
GV.setLinkage(GlobalValue::ExternalLinkage);
- NumVariables++;
+ ++NumVariables;
Changed = true;
}
// Drop the bodies of available externally functions.
- for (Function &F : M) {
- if (!F.hasAvailableExternallyLinkage())
+ for (Function &F : llvm::make_early_inc_range(M)) {
+ if (F.isDeclaration() || !F.hasAvailableExternallyLinkage())
continue;
- if (!F.isDeclaration())
- // This will set the linkage to external
- F.deleteBody();
+
+ if (ConvertToLocal)
+ convertToLocalCopy(M, F);
+ else
+ deleteFunction(F);
+
F.removeDeadConstantUsers();
- NumFunctions++;
Changed = true;
}
@@ -70,33 +130,3 @@ EliminateAvailableExternallyPass::run(Module &M, ModuleAnalysisManager &) {
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
-
-namespace {
-
-struct EliminateAvailableExternallyLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
-
- EliminateAvailableExternallyLegacyPass() : ModulePass(ID) {
- initializeEliminateAvailableExternallyLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- // run - Do the EliminateAvailableExternally pass on the specified module,
- // optionally updating the specified callgraph to reflect the changes.
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
- return eliminateAvailableExternally(M);
- }
-};
-
-} // end anonymous namespace
-
-char EliminateAvailableExternallyLegacyPass::ID = 0;
-
-INITIALIZE_PASS(EliminateAvailableExternallyLegacyPass, "elim-avail-extern",
- "Eliminate Available Externally Globals", false, false)
-
-ModulePass *llvm::createEliminateAvailableExternallyPass() {
- return new EliminateAvailableExternallyLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp
new file mode 100644
index 000000000000..fa56a5b564ae
--- /dev/null
+++ b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp
@@ -0,0 +1,52 @@
+//===- EmbedBitcodePass.cpp - Pass that embeds the bitcode into a global---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/EmbedBitcodePass.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/Bitcode/BitcodeWriterPass.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
+#include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#include <memory>
+#include <string>
+
+using namespace llvm;
+
+PreservedAnalyses EmbedBitcodePass::run(Module &M, ModuleAnalysisManager &AM) {
+ if (M.getGlobalVariable("llvm.embedded.module", /*AllowInternal=*/true))
+ report_fatal_error("Can only embed the module once",
+ /*gen_crash_diag=*/false);
+
+ Triple T(M.getTargetTriple());
+ if (T.getObjectFormat() != Triple::ELF)
+ report_fatal_error(
+ "EmbedBitcode pass currently only supports ELF object format",
+ /*gen_crash_diag=*/false);
+
+ std::unique_ptr<Module> NewModule = CloneModule(M);
+ MPM.run(*NewModule, AM);
+
+ std::string Data;
+ raw_string_ostream OS(Data);
+ if (IsThinLTO)
+ ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(*NewModule, AM);
+ else
+ BitcodeWriterPass(OS, /*ShouldPreserveUseListOrder=*/false, EmitLTOSummary)
+ .run(*NewModule, AM);
+
+ embedBufferInModule(M, MemoryBufferRef(Data, "ModuleData"), ".llvm.lto");
+
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Transforms/IPO/ExtractGV.cpp b/llvm/lib/Transforms/IPO/ExtractGV.cpp
index d5073eed2fef..6414ea69c9f7 100644
--- a/llvm/lib/Transforms/IPO/ExtractGV.cpp
+++ b/llvm/lib/Transforms/IPO/ExtractGV.cpp
@@ -36,7 +36,7 @@ static void makeVisible(GlobalValue &GV, bool Delete) {
}
// Map linkonce* to weak* so that llvm doesn't drop this GV.
- switch(GV.getLinkage()) {
+ switch (GV.getLinkage()) {
default:
llvm_unreachable("Unexpected linkage");
case GlobalValue::LinkOnceAnyLinkage:
@@ -48,10 +48,9 @@ static void makeVisible(GlobalValue &GV, bool Delete) {
}
}
-
- /// If deleteS is true, this pass deletes the specified global values.
- /// Otherwise, it deletes as much of the module as possible, except for the
- /// global values specified.
+/// If deleteS is true, this pass deletes the specified global values.
+/// Otherwise, it deletes as much of the module as possible, except for the
+/// global values specified.
ExtractGVPass::ExtractGVPass(std::vector<GlobalValue *> &GVs, bool deleteS,
bool keepConstInit)
: Named(GVs.begin(), GVs.end()), deleteStuff(deleteS),
@@ -129,5 +128,22 @@ PreservedAnalyses ExtractGVPass::run(Module &M, ModuleAnalysisManager &) {
}
}
+ // Visit the IFuncs.
+ for (GlobalIFunc &IF : llvm::make_early_inc_range(M.ifuncs())) {
+ bool Delete = deleteStuff == (bool)Named.count(&IF);
+ makeVisible(IF, Delete);
+
+ if (!Delete)
+ continue;
+
+ auto *FuncType = dyn_cast<FunctionType>(IF.getValueType());
+ IF.removeFromParent();
+ llvm::Value *Declaration =
+ Function::Create(FuncType, GlobalValue::ExternalLinkage,
+ IF.getAddressSpace(), IF.getName(), &M);
+ IF.replaceAllUsesWith(Declaration);
+ delete &IF;
+ }
+
return PreservedAnalyses::none();
}
diff --git a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
index b10c2ea13469..74931e1032d1 100644
--- a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
@@ -9,8 +9,6 @@
#include "llvm/Transforms/IPO/ForceFunctionAttrs.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -80,32 +78,3 @@ PreservedAnalyses ForceFunctionAttrsPass::run(Module &M,
// Just conservatively invalidate analyses, this isn't likely to be important.
return PreservedAnalyses::none();
}
-
-namespace {
-struct ForceFunctionAttrsLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
- ForceFunctionAttrsLegacyPass() : ModulePass(ID) {
- initializeForceFunctionAttrsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (!hasForceAttributes())
- return false;
-
- for (Function &F : M.functions())
- forceAttributes(F);
-
- // Conservatively assume we changed something.
- return true;
- }
-};
-}
-
-char ForceFunctionAttrsLegacyPass::ID = 0;
-INITIALIZE_PASS(ForceFunctionAttrsLegacyPass, "forceattrs",
- "Force set function attributes", false, false)
-
-Pass *llvm::createForceFunctionAttrsLegacyPass() {
- return new ForceFunctionAttrsLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 3f61dbe3354e..34299f9dbb23 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -50,8 +50,6 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
@@ -154,7 +152,7 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody,
// If it's not an identified object, it might be an argument.
if (!isIdentifiedObject(UO))
ME |= MemoryEffects::argMemOnly(MR);
- ME |= MemoryEffects(MemoryEffects::Other, MR);
+ ME |= MemoryEffects(IRMemLocation::Other, MR);
};
// Scan the function body for instructions that may read or write memory.
for (Instruction &I : instructions(F)) {
@@ -181,17 +179,17 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody,
if (isa<PseudoProbeInst>(I))
continue;
- ME |= CallME.getWithoutLoc(MemoryEffects::ArgMem);
+ ME |= CallME.getWithoutLoc(IRMemLocation::ArgMem);
// If the call accesses captured memory (currently part of "other") and
// an argument is captured (currently not tracked), then it may also
// access argument memory.
- ModRefInfo OtherMR = CallME.getModRef(MemoryEffects::Other);
+ ModRefInfo OtherMR = CallME.getModRef(IRMemLocation::Other);
ME |= MemoryEffects::argMemOnly(OtherMR);
// Check whether all pointer arguments point to local memory, and
// ignore calls that only access local memory.
- ModRefInfo ArgMR = CallME.getModRef(MemoryEffects::ArgMem);
+ ModRefInfo ArgMR = CallME.getModRef(IRMemLocation::ArgMem);
if (ArgMR != ModRefInfo::NoModRef) {
for (const Use &U : Call->args()) {
const Value *Arg = U;
@@ -640,7 +638,7 @@ determinePointerAccessAttrs(Argument *A,
if (Visited.insert(&UU).second)
Worklist.push_back(&UU);
}
-
+
if (CB.doesNotAccessMemory())
continue;
@@ -723,18 +721,18 @@ static void addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes,
continue;
// There is nothing to do if an argument is already marked as 'returned'.
- if (llvm::any_of(F->args(),
- [](const Argument &Arg) { return Arg.hasReturnedAttr(); }))
+ if (F->getAttributes().hasAttrSomewhere(Attribute::Returned))
continue;
- auto FindRetArg = [&]() -> Value * {
- Value *RetArg = nullptr;
+ auto FindRetArg = [&]() -> Argument * {
+ Argument *RetArg = nullptr;
for (BasicBlock &BB : *F)
if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) {
// Note that stripPointerCasts should look through functions with
// returned arguments.
- Value *RetVal = Ret->getReturnValue()->stripPointerCasts();
- if (!isa<Argument>(RetVal) || RetVal->getType() != F->getReturnType())
+ auto *RetVal =
+ dyn_cast<Argument>(Ret->getReturnValue()->stripPointerCasts());
+ if (!RetVal || RetVal->getType() != F->getReturnType())
return nullptr;
if (!RetArg)
@@ -746,9 +744,8 @@ static void addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes,
return RetArg;
};
- if (Value *RetArg = FindRetArg()) {
- auto *A = cast<Argument>(RetArg);
- A->addAttr(Attribute::Returned);
+ if (Argument *RetArg = FindRetArg()) {
+ RetArg->addAttr(Attribute::Returned);
++NumReturned;
Changed.insert(F);
}
@@ -1379,7 +1376,7 @@ static bool InstrBreaksNonConvergent(Instruction &I,
/// Helper for NoUnwind inference predicate InstrBreaksAttribute.
static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) {
- if (!I.mayThrow())
+ if (!I.mayThrow(/* IncludePhaseOneUnwind */ true))
return false;
if (const auto *CI = dyn_cast<CallInst>(&I)) {
if (Function *Callee = CI->getCalledFunction()) {
@@ -1410,6 +1407,61 @@ static bool InstrBreaksNoFree(Instruction &I, const SCCNodeSet &SCCNodes) {
return true;
}
+// Return true if this is an atomic which has an ordering stronger than
+// unordered. Note that this is different than the predicate we use in
+// Attributor. Here we chose to be conservative and consider monotonic
+// operations potentially synchronizing. We generally don't do much with
+// monotonic operations, so this is simply risk reduction.
+static bool isOrderedAtomic(Instruction *I) {
+ if (!I->isAtomic())
+ return false;
+
+ if (auto *FI = dyn_cast<FenceInst>(I))
+ // All legal orderings for fence are stronger than monotonic.
+ return FI->getSyncScopeID() != SyncScope::SingleThread;
+ else if (isa<AtomicCmpXchgInst>(I) || isa<AtomicRMWInst>(I))
+ return true;
+ else if (auto *SI = dyn_cast<StoreInst>(I))
+ return !SI->isUnordered();
+ else if (auto *LI = dyn_cast<LoadInst>(I))
+ return !LI->isUnordered();
+ else {
+ llvm_unreachable("unknown atomic instruction?");
+ }
+}
+
+static bool InstrBreaksNoSync(Instruction &I, const SCCNodeSet &SCCNodes) {
+ // Volatile may synchronize
+ if (I.isVolatile())
+ return true;
+
+ // An ordered atomic may synchronize. (See comment about on monotonic.)
+ if (isOrderedAtomic(&I))
+ return true;
+
+ auto *CB = dyn_cast<CallBase>(&I);
+ if (!CB)
+ // Non call site cases covered by the two checks above
+ return false;
+
+ if (CB->hasFnAttr(Attribute::NoSync))
+ return false;
+
+ // Non volatile memset/memcpy/memmoves are nosync
+ // NOTE: Only intrinsics with volatile flags should be handled here. All
+ // others should be marked in Intrinsics.td.
+ if (auto *MI = dyn_cast<MemIntrinsic>(&I))
+ if (!MI->isVolatile())
+ return false;
+
+ // Speculatively assume in SCC.
+ if (Function *Callee = CB->getCalledFunction())
+ if (SCCNodes.contains(Callee))
+ return false;
+
+ return true;
+}
+
/// Attempt to remove convergent function attribute when possible.
///
/// Returns true if any changes to function attributes were made.
@@ -1441,9 +1493,7 @@ static void inferConvergent(const SCCNodeSet &SCCNodes,
}
/// Infer attributes from all functions in the SCC by scanning every
-/// instruction for compliance to the attribute assumptions. Currently it
-/// does:
-/// - addition of NoUnwind attribute
+/// instruction for compliance to the attribute assumptions.
///
/// Returns true if any changes to function attributes were made.
static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes,
@@ -1495,6 +1545,22 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes,
},
/* RequiresExactDefinition= */ true});
+ AI.registerAttrInference(AttributeInferer::InferenceDescriptor{
+ Attribute::NoSync,
+ // Skip already marked functions.
+ [](const Function &F) { return F.hasNoSync(); },
+ // Instructions that break nosync assumption.
+ [&SCCNodes](Instruction &I) {
+ return InstrBreaksNoSync(I, SCCNodes);
+ },
+ [](Function &F) {
+ LLVM_DEBUG(dbgs()
+ << "Adding nosync attr to fn " << F.getName() << "\n");
+ F.setNoSync();
+ ++NumNoSync;
+ },
+ /* RequiresExactDefinition= */ true});
+
// Perform all the requested attribute inference actions.
AI.run(SCCNodes, Changed);
}
@@ -1622,83 +1688,6 @@ static void addWillReturn(const SCCNodeSet &SCCNodes,
}
}
-// Return true if this is an atomic which has an ordering stronger than
-// unordered. Note that this is different than the predicate we use in
-// Attributor. Here we chose to be conservative and consider monotonic
-// operations potentially synchronizing. We generally don't do much with
-// monotonic operations, so this is simply risk reduction.
-static bool isOrderedAtomic(Instruction *I) {
- if (!I->isAtomic())
- return false;
-
- if (auto *FI = dyn_cast<FenceInst>(I))
- // All legal orderings for fence are stronger than monotonic.
- return FI->getSyncScopeID() != SyncScope::SingleThread;
- else if (isa<AtomicCmpXchgInst>(I) || isa<AtomicRMWInst>(I))
- return true;
- else if (auto *SI = dyn_cast<StoreInst>(I))
- return !SI->isUnordered();
- else if (auto *LI = dyn_cast<LoadInst>(I))
- return !LI->isUnordered();
- else {
- llvm_unreachable("unknown atomic instruction?");
- }
-}
-
-static bool InstrBreaksNoSync(Instruction &I, const SCCNodeSet &SCCNodes) {
- // Volatile may synchronize
- if (I.isVolatile())
- return true;
-
- // An ordered atomic may synchronize. (See comment about on monotonic.)
- if (isOrderedAtomic(&I))
- return true;
-
- auto *CB = dyn_cast<CallBase>(&I);
- if (!CB)
- // Non call site cases covered by the two checks above
- return false;
-
- if (CB->hasFnAttr(Attribute::NoSync))
- return false;
-
- // Non volatile memset/memcpy/memmoves are nosync
- // NOTE: Only intrinsics with volatile flags should be handled here. All
- // others should be marked in Intrinsics.td.
- if (auto *MI = dyn_cast<MemIntrinsic>(&I))
- if (!MI->isVolatile())
- return false;
-
- // Speculatively assume in SCC.
- if (Function *Callee = CB->getCalledFunction())
- if (SCCNodes.contains(Callee))
- return false;
-
- return true;
-}
-
-// Infer the nosync attribute.
-static void addNoSyncAttr(const SCCNodeSet &SCCNodes,
- SmallSet<Function *, 8> &Changed) {
- AttributeInferer AI;
- AI.registerAttrInference(AttributeInferer::InferenceDescriptor{
- Attribute::NoSync,
- // Skip already marked functions.
- [](const Function &F) { return F.hasNoSync(); },
- // Instructions that break nosync assumption.
- [&SCCNodes](Instruction &I) {
- return InstrBreaksNoSync(I, SCCNodes);
- },
- [](Function &F) {
- LLVM_DEBUG(dbgs()
- << "Adding nosync attr to fn " << F.getName() << "\n");
- F.setNoSync();
- ++NumNoSync;
- },
- /* RequiresExactDefinition= */ true});
- AI.run(SCCNodes, Changed);
-}
-
static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) {
SCCNodesResult Res;
Res.HasUnknownCall = false;
@@ -1756,8 +1745,6 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) {
addNoRecurseAttrs(Nodes.SCCNodes, Changed);
}
- addNoSyncAttr(Nodes.SCCNodes, Changed);
-
// Finally, infer the maximal set of attributes from the ones we've inferred
// above. This is handling the cases where one attribute on a signature
// implies another, but for implementation reasons the inference rule for
@@ -1774,6 +1761,13 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C,
CGSCCAnalysisManager &AM,
LazyCallGraph &CG,
CGSCCUpdateResult &) {
+ // Skip non-recursive functions if requested.
+ if (C.size() == 1 && SkipNonRecursive) {
+ LazyCallGraph::Node &N = *C.begin();
+ if (!N->lookup(N))
+ return PreservedAnalyses::all();
+ }
+
FunctionAnalysisManager &FAM =
AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
@@ -1819,40 +1813,12 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C,
return PA;
}
-namespace {
-
-struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass {
- // Pass identification, replacement for typeid
- static char ID;
-
- PostOrderFunctionAttrsLegacyPass() : CallGraphSCCPass(ID) {
- initializePostOrderFunctionAttrsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnSCC(CallGraphSCC &SCC) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<AssumptionCacheTracker>();
- getAAResultsAnalysisUsage(AU);
- CallGraphSCCPass::getAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-char PostOrderFunctionAttrsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "function-attrs",
- "Deduce function attributes", false, false)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
-INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "function-attrs",
- "Deduce function attributes", false, false)
-
-Pass *llvm::createPostOrderFunctionAttrsLegacyPass() {
- return new PostOrderFunctionAttrsLegacyPass();
+void PostOrderFunctionAttrsPass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<PostOrderFunctionAttrsPass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ if (SkipNonRecursive)
+ OS << "<skip-non-recursive>";
}
template <typename AARGetterT>
@@ -1865,48 +1831,6 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) {
return !deriveAttrsInPostOrder(Functions, AARGetter).empty();
}
-bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) {
- if (skipSCC(SCC))
- return false;
- return runImpl(SCC, LegacyAARGetter(*this));
-}
-
-namespace {
-
-struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass {
- // Pass identification, replacement for typeid
- static char ID;
-
- ReversePostOrderFunctionAttrsLegacyPass() : ModulePass(ID) {
- initializeReversePostOrderFunctionAttrsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<CallGraphWrapperPass>();
- AU.addPreserved<CallGraphWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-char ReversePostOrderFunctionAttrsLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass,
- "rpo-function-attrs", "Deduce function attributes in RPO",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
-INITIALIZE_PASS_END(ReversePostOrderFunctionAttrsLegacyPass,
- "rpo-function-attrs", "Deduce function attributes in RPO",
- false, false)
-
-Pass *llvm::createReversePostOrderFunctionAttrsPass() {
- return new ReversePostOrderFunctionAttrsLegacyPass();
-}
-
static bool addNoRecurseAttrsTopDown(Function &F) {
// We check the preconditions for the function prior to calling this to avoid
// the cost of building up a reversible post-order list. We assert them here
@@ -1939,7 +1863,7 @@ static bool addNoRecurseAttrsTopDown(Function &F) {
return true;
}
-static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) {
+static bool deduceFunctionAttributeInRPO(Module &M, LazyCallGraph &CG) {
// We only have a post-order SCC traversal (because SCCs are inherently
// discovered in post-order), so we accumulate them in a vector and then walk
// it in reverse. This is simpler than using the RPO iterator infrastructure
@@ -1947,17 +1871,18 @@ static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) {
// graph. We can also cheat egregiously because we're primarily interested in
// synthesizing norecurse and so we can only save the singular SCCs as SCCs
// with multiple functions in them will clearly be recursive.
- SmallVector<Function *, 16> Worklist;
- for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) {
- if (I->size() != 1)
- continue;
- Function *F = I->front()->getFunction();
- if (F && !F->isDeclaration() && !F->doesNotRecurse() &&
- F->hasInternalLinkage())
- Worklist.push_back(F);
+ SmallVector<Function *, 16> Worklist;
+ CG.buildRefSCCs();
+ for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) {
+ for (LazyCallGraph::SCC &SCC : RC) {
+ if (SCC.size() != 1)
+ continue;
+ Function &F = SCC.begin()->getFunction();
+ if (!F.isDeclaration() && !F.doesNotRecurse() && F.hasInternalLinkage())
+ Worklist.push_back(&F);
+ }
}
-
bool Changed = false;
for (auto *F : llvm::reverse(Worklist))
Changed |= addNoRecurseAttrsTopDown(*F);
@@ -1965,23 +1890,14 @@ static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) {
return Changed;
}
-bool ReversePostOrderFunctionAttrsLegacyPass::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
-
- auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
-
- return deduceFunctionAttributeInRPO(M, CG);
-}
-
PreservedAnalyses
ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) {
- auto &CG = AM.getResult<CallGraphAnalysis>(M);
+ auto &CG = AM.getResult<LazyCallGraphAnalysis>(M);
if (!deduceFunctionAttributeInRPO(M, CG))
return PreservedAnalyses::all();
PreservedAnalyses PA;
- PA.preserve<CallGraphAnalysis>();
+ PA.preserve<LazyCallGraphAnalysis>();
return PA;
}
diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp
index 7c994657e5c8..f635b14cd2a9 100644
--- a/llvm/lib/Transforms/IPO/FunctionImport.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp
@@ -30,9 +30,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSummaryIndex.h"
#include "llvm/IRReader/IRReader.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Linker/IRMover.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -159,39 +157,37 @@ static std::unique_ptr<Module> loadFile(const std::string &FileName,
return Result;
}
-/// Given a list of possible callee implementation for a call site, select one
-/// that fits the \p Threshold.
-///
-/// FIXME: select "best" instead of first that fits. But what is "best"?
-/// - The smallest: more likely to be inlined.
-/// - The one with the least outgoing edges (already well optimized).
-/// - One from a module already being imported from in order to reduce the
-/// number of source modules parsed/linked.
-/// - One that has PGO data attached.
-/// - [insert you fancy metric here]
-static const GlobalValueSummary *
-selectCallee(const ModuleSummaryIndex &Index,
- ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList,
- unsigned Threshold, StringRef CallerModulePath,
- FunctionImporter::ImportFailureReason &Reason,
- GlobalValue::GUID GUID) {
- Reason = FunctionImporter::ImportFailureReason::None;
- auto It = llvm::find_if(
+/// Given a list of possible callee implementation for a call site, qualify the
+/// legality of importing each. The return is a range of pairs. Each pair
+/// corresponds to a candidate. The first value is the ImportFailureReason for
+/// that candidate, the second is the candidate.
+static auto qualifyCalleeCandidates(
+ const ModuleSummaryIndex &Index,
+ ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList,
+ StringRef CallerModulePath) {
+ return llvm::map_range(
CalleeSummaryList,
- [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) {
+ [&Index, CalleeSummaryList,
+ CallerModulePath](const std::unique_ptr<GlobalValueSummary> &SummaryPtr)
+ -> std::pair<FunctionImporter::ImportFailureReason,
+ const GlobalValueSummary *> {
auto *GVSummary = SummaryPtr.get();
- if (!Index.isGlobalValueLive(GVSummary)) {
- Reason = FunctionImporter::ImportFailureReason::NotLive;
- return false;
- }
+ if (!Index.isGlobalValueLive(GVSummary))
+ return {FunctionImporter::ImportFailureReason::NotLive, GVSummary};
- if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) {
- Reason = FunctionImporter::ImportFailureReason::InterposableLinkage;
- // There is no point in importing these, we can't inline them
- return false;
- }
+ if (GlobalValue::isInterposableLinkage(GVSummary->linkage()))
+ return {FunctionImporter::ImportFailureReason::InterposableLinkage,
+ GVSummary};
- auto *Summary = cast<FunctionSummary>(GVSummary->getBaseObject());
+ auto *Summary = dyn_cast<FunctionSummary>(GVSummary->getBaseObject());
+
+ // Ignore any callees that aren't actually functions. This could happen
+ // in the case of GUID hash collisions. It could also happen in theory
+ // for SamplePGO profiles collected on old versions of the code after
+ // renaming, since we synthesize edges to any inlined callees appearing
+ // in the profile.
+ if (!Summary)
+ return {FunctionImporter::ImportFailureReason::GlobalVar, GVSummary};
// If this is a local function, make sure we import the copy
// in the caller's module. The only time a local function can
@@ -205,119 +201,174 @@ selectCallee(const ModuleSummaryIndex &Index,
// a local in another module.
if (GlobalValue::isLocalLinkage(Summary->linkage()) &&
CalleeSummaryList.size() > 1 &&
- Summary->modulePath() != CallerModulePath) {
- Reason =
- FunctionImporter::ImportFailureReason::LocalLinkageNotInModule;
- return false;
- }
-
- if ((Summary->instCount() > Threshold) &&
- !Summary->fflags().AlwaysInline && !ForceImportAll) {
- Reason = FunctionImporter::ImportFailureReason::TooLarge;
- return false;
- }
+ Summary->modulePath() != CallerModulePath)
+ return {
+ FunctionImporter::ImportFailureReason::LocalLinkageNotInModule,
+ GVSummary};
// Skip if it isn't legal to import (e.g. may reference unpromotable
// locals).
- if (Summary->notEligibleToImport()) {
- Reason = FunctionImporter::ImportFailureReason::NotEligible;
- return false;
- }
+ if (Summary->notEligibleToImport())
+ return {FunctionImporter::ImportFailureReason::NotEligible,
+ GVSummary};
- // Don't bother importing if we can't inline it anyway.
- if (Summary->fflags().NoInline && !ForceImportAll) {
- Reason = FunctionImporter::ImportFailureReason::NoInline;
- return false;
- }
-
- return true;
+ return {FunctionImporter::ImportFailureReason::None, GVSummary};
});
- if (It == CalleeSummaryList.end())
- return nullptr;
+}
+
+/// Given a list of possible callee implementation for a call site, select one
+/// that fits the \p Threshold. If none are found, the Reason will give the last
+/// reason for the failure (last, in the order of CalleeSummaryList entries).
+///
+/// FIXME: select "best" instead of first that fits. But what is "best"?
+/// - The smallest: more likely to be inlined.
+/// - The one with the least outgoing edges (already well optimized).
+/// - One from a module already being imported from in order to reduce the
+/// number of source modules parsed/linked.
+/// - One that has PGO data attached.
+/// - [insert you fancy metric here]
+static const GlobalValueSummary *
+selectCallee(const ModuleSummaryIndex &Index,
+ ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList,
+ unsigned Threshold, StringRef CallerModulePath,
+ FunctionImporter::ImportFailureReason &Reason) {
+ auto QualifiedCandidates =
+ qualifyCalleeCandidates(Index, CalleeSummaryList, CallerModulePath);
+ for (auto QualifiedValue : QualifiedCandidates) {
+ Reason = QualifiedValue.first;
+ if (Reason != FunctionImporter::ImportFailureReason::None)
+ continue;
+ auto *Summary =
+ cast<FunctionSummary>(QualifiedValue.second->getBaseObject());
+
+ if ((Summary->instCount() > Threshold) && !Summary->fflags().AlwaysInline &&
+ !ForceImportAll) {
+ Reason = FunctionImporter::ImportFailureReason::TooLarge;
+ continue;
+ }
- return cast<GlobalValueSummary>(It->get());
+ // Don't bother importing if we can't inline it anyway.
+ if (Summary->fflags().NoInline && !ForceImportAll) {
+ Reason = FunctionImporter::ImportFailureReason::NoInline;
+ continue;
+ }
+
+ return Summary;
+ }
+ return nullptr;
}
namespace {
-using EdgeInfo =
- std::tuple<const GlobalValueSummary *, unsigned /* Threshold */>;
+using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */>;
} // anonymous namespace
-static bool shouldImportGlobal(const ValueInfo &VI,
- const GVSummaryMapTy &DefinedGVSummaries) {
- const auto &GVS = DefinedGVSummaries.find(VI.getGUID());
- if (GVS == DefinedGVSummaries.end())
- return true;
- // We should not skip import if the module contains a definition with
- // interposable linkage type. This is required for correctness in
- // the situation with two following conditions:
- // * the def with interposable linkage is non-prevailing,
- // * there is a prevailing def available for import and marked read-only.
- // In this case, the non-prevailing def will be converted to a declaration,
- // while the prevailing one becomes internal, thus no definitions will be
- // available for linking. In order to prevent undefined symbol link error,
- // the prevailing definition must be imported.
- // FIXME: Consider adding a check that the suitable prevailing definition
- // exists and marked read-only.
- if (VI.getSummaryList().size() > 1 &&
- GlobalValue::isInterposableLinkage(GVS->second->linkage()))
- return true;
-
- return false;
-}
+/// Import globals referenced by a function or other globals that are being
+/// imported, if importing such global is possible.
+class GlobalsImporter final {
+ const ModuleSummaryIndex &Index;
+ const GVSummaryMapTy &DefinedGVSummaries;
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ IsPrevailing;
+ FunctionImporter::ImportMapTy &ImportList;
+ StringMap<FunctionImporter::ExportSetTy> *const ExportLists;
+
+ bool shouldImportGlobal(const ValueInfo &VI) {
+ const auto &GVS = DefinedGVSummaries.find(VI.getGUID());
+ if (GVS == DefinedGVSummaries.end())
+ return true;
+ // We should not skip import if the module contains a non-prevailing
+ // definition with interposable linkage type. This is required for
+ // correctness in the situation where there is a prevailing def available
+ // for import and marked read-only. In this case, the non-prevailing def
+ // will be converted to a declaration, while the prevailing one becomes
+ // internal, thus no definitions will be available for linking. In order to
+ // prevent undefined symbol link error, the prevailing definition must be
+ // imported.
+ // FIXME: Consider adding a check that the suitable prevailing definition
+ // exists and marked read-only.
+ if (VI.getSummaryList().size() > 1 &&
+ GlobalValue::isInterposableLinkage(GVS->second->linkage()) &&
+ !IsPrevailing(VI.getGUID(), GVS->second))
+ return true;
-static void computeImportForReferencedGlobals(
- const GlobalValueSummary &Summary, const ModuleSummaryIndex &Index,
- const GVSummaryMapTy &DefinedGVSummaries,
- SmallVectorImpl<EdgeInfo> &Worklist,
- FunctionImporter::ImportMapTy &ImportList,
- StringMap<FunctionImporter::ExportSetTy> *ExportLists) {
- for (const auto &VI : Summary.refs()) {
- if (!shouldImportGlobal(VI, DefinedGVSummaries)) {
- LLVM_DEBUG(
- dbgs() << "Ref ignored! Target already in destination module.\n");
- continue;
- }
+ return false;
+ }
- LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n");
-
- // If this is a local variable, make sure we import the copy
- // in the caller's module. The only time a local variable can
- // share an entry in the index is if there is a local with the same name
- // in another module that had the same source file name (in a different
- // directory), where each was compiled in their own directory so there
- // was not distinguishing path.
- auto LocalNotInModule = [&](const GlobalValueSummary *RefSummary) -> bool {
- return GlobalValue::isLocalLinkage(RefSummary->linkage()) &&
- RefSummary->modulePath() != Summary.modulePath();
- };
+ void
+ onImportingSummaryImpl(const GlobalValueSummary &Summary,
+ SmallVectorImpl<const GlobalVarSummary *> &Worklist) {
+ for (const auto &VI : Summary.refs()) {
+ if (!shouldImportGlobal(VI)) {
+ LLVM_DEBUG(
+ dbgs() << "Ref ignored! Target already in destination module.\n");
+ continue;
+ }
- for (const auto &RefSummary : VI.getSummaryList())
- if (isa<GlobalVarSummary>(RefSummary.get()) &&
- Index.canImportGlobalVar(RefSummary.get(), /* AnalyzeRefs */ true) &&
- !LocalNotInModule(RefSummary.get())) {
+ LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n");
+
+ // If this is a local variable, make sure we import the copy
+ // in the caller's module. The only time a local variable can
+ // share an entry in the index is if there is a local with the same name
+ // in another module that had the same source file name (in a different
+ // directory), where each was compiled in their own directory so there
+ // was not distinguishing path.
+ auto LocalNotInModule =
+ [&](const GlobalValueSummary *RefSummary) -> bool {
+ return GlobalValue::isLocalLinkage(RefSummary->linkage()) &&
+ RefSummary->modulePath() != Summary.modulePath();
+ };
+
+ for (const auto &RefSummary : VI.getSummaryList()) {
+ const auto *GVS = dyn_cast<GlobalVarSummary>(RefSummary.get());
+ // Functions could be referenced by global vars - e.g. a vtable; but we
+ // don't currently imagine a reason those would be imported here, rather
+ // than as part of the logic deciding which functions to import (i.e.
+ // based on profile information). Should we decide to handle them here,
+ // we can refactor accordingly at that time.
+ if (!GVS || !Index.canImportGlobalVar(GVS, /* AnalyzeRefs */ true) ||
+ LocalNotInModule(GVS))
+ continue;
auto ILI = ImportList[RefSummary->modulePath()].insert(VI.getGUID());
// Only update stat and exports if we haven't already imported this
// variable.
if (!ILI.second)
break;
NumImportedGlobalVarsThinLink++;
- // Any references made by this variable will be marked exported later,
- // in ComputeCrossModuleImport, after import decisions are complete,
- // which is more efficient than adding them here.
+ // Any references made by this variable will be marked exported
+ // later, in ComputeCrossModuleImport, after import decisions are
+ // complete, which is more efficient than adding them here.
if (ExportLists)
(*ExportLists)[RefSummary->modulePath()].insert(VI);
// If variable is not writeonly we attempt to recursively analyze
// its references in order to import referenced constants.
- if (!Index.isWriteOnly(cast<GlobalVarSummary>(RefSummary.get())))
- Worklist.emplace_back(RefSummary.get(), 0);
+ if (!Index.isWriteOnly(GVS))
+ Worklist.emplace_back(GVS);
break;
}
+ }
}
-}
+
+public:
+ GlobalsImporter(
+ const ModuleSummaryIndex &Index, const GVSummaryMapTy &DefinedGVSummaries,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ IsPrevailing,
+ FunctionImporter::ImportMapTy &ImportList,
+ StringMap<FunctionImporter::ExportSetTy> *ExportLists)
+ : Index(Index), DefinedGVSummaries(DefinedGVSummaries),
+ IsPrevailing(IsPrevailing), ImportList(ImportList),
+ ExportLists(ExportLists) {}
+
+ void onImportingSummary(const GlobalValueSummary &Summary) {
+ SmallVector<const GlobalVarSummary *, 128> Worklist;
+ onImportingSummaryImpl(Summary, Worklist);
+ while (!Worklist.empty())
+ onImportingSummaryImpl(*Worklist.pop_back_val(), Worklist);
+ }
+};
static const char *
getFailureName(FunctionImporter::ImportFailureReason Reason) {
@@ -348,12 +399,13 @@ getFailureName(FunctionImporter::ImportFailureReason Reason) {
static void computeImportForFunction(
const FunctionSummary &Summary, const ModuleSummaryIndex &Index,
const unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries,
- SmallVectorImpl<EdgeInfo> &Worklist,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing,
+ SmallVectorImpl<EdgeInfo> &Worklist, GlobalsImporter &GVImporter,
FunctionImporter::ImportMapTy &ImportList,
StringMap<FunctionImporter::ExportSetTy> *ExportLists,
FunctionImporter::ImportThresholdsTy &ImportThresholds) {
- computeImportForReferencedGlobals(Summary, Index, DefinedGVSummaries,
- Worklist, ImportList, ExportLists);
+ GVImporter.onImportingSummary(Summary);
static int ImportCount = 0;
for (const auto &Edge : Summary.calls()) {
ValueInfo VI = Edge.first;
@@ -432,7 +484,7 @@ static void computeImportForFunction(
FunctionImporter::ImportFailureReason Reason;
CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold,
- Summary.modulePath(), Reason, VI.getGUID());
+ Summary.modulePath(), Reason);
if (!CalleeSummary) {
// Update with new larger threshold if this was a retry (otherwise
// we would have already inserted with NewThreshold above). Also
@@ -519,12 +571,17 @@ static void computeImportForFunction(
/// as well as the list of "exports", i.e. the list of symbols referenced from
/// another module (that may require promotion).
static void ComputeImportForModule(
- const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index,
- StringRef ModName, FunctionImporter::ImportMapTy &ImportList,
+ const GVSummaryMapTy &DefinedGVSummaries,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing,
+ const ModuleSummaryIndex &Index, StringRef ModName,
+ FunctionImporter::ImportMapTy &ImportList,
StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) {
// Worklist contains the list of function imported in this module, for which
// we will analyse the callees and may import further down the callgraph.
SmallVector<EdgeInfo, 128> Worklist;
+ GlobalsImporter GVI(Index, DefinedGVSummaries, isPrevailing, ImportList,
+ ExportLists);
FunctionImporter::ImportThresholdsTy ImportThresholds;
// Populate the worklist with the import for the functions in the current
@@ -546,8 +603,8 @@ static void ComputeImportForModule(
continue;
LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n");
computeImportForFunction(*FuncSummary, Index, ImportInstrLimit,
- DefinedGVSummaries, Worklist, ImportList,
- ExportLists, ImportThresholds);
+ DefinedGVSummaries, isPrevailing, Worklist, GVI,
+ ImportList, ExportLists, ImportThresholds);
}
// Process the newly imported functions and add callees to the worklist.
@@ -558,11 +615,8 @@ static void ComputeImportForModule(
if (auto *FS = dyn_cast<FunctionSummary>(Summary))
computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries,
- Worklist, ImportList, ExportLists,
- ImportThresholds);
- else
- computeImportForReferencedGlobals(*Summary, Index, DefinedGVSummaries,
- Worklist, ImportList, ExportLists);
+ isPrevailing, Worklist, GVI, ImportList,
+ ExportLists, ImportThresholds);
}
// Print stats about functions considered but rejected for importing
@@ -632,17 +686,23 @@ checkVariableImport(const ModuleSummaryIndex &Index,
// Checks that all GUIDs of read/writeonly vars we see in export lists
// are also in the import lists. Otherwise we my face linker undefs,
// because readonly and writeonly vars are internalized in their
- // source modules.
- auto IsReadOrWriteOnlyVar = [&](StringRef ModulePath, const ValueInfo &VI) {
+ // source modules. The exception would be if it has a linkage type indicating
+ // that there may have been a copy existing in the importing module (e.g.
+ // linkonce_odr). In that case we cannot accurately do this checking.
+ auto IsReadOrWriteOnlyVarNeedingImporting = [&](StringRef ModulePath,
+ const ValueInfo &VI) {
auto *GVS = dyn_cast_or_null<GlobalVarSummary>(
Index.findSummaryInModule(VI, ModulePath));
- return GVS && (Index.isReadOnly(GVS) || Index.isWriteOnly(GVS));
+ return GVS && (Index.isReadOnly(GVS) || Index.isWriteOnly(GVS)) &&
+ !(GVS->linkage() == GlobalValue::AvailableExternallyLinkage ||
+ GVS->linkage() == GlobalValue::WeakODRLinkage ||
+ GVS->linkage() == GlobalValue::LinkOnceODRLinkage);
};
for (auto &ExportPerModule : ExportLists)
for (auto &VI : ExportPerModule.second)
if (!FlattenedImports.count(VI.getGUID()) &&
- IsReadOrWriteOnlyVar(ExportPerModule.first(), VI))
+ IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first(), VI))
return false;
return true;
@@ -653,6 +713,8 @@ checkVariableImport(const ModuleSummaryIndex &Index,
void llvm::ComputeCrossModuleImport(
const ModuleSummaryIndex &Index,
const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing,
StringMap<FunctionImporter::ImportMapTy> &ImportLists,
StringMap<FunctionImporter::ExportSetTy> &ExportLists) {
// For each module that has function defined, compute the import/export lists.
@@ -660,7 +722,7 @@ void llvm::ComputeCrossModuleImport(
auto &ImportList = ImportLists[DefinedGVSummaries.first()];
LLVM_DEBUG(dbgs() << "Computing import for Module '"
<< DefinedGVSummaries.first() << "'\n");
- ComputeImportForModule(DefinedGVSummaries.second, Index,
+ ComputeImportForModule(DefinedGVSummaries.second, isPrevailing, Index,
DefinedGVSummaries.first(), ImportList,
&ExportLists);
}
@@ -759,7 +821,10 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index,
/// Compute all the imports for the given module in the Index.
void llvm::ComputeCrossModuleImportForModule(
- StringRef ModulePath, const ModuleSummaryIndex &Index,
+ StringRef ModulePath,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing,
+ const ModuleSummaryIndex &Index,
FunctionImporter::ImportMapTy &ImportList) {
// Collect the list of functions this module defines.
// GUID -> Summary
@@ -768,7 +833,8 @@ void llvm::ComputeCrossModuleImportForModule(
// Compute the import list for this module.
LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n");
- ComputeImportForModule(FunctionSummaryMap, Index, ModulePath, ImportList);
+ ComputeImportForModule(FunctionSummaryMap, isPrevailing, Index, ModulePath,
+ ImportList);
#ifndef NDEBUG
dumpImportListForModule(Index, ModulePath, ImportList);
@@ -1373,8 +1439,9 @@ Expected<bool> FunctionImporter::importFunctions(
if (Error Err = Mover.move(std::move(SrcModule),
GlobalsToImport.getArrayRef(), nullptr,
/*IsPerformingImport=*/true))
- report_fatal_error(Twine("Function Import: link error: ") +
- toString(std::move(Err)));
+ return createStringError(errc::invalid_argument,
+ Twine("Function Import: link error: ") +
+ toString(std::move(Err)));
ImportedCount += GlobalsToImport.size();
NumImportedModules++;
@@ -1394,7 +1461,9 @@ Expected<bool> FunctionImporter::importFunctions(
return ImportedCount;
}
-static bool doImportingForModule(Module &M) {
+static bool doImportingForModule(
+ Module &M, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing) {
if (SummaryFile.empty())
report_fatal_error("error: -function-import requires -summary-file\n");
Expected<std::unique_ptr<ModuleSummaryIndex>> IndexPtrOrErr =
@@ -1415,8 +1484,8 @@ static bool doImportingForModule(Module &M) {
ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index,
ImportList);
else
- ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index,
- ImportList);
+ ComputeCrossModuleImportForModule(M.getModuleIdentifier(), isPrevailing,
+ *Index, ImportList);
// Conservatively mark all internal values as promoted. This interface is
// only used when doing importing via the function importing pass. The pass
@@ -1434,7 +1503,7 @@ static bool doImportingForModule(Module &M) {
if (renameModuleForThinLTO(M, *Index, /*ClearDSOLocalOnDeclarations=*/false,
/*GlobalsToImport=*/nullptr)) {
errs() << "Error renaming module\n";
- return false;
+ return true;
}
// Perform the import now.
@@ -1449,15 +1518,22 @@ static bool doImportingForModule(Module &M) {
if (!Result) {
logAllUnhandledErrors(Result.takeError(), errs(),
"Error importing module: ");
- return false;
+ return true;
}
- return *Result;
+ return true;
}
PreservedAnalyses FunctionImportPass::run(Module &M,
ModuleAnalysisManager &AM) {
- if (!doImportingForModule(M))
+ // This is only used for testing the function import pass via opt, where we
+ // don't have prevailing information from the LTO context available, so just
+ // conservatively assume everything is prevailing (which is fine for the very
+ // limited use of prevailing checking in this pass).
+ auto isPrevailing = [](GlobalValue::GUID, const GlobalValueSummary *) {
+ return true;
+ };
+ if (!doImportingForModule(M, isPrevailing))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 4a7efb28e853..3d6c501e4596 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -48,11 +48,13 @@
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InlineCost.h"
-#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -64,42 +66,324 @@ using namespace llvm;
#define DEBUG_TYPE "function-specialization"
-STATISTIC(NumFuncSpecialized, "Number of functions specialized");
+STATISTIC(NumSpecsCreated, "Number of specializations created");
-static cl::opt<bool> ForceFunctionSpecialization(
- "force-function-specialization", cl::init(false), cl::Hidden,
- cl::desc("Force function specialization for every call site with a "
- "constant argument"));
+static cl::opt<bool> ForceSpecialization(
+ "force-specialization", cl::init(false), cl::Hidden, cl::desc(
+ "Force function specialization for every call site with a constant "
+ "argument"));
-static cl::opt<unsigned> MaxClonesThreshold(
- "func-specialization-max-clones", cl::Hidden,
- cl::desc("The maximum number of clones allowed for a single function "
- "specialization"),
- cl::init(3));
+static cl::opt<unsigned> MaxClones(
+ "funcspec-max-clones", cl::init(3), cl::Hidden, cl::desc(
+ "The maximum number of clones allowed for a single function "
+ "specialization"));
-static cl::opt<unsigned> SmallFunctionThreshold(
- "func-specialization-size-threshold", cl::Hidden,
- cl::desc("Don't specialize functions that have less than this theshold "
- "number of instructions"),
- cl::init(100));
+static cl::opt<unsigned> MaxIncomingPhiValues(
+ "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
+ "The maximum number of incoming values a PHI node can have to be "
+ "considered during the specialization bonus estimation"));
-static cl::opt<unsigned>
- AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
- cl::desc("Average loop iteration count cost"),
- cl::init(10));
+static cl::opt<unsigned> MinFunctionSize(
+ "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
+ "Don't specialize functions that have less than this number of "
+ "instructions"));
-static cl::opt<bool> SpecializeOnAddresses(
- "func-specialization-on-address", cl::init(false), cl::Hidden,
- cl::desc("Enable function specialization on the address of global values"));
+static cl::opt<bool> SpecializeOnAddress(
+ "funcspec-on-address", cl::init(false), cl::Hidden, cl::desc(
+ "Enable function specialization on the address of global values"));
// Disabled by default as it can significantly increase compilation times.
//
// https://llvm-compile-time-tracker.com
// https://github.com/nikic/llvm-compile-time-tracker
-static cl::opt<bool> EnableSpecializationForLiteralConstant(
- "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
- cl::desc("Enable specialization of functions that take a literal constant "
- "as an argument."));
+static cl::opt<bool> SpecializeLiteralConstant(
+ "funcspec-for-literal-constant", cl::init(false), cl::Hidden, cl::desc(
+ "Enable specialization of functions that take a literal constant as an "
+ "argument"));
+
+// Estimates the instruction cost of all the basic blocks in \p WorkList.
+// The successors of such blocks are added to the list as long as they are
+// executable and they have a unique predecessor. \p WorkList represents
+// the basic blocks of a specialization which become dead once we replace
+// instructions that are known to be constants. The aim here is to estimate
+// the combination of size and latency savings in comparison to the non
+// specialized version of the function.
+static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
+ DenseSet<BasicBlock *> &DeadBlocks,
+ ConstMap &KnownConstants, SCCPSolver &Solver,
+ BlockFrequencyInfo &BFI,
+ TargetTransformInfo &TTI) {
+ Cost Bonus = 0;
+
+ // Accumulate the instruction cost of each basic block weighted by frequency.
+ while (!WorkList.empty()) {
+ BasicBlock *BB = WorkList.pop_back_val();
+
+ uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
+ BFI.getEntryFreq();
+ if (!Weight)
+ continue;
+
+ // These blocks are considered dead as far as the InstCostVisitor is
+ // concerned. They haven't been proven dead yet by the Solver, but
+ // may become if we propagate the constant specialization arguments.
+ if (!DeadBlocks.insert(BB).second)
+ continue;
+
+ for (Instruction &I : *BB) {
+ // Disregard SSA copies.
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::ssa_copy)
+ continue;
+ // If it's a known constant we have already accounted for it.
+ if (KnownConstants.contains(&I))
+ continue;
+
+ Bonus += Weight *
+ TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
+ << " after user " << I << "\n");
+ }
+
+ // Keep adding dead successors to the list as long as they are
+ // executable and they have a unique predecessor.
+ for (BasicBlock *SuccBB : successors(BB))
+ if (Solver.isBlockExecutable(SuccBB) &&
+ SuccBB->getUniquePredecessor() == BB)
+ WorkList.push_back(SuccBB);
+ }
+ return Bonus;
+}
+
+static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
+ if (auto *C = dyn_cast<Constant>(V))
+ return C;
+ if (auto It = KnownConstants.find(V); It != KnownConstants.end())
+ return It->second;
+ return nullptr;
+}
+
+Cost InstCostVisitor::getBonusFromPendingPHIs() {
+ Cost Bonus = 0;
+ while (!PendingPHIs.empty()) {
+ Instruction *Phi = PendingPHIs.pop_back_val();
+ Bonus += getUserBonus(Phi);
+ }
+ return Bonus;
+}
+
+Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+ // Cache the iterator before visiting.
+ LastVisited = Use ? KnownConstants.insert({Use, C}).first
+ : KnownConstants.end();
+
+ if (auto *I = dyn_cast<SwitchInst>(User))
+ return estimateSwitchInst(*I);
+
+ if (auto *I = dyn_cast<BranchInst>(User))
+ return estimateBranchInst(*I);
+
+ C = visit(*User);
+ if (!C)
+ return 0;
+
+ KnownConstants.insert({User, C});
+
+ uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
+ BFI.getEntryFreq();
+ if (!Weight)
+ return 0;
+
+ Cost Bonus = Weight *
+ TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
+ << " for user " << *User << "\n");
+
+ for (auto *U : User->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ if (UI != User && Solver.isBlockExecutable(UI->getParent()))
+ Bonus += getUserBonus(UI, User, C);
+
+ return Bonus;
+}
+
+Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ if (I.getCondition() != LastVisited->first)
+ return 0;
+
+ auto *C = dyn_cast<ConstantInt>(LastVisited->second);
+ if (!C)
+ return 0;
+
+ BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor();
+ // Initialize the worklist with the dead basic blocks. These are the
+ // destination labels which are different from the one corresponding
+ // to \p C. They should be executable and have a unique predecessor.
+ SmallVector<BasicBlock *> WorkList;
+ for (const auto &Case : I.cases()) {
+ BasicBlock *BB = Case.getCaseSuccessor();
+ if (BB == Succ || !Solver.isBlockExecutable(BB) ||
+ BB->getUniquePredecessor() != I.getParent())
+ continue;
+ WorkList.push_back(BB);
+ }
+
+ return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+ TTI);
+}
+
+Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ if (I.getCondition() != LastVisited->first)
+ return 0;
+
+ BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
+ // Initialize the worklist with the dead successor as long as
+ // it is executable and has a unique predecessor.
+ SmallVector<BasicBlock *> WorkList;
+ if (Solver.isBlockExecutable(Succ) &&
+ Succ->getUniquePredecessor() == I.getParent())
+ WorkList.push_back(Succ);
+
+ return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+ TTI);
+}
+
+Constant *InstCostVisitor::visitPHINode(PHINode &I) {
+ if (I.getNumIncomingValues() > MaxIncomingPhiValues)
+ return nullptr;
+
+ bool Inserted = VisitedPHIs.insert(&I).second;
+ Constant *Const = nullptr;
+
+ for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
+ Value *V = I.getIncomingValue(Idx);
+ if (auto *Inst = dyn_cast<Instruction>(V))
+ if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
+ continue;
+ Constant *C = findConstantFor(V, KnownConstants);
+ if (!C) {
+ if (Inserted)
+ PendingPHIs.push_back(&I);
+ return nullptr;
+ }
+ if (!Const)
+ Const = C;
+ else if (C != Const)
+ return nullptr;
+ }
+ return Const;
+}
+
+Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
+ return LastVisited->second;
+ return nullptr;
+}
+
+Constant *InstCostVisitor::visitCallBase(CallBase &I) {
+ Function *F = I.getCalledFunction();
+ if (!F || !canConstantFoldCallTo(&I, F))
+ return nullptr;
+
+ SmallVector<Constant *, 8> Operands;
+ Operands.reserve(I.getNumOperands());
+
+ for (unsigned Idx = 0, E = I.getNumOperands() - 1; Idx != E; ++Idx) {
+ Value *V = I.getOperand(Idx);
+ Constant *C = findConstantFor(V, KnownConstants);
+ if (!C)
+ return nullptr;
+ Operands.push_back(C);
+ }
+
+ auto Ops = ArrayRef(Operands.begin(), Operands.end());
+ return ConstantFoldCall(&I, F, Ops);
+}
+
+Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ if (isa<ConstantPointerNull>(LastVisited->second))
+ return nullptr;
+ return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
+ SmallVector<Constant *, 8> Operands;
+ Operands.reserve(I.getNumOperands());
+
+ for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
+ Value *V = I.getOperand(Idx);
+ Constant *C = findConstantFor(V, KnownConstants);
+ if (!C)
+ return nullptr;
+ Operands.push_back(C);
+ }
+
+ auto Ops = ArrayRef(Operands.begin(), Operands.end());
+ return ConstantFoldInstOperands(&I, Ops, DL);
+}
+
+Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ if (I.getCondition() != LastVisited->first)
+ return nullptr;
+
+ Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+ : I.getTrueValue();
+ Constant *C = findConstantFor(V, KnownConstants);
+ return C;
+}
+
+Constant *InstCostVisitor::visitCastInst(CastInst &I) {
+ return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
+ I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ bool Swap = I.getOperand(1) == LastVisited->first;
+ Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+ Constant *Other = findConstantFor(V, KnownConstants);
+ if (!Other)
+ return nullptr;
+
+ Constant *Const = LastVisited->second;
+ return Swap ?
+ ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
+ : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
+}
+
+Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
+}
+
+Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
+ bool Swap = I.getOperand(1) == LastVisited->first;
+ Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+ Constant *Other = findConstantFor(V, KnownConstants);
+ if (!Other)
+ return nullptr;
+
+ Constant *Const = LastVisited->second;
+ return dyn_cast_or_null<Constant>(Swap ?
+ simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
+ : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
+}
Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
CallInst *Call) {
@@ -125,6 +409,10 @@ Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
// Bail if there is any other unknown usage.
return nullptr;
}
+
+ if (!StoreValue)
+ return nullptr;
+
return getCandidateConstant(StoreValue);
}
@@ -165,49 +453,37 @@ Constant *FunctionSpecializer::getConstantStackValue(CallInst *Call,
// ret void
// }
//
-void FunctionSpecializer::promoteConstantStackValues() {
- // Iterate over the argument tracked functions see if there
- // are any new constant values for the call instruction via
- // stack variables.
- for (Function &F : M) {
- if (!Solver.isArgumentTrackedFunction(&F))
+// See if there are any new constant values for the callers of \p F via
+// stack variables and promote them to global variables.
+void FunctionSpecializer::promoteConstantStackValues(Function *F) {
+ for (User *U : F->users()) {
+
+ auto *Call = dyn_cast<CallInst>(U);
+ if (!Call)
continue;
- for (auto *User : F.users()) {
+ if (!Solver.isBlockExecutable(Call->getParent()))
+ continue;
- auto *Call = dyn_cast<CallInst>(User);
- if (!Call)
- continue;
+ for (const Use &U : Call->args()) {
+ unsigned Idx = Call->getArgOperandNo(&U);
+ Value *ArgOp = Call->getArgOperand(Idx);
+ Type *ArgOpType = ArgOp->getType();
- if (!Solver.isBlockExecutable(Call->getParent()))
+ if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy())
continue;
- bool Changed = false;
- for (const Use &U : Call->args()) {
- unsigned Idx = Call->getArgOperandNo(&U);
- Value *ArgOp = Call->getArgOperand(Idx);
- Type *ArgOpType = ArgOp->getType();
-
- if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy())
- continue;
-
- auto *ConstVal = getConstantStackValue(Call, ArgOp);
- if (!ConstVal)
- continue;
-
- Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
- GlobalValue::InternalLinkage, ConstVal,
- "funcspec.arg");
- if (ArgOpType != ConstVal->getType())
- GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
+ auto *ConstVal = getConstantStackValue(Call, ArgOp);
+ if (!ConstVal)
+ continue;
- Call->setArgOperand(Idx, GV);
- Changed = true;
- }
+ Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
+ GlobalValue::InternalLinkage, ConstVal,
+ "funcspec.arg");
+ if (ArgOpType != ConstVal->getType())
+ GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
- // Add the changed CallInst to Solver Worklist
- if (Changed)
- Solver.visitCall(*Call);
+ Call->setArgOperand(Idx, GV);
}
}
}
@@ -230,7 +506,7 @@ static void removeSSACopy(Function &F) {
/// Remove any ssa_copy intrinsics that may have been introduced.
void FunctionSpecializer::cleanUpSSA() {
- for (Function *F : SpecializedFuncs)
+ for (Function *F : Specializations)
removeSSACopy(*F);
}
@@ -249,6 +525,16 @@ template <> struct llvm::DenseMapInfo<SpecSig> {
}
};
+FunctionSpecializer::~FunctionSpecializer() {
+ LLVM_DEBUG(
+ if (NumSpecsCreated > 0)
+ dbgs() << "FnSpecialization: Created " << NumSpecsCreated
+ << " specializations in module " << M.getName() << "\n");
+ // Eliminate dead code.
+ removeDeadFunctions();
+ cleanUpSSA();
+}
+
/// Attempt to specialize functions in the module to enable constant
/// propagation across function boundaries.
///
@@ -262,17 +548,37 @@ bool FunctionSpecializer::run() {
if (!isCandidateFunction(&F))
continue;
- auto Cost = getSpecializationCost(&F);
- if (!Cost.isValid()) {
- LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost for "
- << F.getName() << "\n");
- continue;
+ auto [It, Inserted] = FunctionMetrics.try_emplace(&F);
+ CodeMetrics &Metrics = It->second;
+ //Analyze the function.
+ if (Inserted) {
+ SmallPtrSet<const Value *, 32> EphValues;
+ CodeMetrics::collectEphemeralValues(&F, &GetAC(F), EphValues);
+ for (BasicBlock &BB : F)
+ Metrics.analyzeBasicBlock(&BB, GetTTI(F), EphValues);
}
+ // If the code metrics reveal that we shouldn't duplicate the function,
+ // or if the code size implies that this function is easy to get inlined,
+ // then we shouldn't specialize it.
+ if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() ||
+ (!ForceSpecialization && !F.hasFnAttribute(Attribute::NoInline) &&
+ Metrics.NumInsts < MinFunctionSize))
+ continue;
+
+ // TODO: For now only consider recursive functions when running multiple
+ // times. This should change if specialization on literal constants gets
+ // enabled.
+ if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant)
+ continue;
+
LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
- << F.getName() << " is " << Cost << "\n");
+ << F.getName() << " is " << Metrics.NumInsts << "\n");
+
+ if (Inserted && Metrics.isRecursive)
+ promoteConstantStackValues(&F);
- if (!findSpecializations(&F, Cost, AllSpecs, SM)) {
+ if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) {
LLVM_DEBUG(
dbgs() << "FnSpecialization: No possible specializations found for "
<< F.getName() << "\n");
@@ -292,11 +598,11 @@ bool FunctionSpecializer::run() {
// Choose the most profitable specialisations, which fit in the module
// specialization budget, which is derived from maximum number of
// specializations per specialization candidate function.
- auto CompareGain = [&AllSpecs](unsigned I, unsigned J) {
- return AllSpecs[I].Gain > AllSpecs[J].Gain;
+ auto CompareScore = [&AllSpecs](unsigned I, unsigned J) {
+ return AllSpecs[I].Score > AllSpecs[J].Score;
};
const unsigned NSpecs =
- std::min(NumCandidates * MaxClonesThreshold, unsigned(AllSpecs.size()));
+ std::min(NumCandidates * MaxClones, unsigned(AllSpecs.size()));
SmallVector<unsigned> BestSpecs(NSpecs + 1);
std::iota(BestSpecs.begin(), BestSpecs.begin() + NSpecs, 0);
if (AllSpecs.size() > NSpecs) {
@@ -305,11 +611,11 @@ bool FunctionSpecializer::run() {
<< "FnSpecialization: Specializing the "
<< NSpecs
<< " most profitable candidates.\n");
- std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareGain);
+ std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareScore);
for (unsigned I = NSpecs, N = AllSpecs.size(); I < N; ++I) {
BestSpecs[NSpecs] = I;
- std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain);
- std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain);
+ std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareScore);
+ std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareScore);
}
}
@@ -317,7 +623,7 @@ bool FunctionSpecializer::run() {
for (unsigned I = 0; I < NSpecs; ++I) {
const Spec &S = AllSpecs[BestSpecs[I]];
dbgs() << "FnSpecialization: Function " << S.F->getName()
- << " , gain " << S.Gain << "\n";
+ << " , score " << S.Score << "\n";
for (const ArgInfo &Arg : S.Sig.Args)
dbgs() << "FnSpecialization: FormalArg = "
<< Arg.Formal->getNameOrAsOperand()
@@ -353,12 +659,37 @@ bool FunctionSpecializer::run() {
updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End);
}
- promoteConstantStackValues();
- LLVM_DEBUG(if (NbFunctionsSpecialized) dbgs()
- << "FnSpecialization: Specialized " << NbFunctionsSpecialized
- << " functions in module " << M.getName() << "\n");
+ for (Function *F : Clones) {
+ if (F->getReturnType()->isVoidTy())
+ continue;
+ if (F->getReturnType()->isStructTy()) {
+ auto *STy = cast<StructType>(F->getReturnType());
+ if (!Solver.isStructLatticeConstant(F, STy))
+ continue;
+ } else {
+ auto It = Solver.getTrackedRetVals().find(F);
+ assert(It != Solver.getTrackedRetVals().end() &&
+ "Return value ought to be tracked");
+ if (SCCPSolver::isOverdefined(It->second))
+ continue;
+ }
+ for (User *U : F->users()) {
+ if (auto *CS = dyn_cast<CallBase>(U)) {
+ //The user instruction does not call our function.
+ if (CS->getCalledFunction() != F)
+ continue;
+ Solver.resetLatticeValueFor(CS);
+ }
+ }
+ }
+
+ // Rerun the solver to notify the users of the modified callsites.
+ Solver.solveWhileResolvedUndefs();
+
+ for (Function *F : OriginalFuncs)
+ if (FunctionMetrics[F].isRecursive)
+ promoteConstantStackValues(F);
- NumFuncSpecialized += NbFunctionsSpecialized;
return true;
}
@@ -373,24 +704,6 @@ void FunctionSpecializer::removeDeadFunctions() {
FullySpecialized.clear();
}
-// Compute the code metrics for function \p F.
-CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) {
- auto I = FunctionMetrics.insert({F, CodeMetrics()});
- CodeMetrics &Metrics = I.first->second;
- if (I.second) {
- // The code metrics were not cached.
- SmallPtrSet<const Value *, 32> EphValues;
- CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
- for (BasicBlock &BB : *F)
- Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
-
- LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
- << F->getName() << " is " << Metrics.NumInsts
- << " instructions\n");
- }
- return Metrics;
-}
-
/// Clone the function \p F and remove the ssa_copy intrinsics added by
/// the SCCPSolver in the cloned version.
static Function *cloneCandidateFunction(Function *F) {
@@ -400,13 +713,13 @@ static Function *cloneCandidateFunction(Function *F) {
return Clone;
}
-bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost,
+bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
SmallVectorImpl<Spec> &AllSpecs,
SpecMap &SM) {
// A mapping from a specialisation signature to the index of the respective
// entry in the all specialisation array. Used to ensure uniqueness of
// specialisations.
- DenseMap<SpecSig, unsigned> UM;
+ DenseMap<SpecSig, unsigned> UniqueSpecs;
// Get a list of interesting arguments.
SmallVector<Argument *> Args;
@@ -417,7 +730,6 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost,
if (Args.empty())
return false;
- bool Found = false;
for (User *U : F->users()) {
if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
continue;
@@ -454,7 +766,7 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost,
continue;
// Check if we have encountered the same specialisation already.
- if (auto It = UM.find(S); It != UM.end()) {
+ if (auto It = UniqueSpecs.find(S); It != UniqueSpecs.end()) {
// Existing specialisation. Add the call to the list to rewrite, unless
// it's a recursive call. A specialisation, generated because of a
// recursive call may end up as not the best specialisation for all
@@ -467,42 +779,42 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
- InstructionCost Gain = 0 - Cost;
+ Cost Score = 0;
+ InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
- Gain +=
- getSpecializationBonus(A.Formal, A.Actual, Solver.getLoopInfo(*F));
+ Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
+ Score += Visitor.getBonusFromPendingPHIs();
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
+ << Score << "\n");
// Discard unprofitable specialisations.
- if (!ForceFunctionSpecialization && Gain <= 0)
+ if (!ForceSpecialization && Score <= SpecCost)
continue;
// Create a new specialisation entry.
- auto &Spec = AllSpecs.emplace_back(F, S, Gain);
+ auto &Spec = AllSpecs.emplace_back(F, S, Score);
if (CS.getFunction() != F)
Spec.CallSites.push_back(&CS);
const unsigned Index = AllSpecs.size() - 1;
- UM[S] = Index;
+ UniqueSpecs[S] = Index;
if (auto [It, Inserted] = SM.try_emplace(F, Index, Index + 1); !Inserted)
It->second.second = Index + 1;
- Found = true;
}
}
- return Found;
+ return !UniqueSpecs.empty();
}
bool FunctionSpecializer::isCandidateFunction(Function *F) {
- if (F->isDeclaration())
+ if (F->isDeclaration() || F->arg_empty())
return false;
if (F->hasFnAttribute(Attribute::NoDuplicate))
return false;
- if (!Solver.isArgumentTrackedFunction(F))
- return false;
-
// Do not specialize the cloned function again.
- if (SpecializedFuncs.contains(F))
+ if (Specializations.contains(F))
return false;
// If we're optimizing the function for size, we shouldn't specialize it.
@@ -524,86 +836,50 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) {
return true;
}
-Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) {
+Function *FunctionSpecializer::createSpecialization(Function *F,
+ const SpecSig &S) {
Function *Clone = cloneCandidateFunction(F);
+ // The original function does not neccessarily have internal linkage, but the
+ // clone must.
+ Clone->setLinkage(GlobalValue::InternalLinkage);
+
// Initialize the lattice state of the arguments of the function clone,
// marking the argument on which we specialized the function constant
// with the given value.
- Solver.markArgInFuncSpecialization(Clone, S.Args);
-
- Solver.addArgumentTrackedFunction(Clone);
+ Solver.setLatticeValueForSpecializationArguments(Clone, S.Args);
Solver.markBlockExecutable(&Clone->front());
+ Solver.addArgumentTrackedFunction(Clone);
+ Solver.addTrackedFunction(Clone);
// Mark all the specialized functions
- SpecializedFuncs.insert(Clone);
- NbFunctionsSpecialized++;
+ Specializations.insert(Clone);
+ ++NumSpecsCreated;
return Clone;
}
-/// Compute and return the cost of specializing function \p F.
-InstructionCost FunctionSpecializer::getSpecializationCost(Function *F) {
- CodeMetrics &Metrics = analyzeFunction(F);
- // If the code metrics reveal that we shouldn't duplicate the function, we
- // shouldn't specialize it. Set the specialization cost to Invalid.
- // Or if the lines of codes implies that this function is easy to get
- // inlined so that we shouldn't specialize it.
- if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() ||
- (!ForceFunctionSpecialization &&
- !F->hasFnAttribute(Attribute::NoInline) &&
- Metrics.NumInsts < SmallFunctionThreshold))
- return InstructionCost::getInvalid();
-
- // Otherwise, set the specialization cost to be the cost of all the
- // instructions in the function.
- return Metrics.NumInsts * InlineConstants::getInstrCost();
-}
-
-static InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
- const LoopInfo &LI) {
- auto *I = dyn_cast_or_null<Instruction>(U);
- // If not an instruction we do not know how to evaluate.
- // Keep minimum possible cost for now so that it doesnt affect
- // specialization.
- if (!I)
- return std::numeric_limits<unsigned>::min();
-
- InstructionCost Cost =
- TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency);
-
- // Increase the cost if it is inside the loop.
- unsigned LoopDepth = LI.getLoopDepth(I->getParent());
- Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
-
- // Traverse recursively if there are more uses.
- // TODO: Any other instructions to be added here?
- if (I->mayReadFromMemory() || I->isCast())
- for (auto *User : I->users())
- Cost += getUserBonus(User, TTI, LI);
-
- return Cost;
-}
-
/// Compute a bonus for replacing argument \p A with constant \p C.
-InstructionCost
-FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
- const LoopInfo &LI) {
- Function *F = A->getParent();
- auto &TTI = (GetTTI)(*F);
+Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
+ InstCostVisitor &Visitor) {
LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
<< C->getNameOrAsOperand() << "\n");
- InstructionCost TotalCost = 0;
- for (auto *U : A->users()) {
- TotalCost += getUserBonus(U, TTI, LI);
- LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
- TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
- }
+ Cost TotalCost = 0;
+ for (auto *U : A->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ if (Solver.isBlockExecutable(UI->getParent()))
+ TotalCost += Visitor.getUserBonus(UI, A, C);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
+ << TotalCost << " for argument " << *A << "\n");
// The below heuristic is only concerned with exposing inlining
// opportunities via indirect call promotion. If the argument is not a
// (potentially casted) function pointer, give up.
+ //
+ // TODO: Perhaps we should consider checking such inlining opportunities
+ // while traversing the users of the specialization arguments ?
Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts());
if (!CalledFunction)
return TotalCost;
@@ -661,16 +937,9 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
if (A->user_empty())
return false;
- // For now, don't attempt to specialize functions based on the values of
- // composite types.
- Type *ArgTy = A->getType();
- if (!ArgTy->isSingleValueType())
- return false;
-
- // Specialization of integer and floating point types needs to be explicitly
- // enabled.
- if (!EnableSpecializationForLiteralConstant &&
- (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy()))
+ Type *Ty = A->getType();
+ if (!Ty->isPointerTy() && (!SpecializeLiteralConstant ||
+ (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isStructTy())))
return false;
// SCCP solver does not record an argument that will be constructed on
@@ -678,54 +947,46 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
if (A->hasByValAttr() && !A->getParent()->onlyReadsMemory())
return false;
+ // For non-argument-tracked functions every argument is overdefined.
+ if (!Solver.isArgumentTrackedFunction(A->getParent()))
+ return true;
+
// Check the lattice value and decide if we should attemt to specialize,
// based on this argument. No point in specialization, if the lattice value
// is already a constant.
- const ValueLatticeElement &LV = Solver.getLatticeValueFor(A);
- if (LV.isUnknownOrUndef() || LV.isConstant() ||
- (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) {
- LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter "
- << A->getNameOrAsOperand() << " is already constant\n");
- return false;
- }
-
- LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter "
- << A->getNameOrAsOperand() << "\n");
-
- return true;
+ bool IsOverdefined = Ty->isStructTy()
+ ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined)
+ : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A));
+
+ LLVM_DEBUG(
+ if (IsOverdefined)
+ dbgs() << "FnSpecialization: Found interesting parameter "
+ << A->getNameOrAsOperand() << "\n";
+ else
+ dbgs() << "FnSpecialization: Nothing to do, parameter "
+ << A->getNameOrAsOperand() << " is already constant\n";
+ );
+ return IsOverdefined;
}
-/// Check if the valuy \p V (an actual argument) is a constant or can only
+/// Check if the value \p V (an actual argument) is a constant or can only
/// have a constant value. Return that constant.
Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
if (isa<PoisonValue>(V))
return nullptr;
- // TrackValueOfGlobalVariable only tracks scalar global variables.
- if (auto *GV = dyn_cast<GlobalVariable>(V)) {
- // Check if we want to specialize on the address of non-constant
- // global values.
- if (!GV->isConstant() && !SpecializeOnAddresses)
- return nullptr;
-
- if (!GV->getValueType()->isSingleValueType())
- return nullptr;
- }
-
// Select for possible specialisation values that are constants or
// are deduced to be constants or constant ranges with a single element.
Constant *C = dyn_cast<Constant>(V);
- if (!C) {
- const ValueLatticeElement &LV = Solver.getLatticeValueFor(V);
- if (LV.isConstant())
- C = LV.getConstant();
- else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) {
- assert(V->getType()->isIntegerTy() && "Non-integral constant range");
- C = Constant::getIntegerValue(V->getType(),
- *LV.getConstantRange().getSingleElement());
- } else
+ if (!C)
+ C = Solver.getConstantOrNull(V);
+
+ // Don't specialize on (anything derived from) the address of a non-constant
+ // global variable, unless explicitly enabled.
+ if (C && C->getType()->isPointerTy() && !C->isNullValue())
+ if (auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(C));
+ GV && !(GV->isConstant() || SpecializeOnAddress))
return nullptr;
- }
return C;
}
@@ -747,7 +1008,7 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin,
// Find the best matching specialisation.
const Spec *BestSpec = nullptr;
for (const Spec &S : make_range(Begin, End)) {
- if (!S.Clone || (BestSpec && S.Gain <= BestSpec->Gain))
+ if (!S.Clone || (BestSpec && S.Score <= BestSpec->Score))
continue;
if (any_of(S.Sig.Args, [CS, this](const ArgInfo &Arg) {
@@ -772,7 +1033,7 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin,
// If the function has been completely specialized, the original function
// is no longer needed. Mark it unreachable.
- if (NCallsLeft == 0) {
+ if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F)) {
Solver.markFunctionUnreachable(F);
FullySpecialized.insert(F);
}
diff --git a/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/llvm/lib/Transforms/IPO/GlobalDCE.cpp
index 2f2bb174a8c8..e36d524d7667 100644
--- a/llvm/lib/Transforms/IPO/GlobalDCE.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalDCE.cpp
@@ -21,8 +21,6 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/CtorUtils.h"
@@ -42,47 +40,6 @@ STATISTIC(NumIFuncs, "Number of indirect functions removed");
STATISTIC(NumVariables, "Number of global variables removed");
STATISTIC(NumVFuncs, "Number of virtual functions removed");
-namespace {
- class GlobalDCELegacyPass : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- GlobalDCELegacyPass() : ModulePass(ID) {
- initializeGlobalDCELegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- // run - Do the GlobalDCE pass on the specified module, optionally updating
- // the specified callgraph to reflect the changes.
- //
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- // We need a minimally functional dummy module analysis manager. It needs
- // to at least know about the possibility of proxying a function analysis
- // manager.
- FunctionAnalysisManager DummyFAM;
- ModuleAnalysisManager DummyMAM;
- DummyMAM.registerPass(
- [&] { return FunctionAnalysisManagerModuleProxy(DummyFAM); });
-
- auto PA = Impl.run(M, DummyMAM);
- return !PA.areAllPreserved();
- }
-
- private:
- GlobalDCEPass Impl;
- };
-}
-
-char GlobalDCELegacyPass::ID = 0;
-INITIALIZE_PASS(GlobalDCELegacyPass, "globaldce",
- "Dead Global Elimination", false, false)
-
-// Public interface to the GlobalDCEPass.
-ModulePass *llvm::createGlobalDCEPass() {
- return new GlobalDCELegacyPass();
-}
-
/// Returns true if F is effectively empty.
static bool isEmptyFunction(Function *F) {
// Skip external functions.
@@ -163,12 +120,6 @@ void GlobalDCEPass::ScanVTables(Module &M) {
SmallVector<MDNode *, 2> Types;
LLVM_DEBUG(dbgs() << "Building type info -> vtable map\n");
- auto *LTOPostLinkMD =
- cast_or_null<ConstantAsMetadata>(M.getModuleFlag("LTOPostLink"));
- bool LTOPostLink =
- LTOPostLinkMD &&
- (cast<ConstantInt>(LTOPostLinkMD->getValue())->getZExtValue() != 0);
-
for (GlobalVariable &GV : M.globals()) {
Types.clear();
GV.getMetadata(LLVMContext::MD_type, Types);
@@ -195,7 +146,7 @@ void GlobalDCEPass::ScanVTables(Module &M) {
if (auto GO = dyn_cast<GlobalObject>(&GV)) {
GlobalObject::VCallVisibility TypeVis = GO->getVCallVisibility();
if (TypeVis == GlobalObject::VCallVisibilityTranslationUnit ||
- (LTOPostLink &&
+ (InLTOPostLink &&
TypeVis == GlobalObject::VCallVisibilityLinkageUnit)) {
LLVM_DEBUG(dbgs() << GV.getName() << " is safe for VFE\n");
VFESafeVTables.insert(&GV);
@@ -236,29 +187,36 @@ void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) {
LLVM_DEBUG(dbgs() << "Scanning type.checked.load intrinsics\n");
Function *TypeCheckedLoadFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
-
- if (!TypeCheckedLoadFunc)
- return;
-
- for (auto *U : TypeCheckedLoadFunc->users()) {
- auto CI = dyn_cast<CallInst>(U);
- if (!CI)
- continue;
-
- auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
- Value *TypeIdValue = CI->getArgOperand(2);
- auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
-
- if (Offset) {
- ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue());
- } else {
- // type.checked.load with a non-constant offset, so assume every entry in
- // every matching vtable is used.
- for (const auto &VTableInfo : TypeIdMap[TypeId]) {
- VFESafeVTables.erase(VTableInfo.first);
+ Function *TypeCheckedLoadRelativeFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative));
+
+ auto scan = [&](Function *CheckedLoadFunc) {
+ if (!CheckedLoadFunc)
+ return;
+
+ for (auto *U : CheckedLoadFunc->users()) {
+ auto CI = dyn_cast<CallInst>(U);
+ if (!CI)
+ continue;
+
+ auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+ Value *TypeIdValue = CI->getArgOperand(2);
+ auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
+
+ if (Offset) {
+ ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue());
+ } else {
+ // type.checked.load with a non-constant offset, so assume every entry
+ // in every matching vtable is used.
+ for (const auto &VTableInfo : TypeIdMap[TypeId]) {
+ VFESafeVTables.erase(VTableInfo.first);
+ }
}
}
- }
+ };
+
+ scan(TypeCheckedLoadFunc);
+ scan(TypeCheckedLoadRelativeFunc);
}
void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) {
@@ -271,7 +229,7 @@ void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) {
// Don't attempt VFE in that case.
auto *Val = mdconst::dyn_extract_or_null<ConstantInt>(
M.getModuleFlag("Virtual Function Elim"));
- if (!Val || Val->getZExtValue() == 0)
+ if (!Val || Val->isZero())
return;
ScanVTables(M);
@@ -458,3 +416,11 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
+
+void GlobalDCEPass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<GlobalDCEPass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ if (InLTOPostLink)
+ OS << "<vfe-linkage-unit-visibility>";
+}
diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
index 0317a8bcb6bc..1ccc523ead8a 100644
--- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
@@ -53,8 +53,6 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
@@ -206,8 +204,10 @@ CleanupPointerRootUsers(GlobalVariable *GV,
// chain of computation and the store to the global in Dead[n].second.
SmallVector<std::pair<Instruction *, Instruction *>, 32> Dead;
+ SmallVector<User *> Worklist(GV->users());
// Constants can't be pointers to dynamically allocated memory.
- for (User *U : llvm::make_early_inc_range(GV->users())) {
+ while (!Worklist.empty()) {
+ User *U = Worklist.pop_back_val();
if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
Value *V = SI->getValueOperand();
if (isa<Constant>(V)) {
@@ -235,18 +235,8 @@ CleanupPointerRootUsers(GlobalVariable *GV,
Dead.push_back(std::make_pair(I, MTI));
}
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
- if (CE->use_empty()) {
- CE->destroyConstant();
- Changed = true;
- }
- } else if (Constant *C = dyn_cast<Constant>(U)) {
- if (isSafeToDestroyConstant(C)) {
- C->destroyConstant();
- // This could have invalidated UI, start over from scratch.
- Dead.clear();
- CleanupPointerRootUsers(GV, GetTLI);
- return true;
- }
+ if (isa<GEPOperator>(CE))
+ append_range(Worklist, CE->users());
}
}
@@ -268,6 +258,7 @@ CleanupPointerRootUsers(GlobalVariable *GV,
}
}
+ GV->removeDeadConstantUsers();
return Changed;
}
@@ -335,10 +326,19 @@ static bool CleanupConstantGlobalUsers(GlobalVariable *GV,
return Changed;
}
+/// Part of the global at a specific offset, which is only accessed through
+/// loads and stores with the given type.
+struct GlobalPart {
+ Type *Ty;
+ Constant *Initializer = nullptr;
+ bool IsLoaded = false;
+ bool IsStored = false;
+};
+
/// Look at all uses of the global and determine which (offset, type) pairs it
/// can be split into.
-static bool collectSRATypes(DenseMap<uint64_t, Type *> &Types, GlobalValue *GV,
- const DataLayout &DL) {
+static bool collectSRATypes(DenseMap<uint64_t, GlobalPart> &Parts,
+ GlobalVariable *GV, const DataLayout &DL) {
SmallVector<Use *, 16> Worklist;
SmallPtrSet<Use *, 16> Visited;
auto AppendUses = [&](Value *V) {
@@ -373,14 +373,41 @@ static bool collectSRATypes(DenseMap<uint64_t, Type *> &Types, GlobalValue *GV,
// TODO: We currently require that all accesses at a given offset must
// use the same type. This could be relaxed.
Type *Ty = getLoadStoreType(V);
- auto It = Types.try_emplace(Offset.getZExtValue(), Ty).first;
- if (Ty != It->second)
+ const auto &[It, Inserted] =
+ Parts.try_emplace(Offset.getZExtValue(), GlobalPart{Ty});
+ if (Ty != It->second.Ty)
return false;
+ if (Inserted) {
+ It->second.Initializer =
+ ConstantFoldLoadFromConst(GV->getInitializer(), Ty, Offset, DL);
+ if (!It->second.Initializer) {
+ LLVM_DEBUG(dbgs() << "Global SRA: Failed to evaluate initializer of "
+ << *GV << " with type " << *Ty << " at offset "
+ << Offset.getZExtValue());
+ return false;
+ }
+ }
+
// Scalable types not currently supported.
if (isa<ScalableVectorType>(Ty))
return false;
+ auto IsStored = [](Value *V, Constant *Initializer) {
+ auto *SI = dyn_cast<StoreInst>(V);
+ if (!SI)
+ return false;
+
+ Constant *StoredConst = dyn_cast<Constant>(SI->getOperand(0));
+ if (!StoredConst)
+ return true;
+
+ // Don't consider stores that only write the initializer value.
+ return Initializer != StoredConst;
+ };
+
+ It->second.IsLoaded |= isa<LoadInst>(V);
+ It->second.IsStored |= IsStored(V, It->second.Initializer);
continue;
}
@@ -410,6 +437,7 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV,
DIExpression *Expr = GVE->getExpression();
int64_t CurVarOffsetInBytes = 0;
uint64_t CurVarOffsetInBits = 0;
+ uint64_t FragmentEndInBits = FragmentOffsetInBits + FragmentSizeInBits;
// Calculate the offset (Bytes), Continue if unknown.
if (!Expr->extractIfOffset(CurVarOffsetInBytes))
@@ -423,27 +451,50 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV,
CurVarOffsetInBits = CHAR_BIT * (uint64_t)CurVarOffsetInBytes;
// Current var starts after the fragment, ignore.
- if (CurVarOffsetInBits >= (FragmentOffsetInBits + FragmentSizeInBits))
+ if (CurVarOffsetInBits >= FragmentEndInBits)
continue;
uint64_t CurVarSize = Var->getType()->getSizeInBits();
+ uint64_t CurVarEndInBits = CurVarOffsetInBits + CurVarSize;
// Current variable ends before start of fragment, ignore.
- if (CurVarSize != 0 &&
- (CurVarOffsetInBits + CurVarSize) <= FragmentOffsetInBits)
+ if (CurVarSize != 0 && /* CurVarSize is known */
+ CurVarEndInBits <= FragmentOffsetInBits)
continue;
- // Current variable fits in the fragment.
- if (CurVarOffsetInBits == FragmentOffsetInBits &&
- CurVarSize == FragmentSizeInBits)
- Expr = DIExpression::get(Expr->getContext(), {});
- // If the FragmentSize is smaller than the variable,
+ // Current variable fits in (not greater than) the fragment,
+ // does not need fragment expression.
+ if (CurVarSize != 0 && /* CurVarSize is known */
+ CurVarOffsetInBits >= FragmentOffsetInBits &&
+ CurVarEndInBits <= FragmentEndInBits) {
+ uint64_t CurVarOffsetInFragment =
+ (CurVarOffsetInBits - FragmentOffsetInBits) / 8;
+ if (CurVarOffsetInFragment != 0)
+ Expr = DIExpression::get(Expr->getContext(), {dwarf::DW_OP_plus_uconst,
+ CurVarOffsetInFragment});
+ else
+ Expr = DIExpression::get(Expr->getContext(), {});
+ auto *NGVE =
+ DIGlobalVariableExpression::get(GVE->getContext(), Var, Expr);
+ NGV->addDebugInfo(NGVE);
+ continue;
+ }
+ // Current variable does not fit in single fragment,
// emit a fragment expression.
- else if (FragmentSizeInBits < VarSize) {
+ if (FragmentSizeInBits < VarSize) {
+ if (CurVarOffsetInBits > FragmentOffsetInBits)
+ continue;
+ uint64_t CurVarFragmentOffsetInBits =
+ FragmentOffsetInBits - CurVarOffsetInBits;
+ uint64_t CurVarFragmentSizeInBits = FragmentSizeInBits;
+ if (CurVarSize != 0 && CurVarEndInBits < FragmentEndInBits)
+ CurVarFragmentSizeInBits -= (FragmentEndInBits - CurVarEndInBits);
+ if (CurVarOffsetInBits)
+ Expr = DIExpression::get(Expr->getContext(), {});
if (auto E = DIExpression::createFragmentExpression(
- Expr, FragmentOffsetInBits, FragmentSizeInBits))
+ Expr, CurVarFragmentOffsetInBits, CurVarFragmentSizeInBits))
Expr = *E;
else
- return;
+ continue;
}
auto *NGVE = DIGlobalVariableExpression::get(GVE->getContext(), Var, Expr);
NGV->addDebugInfo(NGVE);
@@ -459,52 +510,45 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
assert(GV->hasLocalLinkage());
// Collect types to split into.
- DenseMap<uint64_t, Type *> Types;
- if (!collectSRATypes(Types, GV, DL) || Types.empty())
+ DenseMap<uint64_t, GlobalPart> Parts;
+ if (!collectSRATypes(Parts, GV, DL) || Parts.empty())
return nullptr;
// Make sure we don't SRA back to the same type.
- if (Types.size() == 1 && Types.begin()->second == GV->getValueType())
+ if (Parts.size() == 1 && Parts.begin()->second.Ty == GV->getValueType())
return nullptr;
- // Don't perform SRA if we would have to split into many globals.
- if (Types.size() > 16)
+ // Don't perform SRA if we would have to split into many globals. Ignore
+ // parts that are either only loaded or only stored, because we expect them
+ // to be optimized away.
+ unsigned NumParts = count_if(Parts, [](const auto &Pair) {
+ return Pair.second.IsLoaded && Pair.second.IsStored;
+ });
+ if (NumParts > 16)
return nullptr;
// Sort by offset.
- SmallVector<std::pair<uint64_t, Type *>, 16> TypesVector;
- append_range(TypesVector, Types);
+ SmallVector<std::tuple<uint64_t, Type *, Constant *>, 16> TypesVector;
+ for (const auto &Pair : Parts) {
+ TypesVector.push_back(
+ {Pair.first, Pair.second.Ty, Pair.second.Initializer});
+ }
sort(TypesVector, llvm::less_first());
// Check that the types are non-overlapping.
uint64_t Offset = 0;
- for (const auto &Pair : TypesVector) {
+ for (const auto &[OffsetForTy, Ty, _] : TypesVector) {
// Overlaps with previous type.
- if (Pair.first < Offset)
+ if (OffsetForTy < Offset)
return nullptr;
- Offset = Pair.first + DL.getTypeAllocSize(Pair.second);
+ Offset = OffsetForTy + DL.getTypeAllocSize(Ty);
}
// Some accesses go beyond the end of the global, don't bother.
if (Offset > DL.getTypeAllocSize(GV->getValueType()))
return nullptr;
- // Collect initializers for new globals.
- Constant *OrigInit = GV->getInitializer();
- DenseMap<uint64_t, Constant *> Initializers;
- for (const auto &Pair : Types) {
- Constant *NewInit = ConstantFoldLoadFromConst(OrigInit, Pair.second,
- APInt(64, Pair.first), DL);
- if (!NewInit) {
- LLVM_DEBUG(dbgs() << "Global SRA: Failed to evaluate initializer of "
- << *GV << " with type " << *Pair.second << " at offset "
- << Pair.first << "\n");
- return nullptr;
- }
- Initializers.insert({Pair.first, NewInit});
- }
-
LLVM_DEBUG(dbgs() << "PERFORMING GLOBAL SRA ON: " << *GV << "\n");
// Get the alignment of the global, either explicit or target-specific.
@@ -515,26 +559,24 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
// Create replacement globals.
DenseMap<uint64_t, GlobalVariable *> NewGlobals;
unsigned NameSuffix = 0;
- for (auto &Pair : TypesVector) {
- uint64_t Offset = Pair.first;
- Type *Ty = Pair.second;
+ for (auto &[OffsetForTy, Ty, Initializer] : TypesVector) {
GlobalVariable *NGV = new GlobalVariable(
*GV->getParent(), Ty, false, GlobalVariable::InternalLinkage,
- Initializers[Offset], GV->getName() + "." + Twine(NameSuffix++), GV,
+ Initializer, GV->getName() + "." + Twine(NameSuffix++), GV,
GV->getThreadLocalMode(), GV->getAddressSpace());
NGV->copyAttributesFrom(GV);
- NewGlobals.insert({Offset, NGV});
+ NewGlobals.insert({OffsetForTy, NGV});
// Calculate the known alignment of the field. If the original aggregate
// had 256 byte alignment for example, something might depend on that:
// propagate info to each field.
- Align NewAlign = commonAlignment(StartAlignment, Offset);
+ Align NewAlign = commonAlignment(StartAlignment, OffsetForTy);
if (NewAlign > DL.getABITypeAlign(Ty))
NGV->setAlignment(NewAlign);
// Copy over the debug info for the variable.
- transferSRADebugInfo(GV, NGV, Offset * 8, DL.getTypeAllocSizeInBits(Ty),
- VarSize);
+ transferSRADebugInfo(GV, NGV, OffsetForTy * 8,
+ DL.getTypeAllocSizeInBits(Ty), VarSize);
}
// Replace uses of the original global with uses of the new global.
@@ -621,8 +663,9 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V,
if (II->getCalledOperand() != V) {
return false; // Not calling the ptr
}
- } else if (const BitCastInst *CI = dyn_cast<BitCastInst>(U)) {
- if (!AllUsesOfValueWillTrapIfNull(CI, PHIs)) return false;
+ } else if (const AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(U)) {
+ if (!AllUsesOfValueWillTrapIfNull(CI, PHIs))
+ return false;
} else if (const GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) {
if (!AllUsesOfValueWillTrapIfNull(GEPI, PHIs)) return false;
} else if (const PHINode *PN = dyn_cast<PHINode>(U)) {
@@ -735,10 +778,9 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) {
UI = V->user_begin();
}
}
- } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
- Changed |= OptimizeAwayTrappingUsesOfValue(CI,
- ConstantExpr::getCast(CI->getOpcode(),
- NewV, CI->getType()));
+ } else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(I)) {
+ Changed |= OptimizeAwayTrappingUsesOfValue(
+ CI, ConstantExpr::getAddrSpaceCast(NewV, CI->getType()));
if (CI->use_empty()) {
Changed = true;
CI->eraseFromParent();
@@ -803,7 +845,8 @@ static bool OptimizeAwayTrappingUsesOfLoads(
assert((isa<PHINode>(GlobalUser) || isa<SelectInst>(GlobalUser) ||
isa<ConstantExpr>(GlobalUser) || isa<CmpInst>(GlobalUser) ||
isa<BitCastInst>(GlobalUser) ||
- isa<GetElementPtrInst>(GlobalUser)) &&
+ isa<GetElementPtrInst>(GlobalUser) ||
+ isa<AddrSpaceCastInst>(GlobalUser)) &&
"Only expect load and stores!");
}
}
@@ -976,7 +1019,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI,
cast<StoreInst>(InitBool->user_back())->eraseFromParent();
delete InitBool;
} else
- GV->getParent()->getGlobalList().insert(GV->getIterator(), InitBool);
+ GV->getParent()->insertGlobalVariable(GV->getIterator(), InitBool);
// Now the GV is dead, nuke it and the allocation..
GV->eraseFromParent();
@@ -1103,9 +1146,6 @@ optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal,
nullptr /* F */,
GV->getInitializer()->getType()->getPointerAddressSpace())) {
if (Constant *SOVC = dyn_cast<Constant>(StoredOnceVal)) {
- if (GV->getInitializer()->getType() != SOVC->getType())
- SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType());
-
// Optimize away any trapping uses of the loaded value.
if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, GetTLI))
return true;
@@ -1158,7 +1198,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) {
GV->getThreadLocalMode(),
GV->getType()->getAddressSpace());
NewGV->copyAttributesFrom(GV);
- GV->getParent()->getGlobalList().insert(GV->getIterator(), NewGV);
+ GV->getParent()->insertGlobalVariable(GV->getIterator(), NewGV);
Constant *InitVal = GV->getInitializer();
assert(InitVal->getType() != Type::getInt1Ty(GV->getContext()) &&
@@ -1330,18 +1370,6 @@ static bool isPointerValueDeadOnEntryToFunction(
SmallVector<LoadInst *, 4> Loads;
SmallVector<StoreInst *, 4> Stores;
for (auto *U : GV->users()) {
- if (Operator::getOpcode(U) == Instruction::BitCast) {
- for (auto *UU : U->users()) {
- if (auto *LI = dyn_cast<LoadInst>(UU))
- Loads.push_back(LI);
- else if (auto *SI = dyn_cast<StoreInst>(UU))
- Stores.push_back(SI);
- else
- return false;
- }
- continue;
- }
-
Instruction *I = dyn_cast<Instruction>(U);
if (!I)
return false;
@@ -1391,62 +1419,6 @@ static bool isPointerValueDeadOnEntryToFunction(
return true;
}
-/// C may have non-instruction users. Can all of those users be turned into
-/// instructions?
-static bool allNonInstructionUsersCanBeMadeInstructions(Constant *C) {
- // We don't do this exhaustively. The most common pattern that we really need
- // to care about is a constant GEP or constant bitcast - so just looking
- // through one single ConstantExpr.
- //
- // The set of constants that this function returns true for must be able to be
- // handled by makeAllConstantUsesInstructions.
- for (auto *U : C->users()) {
- if (isa<Instruction>(U))
- continue;
- if (!isa<ConstantExpr>(U))
- // Non instruction, non-constantexpr user; cannot convert this.
- return false;
- for (auto *UU : U->users())
- if (!isa<Instruction>(UU))
- // A constantexpr used by another constant. We don't try and recurse any
- // further but just bail out at this point.
- return false;
- }
-
- return true;
-}
-
-/// C may have non-instruction users, and
-/// allNonInstructionUsersCanBeMadeInstructions has returned true. Convert the
-/// non-instruction users to instructions.
-static void makeAllConstantUsesInstructions(Constant *C) {
- SmallVector<ConstantExpr*,4> Users;
- for (auto *U : C->users()) {
- if (isa<ConstantExpr>(U))
- Users.push_back(cast<ConstantExpr>(U));
- else
- // We should never get here; allNonInstructionUsersCanBeMadeInstructions
- // should not have returned true for C.
- assert(
- isa<Instruction>(U) &&
- "Can't transform non-constantexpr non-instruction to instruction!");
- }
-
- SmallVector<Value*,4> UUsers;
- for (auto *U : Users) {
- UUsers.clear();
- append_range(UUsers, U->users());
- for (auto *UU : UUsers) {
- Instruction *UI = cast<Instruction>(UU);
- Instruction *NewU = U->getAsInstruction(UI);
- UI->replaceUsesOfWith(U, NewU);
- }
- // We've replaced all the uses, so destroy the constant. (destroyConstant
- // will update value handles and metadata.)
- U->destroyConstant();
- }
-}
-
// For a global variable with one store, if the store dominates any loads,
// those loads will always load the stored value (as opposed to the
// initializer), even in the presence of recursion.
@@ -1504,7 +1476,6 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
GV->getValueType()->isSingleValueType() &&
GV->getType()->getAddressSpace() == 0 &&
!GV->isExternallyInitialized() &&
- allNonInstructionUsersCanBeMadeInstructions(GV) &&
GS.AccessingFunction->doesNotRecurse() &&
isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV,
LookupDomTree)) {
@@ -1520,8 +1491,6 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
if (!isa<UndefValue>(GV->getInitializer()))
new StoreInst(GV->getInitializer(), Alloca, &FirstI);
- makeAllConstantUsesInstructions(GV);
-
GV->replaceAllUsesWith(Alloca);
GV->eraseFromParent();
++NumLocalized;
@@ -2142,15 +2111,22 @@ static void setUsedInitializer(GlobalVariable &V,
return;
}
+ // Get address space of pointers in the array of pointers.
+ const Type *UsedArrayType = V.getValueType();
+ const auto *VAT = cast<ArrayType>(UsedArrayType);
+ const auto *VEPT = cast<PointerType>(VAT->getArrayElementType());
+
// Type of pointer to the array of pointers.
- PointerType *Int8PtrTy = Type::getInt8PtrTy(V.getContext(), 0);
+ PointerType *Int8PtrTy =
+ Type::getInt8PtrTy(V.getContext(), VEPT->getAddressSpace());
SmallVector<Constant *, 8> UsedArray;
for (GlobalValue *GV : Init) {
- Constant *Cast
- = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy);
+ Constant *Cast =
+ ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy);
UsedArray.push_back(Cast);
}
+
// Sort to get deterministic order.
array_pod_sort(UsedArray.begin(), UsedArray.end(), compareNames);
ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size());
@@ -2241,22 +2217,11 @@ static bool hasUseOtherThanLLVMUsed(GlobalAlias &GA, const LLVMUsed &U) {
return !U.usedCount(&GA) && !U.compilerUsedCount(&GA);
}
-static bool hasMoreThanOneUseOtherThanLLVMUsed(GlobalValue &V,
- const LLVMUsed &U) {
- unsigned N = 2;
- assert((!U.usedCount(&V) || !U.compilerUsedCount(&V)) &&
- "We should have removed the duplicated "
- "element from llvm.compiler.used");
- if (U.usedCount(&V) || U.compilerUsedCount(&V))
- ++N;
- return V.hasNUsesOrMore(N);
-}
-
-static bool mayHaveOtherReferences(GlobalAlias &GA, const LLVMUsed &U) {
- if (!GA.hasLocalLinkage())
+static bool mayHaveOtherReferences(GlobalValue &GV, const LLVMUsed &U) {
+ if (!GV.hasLocalLinkage())
return true;
- return U.usedCount(&GA) || U.compilerUsedCount(&GA);
+ return U.usedCount(&GV) || U.compilerUsedCount(&GV);
}
static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U,
@@ -2270,21 +2235,16 @@ static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U,
if (!mayHaveOtherReferences(GA, U))
return Ret;
- // If the aliasee has internal linkage, give it the name and linkage
- // of the alias, and delete the alias. This turns:
+ // If the aliasee has internal linkage and no other references (e.g.,
+ // @llvm.used, @llvm.compiler.used), give it the name and linkage of the
+ // alias, and delete the alias. This turns:
// define internal ... @f(...)
// @a = alias ... @f
// into:
// define ... @a(...)
Constant *Aliasee = GA.getAliasee();
GlobalValue *Target = cast<GlobalValue>(Aliasee->stripPointerCasts());
- if (!Target->hasLocalLinkage())
- return Ret;
-
- // Do not perform the transform if multiple aliases potentially target the
- // aliasee. This check also ensures that it is safe to replace the section
- // and other attributes of the aliasee with those of the alias.
- if (hasMoreThanOneUseOtherThanLLVMUsed(*Target, U))
+ if (mayHaveOtherReferences(*Target, U))
return Ret;
RenameTarget = true;
@@ -2360,7 +2320,7 @@ OptimizeGlobalAliases(Module &M,
continue;
// Delete the alias.
- M.getAliasList().erase(&J);
+ M.eraseAlias(&J);
++NumAliasesRemoved;
Changed = true;
}
@@ -2562,65 +2522,3 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) {
PA.preserveSet<CFGAnalyses>();
return PA;
}
-
-namespace {
-
-struct GlobalOptLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
-
- GlobalOptLegacyPass() : ModulePass(ID) {
- initializeGlobalOptLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- auto &DL = M.getDataLayout();
- auto LookupDomTree = [this](Function &F) -> DominatorTree & {
- return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
- };
- auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
- return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
- auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
- return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- };
-
- auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
- return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
- };
-
- auto ChangedCFGCallback = [&LookupDomTree](Function &F) {
- auto &DT = LookupDomTree(F);
- DT.recalculate(F);
- };
-
- return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree,
- ChangedCFGCallback, nullptr);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<BlockFrequencyInfoWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-char GlobalOptLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(GlobalOptLegacyPass, "globalopt",
- "Global Variable Optimizer", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_END(GlobalOptLegacyPass, "globalopt",
- "Global Variable Optimizer", false, false)
-
-ModulePass *llvm::createGlobalOptimizerPass() {
- return new GlobalOptLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/llvm/lib/Transforms/IPO/GlobalSplit.cpp
index 7d9e6135b2eb..84e9c219f935 100644
--- a/llvm/lib/Transforms/IPO/GlobalSplit.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalSplit.cpp
@@ -29,8 +29,6 @@
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/IPO.h"
#include <cstdint>
@@ -149,8 +147,12 @@ static bool splitGlobals(Module &M) {
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
Function *TypeCheckedLoadFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+ Function *TypeCheckedLoadRelativeFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative));
if ((!TypeTestFunc || TypeTestFunc->use_empty()) &&
- (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
+ (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
+ (!TypeCheckedLoadRelativeFunc ||
+ TypeCheckedLoadRelativeFunc->use_empty()))
return false;
bool Changed = false;
@@ -159,33 +161,6 @@ static bool splitGlobals(Module &M) {
return Changed;
}
-namespace {
-
-struct GlobalSplit : public ModulePass {
- static char ID;
-
- GlobalSplit() : ModulePass(ID) {
- initializeGlobalSplitPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- return splitGlobals(M);
- }
-};
-
-} // end anonymous namespace
-
-char GlobalSplit::ID = 0;
-
-INITIALIZE_PASS(GlobalSplit, "globalsplit", "Global splitter", false, false)
-
-ModulePass *llvm::createGlobalSplitPass() {
- return new GlobalSplit;
-}
-
PreservedAnalyses GlobalSplitPass::run(Module &M, ModuleAnalysisManager &AM) {
if (!splitGlobals(M))
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 95e8ae0fd22f..599ace9ca79f 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -46,8 +46,6 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -169,23 +167,6 @@ static bool markFunctionCold(Function &F, bool UpdateEntryCount = false) {
return Changed;
}
-class HotColdSplittingLegacyPass : public ModulePass {
-public:
- static char ID;
- HotColdSplittingLegacyPass() : ModulePass(ID) {
- initializeHotColdSplittingLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<BlockFrequencyInfoWrapperPass>();
- AU.addRequired<ProfileSummaryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addUsedIfAvailable<AssumptionCacheTracker>();
- }
-
- bool runOnModule(Module &M) override;
-};
-
} // end anonymous namespace
/// Check whether \p F is inherently cold.
@@ -713,32 +694,6 @@ bool HotColdSplitting::run(Module &M) {
return Changed;
}
-bool HotColdSplittingLegacyPass::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
- ProfileSummaryInfo *PSI =
- &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
- auto GTTI = [this](Function &F) -> TargetTransformInfo & {
- return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- };
- auto GBFI = [this](Function &F) {
- return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
- };
- std::unique_ptr<OptimizationRemarkEmitter> ORE;
- std::function<OptimizationRemarkEmitter &(Function &)> GetORE =
- [&ORE](Function &F) -> OptimizationRemarkEmitter & {
- ORE.reset(new OptimizationRemarkEmitter(&F));
- return *ORE;
- };
- auto LookupAC = [this](Function &F) -> AssumptionCache * {
- if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
- return ACT->lookupAssumptionCache(F);
- return nullptr;
- };
-
- return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M);
-}
-
PreservedAnalyses
HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
@@ -769,15 +724,3 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
-
-char HotColdSplittingLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(HotColdSplittingLegacyPass, "hotcoldsplit",
- "Hot Cold Splitting", false, false)
-INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
-INITIALIZE_PASS_END(HotColdSplittingLegacyPass, "hotcoldsplit",
- "Hot Cold Splitting", false, false)
-
-ModulePass *llvm::createHotColdSplittingPass() {
- return new HotColdSplittingLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/IPO.cpp b/llvm/lib/Transforms/IPO/IPO.cpp
index 4163c448dc8f..5ad1289277a7 100644
--- a/llvm/lib/Transforms/IPO/IPO.cpp
+++ b/llvm/lib/Transforms/IPO/IPO.cpp
@@ -12,9 +12,6 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm-c/Transforms/IPO.h"
-#include "llvm-c/Initialization.h"
-#include "llvm/IR/LegacyPassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
@@ -23,104 +20,10 @@
using namespace llvm;
void llvm::initializeIPO(PassRegistry &Registry) {
- initializeAnnotation2MetadataLegacyPass(Registry);
- initializeCalledValuePropagationLegacyPassPass(Registry);
- initializeConstantMergeLegacyPassPass(Registry);
- initializeCrossDSOCFIPass(Registry);
initializeDAEPass(Registry);
initializeDAHPass(Registry);
- initializeForceFunctionAttrsLegacyPassPass(Registry);
- initializeGlobalDCELegacyPassPass(Registry);
- initializeGlobalOptLegacyPassPass(Registry);
- initializeGlobalSplitPass(Registry);
- initializeHotColdSplittingLegacyPassPass(Registry);
- initializeIROutlinerLegacyPassPass(Registry);
initializeAlwaysInlinerLegacyPassPass(Registry);
- initializeSimpleInlinerPass(Registry);
- initializeInferFunctionAttrsLegacyPassPass(Registry);
- initializeInternalizeLegacyPassPass(Registry);
initializeLoopExtractorLegacyPassPass(Registry);
initializeSingleLoopExtractorPass(Registry);
- initializeMergeFunctionsLegacyPassPass(Registry);
- initializePartialInlinerLegacyPassPass(Registry);
- initializeAttributorLegacyPassPass(Registry);
- initializeAttributorCGSCCLegacyPassPass(Registry);
- initializePostOrderFunctionAttrsLegacyPassPass(Registry);
- initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry);
- initializeIPSCCPLegacyPassPass(Registry);
- initializeStripDeadPrototypesLegacyPassPass(Registry);
- initializeStripSymbolsPass(Registry);
- initializeStripDebugDeclarePass(Registry);
- initializeStripDeadDebugInfoPass(Registry);
- initializeStripNonDebugSymbolsPass(Registry);
initializeBarrierNoopPass(Registry);
- initializeEliminateAvailableExternallyLegacyPassPass(Registry);
-}
-
-void LLVMInitializeIPO(LLVMPassRegistryRef R) {
- initializeIPO(*unwrap(R));
-}
-
-void LLVMAddCalledValuePropagationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createCalledValuePropagationPass());
-}
-
-void LLVMAddConstantMergePass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createConstantMergePass());
-}
-
-void LLVMAddDeadArgEliminationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createDeadArgEliminationPass());
-}
-
-void LLVMAddFunctionAttrsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createPostOrderFunctionAttrsLegacyPass());
-}
-
-void LLVMAddFunctionInliningPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createFunctionInliningPass());
-}
-
-void LLVMAddAlwaysInlinerPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(llvm::createAlwaysInlinerLegacyPass());
-}
-
-void LLVMAddGlobalDCEPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createGlobalDCEPass());
-}
-
-void LLVMAddGlobalOptimizerPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createGlobalOptimizerPass());
-}
-
-void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createIPSCCPPass());
-}
-
-void LLVMAddMergeFunctionsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createMergeFunctionsPass());
-}
-
-void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) {
- auto PreserveMain = [=](const GlobalValue &GV) {
- return AllButMain && GV.getName() == "main";
- };
- unwrap(PM)->add(createInternalizePass(PreserveMain));
-}
-
-void LLVMAddInternalizePassWithMustPreservePredicate(
- LLVMPassManagerRef PM,
- void *Context,
- LLVMBool (*Pred)(LLVMValueRef, void *)) {
- unwrap(PM)->add(createInternalizePass([=](const GlobalValue &GV) {
- return Pred(wrap(&GV), Context) == 0 ? false : true;
- }));
-}
-
-void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createStripDeadPrototypesPass());
-}
-
-void LLVMAddStripSymbolsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createStripSymbolsPass());
}
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp
index f5c52e5c7f5d..e258299c6a4c 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -22,8 +22,6 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Mangler.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
#include <optional>
@@ -179,10 +177,8 @@ static void getSortedConstantKeys(std::vector<Value *> &SortedKeys,
stable_sort(SortedKeys, [](const Value *LHS, const Value *RHS) {
assert(LHS && RHS && "Expected non void values.");
- const ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS);
- const ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS);
- assert(RHSC && "Not a constant integer in return value?");
- assert(LHSC && "Not a constant integer in return value?");
+ const ConstantInt *LHSC = cast<ConstantInt>(LHS);
+ const ConstantInt *RHSC = cast<ConstantInt>(RHS);
return LHSC->getLimitedValue() < RHSC->getLimitedValue();
});
@@ -590,7 +586,7 @@ collectRegionsConstants(OutlinableRegion &Region,
// While this value is a register, it might not have been previously,
// make sure we don't already have a constant mapped to this global value
// number.
- if (GVNToConstant.find(GVN) != GVNToConstant.end())
+ if (GVNToConstant.contains(GVN))
ConstantsTheSame = false;
NotSame.insert(GVN);
@@ -818,7 +814,7 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C,
// replacement.
for (Value *Input : CurrentInputs) {
assert(Input && "Have a nullptr as an input");
- if (OutputMappings.find(Input) != OutputMappings.end())
+ if (OutputMappings.contains(Input))
Input = OutputMappings.find(Input)->second;
assert(C.getGVN(Input) && "Could not find a numbering for the given input");
EndInputNumbers.push_back(*C.getGVN(Input));
@@ -840,7 +836,7 @@ remapExtractedInputs(const ArrayRef<Value *> ArgInputs,
// Get the global value number for each input that will be extracted as an
// argument by the code extractor, remapping if needed for reloaded values.
for (Value *Input : ArgInputs) {
- if (OutputMappings.find(Input) != OutputMappings.end())
+ if (OutputMappings.contains(Input))
Input = OutputMappings.find(Input)->second;
RemappedArgInputs.insert(Input);
}
@@ -1332,7 +1328,7 @@ findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region,
unsigned AggArgIdx = 0;
for (unsigned Jdx = TypeIndex; Jdx < ArgumentSize; Jdx++) {
- if (Group.ArgumentTypes[Jdx] != PointerType::getUnqual(Output->getType()))
+ if (!isa<PointerType>(Group.ArgumentTypes[Jdx]))
continue;
if (AggArgsUsed.contains(Jdx))
@@ -1483,8 +1479,7 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) {
}
// If it is a constant, we simply add it to the argument list as a value.
- if (Region.AggArgToConstant.find(AggArgIdx) !=
- Region.AggArgToConstant.end()) {
+ if (Region.AggArgToConstant.contains(AggArgIdx)) {
Constant *CST = Region.AggArgToConstant.find(AggArgIdx)->second;
LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value "
<< *CST << "\n");
@@ -1818,8 +1813,7 @@ replaceArgumentUses(OutlinableRegion &Region,
for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size();
ArgIdx++) {
- assert(Region.ExtractedArgToAgg.find(ArgIdx) !=
- Region.ExtractedArgToAgg.end() &&
+ assert(Region.ExtractedArgToAgg.contains(ArgIdx) &&
"No mapping from extracted to outlined?");
unsigned AggArgIdx = Region.ExtractedArgToAgg.find(ArgIdx)->second;
Argument *AggArg = Group.OutlinedFunction->getArg(AggArgIdx);
@@ -2700,7 +2694,7 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region,
if (!OutputIdx)
return;
- if (OutputMappings.find(Outputs[*OutputIdx]) == OutputMappings.end()) {
+ if (!OutputMappings.contains(Outputs[*OutputIdx])) {
LLVM_DEBUG(dbgs() << "Mapping extracted output " << *LI << " to "
<< *Outputs[*OutputIdx] << "\n");
OutputMappings.insert(std::make_pair(LI, Outputs[*OutputIdx]));
@@ -3024,46 +3018,6 @@ bool IROutliner::run(Module &M) {
return doOutline(M) > 0;
}
-// Pass Manager Boilerplate
-namespace {
-class IROutlinerLegacyPass : public ModulePass {
-public:
- static char ID;
- IROutlinerLegacyPass() : ModulePass(ID) {
- initializeIROutlinerLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<IRSimilarityIdentifierWrapperPass>();
- }
-
- bool runOnModule(Module &M) override;
-};
-} // namespace
-
-bool IROutlinerLegacyPass::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
-
- std::unique_ptr<OptimizationRemarkEmitter> ORE;
- auto GORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & {
- ORE.reset(new OptimizationRemarkEmitter(&F));
- return *ORE;
- };
-
- auto GTTI = [this](Function &F) -> TargetTransformInfo & {
- return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- };
-
- auto GIRSI = [this](Module &) -> IRSimilarityIdentifier & {
- return this->getAnalysis<IRSimilarityIdentifierWrapperPass>().getIRSI();
- };
-
- return IROutliner(GTTI, GIRSI, GORE).run(M);
-}
-
PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
@@ -3088,14 +3042,3 @@ PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) {
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
-
-char IROutlinerLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(IRSimilarityIdentifierWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false,
- false)
-
-ModulePass *llvm::createIROutlinerPass() { return new IROutlinerLegacyPass(); }
diff --git a/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp
index 76f8f1a7a482..18d5911d10f1 100644
--- a/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp
@@ -10,7 +10,6 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
@@ -52,38 +51,3 @@ PreservedAnalyses InferFunctionAttrsPass::run(Module &M,
// out all the passes.
return PreservedAnalyses::none();
}
-
-namespace {
-struct InferFunctionAttrsLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
- InferFunctionAttrsLegacyPass() : ModulePass(ID) {
- initializeInferFunctionAttrsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
- return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
- return inferAllPrototypeAttributes(M, GetTLI);
- }
-};
-}
-
-char InferFunctionAttrsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(InferFunctionAttrsLegacyPass, "inferattrs",
- "Infer set function attributes", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(InferFunctionAttrsLegacyPass, "inferattrs",
- "Infer set function attributes", false, false)
-
-Pass *llvm::createInferFunctionAttrsLegacyPass() {
- return new InferFunctionAttrsLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/InlineSimple.cpp b/llvm/lib/Transforms/IPO/InlineSimple.cpp
deleted file mode 100644
index eba0d6636d6c..000000000000
--- a/llvm/lib/Transforms/IPO/InlineSimple.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- InlineSimple.cpp - Code to perform simple function inlining --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements bottom-up inlining of functions into callees.
-//
-//===----------------------------------------------------------------------===//
-
-#include "llvm/Analysis/AssumptionCache.h"
-#include "llvm/Analysis/InlineCost.h"
-#include "llvm/Analysis/OptimizationRemarkEmitter.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Transforms/IPO.h"
-#include "llvm/Transforms/IPO/Inliner.h"
-
-using namespace llvm;
-
-#define DEBUG_TYPE "inline"
-
-namespace {
-
-/// Actual inliner pass implementation.
-///
-/// The common implementation of the inlining logic is shared between this
-/// inliner pass and the always inliner pass. The two passes use different cost
-/// analyses to determine when to inline.
-class SimpleInliner : public LegacyInlinerBase {
-
- InlineParams Params;
-
-public:
- SimpleInliner() : LegacyInlinerBase(ID), Params(llvm::getInlineParams()) {
- initializeSimpleInlinerPass(*PassRegistry::getPassRegistry());
- }
-
- explicit SimpleInliner(InlineParams Params)
- : LegacyInlinerBase(ID), Params(std::move(Params)) {
- initializeSimpleInlinerPass(*PassRegistry::getPassRegistry());
- }
-
- static char ID; // Pass identification, replacement for typeid
-
- InlineCost getInlineCost(CallBase &CB) override {
- Function *Callee = CB.getCalledFunction();
- TargetTransformInfo &TTI = TTIWP->getTTI(*Callee);
-
- bool RemarksEnabled = false;
- const auto &BBs = *CB.getCaller();
- if (!BBs.empty()) {
- auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBs.front());
- if (DI.isEnabled())
- RemarksEnabled = true;
- }
- OptimizationRemarkEmitter ORE(CB.getCaller());
-
- std::function<AssumptionCache &(Function &)> GetAssumptionCache =
- [&](Function &F) -> AssumptionCache & {
- return ACT->getAssumptionCache(F);
- };
- return llvm::getInlineCost(CB, Params, TTI, GetAssumptionCache, GetTLI,
- /*GetBFI=*/nullptr, PSI,
- RemarksEnabled ? &ORE : nullptr);
- }
-
- bool runOnSCC(CallGraphSCC &SCC) override;
- void getAnalysisUsage(AnalysisUsage &AU) const override;
-
-private:
- TargetTransformInfoWrapperPass *TTIWP;
-
-};
-
-} // end anonymous namespace
-
-char SimpleInliner::ID = 0;
-INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", "Function Integration/Inlining",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(SimpleInliner, "inline", "Function Integration/Inlining",
- false, false)
-
-Pass *llvm::createFunctionInliningPass() { return new SimpleInliner(); }
-
-Pass *llvm::createFunctionInliningPass(int Threshold) {
- return new SimpleInliner(llvm::getInlineParams(Threshold));
-}
-
-Pass *llvm::createFunctionInliningPass(unsigned OptLevel,
- unsigned SizeOptLevel,
- bool DisableInlineHotCallSite) {
- auto Param = llvm::getInlineParams(OptLevel, SizeOptLevel);
- if (DisableInlineHotCallSite)
- Param.HotCallSiteThreshold = 0;
- return new SimpleInliner(Param);
-}
-
-Pass *llvm::createFunctionInliningPass(InlineParams &Params) {
- return new SimpleInliner(Params);
-}
-
-bool SimpleInliner::runOnSCC(CallGraphSCC &SCC) {
- TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>();
- return LegacyInlinerBase::runOnSCC(SCC);
-}
-
-void SimpleInliner::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- LegacyInlinerBase::getAnalysisUsage(AU);
-}
diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp
index 5bcfc38c585b..3e00aebce372 100644
--- a/llvm/lib/Transforms/IPO/Inliner.cpp
+++ b/llvm/lib/Transforms/IPO/Inliner.cpp
@@ -27,7 +27,6 @@
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CGSCCPassManager.h"
-#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/InlineAdvisor.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/LazyCallGraph.h"
@@ -71,20 +70,7 @@ using namespace llvm;
#define DEBUG_TYPE "inline"
STATISTIC(NumInlined, "Number of functions inlined");
-STATISTIC(NumCallsDeleted, "Number of call sites deleted, not inlined");
STATISTIC(NumDeleted, "Number of functions deleted because all callers found");
-STATISTIC(NumMergedAllocas, "Number of allocas merged together");
-
-/// Flag to disable manual alloca merging.
-///
-/// Merging of allocas was originally done as a stack-size saving technique
-/// prior to LLVM's code generator having support for stack coloring based on
-/// lifetime markers. It is now in the process of being removed. To experiment
-/// with disabling it and relying fully on lifetime marker based stack
-/// coloring, you can pass this flag to LLVM.
-static cl::opt<bool>
- DisableInlinedAllocaMerging("disable-inlined-alloca-merging",
- cl::init(false), cl::Hidden);
static cl::opt<int> IntraSCCCostMultiplier(
"intra-scc-cost-multiplier", cl::init(2), cl::Hidden,
@@ -108,9 +94,6 @@ static cl::opt<bool>
EnablePostSCCAdvisorPrinting("enable-scc-inline-advisor-printing",
cl::init(false), cl::Hidden);
-namespace llvm {
-extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
-}
static cl::opt<std::string> CGSCCInlineReplayFile(
"cgscc-inline-replay", cl::init(""), cl::value_desc("filename"),
@@ -163,174 +146,6 @@ static cl::opt<CallSiteFormat::Format> CGSCCInlineReplayFormat(
"<Line Number>:<Column Number>.<Discriminator> (default)")),
cl::desc("How cgscc inline replay file is formatted"), cl::Hidden);
-LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {}
-
-LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime)
- : CallGraphSCCPass(ID), InsertLifetime(InsertLifetime) {}
-
-/// For this class, we declare that we require and preserve the call graph.
-/// If the derived class implements this method, it should
-/// always explicitly call the implementation here.
-void LegacyInlinerBase::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<ProfileSummaryInfoWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- getAAResultsAnalysisUsage(AU);
- CallGraphSCCPass::getAnalysisUsage(AU);
-}
-
-using InlinedArrayAllocasTy = DenseMap<ArrayType *, std::vector<AllocaInst *>>;
-
-/// Look at all of the allocas that we inlined through this call site. If we
-/// have already inlined other allocas through other calls into this function,
-/// then we know that they have disjoint lifetimes and that we can merge them.
-///
-/// There are many heuristics possible for merging these allocas, and the
-/// different options have different tradeoffs. One thing that we *really*
-/// don't want to hurt is SRoA: once inlining happens, often allocas are no
-/// longer address taken and so they can be promoted.
-///
-/// Our "solution" for that is to only merge allocas whose outermost type is an
-/// array type. These are usually not promoted because someone is using a
-/// variable index into them. These are also often the most important ones to
-/// merge.
-///
-/// A better solution would be to have real memory lifetime markers in the IR
-/// and not have the inliner do any merging of allocas at all. This would
-/// allow the backend to do proper stack slot coloring of all allocas that
-/// *actually make it to the backend*, which is really what we want.
-///
-/// Because we don't have this information, we do this simple and useful hack.
-static void mergeInlinedArrayAllocas(Function *Caller, InlineFunctionInfo &IFI,
- InlinedArrayAllocasTy &InlinedArrayAllocas,
- int InlineHistory) {
- SmallPtrSet<AllocaInst *, 16> UsedAllocas;
-
- // When processing our SCC, check to see if the call site was inlined from
- // some other call site. For example, if we're processing "A" in this code:
- // A() { B() }
- // B() { x = alloca ... C() }
- // C() { y = alloca ... }
- // Assume that C was not inlined into B initially, and so we're processing A
- // and decide to inline B into A. Doing this makes an alloca available for
- // reuse and makes a callsite (C) available for inlining. When we process
- // the C call site we don't want to do any alloca merging between X and Y
- // because their scopes are not disjoint. We could make this smarter by
- // keeping track of the inline history for each alloca in the
- // InlinedArrayAllocas but this isn't likely to be a significant win.
- if (InlineHistory != -1) // Only do merging for top-level call sites in SCC.
- return;
-
- // Loop over all the allocas we have so far and see if they can be merged with
- // a previously inlined alloca. If not, remember that we had it.
- for (unsigned AllocaNo = 0, E = IFI.StaticAllocas.size(); AllocaNo != E;
- ++AllocaNo) {
- AllocaInst *AI = IFI.StaticAllocas[AllocaNo];
-
- // Don't bother trying to merge array allocations (they will usually be
- // canonicalized to be an allocation *of* an array), or allocations whose
- // type is not itself an array (because we're afraid of pessimizing SRoA).
- ArrayType *ATy = dyn_cast<ArrayType>(AI->getAllocatedType());
- if (!ATy || AI->isArrayAllocation())
- continue;
-
- // Get the list of all available allocas for this array type.
- std::vector<AllocaInst *> &AllocasForType = InlinedArrayAllocas[ATy];
-
- // Loop over the allocas in AllocasForType to see if we can reuse one. Note
- // that we have to be careful not to reuse the same "available" alloca for
- // multiple different allocas that we just inlined, we use the 'UsedAllocas'
- // set to keep track of which "available" allocas are being used by this
- // function. Also, AllocasForType can be empty of course!
- bool MergedAwayAlloca = false;
- for (AllocaInst *AvailableAlloca : AllocasForType) {
- Align Align1 = AI->getAlign();
- Align Align2 = AvailableAlloca->getAlign();
-
- // The available alloca has to be in the right function, not in some other
- // function in this SCC.
- if (AvailableAlloca->getParent() != AI->getParent())
- continue;
-
- // If the inlined function already uses this alloca then we can't reuse
- // it.
- if (!UsedAllocas.insert(AvailableAlloca).second)
- continue;
-
- // Otherwise, we *can* reuse it, RAUW AI into AvailableAlloca and declare
- // success!
- LLVM_DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI
- << "\n\t\tINTO: " << *AvailableAlloca << '\n');
-
- // Move affected dbg.declare calls immediately after the new alloca to
- // avoid the situation when a dbg.declare precedes its alloca.
- if (auto *L = LocalAsMetadata::getIfExists(AI))
- if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L))
- for (User *U : MDV->users())
- if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U))
- DDI->moveBefore(AvailableAlloca->getNextNode());
-
- AI->replaceAllUsesWith(AvailableAlloca);
-
- if (Align1 > Align2)
- AvailableAlloca->setAlignment(AI->getAlign());
-
- AI->eraseFromParent();
- MergedAwayAlloca = true;
- ++NumMergedAllocas;
- IFI.StaticAllocas[AllocaNo] = nullptr;
- break;
- }
-
- // If we already nuked the alloca, we're done with it.
- if (MergedAwayAlloca)
- continue;
-
- // If we were unable to merge away the alloca either because there are no
- // allocas of the right type available or because we reused them all
- // already, remember that this alloca came from an inlined function and mark
- // it used so we don't reuse it for other allocas from this inline
- // operation.
- AllocasForType.push_back(AI);
- UsedAllocas.insert(AI);
- }
-}
-
-/// If it is possible to inline the specified call site,
-/// do so and update the CallGraph for this operation.
-///
-/// This function also does some basic book-keeping to update the IR. The
-/// InlinedArrayAllocas map keeps track of any allocas that are already
-/// available from other functions inlined into the caller. If we are able to
-/// inline this call site we attempt to reuse already available allocas or add
-/// any new allocas to the set if not possible.
-static InlineResult inlineCallIfPossible(
- CallBase &CB, InlineFunctionInfo &IFI,
- InlinedArrayAllocasTy &InlinedArrayAllocas, int InlineHistory,
- bool InsertLifetime, function_ref<AAResults &(Function &)> &AARGetter,
- ImportedFunctionsInliningStatistics &ImportedFunctionsStats) {
- Function *Callee = CB.getCalledFunction();
- Function *Caller = CB.getCaller();
-
- AAResults &AAR = AARGetter(*Callee);
-
- // Try to inline the function. Get the list of static allocas that were
- // inlined.
- InlineResult IR =
- InlineFunction(CB, IFI,
- /*MergeAttributes=*/true, &AAR, InsertLifetime);
- if (!IR.isSuccess())
- return IR;
-
- if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No)
- ImportedFunctionsStats.recordInline(*Caller, *Callee);
-
- if (!DisableInlinedAllocaMerging)
- mergeInlinedArrayAllocas(Caller, IFI, InlinedArrayAllocas, InlineHistory);
-
- return IR; // success
-}
-
/// Return true if the specified inline history ID
/// indicates an inline history that includes the specified function.
static bool inlineHistoryIncludes(
@@ -346,361 +161,6 @@ static bool inlineHistoryIncludes(
return false;
}
-bool LegacyInlinerBase::doInitialization(CallGraph &CG) {
- if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No)
- ImportedFunctionsStats.setModuleInfo(CG.getModule());
- return false; // No changes to CallGraph.
-}
-
-bool LegacyInlinerBase::runOnSCC(CallGraphSCC &SCC) {
- if (skipSCC(SCC))
- return false;
- return inlineCalls(SCC);
-}
-
-static bool
-inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG,
- std::function<AssumptionCache &(Function &)> GetAssumptionCache,
- ProfileSummaryInfo *PSI,
- std::function<const TargetLibraryInfo &(Function &)> GetTLI,
- bool InsertLifetime,
- function_ref<InlineCost(CallBase &CB)> GetInlineCost,
- function_ref<AAResults &(Function &)> AARGetter,
- ImportedFunctionsInliningStatistics &ImportedFunctionsStats) {
- SmallPtrSet<Function *, 8> SCCFunctions;
- LLVM_DEBUG(dbgs() << "Inliner visiting SCC:");
- for (CallGraphNode *Node : SCC) {
- Function *F = Node->getFunction();
- if (F)
- SCCFunctions.insert(F);
- LLVM_DEBUG(dbgs() << " " << (F ? F->getName() : "INDIRECTNODE"));
- }
-
- // Scan through and identify all call sites ahead of time so that we only
- // inline call sites in the original functions, not call sites that result
- // from inlining other functions.
- SmallVector<std::pair<CallBase *, int>, 16> CallSites;
-
- // When inlining a callee produces new call sites, we want to keep track of
- // the fact that they were inlined from the callee. This allows us to avoid
- // infinite inlining in some obscure cases. To represent this, we use an
- // index into the InlineHistory vector.
- SmallVector<std::pair<Function *, int>, 8> InlineHistory;
-
- for (CallGraphNode *Node : SCC) {
- Function *F = Node->getFunction();
- if (!F || F->isDeclaration())
- continue;
-
- OptimizationRemarkEmitter ORE(F);
- for (BasicBlock &BB : *F)
- for (Instruction &I : BB) {
- auto *CB = dyn_cast<CallBase>(&I);
- // If this isn't a call, or it is a call to an intrinsic, it can
- // never be inlined.
- if (!CB || isa<IntrinsicInst>(I))
- continue;
-
- // If this is a direct call to an external function, we can never inline
- // it. If it is an indirect call, inlining may resolve it to be a
- // direct call, so we keep it.
- if (Function *Callee = CB->getCalledFunction())
- if (Callee->isDeclaration()) {
- using namespace ore;
-
- setInlineRemark(*CB, "unavailable definition");
- ORE.emit([&]() {
- return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I)
- << NV("Callee", Callee) << " will not be inlined into "
- << NV("Caller", CB->getCaller())
- << " because its definition is unavailable"
- << setIsVerbose();
- });
- continue;
- }
-
- CallSites.push_back(std::make_pair(CB, -1));
- }
- }
-
- LLVM_DEBUG(dbgs() << ": " << CallSites.size() << " call sites.\n");
-
- // If there are no calls in this function, exit early.
- if (CallSites.empty())
- return false;
-
- // Now that we have all of the call sites, move the ones to functions in the
- // current SCC to the end of the list.
- unsigned FirstCallInSCC = CallSites.size();
- for (unsigned I = 0; I < FirstCallInSCC; ++I)
- if (Function *F = CallSites[I].first->getCalledFunction())
- if (SCCFunctions.count(F))
- std::swap(CallSites[I--], CallSites[--FirstCallInSCC]);
-
- InlinedArrayAllocasTy InlinedArrayAllocas;
- InlineFunctionInfo InlineInfo(&CG, GetAssumptionCache, PSI);
-
- // Now that we have all of the call sites, loop over them and inline them if
- // it looks profitable to do so.
- bool Changed = false;
- bool LocalChange;
- do {
- LocalChange = false;
- // Iterate over the outer loop because inlining functions can cause indirect
- // calls to become direct calls.
- // CallSites may be modified inside so ranged for loop can not be used.
- for (unsigned CSi = 0; CSi != CallSites.size(); ++CSi) {
- auto &P = CallSites[CSi];
- CallBase &CB = *P.first;
- const int InlineHistoryID = P.second;
-
- Function *Caller = CB.getCaller();
- Function *Callee = CB.getCalledFunction();
-
- // We can only inline direct calls to non-declarations.
- if (!Callee || Callee->isDeclaration())
- continue;
-
- bool IsTriviallyDead = isInstructionTriviallyDead(&CB, &GetTLI(*Caller));
-
- if (!IsTriviallyDead) {
- // If this call site was obtained by inlining another function, verify
- // that the include path for the function did not include the callee
- // itself. If so, we'd be recursively inlining the same function,
- // which would provide the same callsites, which would cause us to
- // infinitely inline.
- if (InlineHistoryID != -1 &&
- inlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) {
- setInlineRemark(CB, "recursive");
- continue;
- }
- }
-
- // FIXME for new PM: because of the old PM we currently generate ORE and
- // in turn BFI on demand. With the new PM, the ORE dependency should
- // just become a regular analysis dependency.
- OptimizationRemarkEmitter ORE(Caller);
-
- auto OIC = shouldInline(CB, GetInlineCost, ORE);
- // If the policy determines that we should inline this function,
- // delete the call instead.
- if (!OIC)
- continue;
-
- // If this call site is dead and it is to a readonly function, we should
- // just delete the call instead of trying to inline it, regardless of
- // size. This happens because IPSCCP propagates the result out of the
- // call and then we're left with the dead call.
- if (IsTriviallyDead) {
- LLVM_DEBUG(dbgs() << " -> Deleting dead call: " << CB << "\n");
- // Update the call graph by deleting the edge from Callee to Caller.
- setInlineRemark(CB, "trivially dead");
- CG[Caller]->removeCallEdgeFor(CB);
- CB.eraseFromParent();
- ++NumCallsDeleted;
- } else {
- // Get DebugLoc to report. CB will be invalid after Inliner.
- DebugLoc DLoc = CB.getDebugLoc();
- BasicBlock *Block = CB.getParent();
-
- // Attempt to inline the function.
- using namespace ore;
-
- InlineResult IR = inlineCallIfPossible(
- CB, InlineInfo, InlinedArrayAllocas, InlineHistoryID,
- InsertLifetime, AARGetter, ImportedFunctionsStats);
- if (!IR.isSuccess()) {
- setInlineRemark(CB, std::string(IR.getFailureReason()) + "; " +
- inlineCostStr(*OIC));
- ORE.emit([&]() {
- return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc,
- Block)
- << NV("Callee", Callee) << " will not be inlined into "
- << NV("Caller", Caller) << ": "
- << NV("Reason", IR.getFailureReason());
- });
- continue;
- }
- ++NumInlined;
-
- emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC);
-
- // If inlining this function gave us any new call sites, throw them
- // onto our worklist to process. They are useful inline candidates.
- if (!InlineInfo.InlinedCalls.empty()) {
- // Create a new inline history entry for this, so that we remember
- // that these new callsites came about due to inlining Callee.
- int NewHistoryID = InlineHistory.size();
- InlineHistory.push_back(std::make_pair(Callee, InlineHistoryID));
-
-#ifndef NDEBUG
- // Make sure no dupplicates in the inline candidates. This could
- // happen when a callsite is simpilfied to reusing the return value
- // of another callsite during function cloning, thus the other
- // callsite will be reconsidered here.
- DenseSet<CallBase *> DbgCallSites;
- for (auto &II : CallSites)
- DbgCallSites.insert(II.first);
-#endif
-
- for (Value *Ptr : InlineInfo.InlinedCalls) {
-#ifndef NDEBUG
- assert(DbgCallSites.count(dyn_cast<CallBase>(Ptr)) == 0);
-#endif
- CallSites.push_back(
- std::make_pair(dyn_cast<CallBase>(Ptr), NewHistoryID));
- }
- }
- }
-
- // If we inlined or deleted the last possible call site to the function,
- // delete the function body now.
- if (Callee && Callee->use_empty() && Callee->hasLocalLinkage() &&
- // TODO: Can remove if in SCC now.
- !SCCFunctions.count(Callee) &&
- // The function may be apparently dead, but if there are indirect
- // callgraph references to the node, we cannot delete it yet, this
- // could invalidate the CGSCC iterator.
- CG[Callee]->getNumReferences() == 0) {
- LLVM_DEBUG(dbgs() << " -> Deleting dead function: "
- << Callee->getName() << "\n");
- CallGraphNode *CalleeNode = CG[Callee];
-
- // Remove any call graph edges from the callee to its callees.
- CalleeNode->removeAllCalledFunctions();
-
- // Removing the node for callee from the call graph and delete it.
- delete CG.removeFunctionFromModule(CalleeNode);
- ++NumDeleted;
- }
-
- // Remove this call site from the list. If possible, use
- // swap/pop_back for efficiency, but do not use it if doing so would
- // move a call site to a function in this SCC before the
- // 'FirstCallInSCC' barrier.
- if (SCC.isSingular()) {
- CallSites[CSi] = CallSites.back();
- CallSites.pop_back();
- } else {
- CallSites.erase(CallSites.begin() + CSi);
- }
- --CSi;
-
- Changed = true;
- LocalChange = true;
- }
- } while (LocalChange);
-
- return Changed;
-}
-
-bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) {
- CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
- ACT = &getAnalysis<AssumptionCacheTracker>();
- PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
- GetTLI = [&](Function &F) -> const TargetLibraryInfo & {
- return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
- auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
- return ACT->getAssumptionCache(F);
- };
- return inlineCallsImpl(
- SCC, CG, GetAssumptionCache, PSI, GetTLI, InsertLifetime,
- [&](CallBase &CB) { return getInlineCost(CB); }, LegacyAARGetter(*this),
- ImportedFunctionsStats);
-}
-
-/// Remove now-dead linkonce functions at the end of
-/// processing to avoid breaking the SCC traversal.
-bool LegacyInlinerBase::doFinalization(CallGraph &CG) {
- if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No)
- ImportedFunctionsStats.dump(InlinerFunctionImportStats ==
- InlinerFunctionImportStatsOpts::Verbose);
- return removeDeadFunctions(CG);
-}
-
-/// Remove dead functions that are not included in DNR (Do Not Remove) list.
-bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG,
- bool AlwaysInlineOnly) {
- SmallVector<CallGraphNode *, 16> FunctionsToRemove;
- SmallVector<Function *, 16> DeadFunctionsInComdats;
-
- auto RemoveCGN = [&](CallGraphNode *CGN) {
- // Remove any call graph edges from the function to its callees.
- CGN->removeAllCalledFunctions();
-
- // Remove any edges from the external node to the function's call graph
- // node. These edges might have been made irrelegant due to
- // optimization of the program.
- CG.getExternalCallingNode()->removeAnyCallEdgeTo(CGN);
-
- // Removing the node for callee from the call graph and delete it.
- FunctionsToRemove.push_back(CGN);
- };
-
- // Scan for all of the functions, looking for ones that should now be removed
- // from the program. Insert the dead ones in the FunctionsToRemove set.
- for (const auto &I : CG) {
- CallGraphNode *CGN = I.second.get();
- Function *F = CGN->getFunction();
- if (!F || F->isDeclaration())
- continue;
-
- // Handle the case when this function is called and we only want to care
- // about always-inline functions. This is a bit of a hack to share code
- // between here and the InlineAlways pass.
- if (AlwaysInlineOnly && !F->hasFnAttribute(Attribute::AlwaysInline))
- continue;
-
- // If the only remaining users of the function are dead constants, remove
- // them.
- F->removeDeadConstantUsers();
-
- if (!F->isDefTriviallyDead())
- continue;
-
- // It is unsafe to drop a function with discardable linkage from a COMDAT
- // without also dropping the other members of the COMDAT.
- // The inliner doesn't visit non-function entities which are in COMDAT
- // groups so it is unsafe to do so *unless* the linkage is local.
- if (!F->hasLocalLinkage()) {
- if (F->hasComdat()) {
- DeadFunctionsInComdats.push_back(F);
- continue;
- }
- }
-
- RemoveCGN(CGN);
- }
- if (!DeadFunctionsInComdats.empty()) {
- // Filter out the functions whose comdats remain alive.
- filterDeadComdatFunctions(DeadFunctionsInComdats);
- // Remove the rest.
- for (Function *F : DeadFunctionsInComdats)
- RemoveCGN(CG[F]);
- }
-
- if (FunctionsToRemove.empty())
- return false;
-
- // Now that we know which functions to delete, do so. We didn't want to do
- // this inline, because that would invalidate our CallGraph::iterator
- // objects. :(
- //
- // Note that it doesn't matter that we are iterating over a non-stable order
- // here to do this, it doesn't matter which order the functions are deleted
- // in.
- array_pod_sort(FunctionsToRemove.begin(), FunctionsToRemove.end());
- FunctionsToRemove.erase(
- std::unique(FunctionsToRemove.begin(), FunctionsToRemove.end()),
- FunctionsToRemove.end());
- for (CallGraphNode *CGN : FunctionsToRemove) {
- delete CG.removeFunctionFromModule(CGN);
- ++NumDeleted;
- }
- return true;
-}
-
InlineAdvisor &
InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM,
FunctionAnalysisManager &FAM, Module &M) {
@@ -729,8 +189,7 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM,
CGSCCInlineReplayFallback,
{CGSCCInlineReplayFormat}},
/*EmitRemarks=*/true,
- InlineContext{LTOPhase,
- InlinePass::ReplayCGSCCInliner});
+ InlineContext{LTOPhase, InlinePass::ReplayCGSCCInliner});
return *OwnedAdvisor;
}
@@ -871,9 +330,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
if (InlineHistoryID != -1 &&
inlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) {
- LLVM_DEBUG(dbgs() << "Skipping inlining due to history: "
- << F.getName() << " -> " << Callee.getName() << "\n");
+ LLVM_DEBUG(dbgs() << "Skipping inlining due to history: " << F.getName()
+ << " -> " << Callee.getName() << "\n");
setInlineRemark(*CB, "recursive");
+ // Set noinline so that we don't forget this decision across CGSCC
+ // iterations.
+ CB->setIsNoInline();
continue;
}
@@ -911,7 +373,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// Setup the data structure used to plumb customization into the
// `InlineFunction` routine.
InlineFunctionInfo IFI(
- /*cg=*/nullptr, GetAssumptionCache, PSI,
+ GetAssumptionCache, PSI,
&FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())),
&FAM.getResult<BlockFrequencyAnalysis>(Callee));
@@ -1193,13 +655,13 @@ void ModuleInlinerWrapperPass::printPipeline(
// on Params and Mode).
if (!MPM.isEmpty()) {
MPM.printPipeline(OS, MapClassName2PassName);
- OS << ",";
+ OS << ',';
}
OS << "cgscc(";
if (MaxDevirtIterations != 0)
OS << "devirt<" << MaxDevirtIterations << ">(";
PM.printPipeline(OS, MapClassName2PassName);
if (MaxDevirtIterations != 0)
- OS << ")";
- OS << ")";
+ OS << ')';
+ OS << ')';
}
diff --git a/llvm/lib/Transforms/IPO/Internalize.cpp b/llvm/lib/Transforms/IPO/Internalize.cpp
index 85b1a8303d33..0b8fde6489f8 100644
--- a/llvm/lib/Transforms/IPO/Internalize.cpp
+++ b/llvm/lib/Transforms/IPO/Internalize.cpp
@@ -19,19 +19,18 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/Internalize.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/GlobPattern.h"
#include "llvm/Support/LineIterator.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO.h"
using namespace llvm;
@@ -183,9 +182,8 @@ void InternalizePass::checkComdat(
Info.External = true;
}
-bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) {
+bool InternalizePass::internalizeModule(Module &M) {
bool Changed = false;
- CallGraphNode *ExternalNode = CG ? CG->getExternalCallingNode() : nullptr;
SmallVector<GlobalValue *, 4> Used;
collectUsedGlobalVariables(M, Used, false);
@@ -242,10 +240,6 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) {
continue;
Changed = true;
- if (ExternalNode)
- // Remove a callgraph edge from the external node to this function.
- ExternalNode->removeOneAbstractEdgeTo((*CG)[&I]);
-
++NumFunctions;
LLVM_DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n");
}
@@ -277,55 +271,8 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) {
InternalizePass::InternalizePass() : MustPreserveGV(PreserveAPIList()) {}
PreservedAnalyses InternalizePass::run(Module &M, ModuleAnalysisManager &AM) {
- if (!internalizeModule(M, AM.getCachedResult<CallGraphAnalysis>(M)))
+ if (!internalizeModule(M))
return PreservedAnalyses::all();
- PreservedAnalyses PA;
- PA.preserve<CallGraphAnalysis>();
- return PA;
-}
-
-namespace {
-class InternalizeLegacyPass : public ModulePass {
- // Client supplied callback to control wheter a symbol must be preserved.
- std::function<bool(const GlobalValue &)> MustPreserveGV;
-
-public:
- static char ID; // Pass identification, replacement for typeid
-
- InternalizeLegacyPass() : ModulePass(ID), MustPreserveGV(PreserveAPIList()) {}
-
- InternalizeLegacyPass(std::function<bool(const GlobalValue &)> MustPreserveGV)
- : ModulePass(ID), MustPreserveGV(std::move(MustPreserveGV)) {
- initializeInternalizeLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- CallGraphWrapperPass *CGPass =
- getAnalysisIfAvailable<CallGraphWrapperPass>();
- CallGraph *CG = CGPass ? &CGPass->getCallGraph() : nullptr;
- return internalizeModule(M, MustPreserveGV, CG);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addPreserved<CallGraphWrapperPass>();
- }
-};
-}
-
-char InternalizeLegacyPass::ID = 0;
-INITIALIZE_PASS(InternalizeLegacyPass, "internalize",
- "Internalize Global Symbols", false, false)
-
-ModulePass *llvm::createInternalizePass() {
- return new InternalizeLegacyPass();
-}
-
-ModulePass *llvm::createInternalizePass(
- std::function<bool(const GlobalValue &)> MustPreserveGV) {
- return new InternalizeLegacyPass(std::move(MustPreserveGV));
+ return PreservedAnalyses::none();
}
diff --git a/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/llvm/lib/Transforms/IPO/LoopExtractor.cpp
index ad1927c09803..9a5876f85ba7 100644
--- a/llvm/lib/Transforms/IPO/LoopExtractor.cpp
+++ b/llvm/lib/Transforms/IPO/LoopExtractor.cpp
@@ -283,8 +283,8 @@ void LoopExtractorPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (NumLoops == 1)
OS << "single";
- OS << ">";
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index ddfcace6acf8..9b4b3efd7283 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -24,7 +24,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TinyPtrVector.h"
-#include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Attributes.h"
@@ -51,12 +51,11 @@
#include "llvm/IR/ModuleSummaryIndexYAML.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ReplaceConstant.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
@@ -69,6 +68,7 @@
#include "llvm/Support/TrailingObjects.h"
#include "llvm/Support/YAMLTraits.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -172,7 +172,7 @@ BitSetInfo BitSetBuilder::build() {
BSI.AlignLog2 = 0;
if (Mask != 0)
- BSI.AlignLog2 = countTrailingZeros(Mask);
+ BSI.AlignLog2 = llvm::countr_zero(Mask);
// Build the compressed bitset while normalizing the offsets against the
// computed alignment.
@@ -242,7 +242,7 @@ bool lowertypetests::isJumpTableCanonical(Function *F) {
return false;
auto *CI = mdconst::extract_or_null<ConstantInt>(
F->getParent()->getModuleFlag("CFI Canonical Jump Tables"));
- if (!CI || CI->getZExtValue() != 0)
+ if (!CI || !CI->isZero())
return true;
return F->hasFnAttribute("cfi-canonical-jump-table");
}
@@ -406,6 +406,15 @@ class LowerTypeTestsModule {
Triple::OSType OS;
Triple::ObjectFormatType ObjectFormat;
+ // Determines which kind of Thumb jump table we generate. If arch is
+ // either 'arm' or 'thumb' we need to find this out, because
+ // selectJumpTableArmEncoding may decide to use Thumb in either case.
+ bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false;
+
+ // The jump table type we ended up deciding on. (Usually the same as
+ // Arch, except that 'arm' and 'thumb' are often interchangeable.)
+ Triple::ArchType JumpTableArch = Triple::UnknownArch;
+
IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
@@ -481,6 +490,8 @@ class LowerTypeTestsModule {
void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
ArrayRef<GlobalTypeMember *> Globals);
+ Triple::ArchType
+ selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions);
unsigned getJumpTableEntrySize();
Type *getJumpTableEntryType();
void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
@@ -518,7 +529,8 @@ class LowerTypeTestsModule {
void replaceDirectCalls(Value *Old, Value *New);
public:
- LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary,
+ LowerTypeTestsModule(Module &M, ModuleAnalysisManager &AM,
+ ModuleSummaryIndex *ExportSummary,
const ModuleSummaryIndex *ImportSummary,
bool DropTypeTests);
@@ -526,7 +538,7 @@ public:
// Lower the module using the action and summary passed as command line
// arguments. For testing purposes only.
- static bool runForTesting(Module &M);
+ static bool runForTesting(Module &M, ModuleAnalysisManager &AM);
};
} // end anonymous namespace
@@ -686,7 +698,7 @@ static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL,
}
if (auto GEP = dyn_cast<GEPOperator>(V)) {
- APInt APOffset(DL.getPointerSizeInBits(0), 0);
+ APInt APOffset(DL.getIndexSizeInBits(0), 0);
bool Result = GEP->accumulateConstantOffset(DL, APOffset);
if (!Result)
return false;
@@ -1182,31 +1194,36 @@ static const unsigned kX86JumpTableEntrySize = 8;
static const unsigned kX86IBTJumpTableEntrySize = 16;
static const unsigned kARMJumpTableEntrySize = 4;
static const unsigned kARMBTIJumpTableEntrySize = 8;
+static const unsigned kARMv6MJumpTableEntrySize = 16;
static const unsigned kRISCVJumpTableEntrySize = 8;
unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
- switch (Arch) {
- case Triple::x86:
- case Triple::x86_64:
- if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
+ switch (JumpTableArch) {
+ case Triple::x86:
+ case Triple::x86_64:
+ if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
M.getModuleFlag("cf-protection-branch")))
- if (MD->getZExtValue())
- return kX86IBTJumpTableEntrySize;
- return kX86JumpTableEntrySize;
- case Triple::arm:
- case Triple::thumb:
+ if (MD->getZExtValue())
+ return kX86IBTJumpTableEntrySize;
+ return kX86JumpTableEntrySize;
+ case Triple::arm:
+ return kARMJumpTableEntrySize;
+ case Triple::thumb:
+ if (CanUseThumbBWJumpTable)
return kARMJumpTableEntrySize;
- case Triple::aarch64:
- if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
+ else
+ return kARMv6MJumpTableEntrySize;
+ case Triple::aarch64:
+ if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
M.getModuleFlag("branch-target-enforcement")))
- if (BTE->getZExtValue())
- return kARMBTIJumpTableEntrySize;
- return kARMJumpTableEntrySize;
- case Triple::riscv32:
- case Triple::riscv64:
- return kRISCVJumpTableEntrySize;
- default:
- report_fatal_error("Unsupported architecture for jump tables");
+ if (BTE->getZExtValue())
+ return kARMBTIJumpTableEntrySize;
+ return kARMJumpTableEntrySize;
+ case Triple::riscv32:
+ case Triple::riscv64:
+ return kRISCVJumpTableEntrySize;
+ default:
+ report_fatal_error("Unsupported architecture for jump tables");
}
}
@@ -1223,7 +1240,7 @@ void LowerTypeTestsModule::createJumpTableEntry(
bool Endbr = false;
if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
Dest->getParent()->getModuleFlag("cf-protection-branch")))
- Endbr = MD->getZExtValue() != 0;
+ Endbr = !MD->isZero();
if (Endbr)
AsmOS << (JumpTableArch == Triple::x86 ? "endbr32\n" : "endbr64\n");
AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n";
@@ -1240,7 +1257,32 @@ void LowerTypeTestsModule::createJumpTableEntry(
AsmOS << "bti c\n";
AsmOS << "b $" << ArgIndex << "\n";
} else if (JumpTableArch == Triple::thumb) {
- AsmOS << "b.w $" << ArgIndex << "\n";
+ if (!CanUseThumbBWJumpTable) {
+ // In Armv6-M, this sequence will generate a branch without corrupting
+ // any registers. We use two stack words; in the second, we construct the
+ // address we'll pop into pc, and the first is used to save and restore
+ // r0 which we use as a temporary register.
+ //
+ // To support position-independent use cases, the offset of the target
+ // function is stored as a relative offset (which will expand into an
+ // R_ARM_REL32 relocation in ELF, and presumably the equivalent in other
+ // object file types), and added to pc after we load it. (The alternative
+ // B.W is automatically pc-relative.)
+ //
+ // There are five 16-bit Thumb instructions here, so the .balign 4 adds a
+ // sixth halfword of padding, and then the offset consumes a further 4
+ // bytes, for a total of 16, which is very convenient since entries in
+ // this jump table need to have power-of-two size.
+ AsmOS << "push {r0,r1}\n"
+ << "ldr r0, 1f\n"
+ << "0: add r0, r0, pc\n"
+ << "str r0, [sp, #4]\n"
+ << "pop {r0,pc}\n"
+ << ".balign 4\n"
+ << "1: .word $" << ArgIndex << " - (0b + 4)\n";
+ } else {
+ AsmOS << "b.w $" << ArgIndex << "\n";
+ }
} else if (JumpTableArch == Triple::riscv32 ||
JumpTableArch == Triple::riscv64) {
AsmOS << "tail $" << ArgIndex << "@plt\n";
@@ -1325,11 +1367,27 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
F->getAddressSpace(), "", &M);
replaceCfiUses(F, PlaceholderFn, IsJumpTableCanonical);
- Constant *Target = ConstantExpr::getSelect(
- ConstantExpr::getICmp(CmpInst::ICMP_NE, F,
- Constant::getNullValue(F->getType())),
- JT, Constant::getNullValue(F->getType()));
- PlaceholderFn->replaceAllUsesWith(Target);
+ convertUsersOfConstantsToInstructions(PlaceholderFn);
+ // Don't use range based loop, because use list will be modified.
+ while (!PlaceholderFn->use_empty()) {
+ Use &U = *PlaceholderFn->use_begin();
+ auto *InsertPt = dyn_cast<Instruction>(U.getUser());
+ assert(InsertPt && "Non-instruction users should have been eliminated");
+ auto *PN = dyn_cast<PHINode>(InsertPt);
+ if (PN)
+ InsertPt = PN->getIncomingBlock(U)->getTerminator();
+ IRBuilder Builder(InsertPt);
+ Value *ICmp = Builder.CreateICmp(CmpInst::ICMP_NE, F,
+ Constant::getNullValue(F->getType()));
+ Value *Select = Builder.CreateSelect(ICmp, JT,
+ Constant::getNullValue(F->getType()));
+ // For phi nodes, we need to update the incoming value for all operands
+ // with the same predecessor.
+ if (PN)
+ PN->setIncomingValueForBlock(InsertPt->getParent(), Select);
+ else
+ U.set(Select);
+ }
PlaceholderFn->eraseFromParent();
}
@@ -1352,12 +1410,19 @@ static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) {
// Each jump table must be either ARM or Thumb as a whole for the bit-test math
// to work. Pick one that matches the majority of members to minimize interop
// veneers inserted by the linker.
-static Triple::ArchType
-selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions,
- Triple::ArchType ModuleArch) {
- if (ModuleArch != Triple::arm && ModuleArch != Triple::thumb)
- return ModuleArch;
+Triple::ArchType LowerTypeTestsModule::selectJumpTableArmEncoding(
+ ArrayRef<GlobalTypeMember *> Functions) {
+ if (Arch != Triple::arm && Arch != Triple::thumb)
+ return Arch;
+
+ if (!CanUseThumbBWJumpTable && CanUseArmJumpTable) {
+ // In architectures that provide Arm and Thumb-1 but not Thumb-2,
+ // we should always prefer the Arm jump table format, because the
+ // Thumb-1 one is larger and slower.
+ return Triple::arm;
+ }
+ // Otherwise, go with majority vote.
unsigned ArmCount = 0, ThumbCount = 0;
for (const auto GTM : Functions) {
if (!GTM->isJumpTableCanonical()) {
@@ -1368,7 +1433,7 @@ selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions,
}
Function *F = cast<Function>(GTM->getGlobal());
- ++(isThumbFunction(F, ModuleArch) ? ThumbCount : ArmCount);
+ ++(isThumbFunction(F, Arch) ? ThumbCount : ArmCount);
}
return ArmCount > ThumbCount ? Triple::arm : Triple::thumb;
@@ -1381,8 +1446,6 @@ void LowerTypeTestsModule::createJumpTable(
SmallVector<Value *, 16> AsmArgs;
AsmArgs.reserve(Functions.size() * 2);
- Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions, Arch);
-
for (GlobalTypeMember *GTM : Functions)
createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs,
cast<Function>(GTM->getGlobal()));
@@ -1399,9 +1462,11 @@ void LowerTypeTestsModule::createJumpTable(
F->addFnAttr("target-features", "-thumb-mode");
if (JumpTableArch == Triple::thumb) {
F->addFnAttr("target-features", "+thumb-mode");
- // Thumb jump table assembly needs Thumb2. The following attribute is added
- // by Clang for -march=armv7.
- F->addFnAttr("target-cpu", "cortex-a8");
+ if (CanUseThumbBWJumpTable) {
+ // Thumb jump table assembly needs Thumb2. The following attribute is
+ // added by Clang for -march=armv7.
+ F->addFnAttr("target-cpu", "cortex-a8");
+ }
}
// When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI
// for the function to avoid double BTI. This is a no-op without
@@ -1521,6 +1586,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
// FIXME: find a better way to represent the jumptable in the IR.
assert(!Functions.empty());
+ // Decide on the jump table encoding, so that we know how big the
+ // entries will be.
+ JumpTableArch = selectJumpTableArmEncoding(Functions);
+
// Build a simple layout based on the regular layout of jump tables.
DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
unsigned EntrySize = getJumpTableEntrySize();
@@ -1706,18 +1775,31 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
/// Lower all type tests in this module.
LowerTypeTestsModule::LowerTypeTestsModule(
- Module &M, ModuleSummaryIndex *ExportSummary,
+ Module &M, ModuleAnalysisManager &AM, ModuleSummaryIndex *ExportSummary,
const ModuleSummaryIndex *ImportSummary, bool DropTypeTests)
: M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary),
DropTypeTests(DropTypeTests || ClDropTypeTests) {
assert(!(ExportSummary && ImportSummary));
Triple TargetTriple(M.getTargetTriple());
Arch = TargetTriple.getArch();
+ if (Arch == Triple::arm)
+ CanUseArmJumpTable = true;
+ if (Arch == Triple::arm || Arch == Triple::thumb) {
+ auto &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ for (Function &F : M) {
+ auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
+ if (TTI.hasArmWideBranch(false))
+ CanUseArmJumpTable = true;
+ if (TTI.hasArmWideBranch(true))
+ CanUseThumbBWJumpTable = true;
+ }
+ }
OS = TargetTriple.getOS();
ObjectFormat = TargetTriple.getObjectFormat();
}
-bool LowerTypeTestsModule::runForTesting(Module &M) {
+bool LowerTypeTestsModule::runForTesting(Module &M, ModuleAnalysisManager &AM) {
ModuleSummaryIndex Summary(/*HaveGVs=*/false);
// Handle the command-line summary arguments. This code is for testing
@@ -1735,7 +1817,8 @@ bool LowerTypeTestsModule::runForTesting(Module &M) {
bool Changed =
LowerTypeTestsModule(
- M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
+ M, AM,
+ ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr,
/*DropTypeTests*/ false)
.lower();
@@ -2186,9 +2269,9 @@ bool LowerTypeTestsModule::lower() {
unsigned MaxUniqueId = 0;
for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
MI != GlobalClasses.member_end(); ++MI) {
- if (auto *MD = MI->dyn_cast<Metadata *>())
+ if (auto *MD = dyn_cast_if_present<Metadata *>(*MI))
MaxUniqueId = std::max(MaxUniqueId, TypeIdInfo[MD].UniqueId);
- else if (auto *BF = MI->dyn_cast<ICallBranchFunnel *>())
+ else if (auto *BF = dyn_cast_if_present<ICallBranchFunnel *>(*MI))
MaxUniqueId = std::max(MaxUniqueId, BF->UniqueId);
}
Sets.emplace_back(I, MaxUniqueId);
@@ -2204,12 +2287,12 @@ bool LowerTypeTestsModule::lower() {
for (GlobalClassesTy::member_iterator MI =
GlobalClasses.member_begin(S.first);
MI != GlobalClasses.member_end(); ++MI) {
- if (MI->is<Metadata *>())
- TypeIds.push_back(MI->get<Metadata *>());
- else if (MI->is<GlobalTypeMember *>())
- Globals.push_back(MI->get<GlobalTypeMember *>());
+ if (isa<Metadata *>(*MI))
+ TypeIds.push_back(cast<Metadata *>(*MI));
+ else if (isa<GlobalTypeMember *>(*MI))
+ Globals.push_back(cast<GlobalTypeMember *>(*MI));
else
- ICallBranchFunnels.push_back(MI->get<ICallBranchFunnel *>());
+ ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(*MI));
}
// Order type identifiers by unique ID for determinism. This ordering is
@@ -2298,10 +2381,10 @@ PreservedAnalyses LowerTypeTestsPass::run(Module &M,
ModuleAnalysisManager &AM) {
bool Changed;
if (UseCommandLine)
- Changed = LowerTypeTestsModule::runForTesting(M);
+ Changed = LowerTypeTestsModule::runForTesting(M, AM);
else
Changed =
- LowerTypeTestsModule(M, ExportSummary, ImportSummary, DropTypeTests)
+ LowerTypeTestsModule(M, AM, ExportSummary, ImportSummary, DropTypeTests)
.lower();
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
new file mode 100644
index 000000000000..f835fb26fcb8
--- /dev/null
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -0,0 +1,3277 @@
+//==-- MemProfContextDisambiguation.cpp - Disambiguate contexts -------------=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements support for context disambiguation of allocation
+// calls for profile guided heap optimization. Specifically, it uses Memprof
+// profiles which indicate context specific allocation behavior (currently
+// distinguishing cold vs hot memory allocations). Cloning is performed to
+// expose the cold allocation call contexts, and the allocation calls are
+// subsequently annotated with an attribute for later transformation.
+//
+// The transformations can be performed either directly on IR (regular LTO), or
+// on a ThinLTO index (and later applied to the IR during the ThinLTO backend).
+// Both types of LTO operate on a the same base graph representation, which
+// uses CRTP to support either IR or Index formats.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MemProfContextDisambiguation.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/MemoryProfileInfo.h"
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/ModuleSummaryIndex.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include <sstream>
+#include <vector>
+using namespace llvm;
+using namespace llvm::memprof;
+
+#define DEBUG_TYPE "memprof-context-disambiguation"
+
+STATISTIC(FunctionClonesAnalysis,
+ "Number of function clones created during whole program analysis");
+STATISTIC(FunctionClonesThinBackend,
+ "Number of function clones created during ThinLTO backend");
+STATISTIC(FunctionsClonedThinBackend,
+ "Number of functions that had clones created during ThinLTO backend");
+STATISTIC(AllocTypeNotCold, "Number of not cold static allocations (possibly "
+ "cloned) during whole program analysis");
+STATISTIC(AllocTypeCold, "Number of cold static allocations (possibly cloned) "
+ "during whole program analysis");
+STATISTIC(AllocTypeNotColdThinBackend,
+ "Number of not cold static allocations (possibly cloned) during "
+ "ThinLTO backend");
+STATISTIC(AllocTypeColdThinBackend, "Number of cold static allocations "
+ "(possibly cloned) during ThinLTO backend");
+STATISTIC(OrigAllocsThinBackend,
+ "Number of original (not cloned) allocations with memprof profiles "
+ "during ThinLTO backend");
+STATISTIC(
+ AllocVersionsThinBackend,
+ "Number of allocation versions (including clones) during ThinLTO backend");
+STATISTIC(MaxAllocVersionsThinBackend,
+ "Maximum number of allocation versions created for an original "
+ "allocation during ThinLTO backend");
+STATISTIC(UnclonableAllocsThinBackend,
+ "Number of unclonable ambigous allocations during ThinLTO backend");
+
+static cl::opt<std::string> DotFilePathPrefix(
+ "memprof-dot-file-path-prefix", cl::init(""), cl::Hidden,
+ cl::value_desc("filename"),
+ cl::desc("Specify the path prefix of the MemProf dot files."));
+
+static cl::opt<bool> ExportToDot("memprof-export-to-dot", cl::init(false),
+ cl::Hidden,
+ cl::desc("Export graph to dot files."));
+
+static cl::opt<bool>
+ DumpCCG("memprof-dump-ccg", cl::init(false), cl::Hidden,
+ cl::desc("Dump CallingContextGraph to stdout after each stage."));
+
+static cl::opt<bool>
+ VerifyCCG("memprof-verify-ccg", cl::init(false), cl::Hidden,
+ cl::desc("Perform verification checks on CallingContextGraph."));
+
+static cl::opt<bool>
+ VerifyNodes("memprof-verify-nodes", cl::init(false), cl::Hidden,
+ cl::desc("Perform frequent verification checks on nodes."));
+
+static cl::opt<std::string> MemProfImportSummary(
+ "memprof-import-summary",
+ cl::desc("Import summary to use for testing the ThinLTO backend via opt"),
+ cl::Hidden);
+
+// Indicate we are linking with an allocator that supports hot/cold operator
+// new interfaces.
+cl::opt<bool> SupportsHotColdNew(
+ "supports-hot-cold-new", cl::init(false), cl::Hidden,
+ cl::desc("Linking with hot/cold operator new interfaces"));
+
+namespace {
+/// CRTP base for graphs built from either IR or ThinLTO summary index.
+///
+/// The graph represents the call contexts in all memprof metadata on allocation
+/// calls, with nodes for the allocations themselves, as well as for the calls
+/// in each context. The graph is initially built from the allocation memprof
+/// metadata (or summary) MIBs. It is then updated to match calls with callsite
+/// metadata onto the nodes, updating it to reflect any inlining performed on
+/// those calls.
+///
+/// Each MIB (representing an allocation's call context with allocation
+/// behavior) is assigned a unique context id during the graph build. The edges
+/// and nodes in the graph are decorated with the context ids they carry. This
+/// is used to correctly update the graph when cloning is performed so that we
+/// can uniquify the context for a single (possibly cloned) allocation.
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+class CallsiteContextGraph {
+public:
+ CallsiteContextGraph() = default;
+ CallsiteContextGraph(const CallsiteContextGraph &) = default;
+ CallsiteContextGraph(CallsiteContextGraph &&) = default;
+
+ /// Main entry point to perform analysis and transformations on graph.
+ bool process();
+
+ /// Perform cloning on the graph necessary to uniquely identify the allocation
+ /// behavior of an allocation based on its context.
+ void identifyClones();
+
+ /// Assign callsite clones to functions, cloning functions as needed to
+ /// accommodate the combinations of their callsite clones reached by callers.
+ /// For regular LTO this clones functions and callsites in the IR, but for
+ /// ThinLTO the cloning decisions are noted in the summaries and later applied
+ /// in applyImport.
+ bool assignFunctions();
+
+ void dump() const;
+ void print(raw_ostream &OS) const;
+
+ friend raw_ostream &operator<<(raw_ostream &OS,
+ const CallsiteContextGraph &CCG) {
+ CCG.print(OS);
+ return OS;
+ }
+
+ friend struct GraphTraits<
+ const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *>;
+ friend struct DOTGraphTraits<
+ const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *>;
+
+ void exportToDot(std::string Label) const;
+
+ /// Represents a function clone via FuncTy pointer and clone number pair.
+ struct FuncInfo final
+ : public std::pair<FuncTy *, unsigned /*Clone number*/> {
+ using Base = std::pair<FuncTy *, unsigned>;
+ FuncInfo(const Base &B) : Base(B) {}
+ FuncInfo(FuncTy *F = nullptr, unsigned CloneNo = 0) : Base(F, CloneNo) {}
+ explicit operator bool() const { return this->first != nullptr; }
+ FuncTy *func() const { return this->first; }
+ unsigned cloneNo() const { return this->second; }
+ };
+
+ /// Represents a callsite clone via CallTy and clone number pair.
+ struct CallInfo final : public std::pair<CallTy, unsigned /*Clone number*/> {
+ using Base = std::pair<CallTy, unsigned>;
+ CallInfo(const Base &B) : Base(B) {}
+ CallInfo(CallTy Call = nullptr, unsigned CloneNo = 0)
+ : Base(Call, CloneNo) {}
+ explicit operator bool() const { return (bool)this->first; }
+ CallTy call() const { return this->first; }
+ unsigned cloneNo() const { return this->second; }
+ void setCloneNo(unsigned N) { this->second = N; }
+ void print(raw_ostream &OS) const {
+ if (!operator bool()) {
+ assert(!cloneNo());
+ OS << "null Call";
+ return;
+ }
+ call()->print(OS);
+ OS << "\t(clone " << cloneNo() << ")";
+ }
+ void dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+ }
+ friend raw_ostream &operator<<(raw_ostream &OS, const CallInfo &Call) {
+ Call.print(OS);
+ return OS;
+ }
+ };
+
+ struct ContextEdge;
+
+ /// Node in the Callsite Context Graph
+ struct ContextNode {
+ // Keep this for now since in the IR case where we have an Instruction* it
+ // is not as immediately discoverable. Used for printing richer information
+ // when dumping graph.
+ bool IsAllocation;
+
+ // Keeps track of when the Call was reset to null because there was
+ // recursion.
+ bool Recursive = false;
+
+ // The corresponding allocation or interior call.
+ CallInfo Call;
+
+ // For alloc nodes this is a unique id assigned when constructed, and for
+ // callsite stack nodes it is the original stack id when the node is
+ // constructed from the memprof MIB metadata on the alloc nodes. Note that
+ // this is only used when matching callsite metadata onto the stack nodes
+ // created when processing the allocation memprof MIBs, and for labeling
+ // nodes in the dot graph. Therefore we don't bother to assign a value for
+ // clones.
+ uint64_t OrigStackOrAllocId = 0;
+
+ // This will be formed by ORing together the AllocationType enum values
+ // for contexts including this node.
+ uint8_t AllocTypes = 0;
+
+ // Edges to all callees in the profiled call stacks.
+ // TODO: Should this be a map (from Callee node) for more efficient lookup?
+ std::vector<std::shared_ptr<ContextEdge>> CalleeEdges;
+
+ // Edges to all callers in the profiled call stacks.
+ // TODO: Should this be a map (from Caller node) for more efficient lookup?
+ std::vector<std::shared_ptr<ContextEdge>> CallerEdges;
+
+ // The set of IDs for contexts including this node.
+ DenseSet<uint32_t> ContextIds;
+
+ // List of clones of this ContextNode, initially empty.
+ std::vector<ContextNode *> Clones;
+
+ // If a clone, points to the original uncloned node.
+ ContextNode *CloneOf = nullptr;
+
+ ContextNode(bool IsAllocation) : IsAllocation(IsAllocation), Call() {}
+
+ ContextNode(bool IsAllocation, CallInfo C)
+ : IsAllocation(IsAllocation), Call(C) {}
+
+ void addClone(ContextNode *Clone) {
+ if (CloneOf) {
+ CloneOf->Clones.push_back(Clone);
+ Clone->CloneOf = CloneOf;
+ } else {
+ Clones.push_back(Clone);
+ assert(!Clone->CloneOf);
+ Clone->CloneOf = this;
+ }
+ }
+
+ ContextNode *getOrigNode() {
+ if (!CloneOf)
+ return this;
+ return CloneOf;
+ }
+
+ void addOrUpdateCallerEdge(ContextNode *Caller, AllocationType AllocType,
+ unsigned int ContextId);
+
+ ContextEdge *findEdgeFromCallee(const ContextNode *Callee);
+ ContextEdge *findEdgeFromCaller(const ContextNode *Caller);
+ void eraseCalleeEdge(const ContextEdge *Edge);
+ void eraseCallerEdge(const ContextEdge *Edge);
+
+ void setCall(CallInfo C) { Call = C; }
+
+ bool hasCall() const { return (bool)Call.call(); }
+
+ void printCall(raw_ostream &OS) const { Call.print(OS); }
+
+ // True if this node was effectively removed from the graph, in which case
+ // its context id set, caller edges, and callee edges should all be empty.
+ bool isRemoved() const {
+ assert(ContextIds.empty() ==
+ (CalleeEdges.empty() && CallerEdges.empty()));
+ return ContextIds.empty();
+ }
+
+ void dump() const;
+ void print(raw_ostream &OS) const;
+
+ friend raw_ostream &operator<<(raw_ostream &OS, const ContextNode &Node) {
+ Node.print(OS);
+ return OS;
+ }
+ };
+
+ /// Edge in the Callsite Context Graph from a ContextNode N to a caller or
+ /// callee.
+ struct ContextEdge {
+ ContextNode *Callee;
+ ContextNode *Caller;
+
+ // This will be formed by ORing together the AllocationType enum values
+ // for contexts including this edge.
+ uint8_t AllocTypes = 0;
+
+ // The set of IDs for contexts including this edge.
+ DenseSet<uint32_t> ContextIds;
+
+ ContextEdge(ContextNode *Callee, ContextNode *Caller, uint8_t AllocType,
+ DenseSet<uint32_t> ContextIds)
+ : Callee(Callee), Caller(Caller), AllocTypes(AllocType),
+ ContextIds(ContextIds) {}
+
+ DenseSet<uint32_t> &getContextIds() { return ContextIds; }
+
+ void dump() const;
+ void print(raw_ostream &OS) const;
+
+ friend raw_ostream &operator<<(raw_ostream &OS, const ContextEdge &Edge) {
+ Edge.print(OS);
+ return OS;
+ }
+ };
+
+ /// Helper to remove callee edges that have allocation type None (due to not
+ /// carrying any context ids) after transformations.
+ void removeNoneTypeCalleeEdges(ContextNode *Node);
+
+protected:
+ /// Get a list of nodes corresponding to the stack ids in the given callsite
+ /// context.
+ template <class NodeT, class IteratorT>
+ std::vector<uint64_t>
+ getStackIdsWithContextNodes(CallStack<NodeT, IteratorT> &CallsiteContext);
+
+ /// Adds nodes for the given allocation and any stack ids on its memprof MIB
+ /// metadata (or summary).
+ ContextNode *addAllocNode(CallInfo Call, const FuncTy *F);
+
+ /// Adds nodes for the given MIB stack ids.
+ template <class NodeT, class IteratorT>
+ void addStackNodesForMIB(ContextNode *AllocNode,
+ CallStack<NodeT, IteratorT> &StackContext,
+ CallStack<NodeT, IteratorT> &CallsiteContext,
+ AllocationType AllocType);
+
+ /// Matches all callsite metadata (or summary) to the nodes created for
+ /// allocation memprof MIB metadata, synthesizing new nodes to reflect any
+ /// inlining performed on those callsite instructions.
+ void updateStackNodes();
+
+ /// Update graph to conservatively handle any callsite stack nodes that target
+ /// multiple different callee target functions.
+ void handleCallsitesWithMultipleTargets();
+
+ /// Save lists of calls with MemProf metadata in each function, for faster
+ /// iteration.
+ std::vector<std::pair<FuncTy *, std::vector<CallInfo>>>
+ FuncToCallsWithMetadata;
+
+ /// Map from callsite node to the enclosing caller function.
+ std::map<const ContextNode *, const FuncTy *> NodeToCallingFunc;
+
+private:
+ using EdgeIter = typename std::vector<std::shared_ptr<ContextEdge>>::iterator;
+
+ using CallContextInfo = std::tuple<CallTy, std::vector<uint64_t>,
+ const FuncTy *, DenseSet<uint32_t>>;
+
+ /// Assigns the given Node to calls at or inlined into the location with
+ /// the Node's stack id, after post order traversing and processing its
+ /// caller nodes. Uses the call information recorded in the given
+ /// StackIdToMatchingCalls map, and creates new nodes for inlined sequences
+ /// as needed. Called by updateStackNodes which sets up the given
+ /// StackIdToMatchingCalls map.
+ void assignStackNodesPostOrder(
+ ContextNode *Node, DenseSet<const ContextNode *> &Visited,
+ DenseMap<uint64_t, std::vector<CallContextInfo>> &StackIdToMatchingCalls);
+
+ /// Duplicates the given set of context ids, updating the provided
+ /// map from each original id with the newly generated context ids,
+ /// and returning the new duplicated id set.
+ DenseSet<uint32_t> duplicateContextIds(
+ const DenseSet<uint32_t> &StackSequenceContextIds,
+ DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds);
+
+ /// Propagates all duplicated context ids across the graph.
+ void propagateDuplicateContextIds(
+ const DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds);
+
+ /// Connect the NewNode to OrigNode's callees if TowardsCallee is true,
+ /// else to its callers. Also updates OrigNode's edges to remove any context
+ /// ids moved to the newly created edge.
+ void connectNewNode(ContextNode *NewNode, ContextNode *OrigNode,
+ bool TowardsCallee);
+
+ /// Get the stack id corresponding to the given Id or Index (for IR this will
+ /// return itself, for a summary index this will return the id recorded in the
+ /// index for that stack id index value).
+ uint64_t getStackId(uint64_t IdOrIndex) const {
+ return static_cast<const DerivedCCG *>(this)->getStackId(IdOrIndex);
+ }
+
+ /// Returns true if the given call targets the given function.
+ bool calleeMatchesFunc(CallTy Call, const FuncTy *Func) {
+ return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(Call, Func);
+ }
+
+ /// Get a list of nodes corresponding to the stack ids in the given
+ /// callsite's context.
+ std::vector<uint64_t> getStackIdsWithContextNodesForCall(CallTy Call) {
+ return static_cast<DerivedCCG *>(this)->getStackIdsWithContextNodesForCall(
+ Call);
+ }
+
+ /// Get the last stack id in the context for callsite.
+ uint64_t getLastStackId(CallTy Call) {
+ return static_cast<DerivedCCG *>(this)->getLastStackId(Call);
+ }
+
+ /// Update the allocation call to record type of allocated memory.
+ void updateAllocationCall(CallInfo &Call, AllocationType AllocType) {
+ AllocType == AllocationType::Cold ? AllocTypeCold++ : AllocTypeNotCold++;
+ static_cast<DerivedCCG *>(this)->updateAllocationCall(Call, AllocType);
+ }
+
+ /// Update non-allocation call to invoke (possibly cloned) function
+ /// CalleeFunc.
+ void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc) {
+ static_cast<DerivedCCG *>(this)->updateCall(CallerCall, CalleeFunc);
+ }
+
+ /// Clone the given function for the given callsite, recording mapping of all
+ /// of the functions tracked calls to their new versions in the CallMap.
+ /// Assigns new clones to clone number CloneNo.
+ FuncInfo cloneFunctionForCallsite(
+ FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap,
+ std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) {
+ return static_cast<DerivedCCG *>(this)->cloneFunctionForCallsite(
+ Func, Call, CallMap, CallsWithMetadataInFunc, CloneNo);
+ }
+
+ /// Gets a label to use in the dot graph for the given call clone in the given
+ /// function.
+ std::string getLabel(const FuncTy *Func, const CallTy Call,
+ unsigned CloneNo) const {
+ return static_cast<const DerivedCCG *>(this)->getLabel(Func, Call, CloneNo);
+ }
+
+ /// Helpers to find the node corresponding to the given call or stackid.
+ ContextNode *getNodeForInst(const CallInfo &C);
+ ContextNode *getNodeForAlloc(const CallInfo &C);
+ ContextNode *getNodeForStackId(uint64_t StackId);
+
+ /// Removes the node information recorded for the given call.
+ void unsetNodeForInst(const CallInfo &C);
+
+ /// Computes the alloc type corresponding to the given context ids, by
+ /// unioning their recorded alloc types.
+ uint8_t computeAllocType(DenseSet<uint32_t> &ContextIds);
+
+ /// Returns the alloction type of the intersection of the contexts of two
+ /// nodes (based on their provided context id sets), optimized for the case
+ /// when Node1Ids is smaller than Node2Ids.
+ uint8_t intersectAllocTypesImpl(const DenseSet<uint32_t> &Node1Ids,
+ const DenseSet<uint32_t> &Node2Ids);
+
+ /// Returns the alloction type of the intersection of the contexts of two
+ /// nodes (based on their provided context id sets).
+ uint8_t intersectAllocTypes(const DenseSet<uint32_t> &Node1Ids,
+ const DenseSet<uint32_t> &Node2Ids);
+
+ /// Create a clone of Edge's callee and move Edge to that new callee node,
+ /// performing the necessary context id and allocation type updates.
+ /// If callee's caller edge iterator is supplied, it is updated when removing
+ /// the edge from that list.
+ ContextNode *
+ moveEdgeToNewCalleeClone(const std::shared_ptr<ContextEdge> &Edge,
+ EdgeIter *CallerEdgeI = nullptr);
+
+ /// Change the callee of Edge to existing callee clone NewCallee, performing
+ /// the necessary context id and allocation type updates.
+ /// If callee's caller edge iterator is supplied, it is updated when removing
+ /// the edge from that list.
+ void moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge,
+ ContextNode *NewCallee,
+ EdgeIter *CallerEdgeI = nullptr,
+ bool NewClone = false);
+
+ /// Recursively perform cloning on the graph for the given Node and its
+ /// callers, in order to uniquely identify the allocation behavior of an
+ /// allocation given its context.
+ void identifyClones(ContextNode *Node,
+ DenseSet<const ContextNode *> &Visited);
+
+ /// Map from each context ID to the AllocationType assigned to that context.
+ std::map<uint32_t, AllocationType> ContextIdToAllocationType;
+
+ /// Identifies the context node created for a stack id when adding the MIB
+ /// contexts to the graph. This is used to locate the context nodes when
+ /// trying to assign the corresponding callsites with those stack ids to these
+ /// nodes.
+ std::map<uint64_t, ContextNode *> StackEntryIdToContextNodeMap;
+
+ /// Maps to track the calls to their corresponding nodes in the graph.
+ MapVector<CallInfo, ContextNode *> AllocationCallToContextNodeMap;
+ MapVector<CallInfo, ContextNode *> NonAllocationCallToContextNodeMap;
+
+ /// Owner of all ContextNode unique_ptrs.
+ std::vector<std::unique_ptr<ContextNode>> NodeOwner;
+
+ /// Perform sanity checks on graph when requested.
+ void check() const;
+
+ /// Keeps track of the last unique context id assigned.
+ unsigned int LastContextId = 0;
+};
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+using ContextNode =
+ typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode;
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+using ContextEdge =
+ typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge;
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+using FuncInfo =
+ typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::FuncInfo;
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+using CallInfo =
+ typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::CallInfo;
+
+/// CRTP derived class for graphs built from IR (regular LTO).
+class ModuleCallsiteContextGraph
+ : public CallsiteContextGraph<ModuleCallsiteContextGraph, Function,
+ Instruction *> {
+public:
+ ModuleCallsiteContextGraph(
+ Module &M,
+ function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter);
+
+private:
+ friend CallsiteContextGraph<ModuleCallsiteContextGraph, Function,
+ Instruction *>;
+
+ uint64_t getStackId(uint64_t IdOrIndex) const;
+ bool calleeMatchesFunc(Instruction *Call, const Function *Func);
+ uint64_t getLastStackId(Instruction *Call);
+ std::vector<uint64_t> getStackIdsWithContextNodesForCall(Instruction *Call);
+ void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
+ void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc);
+ CallsiteContextGraph<ModuleCallsiteContextGraph, Function,
+ Instruction *>::FuncInfo
+ cloneFunctionForCallsite(FuncInfo &Func, CallInfo &Call,
+ std::map<CallInfo, CallInfo> &CallMap,
+ std::vector<CallInfo> &CallsWithMetadataInFunc,
+ unsigned CloneNo);
+ std::string getLabel(const Function *Func, const Instruction *Call,
+ unsigned CloneNo) const;
+
+ const Module &Mod;
+ function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
+};
+
+/// Represents a call in the summary index graph, which can either be an
+/// allocation or an interior callsite node in an allocation's context.
+/// Holds a pointer to the corresponding data structure in the index.
+struct IndexCall : public PointerUnion<CallsiteInfo *, AllocInfo *> {
+ IndexCall() : PointerUnion() {}
+ IndexCall(std::nullptr_t) : IndexCall() {}
+ IndexCall(CallsiteInfo *StackNode) : PointerUnion(StackNode) {}
+ IndexCall(AllocInfo *AllocNode) : PointerUnion(AllocNode) {}
+ IndexCall(PointerUnion PT) : PointerUnion(PT) {}
+
+ IndexCall *operator->() { return this; }
+
+ PointerUnion<CallsiteInfo *, AllocInfo *> getBase() const { return *this; }
+
+ void print(raw_ostream &OS) const {
+ if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(getBase())) {
+ OS << *AI;
+ } else {
+ auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(getBase());
+ assert(CI);
+ OS << *CI;
+ }
+ }
+};
+
+/// CRTP derived class for graphs built from summary index (ThinLTO).
+class IndexCallsiteContextGraph
+ : public CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
+ IndexCall> {
+public:
+ IndexCallsiteContextGraph(
+ ModuleSummaryIndex &Index,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing);
+
+private:
+ friend CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
+ IndexCall>;
+
+ uint64_t getStackId(uint64_t IdOrIndex) const;
+ bool calleeMatchesFunc(IndexCall &Call, const FunctionSummary *Func);
+ uint64_t getLastStackId(IndexCall &Call);
+ std::vector<uint64_t> getStackIdsWithContextNodesForCall(IndexCall &Call);
+ void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
+ void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc);
+ CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
+ IndexCall>::FuncInfo
+ cloneFunctionForCallsite(FuncInfo &Func, CallInfo &Call,
+ std::map<CallInfo, CallInfo> &CallMap,
+ std::vector<CallInfo> &CallsWithMetadataInFunc,
+ unsigned CloneNo);
+ std::string getLabel(const FunctionSummary *Func, const IndexCall &Call,
+ unsigned CloneNo) const;
+
+ // Saves mapping from function summaries containing memprof records back to
+ // its VI, for use in checking and debugging.
+ std::map<const FunctionSummary *, ValueInfo> FSToVIMap;
+
+ const ModuleSummaryIndex &Index;
+};
+} // namespace
+
+namespace llvm {
+template <>
+struct DenseMapInfo<typename CallsiteContextGraph<
+ ModuleCallsiteContextGraph, Function, Instruction *>::CallInfo>
+ : public DenseMapInfo<std::pair<Instruction *, unsigned>> {};
+template <>
+struct DenseMapInfo<typename CallsiteContextGraph<
+ IndexCallsiteContextGraph, FunctionSummary, IndexCall>::CallInfo>
+ : public DenseMapInfo<std::pair<IndexCall, unsigned>> {};
+template <>
+struct DenseMapInfo<IndexCall>
+ : public DenseMapInfo<PointerUnion<CallsiteInfo *, AllocInfo *>> {};
+} // end namespace llvm
+
+namespace {
+
+struct FieldSeparator {
+ bool Skip = true;
+ const char *Sep;
+
+ FieldSeparator(const char *Sep = ", ") : Sep(Sep) {}
+};
+
+raw_ostream &operator<<(raw_ostream &OS, FieldSeparator &FS) {
+ if (FS.Skip) {
+ FS.Skip = false;
+ return OS;
+ }
+ return OS << FS.Sep;
+}
+
+// Map the uint8_t alloc types (which may contain NotCold|Cold) to the alloc
+// type we should actually use on the corresponding allocation.
+// If we can't clone a node that has NotCold+Cold alloc type, we will fall
+// back to using NotCold. So don't bother cloning to distinguish NotCold+Cold
+// from NotCold.
+AllocationType allocTypeToUse(uint8_t AllocTypes) {
+ assert(AllocTypes != (uint8_t)AllocationType::None);
+ if (AllocTypes ==
+ ((uint8_t)AllocationType::NotCold | (uint8_t)AllocationType::Cold))
+ return AllocationType::NotCold;
+ else
+ return (AllocationType)AllocTypes;
+}
+
+// Helper to check if the alloc types for all edges recorded in the
+// InAllocTypes vector match the alloc types for all edges in the Edges
+// vector.
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+bool allocTypesMatch(
+ const std::vector<uint8_t> &InAllocTypes,
+ const std::vector<std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>>>
+ &Edges) {
+ return std::equal(
+ InAllocTypes.begin(), InAllocTypes.end(), Edges.begin(),
+ [](const uint8_t &l,
+ const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &r) {
+ // Can share if one of the edges is None type - don't
+ // care about the type along that edge as it doesn't
+ // exist for those context ids.
+ if (l == (uint8_t)AllocationType::None ||
+ r->AllocTypes == (uint8_t)AllocationType::None)
+ return true;
+ return allocTypeToUse(l) == allocTypeToUse(r->AllocTypes);
+ });
+}
+
+} // end anonymous namespace
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForInst(
+ const CallInfo &C) {
+ ContextNode *Node = getNodeForAlloc(C);
+ if (Node)
+ return Node;
+
+ return NonAllocationCallToContextNodeMap.lookup(C);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForAlloc(
+ const CallInfo &C) {
+ return AllocationCallToContextNodeMap.lookup(C);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForStackId(
+ uint64_t StackId) {
+ auto StackEntryNode = StackEntryIdToContextNodeMap.find(StackId);
+ if (StackEntryNode != StackEntryIdToContextNodeMap.end())
+ return StackEntryNode->second;
+ return nullptr;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::unsetNodeForInst(
+ const CallInfo &C) {
+ AllocationCallToContextNodeMap.erase(C) ||
+ NonAllocationCallToContextNodeMap.erase(C);
+ assert(!AllocationCallToContextNodeMap.count(C) &&
+ !NonAllocationCallToContextNodeMap.count(C));
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
+ addOrUpdateCallerEdge(ContextNode *Caller, AllocationType AllocType,
+ unsigned int ContextId) {
+ for (auto &Edge : CallerEdges) {
+ if (Edge->Caller == Caller) {
+ Edge->AllocTypes |= (uint8_t)AllocType;
+ Edge->getContextIds().insert(ContextId);
+ return;
+ }
+ }
+ std::shared_ptr<ContextEdge> Edge = std::make_shared<ContextEdge>(
+ this, Caller, (uint8_t)AllocType, DenseSet<uint32_t>({ContextId}));
+ CallerEdges.push_back(Edge);
+ Caller->CalleeEdges.push_back(Edge);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<
+ DerivedCCG, FuncTy, CallTy>::removeNoneTypeCalleeEdges(ContextNode *Node) {
+ for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end();) {
+ auto Edge = *EI;
+ if (Edge->AllocTypes == (uint8_t)AllocationType::None) {
+ assert(Edge->ContextIds.empty());
+ Edge->Callee->eraseCallerEdge(Edge.get());
+ EI = Node->CalleeEdges.erase(EI);
+ } else
+ ++EI;
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
+ findEdgeFromCallee(const ContextNode *Callee) {
+ for (const auto &Edge : CalleeEdges)
+ if (Edge->Callee == Callee)
+ return Edge.get();
+ return nullptr;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
+ findEdgeFromCaller(const ContextNode *Caller) {
+ for (const auto &Edge : CallerEdges)
+ if (Edge->Caller == Caller)
+ return Edge.get();
+ return nullptr;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
+ eraseCalleeEdge(const ContextEdge *Edge) {
+ auto EI =
+ std::find_if(CalleeEdges.begin(), CalleeEdges.end(),
+ [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) {
+ return CalleeEdge.get() == Edge;
+ });
+ assert(EI != CalleeEdges.end());
+ CalleeEdges.erase(EI);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
+ eraseCallerEdge(const ContextEdge *Edge) {
+ auto EI =
+ std::find_if(CallerEdges.begin(), CallerEdges.end(),
+ [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) {
+ return CallerEdge.get() == Edge;
+ });
+ assert(EI != CallerEdges.end());
+ CallerEdges.erase(EI);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+uint8_t CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::computeAllocType(
+ DenseSet<uint32_t> &ContextIds) {
+ uint8_t BothTypes =
+ (uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold;
+ uint8_t AllocType = (uint8_t)AllocationType::None;
+ for (auto Id : ContextIds) {
+ AllocType |= (uint8_t)ContextIdToAllocationType[Id];
+ // Bail early if alloc type reached both, no further refinement.
+ if (AllocType == BothTypes)
+ return AllocType;
+ }
+ return AllocType;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+uint8_t
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::intersectAllocTypesImpl(
+ const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids) {
+ uint8_t BothTypes =
+ (uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold;
+ uint8_t AllocType = (uint8_t)AllocationType::None;
+ for (auto Id : Node1Ids) {
+ if (!Node2Ids.count(Id))
+ continue;
+ AllocType |= (uint8_t)ContextIdToAllocationType[Id];
+ // Bail early if alloc type reached both, no further refinement.
+ if (AllocType == BothTypes)
+ return AllocType;
+ }
+ return AllocType;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+uint8_t CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::intersectAllocTypes(
+ const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids) {
+ if (Node1Ids.size() < Node2Ids.size())
+ return intersectAllocTypesImpl(Node1Ids, Node2Ids);
+ else
+ return intersectAllocTypesImpl(Node2Ids, Node1Ids);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addAllocNode(
+ CallInfo Call, const FuncTy *F) {
+ assert(!getNodeForAlloc(Call));
+ NodeOwner.push_back(
+ std::make_unique<ContextNode>(/*IsAllocation=*/true, Call));
+ ContextNode *AllocNode = NodeOwner.back().get();
+ AllocationCallToContextNodeMap[Call] = AllocNode;
+ NodeToCallingFunc[AllocNode] = F;
+ // Use LastContextId as a uniq id for MIB allocation nodes.
+ AllocNode->OrigStackOrAllocId = LastContextId;
+ // Alloc type should be updated as we add in the MIBs. We should assert
+ // afterwards that it is not still None.
+ AllocNode->AllocTypes = (uint8_t)AllocationType::None;
+
+ return AllocNode;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+template <class NodeT, class IteratorT>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB(
+ ContextNode *AllocNode, CallStack<NodeT, IteratorT> &StackContext,
+ CallStack<NodeT, IteratorT> &CallsiteContext, AllocationType AllocType) {
+ // Treating the hot alloc type as NotCold before the disambiguation for "hot"
+ // is done.
+ if (AllocType == AllocationType::Hot)
+ AllocType = AllocationType::NotCold;
+
+ ContextIdToAllocationType[++LastContextId] = AllocType;
+
+ // Update alloc type and context ids for this MIB.
+ AllocNode->AllocTypes |= (uint8_t)AllocType;
+ AllocNode->ContextIds.insert(LastContextId);
+
+ // Now add or update nodes for each stack id in alloc's context.
+ // Later when processing the stack ids on non-alloc callsites we will adjust
+ // for any inlining in the context.
+ ContextNode *PrevNode = AllocNode;
+ // Look for recursion (direct recursion should have been collapsed by
+ // module summary analysis, here we should just be detecting mutual
+ // recursion). Mark these nodes so we don't try to clone.
+ SmallSet<uint64_t, 8> StackIdSet;
+ // Skip any on the allocation call (inlining).
+ for (auto ContextIter = StackContext.beginAfterSharedPrefix(CallsiteContext);
+ ContextIter != StackContext.end(); ++ContextIter) {
+ auto StackId = getStackId(*ContextIter);
+ ContextNode *StackNode = getNodeForStackId(StackId);
+ if (!StackNode) {
+ NodeOwner.push_back(
+ std::make_unique<ContextNode>(/*IsAllocation=*/false));
+ StackNode = NodeOwner.back().get();
+ StackEntryIdToContextNodeMap[StackId] = StackNode;
+ StackNode->OrigStackOrAllocId = StackId;
+ }
+ auto Ins = StackIdSet.insert(StackId);
+ if (!Ins.second)
+ StackNode->Recursive = true;
+ StackNode->ContextIds.insert(LastContextId);
+ StackNode->AllocTypes |= (uint8_t)AllocType;
+ PrevNode->addOrUpdateCallerEdge(StackNode, AllocType, LastContextId);
+ PrevNode = StackNode;
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+DenseSet<uint32_t>
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::duplicateContextIds(
+ const DenseSet<uint32_t> &StackSequenceContextIds,
+ DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds) {
+ DenseSet<uint32_t> NewContextIds;
+ for (auto OldId : StackSequenceContextIds) {
+ NewContextIds.insert(++LastContextId);
+ OldToNewContextIds[OldId].insert(LastContextId);
+ assert(ContextIdToAllocationType.count(OldId));
+ // The new context has the same allocation type as original.
+ ContextIdToAllocationType[LastContextId] = ContextIdToAllocationType[OldId];
+ }
+ return NewContextIds;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::
+ propagateDuplicateContextIds(
+ const DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds) {
+ // Build a set of duplicated context ids corresponding to the input id set.
+ auto GetNewIds = [&OldToNewContextIds](const DenseSet<uint32_t> &ContextIds) {
+ DenseSet<uint32_t> NewIds;
+ for (auto Id : ContextIds)
+ if (auto NewId = OldToNewContextIds.find(Id);
+ NewId != OldToNewContextIds.end())
+ NewIds.insert(NewId->second.begin(), NewId->second.end());
+ return NewIds;
+ };
+
+ // Recursively update context ids sets along caller edges.
+ auto UpdateCallers = [&](ContextNode *Node,
+ DenseSet<const ContextEdge *> &Visited,
+ auto &&UpdateCallers) -> void {
+ for (const auto &Edge : Node->CallerEdges) {
+ auto Inserted = Visited.insert(Edge.get());
+ if (!Inserted.second)
+ continue;
+ ContextNode *NextNode = Edge->Caller;
+ DenseSet<uint32_t> NewIdsToAdd = GetNewIds(Edge->getContextIds());
+ // Only need to recursively iterate to NextNode via this caller edge if
+ // it resulted in any added ids to NextNode.
+ if (!NewIdsToAdd.empty()) {
+ Edge->getContextIds().insert(NewIdsToAdd.begin(), NewIdsToAdd.end());
+ NextNode->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end());
+ UpdateCallers(NextNode, Visited, UpdateCallers);
+ }
+ }
+ };
+
+ DenseSet<const ContextEdge *> Visited;
+ for (auto &Entry : AllocationCallToContextNodeMap) {
+ auto *Node = Entry.second;
+ // Update ids on the allocation nodes before calling the recursive
+ // update along caller edges, since this simplifies the logic during
+ // that traversal.
+ DenseSet<uint32_t> NewIdsToAdd = GetNewIds(Node->ContextIds);
+ Node->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end());
+ UpdateCallers(Node, Visited, UpdateCallers);
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::connectNewNode(
+ ContextNode *NewNode, ContextNode *OrigNode, bool TowardsCallee) {
+ // Make a copy of the context ids, since this will be adjusted below as they
+ // are moved.
+ DenseSet<uint32_t> RemainingContextIds = NewNode->ContextIds;
+ auto &OrigEdges =
+ TowardsCallee ? OrigNode->CalleeEdges : OrigNode->CallerEdges;
+ // Increment iterator in loop so that we can remove edges as needed.
+ for (auto EI = OrigEdges.begin(); EI != OrigEdges.end();) {
+ auto Edge = *EI;
+ // Remove any matching context ids from Edge, return set that were found and
+ // removed, these are the new edge's context ids. Also update the remaining
+ // (not found ids).
+ DenseSet<uint32_t> NewEdgeContextIds, NotFoundContextIds;
+ set_subtract(Edge->getContextIds(), RemainingContextIds, NewEdgeContextIds,
+ NotFoundContextIds);
+ RemainingContextIds.swap(NotFoundContextIds);
+ // If no matching context ids for this edge, skip it.
+ if (NewEdgeContextIds.empty()) {
+ ++EI;
+ continue;
+ }
+ if (TowardsCallee) {
+ auto NewEdge = std::make_shared<ContextEdge>(
+ Edge->Callee, NewNode, computeAllocType(NewEdgeContextIds),
+ NewEdgeContextIds);
+ NewNode->CalleeEdges.push_back(NewEdge);
+ NewEdge->Callee->CallerEdges.push_back(NewEdge);
+ } else {
+ auto NewEdge = std::make_shared<ContextEdge>(
+ NewNode, Edge->Caller, computeAllocType(NewEdgeContextIds),
+ NewEdgeContextIds);
+ NewNode->CallerEdges.push_back(NewEdge);
+ NewEdge->Caller->CalleeEdges.push_back(NewEdge);
+ }
+ // Remove old edge if context ids empty.
+ if (Edge->getContextIds().empty()) {
+ if (TowardsCallee) {
+ Edge->Callee->eraseCallerEdge(Edge.get());
+ EI = OrigNode->CalleeEdges.erase(EI);
+ } else {
+ Edge->Caller->eraseCalleeEdge(Edge.get());
+ EI = OrigNode->CallerEdges.erase(EI);
+ }
+ continue;
+ }
+ ++EI;
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::
+ assignStackNodesPostOrder(ContextNode *Node,
+ DenseSet<const ContextNode *> &Visited,
+ DenseMap<uint64_t, std::vector<CallContextInfo>>
+ &StackIdToMatchingCalls) {
+ auto Inserted = Visited.insert(Node);
+ if (!Inserted.second)
+ return;
+ // Post order traversal. Iterate over a copy since we may add nodes and
+ // therefore new callers during the recursive call, invalidating any
+ // iterator over the original edge vector. We don't need to process these
+ // new nodes as they were already processed on creation.
+ auto CallerEdges = Node->CallerEdges;
+ for (auto &Edge : CallerEdges) {
+ // Skip any that have been removed during the recursion.
+ if (!Edge)
+ continue;
+ assignStackNodesPostOrder(Edge->Caller, Visited, StackIdToMatchingCalls);
+ }
+
+ // If this node's stack id is in the map, update the graph to contain new
+ // nodes representing any inlining at interior callsites. Note we move the
+ // associated context ids over to the new nodes.
+
+ // Ignore this node if it is for an allocation or we didn't record any
+ // stack id lists ending at it.
+ if (Node->IsAllocation ||
+ !StackIdToMatchingCalls.count(Node->OrigStackOrAllocId))
+ return;
+
+ auto &Calls = StackIdToMatchingCalls[Node->OrigStackOrAllocId];
+ // Handle the simple case first. A single call with a single stack id.
+ // In this case there is no need to create any new context nodes, simply
+ // assign the context node for stack id to this Call.
+ if (Calls.size() == 1) {
+ auto &[Call, Ids, Func, SavedContextIds] = Calls[0];
+ if (Ids.size() == 1) {
+ assert(SavedContextIds.empty());
+ // It should be this Node
+ assert(Node == getNodeForStackId(Ids[0]));
+ if (Node->Recursive)
+ return;
+ Node->setCall(Call);
+ NonAllocationCallToContextNodeMap[Call] = Node;
+ NodeToCallingFunc[Node] = Func;
+ return;
+ }
+ }
+
+ // Find the node for the last stack id, which should be the same
+ // across all calls recorded for this id, and is this node's id.
+ uint64_t LastId = Node->OrigStackOrAllocId;
+ ContextNode *LastNode = getNodeForStackId(LastId);
+ // We should only have kept stack ids that had nodes.
+ assert(LastNode);
+
+ for (unsigned I = 0; I < Calls.size(); I++) {
+ auto &[Call, Ids, Func, SavedContextIds] = Calls[I];
+ // Skip any for which we didn't assign any ids, these don't get a node in
+ // the graph.
+ if (SavedContextIds.empty())
+ continue;
+
+ assert(LastId == Ids.back());
+
+ ContextNode *FirstNode = getNodeForStackId(Ids[0]);
+ assert(FirstNode);
+
+ // Recompute the context ids for this stack id sequence (the
+ // intersection of the context ids of the corresponding nodes).
+ // Start with the ids we saved in the map for this call, which could be
+ // duplicated context ids. We have to recompute as we might have overlap
+ // overlap between the saved context ids for different last nodes, and
+ // removed them already during the post order traversal.
+ set_intersect(SavedContextIds, FirstNode->ContextIds);
+ ContextNode *PrevNode = nullptr;
+ for (auto Id : Ids) {
+ ContextNode *CurNode = getNodeForStackId(Id);
+ // We should only have kept stack ids that had nodes and weren't
+ // recursive.
+ assert(CurNode);
+ assert(!CurNode->Recursive);
+ if (!PrevNode) {
+ PrevNode = CurNode;
+ continue;
+ }
+ auto *Edge = CurNode->findEdgeFromCallee(PrevNode);
+ if (!Edge) {
+ SavedContextIds.clear();
+ break;
+ }
+ PrevNode = CurNode;
+ set_intersect(SavedContextIds, Edge->getContextIds());
+
+ // If we now have no context ids for clone, skip this call.
+ if (SavedContextIds.empty())
+ break;
+ }
+ if (SavedContextIds.empty())
+ continue;
+
+ // Create new context node.
+ NodeOwner.push_back(
+ std::make_unique<ContextNode>(/*IsAllocation=*/false, Call));
+ ContextNode *NewNode = NodeOwner.back().get();
+ NodeToCallingFunc[NewNode] = Func;
+ NonAllocationCallToContextNodeMap[Call] = NewNode;
+ NewNode->ContextIds = SavedContextIds;
+ NewNode->AllocTypes = computeAllocType(NewNode->ContextIds);
+
+ // Connect to callees of innermost stack frame in inlined call chain.
+ // This updates context ids for FirstNode's callee's to reflect those
+ // moved to NewNode.
+ connectNewNode(NewNode, FirstNode, /*TowardsCallee=*/true);
+
+ // Connect to callers of outermost stack frame in inlined call chain.
+ // This updates context ids for FirstNode's caller's to reflect those
+ // moved to NewNode.
+ connectNewNode(NewNode, LastNode, /*TowardsCallee=*/false);
+
+ // Now we need to remove context ids from edges/nodes between First and
+ // Last Node.
+ PrevNode = nullptr;
+ for (auto Id : Ids) {
+ ContextNode *CurNode = getNodeForStackId(Id);
+ // We should only have kept stack ids that had nodes.
+ assert(CurNode);
+
+ // Remove the context ids moved to NewNode from CurNode, and the
+ // edge from the prior node.
+ set_subtract(CurNode->ContextIds, NewNode->ContextIds);
+ if (PrevNode) {
+ auto *PrevEdge = CurNode->findEdgeFromCallee(PrevNode);
+ assert(PrevEdge);
+ set_subtract(PrevEdge->getContextIds(), NewNode->ContextIds);
+ if (PrevEdge->getContextIds().empty()) {
+ PrevNode->eraseCallerEdge(PrevEdge);
+ CurNode->eraseCalleeEdge(PrevEdge);
+ }
+ }
+ PrevNode = CurNode;
+ }
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::updateStackNodes() {
+ // Map of stack id to all calls with that as the last (outermost caller)
+ // callsite id that has a context node (some might not due to pruning
+ // performed during matching of the allocation profile contexts).
+ // The CallContextInfo contains the Call and a list of its stack ids with
+ // ContextNodes, the function containing Call, and the set of context ids
+ // the analysis will eventually identify for use in any new node created
+ // for that callsite.
+ DenseMap<uint64_t, std::vector<CallContextInfo>> StackIdToMatchingCalls;
+ for (auto &[Func, CallsWithMetadata] : FuncToCallsWithMetadata) {
+ for (auto &Call : CallsWithMetadata) {
+ // Ignore allocations, already handled.
+ if (AllocationCallToContextNodeMap.count(Call))
+ continue;
+ auto StackIdsWithContextNodes =
+ getStackIdsWithContextNodesForCall(Call.call());
+ // If there were no nodes created for MIBs on allocs (maybe this was in
+ // the unambiguous part of the MIB stack that was pruned), ignore.
+ if (StackIdsWithContextNodes.empty())
+ continue;
+ // Otherwise, record this Call along with the list of ids for the last
+ // (outermost caller) stack id with a node.
+ StackIdToMatchingCalls[StackIdsWithContextNodes.back()].push_back(
+ {Call.call(), StackIdsWithContextNodes, Func, {}});
+ }
+ }
+
+ // First make a pass through all stack ids that correspond to a call,
+ // as identified in the above loop. Compute the context ids corresponding to
+ // each of these calls when they correspond to multiple stack ids due to
+ // due to inlining. Perform any duplication of context ids required when
+ // there is more than one call with the same stack ids. Their (possibly newly
+ // duplicated) context ids are saved in the StackIdToMatchingCalls map.
+ DenseMap<uint32_t, DenseSet<uint32_t>> OldToNewContextIds;
+ for (auto &It : StackIdToMatchingCalls) {
+ auto &Calls = It.getSecond();
+ // Skip single calls with a single stack id. These don't need a new node.
+ if (Calls.size() == 1) {
+ auto &Ids = std::get<1>(Calls[0]);
+ if (Ids.size() == 1)
+ continue;
+ }
+ // In order to do the best and maximal matching of inlined calls to context
+ // node sequences we will sort the vectors of stack ids in descending order
+ // of length, and within each length, lexicographically by stack id. The
+ // latter is so that we can specially handle calls that have identical stack
+ // id sequences (either due to cloning or artificially because of the MIB
+ // context pruning).
+ std::stable_sort(Calls.begin(), Calls.end(),
+ [](const CallContextInfo &A, const CallContextInfo &B) {
+ auto &IdsA = std::get<1>(A);
+ auto &IdsB = std::get<1>(B);
+ return IdsA.size() > IdsB.size() ||
+ (IdsA.size() == IdsB.size() && IdsA < IdsB);
+ });
+
+ // Find the node for the last stack id, which should be the same
+ // across all calls recorded for this id, and is the id for this
+ // entry in the StackIdToMatchingCalls map.
+ uint64_t LastId = It.getFirst();
+ ContextNode *LastNode = getNodeForStackId(LastId);
+ // We should only have kept stack ids that had nodes.
+ assert(LastNode);
+
+ if (LastNode->Recursive)
+ continue;
+
+ // Initialize the context ids with the last node's. We will subsequently
+ // refine the context ids by computing the intersection along all edges.
+ DenseSet<uint32_t> LastNodeContextIds = LastNode->ContextIds;
+ assert(!LastNodeContextIds.empty());
+
+ for (unsigned I = 0; I < Calls.size(); I++) {
+ auto &[Call, Ids, Func, SavedContextIds] = Calls[I];
+ assert(SavedContextIds.empty());
+ assert(LastId == Ids.back());
+
+ // First compute the context ids for this stack id sequence (the
+ // intersection of the context ids of the corresponding nodes).
+ // Start with the remaining saved ids for the last node.
+ assert(!LastNodeContextIds.empty());
+ DenseSet<uint32_t> StackSequenceContextIds = LastNodeContextIds;
+
+ ContextNode *PrevNode = LastNode;
+ ContextNode *CurNode = LastNode;
+ bool Skip = false;
+
+ // Iterate backwards through the stack Ids, starting after the last Id
+ // in the list, which was handled once outside for all Calls.
+ for (auto IdIter = Ids.rbegin() + 1; IdIter != Ids.rend(); IdIter++) {
+ auto Id = *IdIter;
+ CurNode = getNodeForStackId(Id);
+ // We should only have kept stack ids that had nodes.
+ assert(CurNode);
+
+ if (CurNode->Recursive) {
+ Skip = true;
+ break;
+ }
+
+ auto *Edge = CurNode->findEdgeFromCaller(PrevNode);
+ // If there is no edge then the nodes belong to different MIB contexts,
+ // and we should skip this inlined context sequence. For example, this
+ // particular inlined context may include stack ids A->B, and we may
+ // indeed have nodes for both A and B, but it is possible that they were
+ // never profiled in sequence in a single MIB for any allocation (i.e.
+ // we might have profiled an allocation that involves the callsite A,
+ // but through a different one of its callee callsites, and we might
+ // have profiled an allocation that involves callsite B, but reached
+ // from a different caller callsite).
+ if (!Edge) {
+ Skip = true;
+ break;
+ }
+ PrevNode = CurNode;
+
+ // Update the context ids, which is the intersection of the ids along
+ // all edges in the sequence.
+ set_intersect(StackSequenceContextIds, Edge->getContextIds());
+
+ // If we now have no context ids for clone, skip this call.
+ if (StackSequenceContextIds.empty()) {
+ Skip = true;
+ break;
+ }
+ }
+ if (Skip)
+ continue;
+
+ // If some of this call's stack ids did not have corresponding nodes (due
+ // to pruning), don't include any context ids for contexts that extend
+ // beyond these nodes. Otherwise we would be matching part of unrelated /
+ // not fully matching stack contexts. To do this, subtract any context ids
+ // found in caller nodes of the last node found above.
+ if (Ids.back() != getLastStackId(Call)) {
+ for (const auto &PE : CurNode->CallerEdges) {
+ set_subtract(StackSequenceContextIds, PE->getContextIds());
+ if (StackSequenceContextIds.empty())
+ break;
+ }
+ // If we now have no context ids for clone, skip this call.
+ if (StackSequenceContextIds.empty())
+ continue;
+ }
+
+ // Check if the next set of stack ids is the same (since the Calls vector
+ // of tuples is sorted by the stack ids we can just look at the next one).
+ bool DuplicateContextIds = false;
+ if (I + 1 < Calls.size()) {
+ auto NextIds = std::get<1>(Calls[I + 1]);
+ DuplicateContextIds = Ids == NextIds;
+ }
+
+ // If we don't have duplicate context ids, then we can assign all the
+ // context ids computed for the original node sequence to this call.
+ // If there are duplicate calls with the same stack ids then we synthesize
+ // new context ids that are duplicates of the originals. These are
+ // assigned to SavedContextIds, which is a reference into the map entry
+ // for this call, allowing us to access these ids later on.
+ OldToNewContextIds.reserve(OldToNewContextIds.size() +
+ StackSequenceContextIds.size());
+ SavedContextIds =
+ DuplicateContextIds
+ ? duplicateContextIds(StackSequenceContextIds, OldToNewContextIds)
+ : StackSequenceContextIds;
+ assert(!SavedContextIds.empty());
+
+ if (!DuplicateContextIds) {
+ // Update saved last node's context ids to remove those that are
+ // assigned to other calls, so that it is ready for the next call at
+ // this stack id.
+ set_subtract(LastNodeContextIds, StackSequenceContextIds);
+ if (LastNodeContextIds.empty())
+ break;
+ }
+ }
+ }
+
+ // Propagate the duplicate context ids over the graph.
+ propagateDuplicateContextIds(OldToNewContextIds);
+
+ if (VerifyCCG)
+ check();
+
+ // Now perform a post-order traversal over the graph, starting with the
+ // allocation nodes, essentially processing nodes from callers to callees.
+ // For any that contains an id in the map, update the graph to contain new
+ // nodes representing any inlining at interior callsites. Note we move the
+ // associated context ids over to the new nodes.
+ DenseSet<const ContextNode *> Visited;
+ for (auto &Entry : AllocationCallToContextNodeMap)
+ assignStackNodesPostOrder(Entry.second, Visited, StackIdToMatchingCalls);
+}
+
+uint64_t ModuleCallsiteContextGraph::getLastStackId(Instruction *Call) {
+ CallStack<MDNode, MDNode::op_iterator> CallsiteContext(
+ Call->getMetadata(LLVMContext::MD_callsite));
+ return CallsiteContext.back();
+}
+
+uint64_t IndexCallsiteContextGraph::getLastStackId(IndexCall &Call) {
+ assert(isa<CallsiteInfo *>(Call.getBase()));
+ CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
+ CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+ // Need to convert index into stack id.
+ return Index.getStackIdAtIndex(CallsiteContext.back());
+}
+
+static const std::string MemProfCloneSuffix = ".memprof.";
+
+static std::string getMemProfFuncName(Twine Base, unsigned CloneNo) {
+ // We use CloneNo == 0 to refer to the original version, which doesn't get
+ // renamed with a suffix.
+ if (!CloneNo)
+ return Base.str();
+ return (Base + MemProfCloneSuffix + Twine(CloneNo)).str();
+}
+
+std::string ModuleCallsiteContextGraph::getLabel(const Function *Func,
+ const Instruction *Call,
+ unsigned CloneNo) const {
+ return (Twine(Call->getFunction()->getName()) + " -> " +
+ cast<CallBase>(Call)->getCalledFunction()->getName())
+ .str();
+}
+
+std::string IndexCallsiteContextGraph::getLabel(const FunctionSummary *Func,
+ const IndexCall &Call,
+ unsigned CloneNo) const {
+ auto VI = FSToVIMap.find(Func);
+ assert(VI != FSToVIMap.end());
+ if (isa<AllocInfo *>(Call.getBase()))
+ return (VI->second.name() + " -> alloc").str();
+ else {
+ auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call.getBase());
+ return (VI->second.name() + " -> " +
+ getMemProfFuncName(Callsite->Callee.name(),
+ Callsite->Clones[CloneNo]))
+ .str();
+ }
+}
+
+std::vector<uint64_t>
+ModuleCallsiteContextGraph::getStackIdsWithContextNodesForCall(
+ Instruction *Call) {
+ CallStack<MDNode, MDNode::op_iterator> CallsiteContext(
+ Call->getMetadata(LLVMContext::MD_callsite));
+ return getStackIdsWithContextNodes<MDNode, MDNode::op_iterator>(
+ CallsiteContext);
+}
+
+std::vector<uint64_t>
+IndexCallsiteContextGraph::getStackIdsWithContextNodesForCall(IndexCall &Call) {
+ assert(isa<CallsiteInfo *>(Call.getBase()));
+ CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
+ CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+ return getStackIdsWithContextNodes<CallsiteInfo,
+ SmallVector<unsigned>::const_iterator>(
+ CallsiteContext);
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+template <class NodeT, class IteratorT>
+std::vector<uint64_t>
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getStackIdsWithContextNodes(
+ CallStack<NodeT, IteratorT> &CallsiteContext) {
+ std::vector<uint64_t> StackIds;
+ for (auto IdOrIndex : CallsiteContext) {
+ auto StackId = getStackId(IdOrIndex);
+ ContextNode *Node = getNodeForStackId(StackId);
+ if (!Node)
+ break;
+ StackIds.push_back(StackId);
+ }
+ return StackIds;
+}
+
+ModuleCallsiteContextGraph::ModuleCallsiteContextGraph(
+ Module &M, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter)
+ : Mod(M), OREGetter(OREGetter) {
+ for (auto &F : M) {
+ std::vector<CallInfo> CallsWithMetadata;
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ if (!isa<CallBase>(I))
+ continue;
+ if (auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof)) {
+ CallsWithMetadata.push_back(&I);
+ auto *AllocNode = addAllocNode(&I, &F);
+ auto *CallsiteMD = I.getMetadata(LLVMContext::MD_callsite);
+ assert(CallsiteMD);
+ CallStack<MDNode, MDNode::op_iterator> CallsiteContext(CallsiteMD);
+ // Add all of the MIBs and their stack nodes.
+ for (auto &MDOp : MemProfMD->operands()) {
+ auto *MIBMD = cast<const MDNode>(MDOp);
+ MDNode *StackNode = getMIBStackNode(MIBMD);
+ assert(StackNode);
+ CallStack<MDNode, MDNode::op_iterator> StackContext(StackNode);
+ addStackNodesForMIB<MDNode, MDNode::op_iterator>(
+ AllocNode, StackContext, CallsiteContext,
+ getMIBAllocType(MIBMD));
+ }
+ assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None);
+ // Memprof and callsite metadata on memory allocations no longer
+ // needed.
+ I.setMetadata(LLVMContext::MD_memprof, nullptr);
+ I.setMetadata(LLVMContext::MD_callsite, nullptr);
+ }
+ // For callsite metadata, add to list for this function for later use.
+ else if (I.getMetadata(LLVMContext::MD_callsite))
+ CallsWithMetadata.push_back(&I);
+ }
+ }
+ if (!CallsWithMetadata.empty())
+ FuncToCallsWithMetadata.push_back({&F, CallsWithMetadata});
+ }
+
+ if (DumpCCG) {
+ dbgs() << "CCG before updating call stack chains:\n";
+ dbgs() << *this;
+ }
+
+ if (ExportToDot)
+ exportToDot("prestackupdate");
+
+ updateStackNodes();
+
+ handleCallsitesWithMultipleTargets();
+
+ // Strip off remaining callsite metadata, no longer needed.
+ for (auto &FuncEntry : FuncToCallsWithMetadata)
+ for (auto &Call : FuncEntry.second)
+ Call.call()->setMetadata(LLVMContext::MD_callsite, nullptr);
+}
+
+IndexCallsiteContextGraph::IndexCallsiteContextGraph(
+ ModuleSummaryIndex &Index,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing)
+ : Index(Index) {
+ for (auto &I : Index) {
+ auto VI = Index.getValueInfo(I);
+ for (auto &S : VI.getSummaryList()) {
+ // We should only add the prevailing nodes. Otherwise we may try to clone
+ // in a weak copy that won't be linked (and may be different than the
+ // prevailing version).
+ // We only keep the memprof summary on the prevailing copy now when
+ // building the combined index, as a space optimization, however don't
+ // rely on this optimization. The linker doesn't resolve local linkage
+ // values so don't check whether those are prevailing.
+ if (!GlobalValue::isLocalLinkage(S->linkage()) &&
+ !isPrevailing(VI.getGUID(), S.get()))
+ continue;
+ auto *FS = dyn_cast<FunctionSummary>(S.get());
+ if (!FS)
+ continue;
+ std::vector<CallInfo> CallsWithMetadata;
+ if (!FS->allocs().empty()) {
+ for (auto &AN : FS->mutableAllocs()) {
+ // This can happen because of recursion elimination handling that
+ // currently exists in ModuleSummaryAnalysis. Skip these for now.
+ // We still added them to the summary because we need to be able to
+ // correlate properly in applyImport in the backends.
+ if (AN.MIBs.empty())
+ continue;
+ CallsWithMetadata.push_back({&AN});
+ auto *AllocNode = addAllocNode({&AN}, FS);
+ // Pass an empty CallStack to the CallsiteContext (second)
+ // parameter, since for ThinLTO we already collapsed out the inlined
+ // stack ids on the allocation call during ModuleSummaryAnalysis.
+ CallStack<MIBInfo, SmallVector<unsigned>::const_iterator>
+ EmptyContext;
+ // Now add all of the MIBs and their stack nodes.
+ for (auto &MIB : AN.MIBs) {
+ CallStack<MIBInfo, SmallVector<unsigned>::const_iterator>
+ StackContext(&MIB);
+ addStackNodesForMIB<MIBInfo, SmallVector<unsigned>::const_iterator>(
+ AllocNode, StackContext, EmptyContext, MIB.AllocType);
+ }
+ assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None);
+ // Initialize version 0 on the summary alloc node to the current alloc
+ // type, unless it has both types in which case make it default, so
+ // that in the case where we aren't able to clone the original version
+ // always ends up with the default allocation behavior.
+ AN.Versions[0] = (uint8_t)allocTypeToUse(AllocNode->AllocTypes);
+ }
+ }
+ // For callsite metadata, add to list for this function for later use.
+ if (!FS->callsites().empty())
+ for (auto &SN : FS->mutableCallsites())
+ CallsWithMetadata.push_back({&SN});
+
+ if (!CallsWithMetadata.empty())
+ FuncToCallsWithMetadata.push_back({FS, CallsWithMetadata});
+
+ if (!FS->allocs().empty() || !FS->callsites().empty())
+ FSToVIMap[FS] = VI;
+ }
+ }
+
+ if (DumpCCG) {
+ dbgs() << "CCG before updating call stack chains:\n";
+ dbgs() << *this;
+ }
+
+ if (ExportToDot)
+ exportToDot("prestackupdate");
+
+ updateStackNodes();
+
+ handleCallsitesWithMultipleTargets();
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy,
+ CallTy>::handleCallsitesWithMultipleTargets() {
+ // Look for and workaround callsites that call multiple functions.
+ // This can happen for indirect calls, which needs better handling, and in
+ // more rare cases (e.g. macro expansion).
+ // TODO: To fix this for indirect calls we will want to perform speculative
+ // devirtualization using either the normal PGO info with ICP, or using the
+ // information in the profiled MemProf contexts. We can do this prior to
+ // this transformation for regular LTO, and for ThinLTO we can simulate that
+ // effect in the summary and perform the actual speculative devirtualization
+ // while cloning in the ThinLTO backend.
+ for (auto Entry = NonAllocationCallToContextNodeMap.begin();
+ Entry != NonAllocationCallToContextNodeMap.end();) {
+ auto *Node = Entry->second;
+ assert(Node->Clones.empty());
+ // Check all node callees and see if in the same function.
+ bool Removed = false;
+ auto Call = Node->Call.call();
+ for (auto &Edge : Node->CalleeEdges) {
+ if (!Edge->Callee->hasCall())
+ continue;
+ assert(NodeToCallingFunc.count(Edge->Callee));
+ // Check if the called function matches that of the callee node.
+ if (calleeMatchesFunc(Call, NodeToCallingFunc[Edge->Callee]))
+ continue;
+ // Work around by setting Node to have a null call, so it gets
+ // skipped during cloning. Otherwise assignFunctions will assert
+ // because its data structures are not designed to handle this case.
+ Entry = NonAllocationCallToContextNodeMap.erase(Entry);
+ Node->setCall(CallInfo());
+ Removed = true;
+ break;
+ }
+ if (!Removed)
+ Entry++;
+ }
+}
+
+uint64_t ModuleCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
+ // In the Module (IR) case this is already the Id.
+ return IdOrIndex;
+}
+
+uint64_t IndexCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
+ // In the Index case this is an index into the stack id list in the summary
+ // index, convert it to an Id.
+ return Index.getStackIdAtIndex(IdOrIndex);
+}
+
+bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call,
+ const Function *Func) {
+ auto *CB = dyn_cast<CallBase>(Call);
+ if (!CB->getCalledOperand())
+ return false;
+ auto *CalleeVal = CB->getCalledOperand()->stripPointerCasts();
+ auto *CalleeFunc = dyn_cast<Function>(CalleeVal);
+ if (CalleeFunc == Func)
+ return true;
+ auto *Alias = dyn_cast<GlobalAlias>(CalleeVal);
+ return Alias && Alias->getAliasee() == Func;
+}
+
+bool IndexCallsiteContextGraph::calleeMatchesFunc(IndexCall &Call,
+ const FunctionSummary *Func) {
+ ValueInfo Callee =
+ dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
+ // If there is no summary list then this is a call to an externally defined
+ // symbol.
+ AliasSummary *Alias =
+ Callee.getSummaryList().empty()
+ ? nullptr
+ : dyn_cast<AliasSummary>(Callee.getSummaryList()[0].get());
+ assert(FSToVIMap.count(Func));
+ return Callee == FSToVIMap[Func] ||
+ // If callee is an alias, check the aliasee, since only function
+ // summary base objects will contain the stack node summaries and thus
+ // get a context node.
+ (Alias && Alias->getAliaseeVI() == FSToVIMap[Func]);
+}
+
+static std::string getAllocTypeString(uint8_t AllocTypes) {
+ if (!AllocTypes)
+ return "None";
+ std::string Str;
+ if (AllocTypes & (uint8_t)AllocationType::NotCold)
+ Str += "NotCold";
+ if (AllocTypes & (uint8_t)AllocationType::Cold)
+ Str += "Cold";
+ return Str;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::dump()
+ const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::print(
+ raw_ostream &OS) const {
+ OS << "Node " << this << "\n";
+ OS << "\t";
+ printCall(OS);
+ if (Recursive)
+ OS << " (recursive)";
+ OS << "\n";
+ OS << "\tAllocTypes: " << getAllocTypeString(AllocTypes) << "\n";
+ OS << "\tContextIds:";
+ std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end());
+ std::sort(SortedIds.begin(), SortedIds.end());
+ for (auto Id : SortedIds)
+ OS << " " << Id;
+ OS << "\n";
+ OS << "\tCalleeEdges:\n";
+ for (auto &Edge : CalleeEdges)
+ OS << "\t\t" << *Edge << "\n";
+ OS << "\tCallerEdges:\n";
+ for (auto &Edge : CallerEdges)
+ OS << "\t\t" << *Edge << "\n";
+ if (!Clones.empty()) {
+ OS << "\tClones: ";
+ FieldSeparator FS;
+ for (auto *Clone : Clones)
+ OS << FS << Clone;
+ OS << "\n";
+ } else if (CloneOf) {
+ OS << "\tClone of " << CloneOf << "\n";
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge::dump()
+ const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge::print(
+ raw_ostream &OS) const {
+ OS << "Edge from Callee " << Callee << " to Caller: " << Caller
+ << " AllocTypes: " << getAllocTypeString(AllocTypes);
+ OS << " ContextIds:";
+ std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end());
+ std::sort(SortedIds.begin(), SortedIds.end());
+ for (auto Id : SortedIds)
+ OS << " " << Id;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::dump() const {
+ print(dbgs());
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::print(
+ raw_ostream &OS) const {
+ OS << "Callsite Context Graph:\n";
+ using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *;
+ for (const auto Node : nodes<GraphType>(this)) {
+ if (Node->isRemoved())
+ continue;
+ Node->print(OS);
+ OS << "\n";
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+static void checkEdge(
+ const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &Edge) {
+ // Confirm that alloc type is not None and that we have at least one context
+ // id.
+ assert(Edge->AllocTypes != (uint8_t)AllocationType::None);
+ assert(!Edge->ContextIds.empty());
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+static void checkNode(const ContextNode<DerivedCCG, FuncTy, CallTy> *Node,
+ bool CheckEdges = true) {
+ if (Node->isRemoved())
+ return;
+ // Node's context ids should be the union of both its callee and caller edge
+ // context ids.
+ if (Node->CallerEdges.size()) {
+ auto EI = Node->CallerEdges.begin();
+ auto &FirstEdge = *EI;
+ EI++;
+ DenseSet<uint32_t> CallerEdgeContextIds(FirstEdge->ContextIds);
+ for (; EI != Node->CallerEdges.end(); EI++) {
+ const auto &Edge = *EI;
+ if (CheckEdges)
+ checkEdge<DerivedCCG, FuncTy, CallTy>(Edge);
+ set_union(CallerEdgeContextIds, Edge->ContextIds);
+ }
+ // Node can have more context ids than callers if some contexts terminate at
+ // node and some are longer.
+ assert(Node->ContextIds == CallerEdgeContextIds ||
+ set_is_subset(CallerEdgeContextIds, Node->ContextIds));
+ }
+ if (Node->CalleeEdges.size()) {
+ auto EI = Node->CalleeEdges.begin();
+ auto &FirstEdge = *EI;
+ EI++;
+ DenseSet<uint32_t> CalleeEdgeContextIds(FirstEdge->ContextIds);
+ for (; EI != Node->CalleeEdges.end(); EI++) {
+ const auto &Edge = *EI;
+ if (CheckEdges)
+ checkEdge<DerivedCCG, FuncTy, CallTy>(Edge);
+ set_union(CalleeEdgeContextIds, Edge->ContextIds);
+ }
+ assert(Node->ContextIds == CalleeEdgeContextIds);
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::check() const {
+ using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *;
+ for (const auto Node : nodes<GraphType>(this)) {
+ checkNode<DerivedCCG, FuncTy, CallTy>(Node, /*CheckEdges=*/false);
+ for (auto &Edge : Node->CallerEdges)
+ checkEdge<DerivedCCG, FuncTy, CallTy>(Edge);
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+struct GraphTraits<const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *> {
+ using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *;
+ using NodeRef = const ContextNode<DerivedCCG, FuncTy, CallTy> *;
+
+ using NodePtrTy = std::unique_ptr<ContextNode<DerivedCCG, FuncTy, CallTy>>;
+ static NodeRef getNode(const NodePtrTy &P) { return P.get(); }
+
+ using nodes_iterator =
+ mapped_iterator<typename std::vector<NodePtrTy>::const_iterator,
+ decltype(&getNode)>;
+
+ static nodes_iterator nodes_begin(GraphType G) {
+ return nodes_iterator(G->NodeOwner.begin(), &getNode);
+ }
+
+ static nodes_iterator nodes_end(GraphType G) {
+ return nodes_iterator(G->NodeOwner.end(), &getNode);
+ }
+
+ static NodeRef getEntryNode(GraphType G) {
+ return G->NodeOwner.begin()->get();
+ }
+
+ using EdgePtrTy = std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>>;
+ static const ContextNode<DerivedCCG, FuncTy, CallTy> *
+ GetCallee(const EdgePtrTy &P) {
+ return P->Callee;
+ }
+
+ using ChildIteratorType =
+ mapped_iterator<typename std::vector<std::shared_ptr<ContextEdge<
+ DerivedCCG, FuncTy, CallTy>>>::const_iterator,
+ decltype(&GetCallee)>;
+
+ static ChildIteratorType child_begin(NodeRef N) {
+ return ChildIteratorType(N->CalleeEdges.begin(), &GetCallee);
+ }
+
+ static ChildIteratorType child_end(NodeRef N) {
+ return ChildIteratorType(N->CalleeEdges.end(), &GetCallee);
+ }
+};
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+struct DOTGraphTraits<const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *>
+ : public DefaultDOTGraphTraits {
+ DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
+
+ using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *;
+ using GTraits = GraphTraits<GraphType>;
+ using NodeRef = typename GTraits::NodeRef;
+ using ChildIteratorType = typename GTraits::ChildIteratorType;
+
+ static std::string getNodeLabel(NodeRef Node, GraphType G) {
+ std::string LabelString =
+ (Twine("OrigId: ") + (Node->IsAllocation ? "Alloc" : "") +
+ Twine(Node->OrigStackOrAllocId))
+ .str();
+ LabelString += "\n";
+ if (Node->hasCall()) {
+ auto Func = G->NodeToCallingFunc.find(Node);
+ assert(Func != G->NodeToCallingFunc.end());
+ LabelString +=
+ G->getLabel(Func->second, Node->Call.call(), Node->Call.cloneNo());
+ } else {
+ LabelString += "null call";
+ if (Node->Recursive)
+ LabelString += " (recursive)";
+ else
+ LabelString += " (external)";
+ }
+ return LabelString;
+ }
+
+ static std::string getNodeAttributes(NodeRef Node, GraphType) {
+ std::string AttributeString = (Twine("tooltip=\"") + getNodeId(Node) + " " +
+ getContextIds(Node->ContextIds) + "\"")
+ .str();
+ AttributeString +=
+ (Twine(",fillcolor=\"") + getColor(Node->AllocTypes) + "\"").str();
+ AttributeString += ",style=\"filled\"";
+ if (Node->CloneOf) {
+ AttributeString += ",color=\"blue\"";
+ AttributeString += ",style=\"filled,bold,dashed\"";
+ } else
+ AttributeString += ",style=\"filled\"";
+ return AttributeString;
+ }
+
+ static std::string getEdgeAttributes(NodeRef, ChildIteratorType ChildIter,
+ GraphType) {
+ auto &Edge = *(ChildIter.getCurrent());
+ return (Twine("tooltip=\"") + getContextIds(Edge->ContextIds) + "\"" +
+ Twine(",fillcolor=\"") + getColor(Edge->AllocTypes) + "\"")
+ .str();
+ }
+
+ // Since the NodeOwners list includes nodes that are no longer connected to
+ // the graph, skip them here.
+ static bool isNodeHidden(NodeRef Node, GraphType) {
+ return Node->isRemoved();
+ }
+
+private:
+ static std::string getContextIds(const DenseSet<uint32_t> &ContextIds) {
+ std::string IdString = "ContextIds:";
+ if (ContextIds.size() < 100) {
+ std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end());
+ std::sort(SortedIds.begin(), SortedIds.end());
+ for (auto Id : SortedIds)
+ IdString += (" " + Twine(Id)).str();
+ } else {
+ IdString += (" (" + Twine(ContextIds.size()) + " ids)").str();
+ }
+ return IdString;
+ }
+
+ static std::string getColor(uint8_t AllocTypes) {
+ if (AllocTypes == (uint8_t)AllocationType::NotCold)
+ // Color "brown1" actually looks like a lighter red.
+ return "brown1";
+ if (AllocTypes == (uint8_t)AllocationType::Cold)
+ return "cyan";
+ if (AllocTypes ==
+ ((uint8_t)AllocationType::NotCold | (uint8_t)AllocationType::Cold))
+ // Lighter purple.
+ return "mediumorchid1";
+ return "gray";
+ }
+
+ static std::string getNodeId(NodeRef Node) {
+ std::stringstream SStream;
+ SStream << std::hex << "N0x" << (unsigned long long)Node;
+ std::string Result = SStream.str();
+ return Result;
+ }
+};
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::exportToDot(
+ std::string Label) const {
+ WriteGraph(this, "", false, Label,
+ DotFilePathPrefix + "ccg." + Label + ".dot");
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode *
+CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::moveEdgeToNewCalleeClone(
+ const std::shared_ptr<ContextEdge> &Edge, EdgeIter *CallerEdgeI) {
+ ContextNode *Node = Edge->Callee;
+ NodeOwner.push_back(
+ std::make_unique<ContextNode>(Node->IsAllocation, Node->Call));
+ ContextNode *Clone = NodeOwner.back().get();
+ Node->addClone(Clone);
+ assert(NodeToCallingFunc.count(Node));
+ NodeToCallingFunc[Clone] = NodeToCallingFunc[Node];
+ moveEdgeToExistingCalleeClone(Edge, Clone, CallerEdgeI, /*NewClone=*/true);
+ return Clone;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::
+ moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge,
+ ContextNode *NewCallee, EdgeIter *CallerEdgeI,
+ bool NewClone) {
+ // NewCallee and Edge's current callee must be clones of the same original
+ // node (Edge's current callee may be the original node too).
+ assert(NewCallee->getOrigNode() == Edge->Callee->getOrigNode());
+ auto &EdgeContextIds = Edge->getContextIds();
+ ContextNode *OldCallee = Edge->Callee;
+ if (CallerEdgeI)
+ *CallerEdgeI = OldCallee->CallerEdges.erase(*CallerEdgeI);
+ else
+ OldCallee->eraseCallerEdge(Edge.get());
+ Edge->Callee = NewCallee;
+ NewCallee->CallerEdges.push_back(Edge);
+ // Don't need to update Edge's context ids since we are simply reconnecting
+ // it.
+ set_subtract(OldCallee->ContextIds, EdgeContextIds);
+ NewCallee->ContextIds.insert(EdgeContextIds.begin(), EdgeContextIds.end());
+ NewCallee->AllocTypes |= Edge->AllocTypes;
+ OldCallee->AllocTypes = computeAllocType(OldCallee->ContextIds);
+ // OldCallee alloc type should be None iff its context id set is now empty.
+ assert((OldCallee->AllocTypes == (uint8_t)AllocationType::None) ==
+ OldCallee->ContextIds.empty());
+ // Now walk the old callee node's callee edges and move Edge's context ids
+ // over to the corresponding edge into the clone (which is created here if
+ // this is a newly created clone).
+ for (auto &OldCalleeEdge : OldCallee->CalleeEdges) {
+ // The context ids moving to the new callee are the subset of this edge's
+ // context ids and the context ids on the caller edge being moved.
+ DenseSet<uint32_t> EdgeContextIdsToMove =
+ set_intersection(OldCalleeEdge->getContextIds(), EdgeContextIds);
+ set_subtract(OldCalleeEdge->getContextIds(), EdgeContextIdsToMove);
+ OldCalleeEdge->AllocTypes =
+ computeAllocType(OldCalleeEdge->getContextIds());
+ if (!NewClone) {
+ // Update context ids / alloc type on corresponding edge to NewCallee.
+ // There is a chance this may not exist if we are reusing an existing
+ // clone, specifically during function assignment, where we would have
+ // removed none type edges after creating the clone. If we can't find
+ // a corresponding edge there, fall through to the cloning below.
+ if (auto *NewCalleeEdge =
+ NewCallee->findEdgeFromCallee(OldCalleeEdge->Callee)) {
+ NewCalleeEdge->getContextIds().insert(EdgeContextIdsToMove.begin(),
+ EdgeContextIdsToMove.end());
+ NewCalleeEdge->AllocTypes |= computeAllocType(EdgeContextIdsToMove);
+ continue;
+ }
+ }
+ auto NewEdge = std::make_shared<ContextEdge>(
+ OldCalleeEdge->Callee, NewCallee,
+ computeAllocType(EdgeContextIdsToMove), EdgeContextIdsToMove);
+ NewCallee->CalleeEdges.push_back(NewEdge);
+ NewEdge->Callee->CallerEdges.push_back(NewEdge);
+ }
+ if (VerifyCCG) {
+ checkNode<DerivedCCG, FuncTy, CallTy>(OldCallee, /*CheckEdges=*/false);
+ checkNode<DerivedCCG, FuncTy, CallTy>(NewCallee, /*CheckEdges=*/false);
+ for (const auto &OldCalleeEdge : OldCallee->CalleeEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(OldCalleeEdge->Callee,
+ /*CheckEdges=*/false);
+ for (const auto &NewCalleeEdge : NewCallee->CalleeEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(NewCalleeEdge->Callee,
+ /*CheckEdges=*/false);
+ }
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones() {
+ DenseSet<const ContextNode *> Visited;
+ for (auto &Entry : AllocationCallToContextNodeMap)
+ identifyClones(Entry.second, Visited);
+}
+
+// helper function to check an AllocType is cold or notcold or both.
+bool checkColdOrNotCold(uint8_t AllocType) {
+ return (AllocType == (uint8_t)AllocationType::Cold) ||
+ (AllocType == (uint8_t)AllocationType::NotCold) ||
+ (AllocType ==
+ ((uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold));
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones(
+ ContextNode *Node, DenseSet<const ContextNode *> &Visited) {
+ if (VerifyNodes)
+ checkNode<DerivedCCG, FuncTy, CallTy>(Node);
+ assert(!Node->CloneOf);
+
+ // If Node as a null call, then either it wasn't found in the module (regular
+ // LTO) or summary index (ThinLTO), or there were other conditions blocking
+ // cloning (e.g. recursion, calls multiple targets, etc).
+ // Do this here so that we don't try to recursively clone callers below, which
+ // isn't useful at least for this node.
+ if (!Node->hasCall())
+ return;
+
+#ifndef NDEBUG
+ auto Insert =
+#endif
+ Visited.insert(Node);
+ // We should not have visited this node yet.
+ assert(Insert.second);
+ // The recursive call to identifyClones may delete the current edge from the
+ // CallerEdges vector. Make a copy and iterate on that, simpler than passing
+ // in an iterator and having recursive call erase from it. Other edges may
+ // also get removed during the recursion, which will have null Callee and
+ // Caller pointers (and are deleted later), so we skip those below.
+ {
+ auto CallerEdges = Node->CallerEdges;
+ for (auto &Edge : CallerEdges) {
+ // Skip any that have been removed by an earlier recursive call.
+ if (Edge->Callee == nullptr && Edge->Caller == nullptr) {
+ assert(!std::count(Node->CallerEdges.begin(), Node->CallerEdges.end(),
+ Edge));
+ continue;
+ }
+ // Ignore any caller we previously visited via another edge.
+ if (!Visited.count(Edge->Caller) && !Edge->Caller->CloneOf) {
+ identifyClones(Edge->Caller, Visited);
+ }
+ }
+ }
+
+ // Check if we reached an unambiguous call or have have only a single caller.
+ if (hasSingleAllocType(Node->AllocTypes) || Node->CallerEdges.size() <= 1)
+ return;
+
+ // We need to clone.
+
+ // Try to keep the original version as alloc type NotCold. This will make
+ // cases with indirect calls or any other situation with an unknown call to
+ // the original function get the default behavior. We do this by sorting the
+ // CallerEdges of the Node we will clone by alloc type.
+ //
+ // Give NotCold edge the lowest sort priority so those edges are at the end of
+ // the caller edges vector, and stay on the original version (since the below
+ // code clones greedily until it finds all remaining edges have the same type
+ // and leaves the remaining ones on the original Node).
+ //
+ // We shouldn't actually have any None type edges, so the sorting priority for
+ // that is arbitrary, and we assert in that case below.
+ const unsigned AllocTypeCloningPriority[] = {/*None*/ 3, /*NotCold*/ 4,
+ /*Cold*/ 1,
+ /*NotColdCold*/ 2};
+ std::stable_sort(Node->CallerEdges.begin(), Node->CallerEdges.end(),
+ [&](const std::shared_ptr<ContextEdge> &A,
+ const std::shared_ptr<ContextEdge> &B) {
+ assert(checkColdOrNotCold(A->AllocTypes) &&
+ checkColdOrNotCold(B->AllocTypes));
+
+ if (A->AllocTypes == B->AllocTypes)
+ // Use the first context id for each edge as a
+ // tie-breaker.
+ return *A->ContextIds.begin() < *B->ContextIds.begin();
+ return AllocTypeCloningPriority[A->AllocTypes] <
+ AllocTypeCloningPriority[B->AllocTypes];
+ });
+
+ assert(Node->AllocTypes != (uint8_t)AllocationType::None);
+
+ // Iterate until we find no more opportunities for disambiguating the alloc
+ // types via cloning. In most cases this loop will terminate once the Node
+ // has a single allocation type, in which case no more cloning is needed.
+ // We need to be able to remove Edge from CallerEdges, so need to adjust
+ // iterator inside the loop.
+ for (auto EI = Node->CallerEdges.begin(); EI != Node->CallerEdges.end();) {
+ auto CallerEdge = *EI;
+
+ // See if cloning the prior caller edge left this node with a single alloc
+ // type or a single caller. In that case no more cloning of Node is needed.
+ if (hasSingleAllocType(Node->AllocTypes) || Node->CallerEdges.size() <= 1)
+ break;
+
+ // Compute the node callee edge alloc types corresponding to the context ids
+ // for this caller edge.
+ std::vector<uint8_t> CalleeEdgeAllocTypesForCallerEdge;
+ CalleeEdgeAllocTypesForCallerEdge.reserve(Node->CalleeEdges.size());
+ for (auto &CalleeEdge : Node->CalleeEdges)
+ CalleeEdgeAllocTypesForCallerEdge.push_back(intersectAllocTypes(
+ CalleeEdge->getContextIds(), CallerEdge->getContextIds()));
+
+ // Don't clone if doing so will not disambiguate any alloc types amongst
+ // caller edges (including the callee edges that would be cloned).
+ // Otherwise we will simply move all edges to the clone.
+ //
+ // First check if by cloning we will disambiguate the caller allocation
+ // type from node's allocation type. Query allocTypeToUse so that we don't
+ // bother cloning to distinguish NotCold+Cold from NotCold. Note that
+ // neither of these should be None type.
+ //
+ // Then check if by cloning node at least one of the callee edges will be
+ // disambiguated by splitting out different context ids.
+ assert(CallerEdge->AllocTypes != (uint8_t)AllocationType::None);
+ assert(Node->AllocTypes != (uint8_t)AllocationType::None);
+ if (allocTypeToUse(CallerEdge->AllocTypes) ==
+ allocTypeToUse(Node->AllocTypes) &&
+ allocTypesMatch<DerivedCCG, FuncTy, CallTy>(
+ CalleeEdgeAllocTypesForCallerEdge, Node->CalleeEdges)) {
+ ++EI;
+ continue;
+ }
+
+ // First see if we can use an existing clone. Check each clone and its
+ // callee edges for matching alloc types.
+ ContextNode *Clone = nullptr;
+ for (auto *CurClone : Node->Clones) {
+ if (allocTypeToUse(CurClone->AllocTypes) !=
+ allocTypeToUse(CallerEdge->AllocTypes))
+ continue;
+
+ if (!allocTypesMatch<DerivedCCG, FuncTy, CallTy>(
+ CalleeEdgeAllocTypesForCallerEdge, CurClone->CalleeEdges))
+ continue;
+ Clone = CurClone;
+ break;
+ }
+
+ // The edge iterator is adjusted when we move the CallerEdge to the clone.
+ if (Clone)
+ moveEdgeToExistingCalleeClone(CallerEdge, Clone, &EI);
+ else
+ Clone = moveEdgeToNewCalleeClone(CallerEdge, &EI);
+
+ assert(EI == Node->CallerEdges.end() ||
+ Node->AllocTypes != (uint8_t)AllocationType::None);
+ // Sanity check that no alloc types on clone or its edges are None.
+ assert(Clone->AllocTypes != (uint8_t)AllocationType::None);
+ assert(llvm::none_of(
+ Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) {
+ return E->AllocTypes == (uint8_t)AllocationType::None;
+ }));
+ }
+
+ // Cloning may have resulted in some cloned callee edges with type None,
+ // because they aren't carrying any contexts. Remove those edges.
+ for (auto *Clone : Node->Clones) {
+ removeNoneTypeCalleeEdges(Clone);
+ if (VerifyNodes)
+ checkNode<DerivedCCG, FuncTy, CallTy>(Clone);
+ }
+ // We should still have some context ids on the original Node.
+ assert(!Node->ContextIds.empty());
+
+ // Remove any callee edges that ended up with alloc type None after creating
+ // clones and updating callee edges.
+ removeNoneTypeCalleeEdges(Node);
+
+ // Sanity check that no alloc types on node or edges are None.
+ assert(Node->AllocTypes != (uint8_t)AllocationType::None);
+ assert(llvm::none_of(Node->CalleeEdges,
+ [&](const std::shared_ptr<ContextEdge> &E) {
+ return E->AllocTypes == (uint8_t)AllocationType::None;
+ }));
+ assert(llvm::none_of(Node->CallerEdges,
+ [&](const std::shared_ptr<ContextEdge> &E) {
+ return E->AllocTypes == (uint8_t)AllocationType::None;
+ }));
+
+ if (VerifyNodes)
+ checkNode<DerivedCCG, FuncTy, CallTy>(Node);
+}
+
+void ModuleCallsiteContextGraph::updateAllocationCall(
+ CallInfo &Call, AllocationType AllocType) {
+ std::string AllocTypeString = getAllocTypeAttributeString(AllocType);
+ auto A = llvm::Attribute::get(Call.call()->getFunction()->getContext(),
+ "memprof", AllocTypeString);
+ cast<CallBase>(Call.call())->addFnAttr(A);
+ OREGetter(Call.call()->getFunction())
+ .emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", Call.call())
+ << ore::NV("AllocationCall", Call.call()) << " in clone "
+ << ore::NV("Caller", Call.call()->getFunction())
+ << " marked with memprof allocation attribute "
+ << ore::NV("Attribute", AllocTypeString));
+}
+
+void IndexCallsiteContextGraph::updateAllocationCall(CallInfo &Call,
+ AllocationType AllocType) {
+ auto *AI = Call.call().dyn_cast<AllocInfo *>();
+ assert(AI);
+ assert(AI->Versions.size() > Call.cloneNo());
+ AI->Versions[Call.cloneNo()] = (uint8_t)AllocType;
+}
+
+void ModuleCallsiteContextGraph::updateCall(CallInfo &CallerCall,
+ FuncInfo CalleeFunc) {
+ if (CalleeFunc.cloneNo() > 0)
+ cast<CallBase>(CallerCall.call())->setCalledFunction(CalleeFunc.func());
+ OREGetter(CallerCall.call()->getFunction())
+ .emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CallerCall.call())
+ << ore::NV("Call", CallerCall.call()) << " in clone "
+ << ore::NV("Caller", CallerCall.call()->getFunction())
+ << " assigned to call function clone "
+ << ore::NV("Callee", CalleeFunc.func()));
+}
+
+void IndexCallsiteContextGraph::updateCall(CallInfo &CallerCall,
+ FuncInfo CalleeFunc) {
+ auto *CI = CallerCall.call().dyn_cast<CallsiteInfo *>();
+ assert(CI &&
+ "Caller cannot be an allocation which should not have profiled calls");
+ assert(CI->Clones.size() > CallerCall.cloneNo());
+ CI->Clones[CallerCall.cloneNo()] = CalleeFunc.cloneNo();
+}
+
+CallsiteContextGraph<ModuleCallsiteContextGraph, Function,
+ Instruction *>::FuncInfo
+ModuleCallsiteContextGraph::cloneFunctionForCallsite(
+ FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap,
+ std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) {
+ // Use existing LLVM facilities for cloning and obtaining Call in clone
+ ValueToValueMapTy VMap;
+ auto *NewFunc = CloneFunction(Func.func(), VMap);
+ std::string Name = getMemProfFuncName(Func.func()->getName(), CloneNo);
+ assert(!Func.func()->getParent()->getFunction(Name));
+ NewFunc->setName(Name);
+ for (auto &Inst : CallsWithMetadataInFunc) {
+ // This map always has the initial version in it.
+ assert(Inst.cloneNo() == 0);
+ CallMap[Inst] = {cast<Instruction>(VMap[Inst.call()]), CloneNo};
+ }
+ OREGetter(Func.func())
+ .emit(OptimizationRemark(DEBUG_TYPE, "MemprofClone", Func.func())
+ << "created clone " << ore::NV("NewFunction", NewFunc));
+ return {NewFunc, CloneNo};
+}
+
+CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
+ IndexCall>::FuncInfo
+IndexCallsiteContextGraph::cloneFunctionForCallsite(
+ FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap,
+ std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) {
+ // Check how many clones we have of Call (and therefore function).
+ // The next clone number is the current size of versions array.
+ // Confirm this matches the CloneNo provided by the caller, which is based on
+ // the number of function clones we have.
+ assert(CloneNo ==
+ (Call.call().is<AllocInfo *>()
+ ? Call.call().dyn_cast<AllocInfo *>()->Versions.size()
+ : Call.call().dyn_cast<CallsiteInfo *>()->Clones.size()));
+ // Walk all the instructions in this function. Create a new version for
+ // each (by adding an entry to the Versions/Clones summary array), and copy
+ // over the version being called for the function clone being cloned here.
+ // Additionally, add an entry to the CallMap for the new function clone,
+ // mapping the original call (clone 0, what is in CallsWithMetadataInFunc)
+ // to the new call clone.
+ for (auto &Inst : CallsWithMetadataInFunc) {
+ // This map always has the initial version in it.
+ assert(Inst.cloneNo() == 0);
+ if (auto *AI = Inst.call().dyn_cast<AllocInfo *>()) {
+ assert(AI->Versions.size() == CloneNo);
+ // We assign the allocation type later (in updateAllocationCall), just add
+ // an entry for it here.
+ AI->Versions.push_back(0);
+ } else {
+ auto *CI = Inst.call().dyn_cast<CallsiteInfo *>();
+ assert(CI && CI->Clones.size() == CloneNo);
+ // We assign the clone number later (in updateCall), just add an entry for
+ // it here.
+ CI->Clones.push_back(0);
+ }
+ CallMap[Inst] = {Inst.call(), CloneNo};
+ }
+ return {Func.func(), CloneNo};
+}
+
+// This method assigns cloned callsites to functions, cloning the functions as
+// needed. The assignment is greedy and proceeds roughly as follows:
+//
+// For each function Func:
+// For each call with graph Node having clones:
+// Initialize ClonesWorklist to Node and its clones
+// Initialize NodeCloneCount to 0
+// While ClonesWorklist is not empty:
+// Clone = pop front ClonesWorklist
+// NodeCloneCount++
+// If Func has been cloned less than NodeCloneCount times:
+// If NodeCloneCount is 1:
+// Assign Clone to original Func
+// Continue
+// Create a new function clone
+// If other callers not assigned to call a function clone yet:
+// Assign them to call new function clone
+// Continue
+// Assign any other caller calling the cloned version to new clone
+//
+// For each caller of Clone:
+// If caller is assigned to call a specific function clone:
+// If we cannot assign Clone to that function clone:
+// Create new callsite Clone NewClone
+// Add NewClone to ClonesWorklist
+// Continue
+// Assign Clone to existing caller's called function clone
+// Else:
+// If Clone not already assigned to a function clone:
+// Assign to first function clone without assignment
+// Assign caller to selected function clone
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::assignFunctions() {
+ bool Changed = false;
+
+ // Keep track of the assignment of nodes (callsites) to function clones they
+ // call.
+ DenseMap<ContextNode *, FuncInfo> CallsiteToCalleeFuncCloneMap;
+
+ // Update caller node to call function version CalleeFunc, by recording the
+ // assignment in CallsiteToCalleeFuncCloneMap.
+ auto RecordCalleeFuncOfCallsite = [&](ContextNode *Caller,
+ const FuncInfo &CalleeFunc) {
+ assert(Caller->hasCall());
+ CallsiteToCalleeFuncCloneMap[Caller] = CalleeFunc;
+ };
+
+ // Walk all functions for which we saw calls with memprof metadata, and handle
+ // cloning for each of its calls.
+ for (auto &[Func, CallsWithMetadata] : FuncToCallsWithMetadata) {
+ FuncInfo OrigFunc(Func);
+ // Map from each clone of OrigFunc to a map of remappings of each call of
+ // interest (from original uncloned call to the corresponding cloned call in
+ // that function clone).
+ std::map<FuncInfo, std::map<CallInfo, CallInfo>> FuncClonesToCallMap;
+ for (auto &Call : CallsWithMetadata) {
+ ContextNode *Node = getNodeForInst(Call);
+ // Skip call if we do not have a node for it (all uses of its stack ids
+ // were either on inlined chains or pruned from the MIBs), or if we did
+ // not create any clones for it.
+ if (!Node || Node->Clones.empty())
+ continue;
+ assert(Node->hasCall() &&
+ "Not having a call should have prevented cloning");
+
+ // Track the assignment of function clones to clones of the current
+ // callsite Node being handled.
+ std::map<FuncInfo, ContextNode *> FuncCloneToCurNodeCloneMap;
+
+ // Assign callsite version CallsiteClone to function version FuncClone,
+ // and also assign (possibly cloned) Call to CallsiteClone.
+ auto AssignCallsiteCloneToFuncClone = [&](const FuncInfo &FuncClone,
+ CallInfo &Call,
+ ContextNode *CallsiteClone,
+ bool IsAlloc) {
+ // Record the clone of callsite node assigned to this function clone.
+ FuncCloneToCurNodeCloneMap[FuncClone] = CallsiteClone;
+
+ assert(FuncClonesToCallMap.count(FuncClone));
+ std::map<CallInfo, CallInfo> &CallMap = FuncClonesToCallMap[FuncClone];
+ CallInfo CallClone(Call);
+ if (CallMap.count(Call))
+ CallClone = CallMap[Call];
+ CallsiteClone->setCall(CallClone);
+ };
+
+ // Keep track of the clones of callsite Node that need to be assigned to
+ // function clones. This list may be expanded in the loop body below if we
+ // find additional cloning is required.
+ std::deque<ContextNode *> ClonesWorklist;
+ // Ignore original Node if we moved all of its contexts to clones.
+ if (!Node->ContextIds.empty())
+ ClonesWorklist.push_back(Node);
+ ClonesWorklist.insert(ClonesWorklist.end(), Node->Clones.begin(),
+ Node->Clones.end());
+
+ // Now walk through all of the clones of this callsite Node that we need,
+ // and determine the assignment to a corresponding clone of the current
+ // function (creating new function clones as needed).
+ unsigned NodeCloneCount = 0;
+ while (!ClonesWorklist.empty()) {
+ ContextNode *Clone = ClonesWorklist.front();
+ ClonesWorklist.pop_front();
+ NodeCloneCount++;
+ if (VerifyNodes)
+ checkNode<DerivedCCG, FuncTy, CallTy>(Clone);
+
+ // Need to create a new function clone if we have more callsite clones
+ // than existing function clones, which would have been assigned to an
+ // earlier clone in the list (we assign callsite clones to function
+ // clones greedily).
+ if (FuncClonesToCallMap.size() < NodeCloneCount) {
+ // If this is the first callsite copy, assign to original function.
+ if (NodeCloneCount == 1) {
+ // Since FuncClonesToCallMap is empty in this case, no clones have
+ // been created for this function yet, and no callers should have
+ // been assigned a function clone for this callee node yet.
+ assert(llvm::none_of(
+ Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) {
+ return CallsiteToCalleeFuncCloneMap.count(E->Caller);
+ }));
+ // Initialize with empty call map, assign Clone to original function
+ // and its callers, and skip to the next clone.
+ FuncClonesToCallMap[OrigFunc] = {};
+ AssignCallsiteCloneToFuncClone(
+ OrigFunc, Call, Clone,
+ AllocationCallToContextNodeMap.count(Call));
+ for (auto &CE : Clone->CallerEdges) {
+ // Ignore any caller that does not have a recorded callsite Call.
+ if (!CE->Caller->hasCall())
+ continue;
+ RecordCalleeFuncOfCallsite(CE->Caller, OrigFunc);
+ }
+ continue;
+ }
+
+ // First locate which copy of OrigFunc to clone again. If a caller
+ // of this callsite clone was already assigned to call a particular
+ // function clone, we need to redirect all of those callers to the
+ // new function clone, and update their other callees within this
+ // function.
+ FuncInfo PreviousAssignedFuncClone;
+ auto EI = llvm::find_if(
+ Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) {
+ return CallsiteToCalleeFuncCloneMap.count(E->Caller);
+ });
+ bool CallerAssignedToCloneOfFunc = false;
+ if (EI != Clone->CallerEdges.end()) {
+ const std::shared_ptr<ContextEdge> &Edge = *EI;
+ PreviousAssignedFuncClone =
+ CallsiteToCalleeFuncCloneMap[Edge->Caller];
+ CallerAssignedToCloneOfFunc = true;
+ }
+
+ // Clone function and save it along with the CallInfo map created
+ // during cloning in the FuncClonesToCallMap.
+ std::map<CallInfo, CallInfo> NewCallMap;
+ unsigned CloneNo = FuncClonesToCallMap.size();
+ assert(CloneNo > 0 && "Clone 0 is the original function, which "
+ "should already exist in the map");
+ FuncInfo NewFuncClone = cloneFunctionForCallsite(
+ OrigFunc, Call, NewCallMap, CallsWithMetadata, CloneNo);
+ FuncClonesToCallMap.emplace(NewFuncClone, std::move(NewCallMap));
+ FunctionClonesAnalysis++;
+ Changed = true;
+
+ // If no caller callsites were already assigned to a clone of this
+ // function, we can simply assign this clone to the new func clone
+ // and update all callers to it, then skip to the next clone.
+ if (!CallerAssignedToCloneOfFunc) {
+ AssignCallsiteCloneToFuncClone(
+ NewFuncClone, Call, Clone,
+ AllocationCallToContextNodeMap.count(Call));
+ for (auto &CE : Clone->CallerEdges) {
+ // Ignore any caller that does not have a recorded callsite Call.
+ if (!CE->Caller->hasCall())
+ continue;
+ RecordCalleeFuncOfCallsite(CE->Caller, NewFuncClone);
+ }
+ continue;
+ }
+
+ // We may need to do additional node cloning in this case.
+ // Reset the CallsiteToCalleeFuncCloneMap entry for any callers
+ // that were previously assigned to call PreviousAssignedFuncClone,
+ // to record that they now call NewFuncClone.
+ for (auto CE : Clone->CallerEdges) {
+ // Ignore any caller that does not have a recorded callsite Call.
+ if (!CE->Caller->hasCall())
+ continue;
+
+ if (!CallsiteToCalleeFuncCloneMap.count(CE->Caller) ||
+ // We subsequently fall through to later handling that
+ // will perform any additional cloning required for
+ // callers that were calling other function clones.
+ CallsiteToCalleeFuncCloneMap[CE->Caller] !=
+ PreviousAssignedFuncClone)
+ continue;
+
+ RecordCalleeFuncOfCallsite(CE->Caller, NewFuncClone);
+
+ // If we are cloning a function that was already assigned to some
+ // callers, then essentially we are creating new callsite clones
+ // of the other callsites in that function that are reached by those
+ // callers. Clone the other callees of the current callsite's caller
+ // that were already assigned to PreviousAssignedFuncClone
+ // accordingly. This is important since we subsequently update the
+ // calls from the nodes in the graph and their assignments to callee
+ // functions recorded in CallsiteToCalleeFuncCloneMap.
+ for (auto CalleeEdge : CE->Caller->CalleeEdges) {
+ // Skip any that have been removed on an earlier iteration when
+ // cleaning up newly None type callee edges.
+ if (!CalleeEdge)
+ continue;
+ ContextNode *Callee = CalleeEdge->Callee;
+ // Skip the current callsite, we are looking for other
+ // callsites Caller calls, as well as any that does not have a
+ // recorded callsite Call.
+ if (Callee == Clone || !Callee->hasCall())
+ continue;
+ ContextNode *NewClone = moveEdgeToNewCalleeClone(CalleeEdge);
+ removeNoneTypeCalleeEdges(NewClone);
+ // Moving the edge may have resulted in some none type
+ // callee edges on the original Callee.
+ removeNoneTypeCalleeEdges(Callee);
+ assert(NewClone->AllocTypes != (uint8_t)AllocationType::None);
+ // If the Callee node was already assigned to call a specific
+ // function version, make sure its new clone is assigned to call
+ // that same function clone.
+ if (CallsiteToCalleeFuncCloneMap.count(Callee))
+ RecordCalleeFuncOfCallsite(
+ NewClone, CallsiteToCalleeFuncCloneMap[Callee]);
+ // Update NewClone with the new Call clone of this callsite's Call
+ // created for the new function clone created earlier.
+ // Recall that we have already ensured when building the graph
+ // that each caller can only call callsites within the same
+ // function, so we are guaranteed that Callee Call is in the
+ // current OrigFunc.
+ // CallMap is set up as indexed by original Call at clone 0.
+ CallInfo OrigCall(Callee->getOrigNode()->Call);
+ OrigCall.setCloneNo(0);
+ std::map<CallInfo, CallInfo> &CallMap =
+ FuncClonesToCallMap[NewFuncClone];
+ assert(CallMap.count(OrigCall));
+ CallInfo NewCall(CallMap[OrigCall]);
+ assert(NewCall);
+ NewClone->setCall(NewCall);
+ }
+ }
+ // Fall through to handling below to perform the recording of the
+ // function for this callsite clone. This enables handling of cases
+ // where the callers were assigned to different clones of a function.
+ }
+
+ // See if we can use existing function clone. Walk through
+ // all caller edges to see if any have already been assigned to
+ // a clone of this callsite's function. If we can use it, do so. If not,
+ // because that function clone is already assigned to a different clone
+ // of this callsite, then we need to clone again.
+ // Basically, this checking is needed to handle the case where different
+ // caller functions/callsites may need versions of this function
+ // containing different mixes of callsite clones across the different
+ // callsites within the function. If that happens, we need to create
+ // additional function clones to handle the various combinations.
+ //
+ // Keep track of any new clones of this callsite created by the
+ // following loop, as well as any existing clone that we decided to
+ // assign this clone to.
+ std::map<FuncInfo, ContextNode *> FuncCloneToNewCallsiteCloneMap;
+ FuncInfo FuncCloneAssignedToCurCallsiteClone;
+ // We need to be able to remove Edge from CallerEdges, so need to adjust
+ // iterator in the loop.
+ for (auto EI = Clone->CallerEdges.begin();
+ EI != Clone->CallerEdges.end();) {
+ auto Edge = *EI;
+ // Ignore any caller that does not have a recorded callsite Call.
+ if (!Edge->Caller->hasCall()) {
+ EI++;
+ continue;
+ }
+ // If this caller already assigned to call a version of OrigFunc, need
+ // to ensure we can assign this callsite clone to that function clone.
+ if (CallsiteToCalleeFuncCloneMap.count(Edge->Caller)) {
+ FuncInfo FuncCloneCalledByCaller =
+ CallsiteToCalleeFuncCloneMap[Edge->Caller];
+ // First we need to confirm that this function clone is available
+ // for use by this callsite node clone.
+ //
+ // While FuncCloneToCurNodeCloneMap is built only for this Node and
+ // its callsite clones, one of those callsite clones X could have
+ // been assigned to the same function clone called by Edge's caller
+ // - if Edge's caller calls another callsite within Node's original
+ // function, and that callsite has another caller reaching clone X.
+ // We need to clone Node again in this case.
+ if ((FuncCloneToCurNodeCloneMap.count(FuncCloneCalledByCaller) &&
+ FuncCloneToCurNodeCloneMap[FuncCloneCalledByCaller] !=
+ Clone) ||
+ // Detect when we have multiple callers of this callsite that
+ // have already been assigned to specific, and different, clones
+ // of OrigFunc (due to other unrelated callsites in Func they
+ // reach via call contexts). Is this Clone of callsite Node
+ // assigned to a different clone of OrigFunc? If so, clone Node
+ // again.
+ (FuncCloneAssignedToCurCallsiteClone &&
+ FuncCloneAssignedToCurCallsiteClone !=
+ FuncCloneCalledByCaller)) {
+ // We need to use a different newly created callsite clone, in
+ // order to assign it to another new function clone on a
+ // subsequent iteration over the Clones array (adjusted below).
+ // Note we specifically do not reset the
+ // CallsiteToCalleeFuncCloneMap entry for this caller, so that
+ // when this new clone is processed later we know which version of
+ // the function to copy (so that other callsite clones we have
+ // assigned to that function clone are properly cloned over). See
+ // comments in the function cloning handling earlier.
+
+ // Check if we already have cloned this callsite again while
+ // walking through caller edges, for a caller calling the same
+ // function clone. If so, we can move this edge to that new clone
+ // rather than creating yet another new clone.
+ if (FuncCloneToNewCallsiteCloneMap.count(
+ FuncCloneCalledByCaller)) {
+ ContextNode *NewClone =
+ FuncCloneToNewCallsiteCloneMap[FuncCloneCalledByCaller];
+ moveEdgeToExistingCalleeClone(Edge, NewClone, &EI);
+ // Cleanup any none type edges cloned over.
+ removeNoneTypeCalleeEdges(NewClone);
+ } else {
+ // Create a new callsite clone.
+ ContextNode *NewClone = moveEdgeToNewCalleeClone(Edge, &EI);
+ removeNoneTypeCalleeEdges(NewClone);
+ FuncCloneToNewCallsiteCloneMap[FuncCloneCalledByCaller] =
+ NewClone;
+ // Add to list of clones and process later.
+ ClonesWorklist.push_back(NewClone);
+ assert(EI == Clone->CallerEdges.end() ||
+ Clone->AllocTypes != (uint8_t)AllocationType::None);
+ assert(NewClone->AllocTypes != (uint8_t)AllocationType::None);
+ }
+ // Moving the caller edge may have resulted in some none type
+ // callee edges.
+ removeNoneTypeCalleeEdges(Clone);
+ // We will handle the newly created callsite clone in a subsequent
+ // iteration over this Node's Clones. Continue here since we
+ // already adjusted iterator EI while moving the edge.
+ continue;
+ }
+
+ // Otherwise, we can use the function clone already assigned to this
+ // caller.
+ if (!FuncCloneAssignedToCurCallsiteClone) {
+ FuncCloneAssignedToCurCallsiteClone = FuncCloneCalledByCaller;
+ // Assign Clone to FuncCloneCalledByCaller
+ AssignCallsiteCloneToFuncClone(
+ FuncCloneCalledByCaller, Call, Clone,
+ AllocationCallToContextNodeMap.count(Call));
+ } else
+ // Don't need to do anything - callsite is already calling this
+ // function clone.
+ assert(FuncCloneAssignedToCurCallsiteClone ==
+ FuncCloneCalledByCaller);
+
+ } else {
+ // We have not already assigned this caller to a version of
+ // OrigFunc. Do the assignment now.
+
+ // First check if we have already assigned this callsite clone to a
+ // clone of OrigFunc for another caller during this iteration over
+ // its caller edges.
+ if (!FuncCloneAssignedToCurCallsiteClone) {
+ // Find first function in FuncClonesToCallMap without an assigned
+ // clone of this callsite Node. We should always have one
+ // available at this point due to the earlier cloning when the
+ // FuncClonesToCallMap size was smaller than the clone number.
+ for (auto &CF : FuncClonesToCallMap) {
+ if (!FuncCloneToCurNodeCloneMap.count(CF.first)) {
+ FuncCloneAssignedToCurCallsiteClone = CF.first;
+ break;
+ }
+ }
+ assert(FuncCloneAssignedToCurCallsiteClone);
+ // Assign Clone to FuncCloneAssignedToCurCallsiteClone
+ AssignCallsiteCloneToFuncClone(
+ FuncCloneAssignedToCurCallsiteClone, Call, Clone,
+ AllocationCallToContextNodeMap.count(Call));
+ } else
+ assert(FuncCloneToCurNodeCloneMap
+ [FuncCloneAssignedToCurCallsiteClone] == Clone);
+ // Update callers to record function version called.
+ RecordCalleeFuncOfCallsite(Edge->Caller,
+ FuncCloneAssignedToCurCallsiteClone);
+ }
+
+ EI++;
+ }
+ }
+ if (VerifyCCG) {
+ checkNode<DerivedCCG, FuncTy, CallTy>(Node);
+ for (const auto &PE : Node->CalleeEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(PE->Callee);
+ for (const auto &CE : Node->CallerEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(CE->Caller);
+ for (auto *Clone : Node->Clones) {
+ checkNode<DerivedCCG, FuncTy, CallTy>(Clone);
+ for (const auto &PE : Clone->CalleeEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(PE->Callee);
+ for (const auto &CE : Clone->CallerEdges)
+ checkNode<DerivedCCG, FuncTy, CallTy>(CE->Caller);
+ }
+ }
+ }
+ }
+
+ auto UpdateCalls = [&](ContextNode *Node,
+ DenseSet<const ContextNode *> &Visited,
+ auto &&UpdateCalls) {
+ auto Inserted = Visited.insert(Node);
+ if (!Inserted.second)
+ return;
+
+ for (auto *Clone : Node->Clones)
+ UpdateCalls(Clone, Visited, UpdateCalls);
+
+ for (auto &Edge : Node->CallerEdges)
+ UpdateCalls(Edge->Caller, Visited, UpdateCalls);
+
+ // Skip if either no call to update, or if we ended up with no context ids
+ // (we moved all edges onto other clones).
+ if (!Node->hasCall() || Node->ContextIds.empty())
+ return;
+
+ if (Node->IsAllocation) {
+ updateAllocationCall(Node->Call, allocTypeToUse(Node->AllocTypes));
+ return;
+ }
+
+ if (!CallsiteToCalleeFuncCloneMap.count(Node))
+ return;
+
+ auto CalleeFunc = CallsiteToCalleeFuncCloneMap[Node];
+ updateCall(Node->Call, CalleeFunc);
+ };
+
+ // Performs DFS traversal starting from allocation nodes to update calls to
+ // reflect cloning decisions recorded earlier. For regular LTO this will
+ // update the actual calls in the IR to call the appropriate function clone
+ // (and add attributes to allocation calls), whereas for ThinLTO the decisions
+ // are recorded in the summary entries.
+ DenseSet<const ContextNode *> Visited;
+ for (auto &Entry : AllocationCallToContextNodeMap)
+ UpdateCalls(Entry.second, Visited, UpdateCalls);
+
+ return Changed;
+}
+
+static SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> createFunctionClones(
+ Function &F, unsigned NumClones, Module &M, OptimizationRemarkEmitter &ORE,
+ std::map<const Function *, SmallPtrSet<const GlobalAlias *, 1>>
+ &FuncToAliasMap) {
+ // The first "clone" is the original copy, we should only call this if we
+ // needed to create new clones.
+ assert(NumClones > 1);
+ SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> VMaps;
+ VMaps.reserve(NumClones - 1);
+ FunctionsClonedThinBackend++;
+ for (unsigned I = 1; I < NumClones; I++) {
+ VMaps.emplace_back(std::make_unique<ValueToValueMapTy>());
+ auto *NewF = CloneFunction(&F, *VMaps.back());
+ FunctionClonesThinBackend++;
+ // Strip memprof and callsite metadata from clone as they are no longer
+ // needed.
+ for (auto &BB : *NewF) {
+ for (auto &Inst : BB) {
+ Inst.setMetadata(LLVMContext::MD_memprof, nullptr);
+ Inst.setMetadata(LLVMContext::MD_callsite, nullptr);
+ }
+ }
+ std::string Name = getMemProfFuncName(F.getName(), I);
+ auto *PrevF = M.getFunction(Name);
+ if (PrevF) {
+ // We might have created this when adjusting callsite in another
+ // function. It should be a declaration.
+ assert(PrevF->isDeclaration());
+ NewF->takeName(PrevF);
+ PrevF->replaceAllUsesWith(NewF);
+ PrevF->eraseFromParent();
+ } else
+ NewF->setName(Name);
+ ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofClone", &F)
+ << "created clone " << ore::NV("NewFunction", NewF));
+
+ // Now handle aliases to this function, and clone those as well.
+ if (!FuncToAliasMap.count(&F))
+ continue;
+ for (auto *A : FuncToAliasMap[&F]) {
+ std::string Name = getMemProfFuncName(A->getName(), I);
+ auto *PrevA = M.getNamedAlias(Name);
+ auto *NewA = GlobalAlias::create(A->getValueType(),
+ A->getType()->getPointerAddressSpace(),
+ A->getLinkage(), Name, NewF);
+ NewA->copyAttributesFrom(A);
+ if (PrevA) {
+ // We might have created this when adjusting callsite in another
+ // function. It should be a declaration.
+ assert(PrevA->isDeclaration());
+ NewA->takeName(PrevA);
+ PrevA->replaceAllUsesWith(NewA);
+ PrevA->eraseFromParent();
+ }
+ }
+ }
+ return VMaps;
+}
+
+// Locate the summary for F. This is complicated by the fact that it might
+// have been internalized or promoted.
+static ValueInfo findValueInfoForFunc(const Function &F, const Module &M,
+ const ModuleSummaryIndex *ImportSummary) {
+ // FIXME: Ideally we would retain the original GUID in some fashion on the
+ // function (e.g. as metadata), but for now do our best to locate the
+ // summary without that information.
+ ValueInfo TheFnVI = ImportSummary->getValueInfo(F.getGUID());
+ if (!TheFnVI)
+ // See if theFn was internalized, by checking index directly with
+ // original name (this avoids the name adjustment done by getGUID() for
+ // internal symbols).
+ TheFnVI = ImportSummary->getValueInfo(GlobalValue::getGUID(F.getName()));
+ if (TheFnVI)
+ return TheFnVI;
+ // Now query with the original name before any promotion was performed.
+ StringRef OrigName =
+ ModuleSummaryIndex::getOriginalNameBeforePromote(F.getName());
+ std::string OrigId = GlobalValue::getGlobalIdentifier(
+ OrigName, GlobalValue::InternalLinkage, M.getSourceFileName());
+ TheFnVI = ImportSummary->getValueInfo(GlobalValue::getGUID(OrigId));
+ if (TheFnVI)
+ return TheFnVI;
+ // Could be a promoted local imported from another module. We need to pass
+ // down more info here to find the original module id. For now, try with
+ // the OrigName which might have been stored in the OidGuidMap in the
+ // index. This would not work if there were same-named locals in multiple
+ // modules, however.
+ auto OrigGUID =
+ ImportSummary->getGUIDFromOriginalID(GlobalValue::getGUID(OrigName));
+ if (OrigGUID)
+ TheFnVI = ImportSummary->getValueInfo(OrigGUID);
+ return TheFnVI;
+}
+
+bool MemProfContextDisambiguation::applyImport(Module &M) {
+ assert(ImportSummary);
+ bool Changed = false;
+
+ auto IsMemProfClone = [](const Function &F) {
+ return F.getName().contains(MemProfCloneSuffix);
+ };
+
+ // We also need to clone any aliases that reference cloned functions, because
+ // the modified callsites may invoke via the alias. Keep track of the aliases
+ // for each function.
+ std::map<const Function *, SmallPtrSet<const GlobalAlias *, 1>>
+ FuncToAliasMap;
+ for (auto &A : M.aliases()) {
+ auto *Aliasee = A.getAliaseeObject();
+ if (auto *F = dyn_cast<Function>(Aliasee))
+ FuncToAliasMap[F].insert(&A);
+ }
+
+ for (auto &F : M) {
+ if (F.isDeclaration() || IsMemProfClone(F))
+ continue;
+
+ OptimizationRemarkEmitter ORE(&F);
+
+ SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> VMaps;
+ bool ClonesCreated = false;
+ unsigned NumClonesCreated = 0;
+ auto CloneFuncIfNeeded = [&](unsigned NumClones) {
+ // We should at least have version 0 which is the original copy.
+ assert(NumClones > 0);
+ // If only one copy needed use original.
+ if (NumClones == 1)
+ return;
+ // If we already performed cloning of this function, confirm that the
+ // requested number of clones matches (the thin link should ensure the
+ // number of clones for each constituent callsite is consistent within
+ // each function), before returning.
+ if (ClonesCreated) {
+ assert(NumClonesCreated == NumClones);
+ return;
+ }
+ VMaps = createFunctionClones(F, NumClones, M, ORE, FuncToAliasMap);
+ // The first "clone" is the original copy, which doesn't have a VMap.
+ assert(VMaps.size() == NumClones - 1);
+ Changed = true;
+ ClonesCreated = true;
+ NumClonesCreated = NumClones;
+ };
+
+ // Locate the summary for F.
+ ValueInfo TheFnVI = findValueInfoForFunc(F, M, ImportSummary);
+ // If not found, this could be an imported local (see comment in
+ // findValueInfoForFunc). Skip for now as it will be cloned in its original
+ // module (where it would have been promoted to global scope so should
+ // satisfy any reference in this module).
+ if (!TheFnVI)
+ continue;
+
+ auto *GVSummary =
+ ImportSummary->findSummaryInModule(TheFnVI, M.getModuleIdentifier());
+ if (!GVSummary)
+ // Must have been imported, use the first summary (might be multiple if
+ // this was a linkonce_odr).
+ GVSummary = TheFnVI.getSummaryList().front().get();
+
+ // If this was an imported alias skip it as we won't have the function
+ // summary, and it should be cloned in the original module.
+ if (isa<AliasSummary>(GVSummary))
+ continue;
+
+ auto *FS = cast<FunctionSummary>(GVSummary->getBaseObject());
+
+ if (FS->allocs().empty() && FS->callsites().empty())
+ continue;
+
+ auto SI = FS->callsites().begin();
+ auto AI = FS->allocs().begin();
+
+ // Assume for now that the instructions are in the exact same order
+ // as when the summary was created, but confirm this is correct by
+ // matching the stack ids.
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ auto *CB = dyn_cast<CallBase>(&I);
+ // Same handling as when creating module summary.
+ if (!mayHaveMemprofSummary(CB))
+ continue;
+
+ CallStack<MDNode, MDNode::op_iterator> CallsiteContext(
+ I.getMetadata(LLVMContext::MD_callsite));
+ auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof);
+
+ // Include allocs that were already assigned a memprof function
+ // attribute in the statistics.
+ if (CB->getAttributes().hasFnAttr("memprof")) {
+ assert(!MemProfMD);
+ CB->getAttributes().getFnAttr("memprof").getValueAsString() == "cold"
+ ? AllocTypeColdThinBackend++
+ : AllocTypeNotColdThinBackend++;
+ OrigAllocsThinBackend++;
+ AllocVersionsThinBackend++;
+ if (!MaxAllocVersionsThinBackend)
+ MaxAllocVersionsThinBackend = 1;
+ // Remove any remaining callsite metadata and we can skip the rest of
+ // the handling for this instruction, since no cloning needed.
+ I.setMetadata(LLVMContext::MD_callsite, nullptr);
+ continue;
+ }
+
+ if (MemProfMD) {
+ // Consult the next alloc node.
+ assert(AI != FS->allocs().end());
+ auto &AllocNode = *(AI++);
+
+ // Sanity check that the MIB stack ids match between the summary and
+ // instruction metadata.
+ auto MIBIter = AllocNode.MIBs.begin();
+ for (auto &MDOp : MemProfMD->operands()) {
+ assert(MIBIter != AllocNode.MIBs.end());
+ LLVM_ATTRIBUTE_UNUSED auto StackIdIndexIter =
+ MIBIter->StackIdIndices.begin();
+ auto *MIBMD = cast<const MDNode>(MDOp);
+ MDNode *StackMDNode = getMIBStackNode(MIBMD);
+ assert(StackMDNode);
+ SmallVector<unsigned> StackIdsFromMetadata;
+ CallStack<MDNode, MDNode::op_iterator> StackContext(StackMDNode);
+ for (auto ContextIter =
+ StackContext.beginAfterSharedPrefix(CallsiteContext);
+ ContextIter != StackContext.end(); ++ContextIter) {
+ // If this is a direct recursion, simply skip the duplicate
+ // entries, to be consistent with how the summary ids were
+ // generated during ModuleSummaryAnalysis.
+ if (!StackIdsFromMetadata.empty() &&
+ StackIdsFromMetadata.back() == *ContextIter)
+ continue;
+ assert(StackIdIndexIter != MIBIter->StackIdIndices.end());
+ assert(ImportSummary->getStackIdAtIndex(*StackIdIndexIter) ==
+ *ContextIter);
+ StackIdIndexIter++;
+ }
+ MIBIter++;
+ }
+
+ // Perform cloning if not yet done.
+ CloneFuncIfNeeded(/*NumClones=*/AllocNode.Versions.size());
+
+ OrigAllocsThinBackend++;
+ AllocVersionsThinBackend += AllocNode.Versions.size();
+ if (MaxAllocVersionsThinBackend < AllocNode.Versions.size())
+ MaxAllocVersionsThinBackend = AllocNode.Versions.size();
+
+ // If there is only one version that means we didn't end up
+ // considering this function for cloning, and in that case the alloc
+ // will still be none type or should have gotten the default NotCold.
+ // Skip that after calling clone helper since that does some sanity
+ // checks that confirm we haven't decided yet that we need cloning.
+ if (AllocNode.Versions.size() == 1) {
+ assert((AllocationType)AllocNode.Versions[0] ==
+ AllocationType::NotCold ||
+ (AllocationType)AllocNode.Versions[0] ==
+ AllocationType::None);
+ UnclonableAllocsThinBackend++;
+ continue;
+ }
+
+ // All versions should have a singular allocation type.
+ assert(llvm::none_of(AllocNode.Versions, [](uint8_t Type) {
+ return Type == ((uint8_t)AllocationType::NotCold |
+ (uint8_t)AllocationType::Cold);
+ }));
+
+ // Update the allocation types per the summary info.
+ for (unsigned J = 0; J < AllocNode.Versions.size(); J++) {
+ // Ignore any that didn't get an assigned allocation type.
+ if (AllocNode.Versions[J] == (uint8_t)AllocationType::None)
+ continue;
+ AllocationType AllocTy = (AllocationType)AllocNode.Versions[J];
+ AllocTy == AllocationType::Cold ? AllocTypeColdThinBackend++
+ : AllocTypeNotColdThinBackend++;
+ std::string AllocTypeString = getAllocTypeAttributeString(AllocTy);
+ auto A = llvm::Attribute::get(F.getContext(), "memprof",
+ AllocTypeString);
+ CallBase *CBClone;
+ // Copy 0 is the original function.
+ if (!J)
+ CBClone = CB;
+ else
+ // Since VMaps are only created for new clones, we index with
+ // clone J-1 (J==0 is the original clone and does not have a VMaps
+ // entry).
+ CBClone = cast<CallBase>((*VMaps[J - 1])[CB]);
+ CBClone->addFnAttr(A);
+ ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", CBClone)
+ << ore::NV("AllocationCall", CBClone) << " in clone "
+ << ore::NV("Caller", CBClone->getFunction())
+ << " marked with memprof allocation attribute "
+ << ore::NV("Attribute", AllocTypeString));
+ }
+ } else if (!CallsiteContext.empty()) {
+ // Consult the next callsite node.
+ assert(SI != FS->callsites().end());
+ auto &StackNode = *(SI++);
+
+#ifndef NDEBUG
+ // Sanity check that the stack ids match between the summary and
+ // instruction metadata.
+ auto StackIdIndexIter = StackNode.StackIdIndices.begin();
+ for (auto StackId : CallsiteContext) {
+ assert(StackIdIndexIter != StackNode.StackIdIndices.end());
+ assert(ImportSummary->getStackIdAtIndex(*StackIdIndexIter) ==
+ StackId);
+ StackIdIndexIter++;
+ }
+#endif
+
+ // Perform cloning if not yet done.
+ CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size());
+
+ // Should have skipped indirect calls via mayHaveMemprofSummary.
+ assert(CB->getCalledFunction());
+ assert(!IsMemProfClone(*CB->getCalledFunction()));
+
+ // Update the calls per the summary info.
+ // Save orig name since it gets updated in the first iteration
+ // below.
+ auto CalleeOrigName = CB->getCalledFunction()->getName();
+ for (unsigned J = 0; J < StackNode.Clones.size(); J++) {
+ // Do nothing if this version calls the original version of its
+ // callee.
+ if (!StackNode.Clones[J])
+ continue;
+ auto NewF = M.getOrInsertFunction(
+ getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]),
+ CB->getCalledFunction()->getFunctionType());
+ CallBase *CBClone;
+ // Copy 0 is the original function.
+ if (!J)
+ CBClone = CB;
+ else
+ CBClone = cast<CallBase>((*VMaps[J - 1])[CB]);
+ CBClone->setCalledFunction(NewF);
+ ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CBClone)
+ << ore::NV("Call", CBClone) << " in clone "
+ << ore::NV("Caller", CBClone->getFunction())
+ << " assigned to call function clone "
+ << ore::NV("Callee", NewF.getCallee()));
+ }
+ }
+ // Memprof and callsite metadata on memory allocations no longer needed.
+ I.setMetadata(LLVMContext::MD_memprof, nullptr);
+ I.setMetadata(LLVMContext::MD_callsite, nullptr);
+ }
+ }
+ }
+
+ return Changed;
+}
+
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::process() {
+ if (DumpCCG) {
+ dbgs() << "CCG before cloning:\n";
+ dbgs() << *this;
+ }
+ if (ExportToDot)
+ exportToDot("postbuild");
+
+ if (VerifyCCG) {
+ check();
+ }
+
+ identifyClones();
+
+ if (VerifyCCG) {
+ check();
+ }
+
+ if (DumpCCG) {
+ dbgs() << "CCG after cloning:\n";
+ dbgs() << *this;
+ }
+ if (ExportToDot)
+ exportToDot("cloned");
+
+ bool Changed = assignFunctions();
+
+ if (DumpCCG) {
+ dbgs() << "CCG after assigning function clones:\n";
+ dbgs() << *this;
+ }
+ if (ExportToDot)
+ exportToDot("clonefuncassign");
+
+ return Changed;
+}
+
+bool MemProfContextDisambiguation::processModule(
+ Module &M,
+ function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
+
+ // If we have an import summary, then the cloning decisions were made during
+ // the thin link on the index. Apply them and return.
+ if (ImportSummary)
+ return applyImport(M);
+
+ // TODO: If/when other types of memprof cloning are enabled beyond just for
+ // hot and cold, we will need to change this to individually control the
+ // AllocationType passed to addStackNodesForMIB during CCG construction.
+ // Note that we specifically check this after applying imports above, so that
+ // the option isn't needed to be passed to distributed ThinLTO backend
+ // clang processes, which won't necessarily have visibility into the linker
+ // dependences. Instead the information is communicated from the LTO link to
+ // the backends via the combined summary index.
+ if (!SupportsHotColdNew)
+ return false;
+
+ ModuleCallsiteContextGraph CCG(M, OREGetter);
+ return CCG.process();
+}
+
+MemProfContextDisambiguation::MemProfContextDisambiguation(
+ const ModuleSummaryIndex *Summary)
+ : ImportSummary(Summary) {
+ if (ImportSummary) {
+ // The MemProfImportSummary should only be used for testing ThinLTO
+ // distributed backend handling via opt, in which case we don't have a
+ // summary from the pass pipeline.
+ assert(MemProfImportSummary.empty());
+ return;
+ }
+ if (MemProfImportSummary.empty())
+ return;
+
+ auto ReadSummaryFile =
+ errorOrToExpected(MemoryBuffer::getFile(MemProfImportSummary));
+ if (!ReadSummaryFile) {
+ logAllUnhandledErrors(ReadSummaryFile.takeError(), errs(),
+ "Error loading file '" + MemProfImportSummary +
+ "': ");
+ return;
+ }
+ auto ImportSummaryForTestingOrErr = getModuleSummaryIndex(**ReadSummaryFile);
+ if (!ImportSummaryForTestingOrErr) {
+ logAllUnhandledErrors(ImportSummaryForTestingOrErr.takeError(), errs(),
+ "Error parsing file '" + MemProfImportSummary +
+ "': ");
+ return;
+ }
+ ImportSummaryForTesting = std::move(*ImportSummaryForTestingOrErr);
+ ImportSummary = ImportSummaryForTesting.get();
+}
+
+PreservedAnalyses MemProfContextDisambiguation::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
+ return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
+ };
+ if (!processModule(M, OREGetter))
+ return PreservedAnalyses::all();
+ return PreservedAnalyses::none();
+}
+
+void MemProfContextDisambiguation::run(
+ ModuleSummaryIndex &Index,
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ isPrevailing) {
+ // TODO: If/when other types of memprof cloning are enabled beyond just for
+ // hot and cold, we will need to change this to individually control the
+ // AllocationType passed to addStackNodesForMIB during CCG construction.
+ // The index was set from the option, so these should be in sync.
+ assert(Index.withSupportsHotColdNew() == SupportsHotColdNew);
+ if (!SupportsHotColdNew)
+ return;
+
+ IndexCallsiteContextGraph CCG(Index, isPrevailing);
+ CCG.process();
+}
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 590f62ca58dd..feda5d6459cb 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -112,8 +112,6 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -294,34 +292,8 @@ private:
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
};
-
-class MergeFunctionsLegacyPass : public ModulePass {
-public:
- static char ID;
-
- MergeFunctionsLegacyPass(): ModulePass(ID) {
- initializeMergeFunctionsLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- MergeFunctions MF;
- return MF.runOnModule(M);
- }
-};
-
} // end anonymous namespace
-char MergeFunctionsLegacyPass::ID = 0;
-INITIALIZE_PASS(MergeFunctionsLegacyPass, "mergefunc",
- "Merge Functions", false, false)
-
-ModulePass *llvm::createMergeFunctionsPass() {
- return new MergeFunctionsLegacyPass();
-}
-
PreservedAnalyses MergeFunctionsPass::run(Module &M,
ModuleAnalysisManager &AM) {
MergeFunctions MF;
diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
index ee382657f5e6..5e91ab80d750 100644
--- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp
+++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
@@ -138,17 +138,12 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
//
// TODO: Here is a huge amount duplicate code between the module inliner and
// the SCC inliner, which need some refactoring.
- auto Calls = getInlineOrder(FAM, Params);
+ auto Calls = getInlineOrder(FAM, Params, MAM, M);
assert(Calls != nullptr && "Expected an initialized InlineOrder");
// Populate the initial list of calls in this module.
for (Function &F : M) {
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- // We want to generally process call sites top-down in order for
- // simplifications stemming from replacing the call with the returned value
- // after inlining to be visible to subsequent inlining decisions.
- // FIXME: Using instructions sequence is a really bad way to do this.
- // Instead we should do an actual RPO walk of the function body.
for (Instruction &I : instructions(F))
if (auto *CB = dyn_cast<CallBase>(&I))
if (Function *Callee = CB->getCalledFunction()) {
@@ -213,7 +208,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
// Setup the data structure used to plumb customization into the
// `InlineFunction` routine.
InlineFunctionInfo IFI(
- /*cg=*/nullptr, GetAssumptionCache, PSI,
+ GetAssumptionCache, PSI,
&FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())),
&FAM.getResult<BlockFrequencyAnalysis>(Callee));
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index bee154dab10f..588f3901e3cb 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -22,8 +22,10 @@
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
@@ -36,6 +38,8 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instruction.h"
@@ -44,7 +48,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
-#include "llvm/InitializePasses.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/IPO/Attributor.h"
@@ -188,9 +192,9 @@ struct AAICVTracker;
struct OMPInformationCache : public InformationCache {
OMPInformationCache(Module &M, AnalysisGetter &AG,
BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
- KernelSet &Kernels)
+ bool OpenMPPostLink)
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
- Kernels(Kernels) {
+ OpenMPPostLink(OpenMPPostLink) {
OMPBuilder.initialize();
initializeRuntimeFunctions(M);
@@ -417,7 +421,7 @@ struct OMPInformationCache : public InformationCache {
// TODO: We directly convert uses into proper calls and unknown uses.
for (Use &U : RFI.Declaration->uses()) {
if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
- if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
+ if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
++NumUses;
}
@@ -448,6 +452,24 @@ struct OMPInformationCache : public InformationCache {
CI->setCallingConv(Fn->getCallingConv());
}
+ // Helper function to determine if it's legal to create a call to the runtime
+ // functions.
+ bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
+ // We can always emit calls if we haven't yet linked in the runtime.
+ if (!OpenMPPostLink)
+ return true;
+
+ // Once the runtime has been already been linked in we cannot emit calls to
+ // any undefined functions.
+ for (RuntimeFunction Fn : Fns) {
+ RuntimeFunctionInfo &RFI = RFIs[Fn];
+
+ if (RFI.Declaration && RFI.Declaration->isDeclaration())
+ return false;
+ }
+ return true;
+ }
+
/// Helper to initialize all runtime function information for those defined
/// in OpenMPKinds.def.
void initializeRuntimeFunctions(Module &M) {
@@ -518,11 +540,11 @@ struct OMPInformationCache : public InformationCache {
// TODO: We should attach the attributes defined in OMPKinds.def.
}
- /// Collection of known kernels (\see Kernel) in the module.
- KernelSet &Kernels;
-
/// Collection of known OpenMP runtime functions..
DenseSet<const Function *> RTLFunctions;
+
+ /// Indicates if we have already linked in the OpenMP device library.
+ bool OpenMPPostLink = false;
};
template <typename Ty, bool InsertInvalidates = true>
@@ -808,7 +830,7 @@ struct OpenMPOpt {
return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
}
- /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
+ /// Run all OpenMP optimizations on the underlying SCC.
bool run(bool IsModulePass) {
if (SCC.empty())
return false;
@@ -816,8 +838,7 @@ struct OpenMPOpt {
bool Changed = false;
LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
- << " functions in a slice with "
- << OMPInfoCache.ModuleSlice.size() << " functions\n");
+ << " functions\n");
if (IsModulePass) {
Changed |= runAttributor(IsModulePass);
@@ -882,7 +903,7 @@ struct OpenMPOpt {
/// Print OpenMP GPU kernels for testing.
void printKernels() const {
for (Function *F : SCC) {
- if (!OMPInfoCache.Kernels.count(F))
+ if (!omp::isKernel(*F))
continue;
auto Remark = [&](OptimizationRemarkAnalysis ORA) {
@@ -1412,7 +1433,10 @@ private:
Changed |= WasSplit;
return WasSplit;
};
- RFI.foreachUse(SCC, SplitMemTransfers);
+ if (OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___tgt_target_data_begin_mapper_issue,
+ OMPRTL___tgt_target_data_begin_mapper_wait}))
+ RFI.foreachUse(SCC, SplitMemTransfers);
return Changed;
}
@@ -1681,37 +1705,27 @@ private:
};
if (!ReplVal) {
- for (Use *U : *UV)
+ auto *DT =
+ OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
+ if (!DT)
+ return false;
+ Instruction *IP = nullptr;
+ for (Use *U : *UV) {
if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
+ if (IP)
+ IP = DT->findNearestCommonDominator(IP, CI);
+ else
+ IP = CI;
if (!CanBeMoved(*CI))
continue;
-
- // If the function is a kernel, dedup will move
- // the runtime call right after the kernel init callsite. Otherwise,
- // it will move it to the beginning of the caller function.
- if (isKernel(F)) {
- auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
- auto *KernelInitUV = KernelInitRFI.getUseVector(F);
-
- if (KernelInitUV->empty())
- continue;
-
- assert(KernelInitUV->size() == 1 &&
- "Expected a single __kmpc_target_init in kernel\n");
-
- CallInst *KernelInitCI =
- getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
- assert(KernelInitCI &&
- "Expected a call to __kmpc_target_init in kernel\n");
-
- CI->moveAfter(KernelInitCI);
- } else
- CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
- ReplVal = CI;
- break;
+ if (!ReplVal)
+ ReplVal = CI;
}
+ }
if (!ReplVal)
return false;
+ assert(IP && "Expected insertion point!");
+ cast<Instruction>(ReplVal)->moveBefore(IP);
}
// If we use a call as a replacement value we need to make sure the ident is
@@ -1809,9 +1823,6 @@ private:
///
///{{
- /// Check if \p F is a kernel, hence entry point for target offloading.
- bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
-
/// Cache to remember the unique kernel for a function.
DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
@@ -1920,7 +1931,8 @@ public:
};
Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
- if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F))
+ if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
+ !OMPInfoCache.CGSCC->contains(&F))
return nullptr;
// Use a scope to keep the lifetime of the CachedKernel short.
@@ -2095,12 +2107,6 @@ struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
using Base = StateWrapper<BooleanState, AbstractAttribute>;
AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
- void initialize(Attributor &A) override {
- Function *F = getAnchorScope();
- if (!F || !A.isFunctionIPOAmendable(*F))
- indicatePessimisticFixpoint();
- }
-
/// Returns true if value is assumed to be tracked.
bool isAssumedTracked() const { return getAssumed(); }
@@ -2146,7 +2152,9 @@ struct AAICVTrackerFunction : public AAICVTracker {
: AAICVTracker(IRP, A) {}
// FIXME: come up with better string.
- const std::string getAsStr() const override { return "ICVTrackerFunction"; }
+ const std::string getAsStr(Attributor *) const override {
+ return "ICVTrackerFunction";
+ }
// FIXME: come up with some stats.
void trackStatistics() const override {}
@@ -2242,11 +2250,12 @@ struct AAICVTrackerFunction : public AAICVTracker {
if (CalledFunction->isDeclaration())
return nullptr;
- const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
+ const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
*this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
- if (ICVTrackingAA.isAssumedTracked()) {
- std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
+ if (ICVTrackingAA->isAssumedTracked()) {
+ std::optional<Value *> URV =
+ ICVTrackingAA->getUniqueReplacementValue(ICV);
if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
OMPInfoCache)))
return URV;
@@ -2337,7 +2346,7 @@ struct AAICVTrackerFunctionReturned : AAICVTracker {
: AAICVTracker(IRP, A) {}
// FIXME: come up with better string.
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *) const override {
return "ICVTrackerFunctionReturned";
}
@@ -2362,10 +2371,10 @@ struct AAICVTrackerFunctionReturned : AAICVTracker {
ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
- const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
+ const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
- if (!ICVTrackingAA.isAssumedTracked())
+ if (!ICVTrackingAA->isAssumedTracked())
return indicatePessimisticFixpoint();
for (InternalControlVar ICV : TrackableICVs) {
@@ -2374,7 +2383,7 @@ struct AAICVTrackerFunctionReturned : AAICVTracker {
auto CheckReturnInst = [&](Instruction &I) {
std::optional<Value *> NewReplVal =
- ICVTrackingAA.getReplacementValue(ICV, &I, A);
+ ICVTrackingAA->getReplacementValue(ICV, &I, A);
// If we found a second ICV value there is no unique returned value.
if (UniqueICVValue && UniqueICVValue != NewReplVal)
@@ -2407,9 +2416,7 @@ struct AAICVTrackerCallSite : AAICVTracker {
: AAICVTracker(IRP, A) {}
void initialize(Attributor &A) override {
- Function *F = getAnchorScope();
- if (!F || !A.isFunctionIPOAmendable(*F))
- indicatePessimisticFixpoint();
+ assert(getAnchorScope() && "Expected anchor function");
// We only initialize this AA for getters, so we need to know which ICV it
// gets.
@@ -2438,7 +2445,9 @@ struct AAICVTrackerCallSite : AAICVTracker {
}
// FIXME: come up with better string.
- const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
+ const std::string getAsStr(Attributor *) const override {
+ return "ICVTrackerCallSite";
+ }
// FIXME: come up with some stats.
void trackStatistics() const override {}
@@ -2447,15 +2456,15 @@ struct AAICVTrackerCallSite : AAICVTracker {
std::optional<Value *> ReplVal;
ChangeStatus updateImpl(Attributor &A) override {
- const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
+ const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
// We don't have any information, so we assume it changes the ICV.
- if (!ICVTrackingAA.isAssumedTracked())
+ if (!ICVTrackingAA->isAssumedTracked())
return indicatePessimisticFixpoint();
std::optional<Value *> NewReplVal =
- ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
+ ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
if (ReplVal == NewReplVal)
return ChangeStatus::UNCHANGED;
@@ -2477,7 +2486,7 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker {
: AAICVTracker(IRP, A) {}
// FIXME: come up with better string.
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *) const override {
return "ICVTrackerCallSiteReturned";
}
@@ -2503,18 +2512,18 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker {
ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
- const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
+ const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
*this, IRPosition::returned(*getAssociatedFunction()),
DepClassTy::REQUIRED);
// We don't have any information, so we assume it changes the ICV.
- if (!ICVTrackingAA.isAssumedTracked())
+ if (!ICVTrackingAA->isAssumedTracked())
return indicatePessimisticFixpoint();
for (InternalControlVar ICV : TrackableICVs) {
std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
std::optional<Value *> NewReplVal =
- ICVTrackingAA.getUniqueReplacementValue(ICV);
+ ICVTrackingAA->getUniqueReplacementValue(ICV);
if (ReplVal == NewReplVal)
continue;
@@ -2530,26 +2539,28 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
: AAExecutionDomain(IRP, A) {}
- ~AAExecutionDomainFunction() {
- delete RPOT;
- }
+ ~AAExecutionDomainFunction() { delete RPOT; }
void initialize(Attributor &A) override {
- if (getAnchorScope()->isDeclaration()) {
- indicatePessimisticFixpoint();
- return;
- }
- RPOT = new ReversePostOrderTraversal<Function *>(getAnchorScope());
+ Function *F = getAnchorScope();
+ assert(F && "Expected anchor function");
+ RPOT = new ReversePostOrderTraversal<Function *>(F);
}
- const std::string getAsStr() const override {
- unsigned TotalBlocks = 0, InitialThreadBlocks = 0;
+ const std::string getAsStr(Attributor *) const override {
+ unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
for (auto &It : BEDMap) {
+ if (!It.getFirst())
+ continue;
TotalBlocks++;
InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
+ AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
+ It.getSecond().IsReachingAlignedBarrierOnly;
}
return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
- std::to_string(TotalBlocks) + " executed by initial thread only";
+ std::to_string(AlignedBlocks) + " of " +
+ std::to_string(TotalBlocks) +
+ " executed by initial thread / aligned";
}
/// See AbstractAttribute::trackStatistics().
@@ -2572,7 +2583,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
SmallPtrSet<CallBase *, 16> DeletedBarriers;
auto HandleAlignedBarrier = [&](CallBase *CB) {
- const ExecutionDomainTy &ED = CEDMap[CB];
+ const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
if (!ED.IsReachedFromAlignedBarrierOnly ||
ED.EncounteredNonLocalSideEffect)
return;
@@ -2596,6 +2607,8 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
CallBase *LastCB = Worklist.pop_back_val();
if (!Visited.insert(LastCB))
continue;
+ if (LastCB->getFunction() != getAnchorScope())
+ continue;
if (!DeletedBarriers.count(LastCB)) {
A.deleteAfterManifest(*LastCB);
continue;
@@ -2603,7 +2616,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
// The final aligned barrier (LastCB) reaching the kernel end was
// removed already. This means we can go one step further and remove
// the barriers encoutered last before (LastCB).
- const ExecutionDomainTy &LastED = CEDMap[LastCB];
+ const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
Worklist.append(LastED.AlignedBarriers.begin(),
LastED.AlignedBarriers.end());
}
@@ -2619,14 +2632,17 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
for (auto *CB : AlignedBarriers)
HandleAlignedBarrier(CB);
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
// Handle the "kernel end barrier" for kernels too.
- if (OMPInfoCache.Kernels.count(getAnchorScope()))
+ if (omp::isKernel(*getAnchorScope()))
HandleAlignedBarrier(nullptr);
return Changed;
}
+ bool isNoOpFence(const FenceInst &FI) const override {
+ return getState().isValidState() && !NonNoOpFences.count(&FI);
+ }
+
/// Merge barrier and assumption information from \p PredED into the successor
/// \p ED.
void
@@ -2636,12 +2652,12 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
/// Merge all information from \p PredED into the successor \p ED. If
/// \p InitialEdgeOnly is set, only the initial edge will enter the block
/// represented by \p ED from this predecessor.
- void mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
+ bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
bool InitialEdgeOnly = false);
/// Accumulate information for the entry block in \p EntryBBED.
- void handleEntryBB(Attributor &A, ExecutionDomainTy &EntryBBED);
+ bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
/// See AbstractAttribute::updateImpl.
ChangeStatus updateImpl(Attributor &A) override;
@@ -2651,14 +2667,18 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
if (!isValidState())
return false;
+ assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
}
bool isExecutedInAlignedRegion(Attributor &A,
const Instruction &I) const override {
- if (!isValidState() || isa<CallBase>(I))
+ assert(I.getFunction() == getAnchorScope() &&
+ "Instruction is out of scope!");
+ if (!isValidState())
return false;
+ bool ForwardIsOk = true;
const Instruction *CurI;
// Check forward until a call or the block end is reached.
@@ -2667,15 +2687,18 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
auto *CB = dyn_cast<CallBase>(CurI);
if (!CB)
continue;
- const auto &It = CEDMap.find(CB);
+ if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
+ return true;
+ const auto &It = CEDMap.find({CB, PRE});
if (It == CEDMap.end())
continue;
- if (!It->getSecond().IsReachedFromAlignedBarrierOnly)
- return false;
+ if (!It->getSecond().IsReachingAlignedBarrierOnly)
+ ForwardIsOk = false;
+ break;
} while ((CurI = CurI->getNextNonDebugInstruction()));
- if (!CurI && !BEDMap.lookup(I.getParent()).IsReachedFromAlignedBarrierOnly)
- return false;
+ if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
+ ForwardIsOk = false;
// Check backward until a call or the block beginning is reached.
CurI = &I;
@@ -2683,33 +2706,30 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
auto *CB = dyn_cast<CallBase>(CurI);
if (!CB)
continue;
- const auto &It = CEDMap.find(CB);
+ if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
+ return true;
+ const auto &It = CEDMap.find({CB, POST});
if (It == CEDMap.end())
continue;
- if (!AA::isNoSyncInst(A, *CB, *this)) {
- if (It->getSecond().IsReachedFromAlignedBarrierOnly)
- break;
- return false;
- }
-
- Function *Callee = CB->getCalledFunction();
- if (!Callee || Callee->isDeclaration())
- return false;
- const auto &EDAA = A.getAAFor<AAExecutionDomain>(
- *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
- if (!EDAA.getState().isValidState())
- return false;
- if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly)
- return false;
- break;
+ if (It->getSecond().IsReachedFromAlignedBarrierOnly)
+ break;
+ return false;
} while ((CurI = CurI->getPrevNonDebugInstruction()));
- if (!CurI &&
- !llvm::all_of(
- predecessors(I.getParent()), [&](const BasicBlock *PredBB) {
- return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
- })) {
+ // Delayed decision on the forward pass to allow aligned barrier detection
+ // in the backwards traversal.
+ if (!ForwardIsOk)
return false;
+
+ if (!CurI) {
+ const BasicBlock *BB = I.getParent();
+ if (BB == &BB->getParent()->getEntryBlock())
+ return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
+ if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
+ return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
+ })) {
+ return false;
+ }
}
// On neither traversal we found a anything but aligned barriers.
@@ -2721,15 +2741,16 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
"No request should be made against an invalid state!");
return BEDMap.lookup(&BB);
}
- ExecutionDomainTy getExecutionDomain(const CallBase &CB) const override {
+ std::pair<ExecutionDomainTy, ExecutionDomainTy>
+ getExecutionDomain(const CallBase &CB) const override {
assert(isValidState() &&
"No request should be made against an invalid state!");
- return CEDMap.lookup(&CB);
+ return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
}
ExecutionDomainTy getFunctionExecutionDomain() const override {
assert(isValidState() &&
"No request should be made against an invalid state!");
- return BEDMap.lookup(nullptr);
+ return InterProceduralED;
}
///}
@@ -2778,12 +2799,28 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
return false;
};
+ /// Mapping containing information about the function for other AAs.
+ ExecutionDomainTy InterProceduralED;
+
+ enum Direction { PRE = 0, POST = 1 };
/// Mapping containing information per block.
DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
- DenseMap<const CallBase *, ExecutionDomainTy> CEDMap;
+ DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
+ CEDMap;
SmallSetVector<CallBase *, 16> AlignedBarriers;
ReversePostOrderTraversal<Function *> *RPOT = nullptr;
+
+ /// Set \p R to \V and report true if that changed \p R.
+ static bool setAndRecord(bool &R, bool V) {
+ bool Eq = (R == V);
+ R = V;
+ return !Eq;
+ }
+
+ /// Collection of fences known to be non-no-opt. All fences not in this set
+ /// can be assumed no-opt.
+ SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
};
void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
@@ -2795,62 +2832,82 @@ void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
ED.addAlignedBarrier(A, *AB);
}
-void AAExecutionDomainFunction::mergeInPredecessor(
+bool AAExecutionDomainFunction::mergeInPredecessor(
Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
bool InitialEdgeOnly) {
- ED.IsExecutedByInitialThreadOnly =
- InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
- ED.IsExecutedByInitialThreadOnly);
-
- ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly &&
- PredED.IsReachedFromAlignedBarrierOnly;
- ED.EncounteredNonLocalSideEffect =
- ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect;
+
+ bool Changed = false;
+ Changed |=
+ setAndRecord(ED.IsExecutedByInitialThreadOnly,
+ InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
+ ED.IsExecutedByInitialThreadOnly));
+
+ Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
+ ED.IsReachedFromAlignedBarrierOnly &&
+ PredED.IsReachedFromAlignedBarrierOnly);
+ Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
+ ED.EncounteredNonLocalSideEffect |
+ PredED.EncounteredNonLocalSideEffect);
+ // Do not track assumptions and barriers as part of Changed.
if (ED.IsReachedFromAlignedBarrierOnly)
mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
else
ED.clearAssumeInstAndAlignedBarriers();
+ return Changed;
}
-void AAExecutionDomainFunction::handleEntryBB(Attributor &A,
+bool AAExecutionDomainFunction::handleCallees(Attributor &A,
ExecutionDomainTy &EntryBBED) {
- SmallVector<ExecutionDomainTy> PredExecDomains;
+ SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
auto PredForCallSite = [&](AbstractCallSite ACS) {
- const auto &EDAA = A.getAAFor<AAExecutionDomain>(
+ const auto *EDAA = A.getAAFor<AAExecutionDomain>(
*this, IRPosition::function(*ACS.getInstruction()->getFunction()),
DepClassTy::OPTIONAL);
- if (!EDAA.getState().isValidState())
+ if (!EDAA || !EDAA->getState().isValidState())
return false;
- PredExecDomains.emplace_back(
- EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
+ CallSiteEDs.emplace_back(
+ EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
return true;
};
+ ExecutionDomainTy ExitED;
bool AllCallSitesKnown;
if (A.checkForAllCallSites(PredForCallSite, *this,
/* RequiresAllCallSites */ true,
AllCallSitesKnown)) {
- for (const auto &PredED : PredExecDomains)
- mergeInPredecessor(A, EntryBBED, PredED);
+ for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
+ mergeInPredecessor(A, EntryBBED, CSInED);
+ ExitED.IsReachingAlignedBarrierOnly &=
+ CSOutED.IsReachingAlignedBarrierOnly;
+ }
} else {
// We could not find all predecessors, so this is either a kernel or a
// function with external linkage (or with some other weird uses).
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- if (OMPInfoCache.Kernels.count(getAnchorScope())) {
+ if (omp::isKernel(*getAnchorScope())) {
EntryBBED.IsExecutedByInitialThreadOnly = false;
EntryBBED.IsReachedFromAlignedBarrierOnly = true;
EntryBBED.EncounteredNonLocalSideEffect = false;
+ ExitED.IsReachingAlignedBarrierOnly = true;
} else {
EntryBBED.IsExecutedByInitialThreadOnly = false;
EntryBBED.IsReachedFromAlignedBarrierOnly = false;
EntryBBED.EncounteredNonLocalSideEffect = true;
+ ExitED.IsReachingAlignedBarrierOnly = false;
}
}
+ bool Changed = false;
auto &FnED = BEDMap[nullptr];
- FnED.IsReachingAlignedBarrierOnly &=
- EntryBBED.IsReachedFromAlignedBarrierOnly;
+ Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
+ FnED.IsReachedFromAlignedBarrierOnly &
+ EntryBBED.IsReachedFromAlignedBarrierOnly);
+ Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
+ FnED.IsReachingAlignedBarrierOnly &
+ ExitED.IsReachingAlignedBarrierOnly);
+ Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
+ EntryBBED.IsExecutedByInitialThreadOnly);
+ return Changed;
}
ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
@@ -2860,36 +2917,28 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
// Helper to deal with an aligned barrier encountered during the forward
// traversal. \p CB is the aligned barrier, \p ED is the execution domain when
// it was encountered.
- auto HandleAlignedBarrier = [&](CallBase *CB, ExecutionDomainTy &ED) {
- if (CB)
- Changed |= AlignedBarriers.insert(CB);
+ auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
+ Changed |= AlignedBarriers.insert(&CB);
// First, update the barrier ED kept in the separate CEDMap.
- auto &CallED = CEDMap[CB];
- mergeInPredecessor(A, CallED, ED);
+ auto &CallInED = CEDMap[{&CB, PRE}];
+ Changed |= mergeInPredecessor(A, CallInED, ED);
+ CallInED.IsReachingAlignedBarrierOnly = true;
// Next adjust the ED we use for the traversal.
ED.EncounteredNonLocalSideEffect = false;
ED.IsReachedFromAlignedBarrierOnly = true;
// Aligned barrier collection has to come last.
ED.clearAssumeInstAndAlignedBarriers();
- if (CB)
- ED.addAlignedBarrier(A, *CB);
+ ED.addAlignedBarrier(A, CB);
+ auto &CallOutED = CEDMap[{&CB, POST}];
+ Changed |= mergeInPredecessor(A, CallOutED, ED);
};
- auto &LivenessAA =
+ auto *LivenessAA =
A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
- // Set \p R to \V and report true if that changed \p R.
- auto SetAndRecord = [&](bool &R, bool V) {
- bool Eq = (R == V);
- R = V;
- return !Eq;
- };
-
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
-
Function *F = getAnchorScope();
BasicBlock &EntryBB = F->getEntryBlock();
- bool IsKernel = OMPInfoCache.Kernels.count(F);
+ bool IsKernel = omp::isKernel(*F);
SmallVector<Instruction *> SyncInstWorklist;
for (auto &RIt : *RPOT) {
@@ -2899,18 +2948,19 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
// TODO: We use local reasoning since we don't have a divergence analysis
// running as well. We could basically allow uniform branches here.
bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
+ bool IsExplicitlyAligned = IsEntryBB && IsKernel;
ExecutionDomainTy ED;
// Propagate "incoming edges" into information about this block.
if (IsEntryBB) {
- handleEntryBB(A, ED);
+ Changed |= handleCallees(A, ED);
} else {
// For live non-entry blocks we only propagate
// information via live edges.
- if (LivenessAA.isAssumedDead(&BB))
+ if (LivenessAA && LivenessAA->isAssumedDead(&BB))
continue;
for (auto *PredBB : predecessors(&BB)) {
- if (LivenessAA.isEdgeDead(PredBB, &BB))
+ if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
continue;
bool InitialEdgeOnly = isInitialThreadOnlyEdge(
A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
@@ -2922,7 +2972,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
// information to calls.
for (Instruction &I : BB) {
bool UsedAssumedInformation;
- if (A.isAssumedDead(I, *this, &LivenessAA, UsedAssumedInformation,
+ if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
/* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
/* CheckForDeadStore */ true))
continue;
@@ -2939,6 +2989,33 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
continue;
}
+ if (auto *FI = dyn_cast<FenceInst>(&I)) {
+ if (!ED.EncounteredNonLocalSideEffect) {
+ // An aligned fence without non-local side-effects is a no-op.
+ if (ED.IsReachedFromAlignedBarrierOnly)
+ continue;
+ // A non-aligned fence without non-local side-effects is a no-op
+ // if the ordering only publishes non-local side-effects (or less).
+ switch (FI->getOrdering()) {
+ case AtomicOrdering::NotAtomic:
+ continue;
+ case AtomicOrdering::Unordered:
+ continue;
+ case AtomicOrdering::Monotonic:
+ continue;
+ case AtomicOrdering::Acquire:
+ break;
+ case AtomicOrdering::Release:
+ continue;
+ case AtomicOrdering::AcquireRelease:
+ break;
+ case AtomicOrdering::SequentiallyConsistent:
+ break;
+ };
+ }
+ NonNoOpFences.insert(FI);
+ }
+
auto *CB = dyn_cast<CallBase>(&I);
bool IsNoSync = AA::isNoSyncInst(A, I, *this);
bool IsAlignedBarrier =
@@ -2946,14 +3023,16 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
AlignedBarrierLastInBlock &= IsNoSync;
+ IsExplicitlyAligned &= IsNoSync;
// Next we check for calls. Aligned barriers are handled
// explicitly, everything else is kept for the backward traversal and will
// also affect our state.
if (CB) {
if (IsAlignedBarrier) {
- HandleAlignedBarrier(CB, ED);
+ HandleAlignedBarrier(*CB, ED);
AlignedBarrierLastInBlock = true;
+ IsExplicitlyAligned = true;
continue;
}
@@ -2971,20 +3050,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
// Record how we entered the call, then accumulate the effect of the
// call in ED for potential use by the callee.
- auto &CallED = CEDMap[CB];
- mergeInPredecessor(A, CallED, ED);
+ auto &CallInED = CEDMap[{CB, PRE}];
+ Changed |= mergeInPredecessor(A, CallInED, ED);
// If we have a sync-definition we can check if it starts/ends in an
// aligned barrier. If we are unsure we assume any sync breaks
// alignment.
Function *Callee = CB->getCalledFunction();
if (!IsNoSync && Callee && !Callee->isDeclaration()) {
- const auto &EDAA = A.getAAFor<AAExecutionDomain>(
+ const auto *EDAA = A.getAAFor<AAExecutionDomain>(
*this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
- if (EDAA.getState().isValidState()) {
- const auto &CalleeED = EDAA.getFunctionExecutionDomain();
+ if (EDAA && EDAA->getState().isValidState()) {
+ const auto &CalleeED = EDAA->getFunctionExecutionDomain();
ED.IsReachedFromAlignedBarrierOnly =
- CalleeED.IsReachedFromAlignedBarrierOnly;
+ CalleeED.IsReachedFromAlignedBarrierOnly;
AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
ED.EncounteredNonLocalSideEffect |=
@@ -2992,19 +3071,27 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
else
ED.EncounteredNonLocalSideEffect =
CalleeED.EncounteredNonLocalSideEffect;
- if (!CalleeED.IsReachingAlignedBarrierOnly)
+ if (!CalleeED.IsReachingAlignedBarrierOnly) {
+ Changed |=
+ setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
SyncInstWorklist.push_back(&I);
+ }
if (CalleeED.IsReachedFromAlignedBarrierOnly)
mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
+ auto &CallOutED = CEDMap[{CB, POST}];
+ Changed |= mergeInPredecessor(A, CallOutED, ED);
continue;
}
}
- ED.IsReachedFromAlignedBarrierOnly =
- IsNoSync && ED.IsReachedFromAlignedBarrierOnly;
+ if (!IsNoSync) {
+ ED.IsReachedFromAlignedBarrierOnly = false;
+ Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
+ SyncInstWorklist.push_back(&I);
+ }
AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
- if (!IsNoSync)
- SyncInstWorklist.push_back(&I);
+ auto &CallOutED = CEDMap[{CB, POST}];
+ Changed |= mergeInPredecessor(A, CallOutED, ED);
}
if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
@@ -3013,7 +3100,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
// If we have a callee we try to use fine-grained information to
// determine local side-effects.
if (CB) {
- const auto &MemAA = A.getAAFor<AAMemoryLocation>(
+ const auto *MemAA = A.getAAFor<AAMemoryLocation>(
*this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
auto AccessPred = [&](const Instruction *I, const Value *Ptr,
@@ -3021,13 +3108,14 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
AAMemoryLocation::MemoryLocationsKind) {
return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
};
- if (MemAA.getState().isValidState() &&
- MemAA.checkForAllAccessesToMemoryKind(
+ if (MemAA && MemAA->getState().isValidState() &&
+ MemAA->checkForAllAccessesToMemoryKind(
AccessPred, AAMemoryLocation::ALL_LOCATIONS))
continue;
}
- if (!I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(I))
+ auto &InfoCache = A.getInfoCache();
+ if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
continue;
if (auto *LI = dyn_cast<LoadInst>(&I))
@@ -3039,18 +3127,28 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
ED.EncounteredNonLocalSideEffect = true;
}
+ bool IsEndAndNotReachingAlignedBarriersOnly = false;
if (!isa<UnreachableInst>(BB.getTerminator()) &&
!BB.getTerminator()->getNumSuccessors()) {
- auto &FnED = BEDMap[nullptr];
- mergeInPredecessor(A, FnED, ED);
+ Changed |= mergeInPredecessor(A, InterProceduralED, ED);
- if (IsKernel)
- HandleAlignedBarrier(nullptr, ED);
+ auto &FnED = BEDMap[nullptr];
+ if (IsKernel && !IsExplicitlyAligned)
+ FnED.IsReachingAlignedBarrierOnly = false;
+ Changed |= mergeInPredecessor(A, FnED, ED);
+
+ if (!FnED.IsReachingAlignedBarrierOnly) {
+ IsEndAndNotReachingAlignedBarriersOnly = true;
+ SyncInstWorklist.push_back(BB.getTerminator());
+ auto &BBED = BEDMap[&BB];
+ Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
+ }
}
ExecutionDomainTy &StoredED = BEDMap[&BB];
- ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly;
+ ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
+ !IsEndAndNotReachingAlignedBarriersOnly;
// Check if we computed anything different as part of the forward
// traversal. We do not take assumptions and aligned barriers into account
@@ -3074,36 +3172,38 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
while (!SyncInstWorklist.empty()) {
Instruction *SyncInst = SyncInstWorklist.pop_back_val();
Instruction *CurInst = SyncInst;
- bool HitAlignedBarrier = false;
+ bool HitAlignedBarrierOrKnownEnd = false;
while ((CurInst = CurInst->getPrevNode())) {
auto *CB = dyn_cast<CallBase>(CurInst);
if (!CB)
continue;
- auto &CallED = CEDMap[CB];
- if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly, false))
- Changed = true;
- HitAlignedBarrier = AlignedBarriers.count(CB);
- if (HitAlignedBarrier)
+ auto &CallOutED = CEDMap[{CB, POST}];
+ Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
+ auto &CallInED = CEDMap[{CB, PRE}];
+ HitAlignedBarrierOrKnownEnd =
+ AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
+ if (HitAlignedBarrierOrKnownEnd)
break;
+ Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
}
- if (HitAlignedBarrier)
+ if (HitAlignedBarrierOrKnownEnd)
continue;
BasicBlock *SyncBB = SyncInst->getParent();
for (auto *PredBB : predecessors(SyncBB)) {
- if (LivenessAA.isEdgeDead(PredBB, SyncBB))
+ if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
continue;
if (!Visited.insert(PredBB))
continue;
- SyncInstWorklist.push_back(PredBB->getTerminator());
auto &PredED = BEDMap[PredBB];
- if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly, false))
+ if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
Changed = true;
+ SyncInstWorklist.push_back(PredBB->getTerminator());
+ }
}
if (SyncBB != &EntryBB)
continue;
- auto &FnED = BEDMap[nullptr];
- if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly, false))
- Changed = true;
+ Changed |=
+ setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
}
return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
@@ -3146,7 +3246,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
: AAHeapToShared(IRP, A) {}
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *) const override {
return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
" malloc calls eligible.";
}
@@ -3261,7 +3361,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
auto *SharedMem = new GlobalVariable(
*M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
- UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
+ PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
GlobalValue::NotThreadLocal,
static_cast<unsigned>(AddressSpace::Shared));
auto *NewBuffer =
@@ -3270,7 +3370,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
auto Remark = [&](OptimizationRemark OR) {
return OR << "Replaced globalized variable with "
<< ore::NV("SharedMemory", AllocSize->getZExtValue())
- << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
+ << (AllocSize->isOne() ? " byte " : " bytes ")
<< "of shared memory.";
};
A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
@@ -3278,7 +3378,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
MaybeAlign Alignment = CB->getRetAlign();
assert(Alignment &&
"HeapToShared on allocation without alignment attribute");
- SharedMem->setAlignment(MaybeAlign(Alignment));
+ SharedMem->setAlignment(*Alignment);
A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
A.deleteAfterManifest(*CB);
@@ -3315,9 +3415,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
MallocCalls.remove(CB);
continue;
}
- const auto &ED = A.getAAFor<AAExecutionDomain>(
+ const auto *ED = A.getAAFor<AAExecutionDomain>(
*this, IRPosition::function(*F), DepClassTy::REQUIRED);
- if (!ED.isExecutedByInitialThreadOnly(*CB))
+ if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
MallocCalls.remove(CB);
}
}
@@ -3346,7 +3446,7 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
void trackStatistics() const override {}
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *) const override {
if (!isValidState())
return "<invalid>";
return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
@@ -3456,22 +3556,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
[&](const IRPosition &IRP, const AbstractAttribute *AA,
bool &UsedAssumedInformation) -> std::optional<Value *> {
- // IRP represents the "use generic state machine" argument of an
- // __kmpc_target_init call. We will answer this one with the internal
- // state. As long as we are not in an invalid state, we will create a
- // custom state machine so the value should be a `i1 false`. If we are
- // in an invalid state, we won't change the value that is in the IR.
- if (!ReachedKnownParallelRegions.isValidState())
- return nullptr;
- // If we have disabled state machine rewrites, don't make a custom one.
- if (DisableOpenMPOptStateMachineRewrite)
return nullptr;
- if (AA)
- A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
- UsedAssumedInformation = !isAtFixpoint();
- auto *FalseVal =
- ConstantInt::getBool(IRP.getAnchorValue().getContext(), false);
- return FalseVal;
};
Attributor::SimplifictionCallbackTy ModeSimplifyCB =
@@ -3622,10 +3707,11 @@ struct AAKernelInfoFunction : AAKernelInfo {
Function *Kernel = getAnchorScope();
Module &M = *Kernel->getParent();
Type *Int8Ty = Type::getInt8Ty(M.getContext());
- new GlobalVariable(M, Int8Ty, /* isConstant */ true,
- GlobalValue::WeakAnyLinkage,
- ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0),
- Kernel->getName() + "_nested_parallelism");
+ auto *GV = new GlobalVariable(
+ M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage,
+ ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0),
+ Kernel->getName() + "_nested_parallelism");
+ GV->setVisibility(GlobalValue::HiddenVisibility);
// If we can we change the execution mode to SPMD-mode otherwise we build a
// custom state machine.
@@ -3914,6 +4000,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ // We cannot change to SPMD mode if the runtime functions aren't availible.
+ if (!OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___kmpc_get_hardware_thread_id_in_block,
+ OMPRTL___kmpc_barrier_simple_spmd}))
+ return false;
+
if (!SPMDCompatibilityTracker.isAssumed()) {
for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
if (!NonCompatibleI)
@@ -3951,7 +4043,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
auto *CB = cast<CallBase>(Kernel->user_back());
Kernel = CB->getCaller();
}
- assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
+ assert(omp::isKernel(*Kernel) && "Expected kernel function!");
// Check if the kernel is already in SPMD mode, if so, return success.
GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
@@ -4021,6 +4113,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!ReachedKnownParallelRegions.isValidState())
return ChangeStatus::UNCHANGED;
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ if (!OMPInfoCache.runtimeFnsAvailable(
+ {OMPRTL___kmpc_get_hardware_num_threads_in_block,
+ OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
+ OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
+ return ChangeStatus::UNCHANGED;
+
const int InitModeArgNo = 1;
const int InitUseStateMachineArgNo = 2;
@@ -4167,7 +4266,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
Module &M = *Kernel->getParent();
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
FunctionCallee BlockHwSizeFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
@@ -4220,10 +4318,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (WorkFnAI->getType()->getPointerAddressSpace() !=
(unsigned int)AddressSpace::Generic) {
WorkFnAI = new AddrSpaceCastInst(
- WorkFnAI,
- PointerType::getWithSamePointeeType(
- cast<PointerType>(WorkFnAI->getType()),
- (unsigned int)AddressSpace::Generic),
+ WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
WorkFnAI->getName() + ".generic", StateMachineBeginBB);
WorkFnAI->setDebugLoc(DLoc);
}
@@ -4345,19 +4440,20 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!I.mayWriteToMemory())
return true;
if (auto *SI = dyn_cast<StoreInst>(&I)) {
- const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
+ const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
*this, IRPosition::value(*SI->getPointerOperand()),
DepClassTy::OPTIONAL);
- auto &HS = A.getAAFor<AAHeapToStack>(
+ auto *HS = A.getAAFor<AAHeapToStack>(
*this, IRPosition::function(*I.getFunction()),
DepClassTy::OPTIONAL);
- if (UnderlyingObjsAA.forallUnderlyingObjects([&](Value &Obj) {
+ if (UnderlyingObjsAA &&
+ UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
if (AA::isAssumedThreadLocalObject(A, Obj, *this))
return true;
// Check for AAHeapToStack moved objects which must not be
// guarded.
auto *CB = dyn_cast<CallBase>(&Obj);
- return CB && HS.isAssumedHeapToStack(*CB);
+ return CB && HS && HS->isAssumedHeapToStack(*CB);
}))
return true;
}
@@ -4392,14 +4488,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
// we cannot fix the internal spmd-zation state either.
int SPMD = 0, Generic = 0;
for (auto *Kernel : ReachingKernelEntries) {
- auto &CBAA = A.getAAFor<AAKernelInfo>(
+ auto *CBAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
- if (CBAA.SPMDCompatibilityTracker.isValidState() &&
- CBAA.SPMDCompatibilityTracker.isAssumed())
+ if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
+ CBAA->SPMDCompatibilityTracker.isAssumed())
++SPMD;
else
++Generic;
- if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
+ if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
UsedAssumedInformationFromReachingKernels = true;
}
if (SPMD != 0 && Generic != 0)
@@ -4413,14 +4509,16 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool AllSPMDStatesWereFixed = true;
auto CheckCallInst = [&](Instruction &I) {
auto &CB = cast<CallBase>(I);
- auto &CBAA = A.getAAFor<AAKernelInfo>(
+ auto *CBAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
- getState() ^= CBAA.getState();
- AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
+ if (!CBAA)
+ return false;
+ getState() ^= CBAA->getState();
+ AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
AllParallelRegionStatesWereFixed &=
- CBAA.ReachedKnownParallelRegions.isAtFixpoint();
+ CBAA->ReachedKnownParallelRegions.isAtFixpoint();
AllParallelRegionStatesWereFixed &=
- CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
+ CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
return true;
};
@@ -4460,10 +4558,10 @@ private:
assert(Caller && "Caller is nullptr");
- auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
+ auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
- if (CAA.ReachingKernelEntries.isValidState()) {
- ReachingKernelEntries ^= CAA.ReachingKernelEntries;
+ if (CAA && CAA->ReachingKernelEntries.isValidState()) {
+ ReachingKernelEntries ^= CAA->ReachingKernelEntries;
return true;
}
@@ -4491,9 +4589,9 @@ private:
assert(Caller && "Caller is nullptr");
- auto &CAA =
+ auto *CAA =
A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
- if (CAA.ParallelLevels.isValidState()) {
+ if (CAA && CAA->ParallelLevels.isValidState()) {
// Any function that is called by `__kmpc_parallel_51` will not be
// folded as the parallel level in the function is updated. In order to
// get it right, all the analysis would depend on the implentation. That
@@ -4504,7 +4602,7 @@ private:
return true;
}
- ParallelLevels ^= CAA.ParallelLevels;
+ ParallelLevels ^= CAA->ParallelLevels;
return true;
}
@@ -4538,11 +4636,11 @@ struct AAKernelInfoCallSite : AAKernelInfo {
CallBase &CB = cast<CallBase>(getAssociatedValue());
Function *Callee = getAssociatedFunction();
- auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
+ auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
// Check for SPMD-mode assumptions.
- if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
+ if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
SPMDCompatibilityTracker.indicateOptimisticFixpoint();
indicateOptimisticFixpoint();
}
@@ -4567,8 +4665,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// Unknown callees might contain parallel regions, except if they have
// an appropriate assumption attached.
- if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
- AssumptionAA.hasAssumption("omp_no_parallelism")))
+ if (!AssumptionAA ||
+ !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+ AssumptionAA->hasAssumption("omp_no_parallelism")))
ReachedUnknownParallelRegions.insert(&CB);
// If SPMDCompatibilityTracker is not fixed, we need to give up on the
@@ -4643,11 +4742,11 @@ struct AAKernelInfoCallSite : AAKernelInfo {
CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
ReachedKnownParallelRegions.insert(ParallelRegion);
/// Check nested parallelism
- auto &FnAA = A.getAAFor<AAKernelInfo>(
+ auto *FnAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
- NestedParallelism |= !FnAA.getState().isValidState() ||
- !FnAA.ReachedKnownParallelRegions.empty() ||
- !FnAA.ReachedUnknownParallelRegions.empty();
+ NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
+ !FnAA->ReachedKnownParallelRegions.empty() ||
+ !FnAA->ReachedUnknownParallelRegions.empty();
break;
}
// The condition above should usually get the parallel region function
@@ -4691,10 +4790,12 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// If F is not a runtime function, propagate the AAKernelInfo of the callee.
if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
const IRPosition &FnPos = IRPosition::function(*F);
- auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
- if (getState() == FnAA.getState())
+ auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
+ if (getState() == FnAA->getState())
return ChangeStatus::UNCHANGED;
- getState() = FnAA.getState();
+ getState() = FnAA->getState();
return ChangeStatus::CHANGED;
}
@@ -4707,9 +4808,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
CallBase &CB = cast<CallBase>(getAssociatedValue());
- auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
+ auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
*this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
- auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
+ auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
*this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
RuntimeFunction RF = It->getSecond();
@@ -4718,13 +4819,15 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// If neither HeapToStack nor HeapToShared assume the call is removed,
// assume SPMD incompatibility.
case OMPRTL___kmpc_alloc_shared:
- if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
- !HeapToSharedAA.isAssumedHeapToShared(CB))
+ if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
+ (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
SPMDCompatibilityTracker.insert(&CB);
break;
case OMPRTL___kmpc_free_shared:
- if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
- !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
+ if ((!HeapToStackAA ||
+ !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
+ (!HeapToSharedAA ||
+ !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
SPMDCompatibilityTracker.insert(&CB);
break;
default:
@@ -4770,7 +4873,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
: AAFoldRuntimeCall(IRP, A) {}
/// See AbstractAttribute::getAsStr()
- const std::string getAsStr() const override {
+ const std::string getAsStr(Attributor *) const override {
if (!isValidState())
return "<invalid>";
@@ -4883,28 +4986,29 @@ private:
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
- auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
- if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
+ if (!CallerKernelInfoAA ||
+ !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();
- for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
- auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
+ for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
+ auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
- if (!AA.isValidState()) {
+ if (!AA || !AA->isValidState()) {
SimplifiedValue = nullptr;
return indicatePessimisticFixpoint();
}
- if (AA.SPMDCompatibilityTracker.isAssumed()) {
- if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ if (AA->SPMDCompatibilityTracker.isAssumed()) {
+ if (AA->SPMDCompatibilityTracker.isAtFixpoint())
++KnownSPMDCount;
else
++AssumedSPMDCount;
} else {
- if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ if (AA->SPMDCompatibilityTracker.isAtFixpoint())
++KnownNonSPMDCount;
else
++AssumedNonSPMDCount;
@@ -4943,16 +5047,17 @@ private:
ChangeStatus foldParallelLevel(Attributor &A) {
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
- if (!CallerKernelInfoAA.ParallelLevels.isValidState())
+ if (!CallerKernelInfoAA ||
+ !CallerKernelInfoAA->ParallelLevels.isValidState())
return indicatePessimisticFixpoint();
- if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
+ if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();
- if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
+ if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
assert(!SimplifiedValue &&
"SimplifiedValue should keep none at this point");
return ChangeStatus::UNCHANGED;
@@ -4960,19 +5065,19 @@ private:
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
- for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
- auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
+ for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
+ auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
- if (!AA.SPMDCompatibilityTracker.isValidState())
+ if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
return indicatePessimisticFixpoint();
- if (AA.SPMDCompatibilityTracker.isAssumed()) {
- if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ if (AA->SPMDCompatibilityTracker.isAssumed()) {
+ if (AA->SPMDCompatibilityTracker.isAtFixpoint())
++KnownSPMDCount;
else
++AssumedSPMDCount;
} else {
- if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ if (AA->SPMDCompatibilityTracker.isAtFixpoint())
++KnownNonSPMDCount;
else
++AssumedNonSPMDCount;
@@ -5005,14 +5110,15 @@ private:
int32_t CurrentAttrValue = -1;
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
- if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
+ if (!CallerKernelInfoAA ||
+ !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();
// Iterate over the kernels that reach this function
- for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
+ for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
if (NextAttrVal == -1 ||
@@ -5135,6 +5241,8 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
if (!DisableOpenMPOptDeglobalization)
A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
+ if (F.hasFnAttribute(Attribute::Convergent))
+ A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
for (auto &I : instructions(F)) {
if (auto *LI = dyn_cast<LoadInst>(&I)) {
@@ -5147,6 +5255,10 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
continue;
}
+ if (auto *FI = dyn_cast<FenceInst>(&I)) {
+ A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
+ continue;
+ }
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
if (II->getIntrinsicID() == Intrinsic::assume) {
A.getOrCreateAAFor<AAPotentialValues>(
@@ -5304,6 +5416,8 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
});
};
+ bool Changed = false;
+
// Create internal copies of each function if this is a kernel Module. This
// allows iterprocedural passes to see every call edge.
DenseMap<Function *, Function *> InternalizedMap;
@@ -5319,7 +5433,8 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
}
}
- Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
+ Changed |=
+ Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
}
// Look at every function in the Module unless it was internalized.
@@ -5332,7 +5447,7 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
}
if (SCC.empty())
- return PreservedAnalyses::all();
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
AnalysisGetter AG(FAM);
@@ -5343,7 +5458,9 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
BumpPtrAllocator Allocator;
CallGraphUpdater CGUpdater;
- OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
+ bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+ LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
+ OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;
@@ -5356,11 +5473,14 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
AC.OREGetter = OREGetter;
AC.PassName = DEBUG_TYPE;
AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
+ AC.IPOAmendableCB = [](const Function &F) {
+ return F.hasFnAttribute("kernel");
+ };
Attributor A(Functions, InfoCache, AC);
OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
- bool Changed = OMPOpt.run(true);
+ Changed |= OMPOpt.run(true);
// Optionally inline device functions for potentially better performance.
if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
@@ -5417,9 +5537,11 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
CallGraphUpdater CGUpdater;
CGUpdater.initialize(CG, C, AM, UR);
+ bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+ LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
SetVector<Function *> Functions(SCC.begin(), SCC.end());
OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
- /*CGSCC*/ &Functions, Kernels);
+ /*CGSCC*/ &Functions, PostLink);
unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;
@@ -5447,6 +5569,8 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
return PreservedAnalyses::all();
}
+bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); }
+
KernelSet llvm::omp::getDeviceKernels(Module &M) {
// TODO: Create a more cross-platform way of determining device kernels.
NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
@@ -5467,6 +5591,7 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) {
if (!KernelFn)
continue;
+ assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation");
++NumOpenMPTargetRegionKernels;
Kernels.insert(KernelFn);
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 310e4d4164a5..b88ba2dec24b 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -14,6 +14,7 @@
#include "llvm/Transforms/IPO/PartialInlining.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
@@ -41,8 +42,6 @@
#include "llvm/IR/Operator.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/User.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/BlockFrequency.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/Casting.h"
@@ -342,52 +341,6 @@ private:
OptimizationRemarkEmitter &ORE) const;
};
-struct PartialInlinerLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
-
- PartialInlinerLegacyPass() : ModulePass(ID) {
- initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<ProfileSummaryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>();
- TargetTransformInfoWrapperPass *TTIWP =
- &getAnalysis<TargetTransformInfoWrapperPass>();
- ProfileSummaryInfo &PSI =
- getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
-
- auto GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & {
- return ACT->getAssumptionCache(F);
- };
-
- auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * {
- return ACT->lookupAssumptionCache(F);
- };
-
- auto GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & {
- return TTIWP->getTTI(F);
- };
-
- auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
- return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
-
- return PartialInlinerImpl(GetAssumptionCache, LookupAssumptionCache, GetTTI,
- GetTLI, PSI)
- .run(M);
- }
-};
-
} // end anonymous namespace
std::unique_ptr<FunctionOutliningMultiRegionInfo>
@@ -1027,7 +980,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner(
// Go through all Outline Candidate Regions and update all BasicBlock
// information.
- for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo :
+ for (const FunctionOutliningMultiRegionInfo::OutlineRegionInfo &RegionInfo :
OI->ORI) {
SmallVector<BasicBlock *, 8> Region;
for (BasicBlock *BB : RegionInfo.Region)
@@ -1226,14 +1179,14 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
ToExtract.push_back(ClonedOI->NonReturnBlock);
OutlinedRegionCost += PartialInlinerImpl::computeBBInlineCost(
ClonedOI->NonReturnBlock, ClonedFuncTTI);
- for (BasicBlock &BB : *ClonedFunc)
- if (!ToBeInlined(&BB) && &BB != ClonedOI->NonReturnBlock) {
- ToExtract.push_back(&BB);
+ for (BasicBlock *BB : depth_first(&ClonedFunc->getEntryBlock()))
+ if (!ToBeInlined(BB) && BB != ClonedOI->NonReturnBlock) {
+ ToExtract.push_back(BB);
// FIXME: the code extractor may hoist/sink more code
// into the outlined function which may make the outlining
// overhead (the difference of the outlined function cost
// and OutliningRegionCost) look larger.
- OutlinedRegionCost += computeBBInlineCost(&BB, ClonedFuncTTI);
+ OutlinedRegionCost += computeBBInlineCost(BB, ClonedFuncTTI);
}
// Extract the body of the if.
@@ -1429,7 +1382,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
OR << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into "
<< ore::NV("Caller", CB->getCaller());
- InlineFunctionInfo IFI(nullptr, GetAssumptionCache, &PSI);
+ InlineFunctionInfo IFI(GetAssumptionCache, &PSI);
// We can only forward varargs when we outlined a single region, else we
// bail on vararg functions.
if (!InlineFunction(*CB, IFI, /*MergeAttributes=*/false, nullptr, true,
@@ -1497,21 +1450,6 @@ bool PartialInlinerImpl::run(Module &M) {
return Changed;
}
-char PartialInlinerLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner",
- "Partial Inliner", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner",
- "Partial Inliner", false, false)
-
-ModulePass *llvm::createPartialInliningPass() {
- return new PartialInlinerLegacyPass();
-}
-
PreservedAnalyses PartialInlinerPass::run(Module &M,
ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
deleted file mode 100644
index 6b91c8494f39..000000000000
--- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
+++ /dev/null
@@ -1,517 +0,0 @@
-//===- PassManagerBuilder.cpp - Build Standard Pass -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the PassManagerBuilder class, which is used to set up a
-// "standard" optimization sequence suitable for languages like C and C++.
-//
-//===----------------------------------------------------------------------===//
-
-#include "llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "llvm-c/Transforms/PassManagerBuilder.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Analysis/GlobalsModRef.h"
-#include "llvm/Analysis/ScopedNoAliasAA.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
-#include "llvm/IR/LegacyPassManager.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/ManagedStatic.h"
-#include "llvm/Target/CGPassBuilderOption.h"
-#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
-#include "llvm/Transforms/IPO.h"
-#include "llvm/Transforms/IPO/Attributor.h"
-#include "llvm/Transforms/IPO/ForceFunctionAttrs.h"
-#include "llvm/Transforms/IPO/FunctionAttrs.h"
-#include "llvm/Transforms/IPO/InferFunctionAttrs.h"
-#include "llvm/Transforms/InstCombine/InstCombine.h"
-#include "llvm/Transforms/Instrumentation.h"
-#include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Scalar/GVN.h"
-#include "llvm/Transforms/Scalar/LICM.h"
-#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
-#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h"
-#include "llvm/Transforms/Utils.h"
-#include "llvm/Transforms/Vectorize.h"
-
-using namespace llvm;
-
-PassManagerBuilder::PassManagerBuilder() {
- OptLevel = 2;
- SizeLevel = 0;
- LibraryInfo = nullptr;
- Inliner = nullptr;
- DisableUnrollLoops = false;
- SLPVectorize = false;
- LoopVectorize = true;
- LoopsInterleaved = true;
- LicmMssaOptCap = SetLicmMssaOptCap;
- LicmMssaNoAccForPromotionCap = SetLicmMssaNoAccForPromotionCap;
- DisableGVNLoadPRE = false;
- ForgetAllSCEVInLoopUnroll = ForgetSCEVInLoopUnroll;
- VerifyInput = false;
- VerifyOutput = false;
- MergeFunctions = false;
- DivergentTarget = false;
- CallGraphProfile = true;
-}
-
-PassManagerBuilder::~PassManagerBuilder() {
- delete LibraryInfo;
- delete Inliner;
-}
-
-void PassManagerBuilder::addInitialAliasAnalysisPasses(
- legacy::PassManagerBase &PM) const {
- // Add TypeBasedAliasAnalysis before BasicAliasAnalysis so that
- // BasicAliasAnalysis wins if they disagree. This is intended to help
- // support "obvious" type-punning idioms.
- PM.add(createTypeBasedAAWrapperPass());
- PM.add(createScopedNoAliasAAWrapperPass());
-}
-
-void PassManagerBuilder::populateFunctionPassManager(
- legacy::FunctionPassManager &FPM) {
- // Add LibraryInfo if we have some.
- if (LibraryInfo)
- FPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo));
-
- if (OptLevel == 0) return;
-
- addInitialAliasAnalysisPasses(FPM);
-
- // Lower llvm.expect to metadata before attempting transforms.
- // Compare/branch metadata may alter the behavior of passes like SimplifyCFG.
- FPM.add(createLowerExpectIntrinsicPass());
- FPM.add(createCFGSimplificationPass());
- FPM.add(createSROAPass());
- FPM.add(createEarlyCSEPass());
-}
-
-void PassManagerBuilder::addFunctionSimplificationPasses(
- legacy::PassManagerBase &MPM) {
- // Start of function pass.
- // Break up aggregate allocas, using SSAUpdater.
- assert(OptLevel >= 1 && "Calling function optimizer with no optimization level!");
- MPM.add(createSROAPass());
- MPM.add(createEarlyCSEPass(true /* Enable mem-ssa. */)); // Catch trivial redundancies
-
- if (OptLevel > 1) {
- // Speculative execution if the target has divergent branches; otherwise nop.
- MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass());
-
- MPM.add(createJumpThreadingPass()); // Thread jumps.
- MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals
- }
- MPM.add(
- createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp(
- true))); // Merge & remove BBs
- // Combine silly seq's
- MPM.add(createInstructionCombiningPass());
- if (SizeLevel == 0)
- MPM.add(createLibCallsShrinkWrapPass());
-
- // TODO: Investigate the cost/benefit of tail call elimination on debugging.
- if (OptLevel > 1)
- MPM.add(createTailCallEliminationPass()); // Eliminate tail calls
- MPM.add(
- createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp(
- true))); // Merge & remove BBs
- MPM.add(createReassociatePass()); // Reassociate expressions
-
- // Begin the loop pass pipeline.
-
- // The simple loop unswitch pass relies on separate cleanup passes. Schedule
- // them first so when we re-process a loop they run before other loop
- // passes.
- MPM.add(createLoopInstSimplifyPass());
- MPM.add(createLoopSimplifyCFGPass());
-
- // Try to remove as much code from the loop header as possible,
- // to reduce amount of IR that will have to be duplicated. However,
- // do not perform speculative hoisting the first time as LICM
- // will destroy metadata that may not need to be destroyed if run
- // after loop rotation.
- // TODO: Investigate promotion cap for O1.
- MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- /*AllowSpeculation=*/false));
- // Rotate Loop - disable header duplication at -Oz
- MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, false));
- // TODO: Investigate promotion cap for O1.
- MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- /*AllowSpeculation=*/true));
- MPM.add(createSimpleLoopUnswitchLegacyPass(OptLevel == 3));
- // FIXME: We break the loop pass pipeline here in order to do full
- // simplifycfg. Eventually loop-simplifycfg should be enhanced to replace the
- // need for this.
- MPM.add(createCFGSimplificationPass(
- SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
- MPM.add(createInstructionCombiningPass());
- // We resume loop passes creating a second loop pipeline here.
- MPM.add(createLoopIdiomPass()); // Recognize idioms like memset.
- MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars
- MPM.add(createLoopDeletionPass()); // Delete dead loops
-
- // Unroll small loops and perform peeling.
- MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops,
- ForgetAllSCEVInLoopUnroll));
- // This ends the loop pass pipelines.
-
- // Break up allocas that may now be splittable after loop unrolling.
- MPM.add(createSROAPass());
-
- if (OptLevel > 1) {
- MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds
- MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies
- }
- MPM.add(createSCCPPass()); // Constant prop with SCCP
-
- // Delete dead bit computations (instcombine runs after to fold away the dead
- // computations, and then ADCE will run later to exploit any new DCE
- // opportunities that creates).
- MPM.add(createBitTrackingDCEPass()); // Delete dead bit computations
-
- // Run instcombine after redundancy elimination to exploit opportunities
- // opened up by them.
- MPM.add(createInstructionCombiningPass());
- if (OptLevel > 1) {
- MPM.add(createJumpThreadingPass()); // Thread jumps
- MPM.add(createCorrelatedValuePropagationPass());
- }
- MPM.add(createAggressiveDCEPass()); // Delete dead instructions
-
- MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset
- // TODO: Investigate if this is too expensive at O1.
- if (OptLevel > 1) {
- MPM.add(createDeadStoreEliminationPass()); // Delete dead stores
- MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- /*AllowSpeculation=*/true));
- }
-
- // Merge & remove BBs and sink & hoist common instructions.
- MPM.add(createCFGSimplificationPass(
- SimplifyCFGOptions().hoistCommonInsts(true).sinkCommonInsts(true)));
- // Clean up after everything.
- MPM.add(createInstructionCombiningPass());
-}
-
-/// FIXME: Should LTO cause any differences to this set of passes?
-void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM,
- bool IsFullLTO) {
- PM.add(createLoopVectorizePass(!LoopsInterleaved, !LoopVectorize));
-
- if (IsFullLTO) {
- // The vectorizer may have significantly shortened a loop body; unroll
- // again. Unroll small loops to hide loop backedge latency and saturate any
- // parallel execution resources of an out-of-order processor. We also then
- // need to clean up redundancies and loop invariant code.
- // FIXME: It would be really good to use a loop-integrated instruction
- // combiner for cleanup here so that the unrolling and LICM can be pipelined
- // across the loop nests.
- PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops,
- ForgetAllSCEVInLoopUnroll));
- PM.add(createWarnMissedTransformationsPass());
- }
-
- if (!IsFullLTO) {
- // Eliminate loads by forwarding stores from the previous iteration to loads
- // of the current iteration.
- PM.add(createLoopLoadEliminationPass());
- }
- // Cleanup after the loop optimization passes.
- PM.add(createInstructionCombiningPass());
-
- // Now that we've formed fast to execute loop structures, we do further
- // optimizations. These are run afterward as they might block doing complex
- // analyses and transforms such as what are needed for loop vectorization.
-
- // Cleanup after loop vectorization, etc. Simplification passes like CVP and
- // GVN, loop transforms, and others have already run, so it's now better to
- // convert to more optimized IR using more aggressive simplify CFG options.
- // The extra sinking transform can create larger basic blocks, so do this
- // before SLP vectorization.
- PM.add(createCFGSimplificationPass(SimplifyCFGOptions()
- .forwardSwitchCondToPhi(true)
- .convertSwitchRangeToICmp(true)
- .convertSwitchToLookupTable(true)
- .needCanonicalLoops(false)
- .hoistCommonInsts(true)
- .sinkCommonInsts(true)));
-
- if (IsFullLTO) {
- PM.add(createSCCPPass()); // Propagate exposed constants
- PM.add(createInstructionCombiningPass()); // Clean up again
- PM.add(createBitTrackingDCEPass());
- }
-
- // Optimize parallel scalar instruction chains into SIMD instructions.
- if (SLPVectorize) {
- PM.add(createSLPVectorizerPass());
- }
-
- // Enhance/cleanup vector code.
- PM.add(createVectorCombinePass());
-
- if (!IsFullLTO) {
- PM.add(createInstructionCombiningPass());
-
- // Unroll small loops
- PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops,
- ForgetAllSCEVInLoopUnroll));
-
- if (!DisableUnrollLoops) {
- // LoopUnroll may generate some redundency to cleanup.
- PM.add(createInstructionCombiningPass());
-
- // Runtime unrolling will introduce runtime check in loop prologue. If the
- // unrolled loop is a inner loop, then the prologue will be inside the
- // outer loop. LICM pass can help to promote the runtime check out if the
- // checked value is loop invariant.
- PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- /*AllowSpeculation=*/true));
- }
-
- PM.add(createWarnMissedTransformationsPass());
- }
-
- // After vectorization and unrolling, assume intrinsics may tell us more
- // about pointer alignments.
- PM.add(createAlignmentFromAssumptionsPass());
-
- if (IsFullLTO)
- PM.add(createInstructionCombiningPass());
-}
-
-void PassManagerBuilder::populateModulePassManager(
- legacy::PassManagerBase &MPM) {
- MPM.add(createAnnotation2MetadataLegacyPass());
-
- // Allow forcing function attributes as a debugging and tuning aid.
- MPM.add(createForceFunctionAttrsLegacyPass());
-
- // If all optimizations are disabled, just run the always-inline pass and,
- // if enabled, the function merging pass.
- if (OptLevel == 0) {
- if (Inliner) {
- MPM.add(Inliner);
- Inliner = nullptr;
- }
-
- // FIXME: The BarrierNoopPass is a HACK! The inliner pass above implicitly
- // creates a CGSCC pass manager, but we don't want to add extensions into
- // that pass manager. To prevent this we insert a no-op module pass to reset
- // the pass manager to get the same behavior as EP_OptimizerLast in non-O0
- // builds. The function merging pass is
- if (MergeFunctions)
- MPM.add(createMergeFunctionsPass());
- return;
- }
-
- // Add LibraryInfo if we have some.
- if (LibraryInfo)
- MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo));
-
- addInitialAliasAnalysisPasses(MPM);
-
- // Infer attributes about declarations if possible.
- MPM.add(createInferFunctionAttrsLegacyPass());
-
- if (OptLevel > 2)
- MPM.add(createCallSiteSplittingPass());
-
- MPM.add(createIPSCCPPass()); // IP SCCP
- MPM.add(createCalledValuePropagationPass());
-
- MPM.add(createGlobalOptimizerPass()); // Optimize out global vars
- // Promote any localized global vars.
- MPM.add(createPromoteMemoryToRegisterPass());
-
- MPM.add(createDeadArgEliminationPass()); // Dead argument elimination
-
- MPM.add(createInstructionCombiningPass()); // Clean up after IPCP & DAE
- MPM.add(
- createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp(
- true))); // Clean up after IPCP & DAE
-
- // We add a module alias analysis pass here. In part due to bugs in the
- // analysis infrastructure this "works" in that the analysis stays alive
- // for the entire SCC pass run below.
- MPM.add(createGlobalsAAWrapperPass());
-
- // Start of CallGraph SCC passes.
- bool RunInliner = false;
- if (Inliner) {
- MPM.add(Inliner);
- Inliner = nullptr;
- RunInliner = true;
- }
-
- MPM.add(createPostOrderFunctionAttrsLegacyPass());
-
- addFunctionSimplificationPasses(MPM);
-
- // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC
- // pass manager that we are specifically trying to avoid. To prevent this
- // we must insert a no-op module pass to reset the pass manager.
- MPM.add(createBarrierNoopPass());
-
- if (OptLevel > 1)
- // Remove avail extern fns and globals definitions if we aren't
- // compiling an object file for later LTO. For LTO we want to preserve
- // these so they are eligible for inlining at link-time. Note if they
- // are unreferenced they will be removed by GlobalDCE later, so
- // this only impacts referenced available externally globals.
- // Eventually they will be suppressed during codegen, but eliminating
- // here enables more opportunity for GlobalDCE as it may make
- // globals referenced by available external functions dead
- // and saves running remaining passes on the eliminated functions.
- MPM.add(createEliminateAvailableExternallyPass());
-
- MPM.add(createReversePostOrderFunctionAttrsPass());
-
- // The inliner performs some kind of dead code elimination as it goes,
- // but there are cases that are not really caught by it. We might
- // at some point consider teaching the inliner about them, but it
- // is OK for now to run GlobalOpt + GlobalDCE in tandem as their
- // benefits generally outweight the cost, making the whole pipeline
- // faster.
- if (RunInliner) {
- MPM.add(createGlobalOptimizerPass());
- MPM.add(createGlobalDCEPass());
- }
-
- // We add a fresh GlobalsModRef run at this point. This is particularly
- // useful as the above will have inlined, DCE'ed, and function-attr
- // propagated everything. We should at this point have a reasonably minimal
- // and richly annotated call graph. By computing aliasing and mod/ref
- // information for all local globals here, the late loop passes and notably
- // the vectorizer will be able to use them to help recognize vectorizable
- // memory operations.
- //
- // Note that this relies on a bug in the pass manager which preserves
- // a module analysis into a function pass pipeline (and throughout it) so
- // long as the first function pass doesn't invalidate the module analysis.
- // Thus both Float2Int and LoopRotate have to preserve AliasAnalysis for
- // this to work. Fortunately, it is trivial to preserve AliasAnalysis
- // (doing nothing preserves it as it is required to be conservatively
- // correct in the face of IR changes).
- MPM.add(createGlobalsAAWrapperPass());
-
- MPM.add(createFloat2IntPass());
- MPM.add(createLowerConstantIntrinsicsPass());
-
- // Re-rotate loops in all our loop nests. These may have fallout out of
- // rotated form due to GVN or other transformations, and the vectorizer relies
- // on the rotated form. Disable header duplication at -Oz.
- MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, false));
-
- // Distribute loops to allow partial vectorization. I.e. isolate dependences
- // into separate loop that would otherwise inhibit vectorization. This is
- // currently only performed for loops marked with the metadata
- // llvm.loop.distribute=true or when -enable-loop-distribute is specified.
- MPM.add(createLoopDistributePass());
-
- addVectorPasses(MPM, /* IsFullLTO */ false);
-
- // FIXME: We shouldn't bother with this anymore.
- MPM.add(createStripDeadPrototypesPass()); // Get rid of dead prototypes
-
- // GlobalOpt already deletes dead functions and globals, at -O2 try a
- // late pass of GlobalDCE. It is capable of deleting dead cycles.
- if (OptLevel > 1) {
- MPM.add(createGlobalDCEPass()); // Remove dead fns and globals.
- MPM.add(createConstantMergePass()); // Merge dup global constants
- }
-
- if (MergeFunctions)
- MPM.add(createMergeFunctionsPass());
-
- // LoopSink pass sinks instructions hoisted by LICM, which serves as a
- // canonicalization pass that enables other optimizations. As a result,
- // LoopSink pass needs to be a very late IR pass to avoid undoing LICM
- // result too early.
- MPM.add(createLoopSinkPass());
- // Get rid of LCSSA nodes.
- MPM.add(createInstSimplifyLegacyPass());
-
- // This hoists/decomposes div/rem ops. It should run after other sink/hoist
- // passes to avoid re-sinking, but before SimplifyCFG because it can allow
- // flattening of blocks.
- MPM.add(createDivRemPairsPass());
-
- // LoopSink (and other loop passes since the last simplifyCFG) might have
- // resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
- MPM.add(createCFGSimplificationPass(
- SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
-}
-
-LLVMPassManagerBuilderRef LLVMPassManagerBuilderCreate() {
- PassManagerBuilder *PMB = new PassManagerBuilder();
- return wrap(PMB);
-}
-
-void LLVMPassManagerBuilderDispose(LLVMPassManagerBuilderRef PMB) {
- PassManagerBuilder *Builder = unwrap(PMB);
- delete Builder;
-}
-
-void
-LLVMPassManagerBuilderSetOptLevel(LLVMPassManagerBuilderRef PMB,
- unsigned OptLevel) {
- PassManagerBuilder *Builder = unwrap(PMB);
- Builder->OptLevel = OptLevel;
-}
-
-void
-LLVMPassManagerBuilderSetSizeLevel(LLVMPassManagerBuilderRef PMB,
- unsigned SizeLevel) {
- PassManagerBuilder *Builder = unwrap(PMB);
- Builder->SizeLevel = SizeLevel;
-}
-
-void
-LLVMPassManagerBuilderSetDisableUnitAtATime(LLVMPassManagerBuilderRef PMB,
- LLVMBool Value) {
- // NOTE: The DisableUnitAtATime switch has been removed.
-}
-
-void
-LLVMPassManagerBuilderSetDisableUnrollLoops(LLVMPassManagerBuilderRef PMB,
- LLVMBool Value) {
- PassManagerBuilder *Builder = unwrap(PMB);
- Builder->DisableUnrollLoops = Value;
-}
-
-void
-LLVMPassManagerBuilderSetDisableSimplifyLibCalls(LLVMPassManagerBuilderRef PMB,
- LLVMBool Value) {
- // NOTE: The simplify-libcalls pass has been removed.
-}
-
-void
-LLVMPassManagerBuilderUseInlinerWithThreshold(LLVMPassManagerBuilderRef PMB,
- unsigned Threshold) {
- PassManagerBuilder *Builder = unwrap(PMB);
- Builder->Inliner = createFunctionInliningPass(Threshold);
-}
-
-void
-LLVMPassManagerBuilderPopulateFunctionPassManager(LLVMPassManagerBuilderRef PMB,
- LLVMPassManagerRef PM) {
- PassManagerBuilder *Builder = unwrap(PMB);
- legacy::FunctionPassManager *FPM = unwrap<legacy::FunctionPassManager>(PM);
- Builder->populateFunctionPassManager(*FPM);
-}
-
-void
-LLVMPassManagerBuilderPopulateModulePassManager(LLVMPassManagerBuilderRef PMB,
- LLVMPassManagerRef PM) {
- PassManagerBuilder *Builder = unwrap(PMB);
- legacy::PassManagerBase *MPM = unwrap(PM);
- Builder->populateModulePassManager(*MPM);
-}
diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp
index 5c1582ddfdae..e2e6364df906 100644
--- a/llvm/lib/Transforms/IPO/SCCP.cpp
+++ b/llvm/lib/Transforms/IPO/SCCP.cpp
@@ -13,14 +13,14 @@
#include "llvm/Transforms/IPO/SCCP.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/AssumptionCache.h"
-#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/Analysis/ValueTracking.h"
-#include "llvm/InitializePasses.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Support/CommandLine.h"
@@ -42,8 +42,8 @@ STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable");
STATISTIC(NumInstReplaced,
"Number of instructions replaced with (simpler) instruction");
-static cl::opt<unsigned> FuncSpecializationMaxIters(
- "func-specialization-max-iters", cl::init(1), cl::Hidden, cl::desc(
+static cl::opt<unsigned> FuncSpecMaxIters(
+ "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc(
"The maximum number of iterations function specialization is run"));
static void findReturnsToZap(Function &F,
@@ -111,10 +111,12 @@ static bool runIPSCCP(
std::function<const TargetLibraryInfo &(Function &)> GetTLI,
std::function<TargetTransformInfo &(Function &)> GetTTI,
std::function<AssumptionCache &(Function &)> GetAC,
- function_ref<AnalysisResultsForFn(Function &)> getAnalysis,
+ std::function<DominatorTree &(Function &)> GetDT,
+ std::function<BlockFrequencyInfo &(Function &)> GetBFI,
bool IsFuncSpecEnabled) {
SCCPSolver Solver(DL, GetTLI, M.getContext());
- FunctionSpecializer Specializer(Solver, M, FAM, GetTLI, GetTTI, GetAC);
+ FunctionSpecializer Specializer(Solver, M, FAM, GetBFI, GetTLI, GetTTI,
+ GetAC);
// Loop over all functions, marking arguments to those with their addresses
// taken or that are external as overdefined.
@@ -122,7 +124,9 @@ static bool runIPSCCP(
if (F.isDeclaration())
continue;
- Solver.addAnalysis(F, getAnalysis(F));
+ DominatorTree &DT = GetDT(F);
+ AssumptionCache &AC = GetAC(F);
+ Solver.addPredicateInfo(F, DT, AC);
// Determine if we can track the function's return values. If so, add the
// function to the solver's set of return-tracked functions.
@@ -158,7 +162,7 @@ static bool runIPSCCP(
if (IsFuncSpecEnabled) {
unsigned Iters = 0;
- while (Iters++ < FuncSpecializationMaxIters && Specializer.run());
+ while (Iters++ < FuncSpecMaxIters && Specializer.run());
}
// Iterate over all of the instructions in the module, replacing them with
@@ -187,8 +191,8 @@ static bool runIPSCCP(
if (ME == MemoryEffects::unknown())
return AL;
- ME |= MemoryEffects(MemoryEffects::Other,
- ME.getModRef(MemoryEffects::ArgMem));
+ ME |= MemoryEffects(IRMemLocation::Other,
+ ME.getModRef(IRMemLocation::ArgMem));
return AL.addFnAttribute(
F.getContext(),
Attribute::getWithMemoryEffects(F.getContext(), ME));
@@ -223,10 +227,9 @@ static bool runIPSCCP(
BB, InsertedValues, NumInstRemoved, NumInstReplaced);
}
- DomTreeUpdater DTU = IsFuncSpecEnabled && Specializer.isClonedFunction(&F)
- ? DomTreeUpdater(DomTreeUpdater::UpdateStrategy::Lazy)
- : Solver.getDTU(F);
-
+ DominatorTree *DT = FAM->getCachedResult<DominatorTreeAnalysis>(F);
+ PostDominatorTree *PDT = FAM->getCachedResult<PostDominatorTreeAnalysis>(F);
+ DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
// Change dead blocks to unreachable. We do it after replacing constants
// in all executable blocks, because changeToUnreachable may remove PHI
// nodes in executable blocks we found values for. The function's entry
@@ -292,13 +295,6 @@ static bool runIPSCCP(
if (!CB || CB->getCalledFunction() != F)
continue;
- // Limit to cases where the return value is guaranteed to be neither
- // poison nor undef. Poison will be outside any range and currently
- // values outside of the specified range cause immediate undefined
- // behavior.
- if (!isGuaranteedNotToBeUndefOrPoison(CB, nullptr, CB))
- continue;
-
// Do not touch existing metadata for now.
// TODO: We should be able to take the intersection of the existing
// metadata and the inferred range.
@@ -338,9 +334,14 @@ static bool runIPSCCP(
// Remove the returned attribute for zapped functions and the
// corresponding call sites.
+ // Also remove any attributes that convert an undef return value into
+ // immediate undefined behavior
+ AttributeMask UBImplyingAttributes =
+ AttributeFuncs::getUBImplyingAttributes();
for (Function *F : FuncZappedReturn) {
for (Argument &A : F->args())
F->removeParamAttr(A.getArgNo(), Attribute::Returned);
+ F->removeRetAttrs(UBImplyingAttributes);
for (Use &U : F->uses()) {
CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (!CB) {
@@ -354,6 +355,7 @@ static bool runIPSCCP(
for (Use &Arg : CB->args())
CB->removeParamAttr(CB->getArgOperandNo(&Arg), Attribute::Returned);
+ CB->removeRetAttrs(UBImplyingAttributes);
}
}
@@ -368,9 +370,9 @@ static bool runIPSCCP(
while (!GV->use_empty()) {
StoreInst *SI = cast<StoreInst>(GV->user_back());
SI->eraseFromParent();
- MadeChanges = true;
}
- M.getGlobalList().erase(GV);
+ MadeChanges = true;
+ M.eraseGlobalVariable(GV);
++NumGlobalConst;
}
@@ -389,15 +391,15 @@ PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) {
auto GetAC = [&FAM](Function &F) -> AssumptionCache & {
return FAM.getResult<AssumptionAnalysis>(F);
};
- auto getAnalysis = [&FAM, this](Function &F) -> AnalysisResultsForFn {
- DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
- return {
- std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)),
- &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F),
- isFuncSpecEnabled() ? &FAM.getResult<LoopAnalysis>(F) : nullptr };
+ auto GetDT = [&FAM](Function &F) -> DominatorTree & {
+ return FAM.getResult<DominatorTreeAnalysis>(F);
};
+ auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & {
+ return FAM.getResult<BlockFrequencyAnalysis>(F);
+ };
+
- if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, getAnalysis,
+ if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, GetDT, GetBFI,
isFuncSpecEnabled()))
return PreservedAnalyses::all();
@@ -407,73 +409,3 @@ PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) {
PA.preserve<FunctionAnalysisManagerModuleProxy>();
return PA;
}
-
-namespace {
-
-//===--------------------------------------------------------------------===//
-//
-/// IPSCCP Class - This class implements interprocedural Sparse Conditional
-/// Constant Propagation.
-///
-class IPSCCPLegacyPass : public ModulePass {
-public:
- static char ID;
-
- IPSCCPLegacyPass() : ModulePass(ID) {
- initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
- const DataLayout &DL = M.getDataLayout();
- auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
- return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
- auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
- return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- };
- auto GetAC = [this](Function &F) -> AssumptionCache & {
- return this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- };
- auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn {
- DominatorTree &DT =
- this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
- return {
- std::make_unique<PredicateInfo>(
- F, DT,
- this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
- F)),
- nullptr, // We cannot preserve the LI, DT or PDT with the legacy pass
- nullptr, // manager, so set them to nullptr.
- nullptr};
- };
-
- return runIPSCCP(M, DL, nullptr, GetTLI, GetTTI, GetAC, getAnalysis, false);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-char IPSCCPLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp",
- "Interprocedural Sparse Conditional Constant Propagation",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp",
- "Interprocedural Sparse Conditional Constant Propagation",
- false, false)
-
-// createIPSCCPPass - This is the public interface to this file.
-ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); }
-
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 93b368fd72a6..a53baecd4776 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -35,9 +35,9 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
-#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/InlineAdvisor.h"
#include "llvm/Analysis/InlineCost.h"
+#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ReplayInlineAdvisor.h"
@@ -58,8 +58,6 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/SampleProf.h"
#include "llvm/ProfileData/SampleProfReader.h"
@@ -67,6 +65,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/ProfiledCallGraph.h"
@@ -129,6 +128,11 @@ static cl::opt<std::string> SampleProfileRemappingFile(
"sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"),
cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden);
+static cl::opt<bool> SalvageStaleProfile(
+ "salvage-stale-profile", cl::Hidden, cl::init(false),
+ cl::desc("Salvage stale profile by fuzzy matching and use the remapped "
+ "location for sample profile query."));
+
static cl::opt<bool> ReportProfileStaleness(
"report-profile-staleness", cl::Hidden, cl::init(false),
cl::desc("Compute and report stale profile statistical metrics."));
@@ -138,6 +142,11 @@ static cl::opt<bool> PersistProfileStaleness(
cl::desc("Compute stale profile statistical metrics and write it into the "
"native object file(.llvm_stats section)."));
+static cl::opt<bool> FlattenProfileForMatching(
+ "flatten-profile-for-matching", cl::Hidden, cl::init(true),
+ cl::desc(
+ "Use flattened profile for stale profile detection and matching."));
+
static cl::opt<bool> ProfileSampleAccurate(
"profile-sample-accurate", cl::Hidden, cl::init(false),
cl::desc("If the sample profile is accurate, we will mark all un-sampled "
@@ -173,9 +182,6 @@ static cl::opt<bool>
cl::desc("Process functions in a top-down order "
"defined by the profiled call graph when "
"-sample-profile-top-down-load is on."));
-cl::opt<bool>
- SortProfiledSCC("sort-profiled-scc-member", cl::init(true), cl::Hidden,
- cl::desc("Sort profiled recursion by edge weights."));
static cl::opt<bool> ProfileSizeInline(
"sample-profile-inline-size", cl::Hidden, cl::init(false),
@@ -191,6 +197,11 @@ static cl::opt<bool> DisableSampleLoaderInlining(
"pass, and merge (or scale) profiles (as configured by "
"--sample-profile-merge-inlinee)."));
+namespace llvm {
+cl::opt<bool>
+ SortProfiledSCC("sort-profiled-scc-member", cl::init(true), cl::Hidden,
+ cl::desc("Sort profiled recursion by edge weights."));
+
cl::opt<int> ProfileInlineGrowthLimit(
"sample-profile-inline-growth-limit", cl::Hidden, cl::init(12),
cl::desc("The size growth ratio limit for proirity-based sample profile "
@@ -214,6 +225,7 @@ cl::opt<int> SampleHotCallSiteThreshold(
cl::opt<int> SampleColdCallSiteThreshold(
"sample-profile-cold-inline-threshold", cl::Hidden, cl::init(45),
cl::desc("Threshold for inlining cold callsites"));
+} // namespace llvm
static cl::opt<unsigned> ProfileICPRelativeHotness(
"sample-profile-icp-relative-hotness", cl::Hidden, cl::init(25),
@@ -307,7 +319,9 @@ static cl::opt<bool> AnnotateSampleProfileInlinePhase(
cl::desc("Annotate LTO phase (prelink / postlink), or main (no LTO) for "
"sample-profile inline pass name."));
+namespace llvm {
extern cl::opt<bool> EnableExtTspBlockPlacement;
+}
namespace {
@@ -428,6 +442,11 @@ class SampleProfileMatcher {
Module &M;
SampleProfileReader &Reader;
const PseudoProbeManager *ProbeManager;
+ SampleProfileMap FlattenedProfiles;
+ // For each function, the matcher generates a map, of which each entry is a
+ // mapping from the source location of current build to the source location in
+ // the profile.
+ StringMap<LocToLocMap> FuncMappings;
// Profile mismatching statstics.
uint64_t TotalProfiledCallsites = 0;
@@ -442,9 +461,43 @@ class SampleProfileMatcher {
public:
SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
const PseudoProbeManager *ProbeManager)
- : M(M), Reader(Reader), ProbeManager(ProbeManager) {}
- void detectProfileMismatch();
- void detectProfileMismatch(const Function &F, const FunctionSamples &FS);
+ : M(M), Reader(Reader), ProbeManager(ProbeManager) {
+ if (FlattenProfileForMatching) {
+ ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+ FunctionSamples::ProfileIsCS);
+ }
+ }
+ void runOnModule();
+
+private:
+ FunctionSamples *getFlattenedSamplesFor(const Function &F) {
+ StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
+ auto It = FlattenedProfiles.find(CanonFName);
+ if (It != FlattenedProfiles.end())
+ return &It->second;
+ return nullptr;
+ }
+ void runOnFunction(const Function &F, const FunctionSamples &FS);
+ void countProfileMismatches(
+ const FunctionSamples &FS,
+ const std::unordered_set<LineLocation, LineLocationHash>
+ &MatchedCallsiteLocs,
+ uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
+
+ LocToLocMap &getIRToProfileLocationMap(const Function &F) {
+ auto Ret = FuncMappings.try_emplace(
+ FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
+ return Ret.first->second;
+ }
+ void distributeIRToProfileLocationMap();
+ void distributeIRToProfileLocationMap(FunctionSamples &FS);
+ void populateProfileCallsites(
+ const FunctionSamples &FS,
+ StringMap<std::set<LineLocation>> &CalleeToCallsitesMap);
+ void runStaleProfileMatching(
+ const std::map<LineLocation, StringRef> &IRLocations,
+ StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+ LocToLocMap &IRToProfileLocationMap);
};
/// Sample profile pass.
@@ -452,15 +505,16 @@ public:
/// This pass reads profile data from the file specified by
/// -sample-profile-file and annotates every affected function with the
/// profile information found in that file.
-class SampleProfileLoader final
- : public SampleProfileLoaderBaseImpl<BasicBlock> {
+class SampleProfileLoader final : public SampleProfileLoaderBaseImpl<Function> {
public:
SampleProfileLoader(
StringRef Name, StringRef RemapName, ThinOrFullLTOPhase LTOPhase,
+ IntrusiveRefCntPtr<vfs::FileSystem> FS,
std::function<AssumptionCache &(Function &)> GetAssumptionCache,
std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo,
std::function<const TargetLibraryInfo &(Function &)> GetTLI)
- : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)),
+ : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
+ std::move(FS)),
GetAC(std::move(GetAssumptionCache)),
GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)),
LTOPhase(LTOPhase),
@@ -471,13 +525,12 @@ public:
bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr);
bool runOnModule(Module &M, ModuleAnalysisManager *AM,
- ProfileSummaryInfo *_PSI, CallGraph *CG);
+ ProfileSummaryInfo *_PSI, LazyCallGraph &CG);
protected:
bool runOnFunction(Function &F, ModuleAnalysisManager *AM);
bool emitAnnotations(Function &F);
ErrorOr<uint64_t> getInstWeight(const Instruction &I) override;
- ErrorOr<uint64_t> getProbeWeight(const Instruction &I);
const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const;
const FunctionSamples *
findFunctionSamples(const Instruction &I) const override;
@@ -512,8 +565,8 @@ protected:
void promoteMergeNotInlinedContextSamples(
MapVector<CallBase *, const FunctionSamples *> NonInlinedCallSites,
const Function &F);
- std::vector<Function *> buildFunctionOrder(Module &M, CallGraph *CG);
- std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(CallGraph &CG);
+ std::vector<Function *> buildFunctionOrder(Module &M, LazyCallGraph &CG);
+ std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(Module &M);
void generateMDProfMetadata(Function &F);
/// Map from function name to Function *. Used to find the function from
@@ -573,9 +626,6 @@ protected:
// External inline advisor used to replay inline decision from remarks.
std::unique_ptr<InlineAdvisor> ExternalInlineAdvisor;
- // A pseudo probe helper to correlate the imported sample counts.
- std::unique_ptr<PseudoProbeManager> ProbeManager;
-
// A helper to implement the sample profile matching algorithm.
std::unique_ptr<SampleProfileMatcher> MatchingManager;
@@ -586,6 +636,50 @@ private:
};
} // end anonymous namespace
+namespace llvm {
+template <>
+inline bool SampleProfileInference<Function>::isExit(const BasicBlock *BB) {
+ return succ_empty(BB);
+}
+
+template <>
+inline void SampleProfileInference<Function>::findUnlikelyJumps(
+ const std::vector<const BasicBlockT *> &BasicBlocks,
+ BlockEdgeMap &Successors, FlowFunction &Func) {
+ for (auto &Jump : Func.Jumps) {
+ const auto *BB = BasicBlocks[Jump.Source];
+ const auto *Succ = BasicBlocks[Jump.Target];
+ const Instruction *TI = BB->getTerminator();
+ // Check if a block ends with InvokeInst and mark non-taken branch unlikely.
+ // In that case block Succ should be a landing pad
+ if (Successors[BB].size() == 2 && Successors[BB].back() == Succ) {
+ if (isa<InvokeInst>(TI)) {
+ Jump.IsUnlikely = true;
+ }
+ }
+ const Instruction *SuccTI = Succ->getTerminator();
+ // Check if the target block contains UnreachableInst and mark it unlikely
+ if (SuccTI->getNumSuccessors() == 0) {
+ if (isa<UnreachableInst>(SuccTI)) {
+ Jump.IsUnlikely = true;
+ }
+ }
+ }
+}
+
+template <>
+void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo(
+ Function &F) {
+ DT.reset(new DominatorTree);
+ DT->recalculate(F);
+
+ PDT.reset(new PostDominatorTree(F));
+
+ LI.reset(new LoopInfo);
+ LI->analyze(*DT);
+}
+} // namespace llvm
+
ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
if (FunctionSamples::ProfileIsProbeBased)
return getProbeWeight(Inst);
@@ -614,68 +708,6 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
return getInstWeightImpl(Inst);
}
-// Here use error_code to represent: 1) The dangling probe. 2) Ignore the weight
-// of non-probe instruction. So if all instructions of the BB give error_code,
-// tell the inference algorithm to infer the BB weight.
-ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) {
- assert(FunctionSamples::ProfileIsProbeBased &&
- "Profile is not pseudo probe based");
- std::optional<PseudoProbe> Probe = extractProbe(Inst);
- // Ignore the non-probe instruction. If none of the instruction in the BB is
- // probe, we choose to infer the BB's weight.
- if (!Probe)
- return std::error_code();
-
- const FunctionSamples *FS = findFunctionSamples(Inst);
- // If none of the instruction has FunctionSample, we choose to return zero
- // value sample to indicate the BB is cold. This could happen when the
- // instruction is from inlinee and no profile data is found.
- // FIXME: This should not be affected by the source drift issue as 1) if the
- // newly added function is top-level inliner, it won't match the CFG checksum
- // in the function profile or 2) if it's the inlinee, the inlinee should have
- // a profile, otherwise it wouldn't be inlined. For non-probe based profile,
- // we can improve it by adding a switch for profile-sample-block-accurate for
- // block level counts in the future.
- if (!FS)
- return 0;
-
- // For non-CS profile, If a direct call/invoke instruction is inlined in
- // profile (findCalleeFunctionSamples returns non-empty result), but not
- // inlined here, it means that the inlined callsite has no sample, thus the
- // call instruction should have 0 count.
- // For CS profile, the callsite count of previously inlined callees is
- // populated with the entry count of the callees.
- if (!FunctionSamples::ProfileIsCS)
- if (const auto *CB = dyn_cast<CallBase>(&Inst))
- if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB))
- return 0;
-
- const ErrorOr<uint64_t> &R = FS->findSamplesAt(Probe->Id, 0);
- if (R) {
- uint64_t Samples = R.get() * Probe->Factor;
- bool FirstMark = CoverageTracker.markSamplesUsed(FS, Probe->Id, 0, Samples);
- if (FirstMark) {
- ORE->emit([&]() {
- OptimizationRemarkAnalysis Remark(DEBUG_TYPE, "AppliedSamples", &Inst);
- Remark << "Applied " << ore::NV("NumSamples", Samples);
- Remark << " samples from profile (ProbeId=";
- Remark << ore::NV("ProbeId", Probe->Id);
- Remark << ", Factor=";
- Remark << ore::NV("Factor", Probe->Factor);
- Remark << ", OriginalSamples=";
- Remark << ore::NV("OriginalSamples", R.get());
- Remark << ")";
- return Remark;
- });
- }
- LLVM_DEBUG(dbgs() << " " << Probe->Id << ":" << Inst
- << " - weight: " << R.get() << " - factor: "
- << format("%0.2f", Probe->Factor) << ")\n");
- return Samples;
- }
- return R;
-}
-
/// Get the FunctionSamples for a call instruction.
///
/// The FunctionSamples of a call/invoke instruction \p Inst is the inlined
@@ -1041,8 +1073,8 @@ void SampleProfileLoader::findExternalInlineCandidate(
DenseSet<GlobalValue::GUID> &InlinedGUIDs,
const StringMap<Function *> &SymbolMap, uint64_t Threshold) {
- // If ExternalInlineAdvisor wants to inline an external function
- // make sure it's imported
+ // If ExternalInlineAdvisor(ReplayInlineAdvisor) wants to inline an external
+ // function make sure it's imported
if (CB && getExternalInlineAdvisorShouldInline(*CB)) {
// Samples may not exist for replayed function, if so
// just add the direct GUID and move on
@@ -1055,7 +1087,13 @@ void SampleProfileLoader::findExternalInlineCandidate(
Threshold = 0;
}
- assert(Samples && "expect non-null caller profile");
+ // In some rare cases, call instruction could be changed after being pushed
+ // into inline candidate queue, this is because earlier inlining may expose
+ // constant propagation which can change indirect call to direct call. When
+ // this happens, we may fail to find matching function samples for the
+ // candidate later, even if a match was found when the candidate was enqueued.
+ if (!Samples)
+ return;
// For AutoFDO profile, retrieve candidate profiles by walking over
// the nested inlinee profiles.
@@ -1255,7 +1293,7 @@ bool SampleProfileLoader::tryInlineCandidate(
if (!Cost)
return false;
- InlineFunctionInfo IFI(nullptr, GetAC);
+ InlineFunctionInfo IFI(GetAC);
IFI.UpdateProfile = false;
InlineResult IR = InlineFunction(CB, IFI,
/*MergeAttributes=*/true);
@@ -1784,9 +1822,10 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {
if (!ProbeManager->profileIsValid(F, *Samples)) {
LLVM_DEBUG(
dbgs() << "Profile is invalid due to CFG mismatch for Function "
- << F.getName());
+ << F.getName() << "\n");
++NumMismatchedProfile;
- return false;
+ if (!SalvageStaleProfile)
+ return false;
}
++NumMatchedProfile;
} else {
@@ -1813,7 +1852,7 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {
}
std::unique_ptr<ProfiledCallGraph>
-SampleProfileLoader::buildProfiledCallGraph(CallGraph &CG) {
+SampleProfileLoader::buildProfiledCallGraph(Module &M) {
std::unique_ptr<ProfiledCallGraph> ProfiledCG;
if (FunctionSamples::ProfileIsCS)
ProfiledCG = std::make_unique<ProfiledCallGraph>(*ContextTracker);
@@ -1823,18 +1862,17 @@ SampleProfileLoader::buildProfiledCallGraph(CallGraph &CG) {
// Add all functions into the profiled call graph even if they are not in
// the profile. This makes sure functions missing from the profile still
// gets a chance to be processed.
- for (auto &Node : CG) {
- const auto *F = Node.first;
- if (!F || F->isDeclaration() || !F->hasFnAttribute("use-sample-profile"))
+ for (Function &F : M) {
+ if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
continue;
- ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(*F));
+ ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(F));
}
return ProfiledCG;
}
std::vector<Function *>
-SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) {
+SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
std::vector<Function *> FunctionOrderList;
FunctionOrderList.reserve(M.size());
@@ -1842,7 +1880,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) {
errs() << "WARNING: -use-profiled-call-graph ignored, should be used "
"together with -sample-profile-top-down-load.\n";
- if (!ProfileTopDownLoad || CG == nullptr) {
+ if (!ProfileTopDownLoad) {
if (ProfileMergeInlinee) {
// Disable ProfileMergeInlinee if profile is not loaded in top down order,
// because the profile for a function may be used for the profile
@@ -1858,8 +1896,6 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) {
return FunctionOrderList;
}
- assert(&CG->getModule() == &M);
-
if (UseProfiledCallGraph || (FunctionSamples::ProfileIsCS &&
!UseProfiledCallGraph.getNumOccurrences())) {
// Use profiled call edges to augment the top-down order. There are cases
@@ -1910,7 +1946,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) {
// static call edges are not so important when they don't correspond to a
// context in the profile.
- std::unique_ptr<ProfiledCallGraph> ProfiledCG = buildProfiledCallGraph(*CG);
+ std::unique_ptr<ProfiledCallGraph> ProfiledCG = buildProfiledCallGraph(M);
scc_iterator<ProfiledCallGraph *> CGI = scc_begin(ProfiledCG.get());
while (!CGI.isAtEnd()) {
auto Range = *CGI;
@@ -1927,25 +1963,27 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) {
++CGI;
}
} else {
- scc_iterator<CallGraph *> CGI = scc_begin(CG);
- while (!CGI.isAtEnd()) {
- for (CallGraphNode *Node : *CGI) {
- auto *F = Node->getFunction();
- if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile"))
- FunctionOrderList.push_back(F);
+ CG.buildRefSCCs();
+ for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) {
+ for (LazyCallGraph::SCC &C : RC) {
+ for (LazyCallGraph::Node &N : C) {
+ Function &F = N.getFunction();
+ if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile"))
+ FunctionOrderList.push_back(&F);
+ }
}
- ++CGI;
}
}
+ std::reverse(FunctionOrderList.begin(), FunctionOrderList.end());
+
LLVM_DEBUG({
dbgs() << "Function processing order:\n";
- for (auto F : reverse(FunctionOrderList)) {
+ for (auto F : FunctionOrderList) {
dbgs() << F->getName() << "\n";
}
});
- std::reverse(FunctionOrderList.begin(), FunctionOrderList.end());
return FunctionOrderList;
}
@@ -1954,7 +1992,7 @@ bool SampleProfileLoader::doInitialization(Module &M,
auto &Ctx = M.getContext();
auto ReaderOrErr = SampleProfileReader::create(
- Filename, Ctx, FSDiscriminatorPass::Base, RemappingFilename);
+ Filename, Ctx, *FS, FSDiscriminatorPass::Base, RemappingFilename);
if (std::error_code EC = ReaderOrErr.getError()) {
std::string Msg = "Could not open profile: " + EC.message();
Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
@@ -2016,6 +2054,16 @@ bool SampleProfileLoader::doInitialization(Module &M,
UsePreInlinerDecision = true;
}
+ // Enable stale profile matching by default for probe-based profile.
+ // Currently the matching relies on if the checksum mismatch is detected,
+ // which is currently only available for pseudo-probe mode. Removing the
+ // checksum check could cause regressions for some cases, so further tuning
+ // might be needed if we want to enable it for all cases.
+ if (Reader->profileIsProbeBased() &&
+ !SalvageStaleProfile.getNumOccurrences()) {
+ SalvageStaleProfile = true;
+ }
+
if (!Reader->profileIsCS()) {
// Non-CS profile should be fine without a function size budget for the
// inliner since the contexts in the profile are either all from inlining
@@ -2046,7 +2094,8 @@ bool SampleProfileLoader::doInitialization(Module &M,
}
}
- if (ReportProfileStaleness || PersistProfileStaleness) {
+ if (ReportProfileStaleness || PersistProfileStaleness ||
+ SalvageStaleProfile) {
MatchingManager =
std::make_unique<SampleProfileMatcher>(M, *Reader, ProbeManager.get());
}
@@ -2054,8 +2103,167 @@ bool SampleProfileLoader::doInitialization(Module &M,
return true;
}
-void SampleProfileMatcher::detectProfileMismatch(const Function &F,
- const FunctionSamples &FS) {
+void SampleProfileMatcher::countProfileMismatches(
+ const FunctionSamples &FS,
+ const std::unordered_set<LineLocation, LineLocationHash>
+ &MatchedCallsiteLocs,
+ uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
+
+ auto isInvalidLineOffset = [](uint32_t LineOffset) {
+ return LineOffset & 0x8000;
+ };
+
+ // Check if there are any callsites in the profile that does not match to any
+ // IR callsites, those callsite samples will be discarded.
+ for (auto &I : FS.getBodySamples()) {
+ const LineLocation &Loc = I.first;
+ if (isInvalidLineOffset(Loc.LineOffset))
+ continue;
+
+ uint64_t Count = I.second.getSamples();
+ if (!I.second.getCallTargets().empty()) {
+ TotalCallsiteSamples += Count;
+ FuncProfiledCallsites++;
+ if (!MatchedCallsiteLocs.count(Loc)) {
+ MismatchedCallsiteSamples += Count;
+ FuncMismatchedCallsites++;
+ }
+ }
+ }
+
+ for (auto &I : FS.getCallsiteSamples()) {
+ const LineLocation &Loc = I.first;
+ if (isInvalidLineOffset(Loc.LineOffset))
+ continue;
+
+ uint64_t Count = 0;
+ for (auto &FM : I.second) {
+ Count += FM.second.getHeadSamplesEstimate();
+ }
+ TotalCallsiteSamples += Count;
+ FuncProfiledCallsites++;
+ if (!MatchedCallsiteLocs.count(Loc)) {
+ MismatchedCallsiteSamples += Count;
+ FuncMismatchedCallsites++;
+ }
+ }
+}
+
+// Populate the anchors(direct callee name) from profile.
+void SampleProfileMatcher::populateProfileCallsites(
+ const FunctionSamples &FS,
+ StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) {
+ for (const auto &I : FS.getBodySamples()) {
+ const auto &Loc = I.first;
+ const auto &CTM = I.second.getCallTargets();
+ // Filter out possible indirect calls, use direct callee name as anchor.
+ if (CTM.size() == 1) {
+ StringRef CalleeName = CTM.begin()->first();
+ const auto &Candidates = CalleeToCallsitesMap.try_emplace(
+ CalleeName, std::set<LineLocation>());
+ Candidates.first->second.insert(Loc);
+ }
+ }
+
+ for (const auto &I : FS.getCallsiteSamples()) {
+ const LineLocation &Loc = I.first;
+ const auto &CalleeMap = I.second;
+ // Filter out possible indirect calls, use direct callee name as anchor.
+ if (CalleeMap.size() == 1) {
+ StringRef CalleeName = CalleeMap.begin()->first;
+ const auto &Candidates = CalleeToCallsitesMap.try_emplace(
+ CalleeName, std::set<LineLocation>());
+ Candidates.first->second.insert(Loc);
+ }
+ }
+}
+
+// Call target name anchor based profile fuzzy matching.
+// Input:
+// For IR locations, the anchor is the callee name of direct callsite; For
+// profile locations, it's the call target name for BodySamples or inlinee's
+// profile name for CallsiteSamples.
+// Matching heuristic:
+// First match all the anchors in lexical order, then split the non-anchor
+// locations between the two anchors evenly, first half are matched based on the
+// start anchor, second half are matched based on the end anchor.
+// For example, given:
+// IR locations: [1, 2(foo), 3, 5, 6(bar), 7]
+// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9]
+// The matching gives:
+// [1, 2(foo), 3, 5, 6(bar), 7]
+// | | | | | |
+// [1, 2, 3(foo), 4, 7, 8(bar), 9]
+// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
+void SampleProfileMatcher::runStaleProfileMatching(
+ const std::map<LineLocation, StringRef> &IRLocations,
+ StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+ LocToLocMap &IRToProfileLocationMap) {
+ assert(IRToProfileLocationMap.empty() &&
+ "Run stale profile matching only once per function");
+
+ auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) {
+ // Skip the unchanged location mapping to save memory.
+ if (From != To)
+ IRToProfileLocationMap.insert({From, To});
+ };
+
+ // Use function's beginning location as the initial anchor.
+ int32_t LocationDelta = 0;
+ SmallVector<LineLocation> LastMatchedNonAnchors;
+
+ for (const auto &IR : IRLocations) {
+ const auto &Loc = IR.first;
+ StringRef CalleeName = IR.second;
+ bool IsMatchedAnchor = false;
+ // Match the anchor location in lexical order.
+ if (!CalleeName.empty()) {
+ auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName);
+ if (ProfileAnchors != CalleeToCallsitesMap.end() &&
+ !ProfileAnchors->second.empty()) {
+ auto CI = ProfileAnchors->second.begin();
+ const auto Candidate = *CI;
+ ProfileAnchors->second.erase(CI);
+ InsertMatching(Loc, Candidate);
+ LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName
+ << " is matched from " << Loc << " to " << Candidate
+ << "\n");
+ LocationDelta = Candidate.LineOffset - Loc.LineOffset;
+
+ // Match backwards for non-anchor locations.
+ // The locations in LastMatchedNonAnchors have been matched forwards
+ // based on the previous anchor, spilt it evenly and overwrite the
+ // second half based on the current anchor.
+ for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2;
+ I < LastMatchedNonAnchors.size(); I++) {
+ const auto &L = LastMatchedNonAnchors[I];
+ uint32_t CandidateLineOffset = L.LineOffset + LocationDelta;
+ LineLocation Candidate(CandidateLineOffset, L.Discriminator);
+ InsertMatching(L, Candidate);
+ LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L
+ << " to " << Candidate << "\n");
+ }
+
+ IsMatchedAnchor = true;
+ LastMatchedNonAnchors.clear();
+ }
+ }
+
+ // Match forwards for non-anchor locations.
+ if (!IsMatchedAnchor) {
+ uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta;
+ LineLocation Candidate(CandidateLineOffset, Loc.Discriminator);
+ InsertMatching(Loc, Candidate);
+ LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to "
+ << Candidate << "\n");
+ LastMatchedNonAnchors.emplace_back(Loc);
+ }
+ }
+}
+
+void SampleProfileMatcher::runOnFunction(const Function &F,
+ const FunctionSamples &FS) {
+ bool IsFuncHashMismatch = false;
if (FunctionSamples::ProfileIsProbeBased) {
uint64_t Count = FS.getTotalSamples();
TotalFuncHashSamples += Count;
@@ -2063,16 +2271,24 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F,
if (!ProbeManager->profileIsValid(F, FS)) {
MismatchedFuncHashSamples += Count;
NumMismatchedFuncHash++;
- return;
+ IsFuncHashMismatch = true;
}
}
std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs;
+ // The value of the map is the name of direct callsite and use empty StringRef
+ // for non-direct-call site.
+ std::map<LineLocation, StringRef> IRLocations;
- // Go through all the callsites on the IR and flag the callsite if the target
- // name is the same as the one in the profile.
+ // Extract profile matching anchors and profile mismatch metrics in the IR.
for (auto &BB : F) {
for (auto &I : BB) {
+ // TODO: Support line-number based location(AutoFDO).
+ if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) {
+ if (std::optional<PseudoProbe> Probe = extractProbe(I))
+ IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef());
+ }
+
if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
continue;
@@ -2084,6 +2300,17 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F,
if (Function *Callee = CB->getCalledFunction())
CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
+ // Force to overwrite the callee name in case any non-call location was
+ // written before.
+ auto R = IRLocations.emplace(IRCallsite, CalleeName);
+ R.first->second = CalleeName;
+ assert((!FunctionSamples::ProfileIsProbeBased || R.second ||
+ R.first->second == CalleeName) &&
+ "Overwrite non-call or different callee name location for "
+ "pseudo probe callsite");
+
+ // Go through all the callsites on the IR and flag the callsite if the
+ // target name is the same as the one in the profile.
const auto CTM = FS.findCallTargetMapAt(IRCallsite);
const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite);
@@ -2105,55 +2332,54 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F,
}
}
- auto isInvalidLineOffset = [](uint32_t LineOffset) {
- return LineOffset & 0x8000;
- };
+ // Detect profile mismatch for profile staleness metrics report.
+ if (ReportProfileStaleness || PersistProfileStaleness) {
+ uint64_t FuncMismatchedCallsites = 0;
+ uint64_t FuncProfiledCallsites = 0;
+ countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites,
+ FuncProfiledCallsites);
+ TotalProfiledCallsites += FuncProfiledCallsites;
+ NumMismatchedCallsites += FuncMismatchedCallsites;
+ LLVM_DEBUG({
+ if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
+ FuncMismatchedCallsites)
+ dbgs() << "Function checksum is matched but there are "
+ << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
+ << " mismatched callsites.\n";
+ });
+ }
- // Check if there are any callsites in the profile that does not match to any
- // IR callsites, those callsite samples will be discarded.
- for (auto &I : FS.getBodySamples()) {
- const LineLocation &Loc = I.first;
- if (isInvalidLineOffset(Loc.LineOffset))
- continue;
+ if (IsFuncHashMismatch && SalvageStaleProfile) {
+ LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
+ << "\n");
- uint64_t Count = I.second.getSamples();
- if (!I.second.getCallTargets().empty()) {
- TotalCallsiteSamples += Count;
- TotalProfiledCallsites++;
- if (!MatchedCallsiteLocs.count(Loc)) {
- MismatchedCallsiteSamples += Count;
- NumMismatchedCallsites++;
- }
- }
- }
+ StringMap<std::set<LineLocation>> CalleeToCallsitesMap;
+ populateProfileCallsites(FS, CalleeToCallsitesMap);
- for (auto &I : FS.getCallsiteSamples()) {
- const LineLocation &Loc = I.first;
- if (isInvalidLineOffset(Loc.LineOffset))
- continue;
+ // The matching result will be saved to IRToProfileLocationMap, create a new
+ // map for each function.
+ auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
- uint64_t Count = 0;
- for (auto &FM : I.second) {
- Count += FM.second.getHeadSamplesEstimate();
- }
- TotalCallsiteSamples += Count;
- TotalProfiledCallsites++;
- if (!MatchedCallsiteLocs.count(Loc)) {
- MismatchedCallsiteSamples += Count;
- NumMismatchedCallsites++;
- }
+ runStaleProfileMatching(IRLocations, CalleeToCallsitesMap,
+ IRToProfileLocationMap);
}
}
-void SampleProfileMatcher::detectProfileMismatch() {
+void SampleProfileMatcher::runOnModule() {
for (auto &F : M) {
if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
continue;
- FunctionSamples *FS = Reader.getSamplesFor(F);
+ FunctionSamples *FS = nullptr;
+ if (FlattenProfileForMatching)
+ FS = getFlattenedSamplesFor(F);
+ else
+ FS = Reader.getSamplesFor(F);
if (!FS)
continue;
- detectProfileMismatch(F, *FS);
+ runOnFunction(F, *FS);
}
+ if (SalvageStaleProfile)
+ distributeIRToProfileLocationMap();
if (ReportProfileStaleness) {
if (FunctionSamples::ProfileIsProbeBased) {
@@ -2196,8 +2422,31 @@ void SampleProfileMatcher::detectProfileMismatch() {
}
}
+void SampleProfileMatcher::distributeIRToProfileLocationMap(
+ FunctionSamples &FS) {
+ const auto ProfileMappings = FuncMappings.find(FS.getName());
+ if (ProfileMappings != FuncMappings.end()) {
+ FS.setIRToProfileLocationMap(&(ProfileMappings->second));
+ }
+
+ for (auto &Inlinees : FS.getCallsiteSamples()) {
+ for (auto FS : Inlinees.second) {
+ distributeIRToProfileLocationMap(FS.second);
+ }
+ }
+}
+
+// Use a central place to distribute the matching results. Outlined and inlined
+// profile with the function name will be set to the same pointer.
+void SampleProfileMatcher::distributeIRToProfileLocationMap() {
+ for (auto &I : Reader.getProfiles()) {
+ distributeIRToProfileLocationMap(I.second);
+ }
+}
+
bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
- ProfileSummaryInfo *_PSI, CallGraph *CG) {
+ ProfileSummaryInfo *_PSI,
+ LazyCallGraph &CG) {
GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap);
PSI = _PSI;
@@ -2240,8 +2489,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
assert(SymbolMap.count(StringRef()) == 0 &&
"No empty StringRef should be added in SymbolMap");
- if (ReportProfileStaleness || PersistProfileStaleness)
- MatchingManager->detectProfileMismatch();
+ if (ReportProfileStaleness || PersistProfileStaleness ||
+ SalvageStaleProfile) {
+ MatchingManager->runOnModule();
+ }
bool retval = false;
for (auto *F : buildFunctionOrder(M, CG)) {
@@ -2327,6 +2578,11 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM)
return emitAnnotations(F);
return false;
}
+SampleProfileLoaderPass::SampleProfileLoaderPass(
+ std::string File, std::string RemappingFile, ThinOrFullLTOPhase LTOPhase,
+ IntrusiveRefCntPtr<vfs::FileSystem> FS)
+ : ProfileFileName(File), ProfileRemappingFileName(RemappingFile),
+ LTOPhase(LTOPhase), FS(std::move(FS)) {}
PreservedAnalyses SampleProfileLoaderPass::run(Module &M,
ModuleAnalysisManager &AM) {
@@ -2343,18 +2599,21 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M,
return FAM.getResult<TargetLibraryAnalysis>(F);
};
+ if (!FS)
+ FS = vfs::getRealFileSystem();
+
SampleProfileLoader SampleLoader(
ProfileFileName.empty() ? SampleProfileFile : ProfileFileName,
ProfileRemappingFileName.empty() ? SampleProfileRemappingFile
: ProfileRemappingFileName,
- LTOPhase, GetAssumptionCache, GetTTI, GetTLI);
+ LTOPhase, FS, GetAssumptionCache, GetTTI, GetTLI);
if (!SampleLoader.doInitialization(M, &FAM))
return PreservedAnalyses::all();
ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
- CallGraph &CG = AM.getResult<CallGraphAnalysis>(M);
- if (!SampleLoader.runOnModule(M, &AM, PSI, &CG))
+ LazyCallGraph &CG = AM.getResult<LazyCallGraphAnalysis>(M);
+ if (!SampleLoader.runOnModule(M, &AM, PSI, CG))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
diff --git a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
index c4844dbe7f3c..0a42de7224b4 100644
--- a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
@@ -13,6 +13,7 @@
#include "llvm/Transforms/IPO/SampleProfileProbe.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/EHUtils.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
@@ -32,7 +33,7 @@
#include <vector>
using namespace llvm;
-#define DEBUG_TYPE "sample-profile-probe"
+#define DEBUG_TYPE "pseudo-probe"
STATISTIC(ArtificialDbgLine,
"Number of probes that have an artificial debug line");
@@ -55,11 +56,7 @@ static uint64_t getCallStackHash(const DILocation *DIL) {
while (InlinedAt) {
Hash ^= MD5Hash(std::to_string(InlinedAt->getLine()));
Hash ^= MD5Hash(std::to_string(InlinedAt->getColumn()));
- const DISubprogram *SP = InlinedAt->getScope()->getSubprogram();
- // Use linkage name for C++ if possible.
- auto Name = SP->getLinkageName();
- if (Name.empty())
- Name = SP->getName();
+ auto Name = InlinedAt->getSubprogramLinkageName();
Hash ^= MD5Hash(Name);
InlinedAt = InlinedAt->getInlinedAt();
}
@@ -169,47 +166,6 @@ void PseudoProbeVerifier::verifyProbeFactors(
}
}
-PseudoProbeManager::PseudoProbeManager(const Module &M) {
- if (NamedMDNode *FuncInfo = M.getNamedMetadata(PseudoProbeDescMetadataName)) {
- for (const auto *Operand : FuncInfo->operands()) {
- const auto *MD = cast<MDNode>(Operand);
- auto GUID =
- mdconst::dyn_extract<ConstantInt>(MD->getOperand(0))->getZExtValue();
- auto Hash =
- mdconst::dyn_extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
- GUIDToProbeDescMap.try_emplace(GUID, PseudoProbeDescriptor(GUID, Hash));
- }
- }
-}
-
-const PseudoProbeDescriptor *
-PseudoProbeManager::getDesc(const Function &F) const {
- auto I = GUIDToProbeDescMap.find(
- Function::getGUID(FunctionSamples::getCanonicalFnName(F)));
- return I == GUIDToProbeDescMap.end() ? nullptr : &I->second;
-}
-
-bool PseudoProbeManager::moduleIsProbed(const Module &M) const {
- return M.getNamedMetadata(PseudoProbeDescMetadataName);
-}
-
-bool PseudoProbeManager::profileIsValid(const Function &F,
- const FunctionSamples &Samples) const {
- const auto *Desc = getDesc(F);
- if (!Desc) {
- LLVM_DEBUG(dbgs() << "Probe descriptor missing for Function " << F.getName()
- << "\n");
- return false;
- } else {
- if (Desc->getFunctionHash() != Samples.getFunctionHash()) {
- LLVM_DEBUG(dbgs() << "Hash mismatch for Function " << F.getName()
- << "\n");
- return false;
- }
- }
- return true;
-}
-
SampleProfileProber::SampleProfileProber(Function &Func,
const std::string &CurModuleUniqueId)
: F(&Func), CurModuleUniqueId(CurModuleUniqueId) {
@@ -253,8 +209,14 @@ void SampleProfileProber::computeCFGHash() {
}
void SampleProfileProber::computeProbeIdForBlocks() {
+ DenseSet<BasicBlock *> KnownColdBlocks;
+ computeEHOnlyBlocks(*F, KnownColdBlocks);
+ // Insert pseudo probe to non-cold blocks only. This will reduce IR size as
+ // well as the binary size while retaining the profile quality.
for (auto &BB : *F) {
- BlockProbeIds[&BB] = ++LastProbeId;
+ ++LastProbeId;
+ if (!KnownColdBlocks.contains(&BB))
+ BlockProbeIds[&BB] = LastProbeId;
}
}
@@ -283,9 +245,16 @@ uint32_t SampleProfileProber::getCallsiteId(const Instruction *Call) const {
void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) {
Module *M = F.getParent();
MDBuilder MDB(F.getContext());
- // Compute a GUID without considering the function's linkage type. This is
- // fine since function name is the only key in the profile database.
- uint64_t Guid = Function::getGUID(F.getName());
+ // Since the GUID from probe desc and inline stack are computed seperately, we
+ // need to make sure their names are consistent, so here also use the name
+ // from debug info.
+ StringRef FName = F.getName();
+ if (auto *SP = F.getSubprogram()) {
+ FName = SP->getLinkageName();
+ if (FName.empty())
+ FName = SP->getName();
+ }
+ uint64_t Guid = Function::getGUID(FName);
// Assign an artificial debug line to a probe that doesn't come with a real
// line. A probe not having a debug line will get an incomplete inline
@@ -339,6 +308,14 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) {
Builder.getInt64(PseudoProbeFullDistributionFactor)};
auto *Probe = Builder.CreateCall(ProbeFn, Args);
AssignDebugLoc(Probe);
+ // Reset the dwarf discriminator if the debug location comes with any. The
+ // discriminator field may be used by FS-AFDO later in the pipeline.
+ if (auto DIL = Probe->getDebugLoc()) {
+ if (DIL->getDiscriminator()) {
+ DIL = DIL->cloneWithDiscriminator(0);
+ Probe->setDebugLoc(DIL);
+ }
+ }
}
// Probe both direct calls and indirect calls. Direct calls are probed so that
@@ -351,12 +328,13 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) {
? (uint32_t)PseudoProbeType::DirectCall
: (uint32_t)PseudoProbeType::IndirectCall;
AssignDebugLoc(Call);
- // Levarge the 32-bit discriminator field of debug data to store the ID and
- // type of a callsite probe. This gets rid of the dependency on plumbing a
- // customized metadata through the codegen pipeline.
- uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData(
- Index, Type, 0, PseudoProbeDwarfDiscriminator::FullDistributionFactor);
if (auto DIL = Call->getDebugLoc()) {
+ // Levarge the 32-bit discriminator field of debug data to store the ID
+ // and type of a callsite probe. This gets rid of the dependency on
+ // plumbing a customized metadata through the codegen pipeline.
+ uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData(
+ Index, Type, 0,
+ PseudoProbeDwarfDiscriminator::FullDistributionFactor);
DIL = DIL->cloneWithDiscriminator(V);
Call->setDebugLoc(DIL);
}
@@ -368,28 +346,10 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) {
// - FunctionHash.
// - FunctionName
auto Hash = getFunctionHash();
- auto *MD = MDB.createPseudoProbeDesc(Guid, Hash, &F);
+ auto *MD = MDB.createPseudoProbeDesc(Guid, Hash, FName);
auto *NMD = M->getNamedMetadata(PseudoProbeDescMetadataName);
assert(NMD && "llvm.pseudo_probe_desc should be pre-created");
NMD->addOperand(MD);
-
- // Preserve a comdat group to hold all probes materialized later. This
- // allows that when the function is considered dead and removed, the
- // materialized probes are disposed too.
- // Imported functions are defined in another module. They do not need
- // the following handling since same care will be taken for them in their
- // original module. The pseudo probes inserted into an imported functions
- // above will naturally not be emitted since the imported function is free
- // from object emission. However they will be emitted together with the
- // inliner functions that the imported function is inlined into. We are not
- // creating a comdat group for an import function since it's useless anyway.
- if (!F.isDeclarationForLinker()) {
- if (TM) {
- auto Triple = TM->getTargetTriple();
- if (Triple.supportsCOMDAT() && TM->getFunctionSections())
- getOrCreateFunctionComdat(F, Triple);
- }
- }
}
PreservedAnalyses SampleProfileProbePass::run(Module &M,
diff --git a/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp b/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp
index 0f2412dce1c9..53d5b18dcead 100644
--- a/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp
+++ b/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp
@@ -16,8 +16,6 @@
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/IPO.h"
using namespace llvm;
@@ -56,30 +54,3 @@ PreservedAnalyses StripDeadPrototypesPass::run(Module &M,
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
-
-namespace {
-
-class StripDeadPrototypesLegacyPass : public ModulePass {
-public:
- static char ID; // Pass identification, replacement for typeid
- StripDeadPrototypesLegacyPass() : ModulePass(ID) {
- initializeStripDeadPrototypesLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- bool runOnModule(Module &M) override {
- if (skipModule(M))
- return false;
-
- return stripDeadPrototypes(M);
- }
-};
-
-} // end anonymous namespace
-
-char StripDeadPrototypesLegacyPass::ID = 0;
-INITIALIZE_PASS(StripDeadPrototypesLegacyPass, "strip-dead-prototypes",
- "Strip Unused Function Prototypes", false, false)
-
-ModulePass *llvm::createStripDeadPrototypesPass() {
- return new StripDeadPrototypesLegacyPass();
-}
diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp
index 34f8c4316cca..147513452789 100644
--- a/llvm/lib/Transforms/IPO/StripSymbols.cpp
+++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp
@@ -30,110 +30,12 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/TypeFinder.h"
#include "llvm/IR/ValueSymbolTable.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/StripSymbols.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
-namespace {
- class StripSymbols : public ModulePass {
- bool OnlyDebugInfo;
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit StripSymbols(bool ODI = false)
- : ModulePass(ID), OnlyDebugInfo(ODI) {
- initializeStripSymbolsPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
- };
-
- class StripNonDebugSymbols : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit StripNonDebugSymbols()
- : ModulePass(ID) {
- initializeStripNonDebugSymbolsPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
- };
-
- class StripDebugDeclare : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit StripDebugDeclare()
- : ModulePass(ID) {
- initializeStripDebugDeclarePass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
- };
-
- class StripDeadDebugInfo : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit StripDeadDebugInfo()
- : ModulePass(ID) {
- initializeStripDeadDebugInfoPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
- };
-}
-
-char StripSymbols::ID = 0;
-INITIALIZE_PASS(StripSymbols, "strip",
- "Strip all symbols from a module", false, false)
-
-ModulePass *llvm::createStripSymbolsPass(bool OnlyDebugInfo) {
- return new StripSymbols(OnlyDebugInfo);
-}
-
-char StripNonDebugSymbols::ID = 0;
-INITIALIZE_PASS(StripNonDebugSymbols, "strip-nondebug",
- "Strip all symbols, except dbg symbols, from a module",
- false, false)
-
-ModulePass *llvm::createStripNonDebugSymbolsPass() {
- return new StripNonDebugSymbols();
-}
-
-char StripDebugDeclare::ID = 0;
-INITIALIZE_PASS(StripDebugDeclare, "strip-debug-declare",
- "Strip all llvm.dbg.declare intrinsics", false, false)
-
-ModulePass *llvm::createStripDebugDeclarePass() {
- return new StripDebugDeclare();
-}
-
-char StripDeadDebugInfo::ID = 0;
-INITIALIZE_PASS(StripDeadDebugInfo, "strip-dead-debug-info",
- "Strip debug info for unused symbols", false, false)
-
-ModulePass *llvm::createStripDeadDebugInfoPass() {
- return new StripDeadDebugInfo();
-}
-
/// OnlyUsedBy - Return true if V is only used by Usr.
static bool OnlyUsedBy(Value *V, Value *Usr) {
for (User *U : V->users())
@@ -234,24 +136,6 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) {
return true;
}
-bool StripSymbols::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
-
- bool Changed = false;
- Changed |= StripDebugInfo(M);
- if (!OnlyDebugInfo)
- Changed |= StripSymbolNames(M, false);
- return Changed;
-}
-
-bool StripNonDebugSymbols::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
-
- return StripSymbolNames(M, true);
-}
-
static bool stripDebugDeclareImpl(Module &M) {
Function *Declare = M.getFunction("llvm.dbg.declare");
@@ -290,50 +174,6 @@ static bool stripDebugDeclareImpl(Module &M) {
return true;
}
-bool StripDebugDeclare::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
- return stripDebugDeclareImpl(M);
-}
-
-/// Collects compilation units referenced by functions or lexical scopes.
-/// Accepts any DIScope and uses recursive bottom-up approach to reach either
-/// DISubprogram or DILexicalBlockBase.
-static void
-collectCUsWithScope(const DIScope *Scope, std::set<DICompileUnit *> &LiveCUs,
- SmallPtrSet<const DIScope *, 8> &VisitedScopes) {
- if (!Scope)
- return;
-
- auto InS = VisitedScopes.insert(Scope);
- if (!InS.second)
- return;
-
- if (const auto *SP = dyn_cast<DISubprogram>(Scope)) {
- if (SP->getUnit())
- LiveCUs.insert(SP->getUnit());
- return;
- }
- if (const auto *LB = dyn_cast<DILexicalBlockBase>(Scope)) {
- const DISubprogram *SP = LB->getSubprogram();
- if (SP && SP->getUnit())
- LiveCUs.insert(SP->getUnit());
- return;
- }
-
- collectCUsWithScope(Scope->getScope(), LiveCUs, VisitedScopes);
-}
-
-static void
-collectCUsForInlinedFuncs(const DILocation *Loc,
- std::set<DICompileUnit *> &LiveCUs,
- SmallPtrSet<const DIScope *, 8> &VisitedScopes) {
- if (!Loc || !Loc->getInlinedAt())
- return;
- collectCUsWithScope(Loc->getScope(), LiveCUs, VisitedScopes);
- collectCUsForInlinedFuncs(Loc->getInlinedAt(), LiveCUs, VisitedScopes);
-}
-
static bool stripDeadDebugInfoImpl(Module &M) {
bool Changed = false;
@@ -361,19 +201,15 @@ static bool stripDeadDebugInfoImpl(Module &M) {
}
std::set<DICompileUnit *> LiveCUs;
- SmallPtrSet<const DIScope *, 8> VisitedScopes;
- // Any CU is live if is referenced from a subprogram metadata that is attached
- // to a function defined or inlined in the module.
- for (const Function &Fn : M.functions()) {
- collectCUsWithScope(Fn.getSubprogram(), LiveCUs, VisitedScopes);
- for (const_inst_iterator I = inst_begin(&Fn), E = inst_end(&Fn); I != E;
- ++I) {
- if (!I->getDebugLoc())
- continue;
- const DILocation *DILoc = I->getDebugLoc().get();
- collectCUsForInlinedFuncs(DILoc, LiveCUs, VisitedScopes);
- }
+ DebugInfoFinder LiveCUFinder;
+ for (const Function &F : M.functions()) {
+ if (auto *SP = cast_or_null<DISubprogram>(F.getSubprogram()))
+ LiveCUFinder.processSubprogram(SP);
+ for (const Instruction &I : instructions(F))
+ LiveCUFinder.processInstruction(M, I);
}
+ auto FoundCUs = LiveCUFinder.compile_units();
+ LiveCUs.insert(FoundCUs.begin(), FoundCUs.end());
bool HasDeadCUs = false;
for (DICompileUnit *DIC : F.compile_units()) {
@@ -424,39 +260,34 @@ static bool stripDeadDebugInfoImpl(Module &M) {
return Changed;
}
-/// Remove any debug info for global variables/functions in the given module for
-/// which said global variable/function no longer exists (i.e. is null).
-///
-/// Debugging information is encoded in llvm IR using metadata. This is designed
-/// such a way that debug info for symbols preserved even if symbols are
-/// optimized away by the optimizer. This special pass removes debug info for
-/// such symbols.
-bool StripDeadDebugInfo::runOnModule(Module &M) {
- if (skipModule(M))
- return false;
- return stripDeadDebugInfoImpl(M);
-}
-
PreservedAnalyses StripSymbolsPass::run(Module &M, ModuleAnalysisManager &AM) {
StripDebugInfo(M);
StripSymbolNames(M, false);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
PreservedAnalyses StripNonDebugSymbolsPass::run(Module &M,
ModuleAnalysisManager &AM) {
StripSymbolNames(M, true);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
PreservedAnalyses StripDebugDeclarePass::run(Module &M,
ModuleAnalysisManager &AM) {
stripDebugDeclareImpl(M);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
PreservedAnalyses StripDeadDebugInfoPass::run(Module &M,
ModuleAnalysisManager &AM) {
stripDeadDebugInfoImpl(M);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
index 670097010085..fc1e70b1b3d3 100644
--- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
+++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
@@ -18,9 +18,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Object/ModuleSymbolTable.h"
-#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
@@ -148,6 +146,14 @@ void promoteTypeIds(Module &M, StringRef ModuleId) {
}
}
+ if (Function *TypeCheckedLoadRelativeFunc = M.getFunction(
+ Intrinsic::getName(Intrinsic::type_checked_load_relative))) {
+ for (const Use &U : TypeCheckedLoadRelativeFunc->uses()) {
+ auto CI = cast<CallInst>(U.getUser());
+ ExternalizeTypeId(CI, 2);
+ }
+ }
+
for (GlobalObject &GO : M.global_objects()) {
SmallVector<MDNode *, 1> MDs;
GO.getMetadata(LLVMContext::MD_type, MDs);
@@ -196,6 +202,13 @@ void simplifyExternals(Module &M) {
F.eraseFromParent();
}
+ for (GlobalIFunc &I : llvm::make_early_inc_range(M.ifuncs())) {
+ if (I.use_empty())
+ I.eraseFromParent();
+ else
+ assert(I.getResolverFunction() && "ifunc misses its resolver function");
+ }
+
for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) {
if (GV.isDeclaration() && GV.use_empty()) {
GV.eraseFromParent();
@@ -246,6 +259,16 @@ static void cloneUsedGlobalVariables(const Module &SrcM, Module &DestM,
appendToUsed(DestM, NewUsed);
}
+#ifndef NDEBUG
+static bool enableUnifiedLTO(Module &M) {
+ bool UnifiedLTO = false;
+ if (auto *MD =
+ mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("UnifiedLTO")))
+ UnifiedLTO = MD->getZExtValue();
+ return UnifiedLTO;
+}
+#endif
+
// If it's possible to split M into regular and thin LTO parts, do so and write
// a multi-module bitcode file with the two parts to OS. Otherwise, write only a
// regular LTO bitcode file to OS.
@@ -254,18 +277,20 @@ void splitAndWriteThinLTOBitcode(
function_ref<AAResults &(Function &)> AARGetter, Module &M) {
std::string ModuleId = getUniqueModuleId(&M);
if (ModuleId.empty()) {
+ assert(!enableUnifiedLTO(M));
// We couldn't generate a module ID for this module, write it out as a
// regular LTO module with an index for summary-based dead stripping.
ProfileSummaryInfo PSI(M);
M.addModuleFlag(Module::Error, "ThinLTO", uint32_t(0));
ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, &PSI);
- WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, &Index);
+ WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, &Index,
+ /*UnifiedLTO=*/false);
if (ThinLinkOS)
// We don't have a ThinLTO part, but still write the module to the
// ThinLinkOS if requested so that the expected output file is produced.
WriteBitcodeToFile(M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false,
- &Index);
+ &Index, /*UnifiedLTO=*/false);
return;
}
@@ -503,15 +528,17 @@ bool hasTypeMetadata(Module &M) {
return false;
}
-void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS,
+bool writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS,
function_ref<AAResults &(Function &)> AARGetter,
Module &M, const ModuleSummaryIndex *Index) {
std::unique_ptr<ModuleSummaryIndex> NewIndex = nullptr;
// See if this module has any type metadata. If so, we try to split it
// or at least promote type ids to enable WPD.
if (hasTypeMetadata(M)) {
- if (enableSplitLTOUnit(M))
- return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M);
+ if (enableSplitLTOUnit(M)) {
+ splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M);
+ return true;
+ }
// Promote type ids as needed for index-based WPD.
std::string ModuleId = getUniqueModuleId(&M);
if (!ModuleId.empty()) {
@@ -544,6 +571,7 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS,
// given OS.
if (ThinLinkOS && Index)
writeThinLinkBitcodeToFile(M, *ThinLinkOS, *Index, ModHash);
+ return false;
}
} // anonymous namespace
@@ -552,10 +580,11 @@ PreservedAnalyses
llvm::ThinLTOBitcodeWriterPass::run(Module &M, ModuleAnalysisManager &AM) {
FunctionAnalysisManager &FAM =
AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- writeThinLTOBitcode(OS, ThinLinkOS,
- [&FAM](Function &F) -> AAResults & {
- return FAM.getResult<AAManager>(F);
- },
- M, &AM.getResult<ModuleSummaryIndexAnalysis>(M));
- return PreservedAnalyses::all();
+ bool Changed = writeThinLTOBitcode(
+ OS, ThinLinkOS,
+ [&FAM](Function &F) -> AAResults & {
+ return FAM.getResult<AAManager>(F);
+ },
+ M, &AM.getResult<ModuleSummaryIndexAnalysis>(M));
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 487a0a4a97f7..d33258642365 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -58,7 +58,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -84,9 +83,6 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSummaryIndexYAML.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/PassRegistry.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Errc.h"
@@ -94,6 +90,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/GlobPattern.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -259,7 +256,7 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
if (I < B.size())
BitsUsed |= B[I];
if (BitsUsed != 0xff)
- return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed));
+ return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed));
}
} else {
// Find a free (Size/8) byte region in each member of Used.
@@ -313,9 +310,10 @@ void wholeprogramdevirt::setAfterReturnValues(
}
}
-VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
+VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM)
: Fn(Fn), TM(TM),
- IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
+ IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()),
+ WasDevirt(false) {}
namespace {
@@ -379,6 +377,7 @@ namespace {
// conditions
// 1) All summaries are live.
// 2) All function summaries indicate it's unreachable
+// 3) There is no non-function with the same GUID (which is rare)
bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
// Returns false if ValueInfo is absent, or the summary list is empty
@@ -391,12 +390,13 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
// In general either all summaries should be live or all should be dead.
if (!Summary->isLive())
return false;
- if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) {
+ if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) {
if (!FS->fflags().MustBeUnreachable)
return false;
}
- // Do nothing if a non-function has the same GUID (which is rare).
- // This is correct since non-function summaries are not relevant.
+ // Be conservative if a non-function has the same GUID (which is rare).
+ else
+ return false;
}
// All function summaries are live and all of them agree that the function is
// unreachble.
@@ -567,6 +567,10 @@ struct DevirtModule {
// optimize a call more than once.
SmallPtrSet<CallBase *, 8> OptimizedCalls;
+ // Store calls that had their ptrauth bundle removed. They are to be deleted
+ // at the end of the optimization.
+ SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved;
+
// This map keeps track of the number of "unsafe" uses of a loaded function
// pointer. The key is the associated llvm.type.test intrinsic call generated
// by this pass. An unsafe use is one that calls the loaded function pointer
@@ -761,7 +765,7 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return FAM.getResult<DominatorTreeAnalysis>(F);
};
if (UseCommandLine) {
- if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
+ if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
@@ -892,8 +896,7 @@ static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
// DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
const auto &ModPaths = Summary->modulePaths();
if (ClSummaryAction != PassSummaryAction::Import &&
- ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) ==
- ModPaths.end())
+ !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName()))
return createStringError(
errc::invalid_argument,
"combined summary should contain Regular LTO module");
@@ -958,7 +961,7 @@ void DevirtModule::buildTypeIdentifierMap(
std::vector<VTableBits> &Bits,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
DenseMap<GlobalVariable *, VTableBits *> GVToBits;
- Bits.reserve(M.getGlobalList().size());
+ Bits.reserve(M.global_size());
SmallVector<MDNode *, 2> Types;
for (GlobalVariable &GV : M.globals()) {
Types.clear();
@@ -1003,11 +1006,17 @@ bool DevirtModule::tryFindVirtualCallTargets(
return false;
Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
- TM.Offset + ByteOffset, M);
+ TM.Offset + ByteOffset, M, TM.Bits->GV);
if (!Ptr)
return false;
- auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
+ auto C = Ptr->stripPointerCasts();
+ // Make sure this is a function or alias to a function.
+ auto Fn = dyn_cast<Function>(C);
+ auto A = dyn_cast<GlobalAlias>(C);
+ if (!Fn && A)
+ Fn = dyn_cast<Function>(A->getAliasee());
+
if (!Fn)
return false;
@@ -1024,7 +1033,11 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (mustBeUnreachableFunction(Fn, ExportSummary))
continue;
- TargetsForSlot.push_back({Fn, &TM});
+ // Save the symbol used in the vtable to use as the devirtualization
+ // target.
+ auto GV = dyn_cast<GlobalValue>(C);
+ assert(GV);
+ TargetsForSlot.push_back({GV, &TM});
}
// Give up if we couldn't find any targets.
@@ -1156,6 +1169,14 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
// !callees metadata.
CB.setMetadata(LLVMContext::MD_prof, nullptr);
CB.setMetadata(LLVMContext::MD_callees, nullptr);
+ if (CB.getCalledOperand() &&
+ CB.getOperandBundle(LLVMContext::OB_ptrauth)) {
+ auto *NewCS =
+ CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB);
+ CB.replaceAllUsesWith(NewCS);
+ // Schedule for deletion at the end of pass run.
+ CallsWithPtrAuthBundleRemoved.push_back(&CB);
+ }
}
// This use is no longer unsafe.
@@ -1205,7 +1226,7 @@ bool DevirtModule::trySingleImplDevirt(
WholeProgramDevirtResolution *Res) {
// See if the program contains a single implementation of this virtual
// function.
- Function *TheFn = TargetsForSlot[0].Fn;
+ auto *TheFn = TargetsForSlot[0].Fn;
for (auto &&Target : TargetsForSlot)
if (TheFn != Target.Fn)
return false;
@@ -1379,9 +1400,20 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
IsExported = true;
if (CSInfo.AllCallSitesDevirted)
return;
+
+ std::map<CallBase *, CallBase *> CallBases;
for (auto &&VCallSite : CSInfo.CallSites) {
CallBase &CB = VCallSite.CB;
+ if (CallBases.find(&CB) != CallBases.end()) {
+ // When finding devirtualizable calls, it's possible to find the same
+ // vtable passed to multiple llvm.type.test or llvm.type.checked.load
+ // calls, which can cause duplicate call sites to be recorded in
+ // [Const]CallSites. If we've already found one of these
+ // call instances, just ignore it. It will be replaced later.
+ continue;
+ }
+
// Jump tables are only profitable if the retpoline mitigation is enabled.
Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (!FSAttr.isValid() ||
@@ -1428,8 +1460,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
AttributeList::get(M.getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), NewArgAttrs));
- CB.replaceAllUsesWith(NewCS);
- CB.eraseFromParent();
+ CallBases[&CB] = NewCS;
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
@@ -1439,6 +1470,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// retpoline mitigation, which would mean that they are lowered to
// llvm.type.test and therefore require an llvm.type.test resolution for the
// type identifier.
+
+ std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) {
+ CBs.first->replaceAllUsesWith(CBs.second);
+ CBs.first->eraseFromParent();
+ });
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -1451,23 +1487,30 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs(
// Evaluate each function and store the result in each target's RetVal
// field.
for (VirtualCallTarget &Target : TargetsForSlot) {
- if (Target.Fn->arg_size() != Args.size() + 1)
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(Target.Fn);
+ if (!Fn)
+ return false;
+
+ if (Fn->arg_size() != Args.size() + 1)
return false;
Evaluator Eval(M.getDataLayout(), nullptr);
SmallVector<Constant *, 2> EvalArgs;
EvalArgs.push_back(
- Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
+ Constant::getNullValue(Fn->getFunctionType()->getParamType(0)));
for (unsigned I = 0; I != Args.size(); ++I) {
- auto *ArgTy = dyn_cast<IntegerType>(
- Target.Fn->getFunctionType()->getParamType(I + 1));
+ auto *ArgTy =
+ dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1));
if (!ArgTy)
return false;
EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
}
Constant *RetVal;
- if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
+ if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) ||
!isa<ConstantInt>(RetVal))
return false;
Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
@@ -1675,8 +1718,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
OREGetter, IsBitSet);
} else {
- Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
- Value *Val = B.CreateLoad(RetType, ValAddr);
+ Value *Val = B.CreateLoad(RetType, Addr);
NumVirtConstProp++;
Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
OREGetter, Val);
@@ -1688,8 +1730,14 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
bool DevirtModule::tryVirtualConstProp(
MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res, VTableSlot Slot) {
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn);
+ if (!Fn)
+ return false;
// This only works if the function returns an integer.
- auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
+ auto RetType = dyn_cast<IntegerType>(Fn->getReturnType());
if (!RetType)
return false;
unsigned BitWidth = RetType->getBitWidth();
@@ -1707,11 +1755,18 @@ bool DevirtModule::tryVirtualConstProp(
// inline all implementations of the virtual function into each call site,
// rather than using function attributes to perform local optimization.
for (VirtualCallTarget &Target : TargetsForSlot) {
- if (Target.Fn->isDeclaration() ||
- !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn))
+ // TODO: Skip for now if the vtable symbol was an alias to a function,
+ // need to evaluate whether it would be correct to analyze the aliasee
+ // function for this optimization.
+ auto Fn = dyn_cast<Function>(Target.Fn);
+ if (!Fn)
+ return false;
+
+ if (Fn->isDeclaration() ||
+ !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn))
.doesNotAccessMemory() ||
- Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
- Target.Fn->getReturnType() != RetType)
+ Fn->arg_empty() || !Fn->arg_begin()->use_empty() ||
+ Fn->getReturnType() != RetType)
return false;
}
@@ -1947,9 +2002,23 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
// This helps avoid unnecessary spills.
IRBuilder<> LoadB(
(LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
- Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
- Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+
+ Value *LoadedValue = nullptr;
+ if (TypeCheckedLoadFunc->getIntrinsicID() ==
+ Intrinsic::type_checked_load_relative) {
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty));
+ LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr);
+ LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy);
+ GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy);
+ LoadedValue = LoadB.CreateAdd(GEP, LoadedValue);
+ LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy);
+ } else {
+ Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
+ Value *GEPPtr =
+ LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
+ LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+ }
for (Instruction *LoadedPtr : LoadedPtrs) {
LoadedPtr->replaceAllUsesWith(LoadedValue);
@@ -2130,6 +2199,8 @@ bool DevirtModule::run() {
M.getFunction(Intrinsic::getName(Intrinsic::type_test));
Function *TypeCheckedLoadFunc =
M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
+ Function *TypeCheckedLoadRelativeFunc =
+ M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative));
Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
// Normally if there are no users of the devirtualization intrinsics in the
@@ -2138,7 +2209,9 @@ bool DevirtModule::run() {
if (!ExportSummary &&
(!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
AssumeFunc->use_empty()) &&
- (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
+ (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
+ (!TypeCheckedLoadRelativeFunc ||
+ TypeCheckedLoadRelativeFunc->use_empty()))
return false;
// Rebuild type metadata into a map for easy lookup.
@@ -2152,6 +2225,9 @@ bool DevirtModule::run() {
if (TypeCheckedLoadFunc)
scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
+ if (TypeCheckedLoadRelativeFunc)
+ scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc);
+
if (ImportSummary) {
for (auto &S : CallSlots)
importResolution(S.first, S.second);
@@ -2219,7 +2295,7 @@ bool DevirtModule::run() {
// For each (type, offset) pair:
bool DidVirtualConstProp = false;
- std::map<std::string, Function*> DevirtTargets;
+ std::map<std::string, GlobalValue *> DevirtTargets;
for (auto &S : CallSlots) {
// Search each of the members of the type identifier for the virtual
// function implementation at offset S.first.ByteOffset, and add to
@@ -2274,7 +2350,14 @@ bool DevirtModule::run() {
if (RemarksEnabled) {
// Generate remarks for each devirtualized function.
for (const auto &DT : DevirtTargets) {
- Function *F = DT.second;
+ GlobalValue *GV = DT.second;
+ auto F = dyn_cast<Function>(GV);
+ if (!F) {
+ auto A = dyn_cast<GlobalAlias>(GV);
+ assert(A && isa<Function>(A->getAliasee()));
+ F = dyn_cast<Function>(A->getAliasee());
+ assert(F);
+ }
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
@@ -2299,6 +2382,9 @@ bool DevirtModule::run() {
for (GlobalVariable &GV : M.globals())
GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
+ for (auto *CI : CallsWithPtrAuthBundleRemoved)
+ CI->eraseFromParent();
+
return true;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index b68efc993723..91ca44e0f11e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -797,7 +797,7 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
// LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
// ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
- if (C1->countTrailingZeros() == 0)
+ if (C1->countr_zero() == 0)
if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
Value *NewOr = Builder.CreateOr(Z, ~(*C2));
return Builder.CreateSub(RHS, NewOr, "sub");
@@ -880,8 +880,15 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
// ~X + C --> (C-1) - X
- if (match(Op0, m_Not(m_Value(X))))
- return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X);
+ if (match(Op0, m_Not(m_Value(X)))) {
+ // ~X + C has NSW and (C-1) won't oveflow => (C-1)-X can have NSW
+ auto *COne = ConstantInt::get(Op1C->getType(), 1);
+ bool WillNotSOV = willNotOverflowSignedSub(Op1C, COne, Add);
+ BinaryOperator *Res =
+ BinaryOperator::CreateSub(ConstantExpr::getSub(Op1C, COne), X);
+ Res->setHasNoSignedWrap(Add.hasNoSignedWrap() && WillNotSOV);
+ return Res;
+ }
// (iN X s>> (N - 1)) + 1 --> zext (X > -1)
const APInt *C;
@@ -975,6 +982,16 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
}
}
+ // Fold (add (zext (add X, -1)), 1) -> (zext X) if X is non-zero.
+ // TODO: There's a general form for any constant on the outer add.
+ if (C->isOne()) {
+ if (match(Op0, m_ZExt(m_Add(m_Value(X), m_AllOnes())))) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Add);
+ if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return new ZExtInst(X, Ty);
+ }
+ }
+
return nullptr;
}
@@ -1366,6 +1383,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *X = foldNoWrapAdd(I, Builder))
return X;
+ if (Instruction *R = foldBinOpShiftWithShift(I))
+ return R;
+
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
@@ -1421,6 +1441,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
Value *Sub = Builder.CreateSub(A, B);
return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2));
}
+
+ // Canonicalize a constant sub operand as an add operand for better folding:
+ // (C1 - A) + B --> (B - A) + C1
+ if (match(&I, m_c_Add(m_OneUse(m_Sub(m_ImmConstant(C1), m_Value(A))),
+ m_Value(B)))) {
+ Value *Sub = Builder.CreateSub(B, A, "reass.sub");
+ return BinaryOperator::CreateAdd(Sub, C1);
+ }
}
// X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
@@ -1439,7 +1467,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) &&
- C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) {
+ C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) {
Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1);
return BinaryOperator::CreateAnd(A, NewMask);
}
@@ -1451,6 +1479,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))))
return new ZExtInst(B, LHS->getType());
+ // zext(A) + sext(A) --> 0 if A is i1
+ if (match(&I, m_c_BinOp(m_ZExt(m_Value(A)), m_SExt(m_Deferred(A)))) &&
+ A->getType()->isIntOrIntVectorTy(1))
+ return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+
// A+B --> A|B iff A and B have no bits set in common.
if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
return BinaryOperator::CreateOr(LHS, RHS);
@@ -1515,7 +1548,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
const APInt *NegPow2C;
if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))),
m_Value(B)))) {
- Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros());
+ Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countr_zero());
Value *Shl = Builder.CreateShl(A, ShiftAmtC);
return BinaryOperator::CreateSub(B, Shl);
}
@@ -1536,6 +1569,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *Ashr = foldAddToAshr(I))
return Ashr;
+ // min(A, B) + max(A, B) => A + B.
+ if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
+ m_c_SMin(m_Deferred(A), m_Deferred(B))),
+ m_c_Add(m_UMax(m_Value(A), m_Value(B)),
+ m_c_UMin(m_Deferred(A), m_Deferred(B))))))
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
@@ -1575,6 +1615,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateOr(A, B)}));
+ if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
+ return Res;
+
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
return Changed ? &I : nullptr;
}
@@ -1786,6 +1832,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, V);
}
+ // minumum(X, Y) + maximum(X, Y) => X + Y.
+ if (match(&I,
+ m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN + NaN,
+ // while in optimized version NaN + Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
@@ -1956,8 +2016,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
Constant *C2;
// C-(X+C2) --> (C-C2)-X
- if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2))))
- return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
+ if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) {
+ // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW
+ // => (C-C2)-X can have NSW
+ bool WillNotSOV = willNotOverflowSignedSub(C, C2, I);
+ BinaryOperator *Res =
+ BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
+ auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
+ Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() &&
+ WillNotSOV);
+ return Res;
+ }
}
auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * {
@@ -2325,7 +2394,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
const APInt *AddC, *AndC;
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) &&
match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) {
- unsigned Cttz = AddC->countTrailingZeros();
+ unsigned Cttz = AddC->countr_zero();
APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz));
if ((HighMask & *AndC).isZero())
return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC)));
@@ -2388,6 +2457,21 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
return replaceInstUsesWith(I, Mul);
}
+ // max(X,Y) nsw/nuw - min(X,Y) --> abs(X nsw - Y)
+ if (match(Op0, m_OneUse(m_c_SMax(m_Value(X), m_Value(Y)))) &&
+ match(Op1, m_OneUse(m_c_SMin(m_Specific(X), m_Specific(Y))))) {
+ if (I.hasNoUnsignedWrap() || I.hasNoSignedWrap()) {
+ Value *Sub =
+ Builder.CreateSub(X, Y, "sub", /*HasNUW=*/false, /*HasNSW=*/true);
+ Value *Call =
+ Builder.CreateBinaryIntrinsic(Intrinsic::abs, Sub, Builder.getTrue());
+ return replaceInstUsesWith(I, Call);
+ }
+ }
+
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
return TryToNarrowDeduceFlags();
}
@@ -2567,7 +2651,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
// Note that if this fsub was really an fneg, the fadd with -0.0 will get
// killed later. We still limit that particular transform with 'hasOneUse'
// because an fneg is assumed better/cheaper than a generic fsub.
- if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) {
+ if (I.hasNoSignedZeros() || cannotBeNegativeZero(Op0, SQ.DL, SQ.TLI)) {
if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) {
Value *NewSub = Builder.CreateFSubFMF(Y, X, &I);
return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 97a001b2ed32..8a1fb6b7f17e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -625,7 +625,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
return RHS;
}
- if (Mask & BMask_Mixed) {
+ if (Mask & (BMask_Mixed | BMask_NotMixed)) {
+ // Mixed:
// (icmp eq (A & B), C) & (icmp eq (A & D), E)
// We already know that B & C == C && D & E == E.
// If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of
@@ -636,24 +637,50 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// We can't simply use C and E because we might actually handle
// (icmp ne (A & B), B) & (icmp eq (A & D), D)
// with B and D, having a single bit set.
+
+ // NotMixed:
+ // (icmp ne (A & B), C) & (icmp ne (A & D), E)
+ // -> (icmp ne (A & (B & D)), (C & E))
+ // Check the intersection (B & D) for inequality.
+ // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B
+ // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the
+ // B and the D, don't contradict.
+ // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous
+ // operation should delete these icmps if it hadn't been met.
+
const APInt *OldConstC, *OldConstE;
if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE)))
return nullptr;
- const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC;
- const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE;
+ auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * {
+ CC = IsNot ? CmpInst::getInversePredicate(CC) : CC;
+ const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC;
+ const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE;
- // If there is a conflict, we should actually return a false for the
- // whole construct.
- if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue())
- return ConstantInt::get(LHS->getType(), !IsAnd);
+ if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue())
+ return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd);
- Value *NewOr1 = Builder.CreateOr(B, D);
- Value *NewAnd = Builder.CreateAnd(A, NewOr1);
- Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE);
- return Builder.CreateICmp(NewCC, NewAnd, NewOr2);
- }
+ if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB))
+ return nullptr;
+ APInt BD, CE;
+ if (IsNot) {
+ BD = *ConstB & *ConstD;
+ CE = ConstC & ConstE;
+ } else {
+ BD = *ConstB | *ConstD;
+ CE = ConstC | ConstE;
+ }
+ Value *NewAnd = Builder.CreateAnd(A, BD);
+ Value *CEVal = ConstantInt::get(A->getType(), CE);
+ return Builder.CreateICmp(CC, CEVal, NewAnd);
+ };
+
+ if (Mask & BMask_Mixed)
+ return FoldBMixed(NewCC, false);
+ if (Mask & BMask_NotMixed) // can be else also
+ return FoldBMixed(NewCC, true);
+ }
return nullptr;
}
@@ -928,6 +955,108 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd,
return nullptr;
}
+/// Try to fold (icmp(A & B) == 0) & (icmp(A & D) != E) into (icmp A u< D) iff
+/// B is a contiguous set of ones starting from the most significant bit
+/// (negative power of 2), D and E are equal, and D is a contiguous set of ones
+/// starting at the most significant zero bit in B. Parameter B supports masking
+/// using undef/poison in either scalar or vector values.
+static Value *foldNegativePower2AndShiftedMask(
+ Value *A, Value *B, Value *D, Value *E, ICmpInst::Predicate PredL,
+ ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) {
+ assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
+ "Expected equality predicates for masked type of icmps.");
+ if (PredL != ICmpInst::ICMP_EQ || PredR != ICmpInst::ICMP_NE)
+ return nullptr;
+
+ if (!match(B, m_NegatedPower2()) || !match(D, m_ShiftedMask()) ||
+ !match(E, m_ShiftedMask()))
+ return nullptr;
+
+ // Test scalar arguments for conversion. B has been validated earlier to be a
+ // negative power of two and thus is guaranteed to have one or more contiguous
+ // ones starting from the MSB followed by zero or more contiguous zeros. D has
+ // been validated earlier to be a shifted set of one or more contiguous ones.
+ // In order to match, B leading ones and D leading zeros should be equal. The
+ // predicate that B be a negative power of 2 prevents the condition of there
+ // ever being zero leading ones. Thus 0 == 0 cannot occur. The predicate that
+ // D always be a shifted mask prevents the condition of D equaling 0. This
+ // prevents matching the condition where B contains the maximum number of
+ // leading one bits (-1) and D contains the maximum number of leading zero
+ // bits (0).
+ auto isReducible = [](const Value *B, const Value *D, const Value *E) {
+ const APInt *BCst, *DCst, *ECst;
+ return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) &&
+ match(E, m_APInt(ECst)) && *DCst == *ECst &&
+ (isa<UndefValue>(B) ||
+ (BCst->countLeadingOnes() == DCst->countLeadingZeros()));
+ };
+
+ // Test vector type arguments for conversion.
+ if (const auto *BVTy = dyn_cast<VectorType>(B->getType())) {
+ const auto *BFVTy = dyn_cast<FixedVectorType>(BVTy);
+ const auto *BConst = dyn_cast<Constant>(B);
+ const auto *DConst = dyn_cast<Constant>(D);
+ const auto *EConst = dyn_cast<Constant>(E);
+
+ if (!BFVTy || !BConst || !DConst || !EConst)
+ return nullptr;
+
+ for (unsigned I = 0; I != BFVTy->getNumElements(); ++I) {
+ const auto *BElt = BConst->getAggregateElement(I);
+ const auto *DElt = DConst->getAggregateElement(I);
+ const auto *EElt = EConst->getAggregateElement(I);
+
+ if (!BElt || !DElt || !EElt)
+ return nullptr;
+ if (!isReducible(BElt, DElt, EElt))
+ return nullptr;
+ }
+ } else {
+ // Test scalar type arguments for conversion.
+ if (!isReducible(B, D, E))
+ return nullptr;
+ }
+ return Builder.CreateICmp(ICmpInst::ICMP_ULT, A, D);
+}
+
+/// Try to fold ((icmp X u< P) & (icmp(X & M) != M)) or ((icmp X s> -1) &
+/// (icmp(X & M) != M)) into (icmp X u< M). Where P is a power of 2, M < P, and
+/// M is a contiguous shifted mask starting at the right most significant zero
+/// bit in P. SGT is supported as when P is the largest representable power of
+/// 2, an earlier optimization converts the expression into (icmp X s> -1).
+/// Parameter P supports masking using undef/poison in either scalar or vector
+/// values.
+static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
+ bool JoinedByAnd,
+ InstCombiner::BuilderTy &Builder) {
+ if (!JoinedByAnd)
+ return nullptr;
+ Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
+ ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
+ CmpPred1 = Cmp1->getPredicate();
+ // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
+ // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
+ // SignMask) == 0).
+ std::optional<std::pair<unsigned, unsigned>> MaskPair =
+ getMaskedTypeForICmpPair(A, B, C, D, E, Cmp0, Cmp1, CmpPred0, CmpPred1);
+ if (!MaskPair)
+ return nullptr;
+
+ const auto compareBMask = BMask_NotMixed | BMask_NotAllOnes;
+ unsigned CmpMask0 = MaskPair->first;
+ unsigned CmpMask1 = MaskPair->second;
+ if ((CmpMask0 & Mask_AllZeros) && (CmpMask1 == compareBMask)) {
+ if (Value *V = foldNegativePower2AndShiftedMask(A, B, D, E, CmpPred0,
+ CmpPred1, Builder))
+ return V;
+ } else if ((CmpMask0 == compareBMask) && (CmpMask1 & Mask_AllZeros)) {
+ if (Value *V = foldNegativePower2AndShiftedMask(A, D, B, C, CmpPred1,
+ CmpPred0, Builder))
+ return V;
+ }
+ return nullptr;
+}
+
/// Commuted variants are assumed to be handled by calling this function again
/// with the parameters swapped.
static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
@@ -1313,9 +1442,44 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
return Right;
}
+ // Turn at least two fcmps with constants into llvm.is.fpclass.
+ //
+ // If we can represent a combined value test with one class call, we can
+ // potentially eliminate 4-6 instructions. If we can represent a test with a
+ // single fcmp with fneg and fabs, that's likely a better canonical form.
+ if (LHS->hasOneUse() && RHS->hasOneUse()) {
+ auto [ClassValRHS, ClassMaskRHS] =
+ fcmpToClassTest(PredR, *RHS->getFunction(), RHS0, RHS1);
+ if (ClassValRHS) {
+ auto [ClassValLHS, ClassMaskLHS] =
+ fcmpToClassTest(PredL, *LHS->getFunction(), LHS0, LHS1);
+ if (ClassValLHS == ClassValRHS) {
+ unsigned CombinedMask = IsAnd ? (ClassMaskLHS & ClassMaskRHS)
+ : (ClassMaskLHS | ClassMaskRHS);
+ return Builder.CreateIntrinsic(
+ Intrinsic::is_fpclass, {ClassValLHS->getType()},
+ {ClassValLHS, Builder.getInt32(CombinedMask)});
+ }
+ }
+ }
+
return nullptr;
}
+/// Match an fcmp against a special value that performs a test possible by
+/// llvm.is.fpclass.
+static bool matchIsFPClassLikeFCmp(Value *Op, Value *&ClassVal,
+ uint64_t &ClassMask) {
+ auto *FCmp = dyn_cast<FCmpInst>(Op);
+ if (!FCmp || !FCmp->hasOneUse())
+ return false;
+
+ std::tie(ClassVal, ClassMask) =
+ fcmpToClassTest(FCmp->getPredicate(), *FCmp->getParent()->getParent(),
+ FCmp->getOperand(0), FCmp->getOperand(1));
+ return ClassVal != nullptr;
+}
+
/// or (is_fpclass x, mask0), (is_fpclass x, mask1)
/// -> is_fpclass x, (mask0 | mask1)
/// and (is_fpclass x, mask0), (is_fpclass x, mask1)
@@ -1324,13 +1488,25 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
/// -> is_fpclass x, (mask0 ^ mask1)
Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO,
Value *Op0, Value *Op1) {
- Value *ClassVal;
+ Value *ClassVal0 = nullptr;
+ Value *ClassVal1 = nullptr;
uint64_t ClassMask0, ClassMask1;
- if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>(
- m_Value(ClassVal), m_ConstantInt(ClassMask0)))) &&
+ // Restrict to folding one fcmp into one is.fpclass for now, don't introduce a
+ // new class.
+ //
+ // TODO: Support forming is.fpclass out of 2 separate fcmps when codegen is
+ // better.
+
+ bool IsLHSClass =
+ match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>(
+ m_Value(ClassVal0), m_ConstantInt(ClassMask0))));
+ bool IsRHSClass =
match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>(
- m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) {
+ m_Value(ClassVal1), m_ConstantInt(ClassMask1))));
+ if ((((IsLHSClass || matchIsFPClassLikeFCmp(Op0, ClassVal0, ClassMask0)) &&
+ (IsRHSClass || matchIsFPClassLikeFCmp(Op1, ClassVal1, ClassMask1)))) &&
+ ClassVal0 == ClassVal1) {
unsigned NewClassMask;
switch (BO.getOpcode()) {
case Instruction::And:
@@ -1346,11 +1522,24 @@ Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO,
llvm_unreachable("not a binary logic operator");
}
- // TODO: Also check for special fcmps
- auto *II = cast<IntrinsicInst>(Op0);
- II->setArgOperand(
- 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask));
- return replaceInstUsesWith(BO, II);
+ if (IsLHSClass) {
+ auto *II = cast<IntrinsicInst>(Op0);
+ II->setArgOperand(
+ 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask));
+ return replaceInstUsesWith(BO, II);
+ }
+
+ if (IsRHSClass) {
+ auto *II = cast<IntrinsicInst>(Op1);
+ II->setArgOperand(
+ 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask));
+ return replaceInstUsesWith(BO, II);
+ }
+
+ CallInst *NewClass =
+ Builder.CreateIntrinsic(Intrinsic::is_fpclass, {ClassVal0->getType()},
+ {ClassVal0, Builder.getInt32(NewClassMask)});
+ return replaceInstUsesWith(BO, NewClass);
}
return nullptr;
@@ -1523,6 +1712,39 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding");
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+
+ // fold bitwise(A >> BW - 1, zext(icmp)) (BW is the scalar bits of the
+ // type of A)
+ // -> bitwise(zext(A < 0), zext(icmp))
+ // -> zext(bitwise(A < 0, icmp))
+ auto FoldBitwiseICmpZeroWithICmp = [&](Value *Op0,
+ Value *Op1) -> Instruction * {
+ ICmpInst::Predicate Pred;
+ Value *A;
+ bool IsMatched =
+ match(Op0,
+ m_OneUse(m_LShr(
+ m_Value(A),
+ m_SpecificInt(Op0->getType()->getScalarSizeInBits() - 1)))) &&
+ match(Op1, m_OneUse(m_ZExt(m_ICmp(Pred, m_Value(), m_Value()))));
+
+ if (!IsMatched)
+ return nullptr;
+
+ auto *ICmpL =
+ Builder.CreateICmpSLT(A, Constant::getNullValue(A->getType()));
+ auto *ICmpR = cast<ZExtInst>(Op1)->getOperand(0);
+ auto *BitwiseOp = Builder.CreateBinOp(LogicOpc, ICmpL, ICmpR);
+
+ return new ZExtInst(BitwiseOp, Op0->getType());
+ };
+
+ if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op0, Op1))
+ return Ret;
+
+ if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op1, Op0))
+ return Ret;
+
CastInst *Cast0 = dyn_cast<CastInst>(Op0);
if (!Cast0)
return nullptr;
@@ -1906,16 +2128,16 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I,
return nullptr;
unsigned Width = Ty->getScalarSizeInBits();
- unsigned LastOneMath = Width - C2->countTrailingZeros();
+ unsigned LastOneMath = Width - C2->countr_zero();
switch (OpC) {
case Instruction::And:
- if (C->countLeadingOnes() < LastOneMath)
+ if (C->countl_one() < LastOneMath)
return nullptr;
break;
case Instruction::Xor:
case Instruction::Or:
- if (C->countLeadingZeros() < LastOneMath)
+ if (C->countl_zero() < LastOneMath)
return nullptr;
break;
default:
@@ -1923,7 +2145,51 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I,
}
Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C));
- return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2));
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, NewBinOp,
+ ConstantInt::get(Ty, *C2), Op0);
+}
+
+// binop(shift(ShiftedC1, ShAmt), shift(ShiftedC2, add(ShAmt, AddC))) ->
+// shift(binop(ShiftedC1, shift(ShiftedC2, AddC)), ShAmt)
+// where both shifts are the same and AddC is a valid shift amount.
+Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) {
+ assert((I.isBitwiseLogicOp() || I.getOpcode() == Instruction::Add) &&
+ "Unexpected opcode");
+
+ Value *ShAmt;
+ Constant *ShiftedC1, *ShiftedC2, *AddC;
+ Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ if (!match(&I,
+ m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)),
+ m_Shift(m_ImmConstant(ShiftedC2),
+ m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC))))))
+ return nullptr;
+
+ // Make sure the add constant is a valid shift amount.
+ if (!match(AddC,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(BitWidth, BitWidth))))
+ return nullptr;
+
+ // Avoid constant expressions.
+ auto *Op0Inst = dyn_cast<Instruction>(I.getOperand(0));
+ auto *Op1Inst = dyn_cast<Instruction>(I.getOperand(1));
+ if (!Op0Inst || !Op1Inst)
+ return nullptr;
+
+ // Both shifts must be the same.
+ Instruction::BinaryOps ShiftOp =
+ static_cast<Instruction::BinaryOps>(Op0Inst->getOpcode());
+ if (ShiftOp != Op1Inst->getOpcode())
+ return nullptr;
+
+ // For adds, only left shifts are supported.
+ if (I.getOpcode() == Instruction::Add && ShiftOp != Instruction::Shl)
+ return nullptr;
+
+ Value *NewC = Builder.CreateBinOp(
+ I.getOpcode(), ShiftedC1, Builder.CreateBinOp(ShiftOp, ShiftedC2, AddC));
+ return BinaryOperator::Create(ShiftOp, NewC, ShAmt);
}
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
@@ -1964,6 +2230,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (Value *V = SimplifyBSwap(I, Builder))
return replaceInstUsesWith(I, V);
+ if (Instruction *R = foldBinOpShiftWithShift(I))
+ return R;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *X, *Y;
@@ -2033,7 +2302,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) {
// If we add zeros to every bit below a mask, the add has no effect:
// (X + AddC) & LowMaskC --> X & LowMaskC
- unsigned Ctlz = C->countLeadingZeros();
+ unsigned Ctlz = C->countl_zero();
APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz));
if ((*AddC & LowMask).isZero())
return BinaryOperator::CreateAnd(X, Op1);
@@ -2150,7 +2419,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
const APInt *C3 = C;
Value *X;
if (C3->isPowerOf2()) {
- Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros());
+ Constant *Log2C3 = ConstantInt::get(Ty, C3->countr_zero());
if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)),
m_ImmConstant(C2)))) &&
match(C1, m_Power2())) {
@@ -2407,6 +2676,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1))
return Folded;
+ if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
+ return Res;
+
return nullptr;
}
@@ -2718,34 +2990,47 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
return nullptr;
}
-// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1)
-// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1)
-static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS,
- bool IsAnd, bool IsLogical,
- IRBuilderBase &Builder) {
+// (icmp eq X, C) | (icmp ult Other, (X - C)) -> (icmp ule Other, (X - (C + 1)))
+// (icmp ne X, C) & (icmp uge Other, (X - C)) -> (icmp ugt Other, (X - (C + 1)))
+static Value *foldAndOrOfICmpEqConstantAndICmp(ICmpInst *LHS, ICmpInst *RHS,
+ bool IsAnd, bool IsLogical,
+ IRBuilderBase &Builder) {
+ Value *LHS0 = LHS->getOperand(0);
+ Value *RHS0 = RHS->getOperand(0);
+ Value *RHS1 = RHS->getOperand(1);
+
ICmpInst::Predicate LPred =
IsAnd ? LHS->getInversePredicate() : LHS->getPredicate();
ICmpInst::Predicate RPred =
IsAnd ? RHS->getInversePredicate() : RHS->getPredicate();
- Value *LHS0 = LHS->getOperand(0);
- if (LPred != ICmpInst::ICMP_EQ || !match(LHS->getOperand(1), m_Zero()) ||
+
+ const APInt *CInt;
+ if (LPred != ICmpInst::ICMP_EQ ||
+ !match(LHS->getOperand(1), m_APIntAllowUndef(CInt)) ||
!LHS0->getType()->isIntOrIntVectorTy() ||
!(LHS->hasOneUse() || RHS->hasOneUse()))
return nullptr;
+ auto MatchRHSOp = [LHS0, CInt](const Value *RHSOp) {
+ return match(RHSOp,
+ m_Add(m_Specific(LHS0), m_SpecificIntAllowUndef(-*CInt))) ||
+ (CInt->isZero() && RHSOp == LHS0);
+ };
+
Value *Other;
- if (RPred == ICmpInst::ICMP_ULT && RHS->getOperand(1) == LHS0)
- Other = RHS->getOperand(0);
- else if (RPred == ICmpInst::ICMP_UGT && RHS->getOperand(0) == LHS0)
- Other = RHS->getOperand(1);
+ if (RPred == ICmpInst::ICMP_ULT && MatchRHSOp(RHS1))
+ Other = RHS0;
+ else if (RPred == ICmpInst::ICMP_UGT && MatchRHSOp(RHS0))
+ Other = RHS1;
else
return nullptr;
if (IsLogical)
Other = Builder.CreateFreeze(Other);
+
return Builder.CreateICmp(
IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE,
- Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())),
+ Builder.CreateSub(LHS0, ConstantInt::get(LHS0->getType(), *CInt + 1)),
Other);
}
@@ -2792,12 +3077,12 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return V;
if (Value *V =
- foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
+ foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
return V;
// We can treat logical like bitwise here, because both operands are used on
// the LHS, and as such poison from both will propagate.
- if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd,
- /*IsLogical*/ false, Builder))
+ if (Value *V = foldAndOrOfICmpEqConstantAndICmp(RHS, LHS, IsAnd,
+ /*IsLogical*/ false, Builder))
return V;
if (Value *V =
@@ -2836,6 +3121,9 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder))
return V;
+ if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder))
+ return V;
+
// TODO: Verify whether this is safe for logical and/or.
if (!IsLogical) {
if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder))
@@ -2849,7 +3137,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
// (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0)
- // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs.
+ // TODO: Remove this and below when foldLogOpOfMaskedICmps can handle undefs.
if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) &&
PredL == PredR && match(LHS1, m_ZeroInt()) && match(RHS1, m_ZeroInt()) &&
LHS0->getType() == RHS0->getType()) {
@@ -2858,6 +3146,16 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Constant::getNullValue(NewOr->getType()));
}
+ // (icmp ne A, -1) | (icmp ne B, -1) --> (icmp ne (A&B), -1)
+ // (icmp eq A, -1) & (icmp eq B, -1) --> (icmp eq (A&B), -1)
+ if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) &&
+ PredL == PredR && match(LHS1, m_AllOnes()) && match(RHS1, m_AllOnes()) &&
+ LHS0->getType() == RHS0->getType()) {
+ Value *NewAnd = Builder.CreateAnd(LHS0, RHS0);
+ return Builder.CreateICmp(PredL, NewAnd,
+ Constant::getAllOnesValue(LHS0->getType()));
+ }
+
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
if (!LHSC || !RHSC)
return nullptr;
@@ -2998,6 +3296,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Concat = matchOrConcat(I, Builder))
return replaceInstUsesWith(I, Concat);
+ if (Instruction *R = foldBinOpShiftWithShift(I))
+ return R;
+
Value *X, *Y;
const APInt *CV;
if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
@@ -3416,6 +3717,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1))
return Folded;
+ if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
+ return Res;
+
return nullptr;
}
@@ -3715,6 +4019,24 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor,
return nullptr;
}
+static bool canFreelyInvert(InstCombiner &IC, Value *Op,
+ Instruction *IgnoredUser) {
+ auto *I = dyn_cast<Instruction>(Op);
+ return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) &&
+ InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser);
+}
+
+static Value *freelyInvert(InstCombinerImpl &IC, Value *Op,
+ Instruction *IgnoredUser) {
+ auto *I = cast<Instruction>(Op);
+ IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef());
+ Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not");
+ Op->replaceUsesWithIf(NotOp,
+ [NotOp](Use &U) { return U.getUser() != NotOp; });
+ IC.freelyInvertAllUsersOf(NotOp, IgnoredUser);
+ return NotOp;
+}
+
// Transform
// z = ~(x &/| y)
// into:
@@ -3739,28 +4061,11 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) {
return false;
// And can the operands be adapted?
- for (Value *Op : {Op0, Op1})
- if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) &&
- (match(Op, m_ImmConstant()) ||
- (isa<Instruction>(Op) &&
- InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op),
- /*IgnoredUser=*/&I)))))
- return false;
+ if (!canFreelyInvert(*this, Op0, &I) || !canFreelyInvert(*this, Op1, &I))
+ return false;
- for (Value **Op : {&Op0, &Op1}) {
- Value *NotOp;
- if (auto *C = dyn_cast<Constant>(*Op)) {
- NotOp = ConstantExpr::getNot(C);
- } else {
- Builder.SetInsertPoint(
- &*cast<Instruction>(*Op)->getInsertionPointAfterDef());
- NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not");
- (*Op)->replaceUsesWithIf(
- NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; });
- freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I);
- }
- *Op = NotOp;
- }
+ Op0 = freelyInvert(*this, Op0, &I);
+ Op1 = freelyInvert(*this, Op1, &I);
Builder.SetInsertPoint(I.getInsertionPointAfterDef());
Value *NewLogicOp;
@@ -3794,20 +4099,11 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) {
Value *NotOp0 = nullptr;
Value *NotOp1 = nullptr;
Value **OpToInvert = nullptr;
- if (match(Op0, m_Not(m_Value(NotOp0))) &&
- InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) &&
- (match(Op1, m_ImmConstant()) ||
- (isa<Instruction>(Op1) &&
- InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1),
- /*IgnoredUser=*/&I)))) {
+ if (match(Op0, m_Not(m_Value(NotOp0))) && canFreelyInvert(*this, Op1, &I)) {
Op0 = NotOp0;
OpToInvert = &Op1;
} else if (match(Op1, m_Not(m_Value(NotOp1))) &&
- InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) &&
- (match(Op0, m_ImmConstant()) ||
- (isa<Instruction>(Op0) &&
- InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0),
- /*IgnoredUser=*/&I)))) {
+ canFreelyInvert(*this, Op0, &I)) {
Op1 = NotOp1;
OpToInvert = &Op0;
} else
@@ -3817,19 +4113,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) {
if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
return false;
- if (auto *C = dyn_cast<Constant>(*OpToInvert)) {
- *OpToInvert = ConstantExpr::getNot(C);
- } else {
- Builder.SetInsertPoint(
- &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef());
- Value *NotOpToInvert =
- Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not");
- (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) {
- return U.getUser() != NotOpToInvert;
- });
- freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I);
- *OpToInvert = NotOpToInvert;
- }
+ *OpToInvert = freelyInvert(*this, *OpToInvert, &I);
Builder.SetInsertPoint(&*I.getInsertionPointAfterDef());
Value *NewBinOp;
@@ -3896,8 +4180,8 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y))))
return BinaryOperator::CreateAShr(X, Y);
- // Bit-hack form of a signbit test:
- // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN
+ // Bit-hack form of a signbit test for iN type:
+ // ~(X >>s (N - 1)) --> sext i1 (X > -1) to iN
unsigned FullShift = Ty->getScalarSizeInBits() - 1;
if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) {
Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg");
@@ -4071,6 +4355,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *R = foldNot(I))
return R;
+ if (Instruction *R = foldBinOpShiftWithShift(I))
+ return R;
+
// Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M)
// This it a special case in haveNoCommonBitsSet, but the computeKnownBits
// calls in there are unnecessary as SimplifyDemandedInstructionBits should
@@ -4280,6 +4567,23 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
}
}
+ // (A & B) ^ (A | C) --> A ? ~B : C -- There are 4 commuted variants.
+ if (I.getType()->isIntOrIntVectorTy(1) &&
+ match(Op0, m_OneUse(m_LogicalAnd(m_Value(A), m_Value(B)))) &&
+ match(Op1, m_OneUse(m_LogicalOr(m_Value(C), m_Value(D))))) {
+ bool NeedFreeze = isa<SelectInst>(Op0) && isa<SelectInst>(Op1) && B == D;
+ if (B == C || B == D)
+ std::swap(A, B);
+ if (A == C)
+ std::swap(C, D);
+ if (A == D) {
+ if (NeedFreeze)
+ A = Builder.CreateFreeze(A);
+ Value *NotB = Builder.CreateNot(B);
+ return SelectInst::Create(A, NotB, C);
+ }
+ }
+
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
if (Value *V = foldXorOfICmps(LHS, RHS, I))
@@ -4313,5 +4617,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I))
return Folded;
+ if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
+ return Res;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
index e73667f9c02e..cba282cea72b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
@@ -116,24 +116,10 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
return &RMWI;
}
- AtomicOrdering Ordering = RMWI.getOrdering();
- assert(Ordering != AtomicOrdering::NotAtomic &&
- Ordering != AtomicOrdering::Unordered &&
+ assert(RMWI.getOrdering() != AtomicOrdering::NotAtomic &&
+ RMWI.getOrdering() != AtomicOrdering::Unordered &&
"AtomicRMWs don't make sense with Unordered or NotAtomic");
- // Any atomicrmw xchg with no uses can be converted to a atomic store if the
- // ordering is compatible.
- if (RMWI.getOperation() == AtomicRMWInst::Xchg &&
- RMWI.use_empty()) {
- if (Ordering != AtomicOrdering::Release &&
- Ordering != AtomicOrdering::Monotonic)
- return nullptr;
- new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(),
- /*isVolatile*/ false, RMWI.getAlign(), Ordering,
- RMWI.getSyncScopeID(), &RMWI);
- return eraseInstFromFunction(RMWI);
- }
-
if (!isIdempotentRMW(RMWI))
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index fbf1327143a8..d3ec6a7aa667 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -27,6 +27,7 @@
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -439,9 +440,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
ElementCount VF = WideLoadTy->getElementCount();
- Constant *EC =
- ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
- Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
+ Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF);
Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
Value *Extract =
Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
@@ -533,16 +532,15 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType()));
}
- // If the operand is a select with constant arm(s), try to hoist ctlz/cttz.
- if (auto *Sel = dyn_cast<SelectInst>(Op0))
- if (Instruction *R = IC.FoldOpIntoSelect(II, Sel))
- return R;
-
if (IsTZ) {
// cttz(-x) -> cttz(x)
if (match(Op0, m_Neg(m_Value(X))))
return IC.replaceOperand(II, 0, X);
+ // cttz(-x & x) -> cttz(x)
+ if (match(Op0, m_c_And(m_Neg(m_Value(X)), m_Deferred(X))))
+ return IC.replaceOperand(II, 0, X);
+
// cttz(sext(x)) -> cttz(zext(x))
if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) {
auto *Zext = IC.Builder.CreateZExt(X, II.getType());
@@ -599,8 +597,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
}
// Add range metadata since known bits can't completely reflect what we know.
- // TODO: Handle splat vectors.
- auto *IT = dyn_cast<IntegerType>(Op0->getType());
+ auto *IT = cast<IntegerType>(Op0->getType()->getScalarType());
if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) {
Metadata *LowAndHigh[] = {
ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)),
@@ -657,11 +654,6 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
return CastInst::Create(Instruction::ZExt, NarrowPop, Ty);
}
- // If the operand is a select with constant arm(s), try to hoist ctpop.
- if (auto *Sel = dyn_cast<SelectInst>(Op0))
- if (Instruction *R = IC.FoldOpIntoSelect(II, Sel))
- return R;
-
KnownBits Known(BitWidth);
IC.computeKnownBits(Op0, Known, 0, &II);
@@ -683,12 +675,8 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
Constant::getNullValue(Ty)),
Ty);
- // FIXME: Try to simplify vectors of integers.
- auto *IT = dyn_cast<IntegerType>(Ty);
- if (!IT)
- return nullptr;
-
// Add range metadata since known bits can't completely reflect what we know.
+ auto *IT = cast<IntegerType>(Ty->getScalarType());
unsigned MinCount = Known.countMinPopulation();
unsigned MaxCount = Known.countMaxPopulation();
if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) {
@@ -830,10 +818,204 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) {
return nullptr;
}
+static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) {
+ Ty = Ty->getScalarType();
+ return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE;
+}
+
+static bool inputDenormalIsDAZ(const Function &F, const Type *Ty) {
+ Ty = Ty->getScalarType();
+ return F.getDenormalMode(Ty->getFltSemantics()).inputsAreZero();
+}
+
+/// \returns the compare predicate type if the test performed by
+/// llvm.is.fpclass(x, \p Mask) is equivalent to fcmp o__ x, 0.0 with the
+/// floating-point environment assumed for \p F for type \p Ty
+static FCmpInst::Predicate fpclassTestIsFCmp0(FPClassTest Mask,
+ const Function &F, Type *Ty) {
+ switch (static_cast<unsigned>(Mask)) {
+ case fcZero:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_OEQ;
+ break;
+ case fcZero | fcSubnormal:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_OEQ;
+ break;
+ case fcPositive | fcNegZero:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_OGE;
+ break;
+ case fcPositive | fcNegZero | fcNegSubnormal:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_OGE;
+ break;
+ case fcPosSubnormal | fcPosNormal | fcPosInf:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_OGT;
+ break;
+ case fcNegative | fcPosZero:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_OLE;
+ break;
+ case fcNegative | fcPosZero | fcPosSubnormal:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_OLE;
+ break;
+ case fcNegSubnormal | fcNegNormal | fcNegInf:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_OLT;
+ break;
+ case fcPosNormal | fcPosInf:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_OGT;
+ break;
+ case fcNegNormal | fcNegInf:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_OLT;
+ break;
+ case ~fcZero & ~fcNan:
+ if (inputDenormalIsIEEE(F, Ty))
+ return FCmpInst::FCMP_ONE;
+ break;
+ case ~(fcZero | fcSubnormal) & ~fcNan:
+ if (inputDenormalIsDAZ(F, Ty))
+ return FCmpInst::FCMP_ONE;
+ break;
+ default:
+ break;
+ }
+
+ return FCmpInst::BAD_FCMP_PREDICATE;
+}
+
+Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) {
+ Value *Src0 = II.getArgOperand(0);
+ Value *Src1 = II.getArgOperand(1);
+ const ConstantInt *CMask = cast<ConstantInt>(Src1);
+ FPClassTest Mask = static_cast<FPClassTest>(CMask->getZExtValue());
+ const bool IsUnordered = (Mask & fcNan) == fcNan;
+ const bool IsOrdered = (Mask & fcNan) == fcNone;
+ const FPClassTest OrderedMask = Mask & ~fcNan;
+ const FPClassTest OrderedInvertedMask = ~OrderedMask & ~fcNan;
+
+ const bool IsStrict = II.isStrictFP();
+
+ Value *FNegSrc;
+ if (match(Src0, m_FNeg(m_Value(FNegSrc)))) {
+ // is.fpclass (fneg x), mask -> is.fpclass x, (fneg mask)
+
+ II.setArgOperand(1, ConstantInt::get(Src1->getType(), fneg(Mask)));
+ return replaceOperand(II, 0, FNegSrc);
+ }
+
+ Value *FAbsSrc;
+ if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) {
+ II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask)));
+ return replaceOperand(II, 0, FAbsSrc);
+ }
+
+ // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf
+
+ if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) &&
+ (IsOrdered || IsUnordered) && !IsStrict) {
+ // is.fpclass(x, fcPosInf) -> fcmp oeq x, +inf
+ // is.fpclass(x, fcNegInf) -> fcmp oeq x, -inf
+ // is.fpclass(x, fcPosInf|fcNan) -> fcmp ueq x, +inf
+ // is.fpclass(x, fcNegInf|fcNan) -> fcmp ueq x, -inf
+ Constant *Inf =
+ ConstantFP::getInfinity(Src0->getType(), OrderedMask == fcNegInf);
+ Value *EqInf = IsUnordered ? Builder.CreateFCmpUEQ(Src0, Inf)
+ : Builder.CreateFCmpOEQ(Src0, Inf);
+
+ EqInf->takeName(&II);
+ return replaceInstUsesWith(II, EqInf);
+ }
+
+ if ((OrderedInvertedMask == fcPosInf || OrderedInvertedMask == fcNegInf) &&
+ (IsOrdered || IsUnordered) && !IsStrict) {
+ // is.fpclass(x, ~fcPosInf) -> fcmp one x, +inf
+ // is.fpclass(x, ~fcNegInf) -> fcmp one x, -inf
+ // is.fpclass(x, ~fcPosInf|fcNan) -> fcmp une x, +inf
+ // is.fpclass(x, ~fcNegInf|fcNan) -> fcmp une x, -inf
+ Constant *Inf = ConstantFP::getInfinity(Src0->getType(),
+ OrderedInvertedMask == fcNegInf);
+ Value *NeInf = IsUnordered ? Builder.CreateFCmpUNE(Src0, Inf)
+ : Builder.CreateFCmpONE(Src0, Inf);
+ NeInf->takeName(&II);
+ return replaceInstUsesWith(II, NeInf);
+ }
+
+ if (Mask == fcNan && !IsStrict) {
+ // Equivalent of isnan. Replace with standard fcmp if we don't care about FP
+ // exceptions.
+ Value *IsNan =
+ Builder.CreateFCmpUNO(Src0, ConstantFP::getZero(Src0->getType()));
+ IsNan->takeName(&II);
+ return replaceInstUsesWith(II, IsNan);
+ }
+
+ if (Mask == (~fcNan & fcAllFlags) && !IsStrict) {
+ // Equivalent of !isnan. Replace with standard fcmp.
+ Value *FCmp =
+ Builder.CreateFCmpORD(Src0, ConstantFP::getZero(Src0->getType()));
+ FCmp->takeName(&II);
+ return replaceInstUsesWith(II, FCmp);
+ }
+
+ FCmpInst::Predicate PredType = FCmpInst::BAD_FCMP_PREDICATE;
+
+ // Try to replace with an fcmp with 0
+ //
+ // is.fpclass(x, fcZero) -> fcmp oeq x, 0.0
+ // is.fpclass(x, fcZero | fcNan) -> fcmp ueq x, 0.0
+ // is.fpclass(x, ~fcZero & ~fcNan) -> fcmp one x, 0.0
+ // is.fpclass(x, ~fcZero) -> fcmp une x, 0.0
+ //
+ // is.fpclass(x, fcPosSubnormal | fcPosNormal | fcPosInf) -> fcmp ogt x, 0.0
+ // is.fpclass(x, fcPositive | fcNegZero) -> fcmp oge x, 0.0
+ //
+ // is.fpclass(x, fcNegSubnormal | fcNegNormal | fcNegInf) -> fcmp olt x, 0.0
+ // is.fpclass(x, fcNegative | fcPosZero) -> fcmp ole x, 0.0
+ //
+ if (!IsStrict && (IsOrdered || IsUnordered) &&
+ (PredType = fpclassTestIsFCmp0(OrderedMask, *II.getFunction(),
+ Src0->getType())) !=
+ FCmpInst::BAD_FCMP_PREDICATE) {
+ Constant *Zero = ConstantFP::getZero(Src0->getType());
+ // Equivalent of == 0.
+ Value *FCmp = Builder.CreateFCmp(
+ IsUnordered ? FCmpInst::getUnorderedPredicate(PredType) : PredType,
+ Src0, Zero);
+
+ FCmp->takeName(&II);
+ return replaceInstUsesWith(II, FCmp);
+ }
+
+ KnownFPClass Known = computeKnownFPClass(
+ Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT);
+
+ // Clear test bits we know must be false from the source value.
+ // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other
+ // fp_class (ninf x), ninf|pinf|other -> fp_class (ninf x), other
+ if ((Mask & Known.KnownFPClasses) != Mask) {
+ II.setArgOperand(
+ 1, ConstantInt::get(Src1->getType(), Mask & Known.KnownFPClasses));
+ return &II;
+ }
+
+ // If none of the tests which can return false are possible, fold to true.
+ // fp_class (nnan x), ~(qnan|snan) -> true
+ // fp_class (ninf x), ~(ninf|pinf) -> true
+ if (Mask == Known.KnownFPClasses)
+ return replaceInstUsesWith(II, ConstantInt::get(II.getType(), true));
+
+ return nullptr;
+}
+
static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
- const DataLayout &DL,
- AssumptionCache *AC,
- DominatorTree *DT) {
+ const DataLayout &DL, AssumptionCache *AC,
+ DominatorTree *DT) {
KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT);
if (Known.isNonNegative())
return false;
@@ -848,6 +1030,19 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL);
}
+/// Return true if two values \p Op0 and \p Op1 are known to have the same sign.
+static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI,
+ const DataLayout &DL, AssumptionCache *AC,
+ DominatorTree *DT) {
+ std::optional<bool> Known1 = getKnownSign(Op1, CxtI, DL, AC, DT);
+ if (!Known1)
+ return false;
+ std::optional<bool> Known0 = getKnownSign(Op0, CxtI, DL, AC, DT);
+ if (!Known0)
+ return false;
+ return *Known0 == *Known1;
+}
+
/// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This
/// can trigger other combines.
static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
@@ -991,7 +1186,8 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II,
/// If this min/max has a constant operand and an operand that is a matching
/// min/max with a constant operand, constant-fold the 2 constant operands.
-static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) {
+static Value *reassociateMinMaxWithConstants(IntrinsicInst *II,
+ IRBuilderBase &Builder) {
Intrinsic::ID MinMaxID = II->getIntrinsicID();
auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
if (!LHS || LHS->getIntrinsicID() != MinMaxID)
@@ -1004,12 +1200,10 @@ static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) {
// max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC
ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID);
- Constant *CondC = ConstantExpr::getICmp(Pred, C0, C1);
- Constant *NewC = ConstantExpr::getSelect(CondC, C0, C1);
-
- Module *Mod = II->getModule();
- Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType());
- return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC});
+ Value *CondC = Builder.CreateICmp(Pred, C0, C1);
+ Value *NewC = Builder.CreateSelect(CondC, C0, C1);
+ return Builder.CreateIntrinsic(MinMaxID, II->getType(),
+ {LHS->getArgOperand(0), NewC});
}
/// If this min/max has a matching min/max operand with a constant, try to push
@@ -1149,15 +1343,60 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II,
return new ShuffleVectorInst(NewIntrinsic, Mask);
}
+/// Fold the following cases and accepts bswap and bitreverse intrinsics:
+/// bswap(logic_op(bswap(x), y)) --> logic_op(x, bswap(y))
+/// bswap(logic_op(bswap(x), bswap(y))) --> logic_op(x, y) (ignores multiuse)
+template <Intrinsic::ID IntrID>
+static Instruction *foldBitOrderCrossLogicOp(Value *V,
+ InstCombiner::BuilderTy &Builder) {
+ static_assert(IntrID == Intrinsic::bswap || IntrID == Intrinsic::bitreverse,
+ "This helper only supports BSWAP and BITREVERSE intrinsics");
+
+ Value *X, *Y;
+ // Find bitwise logic op. Check that it is a BinaryOperator explicitly so we
+ // don't match ConstantExpr that aren't meaningful for this transform.
+ if (match(V, m_OneUse(m_BitwiseLogic(m_Value(X), m_Value(Y)))) &&
+ isa<BinaryOperator>(V)) {
+ Value *OldReorderX, *OldReorderY;
+ BinaryOperator::BinaryOps Op = cast<BinaryOperator>(V)->getOpcode();
+
+ // If both X and Y are bswap/bitreverse, the transform reduces the number
+ // of instructions even if there's multiuse.
+ // If only one operand is bswap/bitreverse, we need to ensure the operand
+ // have only one use.
+ if (match(X, m_Intrinsic<IntrID>(m_Value(OldReorderX))) &&
+ match(Y, m_Intrinsic<IntrID>(m_Value(OldReorderY)))) {
+ return BinaryOperator::Create(Op, OldReorderX, OldReorderY);
+ }
+
+ if (match(X, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderX))))) {
+ Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, Y);
+ return BinaryOperator::Create(Op, OldReorderX, NewReorder);
+ }
+
+ if (match(Y, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderY))))) {
+ Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, X);
+ return BinaryOperator::Create(Op, NewReorder, OldReorderY);
+ }
+ }
+ return nullptr;
+}
+
/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Don't try to simplify calls without uses. It will not do anything useful,
// but will result in the following folds being skipped.
- if (!CI.use_empty())
- if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI)))
+ if (!CI.use_empty()) {
+ SmallVector<Value *, 4> Args;
+ Args.reserve(CI.arg_size());
+ for (Value *Op : CI.args())
+ Args.push_back(Op);
+ if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args,
+ SQ.getWithInstruction(&CI)))
return replaceInstUsesWith(CI, V);
+ }
if (Value *FreedOp = getFreedOperand(&CI, &TLI))
return visitFree(CI, FreedOp);
@@ -1176,7 +1415,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// not a multiple of element size then behavior is undefined.
if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II))
if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength()))
- if (NumBytes->getSExtValue() < 0 ||
+ if (NumBytes->isNegative() ||
(NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) {
CreateNonTerminatorUnreachable(AMI);
assert(AMI->getType()->isVoidTy() &&
@@ -1267,10 +1506,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Intrinsic::ID IID = II->getIntrinsicID();
switch (IID) {
- case Intrinsic::objectsize:
- if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false))
+ case Intrinsic::objectsize: {
+ SmallVector<Instruction *> InsertedInstructions;
+ if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false,
+ &InsertedInstructions)) {
+ for (Instruction *Inserted : InsertedInstructions)
+ Worklist.add(Inserted);
return replaceInstUsesWith(CI, V);
+ }
return nullptr;
+ }
case Intrinsic::abs: {
Value *IIOperand = II->getArgOperand(0);
bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue();
@@ -1377,6 +1622,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
+ // (umax X, (xor X, Pow2))
+ // -> (or X, Pow2)
+ // (umin X, (xor X, Pow2))
+ // -> (and X, ~Pow2)
+ // (smax X, (xor X, Pos_Pow2))
+ // -> (or X, Pos_Pow2)
+ // (smin X, (xor X, Pos_Pow2))
+ // -> (and X, ~Pos_Pow2)
+ // (smax X, (xor X, Neg_Pow2))
+ // -> (and X, ~Neg_Pow2)
+ // (smin X, (xor X, Neg_Pow2))
+ // -> (or X, Neg_Pow2)
+ if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) ||
+ match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) &&
+ isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) {
+ bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax;
+ bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin;
+
+ if (IID == Intrinsic::smax || IID == Intrinsic::smin) {
+ auto KnownSign = getKnownSign(X, II, DL, &AC, &DT);
+ if (KnownSign == std::nullopt) {
+ UseOr = false;
+ UseAndN = false;
+ } else if (*KnownSign /* true is Signed. */) {
+ UseOr ^= true;
+ UseAndN ^= true;
+ Type *Ty = I0->getType();
+ // Negative power of 2 must be IntMin. It's possible to be able to
+ // prove negative / power of 2 without actually having known bits, so
+ // just get the value by hand.
+ X = Constant::getIntegerValue(
+ Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits()));
+ }
+ }
+ if (UseOr)
+ return BinaryOperator::CreateOr(I0, X);
+ else if (UseAndN)
+ return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X));
+ }
+
// If we can eliminate ~A and Y is free to invert:
// max ~A, Y --> ~(min A, ~Y)
//
@@ -1436,13 +1721,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *SAdd = matchSAddSubSat(*II))
return SAdd;
- if (match(I1, m_ImmConstant()))
- if (auto *Sel = dyn_cast<SelectInst>(I0))
- if (Instruction *R = FoldOpIntoSelect(*II, Sel))
- return R;
-
- if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II))
- return NewMinMax;
+ if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder))
+ return replaceInstUsesWith(*II, NewMinMax);
if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder))
return R;
@@ -1453,15 +1733,21 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
break;
}
case Intrinsic::bitreverse: {
+ Value *IIOperand = II->getArgOperand(0);
// bitrev (zext i1 X to ?) --> X ? SignBitC : 0
Value *X;
- if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) &&
+ if (match(IIOperand, m_ZExt(m_Value(X))) &&
X->getType()->isIntOrIntVectorTy(1)) {
Type *Ty = II->getType();
APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits());
return SelectInst::Create(X, ConstantInt::get(Ty, SignBit),
ConstantInt::getNullValue(Ty));
}
+
+ if (Instruction *crossLogicOpFold =
+ foldBitOrderCrossLogicOp<Intrinsic::bitreverse>(IIOperand, Builder))
+ return crossLogicOpFold;
+
break;
}
case Intrinsic::bswap: {
@@ -1511,6 +1797,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *V = Builder.CreateLShr(X, CV);
return new TruncInst(V, IIOperand->getType());
}
+
+ if (Instruction *crossLogicOpFold =
+ foldBitOrderCrossLogicOp<Intrinsic::bswap>(IIOperand, Builder)) {
+ return crossLogicOpFold;
+ }
+
break;
}
case Intrinsic::masked_load:
@@ -1616,6 +1908,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Function *Bswap = Intrinsic::getDeclaration(Mod, Intrinsic::bswap, Ty);
return CallInst::Create(Bswap, { Op0 });
}
+ if (Instruction *BitOp =
+ matchBSwapOrBitReverse(*II, /*MatchBSwaps*/ true,
+ /*MatchBitReversals*/ true))
+ return BitOp;
}
// Left or right might be masked.
@@ -1983,7 +2279,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
case Intrinsic::copysign: {
Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1);
- if (SignBitMustBeZero(Sign, &TLI)) {
+ if (SignBitMustBeZero(Sign, DL, &TLI)) {
// If we know that the sign argument is positive, reduce to FABS:
// copysign Mag, +Sign --> fabs Mag
Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II);
@@ -2079,6 +2375,42 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::ldexp: {
+ // ldexp(ldexp(x, a), b) -> ldexp(x, a + b)
+ //
+ // The danger is if the first ldexp would overflow to infinity or underflow
+ // to zero, but the combined exponent avoids it. We ignore this with
+ // reassoc.
+ //
+ // It's also safe to fold if we know both exponents are >= 0 or <= 0 since
+ // it would just double down on the overflow/underflow which would occur
+ // anyway.
+ //
+ // TODO: Could do better if we had range tracking for the input value
+ // exponent. Also could broaden sign check to cover == 0 case.
+ Value *Src = II->getArgOperand(0);
+ Value *Exp = II->getArgOperand(1);
+ Value *InnerSrc;
+ Value *InnerExp;
+ if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ldexp>(
+ m_Value(InnerSrc), m_Value(InnerExp)))) &&
+ Exp->getType() == InnerExp->getType()) {
+ FastMathFlags FMF = II->getFastMathFlags();
+ FastMathFlags InnerFlags = cast<FPMathOperator>(Src)->getFastMathFlags();
+
+ if ((FMF.allowReassoc() && InnerFlags.allowReassoc()) ||
+ signBitMustBeTheSame(Exp, InnerExp, II, DL, &AC, &DT)) {
+ // TODO: Add nsw/nuw probably safe if integer type exceeds exponent
+ // width.
+ Value *NewExp = Builder.CreateAdd(InnerExp, Exp);
+ II->setArgOperand(1, NewExp);
+ II->setFastMathFlags(InnerFlags); // Or the inner flags.
+ return replaceOperand(*II, 0, InnerSrc);
+ }
+ }
+
+ break;
+ }
case Intrinsic::ptrauth_auth:
case Intrinsic::ptrauth_resign: {
// (sign|resign) + (auth|resign) can be folded by omitting the middle
@@ -2380,12 +2712,34 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
isValidAssumeForContext(II, LHS, &DT)) {
MDNode *MD = MDNode::get(II->getContext(), std::nullopt);
LHS->setMetadata(LLVMContext::MD_nonnull, MD);
+ LHS->setMetadata(LLVMContext::MD_noundef, MD);
return RemoveConditionFromAssume(II);
// TODO: apply nonnull return attributes to calls and invokes
// TODO: apply range metadata for range check patterns?
}
+ // Separate storage assumptions apply to the underlying allocations, not any
+ // particular pointer within them. When evaluating the hints for AA purposes
+ // we getUnderlyingObject them; by precomputing the answers here we can
+ // avoid having to do so repeatedly there.
+ for (unsigned Idx = 0; Idx < II->getNumOperandBundles(); Idx++) {
+ OperandBundleUse OBU = II->getOperandBundleAt(Idx);
+ if (OBU.getTagName() == "separate_storage") {
+ assert(OBU.Inputs.size() == 2);
+ auto MaybeSimplifyHint = [&](const Use &U) {
+ Value *Hint = U.get();
+ // Not having a limit is safe because InstCombine removes unreachable
+ // code.
+ Value *UnderlyingObject = getUnderlyingObject(Hint, /*MaxLookup*/ 0);
+ if (Hint != UnderlyingObject)
+ replaceUse(const_cast<Use &>(U), UnderlyingObject);
+ };
+ MaybeSimplifyHint(OBU.Inputs[0]);
+ MaybeSimplifyHint(OBU.Inputs[1]);
+ }
+ }
+
// Convert nonnull assume like:
// %A = icmp ne i32* %PTR, null
// call void @llvm.assume(i1 %A)
@@ -2479,6 +2833,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Known.isAllOnes() && isAssumeWithEmptyBundle(cast<AssumeInst>(*II)))
return eraseInstFromFunction(*II);
+ // assume(false) is unreachable.
+ if (match(IIOperand, m_CombineOr(m_Zero(), m_Undef()))) {
+ CreateNonTerminatorUnreachable(II);
+ return eraseInstFromFunction(*II);
+ }
+
// Update the cache of affected values for this assumption (we might be
// here because we just simplified the condition).
AC.updateAffectedValues(cast<AssumeInst>(II));
@@ -2545,7 +2905,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
for (i = 0; i != SubVecNumElts; ++i)
WidenMask.push_back(i);
for (; i != VecNumElts; ++i)
- WidenMask.push_back(UndefMaskElem);
+ WidenMask.push_back(PoisonMaskElem);
Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask);
@@ -2840,7 +3200,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
int Sz = Mask.size();
SmallBitVector UsedIndices(Sz);
for (int Idx : Mask) {
- if (Idx == UndefMaskElem || UsedIndices.test(Idx))
+ if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
break;
UsedIndices.set(Idx);
}
@@ -2852,6 +3212,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::is_fpclass: {
+ if (Instruction *I = foldIntrinsicIsFPClass(*II))
+ return I;
+ break;
+ }
default: {
// Handle target specific intrinsics
std::optional<Instruction *> V = targetInstCombineIntrinsic(*II);
@@ -2861,6 +3226,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
+ // Try to fold intrinsic into select operands. This is legal if:
+ // * The intrinsic is speculatable.
+ // * The select condition is not a vector, or the intrinsic does not
+ // perform cross-lane operations.
+ switch (IID) {
+ case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::ctpop:
+ case Intrinsic::umin:
+ case Intrinsic::umax:
+ case Intrinsic::smin:
+ case Intrinsic::smax:
+ case Intrinsic::usub_sat:
+ case Intrinsic::uadd_sat:
+ case Intrinsic::ssub_sat:
+ case Intrinsic::sadd_sat:
+ for (Value *Op : II->args())
+ if (auto *Sel = dyn_cast<SelectInst>(Op))
+ if (Instruction *R = FoldOpIntoSelect(*II, Sel))
+ return R;
+ [[fallthrough]];
+ default:
+ break;
+ }
+
if (Instruction *Shuf = foldShuffledIntrinsicOperands(II, Builder))
return Shuf;
@@ -2907,49 +3297,6 @@ Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) {
return visitCallBase(CBI);
}
-/// If this cast does not affect the value passed through the varargs area, we
-/// can eliminate the use of the cast.
-static bool isSafeToEliminateVarargsCast(const CallBase &Call,
- const DataLayout &DL,
- const CastInst *const CI,
- const int ix) {
- if (!CI->isLosslessCast())
- return false;
-
- // If this is a GC intrinsic, avoid munging types. We need types for
- // statepoint reconstruction in SelectionDAG.
- // TODO: This is probably something which should be expanded to all
- // intrinsics since the entire point of intrinsics is that
- // they are understandable by the optimizer.
- if (isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) ||
- isa<GCResultInst>(Call))
- return false;
-
- // Opaque pointers are compatible with any byval types.
- PointerType *SrcTy = cast<PointerType>(CI->getOperand(0)->getType());
- if (SrcTy->isOpaque())
- return true;
-
- // The size of ByVal or InAlloca arguments is derived from the type, so we
- // can't change to a type with a different size. If the size were
- // passed explicitly we could avoid this check.
- if (!Call.isPassPointeeByValueArgument(ix))
- return true;
-
- // The transform currently only handles type replacement for byval, not other
- // type-carrying attributes.
- if (!Call.isByValArgument(ix))
- return false;
-
- Type *SrcElemTy = SrcTy->getNonOpaquePointerElementType();
- Type *DstElemTy = Call.getParamByValType(ix);
- if (!SrcElemTy->isSized() || !DstElemTy->isSized())
- return false;
- if (DL.getTypeAllocSize(SrcElemTy) != DL.getTypeAllocSize(DstElemTy))
- return false;
- return true;
-}
-
Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) {
if (!CI->getCalledFunction()) return nullptr;
@@ -2965,7 +3312,7 @@ Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) {
auto InstCombineErase = [this](Instruction *I) {
eraseInstFromFunction(*I);
};
- LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW,
+ LibCallSimplifier Simplifier(DL, &TLI, &AC, ORE, BFI, PSI, InstCombineRAUW,
InstCombineErase);
if (Value *With = Simplifier.optimizeCall(CI, Builder)) {
++NumSimplified;
@@ -3198,32 +3545,6 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
if (IntrinsicInst *II = findInitTrampoline(Callee))
return transformCallThroughTrampoline(Call, *II);
- // TODO: Drop this transform once opaque pointer transition is done.
- FunctionType *FTy = Call.getFunctionType();
- if (FTy->isVarArg()) {
- int ix = FTy->getNumParams();
- // See if we can optimize any arguments passed through the varargs area of
- // the call.
- for (auto I = Call.arg_begin() + FTy->getNumParams(), E = Call.arg_end();
- I != E; ++I, ++ix) {
- CastInst *CI = dyn_cast<CastInst>(*I);
- if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) {
- replaceUse(*I, CI->getOperand(0));
-
- // Update the byval type to match the pointer type.
- // Not necessary for opaque pointers.
- PointerType *NewTy = cast<PointerType>(CI->getOperand(0)->getType());
- if (!NewTy->isOpaque() && Call.isByValArgument(ix)) {
- Call.removeParamAttr(ix, Attribute::ByVal);
- Call.addParamAttr(ix, Attribute::getWithByValType(
- Call.getContext(),
- NewTy->getNonOpaquePointerElementType()));
- }
- Changed = true;
- }
- }
- }
-
if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
InlineAsm *IA = cast<InlineAsm>(Callee);
if (!IA->canThrow()) {
@@ -3381,13 +3702,17 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
}
/// If the callee is a constexpr cast of a function, attempt to move the cast to
-/// the arguments of the call/callbr/invoke.
+/// the arguments of the call/invoke.
+/// CallBrInst is not supported.
bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
auto *Callee =
dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts());
if (!Callee)
return false;
+ assert(!isa<CallBrInst>(Call) &&
+ "CallBr's don't have a single point after a def to insert at");
+
// If this is a call to a thunk function, don't remove the cast. Thunks are
// used to transparently forward all incoming parameters and outgoing return
// values, so it's important to leave the cast in place.
@@ -3433,7 +3758,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
return false; // Attribute not compatible with transformed value.
}
- // If the callbase is an invoke/callbr instruction, and the return value is
+ // If the callbase is an invoke instruction, and the return value is
// used by a PHI node in a successor, we cannot change the return type of
// the call because there is no place to put the cast instruction (without
// breaking the critical edge). Bail out in this case.
@@ -3441,8 +3766,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
BasicBlock *PhisNotSupportedBlock = nullptr;
if (auto *II = dyn_cast<InvokeInst>(Caller))
PhisNotSupportedBlock = II->getNormalDest();
- if (auto *CB = dyn_cast<CallBrInst>(Caller))
- PhisNotSupportedBlock = CB->getDefaultDest();
if (PhisNotSupportedBlock)
for (User *U : Caller->users())
if (PHINode *PN = dyn_cast<PHINode>(U))
@@ -3490,24 +3813,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
if (CallerPAL.hasParamAttr(i, Attribute::ByVal) !=
Callee->getAttributes().hasParamAttr(i, Attribute::ByVal))
return false; // Cannot transform to or from byval.
-
- // If the parameter is passed as a byval argument, then we have to have a
- // sized type and the sized type has to have the same size as the old type.
- if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) {
- PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy);
- if (!ParamPTy)
- return false;
-
- if (!ParamPTy->isOpaque()) {
- Type *ParamElTy = ParamPTy->getNonOpaquePointerElementType();
- if (!ParamElTy->isSized())
- return false;
-
- Type *CurElTy = Call.getParamByValType(i);
- if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamElTy))
- return false;
- }
- }
}
if (Callee->isDeclaration()) {
@@ -3568,16 +3873,8 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
// type. Note that we made sure all incompatible ones are safe to drop.
AttributeMask IncompatibleAttrs = AttributeFuncs::typeIncompatible(
ParamTy, AttributeFuncs::ASK_SAFE_TO_DROP);
- if (CallerPAL.hasParamAttr(i, Attribute::ByVal) &&
- !ParamTy->isOpaquePointerTy()) {
- AttrBuilder AB(Ctx, CallerPAL.getParamAttrs(i).removeAttributes(
- Ctx, IncompatibleAttrs));
- AB.addByValAttr(ParamTy->getNonOpaquePointerElementType());
- ArgAttrs.push_back(AttributeSet::get(Ctx, AB));
- } else {
- ArgAttrs.push_back(
- CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs));
- }
+ ArgAttrs.push_back(
+ CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs));
}
// If the function takes more arguments than the call was taking, add them
@@ -3626,9 +3923,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) {
NewCall = Builder.CreateInvoke(Callee, II->getNormalDest(),
II->getUnwindDest(), Args, OpBundles);
- } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) {
- NewCall = Builder.CreateCallBr(Callee, CBI->getDefaultDest(),
- CBI->getIndirectDests(), Args, OpBundles);
} else {
NewCall = Builder.CreateCall(Callee, Args, OpBundles);
cast<CallInst>(NewCall)->setTailCallKind(
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 3f851a2b2182..5c84f666616d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -25,166 +25,6 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-/// Analyze 'Val', seeing if it is a simple linear expression.
-/// If so, decompose it, returning some value X, such that Val is
-/// X*Scale+Offset.
-///
-static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale,
- uint64_t &Offset) {
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) {
- Offset = CI->getZExtValue();
- Scale = 0;
- return ConstantInt::get(Val->getType(), 0);
- }
-
- if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) {
- // Cannot look past anything that might overflow.
- // We specifically require nuw because we store the Scale in an unsigned
- // and perform an unsigned divide on it.
- OverflowingBinaryOperator *OBI = dyn_cast<OverflowingBinaryOperator>(Val);
- if (OBI && !OBI->hasNoUnsignedWrap()) {
- Scale = 1;
- Offset = 0;
- return Val;
- }
-
- if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
- if (I->getOpcode() == Instruction::Shl) {
- // This is a value scaled by '1 << the shift amt'.
- Scale = UINT64_C(1) << RHS->getZExtValue();
- Offset = 0;
- return I->getOperand(0);
- }
-
- if (I->getOpcode() == Instruction::Mul) {
- // This value is scaled by 'RHS'.
- Scale = RHS->getZExtValue();
- Offset = 0;
- return I->getOperand(0);
- }
-
- if (I->getOpcode() == Instruction::Add) {
- // We have X+C. Check to see if we really have (X*C2)+C1,
- // where C1 is divisible by C2.
- unsigned SubScale;
- Value *SubVal =
- decomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset);
- Offset += RHS->getZExtValue();
- Scale = SubScale;
- return SubVal;
- }
- }
- }
-
- // Otherwise, we can't look past this.
- Scale = 1;
- Offset = 0;
- return Val;
-}
-
-/// If we find a cast of an allocation instruction, try to eliminate the cast by
-/// moving the type information into the alloc.
-Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI,
- AllocaInst &AI) {
- PointerType *PTy = cast<PointerType>(CI.getType());
- // Opaque pointers don't have an element type we could replace with.
- if (PTy->isOpaque())
- return nullptr;
-
- IRBuilderBase::InsertPointGuard Guard(Builder);
- Builder.SetInsertPoint(&AI);
-
- // Get the type really allocated and the type casted to.
- Type *AllocElTy = AI.getAllocatedType();
- Type *CastElTy = PTy->getNonOpaquePointerElementType();
- if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr;
-
- // This optimisation does not work for cases where the cast type
- // is scalable and the allocated type is not. This because we need to
- // know how many times the casted type fits into the allocated type.
- // For the opposite case where the allocated type is scalable and the
- // cast type is not this leads to poor code quality due to the
- // introduction of 'vscale' into the calculations. It seems better to
- // bail out for this case too until we've done a proper cost-benefit
- // analysis.
- bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy);
- bool CastIsScalable = isa<ScalableVectorType>(CastElTy);
- if (AllocIsScalable != CastIsScalable) return nullptr;
-
- Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy);
- Align CastElTyAlign = DL.getABITypeAlign(CastElTy);
- if (CastElTyAlign < AllocElTyAlign) return nullptr;
-
- // If the allocation has multiple uses, only promote it if we are strictly
- // increasing the alignment of the resultant allocation. If we keep it the
- // same, we open the door to infinite loops of various kinds.
- if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr;
-
- // The alloc and cast types should be either both fixed or both scalable.
- uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue();
- uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue();
- if (CastElTySize == 0 || AllocElTySize == 0) return nullptr;
-
- // If the allocation has multiple uses, only promote it if we're not
- // shrinking the amount of memory being allocated.
- uint64_t AllocElTyStoreSize =
- DL.getTypeStoreSize(AllocElTy).getKnownMinValue();
- uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue();
- if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr;
-
- // See if we can satisfy the modulus by pulling a scale out of the array
- // size argument.
- unsigned ArraySizeScale;
- uint64_t ArrayOffset;
- Value *NumElements = // See if the array size is a decomposable linear expr.
- decomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset);
-
- // If we can now satisfy the modulus, by using a non-1 scale, we really can
- // do the xform.
- if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 ||
- (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr;
-
- // We don't currently support arrays of scalable types.
- assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0));
-
- unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize;
- Value *Amt = nullptr;
- if (Scale == 1) {
- Amt = NumElements;
- } else {
- Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale);
- // Insert before the alloca, not before the cast.
- Amt = Builder.CreateMul(Amt, NumElements);
- }
-
- if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) {
- Value *Off = ConstantInt::get(AI.getArraySize()->getType(),
- Offset, true);
- Amt = Builder.CreateAdd(Amt, Off);
- }
-
- AllocaInst *New = Builder.CreateAlloca(CastElTy, AI.getAddressSpace(), Amt);
- New->setAlignment(AI.getAlign());
- New->takeName(&AI);
- New->setUsedWithInAlloca(AI.isUsedWithInAlloca());
- New->setMetadata(LLVMContext::MD_DIAssignID,
- AI.getMetadata(LLVMContext::MD_DIAssignID));
-
- replaceAllDbgUsesWith(AI, *New, *New, DT);
-
- // If the allocation has multiple real uses, insert a cast and change all
- // things that used it to use the new cast. This will also hack on CI, but it
- // will die soon.
- if (!AI.hasOneUse()) {
- // New is the allocation instruction, pointer typed. AI is the original
- // allocation instruction, also pointer typed. Thus, cast to use is BitCast.
- Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast");
- replaceInstUsesWith(AI, NewCast);
- eraseInstFromFunction(AI);
- }
- return replaceInstUsesWith(CI, New);
-}
-
/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
/// true for, actually insert the code to evaluate the expression.
Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
@@ -252,6 +92,20 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Res = CastInst::Create(
static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
break;
+ case Instruction::Call:
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ llvm_unreachable("Unsupported call!");
+ case Intrinsic::vscale: {
+ Function *Fn =
+ Intrinsic::getDeclaration(I->getModule(), Intrinsic::vscale, {Ty});
+ Res = CallInst::Create(Fn->getFunctionType(), Fn);
+ break;
+ }
+ }
+ }
+ break;
default:
// TODO: Can handle more cases here.
llvm_unreachable("Unreachable!");
@@ -294,6 +148,10 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
Type *Ty = CI.getType();
+ if (auto *SrcC = dyn_cast<Constant>(Src))
+ if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL))
+ return replaceInstUsesWith(CI, Res);
+
// Try to eliminate a cast of a cast.
if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast
if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) {
@@ -501,16 +359,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// If the integer type can hold the max FP value, it is safe to cast
// directly to that type. Otherwise, we may create poison via overflow
// that did not exist in the original code.
- //
- // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need
- // at least one more bit than the MaxExponent to hold the max FP value.
Type *InputTy = I->getOperand(0)->getType()->getScalarType();
const fltSemantics &Semantics = InputTy->getFltSemantics();
- uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics);
- // Extra sign bit needed.
- if (I->getOpcode() == Instruction::FPToSI)
- ++MinBitWidth;
- return Ty->getScalarSizeInBits() > MinBitWidth;
+ uint32_t MinBitWidth =
+ APFloatBase::semanticsIntSizeInBits(Semantics,
+ I->getOpcode() == Instruction::FPToSI);
+ return Ty->getScalarSizeInBits() >= MinBitWidth;
}
default:
// TODO: Can handle more cases here.
@@ -881,13 +735,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
}
- if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)),
+ if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_ImmConstant(C)),
m_Deferred(X))))) {
// trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0
Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1));
Constant *MaskC = ConstantExpr::getShl(One, C);
- MaskC = ConstantExpr::getOr(MaskC, One);
- Value *And = Builder.CreateAnd(X, MaskC);
+ Value *And = Builder.CreateAnd(X, Builder.CreateOr(MaskC, One));
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
}
}
@@ -904,11 +757,18 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
// removed by the trunc.
if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
APInt(SrcWidth, MaxShiftAmt)))) {
+ auto GetNewShAmt = [&](unsigned Width) {
+ Constant *MaxAmt = ConstantInt::get(SrcTy, Width - 1, false);
+ Constant *Cmp =
+ ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, C, MaxAmt, DL);
+ Constant *ShAmt = ConstantFoldSelectInstruction(Cmp, C, MaxAmt);
+ return ConstantFoldCastOperand(Instruction::Trunc, ShAmt, A->getType(),
+ DL);
+ };
+
// trunc (lshr (sext A), C) --> ashr A, C
if (A->getType() == DestTy) {
- Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
- Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
- ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ Constant *ShAmt = GetNewShAmt(DestWidth);
ShAmt = Constant::mergeUndefsWith(ShAmt, C);
return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
: BinaryOperator::CreateAShr(A, ShAmt);
@@ -916,9 +776,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
// The types are mismatched, so create a cast after shifting:
// trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
if (Src->hasOneUse()) {
- Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false);
- Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
- ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ Constant *ShAmt = GetNewShAmt(AWidth);
Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact);
return CastInst::CreateIntegerCast(Shift, DestTy, true);
}
@@ -998,7 +856,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
}
}
- if (match(Src, m_VScale(DL))) {
+ if (match(Src, m_VScale())) {
if (Trunc.getFunction() &&
Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
@@ -1217,6 +1075,13 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
return false;
return true;
}
+ case Instruction::Call:
+ // llvm.vscale() can always be executed in larger type, because the
+ // value is automatically zero-extended.
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
+ if (II->getIntrinsicID() == Intrinsic::vscale)
+ return true;
+ return false;
default:
// TODO: Can handle more cases here.
return false;
@@ -1226,7 +1091,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
// If this zero extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this zext.
- if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()))
+ if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()) &&
+ !isa<Constant>(Zext.getOperand(0)))
return nullptr;
// If one of the common conversion will work, do it.
@@ -1340,7 +1206,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
return BinaryOperator::CreateAnd(X, ZextC);
}
- if (match(Src, m_VScale(DL))) {
+ if (match(Src, m_VScale())) {
if (Zext.getFunction() &&
Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
@@ -1402,7 +1268,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp,
if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) {
// sext ((x & 2^n) == 0) -> (x >> n) - 1
// sext ((x & 2^n) != 2^n) -> (x >> n) - 1
- unsigned ShiftAmt = KnownZeroMask.countTrailingZeros();
+ unsigned ShiftAmt = KnownZeroMask.countr_zero();
// Perform a right shift to place the desired bit in the LSB.
if (ShiftAmt)
In = Builder.CreateLShr(In,
@@ -1416,7 +1282,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp,
} else {
// sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1
// sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1
- unsigned ShiftAmt = KnownZeroMask.countLeadingZeros();
+ unsigned ShiftAmt = KnownZeroMask.countl_zero();
// Perform a left shift to place the desired bit in the MSB.
if (ShiftAmt)
In = Builder.CreateShl(In,
@@ -1611,7 +1477,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
}
}
- if (match(Src, m_VScale(DL))) {
+ if (match(Src, m_VScale())) {
if (Sext.getFunction() &&
Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
@@ -2687,57 +2553,6 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI,
return RetVal;
}
-static Instruction *convertBitCastToGEP(BitCastInst &CI, IRBuilderBase &Builder,
- const DataLayout &DL) {
- Value *Src = CI.getOperand(0);
- PointerType *SrcPTy = cast<PointerType>(Src->getType());
- PointerType *DstPTy = cast<PointerType>(CI.getType());
-
- // Bitcasts involving opaque pointers cannot be converted into a GEP.
- if (SrcPTy->isOpaque() || DstPTy->isOpaque())
- return nullptr;
-
- Type *DstElTy = DstPTy->getNonOpaquePointerElementType();
- Type *SrcElTy = SrcPTy->getNonOpaquePointerElementType();
-
- // When the type pointed to is not sized the cast cannot be
- // turned into a gep.
- if (!SrcElTy->isSized())
- return nullptr;
-
- // If the source and destination are pointers, and this cast is equivalent
- // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep.
- // This can enhance SROA and other transforms that want type-safe pointers.
- unsigned NumZeros = 0;
- while (SrcElTy && SrcElTy != DstElTy) {
- SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0);
- ++NumZeros;
- }
-
- // If we found a path from the src to dest, create the getelementptr now.
- if (SrcElTy == DstElTy) {
- SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0));
- GetElementPtrInst *GEP = GetElementPtrInst::Create(
- SrcPTy->getNonOpaquePointerElementType(), Src, Idxs);
-
- // If the source pointer is dereferenceable, then assume it points to an
- // allocated object and apply "inbounds" to the GEP.
- bool CanBeNull, CanBeFreed;
- if (Src->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed)) {
- // In a non-default address space (not 0), a null pointer can not be
- // assumed inbounds, so ignore that case (dereferenceable_or_null).
- // The reason is that 'null' is not treated differently in these address
- // spaces, and we consequently ignore the 'gep inbounds' special case
- // for 'null' which allows 'inbounds' on 'null' if the indices are
- // zeros.
- if (SrcPTy->getAddressSpace() == 0 || !CanBeNull)
- GEP->setIsInBounds();
- }
- return GEP;
- }
- return nullptr;
-}
-
Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
// If the operands are integer typed then apply the integer transforms,
// otherwise just apply the common ones.
@@ -2750,19 +2565,6 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
if (DestTy == Src->getType())
return replaceInstUsesWith(CI, Src);
- if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) {
- // If we are casting a alloca to a pointer to a type of the same
- // size, rewrite the allocation instruction to allocate the "right" type.
- // There is no need to modify malloc calls because it is their bitcast that
- // needs to be cleaned up.
- if (AllocaInst *AI = dyn_cast<AllocaInst>(Src))
- if (Instruction *V = PromoteCastOfAllocation(CI, *AI))
- return V;
-
- if (Instruction *I = convertBitCastToGEP(CI, Builder, DL))
- return I;
- }
-
if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) {
// Beware: messing with this target-specific oddity may cause trouble.
if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) {
@@ -2905,23 +2707,5 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
}
Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
- // If the destination pointer element type is not the same as the source's
- // first do a bitcast to the destination type, and then the addrspacecast.
- // This allows the cast to be exposed to other transforms.
- Value *Src = CI.getOperand(0);
- PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType());
- PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType());
-
- if (!SrcTy->hasSameElementTypeAs(DestTy)) {
- Type *MidTy =
- PointerType::getWithSamePointeeType(DestTy, SrcTy->getAddressSpace());
- // Handle vectors of pointers.
- if (VectorType *VT = dyn_cast<VectorType>(CI.getType()))
- MidTy = VectorType::get(MidTy, VT->getElementCount());
-
- Value *NewBitCast = Builder.CreateBitCast(Src, MidTy);
- return new AddrSpaceCastInst(NewBitCast, CI.getType());
- }
-
return commonPointerCastTransforms(CI);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 1480a0ff9e2f..656f04370e17 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -198,7 +199,11 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
}
// If the element is masked, handle it.
- if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst);
+ if (AndCst) {
+ Elt = ConstantFoldBinaryOpOperands(Instruction::And, Elt, AndCst, DL);
+ if (!Elt)
+ return nullptr;
+ }
// Find out if the comparison would be true or false for the i'th element.
Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt,
@@ -276,14 +281,14 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
// order the state machines in complexity of the generated code.
Value *Idx = GEP->getOperand(2);
- // If the index is larger than the pointer size of the target, truncate the
- // index down like the GEP would do implicitly. We don't have to do this for
- // an inbounds GEP because the index can't be out of range.
+ // If the index is larger than the pointer offset size of the target, truncate
+ // the index down like the GEP would do implicitly. We don't have to do this
+ // for an inbounds GEP because the index can't be out of range.
if (!GEP->isInBounds()) {
- Type *IntPtrTy = DL.getIntPtrType(GEP->getType());
- unsigned PtrSize = IntPtrTy->getIntegerBitWidth();
- if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize)
- Idx = Builder.CreateTrunc(Idx, IntPtrTy);
+ Type *PtrIdxTy = DL.getIndexType(GEP->getType());
+ unsigned OffsetSize = PtrIdxTy->getIntegerBitWidth();
+ if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > OffsetSize)
+ Idx = Builder.CreateTrunc(Idx, PtrIdxTy);
}
// If inbounds keyword is not present, Idx * ElementSize can overflow.
@@ -295,10 +300,10 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
// We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
unsigned ElementSize =
DL.getTypeAllocSize(Init->getType()->getArrayElementType());
- auto MaskIdx = [&](Value* Idx){
- if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) {
+ auto MaskIdx = [&](Value *Idx) {
+ if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) {
Value *Mask = ConstantInt::get(Idx->getType(), -1);
- Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize));
+ Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize));
Idx = Builder.CreateAnd(Idx, Mask);
}
return Idx;
@@ -533,7 +538,8 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
/// pointer.
static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
const DataLayout &DL,
- SetVector<Value *> &Explored) {
+ SetVector<Value *> &Explored,
+ InstCombiner &IC) {
// Perform all the substitutions. This is a bit tricky because we can
// have cycles in our use-def chains.
// 1. Create the PHI nodes without any incoming values.
@@ -562,7 +568,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
// Create all the other instructions.
for (Value *Val : Explored) {
- if (NewInsts.find(Val) != NewInsts.end())
+ if (NewInsts.contains(Val))
continue;
if (auto *CI = dyn_cast<CastInst>(Val)) {
@@ -610,7 +616,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) {
Value *NewIncoming = PHI->getIncomingValue(I);
- if (NewInsts.find(NewIncoming) != NewInsts.end())
+ if (NewInsts.contains(NewIncoming))
NewIncoming = NewInsts[NewIncoming];
NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I));
@@ -635,7 +641,10 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
Val->getName() + ".ptr");
NewVal = Builder.CreateBitOrPointerCast(
NewVal, Val->getType(), Val->getName() + ".conv");
- Val->replaceAllUsesWith(NewVal);
+ IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal);
+ // Add old instruction to worklist for DCE. We don't directly remove it
+ // here because the original compare is one of the users.
+ IC.addToWorklist(cast<Instruction>(Val));
}
return NewInsts[Start];
@@ -688,7 +697,8 @@ getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) {
/// between GEPLHS and RHS.
static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
ICmpInst::Predicate Cond,
- const DataLayout &DL) {
+ const DataLayout &DL,
+ InstCombiner &IC) {
// FIXME: Support vector of pointers.
if (GEPLHS->getType()->isVectorTy())
return nullptr;
@@ -712,7 +722,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
// can't have overflow on either side. We can therefore re-write
// this as:
// OFFSET1 cmp OFFSET2
- Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes);
+ Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC);
// RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written
// GEP having PtrBase as the pointer base, and has returned in NewRHS the
@@ -740,7 +750,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
RHS = RHS->stripPointerCasts();
Value *PtrBase = GEPLHS->getOperand(0);
- if (PtrBase == RHS && GEPLHS->isInBounds()) {
+ if (PtrBase == RHS && (GEPLHS->isInBounds() || ICmpInst::isEquality(Cond))) {
// ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0).
Value *Offset = EmitGEPOffset(GEPLHS);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset,
@@ -831,7 +841,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// Otherwise, the base pointers are different and the indices are
// different. Try convert this to an indexed compare by looking through
// PHIs/casts.
- return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
+ return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this);
}
// If one of the GEPs has all zero indices, recurse.
@@ -883,7 +893,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// Only lower this if the icmp is the only user of the GEP or if we expect
// the result to fold to a constant!
- if (GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) &&
+ if ((GEPsInBounds || CmpInst::isEquality(Cond)) &&
+ (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) &&
(isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) {
// ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2)
Value *L = EmitGEPOffset(GEPLHS);
@@ -894,13 +905,10 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// Try convert this to an indexed compare by looking through PHIs/casts as a
// last resort.
- return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
+ return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this);
}
-Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
- const AllocaInst *Alloca) {
- assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
-
+bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
// It would be tempting to fold away comparisons between allocas and any
// pointer not based on that alloca (e.g. an argument). However, even
// though such pointers cannot alias, they can still compare equal.
@@ -909,67 +917,72 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
// doesn't escape we can argue that it's impossible to guess its value, and we
// can therefore act as if any such guesses are wrong.
//
- // The code below checks that the alloca doesn't escape, and that it's only
- // used in a comparison once (the current instruction). The
- // single-comparison-use condition ensures that we're trivially folding all
- // comparisons against the alloca consistently, and avoids the risk of
- // erroneously folding a comparison of the pointer with itself.
-
- unsigned MaxIter = 32; // Break cycles and bound to constant-time.
+ // However, we need to ensure that this folding is consistent: We can't fold
+ // one comparison to false, and then leave a different comparison against the
+ // same value alone (as it might evaluate to true at runtime, leading to a
+ // contradiction). As such, this code ensures that all comparisons are folded
+ // at the same time, and there are no other escapes.
+
+ struct CmpCaptureTracker : public CaptureTracker {
+ AllocaInst *Alloca;
+ bool Captured = false;
+ /// The value of the map is a bit mask of which icmp operands the alloca is
+ /// used in.
+ SmallMapVector<ICmpInst *, unsigned, 4> ICmps;
+
+ CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {}
+
+ void tooManyUses() override { Captured = true; }
+
+ bool captured(const Use *U) override {
+ auto *ICmp = dyn_cast<ICmpInst>(U->getUser());
+ // We need to check that U is based *only* on the alloca, and doesn't
+ // have other contributions from a select/phi operand.
+ // TODO: We could check whether getUnderlyingObjects() reduces to one
+ // object, which would allow looking through phi nodes.
+ if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) {
+ // Collect equality icmps of the alloca, and don't treat them as
+ // captures.
+ auto Res = ICmps.insert({ICmp, 0});
+ Res.first->second |= 1u << U->getOperandNo();
+ return false;
+ }
- SmallVector<const Use *, 32> Worklist;
- for (const Use &U : Alloca->uses()) {
- if (Worklist.size() >= MaxIter)
- return nullptr;
- Worklist.push_back(&U);
- }
+ Captured = true;
+ return true;
+ }
+ };
- unsigned NumCmps = 0;
- while (!Worklist.empty()) {
- assert(Worklist.size() <= MaxIter);
- const Use *U = Worklist.pop_back_val();
- const Value *V = U->getUser();
- --MaxIter;
+ CmpCaptureTracker Tracker(Alloca);
+ PointerMayBeCaptured(Alloca, &Tracker);
+ if (Tracker.Captured)
+ return false;
- if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) ||
- isa<SelectInst>(V)) {
- // Track the uses.
- } else if (isa<LoadInst>(V)) {
- // Loading from the pointer doesn't escape it.
- continue;
- } else if (const auto *SI = dyn_cast<StoreInst>(V)) {
- // Storing *to* the pointer is fine, but storing the pointer escapes it.
- if (SI->getValueOperand() == U->get())
- return nullptr;
- continue;
- } else if (isa<ICmpInst>(V)) {
- if (NumCmps++)
- return nullptr; // Found more than one cmp.
- continue;
- } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) {
- switch (Intrin->getIntrinsicID()) {
- // These intrinsics don't escape or compare the pointer. Memset is safe
- // because we don't allow ptrtoint. Memcpy and memmove are safe because
- // we don't allow stores, so src cannot point to V.
- case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
- case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
- continue;
- default:
- return nullptr;
- }
- } else {
- return nullptr;
+ bool Changed = false;
+ for (auto [ICmp, Operands] : Tracker.ICmps) {
+ switch (Operands) {
+ case 1:
+ case 2: {
+ // The alloca is only used in one icmp operand. Assume that the
+ // equality is false.
+ auto *Res = ConstantInt::get(
+ ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE);
+ replaceInstUsesWith(*ICmp, Res);
+ eraseInstFromFunction(*ICmp);
+ Changed = true;
+ break;
}
- for (const Use &U : V->uses()) {
- if (Worklist.size() >= MaxIter)
- return nullptr;
- Worklist.push_back(&U);
+ case 3:
+ // Both icmp operands are based on the alloca, so this is comparing
+ // pointer offsets, without leaking any information about the address
+ // of the alloca. Ignore such comparisons.
+ break;
+ default:
+ llvm_unreachable("Cannot happen");
}
}
- auto *Res = ConstantInt::get(ICI.getType(),
- !CmpInst::isTrueWhenEqual(ICI.getPredicate()));
- return replaceInstUsesWith(ICI, Res);
+ return Changed;
}
/// Fold "icmp pred (X+C), X".
@@ -1058,9 +1071,9 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A,
int Shift;
if (IsAShr && AP1.isNegative())
- Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes();
+ Shift = AP1.countl_one() - AP2.countl_one();
else
- Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros();
+ Shift = AP1.countl_zero() - AP2.countl_zero();
if (Shift > 0) {
if (IsAShr && AP1 == AP2.ashr(Shift)) {
@@ -1097,7 +1110,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A,
if (AP2.isZero())
return nullptr;
- unsigned AP2TrailingZeros = AP2.countTrailingZeros();
+ unsigned AP2TrailingZeros = AP2.countr_zero();
if (!AP1 && AP2TrailingZeros != 0)
return getICmp(
@@ -1108,7 +1121,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A,
return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
// Get the distance between the lowest bits that are set.
- int Shift = AP1.countTrailingZeros() - AP2TrailingZeros;
+ int Shift = AP1.countr_zero() - AP2TrailingZeros;
if (Shift > 0 && AP2.shl(Shift) == AP1)
return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
@@ -1143,7 +1156,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
// If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow.
if (!CI2->getValue().isPowerOf2())
return nullptr;
- unsigned NewWidth = CI2->getValue().countTrailingZeros();
+ unsigned NewWidth = CI2->getValue().countr_zero();
if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31)
return nullptr;
@@ -1295,6 +1308,48 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
return new ICmpInst(Pred, X, Cmp.getOperand(1));
}
+ // (icmp eq/ne (mul X Y)) -> (icmp eq/ne X/Y) if we know about whether X/Y are
+ // odd/non-zero/there is no overflow.
+ if (match(Cmp.getOperand(0), m_Mul(m_Value(X), m_Value(Y))) &&
+ ICmpInst::isEquality(Pred)) {
+
+ KnownBits XKnown = computeKnownBits(X, 0, &Cmp);
+ // if X % 2 != 0
+ // (icmp eq/ne Y)
+ if (XKnown.countMaxTrailingZeros() == 0)
+ return new ICmpInst(Pred, Y, Cmp.getOperand(1));
+
+ KnownBits YKnown = computeKnownBits(Y, 0, &Cmp);
+ // if Y % 2 != 0
+ // (icmp eq/ne X)
+ if (YKnown.countMaxTrailingZeros() == 0)
+ return new ICmpInst(Pred, X, Cmp.getOperand(1));
+
+ auto *BO0 = cast<OverflowingBinaryOperator>(Cmp.getOperand(0));
+ if (BO0->hasNoUnsignedWrap() || BO0->hasNoSignedWrap()) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Cmp);
+ // `isKnownNonZero` does more analysis than just `!KnownBits.One.isZero()`
+ // but to avoid unnecessary work, first just if this is an obvious case.
+
+ // if X non-zero and NoOverflow(X * Y)
+ // (icmp eq/ne Y)
+ if (!XKnown.One.isZero() || isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(Pred, Y, Cmp.getOperand(1));
+
+ // if Y non-zero and NoOverflow(X * Y)
+ // (icmp eq/ne X)
+ if (!YKnown.One.isZero() || isKnownNonZero(Y, DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(Pred, X, Cmp.getOperand(1));
+ }
+ // Note, we are skipping cases:
+ // if Y % 2 != 0 AND X % 2 != 0
+ // (false/true)
+ // if X non-zero and Y non-zero and NoOverflow(X * Y)
+ // (false/true)
+ // Those can be simplified later as we would have already replaced the (icmp
+ // eq/ne (mul X, Y)) with (icmp eq/ne X/Y) and if X/Y is known non-zero that
+ // will fold to a constant elsewhere.
+ }
return nullptr;
}
@@ -1331,17 +1386,18 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
if (auto *Phi = dyn_cast<PHINode>(Op0))
if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
- Type *Ty = Cmp.getType();
- Builder.SetInsertPoint(Phi);
- PHINode *NewPhi =
- Builder.CreatePHI(Ty, Phi->getNumOperands());
- for (BasicBlock *Predecessor : predecessors(Phi->getParent())) {
- auto *Input =
- cast<Constant>(Phi->getIncomingValueForBlock(Predecessor));
- auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C);
- NewPhi->addIncoming(BoolInput, Predecessor);
+ SmallVector<Constant *> Ops;
+ for (Value *V : Phi->incoming_values()) {
+ Constant *Res =
+ ConstantFoldCompareInstOperands(Pred, cast<Constant>(V), C, DL);
+ if (!Res)
+ return nullptr;
+ Ops.push_back(Res);
}
- NewPhi->takeName(&Cmp);
+ Builder.SetInsertPoint(Phi);
+ PHINode *NewPhi = Builder.CreatePHI(Cmp.getType(), Phi->getNumOperands());
+ for (auto [V, Pred] : zip(Ops, Phi->blocks()))
+ NewPhi->addIncoming(V, Pred);
return replaceInstUsesWith(Cmp, NewPhi);
}
@@ -1369,11 +1425,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
if (TrueBB == FalseBB)
return nullptr;
- // Try to simplify this compare to T/F based on the dominating condition.
- std::optional<bool> Imp =
- isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB);
- if (Imp)
- return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp));
+ // We already checked simple implication in InstSimplify, only handle complex
+ // cases here.
CmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1);
@@ -1475,7 +1528,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
KnownBits Known = computeKnownBits(X, 0, &Cmp);
// If all the high bits are known, we can do this xform.
- if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) {
+ if ((Known.Zero | Known.One).countl_one() >= SrcBits - DstBits) {
// Pull in the high bits from known-ones set.
APInt NewRHS = C.zext(SrcBits);
NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
@@ -1781,17 +1834,12 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
++UsesRemoved;
// Compute A & ((1 << B) | 1)
- Value *NewOr = nullptr;
- if (auto *C = dyn_cast<Constant>(B)) {
- if (UsesRemoved >= 1)
- NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
- } else {
- if (UsesRemoved >= 3)
- NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(),
- /*HasNUW=*/true),
- One, Or->getName());
- }
- if (NewOr) {
+ unsigned RequireUsesRemoved = match(B, m_ImmConstant()) ? 1 : 3;
+ if (UsesRemoved >= RequireUsesRemoved) {
+ Value *NewOr =
+ Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(),
+ /*HasNUW=*/true),
+ One, Or->getName());
Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName());
return replaceOperand(Cmp, 0, NewAnd);
}
@@ -1819,6 +1867,15 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType()));
}
+ // (X & X) < 0 --> X == MinSignedC
+ // (X & X) > -1 --> X != MinSignedC
+ if (match(And, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) {
+ Constant *MinSignedC = ConstantInt::get(
+ X->getType(),
+ APInt::getSignedMinValue(X->getType()->getScalarSizeInBits()));
+ auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
+ return new ICmpInst(NewPred, X, MinSignedC);
+ }
}
// TODO: These all require that Y is constant too, so refactor with the above.
@@ -1846,6 +1903,30 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1))));
}
+ // If we are testing the intersection of 2 select-of-nonzero-constants with no
+ // common bits set, it's the same as checking if exactly one select condition
+ // is set:
+ // ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B
+ // ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B)
+ // TODO: Generalize for non-constant values.
+ // TODO: Handle signed/unsigned predicates.
+ // TODO: Handle other bitwise logic connectors.
+ // TODO: Extend to handle a non-zero compare constant.
+ if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) {
+ assert(Cmp.isEquality() && "Not expecting non-equality predicates");
+ Value *A, *B;
+ const APInt *TC, *FC;
+ if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) &&
+ match(Y,
+ m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) &&
+ !TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) {
+ Value *R = Builder.CreateXor(A, B);
+ if (Pred == CmpInst::ICMP_NE)
+ R = Builder.CreateNot(R);
+ return replaceInstUsesWith(Cmp, R);
+ }
+ }
+
// ((zext i1 X) & Y) == 0 --> !((trunc Y) & X)
// ((zext i1 X) & Y) != 0 --> ((trunc Y) & X)
// ((zext i1 X) & Y) == 1 --> ((trunc Y) & X)
@@ -1863,6 +1944,59 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
return nullptr;
}
+/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0.
+static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or,
+ InstCombiner::BuilderTy &Builder) {
+ // Are we using xors to bitwise check for a pair or pairs of (in)equalities?
+ // Convert to a shorter form that has more potential to be folded even
+ // further.
+ // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4)
+ // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4)
+ // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 -->
+ // (X1 == X2) && (X3 == X4) && (X5 == X6)
+ // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 -->
+ // (X1 != X2) || (X3 != X4) || (X5 != X6)
+ // TODO: Implement for sub
+ SmallVector<std::pair<Value *, Value *>, 2> CmpValues;
+ SmallVector<Value *, 16> WorkList(1, Or);
+
+ while (!WorkList.empty()) {
+ auto MatchOrOperatorArgument = [&](Value *OrOperatorArgument) {
+ Value *Lhs, *Rhs;
+
+ if (match(OrOperatorArgument,
+ m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) {
+ CmpValues.emplace_back(Lhs, Rhs);
+ } else {
+ WorkList.push_back(OrOperatorArgument);
+ }
+ };
+
+ Value *CurrentValue = WorkList.pop_back_val();
+ Value *OrOperatorLhs, *OrOperatorRhs;
+
+ if (!match(CurrentValue,
+ m_Or(m_Value(OrOperatorLhs), m_Value(OrOperatorRhs)))) {
+ return nullptr;
+ }
+
+ MatchOrOperatorArgument(OrOperatorRhs);
+ MatchOrOperatorArgument(OrOperatorLhs);
+ }
+
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
+ auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
+ Value *LhsCmp = Builder.CreateICmp(Pred, CmpValues.rbegin()->first,
+ CmpValues.rbegin()->second);
+
+ for (auto It = CmpValues.rbegin() + 1; It != CmpValues.rend(); ++It) {
+ Value *RhsCmp = Builder.CreateICmp(Pred, It->first, It->second);
+ LhsCmp = Builder.CreateBinOp(BOpc, LhsCmp, RhsCmp);
+ }
+
+ return LhsCmp;
+}
+
/// Fold icmp (or X, Y), C.
Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
BinaryOperator *Or,
@@ -1909,6 +2043,30 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
return new ICmpInst(NewPred, X, NewC);
}
+ const APInt *OrC;
+ // icmp(X | OrC, C) --> icmp(X, 0)
+ if (C.isNonNegative() && match(Or, m_Or(m_Value(X), m_APInt(OrC)))) {
+ switch (Pred) {
+ // X | OrC s< C --> X s< 0 iff OrC s>= C s>= 0
+ case ICmpInst::ICMP_SLT:
+ // X | OrC s>= C --> X s>= 0 iff OrC s>= C s>= 0
+ case ICmpInst::ICMP_SGE:
+ if (OrC->sge(C))
+ return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
+ break;
+ // X | OrC s<= C --> X s< 0 iff OrC s> C s>= 0
+ case ICmpInst::ICMP_SLE:
+ // X | OrC s> C --> X s>= 0 iff OrC s> C s>= 0
+ case ICmpInst::ICMP_SGT:
+ if (OrC->sgt(C))
+ return new ICmpInst(ICmpInst::getFlippedStrictnessPredicate(Pred), X,
+ ConstantInt::getNullValue(X->getType()));
+ break;
+ default:
+ break;
+ }
+ }
+
if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse())
return nullptr;
@@ -1924,18 +2082,8 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
return BinaryOperator::Create(BOpc, CmpP, CmpQ);
}
- // Are we using xors to bitwise check for a pair of (in)equalities? Convert to
- // a shorter form that has more potential to be folded even further.
- Value *X1, *X2, *X3, *X4;
- if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) &&
- match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) {
- // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4)
- // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4)
- Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2);
- Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4);
- auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
- return BinaryOperator::Create(BOpc, Cmp12, Cmp34);
- }
+ if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder))
+ return replaceInstUsesWith(Cmp, V);
return nullptr;
}
@@ -1969,21 +2117,29 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy));
}
- if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap()))
+ if (MulC->isZero())
return nullptr;
- // If the multiply does not wrap, try to divide the compare constant by the
- // multiplication factor.
+ // If the multiply does not wrap or the constant is odd, try to divide the
+ // compare constant by the multiplication factor.
if (Cmp.isEquality()) {
- // (mul nsw X, MulC) == C --> X == C /s MulC
+ // (mul nsw X, MulC) eq/ne C --> X eq/ne C /s MulC
if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) {
Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC));
return new ICmpInst(Pred, X, NewC);
}
- // (mul nuw X, MulC) == C --> X == C /u MulC
- if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) {
- Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC));
- return new ICmpInst(Pred, X, NewC);
+
+ // C % MulC == 0 is weaker than we could use if MulC is odd because it
+ // correct to transform if MulC * N == C including overflow. I.e with i8
+ // (icmp eq (mul X, 5), 101) -> (icmp eq X, 225) but since 101 % 5 != 0, we
+ // miss that case.
+ if (C.urem(*MulC).isZero()) {
+ // (mul nuw X, MulC) eq/ne C --> X eq/ne C /u MulC
+ // (mul X, OddC) eq/ne N * C --> X eq/ne N
+ if ((*MulC & 1).isOne() || Mul->hasNoUnsignedWrap()) {
+ Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC));
+ return new ICmpInst(Pred, X, NewC);
+ }
}
}
@@ -1992,27 +2148,32 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
// (X * MulC) > C --> X > (C / MulC)
// TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE?
Constant *NewC = nullptr;
- if (Mul->hasNoSignedWrap()) {
+ if (Mul->hasNoSignedWrap() && ICmpInst::isSigned(Pred)) {
// MININT / -1 --> overflow.
if (C.isMinSignedValue() && MulC->isAllOnes())
return nullptr;
if (MulC->isNegative())
Pred = ICmpInst::getSwappedPredicate(Pred);
- if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE)
+ if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
NewC = ConstantInt::get(
MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP));
- if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT)
+ } else {
+ assert((Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) &&
+ "Unexpected predicate");
NewC = ConstantInt::get(
MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN));
- } else {
- assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw");
- if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)
+ }
+ } else if (Mul->hasNoUnsignedWrap() && ICmpInst::isUnsigned(Pred)) {
+ if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) {
NewC = ConstantInt::get(
MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP));
- if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)
+ } else {
+ assert((Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) &&
+ "Unexpected predicate");
NewC = ConstantInt::get(
MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN));
+ }
}
return NewC ? new ICmpInst(Pred, X, NewC) : nullptr;
@@ -2070,6 +2231,32 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal)))
return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal);
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
+ // (icmp pred (shl nuw&nsw X, Y), Csle0)
+ // -> (icmp pred X, Csle0)
+ //
+ // The idea is the nuw/nsw essentially freeze the sign bit for the shift op
+ // so X's must be what is used.
+ if (C.sle(0) && Shl->hasNoUnsignedWrap() && Shl->hasNoSignedWrap())
+ return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1));
+
+ // (icmp eq/ne (shl nuw|nsw X, Y), 0)
+ // -> (icmp eq/ne X, 0)
+ if (ICmpInst::isEquality(Pred) && C.isZero() &&
+ (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap()))
+ return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1));
+
+ // (icmp slt (shl nsw X, Y), 0/1)
+ // -> (icmp slt X, 0/1)
+ // (icmp sgt (shl nsw X, Y), 0/-1)
+ // -> (icmp sgt X, 0/-1)
+ //
+ // NB: sge/sle with a constant will canonicalize to sgt/slt.
+ if (Shl->hasNoSignedWrap() &&
+ (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT))
+ if (C.isZero() || (Pred == ICmpInst::ICMP_SGT ? C.isAllOnes() : C.isOne()))
+ return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1));
+
const APInt *ShiftAmt;
if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
return foldICmpShlOne(Cmp, Shl, C);
@@ -2080,7 +2267,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
if (ShiftAmt->uge(TypeBits))
return nullptr;
- ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Shl->getOperand(0);
Type *ShType = Shl->getType();
@@ -2107,11 +2293,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1;
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
- // If this is a signed comparison to 0 and the shift is sign preserving,
- // use the shift LHS operand instead; isSignTest may change 'Pred', so only
- // do that if we're sure to not continue on in this function.
- if (isSignTest(Pred, C))
- return new ICmpInst(Pred, X, Constant::getNullValue(ShType));
}
// NUW guarantees that we are only shifting out zero bits from the high bits,
@@ -2189,7 +2370,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
// free on the target. It has the additional benefit of comparing to a
// smaller constant that may be more target-friendly.
unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1);
- if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt &&
+ if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
if (auto *ShVTy = dyn_cast<VectorType>(ShType))
@@ -2237,9 +2418,8 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
assert(ShiftValC->uge(C) && "Expected simplify of compare");
assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify");
- unsigned CmpLZ =
- IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros();
- unsigned ShiftLZ = ShiftValC->countLeadingZeros();
+ unsigned CmpLZ = IsUGT ? C.countl_zero() : (C - 1).countl_zero();
+ unsigned ShiftLZ = ShiftValC->countl_zero();
Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ);
auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
return new ICmpInst(NewPred, Shr->getOperand(1), NewC);
@@ -3184,18 +3364,30 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
}
break;
}
- case Instruction::And: {
- const APInt *BOC;
- if (match(BOp1, m_APInt(BOC))) {
- // If we have ((X & C) == C), turn it into ((X & C) != 0).
- if (C == *BOC && C.isPowerOf2())
- return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
- BO, Constant::getNullValue(RHS->getType()));
- }
- break;
- }
case Instruction::UDiv:
- if (C.isZero()) {
+ case Instruction::SDiv:
+ if (BO->isExact()) {
+ // div exact X, Y eq/ne 0 -> X eq/ne 0
+ // div exact X, Y eq/ne 1 -> X eq/ne Y
+ // div exact X, Y eq/ne C ->
+ // if Y * C never-overflow && OneUse:
+ // -> Y * C eq/ne X
+ if (C.isZero())
+ return new ICmpInst(Pred, BOp0, Constant::getNullValue(BO->getType()));
+ else if (C.isOne())
+ return new ICmpInst(Pred, BOp0, BOp1);
+ else if (BO->hasOneUse()) {
+ OverflowResult OR = computeOverflow(
+ Instruction::Mul, BO->getOpcode() == Instruction::SDiv, BOp1,
+ Cmp.getOperand(1), BO);
+ if (OR == OverflowResult::NeverOverflows) {
+ Value *YC =
+ Builder.CreateMul(BOp1, ConstantInt::get(BO->getType(), C));
+ return new ICmpInst(Pred, YC, BOp0);
+ }
+ }
+ }
+ if (BO->getOpcode() == Instruction::UDiv && C.isZero()) {
// (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A)
auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
return new ICmpInst(NewPred, BOp1, BOp0);
@@ -3207,6 +3399,44 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
return nullptr;
}
+static Instruction *foldCtpopPow2Test(ICmpInst &I, IntrinsicInst *CtpopLhs,
+ const APInt &CRhs,
+ InstCombiner::BuilderTy &Builder,
+ const SimplifyQuery &Q) {
+ assert(CtpopLhs->getIntrinsicID() == Intrinsic::ctpop &&
+ "Non-ctpop intrin in ctpop fold");
+ if (!CtpopLhs->hasOneUse())
+ return nullptr;
+
+ // Power of 2 test:
+ // isPow2OrZero : ctpop(X) u< 2
+ // isPow2 : ctpop(X) == 1
+ // NotPow2OrZero: ctpop(X) u> 1
+ // NotPow2 : ctpop(X) != 1
+ // If we know any bit of X can be folded to:
+ // IsPow2 : X & (~Bit) == 0
+ // NotPow2 : X & (~Bit) != 0
+ const ICmpInst::Predicate Pred = I.getPredicate();
+ if (((I.isEquality() || Pred == ICmpInst::ICMP_UGT) && CRhs == 1) ||
+ (Pred == ICmpInst::ICMP_ULT && CRhs == 2)) {
+ Value *Op = CtpopLhs->getArgOperand(0);
+ KnownBits OpKnown = computeKnownBits(Op, Q.DL,
+ /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+ // No need to check for count > 1, that should be already constant folded.
+ if (OpKnown.countMinPopulation() == 1) {
+ Value *And = Builder.CreateAnd(
+ Op, Constant::getIntegerValue(Op->getType(), ~(OpKnown.One)));
+ return new ICmpInst(
+ (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_ULT)
+ ? ICmpInst::ICMP_EQ
+ : ICmpInst::ICMP_NE,
+ And, Constant::getNullValue(Op->getType()));
+ }
+ }
+
+ return nullptr;
+}
+
/// Fold an equality icmp with LLVM intrinsic and constant operand.
Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) {
@@ -3227,6 +3457,11 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
return new ICmpInst(Pred, II->getArgOperand(0),
ConstantInt::get(Ty, C.byteSwap()));
+ case Intrinsic::bitreverse:
+ // bitreverse(A) == C -> A == bitreverse(C)
+ return new ICmpInst(Pred, II->getArgOperand(0),
+ ConstantInt::get(Ty, C.reverseBits()));
+
case Intrinsic::ctlz:
case Intrinsic::cttz: {
// ctz(A) == bitwidth(A) -> A == 0 and likewise for !=
@@ -3277,15 +3512,22 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
}
break;
+ case Intrinsic::umax:
case Intrinsic::uadd_sat: {
// uadd.sat(a, b) == 0 -> (a | b) == 0
- if (C.isZero()) {
+ // umax(a, b) == 0 -> (a | b) == 0
+ if (C.isZero() && II->hasOneUse()) {
Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1));
return new ICmpInst(Pred, Or, Constant::getNullValue(Ty));
}
break;
}
+ case Intrinsic::ssub_sat:
+ // ssub.sat(a, b) == 0 -> a == b
+ if (C.isZero())
+ return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1));
+ break;
case Intrinsic::usub_sat: {
// usub.sat(a, b) == 0 -> a <= b
if (C.isZero()) {
@@ -3303,7 +3545,9 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
}
/// Fold an icmp with LLVM intrinsics
-static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) {
+static Instruction *
+foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp,
+ InstCombiner::BuilderTy &Builder) {
assert(Cmp.isEquality());
ICmpInst::Predicate Pred = Cmp.getPredicate();
@@ -3321,16 +3565,32 @@ static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) {
// original values.
return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0));
case Intrinsic::fshl:
- case Intrinsic::fshr:
+ case Intrinsic::fshr: {
// If both operands are rotated by same amount, just compare the
// original values.
if (IIOp0->getOperand(0) != IIOp0->getOperand(1))
break;
if (IIOp1->getOperand(0) != IIOp1->getOperand(1))
break;
- if (IIOp0->getOperand(2) != IIOp1->getOperand(2))
- break;
- return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0));
+ if (IIOp0->getOperand(2) == IIOp1->getOperand(2))
+ return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0));
+
+ // rotate(X, AmtX) == rotate(Y, AmtY)
+ // -> rotate(X, AmtX - AmtY) == Y
+ // Do this if either both rotates have one use or if only one has one use
+ // and AmtX/AmtY are constants.
+ unsigned OneUses = IIOp0->hasOneUse() + IIOp1->hasOneUse();
+ if (OneUses == 2 ||
+ (OneUses == 1 && match(IIOp0->getOperand(2), m_ImmConstant()) &&
+ match(IIOp1->getOperand(2), m_ImmConstant()))) {
+ Value *SubAmt =
+ Builder.CreateSub(IIOp0->getOperand(2), IIOp1->getOperand(2));
+ Value *CombinedRotate = Builder.CreateIntrinsic(
+ Op0->getType(), IIOp0->getIntrinsicID(),
+ {IIOp0->getOperand(0), IIOp0->getOperand(0), SubAmt});
+ return new ICmpInst(Pred, IIOp1->getOperand(0), CombinedRotate);
+ }
+ } break;
default:
break;
}
@@ -3421,16 +3681,119 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
return foldICmpBinOpEqualityWithConstant(Cmp, BO, C);
}
+static Instruction *
+foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
+ SaturatingInst *II, const APInt &C,
+ InstCombiner::BuilderTy &Builder) {
+ // This transform may end up producing more than one instruction for the
+ // intrinsic, so limit it to one user of the intrinsic.
+ if (!II->hasOneUse())
+ return nullptr;
+
+ // Let Y = [add/sub]_sat(X, C) pred C2
+ // SatVal = The saturating value for the operation
+ // WillWrap = Whether or not the operation will underflow / overflow
+ // => Y = (WillWrap ? SatVal : (X binop C)) pred C2
+ // => Y = WillWrap ? (SatVal pred C2) : ((X binop C) pred C2)
+ //
+ // When (SatVal pred C2) is true, then
+ // Y = WillWrap ? true : ((X binop C) pred C2)
+ // => Y = WillWrap || ((X binop C) pred C2)
+ // else
+ // Y = WillWrap ? false : ((X binop C) pred C2)
+ // => Y = !WillWrap ? ((X binop C) pred C2) : false
+ // => Y = !WillWrap && ((X binop C) pred C2)
+ Value *Op0 = II->getOperand(0);
+ Value *Op1 = II->getOperand(1);
+
+ const APInt *COp1;
+ // This transform only works when the intrinsic has an integral constant or
+ // splat vector as the second operand.
+ if (!match(Op1, m_APInt(COp1)))
+ return nullptr;
+
+ APInt SatVal;
+ switch (II->getIntrinsicID()) {
+ default:
+ llvm_unreachable(
+ "This function only works with usub_sat and uadd_sat for now!");
+ case Intrinsic::uadd_sat:
+ SatVal = APInt::getAllOnes(C.getBitWidth());
+ break;
+ case Intrinsic::usub_sat:
+ SatVal = APInt::getZero(C.getBitWidth());
+ break;
+ }
+
+ // Check (SatVal pred C2)
+ bool SatValCheck = ICmpInst::compare(SatVal, C, Pred);
+
+ // !WillWrap.
+ ConstantRange C1 = ConstantRange::makeExactNoWrapRegion(
+ II->getBinaryOp(), *COp1, II->getNoWrapKind());
+
+ // WillWrap.
+ if (SatValCheck)
+ C1 = C1.inverse();
+
+ ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C);
+ if (II->getBinaryOp() == Instruction::Add)
+ C2 = C2.sub(*COp1);
+ else
+ C2 = C2.add(*COp1);
+
+ Instruction::BinaryOps CombiningOp =
+ SatValCheck ? Instruction::BinaryOps::Or : Instruction::BinaryOps::And;
+
+ std::optional<ConstantRange> Combination;
+ if (CombiningOp == Instruction::BinaryOps::Or)
+ Combination = C1.exactUnionWith(C2);
+ else /* CombiningOp == Instruction::BinaryOps::And */
+ Combination = C1.exactIntersectWith(C2);
+
+ if (!Combination)
+ return nullptr;
+
+ CmpInst::Predicate EquivPred;
+ APInt EquivInt;
+ APInt EquivOffset;
+
+ Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset);
+
+ return new ICmpInst(
+ EquivPred,
+ Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)),
+ ConstantInt::get(Op1->getType(), EquivInt));
+}
+
/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C.
Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
IntrinsicInst *II,
const APInt &C) {
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
+
+ // Handle folds that apply for any kind of icmp.
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::uadd_sat:
+ case Intrinsic::usub_sat:
+ if (auto *Folded = foldICmpUSubSatOrUAddSatWithConstant(
+ Pred, cast<SaturatingInst>(II), C, Builder))
+ return Folded;
+ break;
+ case Intrinsic::ctpop: {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Cmp);
+ if (Instruction *R = foldCtpopPow2Test(Cmp, II, C, Builder, Q))
+ return R;
+ } break;
+ }
+
if (Cmp.isEquality())
return foldICmpEqIntrinsicWithConstant(Cmp, II, C);
Type *Ty = II->getType();
unsigned BitWidth = C.getBitWidth();
- ICmpInst::Predicate Pred = Cmp.getPredicate();
switch (II->getIntrinsicID()) {
case Intrinsic::ctpop: {
// (ctpop X > BitWidth - 1) --> X == -1
@@ -3484,6 +3847,21 @@ Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
}
break;
}
+ case Intrinsic::ssub_sat:
+ // ssub.sat(a, b) spred 0 -> a spred b
+ if (ICmpInst::isSigned(Pred)) {
+ if (C.isZero())
+ return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1));
+ // X s<= 0 is cannonicalized to X s< 1
+ if (Pred == ICmpInst::ICMP_SLT && C.isOne())
+ return new ICmpInst(ICmpInst::ICMP_SLE, II->getArgOperand(0),
+ II->getArgOperand(1));
+ // X s>= 0 is cannonicalized to X s> -1
+ if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes())
+ return new ICmpInst(ICmpInst::ICMP_SGE, II->getArgOperand(0),
+ II->getArgOperand(1));
+ }
+ break;
default:
break;
}
@@ -4014,20 +4392,60 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
return Res;
}
-static Instruction *foldICmpXNegX(ICmpInst &I) {
+static Instruction *foldICmpXNegX(ICmpInst &I,
+ InstCombiner::BuilderTy &Builder) {
CmpInst::Predicate Pred;
Value *X;
- if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X))))
- return nullptr;
+ if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) {
+
+ if (ICmpInst::isSigned(Pred))
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ else if (ICmpInst::isUnsigned(Pred))
+ Pred = ICmpInst::getSignedPredicate(Pred);
+ // else for equality-comparisons just keep the predicate.
+
+ return ICmpInst::Create(Instruction::ICmp, Pred, X,
+ Constant::getNullValue(X->getType()), I.getName());
+ }
+
+ // A value is not equal to its negation unless that value is 0 or
+ // MinSignedValue, ie: a != -a --> (a & MaxSignedVal) != 0
+ if (match(&I, m_c_ICmp(Pred, m_OneUse(m_Neg(m_Value(X))), m_Deferred(X))) &&
+ ICmpInst::isEquality(Pred)) {
+ Type *Ty = X->getType();
+ uint32_t BitWidth = Ty->getScalarSizeInBits();
+ Constant *MaxSignedVal =
+ ConstantInt::get(Ty, APInt::getSignedMaxValue(BitWidth));
+ Value *And = Builder.CreateAnd(X, MaxSignedVal);
+ Constant *Zero = Constant::getNullValue(Ty);
+ return CmpInst::Create(Instruction::ICmp, Pred, And, Zero);
+ }
+
+ return nullptr;
+}
- if (ICmpInst::isSigned(Pred))
+static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
+ InstCombinerImpl &IC) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A;
+ // Normalize xor operand as operand 0.
+ CmpInst::Predicate Pred = I.getPredicate();
+ if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value()))) {
+ std::swap(Op0, Op1);
Pred = ICmpInst::getSwappedPredicate(Pred);
- else if (ICmpInst::isUnsigned(Pred))
- Pred = ICmpInst::getSignedPredicate(Pred);
- // else for equality-comparisons just keep the predicate.
+ }
+ if (!match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A))))
+ return nullptr;
- return ICmpInst::Create(Instruction::ICmp, Pred, X,
- Constant::getNullValue(X->getType()), I.getName());
+ // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X
+ // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X
+ // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X
+ // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X
+ CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred);
+ if (PredOut != Pred &&
+ isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(PredOut, Op0, Op1);
+
+ return nullptr;
}
/// Try to fold icmp (binop), X or icmp X, (binop).
@@ -4045,7 +4463,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (!BO0 && !BO1)
return nullptr;
- if (Instruction *NewICmp = foldICmpXNegX(I))
+ if (Instruction *NewICmp = foldICmpXNegX(I, Builder))
return NewICmp;
const CmpInst::Predicate Pred = I.getPredicate();
@@ -4326,17 +4744,41 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
ConstantExpr::getNeg(RHSC));
}
+ if (Instruction * R = foldICmpXorXX(I, Q, *this))
+ return R;
+
{
- // Try to remove shared constant multiplier from equality comparison:
- // X * C == Y * C (with no overflowing/aliasing) --> X == Y
- Value *X, *Y;
- const APInt *C;
- if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 &&
- match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality())
- if (!C->countTrailingZeros() ||
- (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) ||
- (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
- return new ICmpInst(Pred, X, Y);
+ // Try to remove shared multiplier from comparison:
+ // X * Z u{lt/le/gt/ge}/eq/ne Y * Z
+ Value *X, *Y, *Z;
+ if (Pred == ICmpInst::getUnsignedPredicate(Pred) &&
+ ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
+ (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) {
+ bool NonZero;
+ if (ICmpInst::isEquality(Pred)) {
+ KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+ // if Z % 2 != 0
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (ZKnown.countMaxTrailingZeros() == 0)
+ return new ICmpInst(Pred, X, Y);
+ NonZero = !ZKnown.One.isZero() ||
+ isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() &&
+ BO1->hasNoSignedWrap())
+ return new ICmpInst(Pred, X, Y);
+ } else
+ NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+
+ // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
+ // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
+ if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() &&
+ BO1->hasNoUnsignedWrap())
+ return new ICmpInst(Pred, X, Y);
+ }
}
BinaryOperator *SRem = nullptr;
@@ -4405,7 +4847,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
!C->isOne()) {
// icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask)
// Mask = -1 >> count-trailing-zeros(C).
- if (unsigned TZs = C->countTrailingZeros()) {
+ if (unsigned TZs = C->countr_zero()) {
Constant *Mask = ConstantInt::get(
BO0->getType(),
APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs));
@@ -4569,6 +5011,59 @@ static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) {
return nullptr;
}
+// Canonicalize checking for a power-of-2-or-zero value:
+static Instruction *foldICmpPow2Test(ICmpInst &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ const CmpInst::Predicate Pred = I.getPredicate();
+ Value *A = nullptr;
+ bool CheckIs;
+ if (I.isEquality()) {
+ // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants)
+ // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants)
+ if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()),
+ m_Deferred(A)))) ||
+ !match(Op1, m_ZeroInt()))
+ A = nullptr;
+
+ // (A & -A) == A --> ctpop(A) < 2 (four commuted variants)
+ // (-A & A) != A --> ctpop(A) > 1 (four commuted variants)
+ if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1)))))
+ A = Op1;
+ else if (match(Op1,
+ m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0)))))
+ A = Op0;
+
+ CheckIs = Pred == ICmpInst::ICMP_EQ;
+ } else if (ICmpInst::isUnsigned(Pred)) {
+ // (A ^ (A-1)) u>= A --> ctpop(A) < 2 (two commuted variants)
+ // ((A-1) ^ A) u< A --> ctpop(A) > 1 (two commuted variants)
+
+ if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) &&
+ match(Op0, m_OneUse(m_c_Xor(m_Add(m_Specific(Op1), m_AllOnes()),
+ m_Specific(Op1))))) {
+ A = Op1;
+ CheckIs = Pred == ICmpInst::ICMP_UGE;
+ } else if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) &&
+ match(Op1, m_OneUse(m_c_Xor(m_Add(m_Specific(Op0), m_AllOnes()),
+ m_Specific(Op0))))) {
+ A = Op0;
+ CheckIs = Pred == ICmpInst::ICMP_ULE;
+ }
+ }
+
+ if (A) {
+ Type *Ty = A->getType();
+ CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A);
+ return CheckIs ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop,
+ ConstantInt::get(Ty, 2))
+ : new ICmpInst(ICmpInst::ICMP_UGT, CtPop,
+ ConstantInt::get(Ty, 1));
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
@@ -4604,6 +5099,21 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
}
}
+ // canoncalize:
+ // (icmp eq/ne (and X, C), X)
+ // -> (icmp eq/ne (and X, ~C), 0)
+ {
+ Constant *CMask;
+ A = nullptr;
+ if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_ImmConstant(CMask)))))
+ A = Op1;
+ else if (match(Op1, m_OneUse(m_And(m_Specific(Op0), m_ImmConstant(CMask)))))
+ A = Op0;
+ if (A)
+ return new ICmpInst(Pred, Builder.CreateAnd(A, Builder.CreateNot(CMask)),
+ Constant::getNullValue(A->getType()));
+ }
+
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) {
// A == (A^B) -> B == 0
Value *OtherVal = A == Op0 ? B : A;
@@ -4659,22 +5169,36 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
// (B & (Pow2C-1)) != zext A --> A != trunc B
const APInt *MaskC;
if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) &&
- MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits())
+ MaskC->countr_one() == A->getType()->getScalarSizeInBits())
return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
+ }
- // Test if 2 values have different or same signbits:
- // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
- // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
+ // Test if 2 values have different or same signbits:
+ // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
+ // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
+ // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
+ // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
+ Instruction *ExtI;
+ if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
unsigned OpWidth = Op0->getType()->getScalarSizeInBits();
+ Instruction *ShiftI;
Value *X, *Y;
ICmpInst::Predicate Pred2;
- if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) &&
+ if (match(Op0, m_CombineAnd(m_Instruction(ShiftI),
+ m_Shr(m_Value(X),
+ m_SpecificIntAllowUndef(OpWidth - 1)))) &&
match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
- Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
- Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) :
- Builder.CreateIsNotNeg(Xor);
- return replaceInstUsesWith(I, R);
+ unsigned ExtOpc = ExtI->getOpcode();
+ unsigned ShiftOpc = ShiftI->getOpcode();
+ if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
+ (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
+ Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
+ Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
+ : Builder.CreateIsNotNeg(Xor);
+ return replaceInstUsesWith(I, R);
+ }
}
}
@@ -4737,33 +5261,9 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
}
}
- if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I))
+ if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I, Builder))
return ICmp;
- // Canonicalize checking for a power-of-2-or-zero value:
- // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants)
- // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants)
- if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()),
- m_Deferred(A)))) ||
- !match(Op1, m_ZeroInt()))
- A = nullptr;
-
- // (A & -A) == A --> ctpop(A) < 2 (four commuted variants)
- // (-A & A) != A --> ctpop(A) > 1 (four commuted variants)
- if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1)))))
- A = Op1;
- else if (match(Op1,
- m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0)))))
- A = Op0;
-
- if (A) {
- Type *Ty = A->getType();
- CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A);
- return Pred == ICmpInst::ICMP_EQ
- ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2))
- : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1));
- }
-
// Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the
// top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX",
// which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps
@@ -4794,11 +5294,23 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1,
ConstantInt::getNullValue(Op1->getType()));
+ // Canonicalize:
+ // icmp eq/ne X, OneUse(rotate-right(X))
+ // -> icmp eq/ne X, rotate-left(X)
+ // We generally try to convert rotate-right -> rotate-left, this just
+ // canonicalizes another case.
+ CmpInst::Predicate PredUnused = Pred;
+ if (match(&I, m_c_ICmp(PredUnused, m_Value(A),
+ m_OneUse(m_Intrinsic<Intrinsic::fshr>(
+ m_Deferred(A), m_Deferred(A), m_Value(B))))))
+ return new ICmpInst(
+ Pred, A,
+ Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B}));
+
return nullptr;
}
-static Instruction *foldICmpWithTrunc(ICmpInst &ICmp,
- InstCombiner::BuilderTy &Builder) {
+Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
ICmpInst::Predicate Pred = ICmp.getPredicate();
Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1);
@@ -4836,6 +5348,25 @@ static Instruction *foldICmpWithTrunc(ICmpInst &ICmp,
return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
}
+ if (auto *II = dyn_cast<IntrinsicInst>(X)) {
+ if (II->getIntrinsicID() == Intrinsic::cttz ||
+ II->getIntrinsicID() == Intrinsic::ctlz) {
+ unsigned MaxRet = SrcBits;
+ // If the "is_zero_poison" argument is set, then we know at least
+ // one bit is set in the input, so the result is always at least one
+ // less than the full bitwidth of that input.
+ if (match(II->getArgOperand(1), m_One()))
+ MaxRet--;
+
+ // Make sure the destination is wide enough to hold the largest output of
+ // the intrinsic.
+ if (llvm::Log2_32(MaxRet) + 1 <= Op0->getType()->getScalarSizeInBits())
+ if (Instruction *I =
+ foldICmpIntrinsicWithConstant(ICmp, II, C->zext(SrcBits)))
+ return I;
+ }
+ }
+
return nullptr;
}
@@ -4855,10 +5386,19 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0));
bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1));
- // If we have mismatched casts, treat the zext of a non-negative source as
- // a sext to simulate matching casts. Otherwise, we are done.
- // TODO: Can we handle some predicates (equality) without non-negative?
if (IsZext0 != IsZext1) {
+ // If X and Y and both i1
+ // (icmp eq/ne (zext X) (sext Y))
+ // eq -> (icmp eq (or X, Y), 0)
+ // ne -> (icmp ne (or X, Y), 0)
+ if (ICmp.isEquality() && X->getType()->isIntOrIntVectorTy(1) &&
+ Y->getType()->isIntOrIntVectorTy(1))
+ return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y),
+ Constant::getNullValue(X->getType()));
+
+ // If we have mismatched casts, treat the zext of a non-negative source as
+ // a sext to simulate matching casts. Otherwise, we are done.
+ // TODO: Can we handle some predicates (equality) without non-negative?
if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) ||
(IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT)))
IsSignedExt = true;
@@ -4993,7 +5533,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
}
- if (Instruction *R = foldICmpWithTrunc(ICmp, Builder))
+ if (Instruction *R = foldICmpWithTrunc(ICmp))
return R;
return foldICmpWithZextOrSext(ICmp);
@@ -5153,7 +5693,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
return nullptr;
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
const APInt &CVal = CI->getValue();
- if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth)
+ if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
return nullptr;
} else {
// In this case we could have the operand of the binary operation
@@ -5334,44 +5874,18 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) {
// bits doesn't impact the outcome of the comparison, because any value
// greater than the RHS must differ in a bit higher than these due to carry.
case ICmpInst::ICMP_UGT:
- return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes());
+ return APInt::getBitsSetFrom(BitWidth, RHS->countr_one());
// Similarly, for a ULT comparison, we don't care about the trailing zeros.
// Any value less than the RHS must differ in a higher bit because of carries.
case ICmpInst::ICMP_ULT:
- return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros());
+ return APInt::getBitsSetFrom(BitWidth, RHS->countr_zero());
default:
return APInt::getAllOnes(BitWidth);
}
}
-/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst
-/// should be swapped.
-/// The decision is based on how many times these two operands are reused
-/// as subtract operands and their positions in those instructions.
-/// The rationale is that several architectures use the same instruction for
-/// both subtract and cmp. Thus, it is better if the order of those operands
-/// match.
-/// \return true if Op0 and Op1 should be swapped.
-static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) {
- // Filter out pointer values as those cannot appear directly in subtract.
- // FIXME: we may want to go through inttoptrs or bitcasts.
- if (Op0->getType()->isPointerTy())
- return false;
- // If a subtract already has the same operands as a compare, swapping would be
- // bad. If a subtract has the same operands as a compare but in reverse order,
- // then swapping is good.
- int GoodToSwap = 0;
- for (const User *U : Op0->users()) {
- if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0))))
- GoodToSwap++;
- else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1))))
- GoodToSwap--;
- }
- return GoodToSwap > 0;
-}
-
/// Check that one use is in the same block as the definition and all
/// other uses are in blocks dominated by a given block.
///
@@ -5638,14 +6152,14 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
const APInt *C1;
if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) {
Type *XTy = X->getType();
- unsigned Log2C1 = C1->countTrailingZeros();
+ unsigned Log2C1 = C1->countr_zero();
APInt C2 = Op0KnownZeroInverted;
APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1;
if (C2Pow2.isPowerOf2()) {
// iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2):
// ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1))
// ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1))
- unsigned Log2C2 = C2Pow2.countTrailingZeros();
+ unsigned Log2C2 = C2Pow2.countr_zero();
auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1);
auto NewPred =
Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT;
@@ -5653,6 +6167,12 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
}
}
}
+
+ // Op0 eq C_Pow2 -> Op0 ne 0 if Op0 is known to be C_Pow2 or zero.
+ if (Op1Known.isConstant() && Op1Known.getConstant().isPowerOf2() &&
+ (Op0Known & Op1Known) == Op0Known)
+ return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0,
+ ConstantInt::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_ULT: {
@@ -5733,8 +6253,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
/// If one operand of an icmp is effectively a bool (value range of {0,1}),
/// then try to reduce patterns based on that limit.
-static Instruction *foldICmpUsingBoolRange(ICmpInst &I,
- InstCombiner::BuilderTy &Builder) {
+Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Value *X, *Y;
ICmpInst::Predicate Pred;
@@ -5750,6 +6269,60 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I,
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
+ const APInt *C;
+ if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+ match(I.getOperand(1), m_APInt(C)) &&
+ X->getType()->isIntOrIntVectorTy(1) &&
+ Y->getType()->isIntOrIntVectorTy(1)) {
+ unsigned BitWidth = C->getBitWidth();
+ Pred = I.getPredicate();
+ APInt Zero = APInt::getZero(BitWidth);
+ APInt MinusOne = APInt::getAllOnes(BitWidth);
+ APInt One(BitWidth, 1);
+ if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) ||
+ (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT))
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) ||
+ (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT))
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+
+ if (I.getOperand(0)->hasOneUse()) {
+ APInt NewC = *C;
+ // canonicalize predicate to eq/ne
+ if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) ||
+ (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) {
+ // x s< 0 in [-1, 1] --> x == -1
+ // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1
+ NewC = MinusOne;
+ Pred = ICmpInst::ICMP_EQ;
+ } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) ||
+ (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) {
+ // x s> -1 in [-1, 1] --> x != -1
+ // x u< -1 in [-1, 1] --> x != -1
+ Pred = ICmpInst::ICMP_NE;
+ } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) {
+ // x s> 0 in [-1, 1] --> x == 1
+ NewC = One;
+ Pred = ICmpInst::ICMP_EQ;
+ } else if (*C == One && Pred == ICmpInst::ICMP_SLT) {
+ // x s< 1 in [-1, 1] --> x != 1
+ Pred = ICmpInst::ICMP_NE;
+ }
+
+ if (NewC == MinusOne) {
+ if (Pred == ICmpInst::ICMP_EQ)
+ return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
+ if (Pred == ICmpInst::ICMP_NE)
+ return BinaryOperator::CreateOr(X, Builder.CreateNot(Y));
+ } else if (NewC == One) {
+ if (Pred == ICmpInst::ICMP_EQ)
+ return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y));
+ if (Pred == ICmpInst::ICMP_NE)
+ return BinaryOperator::CreateOr(Builder.CreateNot(X), Y);
+ }
+ }
+ }
+
return nullptr;
}
@@ -6162,8 +6735,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
/// Orders the operands of the compare so that they are listed from most
/// complex to least complex. This puts constants before unary operators,
/// before binary operators.
- if (Op0Cplxity < Op1Cplxity ||
- (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) {
+ if (Op0Cplxity < Op1Cplxity) {
I.swapOperands();
std::swap(Op0, Op1);
Changed = true;
@@ -6205,7 +6777,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpWithDominatingICmp(I))
return Res;
- if (Instruction *Res = foldICmpUsingBoolRange(I, Builder))
+ if (Instruction *Res = foldICmpUsingBoolRange(I))
return Res;
if (Instruction *Res = foldICmpUsingKnownBits(I))
@@ -6288,15 +6860,46 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I))
return NI;
+ // In case of a comparison with two select instructions having the same
+ // condition, check whether one of the resulting branches can be simplified.
+ // If so, just compare the other branch and select the appropriate result.
+ // For example:
+ // %tmp1 = select i1 %cmp, i32 %y, i32 %x
+ // %tmp2 = select i1 %cmp, i32 %z, i32 %x
+ // %cmp2 = icmp slt i32 %tmp2, %tmp1
+ // The icmp will result false for the false value of selects and the result
+ // will depend upon the comparison of true values of selects if %cmp is
+ // true. Thus, transform this into:
+ // %cmp = icmp slt i32 %y, %z
+ // %sel = select i1 %cond, i1 %cmp, i1 false
+ // This handles similar cases to transform.
+ {
+ Value *Cond, *A, *B, *C, *D;
+ if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) &&
+ match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ // Check whether comparison of TrueValues can be simplified
+ if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) {
+ Value *NewICMP = Builder.CreateICmp(Pred, B, D);
+ return SelectInst::Create(Cond, Res, NewICMP);
+ }
+ // Check whether comparison of FalseValues can be simplified
+ if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) {
+ Value *NewICMP = Builder.CreateICmp(Pred, A, C);
+ return SelectInst::Create(Cond, NewICMP, Res);
+ }
+ }
+ }
+
// Try to optimize equality comparisons against alloca-based pointers.
if (Op0->getType()->isPointerTy() && I.isEquality()) {
assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0)))
- if (Instruction *New = foldAllocaCmp(I, Alloca))
- return New;
+ if (foldAllocaCmp(Alloca))
+ return nullptr;
if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1)))
- if (Instruction *New = foldAllocaCmp(I, Alloca))
- return New;
+ if (foldAllocaCmp(Alloca))
+ return nullptr;
}
if (Instruction *Res = foldICmpBitCast(I))
@@ -6363,6 +6966,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpEquality(I))
return Res;
+ if (Instruction *Res = foldICmpPow2Test(I, Builder))
+ return Res;
+
if (Instruction *Res = foldICmpOfUAddOv(I))
return Res;
@@ -6717,7 +7323,7 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
Mode.Input == DenormalMode::PositiveZero) {
auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
- Constant *Zero = ConstantFP::getNullValue(X->getType());
+ Constant *Zero = ConstantFP::getZero(X->getType());
return new FCmpInst(P, X, Zero, "", I);
};
@@ -6813,7 +7419,7 @@ static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
// Replace the negated operand with 0.0:
// fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0
- Constant *Zero = ConstantFP::getNullValue(Op0->getType());
+ Constant *Zero = ConstantFP::getZero(Op0->getType());
return new FCmpInst(Pred, Op0, Zero, "", &I);
}
@@ -6863,11 +7469,13 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
// If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand,
// then canonicalize the operand to 0.0.
if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) {
- if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI))
- return replaceOperand(I, 0, ConstantFP::getNullValue(OpType));
+ if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, DL, &TLI, 0,
+ &AC, &I, &DT))
+ return replaceOperand(I, 0, ConstantFP::getZero(OpType));
- if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI))
- return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
+ if (!match(Op1, m_PosZeroFP()) &&
+ isKnownNeverNaN(Op1, DL, &TLI, 0, &AC, &I, &DT))
+ return replaceOperand(I, 1, ConstantFP::getZero(OpType));
}
// fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y
@@ -6896,7 +7504,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
// The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0:
// fcmp Pred X, -0.0 --> fcmp Pred X, 0.0
if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP()))
- return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
+ return replaceOperand(I, 1, ConstantFP::getZero(OpType));
// Ignore signbit of bitcasted int when comparing equality to FP 0.0:
// fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
@@ -6985,11 +7593,11 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
case FCmpInst::FCMP_ONE:
// X is ordered and not equal to an impossible constant --> ordered
return new FCmpInst(FCmpInst::FCMP_ORD, X,
- ConstantFP::getNullValue(X->getType()));
+ ConstantFP::getZero(X->getType()));
case FCmpInst::FCMP_UEQ:
// X is unordered or equal to an impossible constant --> unordered
return new FCmpInst(FCmpInst::FCMP_UNO, X,
- ConstantFP::getNullValue(X->getType()));
+ ConstantFP::getZero(X->getType()));
case FCmpInst::FCMP_UNE:
// X is unordered or not equal to an impossible constant --> true
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f4e88b122383..701579e1de48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -150,7 +150,6 @@ public:
Instruction *visitPHINode(PHINode &PN);
Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP);
Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src);
- Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP);
Instruction *visitAllocaInst(AllocaInst &AI);
Instruction *visitAllocSite(Instruction &FI);
Instruction *visitFree(CallInst &FI, Value *FreedOp);
@@ -330,8 +329,7 @@ private:
Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
Instruction *matchSAddSubSat(IntrinsicInst &MinMax1);
Instruction *foldNot(BinaryOperator &I);
-
- void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
+ Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I);
/// Determine if a pair of casts can be replaced by a single cast.
///
@@ -378,6 +376,7 @@ private:
Instruction *foldLShrOverflowBit(BinaryOperator &I);
Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV);
Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II);
+ Instruction *foldIntrinsicIsFPClass(IntrinsicInst &II);
Instruction *foldFPSignBitOps(BinaryOperator &I);
Instruction *foldFDivConstantDivisor(BinaryOperator &I);
@@ -393,12 +392,12 @@ public:
/// without having to rewrite the CFG from within InstCombine.
void CreateNonTerminatorUnreachable(Instruction *InsertAt) {
auto &Ctx = InsertAt->getContext();
- new StoreInst(ConstantInt::getTrue(Ctx),
- PoisonValue::get(Type::getInt1PtrTy(Ctx)),
- InsertAt);
+ auto *SI = new StoreInst(ConstantInt::getTrue(Ctx),
+ PoisonValue::get(Type::getInt1PtrTy(Ctx)),
+ /*isVolatile*/ false, Align(1));
+ InsertNewInstBefore(SI, *InsertAt);
}
-
/// Combiner aware instruction erasure.
///
/// When dealing with an instruction that has side effects or produces a void
@@ -411,12 +410,11 @@ public:
// Make sure that we reprocess all operands now that we reduced their
// use counts.
- for (Use &Operand : I.operands())
- if (auto *Inst = dyn_cast<Instruction>(Operand))
- Worklist.add(Inst);
-
+ SmallVector<Value *> Ops(I.operands());
Worklist.remove(&I);
I.eraseFromParent();
+ for (Value *Op : Ops)
+ Worklist.handleUseCountDecrement(Op);
MadeIRChange = true;
return nullptr; // Don't do anything with FI
}
@@ -450,6 +448,18 @@ public:
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
Value *RHS);
+ // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
+ // -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C)
+ // (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt))
+ // -> (BinOp (logic_shift (BinOp X, Y)), Mask)
+ Instruction *foldBinOpShiftWithShift(BinaryOperator &I);
+
+ /// Tries to simplify binops of select and cast of the select condition.
+ ///
+ /// (Binop (cast C), (select C, T, F))
+ /// -> (select C, C0, C1)
+ Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I);
+
/// This tries to simplify binary operations by factorizing out common terms
/// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
Value *tryFactorizationFolds(BinaryOperator &I);
@@ -549,7 +559,7 @@ public:
ICmpInst::Predicate Cond, Instruction &I);
Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI,
Value *RHS, const ICmpInst &I);
- Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca);
+ bool foldAllocaCmp(AllocaInst *Alloca);
Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
GetElementPtrInst *GEP,
GlobalVariable *GV, CmpInst &ICI,
@@ -564,6 +574,7 @@ public:
Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
Instruction *foldICmpWithConstant(ICmpInst &Cmp);
+ Instruction *foldICmpUsingBoolRange(ICmpInst &I);
Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
@@ -623,6 +634,7 @@ public:
Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II,
const APInt &C);
Instruction *foldICmpBitCast(ICmpInst &Cmp);
+ Instruction *foldICmpWithTrunc(ICmpInst &Cmp);
// Helpers of visitSelectInst().
Instruction *foldSelectOfBools(SelectInst &SI);
@@ -634,10 +646,11 @@ public:
SelectPatternFlavor SPF2, Value *C);
Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
+ bool replaceInInstruction(Value *V, Value *Old, Value *New,
+ unsigned Depth = 0);
Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside);
- Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);
bool mergeStoreIntoSuccessor(StoreInst &SI);
/// Given an initial instruction, check to see if it is the root of a
@@ -651,10 +664,12 @@ public:
Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned);
- /// Returns a value X such that Val = X * Scale, or null if none.
- ///
- /// If the multiplication is known not to overflow then NoSignedWrap is set.
- Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap);
+ bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock);
+
+ bool removeInstructionsBeforeUnreachable(Instruction &I);
+ bool handleUnreachableFrom(Instruction *I);
+ bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
+ void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
};
class Negator final {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 41bc65620ff6..6aa20ee26b9a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -32,7 +32,7 @@ STATISTIC(NumDeadStore, "Number of dead stores eliminated");
STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global");
static cl::opt<unsigned> MaxCopiedFromConstantUsers(
- "instcombine-max-copied-from-constant-users", cl::init(128),
+ "instcombine-max-copied-from-constant-users", cl::init(300),
cl::desc("Maximum users to visit in copy from constant transform"),
cl::Hidden);
@@ -219,7 +219,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
// Now that I is pointing to the first non-allocation-inst in the block,
// insert our getelementptr instruction...
//
- Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType());
+ Type *IdxTy = IC.getDataLayout().getIndexType(AI.getType());
Value *NullIdx = Constant::getNullValue(IdxTy);
Value *Idx[2] = {NullIdx, NullIdx};
Instruction *GEP = GetElementPtrInst::CreateInBounds(
@@ -235,11 +235,12 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
if (isa<UndefValue>(AI.getArraySize()))
return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType()));
- // Ensure that the alloca array size argument has type intptr_t, so that
- // any casting is exposed early.
- Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType());
- if (AI.getArraySize()->getType() != IntPtrTy) {
- Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false);
+ // Ensure that the alloca array size argument has type equal to the offset
+ // size of the alloca() pointer, which, in the tyical case, is intptr_t,
+ // so that any casting is exposed early.
+ Type *PtrIdxTy = IC.getDataLayout().getIndexType(AI.getType());
+ if (AI.getArraySize()->getType() != PtrIdxTy) {
+ Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), PtrIdxTy, false);
return IC.replaceOperand(AI, 0, V);
}
@@ -259,8 +260,8 @@ namespace {
// instruction.
class PointerReplacer {
public:
- PointerReplacer(InstCombinerImpl &IC, Instruction &Root)
- : IC(IC), Root(Root) {}
+ PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS)
+ : IC(IC), Root(Root), FromAS(SrcAS) {}
bool collectUsers();
void replacePointer(Value *V);
@@ -273,11 +274,21 @@ private:
return I == &Root || Worklist.contains(I);
}
+ bool isEqualOrValidAddrSpaceCast(const Instruction *I,
+ unsigned FromAS) const {
+ const auto *ASC = dyn_cast<AddrSpaceCastInst>(I);
+ if (!ASC)
+ return false;
+ unsigned ToAS = ASC->getDestAddressSpace();
+ return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
+ }
+
SmallPtrSet<Instruction *, 32> ValuesToRevisit;
SmallSetVector<Instruction *, 4> Worklist;
MapVector<Value *, Value *> WorkMap;
InstCombinerImpl &IC;
Instruction &Root;
+ unsigned FromAS;
};
} // end anonymous namespace
@@ -341,6 +352,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) {
if (MI->isVolatile())
return false;
Worklist.insert(Inst);
+ } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) {
+ Worklist.insert(Inst);
} else if (Inst->isLifetimeStartOrEnd()) {
continue;
} else {
@@ -391,9 +404,8 @@ void PointerReplacer::replace(Instruction *I) {
} else if (auto *BC = dyn_cast<BitCastInst>(I)) {
auto *V = getReplacement(BC->getOperand(0));
assert(V && "Operand not replaced");
- auto *NewT = PointerType::getWithSamePointeeType(
- cast<PointerType>(BC->getType()),
- V->getType()->getPointerAddressSpace());
+ auto *NewT = PointerType::get(BC->getType()->getContext(),
+ V->getType()->getPointerAddressSpace());
auto *NewI = new BitCastInst(V, NewT);
IC.InsertNewInstWith(NewI, *BC);
NewI->takeName(BC);
@@ -426,6 +438,22 @@ void PointerReplacer::replace(Instruction *I) {
IC.eraseInstFromFunction(*MemCpy);
WorkMap[MemCpy] = NewI;
+ } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) {
+ auto *V = getReplacement(ASC->getPointerOperand());
+ assert(V && "Operand not replaced");
+ assert(isEqualOrValidAddrSpaceCast(
+ ASC, V->getType()->getPointerAddressSpace()) &&
+ "Invalid address space cast!");
+ auto *NewV = V;
+ if (V->getType()->getPointerAddressSpace() !=
+ ASC->getType()->getPointerAddressSpace()) {
+ auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), "");
+ NewI->takeName(ASC);
+ IC.InsertNewInstWith(NewI, *ASC);
+ NewV = NewI;
+ }
+ IC.replaceInstUsesWith(*ASC, NewV);
+ IC.eraseInstFromFunction(*ASC);
} else {
llvm_unreachable("should never reach here");
}
@@ -435,7 +463,7 @@ void PointerReplacer::replacePointer(Value *V) {
#ifndef NDEBUG
auto *PT = cast<PointerType>(Root.getType());
auto *NT = cast<PointerType>(V->getType());
- assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage");
+ assert(PT != NT && "Invalid usage");
#endif
WorkMap[&Root] = V;
@@ -518,7 +546,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
return NewI;
}
- PointerReplacer PtrReplacer(*this, AI);
+ PointerReplacer PtrReplacer(*this, AI, SrcAddrSpace);
if (PtrReplacer.collectUsers()) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
@@ -739,6 +767,11 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
// the knowledge that padding exists for the rest of the pipeline.
const DataLayout &DL = IC.getDataLayout();
auto *SL = DL.getStructLayout(ST);
+
+ // Don't unpack for structure with scalable vector.
+ if (SL->getSizeInBits().isScalable())
+ return nullptr;
+
if (SL->hasPadding())
return nullptr;
@@ -979,17 +1012,15 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC,
// If we're indexing into an object with a variable index for the memory
// access, but the object has only one element, we can assume that the index
// will always be zero. If we replace the GEP, return it.
-template <typename T>
static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr,
- T &MemI) {
+ Instruction &MemI) {
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) {
unsigned Idx;
if (canReplaceGEPIdxWithZero(IC, GEPI, &MemI, Idx)) {
Instruction *NewGEPI = GEPI->clone();
NewGEPI->setOperand(Idx,
ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0));
- NewGEPI->insertBefore(GEPI);
- MemI.setOperand(MemI.getPointerOperandIndex(), NewGEPI);
+ IC.InsertNewInstBefore(NewGEPI, *GEPI);
return NewGEPI;
}
}
@@ -1024,6 +1055,8 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
Value *Op = LI.getOperand(0);
+ if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
+ return replaceInstUsesWith(LI, Res);
// Try to canonicalize the loaded type.
if (Instruction *Res = combineLoadToOperationType(*this, LI))
@@ -1036,10 +1069,8 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
LI.setAlignment(KnownAlign);
// Replace GEP indices if possible.
- if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) {
- Worklist.push(NewGEPI);
- return &LI;
- }
+ if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI))
+ return replaceOperand(LI, 0, NewGEPI);
if (Instruction *Res = unpackLoadToAggregate(*this, LI))
return Res;
@@ -1065,13 +1096,7 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
// load null/undef -> unreachable
// TODO: Consider a target hook for valid address spaces for this xforms.
if (canSimplifyNullLoadOrGEP(LI, Op)) {
- // Insert a new store to null instruction before the load to indicate
- // that this code is not reachable. We do this instead of inserting
- // an unreachable instruction directly because we cannot modify the
- // CFG.
- StoreInst *SI = new StoreInst(PoisonValue::get(LI.getType()),
- Constant::getNullValue(Op->getType()), &LI);
- SI->setDebugLoc(LI.getDebugLoc());
+ CreateNonTerminatorUnreachable(&LI);
return replaceInstUsesWith(LI, PoisonValue::get(LI.getType()));
}
@@ -1261,6 +1286,11 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
// the knowledge that padding exists for the rest of the pipeline.
const DataLayout &DL = IC.getDataLayout();
auto *SL = DL.getStructLayout(ST);
+
+ // Don't unpack for structure with scalable vector.
+ if (SL->getSizeInBits().isScalable())
+ return false;
+
if (SL->hasPadding())
return false;
@@ -1443,10 +1473,8 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
return eraseInstFromFunction(SI);
// Replace GEP indices if possible.
- if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) {
- Worklist.push(NewGEPI);
- return &SI;
- }
+ if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI))
+ return replaceOperand(SI, 1, NewGEPI);
// Don't hack volatile/ordered stores.
// FIXME: Some bits are legal for ordered atomic stores; needs refactoring.
@@ -1530,6 +1558,16 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
return nullptr; // Do not modify these!
}
+ // This is a non-terminator unreachable marker. Don't remove it.
+ if (isa<UndefValue>(Ptr)) {
+ // Remove all instructions after the marker and guaranteed-to-transfer
+ // instructions before the marker.
+ if (handleUnreachableFrom(SI.getNextNode()) ||
+ removeInstructionsBeforeUnreachable(SI))
+ return &SI;
+ return nullptr;
+ }
+
// store undef, Ptr -> noop
// FIXME: This is technically incorrect because it might overwrite a poison
// value. Change to PoisonValue once #52930 is resolved.
@@ -1571,6 +1609,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
if (!OtherBr || BBI == OtherBB->begin())
return false;
+ auto OtherStoreIsMergeable = [&](StoreInst *OtherStore) -> bool {
+ if (!OtherStore ||
+ OtherStore->getPointerOperand() != SI.getPointerOperand())
+ return false;
+
+ auto *SIVTy = SI.getValueOperand()->getType();
+ auto *OSVTy = OtherStore->getValueOperand()->getType();
+ return CastInst::isBitOrNoopPointerCastable(OSVTy, SIVTy, DL) &&
+ SI.hasSameSpecialState(OtherStore);
+ };
+
// If the other block ends in an unconditional branch, check for the 'if then
// else' case. There is an instruction before the branch.
StoreInst *OtherStore = nullptr;
@@ -1586,8 +1635,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
// If this isn't a store, isn't a store to the same location, or is not the
// right kind of store, bail out.
OtherStore = dyn_cast<StoreInst>(BBI);
- if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1) ||
- !SI.isSameOperationAs(OtherStore))
+ if (!OtherStoreIsMergeable(OtherStore))
return false;
} else {
// Otherwise, the other block ended with a conditional branch. If one of the
@@ -1601,12 +1649,10 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
// lives in OtherBB.
for (;; --BBI) {
// Check to see if we find the matching store.
- if ((OtherStore = dyn_cast<StoreInst>(BBI))) {
- if (OtherStore->getOperand(1) != SI.getOperand(1) ||
- !SI.isSameOperationAs(OtherStore))
- return false;
+ OtherStore = dyn_cast<StoreInst>(BBI);
+ if (OtherStoreIsMergeable(OtherStore))
break;
- }
+
// If we find something that may be using or overwriting the stored
// value, or if we run out of instructions, we can't do the transform.
if (BBI->mayReadFromMemory() || BBI->mayThrow() ||
@@ -1624,14 +1670,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
}
// Insert a PHI node now if we need it.
- Value *MergedVal = OtherStore->getOperand(0);
+ Value *MergedVal = OtherStore->getValueOperand();
// The debug locations of the original instructions might differ. Merge them.
DebugLoc MergedLoc = DILocation::getMergedLocation(SI.getDebugLoc(),
OtherStore->getDebugLoc());
- if (MergedVal != SI.getOperand(0)) {
- PHINode *PN = PHINode::Create(MergedVal->getType(), 2, "storemerge");
- PN->addIncoming(SI.getOperand(0), SI.getParent());
- PN->addIncoming(OtherStore->getOperand(0), OtherBB);
+ if (MergedVal != SI.getValueOperand()) {
+ PHINode *PN =
+ PHINode::Create(SI.getValueOperand()->getType(), 2, "storemerge");
+ PN->addIncoming(SI.getValueOperand(), SI.getParent());
+ Builder.SetInsertPoint(OtherStore);
+ PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()),
+ OtherBB);
MergedVal = InsertNewInstBefore(PN, DestBB->front());
PN->setDebugLoc(MergedLoc);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 97f129e200de..50458e2773e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
return nullptr;
}
+static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
+ bool AssumeNonZero, bool DoFold);
+
Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V =
@@ -270,7 +273,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (match(Op0, m_ZExtOrSExt(m_Value(X))) &&
match(Op1, m_APIntAllowUndef(NegPow2C))) {
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
- unsigned ShiftAmt = NegPow2C->countTrailingZeros();
+ unsigned ShiftAmt = NegPow2C->countr_zero();
if (ShiftAmt >= BitWidth - SrcWidth) {
Value *N = Builder.CreateNeg(X, X->getName() + ".neg");
Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z");
@@ -471,6 +474,40 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
+ // min(X, Y) * max(X, Y) => X * Y.
+ if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
+ m_c_SMin(m_Deferred(X), m_Deferred(Y))),
+ m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
+ m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
+
+ // (mul Op0 Op1):
+ // if Log2(Op0) folds away ->
+ // (shl Op1, Log2(Op0))
+ // if Log2(Op1) folds away ->
+ // (shl Op0, Log2(Op1))
+ if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ true);
+ BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
+ // We can only propegate nuw flag.
+ Shl->setHasNoUnsignedWrap(HasNUW);
+ return Shl;
+ }
+ if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ true);
+ BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
+ // We can only propegate nuw flag.
+ Shl->setHasNoUnsignedWrap(HasNUW);
+ return Shl;
+ }
+
bool Changed = false;
if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) {
Changed = true;
@@ -765,6 +802,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
I.hasNoSignedZeros() && match(Start, m_Zero()))
return replaceInstUsesWith(I, Start);
+ // minimun(X, Y) * maximum(X, Y) => X * Y.
+ if (match(&I,
+ m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN * NaN,
+ // while in optimized version NaN * Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
@@ -976,9 +1027,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
ConstantInt::get(Ty, Product));
}
+ APInt Quotient(C2->getBitWidth(), /*val=*/0ULL, IsSigned);
if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) ||
(!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) {
- APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned);
// (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1.
if (isMultiple(*C2, *C1, Quotient, IsSigned)) {
@@ -1003,7 +1054,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
C1->ult(C1->getBitWidth() - 1)) ||
(!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) &&
C1->ult(C1->getBitWidth()))) {
- APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned);
APInt C1Shifted = APInt::getOneBitSet(
C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue()));
@@ -1026,6 +1076,23 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
}
}
+ // Distribute div over add to eliminate a matching div/mul pair:
+ // ((X * C2) + C1) / C2 --> X + C1/C2
+ // We need a multiple of the divisor for a signed add constant, but
+ // unsigned is fine with any constant pair.
+ if (IsSigned &&
+ match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)),
+ m_APInt(C1))) &&
+ isMultiple(*C1, *C2, Quotient, IsSigned)) {
+ return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient));
+ }
+ if (!IsSigned &&
+ match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)),
+ m_APInt(C1)))) {
+ return BinaryOperator::CreateNUWAdd(X,
+ ConstantInt::get(Ty, C1->udiv(*C2)));
+ }
+
if (!C2->isZero()) // avoid X udiv 0
if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I))
return FoldedDiv;
@@ -1121,7 +1188,7 @@ static const unsigned MaxDepth = 6;
// actual instructions, otherwise return a non-null dummy value. Return nullptr
// on failure.
static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool DoFold) {
+ bool AssumeNonZero, bool DoFold) {
auto IfFold = [DoFold](function_ref<Value *()> Fn) {
if (!DoFold)
return reinterpret_cast<Value *>(-1);
@@ -1147,14 +1214,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: Require one use?
Value *X, *Y;
if (match(Op, m_ZExt(m_Value(X))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
// log2(X << Y) -> log2(X) + Y
// FIXME: Require one use unless X is 1?
- if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
- return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
+ auto *BO = cast<OverflowingBinaryOperator>(Op);
+ // nuw will be set if the `shl` is trivially non-zero.
+ if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ }
// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
// FIXME: missed optimization: if one of the hands of select is/contains
@@ -1162,8 +1233,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: can both hands contain undef?
// FIXME: Require one use?
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
- if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
+ AssumeNonZero, DoFold))
+ if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
+ AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
});
@@ -1171,13 +1244,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
- if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned())
- if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold))
+ if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
+ // Use AssumeNonZero as false here. Otherwise we can hit case where
+ // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
+ if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
+ if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
return IfFold([&]() {
- return Builder.CreateBinaryIntrinsic(
- MinMax->getIntrinsicID(), LogX, LogY);
+ return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
+ LogY);
});
+ }
return nullptr;
}
@@ -1297,8 +1375,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
+ if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
+ /*AssumeNonZero*/ true, /*DoFold*/ true);
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}
@@ -1359,7 +1439,8 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
// (sext X) sdiv C --> sext (X sdiv C)
Value *Op0Src;
if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) &&
- Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) {
+ Op0Src->getType()->getScalarSizeInBits() >=
+ Op1C->getSignificantBits()) {
// In the general case, we need to make sure that the dividend is not the
// minimum signed value because dividing that by -1 is UB. But here, we
@@ -1402,7 +1483,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
KnownBits KnownDividend = computeKnownBits(Op0, 0, &I);
if (!I.isExact() &&
(match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) &&
- KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) {
+ KnownDividend.countMinTrailingZeros() >= Op1C->countr_zero()) {
I.setIsExact();
return &I;
}
@@ -1681,6 +1762,111 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return nullptr;
}
+// Variety of transform for:
+// (urem/srem (mul X, Y), (mul X, Z))
+// (urem/srem (shl X, Y), (shl X, Z))
+// (urem/srem (shl Y, X), (shl Z, X))
+// NB: The shift cases are really just extensions of the mul case. We treat
+// shift as Val * (1 << Amt).
+static Instruction *simplifyIRemMulShl(BinaryOperator &I,
+ InstCombinerImpl &IC) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
+ APInt Y, Z;
+ bool ShiftByX = false;
+
+ // If V is not nullptr, it will be matched using m_Specific.
+ auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool {
+ const APInt *Tmp = nullptr;
+ if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) ||
+ (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
+ C = *Tmp;
+ else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
+ (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
+ C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
+ if (Tmp != nullptr)
+ return true;
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
+ };
+
+ auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool {
+ const APInt *Tmp = nullptr;
+ if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) ||
+ (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) {
+ C = *Tmp;
+ return true;
+ }
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
+ };
+
+ if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) {
+ // pass
+ } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) {
+ ShiftByX = true;
+ } else {
+ return nullptr;
+ }
+
+ bool IsSRem = I.getOpcode() == Instruction::SRem;
+
+ OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0);
+ // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >=
+ // Z or Z >= Y.
+ bool BO0HasNSW = BO0->hasNoSignedWrap();
+ bool BO0HasNUW = BO0->hasNoUnsignedWrap();
+ bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW;
+
+ APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z);
+ // (rem (mul nuw/nsw X, Y), (mul X, Z))
+ // if (rem Y, Z) == 0
+ // -> 0
+ if (RemYZ.isZero() && BO0NoWrap)
+ return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
+
+ // Helper function to emit either (RemSimplificationC << X) or
+ // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as
+ // (shl V, X) or (mul V, X) respectively.
+ auto CreateMulOrShift =
+ [&](const APInt &RemSimplificationC) -> BinaryOperator * {
+ Value *RemSimplification =
+ ConstantInt::get(I.getType(), RemSimplificationC);
+ return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X)
+ : BinaryOperator::CreateMul(X, RemSimplification);
+ };
+
+ OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
+ bool BO1HasNSW = BO1->hasNoSignedWrap();
+ bool BO1HasNUW = BO1->hasNoUnsignedWrap();
+ bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW;
+ // (rem (mul X, Y), (mul nuw/nsw X, Z))
+ // if (rem Y, Z) == Y
+ // -> (mul nuw/nsw X, Y)
+ if (RemYZ == Y && BO1NoWrap) {
+ BinaryOperator *BO = CreateMulOrShift(Y);
+ // Copy any overflow flags from Op0.
+ BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
+ BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
+ return BO;
+ }
+
+ // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
+ // if Y >= Z
+ // -> (mul {nuw} nsw X, (rem Y, Z))
+ if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
+ BinaryOperator *BO = CreateMulOrShift(RemYZ);
+ BO->setHasNoSignedWrap();
+ BO->setHasNoUnsignedWrap(BO0HasNUW);
+ return BO;
+ }
+
+ return nullptr;
+}
+
/// This function implements the transforms common to both integer remainder
/// instructions (urem and srem). It is called by the visitors to those integer
/// remainder instructions.
@@ -1733,6 +1919,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
}
}
+ if (Instruction *R = simplifyIRemMulShl(I, *this))
+ return R;
+
return nullptr;
}
@@ -1782,8 +1971,21 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
// urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0
Value *X;
if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
- Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty));
- return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0);
+ Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen");
+ Value *Cmp =
+ Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty));
+ return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0);
+ }
+
+ // For "(X + 1) % Op1" and if (X u< Op1) => (X + 1) == Op1 ? 0 : X + 1 .
+ if (match(Op0, m_Add(m_Value(X), m_One()))) {
+ Value *Val =
+ simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I));
+ if (Val && match(Val, m_One())) {
+ Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen");
+ Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1);
+ return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0);
+ }
}
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 7f59729f0085..2f6aa85062a5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -316,7 +316,7 @@ Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
if (auto *NewOp =
simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) {
- PN.setIncomingValue(OpNum, NewOp);
+ replaceOperand(PN, OpNum, NewOp);
OperandWithRoundTripCast = true;
}
}
@@ -745,6 +745,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) {
LLVMContext::MD_dereferenceable,
LLVMContext::MD_dereferenceable_or_null,
LLVMContext::MD_access_group,
+ LLVMContext::MD_noundef,
};
for (unsigned ID : KnownIDs)
@@ -1388,11 +1389,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// If all PHI operands are the same operation, pull them through the PHI,
// reducing code size.
- if (isa<Instruction>(PN.getIncomingValue(0)) &&
- isa<Instruction>(PN.getIncomingValue(1)) &&
- cast<Instruction>(PN.getIncomingValue(0))->getOpcode() ==
- cast<Instruction>(PN.getIncomingValue(1))->getOpcode() &&
- PN.getIncomingValue(0)->hasOneUser())
+ auto *Inst0 = dyn_cast<Instruction>(PN.getIncomingValue(0));
+ auto *Inst1 = dyn_cast<Instruction>(PN.getIncomingValue(1));
+ if (Inst0 && Inst1 && Inst0->getOpcode() == Inst1->getOpcode() &&
+ Inst0->hasOneUser())
if (Instruction *Result = foldPHIArgOpIntoPHI(PN))
return Result;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index e7d8208f94fd..661c50062223 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -98,7 +98,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
// +0.0 compares equal to -0.0, and so it does not behave as required for this
// transform. Bail out if we can not exclude that possibility.
if (isa<FPMathOperator>(BO))
- if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI))
+ if (!BO->hasNoSignedZeros() &&
+ !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI))
return nullptr;
// BO = binop Y, X
@@ -386,6 +387,32 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp});
}
}
+
+ // select c, (ldexp v, e0), (ldexp v, e1) -> ldexp v, (select c, e0, e1)
+ // select c, (ldexp v0, e), (ldexp v1, e) -> ldexp (select c, v0, v1), e
+ //
+ // select c, (ldexp v0, e0), (ldexp v1, e1) ->
+ // ldexp (select c, v0, v1), (select c, e0, e1)
+ if (TII->getIntrinsicID() == Intrinsic::ldexp) {
+ Value *LdexpVal0 = TII->getArgOperand(0);
+ Value *LdexpExp0 = TII->getArgOperand(1);
+ Value *LdexpVal1 = FII->getArgOperand(0);
+ Value *LdexpExp1 = FII->getArgOperand(1);
+ if (LdexpExp0->getType() == LdexpExp1->getType()) {
+ FPMathOperator *SelectFPOp = cast<FPMathOperator>(&SI);
+ FastMathFlags FMF = cast<FPMathOperator>(TII)->getFastMathFlags();
+ FMF &= cast<FPMathOperator>(FII)->getFastMathFlags();
+ FMF |= SelectFPOp->getFastMathFlags();
+
+ Value *SelectVal = Builder.CreateSelect(Cond, LdexpVal0, LdexpVal1);
+ Value *SelectExp = Builder.CreateSelect(Cond, LdexpExp0, LdexpExp1);
+
+ CallInst *NewLdexp = Builder.CreateIntrinsic(
+ TII->getType(), Intrinsic::ldexp, {SelectVal, SelectExp});
+ NewLdexp->setFastMathFlags(FMF);
+ return replaceInstUsesWith(SI, NewLdexp);
+ }
+ }
}
// icmp with a common operand also can have the common operand
@@ -429,6 +456,21 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
!OtherOpF->getType()->isVectorTy()))
return nullptr;
+ // If we are sinking div/rem after a select, we may need to freeze the
+ // condition because div/rem may induce immediate UB with a poison operand.
+ // For example, the following transform is not safe if Cond can ever be poison
+ // because we can replace poison with zero and then we have div-by-zero that
+ // didn't exist in the original code:
+ // Cond ? x/y : x/z --> x / (Cond ? y : z)
+ auto *BO = dyn_cast<BinaryOperator>(TI);
+ if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) {
+ // A udiv/urem with a common divisor is safe because UB can only occur with
+ // div-by-zero, and that would be present in the original code.
+ if (BO->getOpcode() == Instruction::SDiv ||
+ BO->getOpcode() == Instruction::SRem || MatchIsOpZero)
+ Cond = Builder.CreateFreeze(Cond);
+ }
+
// If we reach here, they do have operations in common.
Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
SI.getName() + ".v", &SI);
@@ -461,7 +503,7 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) {
/// optimization.
Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
Value *FalseVal) {
- // See the comment above GetSelectFoldableOperands for a description of the
+ // See the comment above getSelectFoldableOperands for a description of the
// transformation we are doing here.
auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal,
Value *FalseVal,
@@ -496,7 +538,7 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
if (!isa<Constant>(OOp) ||
(OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
- Swapped ? OOp : C);
+ Swapped ? OOp : C, "", &SI);
if (isa<FPMathOperator>(&SI))
cast<Instruction>(NewSel)->setFastMathFlags(FMF);
NewSel->takeName(TVI);
@@ -569,6 +611,44 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
}
/// We want to turn:
+/// (select (icmp eq (and X, C1), 0), 0, (shl [nsw/nuw] X, C2));
+/// iff C1 is a mask and the number of its leading zeros is equal to C2
+/// into:
+/// shl X, C2
+static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal,
+ Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ ICmpInst::Predicate Pred;
+ Value *AndVal;
+ if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero())))
+ return nullptr;
+
+ if (Pred == ICmpInst::ICMP_NE) {
+ Pred = ICmpInst::ICMP_EQ;
+ std::swap(TVal, FVal);
+ }
+
+ Value *X;
+ const APInt *C2, *C1;
+ if (Pred != ICmpInst::ICMP_EQ ||
+ !match(AndVal, m_And(m_Value(X), m_APInt(C1))) ||
+ !match(TVal, m_Zero()) || !match(FVal, m_Shl(m_Specific(X), m_APInt(C2))))
+ return nullptr;
+
+ if (!C1->isMask() ||
+ C1->countLeadingZeros() != static_cast<unsigned>(C2->getZExtValue()))
+ return nullptr;
+
+ auto *FI = dyn_cast<Instruction>(FVal);
+ if (!FI)
+ return nullptr;
+
+ FI->setHasNoSignedWrap(false);
+ FI->setHasNoUnsignedWrap(false);
+ return FVal;
+}
+
+/// We want to turn:
/// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1
/// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0
/// into:
@@ -935,10 +1015,53 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+/// Try to match patterns with select and subtract as absolute difference.
+static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ auto *TI = dyn_cast<Instruction>(TVal);
+ auto *FI = dyn_cast<Instruction>(FVal);
+ if (!TI || !FI)
+ return nullptr;
+
+ // Normalize predicate to gt/lt rather than ge/le.
+ ICmpInst::Predicate Pred = Cmp->getStrictPredicate();
+ Value *A = Cmp->getOperand(0);
+ Value *B = Cmp->getOperand(1);
+
+ // Normalize "A - B" as the true value of the select.
+ if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) {
+ std::swap(FI, TI);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ // With any pair of no-wrap subtracts:
+ // (A > B) ? (A - B) : (B - A) --> abs(A - B)
+ if (Pred == CmpInst::ICMP_SGT &&
+ match(TI, m_Sub(m_Specific(A), m_Specific(B))) &&
+ match(FI, m_Sub(m_Specific(B), m_Specific(A))) &&
+ (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) &&
+ (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) {
+ // The remaining subtract is not "nuw" any more.
+ // If there's one use of the subtract (no other use than the use we are
+ // about to replace), then we know that the sub is "nsw" in this context
+ // even if it was only "nuw" before. If there's another use, then we can't
+ // add "nsw" to the existing instruction because it may not be safe in the
+ // other user's context.
+ TI->setHasNoUnsignedWrap(false);
+ if (!TI->hasNoSignedWrap())
+ TI->setHasNoSignedWrap(TI->hasOneUse());
+ return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue());
+ }
+
+ return nullptr;
+}
+
/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
// x ? 31 - a : a;
+// // or
+// x ? 31 - a : 32;
/// \code
///
/// into:
@@ -953,15 +1076,19 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal,
if (ICI->getPredicate() == ICmpInst::ICMP_NE)
std::swap(TrueVal, FalseVal);
+ Value *Ctlz;
if (!match(FalseVal,
- m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1))))
+ m_Xor(m_Value(Ctlz), m_SpecificInt(BitWidth - 1))))
return nullptr;
- if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>()))
+ if (!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>()))
+ return nullptr;
+
+ if (TrueVal != Ctlz && !match(TrueVal, m_SpecificInt(BitWidth)))
return nullptr;
Value *X = ICI->getOperand(0);
- auto *II = cast<IntrinsicInst>(TrueVal);
+ auto *II = cast<IntrinsicInst>(Ctlz);
if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X)))))
return nullptr;
@@ -1038,99 +1165,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
return nullptr;
}
-/// Return true if we find and adjust an icmp+select pattern where the compare
-/// is with a constant that can be incremented or decremented to match the
-/// minimum or maximum idiom.
-static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
- ICmpInst::Predicate Pred = Cmp.getPredicate();
- Value *CmpLHS = Cmp.getOperand(0);
- Value *CmpRHS = Cmp.getOperand(1);
- Value *TrueVal = Sel.getTrueValue();
- Value *FalseVal = Sel.getFalseValue();
-
- // We may move or edit the compare, so make sure the select is the only user.
- const APInt *CmpC;
- if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC)))
- return false;
-
- // These transforms only work for selects of integers or vector selects of
- // integer vectors.
- Type *SelTy = Sel.getType();
- auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType());
- if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy())
- return false;
-
- Constant *AdjustedRHS;
- if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT)
- AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1);
- else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT)
- AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1);
- else
- return false;
-
- // X > C ? X : C+1 --> X < C+1 ? C+1 : X
- // X < C ? X : C-1 --> X > C-1 ? C-1 : X
- if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) ||
- (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) {
- ; // Nothing to do here. Values match without any sign/zero extension.
- }
- // Types do not match. Instead of calculating this with mixed types, promote
- // all to the larger type. This enables scalar evolution to analyze this
- // expression.
- else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) {
- Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy);
-
- // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X
- // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X
- // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X
- // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X
- if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) {
- CmpLHS = TrueVal;
- AdjustedRHS = SextRHS;
- } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) &&
- SextRHS == TrueVal) {
- CmpLHS = FalseVal;
- AdjustedRHS = SextRHS;
- } else if (Cmp.isUnsigned()) {
- Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy);
- // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X
- // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X
- // zext + signed compare cannot be changed:
- // 0xff <s 0x00, but 0x00ff >s 0x0000
- if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) {
- CmpLHS = TrueVal;
- AdjustedRHS = ZextRHS;
- } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) &&
- ZextRHS == TrueVal) {
- CmpLHS = FalseVal;
- AdjustedRHS = ZextRHS;
- } else {
- return false;
- }
- } else {
- return false;
- }
- } else {
- return false;
- }
-
- Pred = ICmpInst::getSwappedPredicate(Pred);
- CmpRHS = AdjustedRHS;
- std::swap(FalseVal, TrueVal);
- Cmp.setPredicate(Pred);
- Cmp.setOperand(0, CmpLHS);
- Cmp.setOperand(1, CmpRHS);
- Sel.setOperand(1, TrueVal);
- Sel.setOperand(2, FalseVal);
- Sel.swapProfMetadata();
-
- // Move the compare instruction right before the select instruction. Otherwise
- // the sext/zext value may be defined after the compare instruction uses it.
- Cmp.moveBefore(&Sel);
-
- return true;
-}
-
static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
InstCombinerImpl &IC) {
Value *LHS, *RHS;
@@ -1182,8 +1216,8 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
}
-static bool replaceInInstruction(Value *V, Value *Old, Value *New,
- InstCombiner &IC, unsigned Depth = 0) {
+bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
+ unsigned Depth) {
// Conservatively limit replacement to two instructions upwards.
if (Depth == 2)
return false;
@@ -1195,10 +1229,11 @@ static bool replaceInInstruction(Value *V, Value *Old, Value *New,
bool Changed = false;
for (Use &U : I->operands()) {
if (U == Old) {
- IC.replaceUse(U, New);
+ replaceUse(U, New);
+ Worklist.add(I);
Changed = true;
} else {
- Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1);
+ Changed |= replaceInInstruction(U, Old, New, Depth + 1);
}
}
return Changed;
@@ -1254,7 +1289,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// FIXME: Support vectors.
if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
!Cmp.getType()->isVectorTy())
- if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this))
+ if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS))
return &Sel;
}
if (TrueVal != CmpRHS &&
@@ -1593,13 +1628,32 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal,
return nullptr;
}
-static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) {
+static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
+ InstCombiner::BuilderTy &Builder) {
const APInt *CmpC;
Value *V;
CmpInst::Predicate Pred;
if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC))))
return nullptr;
+ // Match clamp away from min/max value as a max/min operation.
+ Value *TVal = SI.getTrueValue();
+ Value *FVal = SI.getFalseValue();
+ if (Pred == ICmpInst::ICMP_EQ && V == FVal) {
+ // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1)
+ if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal);
+ // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1)
+ if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal);
+ // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1)
+ if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal);
+ // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1)
+ if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1)))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal);
+ }
+
BinaryOperator *BO;
const APInt *C;
CmpInst::Predicate CPred;
@@ -1632,7 +1686,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this))
return NewSPF;
- if (Value *V = foldSelectInstWithICmpConst(SI, ICI))
+ if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = canonicalizeClampLike(SI, *ICI, Builder))
@@ -1642,18 +1696,17 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
return NewSel;
- bool Changed = adjustMinMax(SI, *ICI);
-
if (Value *V = foldSelectICmpAnd(SI, ICI, Builder))
return replaceInstUsesWith(SI, V);
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
+ bool Changed = false;
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
- if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
+ if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) {
if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
// Transform (X == C) ? X : Y -> (X == C) ? C : Y
SI.setOperand(1, CmpRHS);
@@ -1683,7 +1736,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
// FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring
// decomposeBitTestICmp() might help.
- {
+ if (TrueVal->getType()->isIntOrIntVectorTy()) {
unsigned BitWidth =
DL.getTypeSizeInBits(TrueVal->getType()->getScalarType());
APInt MinSignedValue = APInt::getSignedMinValue(BitWidth);
@@ -1735,6 +1788,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder))
return V;
+ if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder))
return V;
@@ -1756,6 +1812,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
@@ -2418,7 +2477,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
// in the case of a shuffle with no undefined mask elements.
ArrayRef<int> Mask;
if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
- !is_contained(Mask, UndefMaskElem) &&
+ !is_contained(Mask, PoisonMaskElem) &&
cast<ShuffleVectorInst>(TVal)->isSelect()) {
if (X == FVal) {
// select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X)
@@ -2432,7 +2491,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
}
}
if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
- !is_contained(Mask, UndefMaskElem) &&
+ !is_contained(Mask, PoisonMaskElem) &&
cast<ShuffleVectorInst>(FVal)->isSelect()) {
if (X == TVal) {
// select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y)
@@ -2965,6 +3024,14 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
return replaceOperand(SI, 0, A);
+ // select a, (select ~a, true, b), false -> select a, b, false
+ if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) &&
+ match(FalseVal, m_Zero()))
+ return replaceOperand(SI, 1, B);
+ // select a, true, (select ~a, b, false) -> select a, true, b
+ if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) &&
+ match(TrueVal, m_One()))
+ return replaceOperand(SI, 2, B);
// ~(A & B) & (A | B) --> A ^ B
if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
@@ -3077,6 +3144,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return nullptr;
}
+// Return true if we can safely remove the select instruction for std::bit_ceil
+// pattern.
+static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
+ const APInt *Cond1, Value *CtlzOp,
+ unsigned BitWidth) {
+ // The challenge in recognizing std::bit_ceil(X) is that the operand is used
+ // for the CTLZ proper and select condition, each possibly with some
+ // operation like add and sub.
+ //
+ // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the
+ // select instruction would select 1, which allows us to get rid of the select
+ // instruction.
+ //
+ // To see if we can do so, we do some symbolic execution with ConstantRange.
+ // Specifically, we compute the range of values that Cond0 could take when
+ // Cond == false. Then we successively transform the range until we obtain
+ // the range of values that CtlzOp could take.
+ //
+ // Conceptually, we follow the def-use chain backward from Cond0 while
+ // transforming the range for Cond0 until we meet the common ancestor of Cond0
+ // and CtlzOp. Then we follow the def-use chain forward until we obtain the
+ // range for CtlzOp. That said, we only follow at most one ancestor from
+ // Cond0. Likewise, we only follow at most one ancestor from CtrlOp.
+
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(
+ CmpInst::getInversePredicate(Pred), *Cond1);
+
+ // Match the operation that's used to compute CtlzOp from CommonAncestor. If
+ // CtlzOp == CommonAncestor, return true as no operation is needed. If a
+ // match is found, execute the operation on CR, update CR, and return true.
+ // Otherwise, return false.
+ auto MatchForward = [&](Value *CommonAncestor) {
+ const APInt *C = nullptr;
+ if (CtlzOp == CommonAncestor)
+ return true;
+ if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) {
+ CR = CR.add(*C);
+ return true;
+ }
+ if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
+ CR = ConstantRange(*C).sub(CR);
+ return true;
+ }
+ if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) {
+ CR = CR.binaryNot();
+ return true;
+ }
+ return false;
+ };
+
+ const APInt *C = nullptr;
+ Value *CommonAncestor;
+ if (MatchForward(Cond0)) {
+ // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) {
+ CR = CR.sub(*C);
+ if (!MatchForward(CommonAncestor))
+ return false;
+ // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else {
+ return false;
+ }
+
+ // Return true if all the values in the range are either 0 or negative (if
+ // treated as signed). We do so by evaluating:
+ //
+ // CR - 1 u>= (1 << BitWidth) - 1.
+ APInt IntMax = APInt::getSignMask(BitWidth) - 1;
+ CR = CR.sub(APInt(BitWidth, 1));
+ return CR.icmp(ICmpInst::ICMP_UGE, IntMax);
+}
+
+// Transform the std::bit_ceil(X) pattern like:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %sub = sub i32 32, %ctlz
+// %shl = shl i32 1, %sub
+// %ugt = icmp ugt i32 %x, 1
+// %sel = select i1 %ugt, i32 %shl, i32 1
+//
+// into:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %neg = sub i32 0, %ctlz
+// %masked = and i32 %ctlz, 31
+// %shl = shl i32 1, %sub
+//
+// Note that the select is optimized away while the shift count is masked with
+// 31. We handle some variations of the input operand like std::bit_ceil(X +
+// 1).
+static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
+ Type *SelType = SI.getType();
+ unsigned BitWidth = SelType->getScalarSizeInBits();
+
+ Value *FalseVal = SI.getFalseValue();
+ Value *TrueVal = SI.getTrueValue();
+ ICmpInst::Predicate Pred;
+ const APInt *Cond1;
+ Value *Cond0, *Ctlz, *CtlzOp;
+ if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
+ return nullptr;
+
+ if (match(TrueVal, m_One())) {
+ std::swap(FalseVal, TrueVal);
+ Pred = CmpInst::getInversePredicate(Pred);
+ }
+
+ if (!match(FalseVal, m_One()) ||
+ !match(TrueVal,
+ m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
+ m_Value(Ctlz)))))) ||
+ !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
+ !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
+ return nullptr;
+
+ // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a
+ // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
+ // is an integer constant. Masking with BitWidth-1 comes free on some
+ // hardware as part of the shift instruction.
+ Value *Neg = Builder.CreateNeg(Ctlz);
+ Value *Masked =
+ Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1));
+ return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1),
+ Masked);
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3253,6 +3448,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
std::swap(NewT, NewF);
Value *NewSI =
Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI);
+ if (Gep->isInBounds())
+ return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI});
return GetElementPtrInst::Create(ElementType, Ptr, {NewSI});
};
if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal))
@@ -3364,25 +3561,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
- auto canMergeSelectThroughBinop = [](BinaryOperator *BO) {
- // The select might be preventing a division by 0.
- switch (BO->getOpcode()) {
- default:
- return true;
- case Instruction::SRem:
- case Instruction::URem:
- case Instruction::SDiv:
- case Instruction::UDiv:
- return false;
- }
- };
-
// Try to simplify a binop sandwiched between 2 selects with the same
- // condition.
+ // condition. This is not valid for div/rem because the select might be
+ // preventing a division-by-zero.
+ // TODO: A div/rem restriction is conservative; use something like
+ // isSafeToSpeculativelyExecute().
// select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z)
BinaryOperator *TrueBO;
- if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) &&
- canMergeSelectThroughBinop(TrueBO)) {
+ if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) {
if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) {
if (TrueBOSI->getCondition() == CondVal) {
replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue());
@@ -3401,8 +3587,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
// select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W))
BinaryOperator *FalseBO;
- if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) &&
- canMergeSelectThroughBinop(FalseBO)) {
+ if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) {
if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) {
if (FalseBOSI->getCondition() == CondVal) {
replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue());
@@ -3516,5 +3701,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (sinkNotIntoOtherHandOfLogicalOp(SI))
return &SI;
+ if (Instruction *I = foldBitCeil(SI, Builder))
+ return I;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index ec505381cc86..89dad455f015 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -322,15 +322,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
}
-/// If we have a shift-by-constant of a bitwise logic op that itself has a
-/// shift-by-constant operand with identical opcode, we may be able to convert
-/// that into 2 independent shifts followed by the logic op. This eliminates a
-/// a use of an intermediate value (reduces dependency chain).
-static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
+/// If we have a shift-by-constant of a bin op (bitwise logic op or add/sub w/
+/// shl) that itself has a shift-by-constant operand with identical opcode, we
+/// may be able to convert that into 2 independent shifts followed by the logic
+/// op. This eliminates a use of an intermediate value (reduces dependency
+/// chain).
+static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
assert(I.isShift() && "Expected a shift as input");
- auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0));
- if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse())
+ auto *BinInst = dyn_cast<BinaryOperator>(I.getOperand(0));
+ if (!BinInst ||
+ (!BinInst->isBitwiseLogicOp() &&
+ BinInst->getOpcode() != Instruction::Add &&
+ BinInst->getOpcode() != Instruction::Sub) ||
+ !BinInst->hasOneUse())
return nullptr;
Constant *C0, *C1;
@@ -338,6 +343,12 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
return nullptr;
Instruction::BinaryOps ShiftOpcode = I.getOpcode();
+ // Transform for add/sub only works with shl.
+ if ((BinInst->getOpcode() == Instruction::Add ||
+ BinInst->getOpcode() == Instruction::Sub) &&
+ ShiftOpcode != Instruction::Shl)
+ return nullptr;
+
Type *Ty = I.getType();
// Find a matching one-use shift by constant. The fold is not valid if the sum
@@ -352,19 +363,25 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
};
- // Logic ops are commutative, so check each operand for a match.
- if (matchFirstShift(LogicInst->getOperand(0)))
- Y = LogicInst->getOperand(1);
- else if (matchFirstShift(LogicInst->getOperand(1)))
- Y = LogicInst->getOperand(0);
- else
+ // Logic ops and Add are commutative, so check each operand for a match. Sub
+ // is not so we cannot reoder if we match operand(1) and need to keep the
+ // operands in their original positions.
+ bool FirstShiftIsOp1 = false;
+ if (matchFirstShift(BinInst->getOperand(0)))
+ Y = BinInst->getOperand(1);
+ else if (matchFirstShift(BinInst->getOperand(1))) {
+ Y = BinInst->getOperand(0);
+ FirstShiftIsOp1 = BinInst->getOpcode() == Instruction::Sub;
+ } else
return nullptr;
- // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
+ // shift (binop (shift X, C0), Y), C1 -> binop (shift X, C0+C1), (shift Y, C1)
Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1);
- return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
+ Value *Op1 = FirstShiftIsOp1 ? NewShift2 : NewShift1;
+ Value *Op2 = FirstShiftIsOp1 ? NewShift1 : NewShift2;
+ return BinaryOperator::Create(BinInst->getOpcode(), Op1, Op2);
}
Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
@@ -463,9 +480,12 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
return replaceOperand(I, 1, Rem);
}
- if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder))
+ if (Instruction *Logic = foldShiftOfShiftedBinOp(I, Builder))
return Logic;
+ if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1))))
+ return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1));
+
return nullptr;
}
@@ -570,8 +590,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
const APInt *MulConst;
// We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
- MulConst->isNegatedPowerOf2() &&
- MulConst->countTrailingZeros() == NumBits;
+ MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
}
}
}
@@ -900,8 +919,10 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
// Replace the uses of the original add with a zext of the
// NarrowAdd's result. Note that all users at this stage are known to
// be ShAmt-sized truncs, or the lshr itself.
- if (!Add->hasOneUse())
+ if (!Add->hasOneUse()) {
replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
+ eraseInstFromFunction(*AddInst);
+ }
// Replace the LShr with a zext of the overflow check.
return new ZExtInst(Overflow, Ty);
@@ -1133,6 +1154,14 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
return BinaryOperator::CreateLShr(
ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+ // Canonicalize "extract lowest set bit" using cttz to and-with-negate:
+ // 1 << (cttz X) --> -X & X
+ if (match(Op1,
+ m_OneUse(m_Intrinsic<Intrinsic::cttz>(m_Value(X), m_Value())))) {
+ Value *NegX = Builder.CreateNeg(X, "neg");
+ return BinaryOperator::CreateAnd(NegX, X);
+ }
+
// The only way to shift out the 1 is with an over-shift, so that would
// be poison with or without "nuw". Undef is excluded because (undef << X)
// is not undef (it is zero).
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 77d675422966..00eece9534b0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -168,7 +168,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
// about the high bits of the operands.
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
@@ -195,7 +195,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown & RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -224,7 +225,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown | RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -262,7 +264,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
- Known = LHSKnown ^ RHSKnown;
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, DL, &AC, CxtI, &DT);
// If the client is only demanding bits that we know, return the known
// constant.
@@ -381,7 +384,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
// Only known if known in both the LHS and RHS.
- Known = KnownBits::commonBits(LHSKnown, RHSKnown);
+ Known = LHSKnown.intersectWith(RHSKnown);
break;
}
case Instruction::Trunc: {
@@ -393,7 +396,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The shift amount must be valid (not poison) in the narrow type, and
// it must not be greater than the high bits demanded of the result.
if (C->ult(VTy->getScalarSizeInBits()) &&
- C->ule(DemandedMask.countLeadingZeros())) {
+ C->ule(DemandedMask.countl_zero())) {
// trunc (lshr X, C) --> lshr (trunc X), C
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
@@ -508,7 +511,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
@@ -517,7 +520,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If low order bits are not demanded and known to be zero in one operand,
// then we don't need to demand them from the other operand, since they
// can't cause overflow into any bits that are demanded in the result.
- unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
APInt DemandedFromLHS = DemandedFromOps;
DemandedFromLHS.clearLowBits(NTZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
@@ -539,7 +542,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
case Instruction::Sub: {
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
@@ -548,7 +551,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If low order bits are not demanded and are known to be zero in RHS,
// then we don't need to demand them from LHS, since they can't cause a
// borrow from any bits that are demanded in the result.
- unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
APInt DemandedFromLHS = DemandedFromOps;
DemandedFromLHS.clearLowBits(NTZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
@@ -578,10 +581,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
// If we demand exactly one bit N and we have "X * (C' << N)" where C' is
// odd (has LSB set), then the left-shifted low bit of X is the answer.
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
const APInt *C;
- if (match(I->getOperand(1), m_APInt(C)) &&
- C->countTrailingZeros() == CTZ) {
+ if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) {
Constant *ShiftC = ConstantInt::get(VTy, CTZ);
Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
return InsertNewInstWith(Shl, *I);
@@ -619,7 +621,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
Value *X;
Constant *C;
- if (DemandedMask.countTrailingZeros() >= ShiftAmt &&
+ if (DemandedMask.countr_zero() >= ShiftAmt &&
match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC);
@@ -642,29 +644,15 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
- bool SignBitZero = Known.Zero.isSignBitSet();
- bool SignBitOne = Known.One.isSignBitSet();
- Known.Zero <<= ShiftAmt;
- Known.One <<= ShiftAmt;
- // low bits known zero.
- if (ShiftAmt)
- Known.Zero.setLowBits(ShiftAmt);
-
- // If this shift has "nsw" keyword, then the result is either a poison
- // value or has the same sign bit as the first operand.
- if (IOp->hasNoSignedWrap()) {
- if (SignBitZero)
- Known.Zero.setSignBit();
- else if (SignBitOne)
- Known.One.setSignBit();
- if (Known.hasConflict())
- return UndefValue::get(VTy);
- }
+ Known = KnownBits::shl(Known,
+ KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
+ /* NUW */ IOp->hasNoUnsignedWrap(),
+ /* NSW */ IOp->hasNoSignedWrap());
} else {
// This is a variable shift, so we can't shift the demand mask by a known
// amount. But if we are not demanding high bits, then we are not
// demanding those bits from the pre-shifted operand either.
- if (unsigned CTLZ = DemandedMask.countLeadingZeros()) {
+ if (unsigned CTLZ = DemandedMask.countl_zero()) {
APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) {
// We can't guarantee that nsw/nuw hold after simplifying the operand.
@@ -683,11 +671,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If we are just demanding the shifted sign bit and below, then this can
// be treated as an ASHR in disguise.
- if (DemandedMask.countLeadingZeros() >= ShiftAmt) {
+ if (DemandedMask.countl_zero() >= ShiftAmt) {
// If we only want bits that already match the signbit then we don't
// need to shift.
- unsigned NumHiDemandedBits =
- BitWidth - DemandedMask.countTrailingZeros();
+ unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
unsigned SignBits =
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
if (SignBits >= NumHiDemandedBits)
@@ -734,7 +721,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If we only want bits that already match the signbit then we don't need
// to shift.
- unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros();
+ unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
if (SignBits >= NumHiDemandedBits)
return I->getOperand(0);
@@ -757,7 +744,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
// If any of the high bits are demanded, we should set the sign bit as
// demanded.
- if (DemandedMask.countLeadingZeros() <= ShiftAmt)
+ if (DemandedMask.countl_zero() <= ShiftAmt)
DemandedMaskIn.setSignBit();
// If the shift is exact, then it does demand the low bits (and knows that
@@ -797,7 +784,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
// TODO: Take the demanded mask of the result into account.
- unsigned RHSTrailingZeros = SA->countTrailingZeros();
+ unsigned RHSTrailingZeros = SA->countr_zero();
APInt DemandedMaskIn =
APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) {
@@ -807,9 +794,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
}
- // Increase high zero bits from the input.
- Known.Zero.setHighBits(std::min(
- BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros));
+ Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA),
+ cast<BinaryOperator>(I)->isExact());
} else {
computeKnownBits(I, Known, Depth, CxtI);
}
@@ -851,25 +837,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
}
- // The sign bit is the LHS's sign bit, except when the result of the
- // remainder is zero.
- if (DemandedMask.isSignBitSet()) {
- computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
- // If it's known zero, our sign bit is also zero.
- if (LHSKnown.isNonNegative())
- Known.makeNonNegative();
- }
+ computeKnownBits(I, Known, Depth, CxtI);
break;
}
case Instruction::URem: {
- KnownBits Known2(BitWidth);
APInt AllOnes = APInt::getAllOnes(BitWidth);
- if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) ||
- SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, AllOnes, LHSKnown, Depth + 1) ||
+ SimplifyDemandedBits(I, 1, AllOnes, RHSKnown, Depth + 1))
return I;
- unsigned Leaders = Known2.countMinLeadingZeros();
- Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
+ Known = KnownBits::urem(LHSKnown, RHSKnown);
break;
}
case Instruction::Call: {
@@ -897,8 +874,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
case Intrinsic::bswap: {
// If the only bits demanded come from one byte of the bswap result,
// just shift the input byte into position to eliminate the bswap.
- unsigned NLZ = DemandedMask.countLeadingZeros();
- unsigned NTZ = DemandedMask.countTrailingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
+ unsigned NTZ = DemandedMask.countr_zero();
// Round NTZ down to the next byte. If we have 11 trailing zeros, then
// we need all the bits down to bit 8. Likewise, round NLZ. If we
@@ -935,9 +912,28 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
- if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) ||
- SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
- return I;
+ if (I->getOperand(0) != I->getOperand(1)) {
+ if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
+ return I;
+ } else { // fshl is a rotate
+ // Avoid converting rotate into funnel shift.
+ // Only simplify if one operand is constant.
+ LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I);
+ if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) &&
+ !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) {
+ replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One));
+ return I;
+ }
+
+ RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I);
+ if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) &&
+ !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) {
+ replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One));
+ return I;
+ }
+ }
Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
@@ -951,7 +947,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The lowest non-zero bit of DemandMask is higher than the highest
// non-zero bit of C.
const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
if (match(II->getArgOperand(1), m_APInt(C)) &&
CTZ >= C->getActiveBits())
return II->getArgOperand(0);
@@ -963,9 +959,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// non-one bit of C.
// This comes from using DeMorgans on the above umax example.
const APInt *C;
- unsigned CTZ = DemandedMask.countTrailingZeros();
+ unsigned CTZ = DemandedMask.countr_zero();
if (match(II->getArgOperand(1), m_APInt(C)) &&
- CTZ >= C->getBitWidth() - C->countLeadingOnes())
+ CTZ >= C->getBitWidth() - C->countl_one())
return II->getArgOperand(0);
break;
}
@@ -1014,6 +1010,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown & RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1033,6 +1030,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown | RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1054,6 +1052,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown ^ RHSKnown;
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1071,7 +1070,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
break;
}
case Instruction::Add: {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
// If an operand adds zeros to every bit below the highest demanded bit,
@@ -1084,10 +1083,13 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
+ bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+ Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown);
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::Sub: {
- unsigned NLZ = DemandedMask.countLeadingZeros();
+ unsigned NLZ = DemandedMask.countl_zero();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
// If an operand subtracts zeros from every bit below the highest demanded
@@ -1096,6 +1098,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
+ bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+ Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown);
+ computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::AShr: {
@@ -1541,7 +1547,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// Found constant vector with single element - convert to insertelement.
if (Op && Value) {
Instruction *New = InsertElementInst::Create(
- Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx),
+ Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx),
Shuffle->getName());
InsertNewInstWith(New, *Shuffle);
return New;
@@ -1552,7 +1558,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
SmallVector<int, 16> Elts;
for (unsigned i = 0; i < VWidth; ++i) {
if (UndefElts[i])
- Elts.push_back(UndefMaskElem);
+ Elts.push_back(PoisonMaskElem);
else
Elts.push_back(Shuffle->getMaskValue(i));
}
@@ -1653,7 +1659,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// corresponding input elements are undef.
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
- if (SubUndef.countPopulation() == Ratio)
+ if (SubUndef.popcount() == Ratio)
UndefElts.setBit(OutIdx);
}
} else {
@@ -1712,6 +1718,54 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// UB/poison potential, but that should be refined.
BinaryOperator *BO;
if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
+ Value *X = BO->getOperand(0);
+ Value *Y = BO->getOperand(1);
+
+ // Look for an equivalent binop except that one operand has been shuffled.
+ // If the demand for this binop only includes elements that are the same as
+ // the other binop, then we may be able to replace this binop with a use of
+ // the earlier one.
+ //
+ // Example:
+ // %other_bo = bo (shuf X, {0}), Y
+ // %this_extracted_bo = extelt (bo X, Y), 0
+ // -->
+ // %other_bo = bo (shuf X, {0}), Y
+ // %this_extracted_bo = extelt %other_bo, 0
+ //
+ // TODO: Handle demand of an arbitrary single element or more than one
+ // element instead of just element 0.
+ // TODO: Unlike general demanded elements transforms, this should be safe
+ // for any (div/rem/shift) opcode too.
+ if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() &&
+ BO->hasOneUse() ) {
+
+ auto findShufBO = [&](bool MatchShufAsOp0) -> User * {
+ // Try to use shuffle-of-operand in place of an operand:
+ // bo X, Y --> bo (shuf X), Y
+ // bo X, Y --> bo X, (shuf Y)
+ BinaryOperator::BinaryOps Opcode = BO->getOpcode();
+ Value *ShufOp = MatchShufAsOp0 ? X : Y;
+ Value *OtherOp = MatchShufAsOp0 ? Y : X;
+ for (User *U : OtherOp->users()) {
+ auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_ZeroMask());
+ if (BO->isCommutative()
+ ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
+ : MatchShufAsOp0
+ ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
+ : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf)))
+ if (DT.dominates(U, I))
+ return U;
+ }
+ return nullptr;
+ };
+
+ if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true))
+ return ShufBO;
+ if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false))
+ return ShufBO;
+ }
+
simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
simplifyAndSetOp(I, 1, DemandedElts, UndefElts2);
@@ -1723,7 +1777,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// If we've proven all of the lanes undef, return an undef value.
// TODO: Intersect w/demanded lanes
if (UndefElts.isAllOnes())
- return UndefValue::get(I->getType());;
+ return UndefValue::get(I->getType());
return MadeChange ? I : nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 61e62adbe327..4a5ffef2b08e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -171,8 +171,11 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
}
}
- for (auto *E : Extracts)
+ for (auto *E : Extracts) {
replaceInstUsesWith(*E, scalarPHI);
+ // Add old extract to worklist for DCE.
+ addToWorklist(E);
+ }
return &EI;
}
@@ -384,7 +387,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
/// return it with the canonical type if it isn't already canonical. We
/// arbitrarily pick 64 bit as our canonical type. The actual bitwidth doesn't
/// matter, we just want a consistent type to simplify CSE.
-ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
+static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
const unsigned IndexBW = IndexC->getType()->getBitWidth();
if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64)
return nullptr;
@@ -543,16 +546,16 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
->getNumElements();
if (SrcIdx < 0)
- return replaceInstUsesWith(EI, UndefValue::get(EI.getType()));
+ return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
if (SrcIdx < (int)LHSWidth)
Src = SVI->getOperand(0);
else {
SrcIdx -= LHSWidth;
Src = SVI->getOperand(1);
}
- Type *Int32Ty = Type::getInt32Ty(EI.getContext());
+ Type *Int64Ty = Type::getInt64Ty(EI.getContext());
return ExtractElementInst::Create(
- Src, ConstantInt::get(Int32Ty, SrcIdx, false));
+ Src, ConstantInt::get(Int64Ty, SrcIdx, false));
}
} else if (auto *CI = dyn_cast<CastInst>(I)) {
// Canonicalize extractelement(cast) -> cast(extractelement).
@@ -594,6 +597,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
true /* AllowMultipleUsers */)) {
if (V != SrcVec) {
+ Worklist.addValue(SrcVec);
SrcVec->replaceAllUsesWith(V);
return &EI;
}
@@ -640,11 +644,11 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
return false;
unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue();
- if (isa<UndefValue>(ScalarOp)) { // inserting undef into vector.
+ if (isa<PoisonValue>(ScalarOp)) { // inserting poison into vector.
// We can handle this if the vector we are inserting into is
// transitively ok.
if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) {
- // If so, update the mask to reflect the inserted undef.
+ // If so, update the mask to reflect the inserted poison.
Mask[InsertedIdx] = -1;
return true;
}
@@ -680,7 +684,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
/// If we have insertion into a vector that is wider than the vector that we
/// are extracting from, try to widen the source vector to allow a single
/// shufflevector to replace one or more insert/extract pairs.
-static void replaceExtractElements(InsertElementInst *InsElt,
+static bool replaceExtractElements(InsertElementInst *InsElt,
ExtractElementInst *ExtElt,
InstCombinerImpl &IC) {
auto *InsVecType = cast<FixedVectorType>(InsElt->getType());
@@ -691,7 +695,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
// The inserted-to vector must be wider than the extracted-from vector.
if (InsVecType->getElementType() != ExtVecType->getElementType() ||
NumExtElts >= NumInsElts)
- return;
+ return false;
// Create a shuffle mask to widen the extended-from vector using poison
// values. The mask selects all of the values of the original vector followed
@@ -719,7 +723,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
// that will delete our widening shuffle. This would trigger another attempt
// here to create that shuffle, and we spin forever.
if (InsertionBlock != InsElt->getParent())
- return;
+ return false;
// TODO: This restriction matches the check in visitInsertElementInst() and
// prevents an infinite loop caused by not turning the extract/insert pair
@@ -727,7 +731,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
// folds for shufflevectors because we're afraid to generate shuffle masks
// that the backend can't handle.
if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back()))
- return;
+ return false;
auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask);
@@ -747,9 +751,14 @@ static void replaceExtractElements(InsertElementInst *InsElt,
if (!OldExt || OldExt->getParent() != WideVec->getParent())
continue;
auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1));
- NewExt->insertAfter(OldExt);
+ IC.InsertNewInstWith(NewExt, *OldExt);
IC.replaceInstUsesWith(*OldExt, NewExt);
+ // Add the old extracts to the worklist for DCE. We can't remove the
+ // extracts directly, because they may still be used by the calling code.
+ IC.addToWorklist(OldExt);
}
+
+ return true;
}
/// We are building a shuffle to create V, which is a sequence of insertelement,
@@ -764,7 +773,7 @@ using ShuffleOps = std::pair<Value *, Value *>;
static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
Value *PermittedRHS,
- InstCombinerImpl &IC) {
+ InstCombinerImpl &IC, bool &Rerun) {
assert(V->getType()->isVectorTy() && "Invalid shuffle!");
unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements();
@@ -795,13 +804,14 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
// otherwise we'd end up with a shuffle of three inputs.
if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) {
Value *RHS = EI->getOperand(0);
- ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC);
+ ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC, Rerun);
assert(LR.second == nullptr || LR.second == RHS);
if (LR.first->getType() != RHS->getType()) {
// Although we are giving up for now, see if we can create extracts
// that match the inserts for another round of combining.
- replaceExtractElements(IEI, EI, IC);
+ if (replaceExtractElements(IEI, EI, IC))
+ Rerun = true;
// We tried our best, but we can't find anything compatible with RHS
// further up the chain. Return a trivial shuffle.
@@ -1129,6 +1139,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
/// It should be transformed to:
/// %0 = insertvalue { i8, i32 } undef, i8 %y, 0
Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) {
+ if (Value *V = simplifyInsertValueInst(
+ I.getAggregateOperand(), I.getInsertedValueOperand(), I.getIndices(),
+ SQ.getWithInstruction(&I)))
+ return replaceInstUsesWith(I, V);
+
bool IsRedundant = false;
ArrayRef<unsigned int> FirstIndices = I.getIndices();
@@ -1235,22 +1250,22 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) {
if (FirstIE == &InsElt)
return nullptr;
- // If we are not inserting into an undef vector, make sure we've seen an
+ // If we are not inserting into a poison vector, make sure we've seen an
// insert into every element.
// TODO: If the base vector is not undef, it might be better to create a splat
// and then a select-shuffle (blend) with the base vector.
- if (!match(FirstIE->getOperand(0), m_Undef()))
+ if (!match(FirstIE->getOperand(0), m_Poison()))
if (!ElementPresent.all())
return nullptr;
// Create the insert + shuffle.
- Type *Int32Ty = Type::getInt32Ty(InsElt.getContext());
+ Type *Int64Ty = Type::getInt64Ty(InsElt.getContext());
PoisonValue *PoisonVec = PoisonValue::get(VecTy);
- Constant *Zero = ConstantInt::get(Int32Ty, 0);
+ Constant *Zero = ConstantInt::get(Int64Ty, 0);
if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero())
FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", &InsElt);
- // Splat from element 0, but replace absent elements with undef in the mask.
+ // Splat from element 0, but replace absent elements with poison in the mask.
SmallVector<int, 16> Mask(NumElements, 0);
for (unsigned i = 0; i != NumElements; ++i)
if (!ElementPresent[i])
@@ -1339,7 +1354,7 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) {
// (demanded elements analysis may unset it later).
return nullptr;
} else {
- assert(OldMask[i] == UndefMaskElem &&
+ assert(OldMask[i] == PoisonMaskElem &&
"Unexpected shuffle mask element for identity shuffle");
NewMask[i] = IdxC;
}
@@ -1465,10 +1480,10 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) {
}
++ValI;
}
- // Remaining values are filled with 'undef' values.
+ // Remaining values are filled with 'poison' values.
for (unsigned I = 0; I < NumElts; ++I) {
if (!Values[I]) {
- Values[I] = UndefValue::get(InsElt.getType()->getElementType());
+ Values[I] = PoisonValue::get(InsElt.getType()->getElementType());
Mask[I] = I;
}
}
@@ -1676,16 +1691,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
// Try to form a shuffle from a chain of extract-insert ops.
if (isShuffleRootCandidate(IE)) {
- SmallVector<int, 16> Mask;
- ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this);
-
- // The proposed shuffle may be trivial, in which case we shouldn't
- // perform the combine.
- if (LR.first != &IE && LR.second != &IE) {
- // We now have a shuffle of LHS, RHS, Mask.
- if (LR.second == nullptr)
- LR.second = UndefValue::get(LR.first->getType());
- return new ShuffleVectorInst(LR.first, LR.second, Mask);
+ bool Rerun = true;
+ while (Rerun) {
+ Rerun = false;
+
+ SmallVector<int, 16> Mask;
+ ShuffleOps LR =
+ collectShuffleElements(&IE, Mask, nullptr, *this, Rerun);
+
+ // The proposed shuffle may be trivial, in which case we shouldn't
+ // perform the combine.
+ if (LR.first != &IE && LR.second != &IE) {
+ // We now have a shuffle of LHS, RHS, Mask.
+ if (LR.second == nullptr)
+ LR.second = PoisonValue::get(LR.first->getType());
+ return new ShuffleVectorInst(LR.first, LR.second, Mask);
+ }
}
}
}
@@ -1815,9 +1836,9 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
/// Rebuild a new instruction just like 'I' but with the new operands given.
/// In the event of type mismatch, the type of the operands is correct.
-static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) {
- // We don't want to use the IRBuilder here because we want the replacement
- // instructions to appear next to 'I', not the builder's insertion point.
+static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps,
+ IRBuilderBase &Builder) {
+ Builder.SetInsertPoint(I);
switch (I->getOpcode()) {
case Instruction::Add:
case Instruction::FAdd:
@@ -1839,28 +1860,29 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) {
case Instruction::Xor: {
BinaryOperator *BO = cast<BinaryOperator>(I);
assert(NewOps.size() == 2 && "binary operator with #ops != 2");
- BinaryOperator *New =
- BinaryOperator::Create(cast<BinaryOperator>(I)->getOpcode(),
- NewOps[0], NewOps[1], "", BO);
- if (isa<OverflowingBinaryOperator>(BO)) {
- New->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap());
- New->setHasNoSignedWrap(BO->hasNoSignedWrap());
- }
- if (isa<PossiblyExactOperator>(BO)) {
- New->setIsExact(BO->isExact());
+ Value *New = Builder.CreateBinOp(cast<BinaryOperator>(I)->getOpcode(),
+ NewOps[0], NewOps[1]);
+ if (auto *NewI = dyn_cast<Instruction>(New)) {
+ if (isa<OverflowingBinaryOperator>(BO)) {
+ NewI->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap());
+ NewI->setHasNoSignedWrap(BO->hasNoSignedWrap());
+ }
+ if (isa<PossiblyExactOperator>(BO)) {
+ NewI->setIsExact(BO->isExact());
+ }
+ if (isa<FPMathOperator>(BO))
+ NewI->copyFastMathFlags(I);
}
- if (isa<FPMathOperator>(BO))
- New->copyFastMathFlags(I);
return New;
}
case Instruction::ICmp:
assert(NewOps.size() == 2 && "icmp with #ops != 2");
- return new ICmpInst(I, cast<ICmpInst>(I)->getPredicate(),
- NewOps[0], NewOps[1]);
+ return Builder.CreateICmp(cast<ICmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1]);
case Instruction::FCmp:
assert(NewOps.size() == 2 && "fcmp with #ops != 2");
- return new FCmpInst(I, cast<FCmpInst>(I)->getPredicate(),
- NewOps[0], NewOps[1]);
+ return Builder.CreateFCmp(cast<FCmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1]);
case Instruction::Trunc:
case Instruction::ZExt:
case Instruction::SExt:
@@ -1876,27 +1898,26 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) {
I->getType()->getScalarType(),
cast<VectorType>(NewOps[0]->getType())->getElementCount());
assert(NewOps.size() == 1 && "cast with #ops != 1");
- return CastInst::Create(cast<CastInst>(I)->getOpcode(), NewOps[0], DestTy,
- "", I);
+ return Builder.CreateCast(cast<CastInst>(I)->getOpcode(), NewOps[0],
+ DestTy);
}
case Instruction::GetElementPtr: {
Value *Ptr = NewOps[0];
ArrayRef<Value*> Idx = NewOps.slice(1);
- GetElementPtrInst *GEP = GetElementPtrInst::Create(
- cast<GetElementPtrInst>(I)->getSourceElementType(), Ptr, Idx, "", I);
- GEP->setIsInBounds(cast<GetElementPtrInst>(I)->isInBounds());
- return GEP;
+ return Builder.CreateGEP(cast<GEPOperator>(I)->getSourceElementType(),
+ Ptr, Idx, "",
+ cast<GEPOperator>(I)->isInBounds());
}
}
llvm_unreachable("failed to rebuild vector instructions");
}
-static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
+static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask,
+ IRBuilderBase &Builder) {
// Mask.size() does not need to be equal to the number of vector elements.
assert(V->getType()->isVectorTy() && "can't reorder non-vector elements");
Type *EltTy = V->getType()->getScalarType();
- Type *I32Ty = IntegerType::getInt32Ty(V->getContext());
if (match(V, m_Undef()))
return UndefValue::get(FixedVectorType::get(EltTy, Mask.size()));
@@ -1950,15 +1971,14 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
// as well. E.g. GetElementPtr may have scalar operands even if the
// return value is a vector, so we need to examine the operand type.
if (I->getOperand(i)->getType()->isVectorTy())
- V = evaluateInDifferentElementOrder(I->getOperand(i), Mask);
+ V = evaluateInDifferentElementOrder(I->getOperand(i), Mask, Builder);
else
V = I->getOperand(i);
NewOps.push_back(V);
NeedsRebuild |= (V != I->getOperand(i));
}
- if (NeedsRebuild) {
- return buildNew(I, NewOps);
- }
+ if (NeedsRebuild)
+ return buildNew(I, NewOps, Builder);
return I;
}
case Instruction::InsertElement: {
@@ -1979,11 +1999,12 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
// If element is not in Mask, no need to handle the operand 1 (element to
// be inserted). Just evaluate values in operand 0 according to Mask.
if (!Found)
- return evaluateInDifferentElementOrder(I->getOperand(0), Mask);
+ return evaluateInDifferentElementOrder(I->getOperand(0), Mask, Builder);
- Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask);
- return InsertElementInst::Create(V, I->getOperand(1),
- ConstantInt::get(I32Ty, Index), "", I);
+ Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask,
+ Builder);
+ Builder.SetInsertPoint(I);
+ return Builder.CreateInsertElement(V, I->getOperand(1), Index);
}
}
llvm_unreachable("failed to reorder elements of vector instruction!");
@@ -2140,7 +2161,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) {
ConstantExpr::getShuffleVector(IdC, C, Mask);
bool MightCreatePoisonOrUB =
- is_contained(Mask, UndefMaskElem) &&
+ is_contained(Mask, PoisonMaskElem) &&
(Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode));
if (MightCreatePoisonOrUB)
NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true);
@@ -2154,7 +2175,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) {
// An undef shuffle mask element may propagate as an undef constant element in
// the new binop. That would produce poison where the original code might not.
// If we already made a safe constant, then there's no danger.
- if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB)
+ if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB)
NewBO->dropPoisonGeneratingFlags();
return NewBO;
}
@@ -2178,8 +2199,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf,
// Insert into element 0 of an undef vector.
UndefValue *UndefVec = UndefValue::get(Shuf.getType());
- Constant *Zero = Builder.getInt32(0);
- Value *NewIns = Builder.CreateInsertElement(UndefVec, X, Zero);
+ Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0);
// Splat from element 0. Any mask element that is undefined remains undefined.
// For example:
@@ -2189,7 +2209,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf,
cast<FixedVectorType>(Shuf.getType())->getNumElements();
SmallVector<int, 16> NewMask(NumMaskElts, 0);
for (unsigned i = 0; i != NumMaskElts; ++i)
- if (Mask[i] == UndefMaskElem)
+ if (Mask[i] == PoisonMaskElem)
NewMask[i] = Mask[i];
return new ShuffleVectorInst(NewIns, NewMask);
@@ -2274,7 +2294,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) {
// mask element, the result is undefined, but it is not poison or undefined
// behavior. That is not necessarily true for div/rem/shift.
bool MightCreatePoisonOrUB =
- is_contained(Mask, UndefMaskElem) &&
+ is_contained(Mask, PoisonMaskElem) &&
(Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc));
if (MightCreatePoisonOrUB)
NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC,
@@ -2325,7 +2345,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) {
NewI->andIRFlags(B1);
if (DropNSW)
NewI->setHasNoSignedWrap(false);
- if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB)
+ if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB)
NewI->dropPoisonGeneratingFlags();
}
return replaceInstUsesWith(Shuf, NewBO);
@@ -2361,7 +2381,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf,
SrcType->getScalarSizeInBits() / DestType->getScalarSizeInBits();
ArrayRef<int> Mask = Shuf.getShuffleMask();
for (unsigned i = 0, e = Mask.size(); i != e; ++i) {
- if (Mask[i] == UndefMaskElem)
+ if (Mask[i] == PoisonMaskElem)
continue;
uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio;
assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits");
@@ -2407,37 +2427,51 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf,
return SelectInst::Create(NarrowCond, NarrowX, NarrowY);
}
-/// Canonicalize FP negate after shuffle.
-static Instruction *foldFNegShuffle(ShuffleVectorInst &Shuf,
- InstCombiner::BuilderTy &Builder) {
- Instruction *FNeg0;
+/// Canonicalize FP negate/abs after shuffle.
+static Instruction *foldShuffleOfUnaryOps(ShuffleVectorInst &Shuf,
+ InstCombiner::BuilderTy &Builder) {
+ auto *S0 = dyn_cast<Instruction>(Shuf.getOperand(0));
Value *X;
- if (!match(Shuf.getOperand(0), m_CombineAnd(m_Instruction(FNeg0),
- m_FNeg(m_Value(X)))))
+ if (!S0 || !match(S0, m_CombineOr(m_FNeg(m_Value(X)), m_FAbs(m_Value(X)))))
return nullptr;
- // shuffle (fneg X), Mask --> fneg (shuffle X, Mask)
- if (FNeg0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) {
+ bool IsFNeg = S0->getOpcode() == Instruction::FNeg;
+
+ // Match 1-input (unary) shuffle.
+ // shuffle (fneg/fabs X), Mask --> fneg/fabs (shuffle X, Mask)
+ if (S0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) {
Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask());
- return UnaryOperator::CreateFNegFMF(NewShuf, FNeg0);
+ if (IsFNeg)
+ return UnaryOperator::CreateFNegFMF(NewShuf, S0);
+
+ Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(),
+ Intrinsic::fabs, Shuf.getType());
+ CallInst *NewF = CallInst::Create(FAbs, {NewShuf});
+ NewF->setFastMathFlags(S0->getFastMathFlags());
+ return NewF;
}
- Instruction *FNeg1;
+ // Match 2-input (binary) shuffle.
+ auto *S1 = dyn_cast<Instruction>(Shuf.getOperand(1));
Value *Y;
- if (!match(Shuf.getOperand(1), m_CombineAnd(m_Instruction(FNeg1),
- m_FNeg(m_Value(Y)))))
+ if (!S1 || !match(S1, m_CombineOr(m_FNeg(m_Value(Y)), m_FAbs(m_Value(Y)))) ||
+ S0->getOpcode() != S1->getOpcode() ||
+ (!S0->hasOneUse() && !S1->hasOneUse()))
return nullptr;
- // shuffle (fneg X), (fneg Y), Mask --> fneg (shuffle X, Y, Mask)
- if (FNeg0->hasOneUse() || FNeg1->hasOneUse()) {
- Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask());
- Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewShuf);
- NewFNeg->copyIRFlags(FNeg0);
- NewFNeg->andIRFlags(FNeg1);
- return NewFNeg;
+ // shuf (fneg/fabs X), (fneg/fabs Y), Mask --> fneg/fabs (shuf X, Y, Mask)
+ Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask());
+ Instruction *NewF;
+ if (IsFNeg) {
+ NewF = UnaryOperator::CreateFNeg(NewShuf);
+ } else {
+ Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(),
+ Intrinsic::fabs, Shuf.getType());
+ NewF = CallInst::Create(FAbs, {NewShuf});
}
-
- return nullptr;
+ NewF->copyIRFlags(S0);
+ NewF->andIRFlags(S1);
+ return NewF;
}
/// Canonicalize casts after shuffle.
@@ -2533,7 +2567,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
for (unsigned i = 0; i != NumElts; ++i) {
int ExtractMaskElt = Shuf.getMaskValue(i);
int MaskElt = Mask[i];
- NewMask[i] = ExtractMaskElt == UndefMaskElem ? ExtractMaskElt : MaskElt;
+ NewMask[i] = ExtractMaskElt == PoisonMaskElem ? ExtractMaskElt : MaskElt;
}
return new ShuffleVectorInst(X, Y, NewMask);
}
@@ -2699,7 +2733,8 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
// splatting the first element of the result of the BinOp
Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) {
if (!match(SVI.getOperand(1), m_Undef()) ||
- !match(SVI.getShuffleMask(), m_ZeroMask()))
+ !match(SVI.getShuffleMask(), m_ZeroMask()) ||
+ !SVI.getOperand(0)->hasOneUse())
return nullptr;
Value *Op0 = SVI.getOperand(0);
@@ -2759,7 +2794,6 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
}
ArrayRef<int> Mask = SVI.getShuffleMask();
- Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
// Peek through a bitcasted shuffle operand by scaling the mask. If the
// simulated shuffle can simplify, then this shuffle is unnecessary:
@@ -2815,7 +2849,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (Instruction *I = narrowVectorSelect(SVI, Builder))
return I;
- if (Instruction *I = foldFNegShuffle(SVI, Builder))
+ if (Instruction *I = foldShuffleOfUnaryOps(SVI, Builder))
return I;
if (Instruction *I = foldCastShuffle(SVI, Builder))
@@ -2840,7 +2874,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
return I;
if (match(RHS, m_Undef()) && canEvaluateShuffled(LHS, Mask)) {
- Value *V = evaluateInDifferentElementOrder(LHS, Mask);
+ Value *V = evaluateInDifferentElementOrder(LHS, Mask, Builder);
return replaceInstUsesWith(SVI, V);
}
@@ -2916,15 +2950,15 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
unsigned SrcElemsPerTgtElem = TgtElemBitWidth / SrcElemBitWidth;
assert(SrcElemsPerTgtElem);
BegIdx /= SrcElemsPerTgtElem;
- bool BCAlreadyExists = NewBCs.find(CastSrcTy) != NewBCs.end();
+ bool BCAlreadyExists = NewBCs.contains(CastSrcTy);
auto *NewBC =
BCAlreadyExists
? NewBCs[CastSrcTy]
: Builder.CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc");
if (!BCAlreadyExists)
NewBCs[CastSrcTy] = NewBC;
- auto *Ext = Builder.CreateExtractElement(
- NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract");
+ auto *Ext = Builder.CreateExtractElement(NewBC, BegIdx,
+ SVI.getName() + ".extract");
// The shufflevector isn't being replaced: the bitcast that used it
// is. InstCombine will visit the newly-created instructions.
replaceInstUsesWith(*BC, Ext);
@@ -3042,7 +3076,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
for (unsigned i = 0; i < VWidth; ++i) {
int eltMask;
if (Mask[i] < 0) {
- // This element is an undef value.
+ // This element is a poison value.
eltMask = -1;
} else if (Mask[i] < (int)LHSWidth) {
// This element is from left hand side vector operand.
@@ -3051,27 +3085,27 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
// new mask value for the element.
if (newLHS != LHS) {
eltMask = LHSMask[Mask[i]];
- // If the value selected is an undef value, explicitly specify it
+ // If the value selected is an poison value, explicitly specify it
// with a -1 mask value.
- if (eltMask >= (int)LHSOp0Width && isa<UndefValue>(LHSOp1))
+ if (eltMask >= (int)LHSOp0Width && isa<PoisonValue>(LHSOp1))
eltMask = -1;
} else
eltMask = Mask[i];
} else {
// This element is from right hand side vector operand
//
- // If the value selected is an undef value, explicitly specify it
+ // If the value selected is a poison value, explicitly specify it
// with a -1 mask value. (case 1)
- if (match(RHS, m_Undef()))
+ if (match(RHS, m_Poison()))
eltMask = -1;
// If RHS is going to be replaced (case 3 or 4), calculate the
// new mask value for the element.
else if (newRHS != RHS) {
eltMask = RHSMask[Mask[i]-LHSWidth];
- // If the value selected is an undef value, explicitly specify it
+ // If the value selected is an poison value, explicitly specify it
// with a -1 mask value.
if (eltMask >= (int)RHSOp0Width) {
- assert(match(RHSShuffle->getOperand(1), m_Undef()) &&
+ assert(match(RHSShuffle->getOperand(1), m_Poison()) &&
"should have been check above");
eltMask = -1;
}
@@ -3102,7 +3136,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
// or is a splat, do the replacement.
if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) {
if (!newRHS)
- newRHS = UndefValue::get(newLHS->getType());
+ newRHS = PoisonValue::get(newLHS->getType());
return new ShuffleVectorInst(newLHS, newRHS, newMask);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index fb6f4f96ea48..afd6e034f46d 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -33,8 +33,6 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
-#include "llvm-c/Initialization.h"
-#include "llvm-c/Transforms/InstCombine.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
@@ -47,7 +45,6 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/ConstantFolding.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LazyBlockFrequencyInfo.h"
@@ -70,6 +67,7 @@
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
@@ -78,7 +76,6 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
@@ -117,6 +114,11 @@ using namespace llvm::PatternMatch;
STATISTIC(NumWorklistIterations,
"Number of instruction combining iterations performed");
+STATISTIC(NumOneIteration, "Number of functions with one iteration");
+STATISTIC(NumTwoIterations, "Number of functions with two iterations");
+STATISTIC(NumThreeIterations, "Number of functions with three iterations");
+STATISTIC(NumFourOrMoreIterations,
+ "Number of functions with four or more iterations");
STATISTIC(NumCombined , "Number of insts combined");
STATISTIC(NumConstProp, "Number of constant folds");
@@ -129,7 +131,6 @@ DEBUG_COUNTER(VisitCounter, "instcombine-visit",
"Controls which instructions are visited");
// FIXME: these limits eventually should be as low as 2.
-static constexpr unsigned InstCombineDefaultMaxIterations = 1000;
#ifndef NDEBUG
static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100;
#else
@@ -144,11 +145,6 @@ static cl::opt<unsigned> MaxSinkNumUsers(
"instcombine-max-sink-users", cl::init(32),
cl::desc("Maximum number of undroppable users for instruction sinking"));
-static cl::opt<unsigned> LimitMaxIterations(
- "instcombine-max-iterations",
- cl::desc("Limit the maximum number of instruction combining iterations"),
- cl::init(InstCombineDefaultMaxIterations));
-
static cl::opt<unsigned> InfiniteLoopDetectionThreshold(
"instcombine-infinite-loop-threshold",
cl::desc("Number of instruction combining iterations considered an "
@@ -203,6 +199,10 @@ std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
return std::nullopt;
}
+bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
+ return TTI.isValidAddrSpaceCast(FromAS, ToAS);
+}
+
Value *InstCombinerImpl::EmitGEPOffset(User *GEP) {
return llvm::emitGEPOffset(&Builder, DL, GEP);
}
@@ -360,13 +360,17 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
// (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC)
Type *DestTy = C1->getType();
Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy);
- Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2);
+ Constant *FoldedC =
+ ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout());
+ if (!FoldedC)
+ return false;
+
IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0));
IC.replaceOperand(*BinOp1, 1, FoldedC);
return true;
}
-// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast.
+// Simplifies IntToPtr/PtrToInt RoundTrip Cast.
// inttoptr ( ptrtoint (x) ) --> x
Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) {
auto *IntToPtr = dyn_cast<IntToPtrInst>(Val);
@@ -378,10 +382,8 @@ Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) {
CastTy->getPointerAddressSpace() ==
PtrToInt->getSrcTy()->getPointerAddressSpace() &&
DL.getTypeSizeInBits(PtrToInt->getSrcTy()) ==
- DL.getTypeSizeInBits(PtrToInt->getDestTy())) {
- return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy,
- "", PtrToInt);
- }
+ DL.getTypeSizeInBits(PtrToInt->getDestTy()))
+ return PtrToInt->getOperand(0);
}
return nullptr;
}
@@ -732,6 +734,207 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
return RetVal;
}
+// (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
+// IFF
+// 1) the logic_shifts match
+// 2) either both binops are binops and one is `and` or
+// BinOp1 is `and`
+// (logic_shift (inv_logic_shift C1, C), C) == C1 or
+//
+// -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C)
+//
+// (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt))
+// IFF
+// 1) the logic_shifts match
+// 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`).
+//
+// -> (BinOp (logic_shift (BinOp X, Y)), Mask)
+Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
+ auto IsValidBinOpc = [](unsigned Opc) {
+ switch (Opc) {
+ default:
+ return false;
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::Add:
+ // Skip Sub as we only match constant masks which will canonicalize to use
+ // add.
+ return true;
+ }
+ };
+
+ // Check if we can distribute binop arbitrarily. `add` + `lshr` has extra
+ // constraints.
+ auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2,
+ unsigned ShOpc) {
+ return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) ||
+ ShOpc == Instruction::Shl;
+ };
+
+ auto GetInvShift = [](unsigned ShOpc) {
+ return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr;
+ };
+
+ auto CanDistributeBinops = [&](unsigned BinOpc1, unsigned BinOpc2,
+ unsigned ShOpc, Constant *CMask,
+ Constant *CShift) {
+ // If the BinOp1 is `and` we don't need to check the mask.
+ if (BinOpc1 == Instruction::And)
+ return true;
+
+ // For all other possible transfers we need complete distributable
+ // binop/shift (anything but `add` + `lshr`).
+ if (!IsCompletelyDistributable(BinOpc1, BinOpc2, ShOpc))
+ return false;
+
+ // If BinOp2 is `and`, any mask works (this only really helps for non-splat
+ // vecs, otherwise the mask will be simplified and the following check will
+ // handle it).
+ if (BinOpc2 == Instruction::And)
+ return true;
+
+ // Otherwise, need mask that meets the below requirement.
+ // (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask
+ return ConstantExpr::get(
+ ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift),
+ CShift) == CMask;
+ };
+
+ auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * {
+ Constant *CMask, *CShift;
+ Value *X, *Y, *ShiftedX, *Mask, *Shift;
+ if (!match(I.getOperand(ShOpnum),
+ m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift)))))
+ return nullptr;
+ if (!match(I.getOperand(1 - ShOpnum),
+ m_BinOp(m_Value(ShiftedX), m_Value(Mask))))
+ return nullptr;
+
+ if (!match(ShiftedX,
+ m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift)))))
+ return nullptr;
+
+ // Make sure we are matching instruction shifts and not ConstantExpr
+ auto *IY = dyn_cast<Instruction>(I.getOperand(ShOpnum));
+ auto *IX = dyn_cast<Instruction>(ShiftedX);
+ if (!IY || !IX)
+ return nullptr;
+
+ // LHS and RHS need same shift opcode
+ unsigned ShOpc = IY->getOpcode();
+ if (ShOpc != IX->getOpcode())
+ return nullptr;
+
+ // Make sure binop is real instruction and not ConstantExpr
+ auto *BO2 = dyn_cast<Instruction>(I.getOperand(1 - ShOpnum));
+ if (!BO2)
+ return nullptr;
+
+ unsigned BinOpc = BO2->getOpcode();
+ // Make sure we have valid binops.
+ if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc))
+ return nullptr;
+
+ // If BinOp1 == BinOp2 and it's bitwise or shl with add, then just
+ // distribute to drop the shift irrelevant of constants.
+ if (BinOpc == I.getOpcode() &&
+ IsCompletelyDistributable(I.getOpcode(), BinOpc, ShOpc)) {
+ Value *NewBinOp2 = Builder.CreateBinOp(I.getOpcode(), X, Y);
+ Value *NewBinOp1 = Builder.CreateBinOp(
+ static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp2, Shift);
+ return BinaryOperator::Create(I.getOpcode(), NewBinOp1, Mask);
+ }
+
+ // Otherwise we can only distribute by constant shifting the mask, so
+ // ensure we have constants.
+ if (!match(Shift, m_ImmConstant(CShift)))
+ return nullptr;
+ if (!match(Mask, m_ImmConstant(CMask)))
+ return nullptr;
+
+ // Check if we can distribute the binops.
+ if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift))
+ return nullptr;
+
+ Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift);
+ Value *NewBinOp2 = Builder.CreateBinOp(
+ static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask);
+ Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2);
+ return BinaryOperator::Create(static_cast<Instruction::BinaryOps>(ShOpc),
+ NewBinOp1, CShift);
+ };
+
+ if (Instruction *R = MatchBinOp(0))
+ return R;
+ return MatchBinOp(1);
+}
+
+// (Binop (zext C), (select C, T, F))
+// -> (select C, (binop 1, T), (binop 0, F))
+//
+// (Binop (sext C), (select C, T, F))
+// -> (select C, (binop -1, T), (binop 0, F))
+//
+// Attempt to simplify binary operations into a select with folded args, when
+// one operand of the binop is a select instruction and the other operand is a
+// zext/sext extension, whose value is the select condition.
+Instruction *
+InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) {
+ // TODO: this simplification may be extended to any speculatable instruction,
+ // not just binops, and would possibly be handled better in FoldOpIntoSelect.
+ Instruction::BinaryOps Opc = I.getOpcode();
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ Value *A, *CondVal, *TrueVal, *FalseVal;
+ Value *CastOp;
+
+ auto MatchSelectAndCast = [&](Value *CastOp, Value *SelectOp) {
+ return match(CastOp, m_ZExtOrSExt(m_Value(A))) &&
+ A->getType()->getScalarSizeInBits() == 1 &&
+ match(SelectOp, m_Select(m_Value(CondVal), m_Value(TrueVal),
+ m_Value(FalseVal)));
+ };
+
+ // Make sure one side of the binop is a select instruction, and the other is a
+ // zero/sign extension operating on a i1.
+ if (MatchSelectAndCast(LHS, RHS))
+ CastOp = LHS;
+ else if (MatchSelectAndCast(RHS, LHS))
+ CastOp = RHS;
+ else
+ return nullptr;
+
+ auto NewFoldedConst = [&](bool IsTrueArm, Value *V) {
+ bool IsCastOpRHS = (CastOp == RHS);
+ bool IsZExt = isa<ZExtInst>(CastOp);
+ Constant *C;
+
+ if (IsTrueArm) {
+ C = Constant::getNullValue(V->getType());
+ } else if (IsZExt) {
+ unsigned BitWidth = V->getType()->getScalarSizeInBits();
+ C = Constant::getIntegerValue(V->getType(), APInt(BitWidth, 1));
+ } else {
+ C = Constant::getAllOnesValue(V->getType());
+ }
+
+ return IsCastOpRHS ? Builder.CreateBinOp(Opc, V, C)
+ : Builder.CreateBinOp(Opc, C, V);
+ };
+
+ // If the value used in the zext/sext is the select condition, or the negated
+ // of the select condition, the binop can be simplified.
+ if (CondVal == A)
+ return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal),
+ NewFoldedConst(true, FalseVal));
+
+ if (match(A, m_Not(m_Specific(CondVal))))
+ return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal),
+ NewFoldedConst(false, FalseVal));
+
+ return nullptr;
+}
+
Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
@@ -948,6 +1151,7 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
/// Freely adapt every user of V as-if V was changed to !V.
/// WARNING: only if canFreelyInvertAllUsersOf() said this can be done.
void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) {
+ assert(!isa<Constant>(I) && "Shouldn't invert users of constant");
for (User *U : make_early_inc_range(I->users())) {
if (U == IgnoredUser)
continue; // Don't consider this user.
@@ -1033,63 +1237,39 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) {
return SelectInst::Create(X, TVal, FVal);
}
-static Constant *constantFoldOperationIntoSelectOperand(
- Instruction &I, SelectInst *SI, Value *SO) {
- auto *ConstSO = dyn_cast<Constant>(SO);
- if (!ConstSO)
- return nullptr;
-
+static Constant *constantFoldOperationIntoSelectOperand(Instruction &I,
+ SelectInst *SI,
+ bool IsTrueArm) {
SmallVector<Constant *> ConstOps;
for (Value *Op : I.operands()) {
- if (Op == SI)
- ConstOps.push_back(ConstSO);
- else if (auto *C = dyn_cast<Constant>(Op))
- ConstOps.push_back(C);
- else
- llvm_unreachable("Operands should be select or constant");
- }
- return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout());
-}
+ CmpInst::Predicate Pred;
+ Constant *C = nullptr;
+ if (Op == SI) {
+ C = dyn_cast<Constant>(IsTrueArm ? SI->getTrueValue()
+ : SI->getFalseValue());
+ } else if (match(SI->getCondition(),
+ m_ICmp(Pred, m_Specific(Op), m_Constant(C))) &&
+ Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) &&
+ isGuaranteedNotToBeUndefOrPoison(C)) {
+ // Pass
+ } else {
+ C = dyn_cast<Constant>(Op);
+ }
+ if (C == nullptr)
+ return nullptr;
-static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO,
- InstCombiner::BuilderTy &Builder) {
- if (auto *Cast = dyn_cast<CastInst>(&I))
- return Builder.CreateCast(Cast->getOpcode(), SO, I.getType());
-
- if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
- assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) &&
- "Expected constant-foldable intrinsic");
- Intrinsic::ID IID = II->getIntrinsicID();
- if (II->arg_size() == 1)
- return Builder.CreateUnaryIntrinsic(IID, SO);
-
- // This works for real binary ops like min/max (where we always expect the
- // constant operand to be canonicalized as op1) and unary ops with a bonus
- // constant argument like ctlz/cttz.
- // TODO: Handle non-commutative binary intrinsics as below for binops.
- assert(II->arg_size() == 2 && "Expected binary intrinsic");
- assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand");
- return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1));
+ ConstOps.push_back(C);
}
- if (auto *EI = dyn_cast<ExtractElementInst>(&I))
- return Builder.CreateExtractElement(SO, EI->getIndexOperand());
-
- assert(I.isBinaryOp() && "Unexpected opcode for select folding");
-
- // Figure out if the constant is the left or the right argument.
- bool ConstIsRHS = isa<Constant>(I.getOperand(1));
- Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS));
-
- Value *Op0 = SO, *Op1 = ConstOperand;
- if (!ConstIsRHS)
- std::swap(Op0, Op1);
+ return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout());
+}
- Value *NewBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), Op0,
- Op1, SO->getName() + ".op");
- if (auto *NewBOI = dyn_cast<Instruction>(NewBO))
- NewBOI->copyIRFlags(&I);
- return NewBO;
+static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
+ Value *NewOp, InstCombiner &IC) {
+ Instruction *Clone = I.clone();
+ Clone->replaceUsesOfWith(SI, NewOp);
+ IC.InsertNewInstBefore(Clone, *SI);
+ return Clone;
}
Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
@@ -1122,56 +1302,17 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
return nullptr;
}
- // Test if a CmpInst instruction is used exclusively by a select as
- // part of a minimum or maximum operation. If so, refrain from doing
- // any other folding. This helps out other analyses which understand
- // non-obfuscated minimum and maximum idioms, such as ScalarEvolution
- // and CodeGen. And in this case, at least one of the comparison
- // operands has at least one user besides the compare (the select),
- // which would often largely negate the benefit of folding anyway.
- if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) {
- if (CI->hasOneUse()) {
- Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1);
-
- // FIXME: This is a hack to avoid infinite looping with min/max patterns.
- // We have to ensure that vector constants that only differ with
- // undef elements are treated as equivalent.
- auto areLooselyEqual = [](Value *A, Value *B) {
- if (A == B)
- return true;
-
- // Test for vector constants.
- Constant *ConstA, *ConstB;
- if (!match(A, m_Constant(ConstA)) || !match(B, m_Constant(ConstB)))
- return false;
-
- // TODO: Deal with FP constants?
- if (!A->getType()->isIntOrIntVectorTy() || A->getType() != B->getType())
- return false;
-
- // Compare for equality including undefs as equal.
- auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB);
- const APInt *C;
- return match(Cmp, m_APIntAllowUndef(C)) && C->isOne();
- };
-
- if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) ||
- (areLooselyEqual(FV, Op0) && areLooselyEqual(TV, Op1)))
- return nullptr;
- }
- }
-
// Make sure that one of the select arms constant folds successfully.
- Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV);
- Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV);
+ Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true);
+ Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false);
if (!NewTV && !NewFV)
return nullptr;
// Create an instruction for the arm that did not fold.
if (!NewTV)
- NewTV = foldOperationIntoSelectOperand(Op, TV, Builder);
+ NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this);
if (!NewFV)
- NewFV = foldOperationIntoSelectOperand(Op, FV, Builder);
+ NewFV = foldOperationIntoSelectOperand(Op, SI, FV, *this);
return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
}
@@ -1263,6 +1404,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues());
InsertNewInstBefore(NewPN, *PN);
NewPN->takeName(PN);
+ NewPN->setDebugLoc(PN->getDebugLoc());
// If we are going to have to insert a new computation, do so right before the
// predecessor's terminator.
@@ -1291,6 +1433,10 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
replaceInstUsesWith(*User, NewPN);
eraseInstFromFunction(*User);
}
+
+ replaceAllDbgUsesWith(const_cast<PHINode &>(*PN),
+ const_cast<PHINode &>(*NewPN),
+ const_cast<PHINode &>(*PN), DT);
return replaceInstUsesWith(I, NewPN);
}
@@ -1301,7 +1447,7 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
auto *Phi0 = dyn_cast<PHINode>(BO.getOperand(0));
auto *Phi1 = dyn_cast<PHINode>(BO.getOperand(1));
if (!Phi0 || !Phi1 || !Phi0->hasOneUse() || !Phi1->hasOneUse() ||
- Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2)
+ Phi0->getNumOperands() != Phi1->getNumOperands())
return nullptr;
// TODO: Remove the restriction for binop being in the same block as the phis.
@@ -1309,6 +1455,51 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;
+ // Fold if there is at least one specific constant value in phi0 or phi1's
+ // incoming values that comes from the same block and this specific constant
+ // value can be used to do optimization for specific binary operator.
+ // For example:
+ // %phi0 = phi i32 [0, %bb0], [%i, %bb1]
+ // %phi1 = phi i32 [%j, %bb0], [0, %bb1]
+ // %add = add i32 %phi0, %phi1
+ // ==>
+ // %add = phi i32 [%j, %bb0], [%i, %bb1]
+ Constant *C = ConstantExpr::getBinOpIdentity(BO.getOpcode(), BO.getType(),
+ /*AllowRHSConstant*/ false);
+ if (C) {
+ SmallVector<Value *, 4> NewIncomingValues;
+ auto CanFoldIncomingValuePair = [&](std::tuple<Use &, Use &> T) {
+ auto &Phi0Use = std::get<0>(T);
+ auto &Phi1Use = std::get<1>(T);
+ if (Phi0->getIncomingBlock(Phi0Use) != Phi1->getIncomingBlock(Phi1Use))
+ return false;
+ Value *Phi0UseV = Phi0Use.get();
+ Value *Phi1UseV = Phi1Use.get();
+ if (Phi0UseV == C)
+ NewIncomingValues.push_back(Phi1UseV);
+ else if (Phi1UseV == C)
+ NewIncomingValues.push_back(Phi0UseV);
+ else
+ return false;
+ return true;
+ };
+
+ if (all_of(zip(Phi0->operands(), Phi1->operands()),
+ CanFoldIncomingValuePair)) {
+ PHINode *NewPhi =
+ PHINode::Create(Phi0->getType(), Phi0->getNumOperands());
+ assert(NewIncomingValues.size() == Phi0->getNumOperands() &&
+ "The number of collected incoming values should equal the number "
+ "of the original PHINode operands!");
+ for (unsigned I = 0; I < Phi0->getNumOperands(); I++)
+ NewPhi->addIncoming(NewIncomingValues[I], Phi0->getIncomingBlock(I));
+ return NewPhi;
+ }
+ }
+
+ if (Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2)
+ return nullptr;
+
// Match a pair of incoming constants for one of the predecessor blocks.
BasicBlock *ConstBB, *OtherBB;
Constant *C0, *C1;
@@ -1374,28 +1565,6 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
return nullptr;
}
-/// Given a pointer type and a constant offset, determine whether or not there
-/// is a sequence of GEP indices into the pointed type that will land us at the
-/// specified offset. If so, fill them into NewIndices and return the resultant
-/// element type, otherwise return null.
-static Type *findElementAtOffset(PointerType *PtrTy, int64_t IntOffset,
- SmallVectorImpl<Value *> &NewIndices,
- const DataLayout &DL) {
- // Only used by visitGEPOfBitcast(), which is skipped for opaque pointers.
- Type *Ty = PtrTy->getNonOpaquePointerElementType();
- if (!Ty->isSized())
- return nullptr;
-
- APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset);
- SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset);
- if (!Offset.isZero())
- return nullptr;
-
- for (const APInt &Index : Indices)
- NewIndices.push_back(ConstantInt::get(PtrTy->getContext(), Index));
- return Ty;
-}
-
static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
// If this GEP has only 0 indices, it is the same pointer as
// Src. If Src is not a trivial GEP too, don't combine
@@ -1406,248 +1575,6 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
return true;
}
-/// Return a value X such that Val = X * Scale, or null if none.
-/// If the multiplication is known not to overflow, then NoSignedWrap is set.
-Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) {
- assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!");
- assert(cast<IntegerType>(Val->getType())->getBitWidth() ==
- Scale.getBitWidth() && "Scale not compatible with value!");
-
- // If Val is zero or Scale is one then Val = Val * Scale.
- if (match(Val, m_Zero()) || Scale == 1) {
- NoSignedWrap = true;
- return Val;
- }
-
- // If Scale is zero then it does not divide Val.
- if (Scale.isMinValue())
- return nullptr;
-
- // Look through chains of multiplications, searching for a constant that is
- // divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4
- // will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by
- // a factor of 4 will produce X*(Y*2). The principle of operation is to bore
- // down from Val:
- //
- // Val = M1 * X || Analysis starts here and works down
- // M1 = M2 * Y || Doesn't descend into terms with more
- // M2 = Z * 4 \/ than one use
- //
- // Then to modify a term at the bottom:
- //
- // Val = M1 * X
- // M1 = Z * Y || Replaced M2 with Z
- //
- // Then to work back up correcting nsw flags.
-
- // Op - the term we are currently analyzing. Starts at Val then drills down.
- // Replaced with its descaled value before exiting from the drill down loop.
- Value *Op = Val;
-
- // Parent - initially null, but after drilling down notes where Op came from.
- // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the
- // 0'th operand of Val.
- std::pair<Instruction *, unsigned> Parent;
-
- // Set if the transform requires a descaling at deeper levels that doesn't
- // overflow.
- bool RequireNoSignedWrap = false;
-
- // Log base 2 of the scale. Negative if not a power of 2.
- int32_t logScale = Scale.exactLogBase2();
-
- for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
- // If Op is a constant divisible by Scale then descale to the quotient.
- APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth.
- APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder);
- if (!Remainder.isMinValue())
- // Not divisible by Scale.
- return nullptr;
- // Replace with the quotient in the parent.
- Op = ConstantInt::get(CI->getType(), Quotient);
- NoSignedWrap = true;
- break;
- }
-
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) {
- if (BO->getOpcode() == Instruction::Mul) {
- // Multiplication.
- NoSignedWrap = BO->hasNoSignedWrap();
- if (RequireNoSignedWrap && !NoSignedWrap)
- return nullptr;
-
- // There are three cases for multiplication: multiplication by exactly
- // the scale, multiplication by a constant different to the scale, and
- // multiplication by something else.
- Value *LHS = BO->getOperand(0);
- Value *RHS = BO->getOperand(1);
-
- if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
- // Multiplication by a constant.
- if (CI->getValue() == Scale) {
- // Multiplication by exactly the scale, replace the multiplication
- // by its left-hand side in the parent.
- Op = LHS;
- break;
- }
-
- // Otherwise drill down into the constant.
- if (!Op->hasOneUse())
- return nullptr;
-
- Parent = std::make_pair(BO, 1);
- continue;
- }
-
- // Multiplication by something else. Drill down into the left-hand side
- // since that's where the reassociate pass puts the good stuff.
- if (!Op->hasOneUse())
- return nullptr;
-
- Parent = std::make_pair(BO, 0);
- continue;
- }
-
- if (logScale > 0 && BO->getOpcode() == Instruction::Shl &&
- isa<ConstantInt>(BO->getOperand(1))) {
- // Multiplication by a power of 2.
- NoSignedWrap = BO->hasNoSignedWrap();
- if (RequireNoSignedWrap && !NoSignedWrap)
- return nullptr;
-
- Value *LHS = BO->getOperand(0);
- int32_t Amt = cast<ConstantInt>(BO->getOperand(1))->
- getLimitedValue(Scale.getBitWidth());
- // Op = LHS << Amt.
-
- if (Amt == logScale) {
- // Multiplication by exactly the scale, replace the multiplication
- // by its left-hand side in the parent.
- Op = LHS;
- break;
- }
- if (Amt < logScale || !Op->hasOneUse())
- return nullptr;
-
- // Multiplication by more than the scale. Reduce the multiplying amount
- // by the scale in the parent.
- Parent = std::make_pair(BO, 1);
- Op = ConstantInt::get(BO->getType(), Amt - logScale);
- break;
- }
- }
-
- if (!Op->hasOneUse())
- return nullptr;
-
- if (CastInst *Cast = dyn_cast<CastInst>(Op)) {
- if (Cast->getOpcode() == Instruction::SExt) {
- // Op is sign-extended from a smaller type, descale in the smaller type.
- unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits();
- APInt SmallScale = Scale.trunc(SmallSize);
- // Suppose Op = sext X, and we descale X as Y * SmallScale. We want to
- // descale Op as (sext Y) * Scale. In order to have
- // sext (Y * SmallScale) = (sext Y) * Scale
- // some conditions need to hold however: SmallScale must sign-extend to
- // Scale and the multiplication Y * SmallScale should not overflow.
- if (SmallScale.sext(Scale.getBitWidth()) != Scale)
- // SmallScale does not sign-extend to Scale.
- return nullptr;
- assert(SmallScale.exactLogBase2() == logScale);
- // Require that Y * SmallScale must not overflow.
- RequireNoSignedWrap = true;
-
- // Drill down through the cast.
- Parent = std::make_pair(Cast, 0);
- Scale = SmallScale;
- continue;
- }
-
- if (Cast->getOpcode() == Instruction::Trunc) {
- // Op is truncated from a larger type, descale in the larger type.
- // Suppose Op = trunc X, and we descale X as Y * sext Scale. Then
- // trunc (Y * sext Scale) = (trunc Y) * Scale
- // always holds. However (trunc Y) * Scale may overflow even if
- // trunc (Y * sext Scale) does not, so nsw flags need to be cleared
- // from this point up in the expression (see later).
- if (RequireNoSignedWrap)
- return nullptr;
-
- // Drill down through the cast.
- unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits();
- Parent = std::make_pair(Cast, 0);
- Scale = Scale.sext(LargeSize);
- if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits())
- logScale = -1;
- assert(Scale.exactLogBase2() == logScale);
- continue;
- }
- }
-
- // Unsupported expression, bail out.
- return nullptr;
- }
-
- // If Op is zero then Val = Op * Scale.
- if (match(Op, m_Zero())) {
- NoSignedWrap = true;
- return Op;
- }
-
- // We know that we can successfully descale, so from here on we can safely
- // modify the IR. Op holds the descaled version of the deepest term in the
- // expression. NoSignedWrap is 'true' if multiplying Op by Scale is known
- // not to overflow.
-
- if (!Parent.first)
- // The expression only had one term.
- return Op;
-
- // Rewrite the parent using the descaled version of its operand.
- assert(Parent.first->hasOneUse() && "Drilled down when more than one use!");
- assert(Op != Parent.first->getOperand(Parent.second) &&
- "Descaling was a no-op?");
- replaceOperand(*Parent.first, Parent.second, Op);
- Worklist.push(Parent.first);
-
- // Now work back up the expression correcting nsw flags. The logic is based
- // on the following observation: if X * Y is known not to overflow as a signed
- // multiplication, and Y is replaced by a value Z with smaller absolute value,
- // then X * Z will not overflow as a signed multiplication either. As we work
- // our way up, having NoSignedWrap 'true' means that the descaled value at the
- // current level has strictly smaller absolute value than the original.
- Instruction *Ancestor = Parent.first;
- do {
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Ancestor)) {
- // If the multiplication wasn't nsw then we can't say anything about the
- // value of the descaled multiplication, and we have to clear nsw flags
- // from this point on up.
- bool OpNoSignedWrap = BO->hasNoSignedWrap();
- NoSignedWrap &= OpNoSignedWrap;
- if (NoSignedWrap != OpNoSignedWrap) {
- BO->setHasNoSignedWrap(NoSignedWrap);
- Worklist.push(Ancestor);
- }
- } else if (Ancestor->getOpcode() == Instruction::Trunc) {
- // The fact that the descaled input to the trunc has smaller absolute
- // value than the original input doesn't tell us anything useful about
- // the absolute values of the truncations.
- NoSignedWrap = false;
- }
- assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) &&
- "Failed to keep proper track of nsw flags while drilling down?");
-
- if (Ancestor == Val)
- // Got to the top, all done!
- return Val;
-
- // Move up one level in the expression.
- assert(Ancestor->hasOneUse() && "Drilled down when more than one use!");
- Ancestor = Ancestor->user_back();
- } while (true);
-}
-
Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
if (!isa<VectorType>(Inst.getType()))
return nullptr;
@@ -1748,9 +1675,9 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
// TODO: Allow arbitrary shuffles by shuffling after binop?
// That might be legal, but we have to deal with poison.
if (LShuf->isSelect() &&
- !is_contained(LShuf->getShuffleMask(), UndefMaskElem) &&
+ !is_contained(LShuf->getShuffleMask(), PoisonMaskElem) &&
RShuf->isSelect() &&
- !is_contained(RShuf->getShuffleMask(), UndefMaskElem)) {
+ !is_contained(RShuf->getShuffleMask(), PoisonMaskElem)) {
// Example:
// LHS = shuffle V1, V2, <0, 5, 6, 3>
// RHS = shuffle V2, V1, <0, 5, 6, 3>
@@ -1991,50 +1918,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src))
return nullptr;
- if (Src->getResultElementType() == GEP.getSourceElementType() &&
- Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 &&
- Src->hasOneUse()) {
- Value *GO1 = GEP.getOperand(1);
- Value *SO1 = Src->getOperand(1);
-
- if (LI) {
- // Try to reassociate loop invariant GEP chains to enable LICM.
- if (Loop *L = LI->getLoopFor(GEP.getParent())) {
- // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is
- // invariant: this breaks the dependence between GEPs and allows LICM
- // to hoist the invariant part out of the loop.
- if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) {
- // The swapped GEPs are inbounds if both original GEPs are inbounds
- // and the sign of the offsets is the same. For simplicity, only
- // handle both offsets being non-negative.
- bool IsInBounds = Src->isInBounds() && GEP.isInBounds() &&
- isKnownNonNegative(SO1, DL, 0, &AC, &GEP, &DT) &&
- isKnownNonNegative(GO1, DL, 0, &AC, &GEP, &DT);
- // Put NewSrc at same location as %src.
- Builder.SetInsertPoint(cast<Instruction>(Src));
- Value *NewSrc = Builder.CreateGEP(GEP.getSourceElementType(),
- Src->getPointerOperand(), GO1,
- Src->getName(), IsInBounds);
- GetElementPtrInst *NewGEP = GetElementPtrInst::Create(
- GEP.getSourceElementType(), NewSrc, {SO1});
- NewGEP->setIsInBounds(IsInBounds);
- return NewGEP;
- }
- }
- }
- }
-
- // Note that if our source is a gep chain itself then we wait for that
- // chain to be resolved before we perform this transformation. This
- // avoids us creating a TON of code in some cases.
- if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0)))
- if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP))
- return nullptr; // Wait until our source is folded to completion.
-
// For constant GEPs, use a more general offset-based folding approach.
- // Only do this for opaque pointers, as the result element type may change.
Type *PtrTy = Src->getType()->getScalarType();
- if (PtrTy->isOpaquePointerTy() && GEP.hasAllConstantIndices() &&
+ if (GEP.hasAllConstantIndices() &&
(Src->hasOneUse() || Src->hasAllConstantIndices())) {
// Split Src into a variable part and a constant suffix.
gep_type_iterator GTI = gep_type_begin(*Src);
@@ -2077,13 +1963,11 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
// If both GEP are constant-indexed, and cannot be merged in either way,
// convert them to a GEP of i8.
if (Src->hasAllConstantIndices())
- return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))
- ? GetElementPtrInst::CreateInBounds(
- Builder.getInt8Ty(), Src->getOperand(0),
- Builder.getInt(OffsetOld), GEP.getName())
- : GetElementPtrInst::Create(
- Builder.getInt8Ty(), Src->getOperand(0),
- Builder.getInt(OffsetOld), GEP.getName());
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(
+ Builder.getInt8Ty(), Src->getOperand(0),
+ Builder.getInt(OffsetOld), "",
+ isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))));
return nullptr;
}
@@ -2100,13 +1984,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
IsInBounds &= Idx.isNonNegative() == ConstIndices[0].isNonNegative();
}
- return IsInBounds
- ? GetElementPtrInst::CreateInBounds(Src->getSourceElementType(),
- Src->getOperand(0), Indices,
- GEP.getName())
- : GetElementPtrInst::Create(Src->getSourceElementType(),
- Src->getOperand(0), Indices,
- GEP.getName());
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(Src->getSourceElementType(), Src->getOperand(0),
+ Indices, "", IsInBounds));
}
if (Src->getResultElementType() != GEP.getSourceElementType())
@@ -2160,118 +2040,10 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
}
if (!Indices.empty())
- return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))
- ? GetElementPtrInst::CreateInBounds(
- Src->getSourceElementType(), Src->getOperand(0), Indices,
- GEP.getName())
- : GetElementPtrInst::Create(Src->getSourceElementType(),
- Src->getOperand(0), Indices,
- GEP.getName());
-
- return nullptr;
-}
-
-// Note that we may have also stripped an address space cast in between.
-Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI,
- GetElementPtrInst &GEP) {
- // With opaque pointers, there is no pointer element type we can use to
- // adjust the GEP type.
- PointerType *SrcType = cast<PointerType>(BCI->getSrcTy());
- if (SrcType->isOpaque())
- return nullptr;
-
- Type *GEPEltType = GEP.getSourceElementType();
- Type *SrcEltType = SrcType->getNonOpaquePointerElementType();
- Value *SrcOp = BCI->getOperand(0);
-
- // GEP directly using the source operand if this GEP is accessing an element
- // of a bitcasted pointer to vector or array of the same dimensions:
- // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z
- // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z
- auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy,
- const DataLayout &DL) {
- auto *VecVTy = cast<FixedVectorType>(VecTy);
- return ArrTy->getArrayElementType() == VecVTy->getElementType() &&
- ArrTy->getArrayNumElements() == VecVTy->getNumElements() &&
- DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy);
- };
- if (GEP.getNumOperands() == 3 &&
- ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) &&
- areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) ||
- (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() &&
- areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) {
-
- // Create a new GEP here, as using `setOperand()` followed by
- // `setSourceElementType()` won't actually update the type of the
- // existing GEP Value. Causing issues if this Value is accessed when
- // constructing an AddrSpaceCastInst
- SmallVector<Value *, 8> Indices(GEP.indices());
- Value *NGEP =
- Builder.CreateGEP(SrcEltType, SrcOp, Indices, "", GEP.isInBounds());
- NGEP->takeName(&GEP);
-
- // Preserve GEP address space to satisfy users
- if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
- return new AddrSpaceCastInst(NGEP, GEP.getType());
-
- return replaceInstUsesWith(GEP, NGEP);
- }
-
- // See if we can simplify:
- // X = bitcast A* to B*
- // Y = gep X, <...constant indices...>
- // into a gep of the original struct. This is important for SROA and alias
- // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged.
- unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEP.getType());
- APInt Offset(OffsetBits, 0);
-
- // If the bitcast argument is an allocation, The bitcast is for convertion
- // to actual type of allocation. Removing such bitcasts, results in having
- // GEPs with i8* base and pure byte offsets. That means GEP is not aware of
- // struct or array hierarchy.
- // By avoiding such GEPs, phi translation and MemoryDependencyAnalysis have
- // a better chance to succeed.
- if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset) &&
- !isAllocationFn(SrcOp, &TLI)) {
- // If this GEP instruction doesn't move the pointer, just replace the GEP
- // with a bitcast of the real input to the dest type.
- if (!Offset) {
- // If the bitcast is of an allocation, and the allocation will be
- // converted to match the type of the cast, don't touch this.
- if (isa<AllocaInst>(SrcOp)) {
- // See if the bitcast simplifies, if so, don't nuke this GEP yet.
- if (Instruction *I = visitBitCast(*BCI)) {
- if (I != BCI) {
- I->takeName(BCI);
- I->insertInto(BCI->getParent(), BCI->getIterator());
- replaceInstUsesWith(*BCI, I);
- }
- return &GEP;
- }
- }
-
- if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace())
- return new AddrSpaceCastInst(SrcOp, GEP.getType());
- return new BitCastInst(SrcOp, GEP.getType());
- }
-
- // Otherwise, if the offset is non-zero, we need to find out if there is a
- // field at Offset in 'A's type. If so, we can pull the cast through the
- // GEP.
- SmallVector<Value *, 8> NewIndices;
- if (findElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices, DL)) {
- Value *NGEP = Builder.CreateGEP(SrcEltType, SrcOp, NewIndices, "",
- GEP.isInBounds());
-
- if (NGEP->getType() == GEP.getType())
- return replaceInstUsesWith(GEP, NGEP);
- NGEP->takeName(&GEP);
-
- if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
- return new AddrSpaceCastInst(NGEP, GEP.getType());
- return new BitCastInst(NGEP, GEP.getType());
- }
- }
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(
+ Src->getSourceElementType(), Src->getOperand(0), Indices, "",
+ isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))));
return nullptr;
}
@@ -2497,192 +2269,6 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
if (GEPType->isVectorTy())
return nullptr;
- // Handle gep(bitcast x) and gep(gep x, 0, 0, 0).
- Value *StrippedPtr = PtrOp->stripPointerCasts();
- PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType());
-
- // TODO: The basic approach of these folds is not compatible with opaque
- // pointers, because we can't use bitcasts as a hint for a desirable GEP
- // type. Instead, we should perform canonicalization directly on the GEP
- // type. For now, skip these.
- if (StrippedPtr != PtrOp && !StrippedPtrTy->isOpaque()) {
- bool HasZeroPointerIndex = false;
- Type *StrippedPtrEltTy = StrippedPtrTy->getNonOpaquePointerElementType();
-
- if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1)))
- HasZeroPointerIndex = C->isZero();
-
- // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ...
- // into : GEP [10 x i8]* X, i32 0, ...
- //
- // Likewise, transform: GEP (bitcast i8* X to [0 x i8]*), i32 0, ...
- // into : GEP i8* X, ...
- //
- // This occurs when the program declares an array extern like "int X[];"
- if (HasZeroPointerIndex) {
- if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) {
- // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ?
- if (CATy->getElementType() == StrippedPtrEltTy) {
- // -> GEP i8* X, ...
- SmallVector<Value *, 8> Idx(drop_begin(GEP.indices()));
- GetElementPtrInst *Res = GetElementPtrInst::Create(
- StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName());
- Res->setIsInBounds(GEP.isInBounds());
- if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace())
- return Res;
- // Insert Res, and create an addrspacecast.
- // e.g.,
- // GEP (addrspacecast i8 addrspace(1)* X to [0 x i8]*), i32 0, ...
- // ->
- // %0 = GEP i8 addrspace(1)* X, ...
- // addrspacecast i8 addrspace(1)* %0 to i8*
- return new AddrSpaceCastInst(Builder.Insert(Res), GEPType);
- }
-
- if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrEltTy)) {
- // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ?
- if (CATy->getElementType() == XATy->getElementType()) {
- // -> GEP [10 x i8]* X, i32 0, ...
- // At this point, we know that the cast source type is a pointer
- // to an array of the same type as the destination pointer
- // array. Because the array type is never stepped over (there
- // is a leading zero) we can fold the cast into this GEP.
- if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) {
- GEP.setSourceElementType(XATy);
- return replaceOperand(GEP, 0, StrippedPtr);
- }
- // Cannot replace the base pointer directly because StrippedPtr's
- // address space is different. Instead, create a new GEP followed by
- // an addrspacecast.
- // e.g.,
- // GEP (addrspacecast [10 x i8] addrspace(1)* X to [0 x i8]*),
- // i32 0, ...
- // ->
- // %0 = GEP [10 x i8] addrspace(1)* X, ...
- // addrspacecast i8 addrspace(1)* %0 to i8*
- SmallVector<Value *, 8> Idx(GEP.indices());
- Value *NewGEP =
- Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx,
- GEP.getName(), GEP.isInBounds());
- return new AddrSpaceCastInst(NewGEP, GEPType);
- }
- }
- }
- } else if (GEP.getNumOperands() == 2 && !IsGEPSrcEleScalable) {
- // Skip if GEP source element type is scalable. The type alloc size is
- // unknown at compile-time.
- // Transform things like: %t = getelementptr i32*
- // bitcast ([2 x i32]* %str to i32*), i32 %V into: %t1 = getelementptr [2
- // x i32]* %str, i32 0, i32 %V; bitcast
- if (StrippedPtrEltTy->isArrayTy() &&
- DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) ==
- DL.getTypeAllocSize(GEPEltType)) {
- Type *IdxType = DL.getIndexType(GEPType);
- Value *Idx[2] = {Constant::getNullValue(IdxType), GEP.getOperand(1)};
- Value *NewGEP = Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx,
- GEP.getName(), GEP.isInBounds());
-
- // V and GEP are both pointer types --> BitCast
- return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType);
- }
-
- // Transform things like:
- // %V = mul i64 %N, 4
- // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V
- // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast
- if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) {
- // Check that changing the type amounts to dividing the index by a scale
- // factor.
- uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue();
- uint64_t SrcSize =
- DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue();
- if (ResSize && SrcSize % ResSize == 0) {
- Value *Idx = GEP.getOperand(1);
- unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
- uint64_t Scale = SrcSize / ResSize;
-
- // Earlier transforms ensure that the index has the right type
- // according to Data Layout, which considerably simplifies the
- // logic by eliminating implicit casts.
- assert(Idx->getType() == DL.getIndexType(GEPType) &&
- "Index type does not match the Data Layout preferences");
-
- bool NSW;
- if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) {
- // Successfully decomposed Idx as NewIdx * Scale, form a new GEP.
- // If the multiplication NewIdx * Scale may overflow then the new
- // GEP may not be "inbounds".
- Value *NewGEP =
- Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx,
- GEP.getName(), GEP.isInBounds() && NSW);
-
- // The NewGEP must be pointer typed, so must the old one -> BitCast
- return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP,
- GEPType);
- }
- }
- }
-
- // Similarly, transform things like:
- // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp
- // (where tmp = 8*tmp2) into:
- // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast
- if (GEPEltType->isSized() && StrippedPtrEltTy->isSized() &&
- StrippedPtrEltTy->isArrayTy()) {
- // Check that changing to the array element type amounts to dividing the
- // index by a scale factor.
- uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue();
- uint64_t ArrayEltSize =
- DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType())
- .getFixedValue();
- if (ResSize && ArrayEltSize % ResSize == 0) {
- Value *Idx = GEP.getOperand(1);
- unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
- uint64_t Scale = ArrayEltSize / ResSize;
-
- // Earlier transforms ensure that the index has the right type
- // according to the Data Layout, which considerably simplifies
- // the logic by eliminating implicit casts.
- assert(Idx->getType() == DL.getIndexType(GEPType) &&
- "Index type does not match the Data Layout preferences");
-
- bool NSW;
- if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) {
- // Successfully decomposed Idx as NewIdx * Scale, form a new GEP.
- // If the multiplication NewIdx * Scale may overflow then the new
- // GEP may not be "inbounds".
- Type *IndTy = DL.getIndexType(GEPType);
- Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx};
-
- Value *NewGEP =
- Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off,
- GEP.getName(), GEP.isInBounds() && NSW);
- // The NewGEP must be pointer typed, so must the old one -> BitCast
- return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP,
- GEPType);
- }
- }
- }
- }
- }
-
- // addrspacecast between types is canonicalized as a bitcast, then an
- // addrspacecast. To take advantage of the below bitcast + struct GEP, look
- // through the addrspacecast.
- Value *ASCStrippedPtrOp = PtrOp;
- if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) {
- // X = bitcast A addrspace(1)* to B addrspace(1)*
- // Y = addrspacecast A addrspace(1)* to B addrspace(2)*
- // Z = gep Y, <...constant indices...>
- // Into an addrspacecasted GEP of the struct.
- if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0)))
- ASCStrippedPtrOp = BC;
- }
-
- if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp))
- if (Instruction *I = visitGEPOfBitcast(BCI, GEP))
- return I;
-
if (!GEP.isInBounds()) {
unsigned IdxWidth =
DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace());
@@ -2690,12 +2276,13 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
Value *UnderlyingPtrOp =
PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL,
BasePtrOffset);
- if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) {
+ bool CanBeNull, CanBeFreed;
+ uint64_t DerefBytes = UnderlyingPtrOp->getPointerDereferenceableBytes(
+ DL, CanBeNull, CanBeFreed);
+ if (!CanBeNull && !CanBeFreed && DerefBytes != 0) {
if (GEP.accumulateConstantOffset(DL, BasePtrOffset) &&
BasePtrOffset.isNonNegative()) {
- APInt AllocSize(
- IdxWidth,
- DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue());
+ APInt AllocSize(IdxWidth, DerefBytes);
if (BasePtrOffset.ule(AllocSize)) {
return GetElementPtrInst::CreateInBounds(
GEP.getSourceElementType(), PtrOp, Indices, GEP.getName());
@@ -2881,8 +2468,11 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
if (II->getIntrinsicID() == Intrinsic::objectsize) {
- Value *Result =
- lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/true);
+ SmallVector<Instruction *> InsertedInstructions;
+ Value *Result = lowerObjectSizeCall(
+ II, DL, &TLI, AA, /*MustSucceed=*/true, &InsertedInstructions);
+ for (Instruction *Inserted : InsertedInstructions)
+ Worklist.add(Inserted);
replaceInstUsesWith(*I, Result);
eraseInstFromFunction(*I);
Users[i] = nullptr; // Skip examining in the next loop.
@@ -3089,50 +2679,27 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
return nullptr;
}
-static bool isMustTailCall(Value *V) {
- if (auto *CI = dyn_cast<CallInst>(V))
- return CI->isMustTailCall();
- return false;
-}
-
Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
- if (RI.getNumOperands() == 0) // ret void
- return nullptr;
-
- Value *ResultOp = RI.getOperand(0);
- Type *VTy = ResultOp->getType();
- if (!VTy->isIntegerTy() || isa<Constant>(ResultOp))
- return nullptr;
-
- // Don't replace result of musttail calls.
- if (isMustTailCall(ResultOp))
- return nullptr;
-
- // There might be assume intrinsics dominating this return that completely
- // determine the value. If so, constant fold it.
- KnownBits Known = computeKnownBits(ResultOp, 0, &RI);
- if (Known.isConstant())
- return replaceOperand(RI, 0,
- Constant::getIntegerValue(VTy, Known.getConstant()));
-
+ // Nothing for now.
return nullptr;
}
// WARNING: keep in sync with SimplifyCFGOpt::simplifyUnreachable()!
-Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) {
+bool InstCombinerImpl::removeInstructionsBeforeUnreachable(Instruction &I) {
// Try to remove the previous instruction if it must lead to unreachable.
// This includes instructions like stores and "llvm.assume" that may not get
// removed by simple dead code elimination.
+ bool Changed = false;
while (Instruction *Prev = I.getPrevNonDebugInstruction()) {
// While we theoretically can erase EH, that would result in a block that
// used to start with an EH no longer starting with EH, which is invalid.
// To make it valid, we'd need to fixup predecessors to no longer refer to
// this block, but that changes CFG, which is not allowed in InstCombine.
if (Prev->isEHPad())
- return nullptr; // Can not drop any more instructions. We're done here.
+ break; // Can not drop any more instructions. We're done here.
if (!isGuaranteedToTransferExecutionToSuccessor(Prev))
- return nullptr; // Can not drop any more instructions. We're done here.
+ break; // Can not drop any more instructions. We're done here.
// Otherwise, this instruction can be freely erased,
// even if it is not side-effect free.
@@ -3140,9 +2707,13 @@ Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) {
// another unreachable block), so convert those to poison.
replaceInstUsesWith(*Prev, PoisonValue::get(Prev->getType()));
eraseInstFromFunction(*Prev);
+ Changed = true;
}
- assert(I.getParent()->sizeWithoutDebug() == 1 && "The block is now empty.");
- // FIXME: recurse into unconditional predecessors?
+ return Changed;
+}
+
+Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) {
+ removeInstructionsBeforeUnreachable(I);
return nullptr;
}
@@ -3175,6 +2746,57 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) {
return nullptr;
}
+// Under the assumption that I is unreachable, remove it and following
+// instructions.
+bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) {
+ bool Changed = false;
+ BasicBlock *BB = I->getParent();
+ for (Instruction &Inst : make_early_inc_range(
+ make_range(std::next(BB->getTerminator()->getReverseIterator()),
+ std::next(I->getReverseIterator())))) {
+ if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) {
+ replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType()));
+ Changed = true;
+ }
+ if (Inst.isEHPad() || Inst.getType()->isTokenTy())
+ continue;
+ eraseInstFromFunction(Inst);
+ Changed = true;
+ }
+
+ // Replace phi node operands in successor blocks with poison.
+ for (BasicBlock *Succ : successors(BB))
+ for (PHINode &PN : Succ->phis())
+ for (Use &U : PN.incoming_values())
+ if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) {
+ replaceUse(U, PoisonValue::get(PN.getType()));
+ addToWorklist(&PN);
+ Changed = true;
+ }
+
+ // TODO: Successor blocks may also be dead.
+ return Changed;
+}
+
+bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
+ BasicBlock *LiveSucc) {
+ bool Changed = false;
+ for (BasicBlock *Succ : successors(BB)) {
+ // The live successor isn't dead.
+ if (Succ == LiveSucc)
+ continue;
+
+ if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) {
+ return DT.dominates(BasicBlockEdge(BB, Succ),
+ BasicBlockEdge(Pred, Succ));
+ }))
+ continue;
+
+ Changed |= handleUnreachableFrom(&Succ->front());
+ }
+ return Changed;
+}
+
Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
if (BI.isUnconditional())
return visitUnconditionalBranchInst(BI);
@@ -3218,6 +2840,14 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return &BI;
}
+ if (isa<UndefValue>(Cond) &&
+ handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr))
+ return &BI;
+ if (auto *CI = dyn_cast<ConstantInt>(Cond))
+ if (handlePotentiallyDeadSuccessors(BI.getParent(),
+ BI.getSuccessor(!CI->getZExtValue())))
+ return &BI;
+
return nullptr;
}
@@ -3236,6 +2866,14 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
return replaceOperand(SI, 0, Op0);
}
+ if (isa<UndefValue>(Cond) &&
+ handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr))
+ return &SI;
+ if (auto *CI = dyn_cast<ConstantInt>(Cond))
+ if (handlePotentiallyDeadSuccessors(
+ SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor()))
+ return &SI;
+
KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
@@ -3243,10 +2881,10 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
// Compute the number of leading bits we can ignore.
// TODO: A better way to determine this would use ComputeNumSignBits().
for (const auto &C : SI.cases()) {
- LeadingKnownZeros = std::min(
- LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros());
- LeadingKnownOnes = std::min(
- LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes());
+ LeadingKnownZeros =
+ std::min(LeadingKnownZeros, C.getCaseValue()->getValue().countl_zero());
+ LeadingKnownOnes =
+ std::min(LeadingKnownOnes, C.getCaseValue()->getValue().countl_one());
}
unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes);
@@ -3412,6 +3050,11 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
return R;
if (LoadInst *L = dyn_cast<LoadInst>(Agg)) {
+ // Bail out if the aggregate contains scalable vector type
+ if (auto *STy = dyn_cast<StructType>(Agg->getType());
+ STy && STy->containsScalableVectorType())
+ return nullptr;
+
// If the (non-volatile) load only has one use, we can rewrite this to a
// load from a GEP. This reduces the size of the load. If a load is used
// only by extractvalue instructions then this either must have been
@@ -3965,6 +3608,17 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) {
return Changed;
}
+// Check if any direct or bitcast user of this value is a shuffle instruction.
+static bool isUsedWithinShuffleVector(Value *V) {
+ for (auto *U : V->users()) {
+ if (isa<ShuffleVectorInst>(U))
+ return true;
+ else if (match(U, m_BitCast(m_Specific(V))) && isUsedWithinShuffleVector(U))
+ return true;
+ }
+ return false;
+}
+
Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) {
Value *Op0 = I.getOperand(0);
@@ -4014,8 +3668,14 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) {
return BestValue;
};
- if (match(Op0, m_Undef()))
+ if (match(Op0, m_Undef())) {
+ // Don't fold freeze(undef/poison) if it's used as a vector operand in
+ // a shuffle. This may improve codegen for shuffles that allow
+ // unspecified inputs.
+ if (isUsedWithinShuffleVector(&I))
+ return nullptr;
return replaceInstUsesWith(I, getUndefReplacement(I.getType()));
+ }
Constant *C;
if (match(Op0, m_Constant(C)) && C->containsUndefOrPoisonElement()) {
@@ -4078,8 +3738,8 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) {
/// beginning of DestBlock, which can only happen if it's safe to move the
/// instruction past all of the instructions between it and the end of its
/// block.
-static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock,
- TargetLibraryInfo &TLI) {
+bool InstCombinerImpl::tryToSinkInstruction(Instruction *I,
+ BasicBlock *DestBlock) {
BasicBlock *SrcBlock = I->getParent();
// Cannot move control-flow-involving, volatile loads, vaarg, etc.
@@ -4126,10 +3786,13 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock,
return false;
}
- I->dropDroppableUses([DestBlock](const Use *U) {
- if (auto *I = dyn_cast<Instruction>(U->getUser()))
- return I->getParent() != DestBlock;
- return true;
+ I->dropDroppableUses([&](const Use *U) {
+ auto *I = dyn_cast<Instruction>(U->getUser());
+ if (I && I->getParent() != DestBlock) {
+ Worklist.add(I);
+ return true;
+ }
+ return false;
});
/// FIXME: We could remove droppable uses that are not dominated by
/// the new position.
@@ -4227,23 +3890,6 @@ bool InstCombinerImpl::run() {
if (!DebugCounter::shouldExecute(VisitCounter))
continue;
- // Instruction isn't dead, see if we can constant propagate it.
- if (!I->use_empty() &&
- (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) {
- if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) {
- LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I
- << '\n');
-
- // Add operands to the worklist.
- replaceInstUsesWith(*I, C);
- ++NumConstProp;
- if (isInstructionTriviallyDead(I, &TLI))
- eraseInstFromFunction(*I);
- MadeIRChange = true;
- continue;
- }
- }
-
// See if we can trivially sink this instruction to its user if we can
// prove that the successor is not executed more frequently than our block.
// Return the UserBlock if successful.
@@ -4319,7 +3965,7 @@ bool InstCombinerImpl::run() {
if (OptBB) {
auto *UserParent = *OptBB;
// Okay, the CFG is simple enough, try to sink this instruction.
- if (TryToSinkInstruction(I, UserParent, TLI)) {
+ if (tryToSinkInstruction(I, UserParent)) {
LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n');
MadeIRChange = true;
// We'll add uses of the sunk instruction below, but since
@@ -4520,15 +4166,21 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
// Recursively visit successors. If this is a branch or switch on a
// constant, only visit the reachable successor.
Instruction *TI = BB->getTerminator();
- if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
- if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) {
- bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue();
+ if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) {
+ if (isa<UndefValue>(BI->getCondition()))
+ // Branch on undef is UB.
+ continue;
+ if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) {
+ bool CondVal = Cond->getZExtValue();
BasicBlock *ReachableBB = BI->getSuccessor(!CondVal);
Worklist.push_back(ReachableBB);
continue;
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
- if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) {
+ if (isa<UndefValue>(SI->getCondition()))
+ // Switch on undef is UB.
+ continue;
+ if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) {
Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor());
continue;
}
@@ -4584,7 +4236,6 @@ static bool combineInstructionsOverFunction(
DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) {
auto &DL = F.getParent()->getDataLayout();
- MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue());
/// Builder - This is an IRBuilder that automatically inserts new
/// instructions into the worklist when they are created.
@@ -4601,13 +4252,6 @@ static bool combineInstructionsOverFunction(
bool MadeIRChange = false;
if (ShouldLowerDbgDeclare)
MadeIRChange = LowerDbgDeclare(F);
- // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will
- // almost never return true when running an assignment tracking build. Take
- // this opportunity to do some clean up for assignment tracking builds too.
- if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) {
- for (auto &BB : F)
- RemoveRedundantDbgInstrs(&BB);
- }
// Iterate while there is work to do.
unsigned Iteration = 0;
@@ -4643,13 +4287,29 @@ static bool combineInstructionsOverFunction(
MadeIRChange = true;
}
+ if (Iteration == 1)
+ ++NumOneIteration;
+ else if (Iteration == 2)
+ ++NumTwoIterations;
+ else if (Iteration == 3)
+ ++NumThreeIterations;
+ else
+ ++NumFourOrMoreIterations;
+
return MadeIRChange;
}
-InstCombinePass::InstCombinePass() : MaxIterations(LimitMaxIterations) {}
+InstCombinePass::InstCombinePass(InstCombineOptions Opts) : Options(Opts) {}
-InstCombinePass::InstCombinePass(unsigned MaxIterations)
- : MaxIterations(MaxIterations) {}
+void InstCombinePass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<InstCombinePass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ OS << '<';
+ OS << "max-iterations=" << Options.MaxIterations << ";";
+ OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info";
+ OS << '>';
+}
PreservedAnalyses InstCombinePass::run(Function &F,
FunctionAnalysisManager &AM) {
@@ -4659,7 +4319,11 @@ PreservedAnalyses InstCombinePass::run(Function &F,
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
+ // TODO: Only use LoopInfo when the option is set. This requires that the
+ // callers in the pass pipeline explicitly set the option.
auto *LI = AM.getCachedResult<LoopAnalysis>(F);
+ if (!LI && Options.UseLoopInfo)
+ LI = &AM.getResult<LoopAnalysis>(F);
auto *AA = &AM.getResult<AAManager>(F);
auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
@@ -4669,7 +4333,7 @@ PreservedAnalyses InstCombinePass::run(Function &F,
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
- BFI, PSI, MaxIterations, LI))
+ BFI, PSI, Options.MaxIterations, LI))
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
@@ -4718,18 +4382,13 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
nullptr;
return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
- BFI, PSI, MaxIterations, LI);
+ BFI, PSI,
+ InstCombineDefaultMaxIterations, LI);
}
char InstructionCombiningPass::ID = 0;
-InstructionCombiningPass::InstructionCombiningPass()
- : FunctionPass(ID), MaxIterations(InstCombineDefaultMaxIterations) {
- initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry());
-}
-
-InstructionCombiningPass::InstructionCombiningPass(unsigned MaxIterations)
- : FunctionPass(ID), MaxIterations(MaxIterations) {
+InstructionCombiningPass::InstructionCombiningPass() : FunctionPass(ID) {
initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry());
}
@@ -4752,18 +4411,6 @@ void llvm::initializeInstCombine(PassRegistry &Registry) {
initializeInstructionCombiningPassPass(Registry);
}
-void LLVMInitializeInstCombine(LLVMPassRegistryRef R) {
- initializeInstructionCombiningPassPass(*unwrap(R));
-}
-
FunctionPass *llvm::createInstructionCombiningPass() {
return new InstructionCombiningPass();
}
-
-FunctionPass *llvm::createInstructionCombiningPass(unsigned MaxIterations) {
- return new InstructionCombiningPass(MaxIterations);
-}
-
-void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createInstructionCombiningPass());
-}
diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index 599eeeabc143..bde5fba20f3b 100644
--- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -24,7 +24,6 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/MemoryBuiltins.h"
@@ -70,6 +69,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h"
@@ -492,7 +492,7 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize,
bool IsMIPS64 = TargetTriple.isMIPS64();
bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb();
bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64;
- bool IsLoongArch64 = TargetTriple.getArch() == Triple::loongarch64;
+ bool IsLoongArch64 = TargetTriple.isLoongArch64();
bool IsRISCV64 = TargetTriple.getArch() == Triple::riscv64;
bool IsWindows = TargetTriple.isOSWindows();
bool IsFuchsia = TargetTriple.isOSFuchsia();
@@ -656,6 +656,7 @@ struct AddressSanitizer {
: UseAfterReturn),
SSGI(SSGI) {
C = &(M.getContext());
+ DL = &M.getDataLayout();
LongSize = M.getDataLayout().getPointerSizeInBits();
IntptrTy = Type::getIntNTy(*C, LongSize);
Int8PtrTy = Type::getInt8PtrTy(*C);
@@ -667,17 +668,8 @@ struct AddressSanitizer {
assert(this->UseAfterReturn != AsanDetectStackUseAfterReturnMode::Invalid);
}
- uint64_t getAllocaSizeInBytes(const AllocaInst &AI) const {
- uint64_t ArraySize = 1;
- if (AI.isArrayAllocation()) {
- const ConstantInt *CI = dyn_cast<ConstantInt>(AI.getArraySize());
- assert(CI && "non-constant array size");
- ArraySize = CI->getZExtValue();
- }
- Type *Ty = AI.getAllocatedType();
- uint64_t SizeInBytes =
- AI.getModule()->getDataLayout().getTypeAllocSize(Ty);
- return SizeInBytes * ArraySize;
+ TypeSize getAllocaSizeInBytes(const AllocaInst &AI) const {
+ return *AI.getAllocationSize(AI.getModule()->getDataLayout());
}
/// Check if we want (and can) handle this alloca.
@@ -692,19 +684,27 @@ struct AddressSanitizer {
const DataLayout &DL);
void instrumentPointerComparisonOrSubtraction(Instruction *I);
void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore,
- Value *Addr, uint32_t TypeSize, bool IsWrite,
+ Value *Addr, MaybeAlign Alignment,
+ uint32_t TypeStoreSize, bool IsWrite,
Value *SizeArgument, bool UseCalls, uint32_t Exp);
Instruction *instrumentAMDGPUAddress(Instruction *OrigIns,
Instruction *InsertBefore, Value *Addr,
- uint32_t TypeSize, bool IsWrite,
+ uint32_t TypeStoreSize, bool IsWrite,
Value *SizeArgument);
void instrumentUnusualSizeOrAlignment(Instruction *I,
Instruction *InsertBefore, Value *Addr,
- uint32_t TypeSize, bool IsWrite,
+ TypeSize TypeStoreSize, bool IsWrite,
Value *SizeArgument, bool UseCalls,
uint32_t Exp);
+ void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, const DataLayout &DL,
+ Type *IntptrTy, Value *Mask, Value *EVL,
+ Value *Stride, Instruction *I, Value *Addr,
+ MaybeAlign Alignment, unsigned Granularity,
+ Type *OpType, bool IsWrite,
+ Value *SizeArgument, bool UseCalls,
+ uint32_t Exp);
Value *createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong,
- Value *ShadowValue, uint32_t TypeSize);
+ Value *ShadowValue, uint32_t TypeStoreSize);
Instruction *generateCrashCode(Instruction *InsertBefore, Value *Addr,
bool IsWrite, size_t AccessSizeIndex,
Value *SizeArgument, uint32_t Exp);
@@ -724,7 +724,7 @@ private:
bool LooksLikeCodeInBug11395(Instruction *I);
bool GlobalIsLinkerInitialized(GlobalVariable *G);
bool isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis, Value *Addr,
- uint64_t TypeSize) const;
+ TypeSize TypeStoreSize) const;
/// Helper to cleanup per-function state.
struct FunctionStateRAII {
@@ -743,6 +743,7 @@ private:
};
LLVMContext *C;
+ const DataLayout *DL;
Triple TargetTriple;
int LongSize;
bool CompileKernel;
@@ -1040,7 +1041,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> {
/// Collect Alloca instructions we want (and can) handle.
void visitAllocaInst(AllocaInst &AI) {
- if (!ASan.isInterestingAlloca(AI)) {
+ // FIXME: Handle scalable vectors instead of ignoring them.
+ if (!ASan.isInterestingAlloca(AI) ||
+ isa<ScalableVectorType>(AI.getAllocatedType())) {
if (AI.isStaticAlloca()) {
// Skip over allocas that are present *before* the first instrumented
// alloca, we don't want to move those around.
@@ -1133,10 +1136,10 @@ void AddressSanitizerPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<AddressSanitizerPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Options.CompileKernel)
OS << "kernel";
- OS << ">";
+ OS << '>';
}
AddressSanitizerPass::AddressSanitizerPass(
@@ -1176,8 +1179,8 @@ PreservedAnalyses AddressSanitizerPass::run(Module &M,
return PA;
}
-static size_t TypeSizeToSizeIndex(uint32_t TypeSize) {
- size_t Res = countTrailingZeros(TypeSize / 8);
+static size_t TypeStoreSizeToSizeIndex(uint32_t TypeSize) {
+ size_t Res = llvm::countr_zero(TypeSize / 8);
assert(Res < kNumberOfAccessSizes);
return Res;
}
@@ -1227,7 +1230,7 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) {
// Instrument memset/memmove/memcpy
void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
- IRBuilder<> IRB(MI);
+ InstrumentationIRBuilder IRB(MI);
if (isa<MemTransferInst>(MI)) {
IRB.CreateCall(
isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy,
@@ -1254,7 +1257,7 @@ bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) {
bool IsInteresting =
(AI.getAllocatedType()->isSized() &&
// alloca() may be called with 0 size, ignore it.
- ((!AI.isStaticAlloca()) || getAllocaSizeInBytes(AI) > 0) &&
+ ((!AI.isStaticAlloca()) || !getAllocaSizeInBytes(AI).isZero()) &&
// We are only interested in allocas not promotable to registers.
// Promotable allocas are common under -O0.
(!ClSkipPromotableAllocas || !isAllocaPromotable(&AI)) &&
@@ -1326,9 +1329,12 @@ void AddressSanitizer::getInterestingMemoryOperands(
XCHG->getCompareOperand()->getType(),
std::nullopt);
} else if (auto CI = dyn_cast<CallInst>(I)) {
- if (CI->getIntrinsicID() == Intrinsic::masked_load ||
- CI->getIntrinsicID() == Intrinsic::masked_store) {
- bool IsWrite = CI->getIntrinsicID() == Intrinsic::masked_store;
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::masked_load:
+ case Intrinsic::masked_store:
+ case Intrinsic::masked_gather:
+ case Intrinsic::masked_scatter: {
+ bool IsWrite = CI->getType()->isVoidTy();
// Masked store has an initial operand for the value.
unsigned OpOffset = IsWrite ? 1 : 0;
if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads)
@@ -1344,7 +1350,76 @@ void AddressSanitizer::getInterestingMemoryOperands(
Alignment = Op->getMaybeAlignValue();
Value *Mask = CI->getOperand(2 + OpOffset);
Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, Mask);
- } else {
+ break;
+ }
+ case Intrinsic::masked_expandload:
+ case Intrinsic::masked_compressstore: {
+ bool IsWrite = CI->getIntrinsicID() == Intrinsic::masked_compressstore;
+ unsigned OpOffset = IsWrite ? 1 : 0;
+ if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads)
+ return;
+ auto BasePtr = CI->getOperand(OpOffset);
+ if (ignoreAccess(I, BasePtr))
+ return;
+ MaybeAlign Alignment = BasePtr->getPointerAlignment(*DL);
+ Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType();
+
+ IRBuilder IB(I);
+ Value *Mask = CI->getOperand(1 + OpOffset);
+ // Use the popcount of Mask as the effective vector length.
+ Type *ExtTy = VectorType::get(IntptrTy, cast<VectorType>(Ty));
+ Value *ExtMask = IB.CreateZExt(Mask, ExtTy);
+ Value *EVL = IB.CreateAddReduce(ExtMask);
+ Value *TrueMask = ConstantInt::get(Mask->getType(), 1);
+ Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, TrueMask,
+ EVL);
+ break;
+ }
+ case Intrinsic::vp_load:
+ case Intrinsic::vp_store:
+ case Intrinsic::experimental_vp_strided_load:
+ case Intrinsic::experimental_vp_strided_store: {
+ auto *VPI = cast<VPIntrinsic>(CI);
+ unsigned IID = CI->getIntrinsicID();
+ bool IsWrite = CI->getType()->isVoidTy();
+ if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads)
+ return;
+ unsigned PtrOpNo = *VPI->getMemoryPointerParamPos(IID);
+ Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType();
+ MaybeAlign Alignment = VPI->getOperand(PtrOpNo)->getPointerAlignment(*DL);
+ Value *Stride = nullptr;
+ if (IID == Intrinsic::experimental_vp_strided_store ||
+ IID == Intrinsic::experimental_vp_strided_load) {
+ Stride = VPI->getOperand(PtrOpNo + 1);
+ // Use the pointer alignment as the element alignment if the stride is a
+ // mutiple of the pointer alignment. Otherwise, the element alignment
+ // should be Align(1).
+ unsigned PointerAlign = Alignment.valueOrOne().value();
+ if (!isa<ConstantInt>(Stride) ||
+ cast<ConstantInt>(Stride)->getZExtValue() % PointerAlign != 0)
+ Alignment = Align(1);
+ }
+ Interesting.emplace_back(I, PtrOpNo, IsWrite, Ty, Alignment,
+ VPI->getMaskParam(), VPI->getVectorLengthParam(),
+ Stride);
+ break;
+ }
+ case Intrinsic::vp_gather:
+ case Intrinsic::vp_scatter: {
+ auto *VPI = cast<VPIntrinsic>(CI);
+ unsigned IID = CI->getIntrinsicID();
+ bool IsWrite = IID == Intrinsic::vp_scatter;
+ if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads)
+ return;
+ unsigned PtrOpNo = *VPI->getMemoryPointerParamPos(IID);
+ Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType();
+ MaybeAlign Alignment = VPI->getPointerAlignment();
+ Interesting.emplace_back(I, PtrOpNo, IsWrite, Ty, Alignment,
+ VPI->getMaskParam(),
+ VPI->getVectorLengthParam());
+ break;
+ }
+ default:
for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) {
if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) ||
ignoreAccess(I, CI->getArgOperand(ArgNo)))
@@ -1416,57 +1491,94 @@ void AddressSanitizer::instrumentPointerComparisonOrSubtraction(
static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I,
Instruction *InsertBefore, Value *Addr,
MaybeAlign Alignment, unsigned Granularity,
- uint32_t TypeSize, bool IsWrite,
+ TypeSize TypeStoreSize, bool IsWrite,
Value *SizeArgument, bool UseCalls,
uint32_t Exp) {
// Instrument a 1-, 2-, 4-, 8-, or 16- byte access with one check
// if the data is properly aligned.
- if ((TypeSize == 8 || TypeSize == 16 || TypeSize == 32 || TypeSize == 64 ||
- TypeSize == 128) &&
- (!Alignment || *Alignment >= Granularity || *Alignment >= TypeSize / 8))
- return Pass->instrumentAddress(I, InsertBefore, Addr, TypeSize, IsWrite,
- nullptr, UseCalls, Exp);
- Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeSize,
+ if (!TypeStoreSize.isScalable()) {
+ const auto FixedSize = TypeStoreSize.getFixedValue();
+ switch (FixedSize) {
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ case 128:
+ if (!Alignment || *Alignment >= Granularity ||
+ *Alignment >= FixedSize / 8)
+ return Pass->instrumentAddress(I, InsertBefore, Addr, Alignment,
+ FixedSize, IsWrite, nullptr, UseCalls,
+ Exp);
+ }
+ }
+ Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeStoreSize,
IsWrite, nullptr, UseCalls, Exp);
}
-static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass,
- const DataLayout &DL, Type *IntptrTy,
- Value *Mask, Instruction *I,
- Value *Addr, MaybeAlign Alignment,
- unsigned Granularity, Type *OpType,
- bool IsWrite, Value *SizeArgument,
- bool UseCalls, uint32_t Exp) {
- auto *VTy = cast<FixedVectorType>(OpType);
- uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
- unsigned Num = VTy->getNumElements();
+void AddressSanitizer::instrumentMaskedLoadOrStore(
+ AddressSanitizer *Pass, const DataLayout &DL, Type *IntptrTy, Value *Mask,
+ Value *EVL, Value *Stride, Instruction *I, Value *Addr,
+ MaybeAlign Alignment, unsigned Granularity, Type *OpType, bool IsWrite,
+ Value *SizeArgument, bool UseCalls, uint32_t Exp) {
+ auto *VTy = cast<VectorType>(OpType);
+ TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
auto Zero = ConstantInt::get(IntptrTy, 0);
- for (unsigned Idx = 0; Idx < Num; ++Idx) {
- Value *InstrumentedAddress = nullptr;
- Instruction *InsertBefore = I;
- if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
- // dyn_cast as we might get UndefValue
- if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
- if (Masked->isZero())
- // Mask is constant false, so no instrumentation needed.
- continue;
- // If we have a true or undef value, fall through to doInstrumentAddress
- // with InsertBefore == I
- }
+
+ IRBuilder IB(I);
+ Instruction *LoopInsertBefore = I;
+ if (EVL) {
+ // The end argument of SplitBlockAndInsertForLane is assumed bigger
+ // than zero, so we should check whether EVL is zero here.
+ Type *EVLType = EVL->getType();
+ Value *IsEVLZero = IB.CreateICmpNE(EVL, ConstantInt::get(EVLType, 0));
+ LoopInsertBefore = SplitBlockAndInsertIfThen(IsEVLZero, I, false);
+ IB.SetInsertPoint(LoopInsertBefore);
+ // Cast EVL to IntptrTy.
+ EVL = IB.CreateZExtOrTrunc(EVL, IntptrTy);
+ // To avoid undefined behavior for extracting with out of range index, use
+ // the minimum of evl and element count as trip count.
+ Value *EC = IB.CreateElementCount(IntptrTy, VTy->getElementCount());
+ EVL = IB.CreateBinaryIntrinsic(Intrinsic::umin, EVL, EC);
+ } else {
+ EVL = IB.CreateElementCount(IntptrTy, VTy->getElementCount());
+ }
+
+ // Cast Stride to IntptrTy.
+ if (Stride)
+ Stride = IB.CreateZExtOrTrunc(Stride, IntptrTy);
+
+ SplitBlockAndInsertForEachLane(EVL, LoopInsertBefore,
+ [&](IRBuilderBase &IRB, Value *Index) {
+ Value *MaskElem = IRB.CreateExtractElement(Mask, Index);
+ if (auto *MaskElemC = dyn_cast<ConstantInt>(MaskElem)) {
+ if (MaskElemC->isZero())
+ // No check
+ return;
+ // Unconditional check
} else {
- IRBuilder<> IRB(I);
- Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
- Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
- InsertBefore = ThenTerm;
+ // Conditional check
+ Instruction *ThenTerm = SplitBlockAndInsertIfThen(
+ MaskElem, &*IRB.GetInsertPoint(), false);
+ IRB.SetInsertPoint(ThenTerm);
}
- IRBuilder<> IRB(InsertBefore);
- InstrumentedAddress =
- IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)});
- doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
- Granularity, ElemTypeSize, IsWrite, SizeArgument,
- UseCalls, Exp);
- }
+ Value *InstrumentedAddress;
+ if (isa<VectorType>(Addr->getType())) {
+ assert(
+ cast<VectorType>(Addr->getType())->getElementType()->isPointerTy() &&
+ "Expected vector of pointer.");
+ InstrumentedAddress = IRB.CreateExtractElement(Addr, Index);
+ } else if (Stride) {
+ Index = IRB.CreateMul(Index, Stride);
+ Addr = IRB.CreateBitCast(Addr, Type::getInt8PtrTy(*C));
+ InstrumentedAddress = IRB.CreateGEP(Type::getInt8Ty(*C), Addr, {Index});
+ } else {
+ InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index});
+ }
+ doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(),
+ InstrumentedAddress, Alignment, Granularity,
+ ElemTypeSize, IsWrite, SizeArgument, UseCalls, Exp);
+ });
}
void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,
@@ -1492,7 +1604,7 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,
// dynamically initialized global is always valid.
GlobalVariable *G = dyn_cast<GlobalVariable>(getUnderlyingObject(Addr));
if (G && (!ClInitializers || GlobalIsLinkerInitialized(G)) &&
- isSafeAccess(ObjSizeVis, Addr, O.TypeSize)) {
+ isSafeAccess(ObjSizeVis, Addr, O.TypeStoreSize)) {
NumOptimizedAccessesToGlobalVar++;
return;
}
@@ -1501,7 +1613,7 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,
if (ClOpt && ClOptStack) {
// A direct inbounds access to a stack variable is always valid.
if (isa<AllocaInst>(getUnderlyingObject(Addr)) &&
- isSafeAccess(ObjSizeVis, Addr, O.TypeSize)) {
+ isSafeAccess(ObjSizeVis, Addr, O.TypeStoreSize)) {
NumOptimizedAccessesToStackVar++;
return;
}
@@ -1514,12 +1626,13 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,
unsigned Granularity = 1 << Mapping.Scale;
if (O.MaybeMask) {
- instrumentMaskedLoadOrStore(this, DL, IntptrTy, O.MaybeMask, O.getInsn(),
- Addr, O.Alignment, Granularity, O.OpType,
- O.IsWrite, nullptr, UseCalls, Exp);
+ instrumentMaskedLoadOrStore(this, DL, IntptrTy, O.MaybeMask, O.MaybeEVL,
+ O.MaybeStride, O.getInsn(), Addr, O.Alignment,
+ Granularity, O.OpType, O.IsWrite, nullptr,
+ UseCalls, Exp);
} else {
doInstrumentAddress(this, O.getInsn(), O.getInsn(), Addr, O.Alignment,
- Granularity, O.TypeSize, O.IsWrite, nullptr, UseCalls,
+ Granularity, O.TypeStoreSize, O.IsWrite, nullptr, UseCalls,
Exp);
}
}
@@ -1529,7 +1642,7 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore,
size_t AccessSizeIndex,
Value *SizeArgument,
uint32_t Exp) {
- IRBuilder<> IRB(InsertBefore);
+ InstrumentationIRBuilder IRB(InsertBefore);
Value *ExpVal = Exp == 0 ? nullptr : ConstantInt::get(IRB.getInt32Ty(), Exp);
CallInst *Call = nullptr;
if (SizeArgument) {
@@ -1554,15 +1667,15 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore,
Value *AddressSanitizer::createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong,
Value *ShadowValue,
- uint32_t TypeSize) {
+ uint32_t TypeStoreSize) {
size_t Granularity = static_cast<size_t>(1) << Mapping.Scale;
// Addr & (Granularity - 1)
Value *LastAccessedByte =
IRB.CreateAnd(AddrLong, ConstantInt::get(IntptrTy, Granularity - 1));
// (Addr & (Granularity - 1)) + size - 1
- if (TypeSize / 8 > 1)
+ if (TypeStoreSize / 8 > 1)
LastAccessedByte = IRB.CreateAdd(
- LastAccessedByte, ConstantInt::get(IntptrTy, TypeSize / 8 - 1));
+ LastAccessedByte, ConstantInt::get(IntptrTy, TypeStoreSize / 8 - 1));
// (uint8_t) ((Addr & (Granularity-1)) + size - 1)
LastAccessedByte =
IRB.CreateIntCast(LastAccessedByte, ShadowValue->getType(), false);
@@ -1572,7 +1685,7 @@ Value *AddressSanitizer::createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong,
Instruction *AddressSanitizer::instrumentAMDGPUAddress(
Instruction *OrigIns, Instruction *InsertBefore, Value *Addr,
- uint32_t TypeSize, bool IsWrite, Value *SizeArgument) {
+ uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument) {
// Do not instrument unsupported addrspaces.
if (isUnsupportedAMDGPUAddrspace(Addr))
return nullptr;
@@ -1595,18 +1708,19 @@ Instruction *AddressSanitizer::instrumentAMDGPUAddress(
void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
Instruction *InsertBefore, Value *Addr,
- uint32_t TypeSize, bool IsWrite,
+ MaybeAlign Alignment,
+ uint32_t TypeStoreSize, bool IsWrite,
Value *SizeArgument, bool UseCalls,
uint32_t Exp) {
if (TargetTriple.isAMDGPU()) {
InsertBefore = instrumentAMDGPUAddress(OrigIns, InsertBefore, Addr,
- TypeSize, IsWrite, SizeArgument);
+ TypeStoreSize, IsWrite, SizeArgument);
if (!InsertBefore)
return;
}
- IRBuilder<> IRB(InsertBefore);
- size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize);
+ InstrumentationIRBuilder IRB(InsertBefore);
+ size_t AccessSizeIndex = TypeStoreSizeToSizeIndex(TypeStoreSize);
const ASanAccessInfo AccessInfo(IsWrite, CompileKernel, AccessSizeIndex);
if (UseCalls && ClOptimizeCallbacks) {
@@ -1631,17 +1745,19 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
}
Type *ShadowTy =
- IntegerType::get(*C, std::max(8U, TypeSize >> Mapping.Scale));
+ IntegerType::get(*C, std::max(8U, TypeStoreSize >> Mapping.Scale));
Type *ShadowPtrTy = PointerType::get(ShadowTy, 0);
Value *ShadowPtr = memToShadow(AddrLong, IRB);
- Value *ShadowValue =
- IRB.CreateLoad(ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy));
+ const uint64_t ShadowAlign =
+ std::max<uint64_t>(Alignment.valueOrOne().value() >> Mapping.Scale, 1);
+ Value *ShadowValue = IRB.CreateAlignedLoad(
+ ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy), Align(ShadowAlign));
Value *Cmp = IRB.CreateIsNotNull(ShadowValue);
size_t Granularity = 1ULL << Mapping.Scale;
Instruction *CrashTerm = nullptr;
- if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) {
+ if (ClAlwaysSlowPath || (TypeStoreSize < 8 * Granularity)) {
// We use branch weights for the slow path check, to indicate that the slow
// path is rarely taken. This seems to be the case for SPEC benchmarks.
Instruction *CheckTerm = SplitBlockAndInsertIfThen(
@@ -1649,7 +1765,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
assert(cast<BranchInst>(CheckTerm)->isUnconditional());
BasicBlock *NextBB = CheckTerm->getSuccessor(0);
IRB.SetInsertPoint(CheckTerm);
- Value *Cmp2 = createSlowPathCmp(IRB, AddrLong, ShadowValue, TypeSize);
+ Value *Cmp2 = createSlowPathCmp(IRB, AddrLong, ShadowValue, TypeStoreSize);
if (Recover) {
CrashTerm = SplitBlockAndInsertIfThen(Cmp2, CheckTerm, false);
} else {
@@ -1665,7 +1781,8 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
Instruction *Crash = generateCrashCode(CrashTerm, AddrLong, IsWrite,
AccessSizeIndex, SizeArgument, Exp);
- Crash->setDebugLoc(OrigIns->getDebugLoc());
+ if (OrigIns->getDebugLoc())
+ Crash->setDebugLoc(OrigIns->getDebugLoc());
}
// Instrument unusual size or unusual alignment.
@@ -1673,10 +1790,12 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
// and the last bytes. We call __asan_report_*_n(addr, real_size) to be able
// to report the actual access size.
void AddressSanitizer::instrumentUnusualSizeOrAlignment(
- Instruction *I, Instruction *InsertBefore, Value *Addr, uint32_t TypeSize,
+ Instruction *I, Instruction *InsertBefore, Value *Addr, TypeSize TypeStoreSize,
bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) {
- IRBuilder<> IRB(InsertBefore);
- Value *Size = ConstantInt::get(IntptrTy, TypeSize / 8);
+ InstrumentationIRBuilder IRB(InsertBefore);
+ Value *NumBits = IRB.CreateTypeSize(IntptrTy, TypeStoreSize);
+ Value *Size = IRB.CreateLShr(NumBits, ConstantInt::get(IntptrTy, 3));
+
Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy);
if (UseCalls) {
if (Exp == 0)
@@ -1686,11 +1805,13 @@ void AddressSanitizer::instrumentUnusualSizeOrAlignment(
IRB.CreateCall(AsanMemoryAccessCallbackSized[IsWrite][1],
{AddrLong, Size, ConstantInt::get(IRB.getInt32Ty(), Exp)});
} else {
+ Value *SizeMinusOne = IRB.CreateSub(Size, ConstantInt::get(IntptrTy, 1));
Value *LastByte = IRB.CreateIntToPtr(
- IRB.CreateAdd(AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8 - 1)),
+ IRB.CreateAdd(AddrLong, SizeMinusOne),
Addr->getType());
- instrumentAddress(I, InsertBefore, Addr, 8, IsWrite, Size, false, Exp);
- instrumentAddress(I, InsertBefore, LastByte, 8, IsWrite, Size, false, Exp);
+ instrumentAddress(I, InsertBefore, Addr, {}, 8, IsWrite, Size, false, Exp);
+ instrumentAddress(I, InsertBefore, LastByte, {}, 8, IsWrite, Size, false,
+ Exp);
}
}
@@ -2306,7 +2427,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
G->getThreadLocalMode(), G->getAddressSpace());
NewGlobal->copyAttributesFrom(G);
NewGlobal->setComdat(G->getComdat());
- NewGlobal->setAlignment(MaybeAlign(getMinRedzoneSizeForGlobal()));
+ NewGlobal->setAlignment(Align(getMinRedzoneSizeForGlobal()));
// Don't fold globals with redzones. ODR violation detector and redzone
// poisoning implicitly creates a dependence on the global's address, so it
// is no longer valid for it to be marked unnamed_addr.
@@ -3485,7 +3606,11 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) {
// base object. For example, it is a field access or an array access with
// constant inbounds index.
bool AddressSanitizer::isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis,
- Value *Addr, uint64_t TypeSize) const {
+ Value *Addr, TypeSize TypeStoreSize) const {
+ if (TypeStoreSize.isScalable())
+ // TODO: We can use vscale_range to convert a scalable value to an
+ // upper bound on the access size.
+ return false;
SizeOffsetType SizeOffset = ObjSizeVis.compute(Addr);
if (!ObjSizeVis.bothKnown(SizeOffset)) return false;
uint64_t Size = SizeOffset.first.getZExtValue();
@@ -3495,5 +3620,5 @@ bool AddressSanitizer::isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis,
// . Size >= Offset (unsigned)
// . Size - Offset >= NeededSize (unsigned)
return Offset >= 0 && Size >= uint64_t(Offset) &&
- Size - uint64_t(Offset) >= TypeSize / 8;
+ Size - uint64_t(Offset) >= TypeStoreSize / 8;
}
diff --git a/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp b/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp
new file mode 100644
index 000000000000..0e49984c6ee3
--- /dev/null
+++ b/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp
@@ -0,0 +1,368 @@
+//===-- BlockCoverageInference.cpp - Minimal Execution Coverage -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Our algorithm works by first identifying a subset of nodes that must always
+// be instrumented. We call these nodes ambiguous because knowing the coverage
+// of all remaining nodes is not enough to infer their coverage status.
+//
+// In general a node v is ambiguous if there exists two entry-to-terminal paths
+// P_1 and P_2 such that:
+// 1. v not in P_1 but P_1 visits a predecessor of v, and
+// 2. v not in P_2 but P_2 visits a successor of v.
+//
+// If a node v is not ambiguous, then if condition 1 fails, we can infer v’s
+// coverage from the coverage of its predecessors, or if condition 2 fails, we
+// can infer v’s coverage from the coverage of its successors.
+//
+// Sadly, there are example CFGs where it is not possible to infer all nodes
+// from the ambiguous nodes alone. Our algorithm selects a minimum number of
+// extra nodes to add to the ambiguous nodes to form a valid instrumentation S.
+//
+// Details on this algorithm can be found in https://arxiv.org/abs/2208.13907
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/CRC.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "pgo-block-coverage"
+
+STATISTIC(NumFunctions, "Number of total functions that BCI has processed");
+STATISTIC(NumIneligibleFunctions,
+ "Number of functions for which BCI cannot run on");
+STATISTIC(NumBlocks, "Number of total basic blocks that BCI has processed");
+STATISTIC(NumInstrumentedBlocks,
+ "Number of basic blocks instrumented for coverage");
+
+BlockCoverageInference::BlockCoverageInference(const Function &F,
+ bool ForceInstrumentEntry)
+ : F(F), ForceInstrumentEntry(ForceInstrumentEntry) {
+ findDependencies();
+ assert(!ForceInstrumentEntry || shouldInstrumentBlock(F.getEntryBlock()));
+
+ ++NumFunctions;
+ for (auto &BB : F) {
+ ++NumBlocks;
+ if (shouldInstrumentBlock(BB))
+ ++NumInstrumentedBlocks;
+ }
+}
+
+BlockCoverageInference::BlockSet
+BlockCoverageInference::getDependencies(const BasicBlock &BB) const {
+ assert(BB.getParent() == &F);
+ BlockSet Dependencies;
+ auto It = PredecessorDependencies.find(&BB);
+ if (It != PredecessorDependencies.end())
+ Dependencies.set_union(It->second);
+ It = SuccessorDependencies.find(&BB);
+ if (It != SuccessorDependencies.end())
+ Dependencies.set_union(It->second);
+ return Dependencies;
+}
+
+uint64_t BlockCoverageInference::getInstrumentedBlocksHash() const {
+ JamCRC JC;
+ uint64_t Index = 0;
+ for (auto &BB : F) {
+ if (shouldInstrumentBlock(BB)) {
+ uint8_t Data[8];
+ support::endian::write64le(Data, Index);
+ JC.update(Data);
+ }
+ Index++;
+ }
+ return JC.getCRC();
+}
+
+bool BlockCoverageInference::shouldInstrumentBlock(const BasicBlock &BB) const {
+ assert(BB.getParent() == &F);
+ auto It = PredecessorDependencies.find(&BB);
+ if (It != PredecessorDependencies.end() && It->second.size())
+ return false;
+ It = SuccessorDependencies.find(&BB);
+ if (It != SuccessorDependencies.end() && It->second.size())
+ return false;
+ return true;
+}
+
+void BlockCoverageInference::findDependencies() {
+ assert(PredecessorDependencies.empty() && SuccessorDependencies.empty());
+ // Empirical analysis shows that this algorithm finishes within 5 seconds for
+ // functions with fewer than 1.5K blocks.
+ if (F.hasFnAttribute(Attribute::NoReturn) || F.size() > 1500) {
+ ++NumIneligibleFunctions;
+ return;
+ }
+
+ SmallVector<const BasicBlock *, 4> TerminalBlocks;
+ for (auto &BB : F)
+ if (succ_empty(&BB))
+ TerminalBlocks.push_back(&BB);
+
+ // Traverse the CFG backwards from the terminal blocks to make sure every
+ // block can reach some terminal block. Otherwise this algorithm will not work
+ // and we must fall back to instrumenting every block.
+ df_iterator_default_set<const BasicBlock *> Visited;
+ for (auto *BB : TerminalBlocks)
+ for (auto *N : inverse_depth_first_ext(BB, Visited))
+ (void)N;
+ if (F.size() != Visited.size()) {
+ ++NumIneligibleFunctions;
+ return;
+ }
+
+ // The current implementation for computing `PredecessorDependencies` and
+ // `SuccessorDependencies` runs in quadratic time with respect to the number
+ // of basic blocks. While we do have a more complicated linear time algorithm
+ // in https://arxiv.org/abs/2208.13907 we do not know if it will give a
+ // significant speedup in practice given that most functions tend to be
+ // relatively small in size for intended use cases.
+ auto &EntryBlock = F.getEntryBlock();
+ for (auto &BB : F) {
+ // The set of blocks that are reachable while avoiding BB.
+ BlockSet ReachableFromEntry, ReachableFromTerminal;
+ getReachableAvoiding(EntryBlock, BB, /*IsForward=*/true,
+ ReachableFromEntry);
+ for (auto *TerminalBlock : TerminalBlocks)
+ getReachableAvoiding(*TerminalBlock, BB, /*IsForward=*/false,
+ ReachableFromTerminal);
+
+ auto Preds = predecessors(&BB);
+ bool HasSuperReachablePred = llvm::any_of(Preds, [&](auto *Pred) {
+ return ReachableFromEntry.count(Pred) &&
+ ReachableFromTerminal.count(Pred);
+ });
+ if (!HasSuperReachablePred)
+ for (auto *Pred : Preds)
+ if (ReachableFromEntry.count(Pred))
+ PredecessorDependencies[&BB].insert(Pred);
+
+ auto Succs = successors(&BB);
+ bool HasSuperReachableSucc = llvm::any_of(Succs, [&](auto *Succ) {
+ return ReachableFromEntry.count(Succ) &&
+ ReachableFromTerminal.count(Succ);
+ });
+ if (!HasSuperReachableSucc)
+ for (auto *Succ : Succs)
+ if (ReachableFromTerminal.count(Succ))
+ SuccessorDependencies[&BB].insert(Succ);
+ }
+
+ if (ForceInstrumentEntry) {
+ // Force the entry block to be instrumented by clearing the blocks it can
+ // infer coverage from.
+ PredecessorDependencies[&EntryBlock].clear();
+ SuccessorDependencies[&EntryBlock].clear();
+ }
+
+ // Construct a graph where blocks are connected if there is a mutual
+ // dependency between them. This graph has a special property that it contains
+ // only paths.
+ DenseMap<const BasicBlock *, BlockSet> AdjacencyList;
+ for (auto &BB : F) {
+ for (auto *Succ : successors(&BB)) {
+ if (SuccessorDependencies[&BB].count(Succ) &&
+ PredecessorDependencies[Succ].count(&BB)) {
+ AdjacencyList[&BB].insert(Succ);
+ AdjacencyList[Succ].insert(&BB);
+ }
+ }
+ }
+
+ // Given a path with at least one node, return the next node on the path.
+ auto getNextOnPath = [&](BlockSet &Path) -> const BasicBlock * {
+ assert(Path.size());
+ auto &Neighbors = AdjacencyList[Path.back()];
+ if (Path.size() == 1) {
+ // This is the first node on the path, return its neighbor.
+ assert(Neighbors.size() == 1);
+ return Neighbors.front();
+ } else if (Neighbors.size() == 2) {
+ // This is the middle of the path, find the neighbor that is not on the
+ // path already.
+ assert(Path.size() >= 2);
+ return Path.count(Neighbors[0]) ? Neighbors[1] : Neighbors[0];
+ }
+ // This is the end of the path.
+ assert(Neighbors.size() == 1);
+ return nullptr;
+ };
+
+ // Remove all cycles in the inferencing graph.
+ for (auto &BB : F) {
+ if (AdjacencyList[&BB].size() == 1) {
+ // We found the head of some path.
+ BlockSet Path;
+ Path.insert(&BB);
+ while (const BasicBlock *Next = getNextOnPath(Path))
+ Path.insert(Next);
+ LLVM_DEBUG(dbgs() << "Found path: " << getBlockNames(Path) << "\n");
+
+ // Remove these nodes from the graph so we don't discover this path again.
+ for (auto *BB : Path)
+ AdjacencyList[BB].clear();
+
+ // Finally, remove the cycles.
+ if (PredecessorDependencies[Path.front()].size()) {
+ for (auto *BB : Path)
+ if (BB != Path.back())
+ SuccessorDependencies[BB].clear();
+ } else {
+ for (auto *BB : Path)
+ if (BB != Path.front())
+ PredecessorDependencies[BB].clear();
+ }
+ }
+ }
+ LLVM_DEBUG(dump(dbgs()));
+}
+
+void BlockCoverageInference::getReachableAvoiding(const BasicBlock &Start,
+ const BasicBlock &Avoid,
+ bool IsForward,
+ BlockSet &Reachable) const {
+ df_iterator_default_set<const BasicBlock *> Visited;
+ Visited.insert(&Avoid);
+ if (IsForward) {
+ auto Range = depth_first_ext(&Start, Visited);
+ Reachable.insert(Range.begin(), Range.end());
+ } else {
+ auto Range = inverse_depth_first_ext(&Start, Visited);
+ Reachable.insert(Range.begin(), Range.end());
+ }
+}
+
+namespace llvm {
+class DotFuncBCIInfo {
+private:
+ const BlockCoverageInference *BCI;
+ const DenseMap<const BasicBlock *, bool> *Coverage;
+
+public:
+ DotFuncBCIInfo(const BlockCoverageInference *BCI,
+ const DenseMap<const BasicBlock *, bool> *Coverage)
+ : BCI(BCI), Coverage(Coverage) {}
+
+ const Function &getFunction() { return BCI->F; }
+
+ bool isInstrumented(const BasicBlock *BB) const {
+ return BCI->shouldInstrumentBlock(*BB);
+ }
+
+ bool isCovered(const BasicBlock *BB) const {
+ return Coverage && Coverage->lookup(BB);
+ }
+
+ bool isDependent(const BasicBlock *Src, const BasicBlock *Dest) const {
+ return BCI->getDependencies(*Src).count(Dest);
+ }
+};
+
+template <>
+struct GraphTraits<DotFuncBCIInfo *> : public GraphTraits<const BasicBlock *> {
+ static NodeRef getEntryNode(DotFuncBCIInfo *Info) {
+ return &(Info->getFunction().getEntryBlock());
+ }
+
+ // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
+ using nodes_iterator = pointer_iterator<Function::const_iterator>;
+
+ static nodes_iterator nodes_begin(DotFuncBCIInfo *Info) {
+ return nodes_iterator(Info->getFunction().begin());
+ }
+
+ static nodes_iterator nodes_end(DotFuncBCIInfo *Info) {
+ return nodes_iterator(Info->getFunction().end());
+ }
+
+ static size_t size(DotFuncBCIInfo *Info) {
+ return Info->getFunction().size();
+ }
+};
+
+template <>
+struct DOTGraphTraits<DotFuncBCIInfo *> : public DefaultDOTGraphTraits {
+
+ DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
+
+ static std::string getGraphName(DotFuncBCIInfo *Info) {
+ return "BCI CFG for " + Info->getFunction().getName().str();
+ }
+
+ std::string getNodeLabel(const BasicBlock *Node, DotFuncBCIInfo *Info) {
+ return Node->getName().str();
+ }
+
+ std::string getEdgeAttributes(const BasicBlock *Src, const_succ_iterator I,
+ DotFuncBCIInfo *Info) {
+ const BasicBlock *Dest = *I;
+ if (Info->isDependent(Src, Dest))
+ return "color=red";
+ if (Info->isDependent(Dest, Src))
+ return "color=blue";
+ return "";
+ }
+
+ std::string getNodeAttributes(const BasicBlock *Node, DotFuncBCIInfo *Info) {
+ std::string Result;
+ if (Info->isInstrumented(Node))
+ Result += "style=filled,fillcolor=gray";
+ if (Info->isCovered(Node))
+ Result += std::string(Result.empty() ? "" : ",") + "color=red";
+ return Result;
+ }
+};
+
+} // namespace llvm
+
+void BlockCoverageInference::viewBlockCoverageGraph(
+ const DenseMap<const BasicBlock *, bool> *Coverage) const {
+ DotFuncBCIInfo Info(this, Coverage);
+ WriteGraph(&Info, "BCI", false,
+ "Block Coverage Inference for " + F.getName());
+}
+
+void BlockCoverageInference::dump(raw_ostream &OS) const {
+ OS << "Minimal block coverage for function \'" << F.getName()
+ << "\' (Instrumented=*)\n";
+ for (auto &BB : F) {
+ OS << (shouldInstrumentBlock(BB) ? "* " : " ") << BB.getName() << "\n";
+ auto It = PredecessorDependencies.find(&BB);
+ if (It != PredecessorDependencies.end() && It->second.size())
+ OS << " PredDeps = " << getBlockNames(It->second) << "\n";
+ It = SuccessorDependencies.find(&BB);
+ if (It != SuccessorDependencies.end() && It->second.size())
+ OS << " SuccDeps = " << getBlockNames(It->second) << "\n";
+ }
+ OS << " Instrumented Blocks Hash = 0x"
+ << Twine::utohexstr(getInstrumentedBlocksHash()) << "\n";
+}
+
+std::string
+BlockCoverageInference::getBlockNames(ArrayRef<const BasicBlock *> BBs) {
+ std::string Result;
+ raw_string_ostream OS(Result);
+ OS << "[";
+ if (!BBs.empty()) {
+ OS << BBs.front()->getName();
+ BBs = BBs.drop_front();
+ }
+ for (auto *BB : BBs)
+ OS << ", " << BB->getName();
+ OS << "]";
+ return OS.str();
+}
diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
index 8b1d39ad412f..709095184af5 100644
--- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
+++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
@@ -23,8 +23,6 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -56,7 +54,7 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
const DataLayout &DL, TargetLibraryInfo &TLI,
ObjectSizeOffsetEvaluator &ObjSizeEval,
BuilderTy &IRB, ScalarEvolution &SE) {
- uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType());
+ TypeSize NeededSize = DL.getTypeStoreSize(InstVal->getType());
LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
<< " bytes\n");
@@ -71,8 +69,8 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
Value *Offset = SizeOffset.second;
ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size);
- Type *IntTy = DL.getIntPtrType(Ptr->getType());
- Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
+ Type *IndexTy = DL.getIndexType(Ptr->getType());
+ Value *NeededSizeVal = IRB.CreateTypeSize(IndexTy, NeededSize);
auto SizeRange = SE.getUnsignedRange(SE.getSCEV(Size));
auto OffsetRange = SE.getUnsignedRange(SE.getSCEV(Offset));
@@ -97,7 +95,7 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
Value *Or = IRB.CreateOr(Cmp2, Cmp3);
if ((!SizeCI || SizeCI->getValue().slt(0)) &&
!SizeRange.getSignedMin().isNonNegative()) {
- Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0));
+ Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IndexTy, 0));
Or = IRB.CreateOr(Cmp1, Or);
}
diff --git a/llvm/lib/Transforms/Instrumentation/CFGMST.h b/llvm/lib/Transforms/Instrumentation/CFGMST.h
deleted file mode 100644
index 2abe8d12de3c..000000000000
--- a/llvm/lib/Transforms/Instrumentation/CFGMST.h
+++ /dev/null
@@ -1,303 +0,0 @@
-//===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a Union-find algorithm to compute Minimum Spanning Tree
-// for a given CFG.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
-#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
-
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Analysis/BlockFrequencyInfo.h"
-#include "llvm/Analysis/BranchProbabilityInfo.h"
-#include "llvm/Analysis/CFG.h"
-#include "llvm/Support/BranchProbability.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include <utility>
-#include <vector>
-
-#define DEBUG_TYPE "cfgmst"
-
-using namespace llvm;
-
-namespace llvm {
-
-/// An union-find based Minimum Spanning Tree for CFG
-///
-/// Implements a Union-find algorithm to compute Minimum Spanning Tree
-/// for a given CFG.
-template <class Edge, class BBInfo> class CFGMST {
-public:
- Function &F;
-
- // Store all the edges in CFG. It may contain some stale edges
- // when Removed is set.
- std::vector<std::unique_ptr<Edge>> AllEdges;
-
- // This map records the auxiliary information for each BB.
- DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
-
- // Whehter the function has an exit block with no successors.
- // (For function with an infinite loop, this block may be absent)
- bool ExitBlockFound = false;
-
- // Find the root group of the G and compress the path from G to the root.
- BBInfo *findAndCompressGroup(BBInfo *G) {
- if (G->Group != G)
- G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
- return static_cast<BBInfo *>(G->Group);
- }
-
- // Union BB1 and BB2 into the same group and return true.
- // Returns false if BB1 and BB2 are already in the same group.
- bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
- BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
- BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
-
- if (BB1G == BB2G)
- return false;
-
- // Make the smaller rank tree a direct child or the root of high rank tree.
- if (BB1G->Rank < BB2G->Rank)
- BB1G->Group = BB2G;
- else {
- BB2G->Group = BB1G;
- // If the ranks are the same, increment root of one tree by one.
- if (BB1G->Rank == BB2G->Rank)
- BB1G->Rank++;
- }
- return true;
- }
-
- // Give BB, return the auxiliary information.
- BBInfo &getBBInfo(const BasicBlock *BB) const {
- auto It = BBInfos.find(BB);
- assert(It->second.get() != nullptr);
- return *It->second.get();
- }
-
- // Give BB, return the auxiliary information if it's available.
- BBInfo *findBBInfo(const BasicBlock *BB) const {
- auto It = BBInfos.find(BB);
- if (It == BBInfos.end())
- return nullptr;
- return It->second.get();
- }
-
- // Traverse the CFG using a stack. Find all the edges and assign the weight.
- // Edges with large weight will be put into MST first so they are less likely
- // to be instrumented.
- void buildEdges() {
- LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
-
- const BasicBlock *Entry = &(F.getEntryBlock());
- uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
- // If we want to instrument the entry count, lower the weight to 0.
- if (InstrumentFuncEntry)
- EntryWeight = 0;
- Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
- *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
- uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
-
- // Add a fake edge to the entry.
- EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
- LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName()
- << " w = " << EntryWeight << "\n");
-
- // Special handling for single BB functions.
- if (succ_empty(Entry)) {
- addEdge(Entry, nullptr, EntryWeight);
- return;
- }
-
- static const uint32_t CriticalEdgeMultiplier = 1000;
-
- for (BasicBlock &BB : F) {
- Instruction *TI = BB.getTerminator();
- uint64_t BBWeight =
- (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
- uint64_t Weight = 2;
- if (int successors = TI->getNumSuccessors()) {
- for (int i = 0; i != successors; ++i) {
- BasicBlock *TargetBB = TI->getSuccessor(i);
- bool Critical = isCriticalEdge(TI, i);
- uint64_t scaleFactor = BBWeight;
- if (Critical) {
- if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
- scaleFactor *= CriticalEdgeMultiplier;
- else
- scaleFactor = UINT64_MAX;
- }
- if (BPI != nullptr)
- Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
- if (Weight == 0)
- Weight++;
- auto *E = &addEdge(&BB, TargetBB, Weight);
- E->IsCritical = Critical;
- LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to "
- << TargetBB->getName() << " w=" << Weight << "\n");
-
- // Keep track of entry/exit edges:
- if (&BB == Entry) {
- if (Weight > MaxEntryOutWeight) {
- MaxEntryOutWeight = Weight;
- EntryOutgoing = E;
- }
- }
-
- auto *TargetTI = TargetBB->getTerminator();
- if (TargetTI && !TargetTI->getNumSuccessors()) {
- if (Weight > MaxExitInWeight) {
- MaxExitInWeight = Weight;
- ExitIncoming = E;
- }
- }
- }
- } else {
- ExitBlockFound = true;
- Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
- if (BBWeight > MaxExitOutWeight) {
- MaxExitOutWeight = BBWeight;
- ExitOutgoing = ExitO;
- }
- LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to fake exit"
- << " w = " << BBWeight << "\n");
- }
- }
-
- // Entry/exit edge adjustment heurisitic:
- // prefer instrumenting entry edge over exit edge
- // if possible. Those exit edges may never have a chance to be
- // executed (for instance the program is an event handling loop)
- // before the profile is asynchronously dumped.
- //
- // If EntryIncoming and ExitOutgoing has similar weight, make sure
- // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
- // and ExitIncoming has similar weight, make sure ExitIncoming becomes
- // the min-edge.
- uint64_t EntryInWeight = EntryWeight;
-
- if (EntryInWeight >= MaxExitOutWeight &&
- EntryInWeight * 2 < MaxExitOutWeight * 3) {
- EntryIncoming->Weight = MaxExitOutWeight;
- ExitOutgoing->Weight = EntryInWeight + 1;
- }
-
- if (MaxEntryOutWeight >= MaxExitInWeight &&
- MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
- EntryOutgoing->Weight = MaxExitInWeight;
- ExitIncoming->Weight = MaxEntryOutWeight + 1;
- }
- }
-
- // Sort CFG edges based on its weight.
- void sortEdgesByWeight() {
- llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
- const std::unique_ptr<Edge> &Edge2) {
- return Edge1->Weight > Edge2->Weight;
- });
- }
-
- // Traverse all the edges and compute the Minimum Weight Spanning Tree
- // using union-find algorithm.
- void computeMinimumSpanningTree() {
- // First, put all the critical edge with landing-pad as the Dest to MST.
- // This works around the insufficient support of critical edges split
- // when destination BB is a landing pad.
- for (auto &Ei : AllEdges) {
- if (Ei->Removed)
- continue;
- if (Ei->IsCritical) {
- if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
- if (unionGroups(Ei->SrcBB, Ei->DestBB))
- Ei->InMST = true;
- }
- }
- }
-
- for (auto &Ei : AllEdges) {
- if (Ei->Removed)
- continue;
- // If we detect infinite loops, force
- // instrumenting the entry edge:
- if (!ExitBlockFound && Ei->SrcBB == nullptr)
- continue;
- if (unionGroups(Ei->SrcBB, Ei->DestBB))
- Ei->InMST = true;
- }
- }
-
- // Dump the Debug information about the instrumentation.
- void dumpEdges(raw_ostream &OS, const Twine &Message) const {
- if (!Message.str().empty())
- OS << Message << "\n";
- OS << " Number of Basic Blocks: " << BBInfos.size() << "\n";
- for (auto &BI : BBInfos) {
- const BasicBlock *BB = BI.first;
- OS << " BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << " "
- << BI.second->infoString() << "\n";
- }
-
- OS << " Number of Edges: " << AllEdges.size()
- << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
- uint32_t Count = 0;
- for (auto &EI : AllEdges)
- OS << " Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
- << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
- }
-
- // Add an edge to AllEdges with weight W.
- Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
- uint32_t Index = BBInfos.size();
- auto Iter = BBInfos.end();
- bool Inserted;
- std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
- if (Inserted) {
- // Newly inserted, update the real info.
- Iter->second = std::move(std::make_unique<BBInfo>(Index));
- Index++;
- }
- std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
- if (Inserted)
- // Newly inserted, update the real info.
- Iter->second = std::move(std::make_unique<BBInfo>(Index));
- AllEdges.emplace_back(new Edge(Src, Dest, W));
- return *AllEdges.back();
- }
-
- BranchProbabilityInfo *BPI;
- BlockFrequencyInfo *BFI;
-
- // If function entry will be always instrumented.
- bool InstrumentFuncEntry;
-
-public:
- CFGMST(Function &Func, bool InstrumentFuncEntry_,
- BranchProbabilityInfo *BPI_ = nullptr,
- BlockFrequencyInfo *BFI_ = nullptr)
- : F(Func), BPI(BPI_), BFI(BFI_),
- InstrumentFuncEntry(InstrumentFuncEntry_) {
- buildEdges();
- sortEdgesByWeight();
- computeMinimumSpanningTree();
- if (AllEdges.size() > 1 && InstrumentFuncEntry)
- std::iter_swap(std::move(AllEdges.begin()),
- std::move(AllEdges.begin() + AllEdges.size() - 1));
- }
-};
-
-} // end namespace llvm
-
-#undef DEBUG_TYPE // "cfgmst"
-
-#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
diff --git a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
index 1c630e9ee424..d53e12ad1ff5 100644
--- a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
+++ b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
@@ -15,7 +15,6 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Transforms/Instrumentation.h"
#include <optional>
@@ -46,8 +45,7 @@ addModuleFlags(Module &M,
}
static bool runCGProfilePass(
- Module &M, function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
- function_ref<TargetTransformInfo &(Function &)> GetTTI, bool LazyBFI) {
+ Module &M, FunctionAnalysisManager &FAM) {
MapVector<std::pair<Function *, Function *>, uint64_t> Counts;
InstrProfSymtab Symtab;
auto UpdateCounts = [&](TargetTransformInfo &TTI, Function *F,
@@ -64,15 +62,13 @@ static bool runCGProfilePass(
(void)(bool) Symtab.create(M);
for (auto &F : M) {
// Avoid extra cost of running passes for BFI when the function doesn't have
- // entry count. Since LazyBlockFrequencyInfoPass only exists in LPM, check
- // if using LazyBlockFrequencyInfoPass.
- // TODO: Remove LazyBFI when LazyBlockFrequencyInfoPass is available in NPM.
- if (F.isDeclaration() || (LazyBFI && !F.getEntryCount()))
+ // entry count.
+ if (F.isDeclaration() || !F.getEntryCount())
continue;
- auto &BFI = GetBFI(F);
+ auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
if (BFI.getEntryFreq() == 0)
continue;
- TargetTransformInfo &TTI = GetTTI(F);
+ TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
for (auto &BB : F) {
std::optional<uint64_t> BBCount = BFI.getBlockProfileCount(&BB);
if (!BBCount)
@@ -105,14 +101,7 @@ static bool runCGProfilePass(
PreservedAnalyses CGProfilePass::run(Module &M, ModuleAnalysisManager &MAM) {
FunctionAnalysisManager &FAM =
MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & {
- return FAM.getResult<BlockFrequencyAnalysis>(F);
- };
- auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & {
- return FAM.getResult<TargetIRAnalysis>(F);
- };
-
- runCGProfilePass(M, GetBFI, GetTTI, false);
+ runCGProfilePass(M, FAM);
return PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index a072ba278fce..3e3be536defc 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -30,7 +30,6 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
@@ -1888,8 +1887,7 @@ void CHR::fixupBranch(Region *R, CHRScope *Scope,
assert((IsTrueBiased || Scope->FalseBiasedRegions.count(R)) &&
"Must be truthy or falsy");
auto *BI = cast<BranchInst>(R->getEntry()->getTerminator());
- assert(BranchBiasMap.find(R) != BranchBiasMap.end() &&
- "Must be in the bias map");
+ assert(BranchBiasMap.contains(R) && "Must be in the bias map");
BranchProbability Bias = BranchBiasMap[R];
assert(Bias >= getCHRBiasThreshold() && "Must be highly biased");
// Take the min.
@@ -1931,8 +1929,7 @@ void CHR::fixupSelect(SelectInst *SI, CHRScope *Scope,
bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI);
assert((IsTrueBiased ||
Scope->FalseBiasedSelects.count(SI)) && "Must be biased");
- assert(SelectBiasMap.find(SI) != SelectBiasMap.end() &&
- "Must be in the bias map");
+ assert(SelectBiasMap.contains(SI) && "Must be in the bias map");
BranchProbability Bias = SelectBiasMap[SI];
assert(Bias >= getCHRBiasThreshold() && "Must be highly biased");
// Take the min.
@@ -1962,11 +1959,8 @@ void CHR::addToMergedCondition(bool IsTrueBiased, Value *Cond,
Cond = IRB.CreateXor(ConstantInt::getTrue(F.getContext()), Cond);
}
- // Select conditions can be poison, while branching on poison is immediate
- // undefined behavior. As such, we need to freeze potentially poisonous
- // conditions derived from selects.
- if (isa<SelectInst>(BranchOrSelect) &&
- !isGuaranteedNotToBeUndefOrPoison(Cond))
+ // Freeze potentially poisonous conditions.
+ if (!isGuaranteedNotToBeUndefOrPoison(Cond))
Cond = IRB.CreateFreeze(Cond);
// Use logical and to avoid propagating poison from later conditions.
@@ -2080,10 +2074,14 @@ ControlHeightReductionPass::ControlHeightReductionPass() {
PreservedAnalyses ControlHeightReductionPass::run(
Function &F,
FunctionAnalysisManager &FAM) {
+ auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+ auto PPSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
+ // If there is no profile summary, we should not do CHR.
+ if (!PPSI || !PPSI->hasProfileSummary())
+ return PreservedAnalyses::all();
+ auto &PSI = *PPSI;
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
- auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
- auto &PSI = *MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
auto &RI = FAM.getResult<RegionInfoAnalysis>(F);
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
bool Changed = CHR(F, BFI, DT, PSI, RI, ORE).run();
diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index e9614b48fde7..8caee5bed8ed 100644
--- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -67,12 +67,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/iterator.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -96,14 +97,13 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SpecialCaseList.h"
#include "llvm/Support/VirtualFileSystem.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -305,6 +305,14 @@ const MemoryMapParams Linux_X86_64_MemoryMapParams = {
};
// NOLINTEND(readability-identifier-naming)
+// loongarch64 Linux
+const MemoryMapParams Linux_LoongArch64_MemoryMapParams = {
+ 0, // AndMask (not used)
+ 0x500000000000, // XorMask
+ 0, // ShadowBase (not used)
+ 0x100000000000, // OriginBase
+};
+
namespace {
class DFSanABIList {
@@ -1128,6 +1136,9 @@ bool DataFlowSanitizer::initializeModule(Module &M) {
case Triple::x86_64:
MapParams = &Linux_X86_64_MemoryMapParams;
break;
+ case Triple::loongarch64:
+ MapParams = &Linux_LoongArch64_MemoryMapParams;
+ break;
default:
report_fatal_error("unsupported architecture");
}
@@ -1256,7 +1267,7 @@ void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) {
size_t Pos = Asm.find(SearchStr);
if (Pos != std::string::npos) {
Asm.replace(Pos, SearchStr.size(), ".symver " + GVName + Suffix + ",");
- Pos = Asm.find("@");
+ Pos = Asm.find('@');
if (Pos == std::string::npos)
report_fatal_error(Twine("unsupported .symver: ", Asm));
@@ -2156,9 +2167,8 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast(
ShadowSize == 4 ? Type::getInt32Ty(*DFS.Ctx) : Type::getInt64Ty(*DFS.Ctx);
IRBuilder<> IRB(Pos);
- Value *WideAddr = IRB.CreateBitCast(ShadowAddr, WideShadowTy->getPointerTo());
Value *CombinedWideShadow =
- IRB.CreateAlignedLoad(WideShadowTy, WideAddr, ShadowAlign);
+ IRB.CreateAlignedLoad(WideShadowTy, ShadowAddr, ShadowAlign);
unsigned WideShadowBitWidth = WideShadowTy->getIntegerBitWidth();
const uint64_t BytesPerWideShadow = WideShadowBitWidth / DFS.ShadowWidthBits;
@@ -2195,10 +2205,10 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast(
// shadow).
for (uint64_t ByteOfs = BytesPerWideShadow; ByteOfs < Size;
ByteOfs += BytesPerWideShadow) {
- WideAddr = IRB.CreateGEP(WideShadowTy, WideAddr,
- ConstantInt::get(DFS.IntptrTy, 1));
+ ShadowAddr = IRB.CreateGEP(WideShadowTy, ShadowAddr,
+ ConstantInt::get(DFS.IntptrTy, 1));
Value *NextWideShadow =
- IRB.CreateAlignedLoad(WideShadowTy, WideAddr, ShadowAlign);
+ IRB.CreateAlignedLoad(WideShadowTy, ShadowAddr, ShadowAlign);
CombinedWideShadow = IRB.CreateOr(CombinedWideShadow, NextWideShadow);
if (ShouldTrackOrigins) {
Value *NextOrigin = DFS.loadNextOrigin(Pos, OriginAlign, &OriginAddr);
@@ -2526,8 +2536,9 @@ void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size,
ConstantInt::get(DFS.IntptrTy, Size), Origin});
} else {
Value *Cmp = convertToBool(CollapsedShadow, IRB, "_dfscmp");
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
Instruction *CheckTerm = SplitBlockAndInsertIfThen(
- Cmp, &*IRB.GetInsertPoint(), false, DFS.OriginStoreWeights, &DT);
+ Cmp, &*IRB.GetInsertPoint(), false, DFS.OriginStoreWeights, &DTU);
IRBuilder<> IRBNew(CheckTerm);
paintOrigin(IRBNew, updateOrigin(Origin, IRBNew), StoreOriginAddr, Size,
OriginAlignment);
diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
index 9f3ca8b02fd9..21f0b1a92293 100644
--- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
@@ -13,7 +13,6 @@
//
//===----------------------------------------------------------------------===//
-#include "CFGMST.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
@@ -21,10 +20,10 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DebugLoc.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
@@ -38,6 +37,7 @@
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Instrumentation.h"
+#include "llvm/Transforms/Instrumentation/CFGMST.h"
#include "llvm/Transforms/Instrumentation/GCOVProfiler.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
@@ -919,15 +919,21 @@ bool GCOVProfiler::emitProfileNotes(
IRBuilder<> Builder(E.Place, E.Place->getFirstInsertionPt());
Value *V = Builder.CreateConstInBoundsGEP2_64(
Counters->getValueType(), Counters, 0, I);
+ // Disable sanitizers to decrease size bloat. We don't expect
+ // sanitizers to catch interesting issues.
+ Instruction *Inst;
if (Options.Atomic) {
- Builder.CreateAtomicRMW(AtomicRMWInst::Add, V, Builder.getInt64(1),
- MaybeAlign(), AtomicOrdering::Monotonic);
+ Inst = Builder.CreateAtomicRMW(AtomicRMWInst::Add, V,
+ Builder.getInt64(1), MaybeAlign(),
+ AtomicOrdering::Monotonic);
} else {
- Value *Count =
+ LoadInst *OldCount =
Builder.CreateLoad(Builder.getInt64Ty(), V, "gcov_ctr");
- Count = Builder.CreateAdd(Count, Builder.getInt64(1));
- Builder.CreateStore(Count, V);
+ OldCount->setNoSanitizeMetadata();
+ Value *NewCount = Builder.CreateAdd(OldCount, Builder.getInt64(1));
+ Inst = Builder.CreateStore(NewCount, V);
}
+ Inst->setNoSanitizeMetadata();
}
}
}
diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
index 34c61f83ad30..28db47a19092 100644
--- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
@@ -1,4 +1,4 @@
-//===- HWAddressSanitizer.cpp - detector of uninitialized reads -------===//
+//===- HWAddressSanitizer.cpp - memory access error detector --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -17,7 +17,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/StackSafetyAnalysis.h"
@@ -50,6 +49,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/MemoryTaggingSupport.h"
@@ -136,14 +136,6 @@ static cl::opt<bool>
cl::desc("detect use after scope within function"),
cl::Hidden, cl::init(false));
-static cl::opt<bool> ClUARRetagToZero(
- "hwasan-uar-retag-to-zero",
- cl::desc("Clear alloca tags before returning from the function to allow "
- "non-instrumented and instrumented function calls mix. When set "
- "to false, allocas are retagged before returning from the "
- "function to detect use after return."),
- cl::Hidden, cl::init(true));
-
static cl::opt<bool> ClGenerateTagsWithCalls(
"hwasan-generate-tags-with-calls",
cl::desc("generate new tags with runtime library calls"), cl::Hidden,
@@ -247,7 +239,9 @@ bool shouldInstrumentStack(const Triple &TargetTriple) {
}
bool shouldInstrumentWithCalls(const Triple &TargetTriple) {
- return ClInstrumentWithCalls || TargetTriple.getArch() == Triple::x86_64;
+ return ClInstrumentWithCalls.getNumOccurrences()
+ ? ClInstrumentWithCalls
+ : TargetTriple.getArch() == Triple::x86_64;
}
bool mightUseStackSafetyAnalysis(bool DisableOptimization) {
@@ -282,7 +276,7 @@ public:
void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; }
- bool sanitizeFunction(Function &F, FunctionAnalysisManager &FAM);
+ void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM);
void initializeModule();
void createHwasanCtorComdat();
@@ -313,16 +307,15 @@ public:
void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size);
Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag);
Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong);
- bool instrumentStack(memtag::StackInfo &Info, Value *StackTag,
+ bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, Value *UARTag,
const DominatorTree &DT, const PostDominatorTree &PDT,
const LoopInfo &LI);
Value *readRegister(IRBuilder<> &IRB, StringRef Name);
bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec);
Value *getNextTagWithCall(IRBuilder<> &IRB);
Value *getStackBaseTag(IRBuilder<> &IRB);
- Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, AllocaInst *AI,
- unsigned AllocaNo);
- Value *getUARTag(IRBuilder<> &IRB, Value *StackTag);
+ Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, unsigned AllocaNo);
+ Value *getUARTag(IRBuilder<> &IRB);
Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty);
Value *applyTagMask(IRBuilder<> &IRB, Value *OldTag);
@@ -344,8 +337,6 @@ private:
Module &M;
const StackSafetyGlobalInfo *SSI;
Triple TargetTriple;
- FunctionCallee HWAsanMemmove, HWAsanMemcpy, HWAsanMemset;
- FunctionCallee HWAsanHandleVfork;
/// This struct defines the shadow mapping using the rule:
/// shadow = (mem >> Scale) + Offset.
@@ -387,6 +378,7 @@ private:
bool InstrumentStack;
bool DetectUseAfterScope;
bool UsePageAliases;
+ bool UseMatchAllCallback;
std::optional<uint8_t> MatchAllTag;
@@ -398,6 +390,9 @@ private:
FunctionCallee HwasanMemoryAccessCallback[2][kNumberOfAccessSizes];
FunctionCallee HwasanMemoryAccessCallbackSized[2];
+ FunctionCallee HwasanMemmove, HwasanMemcpy, HwasanMemset;
+ FunctionCallee HwasanHandleVfork;
+
FunctionCallee HwasanTagMemoryFunc;
FunctionCallee HwasanGenerateTagFunc;
FunctionCallee HwasanRecordFrameRecordFunc;
@@ -420,12 +415,9 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M,
SSI = &MAM.getResult<StackSafetyGlobalAnalysis>(M);
HWAddressSanitizer HWASan(M, Options.CompileKernel, Options.Recover, SSI);
- bool Modified = false;
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
for (Function &F : M)
- Modified |= HWASan.sanitizeFunction(F, FAM);
- if (!Modified)
- return PreservedAnalyses::all();
+ HWASan.sanitizeFunction(F, FAM);
PreservedAnalyses PA = PreservedAnalyses::none();
// GlobalsAA is considered stateless and does not get invalidated unless
@@ -438,12 +430,12 @@ void HWAddressSanitizerPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<HWAddressSanitizerPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Options.CompileKernel)
OS << "kernel;";
if (Options.Recover)
OS << "recover";
- OS << ">";
+ OS << '>';
}
void HWAddressSanitizer::createHwasanCtorComdat() {
@@ -594,6 +586,7 @@ void HWAddressSanitizer::initializeModule() {
} else if (CompileKernel) {
MatchAllTag = 0xFF;
}
+ UseMatchAllCallback = !CompileKernel && MatchAllTag.has_value();
// If we don't have personality function support, fall back to landing pads.
InstrumentLandingPads = ClInstrumentLandingPads.getNumOccurrences()
@@ -631,51 +624,73 @@ void HWAddressSanitizer::initializeModule() {
void HWAddressSanitizer::initializeCallbacks(Module &M) {
IRBuilder<> IRB(*C);
+ const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : "";
+ FunctionType *HwasanMemoryAccessCallbackSizedFnTy,
+ *HwasanMemoryAccessCallbackFnTy, *HwasanMemTransferFnTy,
+ *HwasanMemsetFnTy;
+ if (UseMatchAllCallback) {
+ HwasanMemoryAccessCallbackSizedFnTy =
+ FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false);
+ HwasanMemoryAccessCallbackFnTy =
+ FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false);
+ HwasanMemTransferFnTy = FunctionType::get(
+ Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false);
+ HwasanMemsetFnTy = FunctionType::get(
+ Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false);
+ } else {
+ HwasanMemoryAccessCallbackSizedFnTy =
+ FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false);
+ HwasanMemoryAccessCallbackFnTy =
+ FunctionType::get(VoidTy, {IntptrTy}, false);
+ HwasanMemTransferFnTy =
+ FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false);
+ HwasanMemsetFnTy =
+ FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false);
+ }
+
for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) {
const std::string TypeStr = AccessIsWrite ? "store" : "load";
const std::string EndingStr = Recover ? "_noabort" : "";
HwasanMemoryAccessCallbackSized[AccessIsWrite] = M.getOrInsertFunction(
- ClMemoryAccessCallbackPrefix + TypeStr + "N" + EndingStr,
- FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false));
+ ClMemoryAccessCallbackPrefix + TypeStr + "N" + MatchAllStr + EndingStr,
+ HwasanMemoryAccessCallbackSizedFnTy);
for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;
AccessSizeIndex++) {
HwasanMemoryAccessCallback[AccessIsWrite][AccessSizeIndex] =
- M.getOrInsertFunction(
- ClMemoryAccessCallbackPrefix + TypeStr +
- itostr(1ULL << AccessSizeIndex) + EndingStr,
- FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false));
+ M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + TypeStr +
+ itostr(1ULL << AccessSizeIndex) +
+ MatchAllStr + EndingStr,
+ HwasanMemoryAccessCallbackFnTy);
}
}
- HwasanTagMemoryFunc = M.getOrInsertFunction(
- "__hwasan_tag_memory", IRB.getVoidTy(), Int8PtrTy, Int8Ty, IntptrTy);
+ const std::string MemIntrinCallbackPrefix =
+ (CompileKernel && !ClKasanMemIntrinCallbackPrefix)
+ ? std::string("")
+ : ClMemoryAccessCallbackPrefix;
+
+ HwasanMemmove = M.getOrInsertFunction(
+ MemIntrinCallbackPrefix + "memmove" + MatchAllStr, HwasanMemTransferFnTy);
+ HwasanMemcpy = M.getOrInsertFunction(
+ MemIntrinCallbackPrefix + "memcpy" + MatchAllStr, HwasanMemTransferFnTy);
+ HwasanMemset = M.getOrInsertFunction(
+ MemIntrinCallbackPrefix + "memset" + MatchAllStr, HwasanMemsetFnTy);
+
+ HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy,
+ Int8PtrTy, Int8Ty, IntptrTy);
HwasanGenerateTagFunc =
M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty);
- HwasanRecordFrameRecordFunc = M.getOrInsertFunction(
- "__hwasan_add_frame_record", IRB.getVoidTy(), Int64Ty);
+ HwasanRecordFrameRecordFunc =
+ M.getOrInsertFunction("__hwasan_add_frame_record", VoidTy, Int64Ty);
- ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow",
- ArrayType::get(IRB.getInt8Ty(), 0));
+ ShadowGlobal =
+ M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(Int8Ty, 0));
- const std::string MemIntrinCallbackPrefix =
- (CompileKernel && !ClKasanMemIntrinCallbackPrefix)
- ? std::string("")
- : ClMemoryAccessCallbackPrefix;
- HWAsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy);
- HWAsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy);
- HWAsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt32Ty(), IntptrTy);
-
- HWAsanHandleVfork =
- M.getOrInsertFunction("__hwasan_handle_vfork", IRB.getVoidTy(), IntptrTy);
+ HwasanHandleVfork =
+ M.getOrInsertFunction("__hwasan_handle_vfork", VoidTy, IntptrTy);
}
Value *HWAddressSanitizer::getOpaqueNoopCast(IRBuilder<> &IRB, Value *Val) {
@@ -788,7 +803,7 @@ static unsigned getPointerOperandIndex(Instruction *I) {
}
static size_t TypeSizeToSizeIndex(uint32_t TypeSize) {
- size_t Res = countTrailingZeros(TypeSize / 8);
+ size_t Res = llvm::countr_zero(TypeSize / 8);
assert(Res < kNumberOfAccessSizes);
return Res;
}
@@ -847,8 +862,8 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
IRBuilder<> IRB(InsertBefore);
Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy);
- Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift),
- IRB.getInt8Ty());
+ Value *PtrTag =
+ IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), Int8Ty);
Value *AddrLong = untagPointer(IRB, PtrLong);
Value *Shadow = memToShadow(AddrLong, IRB);
Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow);
@@ -897,7 +912,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::x86_64:
// The signal handler will find the data address in rdi.
Asm = InlineAsm::get(
- FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {PtrLong->getType()}, false),
"int3\nnopl " +
itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)) +
"(%rax)",
@@ -908,7 +923,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::aarch64_be:
// The signal handler will find the data address in x0.
Asm = InlineAsm::get(
- FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {PtrLong->getType()}, false),
"brk #" + itostr(0x900 + (AccessInfo & HWASanAccessInfo::RuntimeMask)),
"{x0}",
/*hasSideEffects=*/true);
@@ -916,7 +931,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::riscv64:
// The signal handler will find the data address in x10.
Asm = InlineAsm::get(
- FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {PtrLong->getType()}, false),
"ebreak\naddiw x0, x11, " +
itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)),
"{x10}",
@@ -943,17 +958,35 @@ bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) {
void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
IRBuilder<> IRB(MI);
if (isa<MemTransferInst>(MI)) {
- IRB.CreateCall(
- isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ if (UseMatchAllCallback) {
+ IRB.CreateCall(
+ isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy,
+ {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
+ ConstantInt::get(Int8Ty, *MatchAllTag)});
+ } else {
+ IRB.CreateCall(
+ isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy,
+ {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ }
} else if (isa<MemSetInst>(MI)) {
- IRB.CreateCall(
- HWAsanMemset,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ if (UseMatchAllCallback) {
+ IRB.CreateCall(
+ HwasanMemset,
+ {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
+ ConstantInt::get(Int8Ty, *MatchAllTag)});
+ } else {
+ IRB.CreateCall(
+ HwasanMemset,
+ {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ }
}
MI->eraseFromParent();
}
@@ -967,23 +1000,40 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) {
return false; // FIXME
IRBuilder<> IRB(O.getInsn());
- if (isPowerOf2_64(O.TypeSize) &&
- (O.TypeSize / 8 <= (1ULL << (kNumberOfAccessSizes - 1))) &&
+ if (!O.TypeStoreSize.isScalable() && isPowerOf2_64(O.TypeStoreSize) &&
+ (O.TypeStoreSize / 8 <= (1ULL << (kNumberOfAccessSizes - 1))) &&
(!O.Alignment || *O.Alignment >= Mapping.getObjectAlignment() ||
- *O.Alignment >= O.TypeSize / 8)) {
- size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeSize);
+ *O.Alignment >= O.TypeStoreSize / 8)) {
+ size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeStoreSize);
if (InstrumentWithCalls) {
- IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
- IRB.CreatePointerCast(Addr, IntptrTy));
+ if (UseMatchAllCallback) {
+ IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
+ {IRB.CreatePointerCast(Addr, IntptrTy),
+ ConstantInt::get(Int8Ty, *MatchAllTag)});
+ } else {
+ IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
+ IRB.CreatePointerCast(Addr, IntptrTy));
+ }
} else if (OutlinedChecks) {
instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn());
} else {
instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn());
}
} else {
- IRB.CreateCall(HwasanMemoryAccessCallbackSized[O.IsWrite],
- {IRB.CreatePointerCast(Addr, IntptrTy),
- ConstantInt::get(IntptrTy, O.TypeSize / 8)});
+ if (UseMatchAllCallback) {
+ IRB.CreateCall(
+ HwasanMemoryAccessCallbackSized[O.IsWrite],
+ {IRB.CreatePointerCast(Addr, IntptrTy),
+ IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize),
+ ConstantInt::get(IntptrTy, 8)),
+ ConstantInt::get(Int8Ty, *MatchAllTag)});
+ } else {
+ IRB.CreateCall(
+ HwasanMemoryAccessCallbackSized[O.IsWrite],
+ {IRB.CreatePointerCast(Addr, IntptrTy),
+ IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize),
+ ConstantInt::get(IntptrTy, 8))});
+ }
}
untagPointerOperand(O.getInsn(), Addr);
@@ -996,14 +1046,15 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag,
if (!UseShortGranules)
Size = AlignedSize;
- Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty());
+ Tag = IRB.CreateTrunc(Tag, Int8Ty);
if (InstrumentWithCalls) {
IRB.CreateCall(HwasanTagMemoryFunc,
- {IRB.CreatePointerCast(AI, Int8PtrTy), JustTag,
+ {IRB.CreatePointerCast(AI, Int8PtrTy), Tag,
ConstantInt::get(IntptrTy, AlignedSize)});
} else {
size_t ShadowSize = Size >> Mapping.Scale;
- Value *ShadowPtr = memToShadow(IRB.CreatePointerCast(AI, IntptrTy), IRB);
+ Value *AddrLong = untagPointer(IRB, IRB.CreatePointerCast(AI, IntptrTy));
+ Value *ShadowPtr = memToShadow(AddrLong, IRB);
// If this memset is not inlined, it will be intercepted in the hwasan
// runtime library. That's OK, because the interceptor skips the checks if
// the address is in the shadow region.
@@ -1011,14 +1062,14 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag,
// llvm.memset right here into either a sequence of stores, or a call to
// hwasan_tag_memory.
if (ShadowSize)
- IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, Align(1));
+ IRB.CreateMemSet(ShadowPtr, Tag, ShadowSize, Align(1));
if (Size != AlignedSize) {
const uint8_t SizeRemainder = Size % Mapping.getObjectAlignment().value();
IRB.CreateStore(ConstantInt::get(Int8Ty, SizeRemainder),
IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize));
- IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32(
- Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy),
- AlignedSize - 1));
+ IRB.CreateStore(Tag, IRB.CreateConstGEP1_32(
+ Int8Ty, IRB.CreatePointerCast(AI, Int8PtrTy),
+ AlignedSize - 1));
}
}
}
@@ -1037,21 +1088,18 @@ unsigned HWAddressSanitizer::retagMask(unsigned AllocaNo) {
// mask allocated (temporally) nearby. The program that generated this list
// can be found at:
// https://github.com/google/sanitizers/blob/master/hwaddress-sanitizer/sort_masks.py
- static unsigned FastMasks[] = {0, 128, 64, 192, 32, 96, 224, 112, 240,
- 48, 16, 120, 248, 56, 24, 8, 124, 252,
- 60, 28, 12, 4, 126, 254, 62, 30, 14,
- 6, 2, 127, 63, 31, 15, 7, 3, 1};
+ static const unsigned FastMasks[] = {
+ 0, 128, 64, 192, 32, 96, 224, 112, 240, 48, 16, 120,
+ 248, 56, 24, 8, 124, 252, 60, 28, 12, 4, 126, 254,
+ 62, 30, 14, 6, 2, 127, 63, 31, 15, 7, 3, 1};
return FastMasks[AllocaNo % std::size(FastMasks)];
}
Value *HWAddressSanitizer::applyTagMask(IRBuilder<> &IRB, Value *OldTag) {
- if (TargetTriple.getArch() == Triple::x86_64) {
- Constant *TagMask = ConstantInt::get(IntptrTy, TagMaskByte);
- Value *NewTag = IRB.CreateAnd(OldTag, TagMask);
- return NewTag;
- }
- // aarch64 uses 8-bit tags, so no mask is needed.
- return OldTag;
+ if (TagMaskByte == 0xFF)
+ return OldTag; // No need to clear the tag byte.
+ return IRB.CreateAnd(OldTag,
+ ConstantInt::get(OldTag->getType(), TagMaskByte));
}
Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) {
@@ -1060,7 +1108,7 @@ Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) {
Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) {
if (ClGenerateTagsWithCalls)
- return getNextTagWithCall(IRB);
+ return nullptr;
if (StackBaseTag)
return StackBaseTag;
// Extract some entropy from the stack pointer for the tags.
@@ -1075,19 +1123,20 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) {
}
Value *HWAddressSanitizer::getAllocaTag(IRBuilder<> &IRB, Value *StackTag,
- AllocaInst *AI, unsigned AllocaNo) {
+ unsigned AllocaNo) {
if (ClGenerateTagsWithCalls)
return getNextTagWithCall(IRB);
- return IRB.CreateXor(StackTag,
- ConstantInt::get(IntptrTy, retagMask(AllocaNo)));
+ return IRB.CreateXor(
+ StackTag, ConstantInt::get(StackTag->getType(), retagMask(AllocaNo)));
}
-Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB, Value *StackTag) {
- if (ClUARRetagToZero)
- return ConstantInt::get(IntptrTy, 0);
- if (ClGenerateTagsWithCalls)
- return getNextTagWithCall(IRB);
- return IRB.CreateXor(StackTag, ConstantInt::get(IntptrTy, TagMaskByte));
+Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB) {
+ Value *StackPointerLong = getSP(IRB);
+ Value *UARTag =
+ applyTagMask(IRB, IRB.CreateLShr(StackPointerLong, PointerTagShift));
+
+ UARTag->setName("hwasan.uar.tag");
+ return UARTag;
}
// Add a tag to an address.
@@ -1117,12 +1166,12 @@ Value *HWAddressSanitizer::untagPointer(IRBuilder<> &IRB, Value *PtrLong) {
// Kernel addresses have 0xFF in the most significant byte.
UntaggedPtrLong =
IRB.CreateOr(PtrLong, ConstantInt::get(PtrLong->getType(),
- 0xFFULL << PointerTagShift));
+ TagMaskByte << PointerTagShift));
} else {
// Userspace addresses have 0x00.
- UntaggedPtrLong =
- IRB.CreateAnd(PtrLong, ConstantInt::get(PtrLong->getType(),
- ~(0xFFULL << PointerTagShift)));
+ UntaggedPtrLong = IRB.CreateAnd(
+ PtrLong, ConstantInt::get(PtrLong->getType(),
+ ~(TagMaskByte << PointerTagShift)));
}
return UntaggedPtrLong;
}
@@ -1135,8 +1184,7 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) {
Function *ThreadPointerFunc =
Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
Value *SlotPtr = IRB.CreatePointerCast(
- IRB.CreateConstGEP1_32(IRB.getInt8Ty(),
- IRB.CreateCall(ThreadPointerFunc), 0x30),
+ IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), 0x30),
Ty->getPointerTo(0));
return SlotPtr;
}
@@ -1162,8 +1210,7 @@ Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) {
M, Intrinsic::frameaddress,
IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
CachedSP = IRB.CreatePtrToInt(
- IRB.CreateCall(GetStackPointerFn,
- {Constant::getNullValue(IRB.getInt32Ty())}),
+ IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(Int32Ty)}),
IntptrTy);
}
return CachedSP;
@@ -1280,7 +1327,7 @@ bool HWAddressSanitizer::instrumentLandingPads(
for (auto *LP : LandingPadVec) {
IRBuilder<> IRB(LP->getNextNode());
IRB.CreateCall(
- HWAsanHandleVfork,
+ HwasanHandleVfork,
{readRegister(IRB, (TargetTriple.getArch() == Triple::x86_64) ? "rsp"
: "sp")});
}
@@ -1293,7 +1340,7 @@ static bool isLifetimeIntrinsic(Value *V) {
}
bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
- Value *StackTag,
+ Value *StackTag, Value *UARTag,
const DominatorTree &DT,
const PostDominatorTree &PDT,
const LoopInfo &LI) {
@@ -1311,9 +1358,10 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
IRBuilder<> IRB(AI->getNextNode());
// Replace uses of the alloca with tagged address.
- Value *Tag = getAllocaTag(IRB, StackTag, AI, N);
+ Value *Tag = getAllocaTag(IRB, StackTag, N);
Value *AILong = IRB.CreatePointerCast(AI, IntptrTy);
- Value *Replacement = tagPointer(IRB, AI->getType(), AILong, Tag);
+ Value *AINoTagLong = untagPointer(IRB, AILong);
+ Value *Replacement = tagPointer(IRB, AI->getType(), AINoTagLong, Tag);
std::string Name =
AI->hasName() ? AI->getName().str() : "alloca." + itostr(N);
Replacement->setName(Name + ".hwasan");
@@ -1340,7 +1388,7 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
llvm::for_each(Info.LifetimeStart, HandleLifetime);
llvm::for_each(Info.LifetimeEnd, HandleLifetime);
- AI->replaceUsesWithIf(Replacement, [AICast, AILong](Use &U) {
+ AI->replaceUsesWithIf(Replacement, [AICast, AILong](const Use &U) {
auto *User = U.getUser();
return User != AILong && User != AICast && !isLifetimeIntrinsic(User);
});
@@ -1359,9 +1407,8 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
auto TagEnd = [&](Instruction *Node) {
IRB.SetInsertPoint(Node);
- Value *UARTag = getUARTag(IRB, StackTag);
// When untagging, use the `AlignedSize` because we need to set the tags
- // for the entire alloca to zero. If we used `Size` here, we would
+ // for the entire alloca to original. If we used `Size` here, we would
// keep the last granule tagged, and store zero in the last byte of the
// last granule, due to how short granules are implemented.
tagAlloca(IRB, AI, UARTag, AlignedSize);
@@ -1402,13 +1449,13 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
return true;
}
-bool HWAddressSanitizer::sanitizeFunction(Function &F,
+void HWAddressSanitizer::sanitizeFunction(Function &F,
FunctionAnalysisManager &FAM) {
if (&F == HwasanCtorFunction)
- return false;
+ return;
if (!F.hasFnAttribute(Attribute::SanitizeHWAddress))
- return false;
+ return;
LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n");
@@ -1436,22 +1483,19 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F,
initializeCallbacks(*F.getParent());
- bool Changed = false;
-
if (!LandingPadVec.empty())
- Changed |= instrumentLandingPads(LandingPadVec);
+ instrumentLandingPads(LandingPadVec);
if (SInfo.AllocasToInstrument.empty() && F.hasPersonalityFn() &&
F.getPersonalityFn()->getName() == kHwasanPersonalityThunkName) {
// __hwasan_personality_thunk is a no-op for functions without an
// instrumented stack, so we can drop it.
F.setPersonalityFn(nullptr);
- Changed = true;
}
if (SInfo.AllocasToInstrument.empty() && OperandsToInstrument.empty() &&
IntrinToInstrument.empty())
- return Changed;
+ return;
assert(!ShadowBase);
@@ -1466,9 +1510,9 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F,
const DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
const PostDominatorTree &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
const LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
- Value *StackTag =
- ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB);
- instrumentStack(SInfo, StackTag, DT, PDT, LI);
+ Value *StackTag = getStackBaseTag(EntryIRB);
+ Value *UARTag = getUARTag(EntryIRB);
+ instrumentStack(SInfo, StackTag, UARTag, DT, PDT, LI);
}
// If we split the entry block, move any allocas that were originally in the
@@ -1495,8 +1539,6 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F,
ShadowBase = nullptr;
StackBaseTag = nullptr;
CachedSP = nullptr;
-
- return true;
}
void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) {
@@ -1605,11 +1647,14 @@ void HWAddressSanitizer::instrumentGlobals() {
Hasher.final(Hash);
uint8_t Tag = Hash[0];
+ assert(TagMaskByte >= 16);
+
for (GlobalVariable *GV : Globals) {
- Tag &= TagMaskByte;
- // Skip tag 0 in order to avoid collisions with untagged memory.
- if (Tag == 0)
- Tag = 1;
+ // Don't allow globals to be tagged with something that looks like a
+ // short-granule tag, otherwise we lose inter-granule overflow detection, as
+ // the fast path shadow-vs-address check succeeds.
+ if (Tag < 16 || Tag > TagMaskByte)
+ Tag = 16;
instrumentGlobal(GV, Tag++);
}
}
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index b66e761d53b0..5c9799235017 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -104,25 +104,24 @@ static cl::opt<bool>
namespace {
-// The class for main data structure to promote indirect calls to conditional
-// direct calls.
-class ICallPromotionFunc {
+// Promote indirect calls to conditional direct calls, keeping track of
+// thresholds.
+class IndirectCallPromoter {
private:
Function &F;
- Module *M;
// Symtab that maps indirect call profile values to function names and
// defines.
- InstrProfSymtab *Symtab;
+ InstrProfSymtab *const Symtab;
- bool SamplePGO;
+ const bool SamplePGO;
OptimizationRemarkEmitter &ORE;
// A struct that records the direct target and it's call count.
struct PromotionCandidate {
- Function *TargetFunction;
- uint64_t Count;
+ Function *const TargetFunction;
+ const uint64_t Count;
PromotionCandidate(Function *F, uint64_t C) : TargetFunction(F), Count(C) {}
};
@@ -143,11 +142,11 @@ private:
uint64_t &TotalCount);
public:
- ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab,
- bool SamplePGO, OptimizationRemarkEmitter &ORE)
- : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {}
- ICallPromotionFunc(const ICallPromotionFunc &) = delete;
- ICallPromotionFunc &operator=(const ICallPromotionFunc &) = delete;
+ IndirectCallPromoter(Function &Func, InstrProfSymtab *Symtab, bool SamplePGO,
+ OptimizationRemarkEmitter &ORE)
+ : F(Func), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {}
+ IndirectCallPromoter(const IndirectCallPromoter &) = delete;
+ IndirectCallPromoter &operator=(const IndirectCallPromoter &) = delete;
bool processFunction(ProfileSummaryInfo *PSI);
};
@@ -156,8 +155,8 @@ public:
// Indirect-call promotion heuristic. The direct targets are sorted based on
// the count. Stop at the first target that is not promoted.
-std::vector<ICallPromotionFunc::PromotionCandidate>
-ICallPromotionFunc::getPromotionCandidatesForCallSite(
+std::vector<IndirectCallPromoter::PromotionCandidate>
+IndirectCallPromoter::getPromotionCandidatesForCallSite(
const CallBase &CB, const ArrayRef<InstrProfValueData> &ValueDataRef,
uint64_t TotalCount, uint32_t NumCandidates) {
std::vector<PromotionCandidate> Ret;
@@ -276,7 +275,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
}
// Promote indirect-call to conditional direct-call for one callsite.
-uint32_t ICallPromotionFunc::tryToPromote(
+uint32_t IndirectCallPromoter::tryToPromote(
CallBase &CB, const std::vector<PromotionCandidate> &Candidates,
uint64_t &TotalCount) {
uint32_t NumPromoted = 0;
@@ -295,7 +294,7 @@ uint32_t ICallPromotionFunc::tryToPromote(
// Traverse all the indirect-call callsite and get the value profile
// annotation to perform indirect-call promotion.
-bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) {
+bool IndirectCallPromoter::processFunction(ProfileSummaryInfo *PSI) {
bool Changed = false;
ICallPromotionAnalysis ICallAnalysis;
for (auto *CB : findIndirectCalls(F)) {
@@ -319,16 +318,15 @@ bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) {
if (TotalCount == 0 || NumPromoted == NumVals)
continue;
// Otherwise we need update with the un-promoted records back.
- annotateValueSite(*M, *CB, ICallProfDataRef.slice(NumPromoted), TotalCount,
- IPVK_IndirectCallTarget, NumCandidates);
+ annotateValueSite(*F.getParent(), *CB, ICallProfDataRef.slice(NumPromoted),
+ TotalCount, IPVK_IndirectCallTarget, NumCandidates);
}
return Changed;
}
// A wrapper function that does the actual work.
-static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI,
- bool InLTO, bool SamplePGO,
- ModuleAnalysisManager *AM = nullptr) {
+static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool InLTO,
+ bool SamplePGO, ModuleAnalysisManager &MAM) {
if (DisableICP)
return false;
InstrProfSymtab Symtab;
@@ -342,19 +340,12 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI,
if (F.isDeclaration() || F.hasOptNone())
continue;
- std::unique_ptr<OptimizationRemarkEmitter> OwnedORE;
- OptimizationRemarkEmitter *ORE;
- if (AM) {
- auto &FAM =
- AM->getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
- ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- } else {
- OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F);
- ORE = OwnedORE.get();
- }
+ auto &FAM =
+ MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO, *ORE);
- bool FuncChanged = ICallPromotion.processFunction(PSI);
+ IndirectCallPromoter CallPromoter(F, &Symtab, SamplePGO, ORE);
+ bool FuncChanged = CallPromoter.processFunction(PSI);
if (ICPDUMPAFTER && FuncChanged) {
LLVM_DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs()));
LLVM_DEBUG(dbgs() << "\n");
@@ -369,11 +360,11 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI,
}
PreservedAnalyses PGOIndirectCallPromotion::run(Module &M,
- ModuleAnalysisManager &AM) {
- ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
+ ModuleAnalysisManager &MAM) {
+ ProfileSummaryInfo *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M);
if (!promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode,
- SamplePGO | ICPSamplePGOMode, &AM))
+ SamplePGO | ICPSamplePGOMode, MAM))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
diff --git a/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp
index d7561c193aa3..6882dd83f429 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp
@@ -15,9 +15,6 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/PassRegistry.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileSystem.h"
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index c0409206216e..a7b1953ce81c 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -16,7 +16,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
@@ -47,6 +46,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
@@ -421,6 +421,9 @@ bool InstrProfiling::lowerIntrinsics(Function *F) {
} else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(&Instr)) {
lowerIncrement(IPI);
MadeChange = true;
+ } else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(&Instr)) {
+ lowerTimestamp(IPC);
+ MadeChange = true;
} else if (auto *IPC = dyn_cast<InstrProfCoverInst>(&Instr)) {
lowerCover(IPC);
MadeChange = true;
@@ -510,6 +513,7 @@ static bool containsProfilingIntrinsics(Module &M) {
return containsIntrinsic(llvm::Intrinsic::instrprof_cover) ||
containsIntrinsic(llvm::Intrinsic::instrprof_increment) ||
containsIntrinsic(llvm::Intrinsic::instrprof_increment_step) ||
+ containsIntrinsic(llvm::Intrinsic::instrprof_timestamp) ||
containsIntrinsic(llvm::Intrinsic::instrprof_value_profile);
}
@@ -540,18 +544,19 @@ bool InstrProfiling::run(
// the instrumented function. This is counting the number of instrumented
// target value sites to enter it as field in the profile data variable.
for (Function &F : M) {
- InstrProfIncrementInst *FirstProfIncInst = nullptr;
+ InstrProfInstBase *FirstProfInst = nullptr;
for (BasicBlock &BB : F)
for (auto I = BB.begin(), E = BB.end(); I != E; I++)
if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I))
computeNumValueSiteCounts(Ind);
- else if (FirstProfIncInst == nullptr)
- FirstProfIncInst = dyn_cast<InstrProfIncrementInst>(I);
+ else if (FirstProfInst == nullptr &&
+ (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I)))
+ FirstProfInst = dyn_cast<InstrProfInstBase>(I);
// Value profiling intrinsic lowering requires per-function profile data
// variable to be created first.
- if (FirstProfIncInst != nullptr)
- static_cast<void>(getOrCreateRegionCounters(FirstProfIncInst));
+ if (FirstProfInst != nullptr)
+ static_cast<void>(getOrCreateRegionCounters(FirstProfInst));
}
for (Function &F : M)
@@ -669,6 +674,9 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) {
auto *Counters = getOrCreateRegionCounters(I);
IRBuilder<> Builder(I);
+ if (isa<InstrProfTimestampInst>(I))
+ Counters->setAlignment(Align(8));
+
auto *Addr = Builder.CreateConstInBoundsGEP2_32(
Counters->getValueType(), Counters, 0, I->getIndex()->getZExtValue());
@@ -710,6 +718,21 @@ void InstrProfiling::lowerCover(InstrProfCoverInst *CoverInstruction) {
CoverInstruction->eraseFromParent();
}
+void InstrProfiling::lowerTimestamp(
+ InstrProfTimestampInst *TimestampInstruction) {
+ assert(TimestampInstruction->getIndex()->isZeroValue() &&
+ "timestamp probes are always the first probe for a function");
+ auto &Ctx = M->getContext();
+ auto *TimestampAddr = getCounterAddress(TimestampInstruction);
+ IRBuilder<> Builder(TimestampInstruction);
+ auto *CalleeTy =
+ FunctionType::get(Type::getVoidTy(Ctx), TimestampAddr->getType(), false);
+ auto Callee = M->getOrInsertFunction(
+ INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SET_TIMESTAMP), CalleeTy);
+ Builder.CreateCall(Callee, {TimestampAddr});
+ TimestampInstruction->eraseFromParent();
+}
+
void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) {
auto *Addr = getCounterAddress(Inc);
@@ -823,6 +846,72 @@ static inline bool shouldRecordFunctionAddr(Function *F) {
return F->hasAddressTaken() || F->hasLinkOnceLinkage();
}
+static inline bool shouldUsePublicSymbol(Function *Fn) {
+ // It isn't legal to make an alias of this function at all
+ if (Fn->isDeclarationForLinker())
+ return true;
+
+ // Symbols with local linkage can just use the symbol directly without
+ // introducing relocations
+ if (Fn->hasLocalLinkage())
+ return true;
+
+ // PGO + ThinLTO + CFI cause duplicate symbols to be introduced due to some
+ // unfavorable interaction between the new alias and the alias renaming done
+ // in LowerTypeTests under ThinLTO. For comdat functions that would normally
+ // be deduplicated, but the renaming scheme ends up preventing renaming, since
+ // it creates unique names for each alias, resulting in duplicated symbols. In
+ // the future, we should update the CFI related passes to migrate these
+ // aliases to the same module as the jump-table they refer to will be defined.
+ if (Fn->hasMetadata(LLVMContext::MD_type))
+ return true;
+
+ // For comdat functions, an alias would need the same linkage as the original
+ // function and hidden visibility. There is no point in adding an alias with
+ // identical linkage an visibility to avoid introducing symbolic relocations.
+ if (Fn->hasComdat() &&
+ (Fn->getVisibility() == GlobalValue::VisibilityTypes::HiddenVisibility))
+ return true;
+
+ // its OK to use an alias
+ return false;
+}
+
+static inline Constant *getFuncAddrForProfData(Function *Fn) {
+ auto *Int8PtrTy = Type::getInt8PtrTy(Fn->getContext());
+ // Store a nullptr in __llvm_profd, if we shouldn't use a real address
+ if (!shouldRecordFunctionAddr(Fn))
+ return ConstantPointerNull::get(Int8PtrTy);
+
+ // If we can't use an alias, we must use the public symbol, even though this
+ // may require a symbolic relocation.
+ if (shouldUsePublicSymbol(Fn))
+ return ConstantExpr::getBitCast(Fn, Int8PtrTy);
+
+ // When possible use a private alias to avoid symbolic relocations.
+ auto *GA = GlobalAlias::create(GlobalValue::LinkageTypes::PrivateLinkage,
+ Fn->getName() + ".local", Fn);
+
+ // When the instrumented function is a COMDAT function, we cannot use a
+ // private alias. If we did, we would create reference to a local label in
+ // this function's section. If this version of the function isn't selected by
+ // the linker, then the metadata would introduce a reference to a discarded
+ // section. So, for COMDAT functions, we need to adjust the linkage of the
+ // alias. Using hidden visibility avoids a dynamic relocation and an entry in
+ // the dynamic symbol table.
+ //
+ // Note that this handles COMDAT functions with visibility other than Hidden,
+ // since that case is covered in shouldUsePublicSymbol()
+ if (Fn->hasComdat()) {
+ GA->setLinkage(Fn->getLinkage());
+ GA->setVisibility(GlobalValue::VisibilityTypes::HiddenVisibility);
+ }
+
+ // appendToCompilerUsed(*Fn->getParent(), {GA});
+
+ return ConstantExpr::getBitCast(GA, Int8PtrTy);
+}
+
static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) {
// Don't do this for Darwin. compiler-rt uses linker magic.
if (TT.isOSDarwin())
@@ -1014,9 +1103,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
};
auto *DataTy = StructType::get(Ctx, ArrayRef(DataTypes));
- Constant *FunctionAddr = shouldRecordFunctionAddr(Fn)
- ? ConstantExpr::getBitCast(Fn, Int8PtrTy)
- : ConstantPointerNull::get(Int8PtrTy);
+ Constant *FunctionAddr = getFuncAddrForProfData(Fn);
Constant *Int16ArrayVals[IPVK_Last + 1];
for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
@@ -1116,6 +1203,7 @@ void InstrProfiling::emitVNodes() {
Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName());
VNodesVar->setSection(
getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat()));
+ VNodesVar->setAlignment(M->getDataLayout().getABITypeAlign(VNodesTy));
// VNodesVar is used by runtime but not referenced via relocation by other
// sections. Conservatively make it linker retained.
UsedVars.push_back(VNodesVar);
diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
index ab72650ae801..806afc8fcdf7 100644
--- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
@@ -12,12 +12,9 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Instrumentation.h"
-#include "llvm-c/Initialization.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/PassRegistry.h"
+#include "llvm/TargetParser/Triple.h"
using namespace llvm;
diff --git a/llvm/lib/Transforms/Instrumentation/KCFI.cpp b/llvm/lib/Transforms/Instrumentation/KCFI.cpp
index 7978c766f0f0..b1a26880c701 100644
--- a/llvm/lib/Transforms/Instrumentation/KCFI.cpp
+++ b/llvm/lib/Transforms/Instrumentation/KCFI.cpp
@@ -24,10 +24,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Target/TargetMachine.h"
-#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
@@ -76,6 +73,7 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) {
IntegerType *Int32Ty = Type::getInt32Ty(Ctx);
MDNode *VeryUnlikelyWeights =
MDBuilder(Ctx).createBranchWeights(1, (1U << 20) - 1);
+ Triple T(M.getTargetTriple());
for (CallInst *CI : KCFICalls) {
// Get the expected hash value.
@@ -96,14 +94,24 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) {
// Emit a check and trap if the target hash doesn't match.
IRBuilder<> Builder(Call);
- Value *HashPtr = Builder.CreateConstInBoundsGEP1_32(
- Int32Ty, Call->getCalledOperand(), -1);
+ Value *FuncPtr = Call->getCalledOperand();
+ // ARM uses the least significant bit of the function pointer to select
+ // between ARM and Thumb modes for the callee. Instructions are always
+ // at least 16-bit aligned, so clear the LSB before we compute the hash
+ // location.
+ if (T.isARM() || T.isThumb()) {
+ FuncPtr = Builder.CreateIntToPtr(
+ Builder.CreateAnd(Builder.CreatePtrToInt(FuncPtr, Int32Ty),
+ ConstantInt::get(Int32Ty, -2)),
+ FuncPtr->getType());
+ }
+ Value *HashPtr = Builder.CreateConstInBoundsGEP1_32(Int32Ty, FuncPtr, -1);
Value *Test = Builder.CreateICmpNE(Builder.CreateLoad(Int32Ty, HashPtr),
ConstantInt::get(Int32Ty, ExpectedHash));
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Test, Call, false, VeryUnlikelyWeights);
Builder.SetInsertPoint(ThenTerm);
- Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::trap));
+ Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::debugtrap));
++NumKCFIChecks;
}
diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index 2a1601fab45f..789ed005d03d 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -18,10 +18,12 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/MemoryProfileInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
@@ -30,18 +32,30 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/ProfileData/InstrProfReader.h"
+#include "llvm/Support/BLAKE3.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/HashBuilder.h"
+#include "llvm/Support/VirtualFileSystem.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <map>
+#include <set>
using namespace llvm;
+using namespace llvm::memprof;
#define DEBUG_TYPE "memprof"
+namespace llvm {
+extern cl::opt<bool> PGOWarnMissing;
+extern cl::opt<bool> NoPGOWarnMismatch;
+extern cl::opt<bool> NoPGOWarnMismatchComdatWeak;
+} // namespace llvm
+
constexpr int LLVM_MEM_PROFILER_VERSION = 1;
// Size of memory mapped to a single shadow location.
@@ -130,6 +144,7 @@ STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
STATISTIC(NumSkippedStackReads, "Number of non-instrumented stack reads");
STATISTIC(NumSkippedStackWrites, "Number of non-instrumented stack writes");
+STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile.");
namespace {
@@ -603,3 +618,297 @@ bool MemProfiler::instrumentFunction(Function &F) {
return FunctionModified;
}
+
+static void addCallsiteMetadata(Instruction &I,
+ std::vector<uint64_t> &InlinedCallStack,
+ LLVMContext &Ctx) {
+ I.setMetadata(LLVMContext::MD_callsite,
+ buildCallstackMetadata(InlinedCallStack, Ctx));
+}
+
+static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset,
+ uint32_t Column) {
+ llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little>
+ HashBuilder;
+ HashBuilder.add(Function, LineOffset, Column);
+ llvm::BLAKE3Result<8> Hash = HashBuilder.final();
+ uint64_t Id;
+ std::memcpy(&Id, Hash.data(), sizeof(Hash));
+ return Id;
+}
+
+static uint64_t computeStackId(const memprof::Frame &Frame) {
+ return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column);
+}
+
+static void addCallStack(CallStackTrie &AllocTrie,
+ const AllocationInfo *AllocInfo) {
+ SmallVector<uint64_t> StackIds;
+ for (const auto &StackFrame : AllocInfo->CallStack)
+ StackIds.push_back(computeStackId(StackFrame));
+ auto AllocType = getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(),
+ AllocInfo->Info.getAllocCount(),
+ AllocInfo->Info.getTotalLifetime());
+ AllocTrie.addCallStack(AllocType, StackIds);
+}
+
+// Helper to compare the InlinedCallStack computed from an instruction's debug
+// info to a list of Frames from profile data (either the allocation data or a
+// callsite). For callsites, the StartIndex to use in the Frame array may be
+// non-zero.
+static bool
+stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack,
+ ArrayRef<uint64_t> InlinedCallStack,
+ unsigned StartIndex = 0) {
+ auto StackFrame = ProfileCallStack.begin() + StartIndex;
+ auto InlCallStackIter = InlinedCallStack.begin();
+ for (; StackFrame != ProfileCallStack.end() &&
+ InlCallStackIter != InlinedCallStack.end();
+ ++StackFrame, ++InlCallStackIter) {
+ uint64_t StackId = computeStackId(*StackFrame);
+ if (StackId != *InlCallStackIter)
+ return false;
+ }
+ // Return true if we found and matched all stack ids from the call
+ // instruction.
+ return InlCallStackIter == InlinedCallStack.end();
+}
+
+static void readMemprof(Module &M, Function &F,
+ IndexedInstrProfReader *MemProfReader,
+ const TargetLibraryInfo &TLI) {
+ auto &Ctx = M.getContext();
+
+ auto FuncName = getPGOFuncName(F);
+ auto FuncGUID = Function::getGUID(FuncName);
+ Expected<memprof::MemProfRecord> MemProfResult =
+ MemProfReader->getMemProfRecord(FuncGUID);
+ if (Error E = MemProfResult.takeError()) {
+ handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
+ auto Err = IPE.get();
+ bool SkipWarning = false;
+ LLVM_DEBUG(dbgs() << "Error in reading profile for Func " << FuncName
+ << ": ");
+ if (Err == instrprof_error::unknown_function) {
+ NumOfMemProfMissing++;
+ SkipWarning = !PGOWarnMissing;
+ LLVM_DEBUG(dbgs() << "unknown function");
+ } else if (Err == instrprof_error::hash_mismatch) {
+ SkipWarning =
+ NoPGOWarnMismatch ||
+ (NoPGOWarnMismatchComdatWeak &&
+ (F.hasComdat() ||
+ F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
+ LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")");
+ }
+
+ if (SkipWarning)
+ return;
+
+ std::string Msg = (IPE.message() + Twine(" ") + F.getName().str() +
+ Twine(" Hash = ") + std::to_string(FuncGUID))
+ .str();
+
+ Ctx.diagnose(
+ DiagnosticInfoPGOProfile(M.getName().data(), Msg, DS_Warning));
+ });
+ return;
+ }
+
+ // Build maps of the location hash to all profile data with that leaf location
+ // (allocation info and the callsites).
+ std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo;
+ // For the callsites we need to record the index of the associated frame in
+ // the frame array (see comments below where the map entries are added).
+ std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>>
+ LocHashToCallSites;
+ const auto MemProfRec = std::move(MemProfResult.get());
+ for (auto &AI : MemProfRec.AllocSites) {
+ // Associate the allocation info with the leaf frame. The later matching
+ // code will match any inlined call sequences in the IR with a longer prefix
+ // of call stack frames.
+ uint64_t StackId = computeStackId(AI.CallStack[0]);
+ LocHashToAllocInfo[StackId].insert(&AI);
+ }
+ for (auto &CS : MemProfRec.CallSites) {
+ // Need to record all frames from leaf up to and including this function,
+ // as any of these may or may not have been inlined at this point.
+ unsigned Idx = 0;
+ for (auto &StackFrame : CS) {
+ uint64_t StackId = computeStackId(StackFrame);
+ LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++));
+ // Once we find this function, we can stop recording.
+ if (StackFrame.Function == FuncGUID)
+ break;
+ }
+ assert(Idx <= CS.size() && CS[Idx - 1].Function == FuncGUID);
+ }
+
+ auto GetOffset = [](const DILocation *DIL) {
+ return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
+ 0xffff;
+ };
+
+ // Now walk the instructions, looking up the associated profile data using
+ // dbug locations.
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ if (I.isDebugOrPseudoInst())
+ continue;
+ // We are only interested in calls (allocation or interior call stack
+ // context calls).
+ auto *CI = dyn_cast<CallBase>(&I);
+ if (!CI)
+ continue;
+ auto *CalledFunction = CI->getCalledFunction();
+ if (CalledFunction && CalledFunction->isIntrinsic())
+ continue;
+ // List of call stack ids computed from the location hashes on debug
+ // locations (leaf to inlined at root).
+ std::vector<uint64_t> InlinedCallStack;
+ // Was the leaf location found in one of the profile maps?
+ bool LeafFound = false;
+ // If leaf was found in a map, iterators pointing to its location in both
+ // of the maps. It might exist in neither, one, or both (the latter case
+ // can happen because we don't currently have discriminators to
+ // distinguish the case when a single line/col maps to both an allocation
+ // and another callsite).
+ std::map<uint64_t, std::set<const AllocationInfo *>>::iterator
+ AllocInfoIter;
+ std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *,
+ unsigned>>>::iterator CallSitesIter;
+ for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr;
+ DIL = DIL->getInlinedAt()) {
+ // Use C++ linkage name if possible. Need to compile with
+ // -fdebug-info-for-profiling to get linkage name.
+ StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName();
+ if (Name.empty())
+ Name = DIL->getScope()->getSubprogram()->getName();
+ auto CalleeGUID = Function::getGUID(Name);
+ auto StackId =
+ computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn());
+ // LeafFound will only be false on the first iteration, since we either
+ // set it true or break out of the loop below.
+ if (!LeafFound) {
+ AllocInfoIter = LocHashToAllocInfo.find(StackId);
+ CallSitesIter = LocHashToCallSites.find(StackId);
+ // Check if the leaf is in one of the maps. If not, no need to look
+ // further at this call.
+ if (AllocInfoIter == LocHashToAllocInfo.end() &&
+ CallSitesIter == LocHashToCallSites.end())
+ break;
+ LeafFound = true;
+ }
+ InlinedCallStack.push_back(StackId);
+ }
+ // If leaf not in either of the maps, skip inst.
+ if (!LeafFound)
+ continue;
+
+ // First add !memprof metadata from allocation info, if we found the
+ // instruction's leaf location in that map, and if the rest of the
+ // instruction's locations match the prefix Frame locations on an
+ // allocation context with the same leaf.
+ if (AllocInfoIter != LocHashToAllocInfo.end()) {
+ // Only consider allocations via new, to reduce unnecessary metadata,
+ // since those are the only allocations that will be targeted initially.
+ if (!isNewLikeFn(CI, &TLI))
+ continue;
+ // We may match this instruction's location list to multiple MIB
+ // contexts. Add them to a Trie specialized for trimming the contexts to
+ // the minimal needed to disambiguate contexts with unique behavior.
+ CallStackTrie AllocTrie;
+ for (auto *AllocInfo : AllocInfoIter->second) {
+ // Check the full inlined call stack against this one.
+ // If we found and thus matched all frames on the call, include
+ // this MIB.
+ if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack,
+ InlinedCallStack))
+ addCallStack(AllocTrie, AllocInfo);
+ }
+ // We might not have matched any to the full inlined call stack.
+ // But if we did, create and attach metadata, or a function attribute if
+ // all contexts have identical profiled behavior.
+ if (!AllocTrie.empty()) {
+ // MemprofMDAttached will be false if a function attribute was
+ // attached.
+ bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI);
+ assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof));
+ if (MemprofMDAttached) {
+ // Add callsite metadata for the instruction's location list so that
+ // it simpler later on to identify which part of the MIB contexts
+ // are from this particular instruction (including during inlining,
+ // when the callsite metdata will be updated appropriately).
+ // FIXME: can this be changed to strip out the matching stack
+ // context ids from the MIB contexts and not add any callsite
+ // metadata here to save space?
+ addCallsiteMetadata(I, InlinedCallStack, Ctx);
+ }
+ }
+ continue;
+ }
+
+ // Otherwise, add callsite metadata. If we reach here then we found the
+ // instruction's leaf location in the callsites map and not the allocation
+ // map.
+ assert(CallSitesIter != LocHashToCallSites.end());
+ for (auto CallStackIdx : CallSitesIter->second) {
+ // If we found and thus matched all frames on the call, create and
+ // attach call stack metadata.
+ if (stackFrameIncludesInlinedCallStack(
+ *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) {
+ addCallsiteMetadata(I, InlinedCallStack, Ctx);
+ // Only need to find one with a matching call stack and add a single
+ // callsite metadata.
+ break;
+ }
+ }
+ }
+ }
+}
+
+MemProfUsePass::MemProfUsePass(std::string MemoryProfileFile,
+ IntrusiveRefCntPtr<vfs::FileSystem> FS)
+ : MemoryProfileFileName(MemoryProfileFile), FS(FS) {
+ if (!FS)
+ this->FS = vfs::getRealFileSystem();
+}
+
+PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
+ LLVM_DEBUG(dbgs() << "Read in memory profile:");
+ auto &Ctx = M.getContext();
+ auto ReaderOrErr = IndexedInstrProfReader::create(MemoryProfileFileName, *FS);
+ if (Error E = ReaderOrErr.takeError()) {
+ handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) {
+ Ctx.diagnose(
+ DiagnosticInfoPGOProfile(MemoryProfileFileName.data(), EI.message()));
+ });
+ return PreservedAnalyses::all();
+ }
+
+ std::unique_ptr<IndexedInstrProfReader> MemProfReader =
+ std::move(ReaderOrErr.get());
+ if (!MemProfReader) {
+ Ctx.diagnose(DiagnosticInfoPGOProfile(
+ MemoryProfileFileName.data(), StringRef("Cannot get MemProfReader")));
+ return PreservedAnalyses::all();
+ }
+
+ if (!MemProfReader->hasMemoryProfile()) {
+ Ctx.diagnose(DiagnosticInfoPGOProfile(MemoryProfileFileName.data(),
+ "Not a memory profile"));
+ return PreservedAnalyses::all();
+ }
+
+ auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+
+ for (auto &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
+ readMemprof(M, F, MemProfReader.get(), TLI);
+ }
+
+ return PreservedAnalyses::none();
+}
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index fe8b8ce0dc86..83d90049abc3 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -122,6 +122,10 @@
/// Arbitrary sized accesses are handled with:
/// __msan_metadata_ptr_for_load_n(ptr, size)
/// __msan_metadata_ptr_for_store_n(ptr, size);
+/// Note that the sanitizer code has to deal with how shadow/origin pairs
+/// returned by the these functions are represented in different ABIs. In
+/// the X86_64 ABI they are returned in RDX:RAX, and in the SystemZ ABI they
+/// are written to memory pointed to by a hidden parameter.
/// - TLS variables are stored in a single per-task struct. A call to a
/// function __msan_get_context_state() returning a pointer to that struct
/// is inserted into every instrumented function before the entry block;
@@ -135,7 +139,7 @@
/// Also, KMSAN currently ignores uninitialized memory passed into inline asm
/// calls, making sure we're on the safe side wrt. possible false positives.
///
-/// KernelMemorySanitizer only supports X86_64 at the moment.
+/// KernelMemorySanitizer only supports X86_64 and SystemZ at the moment.
///
//
// FIXME: This sanitizer does not yet handle scalable vectors
@@ -152,11 +156,11 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Argument.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallingConv.h"
@@ -190,6 +194,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -434,6 +439,14 @@ static const MemoryMapParams Linux_AArch64_MemoryMapParams = {
0x0200000000000, // OriginBase
};
+// loongarch64 Linux
+static const MemoryMapParams Linux_LoongArch64_MemoryMapParams = {
+ 0, // AndMask (not used)
+ 0x500000000000, // XorMask
+ 0, // ShadowBase (not used)
+ 0x100000000000, // OriginBase
+};
+
// aarch64 FreeBSD
static const MemoryMapParams FreeBSD_AArch64_MemoryMapParams = {
0x1800000000000, // AndMask
@@ -491,6 +504,11 @@ static const PlatformMemoryMapParams Linux_ARM_MemoryMapParams = {
&Linux_AArch64_MemoryMapParams,
};
+static const PlatformMemoryMapParams Linux_LoongArch_MemoryMapParams = {
+ nullptr,
+ &Linux_LoongArch64_MemoryMapParams,
+};
+
static const PlatformMemoryMapParams FreeBSD_ARM_MemoryMapParams = {
nullptr,
&FreeBSD_AArch64_MemoryMapParams,
@@ -543,6 +561,10 @@ private:
void createKernelApi(Module &M, const TargetLibraryInfo &TLI);
void createUserspaceApi(Module &M, const TargetLibraryInfo &TLI);
+ template <typename... ArgsTy>
+ FunctionCallee getOrInsertMsanMetadataFunction(Module &M, StringRef Name,
+ ArgsTy... Args);
+
/// True if we're compiling the Linux kernel.
bool CompileKernel;
/// Track origins (allocation points) of uninitialized values.
@@ -550,6 +572,7 @@ private:
bool Recover;
bool EagerChecks;
+ Triple TargetTriple;
LLVMContext *C;
Type *IntptrTy;
Type *OriginTy;
@@ -620,13 +643,18 @@ private:
/// Functions for poisoning/unpoisoning local variables
FunctionCallee MsanPoisonAllocaFn, MsanUnpoisonAllocaFn;
- /// Each of the MsanMetadataPtrXxx functions returns a pair of shadow/origin
- /// pointers.
+ /// Pair of shadow/origin pointers.
+ Type *MsanMetadata;
+
+ /// Each of the MsanMetadataPtrXxx functions returns a MsanMetadata.
FunctionCallee MsanMetadataPtrForLoadN, MsanMetadataPtrForStoreN;
FunctionCallee MsanMetadataPtrForLoad_1_8[4];
FunctionCallee MsanMetadataPtrForStore_1_8[4];
FunctionCallee MsanInstrumentAsmStoreFn;
+ /// Storage for return values of the MsanMetadataPtrXxx functions.
+ Value *MsanMetadataAlloca;
+
/// Helper to choose between different MsanMetadataPtrXxx().
FunctionCallee getKmsanShadowOriginAccessFn(bool isStore, int size);
@@ -706,7 +734,7 @@ void MemorySanitizerPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<MemorySanitizerPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Options.Recover)
OS << "recover;";
if (Options.Kernel)
@@ -714,7 +742,7 @@ void MemorySanitizerPass::printPipeline(
if (Options.EagerChecks)
OS << "eager-checks;";
OS << "track-origins=" << Options.TrackOrigins;
- OS << ">";
+ OS << '>';
}
/// Create a non-const global initialized with the given string.
@@ -729,6 +757,21 @@ static GlobalVariable *createPrivateConstGlobalForString(Module &M,
GlobalValue::PrivateLinkage, StrConst, "");
}
+template <typename... ArgsTy>
+FunctionCallee
+MemorySanitizer::getOrInsertMsanMetadataFunction(Module &M, StringRef Name,
+ ArgsTy... Args) {
+ if (TargetTriple.getArch() == Triple::systemz) {
+ // SystemZ ABI: shadow/origin pair is returned via a hidden parameter.
+ return M.getOrInsertFunction(Name, Type::getVoidTy(*C),
+ PointerType::get(MsanMetadata, 0),
+ std::forward<ArgsTy>(Args)...);
+ }
+
+ return M.getOrInsertFunction(Name, MsanMetadata,
+ std::forward<ArgsTy>(Args)...);
+}
+
/// Create KMSAN API callbacks.
void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) {
IRBuilder<> IRB(*C);
@@ -758,25 +801,25 @@ void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) {
MsanGetContextStateFn = M.getOrInsertFunction(
"__msan_get_context_state", PointerType::get(MsanContextStateTy, 0));
- Type *RetTy = StructType::get(PointerType::get(IRB.getInt8Ty(), 0),
- PointerType::get(IRB.getInt32Ty(), 0));
+ MsanMetadata = StructType::get(PointerType::get(IRB.getInt8Ty(), 0),
+ PointerType::get(IRB.getInt32Ty(), 0));
for (int ind = 0, size = 1; ind < 4; ind++, size <<= 1) {
std::string name_load =
"__msan_metadata_ptr_for_load_" + std::to_string(size);
std::string name_store =
"__msan_metadata_ptr_for_store_" + std::to_string(size);
- MsanMetadataPtrForLoad_1_8[ind] = M.getOrInsertFunction(
- name_load, RetTy, PointerType::get(IRB.getInt8Ty(), 0));
- MsanMetadataPtrForStore_1_8[ind] = M.getOrInsertFunction(
- name_store, RetTy, PointerType::get(IRB.getInt8Ty(), 0));
+ MsanMetadataPtrForLoad_1_8[ind] = getOrInsertMsanMetadataFunction(
+ M, name_load, PointerType::get(IRB.getInt8Ty(), 0));
+ MsanMetadataPtrForStore_1_8[ind] = getOrInsertMsanMetadataFunction(
+ M, name_store, PointerType::get(IRB.getInt8Ty(), 0));
}
- MsanMetadataPtrForLoadN = M.getOrInsertFunction(
- "__msan_metadata_ptr_for_load_n", RetTy,
- PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty());
- MsanMetadataPtrForStoreN = M.getOrInsertFunction(
- "__msan_metadata_ptr_for_store_n", RetTy,
+ MsanMetadataPtrForLoadN = getOrInsertMsanMetadataFunction(
+ M, "__msan_metadata_ptr_for_load_n", PointerType::get(IRB.getInt8Ty(), 0),
+ IRB.getInt64Ty());
+ MsanMetadataPtrForStoreN = getOrInsertMsanMetadataFunction(
+ M, "__msan_metadata_ptr_for_store_n",
PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty());
// Functions for poisoning and unpoisoning memory.
@@ -927,6 +970,8 @@ FunctionCallee MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore,
void MemorySanitizer::initializeModule(Module &M) {
auto &DL = M.getDataLayout();
+ TargetTriple = Triple(M.getTargetTriple());
+
bool ShadowPassed = ClShadowBase.getNumOccurrences() > 0;
bool OriginPassed = ClOriginBase.getNumOccurrences() > 0;
// Check the overrides first
@@ -937,7 +982,6 @@ void MemorySanitizer::initializeModule(Module &M) {
CustomMapParams.OriginBase = ClOriginBase;
MapParams = &CustomMapParams;
} else {
- Triple TargetTriple(M.getTargetTriple());
switch (TargetTriple.getOS()) {
case Triple::FreeBSD:
switch (TargetTriple.getArch()) {
@@ -986,6 +1030,9 @@ void MemorySanitizer::initializeModule(Module &M) {
case Triple::aarch64_be:
MapParams = Linux_ARM_MemoryMapParams.bits64;
break;
+ case Triple::loongarch64:
+ MapParams = Linux_LoongArch_MemoryMapParams.bits64;
+ break;
default:
report_fatal_error("unsupported architecture");
}
@@ -1056,10 +1103,14 @@ struct MemorySanitizerVisitor;
static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan,
MemorySanitizerVisitor &Visitor);
-static unsigned TypeSizeToSizeIndex(unsigned TypeSize) {
- if (TypeSize <= 8)
+static unsigned TypeSizeToSizeIndex(TypeSize TS) {
+ if (TS.isScalable())
+ // Scalable types unconditionally take slowpaths.
+ return kNumberOfAccessSizes;
+ unsigned TypeSizeFixed = TS.getFixedValue();
+ if (TypeSizeFixed <= 8)
return 0;
- return Log2_32_Ceil((TypeSize + 7) / 8);
+ return Log2_32_Ceil((TypeSizeFixed + 7) / 8);
}
namespace {
@@ -1178,13 +1229,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
/// Fill memory range with the given origin value.
void paintOrigin(IRBuilder<> &IRB, Value *Origin, Value *OriginPtr,
- unsigned Size, Align Alignment) {
+ TypeSize TS, Align Alignment) {
const DataLayout &DL = F.getParent()->getDataLayout();
const Align IntptrAlignment = DL.getABITypeAlign(MS.IntptrTy);
unsigned IntptrSize = DL.getTypeStoreSize(MS.IntptrTy);
assert(IntptrAlignment >= kMinOriginAlignment);
assert(IntptrSize >= kOriginSize);
+ // Note: The loop based formation works for fixed length vectors too,
+ // however we prefer to unroll and specialize alignment below.
+ if (TS.isScalable()) {
+ Value *Size = IRB.CreateTypeSize(IRB.getInt32Ty(), TS);
+ Value *RoundUp = IRB.CreateAdd(Size, IRB.getInt32(kOriginSize - 1));
+ Value *End = IRB.CreateUDiv(RoundUp, IRB.getInt32(kOriginSize));
+ auto [InsertPt, Index] =
+ SplitBlockAndInsertSimpleForLoop(End, &*IRB.GetInsertPoint());
+ IRB.SetInsertPoint(InsertPt);
+
+ Value *GEP = IRB.CreateGEP(MS.OriginTy, OriginPtr, Index);
+ IRB.CreateAlignedStore(Origin, GEP, kMinOriginAlignment);
+ return;
+ }
+
+ unsigned Size = TS.getFixedValue();
+
unsigned Ofs = 0;
Align CurrentAlignment = Alignment;
if (Alignment >= IntptrAlignment && IntptrSize > kOriginSize) {
@@ -1212,7 +1280,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Value *OriginPtr, Align Alignment) {
const DataLayout &DL = F.getParent()->getDataLayout();
const Align OriginAlignment = std::max(kMinOriginAlignment, Alignment);
- unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType());
+ TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType());
Value *ConvertedShadow = convertShadowToScalar(Shadow, IRB);
if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) {
if (!ClCheckConstantShadow || ConstantShadow->isZeroValue()) {
@@ -1229,7 +1297,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// Fallback to runtime check, which still can be optimized out later.
}
- unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType());
+ TypeSize TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType());
unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits);
if (instrumentWithCalls(ConvertedShadow) &&
SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) {
@@ -1325,7 +1393,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
void materializeOneCheck(IRBuilder<> &IRB, Value *ConvertedShadow,
Value *Origin) {
const DataLayout &DL = F.getParent()->getDataLayout();
- unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType());
+ TypeSize TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType());
unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits);
if (instrumentWithCalls(ConvertedShadow) &&
SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) {
@@ -1443,6 +1511,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
MS.RetvalOriginTLS =
IRB.CreateGEP(MS.MsanContextStateTy, ContextState,
{Zero, IRB.getInt32(6)}, "retval_origin");
+ if (MS.TargetTriple.getArch() == Triple::systemz)
+ MS.MsanMetadataAlloca = IRB.CreateAlloca(MS.MsanMetadata, 0u);
}
/// Add MemorySanitizer instrumentation to a function.
@@ -1505,8 +1575,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
const DataLayout &DL = F.getParent()->getDataLayout();
if (VectorType *VT = dyn_cast<VectorType>(OrigTy)) {
uint32_t EltSize = DL.getTypeSizeInBits(VT->getElementType());
- return FixedVectorType::get(IntegerType::get(*MS.C, EltSize),
- cast<FixedVectorType>(VT)->getNumElements());
+ return VectorType::get(IntegerType::get(*MS.C, EltSize),
+ VT->getElementCount());
}
if (ArrayType *AT = dyn_cast<ArrayType>(OrigTy)) {
return ArrayType::get(getShadowTy(AT->getElementType()),
@@ -1524,14 +1594,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
return IntegerType::get(*MS.C, TypeSize);
}
- /// Flatten a vector type.
- Type *getShadowTyNoVec(Type *ty) {
- if (VectorType *vt = dyn_cast<VectorType>(ty))
- return IntegerType::get(*MS.C,
- vt->getPrimitiveSizeInBits().getFixedValue());
- return ty;
- }
-
/// Extract combined shadow of struct elements as a bool
Value *collapseStructShadow(StructType *Struct, Value *Shadow,
IRBuilder<> &IRB) {
@@ -1541,8 +1603,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
for (unsigned Idx = 0; Idx < Struct->getNumElements(); Idx++) {
// Combine by ORing together each element's bool shadow
Value *ShadowItem = IRB.CreateExtractValue(Shadow, Idx);
- Value *ShadowInner = convertShadowToScalar(ShadowItem, IRB);
- Value *ShadowBool = convertToBool(ShadowInner, IRB);
+ Value *ShadowBool = convertToBool(ShadowItem, IRB);
if (Aggregator != FalseVal)
Aggregator = IRB.CreateOr(Aggregator, ShadowBool);
@@ -1578,11 +1639,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
return collapseStructShadow(Struct, V, IRB);
if (ArrayType *Array = dyn_cast<ArrayType>(V->getType()))
return collapseArrayShadow(Array, V, IRB);
- Type *Ty = V->getType();
- Type *NoVecTy = getShadowTyNoVec(Ty);
- if (Ty == NoVecTy)
- return V;
- return IRB.CreateBitCast(V, NoVecTy);
+ if (isa<VectorType>(V->getType())) {
+ if (isa<ScalableVectorType>(V->getType()))
+ return convertShadowToScalar(IRB.CreateOrReduce(V), IRB);
+ unsigned BitWidth =
+ V->getType()->getPrimitiveSizeInBits().getFixedValue();
+ return IRB.CreateBitCast(V, IntegerType::get(*MS.C, BitWidth));
+ }
+ return V;
}
// Convert a scalar value to an i1 by comparing with 0
@@ -1597,28 +1661,28 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
Type *ptrToIntPtrType(Type *PtrTy) const {
- if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(PtrTy)) {
- return FixedVectorType::get(ptrToIntPtrType(VectTy->getElementType()),
- VectTy->getNumElements());
+ if (VectorType *VectTy = dyn_cast<VectorType>(PtrTy)) {
+ return VectorType::get(ptrToIntPtrType(VectTy->getElementType()),
+ VectTy->getElementCount());
}
assert(PtrTy->isIntOrPtrTy());
return MS.IntptrTy;
}
Type *getPtrToShadowPtrType(Type *IntPtrTy, Type *ShadowTy) const {
- if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) {
- return FixedVectorType::get(
+ if (VectorType *VectTy = dyn_cast<VectorType>(IntPtrTy)) {
+ return VectorType::get(
getPtrToShadowPtrType(VectTy->getElementType(), ShadowTy),
- VectTy->getNumElements());
+ VectTy->getElementCount());
}
assert(IntPtrTy == MS.IntptrTy);
return ShadowTy->getPointerTo();
}
Constant *constToIntPtr(Type *IntPtrTy, uint64_t C) const {
- if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) {
- return ConstantDataVector::getSplat(
- VectTy->getNumElements(), constToIntPtr(VectTy->getElementType(), C));
+ if (VectorType *VectTy = dyn_cast<VectorType>(IntPtrTy)) {
+ return ConstantVector::getSplat(
+ VectTy->getElementCount(), constToIntPtr(VectTy->getElementType(), C));
}
assert(IntPtrTy == MS.IntptrTy);
return ConstantInt::get(MS.IntptrTy, C);
@@ -1681,24 +1745,37 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
return std::make_pair(ShadowPtr, OriginPtr);
}
+ template <typename... ArgsTy>
+ Value *createMetadataCall(IRBuilder<> &IRB, FunctionCallee Callee,
+ ArgsTy... Args) {
+ if (MS.TargetTriple.getArch() == Triple::systemz) {
+ IRB.CreateCall(Callee,
+ {MS.MsanMetadataAlloca, std::forward<ArgsTy>(Args)...});
+ return IRB.CreateLoad(MS.MsanMetadata, MS.MsanMetadataAlloca);
+ }
+
+ return IRB.CreateCall(Callee, {std::forward<ArgsTy>(Args)...});
+ }
+
std::pair<Value *, Value *> getShadowOriginPtrKernelNoVec(Value *Addr,
IRBuilder<> &IRB,
Type *ShadowTy,
bool isStore) {
Value *ShadowOriginPtrs;
const DataLayout &DL = F.getParent()->getDataLayout();
- int Size = DL.getTypeStoreSize(ShadowTy);
+ TypeSize Size = DL.getTypeStoreSize(ShadowTy);
FunctionCallee Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size);
Value *AddrCast =
IRB.CreatePointerCast(Addr, PointerType::get(IRB.getInt8Ty(), 0));
if (Getter) {
- ShadowOriginPtrs = IRB.CreateCall(Getter, AddrCast);
+ ShadowOriginPtrs = createMetadataCall(IRB, Getter, AddrCast);
} else {
Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size);
- ShadowOriginPtrs = IRB.CreateCall(isStore ? MS.MsanMetadataPtrForStoreN
- : MS.MsanMetadataPtrForLoadN,
- {AddrCast, SizeVal});
+ ShadowOriginPtrs = createMetadataCall(
+ IRB,
+ isStore ? MS.MsanMetadataPtrForStoreN : MS.MsanMetadataPtrForLoadN,
+ AddrCast, SizeVal);
}
Value *ShadowPtr = IRB.CreateExtractValue(ShadowOriginPtrs, 0);
ShadowPtr = IRB.CreatePointerCast(ShadowPtr, PointerType::get(ShadowTy, 0));
@@ -1714,14 +1791,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
IRBuilder<> &IRB,
Type *ShadowTy,
bool isStore) {
- FixedVectorType *VectTy = dyn_cast<FixedVectorType>(Addr->getType());
+ VectorType *VectTy = dyn_cast<VectorType>(Addr->getType());
if (!VectTy) {
assert(Addr->getType()->isPointerTy());
return getShadowOriginPtrKernelNoVec(Addr, IRB, ShadowTy, isStore);
}
// TODO: Support callbacs with vectors of addresses.
- unsigned NumElements = VectTy->getNumElements();
+ unsigned NumElements = cast<FixedVectorType>(VectTy)->getNumElements();
Value *ShadowPtrs = ConstantInt::getNullValue(
FixedVectorType::get(ShadowTy->getPointerTo(), NumElements));
Value *OriginPtrs = nullptr;
@@ -2367,9 +2444,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Constant *ConstOrigin = dyn_cast<Constant>(OpOrigin);
// No point in adding something that might result in 0 origin value.
if (!ConstOrigin || !ConstOrigin->isNullValue()) {
- Value *FlatShadow = MSV->convertShadowToScalar(OpShadow, IRB);
- Value *Cond =
- IRB.CreateICmpNE(FlatShadow, MSV->getCleanShadow(FlatShadow));
+ Value *Cond = MSV->convertToBool(OpShadow, IRB);
Origin = IRB.CreateSelect(Cond, OpOrigin, Origin);
}
}
@@ -2434,8 +2509,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (dstTy->isIntegerTy() && srcTy->isIntegerTy())
return IRB.CreateIntCast(V, dstTy, Signed);
if (dstTy->isVectorTy() && srcTy->isVectorTy() &&
- cast<FixedVectorType>(dstTy)->getNumElements() ==
- cast<FixedVectorType>(srcTy)->getNumElements())
+ cast<VectorType>(dstTy)->getElementCount() ==
+ cast<VectorType>(srcTy)->getElementCount())
return IRB.CreateIntCast(V, dstTy, Signed);
Value *V1 = IRB.CreateBitCast(V, Type::getIntNTy(*MS.C, srcSizeInBits));
Value *V2 =
@@ -2487,7 +2562,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (ConstantInt *Elt =
dyn_cast<ConstantInt>(ConstArg->getAggregateElement(Idx))) {
const APInt &V = Elt->getValue();
- APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros();
+ APInt V2 = APInt(V.getBitWidth(), 1) << V.countr_zero();
Elements.push_back(ConstantInt::get(EltTy, V2));
} else {
Elements.push_back(ConstantInt::get(EltTy, 1));
@@ -2497,7 +2572,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
} else {
if (ConstantInt *Elt = dyn_cast<ConstantInt>(ConstArg)) {
const APInt &V = Elt->getValue();
- APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros();
+ APInt V2 = APInt(V.getBitWidth(), 1) << V.countr_zero();
ShadowMul = ConstantInt::get(Ty, V2);
} else {
ShadowMul = ConstantInt::get(Ty, 1);
@@ -3356,7 +3431,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
Type *ShadowTy = getShadowTy(&I);
- Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType();
+ Type *ElementShadowTy = cast<VectorType>(ShadowTy)->getElementType();
auto [ShadowPtr, OriginPtr] =
getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ false);
@@ -3382,7 +3457,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Value *Shadow = getShadow(Values);
Type *ElementShadowTy =
- getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType());
+ getShadowTy(cast<VectorType>(Values->getType())->getElementType());
auto [ShadowPtr, OriginPtrs] =
getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ true);
@@ -3415,7 +3490,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
Type *ShadowTy = getShadowTy(&I);
- Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType();
+ Type *ElementShadowTy = cast<VectorType>(ShadowTy)->getElementType();
auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr(
Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ false);
@@ -3448,7 +3523,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Value *Shadow = getShadow(Values);
Type *ElementShadowTy =
- getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType());
+ getShadowTy(cast<VectorType>(Values->getType())->getElementType());
auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr(
Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ true);
@@ -3520,8 +3595,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Value *MaskedPassThruShadow = IRB.CreateAnd(
getShadow(PassThru), IRB.CreateSExt(IRB.CreateNeg(Mask), ShadowTy));
- Value *ConvertedShadow = convertShadowToScalar(MaskedPassThruShadow, IRB);
- Value *NotNull = convertToBool(ConvertedShadow, IRB, "_mscmp");
+ Value *NotNull = convertToBool(MaskedPassThruShadow, IRB, "_mscmp");
Value *PtrOrigin = IRB.CreateLoad(MS.OriginTy, OriginPtr);
Value *Origin = IRB.CreateSelect(NotNull, getOrigin(PassThru), PtrOrigin);
@@ -3645,11 +3719,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOrigin(&I, getOrigin(&I, 0));
}
+ void handleIsFpClass(IntrinsicInst &I) {
+ IRBuilder<> IRB(&I);
+ Value *Shadow = getShadow(&I, 0);
+ setShadow(&I, IRB.CreateICmpNE(Shadow, getCleanShadow(Shadow)));
+ setOrigin(&I, getOrigin(&I, 0));
+ }
+
void visitIntrinsicInst(IntrinsicInst &I) {
switch (I.getIntrinsicID()) {
case Intrinsic::abs:
handleAbsIntrinsic(I);
break;
+ case Intrinsic::is_fpclass:
+ handleIsFpClass(I);
+ break;
case Intrinsic::lifetime_start:
handleLifetimeStart(I);
break;
@@ -4391,11 +4475,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// Origins are always i32, so any vector conditions must be flattened.
// FIXME: consider tracking vector origins for app vectors?
if (B->getType()->isVectorTy()) {
- Type *FlatTy = getShadowTyNoVec(B->getType());
- B = IRB.CreateICmpNE(IRB.CreateBitCast(B, FlatTy),
- ConstantInt::getNullValue(FlatTy));
- Sb = IRB.CreateICmpNE(IRB.CreateBitCast(Sb, FlatTy),
- ConstantInt::getNullValue(FlatTy));
+ B = convertToBool(B, IRB);
+ Sb = convertToBool(Sb, IRB);
}
// a = select b, c, d
// Oa = Sb ? Ob : (b ? Oc : Od)
@@ -4490,9 +4571,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
if (!ElemTy->isSized())
return;
- int Size = DL.getTypeStoreSize(ElemTy);
Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy());
- Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size);
+ Value *SizeVal =
+ IRB.CreateTypeSize(MS.IntptrTy, DL.getTypeStoreSize(ElemTy));
IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal});
}
@@ -4600,8 +4681,8 @@ struct VarArgAMD64Helper : public VarArgHelper {
Function &F;
MemorySanitizer &MS;
MemorySanitizerVisitor &MSV;
- Value *VAArgTLSCopy = nullptr;
- Value *VAArgTLSOriginCopy = nullptr;
+ AllocaInst *VAArgTLSCopy = nullptr;
+ AllocaInst *VAArgTLSOriginCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
SmallVector<CallInst *, 16> VAStartInstrumentationList;
@@ -4721,7 +4802,7 @@ struct VarArgAMD64Helper : public VarArgHelper {
IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment);
if (MS.TrackOrigins) {
Value *Origin = MSV.getOrigin(A);
- unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType());
+ TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType());
MSV.paintOrigin(IRB, Origin, OriginBase, StoreSize,
std::max(kShadowTLSAlignment, kMinOriginAlignment));
}
@@ -4797,11 +4878,20 @@ struct VarArgAMD64Helper : public VarArgHelper {
Value *CopySize = IRB.CreateAdd(
ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), VAArgOverflowSize);
VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize);
+ VAArgTLSCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()),
+ CopySize, kShadowTLSAlignment, false);
+
+ Value *SrcSize = IRB.CreateBinaryIntrinsic(
+ Intrinsic::umin, CopySize,
+ ConstantInt::get(MS.IntptrTy, kParamTLSSize));
+ IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS,
+ kShadowTLSAlignment, SrcSize);
if (MS.TrackOrigins) {
VAArgTLSOriginCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSOriginCopy, Align(8), MS.VAArgOriginTLS,
- Align(8), CopySize);
+ VAArgTLSOriginCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemCpy(VAArgTLSOriginCopy, kShadowTLSAlignment,
+ MS.VAArgOriginTLS, kShadowTLSAlignment, SrcSize);
}
}
@@ -4859,7 +4949,7 @@ struct VarArgMIPS64Helper : public VarArgHelper {
Function &F;
MemorySanitizer &MS;
MemorySanitizerVisitor &MSV;
- Value *VAArgTLSCopy = nullptr;
+ AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgSize = nullptr;
SmallVector<CallInst *, 16> VAStartInstrumentationList;
@@ -4944,7 +5034,15 @@ struct VarArgMIPS64Helper : public VarArgHelper {
// If there is a va_start in this function, make a backup copy of
// va_arg_tls somewhere in the function entry block.
VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize);
+ VAArgTLSCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()),
+ CopySize, kShadowTLSAlignment, false);
+
+ Value *SrcSize = IRB.CreateBinaryIntrinsic(
+ Intrinsic::umin, CopySize,
+ ConstantInt::get(MS.IntptrTy, kParamTLSSize));
+ IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS,
+ kShadowTLSAlignment, SrcSize);
}
// Instrument va_start.
@@ -4986,7 +5084,7 @@ struct VarArgAArch64Helper : public VarArgHelper {
Function &F;
MemorySanitizer &MS;
MemorySanitizerVisitor &MSV;
- Value *VAArgTLSCopy = nullptr;
+ AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
SmallVector<CallInst *, 16> VAStartInstrumentationList;
@@ -5130,7 +5228,15 @@ struct VarArgAArch64Helper : public VarArgHelper {
Value *CopySize = IRB.CreateAdd(
ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), VAArgOverflowSize);
VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize);
+ VAArgTLSCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()),
+ CopySize, kShadowTLSAlignment, false);
+
+ Value *SrcSize = IRB.CreateBinaryIntrinsic(
+ Intrinsic::umin, CopySize,
+ ConstantInt::get(MS.IntptrTy, kParamTLSSize));
+ IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS,
+ kShadowTLSAlignment, SrcSize);
}
Value *GrArgSize = ConstantInt::get(MS.IntptrTy, kAArch64GrArgSize);
@@ -5230,7 +5336,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
Function &F;
MemorySanitizer &MS;
MemorySanitizerVisitor &MSV;
- Value *VAArgTLSCopy = nullptr;
+ AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgSize = nullptr;
SmallVector<CallInst *, 16> VAStartInstrumentationList;
@@ -5373,8 +5479,17 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
if (!VAStartInstrumentationList.empty()) {
// If there is a va_start in this function, make a backup copy of
// va_arg_tls somewhere in the function entry block.
+
VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize);
+ VAArgTLSCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()),
+ CopySize, kShadowTLSAlignment, false);
+
+ Value *SrcSize = IRB.CreateBinaryIntrinsic(
+ Intrinsic::umin, CopySize,
+ ConstantInt::get(MS.IntptrTy, kParamTLSSize));
+ IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS,
+ kShadowTLSAlignment, SrcSize);
}
// Instrument va_start.
@@ -5416,8 +5531,9 @@ struct VarArgSystemZHelper : public VarArgHelper {
Function &F;
MemorySanitizer &MS;
MemorySanitizerVisitor &MSV;
- Value *VAArgTLSCopy = nullptr;
- Value *VAArgTLSOriginCopy = nullptr;
+ bool IsSoftFloatABI;
+ AllocaInst *VAArgTLSCopy = nullptr;
+ AllocaInst *VAArgTLSOriginCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
SmallVector<CallInst *, 16> VAStartInstrumentationList;
@@ -5434,9 +5550,10 @@ struct VarArgSystemZHelper : public VarArgHelper {
VarArgSystemZHelper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV) {}
+ : F(F), MS(MS), MSV(MSV),
+ IsSoftFloatABI(F.getFnAttribute("use-soft-float").getValueAsBool()) {}
- ArgKind classifyArgument(Type *T, bool IsSoftFloatABI) {
+ ArgKind classifyArgument(Type *T) {
// T is a SystemZABIInfo::classifyArgumentType() output, and there are
// only a few possibilities of what it can be. In particular, enums, single
// element structs and large types have already been taken care of.
@@ -5474,9 +5591,6 @@ struct VarArgSystemZHelper : public VarArgHelper {
}
void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override {
- bool IsSoftFloatABI = CB.getCalledFunction()
- ->getFnAttribute("use-soft-float")
- .getValueAsBool();
unsigned GpOffset = SystemZGpOffset;
unsigned FpOffset = SystemZFpOffset;
unsigned VrIndex = 0;
@@ -5487,7 +5601,7 @@ struct VarArgSystemZHelper : public VarArgHelper {
// SystemZABIInfo does not produce ByVal parameters.
assert(!CB.paramHasAttr(ArgNo, Attribute::ByVal));
Type *T = A->getType();
- ArgKind AK = classifyArgument(T, IsSoftFloatABI);
+ ArgKind AK = classifyArgument(T);
if (AK == ArgKind::Indirect) {
T = PointerType::get(T, 0);
AK = ArgKind::GeneralPurpose;
@@ -5587,7 +5701,7 @@ struct VarArgSystemZHelper : public VarArgHelper {
IRB.CreateStore(Shadow, ShadowBase);
if (MS.TrackOrigins) {
Value *Origin = MSV.getOrigin(A);
- unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType());
+ TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType());
MSV.paintOrigin(IRB, Origin, OriginBase, StoreSize,
kMinOriginAlignment);
}
@@ -5642,11 +5756,15 @@ struct VarArgSystemZHelper : public VarArgHelper {
MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), Alignment,
/*isStore*/ true);
// TODO(iii): copy only fragments filled by visitCallBase()
+ // TODO(iii): support packed-stack && !use-soft-float
+ // For use-soft-float functions, it is enough to copy just the GPRs.
+ unsigned RegSaveAreaSize =
+ IsSoftFloatABI ? SystemZGpEndOffset : SystemZRegSaveAreaSize;
IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment,
- SystemZRegSaveAreaSize);
+ RegSaveAreaSize);
if (MS.TrackOrigins)
IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy,
- Alignment, SystemZRegSaveAreaSize);
+ Alignment, RegSaveAreaSize);
}
void copyOverflowArea(IRBuilder<> &IRB, Value *VAListTag) {
@@ -5688,11 +5806,20 @@ struct VarArgSystemZHelper : public VarArgHelper {
IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, SystemZOverflowOffset),
VAArgOverflowSize);
VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize);
+ VAArgTLSCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()),
+ CopySize, kShadowTLSAlignment, false);
+
+ Value *SrcSize = IRB.CreateBinaryIntrinsic(
+ Intrinsic::umin, CopySize,
+ ConstantInt::get(MS.IntptrTy, kParamTLSSize));
+ IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS,
+ kShadowTLSAlignment, SrcSize);
if (MS.TrackOrigins) {
VAArgTLSOriginCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize);
- IRB.CreateMemCpy(VAArgTLSOriginCopy, Align(8), MS.VAArgOriginTLS,
- Align(8), CopySize);
+ VAArgTLSOriginCopy->setAlignment(kShadowTLSAlignment);
+ IRB.CreateMemCpy(VAArgTLSOriginCopy, kShadowTLSAlignment,
+ MS.VAArgOriginTLS, kShadowTLSAlignment, SrcSize);
}
}
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 4d4eb6f8ce80..3c8f25d73c62 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -48,7 +48,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
-#include "CFGMST.h"
#include "ValueProfileCollector.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@@ -56,17 +55,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/CFG.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/Analysis/MemoryProfileInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -78,6 +73,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
@@ -99,7 +95,6 @@
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/InstrProfReader.h"
-#include "llvm/Support/BLAKE3.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CRC.h"
#include "llvm/Support/Casting.h"
@@ -109,27 +104,27 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/GraphWriter.h"
-#include "llvm/Support/HashBuilder.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation.h"
+#include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
+#include "llvm/Transforms/Instrumentation/CFGMST.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/MisExpect.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
-#include <map>
#include <memory>
#include <numeric>
#include <optional>
-#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
using namespace llvm;
-using namespace llvm::memprof;
using ProfileCount = Function::ProfileCount;
using VPCandidateInfo = ValueProfileCollector::CandidateInfo;
@@ -144,7 +139,6 @@ STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts.");
STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
-STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile.");
STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations.");
STATISTIC(NumOfCSPGOInstrument, "Number of edges instrumented in CSPGO.");
STATISTIC(NumOfCSPGOSelectInsts,
@@ -159,6 +153,7 @@ STATISTIC(NumOfCSPGOFunc,
STATISTIC(NumOfCSPGOMismatch,
"Number of functions having mismatch profile in CSPGO.");
STATISTIC(NumOfCSPGOMissing, "Number of functions without profile in CSPGO.");
+STATISTIC(NumCoveredBlocks, "Number of basic blocks that were executed");
// Command line option to specify the file to read profile from. This is
// mainly used for testing.
@@ -200,31 +195,31 @@ static cl::opt<bool> DoComdatRenaming(
cl::desc("Append function hash to the name of COMDAT function to avoid "
"function hash mismatch due to the preinliner"));
+namespace llvm {
// Command line option to enable/disable the warning about missing profile
// information.
-static cl::opt<bool>
- PGOWarnMissing("pgo-warn-missing-function", cl::init(false), cl::Hidden,
- cl::desc("Use this option to turn on/off "
- "warnings about missing profile data for "
- "functions."));
+cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", cl::init(false),
+ cl::Hidden,
+ cl::desc("Use this option to turn on/off "
+ "warnings about missing profile data for "
+ "functions."));
-namespace llvm {
// Command line option to enable/disable the warning about a hash mismatch in
// the profile data.
cl::opt<bool>
NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden,
cl::desc("Use this option to turn off/on "
"warnings about profile cfg mismatch."));
-} // namespace llvm
// Command line option to enable/disable the warning about a hash mismatch in
// the profile data for Comdat functions, which often turns out to be false
// positive due to the pre-instrumentation inline.
-static cl::opt<bool> NoPGOWarnMismatchComdatWeak(
+cl::opt<bool> NoPGOWarnMismatchComdatWeak(
"no-pgo-warn-mismatch-comdat-weak", cl::init(true), cl::Hidden,
cl::desc("The option is used to turn on/off "
"warnings about hash mismatch for comdat "
"or weak functions."));
+} // namespace llvm
// Command line option to enable/disable select instruction instrumentation.
static cl::opt<bool>
@@ -268,6 +263,19 @@ static cl::opt<bool> PGOFunctionEntryCoverage(
cl::desc(
"Use this option to enable function entry coverage instrumentation."));
+static cl::opt<bool> PGOBlockCoverage(
+ "pgo-block-coverage",
+ cl::desc("Use this option to enable basic block coverage instrumentation"));
+
+static cl::opt<bool>
+ PGOViewBlockCoverageGraph("pgo-view-block-coverage-graph",
+ cl::desc("Create a dot file of CFGs with block "
+ "coverage inference information"));
+
+static cl::opt<bool> PGOTemporalInstrumentation(
+ "pgo-temporal-instrumentation",
+ cl::desc("Use this option to enable temporal instrumentation"));
+
static cl::opt<bool>
PGOFixEntryCount("pgo-fix-entry-count", cl::init(true), cl::Hidden,
cl::desc("Fix function entry count in profile use."));
@@ -305,10 +313,6 @@ static cl::opt<unsigned> PGOFunctionSizeThreshold(
"pgo-function-size-threshold", cl::Hidden,
cl::desc("Do not instrument functions smaller than this threshold."));
-static cl::opt<bool> MatchMemProf(
- "pgo-match-memprof", cl::init(true), cl::Hidden,
- cl::desc("Perform matching and annotation of memprof profiles."));
-
static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold(
"pgo-critical-edge-threshold", cl::init(20000), cl::Hidden,
cl::desc("Do not instrument functions with the number of critical edges "
@@ -344,7 +348,7 @@ static std::string getBranchCondString(Instruction *TI) {
std::string result;
raw_string_ostream OS(result);
- OS << CmpInst::getPredicateName(CI->getPredicate()) << "_";
+ OS << CI->getPredicate() << "_";
CI->getOperand(0)->getType()->print(OS, true);
Value *RHS = CI->getOperand(1);
@@ -383,6 +387,10 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
if (PGOFunctionEntryCoverage)
ProfileVersion |=
VARIANT_MASK_BYTE_COVERAGE | VARIANT_MASK_FUNCTION_ENTRY_ONLY;
+ if (PGOBlockCoverage)
+ ProfileVersion |= VARIANT_MASK_BYTE_COVERAGE;
+ if (PGOTemporalInstrumentation)
+ ProfileVersion |= VARIANT_MASK_TEMPORAL_PROF;
auto IRLevelVersionVariable = new GlobalVariable(
M, IntTy64, true, GlobalValue::WeakAnyLinkage,
Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), VarName);
@@ -415,35 +423,37 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
GlobalVariable *FuncNameVar = nullptr;
uint64_t FuncHash = 0;
PGOUseFunc *UseFunc = nullptr;
+ bool HasSingleByteCoverage;
- SelectInstVisitor(Function &Func) : F(Func) {}
+ SelectInstVisitor(Function &Func, bool HasSingleByteCoverage)
+ : F(Func), HasSingleByteCoverage(HasSingleByteCoverage) {}
- void countSelects(Function &Func) {
+ void countSelects() {
NSIs = 0;
Mode = VM_counting;
- visit(Func);
+ visit(F);
}
// Visit the IR stream and instrument all select instructions. \p
// Ind is a pointer to the counter index variable; \p TotalNC
// is the total number of counters; \p FNV is the pointer to the
// PGO function name var; \p FHash is the function hash.
- void instrumentSelects(Function &Func, unsigned *Ind, unsigned TotalNC,
- GlobalVariable *FNV, uint64_t FHash) {
+ void instrumentSelects(unsigned *Ind, unsigned TotalNC, GlobalVariable *FNV,
+ uint64_t FHash) {
Mode = VM_instrument;
CurCtrIdx = Ind;
TotalNumCtrs = TotalNC;
FuncHash = FHash;
FuncNameVar = FNV;
- visit(Func);
+ visit(F);
}
// Visit the IR stream and annotate all select instructions.
- void annotateSelects(Function &Func, PGOUseFunc *UF, unsigned *Ind) {
+ void annotateSelects(PGOUseFunc *UF, unsigned *Ind) {
Mode = VM_annotate;
UseFunc = UF;
CurCtrIdx = Ind;
- visit(Func);
+ visit(F);
}
void instrumentOneSelectInst(SelectInst &SI);
@@ -457,52 +467,41 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
unsigned getNumOfSelectInsts() const { return NSIs; }
};
-} // end anonymous namespace
-
-namespace {
-
-/// An MST based instrumentation for PGO
-///
-/// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO
-/// in the function level.
+/// This class implements the CFG edges for the Minimum Spanning Tree (MST)
+/// based instrumentation.
+/// Note that the CFG can be a multi-graph. So there might be multiple edges
+/// with the same SrcBB and DestBB.
struct PGOEdge {
- // This class implements the CFG edges. Note the CFG can be a multi-graph.
- // So there might be multiple edges with same SrcBB and DestBB.
- const BasicBlock *SrcBB;
- const BasicBlock *DestBB;
+ BasicBlock *SrcBB;
+ BasicBlock *DestBB;
uint64_t Weight;
bool InMST = false;
bool Removed = false;
bool IsCritical = false;
- PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1)
+ PGOEdge(BasicBlock *Src, BasicBlock *Dest, uint64_t W = 1)
: SrcBB(Src), DestBB(Dest), Weight(W) {}
- // Return the information string of an edge.
+ /// Return the information string of an edge.
std::string infoString() const {
return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") +
- (IsCritical ? "c" : " ") + " W=" + Twine(Weight)).str();
+ (IsCritical ? "c" : " ") + " W=" + Twine(Weight))
+ .str();
}
};
-// This class stores the auxiliary information for each BB.
-struct BBInfo {
- BBInfo *Group;
+/// This class stores the auxiliary information for each BB in the MST.
+struct PGOBBInfo {
+ PGOBBInfo *Group;
uint32_t Index;
uint32_t Rank = 0;
- BBInfo(unsigned IX) : Group(this), Index(IX) {}
+ PGOBBInfo(unsigned IX) : Group(this), Index(IX) {}
- // Return the information string of this object.
+ /// Return the information string of this object.
std::string infoString() const {
return (Twine("Index=") + Twine(Index)).str();
}
-
- // Empty function -- only applicable to UseBBInfo.
- void addOutEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {}
-
- // Empty function -- only applicable to UseBBInfo.
- void addInEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {}
};
// This class implements the CFG edges. Note the CFG can be a multi-graph.
@@ -534,6 +533,16 @@ public:
// The Minimum Spanning Tree of function CFG.
CFGMST<Edge, BBInfo> MST;
+ const std::optional<BlockCoverageInference> BCI;
+
+ static std::optional<BlockCoverageInference>
+ constructBCI(Function &Func, bool HasSingleByteCoverage,
+ bool InstrumentFuncEntry) {
+ if (HasSingleByteCoverage)
+ return BlockCoverageInference(Func, InstrumentFuncEntry);
+ return {};
+ }
+
// Collect all the BBs that will be instrumented, and store them in
// InstrumentBBs.
void getInstrumentBBs(std::vector<BasicBlock *> &InstrumentBBs);
@@ -549,9 +558,9 @@ public:
BBInfo *findBBInfo(const BasicBlock *BB) const { return MST.findBBInfo(BB); }
// Dump edges and BB information.
- void dumpInfo(std::string Str = "") const {
- MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " +
- Twine(FunctionHash) + "\t" + Str);
+ void dumpInfo(StringRef Str = "") const {
+ MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName +
+ " Hash: " + Twine(FunctionHash) + "\t" + Str);
}
FuncPGOInstrumentation(
@@ -559,12 +568,16 @@ public:
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
BlockFrequencyInfo *BFI = nullptr, bool IsCS = false,
- bool InstrumentFuncEntry = true)
+ bool InstrumentFuncEntry = true, bool HasSingleByteCoverage = false)
: F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func, TLI),
- TLI(TLI), ValueSites(IPVK_Last + 1), SIVisitor(Func),
- MST(F, InstrumentFuncEntry, BPI, BFI) {
+ TLI(TLI), ValueSites(IPVK_Last + 1),
+ SIVisitor(Func, HasSingleByteCoverage),
+ MST(F, InstrumentFuncEntry, BPI, BFI),
+ BCI(constructBCI(Func, HasSingleByteCoverage, InstrumentFuncEntry)) {
+ if (BCI && PGOViewBlockCoverageGraph)
+ BCI->viewBlockCoverageGraph();
// This should be done before CFG hash computation.
- SIVisitor.countSelects(Func);
+ SIVisitor.countSelects();
ValueSites[IPVK_MemOPSize] = VPC.get(IPVK_MemOPSize);
if (!IsCS) {
NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
@@ -637,7 +650,11 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
updateJCH((uint64_t)SIVisitor.getNumOfSelectInsts());
updateJCH((uint64_t)ValueSites[IPVK_IndirectCallTarget].size());
updateJCH((uint64_t)ValueSites[IPVK_MemOPSize].size());
- updateJCH((uint64_t)MST.AllEdges.size());
+ if (BCI) {
+ updateJCH(BCI->getInstrumentedBlocksHash());
+ } else {
+ updateJCH((uint64_t)MST.AllEdges.size());
+ }
// Hash format for context sensitive profile. Reserve 4 bits for other
// information.
@@ -725,11 +742,18 @@ void FuncPGOInstrumentation<Edge, BBInfo>::renameComdatFunction() {
}
}
-// Collect all the BBs that will be instruments and return them in
-// InstrumentBBs and setup InEdges/OutEdge for UseBBInfo.
+/// Collect all the BBs that will be instruments and add them to
+/// `InstrumentBBs`.
template <class Edge, class BBInfo>
void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs(
std::vector<BasicBlock *> &InstrumentBBs) {
+ if (BCI) {
+ for (auto &BB : F)
+ if (BCI->shouldInstrumentBlock(BB))
+ InstrumentBBs.push_back(&BB);
+ return;
+ }
+
// Use a worklist as we will update the vector during the iteration.
std::vector<Edge *> EdgeList;
EdgeList.reserve(MST.AllEdges.size());
@@ -741,18 +765,6 @@ void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs(
if (InstrBB)
InstrumentBBs.push_back(InstrBB);
}
-
- // Set up InEdges/OutEdges for all BBs.
- for (auto &E : MST.AllEdges) {
- if (E->Removed)
- continue;
- const BasicBlock *SrcBB = E->SrcBB;
- const BasicBlock *DestBB = E->DestBB;
- BBInfo &SrcInfo = getBBInfo(SrcBB);
- BBInfo &DestInfo = getBBInfo(DestBB);
- SrcInfo.addOutEdge(E.get());
- DestInfo.addInEdge(E.get());
- }
}
// Given a CFG E to be instrumented, find which BB to place the instrumented
@@ -762,8 +774,8 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) {
if (E->InMST || E->Removed)
return nullptr;
- BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
- BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
+ BasicBlock *SrcBB = E->SrcBB;
+ BasicBlock *DestBB = E->DestBB;
// For a fake edge, instrument the real BB.
if (SrcBB == nullptr)
return DestBB;
@@ -852,12 +864,15 @@ static void instrumentOneFunc(
BlockFrequencyInfo *BFI,
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
bool IsCS) {
- // Split indirectbr critical edges here before computing the MST rather than
- // later in getInstrBB() to avoid invalidating it.
- SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI);
+ if (!PGOBlockCoverage) {
+ // Split indirectbr critical edges here before computing the MST rather than
+ // later in getInstrBB() to avoid invalidating it.
+ SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI);
+ }
- FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(
- F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry);
+ FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
+ F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry,
+ PGOBlockCoverage);
Type *I8PtrTy = Type::getInt8PtrTy(M->getContext());
auto Name = ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy);
@@ -880,6 +895,18 @@ static void instrumentOneFunc(
InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
uint32_t I = 0;
+ if (PGOTemporalInstrumentation) {
+ NumCounters += PGOBlockCoverage ? 8 : 1;
+ auto &EntryBB = F.getEntryBlock();
+ IRBuilder<> Builder(&EntryBB, EntryBB.getFirstInsertionPt());
+ // llvm.instrprof.timestamp(i8* <name>, i64 <hash>, i32 <num-counters>,
+ // i32 <index>)
+ Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::instrprof_timestamp),
+ {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I)});
+ I += PGOBlockCoverage ? 8 : 1;
+ }
+
for (auto *InstrBB : InstrumentBBs) {
IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt());
assert(Builder.GetInsertPoint() != InstrBB->end() &&
@@ -887,12 +914,14 @@ static void instrumentOneFunc(
// llvm.instrprof.increment(i8* <name>, i64 <hash>, i32 <num-counters>,
// i32 <index>)
Builder.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment),
+ Intrinsic::getDeclaration(M, PGOBlockCoverage
+ ? Intrinsic::instrprof_cover
+ : Intrinsic::instrprof_increment),
{Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I++)});
}
// Now instrument select instructions:
- FuncInfo.SIVisitor.instrumentSelects(F, &I, NumCounters, FuncInfo.FuncNameVar,
+ FuncInfo.SIVisitor.instrumentSelects(&I, NumCounters, FuncInfo.FuncNameVar,
FuncInfo.FunctionHash);
assert(I == NumCounters);
@@ -947,12 +976,11 @@ namespace {
// This class represents a CFG edge in profile use compilation.
struct PGOUseEdge : public PGOEdge {
+ using PGOEdge::PGOEdge;
+
bool CountValid = false;
uint64_t CountValue = 0;
- PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1)
- : PGOEdge(Src, Dest, W) {}
-
// Set edge count value
void setEdgeCount(uint64_t Value) {
CountValue = Value;
@@ -971,7 +999,7 @@ struct PGOUseEdge : public PGOEdge {
using DirectEdges = SmallVector<PGOUseEdge *, 2>;
// This class stores the auxiliary information for each BB.
-struct UseBBInfo : public BBInfo {
+struct PGOUseBBInfo : public PGOBBInfo {
uint64_t CountValue = 0;
bool CountValid;
int32_t UnknownCountInEdge = 0;
@@ -979,10 +1007,7 @@ struct UseBBInfo : public BBInfo {
DirectEdges InEdges;
DirectEdges OutEdges;
- UseBBInfo(unsigned IX) : BBInfo(IX), CountValid(false) {}
-
- UseBBInfo(unsigned IX, uint64_t C)
- : BBInfo(IX), CountValue(C), CountValid(true) {}
+ PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
// Set the profile count value for this BB.
void setBBInfoCount(uint64_t Value) {
@@ -993,8 +1018,9 @@ struct UseBBInfo : public BBInfo {
// Return the information string of this object.
std::string infoString() const {
if (!CountValid)
- return BBInfo::infoString();
- return (Twine(BBInfo::infoString()) + " Count=" + Twine(CountValue)).str();
+ return PGOBBInfo::infoString();
+ return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
+ .str();
}
// Add an OutEdge and update the edge count.
@@ -1030,22 +1056,25 @@ public:
PGOUseFunc(Function &Func, Module *Modu, TargetLibraryInfo &TLI,
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin,
- ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry)
+ ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry,
+ bool HasSingleByteCoverage)
: F(Func), M(Modu), BFI(BFIin), PSI(PSI),
FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, IsCS,
- InstrumentFuncEntry),
+ InstrumentFuncEntry, HasSingleByteCoverage),
FreqAttr(FFA_Normal), IsCS(IsCS) {}
+ void handleInstrProfError(Error Err, uint64_t MismatchedFuncSum);
+
// Read counts for the instrumented BB from profile.
bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
InstrProfRecord::CountPseudoKind &PseudoKind);
- // Read memprof data for the instrumented function from profile.
- bool readMemprof(IndexedInstrProfReader *PGOReader);
-
// Populate the counts for all BBs.
void populateCounters();
+ // Set block coverage based on profile coverage values.
+ void populateCoverage(IndexedInstrProfReader *PGOReader);
+
// Set the branch weights based on the count values.
void setBranchWeights();
@@ -1071,22 +1100,21 @@ public:
InstrProfRecord &getProfileRecord() { return ProfileRecord; }
// Return the auxiliary BB information.
- UseBBInfo &getBBInfo(const BasicBlock *BB) const {
+ PGOUseBBInfo &getBBInfo(const BasicBlock *BB) const {
return FuncInfo.getBBInfo(BB);
}
// Return the auxiliary BB information if available.
- UseBBInfo *findBBInfo(const BasicBlock *BB) const {
+ PGOUseBBInfo *findBBInfo(const BasicBlock *BB) const {
return FuncInfo.findBBInfo(BB);
}
Function &getFunc() const { return F; }
- void dumpInfo(std::string Str = "") const {
- FuncInfo.dumpInfo(Str);
- }
+ void dumpInfo(StringRef Str = "") const { FuncInfo.dumpInfo(Str); }
uint64_t getProgramMaxCount() const { return ProgramMaxCount; }
+
private:
Function &F;
Module *M;
@@ -1094,7 +1122,7 @@ private:
ProfileSummaryInfo *PSI;
// This member stores the shared information with class PGOGenFunc.
- FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
+ FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> FuncInfo;
// The maximum count value in the profile. This is only used in PGO use
// compilation.
@@ -1122,9 +1150,6 @@ private:
// one unknown edge.
void setEdgeCount(DirectEdges &Edges, uint64_t Value);
- // Return FuncName string;
- std::string getFuncName() const { return FuncInfo.FuncName; }
-
// Set the hot/cold inline hints based on the count values.
// FIXME: This function should be removed once the functionality in
// the inliner is implemented.
@@ -1138,6 +1163,24 @@ private:
} // end anonymous namespace
+/// Set up InEdges/OutEdges for all BBs in the MST.
+static void
+setupBBInfoEdges(FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) {
+ // This is not required when there is block coverage inference.
+ if (FuncInfo.BCI)
+ return;
+ for (auto &E : FuncInfo.MST.AllEdges) {
+ if (E->Removed)
+ continue;
+ const BasicBlock *SrcBB = E->SrcBB;
+ const BasicBlock *DestBB = E->DestBB;
+ PGOUseBBInfo &SrcInfo = FuncInfo.getBBInfo(SrcBB);
+ PGOUseBBInfo &DestInfo = FuncInfo.getBBInfo(DestBB);
+ SrcInfo.addOutEdge(E.get());
+ DestInfo.addInEdge(E.get());
+ }
+}
+
// Visit all the edges and assign the count value for the instrumented
// edges and the BB. Return false on error.
bool PGOUseFunc::setInstrumentedCounts(
@@ -1145,6 +1188,9 @@ bool PGOUseFunc::setInstrumentedCounts(
std::vector<BasicBlock *> InstrumentBBs;
FuncInfo.getInstrumentBBs(InstrumentBBs);
+
+ setupBBInfoEdges(FuncInfo);
+
unsigned NumCounters =
InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
// The number of counters here should match the number of counters
@@ -1158,7 +1204,7 @@ bool PGOUseFunc::setInstrumentedCounts(
uint32_t I = 0;
for (BasicBlock *InstrBB : InstrumentBBs) {
uint64_t CountValue = CountFromProfile[I++];
- UseBBInfo &Info = getBBInfo(InstrBB);
+ PGOUseBBInfo &Info = getBBInfo(InstrBB);
// If we reach here, we know that we have some nonzero count
// values in this function. The entry count should not be 0.
// Fix it if necessary.
@@ -1183,7 +1229,7 @@ bool PGOUseFunc::setInstrumentedCounts(
if (E->Removed || E->InMST)
continue;
const BasicBlock *SrcBB = E->SrcBB;
- UseBBInfo &SrcInfo = getBBInfo(SrcBB);
+ PGOUseBBInfo &SrcInfo = getBBInfo(SrcBB);
// If only one out-edge, the edge profile count should be the same as BB
// profile count.
@@ -1191,7 +1237,7 @@ bool PGOUseFunc::setInstrumentedCounts(
setEdgeCount(E.get(), SrcInfo.CountValue);
else {
const BasicBlock *DestBB = E->DestBB;
- UseBBInfo &DestInfo = getBBInfo(DestBB);
+ PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
// If only one in-edge, the edge profile count should be the same as BB
// profile count.
if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
@@ -1222,8 +1268,7 @@ void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) {
}
// Emit function metadata indicating PGO profile mismatch.
-static void annotateFunctionWithHashMismatch(Function &F,
- LLVMContext &ctx) {
+static void annotateFunctionWithHashMismatch(Function &F, LLVMContext &ctx) {
const char MetadataName[] = "instr_prof_hash_mismatch";
SmallVector<Metadata *, 2> Names;
// If this metadata already exists, ignore.
@@ -1231,7 +1276,7 @@ static void annotateFunctionWithHashMismatch(Function &F,
if (Existing) {
MDTuple *Tuple = cast<MDTuple>(Existing);
for (const auto &N : Tuple->operands()) {
- if (cast<MDString>(N.get())->getString() == MetadataName)
+ if (N.equalsStr(MetadataName))
return;
Names.push_back(N.get());
}
@@ -1243,255 +1288,44 @@ static void annotateFunctionWithHashMismatch(Function &F,
F.setMetadata(LLVMContext::MD_annotation, MD);
}
-static void addCallsiteMetadata(Instruction &I,
- std::vector<uint64_t> &InlinedCallStack,
- LLVMContext &Ctx) {
- I.setMetadata(LLVMContext::MD_callsite,
- buildCallstackMetadata(InlinedCallStack, Ctx));
-}
-
-static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset,
- uint32_t Column) {
- llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little>
- HashBuilder;
- HashBuilder.add(Function, LineOffset, Column);
- llvm::BLAKE3Result<8> Hash = HashBuilder.final();
- uint64_t Id;
- std::memcpy(&Id, Hash.data(), sizeof(Hash));
- return Id;
-}
-
-static uint64_t computeStackId(const memprof::Frame &Frame) {
- return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column);
-}
-
-static void addCallStack(CallStackTrie &AllocTrie,
- const AllocationInfo *AllocInfo) {
- SmallVector<uint64_t> StackIds;
- for (auto StackFrame : AllocInfo->CallStack)
- StackIds.push_back(computeStackId(StackFrame));
- auto AllocType = getAllocType(AllocInfo->Info.getMaxAccessCount(),
- AllocInfo->Info.getMinSize(),
- AllocInfo->Info.getMinLifetime());
- AllocTrie.addCallStack(AllocType, StackIds);
-}
-
-// Helper to compare the InlinedCallStack computed from an instruction's debug
-// info to a list of Frames from profile data (either the allocation data or a
-// callsite). For callsites, the StartIndex to use in the Frame array may be
-// non-zero.
-static bool
-stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack,
- ArrayRef<uint64_t> InlinedCallStack,
- unsigned StartIndex = 0) {
- auto StackFrame = ProfileCallStack.begin() + StartIndex;
- auto InlCallStackIter = InlinedCallStack.begin();
- for (; StackFrame != ProfileCallStack.end() &&
- InlCallStackIter != InlinedCallStack.end();
- ++StackFrame, ++InlCallStackIter) {
- uint64_t StackId = computeStackId(*StackFrame);
- if (StackId != *InlCallStackIter)
- return false;
- }
- // Return true if we found and matched all stack ids from the call
- // instruction.
- return InlCallStackIter == InlinedCallStack.end();
-}
-
-bool PGOUseFunc::readMemprof(IndexedInstrProfReader *PGOReader) {
- if (!MatchMemProf)
- return true;
-
- auto &Ctx = M->getContext();
-
- auto FuncGUID = Function::getGUID(FuncInfo.FuncName);
- Expected<memprof::MemProfRecord> MemProfResult =
- PGOReader->getMemProfRecord(FuncGUID);
- if (Error E = MemProfResult.takeError()) {
- handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
- auto Err = IPE.get();
- bool SkipWarning = false;
- LLVM_DEBUG(dbgs() << "Error in reading profile for Func "
- << FuncInfo.FuncName << ": ");
- if (Err == instrprof_error::unknown_function) {
- NumOfMemProfMissing++;
- SkipWarning = !PGOWarnMissing;
- LLVM_DEBUG(dbgs() << "unknown function");
- } else if (Err == instrprof_error::hash_mismatch) {
- SkipWarning =
- NoPGOWarnMismatch ||
- (NoPGOWarnMismatchComdatWeak &&
- (F.hasComdat() ||
- F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
- LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")");
- }
-
- if (SkipWarning)
- return;
-
- std::string Msg =
- (IPE.message() + Twine(" ") + F.getName().str() + Twine(" Hash = ") +
- std::to_string(FuncInfo.FunctionHash))
- .str();
-
- Ctx.diagnose(
- DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
- });
- return false;
- }
-
- // Build maps of the location hash to all profile data with that leaf location
- // (allocation info and the callsites).
- std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo;
- // For the callsites we need to record the index of the associated frame in
- // the frame array (see comments below where the map entries are added).
- std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>>
- LocHashToCallSites;
- const auto MemProfRec = std::move(MemProfResult.get());
- for (auto &AI : MemProfRec.AllocSites) {
- // Associate the allocation info with the leaf frame. The later matching
- // code will match any inlined call sequences in the IR with a longer prefix
- // of call stack frames.
- uint64_t StackId = computeStackId(AI.CallStack[0]);
- LocHashToAllocInfo[StackId].insert(&AI);
- }
- for (auto &CS : MemProfRec.CallSites) {
- // Need to record all frames from leaf up to and including this function,
- // as any of these may or may not have been inlined at this point.
- unsigned Idx = 0;
- for (auto &StackFrame : CS) {
- uint64_t StackId = computeStackId(StackFrame);
- LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++));
- // Once we find this function, we can stop recording.
- if (StackFrame.Function == FuncGUID)
- break;
+void PGOUseFunc::handleInstrProfError(Error Err, uint64_t MismatchedFuncSum) {
+ handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) {
+ auto &Ctx = M->getContext();
+ auto Err = IPE.get();
+ bool SkipWarning = false;
+ LLVM_DEBUG(dbgs() << "Error in reading profile for Func "
+ << FuncInfo.FuncName << ": ");
+ if (Err == instrprof_error::unknown_function) {
+ IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++;
+ SkipWarning = !PGOWarnMissing;
+ LLVM_DEBUG(dbgs() << "unknown function");
+ } else if (Err == instrprof_error::hash_mismatch ||
+ Err == instrprof_error::malformed) {
+ IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++;
+ SkipWarning =
+ NoPGOWarnMismatch ||
+ (NoPGOWarnMismatchComdatWeak &&
+ (F.hasComdat() || F.getLinkage() == GlobalValue::WeakAnyLinkage ||
+ F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
+ LLVM_DEBUG(dbgs() << "hash mismatch (hash= " << FuncInfo.FunctionHash
+ << " skip=" << SkipWarning << ")");
+ // Emit function metadata indicating PGO profile mismatch.
+ annotateFunctionWithHashMismatch(F, M->getContext());
}
- assert(Idx <= CS.size() && CS[Idx - 1].Function == FuncGUID);
- }
-
- auto GetOffset = [](const DILocation *DIL) {
- return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
- 0xffff;
- };
-
- // Now walk the instructions, looking up the associated profile data using
- // dbug locations.
- for (auto &BB : F) {
- for (auto &I : BB) {
- if (I.isDebugOrPseudoInst())
- continue;
- // We are only interested in calls (allocation or interior call stack
- // context calls).
- auto *CI = dyn_cast<CallBase>(&I);
- if (!CI)
- continue;
- auto *CalledFunction = CI->getCalledFunction();
- if (CalledFunction && CalledFunction->isIntrinsic())
- continue;
- // List of call stack ids computed from the location hashes on debug
- // locations (leaf to inlined at root).
- std::vector<uint64_t> InlinedCallStack;
- // Was the leaf location found in one of the profile maps?
- bool LeafFound = false;
- // If leaf was found in a map, iterators pointing to its location in both
- // of the maps. It might exist in neither, one, or both (the latter case
- // can happen because we don't currently have discriminators to
- // distinguish the case when a single line/col maps to both an allocation
- // and another callsite).
- std::map<uint64_t, std::set<const AllocationInfo *>>::iterator
- AllocInfoIter;
- std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *,
- unsigned>>>::iterator CallSitesIter;
- for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr;
- DIL = DIL->getInlinedAt()) {
- // Use C++ linkage name if possible. Need to compile with
- // -fdebug-info-for-profiling to get linkage name.
- StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName();
- if (Name.empty())
- Name = DIL->getScope()->getSubprogram()->getName();
- auto CalleeGUID = Function::getGUID(Name);
- auto StackId =
- computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn());
- // LeafFound will only be false on the first iteration, since we either
- // set it true or break out of the loop below.
- if (!LeafFound) {
- AllocInfoIter = LocHashToAllocInfo.find(StackId);
- CallSitesIter = LocHashToCallSites.find(StackId);
- // Check if the leaf is in one of the maps. If not, no need to look
- // further at this call.
- if (AllocInfoIter == LocHashToAllocInfo.end() &&
- CallSitesIter == LocHashToCallSites.end())
- break;
- LeafFound = true;
- }
- InlinedCallStack.push_back(StackId);
- }
- // If leaf not in either of the maps, skip inst.
- if (!LeafFound)
- continue;
- // First add !memprof metadata from allocation info, if we found the
- // instruction's leaf location in that map, and if the rest of the
- // instruction's locations match the prefix Frame locations on an
- // allocation context with the same leaf.
- if (AllocInfoIter != LocHashToAllocInfo.end()) {
- // Only consider allocations via new, to reduce unnecessary metadata,
- // since those are the only allocations that will be targeted initially.
- if (!isNewLikeFn(CI, &FuncInfo.TLI))
- continue;
- // We may match this instruction's location list to multiple MIB
- // contexts. Add them to a Trie specialized for trimming the contexts to
- // the minimal needed to disambiguate contexts with unique behavior.
- CallStackTrie AllocTrie;
- for (auto *AllocInfo : AllocInfoIter->second) {
- // Check the full inlined call stack against this one.
- // If we found and thus matched all frames on the call, include
- // this MIB.
- if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack,
- InlinedCallStack))
- addCallStack(AllocTrie, AllocInfo);
- }
- // We might not have matched any to the full inlined call stack.
- // But if we did, create and attach metadata, or a function attribute if
- // all contexts have identical profiled behavior.
- if (!AllocTrie.empty()) {
- // MemprofMDAttached will be false if a function attribute was
- // attached.
- bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI);
- assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof));
- if (MemprofMDAttached) {
- // Add callsite metadata for the instruction's location list so that
- // it simpler later on to identify which part of the MIB contexts
- // are from this particular instruction (including during inlining,
- // when the callsite metdata will be updated appropriately).
- // FIXME: can this be changed to strip out the matching stack
- // context ids from the MIB contexts and not add any callsite
- // metadata here to save space?
- addCallsiteMetadata(I, InlinedCallStack, Ctx);
- }
- }
- continue;
- }
+ LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n");
+ if (SkipWarning)
+ return;
- // Otherwise, add callsite metadata. If we reach here then we found the
- // instruction's leaf location in the callsites map and not the allocation
- // map.
- assert(CallSitesIter != LocHashToCallSites.end());
- for (auto CallStackIdx : CallSitesIter->second) {
- // If we found and thus matched all frames on the call, create and
- // attach call stack metadata.
- if (stackFrameIncludesInlinedCallStack(
- *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) {
- addCallsiteMetadata(I, InlinedCallStack, Ctx);
- // Only need to find one with a matching call stack and add a single
- // callsite metadata.
- break;
- }
- }
- }
- }
+ std::string Msg =
+ IPE.message() + std::string(" ") + F.getName().str() +
+ std::string(" Hash = ") + std::to_string(FuncInfo.FunctionHash) +
+ std::string(" up to ") + std::to_string(MismatchedFuncSum) +
+ std::string(" count discarded");
- return true;
+ Ctx.diagnose(
+ DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
+ });
}
// Read the profile from ProfileFileName and assign the value to the
@@ -1504,42 +1338,7 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(
FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum);
if (Error E = Result.takeError()) {
- handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
- auto Err = IPE.get();
- bool SkipWarning = false;
- LLVM_DEBUG(dbgs() << "Error in reading profile for Func "
- << FuncInfo.FuncName << ": ");
- if (Err == instrprof_error::unknown_function) {
- IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++;
- SkipWarning = !PGOWarnMissing;
- LLVM_DEBUG(dbgs() << "unknown function");
- } else if (Err == instrprof_error::hash_mismatch ||
- Err == instrprof_error::malformed) {
- IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++;
- SkipWarning =
- NoPGOWarnMismatch ||
- (NoPGOWarnMismatchComdatWeak &&
- (F.hasComdat() || F.getLinkage() == GlobalValue::WeakAnyLinkage ||
- F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
- LLVM_DEBUG(dbgs() << "hash mismatch (hash= " << FuncInfo.FunctionHash
- << " skip=" << SkipWarning << ")");
- // Emit function metadata indicating PGO profile mismatch.
- annotateFunctionWithHashMismatch(F, M->getContext());
- }
-
- LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n");
- if (SkipWarning)
- return;
-
- std::string Msg =
- IPE.message() + std::string(" ") + F.getName().str() +
- std::string(" Hash = ") + std::to_string(FuncInfo.FunctionHash) +
- std::string(" up to ") + std::to_string(MismatchedFuncSum) +
- std::string(" count discarded");
-
- Ctx.diagnose(
- DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
- });
+ handleInstrProfError(std::move(E), MismatchedFuncSum);
return false;
}
ProfileRecord = std::move(Result.get());
@@ -1569,8 +1368,9 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
dbgs() << "Inconsistent number of counts, skipping this function");
Ctx.diagnose(DiagnosticInfoPGOProfile(
M->getName().data(),
- Twine("Inconsistent number of counts in ") + F.getName().str()
- + Twine(": the profile may be stale or there is a function name collision."),
+ Twine("Inconsistent number of counts in ") + F.getName().str() +
+ Twine(": the profile may be stale or there is a function name "
+ "collision."),
DS_Warning));
return false;
}
@@ -1578,6 +1378,113 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
return true;
}
+void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
+ uint64_t MismatchedFuncSum = 0;
+ Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(
+ FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum);
+ if (auto Err = Result.takeError()) {
+ handleInstrProfError(std::move(Err), MismatchedFuncSum);
+ return;
+ }
+
+ std::vector<uint64_t> &CountsFromProfile = Result.get().Counts;
+ DenseMap<const BasicBlock *, bool> Coverage;
+ unsigned Index = 0;
+ for (auto &BB : F)
+ if (FuncInfo.BCI->shouldInstrumentBlock(BB))
+ Coverage[&BB] = (CountsFromProfile[Index++] != 0);
+ assert(Index == CountsFromProfile.size());
+
+ // For each B in InverseDependencies[A], if A is covered then B is covered.
+ DenseMap<const BasicBlock *, DenseSet<const BasicBlock *>>
+ InverseDependencies;
+ for (auto &BB : F) {
+ for (auto *Dep : FuncInfo.BCI->getDependencies(BB)) {
+ // If Dep is covered then BB is covered.
+ InverseDependencies[Dep].insert(&BB);
+ }
+ }
+
+ // Infer coverage of the non-instrumented blocks using a flood-fill algorithm.
+ std::stack<const BasicBlock *> CoveredBlocksToProcess;
+ for (auto &[BB, IsCovered] : Coverage)
+ if (IsCovered)
+ CoveredBlocksToProcess.push(BB);
+
+ while (!CoveredBlocksToProcess.empty()) {
+ auto *CoveredBlock = CoveredBlocksToProcess.top();
+ assert(Coverage[CoveredBlock]);
+ CoveredBlocksToProcess.pop();
+ for (auto *BB : InverseDependencies[CoveredBlock]) {
+ // If CoveredBlock is covered then BB is covered.
+ if (Coverage[BB])
+ continue;
+ Coverage[BB] = true;
+ CoveredBlocksToProcess.push(BB);
+ }
+ }
+
+ // Annotate block coverage.
+ MDBuilder MDB(F.getContext());
+ // We set the entry count to 10000 if the entry block is covered so that BFI
+ // can propagate a fraction of this count to the other covered blocks.
+ F.setEntryCount(Coverage[&F.getEntryBlock()] ? 10000 : 0);
+ for (auto &BB : F) {
+ // For a block A and its successor B, we set the edge weight as follows:
+ // If A is covered and B is covered, set weight=1.
+ // If A is covered and B is uncovered, set weight=0.
+ // If A is uncovered, set weight=1.
+ // This setup will allow BFI to give nonzero profile counts to only covered
+ // blocks.
+ SmallVector<unsigned, 4> Weights;
+ for (auto *Succ : successors(&BB))
+ Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
+ if (Weights.size() >= 2)
+ BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
+ MDB.createBranchWeights(Weights));
+ }
+
+ unsigned NumCorruptCoverage = 0;
+ DominatorTree DT(F);
+ LoopInfo LI(DT);
+ BranchProbabilityInfo BPI(F, LI);
+ BlockFrequencyInfo BFI(F, BPI, LI);
+ auto IsBlockDead = [&](const BasicBlock &BB) -> std::optional<bool> {
+ if (auto C = BFI.getBlockProfileCount(&BB))
+ return C == 0;
+ return {};
+ };
+ LLVM_DEBUG(dbgs() << "Block Coverage: (Instrumented=*, Covered=X)\n");
+ for (auto &BB : F) {
+ LLVM_DEBUG(dbgs() << (FuncInfo.BCI->shouldInstrumentBlock(BB) ? "* " : " ")
+ << (Coverage[&BB] ? "X " : " ") << " " << BB.getName()
+ << "\n");
+ // In some cases it is possible to find a covered block that has no covered
+ // successors, e.g., when a block calls a function that may call exit(). In
+ // those cases, BFI could find its successor to be covered while BCI could
+ // find its successor to be dead.
+ if (Coverage[&BB] == IsBlockDead(BB).value_or(false)) {
+ LLVM_DEBUG(
+ dbgs() << "Found inconsistent block covearge for " << BB.getName()
+ << ": BCI=" << (Coverage[&BB] ? "Covered" : "Dead") << " BFI="
+ << (IsBlockDead(BB).value() ? "Dead" : "Covered") << "\n");
+ ++NumCorruptCoverage;
+ }
+ if (Coverage[&BB])
+ ++NumCoveredBlocks;
+ }
+ if (PGOVerifyBFI && NumCorruptCoverage) {
+ auto &Ctx = M->getContext();
+ Ctx.diagnose(DiagnosticInfoPGOProfile(
+ M->getName().data(),
+ Twine("Found inconsistent block coverage for function ") + F.getName() +
+ " in " + Twine(NumCorruptCoverage) + " blocks.",
+ DS_Warning));
+ }
+ if (PGOViewBlockCoverageGraph)
+ FuncInfo.BCI->viewBlockCoverageGraph(&Coverage);
+}
+
// Populate the counters from instrumented BBs to all BBs.
// In the end of this operation, all BBs should have a valid count value.
void PGOUseFunc::populateCounters() {
@@ -1590,7 +1497,7 @@ void PGOUseFunc::populateCounters() {
// For efficient traversal, it's better to start from the end as most
// of the instrumented edges are at the end.
for (auto &BB : reverse(F)) {
- UseBBInfo *Count = findBBInfo(&BB);
+ PGOUseBBInfo *Count = findBBInfo(&BB);
if (Count == nullptr)
continue;
if (!Count->CountValid) {
@@ -1629,7 +1536,7 @@ void PGOUseFunc::populateCounters() {
}
LLVM_DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n");
- (void) NumPasses;
+ (void)NumPasses;
#ifndef NDEBUG
// Assert every BB has a valid counter.
for (auto &BB : F) {
@@ -1655,7 +1562,7 @@ void PGOUseFunc::populateCounters() {
markFunctionAttributes(FuncEntryCount, FuncMaxCount);
// Now annotate select instructions
- FuncInfo.SIVisitor.annotateSelects(F, this, &CountPosition);
+ FuncInfo.SIVisitor.annotateSelects(this, &CountPosition);
assert(CountPosition == ProfileCountSize);
LLVM_DEBUG(FuncInfo.dumpInfo("after reading profile."));
@@ -1679,7 +1586,7 @@ void PGOUseFunc::setBranchWeights() {
continue;
// We have a non-zero Branch BB.
- const UseBBInfo &BBCountInfo = getBBInfo(&BB);
+ const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
unsigned Size = BBCountInfo.OutEdges.size();
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
uint64_t MaxCount = 0;
@@ -1704,11 +1611,11 @@ void PGOUseFunc::setBranchWeights() {
// when there is no exit block and the code exits via a noreturn function.
auto &Ctx = M->getContext();
Ctx.diagnose(DiagnosticInfoPGOProfile(
- M->getName().data(),
- Twine("Profile in ") + F.getName().str() +
- Twine(" partially ignored") +
- Twine(", possibly due to the lack of a return path."),
- DS_Warning));
+ M->getName().data(),
+ Twine("Profile in ") + F.getName().str() +
+ Twine(" partially ignored") +
+ Twine(", possibly due to the lack of a return path."),
+ DS_Warning));
}
}
}
@@ -1730,15 +1637,13 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
// duplication.
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
Instruction *TI = BB.getTerminator();
- const UseBBInfo &BBCountInfo = getBBInfo(&BB);
+ const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
}
}
}
void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) {
- if (PGOFunctionEntryCoverage)
- return;
Module *M = F.getParent();
IRBuilder<> Builder(&SI);
Type *Int64Ty = Builder.getInt64Ty();
@@ -1771,7 +1676,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
}
void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
- if (!PGOInstrSelect)
+ if (!PGOInstrSelect || PGOFunctionEntryCoverage || HasSingleByteCoverage)
return;
// FIXME: do not handle this yet.
if (SI.getCondition()->getType()->isVectorTy())
@@ -1815,8 +1720,8 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) {
Ctx.diagnose(DiagnosticInfoPGOProfile(
M->getName().data(),
Twine("Inconsistent number of value sites for ") +
- Twine(ValueProfKindDescr[Kind]) +
- Twine(" profiling in \"") + F.getName().str() +
+ Twine(ValueProfKindDescr[Kind]) + Twine(" profiling in \"") +
+ F.getName().str() +
Twine("\", possibly due to the use of a stale profile."),
DS_Warning));
return;
@@ -1907,17 +1812,20 @@ static bool InstrumentAllFunctions(
}
PreservedAnalyses
-PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &AM) {
+PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) {
createProfileFileNameVar(M, CSInstrName);
// The variable in a comdat may be discarded by LTO. Ensure the declaration
// will be retained.
appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true));
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<FunctionAnalysisManagerModuleProxy>();
+ PA.preserveSet<AllAnalysesOn<Function>>();
+ return PA;
}
PreservedAnalyses PGOInstrumentationGen::run(Module &M,
- ModuleAnalysisManager &AM) {
- auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ ModuleAnalysisManager &MAM) {
+ auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
auto LookupTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
return FAM.getResult<TargetLibraryAnalysis>(F);
};
@@ -1991,7 +1899,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
BlockFrequencyInfo NBFI(F, NBPI, LI);
// bool PrintFunc = false;
bool HotBBOnly = PGOVerifyHotBFI;
- std::string Msg;
+ StringRef Msg;
OptimizationRemarkEmitter ORE(&F);
unsigned BBNum = 0, BBMisMatchNum = 0, NonZeroBBNum = 0;
@@ -2059,6 +1967,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
static bool annotateAllFunctions(
Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName,
+ vfs::FileSystem &FS,
function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
@@ -2066,8 +1975,8 @@ static bool annotateAllFunctions(
LLVM_DEBUG(dbgs() << "Read in profile counters: ");
auto &Ctx = M.getContext();
// Read the counter array from file.
- auto ReaderOrErr =
- IndexedInstrProfReader::create(ProfileFileName, ProfileRemappingFileName);
+ auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName, FS,
+ ProfileRemappingFileName);
if (Error E = ReaderOrErr.takeError()) {
handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) {
Ctx.diagnose(
@@ -2087,17 +1996,11 @@ static bool annotateAllFunctions(
return false;
// TODO: might need to change the warning once the clang option is finalized.
- if (!PGOReader->isIRLevelProfile() && !PGOReader->hasMemoryProfile()) {
+ if (!PGOReader->isIRLevelProfile()) {
Ctx.diagnose(DiagnosticInfoPGOProfile(
ProfileFileName.data(), "Not an IR level instrumentation profile"));
return false;
}
- if (PGOReader->hasSingleByteCoverage()) {
- Ctx.diagnose(DiagnosticInfoPGOProfile(
- ProfileFileName.data(),
- "Cannot use coverage profiles for optimization"));
- return false;
- }
if (PGOReader->functionEntryOnly()) {
Ctx.diagnose(DiagnosticInfoPGOProfile(
ProfileFileName.data(),
@@ -2123,25 +2026,25 @@ static bool annotateAllFunctions(
bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
if (PGOInstrumentEntry.getNumOccurrences() > 0)
InstrumentFuncEntry = PGOInstrumentEntry;
+ bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
for (auto &F : M) {
if (skipPGO(F))
continue;
auto &TLI = LookupTLI(F);
auto *BPI = LookupBPI(F);
auto *BFI = LookupBFI(F);
- // Split indirectbr critical edges here before computing the MST rather than
- // later in getInstrBB() to avoid invalidating it.
- SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI);
+ if (!HasSingleByteCoverage) {
+ // Split indirectbr critical edges here before computing the MST rather
+ // than later in getInstrBB() to avoid invalidating it.
+ SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI,
+ BFI);
+ }
PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS,
- InstrumentFuncEntry);
- // Read and match memprof first since we do this via debug info and can
- // match even if there is an IR mismatch detected for regular PGO below.
- if (PGOReader->hasMemoryProfile())
- Func.readMemprof(PGOReader.get());
-
- if (!PGOReader->isIRLevelProfile())
+ InstrumentFuncEntry, HasSingleByteCoverage);
+ if (HasSingleByteCoverage) {
+ Func.populateCoverage(PGOReader.get());
continue;
-
+ }
// When PseudoKind is set to a vaule other than InstrProfRecord::NotPseudo,
// it means the profile for the function is unrepresentative and this
// function is actually hot / warm. We will reset the function hot / cold
@@ -2249,21 +2152,24 @@ static bool annotateAllFunctions(
return true;
}
-PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename,
- std::string RemappingFilename,
- bool IsCS)
+PGOInstrumentationUse::PGOInstrumentationUse(
+ std::string Filename, std::string RemappingFilename, bool IsCS,
+ IntrusiveRefCntPtr<vfs::FileSystem> VFS)
: ProfileFileName(std::move(Filename)),
- ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS) {
+ ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS),
+ FS(std::move(VFS)) {
if (!PGOTestProfileFile.empty())
ProfileFileName = PGOTestProfileFile;
if (!PGOTestProfileRemappingFile.empty())
ProfileRemappingFileName = PGOTestProfileRemappingFile;
+ if (!FS)
+ FS = vfs::getRealFileSystem();
}
PreservedAnalyses PGOInstrumentationUse::run(Module &M,
- ModuleAnalysisManager &AM) {
+ ModuleAnalysisManager &MAM) {
- auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
auto LookupTLI = [&FAM](Function &F) -> TargetLibraryInfo & {
return FAM.getResult<TargetLibraryAnalysis>(F);
};
@@ -2274,9 +2180,9 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M,
return &FAM.getResult<BlockFrequencyAnalysis>(F);
};
- auto *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
+ auto *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M);
- if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName,
+ if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, *FS,
LookupTLI, LookupBPI, LookupBFI, PSI, IsCS))
return PreservedAnalyses::all();
@@ -2285,7 +2191,7 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M,
static std::string getSimpleNodeName(const BasicBlock *Node) {
if (!Node->getName().empty())
- return std::string(Node->getName());
+ return Node->getName().str();
std::string SimpleNodeName;
raw_string_ostream OS(SimpleNodeName);
@@ -2294,8 +2200,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
}
void llvm::setProfMetadata(Module *M, Instruction *TI,
- ArrayRef<uint64_t> EdgeCounts,
- uint64_t MaxCount) {
+ ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
@@ -2384,7 +2289,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
raw_string_ostream OS(Result);
OS << getSimpleNodeName(Node) << ":\\l";
- UseBBInfo *BI = Graph->findBBInfo(Node);
+ PGOUseBBInfo *BI = Graph->findBBInfo(Node);
OS << "Count : ";
if (BI && BI->CountValid)
OS << BI->CountValue << "\\l";
diff --git a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
index 35db8483fc91..2906fe190984 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
@@ -317,7 +317,7 @@ bool MemOPSizeOpt::perform(MemOp MO) {
}
if (!SeenSizeId.insert(V).second) {
- errs() << "Invalid Profile Data in Function " << Func.getName()
+ errs() << "warning: Invalid Profile Data in Function " << Func.getName()
<< ": Two identical values in MemOp value counts.\n";
return false;
}
diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
index 142b9c38e5fc..d83a3a991c89 100644
--- a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
+++ b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
@@ -15,8 +15,9 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
@@ -31,15 +32,19 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Instrumentation.h"
+#include "llvm/Support/SpecialCaseList.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/VirtualFileSystem.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <array>
#include <cstdint>
+#include <memory>
using namespace llvm;
@@ -49,7 +54,7 @@ namespace {
//===--- Constants --------------------------------------------------------===//
-constexpr uint32_t kVersionBase = 1; // occupies lower 16 bits
+constexpr uint32_t kVersionBase = 2; // occupies lower 16 bits
constexpr uint32_t kVersionPtrSizeRel = (1u << 16); // offsets are pointer-sized
constexpr int kCtorDtorPriority = 2;
@@ -59,7 +64,6 @@ class MetadataInfo {
public:
const StringRef FunctionPrefix;
const StringRef SectionSuffix;
- const uint32_t FeatureMask;
static const MetadataInfo Covered;
static const MetadataInfo Atomics;
@@ -67,16 +71,13 @@ public:
private:
// Forbid construction elsewhere.
explicit constexpr MetadataInfo(StringRef FunctionPrefix,
- StringRef SectionSuffix, uint32_t Feature)
- : FunctionPrefix(FunctionPrefix), SectionSuffix(SectionSuffix),
- FeatureMask(Feature) {}
+ StringRef SectionSuffix)
+ : FunctionPrefix(FunctionPrefix), SectionSuffix(SectionSuffix) {}
};
-const MetadataInfo MetadataInfo::Covered{"__sanitizer_metadata_covered",
- kSanitizerBinaryMetadataCoveredSection,
- kSanitizerBinaryMetadataNone};
-const MetadataInfo MetadataInfo::Atomics{"__sanitizer_metadata_atomics",
- kSanitizerBinaryMetadataAtomicsSection,
- kSanitizerBinaryMetadataAtomics};
+const MetadataInfo MetadataInfo::Covered{
+ "__sanitizer_metadata_covered", kSanitizerBinaryMetadataCoveredSection};
+const MetadataInfo MetadataInfo::Atomics{
+ "__sanitizer_metadata_atomics", kSanitizerBinaryMetadataAtomicsSection};
// The only instances of MetadataInfo are the constants above, so a set of
// them may simply store pointers to them. To deterministically generate code,
@@ -89,6 +90,11 @@ cl::opt<bool> ClWeakCallbacks(
"sanitizer-metadata-weak-callbacks",
cl::desc("Declare callbacks extern weak, and only call if non-null."),
cl::Hidden, cl::init(true));
+cl::opt<bool>
+ ClNoSanitize("sanitizer-metadata-nosanitize-attr",
+ cl::desc("Mark some metadata features uncovered in functions "
+ "with associated no_sanitize attributes."),
+ cl::Hidden, cl::init(true));
cl::opt<bool> ClEmitCovered("sanitizer-metadata-covered",
cl::desc("Emit PCs for covered functions."),
@@ -120,24 +126,20 @@ transformOptionsFromCl(SanitizerBinaryMetadataOptions &&Opts) {
class SanitizerBinaryMetadata {
public:
- SanitizerBinaryMetadata(Module &M, SanitizerBinaryMetadataOptions Opts)
+ SanitizerBinaryMetadata(Module &M, SanitizerBinaryMetadataOptions Opts,
+ std::unique_ptr<SpecialCaseList> Ignorelist)
: Mod(M), Options(transformOptionsFromCl(std::move(Opts))),
- TargetTriple(M.getTargetTriple()), IRB(M.getContext()) {
+ Ignorelist(std::move(Ignorelist)), TargetTriple(M.getTargetTriple()),
+ IRB(M.getContext()) {
// FIXME: Make it work with other formats.
assert(TargetTriple.isOSBinFormatELF() && "ELF only");
+ assert(!(TargetTriple.isNVPTX() || TargetTriple.isAMDGPU()) &&
+ "Device targets are not supported");
}
bool run();
private:
- // Return enabled feature mask of per-instruction metadata.
- uint32_t getEnabledPerInstructionFeature() const {
- uint32_t FeatureMask = 0;
- if (Options.Atomics)
- FeatureMask |= MetadataInfo::Atomics.FeatureMask;
- return FeatureMask;
- }
-
uint32_t getVersion() const {
uint32_t Version = kVersionBase;
const auto CM = Mod.getCodeModel();
@@ -156,7 +158,7 @@ private:
// to determine if a memory operation is atomic or not in modules compiled
// with SanitizerBinaryMetadata.
bool runOn(Instruction &I, MetadataInfoSet &MIS, MDBuilder &MDB,
- uint32_t &FeatureMask);
+ uint64_t &FeatureMask);
// Get start/end section marker pointer.
GlobalVariable *getSectionMarker(const Twine &MarkerName, Type *Ty);
@@ -170,10 +172,16 @@ private:
// Returns the section end marker name.
Twine getSectionEnd(StringRef SectionSuffix);
+ // Returns true if the access to the address should be considered "atomic".
+ bool pretendAtomicAccess(const Value *Addr);
+
Module &Mod;
const SanitizerBinaryMetadataOptions Options;
+ std::unique_ptr<SpecialCaseList> Ignorelist;
const Triple TargetTriple;
IRBuilder<> IRB;
+ BumpPtrAllocator Alloc;
+ UniqueStringSaver StringPool{Alloc};
};
bool SanitizerBinaryMetadata::run() {
@@ -218,17 +226,23 @@ bool SanitizerBinaryMetadata::run() {
(MI->FunctionPrefix + "_del").str(), InitTypes, InitArgs,
/*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks)
.first;
- Constant *CtorData = nullptr;
- Constant *DtorData = nullptr;
+ Constant *CtorComdatKey = nullptr;
+ Constant *DtorComdatKey = nullptr;
if (TargetTriple.supportsCOMDAT()) {
- // Use COMDAT to deduplicate constructor/destructor function.
+ // Use COMDAT to deduplicate constructor/destructor function. The COMDAT
+ // key needs to be a non-local linkage.
Ctor->setComdat(Mod.getOrInsertComdat(Ctor->getName()));
Dtor->setComdat(Mod.getOrInsertComdat(Dtor->getName()));
- CtorData = Ctor;
- DtorData = Dtor;
+ Ctor->setLinkage(GlobalValue::ExternalLinkage);
+ Dtor->setLinkage(GlobalValue::ExternalLinkage);
+ // DSOs should _not_ call another constructor/destructor!
+ Ctor->setVisibility(GlobalValue::HiddenVisibility);
+ Dtor->setVisibility(GlobalValue::HiddenVisibility);
+ CtorComdatKey = Ctor;
+ DtorComdatKey = Dtor;
}
- appendToGlobalCtors(Mod, Ctor, kCtorDtorPriority, CtorData);
- appendToGlobalDtors(Mod, Dtor, kCtorDtorPriority, DtorData);
+ appendToGlobalCtors(Mod, Ctor, kCtorDtorPriority, CtorComdatKey);
+ appendToGlobalDtors(Mod, Dtor, kCtorDtorPriority, DtorComdatKey);
}
return true;
@@ -239,6 +253,8 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) {
return;
if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation))
return;
+ if (Ignorelist && Ignorelist->inSection("metadata", "fun", F.getName()))
+ return;
// Don't touch available_externally functions, their actual body is elsewhere.
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
return;
@@ -247,18 +263,18 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) {
// The metadata features enabled for this function, stored along covered
// metadata (if enabled).
- uint32_t FeatureMask = getEnabledPerInstructionFeature();
+ uint64_t FeatureMask = 0;
// Don't emit unnecessary covered metadata for all functions to save space.
bool RequiresCovered = false;
- // We can only understand if we need to set UAR feature after looking
- // at the instructions. So we need to check instructions even if FeatureMask
- // is empty.
- if (FeatureMask || Options.UAR) {
+
+ if (Options.Atomics || Options.UAR) {
for (BasicBlock &BB : F)
for (Instruction &I : BB)
RequiresCovered |= runOn(I, MIS, MDB, FeatureMask);
}
+ if (ClNoSanitize && F.hasFnAttribute("no_sanitize_thread"))
+ FeatureMask &= ~kSanitizerBinaryMetadataAtomics;
if (F.isVarArg())
FeatureMask &= ~kSanitizerBinaryMetadataUAR;
if (FeatureMask & kSanitizerBinaryMetadataUAR) {
@@ -274,9 +290,8 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) {
const auto *MI = &MetadataInfo::Covered;
MIS.insert(MI);
const StringRef Section = getSectionName(MI->SectionSuffix);
- // The feature mask will be placed after the size (32 bit) of the function,
- // so in total one covered entry will use `sizeof(void*) + 4 + 4`.
- Constant *CFM = IRB.getInt32(FeatureMask);
+ // The feature mask will be placed after the function size.
+ Constant *CFM = IRB.getInt64(FeatureMask);
F.setMetadata(LLVMContext::MD_pcsections,
MDB.createPCSections({{Section, {CFM}}}));
}
@@ -338,23 +353,80 @@ bool useAfterReturnUnsafe(Instruction &I) {
return false;
}
+bool SanitizerBinaryMetadata::pretendAtomicAccess(const Value *Addr) {
+ if (!Addr)
+ return false;
+
+ Addr = Addr->stripInBoundsOffsets();
+ auto *GV = dyn_cast<GlobalVariable>(Addr);
+ if (!GV)
+ return false;
+
+ // Some compiler-generated accesses are known racy, to avoid false positives
+ // in data-race analysis pretend they're atomic.
+ if (GV->hasSection()) {
+ const auto OF = Triple(Mod.getTargetTriple()).getObjectFormat();
+ const auto ProfSec =
+ getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false);
+ if (GV->getSection().endswith(ProfSec))
+ return true;
+ }
+ if (GV->getName().startswith("__llvm_gcov") ||
+ GV->getName().startswith("__llvm_gcda"))
+ return true;
+
+ return false;
+}
+
+// Returns true if the memory at `Addr` may be shared with other threads.
+bool maybeSharedMutable(const Value *Addr) {
+ // By default assume memory may be shared.
+ if (!Addr)
+ return true;
+
+ if (isa<AllocaInst>(getUnderlyingObject(Addr)) &&
+ !PointerMayBeCaptured(Addr, true, true))
+ return false; // Object is on stack but does not escape.
+
+ Addr = Addr->stripInBoundsOffsets();
+ if (auto *GV = dyn_cast<GlobalVariable>(Addr)) {
+ if (GV->isConstant())
+ return false; // Shared, but not mutable.
+ }
+
+ return true;
+}
+
bool SanitizerBinaryMetadata::runOn(Instruction &I, MetadataInfoSet &MIS,
- MDBuilder &MDB, uint32_t &FeatureMask) {
+ MDBuilder &MDB, uint64_t &FeatureMask) {
SmallVector<const MetadataInfo *, 1> InstMetadata;
bool RequiresCovered = false;
+ // Only call if at least 1 type of metadata is requested.
+ assert(Options.UAR || Options.Atomics);
+
if (Options.UAR && !(FeatureMask & kSanitizerBinaryMetadataUAR)) {
if (useAfterReturnUnsafe(I))
FeatureMask |= kSanitizerBinaryMetadataUAR;
}
- if (Options.Atomics && I.mayReadOrWriteMemory()) {
- auto SSID = getAtomicSyncScopeID(&I);
- if (SSID.has_value() && *SSID != SyncScope::SingleThread) {
- NumMetadataAtomics++;
- InstMetadata.push_back(&MetadataInfo::Atomics);
+ if (Options.Atomics) {
+ const Value *Addr = nullptr;
+ if (auto *SI = dyn_cast<StoreInst>(&I))
+ Addr = SI->getPointerOperand();
+ else if (auto *LI = dyn_cast<LoadInst>(&I))
+ Addr = LI->getPointerOperand();
+
+ if (I.mayReadOrWriteMemory() && maybeSharedMutable(Addr)) {
+ auto SSID = getAtomicSyncScopeID(&I);
+ if ((SSID.has_value() && *SSID != SyncScope::SingleThread) ||
+ pretendAtomicAccess(Addr)) {
+ NumMetadataAtomics++;
+ InstMetadata.push_back(&MetadataInfo::Atomics);
+ }
+ FeatureMask |= kSanitizerBinaryMetadataAtomics;
+ RequiresCovered = true;
}
- RequiresCovered = true;
}
// Attach MD_pcsections to instruction.
@@ -381,8 +453,9 @@ SanitizerBinaryMetadata::getSectionMarker(const Twine &MarkerName, Type *Ty) {
}
StringRef SanitizerBinaryMetadata::getSectionName(StringRef SectionSuffix) {
- // FIXME: Other TargetTriple (req. string pool)
- return SectionSuffix;
+ // FIXME: Other TargetTriples.
+ // Request ULEB128 encoding for all integer constants.
+ return StringPool.save(SectionSuffix + "!C");
}
Twine SanitizerBinaryMetadata::getSectionStart(StringRef SectionSuffix) {
@@ -396,12 +469,20 @@ Twine SanitizerBinaryMetadata::getSectionEnd(StringRef SectionSuffix) {
} // namespace
SanitizerBinaryMetadataPass::SanitizerBinaryMetadataPass(
- SanitizerBinaryMetadataOptions Opts)
- : Options(std::move(Opts)) {}
+ SanitizerBinaryMetadataOptions Opts, ArrayRef<std::string> IgnorelistFiles)
+ : Options(std::move(Opts)), IgnorelistFiles(std::move(IgnorelistFiles)) {}
PreservedAnalyses
SanitizerBinaryMetadataPass::run(Module &M, AnalysisManager<Module> &AM) {
- SanitizerBinaryMetadata Pass(M, Options);
+ std::unique_ptr<SpecialCaseList> Ignorelist;
+ if (!IgnorelistFiles.empty()) {
+ Ignorelist = SpecialCaseList::createOrDie(IgnorelistFiles,
+ *vfs::getRealFileSystem());
+ if (Ignorelist->inSection("metadata", "src", M.getSourceFileName()))
+ return PreservedAnalyses::all();
+ }
+
+ SanitizerBinaryMetadata Pass(M, Options, std::move(Ignorelist));
if (Pass.run())
return PreservedAnalyses::none();
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
index 23a88c3cfba2..f22918141f6e 100644
--- a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
+++ b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
@@ -13,13 +13,12 @@
#include "llvm/Transforms/Instrumentation/SanitizerCoverage.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/Triple.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
@@ -28,11 +27,10 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SpecialCaseList.h"
#include "llvm/Support/VirtualFileSystem.h"
-#include "llvm/Transforms/Instrumentation.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -250,10 +248,6 @@ private:
std::pair<Value *, Value *> CreateSecStartEnd(Module &M, const char *Section,
Type *Ty);
- void SetNoSanitizeMetadata(Instruction *I) {
- I->setMetadata(LLVMContext::MD_nosanitize, MDNode::get(*C, std::nullopt));
- }
-
std::string getSectionName(const std::string &Section) const;
std::string getSectionStart(const std::string &Section) const;
std::string getSectionEnd(const std::string &Section) const;
@@ -809,7 +803,7 @@ void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls(
assert(Options.TracePC || Options.TracePCGuard ||
Options.Inline8bitCounters || Options.InlineBoolFlag);
for (auto *I : IndirCalls) {
- IRBuilder<> IRB(I);
+ InstrumentationIRBuilder IRB(I);
CallBase &CB = cast<CallBase>(*I);
Value *Callee = CB.getCalledOperand();
if (isa<InlineAsm>(Callee))
@@ -826,7 +820,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch(
Function &, ArrayRef<Instruction *> SwitchTraceTargets) {
for (auto *I : SwitchTraceTargets) {
if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) {
- IRBuilder<> IRB(I);
+ InstrumentationIRBuilder IRB(I);
SmallVector<Constant *, 16> Initializers;
Value *Cond = SI->getCondition();
if (Cond->getType()->getScalarSizeInBits() >
@@ -864,7 +858,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch(
void ModuleSanitizerCoverage::InjectTraceForDiv(
Function &, ArrayRef<BinaryOperator *> DivTraceTargets) {
for (auto *BO : DivTraceTargets) {
- IRBuilder<> IRB(BO);
+ InstrumentationIRBuilder IRB(BO);
Value *A1 = BO->getOperand(1);
if (isa<ConstantInt>(A1)) continue;
if (!A1->getType()->isIntegerTy())
@@ -882,7 +876,7 @@ void ModuleSanitizerCoverage::InjectTraceForDiv(
void ModuleSanitizerCoverage::InjectTraceForGep(
Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) {
for (auto *GEP : GepTraceTargets) {
- IRBuilder<> IRB(GEP);
+ InstrumentationIRBuilder IRB(GEP);
for (Use &Idx : GEP->indices())
if (!isa<ConstantInt>(Idx) && Idx->getType()->isIntegerTy())
IRB.CreateCall(SanCovTraceGepFunction,
@@ -904,7 +898,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores(
Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy,
Int128PtrTy};
for (auto *LI : Loads) {
- IRBuilder<> IRB(LI);
+ InstrumentationIRBuilder IRB(LI);
auto Ptr = LI->getPointerOperand();
int Idx = CallbackIdx(LI->getType());
if (Idx < 0)
@@ -913,7 +907,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores(
IRB.CreatePointerCast(Ptr, PointerType[Idx]));
}
for (auto *SI : Stores) {
- IRBuilder<> IRB(SI);
+ InstrumentationIRBuilder IRB(SI);
auto Ptr = SI->getPointerOperand();
int Idx = CallbackIdx(SI->getValueOperand()->getType());
if (Idx < 0)
@@ -927,7 +921,7 @@ void ModuleSanitizerCoverage::InjectTraceForCmp(
Function &, ArrayRef<Instruction *> CmpTraceTargets) {
for (auto *I : CmpTraceTargets) {
if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) {
- IRBuilder<> IRB(ICMP);
+ InstrumentationIRBuilder IRB(ICMP);
Value *A0 = ICMP->getOperand(0);
Value *A1 = ICMP->getOperand(1);
if (!A0->getType()->isIntegerTy())
@@ -994,8 +988,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
auto Load = IRB.CreateLoad(Int8Ty, CounterPtr);
auto Inc = IRB.CreateAdd(Load, ConstantInt::get(Int8Ty, 1));
auto Store = IRB.CreateStore(Inc, CounterPtr);
- SetNoSanitizeMetadata(Load);
- SetNoSanitizeMetadata(Store);
+ Load->setNoSanitizeMetadata();
+ Store->setNoSanitizeMetadata();
}
if (Options.InlineBoolFlag) {
auto FlagPtr = IRB.CreateGEP(
@@ -1006,8 +1000,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
SplitBlockAndInsertIfThen(IRB.CreateIsNull(Load), &*IP, false);
IRBuilder<> ThenIRB(ThenTerm);
auto Store = ThenIRB.CreateStore(ConstantInt::getTrue(Int1Ty), FlagPtr);
- SetNoSanitizeMetadata(Load);
- SetNoSanitizeMetadata(Store);
+ Load->setNoSanitizeMetadata();
+ Store->setNoSanitizeMetadata();
}
if (Options.StackDepth && IsEntryBB && !IsLeafFunc) {
// Check stack depth. If it's the deepest so far, record it.
@@ -1023,8 +1017,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
auto ThenTerm = SplitBlockAndInsertIfThen(IsStackLower, &*IP, false);
IRBuilder<> ThenIRB(ThenTerm);
auto Store = ThenIRB.CreateStore(FrameAddrInt, SanCovLowestStack);
- SetNoSanitizeMetadata(LowestStack);
- SetNoSanitizeMetadata(Store);
+ LowestStack->setNoSanitizeMetadata();
+ Store->setNoSanitizeMetadata();
}
}
diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
index a127e81ce643..ce35eefb63fa 100644
--- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
@@ -689,7 +689,7 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) {
// replaced back with intrinsics. If that becomes wrong at some point,
// we will need to call e.g. __tsan_memset to avoid the intrinsics.
bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) {
- IRBuilder<> IRB(I);
+ InstrumentationIRBuilder IRB(I);
if (MemSetInst *M = dyn_cast<MemSetInst>(I)) {
IRB.CreateCall(
MemsetFn,
@@ -813,8 +813,6 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) {
int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr,
const DataLayout &DL) {
assert(OrigTy->isSized());
- assert(
- cast<PointerType>(Addr->getType())->isOpaqueOrPointeeTypeMatches(OrigTy));
uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
if (TypeSize != 8 && TypeSize != 16 &&
TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
@@ -822,7 +820,7 @@ int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr,
// Ignore all unusual sizes.
return -1;
}
- size_t Idx = countTrailingZeros(TypeSize / 8);
+ size_t Idx = llvm::countr_zero(TypeSize / 8);
assert(Idx < kNumberOfAccessSizes);
return Idx;
}
diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/llvm/lib/Transforms/ObjCARC/ObjCARC.h
index d4570ff908f1..9e68bd574851 100644
--- a/llvm/lib/Transforms/ObjCARC/ObjCARC.h
+++ b/llvm/lib/Transforms/ObjCARC/ObjCARC.h
@@ -22,9 +22,9 @@
#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H
#define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
#include "llvm/Analysis/ObjCARCUtil.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/Transforms/Utils/Local.h"
namespace llvm {
diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp
index ab90ef090ae0..c397ab63f388 100644
--- a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp
+++ b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp
@@ -31,9 +31,9 @@
#include "ProvenanceAnalysis.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/ObjCARCUtil.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Operator.h"
diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
index a374958f9707..adf86526ebf1 100644
--- a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
+++ b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
@@ -36,7 +36,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/ObjCARCAliasAnalysis.h"
#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
#include "llvm/Analysis/ObjCARCInstKind.h"
@@ -46,6 +45,7 @@
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstIterator.h"
@@ -933,8 +933,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst,
if (IsNullOrUndef(CI->getArgOperand(0))) {
Changed = true;
new StoreInst(ConstantInt::getTrue(CI->getContext()),
- UndefValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
- Value *NewValue = UndefValue::get(CI->getType());
+ PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
+ Value *NewValue = PoisonValue::get(CI->getType());
LLVM_DEBUG(
dbgs() << "A null pointer-to-weak-pointer is undefined behavior."
"\nOld = "
@@ -952,9 +952,9 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst,
IsNullOrUndef(CI->getArgOperand(1))) {
Changed = true;
new StoreInst(ConstantInt::getTrue(CI->getContext()),
- UndefValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
+ PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
- Value *NewValue = UndefValue::get(CI->getType());
+ Value *NewValue = PoisonValue::get(CI->getType());
LLVM_DEBUG(
dbgs() << "A null pointer-to-weak-pointer is undefined behavior."
"\nOld = "
diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp
index 2fa25a79ae9d..23855231c5b9 100644
--- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp
+++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp
@@ -42,40 +42,21 @@ bool ProvenanceAnalysis::relatedSelect(const SelectInst *A,
const Value *B) {
// If the values are Selects with the same condition, we can do a more precise
// check: just check for relations between the values on corresponding arms.
- if (const SelectInst *SB = dyn_cast<SelectInst>(B)) {
+ if (const SelectInst *SB = dyn_cast<SelectInst>(B))
if (A->getCondition() == SB->getCondition())
return related(A->getTrueValue(), SB->getTrueValue()) ||
related(A->getFalseValue(), SB->getFalseValue());
- // Check both arms of B individually. Return false if neither arm is related
- // to A.
- if (!(related(SB->getTrueValue(), A) || related(SB->getFalseValue(), A)))
- return false;
- }
-
// Check both arms of the Select node individually.
return related(A->getTrueValue(), B) || related(A->getFalseValue(), B);
}
bool ProvenanceAnalysis::relatedPHI(const PHINode *A,
const Value *B) {
-
- auto comparePHISources = [this](const PHINode *PNA, const Value *B) -> bool {
- // Check each unique source of the PHI node against B.
- SmallPtrSet<const Value *, 4> UniqueSrc;
- for (Value *PV1 : PNA->incoming_values()) {
- if (UniqueSrc.insert(PV1).second && related(PV1, B))
- return true;
- }
-
- // All of the arms checked out.
- return false;
- };
-
- if (const PHINode *PNB = dyn_cast<PHINode>(B)) {
- // If the values are PHIs in the same block, we can do a more precise as
- // well as efficient check: just check for relations between the values on
- // corresponding edges.
+ // If the values are PHIs in the same block, we can do a more precise as well
+ // as efficient check: just check for relations between the values on
+ // corresponding edges.
+ if (const PHINode *PNB = dyn_cast<PHINode>(B))
if (PNB->getParent() == A->getParent()) {
for (unsigned i = 0, e = A->getNumIncomingValues(); i != e; ++i)
if (related(A->getIncomingValue(i),
@@ -84,11 +65,15 @@ bool ProvenanceAnalysis::relatedPHI(const PHINode *A,
return false;
}
- if (!comparePHISources(PNB, A))
- return false;
+ // Check each unique source of the PHI node against B.
+ SmallPtrSet<const Value *, 4> UniqueSrc;
+ for (Value *PV1 : A->incoming_values()) {
+ if (UniqueSrc.insert(PV1).second && related(PV1, B))
+ return true;
}
- return comparePHISources(A, B);
+ // All of the arms checked out.
+ return false;
}
/// Test if the value of P, or any value covered by its provenance, is ever
@@ -140,19 +125,22 @@ bool ProvenanceAnalysis::relatedCheck(const Value *A, const Value *B) {
bool BIsIdentified = IsObjCIdentifiedObject(B);
// An ObjC-Identified object can't alias a load if it is never locally stored.
-
- // Check for an obvious escape.
- if ((AIsIdentified && isa<LoadInst>(B) && !IsStoredObjCPointer(A)) ||
- (BIsIdentified && isa<LoadInst>(A) && !IsStoredObjCPointer(B)))
- return false;
-
- if ((AIsIdentified && isa<LoadInst>(B)) ||
- (BIsIdentified && isa<LoadInst>(A)))
- return true;
-
- // Both pointers are identified and escapes aren't an evident problem.
- if (AIsIdentified && BIsIdentified && !isa<LoadInst>(A) && !isa<LoadInst>(B))
- return false;
+ if (AIsIdentified) {
+ // Check for an obvious escape.
+ if (isa<LoadInst>(B))
+ return IsStoredObjCPointer(A);
+ if (BIsIdentified) {
+ // Check for an obvious escape.
+ if (isa<LoadInst>(A))
+ return IsStoredObjCPointer(B);
+ // Both pointers are identified and escapes aren't an evident problem.
+ return false;
+ }
+ } else if (BIsIdentified) {
+ // Check for an obvious escape.
+ if (isa<LoadInst>(A))
+ return IsStoredObjCPointer(B);
+ }
// Special handling for PHI and Select.
if (const PHINode *PN = dyn_cast<PHINode>(A))
@@ -179,15 +167,12 @@ bool ProvenanceAnalysis::related(const Value *A, const Value *B) {
// Begin by inserting a conservative value into the map. If the insertion
// fails, we have the answer already. If it succeeds, leave it there until we
// compute the real answer to guard against recursive queries.
- if (A > B) std::swap(A, B);
std::pair<CachedResultsTy::iterator, bool> Pair =
CachedResults.insert(std::make_pair(ValuePairTy(A, B), true));
if (!Pair.second)
return Pair.first->second;
bool Result = relatedCheck(A, B);
- assert(relatedCheck(B, A) == Result &&
- "relatedCheck result depending on order of parameters!");
CachedResults[ValuePairTy(A, B)] = Result;
return Result;
}
diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp
index 253293582945..24354211341f 100644
--- a/llvm/lib/Transforms/Scalar/ADCE.cpp
+++ b/llvm/lib/Transforms/Scalar/ADCE.cpp
@@ -26,6 +26,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/IteratedDominanceFrontier.h"
+#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
@@ -42,14 +43,11 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstddef>
@@ -113,6 +111,12 @@ struct BlockInfoType {
bool terminatorIsLive() const { return TerminatorLiveInfo->Live; }
};
+struct ADCEChanged {
+ bool ChangedAnything = false;
+ bool ChangedNonDebugInstr = false;
+ bool ChangedControlFlow = false;
+};
+
class AggressiveDeadCodeElimination {
Function &F;
@@ -179,7 +183,7 @@ class AggressiveDeadCodeElimination {
/// Remove instructions not marked live, return if any instruction was
/// removed.
- bool removeDeadInstructions();
+ ADCEChanged removeDeadInstructions();
/// Identify connected sections of the control flow graph which have
/// dead terminators and rewrite the control flow graph to remove them.
@@ -197,12 +201,12 @@ public:
PostDominatorTree &PDT)
: F(F), DT(DT), PDT(PDT) {}
- bool performDeadCodeElimination();
+ ADCEChanged performDeadCodeElimination();
};
} // end anonymous namespace
-bool AggressiveDeadCodeElimination::performDeadCodeElimination() {
+ADCEChanged AggressiveDeadCodeElimination::performDeadCodeElimination() {
initialize();
markLiveInstructions();
return removeDeadInstructions();
@@ -504,9 +508,10 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() {
// Routines to update the CFG and SSA information before removing dead code.
//
//===----------------------------------------------------------------------===//
-bool AggressiveDeadCodeElimination::removeDeadInstructions() {
+ADCEChanged AggressiveDeadCodeElimination::removeDeadInstructions() {
+ ADCEChanged Changed;
// Updates control and dataflow around dead blocks
- bool RegionsUpdated = updateDeadRegions();
+ Changed.ChangedControlFlow = updateDeadRegions();
LLVM_DEBUG({
for (Instruction &I : instructions(F)) {
@@ -554,6 +559,8 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() {
continue;
// Fallthrough and drop the intrinsic.
+ } else {
+ Changed.ChangedNonDebugInstr = true;
}
// Prepare to delete.
@@ -569,7 +576,9 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() {
I->eraseFromParent();
}
- return !Worklist.empty() || RegionsUpdated;
+ Changed.ChangedAnything = Changed.ChangedControlFlow || !Worklist.empty();
+
+ return Changed;
}
// A dead region is the set of dead blocks with a common live post-dominator.
@@ -699,62 +708,25 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) {
// to update analysis if it is already available.
auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F);
auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
- if (!AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination())
+ ADCEChanged Changed =
+ AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination();
+ if (!Changed.ChangedAnything)
return PreservedAnalyses::all();
PreservedAnalyses PA;
- // TODO: We could track if we have actually done CFG changes.
- if (!RemoveControlFlowFlag)
+ if (!Changed.ChangedControlFlow) {
PA.preserveSet<CFGAnalyses>();
- else {
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<PostDominatorTreeAnalysis>();
- }
- return PA;
-}
-
-namespace {
-
-struct ADCELegacyPass : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
-
- ADCELegacyPass() : FunctionPass(ID) {
- initializeADCELegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- // ADCE does not need DominatorTree, but require DominatorTree here
- // to update analysis if it is already available.
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
- return AggressiveDeadCodeElimination(F, DT, PDT)
- .performDeadCodeElimination();
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<PostDominatorTreeWrapperPass>();
- if (!RemoveControlFlowFlag)
- AU.setPreservesCFG();
- else {
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<PostDominatorTreeWrapperPass>();
+ if (!Changed.ChangedNonDebugInstr) {
+ // Only removing debug instructions does not affect MemorySSA.
+ //
+ // Therefore we preserve MemorySSA when only removing debug instructions
+ // since otherwise later passes may behave differently which then makes
+ // the presence of debug info affect code generation.
+ PA.preserve<MemorySSAAnalysis>();
}
- AU.addPreserved<GlobalsAAWrapperPass>();
}
-};
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<PostDominatorTreeAnalysis>();
-} // end anonymous namespace
-
-char ADCELegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce",
- "Aggressive Dead Code Elimination", false, false)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination",
- false, false)
-
-FunctionPass *llvm::createAggressiveDCEPass() { return new ADCELegacyPass(); }
+ return PA;
+}
diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index f419f7bd769f..b259c76fc3a5 100644
--- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -28,13 +28,10 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
-#define AA_NAME "alignment-from-assumptions"
-#define DEBUG_TYPE AA_NAME
+#define DEBUG_TYPE "alignment-from-assumptions"
using namespace llvm;
STATISTIC(NumLoadAlignChanged,
@@ -44,46 +41,6 @@ STATISTIC(NumStoreAlignChanged,
STATISTIC(NumMemIntAlignChanged,
"Number of memory intrinsics changed by alignment assumptions");
-namespace {
-struct AlignmentFromAssumptions : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- AlignmentFromAssumptions() : FunctionPass(ID) {
- initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
-
- AU.setPreservesCFG();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<ScalarEvolutionWrapperPass>();
- }
-
- AlignmentFromAssumptionsPass Impl;
-};
-}
-
-char AlignmentFromAssumptions::ID = 0;
-static const char aip_name[] = "Alignment from assumptions";
-INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
- aip_name, false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
- aip_name, false, false)
-
-FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
- return new AlignmentFromAssumptions();
-}
-
// Given an expression for the (constant) alignment, AlignSCEV, and an
// expression for the displacement between a pointer and the aligned address,
// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
@@ -317,17 +274,6 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
return true;
}
-bool AlignmentFromAssumptions::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
-
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-
- return Impl.runImpl(F, AC, SE, DT);
-}
-
bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
ScalarEvolution *SE_,
DominatorTree *DT_) {
diff --git a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp
index 79f7e253d45b..b182f46cc515 100644
--- a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp
+++ b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp
@@ -16,7 +16,6 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/MemoryOpRemark.h"
using namespace llvm;
@@ -58,7 +57,12 @@ static void runImpl(Function &F, const TargetLibraryInfo &TLI) {
for (const MDOperand &Op :
I.getMetadata(LLVMContext::MD_annotation)->operands()) {
- auto Iter = Mapping.insert({cast<MDString>(Op.get())->getString(), 0});
+ StringRef AnnotationStr =
+ isa<MDString>(Op.get())
+ ? cast<MDString>(Op.get())->getString()
+ : cast<MDString>(cast<MDTuple>(Op.get())->getOperand(0).get())
+ ->getString();
+ auto Iter = Mapping.insert({AnnotationStr, 0});
Iter.first->second++;
}
}
diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp
index 187927b3dede..1fa2c75b0f42 100644
--- a/llvm/lib/Transforms/Scalar/BDCE.cpp
+++ b/llvm/lib/Transforms/Scalar/BDCE.cpp
@@ -23,11 +23,8 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
@@ -116,7 +113,7 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) {
const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits();
auto *const DstTy = SE->getDestTy();
const uint32_t DestBitSize = DstTy->getScalarSizeInBits();
- if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) {
+ if (Demanded.countl_zero() >= (DestBitSize - SrcBitSize)) {
clearAssumptionsOfUsers(SE, DB);
IRBuilder<> Builder(SE);
I.replaceAllUsesWith(
@@ -173,34 +170,3 @@ PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserveSet<CFGAnalyses>();
return PA;
}
-
-namespace {
-struct BDCELegacyPass : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- BDCELegacyPass() : FunctionPass(ID) {
- initializeBDCELegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- auto &DB = getAnalysis<DemandedBitsWrapperPass>().getDemandedBits();
- return bitTrackingDCE(F, DB);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<DemandedBitsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-}
-
-char BDCELegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(BDCELegacyPass, "bdce",
- "Bit-Tracking Dead Code Elimination", false, false)
-INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)
-INITIALIZE_PASS_END(BDCELegacyPass, "bdce",
- "Bit-Tracking Dead Code Elimination", false, false)
-
-FunctionPass *llvm::createBitTrackingDCEPass() { return new BDCELegacyPass(); }
diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index 6665a927826d..aeb7c5d461f0 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -535,45 +535,6 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI,
return Changed;
}
-namespace {
-struct CallSiteSplittingLegacyPass : public FunctionPass {
- static char ID;
- CallSiteSplittingLegacyPass() : FunctionPass(ID) {
- initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- FunctionPass::getAnalysisUsage(AU);
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- return doCallSiteSplitting(F, TLI, TTI, DT);
- }
-};
-} // namespace
-
-char CallSiteSplittingLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting",
- "Call-site splitting", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting",
- "Call-site splitting", false, false)
-FunctionPass *llvm::createCallSiteSplittingPass() {
- return new CallSiteSplittingLegacyPass();
-}
-
PreservedAnalyses CallSiteSplittingPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 8858545bbc5d..611e64bd0976 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -155,16 +155,19 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) {
Fn.getEntryBlock(),
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
- if (MadeChange) {
- LLVM_DEBUG(dbgs() << "********** Function after Constant Hoisting: "
- << Fn.getName() << '\n');
- LLVM_DEBUG(dbgs() << Fn);
- }
LLVM_DEBUG(dbgs() << "********** End Constant Hoisting **********\n");
return MadeChange;
}
+void ConstantHoistingPass::collectMatInsertPts(
+ const RebasedConstantListType &RebasedConstants,
+ SmallVectorImpl<Instruction *> &MatInsertPts) const {
+ for (const RebasedConstantInfo &RCI : RebasedConstants)
+ for (const ConstantUser &U : RCI.Uses)
+ MatInsertPts.emplace_back(findMatInsertPt(U.Inst, U.OpndIdx));
+}
+
/// Find the constant materialization insertion point.
Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst,
unsigned Idx) const {
@@ -312,14 +315,15 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
/// Find an insertion point that dominates all uses.
SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint(
- const ConstantInfo &ConstInfo) const {
+ const ConstantInfo &ConstInfo,
+ const ArrayRef<Instruction *> MatInsertPts) const {
assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry.");
// Collect all basic blocks.
SetVector<BasicBlock *> BBs;
SetVector<Instruction *> InsertPts;
- for (auto const &RCI : ConstInfo.RebasedConstants)
- for (auto const &U : RCI.Uses)
- BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent());
+
+ for (Instruction *MatInsertPt : MatInsertPts)
+ BBs.insert(MatInsertPt->getParent());
if (BBs.count(Entry)) {
InsertPts.insert(&Entry->front());
@@ -328,12 +332,8 @@ SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint(
if (BFI) {
findBestInsertionSet(*DT, *BFI, Entry, BBs);
- for (auto *BB : BBs) {
- BasicBlock::iterator InsertPt = BB->begin();
- for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt)
- ;
- InsertPts.insert(&*InsertPt);
- }
+ for (BasicBlock *BB : BBs)
+ InsertPts.insert(&*BB->getFirstInsertionPt());
return InsertPts;
}
@@ -410,8 +410,8 @@ void ConstantHoistingPass::collectConstantCandidates(
// Get offset from the base GV.
PointerType *GVPtrTy = cast<PointerType>(BaseGV->getType());
- IntegerType *PtrIntTy = DL->getIntPtrType(*Ctx, GVPtrTy->getAddressSpace());
- APInt Offset(DL->getTypeSizeInBits(PtrIntTy), /*val*/0, /*isSigned*/true);
+ IntegerType *OffsetTy = DL->getIndexType(*Ctx, GVPtrTy->getAddressSpace());
+ APInt Offset(DL->getTypeSizeInBits(OffsetTy), /*val*/ 0, /*isSigned*/ true);
auto *GEPO = cast<GEPOperator>(ConstExpr);
// TODO: If we have a mix of inbounds and non-inbounds GEPs, then basing a
@@ -432,7 +432,7 @@ void ConstantHoistingPass::collectConstantCandidates(
// to be cheaper than compute it by <Base + Offset>, which can be lowered to
// an ADD instruction or folded into Load/Store instruction.
InstructionCost Cost =
- TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy,
+ TTI->getIntImmCostInst(Instruction::Add, 1, Offset, OffsetTy,
TargetTransformInfo::TCK_SizeAndLatency, Inst);
ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV];
ConstCandMapType::iterator Itr;
@@ -751,45 +751,41 @@ static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) {
/// Emit materialization code for all rebased constants and update their
/// users.
void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
- Constant *Offset,
- Type *Ty,
- const ConstantUser &ConstUser) {
+ UserAdjustment *Adj) {
Instruction *Mat = Base;
// The same offset can be dereferenced to different types in nested struct.
- if (!Offset && Ty && Ty != Base->getType())
- Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0);
+ if (!Adj->Offset && Adj->Ty && Adj->Ty != Base->getType())
+ Adj->Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0);
- if (Offset) {
- Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst,
- ConstUser.OpndIdx);
- if (Ty) {
+ if (Adj->Offset) {
+ if (Adj->Ty) {
// Constant being rebased is a ConstantExpr.
- PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx,
- cast<PointerType>(Ty)->getAddressSpace());
- Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", InsertionPt);
- Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base,
- Offset, "mat_gep", InsertionPt);
- Mat = new BitCastInst(Mat, Ty, "mat_bitcast", InsertionPt);
+ PointerType *Int8PtrTy = Type::getInt8PtrTy(
+ *Ctx, cast<PointerType>(Adj->Ty)->getAddressSpace());
+ Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", Adj->MatInsertPt);
+ Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset,
+ "mat_gep", Adj->MatInsertPt);
+ Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt);
} else
// Constant being rebased is a ConstantInt.
- Mat = BinaryOperator::Create(Instruction::Add, Base, Offset,
- "const_mat", InsertionPt);
+ Mat = BinaryOperator::Create(Instruction::Add, Base, Adj->Offset,
+ "const_mat", Adj->MatInsertPt);
LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0)
- << " + " << *Offset << ") in BB "
+ << " + " << *Adj->Offset << ") in BB "
<< Mat->getParent()->getName() << '\n'
<< *Mat << '\n');
- Mat->setDebugLoc(ConstUser.Inst->getDebugLoc());
+ Mat->setDebugLoc(Adj->User.Inst->getDebugLoc());
}
- Value *Opnd = ConstUser.Inst->getOperand(ConstUser.OpndIdx);
+ Value *Opnd = Adj->User.Inst->getOperand(Adj->User.OpndIdx);
// Visit constant integer.
if (isa<ConstantInt>(Opnd)) {
- LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
- if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat) && Offset)
+ LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
+ if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat) && Adj->Offset)
Mat->eraseFromParent();
- LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n');
+ LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n');
return;
}
@@ -809,9 +805,9 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
<< "To : " << *ClonedCastInst << '\n');
}
- LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
- updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ClonedCastInst);
- LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n');
+ LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
+ updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ClonedCastInst);
+ LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n');
return;
}
@@ -819,28 +815,27 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) {
if (isa<GEPOperator>(ConstExpr)) {
// Operand is a ConstantGEP, replace it.
- updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat);
+ updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat);
return;
}
// Aside from constant GEPs, only constant cast expressions are collected.
assert(ConstExpr->isCast() && "ConstExpr should be a cast");
- Instruction *ConstExprInst = ConstExpr->getAsInstruction(
- findMatInsertPt(ConstUser.Inst, ConstUser.OpndIdx));
+ Instruction *ConstExprInst = ConstExpr->getAsInstruction(Adj->MatInsertPt);
ConstExprInst->setOperand(0, Mat);
// Use the same debug location as the instruction we are about to update.
- ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc());
+ ConstExprInst->setDebugLoc(Adj->User.Inst->getDebugLoc());
LLVM_DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n'
<< "From : " << *ConstExpr << '\n');
- LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
- if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ConstExprInst)) {
+ LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
+ if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ConstExprInst)) {
ConstExprInst->eraseFromParent();
- if (Offset)
+ if (Adj->Offset)
Mat->eraseFromParent();
}
- LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n');
+ LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n');
return;
}
}
@@ -851,8 +846,11 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
bool MadeChange = false;
SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec =
BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec;
- for (auto const &ConstInfo : ConstInfoVec) {
- SetVector<Instruction *> IPSet = findConstantInsertionPoint(ConstInfo);
+ for (const consthoist::ConstantInfo &ConstInfo : ConstInfoVec) {
+ SmallVector<Instruction *, 4> MatInsertPts;
+ collectMatInsertPts(ConstInfo.RebasedConstants, MatInsertPts);
+ SetVector<Instruction *> IPSet =
+ findConstantInsertionPoint(ConstInfo, MatInsertPts);
// We can have an empty set if the function contains unreachable blocks.
if (IPSet.empty())
continue;
@@ -862,22 +860,21 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
unsigned NotRebasedNum = 0;
for (Instruction *IP : IPSet) {
// First, collect constants depending on this IP of the base.
- unsigned Uses = 0;
- using RebasedUse = std::tuple<Constant *, Type *, ConstantUser>;
- SmallVector<RebasedUse, 4> ToBeRebased;
+ UsesNum = 0;
+ SmallVector<UserAdjustment, 4> ToBeRebased;
+ unsigned MatCtr = 0;
for (auto const &RCI : ConstInfo.RebasedConstants) {
+ UsesNum += RCI.Uses.size();
for (auto const &U : RCI.Uses) {
- Uses++;
- BasicBlock *OrigMatInsertBB =
- findMatInsertPt(U.Inst, U.OpndIdx)->getParent();
+ Instruction *MatInsertPt = MatInsertPts[MatCtr++];
+ BasicBlock *OrigMatInsertBB = MatInsertPt->getParent();
// If Base constant is to be inserted in multiple places,
// generate rebase for U using the Base dominating U.
if (IPSet.size() == 1 ||
DT->dominates(IP->getParent(), OrigMatInsertBB))
- ToBeRebased.push_back(RebasedUse(RCI.Offset, RCI.Ty, U));
+ ToBeRebased.emplace_back(RCI.Offset, RCI.Ty, MatInsertPt, U);
}
}
- UsesNum = Uses;
// If only few constants depend on this IP of base, skip rebasing,
// assuming the base and the rebased have the same materialization cost.
@@ -905,15 +902,12 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
<< *Base << '\n');
// Emit materialization code for rebased constants depending on this IP.
- for (auto const &R : ToBeRebased) {
- Constant *Off = std::get<0>(R);
- Type *Ty = std::get<1>(R);
- ConstantUser U = std::get<2>(R);
- emitBaseConstants(Base, Off, Ty, U);
+ for (UserAdjustment &R : ToBeRebased) {
+ emitBaseConstants(Base, &R);
ReBasesNum++;
// Use the same debug location as the last user of the constant.
Base->setDebugLoc(DILocation::getMergedLocation(
- Base->getDebugLoc(), U.Inst->getDebugLoc()));
+ Base->getDebugLoc(), R.User.Inst->getDebugLoc()));
}
assert(!Base->use_empty() && "The use list is empty!?");
assert(isa<Instruction>(Base->user_back()) &&
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 12fcb6aa9846..15628d32280d 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstraintSystem.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
@@ -26,13 +27,18 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugCounter.h"
+#include "llvm/Support/KnownBits.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
#include <cmath>
+#include <optional>
#include <string>
using namespace llvm;
@@ -48,6 +54,10 @@ static cl::opt<unsigned>
MaxRows("constraint-elimination-max-rows", cl::init(500), cl::Hidden,
cl::desc("Maximum number of rows to keep in constraint system"));
+static cl::opt<bool> DumpReproducers(
+ "constraint-elimination-dump-reproducers", cl::init(false), cl::Hidden,
+ cl::desc("Dump IR to reproduce successful transformations."));
+
static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max();
static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min();
@@ -65,7 +75,86 @@ static int64_t addWithOverflow(int64_t A, int64_t B) {
return Result;
}
+static Instruction *getContextInstForUse(Use &U) {
+ Instruction *UserI = cast<Instruction>(U.getUser());
+ if (auto *Phi = dyn_cast<PHINode>(UserI))
+ UserI = Phi->getIncomingBlock(U)->getTerminator();
+ return UserI;
+}
+
namespace {
+/// Represents either
+/// * a condition that holds on entry to a block (=conditional fact)
+/// * an assume (=assume fact)
+/// * a use of a compare instruction to simplify.
+/// It also tracks the Dominator DFS in and out numbers for each entry.
+struct FactOrCheck {
+ union {
+ Instruction *Inst;
+ Use *U;
+ };
+ unsigned NumIn;
+ unsigned NumOut;
+ bool HasInst;
+ bool Not;
+
+ FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool Not)
+ : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
+ HasInst(true), Not(Not) {}
+
+ FactOrCheck(DomTreeNode *DTN, Use *U)
+ : U(U), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
+ HasInst(false), Not(false) {}
+
+ static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst,
+ bool Not = false) {
+ return FactOrCheck(DTN, Inst, Not);
+ }
+
+ static FactOrCheck getCheck(DomTreeNode *DTN, Use *U) {
+ return FactOrCheck(DTN, U);
+ }
+
+ static FactOrCheck getCheck(DomTreeNode *DTN, CallInst *CI) {
+ return FactOrCheck(DTN, CI, false);
+ }
+
+ bool isCheck() const {
+ return !HasInst ||
+ match(Inst, m_Intrinsic<Intrinsic::ssub_with_overflow>());
+ }
+
+ Instruction *getContextInst() const {
+ if (HasInst)
+ return Inst;
+ return getContextInstForUse(*U);
+ }
+ Instruction *getInstructionToSimplify() const {
+ assert(isCheck());
+ if (HasInst)
+ return Inst;
+ // The use may have been simplified to a constant already.
+ return dyn_cast<Instruction>(*U);
+ }
+ bool isConditionFact() const { return !isCheck() && isa<CmpInst>(Inst); }
+};
+
+/// Keep state required to build worklist.
+struct State {
+ DominatorTree &DT;
+ SmallVector<FactOrCheck, 64> WorkList;
+
+ State(DominatorTree &DT) : DT(DT) {}
+
+ /// Process block \p BB and add known facts to work-list.
+ void addInfoFor(BasicBlock &BB);
+
+ /// Returns true if we can add a known condition from BB to its successor
+ /// block Succ.
+ bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const {
+ return DT.dominates(BasicBlockEdge(&BB, Succ), Succ);
+ }
+};
class ConstraintInfo;
@@ -100,12 +189,13 @@ struct ConstraintTy {
SmallVector<SmallVector<int64_t, 8>> ExtraInfo;
bool IsSigned = false;
- bool IsEq = false;
ConstraintTy() = default;
- ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned)
- : Coefficients(Coefficients), IsSigned(IsSigned) {}
+ ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned, bool IsEq,
+ bool IsNe)
+ : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq), IsNe(IsNe) {
+ }
unsigned size() const { return Coefficients.size(); }
@@ -114,6 +204,21 @@ struct ConstraintTy {
/// Returns true if all preconditions for this list of constraints are
/// satisfied given \p CS and the corresponding \p Value2Index mapping.
bool isValid(const ConstraintInfo &Info) const;
+
+ bool isEq() const { return IsEq; }
+
+ bool isNe() const { return IsNe; }
+
+ /// Check if the current constraint is implied by the given ConstraintSystem.
+ ///
+ /// \return true or false if the constraint is proven to be respectively true,
+ /// or false. When the constraint cannot be proven to be either true or false,
+ /// std::nullopt is returned.
+ std::optional<bool> isImpliedBy(const ConstraintSystem &CS) const;
+
+private:
+ bool IsEq = false;
+ bool IsNe = false;
};
/// Wrapper encapsulating separate constraint systems and corresponding value
@@ -123,8 +228,6 @@ struct ConstraintTy {
/// based on signed-ness, certain conditions can be transferred between the two
/// systems.
class ConstraintInfo {
- DenseMap<Value *, unsigned> UnsignedValue2Index;
- DenseMap<Value *, unsigned> SignedValue2Index;
ConstraintSystem UnsignedCS;
ConstraintSystem SignedCS;
@@ -132,13 +235,14 @@ class ConstraintInfo {
const DataLayout &DL;
public:
- ConstraintInfo(const DataLayout &DL) : DL(DL) {}
+ ConstraintInfo(const DataLayout &DL, ArrayRef<Value *> FunctionArgs)
+ : UnsignedCS(FunctionArgs), SignedCS(FunctionArgs), DL(DL) {}
DenseMap<Value *, unsigned> &getValue2Index(bool Signed) {
- return Signed ? SignedValue2Index : UnsignedValue2Index;
+ return Signed ? SignedCS.getValue2Index() : UnsignedCS.getValue2Index();
}
const DenseMap<Value *, unsigned> &getValue2Index(bool Signed) const {
- return Signed ? SignedValue2Index : UnsignedValue2Index;
+ return Signed ? SignedCS.getValue2Index() : UnsignedCS.getValue2Index();
}
ConstraintSystem &getCS(bool Signed) {
@@ -235,9 +339,8 @@ static bool canUseSExt(ConstantInt *CI) {
}
static Decomposition
-decomposeGEP(GetElementPtrInst &GEP,
- SmallVectorImpl<PreconditionTy> &Preconditions, bool IsSigned,
- const DataLayout &DL) {
+decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions,
+ bool IsSigned, const DataLayout &DL) {
// Do not reason about pointers where the index size is larger than 64 bits,
// as the coefficients used to encode constraints are 64 bit integers.
if (DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()) > 64)
@@ -257,7 +360,7 @@ decomposeGEP(GetElementPtrInst &GEP,
// Handle the (gep (gep ....), C) case by incrementing the constant
// coefficient of the inner GEP, if C is a constant.
- auto *InnerGEP = dyn_cast<GetElementPtrInst>(GEP.getPointerOperand());
+ auto *InnerGEP = dyn_cast<GEPOperator>(GEP.getPointerOperand());
if (VariableOffsets.empty() && InnerGEP && InnerGEP->getNumOperands() == 2) {
auto Result = decompose(InnerGEP, Preconditions, IsSigned, DL);
Result.add(ConstantOffset.getSExtValue());
@@ -320,6 +423,13 @@ static Decomposition decompose(Value *V,
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1))))
return MergeResults(Op0, Op1, IsSigned);
+ ConstantInt *CI;
+ if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI)))) {
+ auto Result = decompose(Op0, Preconditions, IsSigned, DL);
+ Result.mul(CI->getSExtValue());
+ return Result;
+ }
+
return V;
}
@@ -329,7 +439,7 @@ static Decomposition decompose(Value *V,
return int64_t(CI->getZExtValue());
}
- if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
+ if (auto *GEP = dyn_cast<GEPOperator>(V))
return decomposeGEP(*GEP, Preconditions, IsSigned, DL);
Value *Op0;
@@ -363,10 +473,17 @@ static Decomposition decompose(Value *V,
return MergeResults(Op0, CI, true);
}
+ // Decompose or as an add if there are no common bits between the operands.
+ if (match(V, m_Or(m_Value(Op0), m_ConstantInt(CI))) &&
+ haveNoCommonBitsSet(Op0, CI, DL)) {
+ return MergeResults(Op0, CI, IsSigned);
+ }
+
if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) {
- int64_t Mult = int64_t(std::pow(int64_t(2), CI->getSExtValue()));
+ if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64)
+ return {V, IsKnownNonNegative};
auto Result = decompose(Op1, Preconditions, IsSigned, DL);
- Result.mul(Mult);
+ Result.mul(int64_t{1} << CI->getSExtValue());
return Result;
}
@@ -390,6 +507,8 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
SmallVectorImpl<Value *> &NewVariables) const {
assert(NewVariables.empty() && "NewVariables must be empty when passed in");
bool IsEq = false;
+ bool IsNe = false;
+
// Try to convert Pred to one of ULE/SLT/SLE/SLT.
switch (Pred) {
case CmpInst::ICMP_UGT:
@@ -409,10 +528,13 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
}
break;
case CmpInst::ICMP_NE:
- if (!match(Op1, m_Zero()))
- return {};
- Pred = CmpInst::getSwappedPredicate(CmpInst::ICMP_UGT);
- std::swap(Op0, Op1);
+ if (match(Op1, m_Zero())) {
+ Pred = CmpInst::getSwappedPredicate(CmpInst::ICMP_UGT);
+ std::swap(Op0, Op1);
+ } else {
+ IsNe = true;
+ Pred = CmpInst::ICMP_ULE;
+ }
break;
default:
break;
@@ -459,11 +581,10 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
// subtracting all coefficients from B.
ConstraintTy Res(
SmallVector<int64_t, 8>(Value2Index.size() + NewVariables.size() + 1, 0),
- IsSigned);
+ IsSigned, IsEq, IsNe);
// Collect variables that are known to be positive in all uses in the
// constraint.
DenseMap<Value *, bool> KnownNonNegativeVariables;
- Res.IsEq = IsEq;
auto &R = Res.Coefficients;
for (const auto &KV : VariablesA) {
R[GetOrAddIndex(KV.Variable)] += KV.Coefficient;
@@ -473,7 +594,9 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
}
for (const auto &KV : VariablesB) {
- R[GetOrAddIndex(KV.Variable)] -= KV.Coefficient;
+ if (SubOverflow(R[GetOrAddIndex(KV.Variable)], KV.Coefficient,
+ R[GetOrAddIndex(KV.Variable)]))
+ return {};
auto I =
KnownNonNegativeVariables.insert({KV.Variable, KV.IsKnownNonNegative});
I.first->second &= KV.IsKnownNonNegative;
@@ -501,8 +624,8 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
// Add extra constraints for variables that are known positive.
for (auto &KV : KnownNonNegativeVariables) {
- if (!KV.second || (Value2Index.find(KV.first) == Value2Index.end() &&
- NewIndexMap.find(KV.first) == NewIndexMap.end()))
+ if (!KV.second ||
+ (!Value2Index.contains(KV.first) && !NewIndexMap.contains(KV.first)))
continue;
SmallVector<int64_t, 8> C(Value2Index.size() + NewVariables.size() + 1, 0);
C[GetOrAddIndex(KV.first)] = -1;
@@ -524,7 +647,7 @@ ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred,
SmallVector<Value *> NewVariables;
ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables);
- if (R.IsEq || !NewVariables.empty())
+ if (!NewVariables.empty())
return {};
return R;
}
@@ -536,10 +659,54 @@ bool ConstraintTy::isValid(const ConstraintInfo &Info) const {
});
}
+std::optional<bool>
+ConstraintTy::isImpliedBy(const ConstraintSystem &CS) const {
+ bool IsConditionImplied = CS.isConditionImplied(Coefficients);
+
+ if (IsEq || IsNe) {
+ auto NegatedOrEqual = ConstraintSystem::negateOrEqual(Coefficients);
+ bool IsNegatedOrEqualImplied =
+ !NegatedOrEqual.empty() && CS.isConditionImplied(NegatedOrEqual);
+
+ // In order to check that `%a == %b` is true (equality), both conditions `%a
+ // >= %b` and `%a <= %b` must hold true. When checking for equality (`IsEq`
+ // is true), we return true if they both hold, false in the other cases.
+ if (IsConditionImplied && IsNegatedOrEqualImplied)
+ return IsEq;
+
+ auto Negated = ConstraintSystem::negate(Coefficients);
+ bool IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated);
+
+ auto StrictLessThan = ConstraintSystem::toStrictLessThan(Coefficients);
+ bool IsStrictLessThanImplied =
+ !StrictLessThan.empty() && CS.isConditionImplied(StrictLessThan);
+
+ // In order to check that `%a != %b` is true (non-equality), either
+ // condition `%a > %b` or `%a < %b` must hold true. When checking for
+ // non-equality (`IsNe` is true), we return true if one of the two holds,
+ // false in the other cases.
+ if (IsNegatedImplied || IsStrictLessThanImplied)
+ return IsNe;
+
+ return std::nullopt;
+ }
+
+ if (IsConditionImplied)
+ return true;
+
+ auto Negated = ConstraintSystem::negate(Coefficients);
+ auto IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated);
+ if (IsNegatedImplied)
+ return false;
+
+ // Neither the condition nor its negated holds, did not prove anything.
+ return std::nullopt;
+}
+
bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A,
Value *B) const {
auto R = getConstraintForSolving(Pred, A, B);
- return R.Preconditions.empty() && !R.empty() &&
+ return R.isValid(*this) &&
getCS(R.IsSigned).isConditionImplied(R.Coefficients);
}
@@ -568,11 +735,15 @@ void ConstraintInfo::transferToOtherSystem(
if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0)))
addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack);
break;
- case CmpInst::ICMP_SGT:
+ case CmpInst::ICMP_SGT: {
if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1)))
addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn,
NumOut, DFSInStack);
+ if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0)))
+ addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack);
+
break;
+ }
case CmpInst::ICMP_SGE:
if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) {
addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack);
@@ -581,77 +752,13 @@ void ConstraintInfo::transferToOtherSystem(
}
}
-namespace {
-/// Represents either
-/// * a condition that holds on entry to a block (=conditional fact)
-/// * an assume (=assume fact)
-/// * an instruction to simplify.
-/// It also tracks the Dominator DFS in and out numbers for each entry.
-struct FactOrCheck {
- Instruction *Inst;
- unsigned NumIn;
- unsigned NumOut;
- bool IsCheck;
- bool Not;
-
- FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool IsCheck, bool Not)
- : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
- IsCheck(IsCheck), Not(Not) {}
-
- static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst,
- bool Not = false) {
- return FactOrCheck(DTN, Inst, false, Not);
- }
-
- static FactOrCheck getCheck(DomTreeNode *DTN, Instruction *Inst) {
- return FactOrCheck(DTN, Inst, true, false);
- }
-
- bool isAssumeFact() const {
- if (!IsCheck && isa<IntrinsicInst>(Inst)) {
- assert(match(Inst, m_Intrinsic<Intrinsic::assume>()));
- return true;
- }
- return false;
- }
-
- bool isConditionFact() const { return !IsCheck && isa<CmpInst>(Inst); }
-};
-
-/// Keep state required to build worklist.
-struct State {
- DominatorTree &DT;
- SmallVector<FactOrCheck, 64> WorkList;
-
- State(DominatorTree &DT) : DT(DT) {}
-
- /// Process block \p BB and add known facts to work-list.
- void addInfoFor(BasicBlock &BB);
-
- /// Returns true if we can add a known condition from BB to its successor
- /// block Succ.
- bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const {
- return DT.dominates(BasicBlockEdge(&BB, Succ), Succ);
- }
-};
-
-} // namespace
-
#ifndef NDEBUG
-static void dumpWithNames(const ConstraintSystem &CS,
- DenseMap<Value *, unsigned> &Value2Index) {
- SmallVector<std::string> Names(Value2Index.size(), "");
- for (auto &KV : Value2Index) {
- Names[KV.second - 1] = std::string("%") + KV.first->getName().str();
- }
- CS.dump(Names);
-}
-static void dumpWithNames(ArrayRef<int64_t> C,
- DenseMap<Value *, unsigned> &Value2Index) {
- ConstraintSystem CS;
+static void dumpConstraint(ArrayRef<int64_t> C,
+ const DenseMap<Value *, unsigned> &Value2Index) {
+ ConstraintSystem CS(Value2Index);
CS.addVariableRowFill(C);
- dumpWithNames(CS, Value2Index);
+ CS.dump();
}
#endif
@@ -661,12 +768,24 @@ void State::addInfoFor(BasicBlock &BB) {
// Queue conditions and assumes.
for (Instruction &I : BB) {
if (auto Cmp = dyn_cast<ICmpInst>(&I)) {
- WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), Cmp));
+ for (Use &U : Cmp->uses()) {
+ auto *UserI = getContextInstForUse(U);
+ auto *DTN = DT.getNode(UserI->getParent());
+ if (!DTN)
+ continue;
+ WorkList.push_back(FactOrCheck::getCheck(DTN, &U));
+ }
continue;
}
if (match(&I, m_Intrinsic<Intrinsic::ssub_with_overflow>())) {
- WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), &I));
+ WorkList.push_back(
+ FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I)));
+ continue;
+ }
+
+ if (isa<MinMaxIntrinsic>(&I)) {
+ WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I));
continue;
}
@@ -748,7 +867,160 @@ void State::addInfoFor(BasicBlock &BB) {
FactOrCheck::getFact(DT.getNode(Br->getSuccessor(1)), CmpI, true));
}
-static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) {
+namespace {
+/// Helper to keep track of a condition and if it should be treated as negated
+/// for reproducer construction.
+/// Pred == Predicate::BAD_ICMP_PREDICATE indicates that this entry is a
+/// placeholder to keep the ReproducerCondStack in sync with DFSInStack.
+struct ReproducerEntry {
+ ICmpInst::Predicate Pred;
+ Value *LHS;
+ Value *RHS;
+
+ ReproducerEntry(ICmpInst::Predicate Pred, Value *LHS, Value *RHS)
+ : Pred(Pred), LHS(LHS), RHS(RHS) {}
+};
+} // namespace
+
+/// Helper function to generate a reproducer function for simplifying \p Cond.
+/// The reproducer function contains a series of @llvm.assume calls, one for
+/// each condition in \p Stack. For each condition, the operand instruction are
+/// cloned until we reach operands that have an entry in \p Value2Index. Those
+/// will then be added as function arguments. \p DT is used to order cloned
+/// instructions. The reproducer function will get added to \p M, if it is
+/// non-null. Otherwise no reproducer function is generated.
+static void generateReproducer(CmpInst *Cond, Module *M,
+ ArrayRef<ReproducerEntry> Stack,
+ ConstraintInfo &Info, DominatorTree &DT) {
+ if (!M)
+ return;
+
+ LLVMContext &Ctx = Cond->getContext();
+
+ LLVM_DEBUG(dbgs() << "Creating reproducer for " << *Cond << "\n");
+
+ ValueToValueMapTy Old2New;
+ SmallVector<Value *> Args;
+ SmallPtrSet<Value *, 8> Seen;
+ // Traverse Cond and its operands recursively until we reach a value that's in
+ // Value2Index or not an instruction, or not a operation that
+ // ConstraintElimination can decompose. Such values will be considered as
+ // external inputs to the reproducer, they are collected and added as function
+ // arguments later.
+ auto CollectArguments = [&](ArrayRef<Value *> Ops, bool IsSigned) {
+ auto &Value2Index = Info.getValue2Index(IsSigned);
+ SmallVector<Value *, 4> WorkList(Ops);
+ while (!WorkList.empty()) {
+ Value *V = WorkList.pop_back_val();
+ if (!Seen.insert(V).second)
+ continue;
+ if (Old2New.find(V) != Old2New.end())
+ continue;
+ if (isa<Constant>(V))
+ continue;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (Value2Index.contains(V) || !I ||
+ !isa<CmpInst, BinaryOperator, GEPOperator, CastInst>(V)) {
+ Old2New[V] = V;
+ Args.push_back(V);
+ LLVM_DEBUG(dbgs() << " found external input " << *V << "\n");
+ } else {
+ append_range(WorkList, I->operands());
+ }
+ }
+ };
+
+ for (auto &Entry : Stack)
+ if (Entry.Pred != ICmpInst::BAD_ICMP_PREDICATE)
+ CollectArguments({Entry.LHS, Entry.RHS}, ICmpInst::isSigned(Entry.Pred));
+ CollectArguments(Cond, ICmpInst::isSigned(Cond->getPredicate()));
+
+ SmallVector<Type *> ParamTys;
+ for (auto *P : Args)
+ ParamTys.push_back(P->getType());
+
+ FunctionType *FTy = FunctionType::get(Cond->getType(), ParamTys,
+ /*isVarArg=*/false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage,
+ Cond->getModule()->getName() +
+ Cond->getFunction()->getName() + "repro",
+ M);
+ // Add arguments to the reproducer function for each external value collected.
+ for (unsigned I = 0; I < Args.size(); ++I) {
+ F->getArg(I)->setName(Args[I]->getName());
+ Old2New[Args[I]] = F->getArg(I);
+ }
+
+ BasicBlock *Entry = BasicBlock::Create(Ctx, "entry", F);
+ IRBuilder<> Builder(Entry);
+ Builder.CreateRet(Builder.getTrue());
+ Builder.SetInsertPoint(Entry->getTerminator());
+
+ // Clone instructions in \p Ops and their operands recursively until reaching
+ // an value in Value2Index (external input to the reproducer). Update Old2New
+ // mapping for the original and cloned instructions. Sort instructions to
+ // clone by dominance, then insert the cloned instructions in the function.
+ auto CloneInstructions = [&](ArrayRef<Value *> Ops, bool IsSigned) {
+ SmallVector<Value *, 4> WorkList(Ops);
+ SmallVector<Instruction *> ToClone;
+ auto &Value2Index = Info.getValue2Index(IsSigned);
+ while (!WorkList.empty()) {
+ Value *V = WorkList.pop_back_val();
+ if (Old2New.find(V) != Old2New.end())
+ continue;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!Value2Index.contains(V) && I) {
+ Old2New[V] = nullptr;
+ ToClone.push_back(I);
+ append_range(WorkList, I->operands());
+ }
+ }
+
+ sort(ToClone,
+ [&DT](Instruction *A, Instruction *B) { return DT.dominates(A, B); });
+ for (Instruction *I : ToClone) {
+ Instruction *Cloned = I->clone();
+ Old2New[I] = Cloned;
+ Old2New[I]->setName(I->getName());
+ Cloned->insertBefore(&*Builder.GetInsertPoint());
+ Cloned->dropUnknownNonDebugMetadata();
+ Cloned->setDebugLoc({});
+ }
+ };
+
+ // Materialize the assumptions for the reproducer using the entries in Stack.
+ // That is, first clone the operands of the condition recursively until we
+ // reach an external input to the reproducer and add them to the reproducer
+ // function. Then add an ICmp for the condition (with the inverse predicate if
+ // the entry is negated) and an assert using the ICmp.
+ for (auto &Entry : Stack) {
+ if (Entry.Pred == ICmpInst::BAD_ICMP_PREDICATE)
+ continue;
+
+ LLVM_DEBUG(
+ dbgs() << " Materializing assumption icmp " << Entry.Pred << ' ';
+ Entry.LHS->printAsOperand(dbgs(), /*PrintType=*/true); dbgs() << ", ";
+ Entry.RHS->printAsOperand(dbgs(), /*PrintType=*/false); dbgs() << "\n");
+ CloneInstructions({Entry.LHS, Entry.RHS}, CmpInst::isSigned(Entry.Pred));
+
+ auto *Cmp = Builder.CreateICmp(Entry.Pred, Entry.LHS, Entry.RHS);
+ Builder.CreateAssumption(Cmp);
+ }
+
+ // Finally, clone the condition to reproduce and remap instruction operands in
+ // the reproducer using Old2New.
+ CloneInstructions(Cond, CmpInst::isSigned(Cond->getPredicate()));
+ Entry->getTerminator()->setOperand(0, Cond);
+ remapInstructionsInBlocks({Entry}, Old2New);
+
+ assert(!verifyFunction(*F, &dbgs()));
+}
+
+static std::optional<bool> checkCondition(CmpInst *Cmp, ConstraintInfo &Info,
+ unsigned NumIn, unsigned NumOut,
+ Instruction *ContextInst) {
LLVM_DEBUG(dbgs() << "Checking " << *Cmp << "\n");
CmpInst::Predicate Pred = Cmp->getPredicate();
@@ -758,7 +1030,7 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) {
auto R = Info.getConstraintForSolving(Pred, A, B);
if (R.empty() || !R.isValid(Info)){
LLVM_DEBUG(dbgs() << " failed to decompose condition\n");
- return false;
+ return std::nullopt;
}
auto &CSToUse = Info.getCS(R.IsSigned);
@@ -773,39 +1045,107 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) {
CSToUse.popLastConstraint();
});
- bool Changed = false;
- if (CSToUse.isConditionImplied(R.Coefficients)) {
+ if (auto ImpliedCondition = R.isImpliedBy(CSToUse)) {
if (!DebugCounter::shouldExecute(EliminatedCounter))
- return false;
+ return std::nullopt;
LLVM_DEBUG({
- dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n";
- dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned));
+ if (*ImpliedCondition) {
+ dbgs() << "Condition " << *Cmp;
+ } else {
+ auto InversePred = Cmp->getInversePredicate();
+ dbgs() << "Condition " << CmpInst::getPredicateName(InversePred) << " "
+ << *A << ", " << *B;
+ }
+ dbgs() << " implied by dominating constraints\n";
+ CSToUse.dump();
});
- Constant *TrueC =
- ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType()));
- Cmp->replaceUsesWithIf(TrueC, [](Use &U) {
+ return ImpliedCondition;
+ }
+
+ return std::nullopt;
+}
+
+static bool checkAndReplaceCondition(
+ CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut,
+ Instruction *ContextInst, Module *ReproducerModule,
+ ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT) {
+ auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) {
+ generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT);
+ Constant *ConstantC = ConstantInt::getBool(
+ CmpInst::makeCmpResultType(Cmp->getType()), IsTrue);
+ Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut,
+ ContextInst](Use &U) {
+ auto *UserI = getContextInstForUse(U);
+ auto *DTN = DT.getNode(UserI->getParent());
+ if (!DTN || DTN->getDFSNumIn() < NumIn || DTN->getDFSNumOut() > NumOut)
+ return false;
+ if (UserI->getParent() == ContextInst->getParent() &&
+ UserI->comesBefore(ContextInst))
+ return false;
+
// Conditions in an assume trivially simplify to true. Skip uses
// in assume calls to not destroy the available information.
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
return !II || II->getIntrinsicID() != Intrinsic::assume;
});
NumCondsRemoved++;
+ return true;
+ };
+
+ if (auto ImpliedCondition =
+ checkCondition(Cmp, Info, NumIn, NumOut, ContextInst))
+ return ReplaceCmpWithConstant(Cmp, *ImpliedCondition);
+ return false;
+}
+
+static void
+removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info,
+ Module *ReproducerModule,
+ SmallVectorImpl<ReproducerEntry> &ReproducerCondStack,
+ SmallVectorImpl<StackEntry> &DFSInStack) {
+ Info.popLastConstraint(E.IsSigned);
+ // Remove variables in the system that went out of scope.
+ auto &Mapping = Info.getValue2Index(E.IsSigned);
+ for (Value *V : E.ValuesToRelease)
+ Mapping.erase(V);
+ Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size());
+ DFSInStack.pop_back();
+ if (ReproducerModule)
+ ReproducerCondStack.pop_back();
+}
+
+/// Check if the first condition for an AND implies the second.
+static bool checkAndSecondOpImpliedByFirst(
+ FactOrCheck &CB, ConstraintInfo &Info, Module *ReproducerModule,
+ SmallVectorImpl<ReproducerEntry> &ReproducerCondStack,
+ SmallVectorImpl<StackEntry> &DFSInStack) {
+ CmpInst::Predicate Pred;
+ Value *A, *B;
+ Instruction *And = CB.getContextInst();
+ if (!match(And->getOperand(0), m_ICmp(Pred, m_Value(A), m_Value(B))))
+ return false;
+
+ // Optimistically add fact from first condition.
+ unsigned OldSize = DFSInStack.size();
+ Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+ if (OldSize == DFSInStack.size())
+ return false;
+
+ bool Changed = false;
+ // Check if the second condition can be simplified now.
+ if (auto ImpliedCondition =
+ checkCondition(cast<ICmpInst>(And->getOperand(1)), Info, CB.NumIn,
+ CB.NumOut, CB.getContextInst())) {
+ And->setOperand(1, ConstantInt::getBool(And->getType(), *ImpliedCondition));
Changed = true;
}
- if (CSToUse.isConditionImplied(ConstraintSystem::negate(R.Coefficients))) {
- if (!DebugCounter::shouldExecute(EliminatedCounter))
- return false;
- LLVM_DEBUG({
- dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n";
- dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned));
- });
- Constant *FalseC =
- ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType()));
- Cmp->replaceAllUsesWith(FalseC);
- NumCondsRemoved++;
- Changed = true;
+ // Remove entries again.
+ while (OldSize < DFSInStack.size()) {
+ StackEntry E = DFSInStack.back();
+ removeEntryFromStack(E, Info, ReproducerModule, ReproducerCondStack,
+ DFSInStack);
}
return Changed;
}
@@ -817,10 +1157,12 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
// hold.
SmallVector<Value *> NewVariables;
auto R = getConstraint(Pred, A, B, NewVariables);
- if (!R.isValid(*this))
+
+ // TODO: Support non-equality for facts as well.
+ if (!R.isValid(*this) || R.isNe())
return;
- LLVM_DEBUG(dbgs() << "Adding '" << CmpInst::getPredicateName(Pred) << " ";
+ LLVM_DEBUG(dbgs() << "Adding '" << Pred << " ";
A->printAsOperand(dbgs(), false); dbgs() << ", ";
B->printAsOperand(dbgs(), false); dbgs() << "'\n");
bool Added = false;
@@ -842,14 +1184,14 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
LLVM_DEBUG({
dbgs() << " constraint: ";
- dumpWithNames(R.Coefficients, getValue2Index(R.IsSigned));
+ dumpConstraint(R.Coefficients, getValue2Index(R.IsSigned));
dbgs() << "\n";
});
DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned,
std::move(ValuesToRelease));
- if (R.IsEq) {
+ if (R.isEq()) {
// Also add the inverted constraint for equality constraints.
for (auto &Coeff : R.Coefficients)
Coeff *= -1;
@@ -921,12 +1263,17 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info,
return Changed;
}
-static bool eliminateConstraints(Function &F, DominatorTree &DT) {
+static bool eliminateConstraints(Function &F, DominatorTree &DT,
+ OptimizationRemarkEmitter &ORE) {
bool Changed = false;
DT.updateDFSNumbers();
-
- ConstraintInfo Info(F.getParent()->getDataLayout());
+ SmallVector<Value *> FunctionArgs;
+ for (Value &Arg : F.args())
+ FunctionArgs.push_back(&Arg);
+ ConstraintInfo Info(F.getParent()->getDataLayout(), FunctionArgs);
State S(DT);
+ std::unique_ptr<Module> ReproducerModule(
+ DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr);
// First, collect conditions implied by branches and blocks with their
// Dominator DFS in and out numbers.
@@ -961,7 +1308,9 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) {
return true;
if (B.isConditionFact())
return false;
- return A.Inst->comesBefore(B.Inst);
+ auto *InstA = A.getContextInst();
+ auto *InstB = B.getContextInst();
+ return InstA->comesBefore(InstB);
}
return A.NumIn < B.NumIn;
});
@@ -970,6 +1319,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) {
// Finally, process ordered worklist and eliminate implied conditions.
SmallVector<StackEntry, 16> DFSInStack;
+ SmallVector<ReproducerEntry> ReproducerCondStack;
for (FactOrCheck &CB : S.WorkList) {
// First, pop entries from the stack that are out-of-scope for CB. Remove
// the corresponding entry from the constraint system.
@@ -983,61 +1333,96 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) {
break;
LLVM_DEBUG({
dbgs() << "Removing ";
- dumpWithNames(Info.getCS(E.IsSigned).getLastConstraint(),
- Info.getValue2Index(E.IsSigned));
+ dumpConstraint(Info.getCS(E.IsSigned).getLastConstraint(),
+ Info.getValue2Index(E.IsSigned));
dbgs() << "\n";
});
-
- Info.popLastConstraint(E.IsSigned);
- // Remove variables in the system that went out of scope.
- auto &Mapping = Info.getValue2Index(E.IsSigned);
- for (Value *V : E.ValuesToRelease)
- Mapping.erase(V);
- Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size());
- DFSInStack.pop_back();
+ removeEntryFromStack(E, Info, ReproducerModule.get(), ReproducerCondStack,
+ DFSInStack);
}
- LLVM_DEBUG({
- dbgs() << "Processing ";
- if (CB.IsCheck)
- dbgs() << "condition to simplify: " << *CB.Inst;
- else
- dbgs() << "fact to add to the system: " << *CB.Inst;
- dbgs() << "\n";
- });
+ LLVM_DEBUG(dbgs() << "Processing ");
// For a block, check if any CmpInsts become known based on the current set
// of constraints.
- if (CB.IsCheck) {
- if (auto *II = dyn_cast<WithOverflowInst>(CB.Inst)) {
+ if (CB.isCheck()) {
+ Instruction *Inst = CB.getInstructionToSimplify();
+ if (!Inst)
+ continue;
+ LLVM_DEBUG(dbgs() << "condition to simplify: " << *Inst << "\n");
+ if (auto *II = dyn_cast<WithOverflowInst>(Inst)) {
Changed |= tryToSimplifyOverflowMath(II, Info, ToRemove);
- } else if (auto *Cmp = dyn_cast<ICmpInst>(CB.Inst)) {
- Changed |= checkAndReplaceCondition(Cmp, Info);
+ } else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) {
+ bool Simplified = checkAndReplaceCondition(
+ Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(),
+ ReproducerModule.get(), ReproducerCondStack, S.DT);
+ if (!Simplified && match(CB.getContextInst(),
+ m_LogicalAnd(m_Value(), m_Specific(Inst)))) {
+ Simplified =
+ checkAndSecondOpImpliedByFirst(CB, Info, ReproducerModule.get(),
+ ReproducerCondStack, DFSInStack);
+ }
+ Changed |= Simplified;
}
continue;
}
- ICmpInst::Predicate Pred;
- Value *A, *B;
- Value *Cmp = CB.Inst;
- match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
- if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+ LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n");
+ auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
LLVM_DEBUG(
dbgs()
<< "Skip adding constraint because system has too many rows.\n");
- continue;
+ return;
+ }
+
+ Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+ if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size())
+ ReproducerCondStack.emplace_back(Pred, A, B);
+
+ Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+ if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size()) {
+ // Add dummy entries to ReproducerCondStack to keep it in sync with
+ // DFSInStack.
+ for (unsigned I = 0,
+ E = (DFSInStack.size() - ReproducerCondStack.size());
+ I < E; ++I) {
+ ReproducerCondStack.emplace_back(ICmpInst::BAD_ICMP_PREDICATE,
+ nullptr, nullptr);
+ }
}
+ };
+ ICmpInst::Predicate Pred;
+ if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
+ Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
+ AddFact(Pred, MinMax, MinMax->getLHS());
+ AddFact(Pred, MinMax, MinMax->getRHS());
+ continue;
+ }
+
+ Value *A, *B;
+ Value *Cmp = CB.Inst;
+ match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
+ if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
// Use the inverse predicate if required.
if (CB.Not)
Pred = CmpInst::getInversePredicate(Pred);
- Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
- Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+ AddFact(Pred, A, B);
}
}
+ if (ReproducerModule && !ReproducerModule->functions().empty()) {
+ std::string S;
+ raw_string_ostream StringS(S);
+ ReproducerModule->print(StringS, nullptr);
+ StringS.flush();
+ OptimizationRemark Rem(DEBUG_TYPE, "Reproducer", &F);
+ Rem << ore::NV("module") << S;
+ ORE.emit(Rem);
+ }
+
#ifndef NDEBUG
unsigned SignedEntries =
count_if(DFSInStack, [](const StackEntry &E) { return E.IsSigned; });
@@ -1055,7 +1440,8 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) {
PreservedAnalyses ConstraintEliminationPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- if (!eliminateConstraints(F, DT))
+ auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ if (!eliminateConstraints(F, DT, ORE))
return PreservedAnalyses::all();
PreservedAnalyses PA;
diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 90b4b521e7de..48b27a1ea0a2 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -36,11 +36,8 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <optional>
@@ -97,60 +94,33 @@ STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed");
STATISTIC(NumUDivURemsNarrowedExpanded,
"Number of bound udiv's/urem's expanded");
-namespace {
-
- class CorrelatedValuePropagation : public FunctionPass {
- public:
- static char ID;
-
- CorrelatedValuePropagation(): FunctionPass(ID) {
- initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<LazyValueInfoWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<LazyValueInfoWrapperPass>();
- }
- };
-
-} // end anonymous namespace
-
-char CorrelatedValuePropagation::ID = 0;
-
-INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation",
- "Value Propagation", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
-INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation",
- "Value Propagation", false, false)
-
-// Public interface to the Value Propagation pass
-Pass *llvm::createCorrelatedValuePropagationPass() {
- return new CorrelatedValuePropagation();
-}
-
static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
- if (S->getType()->isVectorTy()) return false;
- if (isa<Constant>(S->getCondition())) return false;
-
- Constant *C = LVI->getConstant(S->getCondition(), S);
- if (!C) return false;
+ if (S->getType()->isVectorTy() || isa<Constant>(S->getCondition()))
+ return false;
- ConstantInt *CI = dyn_cast<ConstantInt>(C);
- if (!CI) return false;
+ bool Changed = false;
+ for (Use &U : make_early_inc_range(S->uses())) {
+ auto *I = cast<Instruction>(U.getUser());
+ Constant *C;
+ if (auto *PN = dyn_cast<PHINode>(I))
+ C = LVI->getConstantOnEdge(S->getCondition(), PN->getIncomingBlock(U),
+ I->getParent(), I);
+ else
+ C = LVI->getConstant(S->getCondition(), I);
+
+ auto *CI = dyn_cast_or_null<ConstantInt>(C);
+ if (!CI)
+ continue;
- Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue();
- S->replaceAllUsesWith(ReplaceWith);
- S->eraseFromParent();
+ U.set(CI->isOne() ? S->getTrueValue() : S->getFalseValue());
+ Changed = true;
+ ++NumSelects;
+ }
- ++NumSelects;
+ if (Changed && S->use_empty())
+ S->eraseFromParent();
- return true;
+ return Changed;
}
/// Try to simplify a phi with constant incoming values that match the edge
@@ -698,7 +668,7 @@ enum class Domain { NonNegative, NonPositive, Unknown };
static Domain getDomain(const ConstantRange &CR) {
if (CR.isAllNonNegative())
return Domain::NonNegative;
- if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getNullValue(CR.getBitWidth())))
+ if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getZero(CR.getBitWidth())))
return Domain::NonPositive;
return Domain::Unknown;
}
@@ -717,7 +687,6 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
// What is the smallest bit width that can accommodate the entire value ranges
// of both of the operands?
- std::array<std::optional<ConstantRange>, 2> CRs;
unsigned MinSignedBits =
std::max(LCR.getMinSignedBits(), RCR.getMinSignedBits());
@@ -804,10 +773,18 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
IRBuilder<> B(Instr);
Value *ExpandedOp;
- if (IsRem) {
+ if (XCR.icmp(ICmpInst::ICMP_UGE, YCR)) {
+ // If X is between Y and 2*Y the result is known.
+ if (IsRem)
+ ExpandedOp = B.CreateNUWSub(X, Y);
+ else
+ ExpandedOp = ConstantInt::get(Instr->getType(), 1);
+ } else if (IsRem) {
// NOTE: this transformation introduces two uses of X,
// but it may be undef so we must freeze it first.
- Value *FrozenX = B.CreateFreeze(X, X->getName() + ".frozen");
+ Value *FrozenX = X;
+ if (!isGuaranteedNotToBeUndefOrPoison(X))
+ FrozenX = B.CreateFreeze(X, X->getName() + ".frozen");
auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem");
auto *Cmp =
B.CreateICmp(ICmpInst::ICMP_ULT, FrozenX, Y, Instr->getName() + ".cmp");
@@ -1008,7 +985,8 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
if (SDI->getType()->isVectorTy())
return false;
- ConstantRange LRange = LVI->getConstantRangeAtUse(SDI->getOperandUse(0));
+ ConstantRange LRange =
+ LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false);
unsigned OrigWidth = SDI->getType()->getIntegerBitWidth();
ConstantRange NegOneOrZero =
ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1));
@@ -1040,7 +1018,8 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
return false;
const Use &Base = SDI->getOperandUse(0);
- if (!LVI->getConstantRangeAtUse(Base).isAllNonNegative())
+ if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
+ .isAllNonNegative())
return false;
++NumSExt;
@@ -1222,16 +1201,6 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
return FnChanged;
}
-bool CorrelatedValuePropagation::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
-
- LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-
- return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F));
-}
-
PreservedAnalyses
CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {
LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index 658d0fcb53fa..f2efe60bdf88 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -70,11 +70,8 @@
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/SSAUpdaterBulk.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
@@ -168,51 +165,8 @@ private:
OptimizationRemarkEmitter *ORE;
};
-class DFAJumpThreadingLegacyPass : public FunctionPass {
-public:
- static char ID; // Pass identification
- DFAJumpThreadingLegacyPass() : FunctionPass(ID) {}
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- AssumptionCache *AC =
- &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- TargetTransformInfo *TTI =
- &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- OptimizationRemarkEmitter *ORE =
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
-
- return DFAJumpThreading(AC, DT, TTI, ORE).run(F);
- }
-};
} // end anonymous namespace
-char DFAJumpThreadingLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
- "DFA Jump Threading", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_END(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
- "DFA Jump Threading", false, false)
-
-// Public interface to the DFA Jump Threading pass
-FunctionPass *llvm::createDFAJumpThreadingPass() {
- return new DFAJumpThreadingLegacyPass();
-}
-
namespace {
/// Create a new basic block and sink \p SIToSink into it.
@@ -625,7 +579,7 @@ private:
continue;
PathsType SuccPaths = paths(Succ, Visited, PathDepth + 1);
- for (PathType Path : SuccPaths) {
+ for (const PathType &Path : SuccPaths) {
PathType NewPath(Path);
NewPath.push_front(BB);
Res.push_back(NewPath);
@@ -978,7 +932,7 @@ private:
SSAUpdaterBulk SSAUpdate;
SmallVector<Use *, 16> UsesToRename;
- for (auto KV : NewDefs) {
+ for (const auto &KV : NewDefs) {
Instruction *I = KV.first;
BasicBlock *BB = I->getParent();
std::vector<Instruction *> Cloned = KV.second;
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 9c0b4d673145..d3fbe49439a8 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -69,15 +69,12 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugCounter.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -462,10 +459,10 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI,
"Should not hit the entry block because SI must be dominated by LI");
for (BasicBlock *Pred : predecessors(B)) {
PHITransAddr PredAddr = Addr;
- if (PredAddr.NeedsPHITranslationFromBlock(B)) {
- if (!PredAddr.IsPotentiallyPHITranslatable())
+ if (PredAddr.needsPHITranslationFromBlock(B)) {
+ if (!PredAddr.isPotentiallyPHITranslatable())
return false;
- if (PredAddr.PHITranslateValue(B, Pred, DT, false))
+ if (!PredAddr.translateValue(B, Pred, DT, false))
return false;
}
Value *TranslatedPtr = PredAddr.getAddr();
@@ -485,41 +482,75 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI,
return true;
}
-static void shortenAssignment(Instruction *Inst, uint64_t OldOffsetInBits,
- uint64_t OldSizeInBits, uint64_t NewSizeInBits,
- bool IsOverwriteEnd) {
- DIExpression::FragmentInfo DeadFragment;
- DeadFragment.SizeInBits = OldSizeInBits - NewSizeInBits;
- DeadFragment.OffsetInBits =
+static void shortenAssignment(Instruction *Inst, Value *OriginalDest,
+ uint64_t OldOffsetInBits, uint64_t OldSizeInBits,
+ uint64_t NewSizeInBits, bool IsOverwriteEnd) {
+ const DataLayout &DL = Inst->getModule()->getDataLayout();
+ uint64_t DeadSliceSizeInBits = OldSizeInBits - NewSizeInBits;
+ uint64_t DeadSliceOffsetInBits =
OldOffsetInBits + (IsOverwriteEnd ? NewSizeInBits : 0);
-
- auto CreateDeadFragExpr = [Inst, DeadFragment]() {
- // FIXME: This should be using the DIExpression in the Alloca's dbg.assign
- // for the variable, since that could also contain a fragment?
- return *DIExpression::createFragmentExpression(
- DIExpression::get(Inst->getContext(), std::nullopt),
+ auto SetDeadFragExpr = [](DbgAssignIntrinsic *DAI,
+ DIExpression::FragmentInfo DeadFragment) {
+ // createFragmentExpression expects an offset relative to the existing
+ // fragment offset if there is one.
+ uint64_t RelativeOffset = DeadFragment.OffsetInBits -
+ DAI->getExpression()
+ ->getFragmentInfo()
+ .value_or(DIExpression::FragmentInfo(0, 0))
+ .OffsetInBits;
+ if (auto NewExpr = DIExpression::createFragmentExpression(
+ DAI->getExpression(), RelativeOffset, DeadFragment.SizeInBits)) {
+ DAI->setExpression(*NewExpr);
+ return;
+ }
+ // Failed to create a fragment expression for this so discard the value,
+ // making this a kill location.
+ auto *Expr = *DIExpression::createFragmentExpression(
+ DIExpression::get(DAI->getContext(), std::nullopt),
DeadFragment.OffsetInBits, DeadFragment.SizeInBits);
+ DAI->setExpression(Expr);
+ DAI->setKillLocation();
};
// A DIAssignID to use so that the inserted dbg.assign intrinsics do not
// link to any instructions. Created in the loop below (once).
DIAssignID *LinkToNothing = nullptr;
+ LLVMContext &Ctx = Inst->getContext();
+ auto GetDeadLink = [&Ctx, &LinkToNothing]() {
+ if (!LinkToNothing)
+ LinkToNothing = DIAssignID::getDistinct(Ctx);
+ return LinkToNothing;
+ };
// Insert an unlinked dbg.assign intrinsic for the dead fragment after each
- // overlapping dbg.assign intrinsic.
- for (auto *DAI : at::getAssignmentMarkers(Inst)) {
- if (auto FragInfo = DAI->getExpression()->getFragmentInfo()) {
- if (!DIExpression::fragmentsOverlap(*FragInfo, DeadFragment))
- continue;
+ // overlapping dbg.assign intrinsic. The loop invalidates the iterators
+ // returned by getAssignmentMarkers so save a copy of the markers to iterate
+ // over.
+ auto LinkedRange = at::getAssignmentMarkers(Inst);
+ SmallVector<DbgAssignIntrinsic *> Linked(LinkedRange.begin(),
+ LinkedRange.end());
+ for (auto *DAI : Linked) {
+ std::optional<DIExpression::FragmentInfo> NewFragment;
+ if (!at::calculateFragmentIntersect(DL, OriginalDest, DeadSliceOffsetInBits,
+ DeadSliceSizeInBits, DAI,
+ NewFragment) ||
+ !NewFragment) {
+ // We couldn't calculate the intersecting fragment for some reason. Be
+ // cautious and unlink the whole assignment from the store.
+ DAI->setKillAddress();
+ DAI->setAssignId(GetDeadLink());
+ continue;
}
+ // No intersect.
+ if (NewFragment->SizeInBits == 0)
+ continue;
// Fragments overlap: insert a new dbg.assign for this dead part.
auto *NewAssign = cast<DbgAssignIntrinsic>(DAI->clone());
NewAssign->insertAfter(DAI);
- if (!LinkToNothing)
- LinkToNothing = DIAssignID::getDistinct(Inst->getContext());
- NewAssign->setAssignId(LinkToNothing);
- NewAssign->setExpression(CreateDeadFragExpr());
+ NewAssign->setAssignId(GetDeadLink());
+ if (NewFragment)
+ SetDeadFragExpr(NewAssign, *NewFragment);
NewAssign->setKillAddress();
}
}
@@ -596,8 +627,8 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
DeadIntrinsic->setLength(TrimmedLength);
DeadIntrinsic->setDestAlignment(PrefAlign);
+ Value *OrigDest = DeadIntrinsic->getRawDest();
if (!IsOverwriteEnd) {
- Value *OrigDest = DeadIntrinsic->getRawDest();
Type *Int8PtrTy =
Type::getInt8PtrTy(DeadIntrinsic->getContext(),
OrigDest->getType()->getPointerAddressSpace());
@@ -616,7 +647,7 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
}
// Update attached dbg.assign intrinsics. Assume 8-bit byte.
- shortenAssignment(DeadI, DeadStart * 8, DeadSize * 8, NewSize * 8,
+ shortenAssignment(DeadI, OrigDest, DeadStart * 8, DeadSize * 8, NewSize * 8,
IsOverwriteEnd);
// Finally update start and size of dead access.
@@ -730,7 +761,7 @@ tryToMergePartialOverlappingStores(StoreInst *KillingI, StoreInst *DeadI,
}
namespace {
-// Returns true if \p I is an intrisnic that does not read or write memory.
+// Returns true if \p I is an intrinsic that does not read or write memory.
bool isNoopIntrinsic(Instruction *I) {
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
@@ -740,7 +771,6 @@ bool isNoopIntrinsic(Instruction *I) {
case Intrinsic::launder_invariant_group:
case Intrinsic::assume:
return true;
- case Intrinsic::dbg_addr:
case Intrinsic::dbg_declare:
case Intrinsic::dbg_label:
case Intrinsic::dbg_value:
@@ -2039,7 +2069,6 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
const LoopInfo &LI) {
bool MadeChange = false;
- MSSA.ensureOptimizedUses();
DSEState State(F, AA, MSSA, DT, PDT, AC, TLI, LI);
// For each store:
for (unsigned I = 0; I < State.MemDefs.size(); I++) {
@@ -2241,79 +2270,3 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<LoopAnalysis>();
return PA;
}
-
-namespace {
-
-/// A legacy pass for the legacy pass manager that wraps \c DSEPass.
-class DSELegacyPass : public FunctionPass {
-public:
- static char ID; // Pass identification, replacement for typeid
-
- DSELegacyPass() : FunctionPass(ID) {
- initializeDSELegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
- PostDominatorTree &PDT =
- getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
- AssumptionCache &AC =
- getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-
- bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI);
-
-#ifdef LLVM_ENABLE_STATS
- if (AreStatisticsEnabled())
- for (auto &I : instructions(F))
- NumRemainingStores += isa<StoreInst>(&I);
-#endif
-
- return Changed;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<PostDominatorTreeWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<PostDominatorTreeWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
- }
-};
-
-} // end anonymous namespace
-
-char DSELegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false,
- false)
-
-FunctionPass *llvm::createDeadStoreEliminationPass() {
- return new DSELegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
index 303951643a0b..57d3f312186e 100644
--- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
+++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp
@@ -21,10 +21,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/DebugCounter.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BypassSlowDivision.h"
#include <optional>
@@ -371,6 +368,10 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
Mul->insertAfter(RemInst);
Sub->insertAfter(Mul);
+ // If DivInst has the exact flag, remove it. Otherwise this optimization
+ // may replace a well-defined value 'X % Y' with poison.
+ DivInst->dropPoisonGeneratingFlags();
+
// If X can be undef, X should be frozen first.
// For example, let's assume that Y = 1 & X = undef:
// %div = sdiv undef, 1 // %div = undef
@@ -413,44 +414,6 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,
// Pass manager boilerplate below here.
-namespace {
-struct DivRemPairsLegacyPass : public FunctionPass {
- static char ID;
- DivRemPairsLegacyPass() : FunctionPass(ID) {
- initializeDivRemPairsLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- FunctionPass::getAnalysisUsage(AU);
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- return optimizeDivRem(F, TTI, DT);
- }
-};
-} // namespace
-
-char DivRemPairsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs",
- "Hoist/decompose integer division and remainder", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_END(DivRemPairsLegacyPass, "div-rem-pairs",
- "Hoist/decompose integer division and remainder", false,
- false)
-FunctionPass *llvm::createDivRemPairsPass() {
- return new DivRemPairsLegacyPass();
-}
-
PreservedAnalyses DivRemPairsPass::run(Function &F,
FunctionAnalysisManager &FAM) {
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index 26821c7ee81e..67e8e82e408f 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -218,6 +218,19 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
return true;
}
+static unsigned hashCallInst(CallInst *CI) {
+ // Don't CSE convergent calls in different basic blocks, because they
+ // implicitly depend on the set of threads that is currently executing.
+ if (CI->isConvergent()) {
+ return hash_combine(
+ CI->getOpcode(), CI->getParent(),
+ hash_combine_range(CI->value_op_begin(), CI->value_op_end()));
+ }
+ return hash_combine(
+ CI->getOpcode(),
+ hash_combine_range(CI->value_op_begin(), CI->value_op_end()));
+}
+
static unsigned getHashValueImpl(SimpleValue Val) {
Instruction *Inst = Val.Inst;
// Hash in all of the operands as pointers.
@@ -318,6 +331,11 @@ static unsigned getHashValueImpl(SimpleValue Val) {
return hash_combine(GCR->getOpcode(), GCR->getOperand(0),
GCR->getBasePtr(), GCR->getDerivedPtr());
+ // Don't CSE convergent calls in different basic blocks, because they
+ // implicitly depend on the set of threads that is currently executing.
+ if (CallInst *CI = dyn_cast<CallInst>(Inst))
+ return hashCallInst(CI);
+
// Mix in the opcode.
return hash_combine(
Inst->getOpcode(),
@@ -344,8 +362,16 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
if (LHSI->getOpcode() != RHSI->getOpcode())
return false;
- if (LHSI->isIdenticalToWhenDefined(RHSI))
+ if (LHSI->isIdenticalToWhenDefined(RHSI)) {
+ // Convergent calls implicitly depend on the set of threads that is
+ // currently executing, so conservatively return false if they are in
+ // different basic blocks.
+ if (CallInst *CI = dyn_cast<CallInst>(LHSI);
+ CI && CI->isConvergent() && LHSI->getParent() != RHSI->getParent())
+ return false;
+
return true;
+ }
// If we're not strictly identical, we still might be a commutable instruction
if (BinaryOperator *LHSBinOp = dyn_cast<BinaryOperator>(LHSI)) {
@@ -508,15 +534,21 @@ unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) {
Instruction *Inst = Val.Inst;
// Hash all of the operands as pointers and mix in the opcode.
- return hash_combine(
- Inst->getOpcode(),
- hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
+ return hashCallInst(cast<CallInst>(Inst));
}
bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {
- Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
if (LHS.isSentinel() || RHS.isSentinel())
- return LHSI == RHSI;
+ return LHS.Inst == RHS.Inst;
+
+ CallInst *LHSI = cast<CallInst>(LHS.Inst);
+ CallInst *RHSI = cast<CallInst>(RHS.Inst);
+
+ // Convergent calls implicitly depend on the set of threads that is
+ // currently executing, so conservatively return false if they are in
+ // different basic blocks.
+ if (LHSI->isConvergent() && LHSI->getParent() != RHSI->getParent())
+ return false;
return LHSI->isIdenticalTo(RHSI);
}
@@ -578,12 +610,13 @@ public:
unsigned Generation = 0;
int MatchingId = -1;
bool IsAtomic = false;
+ bool IsLoad = false;
LoadValue() = default;
LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId,
- bool IsAtomic)
+ bool IsAtomic, bool IsLoad)
: DefInst(Inst), Generation(Generation), MatchingId(MatchingId),
- IsAtomic(IsAtomic) {}
+ IsAtomic(IsAtomic), IsLoad(IsLoad) {}
};
using LoadMapAllocator =
@@ -802,17 +835,7 @@ private:
Type *getValueType() const {
// TODO: handle target-specific intrinsics.
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
- switch (II->getIntrinsicID()) {
- case Intrinsic::masked_load:
- return II->getType();
- case Intrinsic::masked_store:
- return II->getArgOperand(0)->getType();
- default:
- return nullptr;
- }
- }
- return getLoadStoreType(Inst);
+ return Inst->getAccessType();
}
bool mayReadFromMemory() const {
@@ -1476,6 +1499,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
+ if (InVal.IsLoad)
+ if (auto *I = dyn_cast<Instruction>(Op))
+ combineMetadataForCSE(I, &Inst, false);
if (!Inst.use_empty())
Inst.replaceAllUsesWith(Op);
salvageKnowledge(&Inst, &AC);
@@ -1490,7 +1516,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
AvailableLoads.insert(MemInst.getPointerOperand(),
LoadValue(&Inst, CurrentGeneration,
MemInst.getMatchingId(),
- MemInst.isAtomic()));
+ MemInst.isAtomic(),
+ MemInst.isLoad()));
LastStore = nullptr;
continue;
}
@@ -1614,7 +1641,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
AvailableLoads.insert(MemInst.getPointerOperand(),
LoadValue(&Inst, CurrentGeneration,
MemInst.getMatchingId(),
- MemInst.isAtomic()));
+ MemInst.isAtomic(),
+ MemInst.isLoad()));
// Remember that this was the last unordered store we saw for DSE. We
// don't yet handle DSE on ordered or volatile stores since we don't
@@ -1710,10 +1738,10 @@ void EarlyCSEPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<EarlyCSEPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (UseMemorySSA)
OS << "memssa";
- OS << ">";
+ OS << '>';
}
namespace {
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index f66d1b914b0b..ccca8bcc1a56 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -20,12 +20,9 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include <deque>
#define DEBUG_TYPE "float2int"
@@ -49,35 +46,6 @@ MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
cl::desc("Max integer bitwidth to consider in float2int"
"(default=64)"));
-namespace {
- struct Float2IntLegacyPass : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- Float2IntLegacyPass() : FunctionPass(ID) {
- initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- return Impl.runImpl(F, DT);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-
- private:
- Float2IntPass Impl;
- };
-}
-
-char Float2IntLegacyPass::ID = 0;
-INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false)
-
// Given a FCmp predicate, return a matching ICmp predicate if one
// exists, otherwise return BAD_ICMP_PREDICATE.
static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
@@ -187,7 +155,7 @@ void Float2IntPass::walkBackwards() {
Instruction *I = Worklist.back();
Worklist.pop_back();
- if (SeenInsts.find(I) != SeenInsts.end())
+ if (SeenInsts.contains(I))
// Seen already.
continue;
@@ -371,7 +339,7 @@ bool Float2IntPass::validateAndTransform() {
ConvertedToTy = I->getType();
for (User *U : I->users()) {
Instruction *UI = dyn_cast<Instruction>(U);
- if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
+ if (!UI || !SeenInsts.contains(UI)) {
LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
Fail = true;
break;
@@ -391,8 +359,9 @@ bool Float2IntPass::validateAndTransform() {
// The number of bits required is the maximum of the upper and
// lower limits, plus one so it can be signed.
- unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
- R.getUpper().getMinSignedBits()) + 1;
+ unsigned MinBW = std::max(R.getLower().getSignificantBits(),
+ R.getUpper().getSignificantBits()) +
+ 1;
LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
// If we've run off the realms of the exactly representable integers,
@@ -427,7 +396,7 @@ bool Float2IntPass::validateAndTransform() {
}
Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {
- if (ConvertedInsts.find(I) != ConvertedInsts.end())
+ if (ConvertedInsts.contains(I))
// Already converted this instruction.
return ConvertedInsts[I];
@@ -528,9 +497,6 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
return Modified;
}
-namespace llvm {
-FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
-
PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
if (!runImpl(F, DT))
@@ -540,4 +506,3 @@ PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserveSet<CFGAnalyses>();
return PA;
}
-} // End namespace llvm
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 6158894e3437..03e8a2507b45 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -94,6 +94,8 @@ STATISTIC(NumGVNSimpl, "Number of instructions simplified");
STATISTIC(NumGVNEqProp, "Number of equalities propagated");
STATISTIC(NumPRELoad, "Number of loads PRE'd");
STATISTIC(NumPRELoopLoad, "Number of loop loads PRE'd");
+STATISTIC(NumPRELoadMoved2CEPred,
+ "Number of loads moved to predecessor of a critical edge in PRE");
STATISTIC(IsValueFullyAvailableInBlockNumSpeculationsMax,
"Number of blocks speculated as available in "
@@ -127,6 +129,11 @@ static cl::opt<uint32_t> MaxNumVisitedInsts(
cl::desc("Max number of visited instructions when trying to find "
"dominating value of select dependency (default = 100)"));
+static cl::opt<uint32_t> MaxNumInsnsPerBlock(
+ "gvn-max-num-insns", cl::Hidden, cl::init(100),
+ cl::desc("Max number of instructions to scan in each basic block in GVN "
+ "(default = 100)"));
+
struct llvm::GVNPass::Expression {
uint32_t opcode;
bool commutative = false;
@@ -416,10 +423,9 @@ GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) {
unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy);
MapVector<Value *, APInt> VariableOffsets;
APInt ConstantOffset(BitWidth, 0);
- if (PtrTy->isOpaquePointerTy() &&
- GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) {
- // For opaque pointers, convert into offset representation, to recognize
- // equivalent address calculations that use different type encoding.
+ if (GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) {
+ // Convert into offset representation, to recognize equivalent address
+ // calculations that use different type encoding.
LLVMContext &Context = GEP->getContext();
E.opcode = GEP->getOpcode();
E.type = nullptr;
@@ -432,8 +438,8 @@ GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) {
E.varargs.push_back(
lookupOrAdd(ConstantInt::get(Context, ConstantOffset)));
} else {
- // If converting to offset representation fails (for typed pointers and
- // scalable vectors), fall back to type-based implementation:
+ // If converting to offset representation fails (for scalable vectors),
+ // fall back to type-based implementation:
E.opcode = GEP->getOpcode();
E.type = GEP->getSourceElementType();
for (Use &Op : GEP->operands())
@@ -461,28 +467,34 @@ void GVNPass::ValueTable::add(Value *V, uint32_t num) {
}
uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
- if (AA->doesNotAccessMemory(C) &&
- // FIXME: Currently the calls which may access the thread id may
- // be considered as not accessing the memory. But this is
- // problematic for coroutines, since coroutines may resume in a
- // different thread. So we disable the optimization here for the
- // correctness. However, it may block many other correct
- // optimizations. Revert this one when we detect the memory
- // accessing kind more precisely.
- !C->getFunction()->isPresplitCoroutine()) {
+ // FIXME: Currently the calls which may access the thread id may
+ // be considered as not accessing the memory. But this is
+ // problematic for coroutines, since coroutines may resume in a
+ // different thread. So we disable the optimization here for the
+ // correctness. However, it may block many other correct
+ // optimizations. Revert this one when we detect the memory
+ // accessing kind more precisely.
+ if (C->getFunction()->isPresplitCoroutine()) {
+ valueNumbering[C] = nextValueNumber;
+ return nextValueNumber++;
+ }
+
+ // Do not combine convergent calls since they implicitly depend on the set of
+ // threads that is currently executing, and they might be in different basic
+ // blocks.
+ if (C->isConvergent()) {
+ valueNumbering[C] = nextValueNumber;
+ return nextValueNumber++;
+ }
+
+ if (AA->doesNotAccessMemory(C)) {
Expression exp = createExpr(C);
uint32_t e = assignExpNewValueNum(exp).first;
valueNumbering[C] = e;
return e;
- } else if (MD && AA->onlyReadsMemory(C) &&
- // FIXME: Currently the calls which may access the thread id may
- // be considered as not accessing the memory. But this is
- // problematic for coroutines, since coroutines may resume in a
- // different thread. So we disable the optimization here for the
- // correctness. However, it may block many other correct
- // optimizations. Revert this one when we detect the memory
- // accessing kind more precisely.
- !C->getFunction()->isPresplitCoroutine()) {
+ }
+
+ if (MD && AA->onlyReadsMemory(C)) {
Expression exp = createExpr(C);
auto ValNum = assignExpNewValueNum(exp);
if (ValNum.second) {
@@ -572,10 +584,10 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
uint32_t v = lookupOrAdd(cdep);
valueNumbering[C] = v;
return v;
- } else {
- valueNumbering[C] = nextValueNumber;
- return nextValueNumber++;
}
+
+ valueNumbering[C] = nextValueNumber;
+ return nextValueNumber++;
}
/// Returns true if a value number exists for the specified value.
@@ -708,10 +720,8 @@ void GVNPass::ValueTable::erase(Value *V) {
/// verifyRemoved - Verify that the value is removed from all internal data
/// structures.
void GVNPass::ValueTable::verifyRemoved(const Value *V) const {
- for (DenseMap<Value*, uint32_t>::const_iterator
- I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) {
- assert(I->first != V && "Inst still occurs in value numbering map!");
- }
+ assert(!valueNumbering.contains(V) &&
+ "Inst still occurs in value numbering map!");
}
//===----------------------------------------------------------------------===//
@@ -772,7 +782,7 @@ void GVNPass::printPipeline(
static_cast<PassInfoMixin<GVNPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Options.AllowPRE != std::nullopt)
OS << (*Options.AllowPRE ? "" : "no-") << "pre;";
if (Options.AllowLoadPRE != std::nullopt)
@@ -782,7 +792,7 @@ void GVNPass::printPipeline(
<< "split-backedge-load-pre;";
if (Options.AllowMemDep != std::nullopt)
OS << (*Options.AllowMemDep ? "" : "no-") << "memdep";
- OS << ">";
+ OS << '>';
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -930,6 +940,18 @@ static bool IsValueFullyAvailableInBlock(
return !UnavailableBB;
}
+/// If the specified OldValue exists in ValuesPerBlock, replace its value with
+/// NewValue.
+static void replaceValuesPerBlockEntry(
+ SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, Value *OldValue,
+ Value *NewValue) {
+ for (AvailableValueInBlock &V : ValuesPerBlock) {
+ if ((V.AV.isSimpleValue() && V.AV.getSimpleValue() == OldValue) ||
+ (V.AV.isCoercedLoadValue() && V.AV.getCoercedLoadValue() == OldValue))
+ V = AvailableValueInBlock::get(V.BB, NewValue);
+ }
+}
+
/// Given a set of loads specified by ValuesPerBlock,
/// construct SSA form, allowing us to eliminate Load. This returns the value
/// that should be used at Load's definition site.
@@ -986,7 +1008,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load,
if (isSimpleValue()) {
Res = getSimpleValue();
if (Res->getType() != LoadTy) {
- Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL);
+ Res = getValueForLoad(Res, Offset, LoadTy, InsertPt, DL);
LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset
<< " " << *getSimpleValue() << '\n'
@@ -997,14 +1019,23 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load,
LoadInst *CoercedLoad = getCoercedLoadValue();
if (CoercedLoad->getType() == LoadTy && Offset == 0) {
Res = CoercedLoad;
+ combineMetadataForCSE(CoercedLoad, Load, false);
} else {
- Res = getLoadValueForLoad(CoercedLoad, Offset, LoadTy, InsertPt, DL);
- // We would like to use gvn.markInstructionForDeletion here, but we can't
- // because the load is already memoized into the leader map table that GVN
- // tracks. It is potentially possible to remove the load from the table,
- // but then there all of the operations based on it would need to be
- // rehashed. Just leave the dead load around.
- gvn.getMemDep().removeInstruction(CoercedLoad);
+ Res = getValueForLoad(CoercedLoad, Offset, LoadTy, InsertPt, DL);
+ // We are adding a new user for this load, for which the original
+ // metadata may not hold. Additionally, the new load may have a different
+ // size and type, so their metadata cannot be combined in any
+ // straightforward way.
+ // Drop all metadata that is not known to cause immediate UB on violation,
+ // unless the load has !noundef, in which case all metadata violations
+ // will be promoted to UB.
+ // TODO: We can combine noalias/alias.scope metadata here, because it is
+ // independent of the load type.
+ if (!CoercedLoad->hasMetadata(LLVMContext::MD_noundef))
+ CoercedLoad->dropUnknownNonDebugMetadata(
+ {LLVMContext::MD_dereferenceable,
+ LLVMContext::MD_dereferenceable_or_null,
+ LLVMContext::MD_invariant_load, LLVMContext::MD_invariant_group});
LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset
<< " " << *getCoercedLoadValue() << '\n'
<< *Res << '\n'
@@ -1314,9 +1345,67 @@ void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps,
"post condition violation");
}
+/// Given the following code, v1 is partially available on some edges, but not
+/// available on the edge from PredBB. This function tries to find if there is
+/// another identical load in the other successor of PredBB.
+///
+/// v0 = load %addr
+/// br %LoadBB
+///
+/// LoadBB:
+/// v1 = load %addr
+/// ...
+///
+/// PredBB:
+/// ...
+/// br %cond, label %LoadBB, label %SuccBB
+///
+/// SuccBB:
+/// v2 = load %addr
+/// ...
+///
+LoadInst *GVNPass::findLoadToHoistIntoPred(BasicBlock *Pred, BasicBlock *LoadBB,
+ LoadInst *Load) {
+ // For simplicity we handle a Pred has 2 successors only.
+ auto *Term = Pred->getTerminator();
+ if (Term->getNumSuccessors() != 2 || Term->isExceptionalTerminator())
+ return nullptr;
+ auto *SuccBB = Term->getSuccessor(0);
+ if (SuccBB == LoadBB)
+ SuccBB = Term->getSuccessor(1);
+ if (!SuccBB->getSinglePredecessor())
+ return nullptr;
+
+ unsigned int NumInsts = MaxNumInsnsPerBlock;
+ for (Instruction &Inst : *SuccBB) {
+ if (Inst.isDebugOrPseudoInst())
+ continue;
+ if (--NumInsts == 0)
+ return nullptr;
+
+ if (!Inst.isIdenticalTo(Load))
+ continue;
+
+ MemDepResult Dep = MD->getDependency(&Inst);
+ // If an identical load doesn't depends on any local instructions, it can
+ // be safely moved to PredBB.
+ // Also check for the implicit control flow instructions. See the comments
+ // in PerformLoadPRE for details.
+ if (Dep.isNonLocal() && !ICF->isDominatedByICFIFromSameBlock(&Inst))
+ return cast<LoadInst>(&Inst);
+
+ // Otherwise there is something in the same BB clobbers the memory, we can't
+ // move this and later load to PredBB.
+ return nullptr;
+ }
+
+ return nullptr;
+}
+
void GVNPass::eliminatePartiallyRedundantLoad(
LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
- MapVector<BasicBlock *, Value *> &AvailableLoads) {
+ MapVector<BasicBlock *, Value *> &AvailableLoads,
+ MapVector<BasicBlock *, LoadInst *> *CriticalEdgePredAndLoad) {
for (const auto &AvailableLoad : AvailableLoads) {
BasicBlock *UnavailableBlock = AvailableLoad.first;
Value *LoadPtr = AvailableLoad.second;
@@ -1370,10 +1459,29 @@ void GVNPass::eliminatePartiallyRedundantLoad(
AvailableValueInBlock::get(UnavailableBlock, NewLoad));
MD->invalidateCachedPointerInfo(LoadPtr);
LLVM_DEBUG(dbgs() << "GVN INSERTED " << *NewLoad << '\n');
+
+ // For PredBB in CriticalEdgePredAndLoad we need to replace the uses of old
+ // load instruction with the new created load instruction.
+ if (CriticalEdgePredAndLoad) {
+ auto I = CriticalEdgePredAndLoad->find(UnavailableBlock);
+ if (I != CriticalEdgePredAndLoad->end()) {
+ ++NumPRELoadMoved2CEPred;
+ ICF->insertInstructionTo(NewLoad, UnavailableBlock);
+ LoadInst *OldLoad = I->second;
+ combineMetadataForCSE(NewLoad, OldLoad, false);
+ OldLoad->replaceAllUsesWith(NewLoad);
+ replaceValuesPerBlockEntry(ValuesPerBlock, OldLoad, NewLoad);
+ if (uint32_t ValNo = VN.lookup(OldLoad, false))
+ removeFromLeaderTable(ValNo, OldLoad, OldLoad->getParent());
+ VN.erase(OldLoad);
+ removeInstruction(OldLoad);
+ }
+ }
}
// Perform PHI construction.
Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this);
+ // ConstructSSAForLoadSet is responsible for combining metadata.
Load->replaceAllUsesWith(V);
if (isa<PHINode>(V))
V->takeName(Load);
@@ -1456,7 +1564,12 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
for (BasicBlock *UnavailableBB : UnavailableBlocks)
FullyAvailableBlocks[UnavailableBB] = AvailabilityState::Unavailable;
- SmallVector<BasicBlock *, 4> CriticalEdgePred;
+ // The edge from Pred to LoadBB is a critical edge will be splitted.
+ SmallVector<BasicBlock *, 4> CriticalEdgePredSplit;
+ // The edge from Pred to LoadBB is a critical edge, another successor of Pred
+ // contains a load can be moved to Pred. This data structure maps the Pred to
+ // the movable load.
+ MapVector<BasicBlock *, LoadInst *> CriticalEdgePredAndLoad;
for (BasicBlock *Pred : predecessors(LoadBB)) {
// If any predecessor block is an EH pad that does not allow non-PHI
// instructions before the terminator, we can't PRE the load.
@@ -1496,7 +1609,10 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
return false;
}
- CriticalEdgePred.push_back(Pred);
+ if (LoadInst *LI = findLoadToHoistIntoPred(Pred, LoadBB, Load))
+ CriticalEdgePredAndLoad[Pred] = LI;
+ else
+ CriticalEdgePredSplit.push_back(Pred);
} else {
// Only add the predecessors that will not be split for now.
PredLoads[Pred] = nullptr;
@@ -1504,31 +1620,38 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
}
// Decide whether PRE is profitable for this load.
- unsigned NumUnavailablePreds = PredLoads.size() + CriticalEdgePred.size();
+ unsigned NumInsertPreds = PredLoads.size() + CriticalEdgePredSplit.size();
+ unsigned NumUnavailablePreds = NumInsertPreds +
+ CriticalEdgePredAndLoad.size();
assert(NumUnavailablePreds != 0 &&
"Fully available value should already be eliminated!");
+ (void)NumUnavailablePreds;
- // If this load is unavailable in multiple predecessors, reject it.
+ // If we need to insert new load in multiple predecessors, reject it.
// FIXME: If we could restructure the CFG, we could make a common pred with
// all the preds that don't have an available Load and insert a new load into
// that one block.
- if (NumUnavailablePreds != 1)
+ if (NumInsertPreds > 1)
return false;
// Now we know where we will insert load. We must ensure that it is safe
// to speculatively execute the load at that points.
if (MustEnsureSafetyOfSpeculativeExecution) {
- if (CriticalEdgePred.size())
+ if (CriticalEdgePredSplit.size())
if (!isSafeToSpeculativelyExecute(Load, LoadBB->getFirstNonPHI(), AC, DT))
return false;
for (auto &PL : PredLoads)
if (!isSafeToSpeculativelyExecute(Load, PL.first->getTerminator(), AC,
DT))
return false;
+ for (auto &CEP : CriticalEdgePredAndLoad)
+ if (!isSafeToSpeculativelyExecute(Load, CEP.first->getTerminator(), AC,
+ DT))
+ return false;
}
// Split critical edges, and update the unavailable predecessors accordingly.
- for (BasicBlock *OrigPred : CriticalEdgePred) {
+ for (BasicBlock *OrigPred : CriticalEdgePredSplit) {
BasicBlock *NewPred = splitCriticalEdges(OrigPred, LoadBB);
assert(!PredLoads.count(OrigPred) && "Split edges shouldn't be in map!");
PredLoads[NewPred] = nullptr;
@@ -1536,6 +1659,9 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
<< LoadBB->getName() << '\n');
}
+ for (auto &CEP : CriticalEdgePredAndLoad)
+ PredLoads[CEP.first] = nullptr;
+
// Check if the load can safely be moved to all the unavailable predecessors.
bool CanDoPRE = true;
const DataLayout &DL = Load->getModule()->getDataLayout();
@@ -1555,8 +1681,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
BasicBlock *Cur = Load->getParent();
while (Cur != LoadBB) {
PHITransAddr Address(LoadPtr, DL, AC);
- LoadPtr = Address.PHITranslateWithInsertion(
- Cur, Cur->getSinglePredecessor(), *DT, NewInsts);
+ LoadPtr = Address.translateWithInsertion(Cur, Cur->getSinglePredecessor(),
+ *DT, NewInsts);
if (!LoadPtr) {
CanDoPRE = false;
break;
@@ -1566,8 +1692,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
if (LoadPtr) {
PHITransAddr Address(LoadPtr, DL, AC);
- LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT,
- NewInsts);
+ LoadPtr = Address.translateWithInsertion(LoadBB, UnavailablePred, *DT,
+ NewInsts);
}
// If we couldn't find or insert a computation of this phi translated value,
// we fail PRE.
@@ -1592,7 +1718,7 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
}
// HINT: Don't revert the edge-splitting as following transformation may
// also need to split these critical edges.
- return !CriticalEdgePred.empty();
+ return !CriticalEdgePredSplit.empty();
}
// Okay, we can eliminate this load by inserting a reload in the predecessor
@@ -1617,7 +1743,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
VN.lookupOrAdd(I);
}
- eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, PredLoads);
+ eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, PredLoads,
+ &CriticalEdgePredAndLoad);
++NumPRELoad;
return true;
}
@@ -1696,7 +1823,8 @@ bool GVNPass::performLoopLoadPRE(LoadInst *Load,
AvailableLoads[Preheader] = LoadPtr;
LLVM_DEBUG(dbgs() << "GVN REMOVING PRE LOOP LOAD: " << *Load << '\n');
- eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, AvailableLoads);
+ eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, AvailableLoads,
+ /*CriticalEdgePredAndLoad*/ nullptr);
++NumPRELoopLoad;
return true;
}
@@ -1772,6 +1900,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
// Perform PHI construction.
Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this);
+ // ConstructSSAForLoadSet is responsible for combining metadata.
Load->replaceAllUsesWith(V);
if (isa<PHINode>(V))
@@ -1823,7 +1952,7 @@ static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
return true;
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
- return true;;
+ return true;
// TODO: Handle vector floating point constants
}
return false;
@@ -1849,7 +1978,7 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
return true;
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
- return true;;
+ return true;
// TODO: Handle vector floating point constants
}
return false;
@@ -1907,10 +2036,14 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
MSSAU->insertDef(cast<MemoryDef>(NewDef), /*RenameUses=*/false);
}
}
- if (isAssumeWithEmptyBundle(*IntrinsicI))
+ if (isAssumeWithEmptyBundle(*IntrinsicI)) {
markInstructionForDeletion(IntrinsicI);
+ return true;
+ }
return false;
- } else if (isa<Constant>(V)) {
+ }
+
+ if (isa<Constant>(V)) {
// If it's not false, and constant, it must evaluate to true. This means our
// assume is assume(true), and thus, pointless, and we don't want to do
// anything more here.
@@ -2043,8 +2176,8 @@ bool GVNPass::processLoad(LoadInst *L) {
Value *AvailableValue = AV->MaterializeAdjustedValue(L, L, *this);
- // Replace the load!
- patchAndReplaceAllUsesWith(L, AvailableValue);
+ // MaterializeAdjustedValue is responsible for combining metadata.
+ L->replaceAllUsesWith(AvailableValue);
markInstructionForDeletion(L);
if (MSSAU)
MSSAU->removeMemoryAccess(L);
@@ -2543,7 +2676,9 @@ bool GVNPass::processInstruction(Instruction *I) {
// Failure, just remember this instance for future use.
addToLeaderTable(Num, I, I->getParent());
return false;
- } else if (Repl == I) {
+ }
+
+ if (Repl == I) {
// If I was the result of a shortcut PRE, it might already be in the table
// and the best replacement for itself. Nothing to do.
return false;
@@ -2669,12 +2804,7 @@ bool GVNPass::processBlock(BasicBlock *BB) {
LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n');
salvageKnowledge(I, AC);
salvageDebugInfo(*I);
- if (MD) MD->removeInstruction(I);
- if (MSSAU)
- MSSAU->removeMemoryAccess(I);
- LLVM_DEBUG(verifyRemoved(I));
- ICF->removeInstruction(I);
- I->eraseFromParent();
+ removeInstruction(I);
}
InstrsToErase.clear();
@@ -2765,9 +2895,6 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {
// We don't currently value number ANY inline asm calls.
if (CallB->isInlineAsm())
return false;
- // Don't do PRE on convergent calls.
- if (CallB->isConvergent())
- return false;
}
uint32_t ValNo = VN.lookup(CurInst);
@@ -2855,7 +2982,9 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {
PREInstr = CurInst->clone();
if (!performScalarPREInsertion(PREInstr, PREPred, CurrentBlock, ValNo)) {
// If we failed insertion, make sure we remove the instruction.
- LLVM_DEBUG(verifyRemoved(PREInstr));
+#ifndef NDEBUG
+ verifyRemoved(PREInstr);
+#endif
PREInstr->deleteValue();
return false;
}
@@ -2894,15 +3023,7 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {
removeFromLeaderTable(ValNo, CurInst, CurrentBlock);
LLVM_DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n');
- if (MD)
- MD->removeInstruction(CurInst);
- if (MSSAU)
- MSSAU->removeMemoryAccess(CurInst);
- LLVM_DEBUG(verifyRemoved(CurInst));
- // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes
- // some assertion failures.
- ICF->removeInstruction(CurInst);
- CurInst->eraseFromParent();
+ removeInstruction(CurInst);
++NumGVNInstr;
return true;
@@ -2998,6 +3119,17 @@ void GVNPass::cleanupGlobalSets() {
InvalidBlockRPONumbers = true;
}
+void GVNPass::removeInstruction(Instruction *I) {
+ if (MD) MD->removeInstruction(I);
+ if (MSSAU)
+ MSSAU->removeMemoryAccess(I);
+#ifndef NDEBUG
+ verifyRemoved(I);
+#endif
+ ICF->removeInstruction(I);
+ I->eraseFromParent();
+}
+
/// Verify that the specified instruction does not occur in our
/// internal data structures.
void GVNPass::verifyRemoved(const Instruction *Inst) const {
diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
index bbff497b7d92..b564f00eb9d1 100644
--- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
@@ -62,13 +62,10 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
@@ -519,39 +516,6 @@ private:
std::pair<unsigned, unsigned> hoistExpressions(Function &F);
};
-class GVNHoistLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- GVNHoistLegacyPass() : FunctionPass(ID) {
- initializeGVNHoistLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
- auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
-
- GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA);
- return G.run(F);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<PostDominatorTreeWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<MemoryDependenceWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-
bool GVNHoist::run(Function &F) {
NumFuncArgs = F.arg_size();
VN.setDomTree(DT);
@@ -808,15 +772,20 @@ bool GVNHoist::valueAnticipable(CHIArgs C, Instruction *TI) const {
void GVNHoist::checkSafety(CHIArgs C, BasicBlock *BB, GVNHoist::InsKind K,
SmallVectorImpl<CHIArg> &Safe) {
int NumBBsOnAllPaths = MaxNumberOfBBSInPath;
+ const Instruction *T = BB->getTerminator();
for (auto CHI : C) {
Instruction *Insn = CHI.I;
if (!Insn) // No instruction was inserted in this CHI.
continue;
+ // If the Terminator is some kind of "exotic terminator" that produces a
+ // value (such as InvokeInst, CallBrInst, or CatchSwitchInst) which the CHI
+ // uses, it is not safe to hoist the use above the def.
+ if (!T->use_empty() && is_contained(Insn->operands(), cast<const Value>(T)))
+ continue;
if (K == InsKind::Scalar) {
if (safeToHoistScalar(BB, Insn->getParent(), NumBBsOnAllPaths))
Safe.push_back(CHI);
} else {
- auto *T = BB->getTerminator();
if (MemoryUseOrDef *UD = MSSA->getMemoryAccess(Insn))
if (safeToHoistLdSt(T, Insn, UD, K, NumBBsOnAllPaths))
Safe.push_back(CHI);
@@ -1251,17 +1220,3 @@ PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-char GVNHoistLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist",
- "Early GVN Hoisting of Expressions", false, false)
-INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist",
- "Early GVN Hoisting of Expressions", false, false)
-
-FunctionPass *llvm::createGVNHoistPass() { return new GVNHoistLegacyPass(); }
diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp
index 5fb8a77051fb..26a6978656e6 100644
--- a/llvm/lib/Transforms/Scalar/GVNSink.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp
@@ -54,8 +54,6 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ArrayRecycler.h"
#include "llvm/Support/AtomicOrdering.h"
@@ -63,7 +61,6 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/GVNExpression.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -154,7 +151,7 @@ public:
void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
for (auto II = Insts.begin(); II != Insts.end();) {
- if (!llvm::is_contained(Blocks, (*II)->getParent())) {
+ if (!Blocks.contains((*II)->getParent())) {
ActiveBlocks.remove((*II)->getParent());
II = Insts.erase(II);
} else {
@@ -272,7 +269,7 @@ public:
auto VI = Values.begin();
while (BI != Blocks.end()) {
assert(VI != Values.end());
- if (!llvm::is_contained(NewBlocks, *BI)) {
+ if (!NewBlocks.contains(*BI)) {
BI = Blocks.erase(BI);
VI = Values.erase(VI);
} else {
@@ -886,29 +883,6 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
NumRemoved += Insts.size() - 1;
}
-////////////////////////////////////////////////////////////////////////////////
-// Pass machinery / boilerplate
-
-class GVNSinkLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- GVNSinkLegacyPass() : FunctionPass(ID) {
- initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- GVNSink G;
- return G.run(F);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-
} // end anonymous namespace
PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
@@ -917,14 +891,3 @@ PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
-
-char GVNSinkLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink",
- "Early GVN sinking of Expressions", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink",
- "Early GVN sinking of Expressions", false, false)
-
-FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); }
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
index abe0babc3f12..62b40a23e38c 100644
--- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
@@ -69,6 +69,7 @@ using namespace llvm;
STATISTIC(GuardsEliminated, "Number of eliminated guards");
STATISTIC(CondBranchEliminated, "Number of eliminated conditional branches");
+STATISTIC(FreezeAdded, "Number of freeze instruction introduced");
static cl::opt<bool>
WidenBranchGuards("guard-widening-widen-branch-guards", cl::Hidden,
@@ -113,6 +114,23 @@ static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) {
++GuardsEliminated;
}
+/// Find a point at which the widened condition of \p Guard should be inserted.
+/// When it is represented as intrinsic call, we can do it right before the call
+/// instruction. However, when we are dealing with widenable branch, we must
+/// account for the following situation: widening should not turn a
+/// loop-invariant condition into a loop-variant. It means that if
+/// widenable.condition() call is invariant (w.r.t. any loop), the new wide
+/// condition should stay invariant. Otherwise there can be a miscompile, like
+/// the one described at https://github.com/llvm/llvm-project/issues/60234. The
+/// safest way to do it is to expand the new condition at WC's block.
+static Instruction *findInsertionPointForWideCondition(Instruction *Guard) {
+ Value *Condition, *WC;
+ BasicBlock *IfTrue, *IfFalse;
+ if (parseWidenableBranch(Guard, Condition, WC, IfTrue, IfFalse))
+ return cast<Instruction>(WC);
+ return Guard;
+}
+
class GuardWideningImpl {
DominatorTree &DT;
PostDominatorTree *PDT;
@@ -170,16 +188,16 @@ class GuardWideningImpl {
bool InvertCond);
/// Helper to check if \p V can be hoisted to \p InsertPos.
- bool isAvailableAt(const Value *V, const Instruction *InsertPos) const {
+ bool canBeHoistedTo(const Value *V, const Instruction *InsertPos) const {
SmallPtrSet<const Instruction *, 8> Visited;
- return isAvailableAt(V, InsertPos, Visited);
+ return canBeHoistedTo(V, InsertPos, Visited);
}
- bool isAvailableAt(const Value *V, const Instruction *InsertPos,
- SmallPtrSetImpl<const Instruction *> &Visited) const;
+ bool canBeHoistedTo(const Value *V, const Instruction *InsertPos,
+ SmallPtrSetImpl<const Instruction *> &Visited) const;
/// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c
- /// isAvailableAt returned true.
+ /// canBeHoistedTo returned true.
void makeAvailableAt(Value *V, Instruction *InsertPos) const;
/// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try
@@ -192,6 +210,10 @@ class GuardWideningImpl {
bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt,
Value *&Result, bool InvertCondition);
+ /// Adds freeze to Orig and push it as far as possible very aggressively.
+ /// Also replaces all uses of frozen instruction with frozen version.
+ Value *freezeAndPush(Value *Orig, Instruction *InsertPt);
+
/// Represents a range check of the form \c Base + \c Offset u< \c Length,
/// with the constraint that \c Length is not negative. \c CheckInst is the
/// pre-existing instruction in the IR that computes the result of this range
@@ -263,8 +285,8 @@ class GuardWideningImpl {
void widenGuard(Instruction *ToWiden, Value *NewCondition,
bool InvertCondition) {
Value *Result;
-
- widenCondCommon(getCondition(ToWiden), NewCondition, ToWiden, Result,
+ Instruction *InsertPt = findInsertionPointForWideCondition(ToWiden);
+ widenCondCommon(getCondition(ToWiden), NewCondition, InsertPt, Result,
InvertCondition);
if (isGuardAsWidenableBranch(ToWiden)) {
setWidenableBranchCond(cast<BranchInst>(ToWiden), Result);
@@ -422,7 +444,10 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
HoistingOutOfLoop = true;
}
- if (!isAvailableAt(getCondition(DominatedInstr), DominatingGuard))
+ auto *WideningPoint = findInsertionPointForWideCondition(DominatingGuard);
+ if (!canBeHoistedTo(getCondition(DominatedInstr), WideningPoint))
+ return WS_IllegalOrNegative;
+ if (!canBeHoistedTo(getCondition(DominatingGuard), WideningPoint))
return WS_IllegalOrNegative;
// If the guard was conditional executed, it may never be reached
@@ -440,30 +465,70 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
if (HoistingOutOfLoop)
return WS_Positive;
- // Returns true if we might be hoisting above explicit control flow. Note
- // that this completely ignores implicit control flow (guards, calls which
- // throw, etc...). That choice appears arbitrary.
- auto MaybeHoistingOutOfIf = [&]() {
- auto *DominatingBlock = DominatingGuard->getParent();
- auto *DominatedBlock = DominatedInstr->getParent();
- if (isGuardAsWidenableBranch(DominatingGuard))
- DominatingBlock = cast<BranchInst>(DominatingGuard)->getSuccessor(0);
+ // For a given basic block \p BB, return its successor which is guaranteed or
+ // highly likely will be taken as its successor.
+ auto GetLikelySuccessor = [](const BasicBlock * BB)->const BasicBlock * {
+ if (auto *UniqueSucc = BB->getUniqueSuccessor())
+ return UniqueSucc;
+ auto *Term = BB->getTerminator();
+ Value *Cond = nullptr;
+ const BasicBlock *IfTrue = nullptr, *IfFalse = nullptr;
+ using namespace PatternMatch;
+ if (!match(Term, m_Br(m_Value(Cond), m_BasicBlock(IfTrue),
+ m_BasicBlock(IfFalse))))
+ return nullptr;
+ // For constant conditions, only one dynamical successor is possible
+ if (auto *ConstCond = dyn_cast<ConstantInt>(Cond))
+ return ConstCond->isAllOnesValue() ? IfTrue : IfFalse;
+ // If one of successors ends with deopt, another one is likely.
+ if (IfFalse->getPostdominatingDeoptimizeCall())
+ return IfTrue;
+ if (IfTrue->getPostdominatingDeoptimizeCall())
+ return IfFalse;
+ // TODO: Use branch frequency metatada to allow hoisting through non-deopt
+ // branches?
+ return nullptr;
+ };
+
+ // Returns true if we might be hoisting above explicit control flow into a
+ // considerably hotter block. Note that this completely ignores implicit
+ // control flow (guards, calls which throw, etc...). That choice appears
+ // arbitrary (we assume that implicit control flow exits are all rare).
+ auto MaybeHoistingToHotterBlock = [&]() {
+ const auto *DominatingBlock = DominatingGuard->getParent();
+ const auto *DominatedBlock = DominatedInstr->getParent();
+
+ // Descend as low as we can, always taking the likely successor.
+ assert(DT.isReachableFromEntry(DominatingBlock) && "Unreached code");
+ assert(DT.isReachableFromEntry(DominatedBlock) && "Unreached code");
+ assert(DT.dominates(DominatingBlock, DominatedBlock) && "No dominance");
+ while (DominatedBlock != DominatingBlock) {
+ auto *LikelySucc = GetLikelySuccessor(DominatingBlock);
+ // No likely successor?
+ if (!LikelySucc)
+ break;
+ // Only go down the dominator tree.
+ if (!DT.properlyDominates(DominatingBlock, LikelySucc))
+ break;
+ DominatingBlock = LikelySucc;
+ }
- // Same Block?
+ // Found?
if (DominatedBlock == DominatingBlock)
return false;
- // Obvious successor (common loop header/preheader case)
- if (DominatedBlock == DominatingBlock->getUniqueSuccessor())
- return false;
+ // We followed the likely successor chain and went past the dominated
+ // block. It means that the dominated guard is in dead/very cold code.
+ if (!DT.dominates(DominatingBlock, DominatedBlock))
+ return true;
// TODO: diamond, triangle cases
if (!PDT) return true;
return !PDT->dominates(DominatedBlock, DominatingBlock);
};
- return MaybeHoistingOutOfIf() ? WS_IllegalOrNegative : WS_Neutral;
+ return MaybeHoistingToHotterBlock() ? WS_IllegalOrNegative : WS_Neutral;
}
-bool GuardWideningImpl::isAvailableAt(
+bool GuardWideningImpl::canBeHoistedTo(
const Value *V, const Instruction *Loc,
SmallPtrSetImpl<const Instruction *> &Visited) const {
auto *Inst = dyn_cast<Instruction>(V);
@@ -482,7 +547,7 @@ bool GuardWideningImpl::isAvailableAt(
assert(DT.isReachableFromEntry(Inst->getParent()) &&
"We did a DFS from the block entry!");
return all_of(Inst->operands(),
- [&](Value *Op) { return isAvailableAt(Op, Loc, Visited); });
+ [&](Value *Op) { return canBeHoistedTo(Op, Loc, Visited); });
}
void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const {
@@ -491,14 +556,115 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const {
return;
assert(isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) &&
- !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!");
+ !Inst->mayReadFromMemory() &&
+ "Should've checked with canBeHoistedTo!");
for (Value *Op : Inst->operands())
makeAvailableAt(Op, Loc);
Inst->moveBefore(Loc);
- // If we moved instruction before guard we must clean poison generating flags.
- Inst->dropPoisonGeneratingFlags();
+}
+
+// Return Instruction before which we can insert freeze for the value V as close
+// to def as possible. If there is no place to add freeze, return nullptr.
+static Instruction *getFreezeInsertPt(Value *V, const DominatorTree &DT) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return &*DT.getRoot()->getFirstNonPHIOrDbgOrAlloca();
+
+ auto *Res = I->getInsertionPointAfterDef();
+ // If there is no place to add freeze - return nullptr.
+ if (!Res || !DT.dominates(I, Res))
+ return nullptr;
+
+ // If there is a User dominated by original I, then it should be dominated
+ // by Freeze instruction as well.
+ if (any_of(I->users(), [&](User *U) {
+ Instruction *User = cast<Instruction>(U);
+ return Res != User && DT.dominates(I, User) && !DT.dominates(Res, User);
+ }))
+ return nullptr;
+ return Res;
+}
+
+Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) {
+ if (isGuaranteedNotToBePoison(Orig, nullptr, InsertPt, &DT))
+ return Orig;
+ Instruction *InsertPtAtDef = getFreezeInsertPt(Orig, DT);
+ if (!InsertPtAtDef)
+ return new FreezeInst(Orig, "gw.freeze", InsertPt);
+ if (isa<Constant>(Orig) || isa<GlobalValue>(Orig))
+ return new FreezeInst(Orig, "gw.freeze", InsertPtAtDef);
+
+ SmallSet<Value *, 16> Visited;
+ SmallVector<Value *, 16> Worklist;
+ SmallSet<Instruction *, 16> DropPoisonFlags;
+ SmallVector<Value *, 16> NeedFreeze;
+ DenseMap<Value *, FreezeInst *> CacheOfFreezes;
+
+ // A bit overloaded data structures. Visited contains constant/GV
+ // if we already met it. In this case CacheOfFreezes has a freeze if it is
+ // required.
+ auto handleConstantOrGlobal = [&](Use &U) {
+ Value *Def = U.get();
+ if (!isa<Constant>(Def) && !isa<GlobalValue>(Def))
+ return false;
+
+ if (Visited.insert(Def).second) {
+ if (isGuaranteedNotToBePoison(Def, nullptr, InsertPt, &DT))
+ return true;
+ CacheOfFreezes[Def] = new FreezeInst(Def, Def->getName() + ".gw.fr",
+ getFreezeInsertPt(Def, DT));
+ }
+
+ if (CacheOfFreezes.count(Def))
+ U.set(CacheOfFreezes[Def]);
+ return true;
+ };
+
+ Worklist.push_back(Orig);
+ while (!Worklist.empty()) {
+ Value *V = Worklist.pop_back_val();
+ if (!Visited.insert(V).second)
+ continue;
+
+ if (isGuaranteedNotToBePoison(V, nullptr, InsertPt, &DT))
+ continue;
+
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I || canCreateUndefOrPoison(cast<Operator>(I),
+ /*ConsiderFlagsAndMetadata*/ false)) {
+ NeedFreeze.push_back(V);
+ continue;
+ }
+ // Check all operands. If for any of them we cannot insert Freeze,
+ // stop here. Otherwise, iterate.
+ if (any_of(I->operands(), [&](Value *Op) {
+ return isa<Instruction>(Op) && !getFreezeInsertPt(Op, DT);
+ })) {
+ NeedFreeze.push_back(I);
+ continue;
+ }
+ DropPoisonFlags.insert(I);
+ for (Use &U : I->operands())
+ if (!handleConstantOrGlobal(U))
+ Worklist.push_back(U.get());
+ }
+ for (Instruction *I : DropPoisonFlags)
+ I->dropPoisonGeneratingFlagsAndMetadata();
+
+ Value *Result = Orig;
+ for (Value *V : NeedFreeze) {
+ auto *FreezeInsertPt = getFreezeInsertPt(V, DT);
+ FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr", FreezeInsertPt);
+ ++FreezeAdded;
+ if (V == Orig)
+ Result = FI;
+ V->replaceUsesWithIf(
+ FI, [&](const Use & U)->bool { return U.getUser() != FI; });
+ }
+
+ return Result;
}
bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
@@ -532,6 +698,8 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
if (InsertPt) {
ConstantInt *NewRHS =
ConstantInt::get(Cond0->getContext(), NewRHSAP);
+ assert(canBeHoistedTo(LHS, InsertPt) && "must be");
+ makeAvailableAt(LHS, InsertPt);
Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk");
}
return true;
@@ -558,6 +726,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
}
assert(Result && "Failed to find result value");
Result->setName("wide.chk");
+ Result = freezeAndPush(Result, InsertPt);
}
return true;
}
@@ -570,6 +739,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
makeAvailableAt(Cond1, InsertPt);
if (InvertCondition)
Cond1 = BinaryOperator::CreateNot(Cond1, "inverted", InsertPt);
+ Cond1 = freezeAndPush(Cond1, InsertPt);
Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt);
}
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index c834e51b5f29..40475d9563b2 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -64,15 +64,12 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -93,15 +90,6 @@ STATISTIC(NumLFTR , "Number of loop exit tests replaced");
STATISTIC(NumElimExt , "Number of IV sign/zero extends eliminated");
STATISTIC(NumElimIV , "Number of congruent IVs eliminated");
-// Trip count verification can be enabled by default under NDEBUG if we
-// implement a strong expression equivalence checker in SCEV. Until then, we
-// use the verify-indvars flag, which may assert in some cases.
-static cl::opt<bool> VerifyIndvars(
- "verify-indvars", cl::Hidden,
- cl::desc("Verify the ScalarEvolution result after running indvars. Has no "
- "effect in release builds. (Note: this adds additional SCEV "
- "queries potentially changing the analysis result)"));
-
static cl::opt<ReplaceExitVal> ReplaceExitValue(
"replexitval", cl::Hidden, cl::init(OnlyCheapRepl),
cl::desc("Choose the strategy to replace exit value in IndVarSimplify"),
@@ -416,8 +404,8 @@ bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) {
PHIs.push_back(&PN);
bool Changed = false;
- for (unsigned i = 0, e = PHIs.size(); i != e; ++i)
- if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i]))
+ for (WeakTrackingVH &PHI : PHIs)
+ if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHI))
Changed |= handleFloatingPointIV(L, PN);
// If the loop previously had floating-point IV, ScalarEvolution
@@ -759,50 +747,6 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {
return Phi != getLoopPhiForCounter(IncV, L);
}
-/// Return true if undefined behavior would provable be executed on the path to
-/// OnPathTo if Root produced a posion result. Note that this doesn't say
-/// anything about whether OnPathTo is actually executed or whether Root is
-/// actually poison. This can be used to assess whether a new use of Root can
-/// be added at a location which is control equivalent with OnPathTo (such as
-/// immediately before it) without introducing UB which didn't previously
-/// exist. Note that a false result conveys no information.
-static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
- Instruction *OnPathTo,
- DominatorTree *DT) {
- // Basic approach is to assume Root is poison, propagate poison forward
- // through all users we can easily track, and then check whether any of those
- // users are provable UB and must execute before out exiting block might
- // exit.
-
- // The set of all recursive users we've visited (which are assumed to all be
- // poison because of said visit)
- SmallSet<const Value *, 16> KnownPoison;
- SmallVector<const Instruction*, 16> Worklist;
- Worklist.push_back(Root);
- while (!Worklist.empty()) {
- const Instruction *I = Worklist.pop_back_val();
-
- // If we know this must trigger UB on a path leading our target.
- if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))
- return true;
-
- // If we can't analyze propagation through this instruction, just skip it
- // and transitive users. Safe as false is a conservative result.
- if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) {
- return KnownPoison.contains(U) && propagatesPoison(U);
- }))
- continue;
-
- if (KnownPoison.insert(I).second)
- for (const User *User : I->users())
- Worklist.push_back(cast<Instruction>(User));
- }
-
- // Might be non-UB, or might have a path we couldn't prove must execute on
- // way to exiting bb.
- return false;
-}
-
/// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils
/// down to checking that all operands are constant and listing instructions
/// that may hide undef.
@@ -845,20 +789,6 @@ static bool hasConcreteDef(Value *V) {
return hasConcreteDefImpl(V, Visited, 0);
}
-/// Return true if this IV has any uses other than the (soon to be rewritten)
-/// loop exit test.
-static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) {
- int LatchIdx = Phi->getBasicBlockIndex(LatchBlock);
- Value *IncV = Phi->getIncomingValue(LatchIdx);
-
- for (User *U : Phi->users())
- if (U != Cond && U != IncV) return false;
-
- for (User *U : IncV->users())
- if (U != Cond && U != Phi) return false;
- return true;
-}
-
/// Return true if the given phi is a "counter" in L. A counter is an
/// add recurance (of integer or pointer type) with an arbitrary start, and a
/// step of 1. Note that L must have exactly one latch.
@@ -910,10 +840,6 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
if (!isLoopCounter(Phi, L, SE))
continue;
- // Avoid comparing an integer IV against a pointer Limit.
- if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy())
- continue;
-
const auto *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
// AR may be a pointer type, while BECount is an integer type.
@@ -949,9 +875,9 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
const SCEV *Init = AR->getStart();
- if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) {
+ if (BestPhi && !isAlmostDeadIV(BestPhi, LatchBlock, Cond)) {
// Don't force a live loop counter if another IV can be used.
- if (AlmostDeadIV(Phi, LatchBlock, Cond))
+ if (isAlmostDeadIV(Phi, LatchBlock, Cond))
continue;
// Prefer to count-from-zero. This is a more "canonical" counter form. It
@@ -979,78 +905,29 @@ static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
const SCEV *ExitCount, bool UsePostInc, Loop *L,
SCEVExpander &Rewriter, ScalarEvolution *SE) {
assert(isLoopCounter(IndVar, L, SE));
+ assert(ExitCount->getType()->isIntegerTy() && "exit count must be integer");
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar));
- const SCEV *IVInit = AR->getStart();
assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride");
- // IVInit may be a pointer while ExitCount is an integer when FindLoopCounter
- // finds a valid pointer IV. Sign extend ExitCount in order to materialize a
- // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing
- // the existing GEPs whenever possible.
- if (IndVar->getType()->isPointerTy() &&
- !ExitCount->getType()->isPointerTy()) {
- // IVOffset will be the new GEP offset that is interpreted by GEP as a
- // signed value. ExitCount on the other hand represents the loop trip count,
- // which is an unsigned value. FindLoopCounter only allows induction
- // variables that have a positive unit stride of one. This means we don't
- // have to handle the case of negative offsets (yet) and just need to zero
- // extend ExitCount.
- Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType());
- const SCEV *IVOffset = SE->getTruncateOrZeroExtend(ExitCount, OfsTy);
- if (UsePostInc)
- IVOffset = SE->getAddExpr(IVOffset, SE->getOne(OfsTy));
-
- // Expand the code for the iteration count.
- assert(SE->isLoopInvariant(IVOffset, L) &&
- "Computed iteration count is not loop invariant!");
-
- const SCEV *IVLimit = SE->getAddExpr(IVInit, IVOffset);
- BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
- return Rewriter.expandCodeFor(IVLimit, IndVar->getType(), BI);
- } else {
- // In any other case, convert both IVInit and ExitCount to integers before
- // comparing. This may result in SCEV expansion of pointers, but in practice
- // SCEV will fold the pointer arithmetic away as such:
- // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc).
- //
- // Valid Cases: (1) both integers is most common; (2) both may be pointers
- // for simple memset-style loops.
- //
- // IVInit integer and ExitCount pointer would only occur if a canonical IV
- // were generated on top of case #2, which is not expected.
-
- // For unit stride, IVCount = Start + ExitCount with 2's complement
- // overflow.
-
- // For integer IVs, truncate the IV before computing IVInit + BECount,
- // unless we know apriori that the limit must be a constant when evaluated
- // in the bitwidth of the IV. We prefer (potentially) keeping a truncate
- // of the IV in the loop over a (potentially) expensive expansion of the
- // widened exit count add(zext(add)) expression.
- if (SE->getTypeSizeInBits(IVInit->getType())
- > SE->getTypeSizeInBits(ExitCount->getType())) {
- if (isa<SCEVConstant>(IVInit) && isa<SCEVConstant>(ExitCount))
- ExitCount = SE->getZeroExtendExpr(ExitCount, IVInit->getType());
- else
- IVInit = SE->getTruncateExpr(IVInit, ExitCount->getType());
- }
-
- const SCEV *IVLimit = SE->getAddExpr(IVInit, ExitCount);
-
- if (UsePostInc)
- IVLimit = SE->getAddExpr(IVLimit, SE->getOne(IVLimit->getType()));
-
- // Expand the code for the iteration count.
- assert(SE->isLoopInvariant(IVLimit, L) &&
- "Computed iteration count is not loop invariant!");
- // Ensure that we generate the same type as IndVar, or a smaller integer
- // type. In the presence of null pointer values, we have an integer type
- // SCEV expression (IVInit) for a pointer type IV value (IndVar).
- Type *LimitTy = ExitCount->getType()->isPointerTy() ?
- IndVar->getType() : ExitCount->getType();
- BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
- return Rewriter.expandCodeFor(IVLimit, LimitTy, BI);
+ // For integer IVs, truncate the IV before computing the limit unless we
+ // know apriori that the limit must be a constant when evaluated in the
+ // bitwidth of the IV. We prefer (potentially) keeping a truncate of the
+ // IV in the loop over a (potentially) expensive expansion of the widened
+ // exit count add(zext(add)) expression.
+ if (IndVar->getType()->isIntegerTy() &&
+ SE->getTypeSizeInBits(AR->getType()) >
+ SE->getTypeSizeInBits(ExitCount->getType())) {
+ const SCEV *IVInit = AR->getStart();
+ if (!isa<SCEVConstant>(IVInit) || !isa<SCEVConstant>(ExitCount))
+ AR = cast<SCEVAddRecExpr>(SE->getTruncateExpr(AR, ExitCount->getType()));
}
+
+ const SCEVAddRecExpr *ARBase = UsePostInc ? AR->getPostIncExpr(*SE) : AR;
+ const SCEV *IVLimit = ARBase->evaluateAtIteration(ExitCount, *SE);
+ assert(SE->isLoopInvariant(IVLimit, L) &&
+ "Computed iteration count is not loop invariant!");
+ return Rewriter.expandCodeFor(IVLimit, ARBase->getType(),
+ ExitingBB->getTerminator());
}
/// This method rewrites the exit condition of the loop to be a canonical !=
@@ -1148,8 +1025,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
// a truncate within in.
bool Extended = false;
const SCEV *IV = SE->getSCEV(CmpIndVar);
- const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar),
- ExitCnt->getType());
+ const SCEV *TruncatedIV = SE->getTruncateExpr(IV, ExitCnt->getType());
const SCEV *ZExtTrunc =
SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType());
@@ -1359,14 +1235,16 @@ createInvariantCond(const Loop *L, BasicBlock *ExitingBB,
const ScalarEvolution::LoopInvariantPredicate &LIP,
SCEVExpander &Rewriter) {
ICmpInst::Predicate InvariantPred = LIP.Pred;
- BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
- Rewriter.setInsertPoint(BI);
+ BasicBlock *Preheader = L->getLoopPreheader();
+ assert(Preheader && "Preheader doesn't exist");
+ Rewriter.setInsertPoint(Preheader->getTerminator());
auto *LHSV = Rewriter.expandCodeFor(LIP.LHS);
auto *RHSV = Rewriter.expandCodeFor(LIP.RHS);
bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
if (ExitIfTrue)
InvariantPred = ICmpInst::getInversePredicate(InvariantPred);
- IRBuilder<> Builder(BI);
+ IRBuilder<> Builder(Preheader->getTerminator());
+ BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
return Builder.CreateICmp(InvariantPred, LHSV, RHSV,
BI->getCondition()->getName());
}
@@ -1519,7 +1397,6 @@ static bool optimizeLoopExitWithUnknownExitCount(
auto *NewCond = *Replaced;
if (auto *NCI = dyn_cast<Instruction>(NewCond)) {
NCI->setName(OldCond->getName() + ".first_iter");
- NCI->moveBefore(cast<Instruction>(OldCond));
}
LLVM_DEBUG(dbgs() << "Unknown exit count: Replacing " << *OldCond
<< " with " << *NewCond << "\n");
@@ -2022,16 +1899,6 @@ bool IndVarSimplify::run(Loop *L) {
if (!L->isLoopSimplifyForm())
return false;
-#ifndef NDEBUG
- // Used below for a consistency check only
- // Note: Since the result returned by ScalarEvolution may depend on the order
- // in which previous results are added to its cache, the call to
- // getBackedgeTakenCount() may change following SCEV queries.
- const SCEV *BackedgeTakenCount;
- if (VerifyIndvars)
- BackedgeTakenCount = SE->getBackedgeTakenCount(L);
-#endif
-
bool Changed = false;
// If there are any floating-point recurrences, attempt to
// transform them to use integer recurrences.
@@ -2180,27 +2047,8 @@ bool IndVarSimplify::run(Loop *L) {
// Check a post-condition.
assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
"Indvars did not preserve LCSSA!");
-
- // Verify that LFTR, and any other change have not interfered with SCEV's
- // ability to compute trip count. We may have *changed* the exit count, but
- // only by reducing it.
-#ifndef NDEBUG
- if (VerifyIndvars && !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
- SE->forgetLoop(L);
- const SCEV *NewBECount = SE->getBackedgeTakenCount(L);
- if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) <
- SE->getTypeSizeInBits(NewBECount->getType()))
- NewBECount = SE->getTruncateOrNoop(NewBECount,
- BackedgeTakenCount->getType());
- else
- BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount,
- NewBECount->getType());
- assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount,
- NewBECount) && "indvars must preserve SCEV");
- }
if (VerifyMemorySSA && MSSAU)
MSSAU->getMemorySSA()->verifyMemorySSA();
-#endif
return Changed;
}
@@ -2222,54 +2070,3 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-
-struct IndVarSimplifyLegacyPass : public LoopPass {
- static char ID; // Pass identification, replacement for typeid
-
- IndVarSimplifyLegacyPass() : LoopPass(ID) {
- initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
-
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr;
- auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
- auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr;
- const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
- auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- MemorySSA *MSSA = nullptr;
- if (MSSAAnalysis)
- MSSA = &MSSAAnalysis->getMSSA();
-
- IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA, AllowIVWidening);
- return IVS.run(L);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addPreserved<MemorySSAWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-char IndVarSimplifyLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars",
- "Induction Variable Simplification", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars",
- "Induction Variable Simplification", false, false)
-
-Pass *llvm::createIndVarSimplifyPass() {
- return new IndVarSimplifyLegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
index 52a4bc8a9f24..b52589baeee7 100644
--- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
@@ -72,8 +72,6 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
@@ -81,7 +79,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -121,6 +119,16 @@ static cl::opt<bool> AllowNarrowLatchCondition(
cl::desc("If set to true, IRCE may eliminate wide range checks in loops "
"with narrow latch condition."));
+static cl::opt<unsigned> MaxTypeSizeForOverflowCheck(
+ "irce-max-type-size-for-overflow-check", cl::Hidden, cl::init(32),
+ cl::desc(
+ "Maximum size of range check type for which can be produced runtime "
+ "overflow check of its limit's computation"));
+
+static cl::opt<bool>
+ PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks",
+ cl::Hidden, cl::init(false));
+
static const char *ClonedLoopTag = "irce.loop.clone";
#define DEBUG_TYPE "irce"
@@ -145,14 +153,23 @@ class InductiveRangeCheck {
Use *CheckUse = nullptr;
static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE,
- Value *&Index, Value *&Length,
- bool &IsSigned);
+ const SCEVAddRecExpr *&Index,
+ const SCEV *&End);
static void
extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse,
SmallVectorImpl<InductiveRangeCheck> &Checks,
SmallPtrSetImpl<Value *> &Visited);
+ static bool parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS,
+ ICmpInst::Predicate Pred, ScalarEvolution &SE,
+ const SCEVAddRecExpr *&Index,
+ const SCEV *&End);
+
+ static bool reassociateSubLHS(Loop *L, Value *VariantLHS, Value *InvariantRHS,
+ ICmpInst::Predicate Pred, ScalarEvolution &SE,
+ const SCEVAddRecExpr *&Index, const SCEV *&End);
+
public:
const SCEV *getBegin() const { return Begin; }
const SCEV *getStep() const { return Step; }
@@ -219,10 +236,9 @@ public:
///
/// NB! There may be conditions feeding into \p BI that aren't inductive range
/// checks, and hence don't end up in \p Checks.
- static void
- extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE,
- BranchProbabilityInfo *BPI,
- SmallVectorImpl<InductiveRangeCheck> &Checks);
+ static void extractRangeChecksFromBranch(
+ BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
+ SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed);
};
struct LoopStructure;
@@ -250,48 +266,16 @@ public:
bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);
};
-class IRCELegacyPass : public FunctionPass {
-public:
- static char ID;
-
- IRCELegacyPass() : FunctionPass(ID) {
- initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<BranchProbabilityInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addPreserved<ScalarEvolutionWrapperPass>();
- }
-
- bool runOnFunction(Function &F) override;
-};
-
} // end anonymous namespace
-char IRCELegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce",
- "Inductive range check elimination", false, false)
-INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination",
- false, false)
-
/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot
-/// be interpreted as a range check, return false and set `Index` and `Length`
-/// to `nullptr`. Otherwise set `Index` to the value being range checked, and
-/// set `Length` to the upper limit `Index` is being range checked.
-bool
-InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
- ScalarEvolution &SE, Value *&Index,
- Value *&Length, bool &IsSigned) {
+/// be interpreted as a range check, return false. Otherwise set `Index` to the
+/// SCEV being range checked, and set `End` to the upper or lower limit `Index`
+/// is being range checked.
+bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+ ScalarEvolution &SE,
+ const SCEVAddRecExpr *&Index,
+ const SCEV *&End) {
auto IsLoopInvariant = [&SE, L](Value *V) {
return SE.isLoopInvariant(SE.getSCEV(V), L);
};
@@ -300,47 +284,79 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
+ // Canonicalize to the `Index Pred Invariant` comparison
+ if (IsLoopInvariant(LHS)) {
+ std::swap(LHS, RHS);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ } else if (!IsLoopInvariant(RHS))
+ // Both LHS and RHS are loop variant
+ return false;
+
+ if (parseIvAgaisntLimit(L, LHS, RHS, Pred, SE, Index, End))
+ return true;
+
+ if (reassociateSubLHS(L, LHS, RHS, Pred, SE, Index, End))
+ return true;
+
+ // TODO: support ReassociateAddLHS
+ return false;
+}
+
+// Try to parse range check in the form of "IV vs Limit"
+bool InductiveRangeCheck::parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS,
+ ICmpInst::Predicate Pred,
+ ScalarEvolution &SE,
+ const SCEVAddRecExpr *&Index,
+ const SCEV *&End) {
+
+ auto SIntMaxSCEV = [&](Type *T) {
+ unsigned BitWidth = cast<IntegerType>(T)->getBitWidth();
+ return SE.getConstant(APInt::getSignedMaxValue(BitWidth));
+ };
+
+ const auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(LHS));
+ if (!AddRec)
+ return false;
+
+ // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
+ // We can potentially do much better here.
+ // If we want to adjust upper bound for the unsigned range check as we do it
+ // for signed one, we will need to pick Unsigned max
switch (Pred) {
default:
return false;
- case ICmpInst::ICMP_SLE:
- std::swap(LHS, RHS);
- [[fallthrough]];
case ICmpInst::ICMP_SGE:
- IsSigned = true;
if (match(RHS, m_ConstantInt<0>())) {
- Index = LHS;
- return true; // Lower.
+ Index = AddRec;
+ End = SIntMaxSCEV(Index->getType());
+ return true;
}
return false;
- case ICmpInst::ICMP_SLT:
- std::swap(LHS, RHS);
- [[fallthrough]];
case ICmpInst::ICMP_SGT:
- IsSigned = true;
if (match(RHS, m_ConstantInt<-1>())) {
- Index = LHS;
- return true; // Lower.
- }
-
- if (IsLoopInvariant(LHS)) {
- Index = RHS;
- Length = LHS;
- return true; // Upper.
+ Index = AddRec;
+ End = SIntMaxSCEV(Index->getType());
+ return true;
}
return false;
+ case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT:
- std::swap(LHS, RHS);
- [[fallthrough]];
- case ICmpInst::ICMP_UGT:
- IsSigned = false;
- if (IsLoopInvariant(LHS)) {
- Index = RHS;
- Length = LHS;
- return true; // Both lower and upper.
+ Index = AddRec;
+ End = SE.getSCEV(RHS);
+ return true;
+
+ case ICmpInst::ICMP_SLE:
+ case ICmpInst::ICMP_ULE:
+ const SCEV *One = SE.getOne(RHS->getType());
+ const SCEV *RHSS = SE.getSCEV(RHS);
+ bool Signed = Pred == ICmpInst::ICMP_SLE;
+ if (SE.willNotOverflow(Instruction::BinaryOps::Add, Signed, RHSS, One)) {
+ Index = AddRec;
+ End = SE.getAddExpr(RHSS, One);
+ return true;
}
return false;
}
@@ -348,6 +364,126 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
llvm_unreachable("default clause returns!");
}
+// Try to parse range check in the form of "IV - Offset vs Limit" or "Offset -
+// IV vs Limit"
+bool InductiveRangeCheck::reassociateSubLHS(
+ Loop *L, Value *VariantLHS, Value *InvariantRHS, ICmpInst::Predicate Pred,
+ ScalarEvolution &SE, const SCEVAddRecExpr *&Index, const SCEV *&End) {
+ Value *LHS, *RHS;
+ if (!match(VariantLHS, m_Sub(m_Value(LHS), m_Value(RHS))))
+ return false;
+
+ const SCEV *IV = SE.getSCEV(LHS);
+ const SCEV *Offset = SE.getSCEV(RHS);
+ const SCEV *Limit = SE.getSCEV(InvariantRHS);
+
+ bool OffsetSubtracted = false;
+ if (SE.isLoopInvariant(IV, L))
+ // "Offset - IV vs Limit"
+ std::swap(IV, Offset);
+ else if (SE.isLoopInvariant(Offset, L))
+ // "IV - Offset vs Limit"
+ OffsetSubtracted = true;
+ else
+ return false;
+
+ const auto *AddRec = dyn_cast<SCEVAddRecExpr>(IV);
+ if (!AddRec)
+ return false;
+
+ // In order to turn "IV - Offset < Limit" into "IV < Limit + Offset", we need
+ // to be able to freely move values from left side of inequality to right side
+ // (just as in normal linear arithmetics). Overflows make things much more
+ // complicated, so we want to avoid this.
+ //
+ // Let's prove that the initial subtraction doesn't overflow with all IV's
+ // values from the safe range constructed for that check.
+ //
+ // [Case 1] IV - Offset < Limit
+ // It doesn't overflow if:
+ // SINT_MIN <= IV - Offset <= SINT_MAX
+ // In terms of scaled SINT we need to prove:
+ // SINT_MIN + Offset <= IV <= SINT_MAX + Offset
+ // Safe range will be constructed:
+ // 0 <= IV < Limit + Offset
+ // It means that 'IV - Offset' doesn't underflow, because:
+ // SINT_MIN + Offset < 0 <= IV
+ // and doesn't overflow:
+ // IV < Limit + Offset <= SINT_MAX + Offset
+ //
+ // [Case 2] Offset - IV > Limit
+ // It doesn't overflow if:
+ // SINT_MIN <= Offset - IV <= SINT_MAX
+ // In terms of scaled SINT we need to prove:
+ // -SINT_MIN >= IV - Offset >= -SINT_MAX
+ // Offset - SINT_MIN >= IV >= Offset - SINT_MAX
+ // Safe range will be constructed:
+ // 0 <= IV < Offset - Limit
+ // It means that 'Offset - IV' doesn't underflow, because
+ // Offset - SINT_MAX < 0 <= IV
+ // and doesn't overflow:
+ // IV < Offset - Limit <= Offset - SINT_MIN
+ //
+ // For the computed upper boundary of the IV's range (Offset +/- Limit) we
+ // don't know exactly whether it overflows or not. So if we can't prove this
+ // fact at compile time, we scale boundary computations to a wider type with
+ // the intention to add runtime overflow check.
+
+ auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp,
+ const SCEV *LHS,
+ const SCEV *RHS) -> const SCEV * {
+ const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
+ SCEV::NoWrapFlags, unsigned);
+ switch (BinOp) {
+ default:
+ llvm_unreachable("Unsupported binary op");
+ case Instruction::Add:
+ Operation = &ScalarEvolution::getAddExpr;
+ break;
+ case Instruction::Sub:
+ Operation = &ScalarEvolution::getMinusSCEV;
+ break;
+ }
+
+ if (SE.willNotOverflow(BinOp, ICmpInst::isSigned(Pred), LHS, RHS,
+ cast<Instruction>(VariantLHS)))
+ return (SE.*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0);
+
+ // We couldn't prove that the expression does not overflow.
+ // Than scale it to a wider type to check overflow at runtime.
+ auto *Ty = cast<IntegerType>(LHS->getType());
+ if (Ty->getBitWidth() > MaxTypeSizeForOverflowCheck)
+ return nullptr;
+
+ auto WideTy = IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
+ return (SE.*Operation)(SE.getSignExtendExpr(LHS, WideTy),
+ SE.getSignExtendExpr(RHS, WideTy), SCEV::FlagAnyWrap,
+ 0);
+ };
+
+ if (OffsetSubtracted)
+ // "IV - Offset < Limit" -> "IV" < Offset + Limit
+ Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Offset, Limit);
+ else {
+ // "Offset - IV > Limit" -> "IV" < Offset - Limit
+ Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Sub, Offset, Limit);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
+ // "Expr <= Limit" -> "Expr < Limit + 1"
+ if (Pred == ICmpInst::ICMP_SLE && Limit)
+ Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Limit,
+ SE.getOne(Limit->getType()));
+ if (Limit) {
+ Index = AddRec;
+ End = Limit;
+ return true;
+ }
+ }
+ return false;
+}
+
void InductiveRangeCheck::extractRangeChecksFromCond(
Loop *L, ScalarEvolution &SE, Use &ConditionUse,
SmallVectorImpl<InductiveRangeCheck> &Checks,
@@ -369,32 +505,17 @@ void InductiveRangeCheck::extractRangeChecksFromCond(
if (!ICI)
return;
- Value *Length = nullptr, *Index;
- bool IsSigned;
- if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned))
+ const SCEV *End = nullptr;
+ const SCEVAddRecExpr *IndexAddRec = nullptr;
+ if (!parseRangeCheckICmp(L, ICI, SE, IndexAddRec, End))
return;
- const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index));
- bool IsAffineIndex =
- IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine();
+ assert(IndexAddRec && "IndexAddRec was not computed");
+ assert(End && "End was not computed");
- if (!IsAffineIndex)
+ if ((IndexAddRec->getLoop() != L) || !IndexAddRec->isAffine())
return;
- const SCEV *End = nullptr;
- // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
- // We can potentially do much better here.
- if (Length)
- End = SE.getSCEV(Length);
- else {
- // So far we can only reach this point for Signed range check. This may
- // change in future. In this case we will need to pick Unsigned max for the
- // unsigned range check.
- unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth();
- const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
- End = SIntMax;
- }
-
InductiveRangeCheck IRC;
IRC.End = End;
IRC.Begin = IndexAddRec->getStart();
@@ -405,16 +526,29 @@ void InductiveRangeCheck::extractRangeChecksFromCond(
void InductiveRangeCheck::extractRangeChecksFromBranch(
BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
- SmallVectorImpl<InductiveRangeCheck> &Checks) {
+ SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed) {
if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
return;
+ unsigned IndexLoopSucc = L->contains(BI->getSuccessor(0)) ? 0 : 1;
+ assert(L->contains(BI->getSuccessor(IndexLoopSucc)) &&
+ "No edges coming to loop?");
BranchProbability LikelyTaken(15, 16);
if (!SkipProfitabilityChecks && BPI &&
- BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken)
+ BPI->getEdgeProbability(BI->getParent(), IndexLoopSucc) < LikelyTaken)
return;
+ // IRCE expects branch's true edge comes to loop. Invert branch for opposite
+ // case.
+ if (IndexLoopSucc != 0) {
+ IRBuilder<> Builder(BI);
+ InvertBranch(BI, Builder);
+ if (BPI)
+ BPI->swapSuccEdgesProbabilities(BI->getParent());
+ Changed = true;
+ }
+
SmallPtrSet<Value *, 8> Visited;
InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0),
Checks, Visited);
@@ -622,7 +756,7 @@ class LoopConstrainer {
// Information about the original loop we started out with.
Loop &OriginalLoop;
- const SCEV *LatchTakenCount = nullptr;
+ const IntegerType *ExitCountTy = nullptr;
BasicBlock *OriginalPreheader = nullptr;
// The preheader of the main loop. This may or may not be different from
@@ -671,8 +805,7 @@ static bool isSafeDecreasingBound(const SCEV *Start,
LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
- LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
- << "\n");
+ LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
bool IsSigned = ICmpInst::isSigned(Pred);
@@ -719,8 +852,7 @@ static bool isSafeIncreasingBound(const SCEV *Start,
LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
- LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
- << "\n");
+ LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
bool IsSigned = ICmpInst::isSigned(Pred);
@@ -746,6 +878,19 @@ static bool isSafeIncreasingBound(const SCEV *Start,
SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
}
+/// Returns estimate for max latch taken count of the loop of the narrowest
+/// available type. If the latch block has such estimate, it is returned.
+/// Otherwise, we use max exit count of whole loop (that is potentially of wider
+/// type than latch check itself), which is still better than no estimate.
+static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
+ const Loop &L) {
+ const SCEV *FromBlock =
+ SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
+ if (isa<SCEVCouldNotCompute>(FromBlock))
+ return SE.getSymbolicMaxBackedgeTakenCount(&L);
+ return FromBlock;
+}
+
std::optional<LoopStructure>
LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
const char *&FailureReason) {
@@ -788,11 +933,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
return std::nullopt;
}
- const SCEV *LatchCount = SE.getExitCount(&L, Latch);
- if (isa<SCEVCouldNotCompute>(LatchCount)) {
+ const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
+ if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
FailureReason = "could not compute latch count";
return std::nullopt;
}
+ assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
+ ScalarEvolution::LoopInvariant &&
+ "loop variant exit count doesn't make sense!");
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *LeftValue = ICI->getOperand(0);
@@ -1017,10 +1165,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
}
BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
- assert(SE.getLoopDisposition(LatchCount, &L) ==
- ScalarEvolution::LoopInvariant &&
- "loop variant exit count doesn't make sense!");
-
assert(!L.contains(LatchExit) && "expected an exit block!");
const DataLayout &DL = Preheader->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "irce");
@@ -1062,14 +1206,11 @@ static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
std::optional<LoopConstrainer::SubRanges>
LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
- IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
-
auto *RTy = cast<IntegerType>(Range.getType());
-
// We only support wide range checks and narrow latches.
- if (!AllowNarrowLatchCondition && RTy != Ty)
+ if (!AllowNarrowLatchCondition && RTy != ExitCountTy)
return std::nullopt;
- if (RTy->getBitWidth() < Ty->getBitWidth())
+ if (RTy->getBitWidth() < ExitCountTy->getBitWidth())
return std::nullopt;
LoopConstrainer::SubRanges Result;
@@ -1403,10 +1544,12 @@ Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
bool LoopConstrainer::run() {
BasicBlock *Preheader = nullptr;
- LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
+ const SCEV *MaxBETakenCount =
+ getNarrowestLatchMaxTakenCountEstimate(SE, OriginalLoop);
Preheader = OriginalLoop.getLoopPreheader();
- assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
+ assert(!isa<SCEVCouldNotCompute>(MaxBETakenCount) && Preheader != nullptr &&
"preconditions!");
+ ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
OriginalPreheader = Preheader;
MainLoopPreheader = Preheader;
@@ -1574,6 +1717,27 @@ bool LoopConstrainer::run() {
CanonicalizeLoop(PostL, false);
CanonicalizeLoop(&OriginalLoop, true);
+ /// At this point:
+ /// - We've broken a "main loop" out of the loop in a way that the "main loop"
+ /// runs with the induction variable in a subset of [Begin, End).
+ /// - There is no overflow when computing "main loop" exit limit.
+ /// - Max latch taken count of the loop is limited.
+ /// It guarantees that induction variable will not overflow iterating in the
+ /// "main loop".
+ if (auto BO = dyn_cast<BinaryOperator>(MainLoopStructure.IndVarBase))
+ if (IsSignedPredicate)
+ BO->setHasNoSignedWrap(true);
+ /// TODO: support unsigned predicate.
+ /// To add NUW flag we need to prove that both operands of BO are
+ /// non-negative. E.g:
+ /// ...
+ /// %iv.next = add nsw i32 %iv, -1
+ /// %cmp = icmp ult i32 %iv.next, %n
+ /// br i1 %cmp, label %loopexit, label %loop
+ ///
+ /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
+ /// overflow, therefore NUW flag is not legal here.
+
return true;
}
@@ -1588,11 +1752,13 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
// if latch check is more narrow.
auto *IVType = dyn_cast<IntegerType>(IndVar->getType());
auto *RCType = dyn_cast<IntegerType>(getBegin()->getType());
+ auto *EndType = dyn_cast<IntegerType>(getEnd()->getType());
// Do not work with pointer types.
if (!IVType || !RCType)
return std::nullopt;
if (IVType->getBitWidth() > RCType->getBitWidth())
return std::nullopt;
+
// IndVar is of the form "A + B * I" (where "I" is the canonical induction
// variable, that may or may not exist as a real llvm::Value in the loop) and
// this inductive range check is a range check on the "C + D * I" ("C" is
@@ -1631,6 +1797,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
assert(!D->getValue()->isZero() && "Recurrence with zero step?");
unsigned BitWidth = RCType->getBitWidth();
const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
+ const SCEV *SIntMin = SE.getConstant(APInt::getSignedMinValue(BitWidth));
// Subtract Y from X so that it does not go through border of the IV
// iteration space. Mathematically, it is equivalent to:
@@ -1682,6 +1849,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
// This function returns SCEV equal to 1 if X is non-negative 0 otherwise.
auto SCEVCheckNonNegative = [&](const SCEV *X) {
const Loop *L = IndVar->getLoop();
+ const SCEV *Zero = SE.getZero(X->getType());
const SCEV *One = SE.getOne(X->getType());
// Can we trivially prove that X is a non-negative or negative value?
if (isKnownNonNegativeInLoop(X, L, SE))
@@ -1693,6 +1861,25 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
const SCEV *NegOne = SE.getNegativeSCEV(One);
return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One);
};
+
+ // This function returns SCEV equal to 1 if X will not overflow in terms of
+ // range check type, 0 otherwise.
+ auto SCEVCheckWillNotOverflow = [&](const SCEV *X) {
+ // X doesn't overflow if SINT_MAX >= X.
+ // Then if (SINT_MAX - X) >= 0, X doesn't overflow
+ const SCEV *SIntMaxExt = SE.getSignExtendExpr(SIntMax, X->getType());
+ const SCEV *OverflowCheck =
+ SCEVCheckNonNegative(SE.getMinusSCEV(SIntMaxExt, X));
+
+ // X doesn't underflow if X >= SINT_MIN.
+ // Then if (X - SINT_MIN) >= 0, X doesn't underflow
+ const SCEV *SIntMinExt = SE.getSignExtendExpr(SIntMin, X->getType());
+ const SCEV *UnderflowCheck =
+ SCEVCheckNonNegative(SE.getMinusSCEV(X, SIntMinExt));
+
+ return SE.getMulExpr(OverflowCheck, UnderflowCheck);
+ };
+
// FIXME: Current implementation of ClampedSubtract implicitly assumes that
// X is non-negative (in sense of a signed value). We need to re-implement
// this function in a way that it will correctly handle negative X as well.
@@ -1702,10 +1889,35 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
// Note that this may pessimize elimination of unsigned range checks against
// negative values.
const SCEV *REnd = getEnd();
- const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd);
+ const SCEV *EndWillNotOverflow = SE.getOne(RCType);
+
+ auto PrintRangeCheck = [&](raw_ostream &OS) {
+ auto L = IndVar->getLoop();
+ OS << "irce: in function ";
+ OS << L->getHeader()->getParent()->getName();
+ OS << ", in ";
+ L->print(OS);
+ OS << "there is range check with scaled boundary:\n";
+ print(OS);
+ };
+
+ if (EndType->getBitWidth() > RCType->getBitWidth()) {
+ assert(EndType->getBitWidth() == RCType->getBitWidth() * 2);
+ if (PrintScaledBoundaryRangeChecks)
+ PrintRangeCheck(errs());
+ // End is computed with extended type but will be truncated to a narrow one
+ // type of range check. Therefore we need a check that the result will not
+ // overflow in terms of narrow type.
+ EndWillNotOverflow =
+ SE.getTruncateExpr(SCEVCheckWillNotOverflow(REnd), RCType);
+ REnd = SE.getTruncateExpr(REnd, RCType);
+ }
+
+ const SCEV *RuntimeChecks =
+ SE.getMulExpr(SCEVCheckNonNegative(REnd), EndWillNotOverflow);
+ const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), RuntimeChecks);
+ const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), RuntimeChecks);
- const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative);
- const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative);
return InductiveRangeCheck::Range(Begin, End);
}
@@ -1825,39 +2037,6 @@ PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) {
return getLoopPassPreservedAnalyses();
}
-bool IRCELegacyPass::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
-
- ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- BranchProbabilityInfo &BPI =
- getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI);
-
- bool Changed = false;
-
- for (const auto &L : LI) {
- Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
- /*PreserveLCSSA=*/false);
- Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
- }
-
- SmallPriorityWorklist<Loop *, 4> Worklist;
- appendLoopsToWorklist(LI, Worklist);
- auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) {
- if (!IsSubloop)
- appendLoopsToWorklist(*NL, Worklist);
- };
-
- while (!Worklist.empty()) {
- Loop *L = Worklist.pop_back_val();
- Changed |= IRCE.run(L, LPMAddNewLoop);
- }
- return Changed;
-}
-
bool
InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L,
LoopStructure &LS) {
@@ -1904,14 +2083,15 @@ bool InductiveRangeCheckElimination::run(
LLVMContext &Context = Preheader->getContext();
SmallVector<InductiveRangeCheck, 16> RangeChecks;
+ bool Changed = false;
for (auto *BBI : L->getBlocks())
if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI,
- RangeChecks);
+ RangeChecks, Changed);
if (RangeChecks.empty())
- return false;
+ return Changed;
auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
OS << "irce: looking at loop "; L->print(OS);
@@ -1932,16 +2112,15 @@ bool InductiveRangeCheckElimination::run(
if (!MaybeLoopStructure) {
LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: "
<< FailureReason << "\n";);
- return false;
+ return Changed;
}
LoopStructure LS = *MaybeLoopStructure;
if (!isProfitableToTransform(*L, LS))
- return false;
+ return Changed;
const SCEVAddRecExpr *IndVar =
cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep)));
std::optional<InductiveRangeCheck::Range> SafeIterRange;
- Instruction *ExprInsertPt = Preheader->getTerminator();
SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate;
// Basing on the type of latch predicate, we interpret the IV iteration range
@@ -1951,7 +2130,6 @@ bool InductiveRangeCheckElimination::run(
auto IntersectRange =
LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange;
- IRBuilder<> B(ExprInsertPt);
for (InductiveRangeCheck &IRC : RangeChecks) {
auto Result = IRC.computeSafeIterationSpace(SE, IndVar,
LS.IsSignedPredicate);
@@ -1967,12 +2145,13 @@ bool InductiveRangeCheckElimination::run(
}
if (!SafeIterRange)
- return false;
+ return Changed;
LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange);
- bool Changed = LC.run();
- if (Changed) {
+ if (LC.run()) {
+ Changed = true;
+
auto PrintConstrainedLoopInfo = [L]() {
dbgs() << "irce: in function ";
dbgs() << L->getHeader()->getParent()->getName() << ": ";
@@ -1997,7 +2176,3 @@ bool InductiveRangeCheckElimination::run(
return Changed;
}
-
-Pass *llvm::createInductiveRangeCheckEliminationPass() {
- return new IRCELegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 114738a35fd1..c2b5a12fd63f 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -76,14 +76,14 @@
// Second, IR rewriting in Step 2 also needs to be circular. For example,
// converting %y to addrspace(3) requires the compiler to know the converted
// %y2, but converting %y2 needs the converted %y. To address this complication,
-// we break these cycles using "undef" placeholders. When converting an
+// we break these cycles using "poison" placeholders. When converting an
// instruction `I` to a new address space, if its operand `Op` is not converted
-// yet, we let `I` temporarily use `undef` and fix all the uses of undef later.
+// yet, we let `I` temporarily use `poison` and fix all the uses later.
// For instance, our algorithm first converts %y to
-// %y' = phi float addrspace(3)* [ %input, undef ]
+// %y' = phi float addrspace(3)* [ %input, poison ]
// Then, it converts %y2 to
// %y2' = getelementptr %y', 1
-// Finally, it fixes the undef in %y' so that
+// Finally, it fixes the poison in %y' so that
// %y' = phi float addrspace(3)* [ %input, %y2' ]
//
//===----------------------------------------------------------------------===//
@@ -206,7 +206,7 @@ class InferAddressSpacesImpl {
Instruction *I, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
- SmallVectorImpl<const Use *> *UndefUsesToFix) const;
+ SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
// Changes the flat address expressions in function F to point to specific
// address spaces if InferredAddrSpace says so. Postorder is the postorder of
@@ -233,7 +233,7 @@ class InferAddressSpacesImpl {
Value *V, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
- SmallVectorImpl<const Use *> *UndefUsesToFix) const;
+ SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const;
unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const;
@@ -256,6 +256,12 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
false, false)
+static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) {
+ assert(Ty->isPtrOrPtrVectorTy());
+ PointerType *NPT = PointerType::get(Ty->getContext(), NewAddrSpace);
+ return Ty->getWithNewType(NPT);
+}
+
// Check whether that's no-op pointer bicast using a pair of
// `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over
// different address spaces.
@@ -301,14 +307,14 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL,
switch (Op->getOpcode()) {
case Instruction::PHI:
- assert(Op->getType()->isPointerTy());
+ assert(Op->getType()->isPtrOrPtrVectorTy());
return true;
case Instruction::BitCast:
case Instruction::AddrSpaceCast:
case Instruction::GetElementPtr:
return true;
case Instruction::Select:
- return Op->getType()->isPointerTy();
+ return Op->getType()->isPtrOrPtrVectorTy();
case Instruction::Call: {
const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V);
return II && II->getIntrinsicID() == Intrinsic::ptrmask;
@@ -373,6 +379,24 @@ bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II,
case Intrinsic::ptrmask:
// This is handled as an address expression, not as a use memory operation.
return false;
+ case Intrinsic::masked_gather: {
+ Type *RetTy = II->getType();
+ Type *NewPtrTy = NewV->getType();
+ Function *NewDecl =
+ Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy});
+ II->setArgOperand(0, NewV);
+ II->setCalledFunction(NewDecl);
+ return true;
+ }
+ case Intrinsic::masked_scatter: {
+ Type *ValueTy = II->getOperand(0)->getType();
+ Type *NewPtrTy = NewV->getType();
+ Function *NewDecl =
+ Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy});
+ II->setArgOperand(1, NewV);
+ II->setCalledFunction(NewDecl);
+ return true;
+ }
default: {
Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
if (!Rewrite)
@@ -394,6 +418,14 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands(
appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
PostorderStack, Visited);
break;
+ case Intrinsic::masked_gather:
+ appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
+ PostorderStack, Visited);
+ break;
+ case Intrinsic::masked_scatter:
+ appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1),
+ PostorderStack, Visited);
+ break;
default:
SmallVector<int, 2> OpIndexes;
if (TTI->collectFlatAddressOperands(OpIndexes, IID)) {
@@ -412,7 +444,7 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands(
void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack(
Value *V, PostorderStackTy &PostorderStack,
DenseSet<Value *> &Visited) const {
- assert(V->getType()->isPointerTy());
+ assert(V->getType()->isPtrOrPtrVectorTy());
// Generic addressing expressions may be hidden in nested constant
// expressions.
@@ -460,8 +492,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
// addressing calculations may also be faster.
for (Instruction &I : instructions(F)) {
if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
- if (!GEP->getType()->isVectorTy())
- PushPtrOperand(GEP->getPointerOperand());
+ PushPtrOperand(GEP->getPointerOperand());
} else if (auto *LI = dyn_cast<LoadInst>(&I))
PushPtrOperand(LI->getPointerOperand());
else if (auto *SI = dyn_cast<StoreInst>(&I))
@@ -480,14 +511,12 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
} else if (auto *II = dyn_cast<IntrinsicInst>(&I))
collectRewritableIntrinsicOperands(II, PostorderStack, Visited);
else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) {
- // FIXME: Handle vectors of pointers
- if (Cmp->getOperand(0)->getType()->isPointerTy()) {
+ if (Cmp->getOperand(0)->getType()->isPtrOrPtrVectorTy()) {
PushPtrOperand(Cmp->getOperand(0));
PushPtrOperand(Cmp->getOperand(1));
}
} else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) {
- if (!ASC->getType()->isVectorTy())
- PushPtrOperand(ASC->getPointerOperand());
+ PushPtrOperand(ASC->getPointerOperand());
} else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
PushPtrOperand(
@@ -521,16 +550,15 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone
// of OperandUse.get() in the new address space. If the clone is not ready yet,
-// returns an undef in the new address space as a placeholder.
-static Value *operandWithNewAddressSpaceOrCreateUndef(
+// returns poison in the new address space as a placeholder.
+static Value *operandWithNewAddressSpaceOrCreatePoison(
const Use &OperandUse, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
- SmallVectorImpl<const Use *> *UndefUsesToFix) {
+ SmallVectorImpl<const Use *> *PoisonUsesToFix) {
Value *Operand = OperandUse.get();
- Type *NewPtrTy = PointerType::getWithSamePointeeType(
- cast<PointerType>(Operand->getType()), NewAddrSpace);
+ Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace);
if (Constant *C = dyn_cast<Constant>(Operand))
return ConstantExpr::getAddrSpaceCast(C, NewPtrTy);
@@ -543,23 +571,22 @@ static Value *operandWithNewAddressSpaceOrCreateUndef(
if (I != PredicatedAS.end()) {
// Insert an addrspacecast on that operand before the user.
unsigned NewAS = I->second;
- Type *NewPtrTy = PointerType::getWithSamePointeeType(
- cast<PointerType>(Operand->getType()), NewAS);
+ Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS);
auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy);
NewI->insertBefore(Inst);
NewI->setDebugLoc(Inst->getDebugLoc());
return NewI;
}
- UndefUsesToFix->push_back(&OperandUse);
- return UndefValue::get(NewPtrTy);
+ PoisonUsesToFix->push_back(&OperandUse);
+ return PoisonValue::get(NewPtrTy);
}
// Returns a clone of `I` with its operands converted to those specified in
// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an
// operand whose address space needs to be modified might not exist in
-// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and
-// adds that operand use to UndefUsesToFix so that caller can fix them later.
+// ValueWithNewAddrSpace. In that case, uses poison as a placeholder operand and
+// adds that operand use to PoisonUsesToFix so that caller can fix them later.
//
// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
// from a pointer whose type already matches. Therefore, this function returns a
@@ -571,9 +598,8 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
Instruction *I, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
- SmallVectorImpl<const Use *> *UndefUsesToFix) const {
- Type *NewPtrType = PointerType::getWithSamePointeeType(
- cast<PointerType>(I->getType()), NewAddrSpace);
+ SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
+ Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace);
if (I->getOpcode() == Instruction::AddrSpaceCast) {
Value *Src = I->getOperand(0);
@@ -590,9 +616,9 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
// Technically the intrinsic ID is a pointer typed argument, so specially
// handle calls early.
assert(II->getIntrinsicID() == Intrinsic::ptrmask);
- Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef(
+ Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace,
- PredicatedAS, UndefUsesToFix);
+ PredicatedAS, PoisonUsesToFix);
Value *Rewrite =
TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
if (Rewrite) {
@@ -607,8 +633,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
if (AS != UninitializedAddressSpace) {
// For the assumed address space, insert an `addrspacecast` to make that
// explicit.
- Type *NewPtrTy = PointerType::getWithSamePointeeType(
- cast<PointerType>(I->getType()), AS);
+ Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS);
auto *NewI = new AddrSpaceCastInst(I, NewPtrTy);
NewI->insertAfter(I);
return NewI;
@@ -617,19 +642,19 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
// Computes the converted pointer operands.
SmallVector<Value *, 4> NewPointerOperands;
for (const Use &OperandUse : I->operands()) {
- if (!OperandUse.get()->getType()->isPointerTy())
+ if (!OperandUse.get()->getType()->isPtrOrPtrVectorTy())
NewPointerOperands.push_back(nullptr);
else
- NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef(
+ NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreatePoison(
OperandUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
- UndefUsesToFix));
+ PoisonUsesToFix));
}
switch (I->getOpcode()) {
case Instruction::BitCast:
return new BitCastInst(NewPointerOperands[0], NewPtrType);
case Instruction::PHI: {
- assert(I->getType()->isPointerTy());
+ assert(I->getType()->isPtrOrPtrVectorTy());
PHINode *PHI = cast<PHINode>(I);
PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues());
for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) {
@@ -648,7 +673,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
return NewGEP;
}
case Instruction::Select:
- assert(I->getType()->isPointerTy());
+ assert(I->getType()->isPtrOrPtrVectorTy());
return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
NewPointerOperands[2], "", nullptr, I);
case Instruction::IntToPtr: {
@@ -674,10 +699,10 @@ static Value *cloneConstantExprWithNewAddressSpace(
ConstantExpr *CE, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
const TargetTransformInfo *TTI) {
- Type *TargetType = CE->getType()->isPointerTy()
- ? PointerType::getWithSamePointeeType(
- cast<PointerType>(CE->getType()), NewAddrSpace)
- : CE->getType();
+ Type *TargetType =
+ CE->getType()->isPtrOrPtrVectorTy()
+ ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace)
+ : CE->getType();
if (CE->getOpcode() == Instruction::AddrSpaceCast) {
// Because CE is flat, the source address space must be specific.
@@ -694,18 +719,6 @@ static Value *cloneConstantExprWithNewAddressSpace(
return ConstantExpr::getAddrSpaceCast(CE, TargetType);
}
- if (CE->getOpcode() == Instruction::Select) {
- Constant *Src0 = CE->getOperand(1);
- Constant *Src1 = CE->getOperand(2);
- if (Src0->getType()->getPointerAddressSpace() ==
- Src1->getType()->getPointerAddressSpace()) {
-
- return ConstantExpr::getSelect(
- CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType),
- ConstantExpr::getAddrSpaceCast(Src1, TargetType));
- }
- }
-
if (CE->getOpcode() == Instruction::IntToPtr) {
assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI));
Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
@@ -758,19 +771,19 @@ static Value *cloneConstantExprWithNewAddressSpace(
// ValueWithNewAddrSpace. This function is called on every flat address
// expression whose address space needs to be modified, in postorder.
//
-// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix.
+// See cloneInstructionWithNewAddressSpace for the meaning of PoisonUsesToFix.
Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace(
Value *V, unsigned NewAddrSpace,
const ValueToValueMapTy &ValueWithNewAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
- SmallVectorImpl<const Use *> *UndefUsesToFix) const {
+ SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
// All values in Postorder are flat address expressions.
assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
isAddressExpression(*V, *DL, TTI));
if (Instruction *I = dyn_cast<Instruction>(V)) {
Value *NewV = cloneInstructionWithNewAddressSpace(
- I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, UndefUsesToFix);
+ I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, PoisonUsesToFix);
if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) {
if (NewI->getParent() == nullptr) {
NewI->insertBefore(I);
@@ -1114,7 +1127,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
// operands are converted, the clone is naturally in the new address space by
// construction.
ValueToValueMapTy ValueWithNewAddrSpace;
- SmallVector<const Use *, 32> UndefUsesToFix;
+ SmallVector<const Use *, 32> PoisonUsesToFix;
for (Value* V : Postorder) {
unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
@@ -1126,7 +1139,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
Value *New =
cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace,
- PredicatedAS, &UndefUsesToFix);
+ PredicatedAS, &PoisonUsesToFix);
if (New)
ValueWithNewAddrSpace[V] = New;
}
@@ -1135,16 +1148,16 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
if (ValueWithNewAddrSpace.empty())
return false;
- // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace.
- for (const Use *UndefUse : UndefUsesToFix) {
- User *V = UndefUse->getUser();
+ // Fixes all the poison uses generated by cloneInstructionWithNewAddressSpace.
+ for (const Use *PoisonUse : PoisonUsesToFix) {
+ User *V = PoisonUse->getUser();
User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V));
if (!NewV)
continue;
- unsigned OperandNo = UndefUse->getOperandNo();
- assert(isa<UndefValue>(NewV->getOperand(OperandNo)));
- NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get()));
+ unsigned OperandNo = PoisonUse->getOperandNo();
+ assert(isa<PoisonValue>(NewV->getOperand(OperandNo)));
+ NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(PoisonUse->get()));
}
SmallVector<Instruction *, 16> DeadInstructions;
@@ -1238,20 +1251,6 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
unsigned NewAS = NewV->getType()->getPointerAddressSpace();
if (ASC->getDestAddressSpace() == NewAS) {
- if (!cast<PointerType>(ASC->getType())
- ->hasSameElementTypeAs(
- cast<PointerType>(NewV->getType()))) {
- BasicBlock::iterator InsertPos;
- if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
- InsertPos = std::next(NewVInst->getIterator());
- else if (Instruction *VInst = dyn_cast<Instruction>(V))
- InsertPos = std::next(VInst->getIterator());
- else
- InsertPos = ASC->getIterator();
-
- NewV = CastInst::Create(Instruction::BitCast, NewV,
- ASC->getType(), "", &*InsertPos);
- }
ASC->replaceAllUsesWith(NewV);
DeadInstructions.push_back(ASC);
continue;
diff --git a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
index 4644905adba3..ee9452ce1c7d 100644
--- a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
+++ b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp
@@ -11,7 +11,6 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
@@ -26,8 +25,7 @@ using namespace llvm;
STATISTIC(NumSimplified, "Number of redundant instructions removed");
-static bool runImpl(Function &F, const SimplifyQuery &SQ,
- OptimizationRemarkEmitter *ORE) {
+static bool runImpl(Function &F, const SimplifyQuery &SQ) {
SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2;
bool Changed = false;
@@ -51,7 +49,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ,
DeadInstsInBB.push_back(&I);
Changed = true;
} else if (!I.use_empty()) {
- if (Value *V = simplifyInstruction(&I, SQ, ORE)) {
+ if (Value *V = simplifyInstruction(&I, SQ)) {
// Mark all uses for resimplification next time round the loop.
for (User *U : I.users())
Next->insert(cast<Instruction>(U));
@@ -88,7 +86,6 @@ struct InstSimplifyLegacyPass : public FunctionPass {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
/// Remove instructions that simplify.
@@ -102,11 +99,9 @@ struct InstSimplifyLegacyPass : public FunctionPass {
&getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
AssumptionCache *AC =
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- OptimizationRemarkEmitter *ORE =
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
const DataLayout &DL = F.getParent()->getDataLayout();
const SimplifyQuery SQ(DL, TLI, DT, AC);
- return runImpl(F, SQ, ORE);
+ return runImpl(F, SQ);
}
};
} // namespace
@@ -117,7 +112,6 @@ INITIALIZE_PASS_BEGIN(InstSimplifyLegacyPass, "instsimplify",
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
INITIALIZE_PASS_END(InstSimplifyLegacyPass, "instsimplify",
"Remove redundant instructions", false, false)
@@ -131,10 +125,9 @@ PreservedAnalyses InstSimplifyPass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
- auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
const DataLayout &DL = F.getParent()->getDataLayout();
const SimplifyQuery SQ(DL, &TLI, &DT, &AC);
- bool Changed = runImpl(F, SQ, &ORE);
+ bool Changed = runImpl(F, SQ);
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index f41eaed2e3e7..24390f1b54f6 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -23,7 +23,6 @@
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/ConstantFolding.h"
-#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -31,6 +30,7 @@
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -40,6 +40,7 @@
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
@@ -57,15 +58,12 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/BlockFrequency.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -114,68 +112,6 @@ static cl::opt<bool> ThreadAcrossLoopHeaders(
cl::desc("Allow JumpThreading to thread across loop headers, for testing"),
cl::init(false), cl::Hidden);
-
-namespace {
-
- /// This pass performs 'jump threading', which looks at blocks that have
- /// multiple predecessors and multiple successors. If one or more of the
- /// predecessors of the block can be proven to always jump to one of the
- /// successors, we forward the edge from the predecessor to the successor by
- /// duplicating the contents of this block.
- ///
- /// An example of when this can occur is code like this:
- ///
- /// if () { ...
- /// X = 4;
- /// }
- /// if (X < 3) {
- ///
- /// In this case, the unconditional branch at the end of the first if can be
- /// revectored to the false side of the second if.
- class JumpThreading : public FunctionPass {
- JumpThreadingPass Impl;
-
- public:
- static char ID; // Pass identification
-
- JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) {
- initializeJumpThreadingPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<LazyValueInfoWrapperPass>();
- AU.addPreserved<LazyValueInfoWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- }
-
- void releaseMemory() override { Impl.releaseMemory(); }
- };
-
-} // end anonymous namespace
-
-char JumpThreading::ID = 0;
-
-INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading",
- "Jump Threading", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_END(JumpThreading, "jump-threading",
- "Jump Threading", false, false)
-
-// Public interface to the Jump Threading pass
-FunctionPass *llvm::createJumpThreadingPass(int Threshold) {
- return new JumpThreading(Threshold);
-}
-
JumpThreadingPass::JumpThreadingPass(int T) {
DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T);
}
@@ -306,102 +242,81 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
}
}
-/// runOnFunction - Toplevel algorithm.
-bool JumpThreading::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
- auto TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- // Jump Threading has no sense for the targets with divergent CF
- if (TTI->hasBranchDivergence())
- return false;
- auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
- auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy);
- std::unique_ptr<BlockFrequencyInfo> BFI;
- std::unique_ptr<BranchProbabilityInfo> BPI;
- if (F.hasProfileData()) {
- LoopInfo LI{*DT};
- BPI.reset(new BranchProbabilityInfo(F, LI, TLI));
- BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
- }
-
- bool Changed = Impl.runImpl(F, TLI, TTI, LVI, AA, &DTU, F.hasProfileData(),
- std::move(BFI), std::move(BPI));
- if (PrintLVIAfterJumpThreading) {
- dbgs() << "LVI for function '" << F.getName() << "':\n";
- LVI->printLVI(F, DTU.getDomTree(), dbgs());
- }
- return Changed;
-}
-
PreservedAnalyses JumpThreadingPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
// Jump Threading has no sense for the targets with divergent CF
- if (TTI.hasBranchDivergence())
+ if (TTI.hasBranchDivergence(&F))
return PreservedAnalyses::all();
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &LVI = AM.getResult<LazyValueAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
-
- std::unique_ptr<BlockFrequencyInfo> BFI;
- std::unique_ptr<BranchProbabilityInfo> BPI;
- if (F.hasProfileData()) {
- LoopInfo LI{DT};
- BPI.reset(new BranchProbabilityInfo(F, LI, &TLI));
- BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
- }
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- bool Changed = runImpl(F, &TLI, &TTI, &LVI, &AA, &DTU, F.hasProfileData(),
- std::move(BFI), std::move(BPI));
+ bool Changed =
+ runImpl(F, &AM, &TLI, &TTI, &LVI, &AA,
+ std::make_unique<DomTreeUpdater>(
+ &DT, nullptr, DomTreeUpdater::UpdateStrategy::Lazy),
+ std::nullopt, std::nullopt);
if (PrintLVIAfterJumpThreading) {
dbgs() << "LVI for function '" << F.getName() << "':\n";
- LVI.printLVI(F, DTU.getDomTree(), dbgs());
+ LVI.printLVI(F, getDomTreeUpdater()->getDomTree(), dbgs());
}
if (!Changed)
return PreservedAnalyses::all();
- PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<LazyValueAnalysis>();
- return PA;
+
+
+ getDomTreeUpdater()->flush();
+
+#if defined(EXPENSIVE_CHECKS)
+ assert(getDomTreeUpdater()->getDomTree().verify(
+ DominatorTree::VerificationLevel::Full) &&
+ "DT broken after JumpThreading");
+ assert((!getDomTreeUpdater()->hasPostDomTree() ||
+ getDomTreeUpdater()->getPostDomTree().verify(
+ PostDominatorTree::VerificationLevel::Full)) &&
+ "PDT broken after JumpThreading");
+#else
+ assert(getDomTreeUpdater()->getDomTree().verify(
+ DominatorTree::VerificationLevel::Fast) &&
+ "DT broken after JumpThreading");
+ assert((!getDomTreeUpdater()->hasPostDomTree() ||
+ getDomTreeUpdater()->getPostDomTree().verify(
+ PostDominatorTree::VerificationLevel::Fast)) &&
+ "PDT broken after JumpThreading");
+#endif
+
+ return getPreservedAnalysis();
}
-bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
+bool JumpThreadingPass::runImpl(Function &F_, FunctionAnalysisManager *FAM_,
+ TargetLibraryInfo *TLI_,
TargetTransformInfo *TTI_, LazyValueInfo *LVI_,
- AliasAnalysis *AA_, DomTreeUpdater *DTU_,
- bool HasProfileData_,
- std::unique_ptr<BlockFrequencyInfo> BFI_,
- std::unique_ptr<BranchProbabilityInfo> BPI_) {
- LLVM_DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n");
+ AliasAnalysis *AA_,
+ std::unique_ptr<DomTreeUpdater> DTU_,
+ std::optional<BlockFrequencyInfo *> BFI_,
+ std::optional<BranchProbabilityInfo *> BPI_) {
+ LLVM_DEBUG(dbgs() << "Jump threading on function '" << F_.getName() << "'\n");
+ F = &F_;
+ FAM = FAM_;
TLI = TLI_;
TTI = TTI_;
LVI = LVI_;
AA = AA_;
- DTU = DTU_;
- BFI.reset();
- BPI.reset();
- // When profile data is available, we need to update edge weights after
- // successful jump threading, which requires both BPI and BFI being available.
- HasProfileData = HasProfileData_;
- auto *GuardDecl = F.getParent()->getFunction(
+ DTU = std::move(DTU_);
+ BFI = BFI_;
+ BPI = BPI_;
+ auto *GuardDecl = F->getParent()->getFunction(
Intrinsic::getName(Intrinsic::experimental_guard));
HasGuards = GuardDecl && !GuardDecl->use_empty();
- if (HasProfileData) {
- BPI = std::move(BPI_);
- BFI = std::move(BFI_);
- }
// Reduce the number of instructions duplicated when optimizing strictly for
// size.
if (BBDuplicateThreshold.getNumOccurrences())
BBDupThreshold = BBDuplicateThreshold;
- else if (F.hasFnAttribute(Attribute::MinSize))
+ else if (F->hasFnAttribute(Attribute::MinSize))
BBDupThreshold = 3;
else
BBDupThreshold = DefaultBBDupThreshold;
@@ -412,22 +327,22 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
assert(DTU && "DTU isn't passed into JumpThreading before using it.");
assert(DTU->hasDomTree() && "JumpThreading relies on DomTree to proceed.");
DominatorTree &DT = DTU->getDomTree();
- for (auto &BB : F)
+ for (auto &BB : *F)
if (!DT.isReachableFromEntry(&BB))
Unreachable.insert(&BB);
if (!ThreadAcrossLoopHeaders)
- findLoopHeaders(F);
+ findLoopHeaders(*F);
bool EverChanged = false;
bool Changed;
do {
Changed = false;
- for (auto &BB : F) {
+ for (auto &BB : *F) {
if (Unreachable.count(&BB))
continue;
while (processBlock(&BB)) // Thread all of the branches we can over BB.
- Changed = true;
+ Changed = ChangedSinceLastAnalysisUpdate = true;
// Jump threading may have introduced redundant debug values into BB
// which should be removed.
@@ -437,7 +352,7 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
// Stop processing BB if it's the entry or is now deleted. The following
// routines attempt to eliminate BB and locating a suitable replacement
// for the entry is non-trivial.
- if (&BB == &F.getEntryBlock() || DTU->isBBPendingDeletion(&BB))
+ if (&BB == &F->getEntryBlock() || DTU->isBBPendingDeletion(&BB))
continue;
if (pred_empty(&BB)) {
@@ -448,8 +363,8 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
<< '\n');
LoopHeaders.erase(&BB);
LVI->eraseBlock(&BB);
- DeleteDeadBlock(&BB, DTU);
- Changed = true;
+ DeleteDeadBlock(&BB, DTU.get());
+ Changed = ChangedSinceLastAnalysisUpdate = true;
continue;
}
@@ -464,12 +379,12 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
// Don't alter Loop headers and latches to ensure another pass can
// detect and transform nested loops later.
!LoopHeaders.count(&BB) && !LoopHeaders.count(Succ) &&
- TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) {
+ TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU.get())) {
RemoveRedundantDbgInstrs(Succ);
// BB is valid for cleanup here because we passed in DTU. F remains
// BB's parent until a DTU->getDomTree() event.
LVI->eraseBlock(&BB);
- Changed = true;
+ Changed = ChangedSinceLastAnalysisUpdate = true;
}
}
}
@@ -1140,8 +1055,8 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) {
<< "' folding terminator: " << *BB->getTerminator()
<< '\n');
++NumFolds;
- ConstantFoldTerminator(BB, true, nullptr, DTU);
- if (HasProfileData)
+ ConstantFoldTerminator(BB, true, nullptr, DTU.get());
+ if (auto *BPI = getBPI())
BPI->eraseBlock(BB);
return true;
}
@@ -1296,7 +1211,7 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) {
FICond->eraseFromParent();
DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, RemoveSucc}});
- if (HasProfileData)
+ if (auto *BPI = getBPI())
BPI->eraseBlock(BB);
return true;
}
@@ -1740,7 +1655,7 @@ bool JumpThreadingPass::processThreadableEdges(Value *Cond, BasicBlock *BB,
++NumFolds;
Term->eraseFromParent();
DTU->applyUpdatesPermissive(Updates);
- if (HasProfileData)
+ if (auto *BPI = getBPI())
BPI->eraseBlock(BB);
// If the condition is now dead due to the removal of the old terminator,
@@ -1993,7 +1908,7 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) {
LoopHeaders.insert(BB);
LVI->eraseBlock(SinglePred);
- MergeBasicBlockIntoOnlyPred(BB, DTU);
+ MergeBasicBlockIntoOnlyPred(BB, DTU.get());
// Now that BB is merged into SinglePred (i.e. SinglePred code followed by
// BB code within one basic block `BB`), we need to invalidate the LVI
@@ -2038,6 +1953,7 @@ void JumpThreadingPass::updateSSA(
// PHI insertion, of which we are prepared to do, clean these up now.
SSAUpdater SSAUpdate;
SmallVector<Use *, 16> UsesToRename;
+ SmallVector<DbgValueInst *, 4> DbgValues;
for (Instruction &I : *BB) {
// Scan all uses of this instruction to see if it is used outside of its
@@ -2053,8 +1969,16 @@ void JumpThreadingPass::updateSSA(
UsesToRename.push_back(&U);
}
+ // Find debug values outside of the block
+ findDbgValues(DbgValues, &I);
+ DbgValues.erase(remove_if(DbgValues,
+ [&](const DbgValueInst *DbgVal) {
+ return DbgVal->getParent() == BB;
+ }),
+ DbgValues.end());
+
// If there are no uses outside the block, we're done with this instruction.
- if (UsesToRename.empty())
+ if (UsesToRename.empty() && DbgValues.empty())
continue;
LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n");
@@ -2067,6 +1991,11 @@ void JumpThreadingPass::updateSSA(
while (!UsesToRename.empty())
SSAUpdate.RewriteUse(*UsesToRename.pop_back_val());
+ if (!DbgValues.empty()) {
+ SSAUpdate.UpdateDebugValues(&I, DbgValues);
+ DbgValues.clear();
+ }
+
LLVM_DEBUG(dbgs() << "\n");
}
}
@@ -2298,6 +2227,11 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB,
LLVM_DEBUG(dbgs() << " Threading through '" << PredBB->getName() << "' and '"
<< BB->getName() << "'\n");
+ // Build BPI/BFI before any changes are made to IR.
+ bool HasProfile = doesBlockHaveProfileData(BB);
+ auto *BFI = getOrCreateBFI(HasProfile);
+ auto *BPI = getOrCreateBPI(BFI != nullptr);
+
BranchInst *CondBr = cast<BranchInst>(BB->getTerminator());
BranchInst *PredBBBranch = cast<BranchInst>(PredBB->getTerminator());
@@ -2307,7 +2241,8 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB,
NewBB->moveAfter(PredBB);
// Set the block frequency of NewBB.
- if (HasProfileData) {
+ if (BFI) {
+ assert(BPI && "It's expected BPI to exist along with BFI");
auto NewBBFreq = BFI->getBlockFreq(PredPredBB) *
BPI->getEdgeProbability(PredPredBB, PredBB);
BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
@@ -2320,7 +2255,7 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB,
cloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB);
// Copy the edge probabilities from PredBB to NewBB.
- if (HasProfileData)
+ if (BPI)
BPI->copyEdgeProbabilities(PredBB, NewBB);
// Update the terminator of PredPredBB to jump to NewBB instead of PredBB.
@@ -2404,6 +2339,11 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB,
assert(!LoopHeaders.count(BB) && !LoopHeaders.count(SuccBB) &&
"Don't thread across loop headers");
+ // Build BPI/BFI before any changes are made to IR.
+ bool HasProfile = doesBlockHaveProfileData(BB);
+ auto *BFI = getOrCreateBFI(HasProfile);
+ auto *BPI = getOrCreateBPI(BFI != nullptr);
+
// And finally, do it! Start by factoring the predecessors if needed.
BasicBlock *PredBB;
if (PredBBs.size() == 1)
@@ -2427,7 +2367,8 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB,
NewBB->moveAfter(PredBB);
// Set the block frequency of NewBB.
- if (HasProfileData) {
+ if (BFI) {
+ assert(BPI && "It's expected BPI to exist along with BFI");
auto NewBBFreq =
BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB);
BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
@@ -2469,7 +2410,7 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB,
SimplifyInstructionsInBlock(NewBB, TLI);
// Update the edge weight from BB to SuccBB, which should be less than before.
- updateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB);
+ updateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB, BFI, BPI, HasProfile);
// Threaded an edge!
++NumThreads;
@@ -2486,10 +2427,13 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
// Collect the frequencies of all predecessors of BB, which will be used to
// update the edge weight of the result of splitting predecessors.
DenseMap<BasicBlock *, BlockFrequency> FreqMap;
- if (HasProfileData)
+ auto *BFI = getBFI();
+ if (BFI) {
+ auto *BPI = getOrCreateBPI(true);
for (auto *Pred : Preds)
FreqMap.insert(std::make_pair(
Pred, BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB)));
+ }
// In the case when BB is a LandingPad block we create 2 new predecessors
// instead of just one.
@@ -2508,10 +2452,10 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
for (auto *Pred : predecessors(NewBB)) {
Updates.push_back({DominatorTree::Delete, Pred, BB});
Updates.push_back({DominatorTree::Insert, Pred, NewBB});
- if (HasProfileData) // Update frequencies between Pred -> NewBB.
+ if (BFI) // Update frequencies between Pred -> NewBB.
NewBBFreq += FreqMap.lookup(Pred);
}
- if (HasProfileData) // Apply the summed frequency to NewBB.
+ if (BFI) // Apply the summed frequency to NewBB.
BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
}
@@ -2521,7 +2465,9 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) {
const Instruction *TI = BB->getTerminator();
- assert(TI->getNumSuccessors() > 1 && "not a split");
+ if (!TI || TI->getNumSuccessors() < 2)
+ return false;
+
return hasValidBranchWeightMD(*TI);
}
@@ -2531,11 +2477,18 @@ bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) {
void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
BasicBlock *BB,
BasicBlock *NewBB,
- BasicBlock *SuccBB) {
- if (!HasProfileData)
+ BasicBlock *SuccBB,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI,
+ bool HasProfile) {
+ assert(((BFI && BPI) || (!BFI && !BFI)) &&
+ "Both BFI & BPI should either be set or unset");
+
+ if (!BFI) {
+ assert(!HasProfile &&
+ "It's expected to have BFI/BPI when profile info exists");
return;
-
- assert(BFI && BPI && "BFI & BPI should have been created here");
+ }
// As the edge from PredBB to BB is deleted, we have to update the block
// frequency of BB.
@@ -2608,7 +2561,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
// FIXME this locally as well so that BPI and BFI are consistent as well. We
// shouldn't make edges extremely likely or unlikely based solely on static
// estimation.
- if (BBSuccProbs.size() >= 2 && doesBlockHaveProfileData(BB)) {
+ if (BBSuccProbs.size() >= 2 && HasProfile) {
SmallVector<uint32_t, 4> Weights;
for (auto Prob : BBSuccProbs)
Weights.push_back(Prob.getNumerator());
@@ -2690,6 +2643,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
// mapping and using it to remap operands in the cloned instructions.
for (; BI != BB->end(); ++BI) {
Instruction *New = BI->clone();
+ New->insertInto(PredBB, OldPredBranch->getIterator());
// Remap operands to patch up intra-block references.
for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i)
@@ -2707,7 +2661,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
{BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) {
ValueMapping[&*BI] = IV;
if (!New->mayHaveSideEffects()) {
- New->deleteValue();
+ New->eraseFromParent();
New = nullptr;
}
} else {
@@ -2716,7 +2670,6 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
if (New) {
// Otherwise, insert the new instruction into the block.
New->setName(BI->getName());
- New->insertInto(PredBB, OldPredBranch->getIterator());
// Update Dominance from simplified New instruction operands.
for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i)
if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i)))
@@ -2740,7 +2693,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
// Remove the unconditional branch at the end of the PredBB block.
OldPredBranch->eraseFromParent();
- if (HasProfileData)
+ if (auto *BPI = getBPI())
BPI->copyEdgeProbabilities(BB, PredBB);
DTU->applyUpdatesPermissive(Updates);
@@ -2777,21 +2730,30 @@ void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
BI->copyMetadata(*SI, {LLVMContext::MD_prof});
SIUse->setIncomingValue(Idx, SI->getFalseValue());
SIUse->addIncoming(SI->getTrueValue(), NewBB);
- // Set the block frequency of NewBB.
- if (HasProfileData) {
- uint64_t TrueWeight, FalseWeight;
- if (extractBranchWeights(*SI, TrueWeight, FalseWeight) &&
- (TrueWeight + FalseWeight) != 0) {
- SmallVector<BranchProbability, 2> BP;
- BP.emplace_back(BranchProbability::getBranchProbability(
- TrueWeight, TrueWeight + FalseWeight));
- BP.emplace_back(BranchProbability::getBranchProbability(
- FalseWeight, TrueWeight + FalseWeight));
+
+ uint64_t TrueWeight = 1;
+ uint64_t FalseWeight = 1;
+ // Copy probabilities from 'SI' to created conditional branch in 'Pred'.
+ if (extractBranchWeights(*SI, TrueWeight, FalseWeight) &&
+ (TrueWeight + FalseWeight) != 0) {
+ SmallVector<BranchProbability, 2> BP;
+ BP.emplace_back(BranchProbability::getBranchProbability(
+ TrueWeight, TrueWeight + FalseWeight));
+ BP.emplace_back(BranchProbability::getBranchProbability(
+ FalseWeight, TrueWeight + FalseWeight));
+ // Update BPI if exists.
+ if (auto *BPI = getBPI())
BPI->setEdgeProbability(Pred, BP);
+ }
+ // Set the block frequency of NewBB.
+ if (auto *BFI = getBFI()) {
+ if ((TrueWeight + FalseWeight) == 0) {
+ TrueWeight = 1;
+ FalseWeight = 1;
}
-
- auto NewBBFreq =
- BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, NewBB);
+ BranchProbability PredToNewBBProb = BranchProbability::getBranchProbability(
+ TrueWeight, TrueWeight + FalseWeight);
+ auto NewBBFreq = BFI->getBlockFreq(Pred) * PredToNewBBProb;
BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
}
@@ -3112,3 +3074,93 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard,
}
return true;
}
+
+PreservedAnalyses JumpThreadingPass::getPreservedAnalysis() const {
+ PreservedAnalyses PA;
+ PA.preserve<LazyValueAnalysis>();
+ PA.preserve<DominatorTreeAnalysis>();
+
+ // TODO: We would like to preserve BPI/BFI. Enable once all paths update them.
+ // TODO: Would be nice to verify BPI/BFI consistency as well.
+ return PA;
+}
+
+template <typename AnalysisT>
+typename AnalysisT::Result *JumpThreadingPass::runExternalAnalysis() {
+ assert(FAM && "Can't run external analysis without FunctionAnalysisManager");
+
+ // If there were no changes since last call to 'runExternalAnalysis' then all
+ // analysis is either up to date or explicitly invalidated. Just go ahead and
+ // run the "external" analysis.
+ if (!ChangedSinceLastAnalysisUpdate) {
+ assert(!DTU->hasPendingUpdates() &&
+ "Lost update of 'ChangedSinceLastAnalysisUpdate'?");
+ // Run the "external" analysis.
+ return &FAM->getResult<AnalysisT>(*F);
+ }
+ ChangedSinceLastAnalysisUpdate = false;
+
+ auto PA = getPreservedAnalysis();
+ // TODO: This shouldn't be needed once 'getPreservedAnalysis' reports BPI/BFI
+ // as preserved.
+ PA.preserve<BranchProbabilityAnalysis>();
+ PA.preserve<BlockFrequencyAnalysis>();
+ // Report everything except explicitly preserved as invalid.
+ FAM->invalidate(*F, PA);
+ // Update DT/PDT.
+ DTU->flush();
+ // Make sure DT/PDT are valid before running "external" analysis.
+ assert(DTU->getDomTree().verify(DominatorTree::VerificationLevel::Fast));
+ assert((!DTU->hasPostDomTree() ||
+ DTU->getPostDomTree().verify(
+ PostDominatorTree::VerificationLevel::Fast)));
+ // Run the "external" analysis.
+ auto *Result = &FAM->getResult<AnalysisT>(*F);
+ // Update analysis JumpThreading depends on and not explicitly preserved.
+ TTI = &FAM->getResult<TargetIRAnalysis>(*F);
+ TLI = &FAM->getResult<TargetLibraryAnalysis>(*F);
+ AA = &FAM->getResult<AAManager>(*F);
+
+ return Result;
+}
+
+BranchProbabilityInfo *JumpThreadingPass::getBPI() {
+ if (!BPI) {
+ assert(FAM && "Can't create BPI without FunctionAnalysisManager");
+ BPI = FAM->getCachedResult<BranchProbabilityAnalysis>(*F);
+ }
+ return *BPI;
+}
+
+BlockFrequencyInfo *JumpThreadingPass::getBFI() {
+ if (!BFI) {
+ assert(FAM && "Can't create BFI without FunctionAnalysisManager");
+ BFI = FAM->getCachedResult<BlockFrequencyAnalysis>(*F);
+ }
+ return *BFI;
+}
+
+// Important note on validity of BPI/BFI. JumpThreading tries to preserve
+// BPI/BFI as it goes. Thus if cached instance exists it will be updated.
+// Otherwise, new instance of BPI/BFI is created (up to date by definition).
+BranchProbabilityInfo *JumpThreadingPass::getOrCreateBPI(bool Force) {
+ auto *Res = getBPI();
+ if (Res)
+ return Res;
+
+ if (Force)
+ BPI = runExternalAnalysis<BranchProbabilityAnalysis>();
+
+ return *BPI;
+}
+
+BlockFrequencyInfo *JumpThreadingPass::getOrCreateBFI(bool Force) {
+ auto *Res = getBFI();
+ if (Res)
+ return Res;
+
+ if (Force)
+ BFI = runExternalAnalysis<BlockFrequencyAnalysis>();
+
+ return *BFI;
+}
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 2865dece8723..f8fab03f151d 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -44,7 +44,6 @@
#include "llvm/Analysis/AliasSetTracker.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CaptureTracking.h"
-#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/LazyBlockFrequencyInfo.h"
#include "llvm/Analysis/Loads.h"
@@ -68,6 +67,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PatternMatch.h"
@@ -102,6 +102,12 @@ STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk");
STATISTIC(NumPromotionCandidates, "Number of promotion candidates");
STATISTIC(NumLoadPromoted, "Number of load-only promotions");
STATISTIC(NumLoadStorePromoted, "Number of load and store promotions");
+STATISTIC(NumMinMaxHoisted,
+ "Number of min/max expressions hoisted out of the loop");
+STATISTIC(NumGEPsHoisted,
+ "Number of geps reassociated and hoisted out of the loop");
+STATISTIC(NumAddSubHoisted, "Number of add/subtract expressions reassociated "
+ "and hoisted out of the loop");
/// Memory promotion is enabled by default.
static cl::opt<bool>
@@ -145,10 +151,10 @@ cl::opt<unsigned> llvm::SetLicmMssaNoAccForPromotionCap(
"enable memory promotion."));
static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI);
-static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop,
- const LoopSafetyInfo *SafetyInfo,
- TargetTransformInfo *TTI, bool &FreeInLoop,
- bool LoopNestMode);
+static bool isNotUsedOrFoldableInLoop(const Instruction &I, const Loop *CurLoop,
+ const LoopSafetyInfo *SafetyInfo,
+ TargetTransformInfo *TTI,
+ bool &FoldableInLoop, bool LoopNestMode);
static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop,
BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo,
MemorySSAUpdater &MSSAU, ScalarEvolution *SE,
@@ -163,9 +169,15 @@ static bool isSafeToExecuteUnconditionally(
AssumptionCache *AC, bool AllowSpeculation);
static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU,
Loop *CurLoop, Instruction &I,
- SinkAndHoistLICMFlags &Flags);
+ SinkAndHoistLICMFlags &Flags,
+ bool InvariantGroup);
static bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA,
MemoryUse &MU);
+/// Aggregates various functions for hoisting computations out of loop.
+static bool hoistArithmetics(Instruction &I, Loop &L,
+ ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU, AssumptionCache *AC,
+ DominatorTree *DT);
static Instruction *cloneInstructionInExitBlock(
Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI,
const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU);
@@ -280,9 +292,6 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,
return PreservedAnalyses::all();
auto PA = getLoopPassPreservedAnalyses();
-
- PA.preserve<DominatorTreeAnalysis>();
- PA.preserve<LoopAnalysis>();
PA.preserve<MemorySSAAnalysis>();
return PA;
@@ -293,9 +302,9 @@ void LICMPass::printPipeline(
static_cast<PassInfoMixin<LICMPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation";
- OS << ">";
+ OS << '>';
}
PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM,
@@ -334,9 +343,9 @@ void LNICMPass::printPipeline(
static_cast<PassInfoMixin<LNICMPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation";
- OS << ">";
+ OS << '>';
}
char LegacyLICMPass::ID = 0;
@@ -351,32 +360,21 @@ INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false,
false)
Pass *llvm::createLICMPass() { return new LegacyLICMPass(); }
-Pass *llvm::createLICMPass(unsigned LicmMssaOptCap,
- unsigned LicmMssaNoAccForPromotionCap,
- bool LicmAllowSpeculation) {
- return new LegacyLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- LicmAllowSpeculation);
-}
-llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags(bool IsSink, Loop *L,
- MemorySSA *MSSA)
+llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags(bool IsSink, Loop &L,
+ MemorySSA &MSSA)
: SinkAndHoistLICMFlags(SetLicmMssaOptCap, SetLicmMssaNoAccForPromotionCap,
IsSink, L, MSSA) {}
llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags(
unsigned LicmMssaOptCap, unsigned LicmMssaNoAccForPromotionCap, bool IsSink,
- Loop *L, MemorySSA *MSSA)
+ Loop &L, MemorySSA &MSSA)
: LicmMssaOptCap(LicmMssaOptCap),
LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap),
IsSink(IsSink) {
- assert(((L != nullptr) == (MSSA != nullptr)) &&
- "Unexpected values for SinkAndHoistLICMFlags");
- if (!MSSA)
- return;
-
unsigned AccessCapCount = 0;
- for (auto *BB : L->getBlocks())
- if (const auto *Accesses = MSSA->getBlockAccesses(BB))
+ for (auto *BB : L.getBlocks())
+ if (const auto *Accesses = MSSA.getBlockAccesses(BB))
for (const auto &MA : *Accesses) {
(void)MA;
++AccessCapCount;
@@ -400,7 +398,6 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
bool Changed = false;
assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form.");
- MSSA->ensureOptimizedUses();
// If this loop has metadata indicating that LICM is not to be performed then
// just exit.
@@ -426,7 +423,7 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
MemorySSAUpdater MSSAU(MSSA);
SinkAndHoistLICMFlags Flags(LicmMssaOptCap, LicmMssaNoAccForPromotionCap,
- /*IsSink=*/true, L, MSSA);
+ /*IsSink=*/true, *L, *MSSA);
// Get the preheader block to move instructions into...
BasicBlock *Preheader = L->getLoopPreheader();
@@ -581,14 +578,15 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
// outside of the loop. In this case, it doesn't even matter if the
// operands of the instruction are loop invariant.
//
- bool FreeInLoop = false;
+ bool FoldableInLoop = false;
bool LoopNestMode = OutermostLoop != nullptr;
if (!I.mayHaveSideEffects() &&
- isNotUsedOrFreeInLoop(I, LoopNestMode ? OutermostLoop : CurLoop,
- SafetyInfo, TTI, FreeInLoop, LoopNestMode) &&
+ isNotUsedOrFoldableInLoop(I, LoopNestMode ? OutermostLoop : CurLoop,
+ SafetyInfo, TTI, FoldableInLoop,
+ LoopNestMode) &&
canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE)) {
if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) {
- if (!FreeInLoop) {
+ if (!FoldableInLoop) {
++II;
salvageDebugInfo(I);
eraseInstruction(I, *SafetyInfo, MSSAU);
@@ -881,6 +879,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
LoopBlocksRPO Worklist(CurLoop);
Worklist.perform(LI);
bool Changed = false;
+ BasicBlock *Preheader = CurLoop->getLoopPreheader();
for (BasicBlock *BB : Worklist) {
// Only need to process the contents of this block if it is not part of a
// subloop (which would already have been processed).
@@ -888,21 +887,6 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
continue;
for (Instruction &I : llvm::make_early_inc_range(*BB)) {
- // Try constant folding this instruction. If all the operands are
- // constants, it is technically hoistable, but it would be better to
- // just fold it.
- if (Constant *C = ConstantFoldInstruction(
- &I, I.getModule()->getDataLayout(), TLI)) {
- LLVM_DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C
- << '\n');
- // FIXME MSSA: Such replacements may make accesses unoptimized (D51960).
- I.replaceAllUsesWith(C);
- if (isInstructionTriviallyDead(&I, TLI))
- eraseInstruction(I, *SafetyInfo, MSSAU);
- Changed = true;
- continue;
- }
-
// Try hoisting the instruction out to the preheader. We can only do
// this if all of the operands of the instruction are loop invariant and
// if it is safe to hoist the instruction. We also check block frequency
@@ -914,8 +898,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE) &&
isSafeToExecuteUnconditionally(
I, DT, TLI, CurLoop, SafetyInfo, ORE,
- CurLoop->getLoopPreheader()->getTerminator(), AC,
- AllowSpeculation)) {
+ Preheader->getTerminator(), AC, AllowSpeculation)) {
hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo,
MSSAU, SE, ORE);
HoistedInstructions.push_back(&I);
@@ -983,6 +966,13 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
}
}
+ // Try to reassociate instructions so that part of computations can be
+ // done out of loop.
+ if (hoistArithmetics(I, *CurLoop, *SafetyInfo, MSSAU, AC, DT)) {
+ Changed = true;
+ continue;
+ }
+
// Remember possibly hoistable branches so we can actually hoist them
// later if needed.
if (BranchInst *BI = dyn_cast<BranchInst>(&I))
@@ -1147,6 +1137,20 @@ bool isOnlyMemoryAccess(const Instruction *I, const Loop *L,
}
}
+static MemoryAccess *getClobberingMemoryAccess(MemorySSA &MSSA,
+ BatchAAResults &BAA,
+ SinkAndHoistLICMFlags &Flags,
+ MemoryUseOrDef *MA) {
+ // See declaration of SetLicmMssaOptCap for usage details.
+ if (Flags.tooManyClobberingCalls())
+ return MA->getDefiningAccess();
+
+ MemoryAccess *Source =
+ MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(MA, BAA);
+ Flags.incrementClobberingCalls();
+ return Source;
+}
+
bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
Loop *CurLoop, MemorySSAUpdater &MSSAU,
bool TargetExecutesOncePerLoop,
@@ -1176,8 +1180,12 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
if (isLoadInvariantInLoop(LI, DT, CurLoop))
return true;
+ auto MU = cast<MemoryUse>(MSSA->getMemoryAccess(LI));
+
+ bool InvariantGroup = LI->hasMetadata(LLVMContext::MD_invariant_group);
+
bool Invalidated = pointerInvalidatedByLoop(
- MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop, I, Flags);
+ MSSA, MU, CurLoop, I, Flags, InvariantGroup);
// Check loop-invariant address because this may also be a sinkable load
// whose address is not necessarily loop-invariant.
if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand()))
@@ -1210,12 +1218,17 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
// Assumes don't actually alias anything or throw
return true;
- if (match(CI, m_Intrinsic<Intrinsic::experimental_widenable_condition>()))
- // Widenable conditions don't actually alias anything or throw
- return true;
-
// Handle simple cases by querying alias analysis.
MemoryEffects Behavior = AA->getMemoryEffects(CI);
+
+ // FIXME: we don't handle the semantics of thread local well. So that the
+ // address of thread locals are fake constants in coroutines. So We forbid
+ // to treat onlyReadsMemory call in coroutines as constants now. Note that
+ // it is possible to hide a thread local access in a onlyReadsMemory call.
+ // Remove this check after we handle the semantics of thread locals well.
+ if (Behavior.onlyReadsMemory() && CI->getFunction()->isPresplitCoroutine())
+ return false;
+
if (Behavior.doesNotAccessMemory())
return true;
if (Behavior.onlyReadsMemory()) {
@@ -1228,7 +1241,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
if (Op->getType()->isPointerTy() &&
pointerInvalidatedByLoop(
MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop, I,
- Flags))
+ Flags, /*InvariantGroup=*/false))
return false;
return true;
}
@@ -1258,21 +1271,30 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
// arbitrary number of reads in the loop.
if (isOnlyMemoryAccess(SI, CurLoop, MSSAU))
return true;
- // If there are more accesses than the Promotion cap or no "quota" to
- // check clobber, then give up as we're not walking a list that long.
- if (Flags.tooManyMemoryAccesses() || Flags.tooManyClobberingCalls())
+ // If there are more accesses than the Promotion cap, then give up as we're
+ // not walking a list that long.
+ if (Flags.tooManyMemoryAccesses())
+ return false;
+
+ auto *SIMD = MSSA->getMemoryAccess(SI);
+ BatchAAResults BAA(*AA);
+ auto *Source = getClobberingMemoryAccess(*MSSA, BAA, Flags, SIMD);
+ // Make sure there are no clobbers inside the loop.
+ if (!MSSA->isLiveOnEntryDef(Source) &&
+ CurLoop->contains(Source->getBlock()))
return false;
+
// If there are interfering Uses (i.e. their defining access is in the
// loop), or ordered loads (stored as Defs!), don't move this store.
// Could do better here, but this is conservatively correct.
// TODO: Cache set of Uses on the first walk in runOnLoop, update when
// moving accesses. Can also extend to dominating uses.
- auto *SIMD = MSSA->getMemoryAccess(SI);
for (auto *BB : CurLoop->getBlocks())
if (auto *Accesses = MSSA->getBlockAccesses(BB)) {
for (const auto &MA : *Accesses)
if (const auto *MU = dyn_cast<MemoryUse>(&MA)) {
- auto *MD = MU->getDefiningAccess();
+ auto *MD = getClobberingMemoryAccess(*MSSA, BAA, Flags,
+ const_cast<MemoryUse *>(MU));
if (!MSSA->isLiveOnEntryDef(MD) &&
CurLoop->contains(MD->getBlock()))
return false;
@@ -1293,17 +1315,13 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
// Check if the call may read from the memory location written
// to by SI. Check CI's attributes and arguments; the number of
// such checks performed is limited above by NoOfMemAccTooLarge.
- ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI));
+ ModRefInfo MRI = BAA.getModRefInfo(CI, MemoryLocation::get(SI));
if (isModOrRefSet(MRI))
return false;
}
}
}
- auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI);
- Flags.incrementClobberingCalls();
- // If there are no clobbering Defs in the loop, store is safe to hoist.
- return MSSA->isLiveOnEntryDef(Source) ||
- !CurLoop->contains(Source->getBlock());
+ return true;
}
assert(!I.mayReadOrWriteMemory() && "unhandled aliasing");
@@ -1326,13 +1344,12 @@ static bool isTriviallyReplaceablePHI(const PHINode &PN, const Instruction &I) {
return true;
}
-/// Return true if the instruction is free in the loop.
-static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,
+/// Return true if the instruction is foldable in the loop.
+static bool isFoldableInLoop(const Instruction &I, const Loop *CurLoop,
const TargetTransformInfo *TTI) {
- InstructionCost CostI =
- TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
-
if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
+ InstructionCost CostI =
+ TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
if (CostI != TargetTransformInfo::TCC_Free)
return false;
// For a GEP, we cannot simply use getInstructionCost because currently
@@ -1349,7 +1366,7 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,
return true;
}
- return CostI == TargetTransformInfo::TCC_Free;
+ return false;
}
/// Return true if the only users of this instruction are outside of
@@ -1358,12 +1375,12 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,
///
/// We also return true if the instruction could be folded away in lowering.
/// (e.g., a GEP can be folded into a load as an addressing mode in the loop).
-static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop,
- const LoopSafetyInfo *SafetyInfo,
- TargetTransformInfo *TTI, bool &FreeInLoop,
- bool LoopNestMode) {
+static bool isNotUsedOrFoldableInLoop(const Instruction &I, const Loop *CurLoop,
+ const LoopSafetyInfo *SafetyInfo,
+ TargetTransformInfo *TTI,
+ bool &FoldableInLoop, bool LoopNestMode) {
const auto &BlockColors = SafetyInfo->getBlockColors();
- bool IsFree = isFreeInLoop(I, CurLoop, TTI);
+ bool IsFoldable = isFoldableInLoop(I, CurLoop, TTI);
for (const User *U : I.users()) {
const Instruction *UI = cast<Instruction>(U);
if (const PHINode *PN = dyn_cast<PHINode>(UI)) {
@@ -1390,8 +1407,8 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop,
}
if (CurLoop->contains(UI)) {
- if (IsFree) {
- FreeInLoop = true;
+ if (IsFoldable) {
+ FoldableInLoop = true;
continue;
}
return false;
@@ -1490,7 +1507,7 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest,
MSSAU.getMemorySSA()->getMemoryAccess(&I)))
MSSAU.moveToPlace(OldMemAcc, Dest.getParent(), MemorySSA::BeforeTerminator);
if (SE)
- SE->forgetValue(&I);
+ SE->forgetBlockAndLoopDispositions(&I);
}
static Instruction *sinkThroughTriviallyReplaceablePHI(
@@ -1695,6 +1712,8 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT,
// The PHI must be trivially replaceable.
Instruction *New = sinkThroughTriviallyReplaceablePHI(
PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU);
+ // As we sink the instruction out of the BB, drop its debug location.
+ New->dropLocation();
PN->replaceAllUsesWith(New);
eraseInstruction(*PN, *SafetyInfo, MSSAU);
Changed = true;
@@ -1729,7 +1748,7 @@ static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop,
// time in isGuaranteedToExecute if we don't actually have anything to
// drop. It is a compile time optimization, not required for correctness.
!SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop))
- I.dropUndefImplyingAttrsAndUnknownMetadata();
+ I.dropUBImplyingAttrsAndMetadata();
if (isa<PHINode>(I))
// Move the new node to the end of the phi list in the destination block.
@@ -1915,6 +1934,8 @@ bool isNotVisibleOnUnwindInLoop(const Value *Object, const Loop *L,
isNotCapturedBeforeOrInLoop(Object, L, DT);
}
+// We don't consider globals as writable: While the physical memory is writable,
+// we may not have provenance to perform the write.
bool isWritableObject(const Value *Object) {
// TODO: Alloca might not be writable after its lifetime ends.
// See https://github.com/llvm/llvm-project/issues/51838.
@@ -1925,9 +1946,6 @@ bool isWritableObject(const Value *Object) {
if (auto *A = dyn_cast<Argument>(Object))
return A->hasByValAttr();
- if (auto *G = dyn_cast<GlobalVariable>(Object))
- return !G->isConstant();
-
// TODO: Noalias has nothing to do with writability, this should check for
// an allocator function.
return isNoAliasCall(Object);
@@ -2203,7 +2221,7 @@ bool llvm::promoteLoopAccessesToScalars(
});
// Look at all the loop uses, and try to merge their locations.
- std::vector<const DILocation *> LoopUsesLocs;
+ std::vector<DILocation *> LoopUsesLocs;
for (auto *U : LoopUses)
LoopUsesLocs.push_back(U->getDebugLoc().get());
auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs));
@@ -2330,19 +2348,24 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) {
static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU,
Loop *CurLoop, Instruction &I,
- SinkAndHoistLICMFlags &Flags) {
+ SinkAndHoistLICMFlags &Flags,
+ bool InvariantGroup) {
// For hoisting, use the walker to determine safety
if (!Flags.getIsSink()) {
- MemoryAccess *Source;
- // See declaration of SetLicmMssaOptCap for usage details.
- if (Flags.tooManyClobberingCalls())
- Source = MU->getDefiningAccess();
- else {
- Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU);
- Flags.incrementClobberingCalls();
- }
+ // If hoisting an invariant group, we only need to check that there
+ // is no store to the loaded pointer between the start of the loop,
+ // and the load (since all values must be the same).
+
+ // This can be checked in two conditions:
+ // 1) if the memoryaccess is outside the loop
+ // 2) the earliest access is at the loop header,
+ // if the memory loaded is the phi node
+
+ BatchAAResults BAA(MSSA->getAA());
+ MemoryAccess *Source = getClobberingMemoryAccess(*MSSA, BAA, Flags, MU);
return !MSSA->isLiveOnEntryDef(Source) &&
- CurLoop->contains(Source->getBlock());
+ CurLoop->contains(Source->getBlock()) &&
+ !(InvariantGroup && Source->getBlock() == CurLoop->getHeader() && isa<MemoryPhi>(Source));
}
// For sinking, we'd need to check all Defs below this use. The getClobbering
@@ -2383,6 +2406,304 @@ bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, MemoryUse &MU) {
return false;
}
+/// Try to simplify things like (A < INV_1 AND icmp A < INV_2) into (A <
+/// min(INV_1, INV_2)), if INV_1 and INV_2 are both loop invariants and their
+/// minimun can be computed outside of loop, and X is not a loop-invariant.
+static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU) {
+ bool Inverse = false;
+ using namespace PatternMatch;
+ Value *Cond1, *Cond2;
+ if (match(&I, m_LogicalOr(m_Value(Cond1), m_Value(Cond2)))) {
+ Inverse = true;
+ } else if (match(&I, m_LogicalAnd(m_Value(Cond1), m_Value(Cond2)))) {
+ // Do nothing
+ } else
+ return false;
+
+ auto MatchICmpAgainstInvariant = [&](Value *C, ICmpInst::Predicate &P,
+ Value *&LHS, Value *&RHS) {
+ if (!match(C, m_OneUse(m_ICmp(P, m_Value(LHS), m_Value(RHS)))))
+ return false;
+ if (!LHS->getType()->isIntegerTy())
+ return false;
+ if (!ICmpInst::isRelational(P))
+ return false;
+ if (L.isLoopInvariant(LHS)) {
+ std::swap(LHS, RHS);
+ P = ICmpInst::getSwappedPredicate(P);
+ }
+ if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS))
+ return false;
+ if (Inverse)
+ P = ICmpInst::getInversePredicate(P);
+ return true;
+ };
+ ICmpInst::Predicate P1, P2;
+ Value *LHS1, *LHS2, *RHS1, *RHS2;
+ if (!MatchICmpAgainstInvariant(Cond1, P1, LHS1, RHS1) ||
+ !MatchICmpAgainstInvariant(Cond2, P2, LHS2, RHS2))
+ return false;
+ if (P1 != P2 || LHS1 != LHS2)
+ return false;
+
+ // Everything is fine, we can do the transform.
+ bool UseMin = ICmpInst::isLT(P1) || ICmpInst::isLE(P1);
+ assert(
+ (UseMin || ICmpInst::isGT(P1) || ICmpInst::isGE(P1)) &&
+ "Relational predicate is either less (or equal) or greater (or equal)!");
+ Intrinsic::ID id = ICmpInst::isSigned(P1)
+ ? (UseMin ? Intrinsic::smin : Intrinsic::smax)
+ : (UseMin ? Intrinsic::umin : Intrinsic::umax);
+ auto *Preheader = L.getLoopPreheader();
+ assert(Preheader && "Loop is not in simplify form?");
+ IRBuilder<> Builder(Preheader->getTerminator());
+ // We are about to create a new guaranteed use for RHS2 which might not exist
+ // before (if it was a non-taken input of logical and/or instruction). If it
+ // was poison, we need to freeze it. Note that no new use for LHS and RHS1 are
+ // introduced, so they don't need this.
+ if (isa<SelectInst>(I))
+ RHS2 = Builder.CreateFreeze(RHS2, RHS2->getName() + ".fr");
+ Value *NewRHS = Builder.CreateBinaryIntrinsic(
+ id, RHS1, RHS2, nullptr, StringRef("invariant.") +
+ (ICmpInst::isSigned(P1) ? "s" : "u") +
+ (UseMin ? "min" : "max"));
+ Builder.SetInsertPoint(&I);
+ ICmpInst::Predicate P = P1;
+ if (Inverse)
+ P = ICmpInst::getInversePredicate(P);
+ Value *NewCond = Builder.CreateICmp(P, LHS1, NewRHS);
+ NewCond->takeName(&I);
+ I.replaceAllUsesWith(NewCond);
+ eraseInstruction(I, SafetyInfo, MSSAU);
+ eraseInstruction(*cast<Instruction>(Cond1), SafetyInfo, MSSAU);
+ eraseInstruction(*cast<Instruction>(Cond2), SafetyInfo, MSSAU);
+ return true;
+}
+
+/// Reassociate gep (gep ptr, idx1), idx2 to gep (gep ptr, idx2), idx1 if
+/// this allows hoisting the inner GEP.
+static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU, AssumptionCache *AC,
+ DominatorTree *DT) {
+ auto *GEP = dyn_cast<GetElementPtrInst>(&I);
+ if (!GEP)
+ return false;
+
+ auto *Src = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand());
+ if (!Src || !Src->hasOneUse() || !L.contains(Src))
+ return false;
+
+ Value *SrcPtr = Src->getPointerOperand();
+ auto LoopInvariant = [&](Value *V) { return L.isLoopInvariant(V); };
+ if (!L.isLoopInvariant(SrcPtr) || !all_of(GEP->indices(), LoopInvariant))
+ return false;
+
+ // This can only happen if !AllowSpeculation, otherwise this would already be
+ // handled.
+ // FIXME: Should we respect AllowSpeculation in these reassociation folds?
+ // The flag exists to prevent metadata dropping, which is not relevant here.
+ if (all_of(Src->indices(), LoopInvariant))
+ return false;
+
+ // The swapped GEPs are inbounds if both original GEPs are inbounds
+ // and the sign of the offsets is the same. For simplicity, only
+ // handle both offsets being non-negative.
+ const DataLayout &DL = GEP->getModule()->getDataLayout();
+ auto NonNegative = [&](Value *V) {
+ return isKnownNonNegative(V, DL, 0, AC, GEP, DT);
+ };
+ bool IsInBounds = Src->isInBounds() && GEP->isInBounds() &&
+ all_of(Src->indices(), NonNegative) &&
+ all_of(GEP->indices(), NonNegative);
+
+ BasicBlock *Preheader = L.getLoopPreheader();
+ IRBuilder<> Builder(Preheader->getTerminator());
+ Value *NewSrc = Builder.CreateGEP(GEP->getSourceElementType(), SrcPtr,
+ SmallVector<Value *>(GEP->indices()),
+ "invariant.gep", IsInBounds);
+ Builder.SetInsertPoint(GEP);
+ Value *NewGEP = Builder.CreateGEP(Src->getSourceElementType(), NewSrc,
+ SmallVector<Value *>(Src->indices()), "gep",
+ IsInBounds);
+ GEP->replaceAllUsesWith(NewGEP);
+ eraseInstruction(*GEP, SafetyInfo, MSSAU);
+ eraseInstruction(*Src, SafetyInfo, MSSAU);
+ return true;
+}
+
+/// Try to turn things like "LV + C1 < C2" into "LV < C2 - C1". Here
+/// C1 and C2 are loop invariants and LV is a loop-variant.
+static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS,
+ Value *InvariantRHS, ICmpInst &ICmp, Loop &L,
+ ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU,
+ AssumptionCache *AC, DominatorTree *DT) {
+ assert(ICmpInst::isSigned(Pred) && "Not supported yet!");
+ assert(!L.isLoopInvariant(VariantLHS) && "Precondition.");
+ assert(L.isLoopInvariant(InvariantRHS) && "Precondition.");
+
+ // Try to represent VariantLHS as sum of invariant and variant operands.
+ using namespace PatternMatch;
+ Value *VariantOp, *InvariantOp;
+ if (!match(VariantLHS, m_NSWAdd(m_Value(VariantOp), m_Value(InvariantOp))))
+ return false;
+
+ // LHS itself is a loop-variant, try to represent it in the form:
+ // "VariantOp + InvariantOp". If it is possible, then we can reassociate.
+ if (L.isLoopInvariant(VariantOp))
+ std::swap(VariantOp, InvariantOp);
+ if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp))
+ return false;
+
+ // In order to turn "LV + C1 < C2" into "LV < C2 - C1", we need to be able to
+ // freely move values from left side of inequality to right side (just as in
+ // normal linear arithmetics). Overflows make things much more complicated, so
+ // we want to avoid this.
+ auto &DL = L.getHeader()->getModule()->getDataLayout();
+ bool ProvedNoOverflowAfterReassociate =
+ computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp,
+ DT) == llvm::OverflowResult::NeverOverflows;
+ if (!ProvedNoOverflowAfterReassociate)
+ return false;
+ auto *Preheader = L.getLoopPreheader();
+ assert(Preheader && "Loop is not in simplify form?");
+ IRBuilder<> Builder(Preheader->getTerminator());
+ Value *NewCmpOp = Builder.CreateSub(InvariantRHS, InvariantOp, "invariant.op",
+ /*HasNUW*/ false, /*HasNSW*/ true);
+ ICmp.setPredicate(Pred);
+ ICmp.setOperand(0, VariantOp);
+ ICmp.setOperand(1, NewCmpOp);
+ eraseInstruction(cast<Instruction>(*VariantLHS), SafetyInfo, MSSAU);
+ return true;
+}
+
+/// Try to reassociate and hoist the following two patterns:
+/// LV - C1 < C2 --> LV < C1 + C2,
+/// C1 - LV < C2 --> LV > C1 - C2.
+static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS,
+ Value *InvariantRHS, ICmpInst &ICmp, Loop &L,
+ ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU,
+ AssumptionCache *AC, DominatorTree *DT) {
+ assert(ICmpInst::isSigned(Pred) && "Not supported yet!");
+ assert(!L.isLoopInvariant(VariantLHS) && "Precondition.");
+ assert(L.isLoopInvariant(InvariantRHS) && "Precondition.");
+
+ // Try to represent VariantLHS as sum of invariant and variant operands.
+ using namespace PatternMatch;
+ Value *VariantOp, *InvariantOp;
+ if (!match(VariantLHS, m_NSWSub(m_Value(VariantOp), m_Value(InvariantOp))))
+ return false;
+
+ bool VariantSubtracted = false;
+ // LHS itself is a loop-variant, try to represent it in the form:
+ // "VariantOp + InvariantOp". If it is possible, then we can reassociate. If
+ // the variant operand goes with minus, we use a slightly different scheme.
+ if (L.isLoopInvariant(VariantOp)) {
+ std::swap(VariantOp, InvariantOp);
+ VariantSubtracted = true;
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+ if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp))
+ return false;
+
+ // In order to turn "LV - C1 < C2" into "LV < C2 + C1", we need to be able to
+ // freely move values from left side of inequality to right side (just as in
+ // normal linear arithmetics). Overflows make things much more complicated, so
+ // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that
+ // "C1 - C2" does not overflow.
+ auto &DL = L.getHeader()->getModule()->getDataLayout();
+ if (VariantSubtracted) {
+ // C1 - LV < C2 --> LV > C1 - C2
+ if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp,
+ DT) != llvm::OverflowResult::NeverOverflows)
+ return false;
+ } else {
+ // LV - C1 < C2 --> LV < C1 + C2
+ if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp,
+ DT) != llvm::OverflowResult::NeverOverflows)
+ return false;
+ }
+ auto *Preheader = L.getLoopPreheader();
+ assert(Preheader && "Loop is not in simplify form?");
+ IRBuilder<> Builder(Preheader->getTerminator());
+ Value *NewCmpOp =
+ VariantSubtracted
+ ? Builder.CreateSub(InvariantOp, InvariantRHS, "invariant.op",
+ /*HasNUW*/ false, /*HasNSW*/ true)
+ : Builder.CreateAdd(InvariantOp, InvariantRHS, "invariant.op",
+ /*HasNUW*/ false, /*HasNSW*/ true);
+ ICmp.setPredicate(Pred);
+ ICmp.setOperand(0, VariantOp);
+ ICmp.setOperand(1, NewCmpOp);
+ eraseInstruction(cast<Instruction>(*VariantLHS), SafetyInfo, MSSAU);
+ return true;
+}
+
+/// Reassociate and hoist add/sub expressions.
+static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU, AssumptionCache *AC,
+ DominatorTree *DT) {
+ using namespace PatternMatch;
+ ICmpInst::Predicate Pred;
+ Value *LHS, *RHS;
+ if (!match(&I, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
+ return false;
+
+ // TODO: Support unsigned predicates?
+ if (!ICmpInst::isSigned(Pred))
+ return false;
+
+ // Put variant operand to LHS position.
+ if (L.isLoopInvariant(LHS)) {
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+ // We want to delete the initial operation after reassociation, so only do it
+ // if it has no other uses.
+ if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS) || !LHS->hasOneUse())
+ return false;
+
+ // TODO: We could go with smarter context, taking common dominator of all I's
+ // users instead of I itself.
+ if (hoistAdd(Pred, LHS, RHS, cast<ICmpInst>(I), L, SafetyInfo, MSSAU, AC, DT))
+ return true;
+
+ if (hoistSub(Pred, LHS, RHS, cast<ICmpInst>(I), L, SafetyInfo, MSSAU, AC, DT))
+ return true;
+
+ return false;
+}
+
+static bool hoistArithmetics(Instruction &I, Loop &L,
+ ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU, AssumptionCache *AC,
+ DominatorTree *DT) {
+ // Optimize complex patterns, such as (x < INV1 && x < INV2), turning them
+ // into (x < min(INV1, INV2)), and hoisting the invariant part of this
+ // expression out of the loop.
+ if (hoistMinMax(I, L, SafetyInfo, MSSAU)) {
+ ++NumHoisted;
+ ++NumMinMaxHoisted;
+ return true;
+ }
+
+ // Try to hoist GEPs by reassociation.
+ if (hoistGEP(I, L, SafetyInfo, MSSAU, AC, DT)) {
+ ++NumHoisted;
+ ++NumGEPsHoisted;
+ return true;
+ }
+
+ // Try to hoist add/sub's by reassociation.
+ if (hoistAddSub(I, L, SafetyInfo, MSSAU, AC, DT)) {
+ ++NumHoisted;
+ ++NumAddSubHoisted;
+ return true;
+ }
+
+ return false;
+}
+
/// Little predicate that returns true if the specified basic block is in
/// a subloop of the current one, not the current one itself.
///
diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
index 7e4dbace043a..c041e3621a16 100644
--- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp
@@ -26,8 +26,6 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -73,7 +71,7 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE,
// of the loop.
bool AllEntriesInvariant = true;
bool AllOutgoingValuesSame = true;
- if (!L->hasNoExitBlocks()) {
+ if (ExitBlock) {
for (PHINode &P : ExitBlock->phis()) {
Value *incoming = P.getIncomingValueForBlock(ExitingBlocks[0]);
@@ -488,6 +486,14 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,
LLVM_DEBUG(dbgs() << "Deletion requires at most one exit block.\n");
return LoopDeletionResult::Unmodified;
}
+
+ // We can't directly branch to an EH pad. Don't bother handling this edge
+ // case.
+ if (ExitBlock && ExitBlock->isEHPad()) {
+ LLVM_DEBUG(dbgs() << "Cannot delete loop exiting to EH pad.\n");
+ return LoopDeletionResult::Unmodified;
+ }
+
// Finally, we have to check that the loop really is dead.
bool Changed = false;
if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader, LI)) {
@@ -539,62 +545,3 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-class LoopDeletionLegacyPass : public LoopPass {
-public:
- static char ID; // Pass ID, replacement for typeid
- LoopDeletionLegacyPass() : LoopPass(ID) {
- initializeLoopDeletionLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- // Possibly eliminate loop L if it is dead.
- bool runOnLoop(Loop *L, LPPassManager &) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addPreserved<MemorySSAWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-}
-
-char LoopDeletionLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LoopDeletionLegacyPass, "loop-deletion",
- "Delete dead loops", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_END(LoopDeletionLegacyPass, "loop-deletion",
- "Delete dead loops", false, false)
-
-Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); }
-
-bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
- if (skipLoop(L))
- return false;
- DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- MemorySSA *MSSA = nullptr;
- if (MSSAAnalysis)
- MSSA = &MSSAAnalysis->getMSSA();
- // For the old PM, we can't use OptimizationRemarkEmitter as an analysis
- // pass. Function analyses need to be preserved across loop transformations
- // but ORE cannot be preserved (see comment before the pass definition).
- OptimizationRemarkEmitter ORE(L->getHeader()->getParent());
-
- LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: ");
- LLVM_DEBUG(L->dump());
-
- LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI, MSSA, ORE);
-
- // If we can prove the backedge isn't taken, just break it and be done. This
- // leaves the loop structure in place which means it can handle dispatching
- // to the right exit based on whatever loop invariant structure remains.
- if (Result != LoopDeletionResult::Deleted)
- Result = merge(Result, breakBackedgeIfNotTaken(L, DT, SE, LI, MSSA, ORE));
-
- if (Result == LoopDeletionResult::Deleted)
- LPM.markLoopAsDeleted(*L);
-
- return Result != LoopDeletionResult::Unmodified;
-}
diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
index 7b52b7dca85f..27196e46ca56 100644
--- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -52,13 +52,10 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -995,45 +992,6 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT,
return Changed;
}
-namespace {
-
-/// The pass class.
-class LoopDistributeLegacy : public FunctionPass {
-public:
- static char ID;
-
- LoopDistributeLegacy() : FunctionPass(ID) {
- // The default is set by the caller.
- initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
-
- return runImpl(F, LI, DT, SE, ORE, LAIs);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addRequired<LoopAccessLegacyAnalysis>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
PreservedAnalyses LoopDistributePass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &LI = AM.getResult<LoopAnalysis>(F);
@@ -1050,18 +1008,3 @@ PreservedAnalyses LoopDistributePass::run(Function &F,
PA.preserve<DominatorTreeAnalysis>();
return PA;
}
-
-char LoopDistributeLegacy::ID;
-
-static const char ldist_name[] = "Loop Distribution";
-
-INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false,
- false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false)
-
-FunctionPass *llvm::createLoopDistributePass() { return new LoopDistributeLegacy(); }
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 7d9ce8d35e0b..edc8a4956dd1 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -65,11 +65,8 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -318,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L,
return false;
}
- // The Extend=false flag is used for getTripCountFromExitCount as we want
- // to verify and match it with the pattern matched tripcount. Please note
- // that overflow checks are performed in checkOverflow, but are first tried
- // to avoid by widening the IV.
+ // Evaluating in the trip count's type can not overflow here as the overflow
+ // checks are performed in checkOverflow, but are first tried to avoid by
+ // widening the IV.
const SCEV *SCEVTripCount =
- SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+ SE->getTripCountFromExitCount(BackedgeTakenCount,
+ BackedgeTakenCount->getType(), L);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
@@ -336,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
- SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+ SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,
+ RHS->getType(), L);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
@@ -918,20 +916,6 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
}
-bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
- MemorySSAUpdater *MSSAU) {
- bool Changed = false;
- for (Loop *InnerLoop : LN.getLoops()) {
- auto *OuterLoop = InnerLoop->getParentLoop();
- if (!OuterLoop)
- continue;
- FlattenInfo FI(OuterLoop, InnerLoop);
- Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
- }
- return Changed;
-}
-
PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
@@ -949,8 +933,14 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
// in simplified form, and also needs LCSSA. Running
// this pass will simplify all loops that contain inner loops,
// regardless of whether anything ends up being flattened.
- Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
- MSSAU ? &*MSSAU : nullptr);
+ for (Loop *InnerLoop : LN.getLoops()) {
+ auto *OuterLoop = InnerLoop->getParentLoop();
+ if (!OuterLoop)
+ continue;
+ FlattenInfo FI(OuterLoop, InnerLoop);
+ Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
+ MSSAU ? &*MSSAU : nullptr);
+ }
if (!Changed)
return PreservedAnalyses::all();
@@ -963,60 +953,3 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-class LoopFlattenLegacyPass : public FunctionPass {
-public:
- static char ID; // Pass ID, replacement for typeid
- LoopFlattenLegacyPass() : FunctionPass(ID) {
- initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- // Possibly flatten loop L into its child.
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- getLoopAnalysisUsage(AU);
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addPreserved<TargetTransformInfoWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addPreserved<AssumptionCacheTracker>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-} // namespace
-
-char LoopFlattenLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
- false, false)
-
-FunctionPass *llvm::createLoopFlattenPass() {
- return new LoopFlattenLegacyPass();
-}
-
-bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
- ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
- auto *TTI = &TTIP.getTTI(F);
- auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
-
- std::optional<MemorySSAUpdater> MSSAU;
- if (MSSA)
- MSSAU = MemorySSAUpdater(&MSSA->getMSSA());
-
- bool Changed = false;
- for (Loop *L : *LI) {
- auto LN = LoopNest::getLoopNest(*L, *SE);
- Changed |=
- Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr);
- }
- return Changed;
-}
diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
index 0eecec373736..d35b562be0aa 100644
--- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
@@ -57,12 +57,9 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Verifier.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/CodeMoverUtils.h"
@@ -2061,51 +2058,6 @@ private:
return FC0.L;
}
};
-
-struct LoopFuseLegacy : public FunctionPass {
-
- static char ID;
-
- LoopFuseLegacy() : FunctionPass(ID) {
- initializeLoopFuseLegacyPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequiredID(LoopSimplifyID);
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<PostDominatorTreeWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<DependenceAnalysisWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
-
- AU.addPreserved<ScalarEvolutionWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<PostDominatorTreeWrapperPass>();
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI();
- auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- const TargetTransformInfo &TTI =
- getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- const DataLayout &DL = F.getParent()->getDataLayout();
-
- LoopFuser LF(LI, DT, DI, SE, PDT, ORE, DL, AC, TTI);
- return LF.fuseLoops(F);
- }
-};
} // namespace
PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) {
@@ -2142,19 +2094,3 @@ PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<LoopAnalysis>();
return PA;
}
-
-char LoopFuseLegacy::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false, false)
-
-FunctionPass *llvm::createLoopFusePass() { return new LoopFuseLegacy(); }
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 035cbdf595a8..8572a442e784 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -84,14 +84,11 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -254,62 +251,8 @@ private:
/// @}
};
-
-class LoopIdiomRecognizeLegacyPass : public LoopPass {
-public:
- static char ID;
-
- explicit LoopIdiomRecognizeLegacyPass() : LoopPass(ID) {
- initializeLoopIdiomRecognizeLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (DisableLIRP::All)
- return false;
-
- if (skipLoop(L))
- return false;
-
- AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
- *L->getHeader()->getParent());
- const TargetTransformInfo *TTI =
- &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
- *L->getHeader()->getParent());
- const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout();
- auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- MemorySSA *MSSA = nullptr;
- if (MSSAAnalysis)
- MSSA = &MSSAAnalysis->getMSSA();
-
- // For the old PM, we can't use OptimizationRemarkEmitter as an analysis
- // pass. Function analyses need to be preserved across loop transformations
- // but ORE cannot be preserved (see comment before the pass definition).
- OptimizationRemarkEmitter ORE(L->getHeader()->getParent());
-
- LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, MSSA, DL, ORE);
- return LIR.runOnLoop(L);
- }
-
- /// This transformation requires natural loop information & requires that
- /// loop preheaders be inserted into the CFG.
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-
} // end anonymous namespace
-char LoopIdiomRecognizeLegacyPass::ID = 0;
-
PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &) {
@@ -334,16 +277,6 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
return PA;
}
-INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom",
- "Recognize loop idioms", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(LoopIdiomRecognizeLegacyPass, "loop-idiom",
- "Recognize loop idioms", false, false)
-
-Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognizeLegacyPass(); }
-
static void deleteDeadInstruction(Instruction *I) {
I->replaceAllUsesWith(PoisonValue::get(I->getType()));
I->eraseFromParent();
@@ -1050,33 +983,6 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
return SE->getMinusSCEV(Start, Index);
}
-/// Compute trip count from the backedge taken count.
-static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
- Loop *CurLoop, const DataLayout *DL,
- ScalarEvolution *SE) {
- const SCEV *TripCountS = nullptr;
- // The # stored bytes is (BECount+1). Expand the trip count out to
- // pointer size if it isn't already.
- //
- // If we're going to need to zero extend the BE count, check if we can add
- // one to it prior to zero extending without overflow. Provided this is safe,
- // it allows better simplification of the +1.
- if (DL->getTypeSizeInBits(BECount->getType()) <
- DL->getTypeSizeInBits(IntPtr) &&
- SE->isLoopEntryGuardedByCond(
- CurLoop, ICmpInst::ICMP_NE, BECount,
- SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
- TripCountS = SE->getZeroExtendExpr(
- SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
- IntPtr);
- } else {
- TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
- SE->getOne(IntPtr), SCEV::FlagNUW);
- }
-
- return TripCountS;
-}
-
/// Compute the number of bytes as a SCEV from the backedge taken count.
///
/// This also maps the SCEV into the provided type and tries to handle the
@@ -1084,8 +990,8 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
const SCEV *StoreSizeSCEV, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
- const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
-
+ const SCEV *TripCountSCEV =
+ SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop);
return SE->getMulExpr(TripCountSCEV,
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
SCEV::FlagNUW);
@@ -1168,20 +1074,24 @@ bool LoopIdiomRecognize::processLoopStridedStore(
Value *NumBytes =
Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
+ if (!SplatValue && !isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16))
+ return Changed;
+
+ AAMDNodes AATags = TheStore->getAAMetadata();
+ for (Instruction *Store : Stores)
+ AATags = AATags.merge(Store->getAAMetadata());
+ if (auto CI = dyn_cast<ConstantInt>(NumBytes))
+ AATags = AATags.extendTo(CI->getZExtValue());
+ else
+ AATags = AATags.extendTo(-1);
+
CallInst *NewCall;
if (SplatValue) {
- AAMDNodes AATags = TheStore->getAAMetadata();
- for (Instruction *Store : Stores)
- AATags = AATags.merge(Store->getAAMetadata());
- if (auto CI = dyn_cast<ConstantInt>(NumBytes))
- AATags = AATags.extendTo(CI->getZExtValue());
- else
- AATags = AATags.extendTo(-1);
-
NewCall = Builder.CreateMemSet(
BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment),
/*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias);
- } else if (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16)) {
+ } else {
+ assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16));
// Everything is emitted in default address space
Type *Int8PtrTy = DestInt8PtrTy;
@@ -1199,8 +1109,17 @@ bool LoopIdiomRecognize::processLoopStridedStore(
GV->setAlignment(Align(16));
Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy);
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
- } else
- return Changed;
+
+ // Set the TBAA info if present.
+ if (AATags.TBAA)
+ NewCall->setMetadata(LLVMContext::MD_tbaa, AATags.TBAA);
+
+ if (AATags.Scope)
+ NewCall->setMetadata(LLVMContext::MD_alias_scope, AATags.Scope);
+
+ if (AATags.NoAlias)
+ NewCall->setMetadata(LLVMContext::MD_noalias, AATags.NoAlias);
+ }
NewCall->setDebugLoc(TheStore->getDebugLoc());
@@ -2471,7 +2390,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
// intrinsic/shift we'll use are not cheap. Note that we are okay with *just*
// making the loop countable, even if nothing else changes.
IntrinsicCostAttributes Attrs(
- IntrID, Ty, {UndefValue::get(Ty), /*is_zero_undef=*/Builder.getTrue()});
+ IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getTrue()});
InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind);
if (Cost > TargetTransformInfo::TCC_Basic) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
@@ -2487,6 +2406,24 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
// Ok, transform appears worthwhile.
MadeChange = true;
+ if (!isGuaranteedNotToBeUndefOrPoison(BitPos)) {
+ // BitMask may be computed from BitPos, Freeze BitPos so we can increase
+ // it's use count.
+ Instruction *InsertPt = nullptr;
+ if (auto *BitPosI = dyn_cast<Instruction>(BitPos))
+ InsertPt = BitPosI->getInsertionPointAfterDef();
+ else
+ InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca();
+ if (!InsertPt)
+ return false;
+ FreezeInst *BitPosFrozen =
+ new FreezeInst(BitPos, BitPos->getName() + ".fr", InsertPt);
+ BitPos->replaceUsesWithIf(BitPosFrozen, [BitPosFrozen](Use &U) {
+ return U.getUser() != BitPosFrozen;
+ });
+ BitPos = BitPosFrozen;
+ }
+
// Step 1: Compute the loop trip count.
Value *LowBitMask = Builder.CreateAdd(BitMask, Constant::getAllOnesValue(Ty),
@@ -2495,7 +2432,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
Builder.CreateOr(LowBitMask, BitMask, BitPos->getName() + ".mask");
Value *XMasked = Builder.CreateAnd(X, Mask, X->getName() + ".masked");
CallInst *XMaskedNumLeadingZeros = Builder.CreateIntrinsic(
- IntrID, Ty, {XMasked, /*is_zero_undef=*/Builder.getTrue()},
+ IntrID, Ty, {XMasked, /*is_zero_poison=*/Builder.getTrue()},
/*FMFSource=*/nullptr, XMasked->getName() + ".numleadingzeros");
Value *XMaskedNumActiveBits = Builder.CreateSub(
ConstantInt::get(Ty, Ty->getScalarSizeInBits()), XMaskedNumLeadingZeros,
@@ -2825,7 +2762,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
// intrinsic we'll use are not cheap. Note that we are okay with *just*
// making the loop countable, even if nothing else changes.
IntrinsicCostAttributes Attrs(
- IntrID, Ty, {UndefValue::get(Ty), /*is_zero_undef=*/Builder.getFalse()});
+ IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getFalse()});
InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind);
if (Cost > TargetTransformInfo::TCC_Basic) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
@@ -2843,7 +2780,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
// Step 1: Compute the loop's final IV value / trip count.
CallInst *ValNumLeadingZeros = Builder.CreateIntrinsic(
- IntrID, Ty, {Val, /*is_zero_undef=*/Builder.getFalse()},
+ IntrID, Ty, {Val, /*is_zero_poison=*/Builder.getFalse()},
/*FMFSource=*/nullptr, Val->getName() + ".numleadingzeros");
Value *ValNumActiveBits = Builder.CreateSub(
ConstantInt::get(Ty, Ty->getScalarSizeInBits()), ValNumLeadingZeros,
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index 0a7c62113c7f..91286ebcea33 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -30,20 +30,16 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
-#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -187,8 +183,7 @@ static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx,
// if the direction matrix, after the same permutation is applied to its
// columns, has no ">" direction as the leftmost non-"=" direction in any row.
static bool isLexicographicallyPositive(std::vector<char> &DV) {
- for (unsigned Level = 0; Level < DV.size(); ++Level) {
- unsigned char Direction = DV[Level];
+ for (unsigned char Direction : DV) {
if (Direction == '<')
return true;
if (Direction == '>' || Direction == '*')
@@ -736,7 +731,6 @@ bool LoopInterchangeLegality::findInductionAndReductions(
if (!L->getLoopLatch() || !L->getLoopPredecessor())
return false;
for (PHINode &PHI : L->getHeader()->phis()) {
- RecurrenceDescriptor RD;
InductionDescriptor ID;
if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID))
Inductions.push_back(&PHI);
@@ -1105,8 +1099,7 @@ LoopInterchangeProfitability::isProfitablePerLoopCacheAnalysis(
// This is the new cost model returned from loop cache analysis.
// A smaller index means the loop should be placed an outer loop, and vice
// versa.
- if (CostMap.find(InnerLoop) != CostMap.end() &&
- CostMap.find(OuterLoop) != CostMap.end()) {
+ if (CostMap.contains(InnerLoop) && CostMap.contains(OuterLoop)) {
unsigned InnerIndex = 0, OuterIndex = 0;
InnerIndex = CostMap.find(InnerLoop)->second;
OuterIndex = CostMap.find(OuterLoop)->second;
@@ -1692,12 +1685,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() {
// latch. In that case, we need to create LCSSA phis for them, because after
// interchanging they will be defined in the new inner loop and used in the
// new outer loop.
- IRBuilder<> Builder(OuterLoopHeader->getContext());
SmallVector<Instruction *, 4> MayNeedLCSSAPhis;
for (Instruction &I :
make_range(OuterLoopHeader->begin(), std::prev(OuterLoopHeader->end())))
MayNeedLCSSAPhis.push_back(&I);
- formLCSSAForInstructions(MayNeedLCSSAPhis, *DT, *LI, SE, Builder);
+ formLCSSAForInstructions(MayNeedLCSSAPhis, *DT, *LI, SE);
return true;
}
@@ -1716,52 +1708,6 @@ bool LoopInterchangeTransform::adjustLoopLinks() {
return Changed;
}
-namespace {
-/// Main LoopInterchange Pass.
-struct LoopInterchangeLegacyPass : public LoopPass {
- static char ID;
-
- LoopInterchangeLegacyPass() : LoopPass(ID) {
- initializeLoopInterchangeLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DependenceAnalysisWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
-
- getLoopAnalysisUsage(AU);
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
-
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- std::unique_ptr<CacheCost> CC = nullptr;
- return LoopInterchange(SE, LI, DI, DT, CC, ORE).run(L);
- }
-};
-} // namespace
-
-char LoopInterchangeLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopInterchangeLegacyPass, "loop-interchange",
- "Interchanges loops for cache reuse", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-
-INITIALIZE_PASS_END(LoopInterchangeLegacyPass, "loop-interchange",
- "Interchanges loops for cache reuse", false, false)
-
-Pass *llvm::createLoopInterchangePass() {
- return new LoopInterchangeLegacyPass();
-}
-
PreservedAnalyses LoopInterchangePass::run(LoopNest &LN,
LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
index b615a0a0a9c0..179ccde8d035 100644
--- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -46,13 +46,10 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
@@ -91,8 +88,9 @@ struct StoreToLoadForwardingCandidate {
StoreToLoadForwardingCandidate(LoadInst *Load, StoreInst *Store)
: Load(Load), Store(Store) {}
- /// Return true if the dependence from the store to the load has a
- /// distance of one. E.g. A[i+1] = A[i]
+ /// Return true if the dependence from the store to the load has an
+ /// absolute distance of one.
+ /// E.g. A[i+1] = A[i] (or A[i-1] = A[i] for descending loop)
bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE,
Loop *L) const {
Value *LoadPtr = Load->getPointerOperand();
@@ -106,11 +104,19 @@ struct StoreToLoadForwardingCandidate {
DL.getTypeSizeInBits(getLoadStoreType(Store)) &&
"Should be a known dependence");
- // Currently we only support accesses with unit stride. FIXME: we should be
- // able to handle non unit stirde as well as long as the stride is equal to
- // the dependence distance.
- if (getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0) != 1 ||
- getPtrStride(PSE, LoadType, StorePtr, L).value_or(0) != 1)
+ int64_t StrideLoad = getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0);
+ int64_t StrideStore = getPtrStride(PSE, LoadType, StorePtr, L).value_or(0);
+ if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
+ return false;
+
+ // TODO: This check for stride values other than 1 and -1 can be eliminated.
+ // However, doing so may cause the LoopAccessAnalysis to overcompensate,
+ // generating numerous non-wrap runtime checks that may undermine the
+ // benefits of load elimination. To safely implement support for non-unit
+ // strides, we would need to ensure either that the processed case does not
+ // require these additional checks, or improve the LAA to handle them more
+ // efficiently, or potentially both.
+ if (std::abs(StrideLoad) != 1)
return false;
unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType));
@@ -123,7 +129,7 @@ struct StoreToLoadForwardingCandidate {
auto *Dist = cast<SCEVConstant>(
PSE.getSE()->getMinusSCEV(StorePtrSCEV, LoadPtrSCEV));
const APInt &Val = Dist->getAPInt();
- return Val == TypeByteSize;
+ return Val == TypeByteSize * StrideLoad;
}
Value *getLoadPtr() const { return Load->getPointerOperand(); }
@@ -658,70 +664,6 @@ static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI,
return Changed;
}
-namespace {
-
-/// The pass. Most of the work is delegated to the per-loop
-/// LoadEliminationForLoop class.
-class LoopLoadElimination : public FunctionPass {
-public:
- static char ID;
-
- LoopLoadElimination() : FunctionPass(ID) {
- initializeLoopLoadEliminationPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
- auto *BFI = (PSI && PSI->hasProfileSummary()) ?
- &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
- nullptr;
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
-
- // Process each loop nest in the function.
- return eliminateLoadsAcrossLoops(F, LI, DT, BFI, PSI, SE, /*AC*/ nullptr,
- LAIs);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequiredID(LoopSimplifyID);
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addRequired<LoopAccessLegacyAnalysis>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<ProfileSummaryInfoWrapperPass>();
- LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-char LoopLoadElimination::ID;
-
-static const char LLE_name[] = "Loop Load Elimination";
-
-INITIALIZE_PASS_BEGIN(LoopLoadElimination, LLE_OPTION, LLE_name, false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass)
-INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false)
-
-FunctionPass *llvm::createLoopLoadEliminationPass() {
- return new LoopLoadElimination();
-}
-
PreservedAnalyses LoopLoadEliminationPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &LI = AM.getResult<LoopAnalysis>(F);
@@ -744,5 +686,7 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F,
return PreservedAnalyses::all();
PreservedAnalyses PA;
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<LoopAnalysis>();
return PA;
}
diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
index c98b94b56e48..2c8a3351281b 100644
--- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -59,7 +59,7 @@ void PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
P->printPipeline(OS, MapClassName2PassName);
}
if (Idx + 1 < Size)
- OS << ",";
+ OS << ',';
}
}
@@ -193,7 +193,7 @@ void FunctionToLoopPassAdaptor::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
OS << (UseMemorySSA ? "loop-mssa(" : "loop(");
Pass->printPipeline(OS, MapClassName2PassName);
- OS << ")";
+ OS << ')';
}
PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
FunctionAnalysisManager &AM) {
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 49c0fff84d81..12852ae5c460 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -623,7 +623,8 @@ std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
GuardStart, GuardLimit);
IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
- return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
+ return Builder.CreateFreeze(
+ Builder.CreateAnd(FirstIterationCheck, LimitCheck));
}
std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
@@ -671,7 +672,8 @@ std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
SE->getOne(Ty));
IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
- return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
+ return Builder.CreateFreeze(
+ Builder.CreateAnd(FirstIterationCheck, LimitCheck));
}
static void normalizePredicate(ScalarEvolution *SE, Loop *L,
@@ -863,7 +865,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
BI->setCondition(AllChecks);
if (InsertAssumesOfPredicatedGuardsConditions) {
Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
- Builder.CreateAssumption(Cond);
+ // If this block has other predecessors, we might not be able to use Cond.
+ // In this case, create a Phi where every other input is `true` and input
+ // from guard block is Cond.
+ Value *AssumeCond = Cond;
+ if (!IfTrueBB->getUniquePredecessor()) {
+ auto *GuardBB = BI->getParent();
+ auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB),
+ "assume.cond");
+ for (auto *Pred : predecessors(IfTrueBB))
+ PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred);
+ AssumeCond = PN;
+ }
+ Builder.CreateAssumption(AssumeCond);
}
RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
assert(isGuardAsWidenableBranch(BI) &&
@@ -1161,6 +1175,11 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
if (ChangedLoop)
SE->forgetLoop(L);
+ // The insertion point for the widening should be at the widenably call, not
+ // at the WidenableBR. If we do this at the widenableBR, we can incorrectly
+ // change a loop-invariant condition to a loop-varying one.
+ auto *IP = cast<Instruction>(WidenableBR->getCondition());
+
// The use of umin(all analyzeable exits) instead of latch is subtle, but
// important for profitability. We may have a loop which hasn't been fully
// canonicalized just yet. If the exit we chose to widen is provably never
@@ -1170,21 +1189,9 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
!SE->isLoopInvariant(MinEC, L) ||
- !Rewriter.isSafeToExpandAt(MinEC, WidenableBR))
+ !Rewriter.isSafeToExpandAt(MinEC, IP))
return ChangedLoop;
- // Subtlety: We need to avoid inserting additional uses of the WC. We know
- // that it can only have one transitive use at the moment, and thus moving
- // that use to just before the branch and inserting code before it and then
- // modifying the operand is legal.
- auto *IP = cast<Instruction>(WidenableBR->getCondition());
- // Here we unconditionally modify the IR, so after this point we should return
- // only `true`!
- IP->moveBefore(WidenableBR);
- if (MSSAU)
- if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP))
- MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
- MemorySSA::BeforeTerminator);
Rewriter.setInsertPoint(IP);
IRBuilder<> B(IP);
diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
index a0b3189c7e09..7f62526a4f6d 100644
--- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
@@ -39,13 +39,10 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopReroll.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -157,22 +154,6 @@ namespace {
IL_End
};
- class LoopRerollLegacyPass : public LoopPass {
- public:
- static char ID; // Pass ID, replacement for typeid
-
- LoopRerollLegacyPass() : LoopPass(ID) {
- initializeLoopRerollLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
- };
-
class LoopReroll {
public:
LoopReroll(AliasAnalysis *AA, LoopInfo *LI, ScalarEvolution *SE,
@@ -490,17 +471,6 @@ namespace {
} // end anonymous namespace
-char LoopRerollLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopRerollLegacyPass, "loop-reroll", "Reroll loops",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(LoopRerollLegacyPass, "loop-reroll", "Reroll loops", false,
- false)
-
-Pass *llvm::createLoopRerollPass() { return new LoopRerollLegacyPass; }
-
// Returns true if the provided instruction is used outside the given loop.
// This operates like Instruction::isUsedOutsideOfBlock, but considers PHIs in
// non-loop blocks to be outside the loop.
@@ -1700,21 +1670,6 @@ bool LoopReroll::runOnLoop(Loop *L) {
return Changed;
}
-bool LoopRerollLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
- if (skipLoop(L))
- return false;
-
- auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
- *L->getHeader()->getParent());
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
-
- return LoopReroll(AA, LI, SE, TLI, DT, PreserveLCSSA).runOnLoop(L);
-}
-
PreservedAnalyses LoopRerollPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp
index ba735adc5b27..eee855058706 100644
--- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp
@@ -43,6 +43,21 @@ LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication, bool PrepareForLTO)
: EnableHeaderDuplication(EnableHeaderDuplication),
PrepareForLTO(PrepareForLTO) {}
+void LoopRotatePass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<LoopRotatePass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ OS << "<";
+ if (!EnableHeaderDuplication)
+ OS << "no-";
+ OS << "header-duplication;";
+
+ if (!PrepareForLTO)
+ OS << "no-";
+ OS << "prepare-for-lto";
+ OS << ">";
+}
+
PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &) {
diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp
index 21025b0bdb33..597c159682c5 100644
--- a/llvm/lib/Transforms/Scalar/LoopSink.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp
@@ -177,13 +177,27 @@ static bool sinkInstruction(
SmallPtrSet<BasicBlock *, 2> BBs;
for (auto &U : I.uses()) {
Instruction *UI = cast<Instruction>(U.getUser());
- // We cannot sink I to PHI-uses.
- if (isa<PHINode>(UI))
- return false;
+
// We cannot sink I if it has uses outside of the loop.
if (!L.contains(LI.getLoopFor(UI->getParent())))
return false;
- BBs.insert(UI->getParent());
+
+ if (!isa<PHINode>(UI)) {
+ BBs.insert(UI->getParent());
+ continue;
+ }
+
+ // We cannot sink I to PHI-uses, try to look through PHI to find the incoming
+ // block of the value being used.
+ PHINode *PN = dyn_cast<PHINode>(UI);
+ BasicBlock *PhiBB = PN->getIncomingBlock(U);
+
+ // If value's incoming block is from loop preheader directly, there's no
+ // place to sink to, bailout.
+ if (L.getLoopPreheader() == PhiBB)
+ return false;
+
+ BBs.insert(PhiBB);
}
// findBBsToSinkInto is O(BBs.size() * ColdLoopBBs.size()). We cap the max
@@ -238,9 +252,11 @@ static bool sinkInstruction(
}
}
- // Replaces uses of I with IC in N
+ // Replaces uses of I with IC in N, except PHI-use which is being taken
+ // care of by defs in PHI's incoming blocks.
I.replaceUsesWithIf(IC, [N](Use &U) {
- return cast<Instruction>(U.getUser())->getParent() == N;
+ Instruction *UIToReplace = cast<Instruction>(U.getUser());
+ return UIToReplace->getParent() == N && !isa<PHINode>(UIToReplace);
});
// Replaces uses of I with IC in blocks dominated by N
replaceDominatedUsesWith(&I, IC, DT, N);
@@ -283,7 +299,7 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI,
return false;
MemorySSAUpdater MSSAU(&MSSA);
- SinkAndHoistLICMFlags LICMFlags(/*IsSink=*/true, &L, &MSSA);
+ SinkAndHoistLICMFlags LICMFlags(/*IsSink=*/true, L, MSSA);
bool Changed = false;
@@ -323,6 +339,11 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI,
}
PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) {
+ // Enable LoopSink only when runtime profile is available.
+ // With static profile, the sinking decision may be sub-optimal.
+ if (!F.hasProfileData())
+ return PreservedAnalyses::all();
+
LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
// Nothing to do if there are no loops.
if (LI.empty())
@@ -348,11 +369,6 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) {
if (!Preheader)
continue;
- // Enable LoopSink only when runtime profile is available.
- // With static profile, the sinking decision may be sub-optimal.
- if (!Preheader->getParent()->hasProfileData())
- continue;
-
// Note that we don't pass SCEV here because it is only used to invalidate
// loops in SCEV and we don't preserve (or request) SCEV at all making that
// unnecessary.
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 4c89f947d7fc..a4369b83e732 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -799,7 +799,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// value, and mutate S to point to a new SCEV with that value excluded.
static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (C->getAPInt().getMinSignedBits() <= 64) {
+ if (C->getAPInt().getSignificantBits() <= 64) {
S = SE.getConstant(C->getType(), 0);
return C->getValue()->getSExtValue();
}
@@ -896,9 +896,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI,
/// Return the type of the memory being accessed.
static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
Instruction *Inst, Value *OperandVal) {
- MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace);
+ MemAccessTy AccessTy = MemAccessTy::getUnknown(Inst->getContext());
+
+ // First get the type of memory being accessed.
+ if (Type *Ty = Inst->getAccessType())
+ AccessTy.MemTy = Ty;
+
+ // Then get the pointer address space.
if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
- AccessTy.MemTy = SI->getOperand(0)->getType();
AccessTy.AddrSpace = SI->getPointerAddressSpace();
} else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
AccessTy.AddrSpace = LI->getPointerAddressSpace();
@@ -923,7 +928,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
II->getArgOperand(0)->getType()->getPointerAddressSpace();
break;
case Intrinsic::masked_store:
- AccessTy.MemTy = II->getOperand(0)->getType();
AccessTy.AddrSpace =
II->getArgOperand(1)->getType()->getPointerAddressSpace();
break;
@@ -976,6 +980,7 @@ static bool isHighCostExpansion(const SCEV *S,
switch (S->getSCEVType()) {
case scUnknown:
case scConstant:
+ case scVScale:
return false;
case scTruncate:
return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(),
@@ -1414,7 +1419,7 @@ void Cost::RateFormula(const Formula &F,
C.ImmCost += 64; // Handle symbolic values conservatively.
// TODO: This should probably be the pointer size.
else if (Offset != 0)
- C.ImmCost += APInt(64, Offset, true).getMinSignedBits();
+ C.ImmCost += APInt(64, Offset, true).getSignificantBits();
// Check with target if this offset with this instruction is
// specifically not supported.
@@ -2498,7 +2503,7 @@ LSRInstance::OptimizeLoopTermCond() {
if (C->isOne() || C->isMinusOne())
goto decline_post_inc;
// Avoid weird situations.
- if (C->getValue().getMinSignedBits() >= 64 ||
+ if (C->getValue().getSignificantBits() >= 64 ||
C->getValue().isMinSignedValue())
goto decline_post_inc;
// Check for possible scaled-address reuse.
@@ -2508,13 +2513,13 @@ LSRInstance::OptimizeLoopTermCond() {
int64_t Scale = C->getSExtValue();
if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr,
/*BaseOffset=*/0,
- /*HasBaseReg=*/false, Scale,
+ /*HasBaseReg=*/true, Scale,
AccessTy.AddrSpace))
goto decline_post_inc;
Scale = -Scale;
if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr,
/*BaseOffset=*/0,
- /*HasBaseReg=*/false, Scale,
+ /*HasBaseReg=*/true, Scale,
AccessTy.AddrSpace))
goto decline_post_inc;
}
@@ -2660,8 +2665,7 @@ LSRUse *
LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF,
const LSRUse &OrigLU) {
// Search all uses for the formula. This could be more clever.
- for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) {
- LSRUse &LU = Uses[LUIdx];
+ for (LSRUse &LU : Uses) {
// Check whether this use is close enough to OrigLU, to see whether it's
// worthwhile looking through its formulae.
// Ignore ICmpZero uses because they may contain formulae generated by
@@ -2703,6 +2707,8 @@ void LSRInstance::CollectInterestingTypesAndFactors() {
SmallVector<const SCEV *, 4> Worklist;
for (const IVStrideUse &U : IU) {
const SCEV *Expr = IU.getExpr(U);
+ if (!Expr)
+ continue;
// Collect interesting types.
Types.insert(SE.getEffectiveSCEVType(Expr->getType()));
@@ -2740,13 +2746,13 @@ void LSRInstance::CollectInterestingTypesAndFactors() {
if (const SCEVConstant *Factor =
dyn_cast_or_null<SCEVConstant>(getExactSDiv(NewStride, OldStride,
SE, true))) {
- if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero())
+ if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero())
Factors.insert(Factor->getAPInt().getSExtValue());
} else if (const SCEVConstant *Factor =
dyn_cast_or_null<SCEVConstant>(getExactSDiv(OldStride,
NewStride,
SE, true))) {
- if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero())
+ if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero())
Factors.insert(Factor->getAPInt().getSExtValue());
}
}
@@ -2812,9 +2818,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) {
/// SCEVUnknown, we simply return the rightmost SCEV operand.
static const SCEV *getExprBase(const SCEV *S) {
switch (S->getSCEVType()) {
- default: // uncluding scUnknown.
+ default: // including scUnknown.
return S;
case scConstant:
+ case scVScale:
return nullptr;
case scTruncate:
return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand());
@@ -3175,7 +3182,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
return false;
- if (IncConst->getAPInt().getMinSignedBits() > 64)
+ if (IncConst->getAPInt().getSignificantBits() > 64)
return false;
MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
@@ -3320,6 +3327,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
}
const SCEV *S = IU.getExpr(U);
+ if (!S)
+ continue;
PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops();
// Equality (== and !=) ICmps are special. We can rewrite (i == N) as
@@ -3352,6 +3361,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
// S is normalized, so normalize N before folding it into S
// to keep the result normalized.
N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);
+ if (!N)
+ continue;
Kind = LSRUse::ICmpZero;
S = SE.getMinusSCEV(N, S);
} else if (L->isLoopInvariant(NV) &&
@@ -3366,6 +3377,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
// SCEV can't compute the difference of two unknown pointers.
N = SE.getUnknown(NV);
N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);
+ if (!N)
+ continue;
Kind = LSRUse::ICmpZero;
S = SE.getMinusSCEV(N, S);
assert(!isa<SCEVCouldNotCompute>(S));
@@ -3494,8 +3507,8 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
if (const Instruction *Inst = dyn_cast<Instruction>(V)) {
// Look for instructions defined outside the loop.
if (L->contains(Inst)) continue;
- } else if (isa<UndefValue>(V))
- // Undef doesn't have a live range, so it doesn't matter.
+ } else if (isa<Constant>(V))
+ // Constants can be re-materialized.
continue;
for (const Use &U : V->uses()) {
const Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
@@ -4137,6 +4150,29 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {
}
}
+/// Extend/Truncate \p Expr to \p ToTy considering post-inc uses in \p Loops.
+/// For all PostIncLoopSets in \p Loops, first de-normalize \p Expr, then
+/// perform the extension/truncate and normalize again, as the normalized form
+/// can result in folds that are not valid in the post-inc use contexts. The
+/// expressions for all PostIncLoopSets must match, otherwise return nullptr.
+static const SCEV *
+getAnyExtendConsideringPostIncUses(ArrayRef<PostIncLoopSet> Loops,
+ const SCEV *Expr, Type *ToTy,
+ ScalarEvolution &SE) {
+ const SCEV *Result = nullptr;
+ for (auto &L : Loops) {
+ auto *DenormExpr = denormalizeForPostIncUse(Expr, L, SE);
+ const SCEV *NewDenormExpr = SE.getAnyExtendExpr(DenormExpr, ToTy);
+ const SCEV *New = normalizeForPostIncUse(NewDenormExpr, L, SE);
+ if (!New || (Result && New != Result))
+ return nullptr;
+ Result = New;
+ }
+
+ assert(Result && "failed to create expression");
+ return Result;
+}
+
/// Generate reuse formulae from different IV types.
void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
// Don't bother truncating symbolic values.
@@ -4156,6 +4192,10 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
[](const SCEV *S) { return S->getType()->isPointerTy(); }))
return;
+ SmallVector<PostIncLoopSet> Loops;
+ for (auto &LF : LU.Fixups)
+ Loops.push_back(LF.PostIncLoops);
+
for (Type *SrcTy : Types) {
if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) {
Formula F = Base;
@@ -4165,15 +4205,17 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
// initial node (maybe due to depth limitations), but it can do them while
// taking ext.
if (F.ScaledReg) {
- const SCEV *NewScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy);
- if (NewScaledReg->isZero())
- continue;
+ const SCEV *NewScaledReg =
+ getAnyExtendConsideringPostIncUses(Loops, F.ScaledReg, SrcTy, SE);
+ if (!NewScaledReg || NewScaledReg->isZero())
+ continue;
F.ScaledReg = NewScaledReg;
}
bool HasZeroBaseReg = false;
for (const SCEV *&BaseReg : F.BaseRegs) {
- const SCEV *NewBaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy);
- if (NewBaseReg->isZero()) {
+ const SCEV *NewBaseReg =
+ getAnyExtendConsideringPostIncUses(Loops, BaseReg, SrcTy, SE);
+ if (!NewBaseReg || NewBaseReg->isZero()) {
HasZeroBaseReg = true;
break;
}
@@ -4379,8 +4421,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {
if ((C->getAPInt() + NewF.BaseOffset)
.abs()
.slt(std::abs(NewF.BaseOffset)) &&
- (C->getAPInt() + NewF.BaseOffset).countTrailingZeros() >=
- countTrailingZeros<uint64_t>(NewF.BaseOffset))
+ (C->getAPInt() + NewF.BaseOffset).countr_zero() >=
+ (unsigned)llvm::countr_zero<uint64_t>(NewF.BaseOffset))
goto skip_formula;
// Ok, looks good.
@@ -4982,6 +5024,32 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() {
LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs()));
}
+// Check if Best and Reg are SCEVs separated by a constant amount C, and if so
+// would the addressing offset +C would be legal where the negative offset -C is
+// not.
+static bool IsSimplerBaseSCEVForTarget(const TargetTransformInfo &TTI,
+ ScalarEvolution &SE, const SCEV *Best,
+ const SCEV *Reg,
+ MemAccessTy AccessType) {
+ if (Best->getType() != Reg->getType() ||
+ (isa<SCEVAddRecExpr>(Best) && isa<SCEVAddRecExpr>(Reg) &&
+ cast<SCEVAddRecExpr>(Best)->getLoop() !=
+ cast<SCEVAddRecExpr>(Reg)->getLoop()))
+ return false;
+ const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Best, Reg));
+ if (!Diff)
+ return false;
+
+ return TTI.isLegalAddressingMode(
+ AccessType.MemTy, /*BaseGV=*/nullptr,
+ /*BaseOffset=*/Diff->getAPInt().getSExtValue(),
+ /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace) &&
+ !TTI.isLegalAddressingMode(
+ AccessType.MemTy, /*BaseGV=*/nullptr,
+ /*BaseOffset=*/-Diff->getAPInt().getSExtValue(),
+ /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace);
+}
+
/// Pick a register which seems likely to be profitable, and then in any use
/// which has any reference to that register, delete all formulae which do not
/// reference that register.
@@ -5010,6 +5078,19 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() {
Best = Reg;
BestNum = Count;
}
+
+ // If the scores are the same, but the Reg is simpler for the target
+ // (for example {x,+,1} as opposed to {x+C,+,1}, where the target can
+ // handle +C but not -C), opt for the simpler formula.
+ if (Count == BestNum) {
+ int LUIdx = RegUses.getUsedByIndices(Reg).find_first();
+ if (LUIdx >= 0 && Uses[LUIdx].Kind == LSRUse::Address &&
+ IsSimplerBaseSCEVForTarget(TTI, SE, Best, Reg,
+ Uses[LUIdx].AccessTy)) {
+ Best = Reg;
+ BestNum = Count;
+ }
+ }
}
}
assert(Best && "Failed to find best LSRUse candidate");
@@ -5497,6 +5578,13 @@ void LSRInstance::RewriteForPHI(
PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F,
SmallVectorImpl<WeakTrackingVH> &DeadInsts) const {
DenseMap<BasicBlock *, Value *> Inserted;
+
+ // Inserting instructions in the loop and using them as PHI's input could
+ // break LCSSA in case if PHI's parent block is not a loop exit (i.e. the
+ // corresponding incoming block is not loop exiting). So collect all such
+ // instructions to form LCSSA for them later.
+ SmallVector<Instruction *, 4> InsertedNonLCSSAInsts;
+
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
if (PN->getIncomingValue(i) == LF.OperandValToReplace) {
bool needUpdateFixups = false;
@@ -5562,6 +5650,13 @@ void LSRInstance::RewriteForPHI(
FullV, LF.OperandValToReplace->getType(),
"tmp", BB->getTerminator());
+ // If the incoming block for this value is not in the loop, it means the
+ // current PHI is not in a loop exit, so we must create a LCSSA PHI for
+ // the inserted value.
+ if (auto *I = dyn_cast<Instruction>(FullV))
+ if (L->contains(I) && !L->contains(BB))
+ InsertedNonLCSSAInsts.push_back(I);
+
PN->setIncomingValue(i, FullV);
Pair.first->second = FullV;
}
@@ -5604,6 +5699,8 @@ void LSRInstance::RewriteForPHI(
}
}
}
+
+ formLCSSAForInstructions(InsertedNonLCSSAInsts, DT, LI, &SE);
}
/// Emit instructions for the leading candidate expression for this LSRUse (this
@@ -5643,6 +5740,36 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF,
DeadInsts.emplace_back(OperandIsInstr);
}
+// Trying to hoist the IVInc to loop header if all IVInc users are in
+// the loop header. It will help backend to generate post index load/store
+// when the latch block is different from loop header block.
+static bool canHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup,
+ const LSRUse &LU, Instruction *IVIncInsertPos,
+ Loop *L) {
+ if (LU.Kind != LSRUse::Address)
+ return false;
+
+ // For now this code do the conservative optimization, only work for
+ // the header block. Later we can hoist the IVInc to the block post
+ // dominate all users.
+ BasicBlock *LHeader = L->getHeader();
+ if (IVIncInsertPos->getParent() == LHeader)
+ return false;
+
+ if (!Fixup.OperandValToReplace ||
+ any_of(Fixup.OperandValToReplace->users(), [&LHeader](User *U) {
+ Instruction *UI = cast<Instruction>(U);
+ return UI->getParent() != LHeader;
+ }))
+ return false;
+
+ Instruction *I = Fixup.UserInst;
+ Type *Ty = I->getType();
+ return Ty->isIntegerTy() &&
+ ((isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) ||
+ (isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty)));
+}
+
/// Rewrite all the fixup locations with new values, following the chosen
/// solution.
void LSRInstance::ImplementSolution(
@@ -5651,8 +5778,6 @@ void LSRInstance::ImplementSolution(
// we can remove them after we are done working.
SmallVector<WeakTrackingVH, 16> DeadInsts;
- Rewriter.setIVIncInsertPos(L, IVIncInsertPos);
-
// Mark phi nodes that terminate chains so the expander tries to reuse them.
for (const IVChain &Chain : IVChainVec) {
if (PHINode *PN = dyn_cast<PHINode>(Chain.tailUserInst()))
@@ -5662,6 +5787,11 @@ void LSRInstance::ImplementSolution(
// Expand the new value definitions and update the users.
for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx)
for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) {
+ Instruction *InsertPos =
+ canHoistIVInc(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, L)
+ ? L->getHeader()->getTerminator()
+ : IVIncInsertPos;
+ Rewriter.setIVIncInsertPos(L, InsertPos);
Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts);
Changed = true;
}
@@ -5994,7 +6124,7 @@ struct SCEVDbgValueBuilder {
}
bool pushConst(const SCEVConstant *C) {
- if (C->getAPInt().getMinSignedBits() > 64)
+ if (C->getAPInt().getSignificantBits() > 64)
return false;
Expr.push_back(llvm::dwarf::DW_OP_consts);
Expr.push_back(C->getAPInt().getSExtValue());
@@ -6083,7 +6213,7 @@ struct SCEVDbgValueBuilder {
/// SCEV constant value is an identity function.
bool isIdentityFunction(uint64_t Op, const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (C->getAPInt().getMinSignedBits() > 64)
+ if (C->getAPInt().getSignificantBits() > 64)
return false;
int64_t I = C->getAPInt().getSExtValue();
switch (Op) {
@@ -6338,13 +6468,13 @@ static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec,
}
}
-/// Cached location ops may be erased during LSR, in which case an undef is
+/// Cached location ops may be erased during LSR, in which case a poison is
/// required when restoring from the cache. The type of that location is no
-/// longer available, so just use int8. The undef will be replaced by one or
+/// longer available, so just use int8. The poison will be replaced by one or
/// more locations later when a SCEVDbgValueBuilder selects alternative
/// locations to use for the salvage.
-static Value *getValueOrUndef(WeakVH &VH, LLVMContext &C) {
- return (VH) ? VH : UndefValue::get(llvm::Type::getInt8Ty(C));
+static Value *getValueOrPoison(WeakVH &VH, LLVMContext &C) {
+ return (VH) ? VH : PoisonValue::get(llvm::Type::getInt8Ty(C));
}
/// Restore the DVI's pre-LSR arguments. Substitute undef for any erased values.
@@ -6363,12 +6493,12 @@ static void restorePreTransformState(DVIRecoveryRec &DVIRec) {
// this case was not present before, so force the location back to a single
// uncontained Value.
Value *CachedValue =
- getValueOrUndef(DVIRec.LocationOps[0], DVIRec.DVI->getContext());
+ getValueOrPoison(DVIRec.LocationOps[0], DVIRec.DVI->getContext());
DVIRec.DVI->setRawLocation(ValueAsMetadata::get(CachedValue));
} else {
SmallVector<ValueAsMetadata *, 3> MetadataLocs;
for (WeakVH VH : DVIRec.LocationOps) {
- Value *CachedValue = getValueOrUndef(VH, DVIRec.DVI->getContext());
+ Value *CachedValue = getValueOrPoison(VH, DVIRec.DVI->getContext());
MetadataLocs.push_back(ValueAsMetadata::get(CachedValue));
}
auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs);
@@ -6431,7 +6561,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE,
// less DWARF ops than an iteration count-based expression.
if (std::optional<APInt> Offset =
SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) {
- if (Offset->getMinSignedBits() <= 64)
+ if (Offset->getSignificantBits() <= 64)
SalvageExpr->createOffsetExpr(Offset->getSExtValue(), LSRInductionVar);
} else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr,
SE))
@@ -6607,7 +6737,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
return nullptr;
}
-static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *>>
+static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
const LoopInfo &LI) {
if (!L->isInnermost()) {
@@ -6626,16 +6756,13 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
}
BasicBlock *LoopLatch = L->getLoopLatch();
-
- // TODO: Can we do something for greater than and less than?
- // Terminating condition is foldable when it is an eq/ne icmp
- BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
- if (BI->isUnconditional())
+ BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
+ if (!BI || BI->isUnconditional())
return std::nullopt;
- Value *TermCond = BI->getCondition();
- if (!isa<ICmpInst>(TermCond) || !cast<ICmpInst>(TermCond)->isEquality()) {
- LLVM_DEBUG(dbgs() << "Cannot fold on branching condition that is not an "
- "ICmpInst::eq / ICmpInst::ne\n");
+ auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
+ if (!TermCond) {
+ LLVM_DEBUG(
+ dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
return std::nullopt;
}
if (!TermCond->hasOneUse()) {
@@ -6645,89 +6772,42 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it
- // is only used by the terminating condition. To check for this, we may need
- // to traverse through a chain of use-def until we can examine the final
- // usage.
- // *----------------------*
- // *---->| LoopHeader: |
- // | | PrimaryIV = phi ... |
- // | *----------------------*
- // | |
- // | |
- // | chain of
- // | single use
- // used by |
- // phi |
- // | Value
- // | / \
- // | chain of chain of
- // | single use single use
- // | / \
- // | / \
- // *- Value Value --> used by terminating condition
- auto IsToFold = [&](PHINode &PN) -> bool {
- Value *V = &PN;
-
- while (V->getNumUses() == 1)
- V = *V->user_begin();
-
- if (V->getNumUses() != 2)
- return false;
+ BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
+ Value *RHS = TermCond->getOperand(1);
+ if (!LHS || !L->isLoopInvariant(RHS))
+ // We could pattern match the inverse form of the icmp, but that is
+ // non-canonical, and this pass is running *very* late in the pipeline.
+ return std::nullopt;
- Value *VToPN = nullptr;
- Value *VToTermCond = nullptr;
- for (User *U : V->users()) {
- while (U->getNumUses() == 1) {
- if (isa<PHINode>(U))
- VToPN = U;
- if (U == TermCond)
- VToTermCond = U;
- U = *U->user_begin();
- }
- }
- return VToPN && VToTermCond;
- };
+ // Find the IV used by the current exit condition.
+ PHINode *ToFold;
+ Value *ToFoldStart, *ToFoldStep;
+ if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ return std::nullopt;
- // If this is an IV which we could replace the terminating condition, return
- // the final value of the alternative IV on the last iteration.
- auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * {
- // FIXME: This does not properly account for overflow.
- const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
- const SCEV *BECount = SE.getBackedgeTakenCount(L);
- const SCEV *TermValueS = SE.getAddExpr(
- AddRec->getOperand(0),
- SE.getTruncateOrZeroExtend(
- SE.getMulExpr(
- AddRec->getOperand(1),
- SE.getTruncateOrZeroExtend(
- SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
- AddRec->getOperand(1)->getType())),
- AddRec->getOperand(0)->getType()));
- const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
- SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
- if (!Expander.isSafeToExpand(TermValueS)) {
- LLVM_DEBUG(
- dbgs() << "Is not safe to expand terminating value for phi node" << PN
- << "\n");
- return nullptr;
- }
- return TermValueS;
- };
+ // If that IV isn't dead after we rewrite the exit condition in terms of
+ // another IV, there's no point in doing the transform.
+ if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
+ return std::nullopt;
+
+ const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
- PHINode *ToFold = nullptr;
PHINode *ToHelpFold = nullptr;
const SCEV *TermValueS = nullptr;
-
+ bool MustDropPoison = false;
for (PHINode &PN : L->getHeader()->phis()) {
+ if (ToFold == &PN)
+ continue;
+
if (!SE.isSCEVable(PN.getType())) {
LLVM_DEBUG(dbgs() << "IV of phi '" << PN
<< "' is not SCEV-able, not qualified for the "
"terminating condition folding.\n");
continue;
}
- const SCEV *S = SE.getSCEV(&PN);
- const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S);
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
// Only speculate on affine AddRec
if (!AddRec || !AddRec->isAffine()) {
LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
@@ -6736,12 +6816,63 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
continue;
}
- if (IsToFold(PN))
- ToFold = &PN;
- else if (auto P = getAlternateIVEnd(PN)) {
- ToHelpFold = &PN;
- TermValueS = P;
+ // Check that we can compute the value of AddRec on the exiting iteration
+ // without soundness problems. evaluateAtIteration internally needs
+ // to multiply the stride of the iteration number - which may wrap around.
+ // The issue here is subtle because computing the result accounting for
+ // wrap is insufficient. In order to use the result in an exit test, we
+ // must also know that AddRec doesn't take the same value on any previous
+ // iteration. The simplest case to consider is a candidate IV which is
+ // narrower than the trip count (and thus original IV), but this can
+ // also happen due to non-unit strides on the candidate IVs.
+ if (!AddRec->hasNoSelfWrap())
+ continue;
+
+ const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
+ const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
+ if (!Expander.isSafeToExpand(TermValueSLocal)) {
+ LLVM_DEBUG(
+ dbgs() << "Is not safe to expand terminating value for phi node" << PN
+ << "\n");
+ continue;
}
+
+ // The candidate IV may have been otherwise dead and poison from the
+ // very first iteration. If we can't disprove that, we can't use the IV.
+ if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV "
+ << PN << "\n");
+ continue;
+ }
+
+ // The candidate IV may become poison on the last iteration. If this
+ // value is not branched on, this is a well defined program. We're
+ // about to add a new use to this IV, and we have to ensure we don't
+ // insert UB which didn't previously exist.
+ bool MustDropPoisonLocal = false;
+ Instruction *PostIncV =
+ cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
+ if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
+ &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use"
+ << PN << "\n");
+
+ // If this is a complex recurrance with multiple instructions computing
+ // the backedge value, we might need to strip poison flags from all of
+ // them.
+ if (PostIncV->getOperand(0) != &PN)
+ continue;
+
+ // In order to perform the transform, we need to drop the poison generating
+ // flags on this instruction (if any).
+ MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
+ }
+
+ // We pick the last legal alternate IV. We could expore choosing an optimal
+ // alternate IV if we had a decent heuristic to do so.
+ ToHelpFold = &PN;
+ TermValueS = TermValueSLocal;
+ MustDropPoison = MustDropPoisonLocal;
}
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
@@ -6757,7 +6888,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (!ToFold || !ToHelpFold)
return std::nullopt;
- return std::make_tuple(ToFold, ToHelpFold, TermValueS);
+ return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
}
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
@@ -6820,7 +6951,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
if (AllowTerminatingConditionFoldingAfterLSR) {
if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) {
- auto [ToFold, ToHelpFold, TermValueS] = *Opt;
+ auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
Changed = true;
NumTermFold++;
@@ -6838,6 +6969,10 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
(void)StartValue;
Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
+ // See comment in canFoldTermCondOfLoop on why this is sufficient.
+ if (MustDrop)
+ cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
+
// SCEVExpander for both use in preheader and latch
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
@@ -6859,11 +6994,12 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
- // FIXME: We are adding a use of an IV here without account for poison safety.
- // This is incorrect.
- Value *NewTermCond = LatchBuilder.CreateICmp(
- OldTermCond->getPredicate(), LoopValue, TermValue,
- "lsr_fold_term_cond.replaced_term_cond");
+ Value *NewTermCond =
+ LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
+ "lsr_fold_term_cond.replaced_term_cond");
+ // Swap successors to exit loop body if IV equals to new TermValue
+ if (BI->getSuccessor(0) == L->getHeader())
+ BI->swapSuccessors();
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
<< *OldTermCond << "\n"
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
index 0ae26b494c5a..9c6e4ebf62a9 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
@@ -32,15 +32,11 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/PassRegistry.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -460,76 +456,6 @@ static bool tryToUnrollAndJamLoop(LoopNest &LN, DominatorTree &DT, LoopInfo &LI,
return DidSomething;
}
-namespace {
-
-class LoopUnrollAndJam : public LoopPass {
-public:
- static char ID; // Pass ID, replacement for typeid
- unsigned OptLevel;
-
- LoopUnrollAndJam(int OptLevel = 2) : LoopPass(ID), OptLevel(OptLevel) {
- initializeLoopUnrollAndJamPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
-
- auto *F = L->getHeader()->getParent();
- auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*F);
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F);
-
- LoopUnrollResult Result =
- tryToUnrollAndJamLoop(L, DT, LI, SE, TTI, AC, DI, ORE, OptLevel);
-
- if (Result == LoopUnrollResult::FullyUnrolled)
- LPM.markLoopAsDeleted(*L);
-
- return Result != LoopUnrollResult::Unmodified;
- }
-
- /// This transformation requires natural loop information & requires that
- /// loop preheaders be inserted into the CFG...
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DependenceAnalysisWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-char LoopUnrollAndJam::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopUnrollAndJam, "loop-unroll-and-jam",
- "Unroll and Jam loops", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_END(LoopUnrollAndJam, "loop-unroll-and-jam",
- "Unroll and Jam loops", false, false)
-
-Pass *llvm::createLoopUnrollAndJamPass(int OptLevel) {
- return new LoopUnrollAndJam(OptLevel);
-}
-
PreservedAnalyses LoopUnrollAndJamPass::run(LoopNest &LN,
LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
index 1a6065cb3f1a..335b489d3cb2 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -1124,7 +1124,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
const TargetTransformInfo &TTI, AssumptionCache &AC,
OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel,
- bool OnlyWhenForced, bool ForgetAllSCEV,
+ bool OnlyFullUnroll, bool OnlyWhenForced, bool ForgetAllSCEV,
std::optional<unsigned> ProvidedCount,
std::optional<unsigned> ProvidedThreshold,
std::optional<bool> ProvidedAllowPartial,
@@ -1133,6 +1133,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
std::optional<bool> ProvidedAllowPeeling,
std::optional<bool> ProvidedAllowProfileBasedPeeling,
std::optional<unsigned> ProvidedFullUnrollMaxCount) {
+
LLVM_DEBUG(dbgs() << "Loop Unroll: F["
<< L->getHeader()->getParent()->getName() << "] Loop %"
<< L->getHeader()->getName() << "\n");
@@ -1304,6 +1305,13 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
return LoopUnrollResult::Unmodified;
}
+ // Do not attempt partial/runtime unrolling in FullLoopUnrolling
+ if (OnlyFullUnroll && !(UP.Count >= MaxTripCount)) {
+ LLVM_DEBUG(
+ dbgs() << "Not attempting partial/runtime unroll in FullLoopUnroll.\n");
+ return LoopUnrollResult::Unmodified;
+ }
+
// At this point, UP.Runtime indicates that run-time unrolling is allowed.
// However, we only want to actually perform it if we don't know the trip
// count and the unroll count doesn't divide the known trip multiple.
@@ -1420,10 +1428,10 @@ public:
LoopUnrollResult Result = tryToUnrollLoop(
L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, PreserveLCSSA, OptLevel,
- OnlyWhenForced, ForgetAllSCEV, ProvidedCount, ProvidedThreshold,
- ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound,
- ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling,
- ProvidedFullUnrollMaxCount);
+ /*OnlyFullUnroll*/ false, OnlyWhenForced, ForgetAllSCEV, ProvidedCount,
+ ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime,
+ ProvidedUpperBound, ProvidedAllowPeeling,
+ ProvidedAllowProfileBasedPeeling, ProvidedFullUnrollMaxCount);
if (Result == LoopUnrollResult::FullyUnrolled)
LPM.markLoopAsDeleted(*L);
@@ -1469,12 +1477,6 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced,
AllowPeeling == -1 ? std::nullopt : std::optional<bool>(AllowPeeling));
}
-Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced,
- bool ForgetAllSCEV) {
- return createLoopUnrollPass(OptLevel, OnlyWhenForced, ForgetAllSCEV, -1, -1,
- 0, 0, 0, 1);
-}
-
PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &Updater) {
@@ -1497,8 +1499,8 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
bool Changed =
tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE,
/*BFI*/ nullptr, /*PSI*/ nullptr,
- /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced,
- ForgetSCEV, /*Count*/ std::nullopt,
+ /*PreserveLCSSA*/ true, OptLevel, /*OnlyFullUnroll*/ true,
+ OnlyWhenForced, ForgetSCEV, /*Count*/ std::nullopt,
/*Threshold*/ std::nullopt, /*AllowPartial*/ false,
/*Runtime*/ false, /*UpperBound*/ false,
/*AllowPeeling*/ true,
@@ -1623,8 +1625,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
// flavors of unrolling during construction time (by setting UnrollOpts).
LoopUnrollResult Result = tryToUnrollLoop(
&L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI,
- /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced,
- UnrollOpts.ForgetSCEV, /*Count*/ std::nullopt,
+ /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false,
+ UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV,
+ /*Count*/ std::nullopt,
/*Threshold*/ std::nullopt, UnrollOpts.AllowPartial,
UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling,
UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount);
@@ -1651,7 +1654,7 @@ void LoopUnrollPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<LoopUnrollPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (UnrollOpts.AllowPartial != std::nullopt)
OS << (*UnrollOpts.AllowPartial ? "" : "no-") << "partial;";
if (UnrollOpts.AllowPeeling != std::nullopt)
@@ -1664,7 +1667,7 @@ void LoopUnrollPass::printPipeline(
OS << (*UnrollOpts.AllowProfileBasedPeeling ? "" : "no-")
<< "profile-peeling;";
if (UnrollOpts.FullUnrollMaxCount != std::nullopt)
- OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ";";
- OS << "O" << UnrollOpts.OptLevel;
- OS << ">";
+ OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ';';
+ OS << 'O' << UnrollOpts.OptLevel;
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
index 848be25a2fe0..13e06c79d0d7 100644
--- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp
@@ -77,13 +77,10 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
@@ -113,33 +110,6 @@ static cl::opt<unsigned> LVLoopDepthThreshold(
namespace {
-struct LoopVersioningLICMLegacyPass : public LoopPass {
- static char ID;
-
- LoopVersioningLICMLegacyPass() : LoopPass(ID) {
- initializeLoopVersioningLICMLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override;
-
- StringRef getPassName() const override { return "Loop Versioning for LICM"; }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequiredID(LCSSAID);
- AU.addRequired<LoopAccessLegacyAnalysis>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addRequiredID(LoopSimplifyID);
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- }
-};
-
struct LoopVersioningLICM {
// We don't explicitly pass in LoopAccessInfo to the constructor since the
// loop versioning might return early due to instructions that are not safe
@@ -563,21 +533,6 @@ void LoopVersioningLICM::setNoAliasToLoop(Loop *VerLoop) {
}
}
-bool LoopVersioningLICMLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
- if (skipLoop(L))
- return false;
-
- AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- OptimizationRemarkEmitter *ORE =
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
-
- return LoopVersioningLICM(AA, SE, ORE, LAIs, LI, L).run(DT);
-}
-
bool LoopVersioningLICM::run(DominatorTree *DT) {
// Do not do the transformation if disabled by metadata.
if (hasLICMVersioningTransformation(CurLoop) & TM_Disable)
@@ -611,26 +566,6 @@ bool LoopVersioningLICM::run(DominatorTree *DT) {
return Changed;
}
-char LoopVersioningLICMLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopVersioningLICMLegacyPass, "loop-versioning-licm",
- "Loop Versioning For LICM", false, false)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_END(LoopVersioningLICMLegacyPass, "loop-versioning-licm",
- "Loop Versioning For LICM", false, false)
-
-Pass *llvm::createLoopVersioningLICMPass() {
- return new LoopVersioningLICMLegacyPass();
-}
-
namespace llvm {
PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM,
diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
index ef22b0401b1b..b167120a906d 100644
--- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp
@@ -29,6 +29,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <optional>
@@ -136,10 +137,12 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI,
continue;
case Intrinsic::is_constant:
NewValue = lowerIsConstantIntrinsic(II);
+ LLVM_DEBUG(dbgs() << "Folding " << *II << " to " << *NewValue << "\n");
IsConstantIntrinsicsHandled++;
break;
case Intrinsic::objectsize:
NewValue = lowerObjectSizeCall(II, DL, &TLI, true);
+ LLVM_DEBUG(dbgs() << "Folding " << *II << " to " << *NewValue << "\n");
ObjectSizeIntrinsicsHandled++;
break;
}
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 17594b98c5bc..f46ea6a20afa 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled(
cl::desc("Allow the use of FMAs if available and profitable. This may "
"result in different results, due to less rounding error."));
+static cl::opt<bool>
+ VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
+ cl::desc("Enable/disable matrix shape verification."),
+ cl::init(false));
+
enum class MatrixLayoutTy { ColumnMajor, RowMajor };
static cl::opt<MatrixLayoutTy> MatrixLayout(
@@ -267,7 +272,7 @@ class LowerMatrixIntrinsics {
unsigned D = isColumnMajor() ? NumColumns : NumRows;
for (unsigned J = 0; J < D; ++J)
- addVector(UndefValue::get(FixedVectorType::get(
+ addVector(PoisonValue::get(FixedVectorType::get(
EltTy, isColumnMajor() ? NumRows : NumColumns)));
}
@@ -535,6 +540,15 @@ public:
auto SIter = ShapeMap.find(V);
if (SIter != ShapeMap.end()) {
+ if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
+ SIter->second.NumColumns != Shape.NumColumns)) {
+ errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
+ << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
+ << Shape.NumColumns << ") for " << *V << "\n";
+ report_fatal_error(
+ "Matrix shape verification failed, compilation aborted!");
+ }
+
LLVM_DEBUG(dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " "
<< SIter->second.NumColumns << " for " << *V << "\n");
@@ -838,10 +852,13 @@ public:
auto NewInst = distributeTransposes(
TAMA, {R, C}, TAMB, {R, C}, Builder,
[&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- auto *FAdd =
- cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
- setShapeInfo(FAdd, Shape0);
- return FAdd;
+ bool IsFP = I.getType()->isFPOrFPVectorTy();
+ auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
+ : LocalBuilder.CreateAdd(T0, T1, "madd");
+
+ auto *Result = cast<Instruction>(Add);
+ setShapeInfo(Result, Shape0);
+ return Result;
});
updateShapeAndReplaceAllUsesWith(I, NewInst);
eraseFromParentAndMove(&I, II, BB);
@@ -978,13 +995,18 @@ public:
MatrixInsts.push_back(&I);
}
- // Second, try to fuse candidates.
+ // Second, try to lower any dot products
SmallPtrSet<Instruction *, 16> FusedInsts;
for (CallInst *CI : MaybeFusableInsts)
+ lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
+
+ // Third, try to fuse candidates.
+ for (CallInst *CI : MaybeFusableInsts)
LowerMatrixMultiplyFused(CI, FusedInsts);
+
Changed = !FusedInsts.empty();
- // Third, lower remaining instructions with shape information.
+ // Fourth, lower remaining instructions with shape information.
for (Instruction *Inst : MatrixInsts) {
if (FusedInsts.count(Inst))
continue;
@@ -1311,6 +1333,165 @@ public:
}
}
+ /// Special case for MatMul lowering. Prevents scalar loads of row-major
+ /// vectors Lowers to vector reduction add instead of sequential add if
+ /// reassocation is enabled.
+ void lowerDotProduct(CallInst *MatMul,
+ SmallPtrSet<Instruction *, 16> &FusedInsts,
+ FastMathFlags FMF) {
+ if (FusedInsts.contains(MatMul) ||
+ MatrixLayout != MatrixLayoutTy::ColumnMajor)
+ return;
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
+ return;
+
+ Value *LHS = MatMul->getArgOperand(0);
+ Value *RHS = MatMul->getArgOperand(1);
+
+ Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+ bool IsIntVec = ElementType->isIntegerTy();
+
+ // Floating point reductions require reassocation.
+ if (!IsIntVec && !FMF.allowReassoc())
+ return;
+
+ auto CanBeFlattened = [this](Value *Op) {
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+ return true;
+ return match(
+ Op, m_OneUse(m_CombineOr(
+ m_Load(m_Value()),
+ m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
+ m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_SpecificInt(1))))));
+ };
+ // Returns the cost benefit of using \p Op with the dot product lowering. If
+ // the returned cost is < 0, the argument is cheaper to use in the
+ // dot-product lowering.
+ auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+ if (!isa<Instruction>(Op))
+ return InstructionCost(0);
+
+ FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
+ Type *EltTy = VecTy->getElementType();
+
+ if (!CanBeFlattened(Op)) {
+ InstructionCost EmbedCost(0);
+ // Roughly estimate the cost for embedding the columns into a vector.
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost -=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
+ }
+
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
+ InstructionCost OriginalCost =
+ TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
+ EltTy) *
+ N;
+ InstructionCost NewCost = TTI.getArithmeticInstrCost(
+ cast<Instruction>(Op)->getOpcode(), VecTy);
+ return NewCost - OriginalCost;
+ }
+
+ if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
+ // The transpose can be skipped for the dot product lowering, roughly
+ // estimate the savings as the cost of embedding the columns in a
+ // vector.
+ InstructionCost EmbedCost(0);
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost +=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
+ }
+
+ // Costs for loads.
+ if (N == 1)
+ return InstructionCost(0);
+
+ return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
+ N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
+ };
+ auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+ // We compare the costs of a vector.reduce.add to sequential add.
+ int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
+ int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
+ InstructionCost ReductionCost =
+ TTI.getArithmeticReductionCost(
+ AddOpCode, cast<VectorType>(LHS->getType()),
+ IsIntVec ? std::nullopt : std::optional(FMF)) +
+ TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
+ InstructionCost SequentialAddCost =
+ TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
+ (LShape.NumColumns - 1) +
+ TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
+ (LShape.NumColumns);
+ if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
+ return;
+
+ FusedInsts.insert(MatMul);
+ IRBuilder<> Builder(MatMul);
+ auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
+ this](Value *Op) -> Value * {
+ // Matmul must be the only user of loads because we don't use LowerLoad
+ // for row vectors (LowerLoad results in scalar loads and shufflevectors
+ // instead of single vector load).
+ if (!CanBeFlattened(Op))
+ return Op;
+
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
+ ShapeMap[Op] = ShapeMap[Op].t();
+ return Op;
+ }
+
+ FusedInsts.insert(cast<Instruction>(Op));
+ // If vector uses the builtin load, lower to a LoadInst
+ Value *Arg;
+ if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(Arg)))) {
+ auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
+ Op->replaceAllUsesWith(NewLoad);
+ cast<Instruction>(Op)->eraseFromParent();
+ return NewLoad;
+ } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(Arg)))) {
+ ToRemove.push_back(cast<Instruction>(Op));
+ return Arg;
+ }
+
+ return Op;
+ };
+ LHS = FlattenArg(LHS);
+
+ // Insert mul/fmul and llvm.vector.reduce.fadd
+ Value *Mul =
+ IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
+
+ Value *Result;
+ if (IsIntVec)
+ Result = Builder.CreateAddReduce(Mul);
+ else {
+ Result = Builder.CreateFAddReduce(
+ ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
+ 0.0),
+ Mul);
+ cast<Instruction>(Result)->setFastMathFlags(FMF);
+ }
+
+ // pack scalar back into a matrix and then replace matmul inst
+ Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
+ Result, uint64_t(0));
+ MatMul->replaceAllUsesWith(Result);
+ FusedInsts.insert(MatMul);
+ ToRemove.push_back(MatMul);
+ }
+
/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
@@ -1469,15 +1650,14 @@ public:
auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
AllocaInst *Alloca =
Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
- Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
- Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
+ Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
Load->getAlign(), LoadLoc.Size.getValue());
Builder.SetInsertPoint(Fusion, Fusion->begin());
PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
PHI->addIncoming(Load->getPointerOperand(), Check0);
PHI->addIncoming(Load->getPointerOperand(), Check1);
- PHI->addIncoming(BC, Copy);
+ PHI->addIncoming(Alloca, Copy);
// Adjust DT.
DTUpdates.push_back({DT->Insert, Check0, Check1});
@@ -2397,99 +2577,8 @@ void LowerMatrixIntrinsicsPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Minimal)
OS << "minimal";
- OS << ">";
-}
-
-namespace {
-
-class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
- bool C = LMT.Visit();
- return C;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- }
-};
-} // namespace
-
-static const char pass_name[] = "Lower the matrix intrinsics";
-char LowerMatrixIntrinsicsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-
-Pass *llvm::createLowerMatrixIntrinsicsPass() {
- return new LowerMatrixIntrinsicsLegacyPass();
-}
-
-namespace {
-
-/// A lightweight version of the matrix lowering pass that only requires TTI.
-/// Advanced features that require DT, AA or ORE like tiling are disabled. This
-/// is used to lower matrix intrinsics if the main lowering pass is not run, for
-/// example with -O0.
-class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
- bool C = LMT.Visit();
- return C;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
- }
-};
-} // namespace
-
-static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
-char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal,
- false, false)
-INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
- false)
-
-Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
- return new LowerMatrixIntrinsicsMinimalLegacyPass();
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 64846484f936..68642a01b37c 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -46,13 +46,10 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
#include <cassert>
@@ -72,6 +69,7 @@ STATISTIC(NumMemSetInfer, "Number of memsets inferred");
STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy");
STATISTIC(NumCpyToSet, "Number of memcpys converted to memset");
STATISTIC(NumCallSlot, "Number of call slot optimizations performed");
+STATISTIC(NumStackMove, "Number of stack-move optimizations performed");
namespace {
@@ -255,54 +253,6 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
// MemCpyOptLegacyPass Pass
//===----------------------------------------------------------------------===//
-namespace {
-
-class MemCpyOptLegacyPass : public FunctionPass {
- MemCpyOptPass Impl;
-
-public:
- static char ID; // Pass identification, replacement for typeid
-
- MemCpyOptLegacyPass() : FunctionPass(ID) {
- initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
-private:
- // This transformation requires dominator postdominator info
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-char MemCpyOptLegacyPass::ID = 0;
-
-/// The public interface to this file...
-FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); }
-
-INITIALIZE_PASS_BEGIN(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_END(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization",
- false, false)
-
// Check that V is either not accessible by the caller, or unwinding cannot
// occur between Start and End.
static bool mayBeVisibleThroughUnwinding(Value *V, Instruction *Start,
@@ -463,7 +413,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
// Check to see if this store is to a constant offset from the start ptr.
std::optional<int64_t> Offset =
- isPointerOffset(StartPtr, NextStore->getPointerOperand(), DL);
+ NextStore->getPointerOperand()->getPointerOffsetFrom(StartPtr, DL);
if (!Offset)
break;
@@ -477,7 +427,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
// Check to see if this store is to a constant offset from the start ptr.
std::optional<int64_t> Offset =
- isPointerOffset(StartPtr, MSI->getDest(), DL);
+ MSI->getDest()->getPointerOffsetFrom(StartPtr, DL);
if (!Offset)
break;
@@ -781,6 +731,23 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI,
return true;
}
+ // If this is a load-store pair from a stack slot to a stack slot, we
+ // might be able to perform the stack-move optimization just as we do for
+ // memcpys from an alloca to an alloca.
+ if (auto *DestAlloca = dyn_cast<AllocaInst>(SI->getPointerOperand())) {
+ if (auto *SrcAlloca = dyn_cast<AllocaInst>(LI->getPointerOperand())) {
+ if (performStackMoveOptzn(LI, SI, DestAlloca, SrcAlloca,
+ DL.getTypeStoreSize(T), BAA)) {
+ // Avoid invalidating the iterator.
+ BBI = SI->getNextNonDebugInstruction()->getIterator();
+ eraseInstruction(SI);
+ eraseInstruction(LI);
+ ++NumMemCpyInstr;
+ return true;
+ }
+ }
+ }
+
return false;
}
@@ -1200,8 +1167,14 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
// still want to eliminate the intermediate value, but we have to generate a
// memmove instead of memcpy.
bool UseMemMove = false;
- if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep))))
+ if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep)))) {
+ // Don't convert llvm.memcpy.inline into memmove because memmove can be
+ // lowered as a call, and that is not allowed for llvm.memcpy.inline (and
+ // there is no inline version of llvm.memmove)
+ if (isa<MemCpyInlineInst>(M))
+ return false;
UseMemMove = true;
+ }
// If all checks passed, then we can transform M.
LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy->memcpy src:\n"
@@ -1246,13 +1219,18 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
/// In other words, transform:
/// \code
/// memset(dst, c, dst_size);
+/// ...
/// memcpy(dst, src, src_size);
/// \endcode
/// into:
/// \code
-/// memcpy(dst, src, src_size);
+/// ...
/// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size);
+/// memcpy(dst, src, src_size);
/// \endcode
+///
+/// The memset is sunk to just before the memcpy to ensure that src_size is
+/// present when emitting the simplified memset.
bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
MemSetInst *MemSet,
BatchAAResults &BAA) {
@@ -1300,6 +1278,15 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
IRBuilder<> Builder(MemCpy);
+ // Preserve the debug location of the old memset for the code emitted here
+ // related to the new memset. This is correct according to the rules in
+ // https://llvm.org/docs/HowToUpdateDebugInfo.html about "when to preserve an
+ // instruction location", given that we move the memset within the basic
+ // block.
+ assert(MemSet->getParent() == MemCpy->getParent() &&
+ "Preserving debug location based on moving memset within BB.");
+ Builder.SetCurrentDebugLocation(MemSet->getDebugLoc());
+
// If the sizes have different types, zext the smaller one.
if (DestSize->getType() != SrcSize->getType()) {
if (DestSize->getType()->getIntegerBitWidth() >
@@ -1323,9 +1310,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) &&
"MemCpy must be a MemoryDef");
- // The new memset is inserted after the memcpy, but it is known that its
- // defining access is the memset about to be removed which immediately
- // precedes the memcpy.
+ // The new memset is inserted before the memcpy, and it is known that the
+ // memcpy's defining access is the memset about to be removed.
auto *LastDef =
cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy));
auto *NewAccess = MSSAU->createMemoryAccessBefore(
@@ -1440,6 +1426,217 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
return true;
}
+// Attempts to optimize the pattern whereby memory is copied from an alloca to
+// another alloca, where the two allocas don't have conflicting mod/ref. If
+// successful, the two allocas can be merged into one and the transfer can be
+// deleted. This pattern is generated frequently in Rust, due to the ubiquity of
+// move operations in that language.
+//
+// Once we determine that the optimization is safe to perform, we replace all
+// uses of the destination alloca with the source alloca. We also "shrink wrap"
+// the lifetime markers of the single merged alloca to before the first use
+// and after the last use. Note that the "shrink wrapping" procedure is a safe
+// transformation only because we restrict the scope of this optimization to
+// allocas that aren't captured.
+bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
+ AllocaInst *DestAlloca,
+ AllocaInst *SrcAlloca, uint64_t Size,
+ BatchAAResults &BAA) {
+ LLVM_DEBUG(dbgs() << "Stack Move: Attempting to optimize:\n"
+ << *Store << "\n");
+
+ // Make sure the two allocas are in the same address space.
+ if (SrcAlloca->getAddressSpace() != DestAlloca->getAddressSpace()) {
+ LLVM_DEBUG(dbgs() << "Stack Move: Address space mismatch\n");
+ return false;
+ }
+
+ // 1. Check that copy is full. Calculate the static size of the allocas to be
+ // merged, bail out if we can't.
+ const DataLayout &DL = DestAlloca->getModule()->getDataLayout();
+ std::optional<TypeSize> SrcSize = SrcAlloca->getAllocationSize(DL);
+ if (!SrcSize || SrcSize->isScalable() || Size != SrcSize->getFixedValue()) {
+ LLVM_DEBUG(dbgs() << "Stack Move: Source alloca size mismatch\n");
+ return false;
+ }
+ std::optional<TypeSize> DestSize = DestAlloca->getAllocationSize(DL);
+ if (!DestSize || DestSize->isScalable() ||
+ Size != DestSize->getFixedValue()) {
+ LLVM_DEBUG(dbgs() << "Stack Move: Destination alloca size mismatch\n");
+ return false;
+ }
+
+ // 2-1. Check that src and dest are static allocas, which are not affected by
+ // stacksave/stackrestore.
+ if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca() ||
+ SrcAlloca->getParent() != Load->getParent() ||
+ SrcAlloca->getParent() != Store->getParent())
+ return false;
+
+ // 2-2. Check that src and dest are never captured, unescaped allocas. Also
+ // collect lifetime markers first/last users in order to shrink wrap the
+ // lifetimes, and instructions with noalias metadata to remove them.
+
+ SmallVector<Instruction *, 4> LifetimeMarkers;
+ Instruction *FirstUser = nullptr, *LastUser = nullptr;
+ SmallSet<Instruction *, 4> NoAliasInstrs;
+
+ // Recursively track the user and check whether modified alias exist.
+ auto IsDereferenceableOrNull = [](Value *V, const DataLayout &DL) -> bool {
+ bool CanBeNull, CanBeFreed;
+ return V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
+ };
+
+ auto CaptureTrackingWithModRef =
+ [&](Instruction *AI,
+ function_ref<bool(Instruction *)> ModRefCallback) -> bool {
+ SmallVector<Instruction *, 8> Worklist;
+ Worklist.push_back(AI);
+ unsigned MaxUsesToExplore = getDefaultMaxUsesToExploreForCaptureTracking();
+ Worklist.reserve(MaxUsesToExplore);
+ SmallSet<const Use *, 20> Visited;
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.back();
+ Worklist.pop_back();
+ for (const Use &U : I->uses()) {
+ if (Visited.size() >= MaxUsesToExplore) {
+ LLVM_DEBUG(
+ dbgs()
+ << "Stack Move: Exceeded max uses to see ModRef, bailing\n");
+ return false;
+ }
+ if (!Visited.insert(&U).second)
+ continue;
+ switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) {
+ case UseCaptureKind::MAY_CAPTURE:
+ return false;
+ case UseCaptureKind::PASSTHROUGH:
+ // Instructions cannot have non-instruction users.
+ Worklist.push_back(cast<Instruction>(U.getUser()));
+ continue;
+ case UseCaptureKind::NO_CAPTURE: {
+ auto *UI = cast<Instruction>(U.getUser());
+ if (DestAlloca->getParent() != UI->getParent())
+ return false;
+ if (!FirstUser || UI->comesBefore(FirstUser))
+ FirstUser = UI;
+ if (!LastUser || LastUser->comesBefore(UI))
+ LastUser = UI;
+ if (UI->isLifetimeStartOrEnd()) {
+ // We note the locations of these intrinsic calls so that we can
+ // delete them later if the optimization succeeds, this is safe
+ // since both llvm.lifetime.start and llvm.lifetime.end intrinsics
+ // conceptually fill all the bytes of the alloca with an undefined
+ // value.
+ int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue();
+ if (Size < 0 || Size == DestSize) {
+ LifetimeMarkers.push_back(UI);
+ continue;
+ }
+ }
+ if (UI->hasMetadata(LLVMContext::MD_noalias))
+ NoAliasInstrs.insert(UI);
+ if (!ModRefCallback(UI))
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+ };
+
+ // 3. Check that dest has no Mod/Ref, except full size lifetime intrinsics,
+ // from the alloca to the Store.
+ ModRefInfo DestModRef = ModRefInfo::NoModRef;
+ MemoryLocation DestLoc(DestAlloca, LocationSize::precise(Size));
+ auto DestModRefCallback = [&](Instruction *UI) -> bool {
+ // We don't care about the store itself.
+ if (UI == Store)
+ return true;
+ ModRefInfo Res = BAA.getModRefInfo(UI, DestLoc);
+ // FIXME: For multi-BB cases, we need to see reachability from it to
+ // store.
+ // Bailout if Dest may have any ModRef before Store.
+ if (UI->comesBefore(Store) && isModOrRefSet(Res))
+ return false;
+ DestModRef |= BAA.getModRefInfo(UI, DestLoc);
+
+ return true;
+ };
+
+ if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback))
+ return false;
+
+ // 3. Check that, from after the Load to the end of the BB,
+ // 3-1. if the dest has any Mod, src has no Ref, and
+ // 3-2. if the dest has any Ref, src has no Mod except full-sized lifetimes.
+ MemoryLocation SrcLoc(SrcAlloca, LocationSize::precise(Size));
+
+ auto SrcModRefCallback = [&](Instruction *UI) -> bool {
+ // Any ModRef before Load doesn't matter, also Load and Store can be
+ // ignored.
+ if (UI->comesBefore(Load) || UI == Load || UI == Store)
+ return true;
+ ModRefInfo Res = BAA.getModRefInfo(UI, SrcLoc);
+ if ((isModSet(DestModRef) && isRefSet(Res)) ||
+ (isRefSet(DestModRef) && isModSet(Res)))
+ return false;
+
+ return true;
+ };
+
+ if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback))
+ return false;
+
+ // We can do the transformation. First, align the allocas appropriately.
+ SrcAlloca->setAlignment(
+ std::max(SrcAlloca->getAlign(), DestAlloca->getAlign()));
+
+ // Merge the two allocas.
+ DestAlloca->replaceAllUsesWith(SrcAlloca);
+ eraseInstruction(DestAlloca);
+
+ // Drop metadata on the source alloca.
+ SrcAlloca->dropUnknownNonDebugMetadata();
+
+ // Do "shrink wrap" the lifetimes, if the original lifetime intrinsics exists.
+ if (!LifetimeMarkers.empty()) {
+ LLVMContext &C = SrcAlloca->getContext();
+ IRBuilder<> Builder(C);
+
+ ConstantInt *AllocaSize = ConstantInt::get(Type::getInt64Ty(C), Size);
+ // Create a new lifetime start marker before the first user of src or alloca
+ // users.
+ Builder.SetInsertPoint(FirstUser->getParent(), FirstUser->getIterator());
+ Builder.CreateLifetimeStart(SrcAlloca, AllocaSize);
+
+ // Create a new lifetime end marker after the last user of src or alloca
+ // users.
+ // FIXME: If the last user is the terminator for the bb, we can insert
+ // lifetime.end marker to the immidiate post-dominator, but currently do
+ // nothing.
+ if (!LastUser->isTerminator()) {
+ Builder.SetInsertPoint(LastUser->getParent(), ++LastUser->getIterator());
+ Builder.CreateLifetimeEnd(SrcAlloca, AllocaSize);
+ }
+
+ // Remove all other lifetime markers.
+ for (Instruction *I : LifetimeMarkers)
+ eraseInstruction(I);
+ }
+
+ // As this transformation can cause memory accesses that didn't previously
+ // alias to begin to alias one another, we remove !noalias metadata from any
+ // uses of either alloca. This is conservative, but more precision doesn't
+ // seem worthwhile right now.
+ for (Instruction *I : NoAliasInstrs)
+ I->setMetadata(LLVMContext::MD_noalias, nullptr);
+
+ LLVM_DEBUG(dbgs() << "Stack Move: Performed staack-move optimization\n");
+ NumStackMove++;
+ return true;
+}
+
/// Perform simplification of memcpy's. If we have memcpy A
/// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite
/// B to be a memcpy from X to Z (or potentially a memmove, depending on
@@ -1484,8 +1681,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc, BAA);
// Try to turn a partially redundant memset + memcpy into
- // memcpy + smaller memset. We don't need the memcpy size for this.
- // The memcpy most post-dom the memset, so limit this to the same basic
+ // smaller memset + memcpy. We don't need the memcpy size for this.
+ // The memcpy must post-dom the memset, so limit this to the same basic
// block. A non-local generalization is likely not worthwhile.
if (auto *MD = dyn_cast<MemoryDef>(DestClobber))
if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst()))
@@ -1496,13 +1693,14 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess(
AnyClobber, MemoryLocation::getForSource(M), BAA);
- // There are four possible optimizations we can do for memcpy:
+ // There are five possible optimizations we can do for memcpy:
// a) memcpy-memcpy xform which exposes redundance for DSE.
// b) call-memcpy xform for return slot optimization.
// c) memcpy from freshly alloca'd space or space that has just started
// its lifetime copies undefined data, and we can therefore eliminate
// the memcpy in favor of the data that was already at the destination.
// d) memcpy from a just-memset'd source can be turned into memset.
+ // e) elimination of memcpy via stack-move optimization.
if (auto *MD = dyn_cast<MemoryDef>(SrcClobber)) {
if (Instruction *MI = MD->getMemoryInst()) {
if (auto *CopySize = dyn_cast<ConstantInt>(M->getLength())) {
@@ -1521,7 +1719,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
}
}
if (auto *MDep = dyn_cast<MemCpyInst>(MI))
- return processMemCpyMemCpyDependence(M, MDep, BAA);
+ if (processMemCpyMemCpyDependence(M, MDep, BAA))
+ return true;
if (auto *MDep = dyn_cast<MemSetInst>(MI)) {
if (performMemCpyToMemSetOptzn(M, MDep, BAA)) {
LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n");
@@ -1540,6 +1739,27 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
}
}
+ // If the transfer is from a stack slot to a stack slot, then we may be able
+ // to perform the stack-move optimization. See the comments in
+ // performStackMoveOptzn() for more details.
+ auto *DestAlloca = dyn_cast<AllocaInst>(M->getDest());
+ if (!DestAlloca)
+ return false;
+ auto *SrcAlloca = dyn_cast<AllocaInst>(M->getSource());
+ if (!SrcAlloca)
+ return false;
+ ConstantInt *Len = dyn_cast<ConstantInt>(M->getLength());
+ if (Len == nullptr)
+ return false;
+ if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca, Len->getZExtValue(),
+ BAA)) {
+ // Avoid invalidating the iterator.
+ BBI = M->getNextNonDebugInstruction()->getIterator();
+ eraseInstruction(M);
+ ++NumMemCpyInstr;
+ return true;
+ }
+
return false;
}
@@ -1623,24 +1843,110 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
// foo(*a)
// It would be invalid to transform the second memcpy into foo(*b).
if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
- MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB)))
+ MSSA->getMemoryAccess(MDep), CallAccess))
return false;
- Value *TmpCast = MDep->getSource();
- if (MDep->getSource()->getType() != ByValArg->getType()) {
- BitCastInst *TmpBitCast = new BitCastInst(MDep->getSource(), ByValArg->getType(),
- "tmpcast", &CB);
- // Set the tmpcast's DebugLoc to MDep's
- TmpBitCast->setDebugLoc(MDep->getDebugLoc());
- TmpCast = TmpBitCast;
- }
-
LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n"
<< " " << *MDep << "\n"
<< " " << CB << "\n");
// Otherwise we're good! Update the byval argument.
- CB.setArgOperand(ArgNo, TmpCast);
+ CB.setArgOperand(ArgNo, MDep->getSource());
+ ++NumMemCpyInstr;
+ return true;
+}
+
+/// This is called on memcpy dest pointer arguments attributed as immutable
+/// during call. Try to use memcpy source directly if all of the following
+/// conditions are satisfied.
+/// 1. The memcpy dst is neither modified during the call nor captured by the
+/// call. (if readonly, noalias, nocapture attributes on call-site.)
+/// 2. The memcpy dst is an alloca with known alignment & size.
+/// 2-1. The memcpy length == the alloca size which ensures that the new
+/// pointer is dereferenceable for the required range
+/// 2-2. The src pointer has alignment >= the alloca alignment or can be
+/// enforced so.
+/// 3. The memcpy dst and src is not modified between the memcpy and the call.
+/// (if MSSA clobber check is safe.)
+/// 4. The memcpy src is not modified during the call. (ModRef check shows no
+/// Mod.)
+bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) {
+ // 1. Ensure passed argument is immutable during call.
+ if (!(CB.paramHasAttr(ArgNo, Attribute::NoAlias) &&
+ CB.paramHasAttr(ArgNo, Attribute::NoCapture)))
+ return false;
+ const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout();
+ Value *ImmutArg = CB.getArgOperand(ArgNo);
+
+ // 2. Check that arg is alloca
+ // TODO: Even if the arg gets back to branches, we can remove memcpy if all
+ // the alloca alignments can be enforced to source alignment.
+ auto *AI = dyn_cast<AllocaInst>(ImmutArg->stripPointerCasts());
+ if (!AI)
+ return false;
+
+ std::optional<TypeSize> AllocaSize = AI->getAllocationSize(DL);
+ // Can't handle unknown size alloca.
+ // (e.g. Variable Length Array, Scalable Vector)
+ if (!AllocaSize || AllocaSize->isScalable())
+ return false;
+ MemoryLocation Loc(ImmutArg, LocationSize::precise(*AllocaSize));
+ MemoryUseOrDef *CallAccess = MSSA->getMemoryAccess(&CB);
+ if (!CallAccess)
+ return false;
+
+ MemCpyInst *MDep = nullptr;
+ BatchAAResults BAA(*AA);
+ MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
+ CallAccess->getDefiningAccess(), Loc, BAA);
+ if (auto *MD = dyn_cast<MemoryDef>(Clobber))
+ MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst());
+
+ // If the immut argument isn't fed by a memcpy, ignore it. If it is fed by
+ // a memcpy, check that the arg equals the memcpy dest.
+ if (!MDep || MDep->isVolatile() || AI != MDep->getDest())
+ return false;
+
+ // The address space of the memcpy source must match the immut argument
+ if (MDep->getSource()->getType()->getPointerAddressSpace() !=
+ ImmutArg->getType()->getPointerAddressSpace())
+ return false;
+
+ // 2-1. The length of the memcpy must be equal to the size of the alloca.
+ auto *MDepLen = dyn_cast<ConstantInt>(MDep->getLength());
+ if (!MDepLen || AllocaSize != MDepLen->getValue())
+ return false;
+
+ // 2-2. the memcpy source align must be larger than or equal the alloca's
+ // align. If not so, we check to see if we can force the source of the memcpy
+ // to the alignment we need. If we fail, we bail out.
+ Align MemDepAlign = MDep->getSourceAlign().valueOrOne();
+ Align AllocaAlign = AI->getAlign();
+ if (MemDepAlign < AllocaAlign &&
+ getOrEnforceKnownAlignment(MDep->getSource(), AllocaAlign, DL, &CB, AC,
+ DT) < AllocaAlign)
+ return false;
+
+ // 3. Verify that the source doesn't change in between the memcpy and
+ // the call.
+ // memcpy(a <- b)
+ // *b = 42;
+ // foo(*a)
+ // It would be invalid to transform the second memcpy into foo(*b).
+ if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
+ MSSA->getMemoryAccess(MDep), CallAccess))
+ return false;
+
+ // 4. The memcpy src must not be modified during the call.
+ if (isModSet(AA->getModRefInfo(&CB, MemoryLocation::getForSource(MDep))))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to Immut src:\n"
+ << " " << *MDep << "\n"
+ << " " << CB << "\n");
+
+ // Otherwise we're good! Update the immut argument.
+ CB.setArgOperand(ArgNo, MDep->getSource());
++NumMemCpyInstr;
return true;
}
@@ -1673,9 +1979,12 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {
else if (auto *M = dyn_cast<MemMoveInst>(I))
RepeatInstruction = processMemMove(M);
else if (auto *CB = dyn_cast<CallBase>(I)) {
- for (unsigned i = 0, e = CB->arg_size(); i != e; ++i)
+ for (unsigned i = 0, e = CB->arg_size(); i != e; ++i) {
if (CB->isByValArgument(i))
MadeChange |= processByValArgument(*CB, i);
+ else if (CB->onlyReadsMemory(i))
+ MadeChange |= processImmutArgument(*CB, i);
+ }
}
// Reprocess the instruction if desired.
@@ -1730,17 +2039,3 @@ bool MemCpyOptPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
return MadeChange;
}
-
-/// This is the main transformation entry point for a function.
-bool MemCpyOptLegacyPass::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
-
- auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
-
- return Impl.runImpl(F, TLI, AA, AC, DT, MSSA);
-}
diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index bcedb05890af..311a6435ba7c 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -42,6 +42,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/MergeICmps.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/Loads.h"
@@ -49,6 +50,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
@@ -157,7 +159,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
return {};
}
- APInt Offset = APInt(DL.getPointerTypeSizeInBits(Addr->getType()), 0);
+ APInt Offset = APInt(DL.getIndexTypeSizeInBits(Addr->getType()), 0);
Value *Base = Addr;
auto *GEP = dyn_cast<GetElementPtrInst>(Addr);
if (GEP) {
@@ -639,10 +641,11 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
if (Comparisons.size() == 1) {
LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
- Value *const LhsLoad =
- Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs);
- Value *const RhsLoad =
- Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
+ // Use clone to keep the metadata
+ Instruction *const LhsLoad = Builder.Insert(FirstCmp.Lhs().LoadI->clone());
+ Instruction *const RhsLoad = Builder.Insert(FirstCmp.Rhs().LoadI->clone());
+ LhsLoad->replaceUsesOfWith(LhsLoad->getOperand(0), Lhs);
+ RhsLoad->replaceUsesOfWith(RhsLoad->getOperand(0), Rhs);
// There are no blocks to merge, just do the comparison.
IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
} else {
diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
index 62e75d98448c..6c5453831ade 100644
--- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
+++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
@@ -78,6 +78,7 @@
#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
@@ -191,11 +192,16 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1,
MemoryLocation Loc0 = MemoryLocation::get(Store0);
MemoryLocation Loc1 = MemoryLocation::get(Store1);
- if (AA->isMustAlias(Loc0, Loc1) && Store0->isSameOperationAs(Store1) &&
+
+ if (AA->isMustAlias(Loc0, Loc1) &&
!isStoreSinkBarrierInRange(*Store1->getNextNode(), BB1->back(), Loc1) &&
- !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0)) {
+ !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0) &&
+ Store0->hasSameSpecialState(Store1) &&
+ CastInst::isBitOrNoopPointerCastable(
+ Store0->getValueOperand()->getType(),
+ Store1->getValueOperand()->getType(),
+ Store0->getModule()->getDataLayout()))
return Store1;
- }
}
return nullptr;
}
@@ -254,6 +260,13 @@ void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0,
S0->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc());
S0->mergeDIAssignID(S1);
+ // Insert bitcast for conflicting typed stores (or just use original value if
+ // same type).
+ IRBuilder<> Builder(S0);
+ auto Cast = Builder.CreateBitOrPointerCast(S0->getValueOperand(),
+ S1->getValueOperand()->getType());
+ S0->setOperand(0, Cast);
+
// Create the new store to be inserted at the join point.
StoreInst *SNew = cast<StoreInst>(S0->clone());
SNew->insertBefore(&*InsertPt);
@@ -428,7 +441,7 @@ void MergedLoadStoreMotionPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<MergedLoadStoreMotionPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
OS << (Options.SplitFooterBB ? "" : "no-") << "split-footer-bb";
- OS << ">";
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
index 19bee4fa3879..9c3e9a2fd018 100644
--- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -351,9 +351,9 @@ Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) {
bool NaryReassociatePass::requiresSignExtension(Value *Index,
GetElementPtrInst *GEP) {
- unsigned PointerSizeInBits =
- DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace());
- return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits;
+ unsigned IndexSizeInBits =
+ DL->getIndexSizeInBits(GEP->getType()->getPointerAddressSpace());
+ return cast<IntegerType>(Index->getType())->getBitWidth() < IndexSizeInBits;
}
GetElementPtrInst *
@@ -449,12 +449,12 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
return nullptr;
// NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0])));
- Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
- if (RHS->getType() != IntPtrTy)
- RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy);
+ Type *PtrIdxTy = DL->getIndexType(GEP->getType());
+ if (RHS->getType() != PtrIdxTy)
+ RHS = Builder.CreateSExtOrTrunc(RHS, PtrIdxTy);
if (IndexedSize != ElementSize) {
RHS = Builder.CreateMul(
- RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize));
+ RHS, ConstantInt::get(PtrIdxTy, IndexedSize / ElementSize));
}
GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(
Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS));
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index d3dba0c5f1d5..1af40e2c4e62 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -93,8 +93,6 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ArrayRecycler.h"
#include "llvm/Support/Casting.h"
@@ -104,7 +102,6 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVNExpression.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -1277,10 +1274,17 @@ const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) const {
const CallExpression *
NewGVN::createCallExpression(CallInst *CI, const MemoryAccess *MA) const {
// FIXME: Add operand bundles for calls.
- // FIXME: Allow commutative matching for intrinsics.
auto *E =
new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA);
setBasicExpressionInfo(CI, E);
+ if (CI->isCommutative()) {
+ // Ensure that commutative intrinsics that only differ by a permutation
+ // of their operands get the same value number by sorting the operand value
+ // numbers.
+ assert(CI->getNumOperands() >= 2 && "Unsupported commutative intrinsic!");
+ if (shouldSwapOperands(E->getOperand(0), E->getOperand(1)))
+ E->swapOperands(0, 1);
+ }
return E;
}
@@ -1453,8 +1457,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr,
if (Offset >= 0) {
if (auto *C = dyn_cast<Constant>(
lookupOperandLeader(DepSI->getValueOperand()))) {
- if (Constant *Res =
- getConstantStoreValueForLoad(C, Offset, LoadType, DL)) {
+ if (Constant *Res = getConstantValueForLoad(C, Offset, LoadType, DL)) {
LLVM_DEBUG(dbgs() << "Coercing load from store " << *DepSI
<< " to constant " << *Res << "\n");
return createConstantExpression(Res);
@@ -1470,7 +1473,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr,
// We can coerce a constant load into a load.
if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI)))
if (auto *PossibleConstant =
- getConstantLoadValueForLoad(C, Offset, LoadType, DL)) {
+ getConstantValueForLoad(C, Offset, LoadType, DL)) {
LLVM_DEBUG(dbgs() << "Coercing load from load " << *LI
<< " to constant " << *PossibleConstant << "\n");
return createConstantExpression(PossibleConstant);
@@ -1617,6 +1620,12 @@ NewGVN::ExprResult NewGVN::performSymbolicCallEvaluation(Instruction *I) const {
if (CI->getFunction()->isPresplitCoroutine())
return ExprResult::none();
+ // Do not combine convergent calls since they implicitly depend on the set of
+ // threads that is currently executing, and they might be in different basic
+ // blocks.
+ if (CI->isConvergent())
+ return ExprResult::none();
+
if (AA->doesNotAccessMemory(CI)) {
return ExprResult::some(
createCallExpression(CI, TOPClass->getMemoryLeader()));
@@ -1992,6 +2001,7 @@ NewGVN::performSymbolicEvaluation(Value *V,
break;
case Instruction::BitCast:
case Instruction::AddrSpaceCast:
+ case Instruction::Freeze:
return createExpression(I);
break;
case Instruction::ICmp:
@@ -2739,10 +2749,10 @@ NewGVN::makePossiblePHIOfOps(Instruction *I,
return nullptr;
}
// No point in doing this for one-operand phis.
- if (OpPHI->getNumOperands() == 1) {
- OpPHI = nullptr;
- continue;
- }
+ // Since all PHIs for operands must be in the same block, then they must
+ // have the same number of operands so we can just abort.
+ if (OpPHI->getNumOperands() == 1)
+ return nullptr;
}
if (!OpPHI)
@@ -3712,9 +3722,10 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {
}
// Now insert something that simplifycfg will turn into an unreachable.
Type *Int8Ty = Type::getInt8Ty(BB->getContext());
- new StoreInst(PoisonValue::get(Int8Ty),
- Constant::getNullValue(Int8Ty->getPointerTo()),
- BB->getTerminator());
+ new StoreInst(
+ PoisonValue::get(Int8Ty),
+ Constant::getNullValue(PointerType::getUnqual(BB->getContext())),
+ BB->getTerminator());
}
void NewGVN::markInstructionForDeletion(Instruction *I) {
@@ -4208,61 +4219,6 @@ bool NewGVN::shouldSwapOperandsForIntrinsic(const Value *A, const Value *B,
return false;
}
-namespace {
-
-class NewGVNLegacyPass : public FunctionPass {
-public:
- // Pass identification, replacement for typeid.
- static char ID;
-
- NewGVNLegacyPass() : FunctionPass(ID) {
- initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
-private:
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-bool NewGVNLegacyPass::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
- return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
- &getAnalysis<AAResultsWrapperPass>().getAAResults(),
- &getAnalysis<MemorySSAWrapperPass>().getMSSA(),
- F.getParent()->getDataLayout())
- .runGVN();
-}
-
-char NewGVNLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false,
- false)
-
-// createGVNPass - The public interface to this file.
-FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); }
-
PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) {
// Apparently the order in which we get these results matter for
// the old GVN (see Chandler's comment in GVN.cpp). I'll keep
diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
index e1cc3fc71c3e..0266eb1a9f50 100644
--- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
+++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp
@@ -47,6 +47,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/Scalar/PlaceSafepoints.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
@@ -67,7 +68,9 @@
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
-#define DEBUG_TYPE "safepoint-placement"
+using namespace llvm;
+
+#define DEBUG_TYPE "place-safepoints"
STATISTIC(NumEntrySafepoints, "Number of entry safepoints inserted");
STATISTIC(NumBackedgeSafepoints, "Number of backedge safepoints inserted");
@@ -77,8 +80,6 @@ STATISTIC(CallInLoop,
STATISTIC(FiniteExecution,
"Number of loops without safepoints finite execution");
-using namespace llvm;
-
// Ignore opportunities to avoid placing safepoints on backedges, useful for
// validation
static cl::opt<bool> AllBackedges("spp-all-backedges", cl::Hidden,
@@ -97,10 +98,10 @@ static cl::opt<bool> SplitBackedge("spp-split-backedge", cl::Hidden,
cl::init(false));
namespace {
-
/// An analysis pass whose purpose is to identify each of the backedges in
/// the function which require a safepoint poll to be inserted.
-struct PlaceBackedgeSafepointsImpl : public FunctionPass {
+class PlaceBackedgeSafepointsLegacyPass : public FunctionPass {
+public:
static char ID;
/// The output of the pass - gives a list of each backedge (described by
@@ -111,17 +112,14 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass {
/// the call-dependent placement opts.
bool CallSafepointsEnabled;
- ScalarEvolution *SE = nullptr;
- DominatorTree *DT = nullptr;
- LoopInfo *LI = nullptr;
- TargetLibraryInfo *TLI = nullptr;
-
- PlaceBackedgeSafepointsImpl(bool CallSafepoints = false)
+ PlaceBackedgeSafepointsLegacyPass(bool CallSafepoints = false)
: FunctionPass(ID), CallSafepointsEnabled(CallSafepoints) {
- initializePlaceBackedgeSafepointsImplPass(*PassRegistry::getPassRegistry());
+ initializePlaceBackedgeSafepointsLegacyPassPass(
+ *PassRegistry::getPassRegistry());
}
bool runOnLoop(Loop *);
+
void runOnLoopAndSubLoops(Loop *L) {
// Visit all the subloops
for (Loop *I : *L)
@@ -149,39 +147,245 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass {
// analysis are preserved.
AU.setPreservesAll();
}
+
+private:
+ ScalarEvolution *SE = nullptr;
+ DominatorTree *DT = nullptr;
+ LoopInfo *LI = nullptr;
+ TargetLibraryInfo *TLI = nullptr;
};
-}
+} // namespace
static cl::opt<bool> NoEntry("spp-no-entry", cl::Hidden, cl::init(false));
static cl::opt<bool> NoCall("spp-no-call", cl::Hidden, cl::init(false));
static cl::opt<bool> NoBackedge("spp-no-backedge", cl::Hidden, cl::init(false));
-namespace {
-struct PlaceSafepoints : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
+char PlaceBackedgeSafepointsLegacyPass::ID = 0;
- PlaceSafepoints() : FunctionPass(ID) {
- initializePlaceSafepointsPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override;
+INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsLegacyPass,
+ "place-backedge-safepoints-impl",
+ "Place Backedge Safepoints", false, false)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(PlaceBackedgeSafepointsLegacyPass,
+ "place-backedge-safepoints-impl",
+ "Place Backedge Safepoints", false, false)
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- // We modify the graph wholesale (inlining, block insertion, etc). We
- // preserve nothing at the moment. We could potentially preserve dom tree
- // if that was worth doing
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- }
-};
-}
+static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header,
+ BasicBlock *Pred,
+ DominatorTree &DT,
+ const TargetLibraryInfo &TLI);
+
+static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE,
+ BasicBlock *Pred);
+
+static Instruction *findLocationForEntrySafepoint(Function &F,
+ DominatorTree &DT);
+
+static bool isGCSafepointPoll(Function &F);
+static bool shouldRewriteFunction(Function &F);
+static bool enableEntrySafepoints(Function &F);
+static bool enableBackedgeSafepoints(Function &F);
+static bool enableCallSafepoints(Function &F);
-// Insert a safepoint poll immediately before the given instruction. Does
-// not handle the parsability of state at the runtime call, that's the
-// callers job.
static void
InsertSafepointPoll(Instruction *InsertBefore,
std::vector<CallBase *> &ParsePointsNeeded /*rval*/,
const TargetLibraryInfo &TLI);
+bool PlaceBackedgeSafepointsLegacyPass::runOnLoop(Loop *L) {
+ // Loop through all loop latches (branches controlling backedges). We need
+ // to place a safepoint on every backedge (potentially).
+ // Note: In common usage, there will be only one edge due to LoopSimplify
+ // having run sometime earlier in the pipeline, but this code must be correct
+ // w.r.t. loops with multiple backedges.
+ BasicBlock *Header = L->getHeader();
+ SmallVector<BasicBlock *, 16> LoopLatches;
+ L->getLoopLatches(LoopLatches);
+ for (BasicBlock *Pred : LoopLatches) {
+ assert(L->contains(Pred));
+
+ // Make a policy decision about whether this loop needs a safepoint or
+ // not. Note that this is about unburdening the optimizer in loops, not
+ // avoiding the runtime cost of the actual safepoint.
+ if (!AllBackedges) {
+ if (mustBeFiniteCountedLoop(L, SE, Pred)) {
+ LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n");
+ FiniteExecution++;
+ continue;
+ }
+ if (CallSafepointsEnabled &&
+ containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) {
+ // Note: This is only semantically legal since we won't do any further
+ // IPO or inlining before the actual call insertion.. If we hadn't, we
+ // might latter loose this call safepoint.
+ LLVM_DEBUG(
+ dbgs()
+ << "skipping safepoint placement due to unconditional call\n");
+ CallInLoop++;
+ continue;
+ }
+ }
+
+ // TODO: We can create an inner loop which runs a finite number of
+ // iterations with an outer loop which contains a safepoint. This would
+ // not help runtime performance that much, but it might help our ability to
+ // optimize the inner loop.
+
+ // Safepoint insertion would involve creating a new basic block (as the
+ // target of the current backedge) which does the safepoint (of all live
+ // variables) and branches to the true header
+ Instruction *Term = Pred->getTerminator();
+
+ LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term);
+
+ PollLocations.push_back(Term);
+ }
+
+ return false;
+}
+
+bool PlaceSafepointsPass::runImpl(Function &F, const TargetLibraryInfo &TLI) {
+ if (F.isDeclaration() || F.empty()) {
+ // This is a declaration, nothing to do. Must exit early to avoid crash in
+ // dom tree calculation
+ return false;
+ }
+
+ if (isGCSafepointPoll(F)) {
+ // Given we're inlining this inside of safepoint poll insertion, this
+ // doesn't make any sense. Note that we do make any contained calls
+ // parseable after we inline a poll.
+ return false;
+ }
+
+ if (!shouldRewriteFunction(F))
+ return false;
+
+ bool Modified = false;
+
+ // In various bits below, we rely on the fact that uses are reachable from
+ // defs. When there are basic blocks unreachable from the entry, dominance
+ // and reachablity queries return non-sensical results. Thus, we preprocess
+ // the function to ensure these properties hold.
+ Modified |= removeUnreachableBlocks(F);
+
+ // STEP 1 - Insert the safepoint polling locations. We do not need to
+ // actually insert parse points yet. That will be done for all polls and
+ // calls in a single pass.
+
+ DominatorTree DT;
+ DT.recalculate(F);
+
+ SmallVector<Instruction *, 16> PollsNeeded;
+ std::vector<CallBase *> ParsePointNeeded;
+
+ if (enableBackedgeSafepoints(F)) {
+ // Construct a pass manager to run the LoopPass backedge logic. We
+ // need the pass manager to handle scheduling all the loop passes
+ // appropriately. Doing this by hand is painful and just not worth messing
+ // with for the moment.
+ legacy::FunctionPassManager FPM(F.getParent());
+ bool CanAssumeCallSafepoints = enableCallSafepoints(F);
+ auto *PBS = new PlaceBackedgeSafepointsLegacyPass(CanAssumeCallSafepoints);
+ FPM.add(PBS);
+ FPM.run(F);
+
+ // We preserve dominance information when inserting the poll, otherwise
+ // we'd have to recalculate this on every insert
+ DT.recalculate(F);
+
+ auto &PollLocations = PBS->PollLocations;
+
+ auto OrderByBBName = [](Instruction *a, Instruction *b) {
+ return a->getParent()->getName() < b->getParent()->getName();
+ };
+ // We need the order of list to be stable so that naming ends up stable
+ // when we split edges. This makes test cases much easier to write.
+ llvm::sort(PollLocations, OrderByBBName);
+
+ // We can sometimes end up with duplicate poll locations. This happens if
+ // a single loop is visited more than once. The fact this happens seems
+ // wrong, but it does happen for the split-backedge.ll test case.
+ PollLocations.erase(std::unique(PollLocations.begin(), PollLocations.end()),
+ PollLocations.end());
+
+ // Insert a poll at each point the analysis pass identified
+ // The poll location must be the terminator of a loop latch block.
+ for (Instruction *Term : PollLocations) {
+ // We are inserting a poll, the function is modified
+ Modified = true;
+
+ if (SplitBackedge) {
+ // Split the backedge of the loop and insert the poll within that new
+ // basic block. This creates a loop with two latches per original
+ // latch (which is non-ideal), but this appears to be easier to
+ // optimize in practice than inserting the poll immediately before the
+ // latch test.
+
+ // Since this is a latch, at least one of the successors must dominate
+ // it. Its possible that we have a) duplicate edges to the same header
+ // and b) edges to distinct loop headers. We need to insert pools on
+ // each.
+ SetVector<BasicBlock *> Headers;
+ for (unsigned i = 0; i < Term->getNumSuccessors(); i++) {
+ BasicBlock *Succ = Term->getSuccessor(i);
+ if (DT.dominates(Succ, Term->getParent())) {
+ Headers.insert(Succ);
+ }
+ }
+ assert(!Headers.empty() && "poll location is not a loop latch?");
+
+ // The split loop structure here is so that we only need to recalculate
+ // the dominator tree once. Alternatively, we could just keep it up to
+ // date and use a more natural merged loop.
+ SetVector<BasicBlock *> SplitBackedges;
+ for (BasicBlock *Header : Headers) {
+ BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT);
+ PollsNeeded.push_back(NewBB->getTerminator());
+ NumBackedgeSafepoints++;
+ }
+ } else {
+ // Split the latch block itself, right before the terminator.
+ PollsNeeded.push_back(Term);
+ NumBackedgeSafepoints++;
+ }
+ }
+ }
+
+ if (enableEntrySafepoints(F)) {
+ if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) {
+ PollsNeeded.push_back(Location);
+ Modified = true;
+ NumEntrySafepoints++;
+ }
+ // TODO: else we should assert that there was, in fact, a policy choice to
+ // not insert a entry safepoint poll.
+ }
+
+ // Now that we've identified all the needed safepoint poll locations, insert
+ // safepoint polls themselves.
+ for (Instruction *PollLocation : PollsNeeded) {
+ std::vector<CallBase *> RuntimeCalls;
+ InsertSafepointPoll(PollLocation, RuntimeCalls, TLI);
+ llvm::append_range(ParsePointNeeded, RuntimeCalls);
+ }
+
+ return Modified;
+}
+
+PreservedAnalyses PlaceSafepointsPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
+
+ if (!runImpl(F, TLI))
+ return PreservedAnalyses::all();
+
+ // TODO: can we preserve more?
+ return PreservedAnalyses::none();
+}
+
static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) {
if (callsGCLeafFunction(Call, TLI))
return false;
@@ -306,58 +510,6 @@ static void scanInlinedCode(Instruction *Start, Instruction *End,
}
}
-bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) {
- // Loop through all loop latches (branches controlling backedges). We need
- // to place a safepoint on every backedge (potentially).
- // Note: In common usage, there will be only one edge due to LoopSimplify
- // having run sometime earlier in the pipeline, but this code must be correct
- // w.r.t. loops with multiple backedges.
- BasicBlock *Header = L->getHeader();
- SmallVector<BasicBlock*, 16> LoopLatches;
- L->getLoopLatches(LoopLatches);
- for (BasicBlock *Pred : LoopLatches) {
- assert(L->contains(Pred));
-
- // Make a policy decision about whether this loop needs a safepoint or
- // not. Note that this is about unburdening the optimizer in loops, not
- // avoiding the runtime cost of the actual safepoint.
- if (!AllBackedges) {
- if (mustBeFiniteCountedLoop(L, SE, Pred)) {
- LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n");
- FiniteExecution++;
- continue;
- }
- if (CallSafepointsEnabled &&
- containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) {
- // Note: This is only semantically legal since we won't do any further
- // IPO or inlining before the actual call insertion.. If we hadn't, we
- // might latter loose this call safepoint.
- LLVM_DEBUG(
- dbgs()
- << "skipping safepoint placement due to unconditional call\n");
- CallInLoop++;
- continue;
- }
- }
-
- // TODO: We can create an inner loop which runs a finite number of
- // iterations with an outer loop which contains a safepoint. This would
- // not help runtime performance that much, but it might help our ability to
- // optimize the inner loop.
-
- // Safepoint insertion would involve creating a new basic block (as the
- // target of the current backedge) which does the safepoint (of all live
- // variables) and branches to the true header
- Instruction *Term = Pred->getTerminator();
-
- LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term);
-
- PollLocations.push_back(Term);
- }
-
- return false;
-}
-
/// Returns true if an entry safepoint is not required before this callsite in
/// the caller function.
static bool doesNotRequireEntrySafepointBefore(CallBase *Call) {
@@ -463,161 +615,9 @@ static bool enableEntrySafepoints(Function &F) { return !NoEntry; }
static bool enableBackedgeSafepoints(Function &F) { return !NoBackedge; }
static bool enableCallSafepoints(Function &F) { return !NoCall; }
-bool PlaceSafepoints::runOnFunction(Function &F) {
- if (F.isDeclaration() || F.empty()) {
- // This is a declaration, nothing to do. Must exit early to avoid crash in
- // dom tree calculation
- return false;
- }
-
- if (isGCSafepointPoll(F)) {
- // Given we're inlining this inside of safepoint poll insertion, this
- // doesn't make any sense. Note that we do make any contained calls
- // parseable after we inline a poll.
- return false;
- }
-
- if (!shouldRewriteFunction(F))
- return false;
-
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
-
- bool Modified = false;
-
- // In various bits below, we rely on the fact that uses are reachable from
- // defs. When there are basic blocks unreachable from the entry, dominance
- // and reachablity queries return non-sensical results. Thus, we preprocess
- // the function to ensure these properties hold.
- Modified |= removeUnreachableBlocks(F);
-
- // STEP 1 - Insert the safepoint polling locations. We do not need to
- // actually insert parse points yet. That will be done for all polls and
- // calls in a single pass.
-
- DominatorTree DT;
- DT.recalculate(F);
-
- SmallVector<Instruction *, 16> PollsNeeded;
- std::vector<CallBase *> ParsePointNeeded;
-
- if (enableBackedgeSafepoints(F)) {
- // Construct a pass manager to run the LoopPass backedge logic. We
- // need the pass manager to handle scheduling all the loop passes
- // appropriately. Doing this by hand is painful and just not worth messing
- // with for the moment.
- legacy::FunctionPassManager FPM(F.getParent());
- bool CanAssumeCallSafepoints = enableCallSafepoints(F);
- auto *PBS = new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints);
- FPM.add(PBS);
- FPM.run(F);
-
- // We preserve dominance information when inserting the poll, otherwise
- // we'd have to recalculate this on every insert
- DT.recalculate(F);
-
- auto &PollLocations = PBS->PollLocations;
-
- auto OrderByBBName = [](Instruction *a, Instruction *b) {
- return a->getParent()->getName() < b->getParent()->getName();
- };
- // We need the order of list to be stable so that naming ends up stable
- // when we split edges. This makes test cases much easier to write.
- llvm::sort(PollLocations, OrderByBBName);
-
- // We can sometimes end up with duplicate poll locations. This happens if
- // a single loop is visited more than once. The fact this happens seems
- // wrong, but it does happen for the split-backedge.ll test case.
- PollLocations.erase(std::unique(PollLocations.begin(),
- PollLocations.end()),
- PollLocations.end());
-
- // Insert a poll at each point the analysis pass identified
- // The poll location must be the terminator of a loop latch block.
- for (Instruction *Term : PollLocations) {
- // We are inserting a poll, the function is modified
- Modified = true;
-
- if (SplitBackedge) {
- // Split the backedge of the loop and insert the poll within that new
- // basic block. This creates a loop with two latches per original
- // latch (which is non-ideal), but this appears to be easier to
- // optimize in practice than inserting the poll immediately before the
- // latch test.
-
- // Since this is a latch, at least one of the successors must dominate
- // it. Its possible that we have a) duplicate edges to the same header
- // and b) edges to distinct loop headers. We need to insert pools on
- // each.
- SetVector<BasicBlock *> Headers;
- for (unsigned i = 0; i < Term->getNumSuccessors(); i++) {
- BasicBlock *Succ = Term->getSuccessor(i);
- if (DT.dominates(Succ, Term->getParent())) {
- Headers.insert(Succ);
- }
- }
- assert(!Headers.empty() && "poll location is not a loop latch?");
-
- // The split loop structure here is so that we only need to recalculate
- // the dominator tree once. Alternatively, we could just keep it up to
- // date and use a more natural merged loop.
- SetVector<BasicBlock *> SplitBackedges;
- for (BasicBlock *Header : Headers) {
- BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT);
- PollsNeeded.push_back(NewBB->getTerminator());
- NumBackedgeSafepoints++;
- }
- } else {
- // Split the latch block itself, right before the terminator.
- PollsNeeded.push_back(Term);
- NumBackedgeSafepoints++;
- }
- }
- }
-
- if (enableEntrySafepoints(F)) {
- if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) {
- PollsNeeded.push_back(Location);
- Modified = true;
- NumEntrySafepoints++;
- }
- // TODO: else we should assert that there was, in fact, a policy choice to
- // not insert a entry safepoint poll.
- }
-
- // Now that we've identified all the needed safepoint poll locations, insert
- // safepoint polls themselves.
- for (Instruction *PollLocation : PollsNeeded) {
- std::vector<CallBase *> RuntimeCalls;
- InsertSafepointPoll(PollLocation, RuntimeCalls, TLI);
- llvm::append_range(ParsePointNeeded, RuntimeCalls);
- }
-
- return Modified;
-}
-
-char PlaceBackedgeSafepointsImpl::ID = 0;
-char PlaceSafepoints::ID = 0;
-
-FunctionPass *llvm::createPlaceSafepointsPass() {
- return new PlaceSafepoints();
-}
-
-INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsImpl,
- "place-backedge-safepoints-impl",
- "Place Backedge Safepoints", false, false)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(PlaceBackedgeSafepointsImpl,
- "place-backedge-safepoints-impl",
- "Place Backedge Safepoints", false, false)
-
-INITIALIZE_PASS_BEGIN(PlaceSafepoints, "place-safepoints", "Place Safepoints",
- false, false)
-INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints",
- false, false)
-
+// Insert a safepoint poll immediately before the given instruction. Does
+// not handle the parsability of state at the runtime call, that's the
+// callers job.
static void
InsertSafepointPoll(Instruction *InsertBefore,
std::vector<CallBase *> &ParsePointsNeeded /*rval*/,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 21628b61edd6..40c84e249523 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -52,6 +52,7 @@
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
@@ -70,6 +71,12 @@ STATISTIC(NumChanged, "Number of insts reassociated");
STATISTIC(NumAnnihil, "Number of expr tree annihilated");
STATISTIC(NumFactor , "Number of multiplies factored");
+static cl::opt<bool>
+ UseCSELocalOpt(DEBUG_TYPE "-use-cse-local",
+ cl::desc("Only reorder expressions within a basic block "
+ "when exposing CSE opportunities"),
+ cl::init(true), cl::Hidden);
+
#ifndef NDEBUG
/// Print out the expression identified in the Ops list.
static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) {
@@ -620,8 +627,7 @@ static bool LinearizeExprTree(Instruction *I,
// The leaves, repeated according to their weights, represent the linearized
// form of the expression.
- for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) {
- Value *V = LeafOrder[i];
+ for (Value *V : LeafOrder) {
LeafMap::iterator It = Leaves.find(V);
if (It == Leaves.end())
// Node initially thought to be a leaf wasn't.
@@ -683,10 +689,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
NotRewritable.insert(Ops[i].Op);
- // ExpressionChanged - Non-null if the rewritten expression differs from the
- // original in some non-trivial way, requiring the clearing of optional flags.
- // Flags are cleared from the operator in ExpressionChanged up to I inclusive.
- BinaryOperator *ExpressionChanged = nullptr;
+ // ExpressionChangedStart - Non-null if the rewritten expression differs from
+ // the original in some non-trivial way, requiring the clearing of optional
+ // flags. Flags are cleared from the operator in ExpressionChangedStart up to
+ // ExpressionChangedEnd inclusive.
+ BinaryOperator *ExpressionChangedStart = nullptr,
+ *ExpressionChangedEnd = nullptr;
for (unsigned i = 0; ; ++i) {
// The last operation (which comes earliest in the IR) is special as both
// operands will come from Ops, rather than just one with the other being
@@ -728,7 +736,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
}
LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n');
- ExpressionChanged = Op;
+ ExpressionChangedStart = Op;
+ if (!ExpressionChangedEnd)
+ ExpressionChangedEnd = Op;
MadeChange = true;
++NumChanged;
@@ -750,7 +760,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
if (BO && !NotRewritable.count(BO))
NodesToRewrite.push_back(BO);
Op->setOperand(1, NewRHS);
- ExpressionChanged = Op;
+ ExpressionChangedStart = Op;
+ if (!ExpressionChangedEnd)
+ ExpressionChangedEnd = Op;
}
LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n');
MadeChange = true;
@@ -787,7 +799,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n');
Op->setOperand(0, NewOp);
LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n');
- ExpressionChanged = Op;
+ ExpressionChangedStart = Op;
+ if (!ExpressionChangedEnd)
+ ExpressionChangedEnd = Op;
MadeChange = true;
++NumChanged;
Op = NewOp;
@@ -797,27 +811,36 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
// starting from the operator specified in ExpressionChanged, and compactify
// the operators to just before the expression root to guarantee that the
// expression tree is dominated by all of Ops.
- if (ExpressionChanged)
+ if (ExpressionChangedStart) {
+ bool ClearFlags = true;
do {
// Preserve FastMathFlags.
- if (isa<FPMathOperator>(I)) {
- FastMathFlags Flags = I->getFastMathFlags();
- ExpressionChanged->clearSubclassOptionalData();
- ExpressionChanged->setFastMathFlags(Flags);
- } else
- ExpressionChanged->clearSubclassOptionalData();
-
- if (ExpressionChanged == I)
+ if (ClearFlags) {
+ if (isa<FPMathOperator>(I)) {
+ FastMathFlags Flags = I->getFastMathFlags();
+ ExpressionChangedStart->clearSubclassOptionalData();
+ ExpressionChangedStart->setFastMathFlags(Flags);
+ } else
+ ExpressionChangedStart->clearSubclassOptionalData();
+ }
+
+ if (ExpressionChangedStart == ExpressionChangedEnd)
+ ClearFlags = false;
+ if (ExpressionChangedStart == I)
break;
// Discard any debug info related to the expressions that has changed (we
- // can leave debug infor related to the root, since the result of the
- // expression tree should be the same even after reassociation).
- replaceDbgUsesWithUndef(ExpressionChanged);
-
- ExpressionChanged->moveBefore(I);
- ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin());
+ // can leave debug info related to the root and any operation that didn't
+ // change, since the result of the expression tree should be the same
+ // even after reassociation).
+ if (ClearFlags)
+ replaceDbgUsesWithUndef(ExpressionChangedStart);
+
+ ExpressionChangedStart->moveBefore(I);
+ ExpressionChangedStart =
+ cast<BinaryOperator>(*ExpressionChangedStart->user_begin());
} while (true);
+ }
// Throw away any left over nodes from the original expression.
for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i)
@@ -1507,8 +1530,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I,
// Step 4: Reassemble the Ops
if (Changed) {
Ops.clear();
- for (unsigned int i = 0, e = Opnds.size(); i < e; i++) {
- XorOpnd &O = Opnds[i];
+ for (const XorOpnd &O : Opnds) {
if (O.isInvalid())
continue;
ValueEntry VE(getRank(O.getValue()), O.getValue());
@@ -1644,8 +1666,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Add one to FactorOccurrences for each unique factor in this op.
SmallPtrSet<Value*, 8> Duplicates;
- for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
- Value *Factor = Factors[i];
+ for (Value *Factor : Factors) {
if (!Duplicates.insert(Factor).second)
continue;
@@ -2048,7 +2069,7 @@ void ReassociatePass::EraseInst(Instruction *I) {
// blocks because it's a waste of time and also because it can
// lead to infinite loop due to LLVM's non-standard definition
// of dominance.
- if (ValueRankMap.find(Op) != ValueRankMap.end())
+ if (ValueRankMap.contains(Op))
RedoInsts.insert(Op);
}
@@ -2410,8 +2431,67 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
unsigned BestRank = 0;
std::pair<unsigned, unsigned> BestPair;
unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
- for (unsigned i = 0; i < Ops.size() - 1; ++i)
- for (unsigned j = i + 1; j < Ops.size(); ++j) {
+ unsigned LimitIdx = 0;
+ // With the CSE-driven heuristic, we are about to slap two values at the
+ // beginning of the expression whereas they could live very late in the CFG.
+ // When using the CSE-local heuristic we avoid creating dependences from
+ // completely unrelated part of the CFG by limiting the expression
+ // reordering on the values that live in the first seen basic block.
+ // The main idea is that we want to avoid forming expressions that would
+ // become loop dependent.
+ if (UseCSELocalOpt) {
+ const BasicBlock *FirstSeenBB = nullptr;
+ int StartIdx = Ops.size() - 1;
+ // Skip the first value of the expression since we need at least two
+ // values to materialize an expression. I.e., even if this value is
+ // anchored in a different basic block, the actual first sub expression
+ // will be anchored on the second value.
+ for (int i = StartIdx - 1; i != -1; --i) {
+ const Value *Val = Ops[i].Op;
+ const auto *CurrLeafInstr = dyn_cast<Instruction>(Val);
+ const BasicBlock *SeenBB = nullptr;
+ if (!CurrLeafInstr) {
+ // The value is free of any CFG dependencies.
+ // Do as if it lives in the entry block.
+ //
+ // We do this to make sure all the values falling on this path are
+ // seen through the same anchor point. The rationale is these values
+ // can be combined together to from a sub expression free of any CFG
+ // dependencies so we want them to stay together.
+ // We could be cleverer and postpone the anchor down to the first
+ // anchored value, but that's likely complicated to get right.
+ // E.g., we wouldn't want to do that if that means being stuck in a
+ // loop.
+ //
+ // For instance, we wouldn't want to change:
+ // res = arg1 op arg2 op arg3 op ... op loop_val1 op loop_val2 ...
+ // into
+ // res = loop_val1 op arg1 op arg2 op arg3 op ... op loop_val2 ...
+ // Because all the sub expressions with arg2..N would be stuck between
+ // two loop dependent values.
+ SeenBB = &I->getParent()->getParent()->getEntryBlock();
+ } else {
+ SeenBB = CurrLeafInstr->getParent();
+ }
+
+ if (!FirstSeenBB) {
+ FirstSeenBB = SeenBB;
+ continue;
+ }
+ if (FirstSeenBB != SeenBB) {
+ // ith value is in a different basic block.
+ // Rewind the index once to point to the last value on the same basic
+ // block.
+ LimitIdx = i + 1;
+ LLVM_DEBUG(dbgs() << "CSE reordering: Consider values between ["
+ << LimitIdx << ", " << StartIdx << "]\n");
+ break;
+ }
+ }
+ }
+ for (unsigned i = Ops.size() - 1; i > LimitIdx; --i) {
+ // We must use int type to go below zero when LimitIdx is 0.
+ for (int j = i - 1; j >= (int)LimitIdx; --j) {
unsigned Score = 0;
Value *Op0 = Ops[i].Op;
Value *Op1 = Ops[j].Op;
@@ -2429,12 +2509,26 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
}
unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
+
+ // By construction, the operands are sorted in reverse order of their
+ // topological order.
+ // So we tend to form (sub) expressions with values that are close to
+ // each other.
+ //
+ // Now to expose more CSE opportunities we want to expose the pair of
+ // operands that occur the most (as statically computed in
+ // BuildPairMap.) as the first sub-expression.
+ //
+ // If two pairs occur as many times, we pick the one with the
+ // lowest rank, meaning the one with both operands appearing first in
+ // the topological order.
if (Score > Max || (Score == Max && MaxRank < BestRank)) {
- BestPair = {i, j};
+ BestPair = {j, i};
Max = Score;
BestRank = MaxRank;
}
}
+ }
if (Max > 1) {
auto Op0 = Ops[BestPair.first];
auto Op1 = Ops[BestPair.second];
@@ -2444,6 +2538,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
Ops.push_back(Op1);
}
}
+ LLVM_DEBUG(dbgs() << "RAOut after CSE reorder:\t"; PrintOps(I, Ops);
+ dbgs() << '\n');
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops);
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index bcb012b79c2e..908bda5709a0 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -27,6 +27,7 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Argument.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallingConv.h"
@@ -36,6 +37,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/GCStrategy.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
@@ -125,6 +127,9 @@ static cl::opt<bool> RematDerivedAtUses("rs4gc-remat-derived-at-uses",
/// constant physical memory: llvm.invariant.start.
static void stripNonValidData(Module &M);
+// Find the GC strategy for a function, or null if it doesn't have one.
+static std::unique_ptr<GCStrategy> findGCStrategy(Function &F);
+
static bool shouldRewriteStatepointsIn(Function &F);
PreservedAnalyses RewriteStatepointsForGC::run(Module &M,
@@ -162,76 +167,6 @@ PreservedAnalyses RewriteStatepointsForGC::run(Module &M,
namespace {
-class RewriteStatepointsForGCLegacyPass : public ModulePass {
- RewriteStatepointsForGC Impl;
-
-public:
- static char ID; // Pass identification, replacement for typeid
-
- RewriteStatepointsForGCLegacyPass() : ModulePass(ID), Impl() {
- initializeRewriteStatepointsForGCLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnModule(Module &M) override {
- bool Changed = false;
- for (Function &F : M) {
- // Nothing to do for declarations.
- if (F.isDeclaration() || F.empty())
- continue;
-
- // Policy choice says not to rewrite - the most common reason is that
- // we're compiling code without a GCStrategy.
- if (!shouldRewriteStatepointsIn(F))
- continue;
-
- TargetTransformInfo &TTI =
- getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
-
- Changed |= Impl.runOnFunction(F, DT, TTI, TLI);
- }
-
- if (!Changed)
- return false;
-
- // stripNonValidData asserts that shouldRewriteStatepointsIn
- // returns true for at least one function in the module. Since at least
- // one function changed, we know that the precondition is satisfied.
- stripNonValidData(M);
- return true;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- // We add and rewrite a bunch of instructions, but don't really do much
- // else. We could in theory preserve a lot more analyses here.
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
-char RewriteStatepointsForGCLegacyPass::ID = 0;
-
-ModulePass *llvm::createRewriteStatepointsForGCLegacyPass() {
- return new RewriteStatepointsForGCLegacyPass();
-}
-
-INITIALIZE_PASS_BEGIN(RewriteStatepointsForGCLegacyPass,
- "rewrite-statepoints-for-gc",
- "Make relocations explicit at statepoints", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(RewriteStatepointsForGCLegacyPass,
- "rewrite-statepoints-for-gc",
- "Make relocations explicit at statepoints", false, false)
-
-namespace {
-
struct GCPtrLivenessData {
/// Values defined in this block.
MapVector<BasicBlock *, SetVector<Value *>> KillSet;
@@ -311,37 +246,35 @@ static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) {
/// Compute the live-in set for every basic block in the function
static void computeLiveInValues(DominatorTree &DT, Function &F,
- GCPtrLivenessData &Data);
+ GCPtrLivenessData &Data, GCStrategy *GC);
/// Given results from the dataflow liveness computation, find the set of live
/// Values at a particular instruction.
static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data,
- StatepointLiveSetTy &out);
+ StatepointLiveSetTy &out, GCStrategy *GC);
-// TODO: Once we can get to the GCStrategy, this becomes
-// std::optional<bool> isGCManagedPointer(const Type *Ty) const override {
+static bool isGCPointerType(Type *T, GCStrategy *GC) {
+ assert(GC && "GC Strategy for isGCPointerType cannot be null");
-static bool isGCPointerType(Type *T) {
- if (auto *PT = dyn_cast<PointerType>(T))
- // For the sake of this example GC, we arbitrarily pick addrspace(1) as our
- // GC managed heap. We know that a pointer into this heap needs to be
- // updated and that no other pointer does.
- return PT->getAddressSpace() == 1;
- return false;
+ if (!isa<PointerType>(T))
+ return false;
+
+ // conservative - same as StatepointLowering
+ return GC->isGCManagedPointer(T).value_or(true);
}
// Return true if this type is one which a) is a gc pointer or contains a GC
// pointer and b) is of a type this code expects to encounter as a live value.
// (The insertion code will assert that a type which matches (a) and not (b)
// is not encountered.)
-static bool isHandledGCPointerType(Type *T) {
+static bool isHandledGCPointerType(Type *T, GCStrategy *GC) {
// We fully support gc pointers
- if (isGCPointerType(T))
+ if (isGCPointerType(T, GC))
return true;
// We partially support vectors of gc pointers. The code will assert if it
// can't handle something.
if (auto VT = dyn_cast<VectorType>(T))
- if (isGCPointerType(VT->getElementType()))
+ if (isGCPointerType(VT->getElementType(), GC))
return true;
return false;
}
@@ -349,23 +282,24 @@ static bool isHandledGCPointerType(Type *T) {
#ifndef NDEBUG
/// Returns true if this type contains a gc pointer whether we know how to
/// handle that type or not.
-static bool containsGCPtrType(Type *Ty) {
- if (isGCPointerType(Ty))
+static bool containsGCPtrType(Type *Ty, GCStrategy *GC) {
+ if (isGCPointerType(Ty, GC))
return true;
if (VectorType *VT = dyn_cast<VectorType>(Ty))
- return isGCPointerType(VT->getScalarType());
+ return isGCPointerType(VT->getScalarType(), GC);
if (ArrayType *AT = dyn_cast<ArrayType>(Ty))
- return containsGCPtrType(AT->getElementType());
+ return containsGCPtrType(AT->getElementType(), GC);
if (StructType *ST = dyn_cast<StructType>(Ty))
- return llvm::any_of(ST->elements(), containsGCPtrType);
+ return llvm::any_of(ST->elements(),
+ [GC](Type *Ty) { return containsGCPtrType(Ty, GC); });
return false;
}
// Returns true if this is a type which a) is a gc pointer or contains a GC
// pointer and b) is of a type which the code doesn't expect (i.e. first class
// aggregates). Used to trip assertions.
-static bool isUnhandledGCPointerType(Type *Ty) {
- return containsGCPtrType(Ty) && !isHandledGCPointerType(Ty);
+static bool isUnhandledGCPointerType(Type *Ty, GCStrategy *GC) {
+ return containsGCPtrType(Ty, GC) && !isHandledGCPointerType(Ty, GC);
}
#endif
@@ -382,9 +316,9 @@ static std::string suffixed_name_or(Value *V, StringRef Suffix,
// live. Values used by that instruction are considered live.
static void analyzeParsePointLiveness(
DominatorTree &DT, GCPtrLivenessData &OriginalLivenessData, CallBase *Call,
- PartiallyConstructedSafepointRecord &Result) {
+ PartiallyConstructedSafepointRecord &Result, GCStrategy *GC) {
StatepointLiveSetTy LiveSet;
- findLiveSetAtInst(Call, OriginalLivenessData, LiveSet);
+ findLiveSetAtInst(Call, OriginalLivenessData, LiveSet, GC);
if (PrintLiveSet) {
dbgs() << "Live Variables:\n";
@@ -692,7 +626,7 @@ static Value *findBaseDefiningValue(Value *I, DefiningValueMapTy &Cache,
/// Returns the base defining value for this value.
static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache,
IsKnownBaseMapTy &KnownBases) {
- if (Cache.find(I) == Cache.end()) {
+ if (!Cache.contains(I)) {
auto *BDV = findBaseDefiningValue(I, Cache, KnownBases);
Cache[I] = BDV;
LLVM_DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> "
@@ -700,7 +634,7 @@ static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache,
<< KnownBases[I] << "\n");
}
assert(Cache[I] != nullptr);
- assert(KnownBases.find(Cache[I]) != KnownBases.end() &&
+ assert(KnownBases.contains(Cache[I]) &&
"Cached value must be present in known bases map");
return Cache[I];
}
@@ -1289,9 +1223,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
if (!BdvSV->isZeroEltSplat())
UpdateOperand(1); // vector operand
else {
- // Never read, so just use undef
+ // Never read, so just use poison
Value *InVal = BdvSV->getOperand(1);
- BaseSV->setOperand(1, UndefValue::get(InVal->getType()));
+ BaseSV->setOperand(1, PoisonValue::get(InVal->getType()));
}
}
}
@@ -1385,20 +1319,21 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache,
static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData,
CallBase *Call,
PartiallyConstructedSafepointRecord &result,
- PointerToBaseTy &PointerToBase);
+ PointerToBaseTy &PointerToBase,
+ GCStrategy *GC);
static void recomputeLiveInValues(
Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate,
MutableArrayRef<struct PartiallyConstructedSafepointRecord> records,
- PointerToBaseTy &PointerToBase) {
+ PointerToBaseTy &PointerToBase, GCStrategy *GC) {
// TODO-PERF: reuse the original liveness, then simply run the dataflow
// again. The old values are still live and will help it stabilize quickly.
GCPtrLivenessData RevisedLivenessData;
- computeLiveInValues(DT, F, RevisedLivenessData);
+ computeLiveInValues(DT, F, RevisedLivenessData, GC);
for (size_t i = 0; i < records.size(); i++) {
struct PartiallyConstructedSafepointRecord &info = records[i];
- recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info,
- PointerToBase);
+ recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info, PointerToBase,
+ GC);
}
}
@@ -1522,7 +1457,7 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx,
static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
ArrayRef<Value *> BasePtrs,
Instruction *StatepointToken,
- IRBuilder<> &Builder) {
+ IRBuilder<> &Builder, GCStrategy *GC) {
if (LiveVariables.empty())
return;
@@ -1542,8 +1477,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
// towards a single unified pointer type anyways, we can just cast everything
// to an i8* of the right address space. A bitcast is added later to convert
// gc_relocate to the actual value's type.
- auto getGCRelocateDecl = [&] (Type *Ty) {
- assert(isHandledGCPointerType(Ty));
+ auto getGCRelocateDecl = [&](Type *Ty) {
+ assert(isHandledGCPointerType(Ty, GC));
auto AS = Ty->getScalarType()->getPointerAddressSpace();
Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS);
if (auto *VT = dyn_cast<VectorType>(Ty))
@@ -1668,7 +1603,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
const SmallVectorImpl<Value *> &LiveVariables,
PartiallyConstructedSafepointRecord &Result,
std::vector<DeferredReplacement> &Replacements,
- const PointerToBaseTy &PointerToBase) {
+ const PointerToBaseTy &PointerToBase,
+ GCStrategy *GC) {
assert(BasePtrs.size() == LiveVariables.size());
// Then go ahead and use the builder do actually do the inserts. We insert
@@ -1901,7 +1837,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst();
Result.UnwindToken = ExceptionalToken;
- CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder);
+ CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder, GC);
// Generate gc relocates and returns for normal block
BasicBlock *NormalDest = II->getNormalDest();
@@ -1947,7 +1883,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Result.StatepointToken = Token;
// Second, create a gc.relocate for every live variable
- CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder);
+ CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder, GC);
}
// Replace an existing gc.statepoint with a new one and a set of gc.relocates
@@ -1959,7 +1895,7 @@ static void
makeStatepointExplicit(DominatorTree &DT, CallBase *Call,
PartiallyConstructedSafepointRecord &Result,
std::vector<DeferredReplacement> &Replacements,
- const PointerToBaseTy &PointerToBase) {
+ const PointerToBaseTy &PointerToBase, GCStrategy *GC) {
const auto &LiveSet = Result.LiveSet;
// Convert to vector for efficient cross referencing.
@@ -1976,7 +1912,7 @@ makeStatepointExplicit(DominatorTree &DT, CallBase *Call,
// Do the actual rewriting and delete the old statepoint
makeStatepointExplicitImpl(Call, BaseVec, LiveVec, Result, Replacements,
- PointerToBase);
+ PointerToBase, GC);
}
// Helper function for the relocationViaAlloca.
@@ -2277,12 +2213,13 @@ static void insertUseHolderAfter(CallBase *Call, const ArrayRef<Value *> Values,
static void findLiveReferences(
Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate,
- MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) {
+ MutableArrayRef<struct PartiallyConstructedSafepointRecord> records,
+ GCStrategy *GC) {
GCPtrLivenessData OriginalLivenessData;
- computeLiveInValues(DT, F, OriginalLivenessData);
+ computeLiveInValues(DT, F, OriginalLivenessData, GC);
for (size_t i = 0; i < records.size(); i++) {
struct PartiallyConstructedSafepointRecord &info = records[i];
- analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info);
+ analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info, GC);
}
}
@@ -2684,6 +2621,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
SmallVectorImpl<CallBase *> &ToUpdate,
DefiningValueMapTy &DVCache,
IsKnownBaseMapTy &KnownBases) {
+ std::unique_ptr<GCStrategy> GC = findGCStrategy(F);
+
#ifndef NDEBUG
// Validate the input
std::set<CallBase *> Uniqued;
@@ -2718,9 +2657,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
SmallVector<Value *, 64> DeoptValues;
for (Value *Arg : GetDeoptBundleOperands(Call)) {
- assert(!isUnhandledGCPointerType(Arg->getType()) &&
+ assert(!isUnhandledGCPointerType(Arg->getType(), GC.get()) &&
"support for FCA unimplemented");
- if (isHandledGCPointerType(Arg->getType()))
+ if (isHandledGCPointerType(Arg->getType(), GC.get()))
DeoptValues.push_back(Arg);
}
@@ -2731,7 +2670,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// A) Identify all gc pointers which are statically live at the given call
// site.
- findLiveReferences(F, DT, ToUpdate, Records);
+ findLiveReferences(F, DT, ToUpdate, Records, GC.get());
/// Global mapping from live pointers to a base-defining-value.
PointerToBaseTy PointerToBase;
@@ -2782,7 +2721,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// By selecting base pointers, we've effectively inserted new uses. Thus, we
// need to rerun liveness. We may *also* have inserted new defs, but that's
// not the key issue.
- recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase);
+ recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase, GC.get());
if (PrintBasePointers) {
errs() << "Base Pairs: (w/Relocation)\n";
@@ -2842,7 +2781,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// the old statepoint calls as we go.)
for (size_t i = 0; i < Records.size(); i++)
makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements,
- PointerToBase);
+ PointerToBase, GC.get());
ToUpdate.clear(); // prevent accident use of invalid calls.
@@ -2866,9 +2805,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// Do all the fixups of the original live variables to their relocated selves
SmallVector<Value *, 128> Live;
- for (size_t i = 0; i < Records.size(); i++) {
- PartiallyConstructedSafepointRecord &Info = Records[i];
-
+ for (const PartiallyConstructedSafepointRecord &Info : Records) {
// We can't simply save the live set from the original insertion. One of
// the live values might be the result of a call which needs a safepoint.
// That Value* no longer exists and we need to use the new gc_result.
@@ -2899,7 +2836,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
#ifndef NDEBUG
// Validation check
for (auto *Ptr : Live)
- assert(isHandledGCPointerType(Ptr->getType()) &&
+ assert(isHandledGCPointerType(Ptr->getType(), GC.get()) &&
"must be a gc pointer type");
#endif
@@ -3019,25 +2956,33 @@ static void stripNonValidDataFromBody(Function &F) {
}
}
- // Delete the invariant.start instructions and RAUW undef.
+ // Delete the invariant.start instructions and RAUW poison.
for (auto *II : InvariantStartInstructions) {
- II->replaceAllUsesWith(UndefValue::get(II->getType()));
+ II->replaceAllUsesWith(PoisonValue::get(II->getType()));
II->eraseFromParent();
}
}
+/// Looks up the GC strategy for a given function, returning null if the
+/// function doesn't have a GC tag. The strategy is stored in the cache.
+static std::unique_ptr<GCStrategy> findGCStrategy(Function &F) {
+ if (!F.hasGC())
+ return nullptr;
+
+ return getGCStrategy(F.getGC());
+}
+
/// Returns true if this function should be rewritten by this pass. The main
/// point of this function is as an extension point for custom logic.
static bool shouldRewriteStatepointsIn(Function &F) {
- // TODO: This should check the GCStrategy
- if (F.hasGC()) {
- const auto &FunctionGCName = F.getGC();
- const StringRef StatepointExampleName("statepoint-example");
- const StringRef CoreCLRName("coreclr");
- return (StatepointExampleName == FunctionGCName) ||
- (CoreCLRName == FunctionGCName);
- } else
+ if (!F.hasGC())
return false;
+
+ std::unique_ptr<GCStrategy> Strategy = findGCStrategy(F);
+
+ assert(Strategy && "GC strategy is required by function, but was not found");
+
+ return Strategy->useRS4GC();
}
static void stripNonValidData(Module &M) {
@@ -3216,7 +3161,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,
/// the live-out set of the basic block
static void computeLiveInValues(BasicBlock::reverse_iterator Begin,
BasicBlock::reverse_iterator End,
- SetVector<Value *> &LiveTmp) {
+ SetVector<Value *> &LiveTmp, GCStrategy *GC) {
for (auto &I : make_range(Begin, End)) {
// KILL/Def - Remove this definition from LiveIn
LiveTmp.remove(&I);
@@ -3228,9 +3173,9 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin,
// USE - Add to the LiveIn set for this instruction
for (Value *V : I.operands()) {
- assert(!isUnhandledGCPointerType(V->getType()) &&
+ assert(!isUnhandledGCPointerType(V->getType(), GC) &&
"support for FCA unimplemented");
- if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) {
+ if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V)) {
// The choice to exclude all things constant here is slightly subtle.
// There are two independent reasons:
// - We assume that things which are constant (from LLVM's definition)
@@ -3247,7 +3192,8 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin,
}
}
-static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) {
+static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp,
+ GCStrategy *GC) {
for (BasicBlock *Succ : successors(BB)) {
for (auto &I : *Succ) {
PHINode *PN = dyn_cast<PHINode>(&I);
@@ -3255,18 +3201,18 @@ static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) {
break;
Value *V = PN->getIncomingValueForBlock(BB);
- assert(!isUnhandledGCPointerType(V->getType()) &&
+ assert(!isUnhandledGCPointerType(V->getType(), GC) &&
"support for FCA unimplemented");
- if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V))
+ if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V))
LiveTmp.insert(V);
}
}
}
-static SetVector<Value *> computeKillSet(BasicBlock *BB) {
+static SetVector<Value *> computeKillSet(BasicBlock *BB, GCStrategy *GC) {
SetVector<Value *> KillSet;
for (Instruction &I : *BB)
- if (isHandledGCPointerType(I.getType()))
+ if (isHandledGCPointerType(I.getType(), GC))
KillSet.insert(&I);
return KillSet;
}
@@ -3301,14 +3247,14 @@ static void checkBasicSSA(DominatorTree &DT, GCPtrLivenessData &Data,
#endif
static void computeLiveInValues(DominatorTree &DT, Function &F,
- GCPtrLivenessData &Data) {
+ GCPtrLivenessData &Data, GCStrategy *GC) {
SmallSetVector<BasicBlock *, 32> Worklist;
// Seed the liveness for each individual block
for (BasicBlock &BB : F) {
- Data.KillSet[&BB] = computeKillSet(&BB);
+ Data.KillSet[&BB] = computeKillSet(&BB, GC);
Data.LiveSet[&BB].clear();
- computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB]);
+ computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB], GC);
#ifndef NDEBUG
for (Value *Kill : Data.KillSet[&BB])
@@ -3316,7 +3262,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F,
#endif
Data.LiveOut[&BB] = SetVector<Value *>();
- computeLiveOutSeed(&BB, Data.LiveOut[&BB]);
+ computeLiveOutSeed(&BB, Data.LiveOut[&BB], GC);
Data.LiveIn[&BB] = Data.LiveSet[&BB];
Data.LiveIn[&BB].set_union(Data.LiveOut[&BB]);
Data.LiveIn[&BB].set_subtract(Data.KillSet[&BB]);
@@ -3368,7 +3314,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F,
}
static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data,
- StatepointLiveSetTy &Out) {
+ StatepointLiveSetTy &Out, GCStrategy *GC) {
BasicBlock *BB = Inst->getParent();
// Note: The copy is intentional and required
@@ -3379,8 +3325,8 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data,
// call result is not live (normal), nor are it's arguments
// (unless they're used again later). This adjustment is
// specifically what we need to relocate
- computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(),
- LiveOut);
+ computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), LiveOut,
+ GC);
LiveOut.remove(Inst);
Out.insert(LiveOut.begin(), LiveOut.end());
}
@@ -3388,9 +3334,10 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data,
static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData,
CallBase *Call,
PartiallyConstructedSafepointRecord &Info,
- PointerToBaseTy &PointerToBase) {
+ PointerToBaseTy &PointerToBase,
+ GCStrategy *GC) {
StatepointLiveSetTy Updated;
- findLiveSetAtInst(Call, RevisedLivenessData, Updated);
+ findLiveSetAtInst(Call, RevisedLivenessData, Updated, GC);
// We may have base pointers which are now live that weren't before. We need
// to update the PointerToBase structure to reflect this.
diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp
index 7b396c6ee074..fcdc503c54a4 100644
--- a/llvm/lib/Transforms/Scalar/SCCP.cpp
+++ b/llvm/lib/Transforms/Scalar/SCCP.cpp
@@ -41,7 +41,6 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@@ -136,54 +135,3 @@ PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<DominatorTreeAnalysis>();
return PA;
}
-
-namespace {
-
-//===--------------------------------------------------------------------===//
-//
-/// SCCP Class - This class uses the SCCPSolver to implement a per-function
-/// Sparse Conditional Constant Propagator.
-///
-class SCCPLegacyPass : public FunctionPass {
-public:
- // Pass identification, replacement for typeid
- static char ID;
-
- SCCPLegacyPass() : FunctionPass(ID) {
- initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- }
-
- // runOnFunction - Run the Sparse Conditional Constant Propagation
- // algorithm, and return true if the function was modified.
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- const DataLayout &DL = F.getParent()->getDataLayout();
- const TargetLibraryInfo *TLI =
- &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- DomTreeUpdater DTU(DTWP ? &DTWP->getDomTree() : nullptr,
- DomTreeUpdater::UpdateStrategy::Lazy);
- return runSCCP(F, DL, TLI, DTU);
- }
-};
-
-} // end anonymous namespace
-
-char SCCPLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(SCCPLegacyPass, "sccp",
- "Sparse Conditional Constant Propagation", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(SCCPLegacyPass, "sccp",
- "Sparse Conditional Constant Propagation", false, false)
-
-// createSCCPPass - This is the public interface to this file.
-FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); }
-
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 8339981e1bdc..983a75e1d708 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -118,13 +118,79 @@ STATISTIC(NumVectorized, "Number of vectorized aggregates");
/// GEPs.
static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false),
cl::Hidden);
+/// Disable running mem2reg during SROA in order to test or debug SROA.
+static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false),
+ cl::Hidden);
namespace {
+
+/// Calculate the fragment of a variable to use when slicing a store
+/// based on the slice dimensions, existing fragment, and base storage
+/// fragment.
+/// Results:
+/// UseFrag - Use Target as the new fragment.
+/// UseNoFrag - The new slice already covers the whole variable.
+/// Skip - The new alloca slice doesn't include this variable.
+/// FIXME: Can we use calculateFragmentIntersect instead?
+enum FragCalcResult { UseFrag, UseNoFrag, Skip };
+static FragCalcResult
+calculateFragment(DILocalVariable *Variable,
+ uint64_t NewStorageSliceOffsetInBits,
+ uint64_t NewStorageSliceSizeInBits,
+ std::optional<DIExpression::FragmentInfo> StorageFragment,
+ std::optional<DIExpression::FragmentInfo> CurrentFragment,
+ DIExpression::FragmentInfo &Target) {
+ // If the base storage describes part of the variable apply the offset and
+ // the size constraint.
+ if (StorageFragment) {
+ Target.SizeInBits =
+ std::min(NewStorageSliceSizeInBits, StorageFragment->SizeInBits);
+ Target.OffsetInBits =
+ NewStorageSliceOffsetInBits + StorageFragment->OffsetInBits;
+ } else {
+ Target.SizeInBits = NewStorageSliceSizeInBits;
+ Target.OffsetInBits = NewStorageSliceOffsetInBits;
+ }
+
+ // If this slice extracts the entirety of an independent variable from a
+ // larger alloca, do not produce a fragment expression, as the variable is
+ // not fragmented.
+ if (!CurrentFragment) {
+ if (auto Size = Variable->getSizeInBits()) {
+ // Treat the current fragment as covering the whole variable.
+ CurrentFragment = DIExpression::FragmentInfo(*Size, 0);
+ if (Target == CurrentFragment)
+ return UseNoFrag;
+ }
+ }
+
+ // No additional work to do if there isn't a fragment already, or there is
+ // but it already exactly describes the new assignment.
+ if (!CurrentFragment || *CurrentFragment == Target)
+ return UseFrag;
+
+ // Reject the target fragment if it doesn't fit wholly within the current
+ // fragment. TODO: We could instead chop up the target to fit in the case of
+ // a partial overlap.
+ if (Target.startInBits() < CurrentFragment->startInBits() ||
+ Target.endInBits() > CurrentFragment->endInBits())
+ return Skip;
+
+ // Target fits within the current fragment, return it.
+ return UseFrag;
+}
+
+static DebugVariable getAggregateVariable(DbgVariableIntrinsic *DVI) {
+ return DebugVariable(DVI->getVariable(), std::nullopt,
+ DVI->getDebugLoc().getInlinedAt());
+}
+
/// Find linked dbg.assign and generate a new one with the correct
/// FragmentInfo. Link Inst to the new dbg.assign. If Value is nullptr the
/// value component is copied from the old dbg.assign to the new.
/// \param OldAlloca Alloca for the variable before splitting.
-/// \param RelativeOffsetInBits Offset into \p OldAlloca relative to the
-/// offset prior to splitting (change in offset).
+/// \param IsSplit True if the store (not necessarily alloca)
+/// is being split.
+/// \param OldAllocaOffsetInBits Offset of the slice taken from OldAlloca.
/// \param SliceSizeInBits New number of bits being written to.
/// \param OldInst Instruction that is being split.
/// \param Inst New instruction performing this part of the
@@ -132,8 +198,8 @@ namespace {
/// \param Dest Store destination.
/// \param Value Stored value.
/// \param DL Datalayout.
-static void migrateDebugInfo(AllocaInst *OldAlloca,
- uint64_t RelativeOffsetInBits,
+static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit,
+ uint64_t OldAllocaOffsetInBits,
uint64_t SliceSizeInBits, Instruction *OldInst,
Instruction *Inst, Value *Dest, Value *Value,
const DataLayout &DL) {
@@ -144,7 +210,9 @@ static void migrateDebugInfo(AllocaInst *OldAlloca,
LLVM_DEBUG(dbgs() << " migrateDebugInfo\n");
LLVM_DEBUG(dbgs() << " OldAlloca: " << *OldAlloca << "\n");
- LLVM_DEBUG(dbgs() << " RelativeOffset: " << RelativeOffsetInBits << "\n");
+ LLVM_DEBUG(dbgs() << " IsSplit: " << IsSplit << "\n");
+ LLVM_DEBUG(dbgs() << " OldAllocaOffsetInBits: " << OldAllocaOffsetInBits
+ << "\n");
LLVM_DEBUG(dbgs() << " SliceSizeInBits: " << SliceSizeInBits << "\n");
LLVM_DEBUG(dbgs() << " OldInst: " << *OldInst << "\n");
LLVM_DEBUG(dbgs() << " Inst: " << *Inst << "\n");
@@ -152,44 +220,66 @@ static void migrateDebugInfo(AllocaInst *OldAlloca,
if (Value)
LLVM_DEBUG(dbgs() << " Value: " << *Value << "\n");
+ /// Map of aggregate variables to their fragment associated with OldAlloca.
+ DenseMap<DebugVariable, std::optional<DIExpression::FragmentInfo>>
+ BaseFragments;
+ for (auto *DAI : at::getAssignmentMarkers(OldAlloca))
+ BaseFragments[getAggregateVariable(DAI)] =
+ DAI->getExpression()->getFragmentInfo();
+
// The new inst needs a DIAssignID unique metadata tag (if OldInst has
// one). It shouldn't already have one: assert this assumption.
assert(!Inst->getMetadata(LLVMContext::MD_DIAssignID));
DIAssignID *NewID = nullptr;
auto &Ctx = Inst->getContext();
DIBuilder DIB(*OldInst->getModule(), /*AllowUnresolved*/ false);
- uint64_t AllocaSizeInBits = *OldAlloca->getAllocationSizeInBits(DL);
assert(OldAlloca->isStaticAlloca());
for (DbgAssignIntrinsic *DbgAssign : MarkerRange) {
LLVM_DEBUG(dbgs() << " existing dbg.assign is: " << *DbgAssign
<< "\n");
auto *Expr = DbgAssign->getExpression();
+ bool SetKillLocation = false;
- // Check if the dbg.assign already describes a fragment.
- auto GetCurrentFragSize = [AllocaSizeInBits, DbgAssign,
- Expr]() -> uint64_t {
- if (auto FI = Expr->getFragmentInfo())
- return FI->SizeInBits;
- if (auto VarSize = DbgAssign->getVariable()->getSizeInBits())
- return *VarSize;
- // The variable type has an unspecified size. This can happen in the
- // case of DW_TAG_unspecified_type types, e.g. std::nullptr_t. Because
- // there is no fragment and we do not know the size of the variable type,
- // we'll guess by looking at the alloca.
- return AllocaSizeInBits;
- };
- uint64_t CurrentFragSize = GetCurrentFragSize();
- bool MakeNewFragment = CurrentFragSize != SliceSizeInBits;
- assert(MakeNewFragment || RelativeOffsetInBits == 0);
-
- assert(SliceSizeInBits <= AllocaSizeInBits);
- if (MakeNewFragment) {
- assert(RelativeOffsetInBits + SliceSizeInBits <= CurrentFragSize);
- auto E = DIExpression::createFragmentExpression(
- Expr, RelativeOffsetInBits, SliceSizeInBits);
- assert(E && "Failed to create fragment expr!");
- Expr = *E;
+ if (IsSplit) {
+ std::optional<DIExpression::FragmentInfo> BaseFragment;
+ {
+ auto R = BaseFragments.find(getAggregateVariable(DbgAssign));
+ if (R == BaseFragments.end())
+ continue;
+ BaseFragment = R->second;
+ }
+ std::optional<DIExpression::FragmentInfo> CurrentFragment =
+ Expr->getFragmentInfo();
+ DIExpression::FragmentInfo NewFragment;
+ FragCalcResult Result = calculateFragment(
+ DbgAssign->getVariable(), OldAllocaOffsetInBits, SliceSizeInBits,
+ BaseFragment, CurrentFragment, NewFragment);
+
+ if (Result == Skip)
+ continue;
+ if (Result == UseFrag && !(NewFragment == CurrentFragment)) {
+ if (CurrentFragment) {
+ // Rewrite NewFragment to be relative to the existing one (this is
+ // what createFragmentExpression wants). CalculateFragment has
+ // already resolved the size for us. FIXME: Should it return the
+ // relative fragment too?
+ NewFragment.OffsetInBits -= CurrentFragment->OffsetInBits;
+ }
+ // Add the new fragment info to the existing expression if possible.
+ if (auto E = DIExpression::createFragmentExpression(
+ Expr, NewFragment.OffsetInBits, NewFragment.SizeInBits)) {
+ Expr = *E;
+ } else {
+ // Otherwise, add the new fragment info to an empty expression and
+ // discard the value component of this dbg.assign as the value cannot
+ // be computed with the new fragment.
+ Expr = *DIExpression::createFragmentExpression(
+ DIExpression::get(Expr->getContext(), std::nullopt),
+ NewFragment.OffsetInBits, NewFragment.SizeInBits);
+ SetKillLocation = true;
+ }
+ }
}
// If we haven't created a DIAssignID ID do that now and attach it to Inst.
@@ -198,11 +288,27 @@ static void migrateDebugInfo(AllocaInst *OldAlloca,
Inst->setMetadata(LLVMContext::MD_DIAssignID, NewID);
}
- Value = Value ? Value : DbgAssign->getValue();
+ ::Value *NewValue = Value ? Value : DbgAssign->getValue();
auto *NewAssign = DIB.insertDbgAssign(
- Inst, Value, DbgAssign->getVariable(), Expr, Dest,
+ Inst, NewValue, DbgAssign->getVariable(), Expr, Dest,
DIExpression::get(Ctx, std::nullopt), DbgAssign->getDebugLoc());
+ // If we've updated the value but the original dbg.assign has an arglist
+ // then kill it now - we can't use the requested new value.
+ // We can't replace the DIArgList with the new value as it'd leave
+ // the DIExpression in an invalid state (DW_OP_LLVM_arg operands without
+ // an arglist). And we can't keep the DIArgList in case the linked store
+ // is being split - in which case the DIArgList + expression may no longer
+ // be computing the correct value.
+ // This should be a very rare situation as it requires the value being
+ // stored to differ from the dbg.assign (i.e., the value has been
+ // represented differently in the debug intrinsic for some reason).
+ SetKillLocation |=
+ Value && (DbgAssign->hasArgList() ||
+ !DbgAssign->getExpression()->isSingleLocationExpression());
+ if (SetKillLocation)
+ NewAssign->setKillLocation();
+
// We could use more precision here at the cost of some additional (code)
// complexity - if the original dbg.assign was adjacent to its store, we
// could position this new dbg.assign adjacent to its store rather than the
@@ -888,11 +994,12 @@ private:
if (!IsOffsetKnown)
return PI.setAborted(&LI);
- if (isa<ScalableVectorType>(LI.getType()))
+ TypeSize Size = DL.getTypeStoreSize(LI.getType());
+ if (Size.isScalable())
return PI.setAborted(&LI);
- uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedValue();
- return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile());
+ return handleLoadOrStore(LI.getType(), LI, Offset, Size.getFixedValue(),
+ LI.isVolatile());
}
void visitStoreInst(StoreInst &SI) {
@@ -902,10 +1009,11 @@ private:
if (!IsOffsetKnown)
return PI.setAborted(&SI);
- if (isa<ScalableVectorType>(ValOp->getType()))
+ TypeSize StoreSize = DL.getTypeStoreSize(ValOp->getType());
+ if (StoreSize.isScalable())
return PI.setAborted(&SI);
- uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedValue();
+ uint64_t Size = StoreSize.getFixedValue();
// If this memory access can be shown to *statically* extend outside the
// bounds of the allocation, it's behavior is undefined, so simply
@@ -1520,12 +1628,6 @@ static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI,
IRB.SetInsertPoint(&LI);
- if (auto *TypedPtrTy = LI.getPointerOperandType();
- !TypedPtrTy->isOpaquePointerTy() && SI.getType() != TypedPtrTy) {
- TV = IRB.CreateBitOrPointerCast(TV, TypedPtrTy, "");
- FV = IRB.CreateBitOrPointerCast(FV, TypedPtrTy, "");
- }
-
LoadInst *TL =
IRB.CreateAlignedLoad(LI.getType(), TV, LI.getAlign(),
LI.getName() + ".sroa.speculate.load.true");
@@ -1581,22 +1683,19 @@ static void rewriteMemOpOfSelect(SelectInst &SI, T &I,
bool IsThen = SuccBB == HeadBI->getSuccessor(0);
int SuccIdx = IsThen ? 0 : 1;
auto *NewMemOpBB = SuccBB == Tail ? Head : SuccBB;
+ auto &CondMemOp = cast<T>(*I.clone());
if (NewMemOpBB != Head) {
NewMemOpBB->setName(Head->getName() + (IsThen ? ".then" : ".else"));
if (isa<LoadInst>(I))
++NumLoadsPredicated;
else
++NumStoresPredicated;
- } else
+ } else {
+ CondMemOp.dropUBImplyingAttrsAndMetadata();
++NumLoadsSpeculated;
- auto &CondMemOp = cast<T>(*I.clone());
+ }
CondMemOp.insertBefore(NewMemOpBB->getTerminator());
Value *Ptr = SI.getOperand(1 + SuccIdx);
- if (auto *PtrTy = Ptr->getType();
- !PtrTy->isOpaquePointerTy() &&
- PtrTy != CondMemOp.getPointerOperandType())
- Ptr = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
- Ptr, CondMemOp.getPointerOperandType(), "", &CondMemOp);
CondMemOp.setOperand(I.getPointerOperandIndex(), Ptr);
if (isa<LoadInst>(I)) {
CondMemOp.setName(I.getName() + (IsThen ? ".then" : ".else") + ".val");
@@ -1654,238 +1753,16 @@ static bool rewriteSelectInstMemOps(SelectInst &SI,
return CFGChanged;
}
-/// Build a GEP out of a base pointer and indices.
-///
-/// This will return the BasePtr if that is valid, or build a new GEP
-/// instruction using the IRBuilder if GEP-ing is needed.
-static Value *buildGEP(IRBuilderTy &IRB, Value *BasePtr,
- SmallVectorImpl<Value *> &Indices,
- const Twine &NamePrefix) {
- if (Indices.empty())
- return BasePtr;
-
- // A single zero index is a no-op, so check for this and avoid building a GEP
- // in that case.
- if (Indices.size() == 1 && cast<ConstantInt>(Indices.back())->isZero())
- return BasePtr;
-
- // buildGEP() is only called for non-opaque pointers.
- return IRB.CreateInBoundsGEP(
- BasePtr->getType()->getNonOpaquePointerElementType(), BasePtr, Indices,
- NamePrefix + "sroa_idx");
-}
-
-/// Get a natural GEP off of the BasePtr walking through Ty toward
-/// TargetTy without changing the offset of the pointer.
-///
-/// This routine assumes we've already established a properly offset GEP with
-/// Indices, and arrived at the Ty type. The goal is to continue to GEP with
-/// zero-indices down through type layers until we find one the same as
-/// TargetTy. If we can't find one with the same type, we at least try to use
-/// one with the same size. If none of that works, we just produce the GEP as
-/// indicated by Indices to have the correct offset.
-static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL,
- Value *BasePtr, Type *Ty, Type *TargetTy,
- SmallVectorImpl<Value *> &Indices,
- const Twine &NamePrefix) {
- if (Ty == TargetTy)
- return buildGEP(IRB, BasePtr, Indices, NamePrefix);
-
- // Offset size to use for the indices.
- unsigned OffsetSize = DL.getIndexTypeSizeInBits(BasePtr->getType());
-
- // See if we can descend into a struct and locate a field with the correct
- // type.
- unsigned NumLayers = 0;
- Type *ElementTy = Ty;
- do {
- if (ElementTy->isPointerTy())
- break;
-
- if (ArrayType *ArrayTy = dyn_cast<ArrayType>(ElementTy)) {
- ElementTy = ArrayTy->getElementType();
- Indices.push_back(IRB.getIntN(OffsetSize, 0));
- } else if (VectorType *VectorTy = dyn_cast<VectorType>(ElementTy)) {
- ElementTy = VectorTy->getElementType();
- Indices.push_back(IRB.getInt32(0));
- } else if (StructType *STy = dyn_cast<StructType>(ElementTy)) {
- if (STy->element_begin() == STy->element_end())
- break; // Nothing left to descend into.
- ElementTy = *STy->element_begin();
- Indices.push_back(IRB.getInt32(0));
- } else {
- break;
- }
- ++NumLayers;
- } while (ElementTy != TargetTy);
- if (ElementTy != TargetTy)
- Indices.erase(Indices.end() - NumLayers, Indices.end());
-
- return buildGEP(IRB, BasePtr, Indices, NamePrefix);
-}
-
-/// Get a natural GEP from a base pointer to a particular offset and
-/// resulting in a particular type.
-///
-/// The goal is to produce a "natural" looking GEP that works with the existing
-/// composite types to arrive at the appropriate offset and element type for
-/// a pointer. TargetTy is the element type the returned GEP should point-to if
-/// possible. We recurse by decreasing Offset, adding the appropriate index to
-/// Indices, and setting Ty to the result subtype.
-///
-/// If no natural GEP can be constructed, this function returns null.
-static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL,
- Value *Ptr, APInt Offset, Type *TargetTy,
- SmallVectorImpl<Value *> &Indices,
- const Twine &NamePrefix) {
- PointerType *Ty = cast<PointerType>(Ptr->getType());
-
- // Don't consider any GEPs through an i8* as natural unless the TargetTy is
- // an i8.
- if (Ty == IRB.getInt8PtrTy(Ty->getAddressSpace()) && TargetTy->isIntegerTy(8))
- return nullptr;
-
- Type *ElementTy = Ty->getNonOpaquePointerElementType();
- if (!ElementTy->isSized())
- return nullptr; // We can't GEP through an unsized element.
-
- SmallVector<APInt> IntIndices = DL.getGEPIndicesForOffset(ElementTy, Offset);
- if (Offset != 0)
- return nullptr;
-
- for (const APInt &Index : IntIndices)
- Indices.push_back(IRB.getInt(Index));
- return getNaturalGEPWithType(IRB, DL, Ptr, ElementTy, TargetTy, Indices,
- NamePrefix);
-}
-
/// Compute an adjusted pointer from Ptr by Offset bytes where the
/// resulting pointer has PointerTy.
-///
-/// This tries very hard to compute a "natural" GEP which arrives at the offset
-/// and produces the pointer type desired. Where it cannot, it will try to use
-/// the natural GEP to arrive at the offset and bitcast to the type. Where that
-/// fails, it will try to use an existing i8* and GEP to the byte offset and
-/// bitcast to the type.
-///
-/// The strategy for finding the more natural GEPs is to peel off layers of the
-/// pointer, walking back through bit casts and GEPs, searching for a base
-/// pointer from which we can compute a natural GEP with the desired
-/// properties. The algorithm tries to fold as many constant indices into
-/// a single GEP as possible, thus making each GEP more independent of the
-/// surrounding code.
static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr,
APInt Offset, Type *PointerTy,
const Twine &NamePrefix) {
- // Create i8 GEP for opaque pointers.
- if (Ptr->getType()->isOpaquePointerTy()) {
- if (Offset != 0)
- Ptr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(Offset),
- NamePrefix + "sroa_idx");
- return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy,
- NamePrefix + "sroa_cast");
- }
-
- // Even though we don't look through PHI nodes, we could be called on an
- // instruction in an unreachable block, which may be on a cycle.
- SmallPtrSet<Value *, 4> Visited;
- Visited.insert(Ptr);
- SmallVector<Value *, 4> Indices;
-
- // We may end up computing an offset pointer that has the wrong type. If we
- // never are able to compute one directly that has the correct type, we'll
- // fall back to it, so keep it and the base it was computed from around here.
- Value *OffsetPtr = nullptr;
- Value *OffsetBasePtr;
-
- // Remember any i8 pointer we come across to re-use if we need to do a raw
- // byte offset.
- Value *Int8Ptr = nullptr;
- APInt Int8PtrOffset(Offset.getBitWidth(), 0);
-
- PointerType *TargetPtrTy = cast<PointerType>(PointerTy);
- Type *TargetTy = TargetPtrTy->getNonOpaquePointerElementType();
-
- // As `addrspacecast` is , `Ptr` (the storage pointer) may have different
- // address space from the expected `PointerTy` (the pointer to be used).
- // Adjust the pointer type based the original storage pointer.
- auto AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
- PointerTy = TargetTy->getPointerTo(AS);
-
- do {
- // First fold any existing GEPs into the offset.
- while (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) {
- APInt GEPOffset(Offset.getBitWidth(), 0);
- if (!GEP->accumulateConstantOffset(DL, GEPOffset))
- break;
- Offset += GEPOffset;
- Ptr = GEP->getPointerOperand();
- if (!Visited.insert(Ptr).second)
- break;
- }
-
- // See if we can perform a natural GEP here.
- Indices.clear();
- if (Value *P = getNaturalGEPWithOffset(IRB, DL, Ptr, Offset, TargetTy,
- Indices, NamePrefix)) {
- // If we have a new natural pointer at the offset, clear out any old
- // offset pointer we computed. Unless it is the base pointer or
- // a non-instruction, we built a GEP we don't need. Zap it.
- if (OffsetPtr && OffsetPtr != OffsetBasePtr)
- if (Instruction *I = dyn_cast<Instruction>(OffsetPtr)) {
- assert(I->use_empty() && "Built a GEP with uses some how!");
- I->eraseFromParent();
- }
- OffsetPtr = P;
- OffsetBasePtr = Ptr;
- // If we also found a pointer of the right type, we're done.
- if (P->getType() == PointerTy)
- break;
- }
-
- // Stash this pointer if we've found an i8*.
- if (Ptr->getType()->isIntegerTy(8)) {
- Int8Ptr = Ptr;
- Int8PtrOffset = Offset;
- }
-
- // Peel off a layer of the pointer and update the offset appropriately.
- if (Operator::getOpcode(Ptr) == Instruction::BitCast) {
- Ptr = cast<Operator>(Ptr)->getOperand(0);
- } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) {
- if (GA->isInterposable())
- break;
- Ptr = GA->getAliasee();
- } else {
- break;
- }
- assert(Ptr->getType()->isPointerTy() && "Unexpected operand type!");
- } while (Visited.insert(Ptr).second);
-
- if (!OffsetPtr) {
- if (!Int8Ptr) {
- Int8Ptr = IRB.CreateBitCast(
- Ptr, IRB.getInt8PtrTy(PointerTy->getPointerAddressSpace()),
- NamePrefix + "sroa_raw_cast");
- Int8PtrOffset = Offset;
- }
-
- OffsetPtr = Int8PtrOffset == 0
- ? Int8Ptr
- : IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Int8Ptr,
- IRB.getInt(Int8PtrOffset),
- NamePrefix + "sroa_raw_idx");
- }
- Ptr = OffsetPtr;
-
- // On the off chance we were targeting i8*, guard the bitcast here.
- if (cast<PointerType>(Ptr->getType()) != TargetPtrTy) {
- Ptr = IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr,
- TargetPtrTy,
- NamePrefix + "sroa_cast");
- }
-
- return Ptr;
+ if (Offset != 0)
+ Ptr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(Offset),
+ NamePrefix + "sroa_idx");
+ return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy,
+ NamePrefix + "sroa_cast");
}
/// Compute the adjusted alignment for a load or store from an offset.
@@ -2126,6 +2003,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Collect the candidate types for vector-based promotion. Also track whether
// we have different element types.
SmallVector<VectorType *, 4> CandidateTys;
+ SetVector<Type *> LoadStoreTys;
Type *CommonEltTy = nullptr;
VectorType *CommonVecPtrTy = nullptr;
bool HaveVecPtrTy = false;
@@ -2159,15 +2037,40 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
}
}
};
- // Consider any loads or stores that are the exact size of the slice.
- for (const Slice &S : P)
- if (S.beginOffset() == P.beginOffset() &&
- S.endOffset() == P.endOffset()) {
- if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser()))
- CheckCandidateType(LI->getType());
- else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser()))
- CheckCandidateType(SI->getValueOperand()->getType());
+ // Put load and store types into a set for de-duplication.
+ for (const Slice &S : P) {
+ Type *Ty;
+ if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser()))
+ Ty = LI->getType();
+ else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser()))
+ Ty = SI->getValueOperand()->getType();
+ else
+ continue;
+ LoadStoreTys.insert(Ty);
+ // Consider any loads or stores that are the exact size of the slice.
+ if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset())
+ CheckCandidateType(Ty);
+ }
+ // Consider additional vector types where the element type size is a
+ // multiple of load/store element size.
+ for (Type *Ty : LoadStoreTys) {
+ if (!VectorType::isValidElementType(Ty))
+ continue;
+ unsigned TypeSize = DL.getTypeSizeInBits(Ty).getFixedValue();
+ // Make a copy of CandidateTys and iterate through it, because we might
+ // append to CandidateTys in the loop.
+ SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys;
+ for (VectorType *&VTy : CandidateTysCopy) {
+ unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue();
+ unsigned ElementSize =
+ DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue();
+ if (TypeSize != VectorSize && TypeSize != ElementSize &&
+ VectorSize % TypeSize == 0) {
+ VectorType *NewVTy = VectorType::get(Ty, VectorSize / TypeSize, false);
+ CheckCandidateType(NewVTy);
+ }
}
+ }
// If we didn't find a vector type, nothing to do here.
if (CandidateTys.empty())
@@ -2195,7 +2098,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
// Rank the remaining candidate vector types. This is easy because we know
// they're all integer vectors. We sort by ascending number of elements.
- auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+ auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
(void)DL;
assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
@@ -2207,10 +2110,22 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
return cast<FixedVectorType>(RHSTy)->getNumElements() <
cast<FixedVectorType>(LHSTy)->getNumElements();
};
- llvm::sort(CandidateTys, RankVectorTypes);
- CandidateTys.erase(
- std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes),
- CandidateTys.end());
+ auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
+ (void)DL;
+ assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
+ DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
+ "Cannot have vector types of different sizes!");
+ assert(RHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ assert(LHSTy->getElementType()->isIntegerTy() &&
+ "All non-integer types eliminated!");
+ return cast<FixedVectorType>(RHSTy)->getNumElements() ==
+ cast<FixedVectorType>(LHSTy)->getNumElements();
+ };
+ llvm::sort(CandidateTys, RankVectorTypesComp);
+ CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(),
+ RankVectorTypesEq),
+ CandidateTys.end());
} else {
// The only way to have the same element type in every vector type is to
// have the same vector type. Check that and remove all but one.
@@ -2554,7 +2469,6 @@ class llvm::sroa::AllocaSliceRewriter
// original alloca.
uint64_t NewBeginOffset = 0, NewEndOffset = 0;
- uint64_t RelativeOffset = 0;
uint64_t SliceSize = 0;
bool IsSplittable = false;
bool IsSplit = false;
@@ -2628,14 +2542,13 @@ public:
NewBeginOffset = std::max(BeginOffset, NewAllocaBeginOffset);
NewEndOffset = std::min(EndOffset, NewAllocaEndOffset);
- RelativeOffset = NewBeginOffset - BeginOffset;
SliceSize = NewEndOffset - NewBeginOffset;
LLVM_DEBUG(dbgs() << " Begin:(" << BeginOffset << ", " << EndOffset
<< ") NewBegin:(" << NewBeginOffset << ", "
<< NewEndOffset << ") NewAllocaBegin:("
<< NewAllocaBeginOffset << ", " << NewAllocaEndOffset
<< ")\n");
- assert(IsSplit || RelativeOffset == 0);
+ assert(IsSplit || NewBeginOffset == BeginOffset);
OldUse = I->getUse();
OldPtr = cast<Instruction>(OldUse->get());
@@ -2898,8 +2811,8 @@ private:
Pass.DeadInsts.push_back(&SI);
// NOTE: Careful to use OrigV rather than V.
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store,
- Store->getPointerOperand(), OrigV, DL);
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
+ Store, Store->getPointerOperand(), OrigV, DL);
LLVM_DEBUG(dbgs() << " to: " << *Store << "\n");
return true;
}
@@ -2923,8 +2836,9 @@ private:
if (AATags)
Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store,
- Store->getPointerOperand(), Store->getValueOperand(), DL);
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
+ Store, Store->getPointerOperand(),
+ Store->getValueOperand(), DL);
Pass.DeadInsts.push_back(&SI);
LLVM_DEBUG(dbgs() << " to: " << *Store << "\n");
@@ -3002,8 +2916,9 @@ private:
if (NewSI->isAtomic())
NewSI->setAlignment(SI.getAlign());
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, NewSI,
- NewSI->getPointerOperand(), NewSI->getValueOperand(), DL);
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
+ NewSI, NewSI->getPointerOperand(),
+ NewSI->getValueOperand(), DL);
Pass.DeadInsts.push_back(&SI);
deleteIfTriviallyDead(OldOp);
@@ -3103,8 +3018,8 @@ private:
if (AATags)
New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New,
- New->getRawDest(), nullptr, DL);
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
+ New, New->getRawDest(), nullptr, DL);
LLVM_DEBUG(dbgs() << " to: " << *New << "\n");
return false;
@@ -3179,8 +3094,8 @@ private:
if (AATags)
New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New,
- New->getPointerOperand(), V, DL);
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
+ New, New->getPointerOperand(), V, DL);
LLVM_DEBUG(dbgs() << " to: " << *New << "\n");
return !II.isVolatile();
@@ -3308,8 +3223,16 @@ private:
if (AATags)
New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New,
- DestPtr, nullptr, DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(DestPtr->getType()), 0);
+ if (IsDest) {
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8,
+ &II, New, DestPtr, nullptr, DL);
+ } else if (AllocaInst *Base = dyn_cast<AllocaInst>(
+ DestPtr->stripAndAccumulateConstantOffsets(
+ DL, Offset, /*AllowNonInbounds*/ true))) {
+ migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8,
+ SliceSize * 8, &II, New, DestPtr, nullptr, DL);
+ }
LLVM_DEBUG(dbgs() << " to: " << *New << "\n");
return false;
}
@@ -3397,8 +3320,18 @@ private:
if (AATags)
Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
- migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, Store,
- DstPtr, Src, DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(DstPtr->getType()), 0);
+ if (IsDest) {
+
+ migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
+ Store, DstPtr, Src, DL);
+ } else if (AllocaInst *Base = dyn_cast<AllocaInst>(
+ DstPtr->stripAndAccumulateConstantOffsets(
+ DL, Offset, /*AllowNonInbounds*/ true))) {
+ migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8, SliceSize * 8,
+ &II, Store, DstPtr, Src, DL);
+ }
+
LLVM_DEBUG(dbgs() << " to: " << *Store << "\n");
return !II.isVolatile();
}
@@ -3760,23 +3693,22 @@ private:
APInt Offset(
DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0);
- if (AATags &&
- GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset))
+ GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset);
+ if (AATags)
Store->setAAMetadata(AATags.shift(Offset.getZExtValue()));
// migrateDebugInfo requires the base Alloca. Walk to it from this gep.
// If we cannot (because there's an intervening non-const or unbounded
// gep) then we wouldn't expect to see dbg.assign intrinsics linked to
// this instruction.
- APInt OffsetInBytes(DL.getTypeSizeInBits(Ptr->getType()), false);
- Value *Base = InBoundsGEP->stripAndAccumulateInBoundsConstantOffsets(
- DL, OffsetInBytes);
+ Value *Base = AggStore->getPointerOperand()->stripInBoundsOffsets();
if (auto *OldAI = dyn_cast<AllocaInst>(Base)) {
uint64_t SizeInBits =
DL.getTypeSizeInBits(Store->getValueOperand()->getType());
- migrateDebugInfo(OldAI, OffsetInBytes.getZExtValue() * 8, SizeInBits,
- AggStore, Store, Store->getPointerOperand(),
- Store->getValueOperand(), DL);
+ migrateDebugInfo(OldAI, /*IsSplit*/ true, Offset.getZExtValue() * 8,
+ SizeInBits, AggStore, Store,
+ Store->getPointerOperand(), Store->getValueOperand(),
+ DL);
} else {
assert(at::getAssignmentMarkers(Store).empty() &&
"AT: unexpected debug.assign linked to store through "
@@ -3799,6 +3731,9 @@ private:
getAdjustedAlignment(&SI, 0), DL, IRB);
Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca");
Visited.erase(&SI);
+ // The stores replacing SI each have markers describing fragments of the
+ // assignment so delete the assignment markers linked to SI.
+ at::deleteAssignmentMarkers(&SI);
SI.eraseFromParent();
return true;
}
@@ -4029,6 +3964,10 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
return nullptr;
const StructLayout *SL = DL.getStructLayout(STy);
+
+ if (SL->getSizeInBits().isScalable())
+ return nullptr;
+
if (Offset >= SL->getSizeInBytes())
return nullptr;
uint64_t EndOffset = Offset + Size;
@@ -4869,11 +4808,13 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
// Migrate debug information from the old alloca to the new alloca(s)
// and the individual partitions.
- TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI);
+ TinyPtrVector<DbgVariableIntrinsic *> DbgVariables;
+ for (auto *DbgDeclare : FindDbgDeclareUses(&AI))
+ DbgVariables.push_back(DbgDeclare);
for (auto *DbgAssign : at::getAssignmentMarkers(&AI))
- DbgDeclares.push_back(DbgAssign);
- for (DbgVariableIntrinsic *DbgDeclare : DbgDeclares) {
- auto *Expr = DbgDeclare->getExpression();
+ DbgVariables.push_back(DbgAssign);
+ for (DbgVariableIntrinsic *DbgVariable : DbgVariables) {
+ auto *Expr = DbgVariable->getExpression();
DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false);
uint64_t AllocaSize =
DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedValue();
@@ -4905,7 +4846,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
}
// The alloca may be larger than the variable.
- auto VarSize = DbgDeclare->getVariable()->getSizeInBits();
+ auto VarSize = DbgVariable->getVariable()->getSizeInBits();
if (VarSize) {
if (Size > *VarSize)
Size = *VarSize;
@@ -4925,18 +4866,18 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
// Remove any existing intrinsics on the new alloca describing
// the variable fragment.
- for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) {
+ for (DbgDeclareInst *OldDII : FindDbgDeclareUses(Fragment.Alloca)) {
auto SameVariableFragment = [](const DbgVariableIntrinsic *LHS,
const DbgVariableIntrinsic *RHS) {
return LHS->getVariable() == RHS->getVariable() &&
LHS->getDebugLoc()->getInlinedAt() ==
RHS->getDebugLoc()->getInlinedAt();
};
- if (SameVariableFragment(OldDII, DbgDeclare))
+ if (SameVariableFragment(OldDII, DbgVariable))
OldDII->eraseFromParent();
}
- if (auto *DbgAssign = dyn_cast<DbgAssignIntrinsic>(DbgDeclare)) {
+ if (auto *DbgAssign = dyn_cast<DbgAssignIntrinsic>(DbgVariable)) {
if (!Fragment.Alloca->hasMetadata(LLVMContext::MD_DIAssignID)) {
Fragment.Alloca->setMetadata(
LLVMContext::MD_DIAssignID,
@@ -4950,8 +4891,8 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign
<< "\n");
} else {
- DIB.insertDeclare(Fragment.Alloca, DbgDeclare->getVariable(),
- FragmentExpr, DbgDeclare->getDebugLoc(), &AI);
+ DIB.insertDeclare(Fragment.Alloca, DbgVariable->getVariable(),
+ FragmentExpr, DbgVariable->getDebugLoc(), &AI);
}
}
}
@@ -4996,8 +4937,9 @@ SROAPass::runOnAlloca(AllocaInst &AI) {
// Skip alloca forms that this analysis can't handle.
auto *AT = AI.getAllocatedType();
- if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) ||
- DL.getTypeAllocSize(AT).getFixedValue() == 0)
+ TypeSize Size = DL.getTypeAllocSize(AT);
+ if (AI.isArrayAllocation() || !AT->isSized() || Size.isScalable() ||
+ Size.getFixedValue() == 0)
return {Changed, CFGChanged};
// First, split any FCA loads and stores touching this alloca to promote
@@ -5074,7 +5016,7 @@ bool SROAPass::deleteDeadInstructions(
// not be able to find it.
if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
DeletedAllocas.insert(AI);
- for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(AI))
+ for (DbgDeclareInst *OldDII : FindDbgDeclareUses(AI))
OldDII->eraseFromParent();
}
@@ -5107,8 +5049,13 @@ bool SROAPass::promoteAllocas(Function &F) {
NumPromoted += PromotableAllocas.size();
- LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n");
- PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC);
+ if (SROASkipMem2Reg) {
+ LLVM_DEBUG(dbgs() << "Not promoting allocas with mem2reg!\n");
+ } else {
+ LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n");
+ PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC);
+ }
+
PromotableAllocas.clear();
return true;
}
@@ -5120,16 +5067,16 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU,
DTU = &RunDTU;
AC = &RunAC;
+ const DataLayout &DL = F.getParent()->getDataLayout();
BasicBlock &EntryBB = F.getEntryBlock();
for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());
I != E; ++I) {
if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
- if (isa<ScalableVectorType>(AI->getAllocatedType())) {
- if (isAllocaPromotable(AI))
- PromotableAllocas.push_back(AI);
- } else {
+ if (DL.getTypeAllocSize(AI->getAllocatedType()).isScalable() &&
+ isAllocaPromotable(AI))
+ PromotableAllocas.push_back(AI);
+ else
Worklist.insert(AI);
- }
}
}
@@ -5172,6 +5119,11 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU,
if (!Changed)
return PreservedAnalyses::all();
+ if (isAssignmentTrackingEnabled(*F.getParent())) {
+ for (auto &BB : F)
+ RemoveRedundantDbgInstrs(&BB);
+ }
+
PreservedAnalyses PA;
if (!CFGChanged)
PA.preserveSet<CFGAnalyses>();
@@ -5186,8 +5138,9 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT,
}
PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) {
- return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F),
- AM.getResult<AssumptionAnalysis>(F));
+ DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
+ return runImpl(F, DT, AC);
}
void SROAPass::printPipeline(
diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp
index 8aee8d140a29..37b032e4d7c7 100644
--- a/llvm/lib/Transforms/Scalar/Scalar.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalar.cpp
@@ -12,76 +12,38 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Scalar.h"
-#include "llvm-c/Initialization.h"
-#include "llvm-c/Transforms/Scalar.h"
-#include "llvm/Analysis/BasicAliasAnalysis.h"
-#include "llvm/Analysis/ScopedNoAliasAA.h"
-#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
-#include "llvm/IR/LegacyPassManager.h"
-#include "llvm/IR/Verifier.h"
#include "llvm/InitializePasses.h"
-#include "llvm/Transforms/Scalar/GVN.h"
-#include "llvm/Transforms/Scalar/Scalarizer.h"
-#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
using namespace llvm;
/// initializeScalarOptsPasses - Initialize all passes linked into the
/// ScalarOpts library.
void llvm::initializeScalarOpts(PassRegistry &Registry) {
- initializeADCELegacyPassPass(Registry);
- initializeBDCELegacyPassPass(Registry);
- initializeAlignmentFromAssumptionsPass(Registry);
- initializeCallSiteSplittingLegacyPassPass(Registry);
initializeConstantHoistingLegacyPassPass(Registry);
- initializeCorrelatedValuePropagationPass(Registry);
initializeDCELegacyPassPass(Registry);
- initializeDivRemPairsLegacyPassPass(Registry);
initializeScalarizerLegacyPassPass(Registry);
- initializeDSELegacyPassPass(Registry);
initializeGuardWideningLegacyPassPass(Registry);
initializeLoopGuardWideningLegacyPassPass(Registry);
initializeGVNLegacyPassPass(Registry);
- initializeNewGVNLegacyPassPass(Registry);
initializeEarlyCSELegacyPassPass(Registry);
initializeEarlyCSEMemSSALegacyPassPass(Registry);
initializeMakeGuardsExplicitLegacyPassPass(Registry);
- initializeGVNHoistLegacyPassPass(Registry);
- initializeGVNSinkLegacyPassPass(Registry);
initializeFlattenCFGLegacyPassPass(Registry);
- initializeIRCELegacyPassPass(Registry);
- initializeIndVarSimplifyLegacyPassPass(Registry);
initializeInferAddressSpacesPass(Registry);
initializeInstSimplifyLegacyPassPass(Registry);
- initializeJumpThreadingPass(Registry);
- initializeDFAJumpThreadingLegacyPassPass(Registry);
initializeLegacyLICMPassPass(Registry);
initializeLegacyLoopSinkPassPass(Registry);
- initializeLoopFuseLegacyPass(Registry);
initializeLoopDataPrefetchLegacyPassPass(Registry);
- initializeLoopDeletionLegacyPassPass(Registry);
- initializeLoopAccessLegacyAnalysisPass(Registry);
initializeLoopInstSimplifyLegacyPassPass(Registry);
- initializeLoopInterchangeLegacyPassPass(Registry);
- initializeLoopFlattenLegacyPassPass(Registry);
initializeLoopPredicationLegacyPassPass(Registry);
initializeLoopRotateLegacyPassPass(Registry);
initializeLoopStrengthReducePass(Registry);
- initializeLoopRerollLegacyPassPass(Registry);
initializeLoopUnrollPass(Registry);
- initializeLoopUnrollAndJamPass(Registry);
- initializeWarnMissedTransformationsLegacyPass(Registry);
- initializeLoopVersioningLICMLegacyPassPass(Registry);
- initializeLoopIdiomRecognizeLegacyPassPass(Registry);
initializeLowerAtomicLegacyPassPass(Registry);
initializeLowerConstantIntrinsicsPass(Registry);
initializeLowerExpectIntrinsicPass(Registry);
initializeLowerGuardIntrinsicLegacyPassPass(Registry);
- initializeLowerMatrixIntrinsicsLegacyPassPass(Registry);
- initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(Registry);
initializeLowerWidenableConditionLegacyPassPass(Registry);
- initializeMemCpyOptLegacyPassPass(Registry);
initializeMergeICmpsLegacyPassPass(Registry);
initializeMergedLoadStoreMotionLegacyPassPass(Registry);
initializeNaryReassociateLegacyPassPass(Registry);
@@ -89,9 +51,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeReassociateLegacyPassPass(Registry);
initializeRedundantDbgInstEliminationPass(Registry);
initializeRegToMemLegacyPass(Registry);
- initializeRewriteStatepointsForGCLegacyPassPass(Registry);
initializeScalarizeMaskedMemIntrinLegacyPassPass(Registry);
- initializeSCCPLegacyPassPass(Registry);
initializeSROALegacyPassPass(Registry);
initializeCFGSimplifyPassPass(Registry);
initializeStructurizeCFGLegacyPassPass(Registry);
@@ -102,196 +62,6 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeSeparateConstOffsetFromGEPLegacyPassPass(Registry);
initializeSpeculativeExecutionLegacyPassPass(Registry);
initializeStraightLineStrengthReduceLegacyPassPass(Registry);
- initializePlaceBackedgeSafepointsImplPass(Registry);
- initializePlaceSafepointsPass(Registry);
- initializeFloat2IntLegacyPassPass(Registry);
- initializeLoopDistributeLegacyPass(Registry);
- initializeLoopLoadEliminationPass(Registry);
+ initializePlaceBackedgeSafepointsLegacyPassPass(Registry);
initializeLoopSimplifyCFGLegacyPassPass(Registry);
- initializeLoopVersioningLegacyPassPass(Registry);
-}
-
-void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopSimplifyCFGPass());
-}
-
-void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) {
- initializeScalarOpts(*unwrap(R));
-}
-
-void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createAggressiveDCEPass());
-}
-
-void LLVMAddDCEPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createDeadCodeEliminationPass());
-}
-
-void LLVMAddBitTrackingDCEPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createBitTrackingDCEPass());
-}
-
-void LLVMAddAlignmentFromAssumptionsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createAlignmentFromAssumptionsPass());
-}
-
-void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createCFGSimplificationPass());
-}
-
-void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createDeadStoreEliminationPass());
-}
-
-void LLVMAddScalarizerPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createScalarizerPass());
-}
-
-void LLVMAddGVNPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createGVNPass());
-}
-
-void LLVMAddNewGVNPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createNewGVNPass());
-}
-
-void LLVMAddMergedLoadStoreMotionPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createMergedLoadStoreMotionPass());
-}
-
-void LLVMAddIndVarSimplifyPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createIndVarSimplifyPass());
-}
-
-void LLVMAddInstructionSimplifyPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createInstSimplifyLegacyPass());
-}
-
-void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createJumpThreadingPass());
-}
-
-void LLVMAddLoopSinkPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopSinkPass());
-}
-
-void LLVMAddLICMPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLICMPass());
-}
-
-void LLVMAddLoopDeletionPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopDeletionPass());
-}
-
-void LLVMAddLoopFlattenPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopFlattenPass());
-}
-
-void LLVMAddLoopIdiomPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopIdiomPass());
-}
-
-void LLVMAddLoopRotatePass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopRotatePass());
-}
-
-void LLVMAddLoopRerollPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopRerollPass());
-}
-
-void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopUnrollPass());
-}
-
-void LLVMAddLoopUnrollAndJamPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopUnrollAndJamPass());
-}
-
-void LLVMAddLowerAtomicPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLowerAtomicPass());
-}
-
-void LLVMAddMemCpyOptPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createMemCpyOptPass());
-}
-
-void LLVMAddPartiallyInlineLibCallsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createPartiallyInlineLibCallsPass());
-}
-
-void LLVMAddReassociatePass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createReassociatePass());
-}
-
-void LLVMAddSCCPPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createSCCPPass());
-}
-
-void LLVMAddScalarReplAggregatesPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createSROAPass());
-}
-
-void LLVMAddScalarReplAggregatesPassSSA(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createSROAPass());
-}
-
-void LLVMAddScalarReplAggregatesPassWithThreshold(LLVMPassManagerRef PM,
- int Threshold) {
- unwrap(PM)->add(createSROAPass());
-}
-
-void LLVMAddSimplifyLibCallsPass(LLVMPassManagerRef PM) {
- // NOTE: The simplify-libcalls pass has been removed.
-}
-
-void LLVMAddTailCallEliminationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createTailCallEliminationPass());
-}
-
-void LLVMAddDemoteMemoryToRegisterPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createDemoteRegisterToMemoryPass());
-}
-
-void LLVMAddVerifierPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createVerifierPass());
-}
-
-void LLVMAddCorrelatedValuePropagationPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createCorrelatedValuePropagationPass());
-}
-
-void LLVMAddEarlyCSEPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createEarlyCSEPass(false/*=UseMemorySSA*/));
-}
-
-void LLVMAddEarlyCSEMemSSAPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createEarlyCSEPass(true/*=UseMemorySSA*/));
-}
-
-void LLVMAddGVNHoistLegacyPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createGVNHoistPass());
-}
-
-void LLVMAddTypeBasedAliasAnalysisPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createTypeBasedAAWrapperPass());
-}
-
-void LLVMAddScopedNoAliasAAPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createScopedNoAliasAAWrapperPass());
-}
-
-void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createBasicAAWrapperPass());
-}
-
-void LLVMAddLowerConstantIntrinsicsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLowerConstantIntrinsicsPass());
-}
-
-void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLowerExpectIntrinsicPass());
-}
-
-void LLVMAddUnifyFunctionExitNodesPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createUnifyFunctionExitNodesPass());
}
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 1c8e4e3512dc..c01d03f64472 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -125,7 +125,7 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
// br label %else
//
// else: ; preds = %0, %cond.load
-// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
+// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
// %6 = extractelement <16 x i1> %mask, i32 1
// br i1 %6, label %cond.load1, label %else2
//
@@ -170,10 +170,6 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
// Adjust alignment for the scalar instruction.
const Align AdjustedAlignVal =
commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
- // Bitcast %addr from i8* to EltTy*
- Type *NewPtrType =
- EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
- Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
// The result vector
@@ -183,7 +179,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
continue;
- Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
VResult = Builder.CreateInsertElement(VResult, Load, Idx);
}
@@ -232,7 +228,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
CondBlock->setName("cond.load");
Builder.SetInsertPoint(CondBlock->getTerminator());
- Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
@@ -309,10 +305,6 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// Adjust alignment for the scalar instruction.
const Align AdjustedAlignVal =
commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
- // Bitcast %addr from i8* to EltTy*
- Type *NewPtrType =
- EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
- Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
if (isConstantIntVector(Mask)) {
@@ -320,7 +312,7 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
continue;
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
- Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
}
CI->eraseFromParent();
@@ -367,7 +359,7 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
Builder.SetInsertPoint(CondBlock->getTerminator());
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
- Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
// Create "else" block, fill it in the next iteration
@@ -394,11 +386,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// cond.load:
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
// %Load0 = load i32, i32* %Ptr0, align 4
-// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
+// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
// br label %else
//
// else:
-// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
+// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
// br i1 %Mask1, label %cond.load1, label %else2
//
@@ -653,16 +645,16 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
Value *VResult = PassThru;
// Shorten the way if the mask is a vector of constants.
- // Create a build_vector pattern, with loads/undefs as necessary and then
+ // Create a build_vector pattern, with loads/poisons as necessary and then
// shuffle blend with the pass through value.
if (isConstantIntVector(Mask)) {
unsigned MemIndex = 0;
VResult = PoisonValue::get(VecType);
- SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
+ SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Value *InsertElt;
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
- InsertElt = UndefValue::get(EltTy);
+ InsertElt = PoisonValue::get(EltTy);
ShuffleMask[Idx] = Idx + VectorWidth;
} else {
Value *NewPtr =
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 4aab88b74f10..86b55dfd304a 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
-// This pass converts vector operations into scalar operations, in order
-// to expose optimization opportunities on the individual scalar operations.
+// This pass converts vector operations into scalar operations (or, optionally,
+// operations on smaller vector widths), in order to expose optimization
+// opportunities on the individual scalar operations.
// It is mainly intended for targets that do not have vector units, but it
// may also be useful for revectorizing code to different vector widths.
//
@@ -62,6 +63,16 @@ static cl::opt<bool> ClScalarizeLoadStore(
"scalarize-load-store", cl::init(false), cl::Hidden,
cl::desc("Allow the scalarizer pass to scalarize loads and store"));
+// Split vectors larger than this size into fragments, where each fragment is
+// either a vector no larger than this size or a scalar.
+//
+// Instructions with operands or results of different sizes that would be split
+// into a different number of fragments are currently left as-is.
+static cl::opt<unsigned> ClScalarizeMinBits(
+ "scalarize-min-bits", cl::init(0), cl::Hidden,
+ cl::desc("Instruct the scalarizer pass to attempt to keep values of a "
+ "minimum number of bits"));
+
namespace {
BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
@@ -88,6 +99,29 @@ using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
// along with a pointer to their scattered forms.
using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
+struct VectorSplit {
+ // The type of the vector.
+ FixedVectorType *VecTy = nullptr;
+
+ // The number of elements packed in a fragment (other than the remainder).
+ unsigned NumPacked = 0;
+
+ // The number of fragments (scalars or smaller vectors) into which the vector
+ // shall be split.
+ unsigned NumFragments = 0;
+
+ // The type of each complete fragment.
+ Type *SplitTy = nullptr;
+
+ // The type of the remainder (last) fragment; null if all fragments are
+ // complete.
+ Type *RemainderTy = nullptr;
+
+ Type *getFragmentType(unsigned I) const {
+ return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
+ }
+};
+
// Provides a very limited vector-like interface for lazily accessing one
// component of a scattered vector or vector pointer.
class Scatterer {
@@ -97,23 +131,23 @@ public:
// Scatter V into Size components. If new instructions are needed,
// insert them before BBI in BB. If Cache is nonnull, use it to cache
// the results.
- Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy,
- ValueVector *cachePtr = nullptr);
+ Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
+ const VectorSplit &VS, ValueVector *cachePtr = nullptr);
// Return component I, creating a new Value for it if necessary.
Value *operator[](unsigned I);
// Return the number of components.
- unsigned size() const { return Size; }
+ unsigned size() const { return VS.NumFragments; }
private:
BasicBlock *BB;
BasicBlock::iterator BBI;
Value *V;
- Type *PtrElemTy;
+ VectorSplit VS;
+ bool IsPointer;
ValueVector *CachePtr;
ValueVector Tmp;
- unsigned Size;
};
// FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
@@ -171,24 +205,74 @@ struct BinarySplitter {
struct VectorLayout {
VectorLayout() = default;
- // Return the alignment of element I.
- Align getElemAlign(unsigned I) {
- return commonAlignment(VecAlign, I * ElemSize);
+ // Return the alignment of fragment Frag.
+ Align getFragmentAlign(unsigned Frag) {
+ return commonAlignment(VecAlign, Frag * SplitSize);
}
- // The type of the vector.
- FixedVectorType *VecTy = nullptr;
-
- // The type of each element.
- Type *ElemTy = nullptr;
+ // The split of the underlying vector type.
+ VectorSplit VS;
// The alignment of the vector.
Align VecAlign;
- // The size of each element.
- uint64_t ElemSize = 0;
+ // The size of each (non-remainder) fragment in bytes.
+ uint64_t SplitSize = 0;
};
+/// Concatenate the given fragments to a single vector value of the type
+/// described in @p VS.
+static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
+ const VectorSplit &VS, Twine Name) {
+ unsigned NumElements = VS.VecTy->getNumElements();
+ SmallVector<int> ExtendMask;
+ SmallVector<int> InsertMask;
+
+ if (VS.NumPacked > 1) {
+ // Prepare the shufflevector masks once and re-use them for all
+ // fragments.
+ ExtendMask.resize(NumElements, -1);
+ for (unsigned I = 0; I < VS.NumPacked; ++I)
+ ExtendMask[I] = I;
+
+ InsertMask.resize(NumElements);
+ for (unsigned I = 0; I < NumElements; ++I)
+ InsertMask[I] = I;
+ }
+
+ Value *Res = PoisonValue::get(VS.VecTy);
+ for (unsigned I = 0; I < VS.NumFragments; ++I) {
+ Value *Fragment = Fragments[I];
+
+ unsigned NumPacked = VS.NumPacked;
+ if (I == VS.NumFragments - 1 && VS.RemainderTy) {
+ if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy))
+ NumPacked = RemVecTy->getNumElements();
+ else
+ NumPacked = 1;
+ }
+
+ if (NumPacked == 1) {
+ Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked,
+ Name + ".upto" + Twine(I));
+ } else {
+ Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask);
+ if (I == 0) {
+ Res = Fragment;
+ } else {
+ for (unsigned J = 0; J < NumPacked; ++J)
+ InsertMask[I * VS.NumPacked + J] = NumElements + J;
+ Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask,
+ Name + ".upto" + Twine(I));
+ for (unsigned J = 0; J < NumPacked; ++J)
+ InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
+ }
+ }
+ }
+
+ return Res;
+}
+
template <typename T>
T getWithDefaultOverride(const cl::opt<T> &ClOption,
const std::optional<T> &DefaultOverride) {
@@ -205,8 +289,9 @@ public:
getWithDefaultOverride(ClScalarizeVariableInsertExtract,
Options.ScalarizeVariableInsertExtract)),
ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
- Options.ScalarizeLoadStore)) {
- }
+ Options.ScalarizeLoadStore)),
+ ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
+ Options.ScalarizeMinBits)) {}
bool visit(Function &F);
@@ -228,13 +313,15 @@ public:
bool visitLoadInst(LoadInst &LI);
bool visitStoreInst(StoreInst &SI);
bool visitCallInst(CallInst &ICI);
+ bool visitFreezeInst(FreezeInst &FI);
private:
- Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr);
- void gather(Instruction *Op, const ValueVector &CV);
+ Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
+ void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
void replaceUses(Instruction *Op, Value *CV);
bool canTransferMetadata(unsigned Kind);
void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
+ std::optional<VectorSplit> getVectorSplit(Type *Ty);
std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
const DataLayout &DL);
bool finish();
@@ -256,6 +343,7 @@ private:
const bool ScalarizeVariableInsertExtract;
const bool ScalarizeLoadStore;
+ const unsigned ScalarizeMinBits;
};
class ScalarizerLegacyPass : public FunctionPass {
@@ -284,42 +372,47 @@ INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
"Scalarize vector operations", false, false)
Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
- Type *PtrElemTy, ValueVector *cachePtr)
- : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) {
- Type *Ty = V->getType();
- if (Ty->isPointerTy()) {
- assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) &&
- "Pointer element type mismatch");
- Ty = PtrElemTy;
+ const VectorSplit &VS, ValueVector *cachePtr)
+ : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
+ IsPointer = V->getType()->isPointerTy();
+ if (!CachePtr) {
+ Tmp.resize(VS.NumFragments, nullptr);
+ } else {
+ assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
+ IsPointer) &&
+ "Inconsistent vector sizes");
+ if (VS.NumFragments > CachePtr->size())
+ CachePtr->resize(VS.NumFragments, nullptr);
}
- Size = cast<FixedVectorType>(Ty)->getNumElements();
- if (!CachePtr)
- Tmp.resize(Size, nullptr);
- else if (CachePtr->empty())
- CachePtr->resize(Size, nullptr);
- else
- assert(Size == CachePtr->size() && "Inconsistent vector sizes");
}
-// Return component I, creating a new Value for it if necessary.
-Value *Scatterer::operator[](unsigned I) {
- ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
+// Return fragment Frag, creating a new Value for it if necessary.
+Value *Scatterer::operator[](unsigned Frag) {
+ ValueVector &CV = CachePtr ? *CachePtr : Tmp;
// Try to reuse a previous value.
- if (CV[I])
- return CV[I];
+ if (CV[Frag])
+ return CV[Frag];
IRBuilder<> Builder(BB, BBI);
- if (PtrElemTy) {
- Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType();
- if (!CV[0]) {
- Type *NewPtrTy = PointerType::get(
- VectorElemTy, V->getType()->getPointerAddressSpace());
- CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0");
- }
- if (I != 0)
- CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I,
- V->getName() + ".i" + Twine(I));
+ if (IsPointer) {
+ if (Frag == 0)
+ CV[Frag] = V;
+ else
+ CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag,
+ V->getName() + ".i" + Twine(Frag));
+ return CV[Frag];
+ }
+
+ Type *FragmentTy = VS.getFragmentType(Frag);
+
+ if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) {
+ SmallVector<int> Mask;
+ for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
+ Mask.push_back(Frag * VS.NumPacked + J);
+ CV[Frag] =
+ Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask,
+ V->getName() + ".i" + Twine(Frag));
} else {
- // Search through a chain of InsertElementInsts looking for element I.
+ // Search through a chain of InsertElementInsts looking for element Frag.
// Record other elements in the cache. The new V is still suitable
// for all uncached indices.
while (true) {
@@ -331,20 +424,23 @@ Value *Scatterer::operator[](unsigned I) {
break;
unsigned J = Idx->getZExtValue();
V = Insert->getOperand(0);
- if (I == J) {
- CV[J] = Insert->getOperand(1);
- return CV[J];
- } else if (!CV[J]) {
+ if (Frag * VS.NumPacked == J) {
+ CV[Frag] = Insert->getOperand(1);
+ return CV[Frag];
+ }
+
+ if (VS.NumPacked == 1 && !CV[J]) {
// Only cache the first entry we find for each index we're not actively
// searching for. This prevents us from going too far up the chain and
// caching incorrect entries.
CV[J] = Insert->getOperand(1);
}
}
- CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
- V->getName() + ".i" + Twine(I));
+ CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked,
+ V->getName() + ".i" + Twine(Frag));
}
- return CV[I];
+
+ return CV[Frag];
}
bool ScalarizerLegacyPass::runOnFunction(Function &F) {
@@ -386,13 +482,13 @@ bool ScalarizerVisitor::visit(Function &F) {
// Return a scattered form of V that can be accessed by Point. V must be a
// vector or a pointer to a vector.
Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
- Type *PtrElemTy) {
+ const VectorSplit &VS) {
if (Argument *VArg = dyn_cast<Argument>(V)) {
// Put the scattered form of arguments in the entry block,
// so that it can be used everywhere.
Function *F = VArg->getParent();
BasicBlock *BB = &F->getEntryBlock();
- return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[{V, PtrElemTy}]);
+ return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
}
if (Instruction *VOp = dyn_cast<Instruction>(V)) {
// When scalarizing PHI nodes we might try to examine/rewrite InsertElement
@@ -403,29 +499,30 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
// need to analyse them further.
if (!DT->isReachableFromEntry(VOp->getParent()))
return Scatterer(Point->getParent(), Point->getIterator(),
- PoisonValue::get(V->getType()), PtrElemTy);
+ PoisonValue::get(V->getType()), VS);
// Put the scattered form of an instruction directly after the
// instruction, skipping over PHI nodes and debug intrinsics.
BasicBlock *BB = VOp->getParent();
return Scatterer(
- BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V,
- PtrElemTy, &Scattered[{V, PtrElemTy}]);
+ BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS,
+ &Scattered[{V, VS.SplitTy}]);
}
// In the fallback case, just put the scattered before Point and
// keep the result local to Point.
- return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy);
+ return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
}
// Replace Op with the gathered form of the components in CV. Defer the
// deletion of Op and creation of the gathered form to the end of the pass,
// so that we can avoid creating the gathered form if all uses of Op are
// replaced with uses of CV.
-void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
+void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
+ const VectorSplit &VS) {
transferMetadataAndIRFlags(Op, CV);
// If we already have a scattered form of Op (created from ExtractElements
// of Op itself), replace them with the new form.
- ValueVector &SV = Scattered[{Op, nullptr}];
+ ValueVector &SV = Scattered[{Op, VS.SplitTy}];
if (!SV.empty()) {
for (unsigned I = 0, E = SV.size(); I != E; ++I) {
Value *V = SV[I];
@@ -483,23 +580,57 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
}
}
+// Determine how Ty is split, if at all.
+std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
+ VectorSplit Split;
+ Split.VecTy = dyn_cast<FixedVectorType>(Ty);
+ if (!Split.VecTy)
+ return {};
+
+ unsigned NumElems = Split.VecTy->getNumElements();
+ Type *ElemTy = Split.VecTy->getElementType();
+
+ if (NumElems == 1 || ElemTy->isPointerTy() ||
+ 2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
+ Split.NumPacked = 1;
+ Split.NumFragments = NumElems;
+ Split.SplitTy = ElemTy;
+ } else {
+ Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
+ if (Split.NumPacked >= NumElems)
+ return {};
+
+ Split.NumFragments = divideCeil(NumElems, Split.NumPacked);
+ Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked);
+
+ unsigned RemainderElems = NumElems % Split.NumPacked;
+ if (RemainderElems > 1)
+ Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems);
+ else if (RemainderElems == 1)
+ Split.RemainderTy = ElemTy;
+ }
+
+ return Split;
+}
+
// Try to fill in Layout from Ty, returning true on success. Alignment is
// the alignment of the vector, or std::nullopt if the ABI default should be
// used.
std::optional<VectorLayout>
ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
const DataLayout &DL) {
+ std::optional<VectorSplit> VS = getVectorSplit(Ty);
+ if (!VS)
+ return {};
+
VectorLayout Layout;
- // Make sure we're dealing with a vector.
- Layout.VecTy = dyn_cast<FixedVectorType>(Ty);
- if (!Layout.VecTy)
- return std::nullopt;
- // Check that we're dealing with full-byte elements.
- Layout.ElemTy = Layout.VecTy->getElementType();
- if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy))
- return std::nullopt;
+ Layout.VS = *VS;
+ // Check that we're dealing with full-byte fragments.
+ if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) ||
+ (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy)))
+ return {};
Layout.VecAlign = Alignment;
- Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
+ Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy);
return Layout;
}
@@ -507,19 +638,27 @@ ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
// to create an instruction like I with operand X and name Name.
template<typename Splitter>
bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
- auto *VT = dyn_cast<FixedVectorType>(I.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(I.getType());
+ if (!VS)
return false;
- unsigned NumElems = VT->getNumElements();
+ std::optional<VectorSplit> OpVS;
+ if (I.getOperand(0)->getType() == I.getType()) {
+ OpVS = VS;
+ } else {
+ OpVS = getVectorSplit(I.getOperand(0)->getType());
+ if (!OpVS || VS->NumPacked != OpVS->NumPacked)
+ return false;
+ }
+
IRBuilder<> Builder(&I);
- Scatterer Op = scatter(&I, I.getOperand(0));
- assert(Op.size() == NumElems && "Mismatched unary operation");
+ Scatterer Op = scatter(&I, I.getOperand(0), *OpVS);
+ assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
ValueVector Res;
- Res.resize(NumElems);
- for (unsigned Elem = 0; Elem < NumElems; ++Elem)
- Res[Elem] = Split(Builder, Op[Elem], I.getName() + ".i" + Twine(Elem));
- gather(&I, Res);
+ Res.resize(VS->NumFragments);
+ for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
+ Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
+ gather(&I, Res, *VS);
return true;
}
@@ -527,24 +666,32 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
// to create an instruction like I with operands X and Y and name Name.
template<typename Splitter>
bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
- auto *VT = dyn_cast<FixedVectorType>(I.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(I.getType());
+ if (!VS)
return false;
- unsigned NumElems = VT->getNumElements();
+ std::optional<VectorSplit> OpVS;
+ if (I.getOperand(0)->getType() == I.getType()) {
+ OpVS = VS;
+ } else {
+ OpVS = getVectorSplit(I.getOperand(0)->getType());
+ if (!OpVS || VS->NumPacked != OpVS->NumPacked)
+ return false;
+ }
+
IRBuilder<> Builder(&I);
- Scatterer VOp0 = scatter(&I, I.getOperand(0));
- Scatterer VOp1 = scatter(&I, I.getOperand(1));
- assert(VOp0.size() == NumElems && "Mismatched binary operation");
- assert(VOp1.size() == NumElems && "Mismatched binary operation");
+ Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS);
+ Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS);
+ assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
+ assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
ValueVector Res;
- Res.resize(NumElems);
- for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
- Value *Op0 = VOp0[Elem];
- Value *Op1 = VOp1[Elem];
- Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem));
+ Res.resize(VS->NumFragments);
+ for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
+ Value *Op0 = VOp0[Frag];
+ Value *Op1 = VOp1[Frag];
+ Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
}
- gather(&I, Res);
+ gather(&I, Res, *VS);
return true;
}
@@ -552,18 +699,11 @@ static bool isTriviallyScalariable(Intrinsic::ID ID) {
return isTriviallyVectorizable(ID);
}
-// All of the current scalarizable intrinsics only have one mangled type.
-static Function *getScalarIntrinsicDeclaration(Module *M,
- Intrinsic::ID ID,
- ArrayRef<Type*> Tys) {
- return Intrinsic::getDeclaration(M, ID, Tys);
-}
-
/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
- auto *VT = dyn_cast<FixedVectorType>(CI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
+ if (!VS)
return false;
Function *F = CI.getCalledFunction();
@@ -574,26 +714,41 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
return false;
- unsigned NumElems = VT->getNumElements();
+ // unsigned NumElems = VT->getNumElements();
unsigned NumArgs = CI.arg_size();
ValueVector ScalarOperands(NumArgs);
SmallVector<Scatterer, 8> Scattered(NumArgs);
-
- Scattered.resize(NumArgs);
+ SmallVector<int> OverloadIdx(NumArgs, -1);
SmallVector<llvm::Type *, 3> Tys;
- Tys.push_back(VT->getScalarType());
+ // Add return type if intrinsic is overloaded on it.
+ if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+ Tys.push_back(VS->SplitTy);
// Assumes that any vector type has the same number of elements as the return
// vector type, which is true for all current intrinsics.
for (unsigned I = 0; I != NumArgs; ++I) {
Value *OpI = CI.getOperand(I);
- if (OpI->getType()->isVectorTy()) {
- Scattered[I] = scatter(&CI, OpI);
- assert(Scattered[I].size() == NumElems && "mismatched call operands");
- if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
- Tys.push_back(OpI->getType()->getScalarType());
+ if (auto *OpVecTy = dyn_cast<FixedVectorType>(OpI->getType())) {
+ assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
+ std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
+ if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
+ // The natural split of the operand doesn't match the result. This could
+ // happen if the vector elements are different and the ScalarizeMinBits
+ // option is used.
+ //
+ // We could in principle handle this case as well, at the cost of
+ // complicating the scattering machinery to support multiple scattering
+ // granularities for a single value.
+ return false;
+ }
+
+ Scattered[I] = scatter(&CI, OpI, *OpVS);
+ if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+ OverloadIdx[I] = Tys.size();
+ Tys.push_back(OpVS->SplitTy);
+ }
} else {
ScalarOperands[I] = OpI;
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -601,49 +756,67 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
}
}
- ValueVector Res(NumElems);
+ ValueVector Res(VS->NumFragments);
ValueVector ScalarCallOps(NumArgs);
- Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, Tys);
+ Function *NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
IRBuilder<> Builder(&CI);
// Perform actual scalarization, taking care to preserve any scalar operands.
- for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
+ bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
ScalarCallOps.clear();
+ if (IsRemainder)
+ Tys[0] = VS->RemainderTy;
+
for (unsigned J = 0; J != NumArgs; ++J) {
- if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
ScalarCallOps.push_back(ScalarOperands[J]);
- else
- ScalarCallOps.push_back(Scattered[J][Elem]);
+ } else {
+ ScalarCallOps.push_back(Scattered[J][I]);
+ if (IsRemainder && OverloadIdx[J] >= 0)
+ Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
+ }
}
- Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps,
- CI.getName() + ".i" + Twine(Elem));
+ if (IsRemainder)
+ NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
+
+ Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps,
+ CI.getName() + ".i" + Twine(I));
}
- gather(&CI, Res);
+ gather(&CI, Res, *VS);
return true;
}
bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
- auto *VT = dyn_cast<FixedVectorType>(SI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(SI.getType());
+ if (!VS)
return false;
- unsigned NumElems = VT->getNumElements();
+ std::optional<VectorSplit> CondVS;
+ if (isa<FixedVectorType>(SI.getCondition()->getType())) {
+ CondVS = getVectorSplit(SI.getCondition()->getType());
+ if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
+ // This happens when ScalarizeMinBits is used.
+ return false;
+ }
+ }
+
IRBuilder<> Builder(&SI);
- Scatterer VOp1 = scatter(&SI, SI.getOperand(1));
- Scatterer VOp2 = scatter(&SI, SI.getOperand(2));
- assert(VOp1.size() == NumElems && "Mismatched select");
- assert(VOp2.size() == NumElems && "Mismatched select");
+ Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS);
+ Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS);
+ assert(VOp1.size() == VS->NumFragments && "Mismatched select");
+ assert(VOp2.size() == VS->NumFragments && "Mismatched select");
ValueVector Res;
- Res.resize(NumElems);
+ Res.resize(VS->NumFragments);
- if (SI.getOperand(0)->getType()->isVectorTy()) {
- Scatterer VOp0 = scatter(&SI, SI.getOperand(0));
- assert(VOp0.size() == NumElems && "Mismatched select");
- for (unsigned I = 0; I < NumElems; ++I) {
+ if (CondVS) {
+ Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS);
+ assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
Value *Op0 = VOp0[I];
Value *Op1 = VOp1[I];
Value *Op2 = VOp2[I];
@@ -652,14 +825,14 @@ bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
}
} else {
Value *Op0 = SI.getOperand(0);
- for (unsigned I = 0; I < NumElems; ++I) {
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
Value *Op1 = VOp1[I];
Value *Op2 = VOp2[I];
Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
SI.getName() + ".i" + Twine(I));
}
}
- gather(&SI, Res);
+ gather(&SI, Res, *VS);
return true;
}
@@ -680,146 +853,194 @@ bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
}
bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
- auto *VT = dyn_cast<FixedVectorType>(GEPI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType());
+ if (!VS)
return false;
IRBuilder<> Builder(&GEPI);
- unsigned NumElems = VT->getNumElements();
unsigned NumIndices = GEPI.getNumIndices();
- // The base pointer might be scalar even if it's a vector GEP. In those cases,
- // splat the pointer into a vector value, and scatter that vector.
- Value *Op0 = GEPI.getOperand(0);
- if (!Op0->getType()->isVectorTy())
- Op0 = Builder.CreateVectorSplat(NumElems, Op0);
- Scatterer Base = scatter(&GEPI, Op0);
-
- SmallVector<Scatterer, 8> Ops;
- Ops.resize(NumIndices);
- for (unsigned I = 0; I < NumIndices; ++I) {
- Value *Op = GEPI.getOperand(I + 1);
-
- // The indices might be scalars even if it's a vector GEP. In those cases,
- // splat the scalar into a vector value, and scatter that vector.
- if (!Op->getType()->isVectorTy())
- Op = Builder.CreateVectorSplat(NumElems, Op);
-
- Ops[I] = scatter(&GEPI, Op);
+ // The base pointer and indices might be scalar even if it's a vector GEP.
+ SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
+ SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
+
+ for (unsigned I = 0; I < 1 + NumIndices; ++I) {
+ if (auto *VecTy =
+ dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) {
+ std::optional<VectorSplit> OpVS = getVectorSplit(VecTy);
+ if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
+ // This can happen when ScalarizeMinBits is used.
+ return false;
+ }
+ ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS);
+ } else {
+ ScalarOps[I] = GEPI.getOperand(I);
+ }
}
ValueVector Res;
- Res.resize(NumElems);
- for (unsigned I = 0; I < NumElems; ++I) {
- SmallVector<Value *, 8> Indices;
- Indices.resize(NumIndices);
- for (unsigned J = 0; J < NumIndices; ++J)
- Indices[J] = Ops[J][I];
- Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices,
+ Res.resize(VS->NumFragments);
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
+ SmallVector<Value *, 8> SplitOps;
+ SplitOps.resize(1 + NumIndices);
+ for (unsigned J = 0; J < 1 + NumIndices; ++J) {
+ if (ScalarOps[J])
+ SplitOps[J] = ScalarOps[J];
+ else
+ SplitOps[J] = ScatterOps[J][I];
+ }
+ Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0],
+ ArrayRef(SplitOps).drop_front(),
GEPI.getName() + ".i" + Twine(I));
if (GEPI.isInBounds())
if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
NewGEPI->setIsInBounds();
}
- gather(&GEPI, Res);
+ gather(&GEPI, Res, *VS);
return true;
}
bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
- auto *VT = dyn_cast<FixedVectorType>(CI.getDestTy());
- if (!VT)
+ std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy());
+ if (!DestVS)
+ return false;
+
+ std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy());
+ if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
return false;
- unsigned NumElems = VT->getNumElements();
IRBuilder<> Builder(&CI);
- Scatterer Op0 = scatter(&CI, CI.getOperand(0));
- assert(Op0.size() == NumElems && "Mismatched cast");
+ Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS);
+ assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
ValueVector Res;
- Res.resize(NumElems);
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
- CI.getName() + ".i" + Twine(I));
- gather(&CI, Res);
+ Res.resize(DestVS->NumFragments);
+ for (unsigned I = 0; I < DestVS->NumFragments; ++I)
+ Res[I] =
+ Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I),
+ CI.getName() + ".i" + Twine(I));
+ gather(&CI, Res, *DestVS);
return true;
}
bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
- auto *DstVT = dyn_cast<FixedVectorType>(BCI.getDestTy());
- auto *SrcVT = dyn_cast<FixedVectorType>(BCI.getSrcTy());
- if (!DstVT || !SrcVT)
+ std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy());
+ std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy());
+ if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
return false;
- unsigned DstNumElems = DstVT->getNumElements();
- unsigned SrcNumElems = SrcVT->getNumElements();
+ const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
+
+ // Vectors of pointers are always fully scalarized.
+ assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
+
IRBuilder<> Builder(&BCI);
- Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
+ Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS);
ValueVector Res;
- Res.resize(DstNumElems);
+ Res.resize(DstVS->NumFragments);
+
+ unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
+ unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
- if (DstNumElems == SrcNumElems) {
- for (unsigned I = 0; I < DstNumElems; ++I)
- Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
+ if (isPointerTy || DstSplitBits == SrcSplitBits) {
+ assert(DstVS->NumFragments == SrcVS->NumFragments);
+ for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
+ Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I),
BCI.getName() + ".i" + Twine(I));
- } else if (DstNumElems > SrcNumElems) {
- // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
- // individual elements to the destination.
- unsigned FanOut = DstNumElems / SrcNumElems;
- auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut);
+ }
+ } else if (SrcSplitBits % DstSplitBits == 0) {
+ // Convert each source fragment to the same-sized destination vector and
+ // then scatter the result to the destination.
+ VectorSplit MidVS;
+ MidVS.NumPacked = DstVS->NumPacked;
+ MidVS.NumFragments = SrcSplitBits / DstSplitBits;
+ MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(),
+ MidVS.NumPacked * MidVS.NumFragments);
+ MidVS.SplitTy = DstVS->SplitTy;
+
unsigned ResI = 0;
- for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
- Value *V = Op0[Op0I];
- Instruction *VI;
+ for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
+ Value *V = Op0[I];
+
// Look through any existing bitcasts before converting to <N x t2>.
// In the best case, the resulting conversion might be a no-op.
+ Instruction *VI;
while ((VI = dyn_cast<Instruction>(V)) &&
VI->getOpcode() == Instruction::BitCast)
V = VI->getOperand(0);
- V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
- Scatterer Mid = scatter(&BCI, V);
- for (unsigned MidI = 0; MidI < FanOut; ++MidI)
- Res[ResI++] = Mid[MidI];
+
+ V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast");
+
+ Scatterer Mid = scatter(&BCI, V, MidVS);
+ for (unsigned J = 0; J < MidVS.NumFragments; ++J)
+ Res[ResI++] = Mid[J];
}
- } else {
- // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
- unsigned FanIn = SrcNumElems / DstNumElems;
- auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn);
- unsigned Op0I = 0;
- for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
- Value *V = PoisonValue::get(MidTy);
- for (unsigned MidI = 0; MidI < FanIn; ++MidI)
- V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
- BCI.getName() + ".i" + Twine(ResI)
- + ".upto" + Twine(MidI));
- Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
- BCI.getName() + ".i" + Twine(ResI));
+ } else if (DstSplitBits % SrcSplitBits == 0) {
+ // Gather enough source fragments to make up a destination fragment and
+ // then convert to the destination type.
+ VectorSplit MidVS;
+ MidVS.NumFragments = DstSplitBits / SrcSplitBits;
+ MidVS.NumPacked = SrcVS->NumPacked;
+ MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(),
+ MidVS.NumPacked * MidVS.NumFragments);
+ MidVS.SplitTy = SrcVS->SplitTy;
+
+ unsigned SrcI = 0;
+ SmallVector<Value *, 8> ConcatOps;
+ ConcatOps.resize(MidVS.NumFragments);
+ for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
+ for (unsigned J = 0; J < MidVS.NumFragments; ++J)
+ ConcatOps[J] = Op0[SrcI++];
+ Value *V = concatenate(Builder, ConcatOps, MidVS,
+ BCI.getName() + ".i" + Twine(I));
+ Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I),
+ BCI.getName() + ".i" + Twine(I));
}
+ } else {
+ return false;
}
- gather(&BCI, Res);
+
+ gather(&BCI, Res, *DstVS);
return true;
}
bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
- auto *VT = dyn_cast<FixedVectorType>(IEI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(IEI.getType());
+ if (!VS)
return false;
- unsigned NumElems = VT->getNumElements();
IRBuilder<> Builder(&IEI);
- Scatterer Op0 = scatter(&IEI, IEI.getOperand(0));
+ Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS);
Value *NewElt = IEI.getOperand(1);
Value *InsIdx = IEI.getOperand(2);
ValueVector Res;
- Res.resize(NumElems);
+ Res.resize(VS->NumFragments);
if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I];
+ unsigned Idx = CI->getZExtValue();
+ unsigned Fragment = Idx / VS->NumPacked;
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
+ if (I == Fragment) {
+ bool IsPacked = VS->NumPacked > 1;
+ if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
+ !VS->RemainderTy->isVectorTy())
+ IsPacked = false;
+ if (IsPacked) {
+ Res[I] =
+ Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked);
+ } else {
+ Res[I] = NewElt;
+ }
+ } else {
+ Res[I] = Op0[I];
+ }
+ }
} else {
- if (!ScalarizeVariableInsertExtract)
+ // Never split a variable insertelement that isn't fully scalarized.
+ if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
return false;
- for (unsigned I = 0; I < NumElems; ++I) {
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
Value *ShouldReplace =
Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
InsIdx->getName() + ".is." + Twine(I));
@@ -829,31 +1050,39 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
}
}
- gather(&IEI, Res);
+ gather(&IEI, Res, *VS);
return true;
}
bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
- auto *VT = dyn_cast<FixedVectorType>(EEI.getOperand(0)->getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
+ if (!VS)
return false;
- unsigned NumSrcElems = VT->getNumElements();
IRBuilder<> Builder(&EEI);
- Scatterer Op0 = scatter(&EEI, EEI.getOperand(0));
+ Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS);
Value *ExtIdx = EEI.getOperand(1);
if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
- Value *Res = Op0[CI->getValue().getZExtValue()];
+ unsigned Idx = CI->getZExtValue();
+ unsigned Fragment = Idx / VS->NumPacked;
+ Value *Res = Op0[Fragment];
+ bool IsPacked = VS->NumPacked > 1;
+ if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
+ !VS->RemainderTy->isVectorTy())
+ IsPacked = false;
+ if (IsPacked)
+ Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked);
replaceUses(&EEI, Res);
return true;
}
- if (!ScalarizeVariableInsertExtract)
+ // Never split a variable extractelement that isn't fully scalarized.
+ if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
return false;
- Value *Res = PoisonValue::get(VT->getElementType());
- for (unsigned I = 0; I < NumSrcElems; ++I) {
+ Value *Res = PoisonValue::get(VS->VecTy->getElementType());
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
Value *ShouldExtract =
Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
ExtIdx->getName() + ".is." + Twine(I));
@@ -866,51 +1095,52 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
}
bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
- auto *VT = dyn_cast<FixedVectorType>(SVI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(SVI.getType());
+ std::optional<VectorSplit> VSOp =
+ getVectorSplit(SVI.getOperand(0)->getType());
+ if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
return false;
- unsigned NumElems = VT->getNumElements();
- Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
- Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
+ Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp);
+ Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp);
ValueVector Res;
- Res.resize(NumElems);
+ Res.resize(VS->NumFragments);
- for (unsigned I = 0; I < NumElems; ++I) {
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
int Selector = SVI.getMaskValue(I);
if (Selector < 0)
- Res[I] = UndefValue::get(VT->getElementType());
+ Res[I] = PoisonValue::get(VS->VecTy->getElementType());
else if (unsigned(Selector) < Op0.size())
Res[I] = Op0[Selector];
else
Res[I] = Op1[Selector - Op0.size()];
}
- gather(&SVI, Res);
+ gather(&SVI, Res, *VS);
return true;
}
bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
- auto *VT = dyn_cast<FixedVectorType>(PHI.getType());
- if (!VT)
+ std::optional<VectorSplit> VS = getVectorSplit(PHI.getType());
+ if (!VS)
return false;
- unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
IRBuilder<> Builder(&PHI);
ValueVector Res;
- Res.resize(NumElems);
+ Res.resize(VS->NumFragments);
unsigned NumOps = PHI.getNumOperands();
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
+ for (unsigned I = 0; I < VS->NumFragments; ++I) {
+ Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps,
PHI.getName() + ".i" + Twine(I));
+ }
for (unsigned I = 0; I < NumOps; ++I) {
- Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
+ Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS);
BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
- for (unsigned J = 0; J < NumElems; ++J)
+ for (unsigned J = 0; J < VS->NumFragments; ++J)
cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
}
- gather(&PHI, Res);
+ gather(&PHI, Res, *VS);
return true;
}
@@ -925,17 +1155,17 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
if (!Layout)
return false;
- unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
IRBuilder<> Builder(&LI);
- Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType());
+ Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS);
ValueVector Res;
- Res.resize(NumElems);
+ Res.resize(Layout->VS.NumFragments);
- for (unsigned I = 0; I < NumElems; ++I)
- Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I],
- Align(Layout->getElemAlign(I)),
+ for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
+ Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I],
+ Align(Layout->getFragmentAlign(I)),
LI.getName() + ".i" + Twine(I));
- gather(&LI, Res);
+ }
+ gather(&LI, Res, Layout->VS);
return true;
}
@@ -951,17 +1181,17 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
if (!Layout)
return false;
- unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
IRBuilder<> Builder(&SI);
- Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType());
- Scatterer VVal = scatter(&SI, FullValue);
+ Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS);
+ Scatterer VVal = scatter(&SI, FullValue, Layout->VS);
ValueVector Stores;
- Stores.resize(NumElems);
- for (unsigned I = 0; I < NumElems; ++I) {
+ Stores.resize(Layout->VS.NumFragments);
+ for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
Value *Val = VVal[I];
Value *Ptr = VPtr[I];
- Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I));
+ Stores[I] =
+ Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I));
}
transferMetadataAndIRFlags(&SI, Stores);
return true;
@@ -971,6 +1201,12 @@ bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
return splitCall(CI);
}
+bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
+ return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
+ return Builder.CreateFreeze(Op, Name);
+ });
+}
+
// Delete the instructions that we scalarized. If a full vector result
// is still needed, recreate it using InsertElements.
bool ScalarizerVisitor::finish() {
@@ -983,17 +1219,19 @@ bool ScalarizerVisitor::finish() {
ValueVector &CV = *GMI.second;
if (!Op->use_empty()) {
// The value is still needed, so recreate it using a series of
- // InsertElements.
- Value *Res = PoisonValue::get(Op->getType());
+ // insertelements and/or shufflevectors.
+ Value *Res;
if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
BasicBlock *BB = Op->getParent();
- unsigned Count = Ty->getNumElements();
IRBuilder<> Builder(Op);
if (isa<PHINode>(Op))
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
- for (unsigned I = 0; I < Count; ++I)
- Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
- Op->getName() + ".upto" + Twine(I));
+
+ VectorSplit VS = *getVectorSplit(Ty);
+ assert(VS.NumFragments == CV.size());
+
+ Res = concatenate(Builder, CV, VS, Op->getName());
+
Res->takeName(Op);
} else {
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 4fb90bcea4f0..89d0b7c33e0d 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -162,7 +162,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -355,7 +354,6 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.setPreservesCFG();
@@ -374,14 +372,23 @@ private:
class SeparateConstOffsetFromGEP {
public:
SeparateConstOffsetFromGEP(
- DominatorTree *DT, ScalarEvolution *SE, LoopInfo *LI,
- TargetLibraryInfo *TLI,
+ DominatorTree *DT, LoopInfo *LI, TargetLibraryInfo *TLI,
function_ref<TargetTransformInfo &(Function &)> GetTTI, bool LowerGEP)
- : DT(DT), SE(SE), LI(LI), TLI(TLI), GetTTI(GetTTI), LowerGEP(LowerGEP) {}
+ : DT(DT), LI(LI), TLI(TLI), GetTTI(GetTTI), LowerGEP(LowerGEP) {}
bool run(Function &F);
private:
+ /// Track the operands of an add or sub.
+ using ExprKey = std::pair<Value *, Value *>;
+
+ /// Create a pair for use as a map key for a commutable operation.
+ static ExprKey createNormalizedCommutablePair(Value *A, Value *B) {
+ if (A < B)
+ return {A, B};
+ return {B, A};
+ }
+
/// Tries to split the given GEP into a variadic base and a constant offset,
/// and returns true if the splitting succeeds.
bool splitGEP(GetElementPtrInst *GEP);
@@ -428,7 +435,7 @@ private:
/// Returns true if the module changes.
///
/// Verified in @i32_add in split-gep.ll
- bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP);
+ bool canonicalizeArrayIndicesToIndexSize(GetElementPtrInst *GEP);
/// Optimize sext(a)+sext(b) to sext(a+b) when a+b can't sign overflow.
/// SeparateConstOffsetFromGEP distributes a sext to leaves before extracting
@@ -446,8 +453,8 @@ private:
/// Find the closest dominator of <Dominatee> that is equivalent to <Key>.
Instruction *findClosestMatchingDominator(
- const SCEV *Key, Instruction *Dominatee,
- DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs);
+ ExprKey Key, Instruction *Dominatee,
+ DenseMap<ExprKey, SmallVector<Instruction *, 2>> &DominatingExprs);
/// Verify F is free of dead code.
void verifyNoDeadCode(Function &F);
@@ -463,7 +470,6 @@ private:
const DataLayout *DL = nullptr;
DominatorTree *DT = nullptr;
- ScalarEvolution *SE;
LoopInfo *LI;
TargetLibraryInfo *TLI;
// Retrieved lazily since not always used.
@@ -473,8 +479,8 @@ private:
/// multiple GEPs with a single index.
bool LowerGEP;
- DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingAdds;
- DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingSubs;
+ DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingAdds;
+ DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingSubs;
};
} // end anonymous namespace
@@ -521,6 +527,12 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
!haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT))
return false;
+ // FIXME: We don't currently support constants from the RHS of subs,
+ // when we are zero-extended, because we need a way to zero-extended
+ // them before they are negated.
+ if (ZeroExtended && !SignExtended && BO->getOpcode() == Instruction::Sub)
+ return false;
+
// In addition, tracing into BO requires that its surrounding s/zext (if
// any) is distributable to both operands.
//
@@ -791,17 +803,17 @@ int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
.getSExtValue();
}
-bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToPointerSize(
+bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
GetElementPtrInst *GEP) {
bool Changed = false;
- Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
+ Type *PtrIdxTy = DL->getIndexType(GEP->getType());
gep_type_iterator GTI = gep_type_begin(*GEP);
for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end();
I != E; ++I, ++GTI) {
// Skip struct member indices which must be i32.
if (GTI.isSequential()) {
- if ((*I)->getType() != IntPtrTy) {
- *I = CastInst::CreateIntegerCast(*I, IntPtrTy, true, "idxprom", GEP);
+ if ((*I)->getType() != PtrIdxTy) {
+ *I = CastInst::CreateIntegerCast(*I, PtrIdxTy, true, "idxprom", GEP);
Changed = true;
}
}
@@ -849,10 +861,8 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs(
GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset) {
IRBuilder<> Builder(Variadic);
- Type *IntPtrTy = DL->getIntPtrType(Variadic->getType());
+ Type *PtrIndexTy = DL->getIndexType(Variadic->getType());
- Type *I8PtrTy =
- Builder.getInt8PtrTy(Variadic->getType()->getPointerAddressSpace());
Value *ResultPtr = Variadic->getOperand(0);
Loop *L = LI->getLoopFor(Variadic->getParent());
// Check if the base is not loop invariant or used more than once.
@@ -861,9 +871,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs(
!hasMoreThanOneUseInLoop(ResultPtr, L);
Value *FirstResult = nullptr;
- if (ResultPtr->getType() != I8PtrTy)
- ResultPtr = Builder.CreateBitCast(ResultPtr, I8PtrTy);
-
gep_type_iterator GTI = gep_type_begin(*Variadic);
// Create an ugly GEP for each sequential index. We don't create GEPs for
// structure indices, as they are accumulated in the constant offset index.
@@ -875,15 +882,16 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs(
if (CI->isZero())
continue;
- APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(),
+ APInt ElementSize = APInt(PtrIndexTy->getIntegerBitWidth(),
DL->getTypeAllocSize(GTI.getIndexedType()));
// Scale the index by element size.
if (ElementSize != 1) {
if (ElementSize.isPowerOf2()) {
Idx = Builder.CreateShl(
- Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2()));
+ Idx, ConstantInt::get(PtrIndexTy, ElementSize.logBase2()));
} else {
- Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize));
+ Idx =
+ Builder.CreateMul(Idx, ConstantInt::get(PtrIndexTy, ElementSize));
}
}
// Create an ugly GEP with a single index for each index.
@@ -896,7 +904,7 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs(
// Create a GEP with the constant offset index.
if (AccumulativeByteOffset != 0) {
- Value *Offset = ConstantInt::get(IntPtrTy, AccumulativeByteOffset);
+ Value *Offset = ConstantInt::get(PtrIndexTy, AccumulativeByteOffset);
ResultPtr =
Builder.CreateGEP(Builder.getInt8Ty(), ResultPtr, Offset, "uglygep");
} else
@@ -910,9 +918,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs(
if (isSwapCandidate && isLegalToSwapOperand(FirstGEP, SecondGEP, L))
swapGEPOperand(FirstGEP, SecondGEP);
- if (ResultPtr->getType() != Variadic->getType())
- ResultPtr = Builder.CreateBitCast(ResultPtr, Variadic->getType());
-
Variadic->replaceAllUsesWith(ResultPtr);
Variadic->eraseFromParent();
}
@@ -922,6 +927,9 @@ SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic,
int64_t AccumulativeByteOffset) {
IRBuilder<> Builder(Variadic);
Type *IntPtrTy = DL->getIntPtrType(Variadic->getType());
+ assert(IntPtrTy == DL->getIndexType(Variadic->getType()) &&
+ "Pointer type must match index type for arithmetic-based lowering of "
+ "split GEPs");
Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy);
gep_type_iterator GTI = gep_type_begin(*Variadic);
@@ -973,7 +981,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
if (GEP->hasAllConstantIndices())
return false;
- bool Changed = canonicalizeArrayIndicesToPointerSize(GEP);
+ bool Changed = canonicalizeArrayIndicesToIndexSize(GEP);
bool NeedsExtraction;
int64_t AccumulativeByteOffset = accumulateByteOffset(GEP, NeedsExtraction);
@@ -1057,7 +1065,15 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
if (LowerGEP) {
// As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to
// arithmetic operations if the target uses alias analysis in codegen.
- if (TTI.useAA())
+ // Additionally, pointers that aren't integral (and so can't be safely
+ // converted to integers) or those whose offset size is different from their
+ // pointer size (which means that doing integer arithmetic on them could
+ // affect that data) can't be lowered in this way.
+ unsigned AddrSpace = GEP->getPointerAddressSpace();
+ bool PointerHasExtraData = DL->getPointerSizeInBits(AddrSpace) !=
+ DL->getIndexSizeInBits(AddrSpace);
+ if (TTI.useAA() || DL->isNonIntegralAddressSpace(AddrSpace) ||
+ PointerHasExtraData)
lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset);
else
lowerToArithmetics(GEP, AccumulativeByteOffset);
@@ -1104,13 +1120,13 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
// used with unsigned integers later.
int64_t ElementTypeSizeOfGEP = static_cast<int64_t>(
DL->getTypeAllocSize(GEP->getResultElementType()));
- Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
+ Type *PtrIdxTy = DL->getIndexType(GEP->getType());
if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) {
// Very likely. As long as %gep is naturally aligned, the byte offset we
// extracted should be a multiple of sizeof(*%gep).
int64_t Index = AccumulativeByteOffset / ElementTypeSizeOfGEP;
NewGEP = GetElementPtrInst::Create(GEP->getResultElementType(), NewGEP,
- ConstantInt::get(IntPtrTy, Index, true),
+ ConstantInt::get(PtrIdxTy, Index, true),
GEP->getName(), GEP);
NewGEP->copyMetadata(*GEP);
// Inherit the inbounds attribute of the original GEP.
@@ -1131,16 +1147,11 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
//
// Emit an uglygep in this case.
IRBuilder<> Builder(GEP);
- Type *I8PtrTy =
- Builder.getInt8Ty()->getPointerTo(GEP->getPointerAddressSpace());
-
NewGEP = cast<Instruction>(Builder.CreateGEP(
- Builder.getInt8Ty(), Builder.CreateBitCast(NewGEP, I8PtrTy),
- {ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true)}, "uglygep",
+ Builder.getInt8Ty(), NewGEP,
+ {ConstantInt::get(PtrIdxTy, AccumulativeByteOffset, true)}, "uglygep",
GEPWasInBounds));
-
NewGEP->copyMetadata(*GEP);
- NewGEP = cast<Instruction>(Builder.CreateBitCast(NewGEP, GEP->getType()));
}
GEP->replaceAllUsesWith(NewGEP);
@@ -1153,13 +1164,12 @@ bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
};
- SeparateConstOffsetFromGEP Impl(DT, SE, LI, TLI, GetTTI, LowerGEP);
+ SeparateConstOffsetFromGEP Impl(DT, LI, TLI, GetTTI, LowerGEP);
return Impl.run(F);
}
@@ -1189,8 +1199,8 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {
}
Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator(
- const SCEV *Key, Instruction *Dominatee,
- DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs) {
+ ExprKey Key, Instruction *Dominatee,
+ DenseMap<ExprKey, SmallVector<Instruction *, 2>> &DominatingExprs) {
auto Pos = DominatingExprs.find(Key);
if (Pos == DominatingExprs.end())
return nullptr;
@@ -1210,7 +1220,7 @@ Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator(
}
bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
- if (!SE->isSCEVable(I->getType()))
+ if (!I->getType()->isIntOrIntVectorTy())
return false;
// Dom: LHS+RHS
@@ -1220,8 +1230,7 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
Value *LHS = nullptr, *RHS = nullptr;
if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {
if (LHS->getType() == RHS->getType()) {
- const SCEV *Key =
- SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
+ ExprKey Key = createNormalizedCommutablePair(LHS, RHS);
if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) {
Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I);
NewSExt->takeName(I);
@@ -1232,9 +1241,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
}
} else if (match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {
if (LHS->getType() == RHS->getType()) {
- const SCEV *Key =
- SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
- if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingSubs)) {
+ if (auto *Dom =
+ findClosestMatchingDominator({LHS, RHS}, I, DominatingSubs)) {
Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I);
NewSExt->takeName(I);
I->replaceAllUsesWith(NewSExt);
@@ -1247,16 +1255,12 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {
// Add I to DominatingExprs if it's an add/sub that can't sign overflow.
if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS)))) {
if (programUndefinedIfPoison(I)) {
- const SCEV *Key =
- SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
+ ExprKey Key = createNormalizedCommutablePair(LHS, RHS);
DominatingAdds[Key].push_back(I);
}
} else if (match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) {
- if (programUndefinedIfPoison(I)) {
- const SCEV *Key =
- SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS));
- DominatingSubs[Key].push_back(I);
- }
+ if (programUndefinedIfPoison(I))
+ DominatingSubs[{LHS, RHS}].push_back(I);
}
return false;
}
@@ -1376,16 +1380,25 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First,
First->setIsInBounds(true);
}
+void SeparateConstOffsetFromGEPPass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<SeparateConstOffsetFromGEPPass> *>(this)
+ ->printPipeline(OS, MapClassName2PassName);
+ OS << '<';
+ if (LowerGEP)
+ OS << "lower-gep";
+ OS << '>';
+}
+
PreservedAnalyses
SeparateConstOffsetFromGEPPass::run(Function &F, FunctionAnalysisManager &AM) {
auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
- auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
auto *LI = &AM.getResult<LoopAnalysis>(F);
auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
auto GetTTI = [&AM](Function &F) -> TargetTransformInfo & {
return AM.getResult<TargetIRAnalysis>(F);
};
- SeparateConstOffsetFromGEP Impl(DT, SE, LI, TLI, GetTTI, LowerGEP);
+ SeparateConstOffsetFromGEP Impl(DT, LI, TLI, GetTTI, LowerGEP);
if (!Impl.run(F))
return PreservedAnalyses::all();
PreservedAnalyses PA;
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index 7e08120f923d..ad7d34b61470 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -19,6 +19,7 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
@@ -42,6 +43,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
@@ -73,11 +75,14 @@ using namespace llvm::PatternMatch;
STATISTIC(NumBranches, "Number of branches unswitched");
STATISTIC(NumSwitches, "Number of switches unswitched");
+STATISTIC(NumSelects, "Number of selects turned into branches for unswitching");
STATISTIC(NumGuards, "Number of guards turned into branches for unswitching");
STATISTIC(NumTrivial, "Number of unswitches that are trivial");
STATISTIC(
NumCostMultiplierSkipped,
"Number of unswitch candidates that had their cost multiplier skipped");
+STATISTIC(NumInvariantConditionsInjected,
+ "Number of invariant conditions injected and unswitched");
static cl::opt<bool> EnableNonTrivialUnswitch(
"enable-nontrivial-unswitch", cl::init(false), cl::Hidden,
@@ -118,15 +123,53 @@ static cl::opt<bool> FreezeLoopUnswitchCond(
cl::desc("If enabled, the freeze instruction will be added to condition "
"of loop unswitch to prevent miscompilation."));
+static cl::opt<bool> InjectInvariantConditions(
+ "simple-loop-unswitch-inject-invariant-conditions", cl::Hidden,
+ cl::desc("Whether we should inject new invariants and unswitch them to "
+ "eliminate some existing (non-invariant) conditions."),
+ cl::init(true));
+
+static cl::opt<unsigned> InjectInvariantConditionHotnesThreshold(
+ "simple-loop-unswitch-inject-invariant-condition-hotness-threshold",
+ cl::Hidden, cl::desc("Only try to inject loop invariant conditions and "
+ "unswitch on them to eliminate branches that are "
+ "not-taken 1/<this option> times or less."),
+ cl::init(16));
+
namespace {
+struct CompareDesc {
+ BranchInst *Term;
+ Value *Invariant;
+ BasicBlock *InLoopSucc;
+
+ CompareDesc(BranchInst *Term, Value *Invariant, BasicBlock *InLoopSucc)
+ : Term(Term), Invariant(Invariant), InLoopSucc(InLoopSucc) {}
+};
+
+struct InjectedInvariant {
+ ICmpInst::Predicate Pred;
+ Value *LHS;
+ Value *RHS;
+ BasicBlock *InLoopSucc;
+
+ InjectedInvariant(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+ BasicBlock *InLoopSucc)
+ : Pred(Pred), LHS(LHS), RHS(RHS), InLoopSucc(InLoopSucc) {}
+};
+
struct NonTrivialUnswitchCandidate {
Instruction *TI = nullptr;
TinyPtrVector<Value *> Invariants;
std::optional<InstructionCost> Cost;
+ std::optional<InjectedInvariant> PendingInjection;
NonTrivialUnswitchCandidate(
Instruction *TI, ArrayRef<Value *> Invariants,
- std::optional<InstructionCost> Cost = std::nullopt)
- : TI(TI), Invariants(Invariants), Cost(Cost){};
+ std::optional<InstructionCost> Cost = std::nullopt,
+ std::optional<InjectedInvariant> PendingInjection = std::nullopt)
+ : TI(TI), Invariants(Invariants), Cost(Cost),
+ PendingInjection(PendingInjection) {};
+
+ bool hasPendingInjection() const { return PendingInjection.has_value(); }
};
} // end anonymous namespace.
@@ -434,10 +477,10 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader,
// Return the top-most loop containing ExitBB and having ExitBB as exiting block
// or the loop containing ExitBB, if there is no parent loop containing ExitBB
// as exiting block.
-static const Loop *getTopMostExitingLoop(const BasicBlock *ExitBB,
- const LoopInfo &LI) {
- const Loop *TopMost = LI.getLoopFor(ExitBB);
- const Loop *Current = TopMost;
+static Loop *getTopMostExitingLoop(const BasicBlock *ExitBB,
+ const LoopInfo &LI) {
+ Loop *TopMost = LI.getLoopFor(ExitBB);
+ Loop *Current = TopMost;
while (Current) {
if (Current->isLoopExiting(ExitBB))
TopMost = Current;
@@ -750,15 +793,32 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
Loop *OuterL = &L;
if (DefaultExitBB) {
- // Clear out the default destination temporarily to allow accurate
- // predecessor lists to be examined below.
- SI.setDefaultDest(nullptr);
// Check the loop containing this exit.
- Loop *ExitL = LI.getLoopFor(DefaultExitBB);
+ Loop *ExitL = getTopMostExitingLoop(DefaultExitBB, LI);
+ if (!ExitL || ExitL->contains(OuterL))
+ OuterL = ExitL;
+ }
+ for (unsigned Index : ExitCaseIndices) {
+ auto CaseI = SI.case_begin() + Index;
+ // Compute the outer loop from this exit.
+ Loop *ExitL = getTopMostExitingLoop(CaseI->getCaseSuccessor(), LI);
if (!ExitL || ExitL->contains(OuterL))
OuterL = ExitL;
}
+ if (SE) {
+ if (OuterL)
+ SE->forgetLoop(OuterL);
+ else
+ SE->forgetTopmostLoop(&L);
+ }
+
+ if (DefaultExitBB) {
+ // Clear out the default destination temporarily to allow accurate
+ // predecessor lists to be examined below.
+ SI.setDefaultDest(nullptr);
+ }
+
// Store the exit cases into a separate data structure and remove them from
// the switch.
SmallVector<std::tuple<ConstantInt *, BasicBlock *,
@@ -770,10 +830,6 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
// and don't disrupt the earlier indices.
for (unsigned Index : reverse(ExitCaseIndices)) {
auto CaseI = SI.case_begin() + Index;
- // Compute the outer loop from this exit.
- Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor());
- if (!ExitL || ExitL->contains(OuterL))
- OuterL = ExitL;
// Save the value of this case.
auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex());
ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W);
@@ -781,13 +837,6 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
SIW.removeCase(CaseI);
}
- if (SE) {
- if (OuterL)
- SE->forgetLoop(OuterL);
- else
- SE->forgetTopmostLoop(&L);
- }
-
// Check if after this all of the remaining cases point at the same
// successor.
BasicBlock *CommonSuccBB = nullptr;
@@ -2079,7 +2128,7 @@ static void unswitchNontrivialInvariants(
AssumptionCache &AC,
function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+ function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze) {
auto *ParentBB = TI.getParent();
BranchInst *BI = dyn_cast<BranchInst>(&TI);
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
@@ -2160,7 +2209,9 @@ static void unswitchNontrivialInvariants(
SmallVector<BasicBlock *, 4> ExitBlocks;
L.getUniqueExitBlocks(ExitBlocks);
for (auto *ExitBB : ExitBlocks) {
- Loop *NewOuterExitL = LI.getLoopFor(ExitBB);
+ // ExitBB can be an exit block for several levels in the loop nest. Make
+ // sure we find the top most.
+ Loop *NewOuterExitL = getTopMostExitingLoop(ExitBB, LI);
if (!NewOuterExitL) {
// We exited the entire nest with this block, so we're done.
OuterExitL = nullptr;
@@ -2181,25 +2232,6 @@ static void unswitchNontrivialInvariants(
SE->forgetBlockAndLoopDispositions();
}
- bool InsertFreeze = false;
- if (FreezeLoopUnswitchCond) {
- ICFLoopSafetyInfo SafetyInfo;
- SafetyInfo.computeLoopSafetyInfo(&L);
- InsertFreeze = !SafetyInfo.isGuaranteedToExecute(TI, &DT, &L);
- }
-
- // Perform the isGuaranteedNotToBeUndefOrPoison() query before the transform,
- // otherwise the branch instruction will have been moved outside the loop
- // already, and may imply that a poison condition is always UB.
- Value *FullUnswitchCond = nullptr;
- if (FullUnswitch) {
- FullUnswitchCond =
- BI ? skipTrivialSelect(BI->getCondition()) : SI->getCondition();
- if (InsertFreeze)
- InsertFreeze = !isGuaranteedNotToBeUndefOrPoison(
- FullUnswitchCond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
- }
-
// If the edge from this terminator to a successor dominates that successor,
// store a map from each block in its dominator subtree to it. This lets us
// tell when cloning for a particular successor if a block is dominated by
@@ -2274,10 +2306,11 @@ static void unswitchNontrivialInvariants(
BasicBlock *ClonedPH = ClonedPHs.begin()->second;
BI->setSuccessor(ClonedSucc, ClonedPH);
BI->setSuccessor(1 - ClonedSucc, LoopPH);
+ Value *Cond = skipTrivialSelect(BI->getCondition());
if (InsertFreeze)
- FullUnswitchCond = new FreezeInst(
- FullUnswitchCond, FullUnswitchCond->getName() + ".fr", BI);
- BI->setCondition(FullUnswitchCond);
+ Cond = new FreezeInst(
+ Cond, Cond->getName() + ".fr", BI);
+ BI->setCondition(Cond);
DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH});
} else {
assert(SI && "Must either be a branch or switch!");
@@ -2294,7 +2327,7 @@ static void unswitchNontrivialInvariants(
if (InsertFreeze)
SI->setCondition(new FreezeInst(
- FullUnswitchCond, FullUnswitchCond->getName() + ".fr", SI));
+ SI->getCondition(), SI->getCondition()->getName() + ".fr", SI));
// We need to use the set to populate domtree updates as even when there
// are multiple cases pointing at the same successor we only want to
@@ -2593,6 +2626,57 @@ static InstructionCost computeDomSubtreeCost(
return Cost;
}
+/// Turns a select instruction into implicit control flow branch,
+/// making the following replacement:
+///
+/// head:
+/// --code before select--
+/// select %cond, %trueval, %falseval
+/// --code after select--
+///
+/// into
+///
+/// head:
+/// --code before select--
+/// br i1 %cond, label %then, label %tail
+///
+/// then:
+/// br %tail
+///
+/// tail:
+/// phi [ %trueval, %then ], [ %falseval, %head]
+/// unreachable
+///
+/// It also makes all relevant DT and LI updates, so that all structures are in
+/// valid state after this transform.
+static BranchInst *turnSelectIntoBranch(SelectInst *SI, DominatorTree &DT,
+ LoopInfo &LI, MemorySSAUpdater *MSSAU,
+ AssumptionCache *AC) {
+ LLVM_DEBUG(dbgs() << "Turning " << *SI << " into a branch.\n");
+ BasicBlock *HeadBB = SI->getParent();
+
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ SplitBlockAndInsertIfThen(SI->getCondition(), SI, false,
+ SI->getMetadata(LLVMContext::MD_prof), &DTU, &LI);
+ auto *CondBr = cast<BranchInst>(HeadBB->getTerminator());
+ BasicBlock *ThenBB = CondBr->getSuccessor(0),
+ *TailBB = CondBr->getSuccessor(1);
+ if (MSSAU)
+ MSSAU->moveAllAfterSpliceBlocks(HeadBB, TailBB, SI);
+
+ PHINode *Phi = PHINode::Create(SI->getType(), 2, "unswitched.select", SI);
+ Phi->addIncoming(SI->getTrueValue(), ThenBB);
+ Phi->addIncoming(SI->getFalseValue(), HeadBB);
+ SI->replaceAllUsesWith(Phi);
+ SI->eraseFromParent();
+
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
+
+ ++NumSelects;
+ return CondBr;
+}
+
/// Turns a llvm.experimental.guard intrinsic into implicit control flow branch,
/// making the following replacement:
///
@@ -2624,15 +2708,10 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
- // Remove all CheckBB's successors from DomTree. A block can be seen among
- // successors more than once, but for DomTree it should be added only once.
- SmallPtrSet<BasicBlock *, 4> Successors;
- for (auto *Succ : successors(CheckBB))
- if (Successors.insert(Succ).second)
- DTUpdates.push_back({DominatorTree::Delete, CheckBB, Succ});
-
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
Instruction *DeoptBlockTerm =
- SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true);
+ SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true,
+ GI->getMetadata(LLVMContext::MD_prof), &DTU, &LI);
BranchInst *CheckBI = cast<BranchInst>(CheckBB->getTerminator());
// SplitBlockAndInsertIfThen inserts control flow that branches to
// DeoptBlockTerm if the condition is true. We want the opposite.
@@ -2649,20 +2728,6 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
GI->moveBefore(DeoptBlockTerm);
GI->setArgOperand(0, ConstantInt::getFalse(GI->getContext()));
- // Add new successors of CheckBB into DomTree.
- for (auto *Succ : successors(CheckBB))
- DTUpdates.push_back({DominatorTree::Insert, CheckBB, Succ});
-
- // Now the blocks that used to be CheckBB's successors are GuardedBlock's
- // successors.
- for (auto *Succ : Successors)
- DTUpdates.push_back({DominatorTree::Insert, GuardedBlock, Succ});
-
- // Make proper changes to DT.
- DT.applyUpdates(DTUpdates);
- // Inform LI of a new loop block.
- L.addBasicBlockToLoop(GuardedBlock, LI);
-
if (MSSAU) {
MemoryDef *MD = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(GI));
MSSAU->moveToPlace(MD, DeoptBlock, MemorySSA::BeforeTerminator);
@@ -2670,6 +2735,8 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
MSSAU->getMemorySSA()->verifyMemorySSA();
}
+ if (VerifyLoopInfo)
+ LI.verify(DT);
++NumGuards;
return CheckBI;
}
@@ -2700,9 +2767,10 @@ static int CalculateUnswitchCostMultiplier(
const BasicBlock *CondBlock = TI.getParent();
if (DT.dominates(CondBlock, Latch) &&
(isGuard(&TI) ||
- llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) {
- return L.contains(SuccBB);
- }) <= 1)) {
+ (TI.isTerminator() &&
+ llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) {
+ return L.contains(SuccBB);
+ }) <= 1))) {
NumCostMultiplierSkipped++;
return 1;
}
@@ -2711,12 +2779,17 @@ static int CalculateUnswitchCostMultiplier(
int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size()
: std::distance(LI.begin(), LI.end()));
// Count amount of clones that all the candidates might cause during
- // unswitching. Branch/guard counts as 1, switch counts as log2 of its cases.
+ // unswitching. Branch/guard/select counts as 1, switch counts as log2 of its
+ // cases.
int UnswitchedClones = 0;
- for (auto Candidate : UnswitchCandidates) {
+ for (const auto &Candidate : UnswitchCandidates) {
const Instruction *CI = Candidate.TI;
const BasicBlock *CondBlock = CI->getParent();
bool SkipExitingSuccessors = DT.dominates(CondBlock, Latch);
+ if (isa<SelectInst>(CI)) {
+ UnswitchedClones++;
+ continue;
+ }
if (isGuard(CI)) {
if (!SkipExitingSuccessors)
UnswitchedClones++;
@@ -2766,6 +2839,24 @@ static bool collectUnswitchCandidates(
const Loop &L, const LoopInfo &LI, AAResults &AA,
const MemorySSAUpdater *MSSAU) {
assert(UnswitchCandidates.empty() && "Should be!");
+
+ auto AddUnswitchCandidatesForInst = [&](Instruction *I, Value *Cond) {
+ Cond = skipTrivialSelect(Cond);
+ if (isa<Constant>(Cond))
+ return;
+ if (L.isLoopInvariant(Cond)) {
+ UnswitchCandidates.push_back({I, {Cond}});
+ return;
+ }
+ if (match(Cond, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) {
+ TinyPtrVector<Value *> Invariants =
+ collectHomogenousInstGraphLoopInvariants(
+ L, *static_cast<Instruction *>(Cond), LI);
+ if (!Invariants.empty())
+ UnswitchCandidates.push_back({I, std::move(Invariants)});
+ }
+ };
+
// Whether or not we should also collect guards in the loop.
bool CollectGuards = false;
if (UnswitchGuards) {
@@ -2779,15 +2870,20 @@ static bool collectUnswitchCandidates(
if (LI.getLoopFor(BB) != &L)
continue;
- if (CollectGuards)
- for (auto &I : *BB)
- if (isGuard(&I)) {
- auto *Cond =
- skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0));
- // TODO: Support AND, OR conditions and partial unswitching.
- if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond))
- UnswitchCandidates.push_back({&I, {Cond}});
- }
+ for (auto &I : *BB) {
+ if (auto *SI = dyn_cast<SelectInst>(&I)) {
+ auto *Cond = SI->getCondition();
+ // Do not unswitch vector selects and logical and/or selects
+ if (Cond->getType()->isIntegerTy(1) && !SI->getType()->isIntegerTy(1))
+ AddUnswitchCandidatesForInst(SI, Cond);
+ } else if (CollectGuards && isGuard(&I)) {
+ auto *Cond =
+ skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0));
+ // TODO: Support AND, OR conditions and partial unswitching.
+ if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond))
+ UnswitchCandidates.push_back({&I, {Cond}});
+ }
+ }
if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
// We can only consider fully loop-invariant switch conditions as we need
@@ -2799,29 +2895,11 @@ static bool collectUnswitchCandidates(
}
auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
- if (!BI || !BI->isConditional() || isa<Constant>(BI->getCondition()) ||
+ if (!BI || !BI->isConditional() ||
BI->getSuccessor(0) == BI->getSuccessor(1))
continue;
- Value *Cond = skipTrivialSelect(BI->getCondition());
- if (isa<Constant>(Cond))
- continue;
-
- if (L.isLoopInvariant(Cond)) {
- UnswitchCandidates.push_back({BI, {Cond}});
- continue;
- }
-
- Instruction &CondI = *cast<Instruction>(Cond);
- if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) {
- TinyPtrVector<Value *> Invariants =
- collectHomogenousInstGraphLoopInvariants(L, CondI, LI);
- if (Invariants.empty())
- continue;
-
- UnswitchCandidates.push_back({BI, std::move(Invariants)});
- continue;
- }
+ AddUnswitchCandidatesForInst(BI, BI->getCondition());
}
if (MSSAU && !findOptionMDForLoop(&L, "llvm.loop.unswitch.partial.disable") &&
@@ -2844,6 +2922,303 @@ static bool collectUnswitchCandidates(
return !UnswitchCandidates.empty();
}
+/// Tries to canonicalize condition described by:
+///
+/// br (LHS pred RHS), label IfTrue, label IfFalse
+///
+/// into its equivalent where `Pred` is something that we support for injected
+/// invariants (so far it is limited to ult), LHS in canonicalized form is
+/// non-invariant and RHS is an invariant.
+static void canonicalizeForInvariantConditionInjection(
+ ICmpInst::Predicate &Pred, Value *&LHS, Value *&RHS, BasicBlock *&IfTrue,
+ BasicBlock *&IfFalse, const Loop &L) {
+ if (!L.contains(IfTrue)) {
+ Pred = ICmpInst::getInversePredicate(Pred);
+ std::swap(IfTrue, IfFalse);
+ }
+
+ // Move loop-invariant argument to RHS position.
+ if (L.isLoopInvariant(LHS)) {
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ std::swap(LHS, RHS);
+ }
+
+ if (Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())) {
+ // Turn "x >=s 0" into "x <u UMIN_INT"
+ Pred = ICmpInst::ICMP_ULT;
+ RHS = ConstantInt::get(
+ RHS->getContext(),
+ APInt::getSignedMinValue(RHS->getType()->getIntegerBitWidth()));
+ }
+}
+
+/// Returns true, if predicate described by ( \p Pred, \p LHS, \p RHS )
+/// succeeding into blocks ( \p IfTrue, \p IfFalse) can be optimized by
+/// injecting a loop-invariant condition.
+static bool shouldTryInjectInvariantCondition(
+ const ICmpInst::Predicate Pred, const Value *LHS, const Value *RHS,
+ const BasicBlock *IfTrue, const BasicBlock *IfFalse, const Loop &L) {
+ if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS))
+ return false;
+ // TODO: Support other predicates.
+ if (Pred != ICmpInst::ICMP_ULT)
+ return false;
+ // TODO: Support non-loop-exiting branches?
+ if (!L.contains(IfTrue) || L.contains(IfFalse))
+ return false;
+ // FIXME: For some reason this causes problems with MSSA updates, need to
+ // investigate why. So far, just don't unswitch latch.
+ if (L.getHeader() == IfTrue)
+ return false;
+ return true;
+}
+
+/// Returns true, if metadata on \p BI allows us to optimize branching into \p
+/// TakenSucc via injection of invariant conditions. The branch should be not
+/// enough and not previously unswitched, the information about this comes from
+/// the metadata.
+bool shouldTryInjectBasingOnMetadata(const BranchInst *BI,
+ const BasicBlock *TakenSucc) {
+ // Skip branches that have already been unswithed this way. After successful
+ // unswitching of injected condition, we will still have a copy of this loop
+ // which looks exactly the same as original one. To prevent the 2nd attempt
+ // of unswitching it in the same pass, mark this branch as "nothing to do
+ // here".
+ if (BI->hasMetadata("llvm.invariant.condition.injection.disabled"))
+ return false;
+ SmallVector<uint32_t> Weights;
+ if (!extractBranchWeights(*BI, Weights))
+ return false;
+ unsigned T = InjectInvariantConditionHotnesThreshold;
+ BranchProbability LikelyTaken(T - 1, T);
+
+ assert(Weights.size() == 2 && "Unexpected profile data!");
+ size_t Idx = BI->getSuccessor(0) == TakenSucc ? 0 : 1;
+ auto Num = Weights[Idx];
+ auto Denom = Weights[0] + Weights[1];
+ // Degenerate or overflowed metadata.
+ if (Denom == 0 || Num > Denom)
+ return false;
+ BranchProbability ActualTaken(Num, Denom);
+ if (LikelyTaken > ActualTaken)
+ return false;
+ return true;
+}
+
+/// Materialize pending invariant condition of the given candidate into IR. The
+/// injected loop-invariant condition implies the original loop-variant branch
+/// condition, so the materialization turns
+///
+/// loop_block:
+/// ...
+/// br i1 %variant_cond, label InLoopSucc, label OutOfLoopSucc
+///
+/// into
+///
+/// preheader:
+/// %invariant_cond = LHS pred RHS
+/// ...
+/// loop_block:
+/// br i1 %invariant_cond, label InLoopSucc, label OriginalCheck
+/// OriginalCheck:
+/// br i1 %variant_cond, label InLoopSucc, label OutOfLoopSucc
+/// ...
+static NonTrivialUnswitchCandidate
+injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L,
+ DominatorTree &DT, LoopInfo &LI,
+ AssumptionCache &AC, MemorySSAUpdater *MSSAU) {
+ assert(Candidate.hasPendingInjection() && "Nothing to inject!");
+ BasicBlock *Preheader = L.getLoopPreheader();
+ assert(Preheader && "Loop is not in simplified form?");
+ assert(LI.getLoopFor(Candidate.TI->getParent()) == &L &&
+ "Unswitching branch of inner loop!");
+
+ auto Pred = Candidate.PendingInjection->Pred;
+ auto *LHS = Candidate.PendingInjection->LHS;
+ auto *RHS = Candidate.PendingInjection->RHS;
+ auto *InLoopSucc = Candidate.PendingInjection->InLoopSucc;
+ auto *TI = cast<BranchInst>(Candidate.TI);
+ auto *BB = Candidate.TI->getParent();
+ auto *OutOfLoopSucc = InLoopSucc == TI->getSuccessor(0) ? TI->getSuccessor(1)
+ : TI->getSuccessor(0);
+ // FIXME: Remove this once limitation on successors is lifted.
+ assert(L.contains(InLoopSucc) && "Not supported yet!");
+ assert(!L.contains(OutOfLoopSucc) && "Not supported yet!");
+ auto &Ctx = BB->getContext();
+
+ IRBuilder<> Builder(Preheader->getTerminator());
+ assert(ICmpInst::isUnsigned(Pred) && "Not supported yet!");
+ if (LHS->getType() != RHS->getType()) {
+ if (LHS->getType()->getIntegerBitWidth() <
+ RHS->getType()->getIntegerBitWidth())
+ LHS = Builder.CreateZExt(LHS, RHS->getType(), LHS->getName() + ".wide");
+ else
+ RHS = Builder.CreateZExt(RHS, LHS->getType(), RHS->getName() + ".wide");
+ }
+ // Do not use builder here: CreateICmp may simplify this into a constant and
+ // unswitching will break. Better optimize it away later.
+ auto *InjectedCond =
+ ICmpInst::Create(Instruction::ICmp, Pred, LHS, RHS, "injected.cond",
+ Preheader->getTerminator());
+ auto *OldCond = TI->getCondition();
+
+ BasicBlock *CheckBlock = BasicBlock::Create(Ctx, BB->getName() + ".check",
+ BB->getParent(), InLoopSucc);
+ Builder.SetInsertPoint(TI);
+ auto *InvariantBr =
+ Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock);
+
+ Builder.SetInsertPoint(CheckBlock);
+ auto *NewTerm = Builder.CreateCondBr(OldCond, InLoopSucc, OutOfLoopSucc);
+
+ TI->eraseFromParent();
+ // Prevent infinite unswitching.
+ NewTerm->setMetadata("llvm.invariant.condition.injection.disabled",
+ MDNode::get(BB->getContext(), {}));
+
+ // Fixup phis.
+ for (auto &I : *InLoopSucc) {
+ auto *PN = dyn_cast<PHINode>(&I);
+ if (!PN)
+ break;
+ auto *Inc = PN->getIncomingValueForBlock(BB);
+ PN->addIncoming(Inc, CheckBlock);
+ }
+ OutOfLoopSucc->replacePhiUsesWith(BB, CheckBlock);
+
+ SmallVector<DominatorTree::UpdateType, 4> DTUpdates = {
+ { DominatorTree::Insert, BB, CheckBlock },
+ { DominatorTree::Insert, CheckBlock, InLoopSucc },
+ { DominatorTree::Insert, CheckBlock, OutOfLoopSucc },
+ { DominatorTree::Delete, BB, OutOfLoopSucc }
+ };
+
+ DT.applyUpdates(DTUpdates);
+ if (MSSAU)
+ MSSAU->applyUpdates(DTUpdates, DT);
+ L.addBasicBlockToLoop(CheckBlock, LI);
+
+#ifndef NDEBUG
+ DT.verify();
+ LI.verify(DT);
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
+#endif
+
+ // TODO: In fact, cost of unswitching a new invariant candidate is *slightly*
+ // higher because we have just inserted a new block. Need to think how to
+ // adjust the cost of injected candidates when it was first computed.
+ LLVM_DEBUG(dbgs() << "Injected a new loop-invariant branch " << *InvariantBr
+ << " and considering it for unswitching.");
+ ++NumInvariantConditionsInjected;
+ return NonTrivialUnswitchCandidate(InvariantBr, { InjectedCond },
+ Candidate.Cost);
+}
+
+/// Given chain of loop branch conditions looking like:
+/// br (Variant < Invariant1)
+/// br (Variant < Invariant2)
+/// br (Variant < Invariant3)
+/// ...
+/// collect set of invariant conditions on which we want to unswitch, which
+/// look like:
+/// Invariant1 <= Invariant2
+/// Invariant2 <= Invariant3
+/// ...
+/// Though they might not immediately exist in the IR, we can still inject them.
+static bool insertCandidatesWithPendingInjections(
+ SmallVectorImpl<NonTrivialUnswitchCandidate> &UnswitchCandidates, Loop &L,
+ ICmpInst::Predicate Pred, ArrayRef<CompareDesc> Compares,
+ const DominatorTree &DT) {
+
+ assert(ICmpInst::isRelational(Pred));
+ assert(ICmpInst::isStrictPredicate(Pred));
+ if (Compares.size() < 2)
+ return false;
+ ICmpInst::Predicate NonStrictPred = ICmpInst::getNonStrictPredicate(Pred);
+ for (auto Prev = Compares.begin(), Next = Compares.begin() + 1;
+ Next != Compares.end(); ++Prev, ++Next) {
+ Value *LHS = Next->Invariant;
+ Value *RHS = Prev->Invariant;
+ BasicBlock *InLoopSucc = Prev->InLoopSucc;
+ InjectedInvariant ToInject(NonStrictPred, LHS, RHS, InLoopSucc);
+ NonTrivialUnswitchCandidate Candidate(Prev->Term, { LHS, RHS },
+ std::nullopt, std::move(ToInject));
+ UnswitchCandidates.push_back(std::move(Candidate));
+ }
+ return true;
+}
+
+/// Collect unswitch candidates by invariant conditions that are not immediately
+/// present in the loop. However, they can be injected into the code if we
+/// decide it's profitable.
+/// An example of such conditions is following:
+///
+/// for (...) {
+/// x = load ...
+/// if (! x <u C1) break;
+/// if (! x <u C2) break;
+/// <do something>
+/// }
+///
+/// We can unswitch by condition "C1 <=u C2". If that is true, then "x <u C1 <=
+/// C2" automatically implies "x <u C2", so we can get rid of one of
+/// loop-variant checks in unswitched loop version.
+static bool collectUnswitchCandidatesWithInjections(
+ SmallVectorImpl<NonTrivialUnswitchCandidate> &UnswitchCandidates,
+ IVConditionInfo &PartialIVInfo, Instruction *&PartialIVCondBranch, Loop &L,
+ const DominatorTree &DT, const LoopInfo &LI, AAResults &AA,
+ const MemorySSAUpdater *MSSAU) {
+ if (!InjectInvariantConditions)
+ return false;
+
+ if (!DT.isReachableFromEntry(L.getHeader()))
+ return false;
+ auto *Latch = L.getLoopLatch();
+ // Need to have a single latch and a preheader.
+ if (!Latch)
+ return false;
+ assert(L.getLoopPreheader() && "Must have a preheader!");
+
+ DenseMap<Value *, SmallVector<CompareDesc, 4> > CandidatesULT;
+ // Traverse the conditions that dominate latch (and therefore dominate each
+ // other).
+ for (auto *DTN = DT.getNode(Latch); L.contains(DTN->getBlock());
+ DTN = DTN->getIDom()) {
+ ICmpInst::Predicate Pred;
+ Value *LHS = nullptr, *RHS = nullptr;
+ BasicBlock *IfTrue = nullptr, *IfFalse = nullptr;
+ auto *BB = DTN->getBlock();
+ // Ignore inner loops.
+ if (LI.getLoopFor(BB) != &L)
+ continue;
+ auto *Term = BB->getTerminator();
+ if (!match(Term, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
+ m_BasicBlock(IfTrue), m_BasicBlock(IfFalse))))
+ continue;
+ if (!LHS->getType()->isIntegerTy())
+ continue;
+ canonicalizeForInvariantConditionInjection(Pred, LHS, RHS, IfTrue, IfFalse,
+ L);
+ if (!shouldTryInjectInvariantCondition(Pred, LHS, RHS, IfTrue, IfFalse, L))
+ continue;
+ if (!shouldTryInjectBasingOnMetadata(cast<BranchInst>(Term), IfTrue))
+ continue;
+ // Strip ZEXT for unsigned predicate.
+ // TODO: once signed predicates are supported, also strip SEXT.
+ CompareDesc Desc(cast<BranchInst>(Term), RHS, IfTrue);
+ while (auto *Zext = dyn_cast<ZExtInst>(LHS))
+ LHS = Zext->getOperand(0);
+ CandidatesULT[LHS].push_back(Desc);
+ }
+
+ bool Found = false;
+ for (auto &It : CandidatesULT)
+ Found |= insertCandidatesWithPendingInjections(
+ UnswitchCandidates, L, ICmpInst::ICMP_ULT, It.second, DT);
+ return Found;
+}
+
static bool isSafeForNoNTrivialUnswitching(Loop &L, LoopInfo &LI) {
if (!L.isSafeToClone())
return false;
@@ -2943,6 +3318,10 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate(
// cost for that terminator.
auto ComputeUnswitchedCost = [&](Instruction &TI,
bool FullUnswitch) -> InstructionCost {
+ // Unswitching selects unswitches the entire loop.
+ if (isa<SelectInst>(TI))
+ return LoopCost;
+
BasicBlock &BB = *TI.getParent();
SmallPtrSet<BasicBlock *, 4> Visited;
@@ -3003,10 +3382,11 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate(
Instruction &TI = *Candidate.TI;
ArrayRef<Value *> Invariants = Candidate.Invariants;
BranchInst *BI = dyn_cast<BranchInst>(&TI);
- InstructionCost CandidateCost = ComputeUnswitchedCost(
- TI, /*FullUnswitch*/ !BI ||
- (Invariants.size() == 1 &&
- Invariants[0] == skipTrivialSelect(BI->getCondition())));
+ bool FullUnswitch =
+ !BI || Candidate.hasPendingInjection() ||
+ (Invariants.size() == 1 &&
+ Invariants[0] == skipTrivialSelect(BI->getCondition()));
+ InstructionCost CandidateCost = ComputeUnswitchedCost(TI, FullUnswitch);
// Calculate cost multiplier which is a tool to limit potentially
// exponential behavior of loop-unswitch.
if (EnableUnswitchCostMultiplier) {
@@ -3033,6 +3413,32 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate(
return *Best;
}
+// Insert a freeze on an unswitched branch if all is true:
+// 1. freeze-loop-unswitch-cond option is true
+// 2. The branch may not execute in the loop pre-transformation. If a branch may
+// not execute and could cause UB, it would always cause UB if it is hoisted outside
+// of the loop. Insert a freeze to prevent this case.
+// 3. The branch condition may be poison or undef
+static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
+ AssumptionCache &AC) {
+ assert(isa<BranchInst>(TI) || isa<SwitchInst>(TI));
+ if (!FreezeLoopUnswitchCond)
+ return false;
+
+ ICFLoopSafetyInfo SafetyInfo;
+ SafetyInfo.computeLoopSafetyInfo(&L);
+ if (SafetyInfo.isGuaranteedToExecute(TI, &DT, &L))
+ return false;
+
+ Value *Cond;
+ if (BranchInst *BI = dyn_cast<BranchInst>(&TI))
+ Cond = skipTrivialSelect(BI->getCondition());
+ else
+ Cond = skipTrivialSelect(cast<SwitchInst>(&TI)->getCondition());
+ return !isGuaranteedNotToBeUndefOrPoison(
+ Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
+}
+
static bool unswitchBestCondition(
Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
AAResults &AA, TargetTransformInfo &TTI,
@@ -3044,9 +3450,13 @@ static bool unswitchBestCondition(
SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
IVConditionInfo PartialIVInfo;
Instruction *PartialIVCondBranch = nullptr;
+ collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
+ PartialIVCondBranch, L, LI, AA, MSSAU);
+ collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
+ PartialIVCondBranch, L, DT, LI, AA,
+ MSSAU);
// If we didn't find any candidates, we're done.
- if (!collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
- PartialIVCondBranch, L, LI, AA, MSSAU))
+ if (UnswitchCandidates.empty())
return false;
LLVM_DEBUG(
@@ -3065,18 +3475,36 @@ static bool unswitchBestCondition(
return false;
}
+ if (Best.hasPendingInjection())
+ Best = injectPendingInvariantConditions(Best, L, DT, LI, AC, MSSAU);
+ assert(!Best.hasPendingInjection() &&
+ "All injections should have been done by now!");
+
if (Best.TI != PartialIVCondBranch)
PartialIVInfo.InstToDuplicate.clear();
- // If the best candidate is a guard, turn it into a branch.
- if (isGuard(Best.TI))
- Best.TI =
- turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU);
+ bool InsertFreeze;
+ if (auto *SI = dyn_cast<SelectInst>(Best.TI)) {
+ // If the best candidate is a select, turn it into a branch. Select
+ // instructions with a poison conditional do not propagate poison, but
+ // branching on poison causes UB. Insert a freeze on the select
+ // conditional to prevent UB after turning the select into a branch.
+ InsertFreeze = !isGuaranteedNotToBeUndefOrPoison(
+ SI->getCondition(), &AC, L.getLoopPreheader()->getTerminator(), &DT);
+ Best.TI = turnSelectIntoBranch(SI, DT, LI, MSSAU, &AC);
+ } else {
+ // If the best candidate is a guard, turn it into a branch.
+ if (isGuard(Best.TI))
+ Best.TI =
+ turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU);
+ InsertFreeze = shouldInsertFreeze(L, *Best.TI, DT, AC);
+ }
LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost
<< ") terminator: " << *Best.TI << "\n");
unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
- LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB);
+ LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
+ InsertFreeze);
return true;
}
@@ -3124,6 +3552,8 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
return true;
}
+ const Function *F = L.getHeader()->getParent();
+
// Check whether we should continue with non-trivial conditions.
// EnableNonTrivialUnswitch: Global variable that forces non-trivial
// unswitching for testing and debugging.
@@ -3136,18 +3566,41 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
// branches even on targets that have divergence.
// https://bugs.llvm.org/show_bug.cgi?id=48819
bool ContinueWithNonTrivial =
- EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence());
+ EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence(F));
if (!ContinueWithNonTrivial)
return false;
// Skip non-trivial unswitching for optsize functions.
- if (L.getHeader()->getParent()->hasOptSize())
+ if (F->hasOptSize())
return false;
- // Skip cold loops, as unswitching them brings little benefit
- // but increases the code size
- if (PSI && PSI->hasProfileSummary() && BFI &&
- PSI->isFunctionColdInCallGraph(L.getHeader()->getParent(), *BFI)) {
+ // Returns true if Loop L's loop nest is cold, i.e. if the headers of L,
+ // of the loops L is nested in, and of the loops nested in L are all cold.
+ auto IsLoopNestCold = [&](const Loop *L) {
+ // Check L and all of its parent loops.
+ auto *Parent = L;
+ while (Parent) {
+ if (!PSI->isColdBlock(Parent->getHeader(), BFI))
+ return false;
+ Parent = Parent->getParentLoop();
+ }
+ // Next check all loops nested within L.
+ SmallVector<const Loop *, 4> Worklist;
+ Worklist.insert(Worklist.end(), L->getSubLoops().begin(),
+ L->getSubLoops().end());
+ while (!Worklist.empty()) {
+ auto *CurLoop = Worklist.pop_back_val();
+ if (!PSI->isColdBlock(CurLoop->getHeader(), BFI))
+ return false;
+ Worklist.insert(Worklist.end(), CurLoop->getSubLoops().begin(),
+ CurLoop->getSubLoops().end());
+ }
+ return true;
+ };
+
+ // Skip cold loops in cold loop nests, as unswitching them brings little
+ // benefit but increases the code size
+ if (PSI && PSI->hasProfileSummary() && BFI && IsLoopNestCold(&L)) {
LLVM_DEBUG(dbgs() << " Skip cold loop: " << L << "\n");
return false;
}
@@ -3249,10 +3702,10 @@ void SimpleLoopUnswitchPass::printPipeline(
static_cast<PassInfoMixin<SimpleLoopUnswitchPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
OS << (NonTrivial ? "" : "no-") << "nontrivial;";
OS << (Trivial ? "" : "no-") << "trivial";
- OS << ">";
+ OS << '>';
}
namespace {
diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
index e014f5d1eb04..7017f6adf3a2 100644
--- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
+++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
@@ -121,7 +121,7 @@ performBlockTailMerging(Function &F, ArrayRef<BasicBlock *> BBs,
// Now, go through each block (with the current terminator type)
// we've recorded, and rewrite it to branch to the new common block.
- const DILocation *CommonDebugLoc = nullptr;
+ DILocation *CommonDebugLoc = nullptr;
for (BasicBlock *BB : BBs) {
auto *Term = BB->getTerminator();
assert(Term->getOpcode() == CanonicalTerm->getOpcode() &&
@@ -228,8 +228,8 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 32> Edges;
FindFunctionBackedges(F, Edges);
SmallPtrSet<BasicBlock *, 16> UniqueLoopHeaders;
- for (unsigned i = 0, e = Edges.size(); i != e; ++i)
- UniqueLoopHeaders.insert(const_cast<BasicBlock *>(Edges[i].second));
+ for (const auto &Edge : Edges)
+ UniqueLoopHeaders.insert(const_cast<BasicBlock *>(Edge.second));
SmallVector<WeakVH, 16> LoopHeaders(UniqueLoopHeaders.begin(),
UniqueLoopHeaders.end());
@@ -338,8 +338,8 @@ void SimplifyCFGPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<SimplifyCFGPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
- OS << "bonus-inst-threshold=" << Options.BonusInstThreshold << ";";
+ OS << '<';
+ OS << "bonus-inst-threshold=" << Options.BonusInstThreshold << ';';
OS << (Options.ForwardSwitchCondToPhi ? "" : "no-") << "forward-switch-cond;";
OS << (Options.ConvertSwitchRangeToICmp ? "" : "no-")
<< "switch-range-to-icmp;";
@@ -347,8 +347,10 @@ void SimplifyCFGPass::printPipeline(
<< "switch-to-lookup;";
OS << (Options.NeedCanonicalLoop ? "" : "no-") << "keep-loops;";
OS << (Options.HoistCommonInsts ? "" : "no-") << "hoist-common-insts;";
- OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts";
- OS << ">";
+ OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts;";
+ OS << (Options.SpeculateBlocks ? "" : "no-") << "speculate-blocks;";
+ OS << (Options.SimplifyCondBranch ? "" : "no-") << "simplify-cond-branch";
+ OS << '>';
}
PreservedAnalyses SimplifyCFGPass::run(Function &F,
@@ -358,11 +360,6 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F,
DominatorTree *DT = nullptr;
if (RequireAndPreserveDomTree)
DT = &AM.getResult<DominatorTreeAnalysis>(F);
- if (F.hasFnAttribute(Attribute::OptForFuzzing)) {
- Options.setSimplifyCondBranch(false).setFoldTwoEntryPHINode(false);
- } else {
- Options.setSimplifyCondBranch(true).setFoldTwoEntryPHINode(true);
- }
if (!simplifyFunctionCFG(F, TTI, DT, Options))
return PreservedAnalyses::all();
PreservedAnalyses PA;
@@ -395,13 +392,6 @@ struct CFGSimplifyPass : public FunctionPass {
DominatorTree *DT = nullptr;
if (RequireAndPreserveDomTree)
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- if (F.hasFnAttribute(Attribute::OptForFuzzing)) {
- Options.setSimplifyCondBranch(false)
- .setFoldTwoEntryPHINode(false);
- } else {
- Options.setSimplifyCondBranch(true)
- .setFoldTwoEntryPHINode(true);
- }
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return simplifyFunctionCFG(F, TTI, DT, Options);
diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
index 65f8d760ede3..e866fe681127 100644
--- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
+++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
@@ -152,7 +152,7 @@ bool SpeculativeExecutionLegacyPass::runOnFunction(Function &F) {
namespace llvm {
bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) {
- if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) {
+ if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence(&F)) {
LLVM_DEBUG(dbgs() << "Not running SpeculativeExecution because "
"TTI->hasBranchDivergence() is false.\n");
return false;
diff --git a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
index 70df0cec0dca..fdb41cb415df 100644
--- a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
@@ -484,9 +484,9 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
// = B + (sext(Idx) * sext(S)) * ElementSize
// = B + (sext(Idx) * ElementSize) * sext(S)
// Casting to IntegerType is safe because we skipped vector GEPs.
- IntegerType *IntPtrTy = cast<IntegerType>(DL->getIntPtrType(I->getType()));
+ IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(I->getType()));
ConstantInt *ScaledIdx = ConstantInt::get(
- IntPtrTy, Idx->getSExtValue() * (int64_t)ElementSize, true);
+ PtrIdxTy, Idx->getSExtValue() * (int64_t)ElementSize, true);
allocateCandidatesAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I);
}
@@ -549,18 +549,18 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
Value *ArrayIdx = GEP->getOperand(I);
uint64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType());
if (ArrayIdx->getType()->getIntegerBitWidth() <=
- DL->getPointerSizeInBits(GEP->getAddressSpace())) {
- // Skip factoring if ArrayIdx is wider than the pointer size, because
- // ArrayIdx is implicitly truncated to the pointer size.
+ DL->getIndexSizeInBits(GEP->getAddressSpace())) {
+ // Skip factoring if ArrayIdx is wider than the index size, because
+ // ArrayIdx is implicitly truncated to the index size.
factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
}
// When ArrayIdx is the sext of a value, we try to factor that value as
// well. Handling this case is important because array indices are
- // typically sign-extended to the pointer size.
+ // typically sign-extended to the pointer index size.
Value *TruncatedArrayIdx = nullptr;
if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) &&
TruncatedArrayIdx->getType()->getIntegerBitWidth() <=
- DL->getPointerSizeInBits(GEP->getAddressSpace())) {
+ DL->getIndexSizeInBits(GEP->getAddressSpace())) {
// Skip factoring if TruncatedArrayIdx is wider than the pointer size,
// because TruncatedArrayIdx is implicitly truncated to the pointer size.
factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
@@ -675,24 +675,24 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis(
}
case Candidate::GEP:
{
- Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType());
- bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
- if (BumpWithUglyGEP) {
- // C = (char *)Basis + Bump
- unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
- Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS);
- Reduced = Builder.CreateBitCast(Basis.Ins, CharTy);
- Reduced =
- Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds);
- Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType());
- } else {
- // C = gep Basis, Bump
- // Canonicalize bump to pointer size.
- Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy);
- Reduced = Builder.CreateGEP(
- cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(),
- Basis.Ins, Bump, "", InBounds);
- }
+ Type *OffsetTy = DL->getIndexType(C.Ins->getType());
+ bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
+ if (BumpWithUglyGEP) {
+ // C = (char *)Basis + Bump
+ unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
+ Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS);
+ Reduced = Builder.CreateBitCast(Basis.Ins, CharTy);
+ Reduced =
+ Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds);
+ Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType());
+ } else {
+ // C = gep Basis, Bump
+ // Canonicalize bump to pointer size.
+ Bump = Builder.CreateSExtOrTrunc(Bump, OffsetTy);
+ Reduced = Builder.CreateGEP(
+ cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), Basis.Ins,
+ Bump, "", InBounds);
+ }
break;
}
default:
diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
index 81d151c2904e..fac5695c7bea 100644
--- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
+++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
@@ -15,10 +15,10 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Analysis/LegacyDivergenceAnalysis.h"
#include "llvm/Analysis/RegionInfo.h"
#include "llvm/Analysis/RegionIterator.h"
#include "llvm/Analysis/RegionPass.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
@@ -239,12 +239,12 @@ class StructurizeCFG {
Type *Boolean;
ConstantInt *BoolTrue;
ConstantInt *BoolFalse;
- UndefValue *BoolUndef;
+ Value *BoolPoison;
Function *Func;
Region *ParentRegion;
- LegacyDivergenceAnalysis *DA = nullptr;
+ UniformityInfo *UA = nullptr;
DominatorTree *DT;
SmallVector<RegionNode *, 8> Order;
@@ -319,7 +319,7 @@ class StructurizeCFG {
public:
void init(Region *R);
bool run(Region *R, DominatorTree *DT);
- bool makeUniformRegion(Region *R, LegacyDivergenceAnalysis *DA);
+ bool makeUniformRegion(Region *R, UniformityInfo &UA);
};
class StructurizeCFGLegacyPass : public RegionPass {
@@ -339,8 +339,9 @@ public:
StructurizeCFG SCFG;
SCFG.init(R);
if (SkipUniformRegions) {
- LegacyDivergenceAnalysis *DA = &getAnalysis<LegacyDivergenceAnalysis>();
- if (SCFG.makeUniformRegion(R, DA))
+ UniformityInfo &UA =
+ getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+ if (SCFG.makeUniformRegion(R, UA))
return false;
}
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
@@ -351,7 +352,7 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
if (SkipUniformRegions)
- AU.addRequired<LegacyDivergenceAnalysis>();
+ AU.addRequired<UniformityInfoWrapperPass>();
AU.addRequiredID(LowerSwitchID);
AU.addRequired<DominatorTreeWrapperPass>();
@@ -366,7 +367,7 @@ char StructurizeCFGLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(StructurizeCFGLegacyPass, "structurizecfg",
"Structurize the CFG", false, false)
-INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis)
+INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(RegionInfoPass)
@@ -798,8 +799,6 @@ void StructurizeCFG::killTerminator(BasicBlock *BB) {
for (BasicBlock *Succ : successors(BB))
delPhiValues(BB, Succ);
- if (DA)
- DA->removeValue(Term);
Term->eraseFromParent();
}
@@ -957,7 +956,7 @@ void StructurizeCFG::wireFlow(bool ExitUseAllowed,
BasicBlock *Next = needPostfix(Flow, ExitUseAllowed);
// let it point to entry and next block
- BranchInst *Br = BranchInst::Create(Entry, Next, BoolUndef, Flow);
+ BranchInst *Br = BranchInst::Create(Entry, Next, BoolPoison, Flow);
Br->setDebugLoc(TermDL[Flow]);
Conditions.push_back(Br);
addPhiValues(Flow, Entry);
@@ -998,7 +997,7 @@ void StructurizeCFG::handleLoops(bool ExitUseAllowed,
// Create an extra loop end node
LoopEnd = needPrefix(false);
BasicBlock *Next = needPostfix(LoopEnd, ExitUseAllowed);
- BranchInst *Br = BranchInst::Create(Next, LoopStart, BoolUndef, LoopEnd);
+ BranchInst *Br = BranchInst::Create(Next, LoopStart, BoolPoison, LoopEnd);
Br->setDebugLoc(TermDL[LoopEnd]);
LoopConds.push_back(Br);
addPhiValues(LoopEnd, LoopStart);
@@ -1064,7 +1063,7 @@ void StructurizeCFG::rebuildSSA() {
}
static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID,
- const LegacyDivergenceAnalysis &DA) {
+ const UniformityInfo &UA) {
// Bool for if all sub-regions are uniform.
bool SubRegionsAreUniform = true;
// Count of how many direct children are conditional.
@@ -1076,7 +1075,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID,
if (!Br || !Br->isConditional())
continue;
- if (!DA.isUniform(Br))
+ if (!UA.isUniform(Br))
return false;
// One of our direct children is conditional.
@@ -1086,7 +1085,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID,
<< " has uniform terminator\n");
} else {
// Explicitly refuse to treat regions as uniform if they have non-uniform
- // subregions. We cannot rely on DivergenceAnalysis for branches in
+ // subregions. We cannot rely on UniformityAnalysis for branches in
// subregions because those branches may have been removed and re-created,
// so we look for our metadata instead.
//
@@ -1126,17 +1125,17 @@ void StructurizeCFG::init(Region *R) {
Boolean = Type::getInt1Ty(Context);
BoolTrue = ConstantInt::getTrue(Context);
BoolFalse = ConstantInt::getFalse(Context);
- BoolUndef = UndefValue::get(Boolean);
+ BoolPoison = PoisonValue::get(Boolean);
- this->DA = nullptr;
+ this->UA = nullptr;
}
-bool StructurizeCFG::makeUniformRegion(Region *R,
- LegacyDivergenceAnalysis *DA) {
+bool StructurizeCFG::makeUniformRegion(Region *R, UniformityInfo &UA) {
if (R->isTopLevelRegion())
return false;
- this->DA = DA;
+ this->UA = &UA;
+
// TODO: We could probably be smarter here with how we handle sub-regions.
// We currently rely on the fact that metadata is set by earlier invocations
// of the pass on sub-regions, and that this metadata doesn't get lost --
@@ -1144,7 +1143,7 @@ bool StructurizeCFG::makeUniformRegion(Region *R,
unsigned UniformMDKindID =
R->getEntry()->getContext().getMDKindID("structurizecfg.uniform");
- if (hasOnlyUniformBranches(R, UniformMDKindID, *DA)) {
+ if (hasOnlyUniformBranches(R, UniformMDKindID, UA)) {
LLVM_DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R
<< '\n');
diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
index 9e08954ef643..e53019768e88 100644
--- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
+++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp
@@ -13,7 +13,6 @@
#include "llvm/Transforms/Scalar/WarnMissedTransforms.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
using namespace llvm;
@@ -104,47 +103,3 @@ WarnMissedTransformationsPass::run(Function &F, FunctionAnalysisManager &AM) {
return PreservedAnalyses::all();
}
-
-// Legacy pass manager boilerplate
-namespace {
-class WarnMissedTransformationsLegacy : public FunctionPass {
-public:
- static char ID;
-
- explicit WarnMissedTransformationsLegacy() : FunctionPass(ID) {
- initializeWarnMissedTransformationsLegacyPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-
- warnAboutLeftoverTransformations(&F, &LI, &ORE);
- return false;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
-
- AU.setPreservesAll();
- }
-};
-} // end anonymous namespace
-
-char WarnMissedTransformationsLegacy::ID = 0;
-
-INITIALIZE_PASS_BEGIN(WarnMissedTransformationsLegacy, "transform-warning",
- "Warn about non-applied transformations", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_END(WarnMissedTransformationsLegacy, "transform-warning",
- "Warn about non-applied transformations", false, false)
-
-Pass *llvm::createWarnMissedTransformationsPass() {
- return new WarnMissedTransformationsLegacy();
-}
diff --git a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
index 24972db404be..2195406c144c 100644
--- a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
+++ b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
@@ -16,7 +16,11 @@
#include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
#include "llvm/ADT/SparseBitVector.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Support/DataExtractor.h"
+#include "llvm/Support/MD5.h"
+#include "llvm/Support/MathExtras.h"
using namespace llvm;
@@ -179,11 +183,7 @@ static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
// Scan the format string to locate all specifiers, and mark the ones that
// specify a string, i.e, the "%s" specifier with optional '*' characters.
-static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
- StringRef Str;
- if (!getConstantStringInfo(Fmt, Str) || Str.empty())
- return;
-
+static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) {
static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
size_t SpecPos = 0;
// Skip the first argument, the format string.
@@ -207,14 +207,320 @@ static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
}
}
-Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder,
- ArrayRef<Value *> Args) {
+// helper struct to package the string related data
+struct StringData {
+ StringRef Str;
+ Value *RealSize = nullptr;
+ Value *AlignedSize = nullptr;
+ bool IsConst = true;
+
+ StringData(StringRef ST, Value *RS, Value *AS, bool IC)
+ : Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {}
+};
+
+// Calculates frame size required for current printf expansion and allocates
+// space on printf buffer. Printf frame includes following contents
+// [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ]
+static Value *callBufferedPrintfStart(
+ IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt,
+ bool isConstFmtStr, SparseBitVector<8> &SpecIsCString,
+ SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) {
+ Module *M = Builder.GetInsertBlock()->getModule();
+ Value *NonConstStrLen = nullptr;
+ Value *LenWithNull = nullptr;
+ Value *LenWithNullAligned = nullptr;
+ Value *TempAdd = nullptr;
+
+ // First 4 bytes to be reserved for control dword
+ size_t BufSize = 4;
+ if (isConstFmtStr)
+ // First 8 bytes of MD5 hash
+ BufSize += 8;
+ else {
+ LenWithNull = getStrlenWithNull(Builder, Fmt);
+
+ // Align the computed length to next 8 byte boundary
+ TempAdd = Builder.CreateAdd(LenWithNull,
+ ConstantInt::get(LenWithNull->getType(), 7U));
+ NonConstStrLen = Builder.CreateAnd(
+ TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
+
+ StringContents.push_back(
+ StringData(StringRef(), LenWithNull, NonConstStrLen, false));
+ }
+
+ for (size_t i = 1; i < Args.size(); i++) {
+ if (SpecIsCString.test(i)) {
+ StringRef ArgStr;
+ if (getConstantStringInfo(Args[i], ArgStr)) {
+ auto alignedLen = alignTo(ArgStr.size() + 1, 8);
+ StringContents.push_back(StringData(
+ ArgStr,
+ /*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true));
+ BufSize += alignedLen;
+ } else {
+ LenWithNull = getStrlenWithNull(Builder, Args[i]);
+
+ // Align the computed length to next 8 byte boundary
+ TempAdd = Builder.CreateAdd(
+ LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U));
+ LenWithNullAligned = Builder.CreateAnd(
+ TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
+
+ if (NonConstStrLen) {
+ auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen,
+ "cumulativeAdd");
+ NonConstStrLen = Val;
+ } else
+ NonConstStrLen = LenWithNullAligned;
+
+ StringContents.push_back(
+ StringData(StringRef(), LenWithNull, LenWithNullAligned, false));
+ }
+ } else {
+ int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType());
+ // We end up expanding non string arguments to 8 bytes
+ // (args smaller than 8 bytes)
+ BufSize += std::max(AllocSize, 8);
+ }
+ }
+
+ // calculate final size value to be passed to printf_alloc
+ Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false);
+ SmallVector<Value *, 1> Alloc_args;
+ if (NonConstStrLen)
+ SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve);
+
+ ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty());
+ Alloc_args.push_back(ArgSize);
+
+ // call the printf_alloc function
+ AttributeList Attr = AttributeList::get(
+ Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);
+
+ Type *Tys_alloc[1] = {Builder.getInt32Ty()};
+ Type *I8Ptr =
+ Builder.getInt8PtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
+ FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false);
+ auto PrintfAllocFn =
+ M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);
+
+ return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn");
+}
+
+// Prepare constant string argument to push onto the buffer
+static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder,
+ SmallVectorImpl<Value *> &WhatToStore) {
+ std::string Str(SD->Str.str() + '\0');
+
+ DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8);
+ DataExtractor::Cursor Offset(0);
+ while (Offset && Offset.tell() < Str.size()) {
+ const uint64_t ReadSize = 4;
+ uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell());
+ uint64_t ReadBytes = 0;
+ switch (ReadNow) {
+ default:
+ llvm_unreachable("min(4, X) > 4?");
+ case 1:
+ ReadBytes = Extractor.getU8(Offset);
+ break;
+ case 2:
+ ReadBytes = Extractor.getU16(Offset);
+ break;
+ case 3:
+ ReadBytes = Extractor.getU24(Offset);
+ break;
+ case 4:
+ ReadBytes = Extractor.getU32(Offset);
+ break;
+ }
+ cantFail(Offset.takeError(), "failed to read bytes from constant array");
+
+ APInt IntVal(8 * ReadSize, ReadBytes);
+
+ // TODO: Should not bother aligning up.
+ if (ReadNow < ReadSize)
+ IntVal = IntVal.zext(8 * ReadSize);
+
+ Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth());
+ WhatToStore.push_back(ConstantInt::get(IntTy, IntVal));
+ }
+ // Additional padding for 8 byte alignment
+ int Rem = (Str.size() % 8);
+ if (Rem > 0 && Rem <= 4)
+ WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
+}
+
+static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) {
+ const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
+ auto Ty = Arg->getType();
+
+ if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
+ if (IntTy->getBitWidth() < 64) {
+ return Builder.CreateZExt(Arg, Builder.getInt64Ty());
+ }
+ }
+
+ if (Ty->isFloatingPointTy()) {
+ if (DL.getTypeAllocSize(Ty) < 8) {
+ return Builder.CreateFPExt(Arg, Builder.getDoubleTy());
+ }
+ }
+
+ return Arg;
+}
+
+static void
+callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args,
+ Value *PtrToStore, SparseBitVector<8> &SpecIsCString,
+ SmallVectorImpl<StringData> &StringContents,
+ bool IsConstFmtStr) {
+ Module *M = Builder.GetInsertBlock()->getModule();
+ const DataLayout &DL = M->getDataLayout();
+ auto StrIt = StringContents.begin();
+ size_t i = IsConstFmtStr ? 1 : 0;
+ for (; i < Args.size(); i++) {
+ SmallVector<Value *, 32> WhatToStore;
+ if ((i == 0) || SpecIsCString.test(i)) {
+ if (StrIt->IsConst) {
+ processConstantStringArg(StrIt, Builder, WhatToStore);
+ StrIt++;
+ } else {
+ // This copies the contents of the string, however the next offset
+ // is at aligned length, the extra space that might be created due
+ // to alignment padding is not populated with any specific value
+ // here. This would be safe as long as runtime is sync with
+ // the offsets.
+ Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i],
+ /*SrcAlign*/ Args[i]->getPointerAlignment(DL),
+ StrIt->RealSize);
+
+ PtrToStore =
+ Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore,
+ {StrIt->AlignedSize}, "PrintBuffNextPtr");
+ LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:"
+ << *PtrToStore << '\n');
+
+ // done with current argument, move to next
+ StrIt++;
+ continue;
+ }
+ } else {
+ WhatToStore.push_back(processNonStringArg(Args[i], Builder));
+ }
+
+ for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) {
+ Value *toStore = WhatToStore[I];
+
+ StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore);
+ LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff
+ << '\n');
+ (void)StBuff;
+ PtrToStore = Builder.CreateConstInBoundsGEP1_32(
+ Builder.getInt8Ty(), PtrToStore,
+ M->getDataLayout().getTypeAllocSize(toStore->getType()),
+ "PrintBuffNextPtr");
+ LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore
+ << '\n');
+ }
+ }
+}
+
+Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args,
+ bool IsBuffered) {
auto NumOps = Args.size();
assert(NumOps >= 1);
auto Fmt = Args[0];
SparseBitVector<8> SpecIsCString;
- locateCStrings(SpecIsCString, Fmt);
+ StringRef FmtStr;
+
+ if (getConstantStringInfo(Fmt, FmtStr))
+ locateCStrings(SpecIsCString, FmtStr);
+
+ if (IsBuffered) {
+ SmallVector<StringData, 8> StringContents;
+ Module *M = Builder.GetInsertBlock()->getModule();
+ LLVMContext &Ctx = Builder.getContext();
+ auto Int8Ty = Builder.getInt8Ty();
+ auto Int32Ty = Builder.getInt32Ty();
+ bool IsConstFmtStr = !FmtStr.empty();
+
+ Value *ArgSize = nullptr;
+ Value *Ptr =
+ callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr,
+ SpecIsCString, StringContents, ArgSize);
+
+ // The buffered version still follows OpenCL printf standards for
+ // printf return value, i.e 0 on success, -1 on failure.
+ ConstantPointerNull *zeroIntPtr =
+ ConstantPointerNull::get(cast<PointerType>(Ptr->getType()));
+
+ auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, ""));
+
+ BasicBlock *End = BasicBlock::Create(Ctx, "end.block",
+ Builder.GetInsertBlock()->getParent());
+ BasicBlock *ArgPush = BasicBlock::Create(
+ Ctx, "argpush.block", Builder.GetInsertBlock()->getParent());
+
+ BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock());
+ Builder.SetInsertPoint(ArgPush);
+
+ // Create controlDWord and store as the first entry, format as follows
+ // Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to
+ // stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size
+ // of printf data frame
+ auto ConstantTwo = Builder.getInt32(2);
+ auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo);
+ if (IsConstFmtStr)
+ ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo);
+
+ Builder.CreateStore(ControlDWord, Ptr);
+
+ Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4);
+
+ // Create MD5 hash for costant format string, push low 64 bits of the
+ // same onto buffer and metadata.
+ NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts");
+ if (IsConstFmtStr) {
+ MD5 Hasher;
+ MD5::MD5Result Hash;
+ Hasher.update(FmtStr);
+ Hasher.final(Hash);
+
+ // Try sticking to llvm.printf.fmts format, although we are not going to
+ // use the ID and argument size fields while printing,
+ std::string MetadataStr =
+ "0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," +
+ FmtStr.str();
+ MDString *fmtStrArray = MDString::get(Ctx, MetadataStr);
+ MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
+ metaD->addOperand(myMD);
+
+ Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr);
+ Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8);
+ } else {
+ // Include a dummy metadata instance in case of only non constant
+ // format string usage, This might be an absurd usecase but needs to
+ // be done for completeness
+ if (metaD->getNumOperands() == 0) {
+ MDString *fmtStrArray =
+ MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\"");
+ MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
+ metaD->addOperand(myMD);
+ }
+ }
+
+ // Push The printf arguments onto buffer
+ callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents,
+ IsConstFmtStr);
+
+ // End block, returns -1 on failure
+ BranchInst::Create(End, ArgPush);
+ Builder.SetInsertPoint(End);
+ return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result");
+ }
auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
diff --git a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
index 56acdcc0bc3c..7d127400651e 100644
--- a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
+++ b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
@@ -85,33 +85,6 @@ static cl::opt<bool> NoDiscriminators(
"no-discriminators", cl::init(false),
cl::desc("Disable generation of discriminator information."));
-namespace {
-
-// The legacy pass of AddDiscriminators.
-struct AddDiscriminatorsLegacyPass : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
-
- AddDiscriminatorsLegacyPass() : FunctionPass(ID) {
- initializeAddDiscriminatorsLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-};
-
-} // end anonymous namespace
-
-char AddDiscriminatorsLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(AddDiscriminatorsLegacyPass, "add-discriminators",
- "Add DWARF path discriminators", false, false)
-INITIALIZE_PASS_END(AddDiscriminatorsLegacyPass, "add-discriminators",
- "Add DWARF path discriminators", false, false)
-
-// Create the legacy AddDiscriminatorsPass.
-FunctionPass *llvm::createAddDiscriminatorsPass() {
- return new AddDiscriminatorsLegacyPass();
-}
-
static bool shouldHaveDiscriminator(const Instruction *I) {
return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I);
}
@@ -269,10 +242,6 @@ static bool addDiscriminators(Function &F) {
return Changed;
}
-bool AddDiscriminatorsLegacyPass::runOnFunction(Function &F) {
- return addDiscriminators(F);
-}
-
PreservedAnalyses AddDiscriminatorsPass::run(Function &F,
FunctionAnalysisManager &AM) {
if (!addDiscriminators(F))
diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
index d17c399ba798..45cf98e65a5a 100644
--- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
+++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
@@ -290,17 +290,20 @@ AssumeInst *llvm::buildAssumeFromInst(Instruction *I) {
return Builder.build();
}
-void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC,
+bool llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC,
DominatorTree *DT) {
if (!EnableKnowledgeRetention || I->isTerminator())
- return;
+ return false;
+ bool Changed = false;
AssumeBuilderState Builder(I->getModule(), I, AC, DT);
Builder.addInstruction(I);
if (auto *Intr = Builder.build()) {
Intr->insertBefore(I);
+ Changed = true;
if (AC)
AC->registerAssumption(Intr);
}
+ return Changed;
}
AssumeInst *
@@ -563,57 +566,26 @@ PreservedAnalyses AssumeSimplifyPass::run(Function &F,
FunctionAnalysisManager &AM) {
if (!EnableKnowledgeRetention)
return PreservedAnalyses::all();
- simplifyAssumes(F, &AM.getResult<AssumptionAnalysis>(F),
- AM.getCachedResult<DominatorTreeAnalysis>(F));
- return PreservedAnalyses::all();
-}
-
-namespace {
-class AssumeSimplifyPassLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- AssumeSimplifyPassLegacyPass() : FunctionPass(ID) {
- initializeAssumeSimplifyPassLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- if (skipFunction(F) || !EnableKnowledgeRetention)
- return false;
- AssumptionCache &AC =
- getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- DominatorTreeWrapperPass *DTWP =
- getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- return simplifyAssumes(F, &AC, DTWP ? &DTWP->getDomTree() : nullptr);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
-
- AU.setPreservesAll();
- }
-};
-} // namespace
-
-char AssumeSimplifyPassLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(AssumeSimplifyPassLegacyPass, "assume-simplify",
- "Assume Simplify", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(AssumeSimplifyPassLegacyPass, "assume-simplify",
- "Assume Simplify", false, false)
-
-FunctionPass *llvm::createAssumeSimplifyPass() {
- return new AssumeSimplifyPassLegacyPass();
+ if (!simplifyAssumes(F, &AM.getResult<AssumptionAnalysis>(F),
+ AM.getCachedResult<DominatorTreeAnalysis>(F)))
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
PreservedAnalyses AssumeBuilderPass::run(Function &F,
FunctionAnalysisManager &AM) {
AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
DominatorTree* DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
+ bool Changed = false;
for (Instruction &I : instructions(F))
- salvageKnowledge(&I, AC, DT);
- return PreservedAnalyses::all();
+ Changed |= salvageKnowledge(&I, AC, DT);
+ if (!Changed)
+ PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
namespace {
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index 58a226fc601c..f06ea89cc61d 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -32,6 +32,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
@@ -379,8 +380,8 @@ bool llvm::MergeBlockSuccessorsIntoGivenBlocks(
///
/// Possible improvements:
/// - Check fully overlapping fragments and not only identical fragments.
-/// - Support dbg.addr, dbg.declare. dbg.label, and possibly other meta
-/// instructions being part of the sequence of consecutive instructions.
+/// - Support dbg.declare. dbg.label, and possibly other meta instructions being
+/// part of the sequence of consecutive instructions.
static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) {
SmallVector<DbgValueInst *, 8> ToBeRemoved;
SmallDenseSet<DebugVariable> VariableSet;
@@ -599,8 +600,8 @@ bool llvm::IsBlockFollowedByDeoptOrUnreachable(const BasicBlock *BB) {
unsigned Depth = 0;
while (BB && Depth++ < MaxDeoptOrUnreachableSuccessorCheckDepth &&
VisitedBlocks.insert(BB).second) {
- if (BB->getTerminatingDeoptimizeCall() ||
- isa<UnreachableInst>(BB->getTerminator()))
+ if (isa<UnreachableInst>(BB->getTerminator()) ||
+ BB->getTerminatingDeoptimizeCall())
return true;
BB = BB->getUniqueSuccessor();
}
@@ -1470,133 +1471,198 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB,
return cast<ReturnInst>(NewRet);
}
-static Instruction *
-SplitBlockAndInsertIfThenImpl(Value *Cond, Instruction *SplitBefore,
- bool Unreachable, MDNode *BranchWeights,
- DomTreeUpdater *DTU, DominatorTree *DT,
- LoopInfo *LI, BasicBlock *ThenBlock) {
- SmallVector<DominatorTree::UpdateType, 8> Updates;
- BasicBlock *Head = SplitBefore->getParent();
- BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator());
- if (DTU) {
- SmallPtrSet<BasicBlock *, 8> UniqueSuccessorsOfHead;
- Updates.push_back({DominatorTree::Insert, Head, Tail});
- Updates.reserve(Updates.size() + 2 * succ_size(Tail));
- for (BasicBlock *SuccessorOfHead : successors(Tail))
- if (UniqueSuccessorsOfHead.insert(SuccessorOfHead).second) {
- Updates.push_back({DominatorTree::Insert, Tail, SuccessorOfHead});
- Updates.push_back({DominatorTree::Delete, Head, SuccessorOfHead});
- }
- }
- Instruction *HeadOldTerm = Head->getTerminator();
- LLVMContext &C = Head->getContext();
- Instruction *CheckTerm;
- bool CreateThenBlock = (ThenBlock == nullptr);
- if (CreateThenBlock) {
- ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail);
- if (Unreachable)
- CheckTerm = new UnreachableInst(C, ThenBlock);
- else {
- CheckTerm = BranchInst::Create(Tail, ThenBlock);
- if (DTU)
- Updates.push_back({DominatorTree::Insert, ThenBlock, Tail});
- }
- CheckTerm->setDebugLoc(SplitBefore->getDebugLoc());
- } else
- CheckTerm = ThenBlock->getTerminator();
- BranchInst *HeadNewTerm =
- BranchInst::Create(/*ifTrue*/ ThenBlock, /*ifFalse*/ Tail, Cond);
- if (DTU)
- Updates.push_back({DominatorTree::Insert, Head, ThenBlock});
- HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights);
- ReplaceInstWithInst(HeadOldTerm, HeadNewTerm);
-
- if (DTU)
- DTU->applyUpdates(Updates);
- else if (DT) {
- if (DomTreeNode *OldNode = DT->getNode(Head)) {
- std::vector<DomTreeNode *> Children(OldNode->begin(), OldNode->end());
-
- DomTreeNode *NewNode = DT->addNewBlock(Tail, Head);
- for (DomTreeNode *Child : Children)
- DT->changeImmediateDominator(Child, NewNode);
-
- // Head dominates ThenBlock.
- if (CreateThenBlock)
- DT->addNewBlock(ThenBlock, Head);
- else
- DT->changeImmediateDominator(ThenBlock, Head);
- }
- }
-
- if (LI) {
- if (Loop *L = LI->getLoopFor(Head)) {
- L->addBasicBlockToLoop(ThenBlock, *LI);
- L->addBasicBlockToLoop(Tail, *LI);
- }
- }
-
- return CheckTerm;
-}
-
Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond,
Instruction *SplitBefore,
bool Unreachable,
MDNode *BranchWeights,
- DominatorTree *DT, LoopInfo *LI,
+ DomTreeUpdater *DTU, LoopInfo *LI,
BasicBlock *ThenBlock) {
- return SplitBlockAndInsertIfThenImpl(Cond, SplitBefore, Unreachable,
- BranchWeights,
- /*DTU=*/nullptr, DT, LI, ThenBlock);
+ SplitBlockAndInsertIfThenElse(
+ Cond, SplitBefore, &ThenBlock, /* ElseBlock */ nullptr,
+ /* UnreachableThen */ Unreachable,
+ /* UnreachableElse */ false, BranchWeights, DTU, LI);
+ return ThenBlock->getTerminator();
}
-Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond,
+
+Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond,
Instruction *SplitBefore,
bool Unreachable,
MDNode *BranchWeights,
DomTreeUpdater *DTU, LoopInfo *LI,
- BasicBlock *ThenBlock) {
- return SplitBlockAndInsertIfThenImpl(Cond, SplitBefore, Unreachable,
- BranchWeights, DTU, /*DT=*/nullptr, LI,
- ThenBlock);
+ BasicBlock *ElseBlock) {
+ SplitBlockAndInsertIfThenElse(
+ Cond, SplitBefore, /* ThenBlock */ nullptr, &ElseBlock,
+ /* UnreachableThen */ false,
+ /* UnreachableElse */ Unreachable, BranchWeights, DTU, LI);
+ return ElseBlock->getTerminator();
}
void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore,
Instruction **ThenTerm,
Instruction **ElseTerm,
MDNode *BranchWeights,
- DomTreeUpdater *DTU) {
- BasicBlock *Head = SplitBefore->getParent();
+ DomTreeUpdater *DTU, LoopInfo *LI) {
+ BasicBlock *ThenBlock = nullptr;
+ BasicBlock *ElseBlock = nullptr;
+ SplitBlockAndInsertIfThenElse(
+ Cond, SplitBefore, &ThenBlock, &ElseBlock, /* UnreachableThen */ false,
+ /* UnreachableElse */ false, BranchWeights, DTU, LI);
+
+ *ThenTerm = ThenBlock->getTerminator();
+ *ElseTerm = ElseBlock->getTerminator();
+}
+
+void llvm::SplitBlockAndInsertIfThenElse(
+ Value *Cond, Instruction *SplitBefore, BasicBlock **ThenBlock,
+ BasicBlock **ElseBlock, bool UnreachableThen, bool UnreachableElse,
+ MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI) {
+ assert((ThenBlock || ElseBlock) &&
+ "At least one branch block must be created");
+ assert((!UnreachableThen || !UnreachableElse) &&
+ "Split block tail must be reachable");
+ SmallVector<DominatorTree::UpdateType, 8> Updates;
SmallPtrSet<BasicBlock *, 8> UniqueOrigSuccessors;
- if (DTU)
+ BasicBlock *Head = SplitBefore->getParent();
+ if (DTU) {
UniqueOrigSuccessors.insert(succ_begin(Head), succ_end(Head));
+ Updates.reserve(4 + 2 * UniqueOrigSuccessors.size());
+ }
+ LLVMContext &C = Head->getContext();
BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator());
+ BasicBlock *TrueBlock = Tail;
+ BasicBlock *FalseBlock = Tail;
+ bool ThenToTailEdge = false;
+ bool ElseToTailEdge = false;
+
+ // Encapsulate the logic around creation/insertion/etc of a new block.
+ auto handleBlock = [&](BasicBlock **PBB, bool Unreachable, BasicBlock *&BB,
+ bool &ToTailEdge) {
+ if (PBB == nullptr)
+ return; // Do not create/insert a block.
+
+ if (*PBB)
+ BB = *PBB; // Caller supplied block, use it.
+ else {
+ // Create a new block.
+ BB = BasicBlock::Create(C, "", Head->getParent(), Tail);
+ if (Unreachable)
+ (void)new UnreachableInst(C, BB);
+ else {
+ (void)BranchInst::Create(Tail, BB);
+ ToTailEdge = true;
+ }
+ BB->getTerminator()->setDebugLoc(SplitBefore->getDebugLoc());
+ // Pass the new block back to the caller.
+ *PBB = BB;
+ }
+ };
+
+ handleBlock(ThenBlock, UnreachableThen, TrueBlock, ThenToTailEdge);
+ handleBlock(ElseBlock, UnreachableElse, FalseBlock, ElseToTailEdge);
+
Instruction *HeadOldTerm = Head->getTerminator();
- LLVMContext &C = Head->getContext();
- BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail);
- BasicBlock *ElseBlock = BasicBlock::Create(C, "", Head->getParent(), Tail);
- *ThenTerm = BranchInst::Create(Tail, ThenBlock);
- (*ThenTerm)->setDebugLoc(SplitBefore->getDebugLoc());
- *ElseTerm = BranchInst::Create(Tail, ElseBlock);
- (*ElseTerm)->setDebugLoc(SplitBefore->getDebugLoc());
BranchInst *HeadNewTerm =
- BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/ElseBlock, Cond);
+ BranchInst::Create(/*ifTrue*/ TrueBlock, /*ifFalse*/ FalseBlock, Cond);
HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights);
ReplaceInstWithInst(HeadOldTerm, HeadNewTerm);
+
if (DTU) {
- SmallVector<DominatorTree::UpdateType, 8> Updates;
- Updates.reserve(4 + 2 * UniqueOrigSuccessors.size());
- for (BasicBlock *Succ : successors(Head)) {
- Updates.push_back({DominatorTree::Insert, Head, Succ});
- Updates.push_back({DominatorTree::Insert, Succ, Tail});
- }
+ Updates.emplace_back(DominatorTree::Insert, Head, TrueBlock);
+ Updates.emplace_back(DominatorTree::Insert, Head, FalseBlock);
+ if (ThenToTailEdge)
+ Updates.emplace_back(DominatorTree::Insert, TrueBlock, Tail);
+ if (ElseToTailEdge)
+ Updates.emplace_back(DominatorTree::Insert, FalseBlock, Tail);
for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors)
- Updates.push_back({DominatorTree::Insert, Tail, UniqueOrigSuccessor});
+ Updates.emplace_back(DominatorTree::Insert, Tail, UniqueOrigSuccessor);
for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors)
- Updates.push_back({DominatorTree::Delete, Head, UniqueOrigSuccessor});
+ Updates.emplace_back(DominatorTree::Delete, Head, UniqueOrigSuccessor);
DTU->applyUpdates(Updates);
}
+
+ if (LI) {
+ if (Loop *L = LI->getLoopFor(Head); L) {
+ if (ThenToTailEdge)
+ L->addBasicBlockToLoop(TrueBlock, *LI);
+ if (ElseToTailEdge)
+ L->addBasicBlockToLoop(FalseBlock, *LI);
+ L->addBasicBlockToLoop(Tail, *LI);
+ }
+ }
+}
+
+std::pair<Instruction*, Value*>
+llvm::SplitBlockAndInsertSimpleForLoop(Value *End, Instruction *SplitBefore) {
+ BasicBlock *LoopPred = SplitBefore->getParent();
+ BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore);
+ BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore);
+
+ auto *Ty = End->getType();
+ auto &DL = SplitBefore->getModule()->getDataLayout();
+ const unsigned Bitwidth = DL.getTypeSizeInBits(Ty);
+
+ IRBuilder<> Builder(LoopBody->getTerminator());
+ auto *IV = Builder.CreatePHI(Ty, 2, "iv");
+ auto *IVNext =
+ Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next",
+ /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
+ auto *IVCheck = Builder.CreateICmpEQ(IVNext, End,
+ IV->getName() + ".check");
+ Builder.CreateCondBr(IVCheck, LoopExit, LoopBody);
+ LoopBody->getTerminator()->eraseFromParent();
+
+ // Populate the IV PHI.
+ IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred);
+ IV->addIncoming(IVNext, LoopBody);
+
+ return std::make_pair(LoopBody->getFirstNonPHI(), IV);
+}
+
+void llvm::SplitBlockAndInsertForEachLane(ElementCount EC,
+ Type *IndexTy, Instruction *InsertBefore,
+ std::function<void(IRBuilderBase&, Value*)> Func) {
+
+ IRBuilder<> IRB(InsertBefore);
+
+ if (EC.isScalable()) {
+ Value *NumElements = IRB.CreateElementCount(IndexTy, EC);
+
+ auto [BodyIP, Index] =
+ SplitBlockAndInsertSimpleForLoop(NumElements, InsertBefore);
+
+ IRB.SetInsertPoint(BodyIP);
+ Func(IRB, Index);
+ return;
+ }
+
+ unsigned Num = EC.getFixedValue();
+ for (unsigned Idx = 0; Idx < Num; ++Idx) {
+ IRB.SetInsertPoint(InsertBefore);
+ Func(IRB, ConstantInt::get(IndexTy, Idx));
+ }
+}
+
+void llvm::SplitBlockAndInsertForEachLane(
+ Value *EVL, Instruction *InsertBefore,
+ std::function<void(IRBuilderBase &, Value *)> Func) {
+
+ IRBuilder<> IRB(InsertBefore);
+ Type *Ty = EVL->getType();
+
+ if (!isa<ConstantInt>(EVL)) {
+ auto [BodyIP, Index] = SplitBlockAndInsertSimpleForLoop(EVL, InsertBefore);
+ IRB.SetInsertPoint(BodyIP);
+ Func(IRB, Index);
+ return;
+ }
+
+ unsigned Num = cast<ConstantInt>(EVL)->getZExtValue();
+ for (unsigned Idx = 0; Idx < Num; ++Idx) {
+ IRB.SetInsertPoint(InsertBefore);
+ Func(IRB, ConstantInt::get(Ty, Idx));
+ }
}
BranchInst *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue,
@@ -1997,3 +2063,17 @@ BasicBlock *llvm::CreateControlFlowHub(
return FirstGuardBlock;
}
+
+void llvm::InvertBranch(BranchInst *PBI, IRBuilderBase &Builder) {
+ Value *NewCond = PBI->getCondition();
+ // If this is a "cmp" instruction, only used for branching (and nowhere
+ // else), then we can simply invert the predicate.
+ if (NewCond->hasOneUse() && isa<CmpInst>(NewCond)) {
+ CmpInst *CI = cast<CmpInst>(NewCond);
+ CI->setPredicate(CI->getInversePredicate());
+ } else
+ NewCond = Builder.CreateNot(NewCond, NewCond->getName() + ".not");
+
+ PBI->setCondition(NewCond);
+ PBI->swapSuccessors();
+}
diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
index 1e21a2f85446..5de8ff84de77 100644
--- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -478,6 +478,8 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F,
case LibFunc_modfl:
Changed |= setDoesNotThrow(F);
Changed |= setWillReturn(F);
+ Changed |= setOnlyAccessesArgMemory(F);
+ Changed |= setOnlyWritesMemory(F);
Changed |= setDoesNotCapture(F, 1);
break;
case LibFunc_memcpy:
@@ -725,6 +727,8 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F,
case LibFunc_frexpl:
Changed |= setDoesNotThrow(F);
Changed |= setWillReturn(F);
+ Changed |= setOnlyAccessesArgMemory(F);
+ Changed |= setOnlyWritesMemory(F);
Changed |= setDoesNotCapture(F, 1);
break;
case LibFunc_fstatvfs:
@@ -1937,3 +1941,87 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B,
return CI;
}
+
+Value *llvm::emitHotColdNew(Value *Num, IRBuilderBase &B,
+ const TargetLibraryInfo *TLI, LibFunc NewFunc,
+ uint8_t HotCold) {
+ Module *M = B.GetInsertBlock()->getModule();
+ if (!isLibFuncEmittable(M, TLI, NewFunc))
+ return nullptr;
+
+ StringRef Name = TLI->getName(NewFunc);
+ FunctionCallee Func = M->getOrInsertFunction(Name, B.getInt8PtrTy(),
+ Num->getType(), B.getInt8Ty());
+ inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
+ CallInst *CI = B.CreateCall(Func, {Num, B.getInt8(HotCold)}, Name);
+
+ if (const Function *F =
+ dyn_cast<Function>(Func.getCallee()->stripPointerCasts()))
+ CI->setCallingConv(F->getCallingConv());
+
+ return CI;
+}
+
+Value *llvm::emitHotColdNewNoThrow(Value *Num, Value *NoThrow, IRBuilderBase &B,
+ const TargetLibraryInfo *TLI,
+ LibFunc NewFunc, uint8_t HotCold) {
+ Module *M = B.GetInsertBlock()->getModule();
+ if (!isLibFuncEmittable(M, TLI, NewFunc))
+ return nullptr;
+
+ StringRef Name = TLI->getName(NewFunc);
+ FunctionCallee Func =
+ M->getOrInsertFunction(Name, B.getInt8PtrTy(), Num->getType(),
+ NoThrow->getType(), B.getInt8Ty());
+ inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
+ CallInst *CI = B.CreateCall(Func, {Num, NoThrow, B.getInt8(HotCold)}, Name);
+
+ if (const Function *F =
+ dyn_cast<Function>(Func.getCallee()->stripPointerCasts()))
+ CI->setCallingConv(F->getCallingConv());
+
+ return CI;
+}
+
+Value *llvm::emitHotColdNewAligned(Value *Num, Value *Align, IRBuilderBase &B,
+ const TargetLibraryInfo *TLI,
+ LibFunc NewFunc, uint8_t HotCold) {
+ Module *M = B.GetInsertBlock()->getModule();
+ if (!isLibFuncEmittable(M, TLI, NewFunc))
+ return nullptr;
+
+ StringRef Name = TLI->getName(NewFunc);
+ FunctionCallee Func = M->getOrInsertFunction(
+ Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), B.getInt8Ty());
+ inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
+ CallInst *CI = B.CreateCall(Func, {Num, Align, B.getInt8(HotCold)}, Name);
+
+ if (const Function *F =
+ dyn_cast<Function>(Func.getCallee()->stripPointerCasts()))
+ CI->setCallingConv(F->getCallingConv());
+
+ return CI;
+}
+
+Value *llvm::emitHotColdNewAlignedNoThrow(Value *Num, Value *Align,
+ Value *NoThrow, IRBuilderBase &B,
+ const TargetLibraryInfo *TLI,
+ LibFunc NewFunc, uint8_t HotCold) {
+ Module *M = B.GetInsertBlock()->getModule();
+ if (!isLibFuncEmittable(M, TLI, NewFunc))
+ return nullptr;
+
+ StringRef Name = TLI->getName(NewFunc);
+ FunctionCallee Func = M->getOrInsertFunction(
+ Name, B.getInt8PtrTy(), Num->getType(), Align->getType(),
+ NoThrow->getType(), B.getInt8Ty());
+ inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
+ CallInst *CI =
+ B.CreateCall(Func, {Num, Align, NoThrow, B.getInt8(HotCold)}, Name);
+
+ if (const Function *F =
+ dyn_cast<Function>(Func.getCallee()->stripPointerCasts()))
+ CI->setCallingConv(F->getCallingConv());
+
+ return CI;
+}
diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
index 930a0bcbfac5..73a50b793e6d 100644
--- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
+++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
@@ -202,7 +202,7 @@ bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) {
ConstantInt *C = dyn_cast<ConstantInt>(Op1);
if (!C && isa<BitCastInst>(Op1))
C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0));
- return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth();
+ return C && C->getValue().getSignificantBits() > BypassType->getBitWidth();
}
case Instruction::PHI:
// Stop IR traversal in case of a crazy input code. This limits recursion
diff --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp
index d0b89ba2606e..d0b9884aa909 100644
--- a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp
+++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp
@@ -120,6 +120,8 @@ void CallGraphUpdater::removeFunction(Function &DeadFn) {
DeadCGN->removeAllCalledFunctions();
CGSCC->DeleteNode(DeadCGN);
}
+ if (FAM)
+ FAM->clear(DeadFn, DeadFn.getName());
}
void CallGraphUpdater::replaceFunctionWith(Function &OldFn, Function &NewFn) {
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 4a82f9606d3f..b488e3bb0cbd 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -14,6 +14,7 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
diff --git a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp
index 4d622679dbdb..c24b6ed70405 100644
--- a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp
+++ b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp
@@ -31,8 +31,6 @@
#include "llvm/Transforms/Utils/CanonicalizeAliases.h"
#include "llvm/IR/Constants.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
using namespace llvm;
diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp
index 87822ee85c2b..d55208602b71 100644
--- a/llvm/lib/Transforms/Utils/CloneFunction.cpp
+++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp
@@ -470,9 +470,8 @@ void PruningFunctionCloner::CloneBlock(
// Nope, clone it now.
BasicBlock *NewBB;
- BBEntry = NewBB = BasicBlock::Create(BB->getContext());
- if (BB->hasName())
- NewBB->setName(BB->getName() + NameSuffix);
+ Twine NewName(BB->hasName() ? Twine(BB->getName()) + NameSuffix : "");
+ BBEntry = NewBB = BasicBlock::Create(BB->getContext(), NewName, NewFunc);
// It is only legal to clone a function if a block address within that
// function is never referenced outside of the function. Given that, we
@@ -498,6 +497,7 @@ void PruningFunctionCloner::CloneBlock(
++II) {
Instruction *NewInst = cloneInstruction(II);
+ NewInst->insertInto(NewBB, NewBB->end());
if (HostFuncIsStrictFP) {
// All function calls in the inlined function must get 'strictfp'
@@ -526,7 +526,7 @@ void PruningFunctionCloner::CloneBlock(
if (!NewInst->mayHaveSideEffects()) {
VMap[&*II] = V;
- NewInst->deleteValue();
+ NewInst->eraseFromParent();
continue;
}
}
@@ -535,7 +535,6 @@ void PruningFunctionCloner::CloneBlock(
if (II->hasName())
NewInst->setName(II->getName() + NameSuffix);
VMap[&*II] = NewInst; // Add instruction map to value.
- NewInst->insertInto(NewBB, NewBB->end());
if (isa<CallInst>(II) && !II->isDebugOrPseudoInst()) {
hasCalls = true;
hasMemProfMetadata |= II->hasMetadata(LLVMContext::MD_memprof);
@@ -683,8 +682,8 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc,
if (!NewBB)
continue; // Dead block.
- // Add the new block to the new function.
- NewFunc->insert(NewFunc->end(), NewBB);
+ // Move the new block to preserve the order in the original function.
+ NewBB->moveBefore(NewFunc->end());
// Handle PHI nodes specially, as we have to remove references to dead
// blocks.
@@ -937,8 +936,8 @@ void llvm::CloneAndPruneFunctionInto(
}
/// Remaps instructions in \p Blocks using the mapping in \p VMap.
-void llvm::remapInstructionsInBlocks(
- const SmallVectorImpl<BasicBlock *> &Blocks, ValueToValueMapTy &VMap) {
+void llvm::remapInstructionsInBlocks(ArrayRef<BasicBlock *> Blocks,
+ ValueToValueMapTy &VMap) {
// Rewrite the code to refer to itself.
for (auto *BB : Blocks)
for (auto &Inst : *BB)
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index c1fe10504e45..c390af351a69 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -918,6 +918,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::AllocKind:
case Attribute::PresplitCoroutine:
case Attribute::Memory:
+ case Attribute::NoFPClass:
continue;
// Those attributes should be safe to propagate to the extracted function.
case Attribute::AlwaysInline:
@@ -1091,32 +1092,20 @@ static void insertLifetimeMarkersSurroundingCall(
Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
CallInst *TheCall) {
LLVMContext &Ctx = M->getContext();
- auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
Instruction *Term = TheCall->getParent()->getTerminator();
- // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
- // needed to satisfy this requirement so they may be reused.
- DenseMap<Value *, Value *> Bitcasts;
-
// Emit lifetime markers for the pointers given in \p Objects. Insert the
// markers before the call if \p InsertBefore, and after the call otherwise.
- auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects,
+ auto insertMarkers = [&](Intrinsic::ID MarkerFunc, ArrayRef<Value *> Objects,
bool InsertBefore) {
for (Value *Mem : Objects) {
assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
TheCall->getFunction()) &&
"Input memory not defined in original function");
- Value *&MemAsI8Ptr = Bitcasts[Mem];
- if (!MemAsI8Ptr) {
- if (Mem->getType() == Int8PtrTy)
- MemAsI8Ptr = Mem;
- else
- MemAsI8Ptr =
- CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
- }
- auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr});
+ Function *Func = Intrinsic::getDeclaration(M, MarkerFunc, Mem->getType());
+ auto Marker = CallInst::Create(Func, {NegativeOne, Mem});
if (InsertBefore)
Marker->insertBefore(TheCall);
else
@@ -1125,15 +1114,13 @@ static void insertLifetimeMarkersSurroundingCall(
};
if (!LifetimesStart.empty()) {
- auto StartFn = llvm::Intrinsic::getDeclaration(
- M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
- insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true);
+ insertMarkers(Intrinsic::lifetime_start, LifetimesStart,
+ /*InsertBefore=*/true);
}
if (!LifetimesEnd.empty()) {
- auto EndFn = llvm::Intrinsic::getDeclaration(
- M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
- insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false);
+ insertMarkers(Intrinsic::lifetime_end, LifetimesEnd,
+ /*InsertBefore=*/false);
}
}
@@ -1663,14 +1650,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
}
}
- // Remove CondGuardInsts that will be moved to the new function from the old
- // function's assumption cache.
+ // Remove @llvm.assume calls that will be moved to the new function from the
+ // old function's assumption cache.
for (BasicBlock *Block : Blocks) {
for (Instruction &I : llvm::make_early_inc_range(*Block)) {
- if (auto *CI = dyn_cast<CondGuardInst>(&I)) {
+ if (auto *AI = dyn_cast<AssumeInst>(&I)) {
if (AC)
- AC->unregisterAssumption(CI);
- CI->eraseFromParent();
+ AC->unregisterAssumption(AI);
+ AI->eraseFromParent();
}
}
}
@@ -1864,7 +1851,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
const Function &NewFunc,
AssumptionCache *AC) {
for (auto AssumeVH : AC->assumptions()) {
- auto *I = dyn_cast_or_null<CondGuardInst>(AssumeVH);
+ auto *I = dyn_cast_or_null<CallInst>(AssumeVH);
if (!I)
continue;
@@ -1876,7 +1863,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
// that were previously in the old function, but that have now been moved
// to the new function.
for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) {
- auto *AffectedCI = dyn_cast_or_null<CondGuardInst>(AffectedValVH);
+ auto *AffectedCI = dyn_cast_or_null<CallInst>(AffectedValVH);
if (!AffectedCI)
continue;
if (AffectedCI->getFunction() != &OldFunc)
diff --git a/llvm/lib/Transforms/Utils/CodeLayout.cpp b/llvm/lib/Transforms/Utils/CodeLayout.cpp
index 9eb3aff3ffe8..ac74a1c116cc 100644
--- a/llvm/lib/Transforms/Utils/CodeLayout.cpp
+++ b/llvm/lib/Transforms/Utils/CodeLayout.cpp
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// ExtTSP - layout of basic blocks with i-cache optimization.
+// The file implements "cache-aware" layout algorithms of basic blocks and
+// functions in a binary.
//
// The algorithm tries to find a layout of nodes (basic blocks) of a given CFG
// optimizing jump locality and thus processor I-cache utilization. This is
@@ -41,12 +42,14 @@
#include "llvm/Transforms/Utils/CodeLayout.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include <cmath>
using namespace llvm;
#define DEBUG_TYPE "code-layout"
+namespace llvm {
cl::opt<bool> EnableExtTspBlockPlacement(
"enable-ext-tsp-block-placement", cl::Hidden, cl::init(false),
cl::desc("Enable machine block placement based on the ext-tsp model, "
@@ -56,6 +59,7 @@ cl::opt<bool> ApplyExtTspWithoutProfile(
"ext-tsp-apply-without-profile",
cl::desc("Whether to apply ext-tsp placement for instances w/o profile"),
cl::init(true), cl::Hidden);
+} // namespace llvm
// Algorithm-specific params. The values are tuned for the best performance
// of large-scale front-end bound binaries.
@@ -69,11 +73,11 @@ static cl::opt<double> ForwardWeightUncond(
static cl::opt<double> BackwardWeightCond(
"ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1),
- cl::desc("The weight of conditonal backward jumps for ExtTSP value"));
+ cl::desc("The weight of conditional backward jumps for ExtTSP value"));
static cl::opt<double> BackwardWeightUncond(
"ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
- cl::desc("The weight of unconditonal backward jumps for ExtTSP value"));
+ cl::desc("The weight of unconditional backward jumps for ExtTSP value"));
static cl::opt<double> FallthroughWeightCond(
"ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0),
@@ -149,29 +153,30 @@ double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr,
/// A type of merging two chains, X and Y. The former chain is split into
/// X1 and X2 and then concatenated with Y in the order specified by the type.
-enum class MergeTypeTy : int { X_Y, X1_Y_X2, Y_X2_X1, X2_X1_Y };
+enum class MergeTypeT : int { X_Y, Y_X, X1_Y_X2, Y_X2_X1, X2_X1_Y };
/// The gain of merging two chains, that is, the Ext-TSP score of the merge
-/// together with the corresponfiding merge 'type' and 'offset'.
-class MergeGainTy {
-public:
- explicit MergeGainTy() = default;
- explicit MergeGainTy(double Score, size_t MergeOffset, MergeTypeTy MergeType)
+/// together with the corresponding merge 'type' and 'offset'.
+struct MergeGainT {
+ explicit MergeGainT() = default;
+ explicit MergeGainT(double Score, size_t MergeOffset, MergeTypeT MergeType)
: Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {}
double score() const { return Score; }
size_t mergeOffset() const { return MergeOffset; }
- MergeTypeTy mergeType() const { return MergeType; }
+ MergeTypeT mergeType() const { return MergeType; }
+
+ void setMergeType(MergeTypeT Ty) { MergeType = Ty; }
// Returns 'true' iff Other is preferred over this.
- bool operator<(const MergeGainTy &Other) const {
+ bool operator<(const MergeGainT &Other) const {
return (Other.Score > EPS && Other.Score > Score + EPS);
}
// Update the current gain if Other is preferred over this.
- void updateIfLessThan(const MergeGainTy &Other) {
+ void updateIfLessThan(const MergeGainT &Other) {
if (*this < Other)
*this = Other;
}
@@ -179,106 +184,102 @@ public:
private:
double Score{-1.0};
size_t MergeOffset{0};
- MergeTypeTy MergeType{MergeTypeTy::X_Y};
+ MergeTypeT MergeType{MergeTypeT::X_Y};
};
-class Jump;
-class Chain;
-class ChainEdge;
+struct JumpT;
+struct ChainT;
+struct ChainEdge;
-/// A node in the graph, typically corresponding to a basic block in CFG.
-class Block {
-public:
- Block(const Block &) = delete;
- Block(Block &&) = default;
- Block &operator=(const Block &) = delete;
- Block &operator=(Block &&) = default;
+/// A node in the graph, typically corresponding to a basic block in the CFG or
+/// a function in the call graph.
+struct NodeT {
+ NodeT(const NodeT &) = delete;
+ NodeT(NodeT &&) = default;
+ NodeT &operator=(const NodeT &) = delete;
+ NodeT &operator=(NodeT &&) = default;
+
+ explicit NodeT(size_t Index, uint64_t Size, uint64_t EC)
+ : Index(Index), Size(Size), ExecutionCount(EC) {}
+
+ bool isEntry() const { return Index == 0; }
+
+ // The total execution count of outgoing jumps.
+ uint64_t outCount() const;
+
+ // The total execution count of incoming jumps.
+ uint64_t inCount() const;
- // The original index of the block in CFG.
+ // The original index of the node in graph.
size_t Index{0};
- // The index of the block in the current chain.
+ // The index of the node in the current chain.
size_t CurIndex{0};
- // Size of the block in the binary.
+ // The size of the node in the binary.
uint64_t Size{0};
- // Execution count of the block in the profile data.
+ // The execution count of the node in the profile data.
uint64_t ExecutionCount{0};
- // Current chain of the node.
- Chain *CurChain{nullptr};
- // An offset of the block in the current chain.
+ // The current chain of the node.
+ ChainT *CurChain{nullptr};
+ // The offset of the node in the current chain.
mutable uint64_t EstimatedAddr{0};
- // Forced successor of the block in CFG.
- Block *ForcedSucc{nullptr};
- // Forced predecessor of the block in CFG.
- Block *ForcedPred{nullptr};
- // Outgoing jumps from the block.
- std::vector<Jump *> OutJumps;
- // Incoming jumps to the block.
- std::vector<Jump *> InJumps;
-
-public:
- explicit Block(size_t Index, uint64_t Size, uint64_t EC)
- : Index(Index), Size(Size), ExecutionCount(EC) {}
- bool isEntry() const { return Index == 0; }
+ // Forced successor of the node in the graph.
+ NodeT *ForcedSucc{nullptr};
+ // Forced predecessor of the node in the graph.
+ NodeT *ForcedPred{nullptr};
+ // Outgoing jumps from the node.
+ std::vector<JumpT *> OutJumps;
+ // Incoming jumps to the node.
+ std::vector<JumpT *> InJumps;
};
-/// An arc in the graph, typically corresponding to a jump between two blocks.
-class Jump {
-public:
- Jump(const Jump &) = delete;
- Jump(Jump &&) = default;
- Jump &operator=(const Jump &) = delete;
- Jump &operator=(Jump &&) = default;
-
- // Source block of the jump.
- Block *Source;
- // Target block of the jump.
- Block *Target;
+/// An arc in the graph, typically corresponding to a jump between two nodes.
+struct JumpT {
+ JumpT(const JumpT &) = delete;
+ JumpT(JumpT &&) = default;
+ JumpT &operator=(const JumpT &) = delete;
+ JumpT &operator=(JumpT &&) = default;
+
+ explicit JumpT(NodeT *Source, NodeT *Target, uint64_t ExecutionCount)
+ : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {}
+
+ // Source node of the jump.
+ NodeT *Source;
+ // Target node of the jump.
+ NodeT *Target;
// Execution count of the arc in the profile data.
uint64_t ExecutionCount{0};
// Whether the jump corresponds to a conditional branch.
bool IsConditional{false};
-
-public:
- explicit Jump(Block *Source, Block *Target, uint64_t ExecutionCount)
- : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {}
+ // The offset of the jump from the source node.
+ uint64_t Offset{0};
};
-/// A chain (ordered sequence) of blocks.
-class Chain {
-public:
- Chain(const Chain &) = delete;
- Chain(Chain &&) = default;
- Chain &operator=(const Chain &) = delete;
- Chain &operator=(Chain &&) = default;
+/// A chain (ordered sequence) of nodes in the graph.
+struct ChainT {
+ ChainT(const ChainT &) = delete;
+ ChainT(ChainT &&) = default;
+ ChainT &operator=(const ChainT &) = delete;
+ ChainT &operator=(ChainT &&) = default;
+
+ explicit ChainT(uint64_t Id, NodeT *Node)
+ : Id(Id), ExecutionCount(Node->ExecutionCount), Size(Node->Size),
+ Nodes(1, Node) {}
- explicit Chain(uint64_t Id, Block *Block)
- : Id(Id), Score(0), Blocks(1, Block) {}
+ size_t numBlocks() const { return Nodes.size(); }
- uint64_t id() const { return Id; }
+ double density() const { return static_cast<double>(ExecutionCount) / Size; }
- bool isEntry() const { return Blocks[0]->Index == 0; }
+ bool isEntry() const { return Nodes[0]->Index == 0; }
bool isCold() const {
- for (auto *Block : Blocks) {
- if (Block->ExecutionCount > 0)
+ for (NodeT *Node : Nodes) {
+ if (Node->ExecutionCount > 0)
return false;
}
return true;
}
- double score() const { return Score; }
-
- void setScore(double NewScore) { Score = NewScore; }
-
- const std::vector<Block *> &blocks() const { return Blocks; }
-
- size_t numBlocks() const { return Blocks.size(); }
-
- const std::vector<std::pair<Chain *, ChainEdge *>> &edges() const {
- return Edges;
- }
-
- ChainEdge *getEdge(Chain *Other) const {
+ ChainEdge *getEdge(ChainT *Other) const {
for (auto It : Edges) {
if (It.first == Other)
return It.second;
@@ -286,7 +287,7 @@ public:
return nullptr;
}
- void removeEdge(Chain *Other) {
+ void removeEdge(ChainT *Other) {
auto It = Edges.begin();
while (It != Edges.end()) {
if (It->first == Other) {
@@ -297,63 +298,68 @@ public:
}
}
- void addEdge(Chain *Other, ChainEdge *Edge) {
+ void addEdge(ChainT *Other, ChainEdge *Edge) {
Edges.push_back(std::make_pair(Other, Edge));
}
- void merge(Chain *Other, const std::vector<Block *> &MergedBlocks) {
- Blocks = MergedBlocks;
- // Update the block's chains
- for (size_t Idx = 0; Idx < Blocks.size(); Idx++) {
- Blocks[Idx]->CurChain = this;
- Blocks[Idx]->CurIndex = Idx;
+ void merge(ChainT *Other, const std::vector<NodeT *> &MergedBlocks) {
+ Nodes = MergedBlocks;
+ // Update the chain's data
+ ExecutionCount += Other->ExecutionCount;
+ Size += Other->Size;
+ Id = Nodes[0]->Index;
+ // Update the node's data
+ for (size_t Idx = 0; Idx < Nodes.size(); Idx++) {
+ Nodes[Idx]->CurChain = this;
+ Nodes[Idx]->CurIndex = Idx;
}
}
- void mergeEdges(Chain *Other);
+ void mergeEdges(ChainT *Other);
void clear() {
- Blocks.clear();
- Blocks.shrink_to_fit();
+ Nodes.clear();
+ Nodes.shrink_to_fit();
Edges.clear();
Edges.shrink_to_fit();
}
-private:
// Unique chain identifier.
uint64_t Id;
// Cached ext-tsp score for the chain.
- double Score;
- // Blocks of the chain.
- std::vector<Block *> Blocks;
+ double Score{0};
+ // The total execution count of the chain.
+ uint64_t ExecutionCount{0};
+ // The total size of the chain.
+ uint64_t Size{0};
+ // Nodes of the chain.
+ std::vector<NodeT *> Nodes;
// Adjacent chains and corresponding edges (lists of jumps).
- std::vector<std::pair<Chain *, ChainEdge *>> Edges;
+ std::vector<std::pair<ChainT *, ChainEdge *>> Edges;
};
-/// An edge in CFG representing jumps between two chains.
-/// When blocks are merged into chains, the edges are combined too so that
+/// An edge in the graph representing jumps between two chains.
+/// When nodes are merged into chains, the edges are combined too so that
/// there is always at most one edge between a pair of chains
-class ChainEdge {
-public:
+struct ChainEdge {
ChainEdge(const ChainEdge &) = delete;
ChainEdge(ChainEdge &&) = default;
ChainEdge &operator=(const ChainEdge &) = delete;
- ChainEdge &operator=(ChainEdge &&) = default;
+ ChainEdge &operator=(ChainEdge &&) = delete;
- explicit ChainEdge(Jump *Jump)
+ explicit ChainEdge(JumpT *Jump)
: SrcChain(Jump->Source->CurChain), DstChain(Jump->Target->CurChain),
Jumps(1, Jump) {}
- const std::vector<Jump *> &jumps() const { return Jumps; }
+ ChainT *srcChain() const { return SrcChain; }
- void changeEndpoint(Chain *From, Chain *To) {
- if (From == SrcChain)
- SrcChain = To;
- if (From == DstChain)
- DstChain = To;
- }
+ ChainT *dstChain() const { return DstChain; }
+
+ bool isSelfEdge() const { return SrcChain == DstChain; }
- void appendJump(Jump *Jump) { Jumps.push_back(Jump); }
+ const std::vector<JumpT *> &jumps() const { return Jumps; }
+
+ void appendJump(JumpT *Jump) { Jumps.push_back(Jump); }
void moveJumps(ChainEdge *Other) {
Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
@@ -361,15 +367,22 @@ public:
Other->Jumps.shrink_to_fit();
}
- bool hasCachedMergeGain(Chain *Src, Chain *Dst) const {
+ void changeEndpoint(ChainT *From, ChainT *To) {
+ if (From == SrcChain)
+ SrcChain = To;
+ if (From == DstChain)
+ DstChain = To;
+ }
+
+ bool hasCachedMergeGain(ChainT *Src, ChainT *Dst) const {
return Src == SrcChain ? CacheValidForward : CacheValidBackward;
}
- MergeGainTy getCachedMergeGain(Chain *Src, Chain *Dst) const {
+ MergeGainT getCachedMergeGain(ChainT *Src, ChainT *Dst) const {
return Src == SrcChain ? CachedGainForward : CachedGainBackward;
}
- void setCachedMergeGain(Chain *Src, Chain *Dst, MergeGainTy MergeGain) {
+ void setCachedMergeGain(ChainT *Src, ChainT *Dst, MergeGainT MergeGain) {
if (Src == SrcChain) {
CachedGainForward = MergeGain;
CacheValidForward = true;
@@ -384,31 +397,55 @@ public:
CacheValidBackward = false;
}
+ void setMergeGain(MergeGainT Gain) { CachedGain = Gain; }
+
+ MergeGainT getMergeGain() const { return CachedGain; }
+
+ double gain() const { return CachedGain.score(); }
+
private:
// Source chain.
- Chain *SrcChain{nullptr};
+ ChainT *SrcChain{nullptr};
// Destination chain.
- Chain *DstChain{nullptr};
- // Original jumps in the binary with correspinding execution counts.
- std::vector<Jump *> Jumps;
- // Cached ext-tsp value for merging the pair of chains.
- // Since the gain of merging (Src, Dst) and (Dst, Src) might be different,
- // we store both values here.
- MergeGainTy CachedGainForward;
- MergeGainTy CachedGainBackward;
+ ChainT *DstChain{nullptr};
+ // Original jumps in the binary with corresponding execution counts.
+ std::vector<JumpT *> Jumps;
+ // Cached gain value for merging the pair of chains.
+ MergeGainT CachedGain;
+
+ // Cached gain values for merging the pair of chains. Since the gain of
+ // merging (Src, Dst) and (Dst, Src) might be different, we store both values
+ // here and a flag indicating which of the options results in a higher gain.
+ // Cached gain values.
+ MergeGainT CachedGainForward;
+ MergeGainT CachedGainBackward;
// Whether the cached value must be recomputed.
bool CacheValidForward{false};
bool CacheValidBackward{false};
};
-void Chain::mergeEdges(Chain *Other) {
- assert(this != Other && "cannot merge a chain with itself");
+uint64_t NodeT::outCount() const {
+ uint64_t Count = 0;
+ for (JumpT *Jump : OutJumps) {
+ Count += Jump->ExecutionCount;
+ }
+ return Count;
+}
+uint64_t NodeT::inCount() const {
+ uint64_t Count = 0;
+ for (JumpT *Jump : InJumps) {
+ Count += Jump->ExecutionCount;
+ }
+ return Count;
+}
+
+void ChainT::mergeEdges(ChainT *Other) {
// Update edges adjacent to chain Other
for (auto EdgeIt : Other->Edges) {
- Chain *DstChain = EdgeIt.first;
+ ChainT *DstChain = EdgeIt.first;
ChainEdge *DstEdge = EdgeIt.second;
- Chain *TargetChain = DstChain == Other ? this : DstChain;
+ ChainT *TargetChain = DstChain == Other ? this : DstChain;
ChainEdge *CurEdge = getEdge(TargetChain);
if (CurEdge == nullptr) {
DstEdge->changeEndpoint(Other, this);
@@ -426,15 +463,14 @@ void Chain::mergeEdges(Chain *Other) {
}
}
-using BlockIter = std::vector<Block *>::const_iterator;
+using NodeIter = std::vector<NodeT *>::const_iterator;
-/// A wrapper around three chains of blocks; it is used to avoid extra
+/// A wrapper around three chains of nodes; it is used to avoid extra
/// instantiation of the vectors.
-class MergedChain {
-public:
- MergedChain(BlockIter Begin1, BlockIter End1, BlockIter Begin2 = BlockIter(),
- BlockIter End2 = BlockIter(), BlockIter Begin3 = BlockIter(),
- BlockIter End3 = BlockIter())
+struct MergedChain {
+ MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
+ NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
+ NodeIter End3 = NodeIter())
: Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
End3(End3) {}
@@ -447,8 +483,8 @@ public:
Func(*It);
}
- std::vector<Block *> getBlocks() const {
- std::vector<Block *> Result;
+ std::vector<NodeT *> getNodes() const {
+ std::vector<NodeT *> Result;
Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) +
std::distance(Begin3, End3));
Result.insert(Result.end(), Begin1, End1);
@@ -457,42 +493,71 @@ public:
return Result;
}
- const Block *getFirstBlock() const { return *Begin1; }
+ const NodeT *getFirstNode() const { return *Begin1; }
private:
- BlockIter Begin1;
- BlockIter End1;
- BlockIter Begin2;
- BlockIter End2;
- BlockIter Begin3;
- BlockIter End3;
+ NodeIter Begin1;
+ NodeIter End1;
+ NodeIter Begin2;
+ NodeIter End2;
+ NodeIter Begin3;
+ NodeIter End3;
};
+/// Merge two chains of nodes respecting a given 'type' and 'offset'.
+///
+/// If MergeType == 0, then the result is a concatenation of two chains.
+/// Otherwise, the first chain is cut into two sub-chains at the offset,
+/// and merged using all possible ways of concatenating three chains.
+MergedChain mergeNodes(const std::vector<NodeT *> &X,
+ const std::vector<NodeT *> &Y, size_t MergeOffset,
+ MergeTypeT MergeType) {
+ // Split the first chain, X, into X1 and X2
+ NodeIter BeginX1 = X.begin();
+ NodeIter EndX1 = X.begin() + MergeOffset;
+ NodeIter BeginX2 = X.begin() + MergeOffset;
+ NodeIter EndX2 = X.end();
+ NodeIter BeginY = Y.begin();
+ NodeIter EndY = Y.end();
+
+ // Construct a new chain from the three existing ones
+ switch (MergeType) {
+ case MergeTypeT::X_Y:
+ return MergedChain(BeginX1, EndX2, BeginY, EndY);
+ case MergeTypeT::Y_X:
+ return MergedChain(BeginY, EndY, BeginX1, EndX2);
+ case MergeTypeT::X1_Y_X2:
+ return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
+ case MergeTypeT::Y_X2_X1:
+ return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
+ case MergeTypeT::X2_X1_Y:
+ return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
+ }
+ llvm_unreachable("unexpected chain merge type");
+}
+
/// The implementation of the ExtTSP algorithm.
class ExtTSPImpl {
- using EdgeT = std::pair<uint64_t, uint64_t>;
- using EdgeCountMap = std::vector<std::pair<EdgeT, uint64_t>>;
-
public:
- ExtTSPImpl(size_t NumNodes, const std::vector<uint64_t> &NodeSizes,
+ ExtTSPImpl(const std::vector<uint64_t> &NodeSizes,
const std::vector<uint64_t> &NodeCounts,
- const EdgeCountMap &EdgeCounts)
- : NumNodes(NumNodes) {
+ const std::vector<EdgeCountT> &EdgeCounts)
+ : NumNodes(NodeSizes.size()) {
initialize(NodeSizes, NodeCounts, EdgeCounts);
}
- /// Run the algorithm and return an optimized ordering of blocks.
+ /// Run the algorithm and return an optimized ordering of nodes.
void run(std::vector<uint64_t> &Result) {
- // Pass 1: Merge blocks with their mutually forced successors
+ // Pass 1: Merge nodes with their mutually forced successors
mergeForcedPairs();
// Pass 2: Merge pairs of chains while improving the ExtTSP objective
mergeChainPairs();
- // Pass 3: Merge cold blocks to reduce code size
+ // Pass 3: Merge cold nodes to reduce code size
mergeColdChains();
- // Collect blocks from all chains
+ // Collect nodes from all chains
concatChains(Result);
}
@@ -500,26 +565,26 @@ private:
/// Initialize the algorithm's data structures.
void initialize(const std::vector<uint64_t> &NodeSizes,
const std::vector<uint64_t> &NodeCounts,
- const EdgeCountMap &EdgeCounts) {
- // Initialize blocks
- AllBlocks.reserve(NumNodes);
- for (uint64_t Node = 0; Node < NumNodes; Node++) {
- uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL);
- uint64_t ExecutionCount = NodeCounts[Node];
- // The execution count of the entry block is set to at least 1
- if (Node == 0 && ExecutionCount == 0)
+ const std::vector<EdgeCountT> &EdgeCounts) {
+ // Initialize nodes
+ AllNodes.reserve(NumNodes);
+ for (uint64_t Idx = 0; Idx < NumNodes; Idx++) {
+ uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL);
+ uint64_t ExecutionCount = NodeCounts[Idx];
+ // The execution count of the entry node is set to at least one
+ if (Idx == 0 && ExecutionCount == 0)
ExecutionCount = 1;
- AllBlocks.emplace_back(Node, Size, ExecutionCount);
+ AllNodes.emplace_back(Idx, Size, ExecutionCount);
}
- // Initialize jumps between blocks
+ // Initialize jumps between nodes
SuccNodes.resize(NumNodes);
PredNodes.resize(NumNodes);
std::vector<uint64_t> OutDegree(NumNodes, 0);
AllJumps.reserve(EdgeCounts.size());
for (auto It : EdgeCounts) {
- auto Pred = It.first.first;
- auto Succ = It.first.second;
+ uint64_t Pred = It.first.first;
+ uint64_t Succ = It.first.second;
OutDegree[Pred]++;
// Ignore self-edges
if (Pred == Succ)
@@ -527,16 +592,16 @@ private:
SuccNodes[Pred].push_back(Succ);
PredNodes[Succ].push_back(Pred);
- auto ExecutionCount = It.second;
+ uint64_t ExecutionCount = It.second;
if (ExecutionCount > 0) {
- auto &Block = AllBlocks[Pred];
- auto &SuccBlock = AllBlocks[Succ];
- AllJumps.emplace_back(&Block, &SuccBlock, ExecutionCount);
- SuccBlock.InJumps.push_back(&AllJumps.back());
- Block.OutJumps.push_back(&AllJumps.back());
+ NodeT &PredNode = AllNodes[Pred];
+ NodeT &SuccNode = AllNodes[Succ];
+ AllJumps.emplace_back(&PredNode, &SuccNode, ExecutionCount);
+ SuccNode.InJumps.push_back(&AllJumps.back());
+ PredNode.OutJumps.push_back(&AllJumps.back());
}
}
- for (auto &Jump : AllJumps) {
+ for (JumpT &Jump : AllJumps) {
assert(OutDegree[Jump.Source->Index] > 0);
Jump.IsConditional = OutDegree[Jump.Source->Index] > 1;
}
@@ -544,78 +609,78 @@ private:
// Initialize chains
AllChains.reserve(NumNodes);
HotChains.reserve(NumNodes);
- for (Block &Block : AllBlocks) {
- AllChains.emplace_back(Block.Index, &Block);
- Block.CurChain = &AllChains.back();
- if (Block.ExecutionCount > 0) {
+ for (NodeT &Node : AllNodes) {
+ AllChains.emplace_back(Node.Index, &Node);
+ Node.CurChain = &AllChains.back();
+ if (Node.ExecutionCount > 0) {
HotChains.push_back(&AllChains.back());
}
}
// Initialize chain edges
AllEdges.reserve(AllJumps.size());
- for (Block &Block : AllBlocks) {
- for (auto &Jump : Block.OutJumps) {
- auto SuccBlock = Jump->Target;
- ChainEdge *CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain);
+ for (NodeT &PredNode : AllNodes) {
+ for (JumpT *Jump : PredNode.OutJumps) {
+ NodeT *SuccNode = Jump->Target;
+ ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
// this edge is already present in the graph
if (CurEdge != nullptr) {
- assert(SuccBlock->CurChain->getEdge(Block.CurChain) != nullptr);
+ assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
CurEdge->appendJump(Jump);
continue;
}
// this is a new edge
AllEdges.emplace_back(Jump);
- Block.CurChain->addEdge(SuccBlock->CurChain, &AllEdges.back());
- SuccBlock->CurChain->addEdge(Block.CurChain, &AllEdges.back());
+ PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
+ SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
}
}
}
- /// For a pair of blocks, A and B, block B is the forced successor of A,
+ /// For a pair of nodes, A and B, node B is the forced successor of A,
/// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
- /// to B are from A. Such blocks should be adjacent in the optimal ordering;
- /// the method finds and merges such pairs of blocks.
+ /// to B are from A. Such nodes should be adjacent in the optimal ordering;
+ /// the method finds and merges such pairs of nodes.
void mergeForcedPairs() {
// Find fallthroughs based on edge weights
- for (auto &Block : AllBlocks) {
- if (SuccNodes[Block.Index].size() == 1 &&
- PredNodes[SuccNodes[Block.Index][0]].size() == 1 &&
- SuccNodes[Block.Index][0] != 0) {
- size_t SuccIndex = SuccNodes[Block.Index][0];
- Block.ForcedSucc = &AllBlocks[SuccIndex];
- AllBlocks[SuccIndex].ForcedPred = &Block;
+ for (NodeT &Node : AllNodes) {
+ if (SuccNodes[Node.Index].size() == 1 &&
+ PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
+ SuccNodes[Node.Index][0] != 0) {
+ size_t SuccIndex = SuccNodes[Node.Index][0];
+ Node.ForcedSucc = &AllNodes[SuccIndex];
+ AllNodes[SuccIndex].ForcedPred = &Node;
}
}
// There might be 'cycles' in the forced dependencies, since profile
// data isn't 100% accurate. Typically this is observed in loops, when the
// loop edges are the hottest successors for the basic blocks of the loop.
- // Break the cycles by choosing the block with the smallest index as the
+ // Break the cycles by choosing the node with the smallest index as the
// head. This helps to keep the original order of the loops, which likely
// have already been rotated in the optimized manner.
- for (auto &Block : AllBlocks) {
- if (Block.ForcedSucc == nullptr || Block.ForcedPred == nullptr)
+ for (NodeT &Node : AllNodes) {
+ if (Node.ForcedSucc == nullptr || Node.ForcedPred == nullptr)
continue;
- auto SuccBlock = Block.ForcedSucc;
- while (SuccBlock != nullptr && SuccBlock != &Block) {
- SuccBlock = SuccBlock->ForcedSucc;
+ NodeT *SuccNode = Node.ForcedSucc;
+ while (SuccNode != nullptr && SuccNode != &Node) {
+ SuccNode = SuccNode->ForcedSucc;
}
- if (SuccBlock == nullptr)
+ if (SuccNode == nullptr)
continue;
// Break the cycle
- AllBlocks[Block.ForcedPred->Index].ForcedSucc = nullptr;
- Block.ForcedPred = nullptr;
+ AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr;
+ Node.ForcedPred = nullptr;
}
- // Merge blocks with their fallthrough successors
- for (auto &Block : AllBlocks) {
- if (Block.ForcedPred == nullptr && Block.ForcedSucc != nullptr) {
- auto CurBlock = &Block;
+ // Merge nodes with their fallthrough successors
+ for (NodeT &Node : AllNodes) {
+ if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) {
+ const NodeT *CurBlock = &Node;
while (CurBlock->ForcedSucc != nullptr) {
- const auto NextBlock = CurBlock->ForcedSucc;
- mergeChains(Block.CurChain, NextBlock->CurChain, 0, MergeTypeTy::X_Y);
+ const NodeT *NextBlock = CurBlock->ForcedSucc;
+ mergeChains(Node.CurChain, NextBlock->CurChain, 0, MergeTypeT::X_Y);
CurBlock = NextBlock;
}
}
@@ -625,23 +690,23 @@ private:
/// Merge pairs of chains while improving the ExtTSP objective.
void mergeChainPairs() {
/// Deterministically compare pairs of chains
- auto compareChainPairs = [](const Chain *A1, const Chain *B1,
- const Chain *A2, const Chain *B2) {
+ auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
+ const ChainT *A2, const ChainT *B2) {
if (A1 != A2)
- return A1->id() < A2->id();
- return B1->id() < B2->id();
+ return A1->Id < A2->Id;
+ return B1->Id < B2->Id;
};
while (HotChains.size() > 1) {
- Chain *BestChainPred = nullptr;
- Chain *BestChainSucc = nullptr;
- auto BestGain = MergeGainTy();
+ ChainT *BestChainPred = nullptr;
+ ChainT *BestChainSucc = nullptr;
+ MergeGainT BestGain;
// Iterate over all pairs of chains
- for (Chain *ChainPred : HotChains) {
+ for (ChainT *ChainPred : HotChains) {
// Get candidates for merging with the current chain
- for (auto EdgeIter : ChainPred->edges()) {
- Chain *ChainSucc = EdgeIter.first;
- class ChainEdge *ChainEdge = EdgeIter.second;
+ for (auto EdgeIt : ChainPred->Edges) {
+ ChainT *ChainSucc = EdgeIt.first;
+ ChainEdge *Edge = EdgeIt.second;
// Ignore loop edges
if (ChainPred == ChainSucc)
continue;
@@ -651,8 +716,7 @@ private:
continue;
// Compute the gain of merging the two chains
- MergeGainTy CurGain =
- getBestMergeGain(ChainPred, ChainSucc, ChainEdge);
+ MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge);
if (CurGain.score() <= EPS)
continue;
@@ -677,43 +741,43 @@ private:
}
}
- /// Merge remaining blocks into chains w/o taking jump counts into
- /// consideration. This allows to maintain the original block order in the
- /// absense of profile data
+ /// Merge remaining nodes into chains w/o taking jump counts into
+ /// consideration. This allows to maintain the original node order in the
+ /// absence of profile data
void mergeColdChains() {
for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) {
// Iterating in reverse order to make sure original fallthrough jumps are
// merged first; this might be beneficial for code size.
size_t NumSuccs = SuccNodes[SrcBB].size();
for (size_t Idx = 0; Idx < NumSuccs; Idx++) {
- auto DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1];
- auto SrcChain = AllBlocks[SrcBB].CurChain;
- auto DstChain = AllBlocks[DstBB].CurChain;
+ size_t DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1];
+ ChainT *SrcChain = AllNodes[SrcBB].CurChain;
+ ChainT *DstChain = AllNodes[DstBB].CurChain;
if (SrcChain != DstChain && !DstChain->isEntry() &&
- SrcChain->blocks().back()->Index == SrcBB &&
- DstChain->blocks().front()->Index == DstBB &&
+ SrcChain->Nodes.back()->Index == SrcBB &&
+ DstChain->Nodes.front()->Index == DstBB &&
SrcChain->isCold() == DstChain->isCold()) {
- mergeChains(SrcChain, DstChain, 0, MergeTypeTy::X_Y);
+ mergeChains(SrcChain, DstChain, 0, MergeTypeT::X_Y);
}
}
}
}
- /// Compute the Ext-TSP score for a given block order and a list of jumps.
+ /// Compute the Ext-TSP score for a given node order and a list of jumps.
double extTSPScore(const MergedChain &MergedBlocks,
- const std::vector<Jump *> &Jumps) const {
+ const std::vector<JumpT *> &Jumps) const {
if (Jumps.empty())
return 0.0;
uint64_t CurAddr = 0;
- MergedBlocks.forEach([&](const Block *BB) {
- BB->EstimatedAddr = CurAddr;
- CurAddr += BB->Size;
+ MergedBlocks.forEach([&](const NodeT *Node) {
+ Node->EstimatedAddr = CurAddr;
+ CurAddr += Node->Size;
});
double Score = 0;
- for (auto &Jump : Jumps) {
- const Block *SrcBlock = Jump->Source;
- const Block *DstBlock = Jump->Target;
+ for (JumpT *Jump : Jumps) {
+ const NodeT *SrcBlock = Jump->Source;
+ const NodeT *DstBlock = Jump->Target;
Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
DstBlock->EstimatedAddr, Jump->ExecutionCount,
Jump->IsConditional);
@@ -727,8 +791,8 @@ private:
/// computes the one having the largest increase in ExtTSP objective. The
/// result is a pair with the first element being the gain and the second
/// element being the corresponding merging type.
- MergeGainTy getBestMergeGain(Chain *ChainPred, Chain *ChainSucc,
- ChainEdge *Edge) const {
+ MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
+ ChainEdge *Edge) const {
if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) {
return Edge->getCachedMergeGain(ChainPred, ChainSucc);
}
@@ -742,22 +806,22 @@ private:
assert(!Jumps.empty() && "trying to merge chains w/o jumps");
// The object holds the best currently chosen gain of merging the two chains
- MergeGainTy Gain = MergeGainTy();
+ MergeGainT Gain = MergeGainT();
/// Given a merge offset and a list of merge types, try to merge two chains
/// and update Gain with a better alternative
auto tryChainMerging = [&](size_t Offset,
- const std::vector<MergeTypeTy> &MergeTypes) {
+ const std::vector<MergeTypeT> &MergeTypes) {
// Skip merging corresponding to concatenation w/o splitting
- if (Offset == 0 || Offset == ChainPred->blocks().size())
+ if (Offset == 0 || Offset == ChainPred->Nodes.size())
return;
// Skip merging if it breaks Forced successors
- auto BB = ChainPred->blocks()[Offset - 1];
- if (BB->ForcedSucc != nullptr)
+ NodeT *Node = ChainPred->Nodes[Offset - 1];
+ if (Node->ForcedSucc != nullptr)
return;
// Apply the merge, compute the corresponding gain, and update the best
// value, if the merge is beneficial
- for (const auto &MergeType : MergeTypes) {
+ for (const MergeTypeT &MergeType : MergeTypes) {
Gain.updateIfLessThan(
computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType));
}
@@ -765,36 +829,36 @@ private:
// Try to concatenate two chains w/o splitting
Gain.updateIfLessThan(
- computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeTy::X_Y));
+ computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y));
if (EnableChainSplitAlongJumps) {
- // Attach (a part of) ChainPred before the first block of ChainSucc
- for (auto &Jump : ChainSucc->blocks().front()->InJumps) {
- const auto SrcBlock = Jump->Source;
+ // Attach (a part of) ChainPred before the first node of ChainSucc
+ for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
+ const NodeT *SrcBlock = Jump->Source;
if (SrcBlock->CurChain != ChainPred)
continue;
size_t Offset = SrcBlock->CurIndex + 1;
- tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::X2_X1_Y});
+ tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
}
- // Attach (a part of) ChainPred after the last block of ChainSucc
- for (auto &Jump : ChainSucc->blocks().back()->OutJumps) {
- const auto DstBlock = Jump->Source;
+ // Attach (a part of) ChainPred after the last node of ChainSucc
+ for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
+ const NodeT *DstBlock = Jump->Source;
if (DstBlock->CurChain != ChainPred)
continue;
size_t Offset = DstBlock->CurIndex;
- tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::Y_X2_X1});
+ tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
}
}
// Try to break ChainPred in various ways and concatenate with ChainSucc
- if (ChainPred->blocks().size() <= ChainSplitThreshold) {
- for (size_t Offset = 1; Offset < ChainPred->blocks().size(); Offset++) {
+ if (ChainPred->Nodes.size() <= ChainSplitThreshold) {
+ for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) {
// Try to split the chain in different ways. In practice, applying
// X2_Y_X1 merging is almost never provides benefits; thus, we exclude
// it from consideration to reduce the search space
- tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::Y_X2_X1,
- MergeTypeTy::X2_X1_Y});
+ tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1,
+ MergeTypeT::X2_X1_Y});
}
}
Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
@@ -805,96 +869,66 @@ private:
/// merge 'type' and 'offset'.
///
/// The two chains are not modified in the method.
- MergeGainTy computeMergeGain(const Chain *ChainPred, const Chain *ChainSucc,
- const std::vector<Jump *> &Jumps,
- size_t MergeOffset,
- MergeTypeTy MergeType) const {
- auto MergedBlocks = mergeBlocks(ChainPred->blocks(), ChainSucc->blocks(),
- MergeOffset, MergeType);
-
- // Do not allow a merge that does not preserve the original entry block
+ MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
+ const std::vector<JumpT *> &Jumps,
+ size_t MergeOffset, MergeTypeT MergeType) const {
+ auto MergedBlocks =
+ mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
+
+ // Do not allow a merge that does not preserve the original entry point
if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
- !MergedBlocks.getFirstBlock()->isEntry())
- return MergeGainTy();
+ !MergedBlocks.getFirstNode()->isEntry())
+ return MergeGainT();
// The gain for the new chain
- auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->score();
- return MergeGainTy(NewGainScore, MergeOffset, MergeType);
- }
-
- /// Merge two chains of blocks respecting a given merge 'type' and 'offset'.
- ///
- /// If MergeType == 0, then the result is a concatenation of two chains.
- /// Otherwise, the first chain is cut into two sub-chains at the offset,
- /// and merged using all possible ways of concatenating three chains.
- MergedChain mergeBlocks(const std::vector<Block *> &X,
- const std::vector<Block *> &Y, size_t MergeOffset,
- MergeTypeTy MergeType) const {
- // Split the first chain, X, into X1 and X2
- BlockIter BeginX1 = X.begin();
- BlockIter EndX1 = X.begin() + MergeOffset;
- BlockIter BeginX2 = X.begin() + MergeOffset;
- BlockIter EndX2 = X.end();
- BlockIter BeginY = Y.begin();
- BlockIter EndY = Y.end();
-
- // Construct a new chain from the three existing ones
- switch (MergeType) {
- case MergeTypeTy::X_Y:
- return MergedChain(BeginX1, EndX2, BeginY, EndY);
- case MergeTypeTy::X1_Y_X2:
- return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
- case MergeTypeTy::Y_X2_X1:
- return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
- case MergeTypeTy::X2_X1_Y:
- return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
- }
- llvm_unreachable("unexpected chain merge type");
+ auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score;
+ return MergeGainT(NewGainScore, MergeOffset, MergeType);
}
/// Merge chain From into chain Into, update the list of active chains,
/// adjacency information, and the corresponding cached values.
- void mergeChains(Chain *Into, Chain *From, size_t MergeOffset,
- MergeTypeTy MergeType) {
+ void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
+ MergeTypeT MergeType) {
assert(Into != From && "a chain cannot be merged with itself");
- // Merge the blocks
- MergedChain MergedBlocks =
- mergeBlocks(Into->blocks(), From->blocks(), MergeOffset, MergeType);
- Into->merge(From, MergedBlocks.getBlocks());
+ // Merge the nodes
+ MergedChain MergedNodes =
+ mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
+ Into->merge(From, MergedNodes.getNodes());
+
+ // Merge the edges
Into->mergeEdges(From);
From->clear();
// Update cached ext-tsp score for the new chain
ChainEdge *SelfEdge = Into->getEdge(Into);
if (SelfEdge != nullptr) {
- MergedBlocks = MergedChain(Into->blocks().begin(), Into->blocks().end());
- Into->setScore(extTSPScore(MergedBlocks, SelfEdge->jumps()));
+ MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end());
+ Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps());
}
- // Remove chain From from the list of active chains
+ // Remove the chain from the list of active chains
llvm::erase_value(HotChains, From);
// Invalidate caches
- for (auto EdgeIter : Into->edges()) {
- EdgeIter.second->invalidateCache();
- }
+ for (auto EdgeIt : Into->Edges)
+ EdgeIt.second->invalidateCache();
}
- /// Concatenate all chains into a final order of blocks.
+ /// Concatenate all chains into the final order.
void concatChains(std::vector<uint64_t> &Order) {
- // Collect chains and calculate some stats for their sorting
- std::vector<Chain *> SortedChains;
- DenseMap<const Chain *, double> ChainDensity;
- for (auto &Chain : AllChains) {
- if (!Chain.blocks().empty()) {
+ // Collect chains and calculate density stats for their sorting
+ std::vector<const ChainT *> SortedChains;
+ DenseMap<const ChainT *, double> ChainDensity;
+ for (ChainT &Chain : AllChains) {
+ if (!Chain.Nodes.empty()) {
SortedChains.push_back(&Chain);
- // Using doubles to avoid overflow of ExecutionCount
+ // Using doubles to avoid overflow of ExecutionCounts
double Size = 0;
double ExecutionCount = 0;
- for (auto *Block : Chain.blocks()) {
- Size += static_cast<double>(Block->Size);
- ExecutionCount += static_cast<double>(Block->ExecutionCount);
+ for (NodeT *Node : Chain.Nodes) {
+ Size += static_cast<double>(Node->Size);
+ ExecutionCount += static_cast<double>(Node->ExecutionCount);
}
assert(Size > 0 && "a chain of zero size");
ChainDensity[&Chain] = ExecutionCount / Size;
@@ -903,24 +937,23 @@ private:
// Sorting chains by density in the decreasing order
std::stable_sort(SortedChains.begin(), SortedChains.end(),
- [&](const Chain *C1, const Chain *C2) {
- // Make sure the original entry block is at the
+ [&](const ChainT *L, const ChainT *R) {
+ // Make sure the original entry point is at the
// beginning of the order
- if (C1->isEntry() != C2->isEntry()) {
- return C1->isEntry();
- }
+ if (L->isEntry() != R->isEntry())
+ return L->isEntry();
- const double D1 = ChainDensity[C1];
- const double D2 = ChainDensity[C2];
+ const double DL = ChainDensity[L];
+ const double DR = ChainDensity[R];
// Compare by density and break ties by chain identifiers
- return (D1 != D2) ? (D1 > D2) : (C1->id() < C2->id());
+ return (DL != DR) ? (DL > DR) : (L->Id < R->Id);
});
- // Collect the blocks in the order specified by their chains
+ // Collect the nodes in the order specified by their chains
Order.reserve(NumNodes);
- for (Chain *Chain : SortedChains) {
- for (Block *Block : Chain->blocks()) {
- Order.push_back(Block->Index);
+ for (const ChainT *Chain : SortedChains) {
+ for (NodeT *Node : Chain->Nodes) {
+ Order.push_back(Node->Index);
}
}
}
@@ -935,49 +968,47 @@ private:
/// Predecessors of each node.
std::vector<std::vector<uint64_t>> PredNodes;
- /// All basic blocks.
- std::vector<Block> AllBlocks;
+ /// All nodes (basic blocks) in the graph.
+ std::vector<NodeT> AllNodes;
- /// All jumps between blocks.
- std::vector<Jump> AllJumps;
+ /// All jumps between the nodes.
+ std::vector<JumpT> AllJumps;
- /// All chains of basic blocks.
- std::vector<Chain> AllChains;
+ /// All chains of nodes.
+ std::vector<ChainT> AllChains;
- /// All edges between chains.
+ /// All edges between the chains.
std::vector<ChainEdge> AllEdges;
/// Active chains. The vector gets updated at runtime when chains are merged.
- std::vector<Chain *> HotChains;
+ std::vector<ChainT *> HotChains;
};
} // end of anonymous namespace
-std::vector<uint64_t> llvm::applyExtTspLayout(
- const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) {
- size_t NumNodes = NodeSizes.size();
-
- // Verify correctness of the input data.
+std::vector<uint64_t>
+llvm::applyExtTspLayout(const std::vector<uint64_t> &NodeSizes,
+ const std::vector<uint64_t> &NodeCounts,
+ const std::vector<EdgeCountT> &EdgeCounts) {
+ // Verify correctness of the input data
assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input");
- assert(NumNodes > 2 && "Incorrect input");
+ assert(NodeSizes.size() > 2 && "Incorrect input");
- // Apply the reordering algorithm.
- auto Alg = ExtTSPImpl(NumNodes, NodeSizes, NodeCounts, EdgeCounts);
+ // Apply the reordering algorithm
+ ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts);
std::vector<uint64_t> Result;
Alg.run(Result);
- // Verify correctness of the output.
+ // Verify correctness of the output
assert(Result.front() == 0 && "Original entry point is not preserved");
- assert(Result.size() == NumNodes && "Incorrect size of reordered layout");
+ assert(Result.size() == NodeSizes.size() && "Incorrect size of layout");
return Result;
}
-double llvm::calcExtTspScore(
- const std::vector<uint64_t> &Order, const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) {
+double llvm::calcExtTspScore(const std::vector<uint64_t> &Order,
+ const std::vector<uint64_t> &NodeSizes,
+ const std::vector<uint64_t> &NodeCounts,
+ const std::vector<EdgeCountT> &EdgeCounts) {
// Estimate addresses of the blocks in memory
std::vector<uint64_t> Addr(NodeSizes.size(), 0);
for (size_t Idx = 1; Idx < Order.size(); Idx++) {
@@ -985,15 +1016,15 @@ double llvm::calcExtTspScore(
}
std::vector<uint64_t> OutDegree(NodeSizes.size(), 0);
for (auto It : EdgeCounts) {
- auto Pred = It.first.first;
+ uint64_t Pred = It.first.first;
OutDegree[Pred]++;
}
// Increase the score for each jump
double Score = 0;
for (auto It : EdgeCounts) {
- auto Pred = It.first.first;
- auto Succ = It.first.second;
+ uint64_t Pred = It.first.first;
+ uint64_t Succ = It.first.second;
uint64_t Count = It.second;
bool IsConditional = OutDegree[Pred] > 1;
Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count,
@@ -1002,10 +1033,9 @@ double llvm::calcExtTspScore(
return Score;
}
-double llvm::calcExtTspScore(
- const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) {
+double llvm::calcExtTspScore(const std::vector<uint64_t> &NodeSizes,
+ const std::vector<uint64_t> &NodeCounts,
+ const std::vector<EdgeCountT> &EdgeCounts) {
std::vector<uint64_t> Order(NodeSizes.size());
for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) {
Order[Idx] = Idx;
diff --git a/llvm/lib/Transforms/Utils/CountVisits.cpp b/llvm/lib/Transforms/Utils/CountVisits.cpp
new file mode 100644
index 000000000000..4faded8fc656
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/CountVisits.cpp
@@ -0,0 +1,25 @@
+//===- CountVisits.cpp ----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/CountVisits.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/PassManager.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "count-visits"
+
+STATISTIC(MaxVisited, "Max number of times we visited a function");
+
+PreservedAnalyses CountVisitsPass::run(Function &F, FunctionAnalysisManager &) {
+ uint32_t Count = Counts[F.getName()] + 1;
+ Counts[F.getName()] = Count;
+ if (Count > MaxVisited)
+ MaxVisited = Count;
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Transforms/Utils/CtorUtils.cpp b/llvm/lib/Transforms/Utils/CtorUtils.cpp
index c997f39508e3..e07c92df2265 100644
--- a/llvm/lib/Transforms/Utils/CtorUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CtorUtils.cpp
@@ -48,7 +48,7 @@ static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemov
GlobalVariable *NGV =
new GlobalVariable(CA->getType(), GCL->isConstant(), GCL->getLinkage(),
CA, "", GCL->getThreadLocalMode());
- GCL->getParent()->getGlobalList().insert(GCL->getIterator(), NGV);
+ GCL->getParent()->insertGlobalVariable(GCL->getIterator(), NGV);
NGV->takeName(GCL);
// Nuke the old list, replacing any uses with the new one.
diff --git a/llvm/lib/Transforms/Utils/Debugify.cpp b/llvm/lib/Transforms/Utils/Debugify.cpp
index 989473693a0b..93cad0888a56 100644
--- a/llvm/lib/Transforms/Utils/Debugify.cpp
+++ b/llvm/lib/Transforms/Utils/Debugify.cpp
@@ -979,7 +979,9 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) {
collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass,
"ModuleDebugify (original debuginfo)",
NameOfWrappedPass);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
ModulePass *createCheckDebugifyModulePass(
@@ -1027,45 +1029,58 @@ static bool isIgnoredPass(StringRef PassID) {
}
void DebugifyEachInstrumentation::registerCallbacks(
- PassInstrumentationCallbacks &PIC) {
- PIC.registerBeforeNonSkippedPassCallback([this](StringRef P, Any IR) {
- if (isIgnoredPass(P))
- return;
- if (const auto **F = any_cast<const Function *>(&IR))
- applyDebugify(*const_cast<Function *>(*F),
- Mode, DebugInfoBeforePass, P);
- else if (const auto **M = any_cast<const Module *>(&IR))
- applyDebugify(*const_cast<Module *>(*M),
- Mode, DebugInfoBeforePass, P);
- });
- PIC.registerAfterPassCallback([this](StringRef P, Any IR,
- const PreservedAnalyses &PassPA) {
+ PassInstrumentationCallbacks &PIC, ModuleAnalysisManager &MAM) {
+ PIC.registerBeforeNonSkippedPassCallback([this, &MAM](StringRef P, Any IR) {
if (isIgnoredPass(P))
return;
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
if (const auto **CF = any_cast<const Function *>(&IR)) {
- auto &F = *const_cast<Function *>(*CF);
- Module &M = *F.getParent();
- auto It = F.getIterator();
- if (Mode == DebugifyMode::SyntheticDebugInfo)
- checkDebugifyMetadata(M, make_range(It, std::next(It)), P,
- "CheckFunctionDebugify", /*Strip=*/true, DIStatsMap);
- else
- checkDebugInfoMetadata(
- M, make_range(It, std::next(It)), *DebugInfoBeforePass,
- "CheckModuleDebugify (original debuginfo)",
- P, OrigDIVerifyBugsReportFilePath);
+ Function &F = *const_cast<Function *>(*CF);
+ applyDebugify(F, Mode, DebugInfoBeforePass, P);
+ MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent())
+ .getManager()
+ .invalidate(F, PA);
} else if (const auto **CM = any_cast<const Module *>(&IR)) {
- auto &M = *const_cast<Module *>(*CM);
- if (Mode == DebugifyMode::SyntheticDebugInfo)
- checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify",
- /*Strip=*/true, DIStatsMap);
- else
- checkDebugInfoMetadata(
- M, M.functions(), *DebugInfoBeforePass,
- "CheckModuleDebugify (original debuginfo)",
- P, OrigDIVerifyBugsReportFilePath);
+ Module &M = *const_cast<Module *>(*CM);
+ applyDebugify(M, Mode, DebugInfoBeforePass, P);
+ MAM.invalidate(M, PA);
}
});
+ PIC.registerAfterPassCallback(
+ [this, &MAM](StringRef P, Any IR, const PreservedAnalyses &PassPA) {
+ if (isIgnoredPass(P))
+ return;
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ if (const auto **CF = any_cast<const Function *>(&IR)) {
+ auto &F = *const_cast<Function *>(*CF);
+ Module &M = *F.getParent();
+ auto It = F.getIterator();
+ if (Mode == DebugifyMode::SyntheticDebugInfo)
+ checkDebugifyMetadata(M, make_range(It, std::next(It)), P,
+ "CheckFunctionDebugify", /*Strip=*/true,
+ DIStatsMap);
+ else
+ checkDebugInfoMetadata(M, make_range(It, std::next(It)),
+ *DebugInfoBeforePass,
+ "CheckModuleDebugify (original debuginfo)",
+ P, OrigDIVerifyBugsReportFilePath);
+ MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent())
+ .getManager()
+ .invalidate(F, PA);
+ } else if (const auto **CM = any_cast<const Module *>(&IR)) {
+ Module &M = *const_cast<Module *>(*CM);
+ if (Mode == DebugifyMode::SyntheticDebugInfo)
+ checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify",
+ /*Strip=*/true, DIStatsMap);
+ else
+ checkDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass,
+ "CheckModuleDebugify (original debuginfo)",
+ P, OrigDIVerifyBugsReportFilePath);
+ MAM.invalidate(M, PA);
+ }
+ });
}
char DebugifyModulePass::ID = 0;
diff --git a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp
index 086ea088dc5e..c894afee68a2 100644
--- a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp
+++ b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp
@@ -74,6 +74,7 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads,
V = new LoadInst(I.getType(), Slot, I.getName() + ".reload",
VolatileLoads,
PN->getIncomingBlock(i)->getTerminator());
+ Loads[PN->getIncomingBlock(i)] = V;
}
PN->setIncomingValue(i, V);
}
diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
index 53af1b1969c2..d424ebbef99d 100644
--- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
+++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/EntryExitInstrumenter.h"
-#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Dominators.h"
@@ -16,9 +15,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/Transforms/Utils.h"
+#include "llvm/TargetParser/Triple.h"
using namespace llvm;
@@ -83,6 +80,13 @@ static void insertCall(Function &CurFn, StringRef Func,
}
static bool runOnFunction(Function &F, bool PostInlining) {
+ // The asm in a naked function may reasonably expect the argument registers
+ // and the return address register (if present) to be live. An inserted
+ // function call will clobber these registers. Simply skip naked functions for
+ // all targets.
+ if (F.hasFnAttribute(Attribute::Naked))
+ return false;
+
StringRef EntryAttr = PostInlining ? "instrument-function-entry-inlined"
: "instrument-function-entry";
@@ -145,8 +149,8 @@ void llvm::EntryExitInstrumenterPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<llvm::EntryExitInstrumenterPass> *>(this)
->printPipeline(OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (PostInlining)
OS << "post-inline";
- OS << ">";
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
index 91053338df5f..88c838685bca 100644
--- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
+++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
@@ -12,9 +12,9 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/EscapeEnumerator.h"
-#include "llvm/ADT/Triple.h"
-#include "llvm/Analysis/EHPersonalities.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Module.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
diff --git a/llvm/lib/Transforms/Utils/Evaluator.cpp b/llvm/lib/Transforms/Utils/Evaluator.cpp
index dc58bebd724b..23c1ca366a44 100644
--- a/llvm/lib/Transforms/Utils/Evaluator.cpp
+++ b/llvm/lib/Transforms/Utils/Evaluator.cpp
@@ -121,7 +121,7 @@ isSimpleEnoughValueToCommit(Constant *C,
}
void Evaluator::MutableValue::clear() {
- if (auto *Agg = Val.dyn_cast<MutableAggregate *>())
+ if (auto *Agg = dyn_cast_if_present<MutableAggregate *>(Val))
delete Agg;
Val = nullptr;
}
@@ -130,7 +130,7 @@ Constant *Evaluator::MutableValue::read(Type *Ty, APInt Offset,
const DataLayout &DL) const {
TypeSize TySize = DL.getTypeStoreSize(Ty);
const MutableValue *V = this;
- while (const auto *Agg = V->Val.dyn_cast<MutableAggregate *>()) {
+ while (const auto *Agg = dyn_cast_if_present<MutableAggregate *>(V->Val)) {
Type *AggTy = Agg->Ty;
std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset);
if (!Index || Index->uge(Agg->Elements.size()) ||
@@ -140,11 +140,11 @@ Constant *Evaluator::MutableValue::read(Type *Ty, APInt Offset,
V = &Agg->Elements[Index->getZExtValue()];
}
- return ConstantFoldLoadFromConst(V->Val.get<Constant *>(), Ty, Offset, DL);
+ return ConstantFoldLoadFromConst(cast<Constant *>(V->Val), Ty, Offset, DL);
}
bool Evaluator::MutableValue::makeMutable() {
- Constant *C = Val.get<Constant *>();
+ Constant *C = cast<Constant *>(Val);
Type *Ty = C->getType();
unsigned NumElements;
if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
@@ -171,10 +171,10 @@ bool Evaluator::MutableValue::write(Constant *V, APInt Offset,
MutableValue *MV = this;
while (Offset != 0 ||
!CastInst::isBitOrNoopPointerCastable(Ty, MV->getType(), DL)) {
- if (MV->Val.is<Constant *>() && !MV->makeMutable())
+ if (isa<Constant *>(MV->Val) && !MV->makeMutable())
return false;
- MutableAggregate *Agg = MV->Val.get<MutableAggregate *>();
+ MutableAggregate *Agg = cast<MutableAggregate *>(MV->Val);
Type *AggTy = Agg->Ty;
std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset);
if (!Index || Index->uge(Agg->Elements.size()) ||
@@ -413,16 +413,28 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB,
}
Constant *Val = getVal(MSI->getValue());
- APInt Len = LenC->getValue();
- while (Len != 0) {
- Constant *DestVal = ComputeLoadResult(GV, Val->getType(), Offset);
- if (DestVal != Val) {
- LLVM_DEBUG(dbgs() << "Memset is not a no-op at offset "
- << Offset << " of " << *GV << ".\n");
+ // Avoid the byte-per-byte scan if we're memseting a zeroinitializer
+ // to zero.
+ if (!Val->isNullValue() || MutatedMemory.contains(GV) ||
+ !GV->hasDefinitiveInitializer() ||
+ !GV->getInitializer()->isNullValue()) {
+ APInt Len = LenC->getValue();
+ if (Len.ugt(64 * 1024)) {
+ LLVM_DEBUG(dbgs() << "Not evaluating large memset of size "
+ << Len << "\n");
return false;
}
- ++Offset;
- --Len;
+
+ while (Len != 0) {
+ Constant *DestVal = ComputeLoadResult(GV, Val->getType(), Offset);
+ if (DestVal != Val) {
+ LLVM_DEBUG(dbgs() << "Memset is not a no-op at offset "
+ << Offset << " of " << *GV << ".\n");
+ return false;
+ }
+ ++Offset;
+ --Len;
+ }
}
LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n");
diff --git a/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/llvm/lib/Transforms/Utils/FlattenCFG.cpp
index 2fb2ab82e41a..1925b91c4da7 100644
--- a/llvm/lib/Transforms/Utils/FlattenCFG.cpp
+++ b/llvm/lib/Transforms/Utils/FlattenCFG.cpp
@@ -487,17 +487,10 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) {
BasicBlock::iterator SaveInsertPt = Builder.GetInsertPoint();
Builder.SetInsertPoint(PBI);
if (InvertCond2) {
- // If this is a "cmp" instruction, only used for branching (and nowhere
- // else), then we can simply invert the predicate.
- auto Cmp2 = dyn_cast<CmpInst>(CInst2);
- if (Cmp2 && Cmp2->hasOneUse())
- Cmp2->setPredicate(Cmp2->getInversePredicate());
- else
- CInst2 = cast<Instruction>(Builder.CreateNot(CInst2));
- PBI->swapSuccessors();
+ InvertBranch(PBI, Builder);
}
- Value *NC = Builder.CreateBinOp(CombineOp, CInst1, CInst2);
- PBI->replaceUsesOfWith(CInst2, NC);
+ Value *NC = Builder.CreateBinOp(CombineOp, CInst1, PBI->getCondition());
+ PBI->replaceUsesOfWith(PBI->getCondition(), NC);
Builder.SetInsertPoint(SaveInsertBB, SaveInsertPt);
// Handle PHI node to replace its predecessors to FirstEntryBlock.
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index 3fa61ec68cd3..8daeb92130ba 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -157,16 +157,31 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
return 0;
}
-int FunctionComparator::cmpRangeMetadata(const MDNode *L,
- const MDNode *R) const {
+int FunctionComparator::cmpMetadata(const Metadata *L,
+ const Metadata *R) const {
+ // TODO: the following routine coerce the metadata contents into constants
+ // before comparison.
+ // It ignores any other cases, so that the metadata nodes are considered
+ // equal even though this is not correct.
+ // We should structurally compare the metadata nodes to be perfect here.
+ auto *CL = dyn_cast<ConstantAsMetadata>(L);
+ auto *CR = dyn_cast<ConstantAsMetadata>(R);
+ if (CL == CR)
+ return 0;
+ if (!CL)
+ return -1;
+ if (!CR)
+ return 1;
+ return cmpConstants(CL->getValue(), CR->getValue());
+}
+
+int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
if (L == R)
return 0;
if (!L)
return -1;
if (!R)
return 1;
- // Range metadata is a sequence of numbers. Make sure they are the same
- // sequence.
// TODO: Note that as this is metadata, it is possible to drop and/or merge
// this data when considering functions to merge. Thus this comparison would
// return 0 (i.e. equivalent), but merging would become more complicated
@@ -175,10 +190,30 @@ int FunctionComparator::cmpRangeMetadata(const MDNode *L,
// function semantically.
if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
return Res;
- for (size_t I = 0; I < L->getNumOperands(); ++I) {
- ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I));
- ConstantInt *RLow = mdconst::extract<ConstantInt>(R->getOperand(I));
- if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue()))
+ for (size_t I = 0; I < L->getNumOperands(); ++I)
+ if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I)))
+ return Res;
+ return 0;
+}
+
+int FunctionComparator::cmpInstMetadata(Instruction const *L,
+ Instruction const *R) const {
+ /// These metadata affects the other optimization passes by making assertions
+ /// or constraints.
+ /// Values that carry different expectations should be considered different.
+ SmallVector<std::pair<unsigned, MDNode *>> MDL, MDR;
+ L->getAllMetadataOtherThanDebugLoc(MDL);
+ R->getAllMetadataOtherThanDebugLoc(MDR);
+ if (MDL.size() > MDR.size())
+ return 1;
+ else if (MDL.size() < MDR.size())
+ return -1;
+ for (size_t I = 0, N = MDL.size(); I < N; ++I) {
+ auto const [KeyL, ML] = MDL[I];
+ auto const [KeyR, MR] = MDR[I];
+ if (int Res = cmpNumbers(KeyL, KeyR))
+ return Res;
+ if (int Res = cmpMDNode(ML, MR))
return Res;
}
return 0;
@@ -586,9 +621,7 @@ int FunctionComparator::cmpOperations(const Instruction *L,
if (int Res = cmpNumbers(LI->getSyncScopeID(),
cast<LoadInst>(R)->getSyncScopeID()))
return Res;
- return cmpRangeMetadata(
- LI->getMetadata(LLVMContext::MD_range),
- cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
+ return cmpInstMetadata(L, R);
}
if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
if (int Res =
@@ -616,8 +649,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
if (int Res = cmpNumbers(CI->getTailCallKind(),
cast<CallInst>(R)->getTailCallKind()))
return Res;
- return cmpRangeMetadata(L->getMetadata(LLVMContext::MD_range),
- R->getMetadata(LLVMContext::MD_range));
+ return cmpMDNode(L->getMetadata(LLVMContext::MD_range),
+ R->getMetadata(LLVMContext::MD_range));
}
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -715,8 +748,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
// When we have target data, we can reduce the GEP down to the value in bytes
// added to the address.
const DataLayout &DL = FnL->getParent()->getDataLayout();
- unsigned BitWidth = DL.getPointerSizeInBits(ASL);
- APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0);
+ unsigned OffsetBitWidth = DL.getIndexSizeInBits(ASL);
+ APInt OffsetL(OffsetBitWidth, 0), OffsetR(OffsetBitWidth, 0);
if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
GEPR->accumulateConstantOffset(DL, OffsetR))
return cmpAPInts(OffsetL, OffsetR);
diff --git a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
index 55bcb6f3b121..dab0be3a9fde 100644
--- a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
+++ b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
@@ -19,7 +19,6 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/InstIterator.h"
-#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace llvm;
@@ -40,7 +39,7 @@ STATISTIC(NumCompUsedAdded,
/// CI (other than void) need to be widened to a VectorType of VF
/// lanes.
static void addVariantDeclaration(CallInst &CI, const ElementCount &VF,
- const StringRef VFName) {
+ bool Predicate, const StringRef VFName) {
Module *M = CI.getModule();
// Add function declaration.
@@ -50,6 +49,8 @@ static void addVariantDeclaration(CallInst &CI, const ElementCount &VF,
Tys.push_back(ToVectorTy(ArgOperand->getType(), VF));
assert(!CI.getFunctionType()->isVarArg() &&
"VarArg functions are not supported.");
+ if (Predicate)
+ Tys.push_back(ToVectorTy(Type::getInt1Ty(RetTy->getContext()), VF));
FunctionType *FTy = FunctionType::get(RetTy, Tys, /*isVarArg=*/false);
Function *VectorF =
Function::Create(FTy, Function::ExternalLinkage, VFName, M);
@@ -89,19 +90,19 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) {
const SetVector<StringRef> OriginalSetOfMappings(Mappings.begin(),
Mappings.end());
- auto AddVariantDecl = [&](const ElementCount &VF) {
+ auto AddVariantDecl = [&](const ElementCount &VF, bool Predicate) {
const std::string TLIName =
- std::string(TLI.getVectorizedFunction(ScalarName, VF));
+ std::string(TLI.getVectorizedFunction(ScalarName, VF, Predicate));
if (!TLIName.empty()) {
- std::string MangledName =
- VFABI::mangleTLIVectorName(TLIName, ScalarName, CI.arg_size(), VF);
+ std::string MangledName = VFABI::mangleTLIVectorName(
+ TLIName, ScalarName, CI.arg_size(), VF, Predicate);
if (!OriginalSetOfMappings.count(MangledName)) {
Mappings.push_back(MangledName);
++NumCallInjected;
}
Function *VariantF = M->getFunction(TLIName);
if (!VariantF)
- addVariantDeclaration(CI, VF, TLIName);
+ addVariantDeclaration(CI, VF, Predicate, TLIName);
}
};
@@ -109,13 +110,15 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) {
ElementCount WidestFixedVF, WidestScalableVF;
TLI.getWidestVF(ScalarName, WidestFixedVF, WidestScalableVF);
- for (ElementCount VF = ElementCount::getFixed(2);
- ElementCount::isKnownLE(VF, WidestFixedVF); VF *= 2)
- AddVariantDecl(VF);
+ for (bool Predicated : {false, true}) {
+ for (ElementCount VF = ElementCount::getFixed(2);
+ ElementCount::isKnownLE(VF, WidestFixedVF); VF *= 2)
+ AddVariantDecl(VF, Predicated);
- // TODO: Add scalable variants once we're able to test them.
- assert(WidestScalableVF.isZero() &&
- "Scalable vector mappings not yet supported");
+ for (ElementCount VF = ElementCount::getScalable(2);
+ ElementCount::isKnownLE(VF, WidestScalableVF); VF *= 2)
+ AddVariantDecl(VF, Predicated);
+ }
VFABI::setVectorVariantNames(&CI, Mappings);
}
@@ -138,39 +141,3 @@ PreservedAnalyses InjectTLIMappings::run(Function &F,
// Even if the pass adds IR attributes, the analyses are preserved.
return PreservedAnalyses::all();
}
-
-////////////////////////////////////////////////////////////////////////////////
-// Legacy PM Implementation.
-////////////////////////////////////////////////////////////////////////////////
-bool InjectTLIMappingsLegacy::runOnFunction(Function &F) {
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- return runImpl(TLI, F);
-}
-
-void InjectTLIMappingsLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.setPreservesCFG();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<TargetLibraryInfoWrapperPass>();
- AU.addPreserved<ScalarEvolutionWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<LoopAccessLegacyAnalysis>();
- AU.addPreserved<DemandedBitsWrapperPass>();
- AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Legacy Pass manager initialization
-////////////////////////////////////////////////////////////////////////////////
-char InjectTLIMappingsLegacy::ID = 0;
-
-INITIALIZE_PASS_BEGIN(InjectTLIMappingsLegacy, DEBUG_TYPE,
- "Inject TLI Mappings", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(InjectTLIMappingsLegacy, DEBUG_TYPE, "Inject TLI Mappings",
- false, false)
-
-FunctionPass *llvm::createInjectTLIMappingsLegacyPass() {
- return new InjectTLIMappingsLegacy();
-}
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 399c9a43793f..f7b93fc8fd06 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -23,7 +23,6 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CaptureTracking.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/MemoryProfileInfo.h"
#include "llvm/Analysis/ObjCARCAnalysisUtils.h"
@@ -42,6 +41,7 @@
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
@@ -99,10 +99,6 @@ PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining",
cl::init(false), cl::Hidden,
cl::desc("Convert align attributes to assumptions during inlining."));
-static cl::opt<bool> UpdateReturnAttributes(
- "update-return-attrs", cl::init(true), cl::Hidden,
- cl::desc("Update return attributes on calls within inlined body"));
-
static cl::opt<unsigned> InlinerAttributeWindow(
"max-inst-checked-for-throw-during-inlining", cl::Hidden,
cl::desc("the maximum number of instructions analyzed for may throw during "
@@ -879,9 +875,6 @@ static void propagateMemProfHelper(const CallBase *OrigCall,
// inlined callee's callsite metadata with that of the inlined call,
// and moving the subset of any memprof contexts to the inlined callee
// allocations if they match the new inlined call stack.
-// FIXME: Replace memprof metadata with function attribute if all MIB end up
-// having the same behavior. Do other context trimming/merging optimizations
-// too.
static void
propagateMemProfMetadata(Function *Callee, CallBase &CB,
bool ContainsMemProfMetadata,
@@ -1368,9 +1361,6 @@ static AttrBuilder IdentifyValidAttributes(CallBase &CB) {
}
static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) {
- if (!UpdateReturnAttributes)
- return;
-
AttrBuilder Valid = IdentifyValidAttributes(CB);
if (!Valid.hasAttributes())
return;
@@ -1460,84 +1450,10 @@ static void AddAlignmentAssumptions(CallBase &CB, InlineFunctionInfo &IFI) {
}
}
-/// Once we have cloned code over from a callee into the caller,
-/// update the specified callgraph to reflect the changes we made.
-/// Note that it's possible that not all code was copied over, so only
-/// some edges of the callgraph may remain.
-static void UpdateCallGraphAfterInlining(CallBase &CB,
- Function::iterator FirstNewBlock,
- ValueToValueMapTy &VMap,
- InlineFunctionInfo &IFI) {
- CallGraph &CG = *IFI.CG;
- const Function *Caller = CB.getCaller();
- const Function *Callee = CB.getCalledFunction();
- CallGraphNode *CalleeNode = CG[Callee];
- CallGraphNode *CallerNode = CG[Caller];
-
- // Since we inlined some uninlined call sites in the callee into the caller,
- // add edges from the caller to all of the callees of the callee.
- CallGraphNode::iterator I = CalleeNode->begin(), E = CalleeNode->end();
-
- // Consider the case where CalleeNode == CallerNode.
- CallGraphNode::CalledFunctionsVector CallCache;
- if (CalleeNode == CallerNode) {
- CallCache.assign(I, E);
- I = CallCache.begin();
- E = CallCache.end();
- }
-
- for (; I != E; ++I) {
- // Skip 'refererence' call records.
- if (!I->first)
- continue;
-
- const Value *OrigCall = *I->first;
-
- ValueToValueMapTy::iterator VMI = VMap.find(OrigCall);
- // Only copy the edge if the call was inlined!
- if (VMI == VMap.end() || VMI->second == nullptr)
- continue;
-
- // If the call was inlined, but then constant folded, there is no edge to
- // add. Check for this case.
- auto *NewCall = dyn_cast<CallBase>(VMI->second);
- if (!NewCall)
- continue;
-
- // We do not treat intrinsic calls like real function calls because we
- // expect them to become inline code; do not add an edge for an intrinsic.
- if (NewCall->getCalledFunction() &&
- NewCall->getCalledFunction()->isIntrinsic())
- continue;
-
- // Remember that this call site got inlined for the client of
- // InlineFunction.
- IFI.InlinedCalls.push_back(NewCall);
-
- // It's possible that inlining the callsite will cause it to go from an
- // indirect to a direct call by resolving a function pointer. If this
- // happens, set the callee of the new call site to a more precise
- // destination. This can also happen if the call graph node of the caller
- // was just unnecessarily imprecise.
- if (!I->second->getFunction())
- if (Function *F = NewCall->getCalledFunction()) {
- // Indirect call site resolved to direct call.
- CallerNode->addCalledFunction(NewCall, CG[F]);
-
- continue;
- }
-
- CallerNode->addCalledFunction(NewCall, I->second);
- }
-
- // Update the call graph by deleting the edge from Callee to Caller. We must
- // do this after the loop above in case Caller and Callee are the same.
- CallerNode->removeCallEdgeFor(*cast<CallBase>(&CB));
-}
-
static void HandleByValArgumentInit(Type *ByValType, Value *Dst, Value *Src,
Module *M, BasicBlock *InsertBlock,
- InlineFunctionInfo &IFI) {
+ InlineFunctionInfo &IFI,
+ Function *CalledFunc) {
IRBuilder<> Builder(InsertBlock, InsertBlock->begin());
Value *Size =
@@ -1546,8 +1462,15 @@ static void HandleByValArgumentInit(Type *ByValType, Value *Dst, Value *Src,
// Always generate a memcpy of alignment 1 here because we don't know
// the alignment of the src pointer. Other optimizations can infer
// better alignment.
- Builder.CreateMemCpy(Dst, /*DstAlign*/ Align(1), Src,
- /*SrcAlign*/ Align(1), Size);
+ CallInst *CI = Builder.CreateMemCpy(Dst, /*DstAlign*/ Align(1), Src,
+ /*SrcAlign*/ Align(1), Size);
+
+ // The verifier requires that all calls of debug-info-bearing functions
+ // from debug-info-bearing functions have a debug location (for inlining
+ // purposes). Assign a dummy location to satisfy the constraint.
+ if (!CI->getDebugLoc() && InsertBlock->getParent()->getSubprogram())
+ if (DISubprogram *SP = CalledFunc->getSubprogram())
+ CI->setDebugLoc(DILocation::get(SP->getContext(), 0, 0, SP));
}
/// When inlining a call site that has a byval argument,
@@ -1557,8 +1480,6 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg,
const Function *CalledFunc,
InlineFunctionInfo &IFI,
MaybeAlign ByValAlignment) {
- assert(cast<PointerType>(Arg->getType())
- ->isOpaqueOrPointeeTypeMatches(ByValType));
Function *Caller = TheCall->getFunction();
const DataLayout &DL = Caller->getParent()->getDataLayout();
@@ -1710,6 +1631,12 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,
if (allocaWouldBeStaticInEntry(AI))
continue;
+ // Do not force a debug loc for pseudo probes, since they do not need to
+ // be debuggable, and also they are expected to have a zero/null dwarf
+ // discriminator at this point which could be violated otherwise.
+ if (isa<PseudoProbeInst>(BI))
+ continue;
+
BI->setDebugLoc(TheCallDL);
}
@@ -2242,7 +2169,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// Inject byval arguments initialization.
for (ByValInit &Init : ByValInits)
HandleByValArgumentInit(Init.Ty, Init.Dst, Init.Src, Caller->getParent(),
- &*FirstNewBlock, IFI);
+ &*FirstNewBlock, IFI, CalledFunc);
std::optional<OperandBundleUse> ParentDeopt =
CB.getOperandBundle(LLVMContext::OB_deopt);
@@ -2292,10 +2219,6 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
}
}
- // Update the callgraph if requested.
- if (IFI.CG)
- UpdateCallGraphAfterInlining(CB, FirstNewBlock, VMap, IFI);
-
// For 'nodebug' functions, the associated DISubprogram is always null.
// Conservatively avoid propagating the callsite debug location to
// instructions inlined from a function whose DISubprogram is not null.
@@ -2333,7 +2256,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
for (BasicBlock &NewBlock :
make_range(FirstNewBlock->getIterator(), Caller->end()))
for (Instruction &I : NewBlock)
- if (auto *II = dyn_cast<CondGuardInst>(&I))
+ if (auto *II = dyn_cast<AssumeInst>(&I))
IFI.GetAssumptionCache(*Caller).registerAssumption(II);
}
@@ -2701,7 +2624,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// call graph updates weren't requested, as those provide value handle based
// tracking of inlined call sites instead. Calls to intrinsics are not
// collected because they are not inlineable.
- if (InlinedFunctionInfo.ContainsCalls && !IFI.CG) {
+ if (InlinedFunctionInfo.ContainsCalls) {
// Otherwise just collect the raw call sites that were inlined.
for (BasicBlock &NewBB :
make_range(FirstNewBlock->getIterator(), Caller->end()))
@@ -2734,7 +2657,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
if (!CB.use_empty()) {
ReturnInst *R = Returns[0];
if (&CB == R->getReturnValue())
- CB.replaceAllUsesWith(UndefValue::get(CB.getType()));
+ CB.replaceAllUsesWith(PoisonValue::get(CB.getType()));
else
CB.replaceAllUsesWith(R->getReturnValue());
}
@@ -2846,7 +2769,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// using the return value of the call with the computed value.
if (!CB.use_empty()) {
if (&CB == Returns[0]->getReturnValue())
- CB.replaceAllUsesWith(UndefValue::get(CB.getType()));
+ CB.replaceAllUsesWith(PoisonValue::get(CB.getType()));
else
CB.replaceAllUsesWith(Returns[0]->getReturnValue());
}
diff --git a/llvm/lib/Transforms/Utils/InstructionNamer.cpp b/llvm/lib/Transforms/Utils/InstructionNamer.cpp
index f3499c9c8aed..3ae570cfeb77 100644
--- a/llvm/lib/Transforms/Utils/InstructionNamer.cpp
+++ b/llvm/lib/Transforms/Utils/InstructionNamer.cpp
@@ -17,9 +17,6 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/Transforms/Utils.h"
using namespace llvm;
@@ -41,35 +38,7 @@ void nameInstructions(Function &F) {
}
}
-struct InstNamer : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- InstNamer() : FunctionPass(ID) {
- initializeInstNamerPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &Info) const override {
- Info.setPreservesAll();
- }
-
- bool runOnFunction(Function &F) override {
- nameInstructions(F);
- return true;
- }
-};
-
- char InstNamer::ID = 0;
- } // namespace
-
-INITIALIZE_PASS(InstNamer, "instnamer",
- "Assign names to anonymous instructions", false, false)
-char &llvm::InstructionNamerID = InstNamer::ID;
-//===----------------------------------------------------------------------===//
-//
-// InstructionNamer - Give any unnamed non-void instructions "tmp" names.
-//
-FunctionPass *llvm::createInstructionNamerPass() {
- return new InstNamer();
-}
+} // namespace
PreservedAnalyses InstructionNamerPass::run(Function &F,
FunctionAnalysisManager &FAM) {
diff --git a/llvm/lib/Transforms/Utils/LCSSA.cpp b/llvm/lib/Transforms/Utils/LCSSA.cpp
index af79dc456ea6..c36b0533580b 100644
--- a/llvm/lib/Transforms/Utils/LCSSA.cpp
+++ b/llvm/lib/Transforms/Utils/LCSSA.cpp
@@ -40,7 +40,6 @@
#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/Dominators.h"
-#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PredIteratorCache.h"
@@ -77,15 +76,14 @@ static bool isExitBlock(BasicBlock *BB,
/// rewrite the uses.
bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
const DominatorTree &DT, const LoopInfo &LI,
- ScalarEvolution *SE, IRBuilderBase &Builder,
- SmallVectorImpl<PHINode *> *PHIsToRemove) {
+ ScalarEvolution *SE,
+ SmallVectorImpl<PHINode *> *PHIsToRemove,
+ SmallVectorImpl<PHINode *> *InsertedPHIs) {
SmallVector<Use *, 16> UsesToRewrite;
SmallSetVector<PHINode *, 16> LocalPHIsToRemove;
PredIteratorCache PredCache;
bool Changed = false;
- IRBuilderBase::InsertPointGuard InsertPtGuard(Builder);
-
// Cache the Loop ExitBlocks across this loop. We expect to get a lot of
// instructions within the same loops, computing the exit blocks is
// expensive, and we're not mutating the loop structure.
@@ -146,17 +144,14 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
SmallVector<PHINode *, 16> AddedPHIs;
SmallVector<PHINode *, 8> PostProcessPHIs;
- SmallVector<PHINode *, 4> InsertedPHIs;
- SSAUpdater SSAUpdate(&InsertedPHIs);
+ SmallVector<PHINode *, 4> LocalInsertedPHIs;
+ SSAUpdater SSAUpdate(&LocalInsertedPHIs);
SSAUpdate.Initialize(I->getType(), I->getName());
- // Force re-computation of I, as some users now need to use the new PHI
- // node.
- if (SE)
- SE->forgetValue(I);
-
// Insert the LCSSA phi's into all of the exit blocks dominated by the
// value, and add them to the Phi's map.
+ bool HasSCEV = SE && SE->isSCEVable(I->getType()) &&
+ SE->getExistingSCEV(I) != nullptr;
for (BasicBlock *ExitBB : ExitBlocks) {
if (!DT.dominates(DomNode, DT.getNode(ExitBB)))
continue;
@@ -164,9 +159,10 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
// If we already inserted something for this BB, don't reprocess it.
if (SSAUpdate.HasValueForBlock(ExitBB))
continue;
- Builder.SetInsertPoint(&ExitBB->front());
- PHINode *PN = Builder.CreatePHI(I->getType(), PredCache.size(ExitBB),
- I->getName() + ".lcssa");
+ PHINode *PN = PHINode::Create(I->getType(), PredCache.size(ExitBB),
+ I->getName() + ".lcssa", &ExitBB->front());
+ if (InsertedPHIs)
+ InsertedPHIs->push_back(PN);
// Get the debug location from the original instruction.
PN->setDebugLoc(I->getDebugLoc());
@@ -203,6 +199,13 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
if (auto *OtherLoop = LI.getLoopFor(ExitBB))
if (!L->contains(OtherLoop))
PostProcessPHIs.push_back(PN);
+
+ // If we have a cached SCEV for the original instruction, make sure the
+ // new LCSSA phi node is also cached. This makes sures that BECounts
+ // based on it will be invalidated when the LCSSA phi node is invalidated,
+ // which some passes rely on.
+ if (HasSCEV)
+ SE->getSCEV(PN);
}
// Rewrite all uses outside the loop in terms of the new PHIs we just
@@ -256,10 +259,12 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
// SSAUpdater might have inserted phi-nodes inside other loops. We'll need
// to post-process them to keep LCSSA form.
- for (PHINode *InsertedPN : InsertedPHIs) {
+ for (PHINode *InsertedPN : LocalInsertedPHIs) {
if (auto *OtherLoop = LI.getLoopFor(InsertedPN->getParent()))
if (!L->contains(OtherLoop))
PostProcessPHIs.push_back(InsertedPN);
+ if (InsertedPHIs)
+ InsertedPHIs->push_back(InsertedPN);
}
// Post process PHI instructions that were inserted into another disjoint
@@ -392,14 +397,7 @@ bool llvm::formLCSSA(Loop &L, const DominatorTree &DT, const LoopInfo *LI,
}
}
- IRBuilder<> Builder(L.getHeader()->getContext());
- Changed = formLCSSAForInstructions(Worklist, DT, *LI, SE, Builder);
-
- // If we modified the code, remove any caches about the loop from SCEV to
- // avoid dangling entries.
- // FIXME: This is a big hammer, can we clear the cache more selectively?
- if (SE && Changed)
- SE->forgetLoop(&L);
+ Changed = formLCSSAForInstructions(Worklist, DT, *LI, SE);
assert(L.isLCSSAForm(DT));
diff --git a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
index 5dd469c7af4b..cdcfb5050bff 100644
--- a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
+++ b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
@@ -28,6 +28,7 @@
#include "llvm/Transforms/Utils/LibCallsShrinkWrap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Constants.h"
@@ -37,8 +38,6 @@
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <cmath>
@@ -51,31 +50,10 @@ STATISTIC(NumWrappedOneCond, "Number of One-Condition Wrappers Inserted");
STATISTIC(NumWrappedTwoCond, "Number of Two-Condition Wrappers Inserted");
namespace {
-class LibCallsShrinkWrapLegacyPass : public FunctionPass {
-public:
- static char ID; // Pass identification, replacement for typeid
- explicit LibCallsShrinkWrapLegacyPass() : FunctionPass(ID) {
- initializeLibCallsShrinkWrapLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override;
- bool runOnFunction(Function &F) override;
-};
-}
-
-char LibCallsShrinkWrapLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap",
- "Conditionally eliminate dead library calls", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap",
- "Conditionally eliminate dead library calls", false, false)
-
-namespace {
class LibCallsShrinkWrap : public InstVisitor<LibCallsShrinkWrap> {
public:
- LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DominatorTree *DT)
- : TLI(TLI), DT(DT){};
+ LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DomTreeUpdater &DTU)
+ : TLI(TLI), DTU(DTU){};
void visitCallInst(CallInst &CI) { checkCandidate(CI); }
bool perform() {
bool Changed = false;
@@ -101,14 +79,21 @@ private:
Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func);
Value *generateCondForPow(CallInst *CI, const LibFunc &Func);
+ // Create an OR of two conditions with given Arg and Arg2.
+ Value *createOrCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp,
+ float Val, Value *Arg2, CmpInst::Predicate Cmp2,
+ float Val2) {
+ IRBuilder<> BBBuilder(CI);
+ auto Cond2 = createCond(BBBuilder, Arg2, Cmp2, Val2);
+ auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val);
+ return BBBuilder.CreateOr(Cond1, Cond2);
+ }
+
// Create an OR of two conditions.
Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val,
CmpInst::Predicate Cmp2, float Val2) {
- IRBuilder<> BBBuilder(CI);
Value *Arg = CI->getArgOperand(0);
- auto Cond2 = createCond(BBBuilder, Arg, Cmp2, Val2);
- auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val);
- return BBBuilder.CreateOr(Cond1, Cond2);
+ return createOrCond(CI, Arg, Cmp, Val, Arg, Cmp2, Val2);
}
// Create a single condition using IRBuilder.
@@ -117,18 +102,26 @@ private:
Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val));
if (!Arg->getType()->isFloatTy())
V = ConstantExpr::getFPExtend(V, Arg->getType());
+ if (BBBuilder.GetInsertBlock()->getParent()->hasFnAttribute(Attribute::StrictFP))
+ BBBuilder.setIsFPConstrained(true);
return BBBuilder.CreateFCmp(Cmp, Arg, V);
}
+ // Create a single condition with given Arg.
+ Value *createCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp,
+ float Val) {
+ IRBuilder<> BBBuilder(CI);
+ return createCond(BBBuilder, Arg, Cmp, Val);
+ }
+
// Create a single condition.
Value *createCond(CallInst *CI, CmpInst::Predicate Cmp, float Val) {
- IRBuilder<> BBBuilder(CI);
Value *Arg = CI->getArgOperand(0);
- return createCond(BBBuilder, Arg, Cmp, Val);
+ return createCond(CI, Arg, Cmp, Val);
}
const TargetLibraryInfo &TLI;
- DominatorTree *DT;
+ DomTreeUpdater &DTU;
SmallVector<CallInst *, 16> WorkList;
};
} // end anonymous namespace
@@ -428,7 +421,6 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI,
Value *Base = CI->getArgOperand(0);
Value *Exp = CI->getArgOperand(1);
- IRBuilder<> BBBuilder(CI);
// Constant Base case.
if (ConstantFP *CF = dyn_cast<ConstantFP>(Base)) {
@@ -439,10 +431,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI,
}
++NumWrappedOneCond;
- Constant *V = ConstantFP::get(CI->getContext(), APFloat(127.0f));
- if (!Exp->getType()->isFloatTy())
- V = ConstantExpr::getFPExtend(V, Exp->getType());
- return BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V);
+ return createCond(CI, Exp, CmpInst::FCMP_OGT, 127.0f);
}
// If the Base value coming from an integer type.
@@ -467,16 +456,8 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI,
}
++NumWrappedTwoCond;
- Constant *V = ConstantFP::get(CI->getContext(), APFloat(UpperV));
- Constant *V0 = ConstantFP::get(CI->getContext(), APFloat(0.0f));
- if (!Exp->getType()->isFloatTy())
- V = ConstantExpr::getFPExtend(V, Exp->getType());
- if (!Base->getType()->isFloatTy())
- V0 = ConstantExpr::getFPExtend(V0, Exp->getType());
-
- Value *Cond = BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V);
- Value *Cond0 = BBBuilder.CreateFCmp(CmpInst::FCMP_OLE, Base, V0);
- return BBBuilder.CreateOr(Cond0, Cond);
+ return createOrCond(CI, Base, CmpInst::FCMP_OLE, 0.0f, Exp,
+ CmpInst::FCMP_OGT, UpperV);
}
LLVM_DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n");
return nullptr;
@@ -489,7 +470,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) {
MDBuilder(CI->getContext()).createBranchWeights(1, 2000);
Instruction *NewInst =
- SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, DT);
+ SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, &DTU);
BasicBlock *CallBB = NewInst->getParent();
CallBB->setName("cdce.call");
BasicBlock *SuccBB = CallBB->getSingleSuccessor();
@@ -515,40 +496,21 @@ bool LibCallsShrinkWrap::perform(CallInst *CI) {
return performCallErrors(CI, Func);
}
-void LibCallsShrinkWrapLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
-}
-
static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
DominatorTree *DT) {
if (F.hasFnAttribute(Attribute::OptimizeForSize))
return false;
- LibCallsShrinkWrap CCDCE(TLI, DT);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ LibCallsShrinkWrap CCDCE(TLI, DTU);
CCDCE.visit(F);
bool Changed = CCDCE.perform();
-// Verify the dominator after we've updated it locally.
- assert(!DT || DT->verify(DominatorTree::VerificationLevel::Fast));
+ // Verify the dominator after we've updated it locally.
+ assert(!DT ||
+ DTU.getDomTree().verify(DominatorTree::VerificationLevel::Fast));
return Changed;
}
-bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) {
- auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- return runImpl(F, TLI, DT);
-}
-
-namespace llvm {
-char &LibCallsShrinkWrapPassID = LibCallsShrinkWrapLegacyPass::ID;
-
-// Public interface to LibCallsShrinkWrap pass.
-FunctionPass *createLibCallsShrinkWrapPass() {
- return new LibCallsShrinkWrapLegacyPass();
-}
-
PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
@@ -559,4 +521,3 @@ PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F,
PA.preserve<DominatorTreeAnalysis>();
return PA;
}
-}
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 31cdd2ee56b9..f153ace5d3fc 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -25,7 +25,6 @@
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/DomTreeUpdater.h"
-#include "llvm/Analysis/EHPersonalities.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
@@ -47,6 +46,7 @@
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/GlobalObject.h"
@@ -201,16 +201,16 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
bool Changed = false;
// Figure out which case it goes to.
- for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) {
+ for (auto It = SI->case_begin(), End = SI->case_end(); It != End;) {
// Found case matching a constant operand?
- if (i->getCaseValue() == CI) {
- TheOnlyDest = i->getCaseSuccessor();
+ if (It->getCaseValue() == CI) {
+ TheOnlyDest = It->getCaseSuccessor();
break;
}
// Check to see if this branch is going to the same place as the default
// dest. If so, eliminate it as an explicit compare.
- if (i->getCaseSuccessor() == DefaultDest) {
+ if (It->getCaseSuccessor() == DefaultDest) {
MDNode *MD = getValidBranchWeightMDNode(*SI);
unsigned NCases = SI->getNumCases();
// Fold the case metadata into the default if there will be any branches
@@ -221,11 +221,11 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
extractBranchWeights(MD, Weights);
// Merge weight of this case to the default weight.
- unsigned idx = i->getCaseIndex();
+ unsigned Idx = It->getCaseIndex();
// TODO: Add overflow check.
- Weights[0] += Weights[idx+1];
+ Weights[0] += Weights[Idx + 1];
// Remove weight for this case.
- std::swap(Weights[idx+1], Weights.back());
+ std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
SI->setMetadata(LLVMContext::MD_prof,
MDBuilder(BB->getContext()).
@@ -234,14 +234,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();
DefaultDest->removePredecessor(ParentBB);
- i = SI->removeCase(i);
- e = SI->case_end();
+ It = SI->removeCase(It);
+ End = SI->case_end();
// Removing this case may have made the condition constant. In that
// case, update CI and restart iteration through the cases.
if (auto *NewCI = dyn_cast<ConstantInt>(SI->getCondition())) {
CI = NewCI;
- i = SI->case_begin();
+ It = SI->case_begin();
}
Changed = true;
@@ -251,11 +251,11 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Otherwise, check to see if the switch only branches to one destination.
// We do this by reseting "TheOnlyDest" to null when we find two non-equal
// destinations.
- if (i->getCaseSuccessor() != TheOnlyDest)
+ if (It->getCaseSuccessor() != TheOnlyDest)
TheOnlyDest = nullptr;
// Increment this iterator as we haven't removed the case.
- ++i;
+ ++It;
}
if (CI && !TheOnlyDest) {
@@ -424,18 +424,10 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
if (I->isEHPad())
return false;
- // We don't want debug info removed by anything this general, unless
- // debug info is empty.
- if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(I)) {
- if (DDI->getAddress())
- return false;
- return true;
- }
- if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(I)) {
- if (DVI->hasArgList() || DVI->getValue(0))
- return false;
- return true;
- }
+ // We don't want debug info removed by anything this general.
+ if (isa<DbgVariableIntrinsic>(I))
+ return false;
+
if (DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) {
if (DLI->getLabel())
return false;
@@ -555,7 +547,7 @@ bool llvm::RecursivelyDeleteTriviallyDeadInstructionsPermissive(
std::function<void(Value *)> AboutToDeleteCallback) {
unsigned S = 0, E = DeadInsts.size(), Alive = 0;
for (; S != E; ++S) {
- auto *I = dyn_cast<Instruction>(DeadInsts[S]);
+ auto *I = dyn_cast_or_null<Instruction>(DeadInsts[S]);
if (!I || !isInstructionTriviallyDead(I)) {
DeadInsts[S] = nullptr;
++Alive;
@@ -1231,12 +1223,10 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
// If the unconditional branch we replaced contains llvm.loop metadata, we
// add the metadata to the branch instructions in the predecessors.
- unsigned LoopMDKind = BB->getContext().getMDKindID("llvm.loop");
- Instruction *TI = BB->getTerminator();
- if (TI)
- if (MDNode *LoopMD = TI->getMetadata(LoopMDKind))
+ if (Instruction *TI = BB->getTerminator())
+ if (MDNode *LoopMD = TI->getMetadata(LLVMContext::MD_loop))
for (BasicBlock *Pred : predecessors(BB))
- Pred->getTerminator()->setMetadata(LoopMDKind, LoopMD);
+ Pred->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD);
// Everything that jumped to BB now goes to Succ.
BB->replaceAllUsesWith(Succ);
@@ -1423,6 +1413,12 @@ static Align tryEnforceAlignment(Value *V, Align PrefAlign,
if (!GO->canIncreaseAlignment())
return CurrentAlign;
+ if (GO->isThreadLocal()) {
+ unsigned MaxTLSAlign = GO->getParent()->getMaxTLSAlignment() / CHAR_BIT;
+ if (MaxTLSAlign && PrefAlign > Align(MaxTLSAlign))
+ PrefAlign = Align(MaxTLSAlign);
+ }
+
GO->setAlignment(PrefAlign);
return PrefAlign;
}
@@ -1480,19 +1476,16 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar,
/// (or fragment of the variable) described by \p DII.
///
/// This is primarily intended as a helper for the different
-/// ConvertDebugDeclareToDebugValue functions. The dbg.declare/dbg.addr that is
-/// converted describes an alloca'd variable, so we need to use the
-/// alloc size of the value when doing the comparison. E.g. an i1 value will be
-/// identified as covering an n-bit fragment, if the store size of i1 is at
-/// least n bits.
+/// ConvertDebugDeclareToDebugValue functions. The dbg.declare that is converted
+/// describes an alloca'd variable, so we need to use the alloc size of the
+/// value when doing the comparison. E.g. an i1 value will be identified as
+/// covering an n-bit fragment, if the store size of i1 is at least n bits.
static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) {
const DataLayout &DL = DII->getModule()->getDataLayout();
TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy);
- if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) {
- assert(!ValueSize.isScalable() &&
- "Fragments don't work on scalable types.");
- return ValueSize.getFixedValue() >= *FragmentSize;
- }
+ if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits())
+ return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize));
+
// We can't always calculate the size of the DI variable (e.g. if it is a
// VLA). Try to use the size of the alloca that the dbg intrinsic describes
// intead.
@@ -1513,7 +1506,7 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) {
}
/// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value
-/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic.
+/// that has an associated llvm.dbg.declare intrinsic.
void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
StoreInst *SI, DIBuilder &Builder) {
assert(DII->isAddressOfVariable() || isa<DbgAssignIntrinsic>(DII));
@@ -1524,24 +1517,39 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
DebugLoc NewLoc = getDebugValueLoc(DII);
- if (!valueCoversEntireFragment(DV->getType(), DII)) {
- // FIXME: If storing to a part of the variable described by the dbg.declare,
- // then we want to insert a dbg.value for the corresponding fragment.
- LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: "
- << *DII << '\n');
- // For now, when there is a store to parts of the variable (but we do not
- // know which part) we insert an dbg.value intrinsic to indicate that we
- // know nothing about the variable's content.
- DV = UndefValue::get(DV->getType());
+ // If the alloca describes the variable itself, i.e. the expression in the
+ // dbg.declare doesn't start with a dereference, we can perform the
+ // conversion if the value covers the entire fragment of DII.
+ // If the alloca describes the *address* of DIVar, i.e. DIExpr is
+ // *just* a DW_OP_deref, we use DV as is for the dbg.value.
+ // We conservatively ignore other dereferences, because the following two are
+ // not equivalent:
+ // dbg.declare(alloca, ..., !Expr(deref, plus_uconstant, 2))
+ // dbg.value(DV, ..., !Expr(deref, plus_uconstant, 2))
+ // The former is adding 2 to the address of the variable, whereas the latter
+ // is adding 2 to the value of the variable. As such, we insist on just a
+ // deref expression.
+ bool CanConvert =
+ DIExpr->isDeref() || (!DIExpr->startsWithDeref() &&
+ valueCoversEntireFragment(DV->getType(), DII));
+ if (CanConvert) {
Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI);
return;
}
+ // FIXME: If storing to a part of the variable described by the dbg.declare,
+ // then we want to insert a dbg.value for the corresponding fragment.
+ LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " << *DII
+ << '\n');
+ // For now, when there is a store to parts of the variable (but we do not
+ // know which part) we insert an dbg.value intrinsic to indicate that we
+ // know nothing about the variable's content.
+ DV = UndefValue::get(DV->getType());
Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI);
}
/// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value
-/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic.
+/// that has an associated llvm.dbg.declare intrinsic.
void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
LoadInst *LI, DIBuilder &Builder) {
auto *DIVar = DII->getVariable();
@@ -1569,7 +1577,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
}
/// Inserts a llvm.dbg.value intrinsic after a phi that has an associated
-/// llvm.dbg.declare or llvm.dbg.addr intrinsic.
+/// llvm.dbg.declare intrinsic.
void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
PHINode *APN, DIBuilder &Builder) {
auto *DIVar = DII->getVariable();
@@ -1752,8 +1760,8 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB,
bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,
DIBuilder &Builder, uint8_t DIExprFlags,
int Offset) {
- auto DbgAddrs = FindDbgAddrUses(Address);
- for (DbgVariableIntrinsic *DII : DbgAddrs) {
+ auto DbgDeclares = FindDbgDeclareUses(Address);
+ for (DbgVariableIntrinsic *DII : DbgDeclares) {
const DebugLoc &Loc = DII->getDebugLoc();
auto *DIVar = DII->getVariable();
auto *DIExpr = DII->getExpression();
@@ -1764,7 +1772,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,
Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, DII);
DII->eraseFromParent();
}
- return !DbgAddrs.empty();
+ return !DbgDeclares.empty();
}
static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress,
@@ -1860,9 +1868,8 @@ void llvm::salvageDebugInfoForDbgValues(
continue;
}
- // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they
- // are implicitly pointing out the value as a DWARF memory location
- // description.
+ // Do not add DW_OP_stack_value for DbgDeclare, because they are implicitly
+ // pointing out the value as a DWARF memory location description.
bool StackValue = isa<DbgValueInst>(DII);
auto DIILocation = DII->location_ops();
assert(
@@ -1896,17 +1903,14 @@ void llvm::salvageDebugInfoForDbgValues(
bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize;
if (AdditionalValues.empty() && IsValidSalvageExpr) {
DII->setExpression(SalvagedExpr);
- } else if (isa<DbgValueInst>(DII) && !isa<DbgAssignIntrinsic>(DII) &&
- IsValidSalvageExpr &&
+ } else if (isa<DbgValueInst>(DII) && IsValidSalvageExpr &&
DII->getNumVariableLocationOps() + AdditionalValues.size() <=
MaxDebugArgs) {
DII->addVariableLocationOps(AdditionalValues, SalvagedExpr);
} else {
- // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is
- // not currently supported in those instructions. Do not salvage using
- // DIArgList for dbg.assign yet. FIXME: support this.
- // Also do not salvage if the resulting DIArgList would contain an
- // unreasonably large number of values.
+ // Do not salvage using DIArgList for dbg.declare, as it is not currently
+ // supported in those instructions. Also do not salvage if the resulting
+ // DIArgList would contain an unreasonably large number of values.
DII->setKillLocation();
}
LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n');
@@ -1934,7 +1938,7 @@ Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL,
Opcodes.insert(Opcodes.begin(), {dwarf::DW_OP_LLVM_arg, 0});
CurrentLocOps = 1;
}
- for (auto Offset : VariableOffsets) {
+ for (const auto &Offset : VariableOffsets) {
AdditionalValues.push_back(Offset.first);
assert(Offset.second.isStrictlyPositive() &&
"Expected strictly positive multiplier for offset.");
@@ -1976,6 +1980,18 @@ uint64_t getDwarfOpForBinOp(Instruction::BinaryOps Opcode) {
}
}
+static void handleSSAValueOperands(uint64_t CurrentLocOps,
+ SmallVectorImpl<uint64_t> &Opcodes,
+ SmallVectorImpl<Value *> &AdditionalValues,
+ Instruction *I) {
+ if (!CurrentLocOps) {
+ Opcodes.append({dwarf::DW_OP_LLVM_arg, 0});
+ CurrentLocOps = 1;
+ }
+ Opcodes.append({dwarf::DW_OP_LLVM_arg, CurrentLocOps});
+ AdditionalValues.push_back(I->getOperand(1));
+}
+
Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps,
SmallVectorImpl<uint64_t> &Opcodes,
SmallVectorImpl<Value *> &AdditionalValues) {
@@ -1998,12 +2014,7 @@ Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps,
}
Opcodes.append({dwarf::DW_OP_constu, Val});
} else {
- if (!CurrentLocOps) {
- Opcodes.append({dwarf::DW_OP_LLVM_arg, 0});
- CurrentLocOps = 1;
- }
- Opcodes.append({dwarf::DW_OP_LLVM_arg, CurrentLocOps});
- AdditionalValues.push_back(BI->getOperand(1));
+ handleSSAValueOperands(CurrentLocOps, Opcodes, AdditionalValues, BI);
}
// Add salvaged binary operator to expression stack, if it has a valid
@@ -2015,6 +2026,60 @@ Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps,
return BI->getOperand(0);
}
+uint64_t getDwarfOpForIcmpPred(CmpInst::Predicate Pred) {
+ // The signedness of the operation is implicit in the typed stack, signed and
+ // unsigned instructions map to the same DWARF opcode.
+ switch (Pred) {
+ case CmpInst::ICMP_EQ:
+ return dwarf::DW_OP_eq;
+ case CmpInst::ICMP_NE:
+ return dwarf::DW_OP_ne;
+ case CmpInst::ICMP_UGT:
+ case CmpInst::ICMP_SGT:
+ return dwarf::DW_OP_gt;
+ case CmpInst::ICMP_UGE:
+ case CmpInst::ICMP_SGE:
+ return dwarf::DW_OP_ge;
+ case CmpInst::ICMP_ULT:
+ case CmpInst::ICMP_SLT:
+ return dwarf::DW_OP_lt;
+ case CmpInst::ICMP_ULE:
+ case CmpInst::ICMP_SLE:
+ return dwarf::DW_OP_le;
+ default:
+ return 0;
+ }
+}
+
+Value *getSalvageOpsForIcmpOp(ICmpInst *Icmp, uint64_t CurrentLocOps,
+ SmallVectorImpl<uint64_t> &Opcodes,
+ SmallVectorImpl<Value *> &AdditionalValues) {
+ // Handle icmp operations with constant integer operands as a special case.
+ auto *ConstInt = dyn_cast<ConstantInt>(Icmp->getOperand(1));
+ // Values wider than 64 bits cannot be represented within a DIExpression.
+ if (ConstInt && ConstInt->getBitWidth() > 64)
+ return nullptr;
+ // Push any Constant Int operand onto the expression stack.
+ if (ConstInt) {
+ if (Icmp->isSigned())
+ Opcodes.push_back(dwarf::DW_OP_consts);
+ else
+ Opcodes.push_back(dwarf::DW_OP_constu);
+ uint64_t Val = ConstInt->getSExtValue();
+ Opcodes.push_back(Val);
+ } else {
+ handleSSAValueOperands(CurrentLocOps, Opcodes, AdditionalValues, Icmp);
+ }
+
+ // Add salvaged binary operator to expression stack, if it has a valid
+ // representation in a DIExpression.
+ uint64_t DwarfIcmpOp = getDwarfOpForIcmpPred(Icmp->getPredicate());
+ if (!DwarfIcmpOp)
+ return nullptr;
+ Opcodes.push_back(DwarfIcmpOp);
+ return Icmp->getOperand(0);
+}
+
Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps,
SmallVectorImpl<uint64_t> &Ops,
SmallVectorImpl<Value *> &AdditionalValues) {
@@ -2054,6 +2119,8 @@ Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps,
return getSalvageOpsForGEP(GEP, DL, CurrentLocOps, Ops, AdditionalValues);
if (auto *BI = dyn_cast<BinaryOperator>(&I))
return getSalvageOpsForBinOp(BI, CurrentLocOps, Ops, AdditionalValues);
+ if (auto *IC = dyn_cast<ICmpInst>(&I))
+ return getSalvageOpsForIcmpOp(IC, CurrentLocOps, Ops, AdditionalValues);
// *Not* to do: we should not attempt to salvage load instructions,
// because the validity and lifetime of a dbg.value containing
@@ -2661,43 +2728,52 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
intersectAccessGroups(K, J));
break;
case LLVMContext::MD_range:
-
- // If K does move, use most generic range. Otherwise keep the range of
- // K.
- if (DoesKMove)
- // FIXME: If K does move, we should drop the range info and nonnull.
- // Currently this function is used with DoesKMove in passes
- // doing hoisting/sinking and the current behavior of using the
- // most generic range is correct in those cases.
+ if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD));
break;
case LLVMContext::MD_fpmath:
K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD));
break;
case LLVMContext::MD_invariant_load:
- // Only set the !invariant.load if it is present in both instructions.
- K->setMetadata(Kind, JMD);
+ // If K moves, only set the !invariant.load if it is present in both
+ // instructions.
+ if (DoesKMove)
+ K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_nonnull:
- // If K does move, keep nonull if it is present in both instructions.
- if (DoesKMove)
+ if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_invariant_group:
// Preserve !invariant.group in K.
break;
case LLVMContext::MD_align:
- K->setMetadata(Kind,
- MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
+ if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
+ K->setMetadata(
+ Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
break;
case LLVMContext::MD_dereferenceable:
case LLVMContext::MD_dereferenceable_or_null:
- K->setMetadata(Kind,
- MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
+ if (DoesKMove)
+ K->setMetadata(Kind,
+ MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
break;
case LLVMContext::MD_preserve_access_index:
// Preserve !preserve.access.index in K.
break;
+ case LLVMContext::MD_noundef:
+ // If K does move, keep noundef if it is present in both instructions.
+ if (DoesKMove)
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_nontemporal:
+ // Preserve !nontemporal if it is present on both instructions.
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_prof:
+ if (DoesKMove)
+ K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
+ break;
}
}
// Set !invariant.group from J if J has it. If both instructions have it
@@ -2713,14 +2789,22 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J,
bool KDominatesJ) {
- unsigned KnownIDs[] = {
- LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias, LLVMContext::MD_range,
- LLVMContext::MD_invariant_load, LLVMContext::MD_nonnull,
- LLVMContext::MD_invariant_group, LLVMContext::MD_align,
- LLVMContext::MD_dereferenceable,
- LLVMContext::MD_dereferenceable_or_null,
- LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index};
+ unsigned KnownIDs[] = {LLVMContext::MD_tbaa,
+ LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias,
+ LLVMContext::MD_range,
+ LLVMContext::MD_fpmath,
+ LLVMContext::MD_invariant_load,
+ LLVMContext::MD_nonnull,
+ LLVMContext::MD_invariant_group,
+ LLVMContext::MD_align,
+ LLVMContext::MD_dereferenceable,
+ LLVMContext::MD_dereferenceable_or_null,
+ LLVMContext::MD_access_group,
+ LLVMContext::MD_preserve_access_index,
+ LLVMContext::MD_prof,
+ LLVMContext::MD_nontemporal,
+ LLVMContext::MD_noundef};
combineMetadata(K, J, KnownIDs, KDominatesJ);
}
@@ -2799,13 +2883,7 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
// In general, GVN unifies expressions over different control-flow
// regions, and so we need a conservative combination of the noalias
// scopes.
- static const unsigned KnownIDs[] = {
- LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias, LLVMContext::MD_range,
- LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load,
- LLVMContext::MD_invariant_group, LLVMContext::MD_nonnull,
- LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index};
- combineMetadata(ReplInst, I, KnownIDs, false);
+ combineMetadataForCSE(ReplInst, I, false);
}
template <typename RootType, typename DominatesFn>
@@ -2930,7 +3008,8 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI,
return;
unsigned BitWidth = DL.getPointerTypeSizeInBits(NewTy);
- if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) {
+ if (BitWidth == OldLI.getType()->getScalarSizeInBits() &&
+ !getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) {
MDNode *NN = MDNode::get(OldLI.getContext(), std::nullopt);
NewLI.setMetadata(LLVMContext::MD_nonnull, NN);
}
@@ -2969,7 +3048,7 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt,
for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
Instruction *I = &*II;
- I->dropUndefImplyingAttrsAndUnknownMetadata();
+ I->dropUBImplyingAttrsAndMetadata();
if (I->isUsedByMetadata())
dropDebugUsers(*I);
if (I->isDebugOrPseudoInst()) {
@@ -3125,7 +3204,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// Check that the mask allows a multiple of 8 bits for a bswap, for an
// early exit.
- unsigned NumMaskedBits = AndMask.countPopulation();
+ unsigned NumMaskedBits = AndMask.popcount();
if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
return Result;
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2acbe9002309..d701cf110154 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -345,20 +345,20 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
assert(L.isLoopSimplifyForm() && "Loop needs to be in loop simplify form");
unsigned DesiredPeelCount = 0;
- for (auto *BB : L.blocks()) {
- auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
- if (!BI || BI->isUnconditional())
- continue;
-
- // Ignore loop exit condition.
- if (L.getLoopLatch() == BB)
- continue;
+ // Do not peel the entire loop.
+ const SCEV *BE = SE.getConstantMaxBackedgeTakenCount(&L);
+ if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(BE))
+ MaxPeelCount =
+ std::min((unsigned)SC->getAPInt().getLimitedValue() - 1, MaxPeelCount);
+
+ auto ComputePeelCount = [&](Value *Condition) -> void {
+ if (!Condition->getType()->isIntegerTy())
+ return;
- Value *Condition = BI->getCondition();
Value *LeftVal, *RightVal;
CmpInst::Predicate Pred;
if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal))))
- continue;
+ return;
const SCEV *LeftSCEV = SE.getSCEV(LeftVal);
const SCEV *RightSCEV = SE.getSCEV(RightVal);
@@ -366,7 +366,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
// Do not consider predicates that are known to be true or false
// independently of the loop iteration.
if (SE.evaluatePredicate(Pred, LeftSCEV, RightSCEV))
- continue;
+ return;
// Check if we have a condition with one AddRec and one non AddRec
// expression. Normalize LeftSCEV to be the AddRec.
@@ -375,7 +375,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
std::swap(LeftSCEV, RightSCEV);
Pred = ICmpInst::getSwappedPredicate(Pred);
} else
- continue;
+ return;
}
const SCEVAddRecExpr *LeftAR = cast<SCEVAddRecExpr>(LeftSCEV);
@@ -383,10 +383,10 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
// Avoid huge SCEV computations in the loop below, make sure we only
// consider AddRecs of the loop we are trying to peel.
if (!LeftAR->isAffine() || LeftAR->getLoop() != &L)
- continue;
+ return;
if (!(ICmpInst::isEquality(Pred) && LeftAR->hasNoSelfWrap()) &&
!SE.getMonotonicPredicateType(LeftAR, Pred))
- continue;
+ return;
// Check if extending the current DesiredPeelCount lets us evaluate Pred
// or !Pred in the loop body statically.
@@ -422,7 +422,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
// first iteration of the loop body after peeling?
if (!SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), IterVal,
RightSCEV))
- continue; // If not, give up.
+ return; // If not, give up.
// However, for equality comparisons, that isn't always sufficient to
// eliminate the comparsion in loop body, we may need to peel one more
@@ -433,11 +433,28 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
!SE.isKnownPredicate(Pred, IterVal, RightSCEV) &&
SE.isKnownPredicate(Pred, NextIterVal, RightSCEV)) {
if (!CanPeelOneMoreIteration())
- continue; // Need to peel one more iteration, but can't. Give up.
+ return; // Need to peel one more iteration, but can't. Give up.
PeelOneMoreIteration(); // Great!
}
DesiredPeelCount = std::max(DesiredPeelCount, NewPeelCount);
+ };
+
+ for (BasicBlock *BB : L.blocks()) {
+ for (Instruction &I : *BB) {
+ if (SelectInst *SI = dyn_cast<SelectInst>(&I))
+ ComputePeelCount(SI->getCondition());
+ }
+
+ auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!BI || BI->isUnconditional())
+ continue;
+
+ // Ignore loop exit condition.
+ if (L.getLoopLatch() == BB)
+ continue;
+
+ ComputePeelCount(BI->getCondition());
}
return DesiredPeelCount;
@@ -1025,6 +1042,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
// We modified the loop, update SE.
SE->forgetTopmostLoop(L);
+ SE->forgetBlockAndLoopDispositions();
#ifdef EXPENSIVE_CHECKS
// Finally DomtTree must be correct.
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 1a9eaf242190..d81db5647c60 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -435,6 +435,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// Otherwise, create a duplicate of the instruction.
Instruction *C = Inst->clone();
+ C->insertBefore(LoopEntryBranch);
+
++NumInstrsDuplicated;
// Eagerly remap the operands of the instruction.
@@ -444,7 +446,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// Avoid inserting the same intrinsic twice.
if (auto *DII = dyn_cast<DbgVariableIntrinsic>(C))
if (DbgIntrinsics.count(makeHash(DII))) {
- C->deleteValue();
+ C->eraseFromParent();
continue;
}
@@ -457,7 +459,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// in the map.
InsertNewValueIntoMap(ValueMap, Inst, V);
if (!C->mayHaveSideEffects()) {
- C->deleteValue();
+ C->eraseFromParent();
C = nullptr;
}
} else {
@@ -466,7 +468,6 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
if (C) {
// Otherwise, stick the new instruction into the new block!
C->setName(Inst->getName());
- C->insertBefore(LoopEntryBranch);
if (auto *II = dyn_cast<AssumeInst>(C))
AC->registerAssumption(II);
diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
index 87a0e54e2704..3e604fdf2e11 100644
--- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp
+++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
@@ -448,16 +448,15 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader,
// backedge blocks to jump to the BEBlock instead of the header.
// If one of the backedges has llvm.loop metadata attached, we remove
// it from the backedge and add it to BEBlock.
- unsigned LoopMDKind = BEBlock->getContext().getMDKindID("llvm.loop");
MDNode *LoopMD = nullptr;
for (BasicBlock *BB : BackedgeBlocks) {
Instruction *TI = BB->getTerminator();
if (!LoopMD)
- LoopMD = TI->getMetadata(LoopMDKind);
- TI->setMetadata(LoopMDKind, nullptr);
+ LoopMD = TI->getMetadata(LLVMContext::MD_loop);
+ TI->setMetadata(LLVMContext::MD_loop, nullptr);
TI->replaceSuccessorWith(Header, BEBlock);
}
- BEBlock->getTerminator()->setMetadata(LoopMDKind, LoopMD);
+ BEBlock->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD);
//===--- Update all analyses which we must preserve now -----------------===//
@@ -693,12 +692,6 @@ ReprocessLoop:
}
}
- // Changing exit conditions for blocks may affect exit counts of this loop and
- // any of its paretns, so we must invalidate the entire subtree if we've made
- // any changes.
- if (Changed && SE)
- SE->forgetTopmostLoop(L);
-
if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -737,6 +730,13 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI,
Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, DT, LI, SE,
AC, MSSAU, PreserveLCSSA);
+ // Changing exit conditions for blocks may affect exit counts of this loop and
+ // any of its parents, so we must invalidate the entire subtree if we've made
+ // any changes. Do this here rather than in simplifyOneLoop() as the top-most
+ // loop is going to be the same for all child loops.
+ if (Changed && SE)
+ SE->forgetTopmostLoop(L);
+
return Changed;
}
diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
index e8f585b4a94d..511dd61308f9 100644
--- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
@@ -45,6 +45,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/ValueHandle.h"
@@ -216,6 +217,8 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI,
ScalarEvolution *SE, DominatorTree *DT,
AssumptionCache *AC,
const TargetTransformInfo *TTI) {
+ using namespace llvm::PatternMatch;
+
// Simplify any new induction variables in the partially unrolled loop.
if (SE && SimplifyIVs) {
SmallVector<WeakTrackingVH, 16> DeadInsts;
@@ -241,6 +244,30 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI,
Inst.replaceAllUsesWith(V);
if (isInstructionTriviallyDead(&Inst))
DeadInsts.emplace_back(&Inst);
+
+ // Fold ((add X, C1), C2) to (add X, C1+C2). This is very common in
+ // unrolled loops, and handling this early allows following code to
+ // identify the IV as a "simple recurrence" without first folding away
+ // a long chain of adds.
+ {
+ Value *X;
+ const APInt *C1, *C2;
+ if (match(&Inst, m_Add(m_Add(m_Value(X), m_APInt(C1)), m_APInt(C2)))) {
+ auto *InnerI = dyn_cast<Instruction>(Inst.getOperand(0));
+ auto *InnerOBO = cast<OverflowingBinaryOperator>(Inst.getOperand(0));
+ bool SignedOverflow;
+ APInt NewC = C1->sadd_ov(*C2, SignedOverflow);
+ Inst.setOperand(0, X);
+ Inst.setOperand(1, ConstantInt::get(Inst.getType(), NewC));
+ Inst.setHasNoUnsignedWrap(Inst.hasNoUnsignedWrap() &&
+ InnerOBO->hasNoUnsignedWrap());
+ Inst.setHasNoSignedWrap(Inst.hasNoSignedWrap() &&
+ InnerOBO->hasNoSignedWrap() &&
+ !SignedOverflow);
+ if (InnerI && isInstructionTriviallyDead(InnerI))
+ DeadInsts.emplace_back(InnerI);
+ }
+ }
}
// We can't do recursive deletion until we're done iterating, as we might
// have a phi which (potentially indirectly) uses instructions later in
@@ -310,6 +337,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
const unsigned MaxTripCount = SE->getSmallConstantMaxTripCount(L);
const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
+ unsigned EstimatedLoopInvocationWeight = 0;
+ std::optional<unsigned> OriginalTripCount =
+ llvm::getLoopEstimatedTripCount(L, &EstimatedLoopInvocationWeight);
// Effectively "DCE" unrolled iterations that are beyond the max tripcount
// and will never be executed.
@@ -513,7 +543,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
!EnableFSDiscriminator)
for (BasicBlock *BB : L->getBlocks())
for (Instruction &I : *BB)
- if (!isa<DbgInfoIntrinsic>(&I))
+ if (!I.isDebugOrPseudoInst())
if (const DILocation *DIL = I.getDebugLoc()) {
auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(ULO.Count);
if (NewDIL)
@@ -830,8 +860,16 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
Loop *OuterL = L->getParentLoop();
// Update LoopInfo if the loop is completely removed.
- if (CompletelyUnroll)
+ if (CompletelyUnroll) {
LI->erase(L);
+ // We shouldn't try to use `L` anymore.
+ L = nullptr;
+ } else if (OriginalTripCount) {
+ // Update the trip count. Note that the remainder has already logic
+ // computing it in `UnrollRuntimeLoopRemainder`.
+ setLoopEstimatedTripCount(L, *OriginalTripCount / ULO.Count,
+ EstimatedLoopInvocationWeight);
+ }
// LoopInfo should not be valid, confirm that.
if (UnrollVerifyLoopInfo)
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
index b125e952ec94..31b8cd34eb24 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
@@ -347,7 +347,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
!EnableFSDiscriminator)
for (BasicBlock *BB : L->getBlocks())
for (Instruction &I : *BB)
- if (!isa<DbgInfoIntrinsic>(&I))
+ if (!I.isDebugOrPseudoInst())
if (const DILocation *DIL = I.getDebugLoc()) {
auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
if (NewDIL)
@@ -757,11 +757,11 @@ checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
DependenceInfo &DI, LoopInfo &LI) {
SmallVector<BasicBlockSet, 8> AllBlocks;
for (Loop *L : Root.getLoopsInPreorder())
- if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
+ if (ForeBlocksMap.contains(L))
AllBlocks.push_back(ForeBlocksMap.lookup(L));
AllBlocks.push_back(SubLoopBlocks);
for (Loop *L : Root.getLoopsInPreorder())
- if (AftBlocksMap.find(L) != AftBlocksMap.end())
+ if (AftBlocksMap.contains(L))
AllBlocks.push_back(AftBlocksMap.lookup(L));
unsigned LoopDepth = Root.getLoopDepth();
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index b19156bcb420..1e22eca30d2d 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -457,7 +457,7 @@ static bool canProfitablyUnrollMultiExitLoop(
// call.
return (OtherExits.size() == 1 &&
(UnrollRuntimeOtherExitPredictable ||
- OtherExits[0]->getTerminatingDeoptimizeCall()));
+ OtherExits[0]->getPostdominatingDeoptimizeCall()));
// TODO: These can be fine-tuned further to consider code size or deopt states
// that are captured by the deoptimize exit block.
// Also, we can extend this to support more cases, if we actually
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 7df8651ede15..7d6662c44f07 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -466,6 +466,19 @@ llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) {
return Worklist;
}
+bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
+ int LatchIdx = PN->getBasicBlockIndex(LatchBlock);
+ Value *IncV = PN->getIncomingValue(LatchIdx);
+
+ for (User *U : PN->users())
+ if (U != Cond && U != IncV) return false;
+
+ for (User *U : IncV->users())
+ if (U != Cond && U != PN) return false;
+ return true;
+}
+
+
void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
LoopInfo *LI, MemorySSA *MSSA) {
assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
@@ -628,18 +641,17 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
}
// After the loop has been deleted all the values defined and modified
- // inside the loop are going to be unavailable.
- // Since debug values in the loop have been deleted, inserting an undef
- // dbg.value truncates the range of any dbg.value before the loop where the
- // loop used to be. This is particularly important for constant values.
+ // inside the loop are going to be unavailable. Values computed in the
+ // loop will have been deleted, automatically causing their debug uses
+ // be be replaced with undef. Loop invariant values will still be available.
+ // Move dbg.values out the loop so that earlier location ranges are still
+ // terminated and loop invariant assignments are preserved.
Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI();
assert(InsertDbgValueBefore &&
"There should be a non-PHI instruction in exit block, else these "
"instructions will have no parent.");
- for (auto *DVI : DeadDebugInst) {
- DVI->setKillLocation();
+ for (auto *DVI : DeadDebugInst)
DVI->moveBefore(InsertDbgValueBefore);
- }
}
// Remove the block from the reference counting scheme, so that we can
@@ -880,6 +892,29 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
return true;
}
+Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
+ switch (RK) {
+ default:
+ llvm_unreachable("Unknown min/max recurrence kind");
+ case RecurKind::UMin:
+ return Intrinsic::umin;
+ case RecurKind::UMax:
+ return Intrinsic::umax;
+ case RecurKind::SMin:
+ return Intrinsic::smin;
+ case RecurKind::SMax:
+ return Intrinsic::smax;
+ case RecurKind::FMin:
+ return Intrinsic::minnum;
+ case RecurKind::FMax:
+ return Intrinsic::maxnum;
+ case RecurKind::FMinimum:
+ return Intrinsic::minimum;
+ case RecurKind::FMaximum:
+ return Intrinsic::maximum;
+ }
+}
+
CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
switch (RK) {
default:
@@ -896,6 +931,9 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
return CmpInst::FCMP_OLT;
case RecurKind::FMax:
return CmpInst::FCMP_OGT;
+ // We do not add FMinimum/FMaximum recurrence kind here since there is no
+ // equivalent predicate which compares signed zeroes according to the
+ // semantics of the intrinsics (llvm.minimum/maximum).
}
}
@@ -910,6 +948,14 @@ Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Value *Right) {
+ Type *Ty = Left->getType();
+ if (Ty->isIntOrIntVectorTy() ||
+ (RK == RecurKind::FMinimum || RK == RecurKind::FMaximum)) {
+ // TODO: Add float minnum/maxnum support when FMF nnan is set.
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK);
+ return Builder.CreateIntrinsic(Ty, Id, {Left, Right}, nullptr,
+ "rdx.minmax");
+ }
CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp");
Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select");
@@ -1055,6 +1101,10 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
return Builder.CreateFPMaxReduce(Src);
case RecurKind::FMin:
return Builder.CreateFPMinReduce(Src);
+ case RecurKind::FMinimum:
+ return Builder.CreateFPMinimumReduce(Src);
+ case RecurKind::FMaximum:
+ return Builder.CreateFPMaximumReduce(Src);
default:
llvm_unreachable("Unhandled opcode");
}
@@ -1123,6 +1173,20 @@ bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L,
SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, S, Zero);
}
+bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L,
+ ScalarEvolution &SE) {
+ const SCEV *Zero = SE.getZero(S->getType());
+ return SE.isAvailableAtLoopEntry(S, L) &&
+ SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGT, S, Zero);
+}
+
+bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L,
+ ScalarEvolution &SE) {
+ const SCEV *Zero = SE.getZero(S->getType());
+ return SE.isAvailableAtLoopEntry(S, L) &&
+ SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLE, S, Zero);
+}
+
bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
bool Signed) {
unsigned BitWidth = cast<IntegerType>(S->getType())->getBitWidth();
diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index 17e71cf5a6c4..78ebe75c121b 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -23,7 +23,6 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -31,6 +30,8 @@
using namespace llvm;
+#define DEBUG_TYPE "loop-versioning"
+
static cl::opt<bool>
AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
cl::Hidden,
@@ -208,7 +209,7 @@ void LoopVersioning::prepareNoAliasMetadata() {
// Finally, transform the above to actually map to scope list which is what
// the metadata uses.
- for (auto Pair : GroupToNonAliasingScopes)
+ for (const auto &Pair : GroupToNonAliasingScopes)
GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
}
@@ -290,56 +291,6 @@ bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT,
return Changed;
}
-
-/// Also expose this is a pass. Currently this is only used for
-/// unit-testing. It adds all memchecks necessary to remove all may-aliasing
-/// array accesses from the loop.
-class LoopVersioningLegacyPass : public FunctionPass {
-public:
- LoopVersioningLegacyPass() : FunctionPass(ID) {
- initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
-
- return runImpl(LI, LAIs, DT, SE);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addRequired<LoopAccessLegacyAnalysis>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- }
-
- static char ID;
-};
-}
-
-#define LVER_OPTION "loop-versioning"
-#define DEBUG_TYPE LVER_OPTION
-
-char LoopVersioningLegacyPass::ID;
-static const char LVer_name[] = "Loop Versioning";
-
-INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
- false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
- false)
-
-namespace llvm {
-FunctionPass *createLoopVersioningLegacyPass() {
- return new LoopVersioningLegacyPass();
}
PreservedAnalyses LoopVersioningPass::run(Function &F,
@@ -353,4 +304,3 @@ PreservedAnalyses LoopVersioningPass::run(Function &F,
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
-} // namespace llvm
diff --git a/llvm/lib/Transforms/Utils/LowerAtomic.cpp b/llvm/lib/Transforms/Utils/LowerAtomic.cpp
index b6f40de0daa6..b203970ef9c5 100644
--- a/llvm/lib/Transforms/Utils/LowerAtomic.cpp
+++ b/llvm/lib/Transforms/Utils/LowerAtomic.cpp
@@ -14,8 +14,7 @@
#include "llvm/Transforms/Utils/LowerAtomic.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
+
using namespace llvm;
#define DEBUG_TYPE "loweratomic"
@@ -102,6 +101,9 @@ Value *llvm::buildAtomicRMWValue(AtomicRMWInst::BinOp Op,
bool llvm::lowerAtomicRMWInst(AtomicRMWInst *RMWI) {
IRBuilder<> Builder(RMWI);
+ Builder.setIsFPConstrained(
+ RMWI->getFunction()->hasFnAttribute(Attribute::StrictFP));
+
Value *Ptr = RMWI->getPointerOperand();
Value *Val = RMWI->getValOperand();
diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
index 165740b55298..906eb71fc2d9 100644
--- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
@@ -12,9 +12,12 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MDBuilder.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <optional>
+#define DEBUG_TYPE "lower-mem-intrinsics"
+
using namespace llvm;
void llvm::createMemCpyLoopKnownSize(
@@ -376,19 +379,14 @@ void llvm::createMemCpyLoopUnknownSize(
static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr,
Value *DstAddr, Value *CopyLen, Align SrcAlign,
Align DstAlign, bool SrcIsVolatile,
- bool DstIsVolatile) {
+ bool DstIsVolatile,
+ const TargetTransformInfo &TTI) {
Type *TypeOfCopyLen = CopyLen->getType();
BasicBlock *OrigBB = InsertBefore->getParent();
Function *F = OrigBB->getParent();
const DataLayout &DL = F->getParent()->getDataLayout();
-
// TODO: Use different element type if possible?
- IRBuilder<> CastBuilder(InsertBefore);
- Type *EltTy = CastBuilder.getInt8Ty();
- Type *PtrTy =
- CastBuilder.getInt8PtrTy(SrcAddr->getType()->getPointerAddressSpace());
- SrcAddr = CastBuilder.CreateBitCast(SrcAddr, PtrTy);
- DstAddr = CastBuilder.CreateBitCast(DstAddr, PtrTy);
+ Type *EltTy = Type::getInt8Ty(F->getContext());
// Create the a comparison of src and dst, based on which we jump to either
// the forward-copy part of the function (if src >= dst) or the backwards-copy
@@ -428,6 +426,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr,
BasicBlock *LoopBB =
BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
IRBuilder<> LoopBuilder(LoopBB);
+
PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
Value *IndexPtr = LoopBuilder.CreateSub(
LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
@@ -552,15 +551,57 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
}
}
-void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) {
- createMemMoveLoop(/* InsertBefore */ Memmove,
- /* SrcAddr */ Memmove->getRawSource(),
- /* DstAddr */ Memmove->getRawDest(),
- /* CopyLen */ Memmove->getLength(),
- /* SrcAlign */ Memmove->getSourceAlign().valueOrOne(),
- /* DestAlign */ Memmove->getDestAlign().valueOrOne(),
- /* SrcIsVolatile */ Memmove->isVolatile(),
- /* DstIsVolatile */ Memmove->isVolatile());
+bool llvm::expandMemMoveAsLoop(MemMoveInst *Memmove,
+ const TargetTransformInfo &TTI) {
+ Value *CopyLen = Memmove->getLength();
+ Value *SrcAddr = Memmove->getRawSource();
+ Value *DstAddr = Memmove->getRawDest();
+ Align SrcAlign = Memmove->getSourceAlign().valueOrOne();
+ Align DstAlign = Memmove->getDestAlign().valueOrOne();
+ bool SrcIsVolatile = Memmove->isVolatile();
+ bool DstIsVolatile = SrcIsVolatile;
+ IRBuilder<> CastBuilder(Memmove);
+
+ unsigned SrcAS = SrcAddr->getType()->getPointerAddressSpace();
+ unsigned DstAS = DstAddr->getType()->getPointerAddressSpace();
+ if (SrcAS != DstAS) {
+ if (!TTI.addrspacesMayAlias(SrcAS, DstAS)) {
+ // We may not be able to emit a pointer comparison, but we don't have
+ // to. Expand as memcpy.
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(CopyLen)) {
+ createMemCpyLoopKnownSize(/*InsertBefore=*/Memmove, SrcAddr, DstAddr,
+ CI, SrcAlign, DstAlign, SrcIsVolatile,
+ DstIsVolatile,
+ /*CanOverlap=*/false, TTI);
+ } else {
+ createMemCpyLoopUnknownSize(/*InsertBefore=*/Memmove, SrcAddr, DstAddr,
+ CopyLen, SrcAlign, DstAlign, SrcIsVolatile,
+ DstIsVolatile,
+ /*CanOverlap=*/false, TTI);
+ }
+
+ return true;
+ }
+
+ if (TTI.isValidAddrSpaceCast(DstAS, SrcAS))
+ DstAddr = CastBuilder.CreateAddrSpaceCast(DstAddr, SrcAddr->getType());
+ else if (TTI.isValidAddrSpaceCast(SrcAS, DstAS))
+ SrcAddr = CastBuilder.CreateAddrSpaceCast(SrcAddr, DstAddr->getType());
+ else {
+ // We don't know generically if it's legal to introduce an
+ // addrspacecast. We need to know either if it's legal to insert an
+ // addrspacecast, or if the address spaces cannot alias.
+ LLVM_DEBUG(
+ dbgs() << "Do not know how to expand memmove between different "
+ "address spaces\n");
+ return false;
+ }
+ }
+
+ createMemMoveLoop(
+ /*InsertBefore=*/Memmove, SrcAddr, DstAddr, CopyLen, SrcAlign, DstAlign,
+ SrcIsVolatile, DstIsVolatile, TTI);
+ return true;
}
void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
diff --git a/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/llvm/lib/Transforms/Utils/Mem2Reg.cpp
index 5ad7aeb463ec..fbc6dd7613de 100644
--- a/llvm/lib/Transforms/Utils/Mem2Reg.cpp
+++ b/llvm/lib/Transforms/Utils/Mem2Reg.cpp
@@ -74,15 +74,19 @@ namespace {
struct PromoteLegacyPass : public FunctionPass {
// Pass identification, replacement for typeid
static char ID;
+ bool ForcePass; /// If true, forces pass to execute, instead of skipping.
- PromoteLegacyPass() : FunctionPass(ID) {
+ PromoteLegacyPass() : FunctionPass(ID), ForcePass(false) {
+ initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+ PromoteLegacyPass(bool IsForced) : FunctionPass(ID), ForcePass(IsForced) {
initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry());
}
// runOnFunction - To run this pass, first we calculate the alloca
// instructions that are safe for promotion, then we promote each one.
bool runOnFunction(Function &F) override {
- if (skipFunction(F))
+ if (!ForcePass && skipFunction(F))
return false;
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
@@ -111,6 +115,6 @@ INITIALIZE_PASS_END(PromoteLegacyPass, "mem2reg", "Promote Memory to Register",
false, false)
// createPromoteMemoryToRegister - Provide an entry point to create this pass.
-FunctionPass *llvm::createPromoteMemoryToRegisterPass() {
- return new PromoteLegacyPass();
+FunctionPass *llvm::createPromoteMemoryToRegisterPass(bool IsForced) {
+ return new PromoteLegacyPass(IsForced);
}
diff --git a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp
index 899928c085c6..531b0a624daf 100644
--- a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp
+++ b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/MemoryOpRemark.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DebugInfo.h"
@@ -321,7 +322,7 @@ void MemoryOpRemark::visitVariable(const Value *V,
// Try to get an llvm.dbg.declare, which has a DILocalVariable giving us the
// real debug info name and size of the variable.
for (const DbgVariableIntrinsic *DVI :
- FindDbgAddrUses(const_cast<Value *>(V))) {
+ FindDbgDeclareUses(const_cast<Value *>(V))) {
if (DILocalVariable *DILV = DVI->getVariable()) {
std::optional<uint64_t> DISize = getSizeInBytes(DILV->getSizeInBits());
VariableInfo Var{DILV->getName(), DISize};
@@ -387,7 +388,8 @@ bool AutoInitRemark::canHandle(const Instruction *I) {
return false;
return any_of(I->getMetadata(LLVMContext::MD_annotation)->operands(),
[](const MDOperand &Op) {
- return cast<MDString>(Op.get())->getString() == "auto-init";
+ return isa<MDString>(Op.get()) &&
+ cast<MDString>(Op.get())->getString() == "auto-init";
});
}
diff --git a/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/llvm/lib/Transforms/Utils/MetaRenamer.cpp
index 0ea210671b93..44ac65f265f0 100644
--- a/llvm/lib/Transforms/Utils/MetaRenamer.cpp
+++ b/llvm/lib/Transforms/Utils/MetaRenamer.cpp
@@ -26,14 +26,12 @@
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/TypeFinder.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Transforms/Utils.h"
using namespace llvm;
@@ -62,6 +60,11 @@ static cl::opt<std::string> RenameExcludeStructPrefixes(
"by a comma"),
cl::Hidden);
+static cl::opt<bool>
+ RenameOnlyInst("rename-only-inst", cl::init(false),
+ cl::desc("only rename the instructions in the function"),
+ cl::Hidden);
+
static const char *const metaNames[] = {
// See http://en.wikipedia.org/wiki/Metasyntactic_variable
"foo", "bar", "baz", "quux", "barney", "snork", "zot", "blam", "hoge",
@@ -105,6 +108,12 @@ parseExcludedPrefixes(StringRef PrefixesStr,
}
}
+void MetaRenameOnlyInstructions(Function &F) {
+ for (auto &I : instructions(F))
+ if (!I.getType()->isVoidTy() && I.getName().empty())
+ I.setName(I.getOpcodeName());
+}
+
void MetaRename(Function &F) {
for (Argument &Arg : F.args())
if (!Arg.getType()->isVoidTy())
@@ -115,7 +124,7 @@ void MetaRename(Function &F) {
for (auto &I : BB)
if (!I.getType()->isVoidTy())
- I.setName("tmp");
+ I.setName(I.getOpcodeName());
}
}
@@ -145,6 +154,26 @@ void MetaRename(Module &M,
[&Name](auto &Prefix) { return Name.startswith(Prefix); });
};
+ // Leave library functions alone because their presence or absence could
+ // affect the behavior of other passes.
+ auto ExcludeLibFuncs = [&](Function &F) {
+ LibFunc Tmp;
+ StringRef Name = F.getName();
+ return Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
+ GetTLI(F).getLibFunc(F, Tmp) ||
+ IsNameExcluded(Name, ExcludedFuncPrefixes);
+ };
+
+ if (RenameOnlyInst) {
+ // Rename all functions
+ for (auto &F : M) {
+ if (ExcludeLibFuncs(F))
+ continue;
+ MetaRenameOnlyInstructions(F);
+ }
+ return;
+ }
+
// Rename all aliases
for (GlobalAlias &GA : M.aliases()) {
StringRef Name = GA.getName();
@@ -181,64 +210,20 @@ void MetaRename(Module &M,
// Rename all functions
for (auto &F : M) {
- StringRef Name = F.getName();
- LibFunc Tmp;
- // Leave library functions alone because their presence or absence could
- // affect the behavior of other passes.
- if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
- GetTLI(F).getLibFunc(F, Tmp) ||
- IsNameExcluded(Name, ExcludedFuncPrefixes))
+ if (ExcludeLibFuncs(F))
continue;
// Leave @main alone. The output of -metarenamer might be passed to
// lli for execution and the latter needs a main entry point.
- if (Name != "main")
+ if (F.getName() != "main")
F.setName(renamer.newName());
MetaRename(F);
}
}
-struct MetaRenamer : public ModulePass {
- // Pass identification, replacement for typeid
- static char ID;
-
- MetaRenamer() : ModulePass(ID) {
- initializeMetaRenamerPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.setPreservesAll();
- }
-
- bool runOnModule(Module &M) override {
- auto GetTLI = [this](Function &F) -> TargetLibraryInfo & {
- return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- };
- MetaRename(M, GetTLI);
- return true;
- }
-};
-
} // end anonymous namespace
-char MetaRenamer::ID = 0;
-
-INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer",
- "Assign new names to everything", false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(MetaRenamer, "metarenamer",
- "Assign new names to everything", false, false)
-
-//===----------------------------------------------------------------------===//
-//
-// MetaRenamer - Rename everything with metasyntactic names.
-//
-ModulePass *llvm::createMetaRenamerPass() {
- return new MetaRenamer();
-}
-
PreservedAnalyses MetaRenamerPass::run(Module &M, ModuleAnalysisManager &AM) {
FunctionAnalysisManager &FAM =
AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index 6d17a466957e..1e243ef74df7 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -12,6 +12,7 @@
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -19,6 +20,7 @@
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/xxhash.h"
+
using namespace llvm;
#define DEBUG_TYPE "moduleutils"
@@ -31,11 +33,9 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F,
// Get the current set of static global constructors and add the new ctor
// to the list.
SmallVector<Constant *, 16> CurrentCtors;
- StructType *EltTy = StructType::get(
- IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()),
- IRB.getInt8PtrTy());
-
+ StructType *EltTy;
if (GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName)) {
+ EltTy = cast<StructType>(GVCtor->getValueType()->getArrayElementType());
if (Constant *Init = GVCtor->getInitializer()) {
unsigned n = Init->getNumOperands();
CurrentCtors.reserve(n + 1);
@@ -43,6 +43,10 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F,
CurrentCtors.push_back(cast<Constant>(Init->getOperand(i)));
}
GVCtor->eraseFromParent();
+ } else {
+ EltTy = StructType::get(
+ IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()),
+ IRB.getInt8PtrTy());
}
// Build a 3 field global_ctor entry. We don't take a comdat key.
@@ -390,9 +394,7 @@ bool llvm::lowerGlobalIFuncUsersAsGlobalCtor(
const DataLayout &DL = M.getDataLayout();
PointerType *TableEntryTy =
- Ctx.supportsTypedPointers()
- ? PointerType::get(Type::getInt8Ty(Ctx), DL.getProgramAddressSpace())
- : PointerType::get(Ctx, DL.getProgramAddressSpace());
+ PointerType::get(Ctx, DL.getProgramAddressSpace());
ArrayType *FuncPtrTableTy =
ArrayType::get(TableEntryTy, IFuncsToLower.size());
@@ -462,9 +464,7 @@ bool llvm::lowerGlobalIFuncUsersAsGlobalCtor(
InitBuilder.CreateRetVoid();
- PointerType *ConstantDataTy = Ctx.supportsTypedPointers()
- ? PointerType::get(Type::getInt8Ty(Ctx), 0)
- : PointerType::get(Ctx, 0);
+ PointerType *ConstantDataTy = PointerType::get(Ctx, 0);
// TODO: Is this the right priority? Probably should be before any other
// constructors?
diff --git a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp
new file mode 100644
index 000000000000..b0ca0b15c08e
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp
@@ -0,0 +1,231 @@
+//===-- MoveAutoInit.cpp - move auto-init inst closer to their use site----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass moves instruction maked as auto-init closer to the basic block that
+// use it, eventually removing it from some control path of the function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/MoveAutoInit.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "move-auto-init"
+
+STATISTIC(NumMoved, "Number of instructions moved");
+
+static cl::opt<unsigned> MoveAutoInitThreshold(
+ "move-auto-init-threshold", cl::Hidden, cl::init(128),
+ cl::desc("Maximum instructions to analyze per moved initialization"));
+
+static bool hasAutoInitMetadata(const Instruction &I) {
+ return I.hasMetadata(LLVMContext::MD_annotation) &&
+ any_of(I.getMetadata(LLVMContext::MD_annotation)->operands(),
+ [](const MDOperand &Op) { return Op.equalsStr("auto-init"); });
+}
+
+static std::optional<MemoryLocation> writeToAlloca(const Instruction &I) {
+ MemoryLocation ML;
+ if (auto *MI = dyn_cast<MemIntrinsic>(&I))
+ ML = MemoryLocation::getForDest(MI);
+ else if (auto *SI = dyn_cast<StoreInst>(&I))
+ ML = MemoryLocation::get(SI);
+ else
+ assert(false && "memory location set");
+
+ if (isa<AllocaInst>(getUnderlyingObject(ML.Ptr)))
+ return ML;
+ else
+ return {};
+}
+
+/// Finds a BasicBlock in the CFG where instruction `I` can be moved to while
+/// not changing the Memory SSA ordering and being guarded by at least one
+/// condition.
+static BasicBlock *usersDominator(const MemoryLocation &ML, Instruction *I,
+ DominatorTree &DT, MemorySSA &MSSA) {
+ BasicBlock *CurrentDominator = nullptr;
+ MemoryUseOrDef &IMA = *MSSA.getMemoryAccess(I);
+ BatchAAResults AA(MSSA.getAA());
+
+ SmallPtrSet<MemoryAccess *, 8> Visited;
+
+ auto AsMemoryAccess = [](User *U) { return cast<MemoryAccess>(U); };
+ SmallVector<MemoryAccess *> WorkList(map_range(IMA.users(), AsMemoryAccess));
+
+ while (!WorkList.empty()) {
+ MemoryAccess *MA = WorkList.pop_back_val();
+ if (!Visited.insert(MA).second)
+ continue;
+
+ if (Visited.size() > MoveAutoInitThreshold)
+ return nullptr;
+
+ bool FoundClobberingUser = false;
+ if (auto *M = dyn_cast<MemoryUseOrDef>(MA)) {
+ Instruction *MI = M->getMemoryInst();
+
+ // If this memory instruction may not clobber `I`, we can skip it.
+ // LifetimeEnd is a valid user, but we do not want it in the user
+ // dominator.
+ if (AA.getModRefInfo(MI, ML) != ModRefInfo::NoModRef &&
+ !MI->isLifetimeStartOrEnd() && MI != I) {
+ FoundClobberingUser = true;
+ CurrentDominator = CurrentDominator
+ ? DT.findNearestCommonDominator(CurrentDominator,
+ MI->getParent())
+ : MI->getParent();
+ }
+ }
+ if (!FoundClobberingUser) {
+ auto UsersAsMemoryAccesses = map_range(MA->users(), AsMemoryAccess);
+ append_range(WorkList, UsersAsMemoryAccesses);
+ }
+ }
+ return CurrentDominator;
+}
+
+static bool runMoveAutoInit(Function &F, DominatorTree &DT, MemorySSA &MSSA) {
+ BasicBlock &EntryBB = F.getEntryBlock();
+ SmallVector<std::pair<Instruction *, BasicBlock *>> JobList;
+
+ //
+ // Compute movable instructions.
+ //
+ for (Instruction &I : EntryBB) {
+ if (!hasAutoInitMetadata(I))
+ continue;
+
+ std::optional<MemoryLocation> ML = writeToAlloca(I);
+ if (!ML)
+ continue;
+
+ if (I.isVolatile())
+ continue;
+
+ BasicBlock *UsersDominator = usersDominator(ML.value(), &I, DT, MSSA);
+ if (!UsersDominator)
+ continue;
+
+ if (UsersDominator == &EntryBB)
+ continue;
+
+ // Traverse the CFG to detect cycles `UsersDominator` would be part of.
+ SmallPtrSet<BasicBlock *, 8> TransitiveSuccessors;
+ SmallVector<BasicBlock *> WorkList(successors(UsersDominator));
+ bool HasCycle = false;
+ while (!WorkList.empty()) {
+ BasicBlock *CurrBB = WorkList.pop_back_val();
+ if (CurrBB == UsersDominator)
+ // No early exit because we want to compute the full set of transitive
+ // successors.
+ HasCycle = true;
+ for (BasicBlock *Successor : successors(CurrBB)) {
+ if (!TransitiveSuccessors.insert(Successor).second)
+ continue;
+ WorkList.push_back(Successor);
+ }
+ }
+
+ // Don't insert if that could create multiple execution of I,
+ // but we can insert it in the non back-edge predecessors, if it exists.
+ if (HasCycle) {
+ BasicBlock *UsersDominatorHead = UsersDominator;
+ while (BasicBlock *UniquePredecessor =
+ UsersDominatorHead->getUniquePredecessor())
+ UsersDominatorHead = UniquePredecessor;
+
+ if (UsersDominatorHead == &EntryBB)
+ continue;
+
+ BasicBlock *DominatingPredecessor = nullptr;
+ for (BasicBlock *Pred : predecessors(UsersDominatorHead)) {
+ // If one of the predecessor of the dominator also transitively is a
+ // successor, moving to the dominator would do the inverse of loop
+ // hoisting, and we don't want that.
+ if (TransitiveSuccessors.count(Pred))
+ continue;
+
+ DominatingPredecessor =
+ DominatingPredecessor
+ ? DT.findNearestCommonDominator(DominatingPredecessor, Pred)
+ : Pred;
+ }
+
+ if (!DominatingPredecessor || DominatingPredecessor == &EntryBB)
+ continue;
+
+ UsersDominator = DominatingPredecessor;
+ }
+
+ // CatchSwitchInst blocks can only have one instruction, so they are not
+ // good candidates for insertion.
+ while (isa<CatchSwitchInst>(UsersDominator->getFirstInsertionPt())) {
+ for (BasicBlock *Pred : predecessors(UsersDominator))
+ UsersDominator = DT.findNearestCommonDominator(UsersDominator, Pred);
+ }
+
+ // We finally found a place where I can be moved while not introducing extra
+ // execution, and guarded by at least one condition.
+ if (UsersDominator != &EntryBB)
+ JobList.emplace_back(&I, UsersDominator);
+ }
+
+ //
+ // Perform the actual substitution.
+ //
+ if (JobList.empty())
+ return false;
+
+ MemorySSAUpdater MSSAU(&MSSA);
+
+ // Reverse insertion to respect relative order between instructions:
+ // if two instructions are moved from the same BB to the same BB, we insert
+ // the second one in the front, then the first on top of it.
+ for (auto &Job : reverse(JobList)) {
+ Job.first->moveBefore(&*Job.second->getFirstInsertionPt());
+ MSSAU.moveToPlace(MSSA.getMemoryAccess(Job.first), Job.first->getParent(),
+ MemorySSA::InsertionPlace::Beginning);
+ }
+
+ if (VerifyMemorySSA)
+ MSSA.verifyMemorySSA();
+
+ NumMoved += JobList.size();
+
+ return true;
+}
+
+PreservedAnalyses MoveAutoInitPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
+ if (!runMoveAutoInit(F, DT, MSSA))
+ return PreservedAnalyses::all();
+
+ PreservedAnalyses PA;
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<MemorySSAAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
+}
diff --git a/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp b/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp
index d4ab4504064f..f41a14cdfbec 100644
--- a/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp
+++ b/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp
@@ -14,8 +14,6 @@
#include "llvm/Transforms/Utils/NameAnonGlobals.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/MD5.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
index 75ea9dc5dfc0..2e5f40d39912 100644
--- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
+++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
@@ -118,19 +118,28 @@ public:
/// Update assignment tracking debug info given for the to-be-deleted store
/// \p ToDelete that stores to this alloca.
- void updateForDeletedStore(StoreInst *ToDelete, DIBuilder &DIB) const {
+ void updateForDeletedStore(
+ StoreInst *ToDelete, DIBuilder &DIB,
+ SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) const {
// There's nothing to do if the alloca doesn't have any variables using
// assignment tracking.
- if (DbgAssigns.empty()) {
- assert(at::getAssignmentMarkers(ToDelete).empty());
+ if (DbgAssigns.empty())
return;
- }
- // Just leave dbg.assign intrinsics in place and remember that we've seen
- // one for each variable fragment.
- SmallSet<DebugVariable, 2> VarHasDbgAssignForStore;
- for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(ToDelete))
- VarHasDbgAssignForStore.insert(DebugVariable(DAI));
+ // Insert a dbg.value where the linked dbg.assign is and remember to delete
+ // the dbg.assign later. Demoting to dbg.value isn't necessary for
+ // correctness but does reduce compile time and memory usage by reducing
+ // unnecessary function-local metadata. Remember that we've seen a
+ // dbg.assign for each variable fragment for the untracked store handling
+ // (after this loop).
+ SmallSet<DebugVariableAggregate, 2> VarHasDbgAssignForStore;
+ for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(ToDelete)) {
+ VarHasDbgAssignForStore.insert(DebugVariableAggregate(DAI));
+ DbgAssignsToDelete->insert(DAI);
+ DIB.insertDbgValueIntrinsic(DAI->getValue(), DAI->getVariable(),
+ DAI->getExpression(), DAI->getDebugLoc(),
+ DAI);
+ }
// It's possible for variables using assignment tracking to have no
// dbg.assign linked to this store. These are variables in DbgAssigns that
@@ -141,7 +150,7 @@ public:
// size) or one that is trackable but has had its DIAssignID attachment
// dropped accidentally.
for (auto *DAI : DbgAssigns) {
- if (VarHasDbgAssignForStore.contains(DebugVariable(DAI)))
+ if (VarHasDbgAssignForStore.contains(DebugVariableAggregate(DAI)))
continue;
ConvertDebugDeclareToDebugValue(DAI, ToDelete, DIB);
}
@@ -324,6 +333,9 @@ struct PromoteMem2Reg {
/// For each alloca, keep an instance of a helper class that gives us an easy
/// way to update assignment tracking debug info if the alloca is promoted.
SmallVector<AssignmentTrackingInfo, 8> AllocaATInfo;
+ /// A set of dbg.assigns to delete because they've been demoted to
+ /// dbg.values. Call cleanUpDbgAssigns to delete them.
+ SmallSet<DbgAssignIntrinsic *, 8> DbgAssignsToDelete;
/// The set of basic blocks the renamer has already visited.
SmallPtrSet<BasicBlock *, 16> Visited;
@@ -367,6 +379,13 @@ private:
RenamePassData::LocationVector &IncLocs,
std::vector<RenamePassData> &Worklist);
bool QueuePhiNode(BasicBlock *BB, unsigned AllocaIdx, unsigned &Version);
+
+ /// Delete dbg.assigns that have been demoted to dbg.values.
+ void cleanUpDbgAssigns() {
+ for (auto *DAI : DbgAssignsToDelete)
+ DAI->eraseFromParent();
+ DbgAssignsToDelete.clear();
+ }
};
} // end anonymous namespace
@@ -438,9 +457,10 @@ static void removeIntrinsicUsers(AllocaInst *AI) {
/// false there were some loads which were not dominated by the single store
/// and thus must be phi-ed with undef. We fall back to the standard alloca
/// promotion algorithm in that case.
-static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
- LargeBlockInfo &LBI, const DataLayout &DL,
- DominatorTree &DT, AssumptionCache *AC) {
+static bool rewriteSingleStoreAlloca(
+ AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, const DataLayout &DL,
+ DominatorTree &DT, AssumptionCache *AC,
+ SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) {
StoreInst *OnlyStore = Info.OnlyStore;
bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0));
BasicBlock *StoreBB = OnlyStore->getParent();
@@ -500,7 +520,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false);
// Update assignment tracking info for the store we're going to delete.
- Info.AssignmentTracking.updateForDeletedStore(Info.OnlyStore, DIB);
+ Info.AssignmentTracking.updateForDeletedStore(Info.OnlyStore, DIB,
+ DbgAssignsToDelete);
// Record debuginfo for the store and remove the declaration's
// debuginfo.
@@ -540,11 +561,10 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info,
/// use(t);
/// *A = 42;
/// }
-static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
- LargeBlockInfo &LBI,
- const DataLayout &DL,
- DominatorTree &DT,
- AssumptionCache *AC) {
+static bool promoteSingleBlockAlloca(
+ AllocaInst *AI, const AllocaInfo &Info, LargeBlockInfo &LBI,
+ const DataLayout &DL, DominatorTree &DT, AssumptionCache *AC,
+ SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) {
// The trickiest case to handle is when we have large blocks. Because of this,
// this code is optimized assuming that large blocks happen. This does not
// significantly pessimize the small block case. This uses LargeBlockInfo to
@@ -608,7 +628,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info,
while (!AI->use_empty()) {
StoreInst *SI = cast<StoreInst>(AI->user_back());
// Update assignment tracking info for the store we're going to delete.
- Info.AssignmentTracking.updateForDeletedStore(SI, DIB);
+ Info.AssignmentTracking.updateForDeletedStore(SI, DIB, DbgAssignsToDelete);
// Record debuginfo for the store before removing it.
for (DbgVariableIntrinsic *DII : Info.DbgUsers) {
if (DII->isAddressOfVariable()) {
@@ -668,7 +688,8 @@ void PromoteMem2Reg::run() {
// If there is only a single store to this value, replace any loads of
// it that are directly dominated by the definition with the value stored.
if (Info.DefiningBlocks.size() == 1) {
- if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC)) {
+ if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC,
+ &DbgAssignsToDelete)) {
// The alloca has been processed, move on.
RemoveFromAllocasList(AllocaNum);
++NumSingleStore;
@@ -679,7 +700,8 @@ void PromoteMem2Reg::run() {
// If the alloca is only read and written in one basic block, just perform a
// linear sweep over the block to eliminate it.
if (Info.OnlyUsedInOneBlock &&
- promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC)) {
+ promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC,
+ &DbgAssignsToDelete)) {
// The alloca has been processed, move on.
RemoveFromAllocasList(AllocaNum);
continue;
@@ -728,9 +750,10 @@ void PromoteMem2Reg::run() {
QueuePhiNode(BB, AllocaNum, CurrentVersion);
}
- if (Allocas.empty())
+ if (Allocas.empty()) {
+ cleanUpDbgAssigns();
return; // All of the allocas must have been trivial!
-
+ }
LBI.clear();
// Set the incoming values for the basic block to be null values for all of
@@ -812,7 +835,7 @@ void PromoteMem2Reg::run() {
// code. Unfortunately, there may be unreachable blocks which the renamer
// hasn't traversed. If this is the case, the PHI nodes may not
// have incoming values for all predecessors. Loop over all PHI nodes we have
- // created, inserting undef values if they are missing any incoming values.
+ // created, inserting poison values if they are missing any incoming values.
for (DenseMap<std::pair<unsigned, unsigned>, PHINode *>::iterator
I = NewPhiNodes.begin(),
E = NewPhiNodes.end();
@@ -862,13 +885,14 @@ void PromoteMem2Reg::run() {
BasicBlock::iterator BBI = BB->begin();
while ((SomePHI = dyn_cast<PHINode>(BBI++)) &&
SomePHI->getNumIncomingValues() == NumBadPreds) {
- Value *UndefVal = UndefValue::get(SomePHI->getType());
+ Value *PoisonVal = PoisonValue::get(SomePHI->getType());
for (BasicBlock *Pred : Preds)
- SomePHI->addIncoming(UndefVal, Pred);
+ SomePHI->addIncoming(PoisonVal, Pred);
}
}
NewPhiNodes.clear();
+ cleanUpDbgAssigns();
}
/// Determine which blocks the value is live in.
@@ -1072,7 +1096,8 @@ NextIteration:
// Record debuginfo for the store before removing it.
IncomingLocs[AllocaNo] = SI->getDebugLoc();
- AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB);
+ AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB,
+ &DbgAssignsToDelete);
for (DbgVariableIntrinsic *DII : AllocaDbgUsers[ai->second])
if (DII->isAddressOfVariable())
ConvertDebugDeclareToDebugValue(DII, SI, DIB);
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 8d03a0d8a2c4..de3626a24212 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -17,6 +17,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@@ -41,6 +42,14 @@ static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() {
MaxNumRangeExtensions);
}
+static ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty,
+ bool UndefAllowed = true) {
+ assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector");
+ if (LV.isConstantRange(UndefAllowed))
+ return LV.getConstantRange();
+ return ConstantRange::getFull(Ty->getScalarSizeInBits());
+}
+
namespace llvm {
bool SCCPSolver::isConstant(const ValueLatticeElement &LV) {
@@ -65,30 +74,9 @@ static bool canRemoveInstruction(Instruction *I) {
}
bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
- Constant *Const = nullptr;
- if (V->getType()->isStructTy()) {
- std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V);
- if (llvm::any_of(IVs, isOverdefined))
- return false;
- std::vector<Constant *> ConstVals;
- auto *ST = cast<StructType>(V->getType());
- for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
- ValueLatticeElement V = IVs[i];
- ConstVals.push_back(SCCPSolver::isConstant(V)
- ? getConstant(V)
- : UndefValue::get(ST->getElementType(i)));
- }
- Const = ConstantStruct::get(ST, ConstVals);
- } else {
- const ValueLatticeElement &IV = getLatticeValueFor(V);
- if (isOverdefined(IV))
- return false;
-
- Const = SCCPSolver::isConstant(IV) ? getConstant(IV)
- : UndefValue::get(V->getType());
- }
- assert(Const && "Constant is nullptr here!");
-
+ Constant *Const = getConstantOrNull(V);
+ if (!Const)
+ return false;
// Replacing `musttail` instructions with constant breaks `musttail` invariant
// unless the call itself can be removed.
// Calls with "clang.arc.attachedcall" implicitly use the return value and
@@ -115,6 +103,47 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
return true;
}
+/// Try to use \p Inst's value range from \p Solver to infer the NUW flag.
+static bool refineInstruction(SCCPSolver &Solver,
+ const SmallPtrSetImpl<Value *> &InsertedValues,
+ Instruction &Inst) {
+ if (!isa<OverflowingBinaryOperator>(Inst))
+ return false;
+
+ auto GetRange = [&Solver, &InsertedValues](Value *Op) {
+ if (auto *Const = dyn_cast<ConstantInt>(Op))
+ return ConstantRange(Const->getValue());
+ if (isa<Constant>(Op) || InsertedValues.contains(Op)) {
+ unsigned Bitwidth = Op->getType()->getScalarSizeInBits();
+ return ConstantRange::getFull(Bitwidth);
+ }
+ return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(),
+ /*UndefAllowed=*/false);
+ };
+ auto RangeA = GetRange(Inst.getOperand(0));
+ auto RangeB = GetRange(Inst.getOperand(1));
+ bool Changed = false;
+ if (!Inst.hasNoUnsignedWrap()) {
+ auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::BinaryOps(Inst.getOpcode()), RangeB,
+ OverflowingBinaryOperator::NoUnsignedWrap);
+ if (NUWRange.contains(RangeA)) {
+ Inst.setHasNoUnsignedWrap();
+ Changed = true;
+ }
+ }
+ if (!Inst.hasNoSignedWrap()) {
+ auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::BinaryOps(Inst.getOpcode()), RangeB, OverflowingBinaryOperator::NoSignedWrap);
+ if (NSWRange.contains(RangeA)) {
+ Inst.setHasNoSignedWrap();
+ Changed = true;
+ }
+ }
+
+ return Changed;
+}
+
/// Try to replace signed instructions with their unsigned equivalent.
static bool replaceSignedInst(SCCPSolver &Solver,
SmallPtrSetImpl<Value *> &InsertedValues,
@@ -195,6 +224,8 @@ bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB,
} else if (replaceSignedInst(*this, InsertedValues, Inst)) {
MadeChanges = true;
++InstReplacedStat;
+ } else if (refineInstruction(*this, InsertedValues, Inst)) {
+ MadeChanges = true;
}
}
return MadeChanges;
@@ -322,6 +353,10 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
MapVector<std::pair<Function *, unsigned>, ValueLatticeElement>
TrackedMultipleRetVals;
+ /// The set of values whose lattice has been invalidated.
+ /// Populated by resetLatticeValueFor(), cleared after resolving undefs.
+ DenseSet<Value *> Invalidated;
+
/// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is
/// represented here for efficient lookup.
SmallPtrSet<Function *, 16> MRVFunctionsTracked;
@@ -352,14 +387,15 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
using Edge = std::pair<BasicBlock *, BasicBlock *>;
DenseSet<Edge> KnownFeasibleEdges;
- DenseMap<Function *, AnalysisResultsForFn> AnalysisResults;
+ DenseMap<Function *, std::unique_ptr<PredicateInfo>> FnPredicateInfo;
+
DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers;
LLVMContext &Ctx;
private:
- ConstantInt *getConstantInt(const ValueLatticeElement &IV) const {
- return dyn_cast_or_null<ConstantInt>(getConstant(IV));
+ ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const {
+ return dyn_cast_or_null<ConstantInt>(getConstant(IV, Ty));
}
// pushToWorkList - Helper for markConstant/markOverdefined
@@ -447,6 +483,64 @@ private:
return LV;
}
+ /// Traverse the use-def chain of \p Call, marking itself and its users as
+ /// "unknown" on the way.
+ void invalidate(CallBase *Call) {
+ SmallVector<Instruction *, 64> ToInvalidate;
+ ToInvalidate.push_back(Call);
+
+ while (!ToInvalidate.empty()) {
+ Instruction *Inst = ToInvalidate.pop_back_val();
+
+ if (!Invalidated.insert(Inst).second)
+ continue;
+
+ if (!BBExecutable.count(Inst->getParent()))
+ continue;
+
+ Value *V = nullptr;
+ // For return instructions we need to invalidate the tracked returns map.
+ // Anything else has its lattice in the value map.
+ if (auto *RetInst = dyn_cast<ReturnInst>(Inst)) {
+ Function *F = RetInst->getParent()->getParent();
+ if (auto It = TrackedRetVals.find(F); It != TrackedRetVals.end()) {
+ It->second = ValueLatticeElement();
+ V = F;
+ } else if (MRVFunctionsTracked.count(F)) {
+ auto *STy = cast<StructType>(F->getReturnType());
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I)
+ TrackedMultipleRetVals[{F, I}] = ValueLatticeElement();
+ V = F;
+ }
+ } else if (auto *STy = dyn_cast<StructType>(Inst->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ if (auto It = StructValueState.find({Inst, I});
+ It != StructValueState.end()) {
+ It->second = ValueLatticeElement();
+ V = Inst;
+ }
+ }
+ } else if (auto It = ValueState.find(Inst); It != ValueState.end()) {
+ It->second = ValueLatticeElement();
+ V = Inst;
+ }
+
+ if (V) {
+ LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n");
+
+ for (User *U : V->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ ToInvalidate.push_back(UI);
+
+ auto It = AdditionalUsers.find(V);
+ if (It != AdditionalUsers.end())
+ for (User *U : It->second)
+ if (auto *UI = dyn_cast<Instruction>(U))
+ ToInvalidate.push_back(UI);
+ }
+ }
+ }
+
/// markEdgeExecutable - Mark a basic block as executable, adding it to the BB
/// work list if it is not already executable.
bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest);
@@ -520,6 +614,7 @@ private:
void visitCastInst(CastInst &I);
void visitSelectInst(SelectInst &I);
void visitUnaryOperator(Instruction &I);
+ void visitFreezeInst(FreezeInst &I);
void visitBinaryOperator(Instruction &I);
void visitCmpInst(CmpInst &I);
void visitExtractValueInst(ExtractValueInst &EVI);
@@ -557,8 +652,8 @@ private:
void visitInstruction(Instruction &I);
public:
- void addAnalysis(Function &F, AnalysisResultsForFn A) {
- AnalysisResults.insert({&F, std::move(A)});
+ void addPredicateInfo(Function &F, DominatorTree &DT, AssumptionCache &AC) {
+ FnPredicateInfo.insert({&F, std::make_unique<PredicateInfo>(F, DT, AC)});
}
void visitCallInst(CallInst &I) { visitCallBase(I); }
@@ -566,23 +661,10 @@ public:
bool markBlockExecutable(BasicBlock *BB);
const PredicateBase *getPredicateInfoFor(Instruction *I) {
- auto A = AnalysisResults.find(I->getParent()->getParent());
- if (A == AnalysisResults.end())
+ auto It = FnPredicateInfo.find(I->getParent()->getParent());
+ if (It == FnPredicateInfo.end())
return nullptr;
- return A->second.PredInfo->getPredicateInfoFor(I);
- }
-
- const LoopInfo &getLoopInfo(Function &F) {
- auto A = AnalysisResults.find(&F);
- assert(A != AnalysisResults.end() && A->second.LI &&
- "Need LoopInfo analysis results for function.");
- return *A->second.LI;
- }
-
- DomTreeUpdater getDTU(Function &F) {
- auto A = AnalysisResults.find(&F);
- assert(A != AnalysisResults.end() && "Need analysis results for function.");
- return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy};
+ return It->second->getPredicateInfoFor(I);
}
SCCPInstVisitor(const DataLayout &DL,
@@ -627,6 +709,8 @@ public:
void solve();
+ bool resolvedUndef(Instruction &I);
+
bool resolvedUndefsIn(Function &F);
bool isBlockExecutable(BasicBlock *BB) const {
@@ -649,6 +733,19 @@ public:
void removeLatticeValueFor(Value *V) { ValueState.erase(V); }
+ /// Invalidate the Lattice Value of \p Call and its users after specializing
+ /// the call. Then recompute it.
+ void resetLatticeValueFor(CallBase *Call) {
+ // Calls to void returning functions do not need invalidation.
+ Function *F = Call->getCalledFunction();
+ (void)F;
+ assert(!F->getReturnType()->isVoidTy() &&
+ (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) &&
+ "All non void specializations should be tracked");
+ invalidate(Call);
+ handleCallResult(*Call);
+ }
+
const ValueLatticeElement &getLatticeValueFor(Value *V) const {
assert(!V->getType()->isStructTy() &&
"Should use getStructLatticeValueFor");
@@ -681,15 +778,16 @@ public:
bool isStructLatticeConstant(Function *F, StructType *STy);
- Constant *getConstant(const ValueLatticeElement &LV) const;
- ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty) const;
+ Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const;
+
+ Constant *getConstantOrNull(Value *V) const;
SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() {
return TrackingIncomingArguments;
}
- void markArgInFuncSpecialization(Function *F,
- const SmallVectorImpl<ArgInfo> &Args);
+ void setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args);
void markFunctionUnreachable(Function *F) {
for (auto &BB : *F)
@@ -715,6 +813,18 @@ public:
ResolvedUndefs |= resolvedUndefsIn(*F);
}
}
+
+ void solveWhileResolvedUndefs() {
+ bool ResolvedUndefs = true;
+ while (ResolvedUndefs) {
+ solve();
+ ResolvedUndefs = false;
+ for (Value *V : Invalidated)
+ if (auto *I = dyn_cast<Instruction>(V))
+ ResolvedUndefs |= resolvedUndef(*I);
+ }
+ Invalidated.clear();
+ }
};
} // namespace llvm
@@ -728,9 +838,13 @@ bool SCCPInstVisitor::markBlockExecutable(BasicBlock *BB) {
}
void SCCPInstVisitor::pushToWorkList(ValueLatticeElement &IV, Value *V) {
- if (IV.isOverdefined())
- return OverdefinedInstWorkList.push_back(V);
- InstWorkList.push_back(V);
+ if (IV.isOverdefined()) {
+ if (OverdefinedInstWorkList.empty() || OverdefinedInstWorkList.back() != V)
+ OverdefinedInstWorkList.push_back(V);
+ return;
+ }
+ if (InstWorkList.empty() || InstWorkList.back() != V)
+ InstWorkList.push_back(V);
}
void SCCPInstVisitor::pushToWorkListMsg(ValueLatticeElement &IV, Value *V) {
@@ -771,57 +885,84 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) {
return true;
}
-Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
- if (LV.isConstant())
- return LV.getConstant();
+Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV,
+ Type *Ty) const {
+ if (LV.isConstant()) {
+ Constant *C = LV.getConstant();
+ assert(C->getType() == Ty && "Type mismatch");
+ return C;
+ }
if (LV.isConstantRange()) {
const auto &CR = LV.getConstantRange();
if (CR.getSingleElement())
- return ConstantInt::get(Ctx, *CR.getSingleElement());
+ return ConstantInt::get(Ty, *CR.getSingleElement());
}
return nullptr;
}
-ConstantRange
-SCCPInstVisitor::getConstantRange(const ValueLatticeElement &LV,
- Type *Ty) const {
- assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector");
- if (LV.isConstantRange())
- return LV.getConstantRange();
- return ConstantRange::getFull(Ty->getScalarSizeInBits());
+Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
+ Constant *Const = nullptr;
+ if (V->getType()->isStructTy()) {
+ std::vector<ValueLatticeElement> LVs = getStructLatticeValueFor(V);
+ if (any_of(LVs, SCCPSolver::isOverdefined))
+ return nullptr;
+ std::vector<Constant *> ConstVals;
+ auto *ST = cast<StructType>(V->getType());
+ for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) {
+ ValueLatticeElement LV = LVs[I];
+ ConstVals.push_back(SCCPSolver::isConstant(LV)
+ ? getConstant(LV, ST->getElementType(I))
+ : UndefValue::get(ST->getElementType(I)));
+ }
+ Const = ConstantStruct::get(ST, ConstVals);
+ } else {
+ const ValueLatticeElement &LV = getLatticeValueFor(V);
+ if (SCCPSolver::isOverdefined(LV))
+ return nullptr;
+ Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType())
+ : UndefValue::get(V->getType());
+ }
+ assert(Const && "Constant is nullptr here!");
+ return Const;
}
-void SCCPInstVisitor::markArgInFuncSpecialization(
- Function *F, const SmallVectorImpl<ArgInfo> &Args) {
+void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args) {
assert(!Args.empty() && "Specialization without arguments");
assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() &&
"Functions should have the same number of arguments");
auto Iter = Args.begin();
- Argument *NewArg = F->arg_begin();
- Argument *OldArg = Args[0].Formal->getParent()->arg_begin();
+ Function::arg_iterator NewArg = F->arg_begin();
+ Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin();
for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) {
LLVM_DEBUG(dbgs() << "SCCP: Marking argument "
<< NewArg->getNameOrAsOperand() << "\n");
- if (Iter != Args.end() && OldArg == Iter->Formal) {
- // Mark the argument constants in the new function.
- markConstant(NewArg, Iter->Actual);
+ // Mark the argument constants in the new function
+ // or copy the lattice state over from the old function.
+ if (Iter != Args.end() && Iter->Formal == &*OldArg) {
+ if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+ NewValue.markConstant(Iter->Actual->getAggregateElement(I));
+ }
+ } else {
+ ValueState[&*NewArg].markConstant(Iter->Actual);
+ }
++Iter;
- } else if (ValueState.count(OldArg)) {
- // For the remaining arguments in the new function, copy the lattice state
- // over from the old function.
- //
- // Note: This previously looked like this:
- // ValueState[NewArg] = ValueState[OldArg];
- // This is incorrect because the DenseMap class may resize the underlying
- // memory when inserting `NewArg`, which will invalidate the reference to
- // `OldArg`. Instead, we make sure `NewArg` exists before setting it.
- auto &NewValue = ValueState[NewArg];
- NewValue = ValueState[OldArg];
- pushToWorkList(NewValue, NewArg);
+ } else {
+ if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+ NewValue = StructValueState[{&*OldArg, I}];
+ }
+ } else {
+ ValueLatticeElement &NewValue = ValueState[&*NewArg];
+ NewValue = ValueState[&*OldArg];
+ }
}
}
}
@@ -874,7 +1015,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
}
ValueLatticeElement BCValue = getValueState(BI->getCondition());
- ConstantInt *CI = getConstantInt(BCValue);
+ ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType());
if (!CI) {
// Overdefined condition variables, and branches on unfoldable constant
// conditions, mean the branch could go either way.
@@ -900,7 +1041,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
return;
}
const ValueLatticeElement &SCValue = getValueState(SI->getCondition());
- if (ConstantInt *CI = getConstantInt(SCValue)) {
+ if (ConstantInt *CI =
+ getConstantInt(SCValue, SI->getCondition()->getType())) {
Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true;
return;
}
@@ -931,7 +1073,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {
// Casts are folded by visitCastInst.
ValueLatticeElement IBRValue = getValueState(IBR->getAddress());
- BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue));
+ BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(
+ getConstant(IBRValue, IBR->getAddress()->getType()));
if (!Addr) { // Overdefined or unknown condition?
// All destinations are executable!
if (!IBRValue.isUnknownOrUndef())
@@ -1086,7 +1229,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
if (OpSt.isUnknownOrUndef())
return;
- if (Constant *OpC = getConstant(OpSt)) {
+ if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) {
// Fold the constant as we build.
Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);
markConstant(&I, C);
@@ -1221,7 +1364,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) {
if (CondValue.isUnknownOrUndef())
return;
- if (ConstantInt *CondCB = getConstantInt(CondValue)) {
+ if (ConstantInt *CondCB =
+ getConstantInt(CondValue, I.getCondition()->getType())) {
Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue();
mergeInValue(&I, getValueState(OpVal));
return;
@@ -1254,13 +1398,37 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) {
return;
if (SCCPSolver::isConstant(V0State))
- if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(),
- getConstant(V0State), DL))
+ if (Constant *C = ConstantFoldUnaryOpOperand(
+ I.getOpcode(), getConstant(V0State, I.getType()), DL))
return (void)markConstant(IV, &I, C);
markOverdefined(&I);
}
+void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) {
+ // If this freeze returns a struct, just mark the result overdefined.
+ // TODO: We could do a lot better than this.
+ if (I.getType()->isStructTy())
+ return (void)markOverdefined(&I);
+
+ ValueLatticeElement V0State = getValueState(I.getOperand(0));
+ ValueLatticeElement &IV = ValueState[&I];
+ // resolvedUndefsIn might mark I as overdefined. Bail out, even if we would
+ // discover a concrete value later.
+ if (SCCPSolver::isOverdefined(IV))
+ return (void)markOverdefined(&I);
+
+ // If something is unknown/undef, wait for it to resolve.
+ if (V0State.isUnknownOrUndef())
+ return;
+
+ if (SCCPSolver::isConstant(V0State) &&
+ isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType())))
+ return (void)markConstant(IV, &I, getConstant(V0State, I.getType()));
+
+ markOverdefined(&I);
+}
+
// Handle Binary Operators.
void SCCPInstVisitor::visitBinaryOperator(Instruction &I) {
ValueLatticeElement V1State = getValueState(I.getOperand(0));
@@ -1280,10 +1448,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) {
// If either of the operands is a constant, try to fold it to a constant.
// TODO: Use information from notconstant better.
if ((V1State.isConstant() || V2State.isConstant())) {
- Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State)
- : I.getOperand(0);
- Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State)
- : I.getOperand(1);
+ Value *V1 = SCCPSolver::isConstant(V1State)
+ ? getConstant(V1State, I.getOperand(0)->getType())
+ : I.getOperand(0);
+ Value *V2 = SCCPSolver::isConstant(V2State)
+ ? getConstant(V2State, I.getOperand(1)->getType())
+ : I.getOperand(1);
Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL));
auto *C = dyn_cast_or_null<Constant>(R);
if (C) {
@@ -1361,7 +1531,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
if (SCCPSolver::isOverdefined(State))
return (void)markOverdefined(&I);
- if (Constant *C = getConstant(State)) {
+ if (Constant *C = getConstant(State, I.getOperand(i)->getType())) {
Operands.push_back(C);
continue;
}
@@ -1427,7 +1597,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) {
ValueLatticeElement &IV = ValueState[&I];
if (SCCPSolver::isConstant(PtrVal)) {
- Constant *Ptr = getConstant(PtrVal);
+ Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType());
// load null is undefined.
if (isa<ConstantPointerNull>(Ptr)) {
@@ -1490,7 +1660,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) {
if (SCCPSolver::isOverdefined(State))
return (void)markOverdefined(&CB);
assert(SCCPSolver::isConstant(State) && "Unknown state!");
- Operands.push_back(getConstant(State));
+ Operands.push_back(getConstant(State, A->getType()));
}
if (SCCPSolver::isOverdefined(getValueState(&CB)))
@@ -1622,6 +1792,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) {
SmallVector<ConstantRange, 2> OpRanges;
for (Value *Op : II->args()) {
const ValueLatticeElement &State = getValueState(Op);
+ if (State.isUnknownOrUndef())
+ return;
OpRanges.push_back(getConstantRange(State, Op->getType()));
}
@@ -1666,6 +1838,7 @@ void SCCPInstVisitor::solve() {
// things to overdefined more quickly.
while (!OverdefinedInstWorkList.empty()) {
Value *I = OverdefinedInstWorkList.pop_back_val();
+ Invalidated.erase(I);
LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n');
@@ -1682,6 +1855,7 @@ void SCCPInstVisitor::solve() {
// Process the instruction work list.
while (!InstWorkList.empty()) {
Value *I = InstWorkList.pop_back_val();
+ Invalidated.erase(I);
LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n');
@@ -1709,6 +1883,61 @@ void SCCPInstVisitor::solve() {
}
}
+bool SCCPInstVisitor::resolvedUndef(Instruction &I) {
+ // Look for instructions which produce undef values.
+ if (I.getType()->isVoidTy())
+ return false;
+
+ if (auto *STy = dyn_cast<StructType>(I.getType())) {
+ // Only a few things that can be structs matter for undef.
+
+ // Tracked calls must never be marked overdefined in resolvedUndefsIn.
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
+ if (MRVFunctionsTracked.count(F))
+ return false;
+
+ // extractvalue and insertvalue don't need to be marked; they are
+ // tracked as precisely as their operands.
+ if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
+ return false;
+ // Send the results of everything else to overdefined. We could be
+ // more precise than this but it isn't worth bothering.
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ ValueLatticeElement &LV = getStructValueState(&I, i);
+ if (LV.isUnknown()) {
+ markOverdefined(LV, &I);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ ValueLatticeElement &LV = getValueState(&I);
+ if (!LV.isUnknown())
+ return false;
+
+ // There are two reasons a call can have an undef result
+ // 1. It could be tracked.
+ // 2. It could be constant-foldable.
+ // Because of the way we solve return values, tracked calls must
+ // never be marked overdefined in resolvedUndefsIn.
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
+ if (TrackedRetVals.count(F))
+ return false;
+
+ if (isa<LoadInst>(I)) {
+ // A load here means one of two things: a load of undef from a global,
+ // a load from an unknown pointer. Either way, having it return undef
+ // is okay.
+ return false;
+ }
+
+ markOverdefined(&I);
+ return true;
+}
+
/// While solving the dataflow for a function, we don't compute a result for
/// operations with an undef operand, to allow undef to be lowered to a
/// constant later. For example, constant folding of "zext i8 undef to i16"
@@ -1728,60 +1957,8 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) {
if (!BBExecutable.count(&BB))
continue;
- for (Instruction &I : BB) {
- // Look for instructions which produce undef values.
- if (I.getType()->isVoidTy())
- continue;
-
- if (auto *STy = dyn_cast<StructType>(I.getType())) {
- // Only a few things that can be structs matter for undef.
-
- // Tracked calls must never be marked overdefined in resolvedUndefsIn.
- if (auto *CB = dyn_cast<CallBase>(&I))
- if (Function *F = CB->getCalledFunction())
- if (MRVFunctionsTracked.count(F))
- continue;
-
- // extractvalue and insertvalue don't need to be marked; they are
- // tracked as precisely as their operands.
- if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
- continue;
- // Send the results of everything else to overdefined. We could be
- // more precise than this but it isn't worth bothering.
- for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- ValueLatticeElement &LV = getStructValueState(&I, i);
- if (LV.isUnknown()) {
- markOverdefined(LV, &I);
- MadeChange = true;
- }
- }
- continue;
- }
-
- ValueLatticeElement &LV = getValueState(&I);
- if (!LV.isUnknown())
- continue;
-
- // There are two reasons a call can have an undef result
- // 1. It could be tracked.
- // 2. It could be constant-foldable.
- // Because of the way we solve return values, tracked calls must
- // never be marked overdefined in resolvedUndefsIn.
- if (auto *CB = dyn_cast<CallBase>(&I))
- if (Function *F = CB->getCalledFunction())
- if (TrackedRetVals.count(F))
- continue;
-
- if (isa<LoadInst>(I)) {
- // A load here means one of two things: a load of undef from a global,
- // a load from an unknown pointer. Either way, having it return undef
- // is okay.
- continue;
- }
-
- markOverdefined(&I);
- MadeChange = true;
- }
+ for (Instruction &I : BB)
+ MadeChange |= resolvedUndef(I);
}
LLVM_DEBUG(if (MadeChange) dbgs()
@@ -1802,8 +1979,9 @@ SCCPSolver::SCCPSolver(
SCCPSolver::~SCCPSolver() = default;
-void SCCPSolver::addAnalysis(Function &F, AnalysisResultsForFn A) {
- return Visitor->addAnalysis(F, std::move(A));
+void SCCPSolver::addPredicateInfo(Function &F, DominatorTree &DT,
+ AssumptionCache &AC) {
+ Visitor->addPredicateInfo(F, DT, AC);
}
bool SCCPSolver::markBlockExecutable(BasicBlock *BB) {
@@ -1814,12 +1992,6 @@ const PredicateBase *SCCPSolver::getPredicateInfoFor(Instruction *I) {
return Visitor->getPredicateInfoFor(I);
}
-const LoopInfo &SCCPSolver::getLoopInfo(Function &F) {
- return Visitor->getLoopInfo(F);
-}
-
-DomTreeUpdater SCCPSolver::getDTU(Function &F) { return Visitor->getDTU(F); }
-
void SCCPSolver::trackValueOfGlobalVariable(GlobalVariable *GV) {
Visitor->trackValueOfGlobalVariable(GV);
}
@@ -1859,6 +2031,10 @@ SCCPSolver::solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) {
Visitor->solveWhileResolvedUndefsIn(WorkList);
}
+void SCCPSolver::solveWhileResolvedUndefs() {
+ Visitor->solveWhileResolvedUndefs();
+}
+
bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const {
return Visitor->isBlockExecutable(BB);
}
@@ -1876,6 +2052,10 @@ void SCCPSolver::removeLatticeValueFor(Value *V) {
return Visitor->removeLatticeValueFor(V);
}
+void SCCPSolver::resetLatticeValueFor(CallBase *Call) {
+ Visitor->resetLatticeValueFor(Call);
+}
+
const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const {
return Visitor->getLatticeValueFor(V);
}
@@ -1900,17 +2080,22 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) {
return Visitor->isStructLatticeConstant(F, STy);
}
-Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const {
- return Visitor->getConstant(LV);
+Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV,
+ Type *Ty) const {
+ return Visitor->getConstant(LV, Ty);
+}
+
+Constant *SCCPSolver::getConstantOrNull(Value *V) const {
+ return Visitor->getConstantOrNull(V);
}
SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() {
return Visitor->getArgumentTrackedFunctions();
}
-void SCCPSolver::markArgInFuncSpecialization(
- Function *F, const SmallVectorImpl<ArgInfo> &Args) {
- Visitor->markArgInFuncSpecialization(F, Args);
+void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args) {
+ Visitor->setLatticeValueForSpecializationArguments(F, Args);
}
void SCCPSolver::markFunctionUnreachable(Function *F) {
diff --git a/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/llvm/lib/Transforms/Utils/SSAUpdater.cpp
index 2520aa5d9db0..ebe9cb27f5ab 100644
--- a/llvm/lib/Transforms/Utils/SSAUpdater.cpp
+++ b/llvm/lib/Transforms/Utils/SSAUpdater.cpp
@@ -19,6 +19,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
@@ -195,6 +196,33 @@ void SSAUpdater::RewriteUse(Use &U) {
U.set(V);
}
+void SSAUpdater::UpdateDebugValues(Instruction *I) {
+ SmallVector<DbgValueInst *, 4> DbgValues;
+ llvm::findDbgValues(DbgValues, I);
+ for (auto &DbgValue : DbgValues) {
+ if (DbgValue->getParent() == I->getParent())
+ continue;
+ UpdateDebugValue(I, DbgValue);
+ }
+}
+
+void SSAUpdater::UpdateDebugValues(Instruction *I,
+ SmallVectorImpl<DbgValueInst *> &DbgValues) {
+ for (auto &DbgValue : DbgValues) {
+ UpdateDebugValue(I, DbgValue);
+ }
+}
+
+void SSAUpdater::UpdateDebugValue(Instruction *I, DbgValueInst *DbgValue) {
+ BasicBlock *UserBB = DbgValue->getParent();
+ if (HasValueForBlock(UserBB)) {
+ Value *NewVal = GetValueAtEndOfBlock(UserBB);
+ DbgValue->replaceVariableLocationOp(I, NewVal);
+ }
+ else
+ DbgValue->setKillLocation();
+}
+
void SSAUpdater::RewriteUseAfterInsertions(Use &U) {
Instruction *User = cast<Instruction>(U.getUser());
diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
index 691ee00bd831..31d62fbf0618 100644
--- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
+++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
@@ -20,6 +20,7 @@
#include <queue>
#include <set>
#include <stack>
+#include <unordered_set>
using namespace llvm;
#define DEBUG_TYPE "sample-profile-inference"
@@ -1218,10 +1219,23 @@ void extractWeights(const ProfiParams &Params, MinCostMaxFlow &Network,
#ifndef NDEBUG
/// Verify that the provided block/jump weights are as expected.
void verifyInput(const FlowFunction &Func) {
- // Verify the entry block
+ // Verify entry and exit blocks
assert(Func.Entry == 0 && Func.Blocks[0].isEntry());
+ size_t NumExitBlocks = 0;
for (size_t I = 1; I < Func.Blocks.size(); I++) {
assert(!Func.Blocks[I].isEntry() && "multiple entry blocks");
+ if (Func.Blocks[I].isExit())
+ NumExitBlocks++;
+ }
+ assert(NumExitBlocks > 0 && "cannot find exit blocks");
+
+ // Verify that there are no parallel edges
+ for (auto &Block : Func.Blocks) {
+ std::unordered_set<uint64_t> UniqueSuccs;
+ for (auto &Jump : Block.SuccJumps) {
+ auto It = UniqueSuccs.insert(Jump->Target);
+ assert(It.second && "input CFG contains parallel edges");
+ }
}
// Verify CFG jumps
for (auto &Block : Func.Blocks) {
@@ -1304,8 +1318,26 @@ void verifyOutput(const FlowFunction &Func) {
} // end of anonymous namespace
-/// Apply the profile inference algorithm for a given function
+/// Apply the profile inference algorithm for a given function and provided
+/// profi options
void llvm::applyFlowInference(const ProfiParams &Params, FlowFunction &Func) {
+ // Check if the function has samples and assign initial flow values
+ bool HasSamples = false;
+ for (FlowBlock &Block : Func.Blocks) {
+ if (Block.Weight > 0)
+ HasSamples = true;
+ Block.Flow = Block.Weight;
+ }
+ for (FlowJump &Jump : Func.Jumps) {
+ if (Jump.Weight > 0)
+ HasSamples = true;
+ Jump.Flow = Jump.Weight;
+ }
+
+ // Quit early for functions with a single block or ones w/o samples
+ if (Func.Blocks.size() <= 1 || !HasSamples)
+ return;
+
#ifndef NDEBUG
// Verify the input data
verifyInput(Func);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 24f1966edd37..20844271b943 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -163,7 +163,7 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) {
"InsertNoopCastOfTo cannot change sizes!");
// inttoptr only works for integral pointers. For non-integral pointers, we
- // can create a GEP on i8* null with the integral value as index. Note that
+ // can create a GEP on null with the integral value as index. Note that
// it is safe to use GEP of null instead of inttoptr here, because only
// expressions already based on a GEP of null should be converted to pointers
// during expansion.
@@ -173,9 +173,8 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) {
auto *Int8PtrTy = Builder.getInt8PtrTy(PtrTy->getAddressSpace());
assert(DL.getTypeAllocSize(Builder.getInt8Ty()) == 1 &&
"alloc size of i8 must by 1 byte for the GEP to be correct");
- auto *GEP = Builder.CreateGEP(
- Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "uglygep");
- return Builder.CreateBitCast(GEP, Ty);
+ return Builder.CreateGEP(
+ Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "scevgep");
}
}
// Short-circuit unnecessary bitcasts.
@@ -287,142 +286,6 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
return BO;
}
-/// FactorOutConstant - Test if S is divisible by Factor, using signed
-/// division. If so, update S with Factor divided out and return true.
-/// S need not be evenly divisible if a reasonable remainder can be
-/// computed.
-static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder,
- const SCEV *Factor, ScalarEvolution &SE,
- const DataLayout &DL) {
- // Everything is divisible by one.
- if (Factor->isOne())
- return true;
-
- // x/x == 1.
- if (S == Factor) {
- S = SE.getConstant(S->getType(), 1);
- return true;
- }
-
- // For a Constant, check for a multiple of the given factor.
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- // 0/x == 0.
- if (C->isZero())
- return true;
- // Check for divisibility.
- if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) {
- ConstantInt *CI =
- ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt()));
- // If the quotient is zero and the remainder is non-zero, reject
- // the value at this scale. It will be considered for subsequent
- // smaller scales.
- if (!CI->isZero()) {
- const SCEV *Div = SE.getConstant(CI);
- S = Div;
- Remainder = SE.getAddExpr(
- Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt())));
- return true;
- }
- }
- }
-
- // In a Mul, check if there is a constant operand which is a multiple
- // of the given factor.
- if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
- // Size is known, check if there is a constant operand which is a multiple
- // of the given factor. If so, we can factor it.
- if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor))
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
- if (!C->getAPInt().srem(FC->getAPInt())) {
- SmallVector<const SCEV *, 4> NewMulOps(M->operands());
- NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt()));
- S = SE.getMulExpr(NewMulOps);
- return true;
- }
- }
-
- // In an AddRec, check if both start and step are divisible.
- if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
- const SCEV *Step = A->getStepRecurrence(SE);
- const SCEV *StepRem = SE.getConstant(Step->getType(), 0);
- if (!FactorOutConstant(Step, StepRem, Factor, SE, DL))
- return false;
- if (!StepRem->isZero())
- return false;
- const SCEV *Start = A->getStart();
- if (!FactorOutConstant(Start, Remainder, Factor, SE, DL))
- return false;
- S = SE.getAddRecExpr(Start, Step, A->getLoop(),
- A->getNoWrapFlags(SCEV::FlagNW));
- return true;
- }
-
- return false;
-}
-
-/// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs
-/// is the number of SCEVAddRecExprs present, which are kept at the end of
-/// the list.
-///
-static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops,
- Type *Ty,
- ScalarEvolution &SE) {
- unsigned NumAddRecs = 0;
- for (unsigned i = Ops.size(); i > 0 && isa<SCEVAddRecExpr>(Ops[i-1]); --i)
- ++NumAddRecs;
- // Group Ops into non-addrecs and addrecs.
- SmallVector<const SCEV *, 8> NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs);
- SmallVector<const SCEV *, 8> AddRecs(Ops.end() - NumAddRecs, Ops.end());
- // Let ScalarEvolution sort and simplify the non-addrecs list.
- const SCEV *Sum = NoAddRecs.empty() ?
- SE.getConstant(Ty, 0) :
- SE.getAddExpr(NoAddRecs);
- // If it returned an add, use the operands. Otherwise it simplified
- // the sum into a single value, so just use that.
- Ops.clear();
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum))
- append_range(Ops, Add->operands());
- else if (!Sum->isZero())
- Ops.push_back(Sum);
- // Then append the addrecs.
- Ops.append(AddRecs.begin(), AddRecs.end());
-}
-
-/// SplitAddRecs - Flatten a list of add operands, moving addrec start values
-/// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}.
-/// This helps expose more opportunities for folding parts of the expressions
-/// into GEP indices.
-///
-static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops,
- Type *Ty,
- ScalarEvolution &SE) {
- // Find the addrecs.
- SmallVector<const SCEV *, 8> AddRecs;
- for (unsigned i = 0, e = Ops.size(); i != e; ++i)
- while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) {
- const SCEV *Start = A->getStart();
- if (Start->isZero()) break;
- const SCEV *Zero = SE.getConstant(Ty, 0);
- AddRecs.push_back(SE.getAddRecExpr(Zero,
- A->getStepRecurrence(SE),
- A->getLoop(),
- A->getNoWrapFlags(SCEV::FlagNW)));
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) {
- Ops[i] = Zero;
- append_range(Ops, Add->operands());
- e += Add->getNumOperands();
- } else {
- Ops[i] = Start;
- }
- }
- if (!AddRecs.empty()) {
- // Add the addrecs onto the end of the list.
- Ops.append(AddRecs.begin(), AddRecs.end());
- // Resort the operand list, moving any constants to the front.
- SimplifyAddOperands(Ops, Ty, SE);
- }
-}
-
/// expandAddToGEP - Expand an addition expression with a pointer type into
/// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps
/// BasicAliasAnalysis and other passes analyze the result. See the rules
@@ -450,210 +313,53 @@ static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops,
/// loop-invariant portions of expressions, after considering what
/// can be folded using target addressing modes.
///
-Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
- const SCEV *const *op_end,
- PointerType *PTy,
- Type *Ty,
- Value *V) {
- SmallVector<Value *, 4> GepIndices;
- SmallVector<const SCEV *, 8> Ops(op_begin, op_end);
- bool AnyNonZeroIndices = false;
-
- // Split AddRecs up into parts as either of the parts may be usable
- // without the other.
- SplitAddRecs(Ops, Ty, SE);
-
- Type *IntIdxTy = DL.getIndexType(PTy);
-
- // For opaque pointers, always generate i8 GEP.
- if (!PTy->isOpaque()) {
- // Descend down the pointer's type and attempt to convert the other
- // operands into GEP indices, at each level. The first index in a GEP
- // indexes into the array implied by the pointer operand; the rest of
- // the indices index into the element or field type selected by the
- // preceding index.
- Type *ElTy = PTy->getNonOpaquePointerElementType();
- for (;;) {
- // If the scale size is not 0, attempt to factor out a scale for
- // array indexing.
- SmallVector<const SCEV *, 8> ScaledOps;
- if (ElTy->isSized()) {
- const SCEV *ElSize = SE.getSizeOfExpr(IntIdxTy, ElTy);
- if (!ElSize->isZero()) {
- SmallVector<const SCEV *, 8> NewOps;
- for (const SCEV *Op : Ops) {
- const SCEV *Remainder = SE.getConstant(Ty, 0);
- if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) {
- // Op now has ElSize factored out.
- ScaledOps.push_back(Op);
- if (!Remainder->isZero())
- NewOps.push_back(Remainder);
- AnyNonZeroIndices = true;
- } else {
- // The operand was not divisible, so add it to the list of
- // operands we'll scan next iteration.
- NewOps.push_back(Op);
- }
- }
- // If we made any changes, update Ops.
- if (!ScaledOps.empty()) {
- Ops = NewOps;
- SimplifyAddOperands(Ops, Ty, SE);
- }
- }
- }
+Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) {
+ assert(!isa<Instruction>(V) ||
+ SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
- // Record the scaled array index for this level of the type. If
- // we didn't find any operands that could be factored, tentatively
- // assume that element zero was selected (since the zero offset
- // would obviously be folded away).
- Value *Scaled =
- ScaledOps.empty()
- ? Constant::getNullValue(Ty)
- : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty);
- GepIndices.push_back(Scaled);
-
- // Collect struct field index operands.
- while (StructType *STy = dyn_cast<StructType>(ElTy)) {
- bool FoundFieldNo = false;
- // An empty struct has no fields.
- if (STy->getNumElements() == 0) break;
- // Field offsets are known. See if a constant offset falls within any of
- // the struct fields.
- if (Ops.empty())
- break;
- if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0]))
- if (SE.getTypeSizeInBits(C->getType()) <= 64) {
- const StructLayout &SL = *DL.getStructLayout(STy);
- uint64_t FullOffset = C->getValue()->getZExtValue();
- if (FullOffset < SL.getSizeInBytes()) {
- unsigned ElIdx = SL.getElementContainingOffset(FullOffset);
- GepIndices.push_back(
- ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx));
- ElTy = STy->getTypeAtIndex(ElIdx);
- Ops[0] =
- SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx));
- AnyNonZeroIndices = true;
- FoundFieldNo = true;
- }
- }
- // If no struct field offsets were found, tentatively assume that
- // field zero was selected (since the zero offset would obviously
- // be folded away).
- if (!FoundFieldNo) {
- ElTy = STy->getTypeAtIndex(0u);
- GepIndices.push_back(
- Constant::getNullValue(Type::getInt32Ty(Ty->getContext())));
- }
- }
+ Value *Idx = expandCodeForImpl(Offset, Ty);
- if (ArrayType *ATy = dyn_cast<ArrayType>(ElTy))
- ElTy = ATy->getElementType();
- else
- // FIXME: Handle VectorType.
- // E.g., If ElTy is scalable vector, then ElSize is not a compile-time
- // constant, therefore can not be factored out. The generated IR is less
- // ideal with base 'V' cast to i8* and do ugly getelementptr over that.
- break;
- }
- }
-
- // If none of the operands were convertible to proper GEP indices, cast
- // the base to i8* and do an ugly getelementptr with that. It's still
- // better than ptrtoint+arithmetic+inttoptr at least.
- if (!AnyNonZeroIndices) {
- // Cast the base to i8*.
- if (!PTy->isOpaque())
- V = InsertNoopCastOfTo(V,
- Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace()));
-
- assert(!isa<Instruction>(V) ||
- SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
-
- // Expand the operands for a plain byte offset.
- Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty);
-
- // Fold a GEP with constant operands.
- if (Constant *CLHS = dyn_cast<Constant>(V))
- if (Constant *CRHS = dyn_cast<Constant>(Idx))
- return Builder.CreateGEP(Builder.getInt8Ty(), CLHS, CRHS);
-
- // Do a quick scan to see if we have this GEP nearby. If so, reuse it.
- unsigned ScanLimit = 6;
- BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin();
- // Scanning starts from the last instruction before the insertion point.
- BasicBlock::iterator IP = Builder.GetInsertPoint();
- if (IP != BlockBegin) {
- --IP;
- for (; ScanLimit; --IP, --ScanLimit) {
- // Don't count dbg.value against the ScanLimit, to avoid perturbing the
- // generated code.
- if (isa<DbgInfoIntrinsic>(IP))
- ScanLimit++;
- if (IP->getOpcode() == Instruction::GetElementPtr &&
- IP->getOperand(0) == V && IP->getOperand(1) == Idx &&
- cast<GEPOperator>(&*IP)->getSourceElementType() ==
- Type::getInt8Ty(Ty->getContext()))
- return &*IP;
- if (IP == BlockBegin) break;
- }
- }
+ // Fold a GEP with constant operands.
+ if (Constant *CLHS = dyn_cast<Constant>(V))
+ if (Constant *CRHS = dyn_cast<Constant>(Idx))
+ return Builder.CreateGEP(Builder.getInt8Ty(), CLHS, CRHS);
- // Save the original insertion point so we can restore it when we're done.
- SCEVInsertPointGuard Guard(Builder, this);
-
- // Move the insertion point out of as many loops as we can.
- while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
- if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break;
- BasicBlock *Preheader = L->getLoopPreheader();
- if (!Preheader) break;
-
- // Ok, move up a level.
- Builder.SetInsertPoint(Preheader->getTerminator());
+ // Do a quick scan to see if we have this GEP nearby. If so, reuse it.
+ unsigned ScanLimit = 6;
+ BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin();
+ // Scanning starts from the last instruction before the insertion point.
+ BasicBlock::iterator IP = Builder.GetInsertPoint();
+ if (IP != BlockBegin) {
+ --IP;
+ for (; ScanLimit; --IP, --ScanLimit) {
+ // Don't count dbg.value against the ScanLimit, to avoid perturbing the
+ // generated code.
+ if (isa<DbgInfoIntrinsic>(IP))
+ ScanLimit++;
+ if (IP->getOpcode() == Instruction::GetElementPtr &&
+ IP->getOperand(0) == V && IP->getOperand(1) == Idx &&
+ cast<GEPOperator>(&*IP)->getSourceElementType() ==
+ Type::getInt8Ty(Ty->getContext()))
+ return &*IP;
+ if (IP == BlockBegin) break;
}
-
- // Emit a GEP.
- return Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep");
}
- {
- SCEVInsertPointGuard Guard(Builder, this);
-
- // Move the insertion point out of as many loops as we can.
- while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
- if (!L->isLoopInvariant(V)) break;
-
- bool AnyIndexNotLoopInvariant = any_of(
- GepIndices, [L](Value *Op) { return !L->isLoopInvariant(Op); });
-
- if (AnyIndexNotLoopInvariant)
- break;
+ // Save the original insertion point so we can restore it when we're done.
+ SCEVInsertPointGuard Guard(Builder, this);
- BasicBlock *Preheader = L->getLoopPreheader();
- if (!Preheader) break;
+ // Move the insertion point out of as many loops as we can.
+ while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) {
+ if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break;
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) break;
- // Ok, move up a level.
- Builder.SetInsertPoint(Preheader->getTerminator());
- }
-
- // Insert a pretty getelementptr. Note that this GEP is not marked inbounds,
- // because ScalarEvolution may have changed the address arithmetic to
- // compute a value which is beyond the end of the allocated object.
- Value *Casted = V;
- if (V->getType() != PTy)
- Casted = InsertNoopCastOfTo(Casted, PTy);
- Value *GEP = Builder.CreateGEP(PTy->getNonOpaquePointerElementType(),
- Casted, GepIndices, "scevgep");
- Ops.push_back(SE.getUnknown(GEP));
+ // Ok, move up a level.
+ Builder.SetInsertPoint(Preheader->getTerminator());
}
- return expand(SE.getAddExpr(Ops));
-}
-
-Value *SCEVExpander::expandAddToGEP(const SCEV *Op, PointerType *PTy, Type *Ty,
- Value *V) {
- const SCEV *const Ops[1] = {Op};
- return expandAddToGEP(Ops, Ops + 1, PTy, Ty, V);
+ // Emit a GEP.
+ return Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "scevgep");
}
/// PickMostRelevantLoop - Given two loops pick the one that's most relevant for
@@ -680,6 +386,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
+ case scVScale:
return nullptr; // A constant has no relevant loops.
case scTruncate:
case scZeroExtend:
@@ -778,7 +485,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
}
assert(!Op->getType()->isPointerTy() && "Only first op can be pointer");
- if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) {
+ if (isa<PointerType>(Sum->getType())) {
// The running sum expression is a pointer. Try to form a getelementptr
// at this level with that as the base.
SmallVector<const SCEV *, 4> NewOps;
@@ -791,7 +498,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
X = SE.getSCEV(U->getValue());
NewOps.push_back(X);
}
- Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum);
+ Sum = expandAddToGEP(SE.getAddExpr(NewOps), Ty, Sum);
} else if (Op->isNonConstantNegative()) {
// Instead of doing a negate and add, just do a subtract.
Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty);
@@ -995,15 +702,8 @@ Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV,
// allow any kind of GEP as long as it can be hoisted.
continue;
}
- // This must be a pointer addition of constants (pretty), which is already
- // handled, or some number of address-size elements (ugly). Ugly geps
- // have 2 operands. i1* is used by the expander to represent an
- // address-size element.
- if (IncV->getNumOperands() != 2)
- return nullptr;
- unsigned AS = cast<PointerType>(IncV->getType())->getAddressSpace();
- if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS)
- && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS))
+ // GEPs produced by SCEVExpander use i8 element type.
+ if (!cast<GEPOperator>(IncV)->getSourceElementType()->isIntegerTy(8))
return nullptr;
break;
}
@@ -1108,15 +808,7 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L,
Value *IncV;
// If the PHI is a pointer, use a GEP, otherwise use an add or sub.
if (ExpandTy->isPointerTy()) {
- PointerType *GEPPtrTy = cast<PointerType>(ExpandTy);
- // If the step isn't constant, don't use an implicitly scaled GEP, because
- // that would require a multiply inside the loop.
- if (!isa<ConstantInt>(StepV))
- GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()),
- GEPPtrTy->getAddressSpace());
- IncV = expandAddToGEP(SE.getSCEV(StepV), GEPPtrTy, IntTy, PN);
- if (IncV->getType() != PN->getType())
- IncV = Builder.CreateBitCast(IncV, PN->getType());
+ IncV = expandAddToGEP(SE.getSCEV(StepV), IntTy, PN);
} else {
IncV = useSubtract ?
Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") :
@@ -1388,7 +1080,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
if (PostIncLoops.count(L)) {
PostIncLoopSet Loops;
Loops.insert(L);
- Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE));
+ Normalized = cast<SCEVAddRecExpr>(
+ normalizeForPostIncUse(S, Loops, SE, /*CheckInvertible=*/false));
}
// Strip off any non-loop-dominating component from the addrec start.
@@ -1515,12 +1208,12 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
// Re-apply any non-loop-dominating offset.
if (PostLoopOffset) {
- if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) {
+ if (isa<PointerType>(ExpandTy)) {
if (Result->getType()->isIntegerTy()) {
Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy);
- Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base);
+ Result = expandAddToGEP(SE.getUnknown(Result), IntTy, Base);
} else {
- Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result);
+ Result = expandAddToGEP(PostLoopOffset, IntTy, Result);
}
} else {
Result = InsertNoopCastOfTo(Result, IntTy);
@@ -1574,10 +1267,9 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
// {X,+,F} --> X + {0,+,F}
if (!S->getStart()->isZero()) {
- if (PointerType *PTy = dyn_cast<PointerType>(S->getType())) {
+ if (isa<PointerType>(S->getType())) {
Value *StartV = expand(SE.getPointerBase(S));
- assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!");
- return expandAddToGEP(SE.removePointerBase(S), PTy, Ty, StartV);
+ return expandAddToGEP(SE.removePointerBase(S), Ty, StartV);
}
SmallVector<const SCEV *, 4> NewOps(S->operands());
@@ -1744,6 +1436,10 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true);
}
+Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
+ return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
+}
+
Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
Instruction *IP) {
setInsertPoint(IP);
@@ -1956,11 +1652,17 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
OrigPhiRef = Phi;
if (Phi->getType()->isIntegerTy() && TTI &&
TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) {
- // This phi can be freely truncated to the narrowest phi type. Map the
- // truncated expression to it so it will be reused for narrow types.
- const SCEV *TruncExpr =
- SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType());
- ExprToIVMap[TruncExpr] = Phi;
+ // Make sure we only rewrite using simple induction variables;
+ // otherwise, we can make the trip count of a loop unanalyzable
+ // to SCEV.
+ const SCEV *PhiExpr = SE.getSCEV(Phi);
+ if (isa<SCEVAddRecExpr>(PhiExpr)) {
+ // This phi can be freely truncated to the narrowest phi type. Map the
+ // truncated expression to it so it will be reused for narrow types.
+ const SCEV *TruncExpr =
+ SE.getTruncateExpr(PhiExpr, Phis.back()->getType());
+ ExprToIVMap[TruncExpr] = Phi;
+ }
}
continue;
}
@@ -2124,6 +1826,7 @@ template<typename T> static InstructionCost costAndCollectOperands(
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
case scUnknown:
case scConstant:
+ case scVScale:
return 0;
case scPtrToInt:
Cost = CastCost(Instruction::PtrToInt);
@@ -2260,6 +1963,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
case scUnknown:
+ case scVScale:
// Assume to be zero-cost.
return false;
case scConstant: {
@@ -2551,7 +2255,11 @@ Value *SCEVExpander::fixupLCSSAFormFor(Value *V) {
SmallVector<Instruction *, 1> ToUpdate;
ToUpdate.push_back(DefI);
SmallVector<PHINode *, 16> PHIsToRemove;
- formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE, Builder, &PHIsToRemove);
+ SmallVector<PHINode *, 16> InsertedPHIs;
+ formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE, &PHIsToRemove,
+ &InsertedPHIs);
+ for (PHINode *PN : InsertedPHIs)
+ rememberInstruction(PN);
for (PHINode *PN : PHIsToRemove) {
if (!PN->use_empty())
continue;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 9e0483966d3e..d3a9a41aef15 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -271,10 +271,8 @@ class SimplifyCFGOpt {
bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI,
IRBuilder<> &Builder);
- bool HoistThenElseCodeToIf(BranchInst *BI, const TargetTransformInfo &TTI,
- bool EqTermsOnly);
- bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
- const TargetTransformInfo &TTI);
+ bool HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly);
+ bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB);
bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond,
BasicBlock *TrueBB, BasicBlock *FalseBB,
uint32_t TrueWeight, uint32_t FalseWeight);
@@ -1086,7 +1084,7 @@ static void GetBranchWeights(Instruction *TI,
static void FitWeights(MutableArrayRef<uint64_t> Weights) {
uint64_t Max = *std::max_element(Weights.begin(), Weights.end());
if (Max > UINT_MAX) {
- unsigned Offset = 32 - countLeadingZeros(Max);
+ unsigned Offset = 32 - llvm::countl_zero(Max);
for (uint64_t &I : Weights)
I >>= Offset;
}
@@ -1117,16 +1115,12 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
VMap[&BonusInst] = NewBonusInst;
- // If we moved a load, we cannot any longer claim any knowledge about
- // its potential value. The previous information might have been valid
+ // If we speculated an instruction, we need to drop any metadata that may
+ // result in undefined behavior, as the metadata might have been valid
// only given the branch precondition.
- // For an analogous reason, we must also drop all the metadata whose
- // semantics we don't understand. We *can* preserve !annotation, because
- // it is tied to the instruction itself, not the value or position.
// Similarly strip attributes on call parameters that may cause UB in
// location the call is moved to.
- NewBonusInst->dropUndefImplyingAttrsAndUnknownMetadata(
- LLVMContext::MD_annotation);
+ NewBonusInst->dropUBImplyingAttrsAndMetadata();
NewBonusInst->insertInto(PredBlock, PTI->getIterator());
NewBonusInst->takeName(&BonusInst);
@@ -1462,7 +1456,7 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) {
// If we have seen an instruction with side effects, it's unsafe to reorder an
// instruction which reads memory or itself has side effects.
if ((Flags & SkipSideEffect) &&
- (I->mayReadFromMemory() || I->mayHaveSideEffects()))
+ (I->mayReadFromMemory() || I->mayHaveSideEffects() || isa<AllocaInst>(I)))
return false;
// Reordering across an instruction which does not necessarily transfer
@@ -1490,14 +1484,43 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) {
static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValueMayBeModified = false);
+/// Helper function for HoistThenElseCodeToIf. Return true if identical
+/// instructions \p I1 and \p I2 can and should be hoisted.
+static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2,
+ const TargetTransformInfo &TTI) {
+ // If we're going to hoist a call, make sure that the two instructions
+ // we're commoning/hoisting are both marked with musttail, or neither of
+ // them is marked as such. Otherwise, we might end up in a situation where
+ // we hoist from a block where the terminator is a `ret` to a block where
+ // the terminator is a `br`, and `musttail` calls expect to be followed by
+ // a return.
+ auto *C1 = dyn_cast<CallInst>(I1);
+ auto *C2 = dyn_cast<CallInst>(I2);
+ if (C1 && C2)
+ if (C1->isMustTailCall() != C2->isMustTailCall())
+ return false;
+
+ if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2))
+ return false;
+
+ // If any of the two call sites has nomerge or convergent attribute, stop
+ // hoisting.
+ if (const auto *CB1 = dyn_cast<CallBase>(I1))
+ if (CB1->cannotMerge() || CB1->isConvergent())
+ return false;
+ if (const auto *CB2 = dyn_cast<CallBase>(I2))
+ if (CB2->cannotMerge() || CB2->isConvergent())
+ return false;
+
+ return true;
+}
+
/// Given a conditional branch that goes to BB1 and BB2, hoist any common code
/// in the two blocks up into the branch block. The caller of this function
/// guarantees that BI's block dominates BB1 and BB2. If EqTermsOnly is given,
/// only perform hoisting in case both blocks only contain a terminator. In that
/// case, only the original BI will be replaced and selects for PHIs are added.
-bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI,
- const TargetTransformInfo &TTI,
- bool EqTermsOnly) {
+bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) {
// This does very trivial matching, with limited scanning, to find identical
// instructions in the two blocks. In particular, we don't want to get into
// O(M*N) situations here where M and N are the sizes of BB1 and BB2. As
@@ -1572,37 +1595,13 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI,
goto HoistTerminator;
}
- if (I1->isIdenticalToWhenDefined(I2)) {
- // Even if the instructions are identical, it may not be safe to hoist
- // them if we have skipped over instructions with side effects or their
- // operands weren't hoisted.
- if (!isSafeToHoistInstr(I1, SkipFlagsBB1) ||
- !isSafeToHoistInstr(I2, SkipFlagsBB2))
- return Changed;
-
- // If we're going to hoist a call, make sure that the two instructions
- // we're commoning/hoisting are both marked with musttail, or neither of
- // them is marked as such. Otherwise, we might end up in a situation where
- // we hoist from a block where the terminator is a `ret` to a block where
- // the terminator is a `br`, and `musttail` calls expect to be followed by
- // a return.
- auto *C1 = dyn_cast<CallInst>(I1);
- auto *C2 = dyn_cast<CallInst>(I2);
- if (C1 && C2)
- if (C1->isMustTailCall() != C2->isMustTailCall())
- return Changed;
-
- if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2))
- return Changed;
-
- // If any of the two call sites has nomerge attribute, stop hoisting.
- if (const auto *CB1 = dyn_cast<CallBase>(I1))
- if (CB1->cannotMerge())
- return Changed;
- if (const auto *CB2 = dyn_cast<CallBase>(I2))
- if (CB2->cannotMerge())
- return Changed;
-
+ if (I1->isIdenticalToWhenDefined(I2) &&
+ // Even if the instructions are identical, it may not be safe to hoist
+ // them if we have skipped over instructions with side effects or their
+ // operands weren't hoisted.
+ isSafeToHoistInstr(I1, SkipFlagsBB1) &&
+ isSafeToHoistInstr(I2, SkipFlagsBB2) &&
+ shouldHoistCommonInstructions(I1, I2, TTI)) {
if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) {
assert(isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2));
// The debug location is an integral part of a debug info intrinsic
@@ -1618,19 +1617,7 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI,
if (!I2->use_empty())
I2->replaceAllUsesWith(I1);
I1->andIRFlags(I2);
- unsigned KnownIDs[] = {LLVMContext::MD_tbaa,
- LLVMContext::MD_range,
- LLVMContext::MD_fpmath,
- LLVMContext::MD_invariant_load,
- LLVMContext::MD_nonnull,
- LLVMContext::MD_invariant_group,
- LLVMContext::MD_align,
- LLVMContext::MD_dereferenceable,
- LLVMContext::MD_dereferenceable_or_null,
- LLVMContext::MD_mem_parallel_loop_access,
- LLVMContext::MD_access_group,
- LLVMContext::MD_preserve_access_index};
- combineMetadata(I1, I2, KnownIDs, true);
+ combineMetadataForCSE(I1, I2, true);
// I1 and I2 are being combined into a single instruction. Its debug
// location is the merged locations of the original instructions.
@@ -1808,9 +1795,9 @@ static bool canSinkInstructions(
// Conservatively return false if I is an inline-asm instruction. Sinking
// and merging inline-asm instructions can potentially create arguments
// that cannot satisfy the inline-asm constraints.
- // If the instruction has nomerge attribute, return false.
+ // If the instruction has nomerge or convergent attribute, return false.
if (const auto *C = dyn_cast<CallBase>(I))
- if (C->isInlineAsm() || C->cannotMerge())
+ if (C->isInlineAsm() || C->cannotMerge() || C->isConvergent())
return false;
// Each instruction must have zero or one use.
@@ -2455,9 +2442,13 @@ bool CompatibleSets::shouldBelongToSameSet(ArrayRef<InvokeInst *> Invokes) {
// Can we theoretically form the data operands for the merged `invoke`?
auto IsIllegalToMergeArguments = [](auto Ops) {
- Type *Ty = std::get<0>(Ops)->getType();
- assert(Ty == std::get<1>(Ops)->getType() && "Incompatible types?");
- return Ty->isTokenTy() && std::get<0>(Ops) != std::get<1>(Ops);
+ Use &U0 = std::get<0>(Ops);
+ Use &U1 = std::get<1>(Ops);
+ if (U0 == U1)
+ return false;
+ return U0->getType()->isTokenTy() ||
+ !canReplaceOperandWithVariable(cast<Instruction>(U0.getUser()),
+ U0.getOperandNo());
};
assert(Invokes.size() == 2 && "Always called with exactly two candidates.");
if (any_of(zip(Invokes[0]->data_ops(), Invokes[1]->data_ops()),
@@ -2571,7 +2562,7 @@ static void MergeCompatibleInvokesImpl(ArrayRef<InvokeInst *> Invokes,
// And finally, replace the original `invoke`s with an unconditional branch
// to the block with the merged `invoke`. Also, give that merged `invoke`
// the merged debugloc of all the original `invoke`s.
- const DILocation *MergedDebugLoc = nullptr;
+ DILocation *MergedDebugLoc = nullptr;
for (InvokeInst *II : Invokes) {
// Compute the debug location common to all the original `invoke`s.
if (!MergedDebugLoc)
@@ -2849,8 +2840,11 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
/// \endcode
///
/// \returns true if the conditional block is removed.
-bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
- const TargetTransformInfo &TTI) {
+bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI,
+ BasicBlock *ThenBB) {
+ if (!Options.SpeculateBlocks)
+ return false;
+
// Be conservative for now. FP select instruction can often be expensive.
Value *BrCond = BI->getCondition();
if (isa<FCmpInst>(BrCond))
@@ -3021,7 +3015,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
}
// Metadata can be dependent on the condition we are hoisting above.
- // Conservatively strip all metadata on the instruction. Drop the debug loc
+ // Strip all UB-implying metadata on the instruction. Drop the debug loc
// to avoid making it appear as if the condition is a constant, which would
// be misleading while debugging.
// Similarly strip attributes that maybe dependent on condition we are
@@ -3032,7 +3026,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
if (!isa<DbgAssignIntrinsic>(&I))
I.setDebugLoc(DebugLoc());
}
- I.dropUndefImplyingAttrsAndUnknownMetadata();
+ I.dropUBImplyingAttrsAndMetadata();
// Drop ephemeral values.
if (EphTracker.contains(&I)) {
@@ -3220,6 +3214,9 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
}
// Clone the instruction.
Instruction *N = BBI->clone();
+ // Insert the new instruction into its new home.
+ N->insertInto(EdgeBB, InsertPt);
+
if (BBI->hasName())
N->setName(BBI->getName() + ".c");
@@ -3235,7 +3232,8 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
if (!BBI->use_empty())
TranslateMap[&*BBI] = V;
if (!N->mayHaveSideEffects()) {
- N->deleteValue(); // Instruction folded away, don't need actual inst
+ N->eraseFromParent(); // Instruction folded away, don't need actual
+ // inst
N = nullptr;
}
} else {
@@ -3243,9 +3241,6 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
TranslateMap[&*BBI] = N;
}
if (N) {
- // Insert the new instruction into its new home.
- N->insertInto(EdgeBB, InsertPt);
-
// Register the new instruction with the assumption cache if necessary.
if (auto *Assume = dyn_cast<AssumeInst>(N))
if (AC)
@@ -3591,17 +3586,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
// If we need to invert the condition in the pred block to match, do so now.
if (InvertPredCond) {
- Value *NewCond = PBI->getCondition();
- if (NewCond->hasOneUse() && isa<CmpInst>(NewCond)) {
- CmpInst *CI = cast<CmpInst>(NewCond);
- CI->setPredicate(CI->getInversePredicate());
- } else {
- NewCond =
- Builder.CreateNot(NewCond, PBI->getCondition()->getName() + ".not");
- }
-
- PBI->setCondition(NewCond);
- PBI->swapSuccessors();
+ InvertBranch(PBI, Builder);
}
BasicBlock *UniqueSucc =
@@ -3887,7 +3872,7 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB,
for (BasicBlock *PredBB : predecessors(Succ))
if (PredBB != BB)
PHI->addIncoming(
- AlternativeV ? AlternativeV : UndefValue::get(V->getType()), PredBB);
+ AlternativeV ? AlternativeV : PoisonValue::get(V->getType()), PredBB);
return PHI;
}
@@ -5150,14 +5135,18 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
Value* Cond = BI->getCondition();
assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
"The destinations are guaranteed to be different here.");
+ CallInst *Assumption;
if (BI->getSuccessor(0) == BB) {
- Builder.CreateAssumption(Builder.CreateNot(Cond));
+ Assumption = Builder.CreateAssumption(Builder.CreateNot(Cond));
Builder.CreateBr(BI->getSuccessor(1));
} else {
assert(BI->getSuccessor(1) == BB && "Incorrect CFG");
- Builder.CreateAssumption(Cond);
+ Assumption = Builder.CreateAssumption(Cond);
Builder.CreateBr(BI->getSuccessor(0));
}
+ if (Options.AC)
+ Options.AC->registerAssumption(cast<AssumeInst>(Assumption));
+
EraseTerminatorAndDCECond(BI);
Changed = true;
}
@@ -5453,7 +5442,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
}
const APInt &CaseVal = Case.getCaseValue()->getValue();
if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) ||
- (CaseVal.getMinSignedBits() > MaxSignificantBitsInCond)) {
+ (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) {
DeadCases.push_back(Case.getCaseValue());
if (DTU)
--NumPerSuccessorCases[Successor];
@@ -5469,7 +5458,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
bool HasDefault =
!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
const unsigned NumUnknownBits =
- Known.getBitWidth() - (Known.Zero | Known.One).countPopulation();
+ Known.getBitWidth() - (Known.Zero | Known.One).popcount();
assert(NumUnknownBits <= Known.getBitWidth());
if (HasDefault && DeadCases.empty() &&
NumUnknownBits < 64 /* avoid overflow */ &&
@@ -5860,7 +5849,7 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
// Check if cases with the same result can cover all number
// in touched bits.
- if (BitMask.countPopulation() == Log2_32(CaseCount)) {
+ if (BitMask.popcount() == Log2_32(CaseCount)) {
if (!MinCaseVal->isNullValue())
Condition = Builder.CreateSub(Condition, MinCaseVal);
Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
@@ -6001,6 +5990,7 @@ private:
// For LinearMapKind, these are the constants used to derive the value.
ConstantInt *LinearOffset = nullptr;
ConstantInt *LinearMultiplier = nullptr;
+ bool LinearMapValWrapped = false;
// For ArrayKind, this is the array.
GlobalVariable *Array = nullptr;
@@ -6061,6 +6051,8 @@ SwitchLookupTable::SwitchLookupTable(
bool LinearMappingPossible = true;
APInt PrevVal;
APInt DistToPrev;
+ // When linear map is monotonic, we can attach nsw.
+ bool Wrapped = false;
assert(TableSize >= 2 && "Should be a SingleValue table.");
// Check if there is the same distance between two consecutive values.
for (uint64_t I = 0; I < TableSize; ++I) {
@@ -6080,12 +6072,15 @@ SwitchLookupTable::SwitchLookupTable(
LinearMappingPossible = false;
break;
}
+ Wrapped |=
+ Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
}
PrevVal = Val;
}
if (LinearMappingPossible) {
LinearOffset = cast<ConstantInt>(TableContents[0]);
LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev);
+ LinearMapValWrapped = Wrapped;
Kind = LinearMapKind;
++NumLinearMaps;
return;
@@ -6134,9 +6129,14 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(),
false, "switch.idx.cast");
if (!LinearMultiplier->isOne())
- Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult");
+ Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult",
+ /*HasNUW = */ false,
+ /*HasNSW = */ !LinearMapValWrapped);
+
if (!LinearOffset->isZero())
- Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset");
+ Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset",
+ /*HasNUW = */ false,
+ /*HasNSW = */ !LinearMapValWrapped);
return Result;
}
case BitMapKind: {
@@ -6148,10 +6148,12 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
// truncating it to the width of the bitmask is safe.
Value *ShiftAmt = Builder.CreateZExtOrTrunc(Index, MapTy, "switch.cast");
- // Multiply the shift amount by the element width.
+ // Multiply the shift amount by the element width. NUW/NSW can always be
+ // set, because WouldFitInRegister guarantees Index * ShiftAmt is in
+ // BitMap's bit width.
ShiftAmt = Builder.CreateMul(
ShiftAmt, ConstantInt::get(MapTy, BitMapElementTy->getBitWidth()),
- "switch.shiftamt");
+ "switch.shiftamt",/*HasNUW =*/true,/*HasNSW =*/true);
// Shift down.
Value *DownShifted =
@@ -6490,6 +6492,21 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
std::vector<DominatorTree::UpdateType> Updates;
+ // Compute the maximum table size representable by the integer type we are
+ // switching upon.
+ unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits();
+ uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize;
+ assert(MaxTableSize >= TableSize &&
+ "It is impossible for a switch to have more entries than the max "
+ "representable value of its input integer type's size.");
+
+ // If the default destination is unreachable, or if the lookup table covers
+ // all values of the conditional variable, branch directly to the lookup table
+ // BB. Otherwise, check that the condition is within the case range.
+ const bool DefaultIsReachable =
+ !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
+ const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize);
+
// Create the BB that does the lookups.
Module &Mod = *CommonDest->getParent()->getParent();
BasicBlock *LookupBB = BasicBlock::Create(
@@ -6504,24 +6521,19 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
TableIndex = SI->getCondition();
} else {
TableIndexOffset = MinCaseVal;
- TableIndex =
- Builder.CreateSub(SI->getCondition(), TableIndexOffset, "switch.tableidx");
- }
+ // If the default is unreachable, all case values are s>= MinCaseVal. Then
+ // we can try to attach nsw.
+ bool MayWrap = true;
+ if (!DefaultIsReachable) {
+ APInt Res = MaxCaseVal->getValue().ssub_ov(MinCaseVal->getValue(), MayWrap);
+ (void)Res;
+ }
- // Compute the maximum table size representable by the integer type we are
- // switching upon.
- unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits();
- uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize;
- assert(MaxTableSize >= TableSize &&
- "It is impossible for a switch to have more entries than the max "
- "representable value of its input integer type's size.");
+ TableIndex = Builder.CreateSub(SI->getCondition(), TableIndexOffset,
+ "switch.tableidx", /*HasNUW =*/false,
+ /*HasNSW =*/!MayWrap);
+ }
- // If the default destination is unreachable, or if the lookup table covers
- // all values of the conditional variable, branch directly to the lookup table
- // BB. Otherwise, check that the condition is within the case range.
- const bool DefaultIsReachable =
- !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
- const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize);
BranchInst *RangeCheckBranch = nullptr;
if (!DefaultIsReachable || GeneratingCoveredLookupTable) {
@@ -6694,7 +6706,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
// less than 64.
unsigned Shift = 64;
for (auto &V : Values)
- Shift = std::min(Shift, countTrailingZeros((uint64_t)V));
+ Shift = std::min(Shift, (unsigned)llvm::countr_zero((uint64_t)V));
assert(Shift < 64);
if (Shift > 0)
for (auto &V : Values)
@@ -6990,7 +7002,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
"Tautological conditional branch should have been eliminated already.");
BasicBlock *BB = BI->getParent();
- if (!Options.SimplifyCondBranch)
+ if (!Options.SimplifyCondBranch ||
+ BI->getFunction()->hasFnAttribute(Attribute::OptForFuzzing))
return false;
// Conditional branch
@@ -7045,8 +7058,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// can hoist it up to the branching block.
if (BI->getSuccessor(0)->getSinglePredecessor()) {
if (BI->getSuccessor(1)->getSinglePredecessor()) {
- if (HoistCommon &&
- HoistThenElseCodeToIf(BI, TTI, !Options.HoistCommonInsts))
+ if (HoistCommon && HoistThenElseCodeToIf(BI, !Options.HoistCommonInsts))
return requestResimplify();
} else {
// If Successor #1 has multiple preds, we may be able to conditionally
@@ -7054,7 +7066,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
Instruction *Succ0TI = BI->getSuccessor(0)->getTerminator();
if (Succ0TI->getNumSuccessors() == 1 &&
Succ0TI->getSuccessor(0) == BI->getSuccessor(1))
- if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), TTI))
+ if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0)))
return requestResimplify();
}
} else if (BI->getSuccessor(1)->getSinglePredecessor()) {
@@ -7063,7 +7075,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
Instruction *Succ1TI = BI->getSuccessor(1)->getTerminator();
if (Succ1TI->getNumSuccessors() == 1 &&
Succ1TI->getSuccessor(0) == BI->getSuccessor(0))
- if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), TTI))
+ if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1)))
return requestResimplify();
}
@@ -7179,7 +7191,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu
/// If BB has an incoming value that will always trigger undefined behavior
/// (eg. null pointer dereference), remove the branch leading here.
static bool removeUndefIntroducingPredecessor(BasicBlock *BB,
- DomTreeUpdater *DTU) {
+ DomTreeUpdater *DTU,
+ AssumptionCache *AC) {
for (PHINode &PHI : BB->phis())
for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i)
if (passingValueIsAlwaysUndefined(PHI.getIncomingValue(i), &PHI)) {
@@ -7196,10 +7209,13 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB,
// Preserve guarding condition in assume, because it might not be
// inferrable from any dominating condition.
Value *Cond = BI->getCondition();
+ CallInst *Assumption;
if (BI->getSuccessor(0) == BB)
- Builder.CreateAssumption(Builder.CreateNot(Cond));
+ Assumption = Builder.CreateAssumption(Builder.CreateNot(Cond));
else
- Builder.CreateAssumption(Cond);
+ Assumption = Builder.CreateAssumption(Cond);
+ if (AC)
+ AC->registerAssumption(cast<AssumeInst>(Assumption));
Builder.CreateBr(BI->getSuccessor(0) == BB ? BI->getSuccessor(1)
: BI->getSuccessor(0));
}
@@ -7260,7 +7276,7 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) {
Changed |= EliminateDuplicatePHINodes(BB);
// Check for and remove branches that will always cause undefined behavior.
- if (removeUndefIntroducingPredecessor(BB, DTU))
+ if (removeUndefIntroducingPredecessor(BB, DTU, Options.AC))
return requestResimplify();
// Merge basic blocks into their predecessor if there is only one distinct
@@ -7282,7 +7298,8 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) {
IRBuilder<> Builder(BB);
- if (Options.FoldTwoEntryPHINode) {
+ if (Options.SpeculateBlocks &&
+ !BB->getParent()->hasFnAttribute(Attribute::OptForFuzzing)) {
// If there is a trivial two-entry PHI node in this basic block, and we can
// eliminate it, do so now.
if (auto *PN = dyn_cast<PHINode>(BB->begin()))
diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index 4e83d2f6e3c6..a28916bc9baf 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -93,6 +93,7 @@ namespace {
void replaceRemWithNumeratorOrZero(BinaryOperator *Rem);
void replaceSRemWithURem(BinaryOperator *Rem);
bool eliminateSDiv(BinaryOperator *SDiv);
+ bool strengthenBinaryOp(BinaryOperator *BO, Instruction *IVOperand);
bool strengthenOverflowingOperation(BinaryOperator *OBO,
Instruction *IVOperand);
bool strengthenRightShift(BinaryOperator *BO, Instruction *IVOperand);
@@ -216,8 +217,10 @@ bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp,
// Do not generate something ridiculous.
auto *PHTerm = Preheader->getTerminator();
- if (Rewriter.isHighCostExpansion({ InvariantLHS, InvariantRHS }, L,
- 2 * SCEVCheapExpansionBudget, TTI, PHTerm))
+ if (Rewriter.isHighCostExpansion({InvariantLHS, InvariantRHS}, L,
+ 2 * SCEVCheapExpansionBudget, TTI, PHTerm) ||
+ !Rewriter.isSafeToExpandAt(InvariantLHS, PHTerm) ||
+ !Rewriter.isSafeToExpandAt(InvariantRHS, PHTerm))
return false;
auto *NewLHS =
Rewriter.expandCodeFor(InvariantLHS, IVOperand->getType(), PHTerm);
@@ -747,6 +750,13 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst,
return true;
}
+bool SimplifyIndvar::strengthenBinaryOp(BinaryOperator *BO,
+ Instruction *IVOperand) {
+ return (isa<OverflowingBinaryOperator>(BO) &&
+ strengthenOverflowingOperation(BO, IVOperand)) ||
+ (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand));
+}
+
/// Annotate BO with nsw / nuw if it provably does not signed-overflow /
/// unsigned-overflow. Returns true if anything changed, false otherwise.
bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO,
@@ -898,6 +908,14 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) {
if (replaceIVUserWithLoopInvariant(UseInst))
continue;
+ // Go further for the bitcast ''prtoint ptr to i64'
+ if (isa<PtrToIntInst>(UseInst))
+ for (Use &U : UseInst->uses()) {
+ Instruction *User = cast<Instruction>(U.getUser());
+ if (replaceIVUserWithLoopInvariant(User))
+ break; // done replacing
+ }
+
Instruction *IVOperand = UseOper.second;
for (unsigned N = 0; IVOperand; ++N) {
assert(N <= Simplified.size() && "runaway iteration");
@@ -917,9 +935,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) {
}
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseInst)) {
- if ((isa<OverflowingBinaryOperator>(BO) &&
- strengthenOverflowingOperation(BO, IVOperand)) ||
- (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand))) {
+ if (strengthenBinaryOp(BO, IVOperand)) {
// re-queue uses of the now modified binary operator and fall
// through to the checks that remain.
pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 20f18322d43c..5b0951252c07 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -14,11 +14,12 @@
#include "llvm/Transforms/Utils/SimplifyLibCalls.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/Triple.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -29,6 +30,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
@@ -44,6 +46,45 @@ static cl::opt<bool>
cl::desc("Enable unsafe double to float "
"shrinking for math lib calls"));
+// Enable conversion of operator new calls with a MemProf hot or cold hint
+// to an operator new call that takes a hot/cold hint. Off by default since
+// not all allocators currently support this extension.
+static cl::opt<bool>
+ OptimizeHotColdNew("optimize-hot-cold-new", cl::Hidden, cl::init(false),
+ cl::desc("Enable hot/cold operator new library calls"));
+
+namespace {
+
+// Specialized parser to ensure the hint is an 8 bit value (we can't specify
+// uint8_t to opt<> as that is interpreted to mean that we are passing a char
+// option with a specific set of values.
+struct HotColdHintParser : public cl::parser<unsigned> {
+ HotColdHintParser(cl::Option &O) : cl::parser<unsigned>(O) {}
+
+ bool parse(cl::Option &O, StringRef ArgName, StringRef Arg, unsigned &Value) {
+ if (Arg.getAsInteger(0, Value))
+ return O.error("'" + Arg + "' value invalid for uint argument!");
+
+ if (Value > 255)
+ return O.error("'" + Arg + "' value must be in the range [0, 255]!");
+
+ return false;
+ }
+};
+
+} // end anonymous namespace
+
+// Hot/cold operator new takes an 8 bit hotness hint, where 0 is the coldest
+// and 255 is the hottest. Default to 1 value away from the coldest and hottest
+// hints, so that the compiler hinted allocations are slightly less strong than
+// manually inserted hints at the two extremes.
+static cl::opt<unsigned, false, HotColdHintParser> ColdNewHintValue(
+ "cold-new-hint-value", cl::Hidden, cl::init(1),
+ cl::desc("Value to pass to hot/cold operator new for cold allocation"));
+static cl::opt<unsigned, false, HotColdHintParser> HotNewHintValue(
+ "hot-new-hint-value", cl::Hidden, cl::init(254),
+ cl::desc("Value to pass to hot/cold operator new for hot allocation"));
+
//===----------------------------------------------------------------------===//
// Helper Functions
//===----------------------------------------------------------------------===//
@@ -186,21 +227,9 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr,
return ConstantInt::get(RetTy, Result);
}
-static bool isOnlyUsedInComparisonWithZero(Value *V) {
- for (User *U : V->users()) {
- if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
- if (Constant *C = dyn_cast<Constant>(IC->getOperand(1)))
- if (C->isNullValue())
- continue;
- // Unknown instruction.
- return false;
- }
- return true;
-}
-
static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len,
const DataLayout &DL) {
- if (!isOnlyUsedInComparisonWithZero(CI))
+ if (!isOnlyUsedInZeroComparison(CI))
return false;
if (!isDereferenceableAndAlignedPointer(Str, Align(1), APInt(64, Len), DL))
@@ -1358,6 +1387,10 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
return nullptr;
}
+ bool OptForSize = CI->getFunction()->hasOptSize() ||
+ llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI,
+ PGSOQueryType::IRPass);
+
// If the char is variable but the input str and length are not we can turn
// this memchr call into a simple bit field test. Of course this only works
// when the return value is only checked against null.
@@ -1368,7 +1401,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
// memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n')))
// != 0
// after bounds check.
- if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI))
+ if (OptForSize || Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI))
return nullptr;
unsigned char Max =
@@ -1380,8 +1413,34 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
// FIXME: On a 64 bit architecture this prevents us from using the
// interesting range of alpha ascii chars. We could do better by emitting
// two bitfields or shifting the range by 64 if no lower chars are used.
- if (!DL.fitsInLegalInteger(Max + 1))
- return nullptr;
+ if (!DL.fitsInLegalInteger(Max + 1)) {
+ // Build chain of ORs
+ // Transform:
+ // memchr("abcd", C, 4) != nullptr
+ // to:
+ // (C == 'a' || C == 'b' || C == 'c' || C == 'd') != 0
+ std::string SortedStr = Str.str();
+ llvm::sort(SortedStr);
+ // Compute the number of of non-contiguous ranges.
+ unsigned NonContRanges = 1;
+ for (size_t i = 1; i < SortedStr.size(); ++i) {
+ if (SortedStr[i] > SortedStr[i - 1] + 1) {
+ NonContRanges++;
+ }
+ }
+
+ // Restrict this optimization to profitable cases with one or two range
+ // checks.
+ if (NonContRanges > 2)
+ return nullptr;
+
+ SmallVector<Value *> CharCompares;
+ for (unsigned char C : SortedStr)
+ CharCompares.push_back(
+ B.CreateICmpEQ(CharVal, ConstantInt::get(CharVal->getType(), C)));
+
+ return B.CreateIntToPtr(B.CreateOr(CharCompares), CI->getType());
+ }
// For the bit field use a power-of-2 type with at least 8 bits to avoid
// creating unnecessary illegal types.
@@ -1481,30 +1540,21 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS,
// First, see if we can fold either argument to a constant.
Value *LHSV = nullptr;
- if (auto *LHSC = dyn_cast<Constant>(LHS)) {
- LHSC = ConstantExpr::getBitCast(LHSC, IntType->getPointerTo());
+ if (auto *LHSC = dyn_cast<Constant>(LHS))
LHSV = ConstantFoldLoadFromConstPtr(LHSC, IntType, DL);
- }
+
Value *RHSV = nullptr;
- if (auto *RHSC = dyn_cast<Constant>(RHS)) {
- RHSC = ConstantExpr::getBitCast(RHSC, IntType->getPointerTo());
+ if (auto *RHSC = dyn_cast<Constant>(RHS))
RHSV = ConstantFoldLoadFromConstPtr(RHSC, IntType, DL);
- }
// Don't generate unaligned loads. If either source is constant data,
// alignment doesn't matter for that source because there is no load.
if ((LHSV || getKnownAlignment(LHS, DL, CI) >= PrefAlignment) &&
(RHSV || getKnownAlignment(RHS, DL, CI) >= PrefAlignment)) {
- if (!LHSV) {
- Type *LHSPtrTy =
- IntType->getPointerTo(LHS->getType()->getPointerAddressSpace());
- LHSV = B.CreateLoad(IntType, B.CreateBitCast(LHS, LHSPtrTy), "lhsv");
- }
- if (!RHSV) {
- Type *RHSPtrTy =
- IntType->getPointerTo(RHS->getType()->getPointerAddressSpace());
- RHSV = B.CreateLoad(IntType, B.CreateBitCast(RHS, RHSPtrTy), "rhsv");
- }
+ if (!LHSV)
+ LHSV = B.CreateLoad(IntType, LHS, "lhsv");
+ if (!RHSV)
+ RHSV = B.CreateLoad(IntType, RHS, "rhsv");
return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp");
}
}
@@ -1653,6 +1703,59 @@ Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilderBase &B) {
return nullptr;
}
+// When enabled, replace operator new() calls marked with a hot or cold memprof
+// attribute with an operator new() call that takes a __hot_cold_t parameter.
+// Currently this is supported by the open source version of tcmalloc, see:
+// https://github.com/google/tcmalloc/blob/master/tcmalloc/new_extension.h
+Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B,
+ LibFunc &Func) {
+ if (!OptimizeHotColdNew)
+ return nullptr;
+
+ uint8_t HotCold;
+ if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "cold")
+ HotCold = ColdNewHintValue;
+ else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "hot")
+ HotCold = HotNewHintValue;
+ else
+ return nullptr;
+
+ switch (Func) {
+ case LibFunc_Znwm:
+ return emitHotColdNew(CI->getArgOperand(0), B, TLI,
+ LibFunc_Znwm12__hot_cold_t, HotCold);
+ case LibFunc_Znam:
+ return emitHotColdNew(CI->getArgOperand(0), B, TLI,
+ LibFunc_Znam12__hot_cold_t, HotCold);
+ case LibFunc_ZnwmRKSt9nothrow_t:
+ return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B,
+ TLI, LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t,
+ HotCold);
+ case LibFunc_ZnamRKSt9nothrow_t:
+ return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B,
+ TLI, LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t,
+ HotCold);
+ case LibFunc_ZnwmSt11align_val_t:
+ return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B,
+ TLI, LibFunc_ZnwmSt11align_val_t12__hot_cold_t,
+ HotCold);
+ case LibFunc_ZnamSt11align_val_t:
+ return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B,
+ TLI, LibFunc_ZnamSt11align_val_t12__hot_cold_t,
+ HotCold);
+ case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
+ return emitHotColdNewAlignedNoThrow(
+ CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
+ TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold);
+ case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
+ return emitHotColdNewAlignedNoThrow(
+ CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
+ TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold);
+ default:
+ return nullptr;
+ }
+}
+
//===----------------------------------------------------------------------===//
// Math Library Optimizations
//===----------------------------------------------------------------------===//
@@ -1939,7 +2042,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) {
AttributeList NoAttrs; // Attributes are only meaningful on the original call
// pow(2.0, itofp(x)) -> ldexp(1.0, x)
- if (match(Base, m_SpecificFP(2.0)) &&
+ // TODO: This does not work for vectors because there is no ldexp intrinsic.
+ if (!Ty->isVectorTy() && match(Base, m_SpecificFP(2.0)) &&
(isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) &&
hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) {
if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize()))
@@ -2056,7 +2160,7 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) {
// pow(-Inf, 0.5) is optionally required to have a result of +Inf (not setting
// errno), but sqrt(-Inf) is required by various standards to set errno.
if (!Pow->doesNotAccessMemory() && !Pow->hasNoInfs() &&
- !isKnownNeverInfinity(Base, TLI))
+ !isKnownNeverInfinity(Base, DL, TLI, 0, AC, Pow))
return nullptr;
Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), Mod, B,
@@ -2217,17 +2321,25 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) {
hasFloatVersion(M, Name))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
+ // Bail out for vectors because the code below only expects scalars.
+ // TODO: This could be allowed if we had a ldexp intrinsic (D14327).
Type *Ty = CI->getType();
- Value *Op = CI->getArgOperand(0);
+ if (Ty->isVectorTy())
+ return Ret;
// exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize
// exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize
+ Value *Op = CI->getArgOperand(0);
if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) &&
hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) {
- if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize()))
- return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI,
- LibFunc_ldexp, LibFunc_ldexpf,
- LibFunc_ldexpl, B, AttributeList());
+ if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) {
+ IRBuilderBase::FastMathFlagGuard Guard(B);
+ B.setFastMathFlags(CI->getFastMathFlags());
+ return copyFlags(
+ *CI, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI,
+ LibFunc_ldexp, LibFunc_ldexpf,
+ LibFunc_ldexpl, B, AttributeList()));
+ }
}
return Ret;
@@ -2579,7 +2691,7 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
return true;
}
-Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
// Make sure the prototype is as expected, otherwise the rest of the
// function is probably invalid and likely to abort.
if (!isTrigLibCall(CI))
@@ -2618,7 +2730,7 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) {
replaceTrigInsts(CosCalls, Cos);
replaceTrigInsts(SinCosCalls, SinCos);
- return nullptr;
+ return IsSin ? Sin : Cos;
}
void LibCallSimplifier::classifyArgUse(
@@ -3439,6 +3551,15 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
return optimizeWcslen(CI, Builder);
case LibFunc_bcopy:
return optimizeBCopy(CI, Builder);
+ case LibFunc_Znwm:
+ case LibFunc_ZnwmRKSt9nothrow_t:
+ case LibFunc_ZnwmSt11align_val_t:
+ case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
+ case LibFunc_Znam:
+ case LibFunc_ZnamRKSt9nothrow_t:
+ case LibFunc_ZnamSt11align_val_t:
+ case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
+ return optimizeNew(CI, Builder, Func);
default:
break;
}
@@ -3461,9 +3582,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
switch (Func) {
case LibFunc_sinpif:
case LibFunc_sinpi:
+ return optimizeSinCosPi(CI, /*IsSin*/true, Builder);
case LibFunc_cospif:
case LibFunc_cospi:
- return optimizeSinCosPi(CI, Builder);
+ return optimizeSinCosPi(CI, /*IsSin*/false, Builder);
case LibFunc_powf:
case LibFunc_pow:
case LibFunc_powl:
@@ -3696,13 +3818,13 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) {
}
LibCallSimplifier::LibCallSimplifier(
- const DataLayout &DL, const TargetLibraryInfo *TLI,
- OptimizationRemarkEmitter &ORE,
- BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI,
+ const DataLayout &DL, const TargetLibraryInfo *TLI, AssumptionCache *AC,
+ OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
+ ProfileSummaryInfo *PSI,
function_ref<void(Instruction *, Value *)> Replacer,
function_ref<void(Instruction *)> Eraser)
- : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), BFI(BFI), PSI(PSI),
- Replacer(Replacer), Eraser(Eraser) {}
+ : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), AC(AC), ORE(ORE), BFI(BFI),
+ PSI(PSI), Replacer(Replacer), Eraser(Eraser) {}
void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
// Indirect through the replacer used in this instance.
diff --git a/llvm/lib/Transforms/Utils/SizeOpts.cpp b/llvm/lib/Transforms/Utils/SizeOpts.cpp
index 1242380f73c1..1ca2e0e6ebb9 100644
--- a/llvm/lib/Transforms/Utils/SizeOpts.cpp
+++ b/llvm/lib/Transforms/Utils/SizeOpts.cpp
@@ -98,14 +98,12 @@ struct BasicBlockBFIAdapter {
bool llvm::shouldOptimizeForSize(const Function *F, ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI,
PGSOQueryType QueryType) {
- return shouldFuncOptimizeForSizeImpl<BasicBlockBFIAdapter>(F, PSI, BFI,
- QueryType);
+ return shouldFuncOptimizeForSizeImpl(F, PSI, BFI, QueryType);
}
bool llvm::shouldOptimizeForSize(const BasicBlock *BB, ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI,
PGSOQueryType QueryType) {
assert(BB);
- return shouldOptimizeForSizeImpl<BasicBlockBFIAdapter>(BB, PSI, BFI,
- QueryType);
+ return shouldOptimizeForSizeImpl(BB, PSI, BFI, QueryType);
}
diff --git a/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp
index 10fda4df51ba..618c6bab3a8f 100644
--- a/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp
+++ b/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp
@@ -8,44 +8,13 @@
#include "llvm/Transforms/Utils/StripNonLineTableDebugInfo.h"
#include "llvm/IR/DebugInfo.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/Transforms/Utils.h"
-using namespace llvm;
-
-namespace {
-
-/// This pass strips all debug info that is not related line tables.
-/// The result will be the same as if the program where compiled with
-/// -gline-tables-only.
-struct StripNonLineTableDebugLegacyPass : public ModulePass {
- static char ID; // Pass identification, replacement for typeid
- StripNonLineTableDebugLegacyPass() : ModulePass(ID) {
- initializeStripNonLineTableDebugLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
-
- bool runOnModule(Module &M) override {
- return llvm::stripNonLineTableDebugInfo(M);
- }
-};
-}
-
-char StripNonLineTableDebugLegacyPass::ID = 0;
-INITIALIZE_PASS(StripNonLineTableDebugLegacyPass,
- "strip-nonlinetable-debuginfo",
- "Strip all debug info except linetables", false, false)
-
-ModulePass *llvm::createStripNonLineTableDebugLegacyPass() {
- return new StripNonLineTableDebugLegacyPass();
-}
+using namespace llvm;
PreservedAnalyses
StripNonLineTableDebugInfoPass::run(Module &M, ModuleAnalysisManager &AM) {
llvm::stripNonLineTableDebugInfo(M);
- return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserveSet<CFGAnalyses>();
+ return PA;
}
diff --git a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
index 4ad16d622e8d..c3ae43e567b0 100644
--- a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
+++ b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
@@ -517,37 +517,6 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
return true;
}
-namespace {
-
-class RewriteSymbolsLegacyPass : public ModulePass {
-public:
- static char ID; // Pass identification, replacement for typeid
-
- RewriteSymbolsLegacyPass();
- RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL);
-
- bool runOnModule(Module &M) override;
-
-private:
- RewriteSymbolPass Impl;
-};
-
-} // end anonymous namespace
-
-char RewriteSymbolsLegacyPass::ID = 0;
-
-RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) {
- initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry());
-}
-
-RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass(
- SymbolRewriter::RewriteDescriptorList &DL)
- : ModulePass(ID), Impl(DL) {}
-
-bool RewriteSymbolsLegacyPass::runOnModule(Module &M) {
- return Impl.runImpl(M);
-}
-
PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) {
if (!runImpl(M))
return PreservedAnalyses::all();
@@ -572,15 +541,3 @@ void RewriteSymbolPass::loadAndParseMapFiles() {
for (const auto &MapFile : MapFiles)
Parser.parse(MapFile, &Descriptors);
}
-
-INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols",
- false, false)
-
-ModulePass *llvm::createRewriteSymbolsPass() {
- return new RewriteSymbolsLegacyPass();
-}
-
-ModulePass *
-llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
- return new RewriteSymbolsLegacyPass(DL);
-}
diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
index 3be96ebc93a2..8c781f59ff5a 100644
--- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
+++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
@@ -113,7 +113,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L,
}
}
- for (auto II : ExternalUsers) {
+ for (const auto &II : ExternalUsers) {
// For each Def used outside the loop, create NewPhi in
// LoopExitBlock. NewPhi receives Def only along exiting blocks that
// dominate it, while the remaining values are undefined since those paths
@@ -130,7 +130,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L,
NewPhi->addIncoming(Def, In);
} else {
LLVM_DEBUG(dbgs() << "not dominated\n");
- NewPhi->addIncoming(UndefValue::get(Def->getType()), In);
+ NewPhi->addIncoming(PoisonValue::get(Def->getType()), In);
}
}
diff --git a/llvm/lib/Transforms/Utils/Utils.cpp b/llvm/lib/Transforms/Utils/Utils.cpp
index d002922cfd30..91c743f17764 100644
--- a/llvm/lib/Transforms/Utils/Utils.cpp
+++ b/llvm/lib/Transforms/Utils/Utils.cpp
@@ -12,9 +12,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils.h"
-#include "llvm-c/Initialization.h"
-#include "llvm-c/Transforms/Utils.h"
-#include "llvm/IR/LegacyPassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
@@ -24,42 +21,18 @@ using namespace llvm;
/// initializeTransformUtils - Initialize all passes in the TransformUtils
/// library.
void llvm::initializeTransformUtils(PassRegistry &Registry) {
- initializeAddDiscriminatorsLegacyPassPass(Registry);
- initializeAssumeSimplifyPassLegacyPassPass(Registry);
initializeAssumeBuilderPassLegacyPassPass(Registry);
initializeBreakCriticalEdgesPass(Registry);
initializeCanonicalizeFreezeInLoopsPass(Registry);
- initializeInstNamerPass(Registry);
initializeLCSSAWrapperPassPass(Registry);
- initializeLibCallsShrinkWrapLegacyPassPass(Registry);
initializeLoopSimplifyPass(Registry);
initializeLowerGlobalDtorsLegacyPassPass(Registry);
initializeLowerInvokeLegacyPassPass(Registry);
initializeLowerSwitchLegacyPassPass(Registry);
initializePromoteLegacyPassPass(Registry);
- initializeStripNonLineTableDebugLegacyPassPass(Registry);
initializeUnifyFunctionExitNodesLegacyPassPass(Registry);
- initializeMetaRenamerPass(Registry);
initializeStripGCRelocatesLegacyPass(Registry);
initializePredicateInfoPrinterLegacyPassPass(Registry);
- initializeInjectTLIMappingsLegacyPass(Registry);
initializeFixIrreduciblePass(Registry);
initializeUnifyLoopExitsLegacyPassPass(Registry);
}
-
-/// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses.
-void LLVMInitializeTransformUtils(LLVMPassRegistryRef R) {
- initializeTransformUtils(*unwrap(R));
-}
-
-void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLowerSwitchPass());
-}
-
-void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createPromoteMemoryToRegisterPass());
-}
-
-void LLVMAddAddDiscriminatorsPass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createAddDiscriminatorsPass());
-}
diff --git a/llvm/lib/Transforms/Utils/VNCoercion.cpp b/llvm/lib/Transforms/Utils/VNCoercion.cpp
index f295a7e312b6..7a597da2bc51 100644
--- a/llvm/lib/Transforms/Utils/VNCoercion.cpp
+++ b/llvm/lib/Transforms/Utils/VNCoercion.cpp
@@ -226,91 +226,6 @@ int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr,
DL);
}
-/// Looks at a memory location for a load (specified by MemLocBase, Offs, and
-/// Size) and compares it against a load.
-///
-/// If the specified load could be safely widened to a larger integer load
-/// that is 1) still efficient, 2) safe for the target, and 3) would provide
-/// the specified memory location value, then this function returns the size
-/// in bytes of the load width to use. If not, this returns zero.
-static unsigned getLoadLoadClobberFullWidthSize(const Value *MemLocBase,
- int64_t MemLocOffs,
- unsigned MemLocSize,
- const LoadInst *LI) {
- // We can only extend simple integer loads.
- if (!isa<IntegerType>(LI->getType()) || !LI->isSimple())
- return 0;
-
- // Load widening is hostile to ThreadSanitizer: it may cause false positives
- // or make the reports more cryptic (access sizes are wrong).
- if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeThread))
- return 0;
-
- const DataLayout &DL = LI->getModule()->getDataLayout();
-
- // Get the base of this load.
- int64_t LIOffs = 0;
- const Value *LIBase =
- GetPointerBaseWithConstantOffset(LI->getPointerOperand(), LIOffs, DL);
-
- // If the two pointers are not based on the same pointer, we can't tell that
- // they are related.
- if (LIBase != MemLocBase)
- return 0;
-
- // Okay, the two values are based on the same pointer, but returned as
- // no-alias. This happens when we have things like two byte loads at "P+1"
- // and "P+3". Check to see if increasing the size of the "LI" load up to its
- // alignment (or the largest native integer type) will allow us to load all
- // the bits required by MemLoc.
-
- // If MemLoc is before LI, then no widening of LI will help us out.
- if (MemLocOffs < LIOffs)
- return 0;
-
- // Get the alignment of the load in bytes. We assume that it is safe to load
- // any legal integer up to this size without a problem. For example, if we're
- // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can
- // widen it up to an i32 load. If it is known 2-byte aligned, we can widen it
- // to i16.
- unsigned LoadAlign = LI->getAlign().value();
-
- int64_t MemLocEnd = MemLocOffs + MemLocSize;
-
- // If no amount of rounding up will let MemLoc fit into LI, then bail out.
- if (LIOffs + LoadAlign < MemLocEnd)
- return 0;
-
- // This is the size of the load to try. Start with the next larger power of
- // two.
- unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U;
- NewLoadByteSize = NextPowerOf2(NewLoadByteSize);
-
- while (true) {
- // If this load size is bigger than our known alignment or would not fit
- // into a native integer register, then we fail.
- if (NewLoadByteSize > LoadAlign ||
- !DL.fitsInLegalInteger(NewLoadByteSize * 8))
- return 0;
-
- if (LIOffs + NewLoadByteSize > MemLocEnd &&
- (LI->getParent()->getParent()->hasFnAttribute(
- Attribute::SanitizeAddress) ||
- LI->getParent()->getParent()->hasFnAttribute(
- Attribute::SanitizeHWAddress)))
- // We will be reading past the location accessed by the original program.
- // While this is safe in a regular build, Address Safety analysis tools
- // may start reporting false warnings. So, don't do widening.
- return 0;
-
- // If a load of this width would include all of MemLoc, then we succeed.
- if (LIOffs + NewLoadByteSize >= MemLocEnd)
- return NewLoadByteSize;
-
- NewLoadByteSize <<= 1;
- }
-}
-
/// This function is called when we have a
/// memdep query of a load that ends up being clobbered by another load. See if
/// the other load can feed into the second load.
@@ -325,28 +240,7 @@ int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI,
Value *DepPtr = DepLI->getPointerOperand();
uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()).getFixedValue();
- int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL);
- if (R != -1)
- return R;
-
- // If we have a load/load clobber an DepLI can be widened to cover this load,
- // then we should widen it!
- int64_t LoadOffs = 0;
- const Value *LoadBase =
- GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL);
- unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue();
-
- unsigned Size =
- getLoadLoadClobberFullWidthSize(LoadBase, LoadOffs, LoadSize, DepLI);
- if (Size == 0)
- return -1;
-
- // Check non-obvious conditions enforced by MDA which we rely on for being
- // able to materialize this potentially available value
- assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!");
- assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load");
-
- return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size * 8, DL);
+ return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL);
}
int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr,
@@ -438,83 +332,27 @@ static Value *getStoreValueForLoadHelper(Value *SrcVal, unsigned Offset,
return SrcVal;
}
-/// This function is called when we have a memdep query of a load that ends up
-/// being a clobbering store. This means that the store provides bits used by
-/// the load but the pointers don't must-alias. Check this case to see if
-/// there is anything more we can do before we give up.
-Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy,
- Instruction *InsertPt, const DataLayout &DL) {
+Value *getValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy,
+ Instruction *InsertPt, const DataLayout &DL) {
+#ifndef NDEBUG
+ unsigned SrcValSize = DL.getTypeStoreSize(SrcVal->getType()).getFixedValue();
+ unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue();
+ assert(Offset + LoadSize <= SrcValSize);
+#endif
IRBuilder<> Builder(InsertPt);
SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL);
return coerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL);
}
-Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset,
- Type *LoadTy, const DataLayout &DL) {
- return ConstantFoldLoadFromConst(SrcVal, LoadTy, APInt(32, Offset), DL);
-}
-
-/// This function is called when we have a memdep query of a load that ends up
-/// being a clobbering load. This means that the load *may* provide bits used
-/// by the load but we can't be sure because the pointers don't must-alias.
-/// Check this case to see if there is anything more we can do before we give
-/// up.
-Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy,
- Instruction *InsertPt, const DataLayout &DL) {
- // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to
- // widen SrcVal out to a larger load.
- unsigned SrcValStoreSize =
- DL.getTypeStoreSize(SrcVal->getType()).getFixedValue();
+Constant *getConstantValueForLoad(Constant *SrcVal, unsigned Offset,
+ Type *LoadTy, const DataLayout &DL) {
+#ifndef NDEBUG
+ unsigned SrcValSize = DL.getTypeStoreSize(SrcVal->getType()).getFixedValue();
unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue();
- if (Offset + LoadSize > SrcValStoreSize) {
- assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!");
- assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load");
- // If we have a load/load clobber an DepLI can be widened to cover this
- // load, then we should widen it to the next power of 2 size big enough!
- unsigned NewLoadSize = Offset + LoadSize;
- if (!isPowerOf2_32(NewLoadSize))
- NewLoadSize = NextPowerOf2(NewLoadSize);
-
- Value *PtrVal = SrcVal->getPointerOperand();
- // Insert the new load after the old load. This ensures that subsequent
- // memdep queries will find the new load. We can't easily remove the old
- // load completely because it is already in the value numbering table.
- IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal));
- Type *DestTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8);
- Type *DestPTy =
- PointerType::get(DestTy, PtrVal->getType()->getPointerAddressSpace());
- Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc());
- PtrVal = Builder.CreateBitCast(PtrVal, DestPTy);
- LoadInst *NewLoad = Builder.CreateLoad(DestTy, PtrVal);
- NewLoad->takeName(SrcVal);
- NewLoad->setAlignment(SrcVal->getAlign());
-
- LLVM_DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n");
- LLVM_DEBUG(dbgs() << "TO: " << *NewLoad << "\n");
-
- // Replace uses of the original load with the wider load. On a big endian
- // system, we need to shift down to get the relevant bits.
- Value *RV = NewLoad;
- if (DL.isBigEndian())
- RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8);
- RV = Builder.CreateTrunc(RV, SrcVal->getType());
- SrcVal->replaceAllUsesWith(RV);
-
- SrcVal = NewLoad;
- }
-
- return getStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL);
-}
-
-Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset,
- Type *LoadTy, const DataLayout &DL) {
- unsigned SrcValStoreSize =
- DL.getTypeStoreSize(SrcVal->getType()).getFixedValue();
- unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue();
- if (Offset + LoadSize > SrcValStoreSize)
- return nullptr;
- return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL);
+ assert(Offset + LoadSize <= SrcValSize);
+#endif
+ return ConstantFoldLoadFromConst(SrcVal, LoadTy, APInt(32, Offset), DL);
}
/// This function is called when we have a
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index a5edbb2acc6d..3446e31cc2ef 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -523,10 +523,14 @@ Value *Mapper::mapValue(const Value *V) {
if (isa<ConstantVector>(C))
return getVM()[V] = ConstantVector::get(Ops);
// If this is a no-operand constant, it must be because the type was remapped.
+ if (isa<PoisonValue>(C))
+ return getVM()[V] = PoisonValue::get(NewTy);
if (isa<UndefValue>(C))
return getVM()[V] = UndefValue::get(NewTy);
if (isa<ConstantAggregateZero>(C))
return getVM()[V] = ConstantAggregateZero::get(NewTy);
+ if (isa<ConstantTargetNone>(C))
+ return getVM()[V] = Constant::getNullValue(NewTy);
assert(isa<ConstantPointerNull>(C));
return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy));
}
@@ -1030,7 +1034,7 @@ void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix,
if (IsOldCtorDtor) {
// FIXME: This upgrade is done during linking to support the C API. See
// also IRLinker::linkAppendingVarProto() in IRMover.cpp.
- VoidPtrTy = Type::getInt8Ty(GV.getContext())->getPointerTo();
+ VoidPtrTy = PointerType::getUnqual(GV.getContext());
auto &ST = *cast<StructType>(NewMembers.front()->getType());
Type *Tys[3] = {ST.getElementType(0), ST.getElementType(1), VoidPtrTy};
EltTy = StructType::get(GV.getContext(), Tys, false);
@@ -1179,6 +1183,10 @@ void ValueMapper::remapFunction(Function &F) {
FlushingMapper(pImpl)->remapFunction(F);
}
+void ValueMapper::remapGlobalObjectMetadata(GlobalObject &GO) {
+ FlushingMapper(pImpl)->remapGlobalObjectMetadata(GO);
+}
+
void ValueMapper::scheduleMapGlobalInitializer(GlobalVariable &GV,
Constant &Init,
unsigned MCID) {
diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 0b7fc853dc1b..260d7889906b 100644
--- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -37,13 +37,34 @@
// multiple scalar registers, similar to a GPU vectorized load. In theory ARM
// could use this pass (with some modifications), but currently it implements
// its own pass to do something similar to what we do here.
+//
+// Overview of the algorithm and terminology in this pass:
+//
+// - Break up each basic block into pseudo-BBs, composed of instructions which
+// are guaranteed to transfer control to their successors.
+// - Within a single pseudo-BB, find all loads, and group them into
+// "equivalence classes" according to getUnderlyingObject() and loaded
+// element size. Do the same for stores.
+// - For each equivalence class, greedily build "chains". Each chain has a
+// leader instruction, and every other member of the chain has a known
+// constant offset from the first instr in the chain.
+// - Break up chains so that they contain only contiguous accesses of legal
+// size with no intervening may-alias instrs.
+// - Convert each chain to vector instructions.
+//
+// The O(n^2) behavior of this pass comes from initially building the chains.
+// In the worst case we have to compare each new instruction to all of those
+// that came before. To limit this, we only calculate the offset to the leaders
+// of the N most recently-used chains.
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
@@ -57,6 +78,7 @@
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
@@ -67,23 +89,33 @@
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/ModRef.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
#include <cassert>
+#include <cstdint>
#include <cstdlib>
+#include <iterator>
+#include <limits>
+#include <numeric>
+#include <optional>
#include <tuple>
+#include <type_traits>
#include <utility>
+#include <vector>
using namespace llvm;
@@ -92,21 +124,115 @@ using namespace llvm;
STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
+namespace {
+
+// Equivalence class key, the initial tuple by which we group loads/stores.
+// Loads/stores with different EqClassKeys are never merged.
+//
+// (We could in theory remove element-size from the this tuple. We'd just need
+// to fix up the vector packing/unpacking code.)
+using EqClassKey =
+ std::tuple<const Value * /* result of getUnderlyingObject() */,
+ unsigned /* AddrSpace */,
+ unsigned /* Load/Store element size bits */,
+ char /* IsLoad; char b/c bool can't be a DenseMap key */
+ >;
+[[maybe_unused]] llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
+ const EqClassKey &K) {
+ const auto &[UnderlyingObject, AddrSpace, ElementSize, IsLoad] = K;
+ return OS << (IsLoad ? "load" : "store") << " of " << *UnderlyingObject
+ << " of element size " << ElementSize << " bits in addrspace "
+ << AddrSpace;
+}
+
+// A Chain is a set of instructions such that:
+// - All instructions have the same equivalence class, so in particular all are
+// loads, or all are stores.
+// - We know the address accessed by the i'th chain elem relative to the
+// chain's leader instruction, which is the first instr of the chain in BB
+// order.
+//
+// Chains have two canonical orderings:
+// - BB order, sorted by Instr->comesBefore.
+// - Offset order, sorted by OffsetFromLeader.
+// This pass switches back and forth between these orders.
+struct ChainElem {
+ Instruction *Inst;
+ APInt OffsetFromLeader;
+};
+using Chain = SmallVector<ChainElem, 1>;
+
+void sortChainInBBOrder(Chain &C) {
+ sort(C, [](auto &A, auto &B) { return A.Inst->comesBefore(B.Inst); });
+}
+
+void sortChainInOffsetOrder(Chain &C) {
+ sort(C, [](const auto &A, const auto &B) {
+ if (A.OffsetFromLeader != B.OffsetFromLeader)
+ return A.OffsetFromLeader.slt(B.OffsetFromLeader);
+ return A.Inst->comesBefore(B.Inst); // stable tiebreaker
+ });
+}
+
+[[maybe_unused]] void dumpChain(ArrayRef<ChainElem> C) {
+ for (const auto &E : C) {
+ dbgs() << " " << *E.Inst << " (offset " << E.OffsetFromLeader << ")\n";
+ }
+}
+
+using EquivalenceClassMap =
+ MapVector<EqClassKey, SmallVector<Instruction *, 8>>;
+
// FIXME: Assuming stack alignment of 4 is always good enough
-static const unsigned StackAdjustedAlignment = 4;
+constexpr unsigned StackAdjustedAlignment = 4;
-namespace {
+Instruction *propagateMetadata(Instruction *I, const Chain &C) {
+ SmallVector<Value *, 8> Values;
+ for (const ChainElem &E : C)
+ Values.push_back(E.Inst);
+ return propagateMetadata(I, Values);
+}
-/// ChainID is an arbitrary token that is allowed to be different only for the
-/// accesses that are guaranteed to be considered non-consecutive by
-/// Vectorizer::isConsecutiveAccess. It's used for grouping instructions
-/// together and reducing the number of instructions the main search operates on
-/// at a time, i.e. this is to reduce compile time and nothing else as the main
-/// search has O(n^2) time complexity. The underlying type of ChainID should not
-/// be relied upon.
-using ChainID = const Value *;
-using InstrList = SmallVector<Instruction *, 8>;
-using InstrListMap = MapVector<ChainID, InstrList>;
+bool isInvariantLoad(const Instruction *I) {
+ const LoadInst *LI = dyn_cast<LoadInst>(I);
+ return LI != nullptr && LI->hasMetadata(LLVMContext::MD_invariant_load);
+}
+
+/// Reorders the instructions that I depends on (the instructions defining its
+/// operands), to ensure they dominate I.
+void reorder(Instruction *I) {
+ SmallPtrSet<Instruction *, 16> InstructionsToMove;
+ SmallVector<Instruction *, 16> Worklist;
+
+ Worklist.push_back(I);
+ while (!Worklist.empty()) {
+ Instruction *IW = Worklist.pop_back_val();
+ int NumOperands = IW->getNumOperands();
+ for (int i = 0; i < NumOperands; i++) {
+ Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
+ if (!IM || IM->getOpcode() == Instruction::PHI)
+ continue;
+
+ // If IM is in another BB, no need to move it, because this pass only
+ // vectorizes instructions within one BB.
+ if (IM->getParent() != I->getParent())
+ continue;
+
+ if (!IM->comesBefore(I)) {
+ InstructionsToMove.insert(IM);
+ Worklist.push_back(IM);
+ }
+ }
+ }
+
+ // All instructions to move should follow I. Start from I, not from begin().
+ for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;) {
+ Instruction *IM = &*(BBI++);
+ if (!InstructionsToMove.count(IM))
+ continue;
+ IM->moveBefore(I);
+ }
+}
class Vectorizer {
Function &F;
@@ -118,6 +244,12 @@ class Vectorizer {
const DataLayout &DL;
IRBuilder<> Builder;
+ // We could erase instrs right after vectorizing them, but that can mess up
+ // our BB iterators, and also can make the equivalence class keys point to
+ // freed memory. This is fixable, but it's simpler just to wait until we're
+ // done with the BB and erase all at once.
+ SmallVector<Instruction *, 128> ToErase;
+
public:
Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
@@ -127,70 +259,83 @@ public:
bool run();
private:
- unsigned getPointerAddressSpace(Value *I);
-
static const unsigned MaxDepth = 3;
- bool isConsecutiveAccess(Value *A, Value *B);
- bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta,
- unsigned Depth = 0) const;
- bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta,
- unsigned Depth) const;
- bool lookThroughSelects(Value *PtrA, Value *PtrB, const APInt &PtrDelta,
- unsigned Depth) const;
-
- /// After vectorization, reorder the instructions that I depends on
- /// (the instructions defining its operands), to ensure they dominate I.
- void reorder(Instruction *I);
-
- /// Returns the first and the last instructions in Chain.
- std::pair<BasicBlock::iterator, BasicBlock::iterator>
- getBoundaryInstrs(ArrayRef<Instruction *> Chain);
-
- /// Erases the original instructions after vectorizing.
- void eraseInstructions(ArrayRef<Instruction *> Chain);
-
- /// "Legalize" the vector type that would be produced by combining \p
- /// ElementSizeBits elements in \p Chain. Break into two pieces such that the
- /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is
- /// expected to have more than 4 elements.
- std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
- splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits);
-
- /// Finds the largest prefix of Chain that's vectorizable, checking for
- /// intervening instructions which may affect the memory accessed by the
- /// instructions within Chain.
+ /// Runs the vectorizer on a "pseudo basic block", which is a range of
+ /// instructions [Begin, End) within one BB all of which have
+ /// isGuaranteedToTransferExecutionToSuccessor(I) == true.
+ bool runOnPseudoBB(BasicBlock::iterator Begin, BasicBlock::iterator End);
+
+ /// Runs the vectorizer on one equivalence class, i.e. one set of loads/stores
+ /// in the same BB with the same value for getUnderlyingObject() etc.
+ bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
+ ArrayRef<Instruction *> EqClass);
+
+ /// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
+ /// where all instructions access a known, constant offset from the first
+ /// instruction.
+ bool runOnChain(Chain &C);
+
+ /// Splits the chain into subchains of instructions which read/write a
+ /// contiguous block of memory. Discards any length-1 subchains (because
+ /// there's nothing to vectorize in there).
+ std::vector<Chain> splitChainByContiguity(Chain &C);
+
+ /// Splits the chain into subchains where it's safe to hoist loads up to the
+ /// beginning of the sub-chain and it's safe to sink loads up to the end of
+ /// the sub-chain. Discards any length-1 subchains.
+ std::vector<Chain> splitChainByMayAliasInstrs(Chain &C);
+
+ /// Splits the chain into subchains that make legal, aligned accesses.
+ /// Discards any length-1 subchains.
+ std::vector<Chain> splitChainByAlignment(Chain &C);
+
+ /// Converts the instrs in the chain into a single vectorized load or store.
+ /// Adds the old scalar loads/stores to ToErase.
+ bool vectorizeChain(Chain &C);
+
+ /// Tries to compute the offset in bytes PtrB - PtrA.
+ std::optional<APInt> getConstantOffset(Value *PtrA, Value *PtrB,
+ Instruction *ContextInst,
+ unsigned Depth = 0);
+ std::optional<APInt> getConstantOffsetComplexAddrs(Value *PtrA, Value *PtrB,
+ Instruction *ContextInst,
+ unsigned Depth);
+ std::optional<APInt> getConstantOffsetSelects(Value *PtrA, Value *PtrB,
+ Instruction *ContextInst,
+ unsigned Depth);
+
+ /// Gets the element type of the vector that the chain will load or store.
+ /// This is nontrivial because the chain may contain elements of different
+ /// types; e.g. it's legal to have a chain that contains both i32 and float.
+ Type *getChainElemTy(const Chain &C);
+
+ /// Determines whether ChainElem can be moved up (if IsLoad) or down (if
+ /// !IsLoad) to ChainBegin -- i.e. there are no intervening may-alias
+ /// instructions.
+ ///
+ /// The map ChainElemOffsets must contain all of the elements in
+ /// [ChainBegin, ChainElem] and their offsets from some arbitrary base
+ /// address. It's ok if it contains additional entries.
+ template <bool IsLoadChain>
+ bool isSafeToMove(
+ Instruction *ChainElem, Instruction *ChainBegin,
+ const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
+
+ /// Collects loads and stores grouped by "equivalence class", where:
+ /// - all elements in an eq class are a load or all are a store,
+ /// - they all load/store the same element size (it's OK to have e.g. i8 and
+ /// <4 x i8> in the same class, but not i32 and <4 x i8>), and
+ /// - they all have the same value for getUnderlyingObject().
+ EquivalenceClassMap collectEquivalenceClasses(BasicBlock::iterator Begin,
+ BasicBlock::iterator End);
+
+ /// Partitions Instrs into "chains" where every instruction has a known
+ /// constant offset from the first instr in the chain.
///
- /// The elements of \p Chain must be all loads or all stores and must be in
- /// address order.
- ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain);
-
- /// Collects load and store instructions to vectorize.
- std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB);
-
- /// Processes the collected instructions, the \p Map. The values of \p Map
- /// should be all loads or all stores.
- bool vectorizeChains(InstrListMap &Map);
-
- /// Finds the load/stores to consecutive memory addresses and vectorizes them.
- bool vectorizeInstructions(ArrayRef<Instruction *> Instrs);
-
- /// Vectorizes the load instructions in Chain.
- bool
- vectorizeLoadChain(ArrayRef<Instruction *> Chain,
- SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
-
- /// Vectorizes the store instructions in Chain.
- bool
- vectorizeStoreChain(ArrayRef<Instruction *> Chain,
- SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
-
- /// Check if this load/store access is misaligned accesses.
- /// Returns a \p RelativeSpeed of an operation if allowed suitable to
- /// compare to another result for the same \p AddressSpace and potentially
- /// different \p Alignment and \p SzInBytes.
- bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
- Align Alignment, unsigned &RelativeSpeed);
+ /// Postcondition: For all i, ret[i][0].second == 0, because the first instr
+ /// in the chain is the leader, and an instr touches distance 0 from itself.
+ std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs);
};
class LoadStoreVectorizerLegacyPass : public FunctionPass {
@@ -198,7 +343,8 @@ public:
static char ID;
LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {
- initializeLoadStoreVectorizerLegacyPassPass(*PassRegistry::getPassRegistry());
+ initializeLoadStoreVectorizerLegacyPassPass(
+ *PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override;
@@ -250,11 +396,11 @@ bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
AssumptionCache &AC =
getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- Vectorizer V(F, AA, AC, DT, SE, TTI);
- return V.run();
+ return Vectorizer(F, AA, AC, DT, SE, TTI).run();
}
-PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) {
+PreservedAnalyses LoadStoreVectorizerPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
// Don't vectorize when the attribute NoImplicitFloat is used.
if (F.hasFnAttribute(Attribute::NoImplicitFloat))
return PreservedAnalyses::all();
@@ -265,125 +411,681 @@ PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisMana
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
- Vectorizer V(F, AA, AC, DT, SE, TTI);
- bool Changed = V.run();
+ bool Changed = Vectorizer(F, AA, AC, DT, SE, TTI).run();
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return Changed ? PA : PreservedAnalyses::all();
}
-// The real propagateMetadata expects a SmallVector<Value*>, but we deal in
-// vectors of Instructions.
-static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) {
- SmallVector<Value *, 8> VL(IL.begin(), IL.end());
- propagateMetadata(I, VL);
-}
-
-// Vectorizer Implementation
bool Vectorizer::run() {
bool Changed = false;
-
- // Scan the blocks in the function in post order.
+ // Break up the BB if there are any instrs which aren't guaranteed to transfer
+ // execution to their successor.
+ //
+ // Consider, for example:
+ //
+ // def assert_arr_len(int n) { if (n < 2) exit(); }
+ //
+ // load arr[0]
+ // call assert_array_len(arr.length)
+ // load arr[1]
+ //
+ // Even though assert_arr_len does not read or write any memory, we can't
+ // speculate the second load before the call. More info at
+ // https://github.com/llvm/llvm-project/issues/52950.
for (BasicBlock *BB : post_order(&F)) {
- InstrListMap LoadRefs, StoreRefs;
- std::tie(LoadRefs, StoreRefs) = collectInstructions(BB);
- Changed |= vectorizeChains(LoadRefs);
- Changed |= vectorizeChains(StoreRefs);
+ // BB must at least have a terminator.
+ assert(!BB->empty());
+
+ SmallVector<BasicBlock::iterator, 8> Barriers;
+ Barriers.push_back(BB->begin());
+ for (Instruction &I : *BB)
+ if (!isGuaranteedToTransferExecutionToSuccessor(&I))
+ Barriers.push_back(I.getIterator());
+ Barriers.push_back(BB->end());
+
+ for (auto It = Barriers.begin(), End = std::prev(Barriers.end()); It != End;
+ ++It)
+ Changed |= runOnPseudoBB(*It, *std::next(It));
+
+ for (Instruction *I : ToErase) {
+ auto *PtrOperand = getLoadStorePointerOperand(I);
+ if (I->use_empty())
+ I->eraseFromParent();
+ RecursivelyDeleteTriviallyDeadInstructions(PtrOperand);
+ }
+ ToErase.clear();
}
return Changed;
}
-unsigned Vectorizer::getPointerAddressSpace(Value *I) {
- if (LoadInst *L = dyn_cast<LoadInst>(I))
- return L->getPointerAddressSpace();
- if (StoreInst *S = dyn_cast<StoreInst>(I))
- return S->getPointerAddressSpace();
- return -1;
+bool Vectorizer::runOnPseudoBB(BasicBlock::iterator Begin,
+ BasicBlock::iterator End) {
+ LLVM_DEBUG({
+ dbgs() << "LSV: Running on pseudo-BB [" << *Begin << " ... ";
+ if (End != Begin->getParent()->end())
+ dbgs() << *End;
+ else
+ dbgs() << "<BB end>";
+ dbgs() << ")\n";
+ });
+
+ bool Changed = false;
+ for (const auto &[EqClassKey, EqClass] :
+ collectEquivalenceClasses(Begin, End))
+ Changed |= runOnEquivalenceClass(EqClassKey, EqClass);
+
+ return Changed;
}
-// FIXME: Merge with llvm::isConsecutiveAccess
-bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
- Value *PtrA = getLoadStorePointerOperand(A);
- Value *PtrB = getLoadStorePointerOperand(B);
- unsigned ASA = getPointerAddressSpace(A);
- unsigned ASB = getPointerAddressSpace(B);
+bool Vectorizer::runOnEquivalenceClass(const EqClassKey &EqClassKey,
+ ArrayRef<Instruction *> EqClass) {
+ bool Changed = false;
- // Check that the address spaces match and that the pointers are valid.
- if (!PtrA || !PtrB || (ASA != ASB))
- return false;
+ LLVM_DEBUG({
+ dbgs() << "LSV: Running on equivalence class of size " << EqClass.size()
+ << " keyed on " << EqClassKey << ":\n";
+ for (Instruction *I : EqClass)
+ dbgs() << " " << *I << "\n";
+ });
- // Make sure that A and B are different pointers of the same size type.
- Type *PtrATy = getLoadStoreType(A);
- Type *PtrBTy = getLoadStoreType(B);
- if (PtrA == PtrB ||
- PtrATy->isVectorTy() != PtrBTy->isVectorTy() ||
- DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) ||
- DL.getTypeStoreSize(PtrATy->getScalarType()) !=
- DL.getTypeStoreSize(PtrBTy->getScalarType()))
- return false;
+ std::vector<Chain> Chains = gatherChains(EqClass);
+ LLVM_DEBUG(dbgs() << "LSV: Got " << Chains.size()
+ << " nontrivial chains.\n";);
+ for (Chain &C : Chains)
+ Changed |= runOnChain(C);
+ return Changed;
+}
- unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA);
- APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy));
+bool Vectorizer::runOnChain(Chain &C) {
+ LLVM_DEBUG({
+ dbgs() << "LSV: Running on chain with " << C.size() << " instructions:\n";
+ dumpChain(C);
+ });
- return areConsecutivePointers(PtrA, PtrB, Size);
+ // Split up the chain into increasingly smaller chains, until we can finally
+ // vectorize the chains.
+ //
+ // (Don't be scared by the depth of the loop nest here. These operations are
+ // all at worst O(n lg n) in the number of instructions, and splitting chains
+ // doesn't change the number of instrs. So the whole loop nest is O(n lg n).)
+ bool Changed = false;
+ for (auto &C : splitChainByMayAliasInstrs(C))
+ for (auto &C : splitChainByContiguity(C))
+ for (auto &C : splitChainByAlignment(C))
+ Changed |= vectorizeChain(C);
+ return Changed;
}
-bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB,
- APInt PtrDelta, unsigned Depth) const {
- unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType());
- APInt OffsetA(PtrBitWidth, 0);
- APInt OffsetB(PtrBitWidth, 0);
- PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
- PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
+std::vector<Chain> Vectorizer::splitChainByMayAliasInstrs(Chain &C) {
+ if (C.empty())
+ return {};
- unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
+ sortChainInBBOrder(C);
- if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
+ LLVM_DEBUG({
+ dbgs() << "LSV: splitChainByMayAliasInstrs considering chain:\n";
+ dumpChain(C);
+ });
+
+ // We know that elements in the chain with nonverlapping offsets can't
+ // alias, but AA may not be smart enough to figure this out. Use a
+ // hashtable.
+ DenseMap<Instruction *, APInt /*OffsetFromLeader*/> ChainOffsets;
+ for (const auto &E : C)
+ ChainOffsets.insert({&*E.Inst, E.OffsetFromLeader});
+
+ // Loads get hoisted up to the first load in the chain. Stores get sunk
+ // down to the last store in the chain. Our algorithm for loads is:
+ //
+ // - Take the first element of the chain. This is the start of a new chain.
+ // - Take the next element of `Chain` and check for may-alias instructions
+ // up to the start of NewChain. If no may-alias instrs, add it to
+ // NewChain. Otherwise, start a new NewChain.
+ //
+ // For stores it's the same except in the reverse direction.
+ //
+ // We expect IsLoad to be an std::bool_constant.
+ auto Impl = [&](auto IsLoad) {
+ // MSVC is unhappy if IsLoad is a capture, so pass it as an arg.
+ auto [ChainBegin, ChainEnd] = [&](auto IsLoad) {
+ if constexpr (IsLoad())
+ return std::make_pair(C.begin(), C.end());
+ else
+ return std::make_pair(C.rbegin(), C.rend());
+ }(IsLoad);
+ assert(ChainBegin != ChainEnd);
+
+ std::vector<Chain> Chains;
+ SmallVector<ChainElem, 1> NewChain;
+ NewChain.push_back(*ChainBegin);
+ for (auto ChainIt = std::next(ChainBegin); ChainIt != ChainEnd; ++ChainIt) {
+ if (isSafeToMove<IsLoad>(ChainIt->Inst, NewChain.front().Inst,
+ ChainOffsets)) {
+ LLVM_DEBUG(dbgs() << "LSV: No intervening may-alias instrs; can merge "
+ << *ChainIt->Inst << " into " << *ChainBegin->Inst
+ << "\n");
+ NewChain.push_back(*ChainIt);
+ } else {
+ LLVM_DEBUG(
+ dbgs() << "LSV: Found intervening may-alias instrs; cannot merge "
+ << *ChainIt->Inst << " into " << *ChainBegin->Inst << "\n");
+ if (NewChain.size() > 1) {
+ LLVM_DEBUG({
+ dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
+ dumpChain(NewChain);
+ });
+ Chains.push_back(std::move(NewChain));
+ }
+
+ // Start a new chain.
+ NewChain = SmallVector<ChainElem, 1>({*ChainIt});
+ }
+ }
+ if (NewChain.size() > 1) {
+ LLVM_DEBUG({
+ dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
+ dumpChain(NewChain);
+ });
+ Chains.push_back(std::move(NewChain));
+ }
+ return Chains;
+ };
+
+ if (isa<LoadInst>(C[0].Inst))
+ return Impl(/*IsLoad=*/std::bool_constant<true>());
+
+ assert(isa<StoreInst>(C[0].Inst));
+ return Impl(/*IsLoad=*/std::bool_constant<false>());
+}
+
+std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
+ if (C.empty())
+ return {};
+
+ sortChainInOffsetOrder(C);
+
+ LLVM_DEBUG({
+ dbgs() << "LSV: splitChainByContiguity considering chain:\n";
+ dumpChain(C);
+ });
+
+ std::vector<Chain> Ret;
+ Ret.push_back({C.front()});
+
+ for (auto It = std::next(C.begin()), End = C.end(); It != End; ++It) {
+ // `prev` accesses offsets [PrevDistFromBase, PrevReadEnd).
+ auto &CurChain = Ret.back();
+ const ChainElem &Prev = CurChain.back();
+ unsigned SzBits = DL.getTypeSizeInBits(getLoadStoreType(&*Prev.Inst));
+ assert(SzBits % 8 == 0 && "Non-byte sizes should have been filtered out by "
+ "collectEquivalenceClass");
+ APInt PrevReadEnd = Prev.OffsetFromLeader + SzBits / 8;
+
+ // Add this instruction to the end of the current chain, or start a new one.
+ bool AreContiguous = It->OffsetFromLeader == PrevReadEnd;
+ LLVM_DEBUG(dbgs() << "LSV: Instructions are "
+ << (AreContiguous ? "" : "not ") << "contiguous: "
+ << *Prev.Inst << " (ends at offset " << PrevReadEnd
+ << ") -> " << *It->Inst << " (starts at offset "
+ << It->OffsetFromLeader << ")\n");
+ if (AreContiguous)
+ CurChain.push_back(*It);
+ else
+ Ret.push_back({*It});
+ }
+
+ // Filter out length-1 chains, these are uninteresting.
+ llvm::erase_if(Ret, [](const auto &Chain) { return Chain.size() <= 1; });
+ return Ret;
+}
+
+Type *Vectorizer::getChainElemTy(const Chain &C) {
+ assert(!C.empty());
+ // The rules are:
+ // - If there are any pointer types in the chain, use an integer type.
+ // - Prefer an integer type if it appears in the chain.
+ // - Otherwise, use the first type in the chain.
+ //
+ // The rule about pointer types is a simplification when we merge e.g. a load
+ // of a ptr and a double. There's no direct conversion from a ptr to a
+ // double; it requires a ptrtoint followed by a bitcast.
+ //
+ // It's unclear to me if the other rules have any practical effect, but we do
+ // it to match this pass's previous behavior.
+ if (any_of(C, [](const ChainElem &E) {
+ return getLoadStoreType(E.Inst)->getScalarType()->isPointerTy();
+ })) {
+ return Type::getIntNTy(
+ F.getContext(),
+ DL.getTypeSizeInBits(getLoadStoreType(C[0].Inst)->getScalarType()));
+ }
+
+ for (const ChainElem &E : C)
+ if (Type *T = getLoadStoreType(E.Inst)->getScalarType(); T->isIntegerTy())
+ return T;
+ return getLoadStoreType(C[0].Inst)->getScalarType();
+}
+
+std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
+ // We use a simple greedy algorithm.
+ // - Given a chain of length N, find all prefixes that
+ // (a) are not longer than the max register length, and
+ // (b) are a power of 2.
+ // - Starting from the longest prefix, try to create a vector of that length.
+ // - If one of them works, great. Repeat the algorithm on any remaining
+ // elements in the chain.
+ // - If none of them work, discard the first element and repeat on a chain
+ // of length N-1.
+ if (C.empty())
+ return {};
+
+ sortChainInOffsetOrder(C);
+
+ LLVM_DEBUG({
+ dbgs() << "LSV: splitChainByAlignment considering chain:\n";
+ dumpChain(C);
+ });
+
+ bool IsLoadChain = isa<LoadInst>(C[0].Inst);
+ auto getVectorFactor = [&](unsigned VF, unsigned LoadStoreSize,
+ unsigned ChainSizeBytes, VectorType *VecTy) {
+ return IsLoadChain ? TTI.getLoadVectorFactor(VF, LoadStoreSize,
+ ChainSizeBytes, VecTy)
+ : TTI.getStoreVectorFactor(VF, LoadStoreSize,
+ ChainSizeBytes, VecTy);
+ };
+
+#ifndef NDEBUG
+ for (const auto &E : C) {
+ Type *Ty = getLoadStoreType(E.Inst)->getScalarType();
+ assert(isPowerOf2_32(DL.getTypeSizeInBits(Ty)) &&
+ "Should have filtered out non-power-of-two elements in "
+ "collectEquivalenceClasses.");
+ }
+#endif
+
+ unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
+ unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AS) / 8;
+
+ std::vector<Chain> Ret;
+ for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) {
+ // Find candidate chains of size not greater than the largest vector reg.
+ // These chains are over the closed interval [CBegin, CEnd].
+ SmallVector<std::pair<unsigned /*CEnd*/, unsigned /*SizeBytes*/>, 8>
+ CandidateChains;
+ for (unsigned CEnd = CBegin + 1, Size = C.size(); CEnd < Size; ++CEnd) {
+ APInt Sz = C[CEnd].OffsetFromLeader +
+ DL.getTypeStoreSize(getLoadStoreType(C[CEnd].Inst)) -
+ C[CBegin].OffsetFromLeader;
+ if (Sz.sgt(VecRegBytes))
+ break;
+ CandidateChains.push_back(
+ {CEnd, static_cast<unsigned>(Sz.getLimitedValue())});
+ }
+
+ // Consider the longest chain first.
+ for (auto It = CandidateChains.rbegin(), End = CandidateChains.rend();
+ It != End; ++It) {
+ auto [CEnd, SizeBytes] = *It;
+ LLVM_DEBUG(
+ dbgs() << "LSV: splitChainByAlignment considering candidate chain ["
+ << *C[CBegin].Inst << " ... " << *C[CEnd].Inst << "]\n");
+
+ Type *VecElemTy = getChainElemTy(C);
+ // Note, VecElemTy is a power of 2, but might be less than one byte. For
+ // example, we can vectorize 2 x <2 x i4> to <4 x i4>, and in this case
+ // VecElemTy would be i4.
+ unsigned VecElemBits = DL.getTypeSizeInBits(VecElemTy);
+
+ // SizeBytes and VecElemBits are powers of 2, so they divide evenly.
+ assert((8 * SizeBytes) % VecElemBits == 0);
+ unsigned NumVecElems = 8 * SizeBytes / VecElemBits;
+ FixedVectorType *VecTy = FixedVectorType::get(VecElemTy, NumVecElems);
+ unsigned VF = 8 * VecRegBytes / VecElemBits;
+
+ // Check that TTI is happy with this vectorization factor.
+ unsigned TargetVF = getVectorFactor(VF, VecElemBits,
+ VecElemBits * NumVecElems / 8, VecTy);
+ if (TargetVF != VF && TargetVF < NumVecElems) {
+ LLVM_DEBUG(
+ dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
+ "because TargetVF="
+ << TargetVF << " != VF=" << VF
+ << " and TargetVF < NumVecElems=" << NumVecElems << "\n");
+ continue;
+ }
+
+ // Is a load/store with this alignment allowed by TTI and at least as fast
+ // as an unvectorized load/store?
+ //
+ // TTI and F are passed as explicit captures to WAR an MSVC misparse (??).
+ auto IsAllowedAndFast = [&, SizeBytes = SizeBytes, &TTI = TTI,
+ &F = F](Align Alignment) {
+ if (Alignment.value() % SizeBytes == 0)
+ return true;
+ unsigned VectorizedSpeed = 0;
+ bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses(
+ F.getContext(), SizeBytes * 8, AS, Alignment, &VectorizedSpeed);
+ if (!AllowsMisaligned) {
+ LLVM_DEBUG(dbgs()
+ << "LSV: Access of " << SizeBytes << "B in addrspace "
+ << AS << " with alignment " << Alignment.value()
+ << " is misaligned, and therefore can't be vectorized.\n");
+ return false;
+ }
+
+ unsigned ElementwiseSpeed = 0;
+ (TTI).allowsMisalignedMemoryAccesses((F).getContext(), VecElemBits, AS,
+ Alignment, &ElementwiseSpeed);
+ if (VectorizedSpeed < ElementwiseSpeed) {
+ LLVM_DEBUG(dbgs()
+ << "LSV: Access of " << SizeBytes << "B in addrspace "
+ << AS << " with alignment " << Alignment.value()
+ << " has relative speed " << VectorizedSpeed
+ << ", which is lower than the elementwise speed of "
+ << ElementwiseSpeed
+ << ". Therefore this access won't be vectorized.\n");
+ return false;
+ }
+ return true;
+ };
+
+ // If we're loading/storing from an alloca, align it if possible.
+ //
+ // FIXME: We eagerly upgrade the alignment, regardless of whether TTI
+ // tells us this is beneficial. This feels a bit odd, but it matches
+ // existing tests. This isn't *so* bad, because at most we align to 4
+ // bytes (current value of StackAdjustedAlignment).
+ //
+ // FIXME: We will upgrade the alignment of the alloca even if it turns out
+ // we can't vectorize for some other reason.
+ Value *PtrOperand = getLoadStorePointerOperand(C[CBegin].Inst);
+ bool IsAllocaAccess = AS == DL.getAllocaAddrSpace() &&
+ isa<AllocaInst>(PtrOperand->stripPointerCasts());
+ Align Alignment = getLoadStoreAlignment(C[CBegin].Inst);
+ Align PrefAlign = Align(StackAdjustedAlignment);
+ if (IsAllocaAccess && Alignment.value() % SizeBytes != 0 &&
+ IsAllowedAndFast(PrefAlign)) {
+ Align NewAlign = getOrEnforceKnownAlignment(
+ PtrOperand, PrefAlign, DL, C[CBegin].Inst, nullptr, &DT);
+ if (NewAlign >= Alignment) {
+ LLVM_DEBUG(dbgs()
+ << "LSV: splitByChain upgrading alloca alignment from "
+ << Alignment.value() << " to " << NewAlign.value()
+ << "\n");
+ Alignment = NewAlign;
+ }
+ }
+
+ if (!IsAllowedAndFast(Alignment)) {
+ LLVM_DEBUG(
+ dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
+ "because its alignment is not AllowedAndFast: "
+ << Alignment.value() << "\n");
+ continue;
+ }
+
+ if ((IsLoadChain &&
+ !TTI.isLegalToVectorizeLoadChain(SizeBytes, Alignment, AS)) ||
+ (!IsLoadChain &&
+ !TTI.isLegalToVectorizeStoreChain(SizeBytes, Alignment, AS))) {
+ LLVM_DEBUG(
+ dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
+ "because !isLegalToVectorizeLoad/StoreChain.");
+ continue;
+ }
+
+ // Hooray, we can vectorize this chain!
+ Chain &NewChain = Ret.emplace_back();
+ for (unsigned I = CBegin; I <= CEnd; ++I)
+ NewChain.push_back(C[I]);
+ CBegin = CEnd; // Skip over the instructions we've added to the chain.
+ break;
+ }
+ }
+ return Ret;
+}
+
+bool Vectorizer::vectorizeChain(Chain &C) {
+ if (C.size() < 2)
return false;
- // In case if we have to shrink the pointer
- // stripAndAccumulateInBoundsConstantOffsets should properly handle a
- // possible overflow and the value should fit into a smallest data type
- // used in the cast/gep chain.
- assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth &&
- OffsetB.getMinSignedBits() <= NewPtrBitWidth);
+ sortChainInOffsetOrder(C);
- OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
- OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
- PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth);
+ LLVM_DEBUG({
+ dbgs() << "LSV: Vectorizing chain of " << C.size() << " instructions:\n";
+ dumpChain(C);
+ });
- APInt OffsetDelta = OffsetB - OffsetA;
+ Type *VecElemTy = getChainElemTy(C);
+ bool IsLoadChain = isa<LoadInst>(C[0].Inst);
+ unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
+ unsigned ChainBytes = std::accumulate(
+ C.begin(), C.end(), 0u, [&](unsigned Bytes, const ChainElem &E) {
+ return Bytes + DL.getTypeStoreSize(getLoadStoreType(E.Inst));
+ });
+ assert(ChainBytes % DL.getTypeStoreSize(VecElemTy) == 0);
+ // VecTy is a power of 2 and 1 byte at smallest, but VecElemTy may be smaller
+ // than 1 byte (e.g. VecTy == <32 x i1>).
+ Type *VecTy = FixedVectorType::get(
+ VecElemTy, 8 * ChainBytes / DL.getTypeSizeInBits(VecElemTy));
+
+ Align Alignment = getLoadStoreAlignment(C[0].Inst);
+ // If this is a load/store of an alloca, we might have upgraded the alloca's
+ // alignment earlier. Get the new alignment.
+ if (AS == DL.getAllocaAddrSpace()) {
+ Alignment = std::max(
+ Alignment,
+ getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst),
+ MaybeAlign(), DL, C[0].Inst, nullptr, &DT));
+ }
- // Check if they are based on the same pointer. That makes the offsets
- // sufficient.
- if (PtrA == PtrB)
- return OffsetDelta == PtrDelta;
-
- // Compute the necessary base pointer delta to have the necessary final delta
- // equal to the pointer delta requested.
- APInt BaseDelta = PtrDelta - OffsetDelta;
-
- // Compute the distance with SCEV between the base pointers.
- const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
- const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
- const SCEV *C = SE.getConstant(BaseDelta);
- const SCEV *X = SE.getAddExpr(PtrSCEVA, C);
- if (X == PtrSCEVB)
+ // All elements of the chain must have the same scalar-type size.
+#ifndef NDEBUG
+ for (const ChainElem &E : C)
+ assert(DL.getTypeStoreSize(getLoadStoreType(E.Inst)->getScalarType()) ==
+ DL.getTypeStoreSize(VecElemTy));
+#endif
+
+ Instruction *VecInst;
+ if (IsLoadChain) {
+ // Loads get hoisted to the location of the first load in the chain. We may
+ // also need to hoist the (transitive) operands of the loads.
+ Builder.SetInsertPoint(
+ std::min_element(C.begin(), C.end(), [](const auto &A, const auto &B) {
+ return A.Inst->comesBefore(B.Inst);
+ })->Inst);
+
+ // Chain is in offset order, so C[0] is the instr with the lowest offset,
+ // i.e. the root of the vector.
+ Value *Bitcast = Builder.CreateBitCast(
+ getLoadStorePointerOperand(C[0].Inst), VecTy->getPointerTo(AS));
+ VecInst = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment);
+
+ unsigned VecIdx = 0;
+ for (const ChainElem &E : C) {
+ Instruction *I = E.Inst;
+ Value *V;
+ Type *T = getLoadStoreType(I);
+ if (auto *VT = dyn_cast<FixedVectorType>(T)) {
+ auto Mask = llvm::to_vector<8>(
+ llvm::seq<int>(VecIdx, VecIdx + VT->getNumElements()));
+ V = Builder.CreateShuffleVector(VecInst, Mask, I->getName());
+ VecIdx += VT->getNumElements();
+ } else {
+ V = Builder.CreateExtractElement(VecInst, Builder.getInt32(VecIdx),
+ I->getName());
+ ++VecIdx;
+ }
+ if (V->getType() != I->getType())
+ V = Builder.CreateBitOrPointerCast(V, I->getType());
+ I->replaceAllUsesWith(V);
+ }
+
+ // Finally, we need to reorder the instrs in the BB so that the (transitive)
+ // operands of VecInst appear before it. To see why, suppose we have
+ // vectorized the following code:
+ //
+ // ptr1 = gep a, 1
+ // load1 = load i32 ptr1
+ // ptr0 = gep a, 0
+ // load0 = load i32 ptr0
+ //
+ // We will put the vectorized load at the location of the earliest load in
+ // the BB, i.e. load1. We get:
+ //
+ // ptr1 = gep a, 1
+ // loadv = load <2 x i32> ptr0
+ // load0 = extractelement loadv, 0
+ // load1 = extractelement loadv, 1
+ // ptr0 = gep a, 0
+ //
+ // Notice that loadv uses ptr0, which is defined *after* it!
+ reorder(VecInst);
+ } else {
+ // Stores get sunk to the location of the last store in the chain.
+ Builder.SetInsertPoint(
+ std::max_element(C.begin(), C.end(), [](auto &A, auto &B) {
+ return A.Inst->comesBefore(B.Inst);
+ })->Inst);
+
+ // Build the vector to store.
+ Value *Vec = PoisonValue::get(VecTy);
+ unsigned VecIdx = 0;
+ auto InsertElem = [&](Value *V) {
+ if (V->getType() != VecElemTy)
+ V = Builder.CreateBitOrPointerCast(V, VecElemTy);
+ Vec = Builder.CreateInsertElement(Vec, V, Builder.getInt32(VecIdx++));
+ };
+ for (const ChainElem &E : C) {
+ auto I = cast<StoreInst>(E.Inst);
+ if (FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(getLoadStoreType(I))) {
+ for (int J = 0, JE = VT->getNumElements(); J < JE; ++J) {
+ InsertElem(Builder.CreateExtractElement(I->getValueOperand(),
+ Builder.getInt32(J)));
+ }
+ } else {
+ InsertElem(I->getValueOperand());
+ }
+ }
+
+ // Chain is in offset order, so C[0] is the instr with the lowest offset,
+ // i.e. the root of the vector.
+ VecInst = Builder.CreateAlignedStore(
+ Vec,
+ Builder.CreateBitCast(getLoadStorePointerOperand(C[0].Inst),
+ VecTy->getPointerTo(AS)),
+ Alignment);
+ }
+
+ propagateMetadata(VecInst, C);
+
+ for (const ChainElem &E : C)
+ ToErase.push_back(E.Inst);
+
+ ++NumVectorInstructions;
+ NumScalarsVectorized += C.size();
+ return true;
+}
+
+template <bool IsLoadChain>
+bool Vectorizer::isSafeToMove(
+ Instruction *ChainElem, Instruction *ChainBegin,
+ const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets) {
+ LLVM_DEBUG(dbgs() << "LSV: isSafeToMove(" << *ChainElem << " -> "
+ << *ChainBegin << ")\n");
+
+ assert(isa<LoadInst>(ChainElem) == IsLoadChain);
+ if (ChainElem == ChainBegin)
return true;
- // The above check will not catch the cases where one of the pointers is
- // factorized but the other one is not, such as (C + (S * (A + B))) vs
- // (AS + BS). Get the minus scev. That will allow re-combining the expresions
- // and getting the simplified difference.
- const SCEV *Dist = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA);
- if (C == Dist)
+ // Invariant loads can always be reordered; by definition they are not
+ // clobbered by stores.
+ if (isInvariantLoad(ChainElem))
return true;
- // Sometimes even this doesn't work, because SCEV can't always see through
- // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
- // things the hard way.
- return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth);
+ auto BBIt = std::next([&] {
+ if constexpr (IsLoadChain)
+ return BasicBlock::reverse_iterator(ChainElem);
+ else
+ return BasicBlock::iterator(ChainElem);
+ }());
+ auto BBItEnd = std::next([&] {
+ if constexpr (IsLoadChain)
+ return BasicBlock::reverse_iterator(ChainBegin);
+ else
+ return BasicBlock::iterator(ChainBegin);
+ }());
+
+ const APInt &ChainElemOffset = ChainOffsets.at(ChainElem);
+ const unsigned ChainElemSize =
+ DL.getTypeStoreSize(getLoadStoreType(ChainElem));
+
+ for (; BBIt != BBItEnd; ++BBIt) {
+ Instruction *I = &*BBIt;
+
+ if (!I->mayReadOrWriteMemory())
+ continue;
+
+ // Loads can be reordered with other loads.
+ if (IsLoadChain && isa<LoadInst>(I))
+ continue;
+
+ // Stores can be sunk below invariant loads.
+ if (!IsLoadChain && isInvariantLoad(I))
+ continue;
+
+ // If I is in the chain, we can tell whether it aliases ChainIt by checking
+ // what offset ChainIt accesses. This may be better than AA is able to do.
+ //
+ // We should really only have duplicate offsets for stores (the duplicate
+ // loads should be CSE'ed), but in case we have a duplicate load, we'll
+ // split the chain so we don't have to handle this case specially.
+ if (auto OffsetIt = ChainOffsets.find(I); OffsetIt != ChainOffsets.end()) {
+ // I and ChainElem overlap if:
+ // - I and ChainElem have the same offset, OR
+ // - I's offset is less than ChainElem's, but I touches past the
+ // beginning of ChainElem, OR
+ // - ChainElem's offset is less than I's, but ChainElem touches past the
+ // beginning of I.
+ const APInt &IOffset = OffsetIt->second;
+ unsigned IElemSize = DL.getTypeStoreSize(getLoadStoreType(I));
+ if (IOffset == ChainElemOffset ||
+ (IOffset.sle(ChainElemOffset) &&
+ (IOffset + IElemSize).sgt(ChainElemOffset)) ||
+ (ChainElemOffset.sle(IOffset) &&
+ (ChainElemOffset + ChainElemSize).sgt(OffsetIt->second))) {
+ LLVM_DEBUG({
+ // Double check that AA also sees this alias. If not, we probably
+ // have a bug.
+ ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
+ assert(IsLoadChain ? isModSet(MR) : isModOrRefSet(MR));
+ dbgs() << "LSV: Found alias in chain: " << *I << "\n";
+ });
+ return false; // We found an aliasing instruction; bail.
+ }
+
+ continue; // We're confident there's no alias.
+ }
+
+ LLVM_DEBUG(dbgs() << "LSV: Querying AA for " << *I << "\n");
+ ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
+ if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) {
+ LLVM_DEBUG(dbgs() << "LSV: Found alias in chain:\n"
+ << " Aliasing instruction:\n"
+ << " " << *I << '\n'
+ << " Aliased instruction and pointer:\n"
+ << " " << *ChainElem << '\n'
+ << " " << *getLoadStorePointerOperand(ChainElem)
+ << '\n');
+
+ return false;
+ }
+ }
+ return true;
}
static bool checkNoWrapFlags(Instruction *I, bool Signed) {
@@ -395,10 +1097,14 @@ static bool checkNoWrapFlags(Instruction *I, bool Signed) {
static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
unsigned MatchingOpIdxA, Instruction *AddOpB,
unsigned MatchingOpIdxB, bool Signed) {
- // If both OpA and OpB is an add with NSW/NUW and with
- // one of the operands being the same, we can guarantee that the
- // transformation is safe if we can prove that OpA won't overflow when
- // IdxDiff added to the other operand of OpA.
+ LLVM_DEBUG(dbgs() << "LSV: checkIfSafeAddSequence IdxDiff=" << IdxDiff
+ << ", AddOpA=" << *AddOpA << ", MatchingOpIdxA="
+ << MatchingOpIdxA << ", AddOpB=" << *AddOpB
+ << ", MatchingOpIdxB=" << MatchingOpIdxB
+ << ", Signed=" << Signed << "\n");
+ // If both OpA and OpB are adds with NSW/NUW and with one of the operands
+ // being the same, we can guarantee that the transformation is safe if we can
+ // prove that OpA won't overflow when Ret added to the other operand of OpA.
// For example:
// %tmp7 = add nsw i32 %tmp2, %v0
// %tmp8 = sext i32 %tmp7 to i64
@@ -407,10 +1113,9 @@ static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
// %tmp12 = add nsw i32 %tmp2, %tmp11
// %tmp13 = sext i32 %tmp12 to i64
//
- // Both %tmp7 and %tmp2 has the nsw flag and the first operand
- // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
- // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
- // nsw flag.
+ // Both %tmp7 and %tmp12 have the nsw flag and the first operand is %tmp2.
+ // It's guaranteed that adding 1 to %tmp7 won't overflow because %tmp11 adds
+ // 1 to %v0 and both %tmp11 and %tmp12 have the nsw flag.
assert(AddOpA->getOpcode() == Instruction::Add &&
AddOpB->getOpcode() == Instruction::Add &&
checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
@@ -461,24 +1166,26 @@ static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
return false;
}
-bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
- APInt PtrDelta,
- unsigned Depth) const {
+std::optional<APInt> Vectorizer::getConstantOffsetComplexAddrs(
+ Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
+ LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs PtrA=" << *PtrA
+ << " PtrB=" << *PtrB << " ContextInst=" << *ContextInst
+ << " Depth=" << Depth << "\n");
auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
if (!GEPA || !GEPB)
- return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth);
+ return getConstantOffsetSelects(PtrA, PtrB, ContextInst, Depth);
// Look through GEPs after checking they're the same except for the last
// index.
if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
GEPA->getPointerOperand() != GEPB->getPointerOperand())
- return false;
+ return std::nullopt;
gep_type_iterator GTIA = gep_type_begin(GEPA);
gep_type_iterator GTIB = gep_type_begin(GEPB);
for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
if (GTIA.getOperand() != GTIB.getOperand())
- return false;
+ return std::nullopt;
++GTIA;
++GTIB;
}
@@ -487,23 +1194,13 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand());
if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
OpA->getType() != OpB->getType())
- return false;
+ return std::nullopt;
- if (PtrDelta.isNegative()) {
- if (PtrDelta.isMinSignedValue())
- return false;
- PtrDelta.negate();
- std::swap(OpA, OpB);
- }
uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType());
- if (PtrDelta.urem(Stride) != 0)
- return false;
- unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits();
- APInt IdxDiff = PtrDelta.udiv(Stride).zext(IdxBitWidth);
// Only look through a ZExt/SExt.
if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
- return false;
+ return std::nullopt;
bool Signed = isa<SExtInst>(OpA);
@@ -511,7 +1208,21 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
Value *ValA = OpA->getOperand(0);
OpB = dyn_cast<Instruction>(OpB->getOperand(0));
if (!OpB || ValA->getType() != OpB->getType())
- return false;
+ return std::nullopt;
+
+ const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
+ const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
+ const SCEV *IdxDiffSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA);
+ if (IdxDiffSCEV == SE.getCouldNotCompute())
+ return std::nullopt;
+
+ ConstantRange IdxDiffRange = SE.getSignedRange(IdxDiffSCEV);
+ if (!IdxDiffRange.isSingleElement())
+ return std::nullopt;
+ APInt IdxDiff = *IdxDiffRange.getSingleElement();
+
+ LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs IdxDiff=" << IdxDiff
+ << "\n");
// Now we need to prove that adding IdxDiff to ValA won't overflow.
bool Safe = false;
@@ -530,10 +1241,9 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) &&
checkNoWrapFlags(OpB, Signed)) {
- // In the checks below a matching operand in OpA and OpB is
- // an operand which is the same in those two instructions.
- // Below we account for possible orders of the operands of
- // these add instructions.
+ // In the checks below a matching operand in OpA and OpB is an operand which
+ // is the same in those two instructions. Below we account for possible
+ // orders of the operands of these add instructions.
for (unsigned MatchingOpIdxA : {0, 1})
for (unsigned MatchingOpIdxB : {0, 1})
if (!Safe)
@@ -544,802 +1254,267 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
// Third attempt:
- // If all set bits of IdxDiff or any higher order bit other than the sign bit
- // are known to be zero in ValA, we can add Diff to it while guaranteeing no
- // overflow of any sort.
+ //
+ // Assuming IdxDiff is positive: If all set bits of IdxDiff or any higher
+ // order bit other than the sign bit are known to be zero in ValA, we can add
+ // Diff to it while guaranteeing no overflow of any sort.
+ //
+ // If IdxDiff is negative, do the same, but swap ValA and ValB.
if (!Safe) {
+ // When computing known bits, use the GEPs as context instructions, since
+ // they likely are in the same BB as the load/store.
KnownBits Known(BitWidth);
- computeKnownBits(ValA, Known, DL, 0, &AC, OpB, &DT);
+ computeKnownBits((IdxDiff.sge(0) ? ValA : OpB), Known, DL, 0, &AC,
+ ContextInst, &DT);
APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth());
if (Signed)
BitsAllowedToBeSet.clearBit(BitWidth - 1);
- if (BitsAllowedToBeSet.ult(IdxDiff))
- return false;
+ if (BitsAllowedToBeSet.ult(IdxDiff.abs()))
+ return std::nullopt;
+ Safe = true;
}
- const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
- const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
- const SCEV *C = SE.getConstant(IdxDiff.trunc(BitWidth));
- const SCEV *X = SE.getAddExpr(OffsetSCEVA, C);
- return X == OffsetSCEVB;
+ if (Safe)
+ return IdxDiff * Stride;
+ return std::nullopt;
}
-bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB,
- const APInt &PtrDelta,
- unsigned Depth) const {
+std::optional<APInt> Vectorizer::getConstantOffsetSelects(
+ Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
if (Depth++ == MaxDepth)
- return false;
+ return std::nullopt;
if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
- return SelectA->getCondition() == SelectB->getCondition() &&
- areConsecutivePointers(SelectA->getTrueValue(),
- SelectB->getTrueValue(), PtrDelta, Depth) &&
- areConsecutivePointers(SelectA->getFalseValue(),
- SelectB->getFalseValue(), PtrDelta, Depth);
+ if (SelectA->getCondition() != SelectB->getCondition())
+ return std::nullopt;
+ LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetSelects, PtrA=" << *PtrA
+ << ", PtrB=" << *PtrB << ", ContextInst="
+ << *ContextInst << ", Depth=" << Depth << "\n");
+ std::optional<APInt> TrueDiff = getConstantOffset(
+ SelectA->getTrueValue(), SelectB->getTrueValue(), ContextInst, Depth);
+ if (!TrueDiff.has_value())
+ return std::nullopt;
+ std::optional<APInt> FalseDiff =
+ getConstantOffset(SelectA->getFalseValue(), SelectB->getFalseValue(),
+ ContextInst, Depth);
+ if (TrueDiff == FalseDiff)
+ return TrueDiff;
}
}
- return false;
+ return std::nullopt;
}
-void Vectorizer::reorder(Instruction *I) {
- SmallPtrSet<Instruction *, 16> InstructionsToMove;
- SmallVector<Instruction *, 16> Worklist;
-
- Worklist.push_back(I);
- while (!Worklist.empty()) {
- Instruction *IW = Worklist.pop_back_val();
- int NumOperands = IW->getNumOperands();
- for (int i = 0; i < NumOperands; i++) {
- Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
- if (!IM || IM->getOpcode() == Instruction::PHI)
- continue;
-
- // If IM is in another BB, no need to move it, because this pass only
- // vectorizes instructions within one BB.
- if (IM->getParent() != I->getParent())
- continue;
-
- if (!IM->comesBefore(I)) {
- InstructionsToMove.insert(IM);
- Worklist.push_back(IM);
- }
+EquivalenceClassMap
+Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
+ BasicBlock::iterator End) {
+ EquivalenceClassMap Ret;
+
+ auto getUnderlyingObject = [](const Value *Ptr) -> const Value * {
+ const Value *ObjPtr = llvm::getUnderlyingObject(Ptr);
+ if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
+ // The select's themselves are distinct instructions even if they share
+ // the same condition and evaluate to consecutive pointers for true and
+ // false values of the condition. Therefore using the select's themselves
+ // for grouping instructions would put consecutive accesses into different
+ // lists and they won't be even checked for being consecutive, and won't
+ // be vectorized.
+ return Sel->getCondition();
}
- }
+ return ObjPtr;
+ };
- // All instructions to move should follow I. Start from I, not from begin().
- for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;
- ++BBI) {
- if (!InstructionsToMove.count(&*BBI))
+ for (Instruction &I : make_range(Begin, End)) {
+ auto *LI = dyn_cast<LoadInst>(&I);
+ auto *SI = dyn_cast<StoreInst>(&I);
+ if (!LI && !SI)
continue;
- Instruction *IM = &*BBI;
- --BBI;
- IM->removeFromParent();
- IM->insertBefore(I);
- }
-}
-
-std::pair<BasicBlock::iterator, BasicBlock::iterator>
-Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) {
- Instruction *C0 = Chain[0];
- BasicBlock::iterator FirstInstr = C0->getIterator();
- BasicBlock::iterator LastInstr = C0->getIterator();
- BasicBlock *BB = C0->getParent();
- unsigned NumFound = 0;
- for (Instruction &I : *BB) {
- if (!is_contained(Chain, &I))
+ if ((LI && !LI->isSimple()) || (SI && !SI->isSimple()))
continue;
- ++NumFound;
- if (NumFound == 1) {
- FirstInstr = I.getIterator();
- }
- if (NumFound == Chain.size()) {
- LastInstr = I.getIterator();
- break;
- }
- }
-
- // Range is [first, last).
- return std::make_pair(FirstInstr, ++LastInstr);
-}
-
-void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) {
- SmallVector<Instruction *, 16> Instrs;
- for (Instruction *I : Chain) {
- Value *PtrOperand = getLoadStorePointerOperand(I);
- assert(PtrOperand && "Instruction must have a pointer operand.");
- Instrs.push_back(I);
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand))
- Instrs.push_back(GEP);
- }
-
- // Erase instructions.
- for (Instruction *I : Instrs)
- if (I->use_empty())
- I->eraseFromParent();
-}
-
-std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
-Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
- unsigned ElementSizeBits) {
- unsigned ElementSizeBytes = ElementSizeBits / 8;
- unsigned SizeBytes = ElementSizeBytes * Chain.size();
- unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
- if (NumLeft == Chain.size()) {
- if ((NumLeft & 1) == 0)
- NumLeft /= 2; // Split even in half
- else
- --NumLeft; // Split off last element
- } else if (NumLeft == 0)
- NumLeft = 1;
- return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
-}
-
-ArrayRef<Instruction *>
-Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
- // These are in BB order, unlike Chain, which is in address order.
- SmallVector<Instruction *, 16> MemoryInstrs;
- SmallVector<Instruction *, 16> ChainInstrs;
-
- bool IsLoadChain = isa<LoadInst>(Chain[0]);
- LLVM_DEBUG({
- for (Instruction *I : Chain) {
- if (IsLoadChain)
- assert(isa<LoadInst>(I) &&
- "All elements of Chain must be loads, or all must be stores.");
- else
- assert(isa<StoreInst>(I) &&
- "All elements of Chain must be loads, or all must be stores.");
- }
- });
-
- for (Instruction &I : make_range(getBoundaryInstrs(Chain))) {
- if ((isa<LoadInst>(I) || isa<StoreInst>(I)) && is_contained(Chain, &I)) {
- ChainInstrs.push_back(&I);
+ if ((LI && !TTI.isLegalToVectorizeLoad(LI)) ||
+ (SI && !TTI.isLegalToVectorizeStore(SI)))
continue;
- }
- if (!isGuaranteedToTransferExecutionToSuccessor(&I)) {
- LLVM_DEBUG(dbgs() << "LSV: Found instruction may not transfer execution: "
- << I << '\n');
- break;
- }
- if (I.mayReadOrWriteMemory())
- MemoryInstrs.push_back(&I);
- }
-
- // Loop until we find an instruction in ChainInstrs that we can't vectorize.
- unsigned ChainInstrIdx = 0;
- Instruction *BarrierMemoryInstr = nullptr;
-
- for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) {
- Instruction *ChainInstr = ChainInstrs[ChainInstrIdx];
-
- // If a barrier memory instruction was found, chain instructions that follow
- // will not be added to the valid prefix.
- if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr))
- break;
- // Check (in BB order) if any instruction prevents ChainInstr from being
- // vectorized. Find and store the first such "conflicting" instruction.
- for (Instruction *MemInstr : MemoryInstrs) {
- // If a barrier memory instruction was found, do not check past it.
- if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr))
- break;
-
- auto *MemLoad = dyn_cast<LoadInst>(MemInstr);
- auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr);
- if (MemLoad && ChainLoad)
- continue;
-
- // We can ignore the alias if the we have a load store pair and the load
- // is known to be invariant. The load cannot be clobbered by the store.
- auto IsInvariantLoad = [](const LoadInst *LI) -> bool {
- return LI->hasMetadata(LLVMContext::MD_invariant_load);
- };
-
- if (IsLoadChain) {
- // We can ignore the alias as long as the load comes before the store,
- // because that means we won't be moving the load past the store to
- // vectorize it (the vectorized load is inserted at the location of the
- // first load in the chain).
- if (ChainInstr->comesBefore(MemInstr) ||
- (ChainLoad && IsInvariantLoad(ChainLoad)))
- continue;
- } else {
- // Same case, but in reverse.
- if (MemInstr->comesBefore(ChainInstr) ||
- (MemLoad && IsInvariantLoad(MemLoad)))
- continue;
- }
-
- ModRefInfo MR =
- AA.getModRefInfo(MemInstr, MemoryLocation::get(ChainInstr));
- if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) {
- LLVM_DEBUG({
- dbgs() << "LSV: Found alias:\n"
- " Aliasing instruction:\n"
- << " " << *MemInstr << '\n'
- << " Aliased instruction and pointer:\n"
- << " " << *ChainInstr << '\n'
- << " " << *getLoadStorePointerOperand(ChainInstr) << '\n';
- });
- // Save this aliasing memory instruction as a barrier, but allow other
- // instructions that precede the barrier to be vectorized with this one.
- BarrierMemoryInstr = MemInstr;
- break;
- }
- }
- // Continue the search only for store chains, since vectorizing stores that
- // precede an aliasing load is valid. Conversely, vectorizing loads is valid
- // up to an aliasing store, but should not pull loads from further down in
- // the basic block.
- if (IsLoadChain && BarrierMemoryInstr) {
- // The BarrierMemoryInstr is a store that precedes ChainInstr.
- assert(BarrierMemoryInstr->comesBefore(ChainInstr));
- break;
- }
- }
-
- // Find the largest prefix of Chain whose elements are all in
- // ChainInstrs[0, ChainInstrIdx). This is the largest vectorizable prefix of
- // Chain. (Recall that Chain is in address order, but ChainInstrs is in BB
- // order.)
- SmallPtrSet<Instruction *, 8> VectorizableChainInstrs(
- ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx);
- unsigned ChainIdx = 0;
- for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
- if (!VectorizableChainInstrs.count(Chain[ChainIdx]))
- break;
- }
- return Chain.slice(0, ChainIdx);
-}
-
-static ChainID getChainID(const Value *Ptr) {
- const Value *ObjPtr = getUnderlyingObject(Ptr);
- if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
- // The select's themselves are distinct instructions even if they share the
- // same condition and evaluate to consecutive pointers for true and false
- // values of the condition. Therefore using the select's themselves for
- // grouping instructions would put consecutive accesses into different lists
- // and they won't be even checked for being consecutive, and won't be
- // vectorized.
- return Sel->getCondition();
- }
- return ObjPtr;
-}
-
-std::pair<InstrListMap, InstrListMap>
-Vectorizer::collectInstructions(BasicBlock *BB) {
- InstrListMap LoadRefs;
- InstrListMap StoreRefs;
-
- for (Instruction &I : *BB) {
- if (!I.mayReadOrWriteMemory())
+ Type *Ty = getLoadStoreType(&I);
+ if (!VectorType::isValidElementType(Ty->getScalarType()))
continue;
- if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
- if (!LI->isSimple())
- continue;
-
- // Skip if it's not legal.
- if (!TTI.isLegalToVectorizeLoad(LI))
- continue;
-
- Type *Ty = LI->getType();
- if (!VectorType::isValidElementType(Ty->getScalarType()))
- continue;
-
- // Skip weird non-byte sizes. They probably aren't worth the effort of
- // handling correctly.
- unsigned TySize = DL.getTypeSizeInBits(Ty);
- if ((TySize % 8) != 0)
- continue;
-
- // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
- // functions are currently using an integer type for the vectorized
- // load/store, and does not support casting between the integer type and a
- // vector of pointers (e.g. i64 to <2 x i16*>)
- if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
- continue;
-
- Value *Ptr = LI->getPointerOperand();
- unsigned AS = Ptr->getType()->getPointerAddressSpace();
- unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
-
- unsigned VF = VecRegSize / TySize;
- VectorType *VecTy = dyn_cast<VectorType>(Ty);
-
- // No point in looking at these if they're too big to vectorize.
- if (TySize > VecRegSize / 2 ||
- (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
- continue;
-
- // Save the load locations.
- const ChainID ID = getChainID(Ptr);
- LoadRefs[ID].push_back(LI);
- } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
- if (!SI->isSimple())
- continue;
-
- // Skip if it's not legal.
- if (!TTI.isLegalToVectorizeStore(SI))
- continue;
-
- Type *Ty = SI->getValueOperand()->getType();
- if (!VectorType::isValidElementType(Ty->getScalarType()))
- continue;
-
- // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
- // functions are currently using an integer type for the vectorized
- // load/store, and does not support casting between the integer type and a
- // vector of pointers (e.g. i64 to <2 x i16*>)
- if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
- continue;
-
- // Skip weird non-byte sizes. They probably aren't worth the effort of
- // handling correctly.
- unsigned TySize = DL.getTypeSizeInBits(Ty);
- if ((TySize % 8) != 0)
- continue;
-
- Value *Ptr = SI->getPointerOperand();
- unsigned AS = Ptr->getType()->getPointerAddressSpace();
- unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
-
- unsigned VF = VecRegSize / TySize;
- VectorType *VecTy = dyn_cast<VectorType>(Ty);
-
- // No point in looking at these if they're too big to vectorize.
- if (TySize > VecRegSize / 2 ||
- (VecTy && TTI.getStoreVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
- continue;
-
- // Save store location.
- const ChainID ID = getChainID(Ptr);
- StoreRefs[ID].push_back(SI);
- }
- }
-
- return {LoadRefs, StoreRefs};
-}
-
-bool Vectorizer::vectorizeChains(InstrListMap &Map) {
- bool Changed = false;
-
- for (const std::pair<ChainID, InstrList> &Chain : Map) {
- unsigned Size = Chain.second.size();
- if (Size < 2)
+ // Skip weird non-byte sizes. They probably aren't worth the effort of
+ // handling correctly.
+ unsigned TySize = DL.getTypeSizeInBits(Ty);
+ if ((TySize % 8) != 0)
continue;
- LLVM_DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n");
-
- // Process the stores in chunks of 64.
- for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) {
- unsigned Len = std::min<unsigned>(CE - CI, 64);
- ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len);
- Changed |= vectorizeInstructions(Chunk);
- }
- }
-
- return Changed;
-}
-
-bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) {
- LLVM_DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size()
- << " instructions.\n");
- SmallVector<int, 16> Heads, Tails;
- int ConsecutiveChain[64];
-
- // Do a quadratic search on all of the given loads/stores and find all of the
- // pairs of loads/stores that follow each other.
- for (int i = 0, e = Instrs.size(); i < e; ++i) {
- ConsecutiveChain[i] = -1;
- for (int j = e - 1; j >= 0; --j) {
- if (i == j)
- continue;
-
- if (isConsecutiveAccess(Instrs[i], Instrs[j])) {
- if (ConsecutiveChain[i] != -1) {
- int CurDistance = std::abs(ConsecutiveChain[i] - i);
- int NewDistance = std::abs(ConsecutiveChain[i] - j);
- if (j < i || NewDistance > CurDistance)
- continue; // Should not insert.
- }
+ // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
+ // functions are currently using an integer type for the vectorized
+ // load/store, and does not support casting between the integer type and a
+ // vector of pointers (e.g. i64 to <2 x i16*>)
+ if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
+ continue;
- Tails.push_back(j);
- Heads.push_back(i);
- ConsecutiveChain[i] = j;
- }
- }
- }
+ Value *Ptr = getLoadStorePointerOperand(&I);
+ unsigned AS = Ptr->getType()->getPointerAddressSpace();
+ unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
- bool Changed = false;
- SmallPtrSet<Instruction *, 16> InstructionsProcessed;
+ unsigned VF = VecRegSize / TySize;
+ VectorType *VecTy = dyn_cast<VectorType>(Ty);
- for (int Head : Heads) {
- if (InstructionsProcessed.count(Instrs[Head]))
+ // Only handle power-of-two sized elements.
+ if ((!VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(Ty))) ||
+ (VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(VecTy->getScalarType()))))
continue;
- bool LongerChainExists = false;
- for (unsigned TIt = 0; TIt < Tails.size(); TIt++)
- if (Head == Tails[TIt] &&
- !InstructionsProcessed.count(Instrs[Heads[TIt]])) {
- LongerChainExists = true;
- break;
- }
- if (LongerChainExists)
- continue;
-
- // We found an instr that starts a chain. Now follow the chain and try to
- // vectorize it.
- SmallVector<Instruction *, 16> Operands;
- int I = Head;
- while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) {
- if (InstructionsProcessed.count(Instrs[I]))
- break;
-
- Operands.push_back(Instrs[I]);
- I = ConsecutiveChain[I];
- }
- bool Vectorized = false;
- if (isa<LoadInst>(*Operands.begin()))
- Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed);
- else
- Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed);
+ // No point in looking at these if they're too big to vectorize.
+ if (TySize > VecRegSize / 2 ||
+ (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
+ continue;
- Changed |= Vectorized;
+ Ret[{getUnderlyingObject(Ptr), AS,
+ DL.getTypeSizeInBits(getLoadStoreType(&I)->getScalarType()),
+ /*IsLoad=*/LI != nullptr}]
+ .push_back(&I);
}
- return Changed;
+ return Ret;
}
-bool Vectorizer::vectorizeStoreChain(
- ArrayRef<Instruction *> Chain,
- SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
- StoreInst *S0 = cast<StoreInst>(Chain[0]);
-
- // If the vector has an int element, default to int for the whole store.
- Type *StoreTy = nullptr;
- for (Instruction *I : Chain) {
- StoreTy = cast<StoreInst>(I)->getValueOperand()->getType();
- if (StoreTy->isIntOrIntVectorTy())
- break;
-
- if (StoreTy->isPtrOrPtrVectorTy()) {
- StoreTy = Type::getIntNTy(F.getParent()->getContext(),
- DL.getTypeSizeInBits(StoreTy));
- break;
- }
- }
- assert(StoreTy && "Failed to find store type");
+std::vector<Chain> Vectorizer::gatherChains(ArrayRef<Instruction *> Instrs) {
+ if (Instrs.empty())
+ return {};
- unsigned Sz = DL.getTypeSizeInBits(StoreTy);
- unsigned AS = S0->getPointerAddressSpace();
- unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
- unsigned VF = VecRegSize / Sz;
- unsigned ChainSize = Chain.size();
- Align Alignment = S0->getAlign();
+ unsigned AS = getLoadStoreAddressSpace(Instrs[0]);
+ unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
- if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
- return false;
+#ifndef NDEBUG
+ // Check that Instrs is in BB order and all have the same addr space.
+ for (size_t I = 1; I < Instrs.size(); ++I) {
+ assert(Instrs[I - 1]->comesBefore(Instrs[I]));
+ assert(getLoadStoreAddressSpace(Instrs[I]) == AS);
}
+#endif
- ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
- if (NewChain.empty()) {
- // No vectorization possible.
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
- return false;
- }
- if (NewChain.size() == 1) {
- // Failed after the first instruction. Discard it and try the smaller chain.
- InstructionsProcessed->insert(NewChain.front());
- return false;
- }
-
- // Update Chain to the valid vectorizable subchain.
- Chain = NewChain;
- ChainSize = Chain.size();
-
- // Check if it's legal to vectorize this chain. If not, split the chain and
- // try again.
- unsigned EltSzInBytes = Sz / 8;
- unsigned SzInBytes = EltSzInBytes * ChainSize;
-
- FixedVectorType *VecTy;
- auto *VecStoreTy = dyn_cast<FixedVectorType>(StoreTy);
- if (VecStoreTy)
- VecTy = FixedVectorType::get(StoreTy->getScalarType(),
- Chain.size() * VecStoreTy->getNumElements());
- else
- VecTy = FixedVectorType::get(StoreTy, Chain.size());
-
- // If it's more than the max vector size or the target has a better
- // vector factor, break it into two pieces.
- unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy);
- if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
- LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
- " Creating two separate arrays.\n");
- bool Vectorized = false;
- Vectorized |=
- vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed);
- Vectorized |=
- vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed);
- return Vectorized;
- }
-
- LLVM_DEBUG({
- dbgs() << "LSV: Stores to vectorize:\n";
- for (Instruction *I : Chain)
- dbgs() << " " << *I << "\n";
- });
-
- // We won't try again to vectorize the elements of the chain, regardless of
- // whether we succeed below.
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
-
- // If the store is going to be misaligned, don't vectorize it.
- unsigned RelativeSpeed;
- if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) {
- if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
- unsigned SpeedBefore;
- accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore);
- if (SpeedBefore > RelativeSpeed)
- return false;
-
- auto Chains = splitOddVectorElts(Chain, Sz);
- bool Vectorized = false;
- Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
- Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
- return Vectorized;
+ // Machinery to build an MRU-hashtable of Chains.
+ //
+ // (Ideally this could be done with MapVector, but as currently implemented,
+ // moving an element to the front of a MapVector is O(n).)
+ struct InstrListElem : ilist_node<InstrListElem>,
+ std::pair<Instruction *, Chain> {
+ explicit InstrListElem(Instruction *I)
+ : std::pair<Instruction *, Chain>(I, {}) {}
+ };
+ struct InstrListElemDenseMapInfo {
+ using PtrInfo = DenseMapInfo<InstrListElem *>;
+ using IInfo = DenseMapInfo<Instruction *>;
+ static InstrListElem *getEmptyKey() { return PtrInfo::getEmptyKey(); }
+ static InstrListElem *getTombstoneKey() {
+ return PtrInfo::getTombstoneKey();
}
-
- Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
- Align(StackAdjustedAlignment),
- DL, S0, nullptr, &DT);
- if (NewAlign >= Alignment)
- Alignment = NewAlign;
- else
- return false;
- }
-
- if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
- auto Chains = splitOddVectorElts(Chain, Sz);
- bool Vectorized = false;
- Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
- Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
- return Vectorized;
- }
-
- BasicBlock::iterator First, Last;
- std::tie(First, Last) = getBoundaryInstrs(Chain);
- Builder.SetInsertPoint(&*Last);
-
- Value *Vec = PoisonValue::get(VecTy);
-
- if (VecStoreTy) {
- unsigned VecWidth = VecStoreTy->getNumElements();
- for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
- StoreInst *Store = cast<StoreInst>(Chain[I]);
- for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) {
- unsigned NewIdx = J + I * VecWidth;
- Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(),
- Builder.getInt32(J));
- if (Extract->getType() != StoreTy->getScalarType())
- Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType());
-
- Value *Insert =
- Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx));
- Vec = Insert;
- }
+ static unsigned getHashValue(const InstrListElem *E) {
+ return IInfo::getHashValue(E->first);
}
- } else {
- for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
- StoreInst *Store = cast<StoreInst>(Chain[I]);
- Value *Extract = Store->getValueOperand();
- if (Extract->getType() != StoreTy->getScalarType())
- Extract =
- Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType());
-
- Value *Insert =
- Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I));
- Vec = Insert;
+ static bool isEqual(const InstrListElem *A, const InstrListElem *B) {
+ if (A == getEmptyKey() || B == getEmptyKey())
+ return A == getEmptyKey() && B == getEmptyKey();
+ if (A == getTombstoneKey() || B == getTombstoneKey())
+ return A == getTombstoneKey() && B == getTombstoneKey();
+ return IInfo::isEqual(A->first, B->first);
}
- }
-
- StoreInst *SI = Builder.CreateAlignedStore(
- Vec,
- Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)),
- Alignment);
- propagateMetadata(SI, Chain);
-
- eraseInstructions(Chain);
- ++NumVectorInstructions;
- NumScalarsVectorized += Chain.size();
- return true;
-}
-
-bool Vectorizer::vectorizeLoadChain(
- ArrayRef<Instruction *> Chain,
- SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
- LoadInst *L0 = cast<LoadInst>(Chain[0]);
-
- // If the vector has an int element, default to int for the whole load.
- Type *LoadTy = nullptr;
- for (const auto &V : Chain) {
- LoadTy = cast<LoadInst>(V)->getType();
- if (LoadTy->isIntOrIntVectorTy())
- break;
-
- if (LoadTy->isPtrOrPtrVectorTy()) {
- LoadTy = Type::getIntNTy(F.getParent()->getContext(),
- DL.getTypeSizeInBits(LoadTy));
- break;
+ };
+ SpecificBumpPtrAllocator<InstrListElem> Allocator;
+ simple_ilist<InstrListElem> MRU;
+ DenseSet<InstrListElem *, InstrListElemDenseMapInfo> Chains;
+
+ // Compare each instruction in `instrs` to leader of the N most recently-used
+ // chains. This limits the O(n^2) behavior of this pass while also allowing
+ // us to build arbitrarily long chains.
+ for (Instruction *I : Instrs) {
+ constexpr int MaxChainsToTry = 64;
+
+ bool MatchFound = false;
+ auto ChainIter = MRU.begin();
+ for (size_t J = 0; J < MaxChainsToTry && ChainIter != MRU.end();
+ ++J, ++ChainIter) {
+ std::optional<APInt> Offset = getConstantOffset(
+ getLoadStorePointerOperand(ChainIter->first),
+ getLoadStorePointerOperand(I),
+ /*ContextInst=*/
+ (ChainIter->first->comesBefore(I) ? I : ChainIter->first));
+ if (Offset.has_value()) {
+ // `Offset` might not have the expected number of bits, if e.g. AS has a
+ // different number of bits than opaque pointers.
+ ChainIter->second.push_back(ChainElem{I, Offset.value()});
+ // Move ChainIter to the front of the MRU list.
+ MRU.remove(*ChainIter);
+ MRU.push_front(*ChainIter);
+ MatchFound = true;
+ break;
+ }
}
- }
- assert(LoadTy && "Can't determine LoadInst type from chain");
-
- unsigned Sz = DL.getTypeSizeInBits(LoadTy);
- unsigned AS = L0->getPointerAddressSpace();
- unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
- unsigned VF = VecRegSize / Sz;
- unsigned ChainSize = Chain.size();
- Align Alignment = L0->getAlign();
-
- if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
- return false;
- }
-
- ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
- if (NewChain.empty()) {
- // No vectorization possible.
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
- return false;
- }
- if (NewChain.size() == 1) {
- // Failed after the first instruction. Discard it and try the smaller chain.
- InstructionsProcessed->insert(NewChain.front());
- return false;
- }
- // Update Chain to the valid vectorizable subchain.
- Chain = NewChain;
- ChainSize = Chain.size();
-
- // Check if it's legal to vectorize this chain. If not, split the chain and
- // try again.
- unsigned EltSzInBytes = Sz / 8;
- unsigned SzInBytes = EltSzInBytes * ChainSize;
- VectorType *VecTy;
- auto *VecLoadTy = dyn_cast<FixedVectorType>(LoadTy);
- if (VecLoadTy)
- VecTy = FixedVectorType::get(LoadTy->getScalarType(),
- Chain.size() * VecLoadTy->getNumElements());
- else
- VecTy = FixedVectorType::get(LoadTy, Chain.size());
-
- // If it's more than the max vector size or the target has a better
- // vector factor, break it into two pieces.
- unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy);
- if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
- LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
- " Creating two separate arrays.\n");
- bool Vectorized = false;
- Vectorized |=
- vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed);
- Vectorized |=
- vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed);
- return Vectorized;
- }
-
- // We won't try again to vectorize the elements of the chain, regardless of
- // whether we succeed below.
- InstructionsProcessed->insert(Chain.begin(), Chain.end());
-
- // If the load is going to be misaligned, don't vectorize it.
- unsigned RelativeSpeed;
- if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) {
- if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
- unsigned SpeedBefore;
- accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore);
- if (SpeedBefore > RelativeSpeed)
- return false;
-
- auto Chains = splitOddVectorElts(Chain, Sz);
- bool Vectorized = false;
- Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
- Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
- return Vectorized;
+ if (!MatchFound) {
+ APInt ZeroOffset(ASPtrBits, 0);
+ InstrListElem *E = new (Allocator.Allocate()) InstrListElem(I);
+ E->second.push_back(ChainElem{I, ZeroOffset});
+ MRU.push_front(*E);
+ Chains.insert(E);
}
-
- Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(),
- Align(StackAdjustedAlignment),
- DL, L0, nullptr, &DT);
- if (NewAlign >= Alignment)
- Alignment = NewAlign;
- else
- return false;
}
- if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
- auto Chains = splitOddVectorElts(Chain, Sz);
- bool Vectorized = false;
- Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
- Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
- return Vectorized;
- }
+ std::vector<Chain> Ret;
+ Ret.reserve(Chains.size());
+ // Iterate over MRU rather than Chains so the order is deterministic.
+ for (auto &E : MRU)
+ if (E.second.size() > 1)
+ Ret.push_back(std::move(E.second));
+ return Ret;
+}
- LLVM_DEBUG({
- dbgs() << "LSV: Loads to vectorize:\n";
- for (Instruction *I : Chain)
- I->dump();
- });
+std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
+ Instruction *ContextInst,
+ unsigned Depth) {
+ LLVM_DEBUG(dbgs() << "LSV: getConstantOffset, PtrA=" << *PtrA
+ << ", PtrB=" << *PtrB << ", ContextInst= " << *ContextInst
+ << ", Depth=" << Depth << "\n");
+ // We'll ultimately return a value of this bit width, even if computations
+ // happen in a different width.
+ unsigned OrigBitWidth = DL.getIndexTypeSizeInBits(PtrA->getType());
+ APInt OffsetA(OrigBitWidth, 0);
+ APInt OffsetB(OrigBitWidth, 0);
+ PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
+ PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
+ unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
+ if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
+ return std::nullopt;
- // getVectorizablePrefix already computed getBoundaryInstrs. The value of
- // Last may have changed since then, but the value of First won't have. If it
- // matters, we could compute getBoundaryInstrs only once and reuse it here.
- BasicBlock::iterator First, Last;
- std::tie(First, Last) = getBoundaryInstrs(Chain);
- Builder.SetInsertPoint(&*First);
-
- Value *Bitcast =
- Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
- LoadInst *LI =
- Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment));
- propagateMetadata(LI, Chain);
-
- for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
- Value *CV = Chain[I];
- Value *V;
- if (VecLoadTy) {
- // Extract a subvector using shufflevector.
- unsigned VecWidth = VecLoadTy->getNumElements();
- auto Mask =
- llvm::to_vector<8>(llvm::seq<int>(I * VecWidth, (I + 1) * VecWidth));
- V = Builder.CreateShuffleVector(LI, Mask, CV->getName());
- } else {
- V = Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName());
- }
+ // If we have to shrink the pointer, stripAndAccumulateInBoundsConstantOffsets
+ // should properly handle a possible overflow and the value should fit into
+ // the smallest data type used in the cast/gep chain.
+ assert(OffsetA.getSignificantBits() <= NewPtrBitWidth &&
+ OffsetB.getSignificantBits() <= NewPtrBitWidth);
- if (V->getType() != CV->getType()) {
- V = Builder.CreateBitOrPointerCast(V, CV->getType());
+ OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
+ OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
+ if (PtrA == PtrB)
+ return (OffsetB - OffsetA).sextOrTrunc(OrigBitWidth);
+
+ // Try to compute B - A.
+ const SCEV *DistScev = SE.getMinusSCEV(SE.getSCEV(PtrB), SE.getSCEV(PtrA));
+ if (DistScev != SE.getCouldNotCompute()) {
+ LLVM_DEBUG(dbgs() << "LSV: SCEV PtrB - PtrA =" << *DistScev << "\n");
+ ConstantRange DistRange = SE.getSignedRange(DistScev);
+ if (DistRange.isSingleElement()) {
+ // Handle index width (the width of Dist) != pointer width (the width of
+ // the Offset*s at this point).
+ APInt Dist = DistRange.getSingleElement()->sextOrTrunc(NewPtrBitWidth);
+ return (OffsetB - OffsetA + Dist).sextOrTrunc(OrigBitWidth);
}
-
- // Replace the old instruction.
- CV->replaceAllUsesWith(V);
}
-
- // Since we might have opaque pointers we might end up using the pointer
- // operand of the first load (wrt. memory loaded) for the vector load. Since
- // this first load might not be the first in the block we potentially need to
- // reorder the pointer operand (and its operands). If we have a bitcast though
- // it might be before the load and should be the reorder start instruction.
- // "Might" because for opaque pointers the "bitcast" is just the first loads
- // pointer operand, as oppposed to something we inserted at the right position
- // ourselves.
- Instruction *BCInst = dyn_cast<Instruction>(Bitcast);
- reorder((BCInst && BCInst != L0->getPointerOperand()) ? BCInst : LI);
-
- eraseInstructions(Chain);
-
- ++NumVectorInstructions;
- NumScalarsVectorized += Chain.size();
- return true;
-}
-
-bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
- Align Alignment, unsigned &RelativeSpeed) {
- RelativeSpeed = 0;
- if (Alignment.value() % SzInBytes == 0)
- return false;
-
- bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(),
- SzInBytes * 8, AddressSpace,
- Alignment, &RelativeSpeed);
- LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows
- << " with relative speed = " << RelativeSpeed << '\n';);
- return !Allows || !RelativeSpeed;
+ std::optional<APInt> Diff =
+ getConstantOffsetComplexAddrs(PtrA, PtrB, ContextInst, Depth);
+ if (Diff.has_value())
+ return (OffsetB - OffsetA + Diff->sext(OffsetB.getBitWidth()))
+ .sextOrTrunc(OrigBitWidth);
+ return std::nullopt;
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index cd48c0d57eb3..f923f0be6621 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -37,6 +37,11 @@ static cl::opt<bool>
EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden,
cl::desc("Enable if-conversion during vectorization."));
+static cl::opt<bool>
+AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
+ cl::desc("Enable recognition of non-constant strided "
+ "pointer induction variables."));
+
namespace llvm {
cl::opt<bool>
HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
@@ -447,8 +452,12 @@ static bool storeToSameAddress(ScalarEvolution *SE, StoreInst *A,
int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
Value *Ptr) const {
- const ValueToValueMap &Strides =
- getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap();
+ // FIXME: Currently, the set of symbolic strides is sometimes queried before
+ // it's collected. This happens from canVectorizeWithIfConvert, when the
+ // pointer is checked to reference consecutive elements suitable for a
+ // masked access.
+ const auto &Strides =
+ LAI ? LAI->getSymbolicStrides() : DenseMap<Value *, const SCEV *>();
Function *F = TheLoop->getHeader()->getParent();
bool OptForSize = F->hasOptSize() ||
@@ -462,11 +471,135 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
return 0;
}
-bool LoopVectorizationLegality::isUniform(Value *V) const {
- return LAI->isUniform(V);
+bool LoopVectorizationLegality::isInvariant(Value *V) const {
+ return LAI->isInvariant(V);
+}
+
+namespace {
+/// A rewriter to build the SCEVs for each of the VF lanes in the expected
+/// vectorized loop, which can then be compared to detect their uniformity. This
+/// is done by replacing the AddRec SCEVs of the original scalar loop (TheLoop)
+/// with new AddRecs where the step is multiplied by StepMultiplier and Offset *
+/// Step is added. Also checks if all sub-expressions are analyzable w.r.t.
+/// uniformity.
+class SCEVAddRecForUniformityRewriter
+ : public SCEVRewriteVisitor<SCEVAddRecForUniformityRewriter> {
+ /// Multiplier to be applied to the step of AddRecs in TheLoop.
+ unsigned StepMultiplier;
+
+ /// Offset to be added to the AddRecs in TheLoop.
+ unsigned Offset;
+
+ /// Loop for which to rewrite AddRecsFor.
+ Loop *TheLoop;
+
+ /// Is any sub-expressions not analyzable w.r.t. uniformity?
+ bool CannotAnalyze = false;
+
+ bool canAnalyze() const { return !CannotAnalyze; }
+
+public:
+ SCEVAddRecForUniformityRewriter(ScalarEvolution &SE, unsigned StepMultiplier,
+ unsigned Offset, Loop *TheLoop)
+ : SCEVRewriteVisitor(SE), StepMultiplier(StepMultiplier), Offset(Offset),
+ TheLoop(TheLoop) {}
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ assert(Expr->getLoop() == TheLoop &&
+ "addrec outside of TheLoop must be invariant and should have been "
+ "handled earlier");
+ // Build a new AddRec by multiplying the step by StepMultiplier and
+ // incrementing the start by Offset * step.
+ Type *Ty = Expr->getType();
+ auto *Step = Expr->getStepRecurrence(SE);
+ if (!SE.isLoopInvariant(Step, TheLoop)) {
+ CannotAnalyze = true;
+ return Expr;
+ }
+ auto *NewStep = SE.getMulExpr(Step, SE.getConstant(Ty, StepMultiplier));
+ auto *ScaledOffset = SE.getMulExpr(Step, SE.getConstant(Ty, Offset));
+ auto *NewStart = SE.getAddExpr(Expr->getStart(), ScaledOffset);
+ return SE.getAddRecExpr(NewStart, NewStep, TheLoop, SCEV::FlagAnyWrap);
+ }
+
+ const SCEV *visit(const SCEV *S) {
+ if (CannotAnalyze || SE.isLoopInvariant(S, TheLoop))
+ return S;
+ return SCEVRewriteVisitor<SCEVAddRecForUniformityRewriter>::visit(S);
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *S) {
+ if (SE.isLoopInvariant(S, TheLoop))
+ return S;
+ // The value could vary across iterations.
+ CannotAnalyze = true;
+ return S;
+ }
+
+ const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *S) {
+ // Could not analyze the expression.
+ CannotAnalyze = true;
+ return S;
+ }
+
+ static const SCEV *rewrite(const SCEV *S, ScalarEvolution &SE,
+ unsigned StepMultiplier, unsigned Offset,
+ Loop *TheLoop) {
+ /// Bail out if the expression does not contain an UDiv expression.
+ /// Uniform values which are not loop invariant require operations to strip
+ /// out the lowest bits. For now just look for UDivs and use it to avoid
+ /// re-writing UDIV-free expressions for other lanes to limit compile time.
+ if (!SCEVExprContains(S,
+ [](const SCEV *S) { return isa<SCEVUDivExpr>(S); }))
+ return SE.getCouldNotCompute();
+
+ SCEVAddRecForUniformityRewriter Rewriter(SE, StepMultiplier, Offset,
+ TheLoop);
+ const SCEV *Result = Rewriter.visit(S);
+
+ if (Rewriter.canAnalyze())
+ return Result;
+ return SE.getCouldNotCompute();
+ }
+};
+
+} // namespace
+
+bool LoopVectorizationLegality::isUniform(Value *V, ElementCount VF) const {
+ if (isInvariant(V))
+ return true;
+ if (VF.isScalable())
+ return false;
+ if (VF.isScalar())
+ return true;
+
+ // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is
+ // never considered uniform.
+ auto *SE = PSE.getSE();
+ if (!SE->isSCEVable(V->getType()))
+ return false;
+ const SCEV *S = SE->getSCEV(V);
+
+ // Rewrite AddRecs in TheLoop to step by VF and check if the expression for
+ // lane 0 matches the expressions for all other lanes.
+ unsigned FixedVF = VF.getKnownMinValue();
+ const SCEV *FirstLaneExpr =
+ SCEVAddRecForUniformityRewriter::rewrite(S, *SE, FixedVF, 0, TheLoop);
+ if (isa<SCEVCouldNotCompute>(FirstLaneExpr))
+ return false;
+
+ // Make sure the expressions for lanes FixedVF-1..1 match the expression for
+ // lane 0. We check lanes in reverse order for compile-time, as frequently
+ // checking the last lane is sufficient to rule out uniformity.
+ return all_of(reverse(seq<unsigned>(1, FixedVF)), [&](unsigned I) {
+ const SCEV *IthLaneExpr =
+ SCEVAddRecForUniformityRewriter::rewrite(S, *SE, FixedVF, I, TheLoop);
+ return FirstLaneExpr == IthLaneExpr;
+ });
}
-bool LoopVectorizationLegality::isUniformMemOp(Instruction &I) const {
+bool LoopVectorizationLegality::isUniformMemOp(Instruction &I,
+ ElementCount VF) const {
Value *Ptr = getLoadStorePointerOperand(&I);
if (!Ptr)
return false;
@@ -474,7 +607,7 @@ bool LoopVectorizationLegality::isUniformMemOp(Instruction &I) const {
// stores from being uniform. The current lowering simply doesn't handle
// it; in particular, the cost model distinguishes scatter/gather from
// scalar w/predication, and we currently rely on the scalar path.
- return isUniform(Ptr) && !blockNeedsPredication(I.getParent());
+ return isUniform(Ptr, VF) && !blockNeedsPredication(I.getParent());
}
bool LoopVectorizationLegality::canVectorizeOuterLoop() {
@@ -700,6 +833,18 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
continue;
}
+ // We prevent matching non-constant strided pointer IVS to preserve
+ // historical vectorizer behavior after a generalization of the
+ // IVDescriptor code. The intent is to remove this check, but we
+ // have to fix issues around code quality for such loops first.
+ auto isDisallowedStridedPointerInduction =
+ [](const InductionDescriptor &ID) {
+ if (AllowStridedPointerIVs)
+ return false;
+ return ID.getKind() == InductionDescriptor::IK_PtrInduction &&
+ ID.getConstIntStepValue() == nullptr;
+ };
+
// TODO: Instead of recording the AllowedExit, it would be good to
// record the complementary set: NotAllowedExit. These include (but may
// not be limited to):
@@ -715,14 +860,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// By recording these, we can then reason about ways to vectorize each
// of these NotAllowedExit.
InductionDescriptor ID;
- if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) {
+ if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID) &&
+ !isDisallowedStridedPointerInduction(ID)) {
addInductionPhi(Phi, ID, AllowedExit);
Requirements->addExactFPMathInst(ID.getExactFPMathInst());
continue;
}
- if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop,
- SinkAfter, DT)) {
+ if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, DT)) {
AllowedExit.insert(Phi);
FixedOrderRecurrences.insert(Phi);
continue;
@@ -730,7 +875,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// As a last resort, coerce the PHI to a AddRec expression
// and re-try classifying it a an induction PHI.
- if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) {
+ if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
+ !isDisallowedStridedPointerInduction(ID)) {
addInductionPhi(Phi, ID, AllowedExit);
continue;
}
@@ -894,18 +1040,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
}
}
- // For fixed order recurrences, we use the previous value (incoming value from
- // the latch) to check if it dominates all users of the recurrence. Bail out
- // if we have to sink such an instruction for another recurrence, as the
- // dominance requirement may not hold after sinking.
- BasicBlock *LoopLatch = TheLoop->getLoopLatch();
- if (any_of(FixedOrderRecurrences, [LoopLatch, this](const PHINode *Phi) {
- Instruction *V =
- cast<Instruction>(Phi->getIncomingValueForBlock(LoopLatch));
- return SinkAfter.find(V) != SinkAfter.end();
- }))
- return false;
-
// Now we know the widest induction type, check if our found induction
// is the same size. If it's not, unset it here and InnerLoopVectorizer
// will create another.
@@ -1124,6 +1258,16 @@ bool LoopVectorizationLegality::blockCanBePredicated(
if (isa<NoAliasScopeDeclInst>(&I))
continue;
+ // We can allow masked calls if there's at least one vector variant, even
+ // if we end up scalarizing due to the cost model calculations.
+ // TODO: Allow other calls if they have appropriate attributes... readonly
+ // and argmemonly?
+ if (CallInst *CI = dyn_cast<CallInst>(&I))
+ if (VFDatabase::hasMaskedVariant(*CI)) {
+ MaskedOp.insert(CI);
+ continue;
+ }
+
// Loads are handled via masking (or speculated if safe to do so.)
if (auto *LI = dyn_cast<LoadInst>(&I)) {
if (!SafePtrs.count(LI->getPointerOperand()))
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 8990a65afdb4..13357cb06c55 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -25,6 +25,7 @@
#define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H
#include "VPlan.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/InstructionCost.h"
namespace llvm {
@@ -217,6 +218,16 @@ struct VectorizationFactor {
}
};
+/// ElementCountComparator creates a total ordering for ElementCount
+/// for the purposes of using it in a set structure.
+struct ElementCountComparator {
+ bool operator()(const ElementCount &LHS, const ElementCount &RHS) const {
+ return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) <
+ std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue());
+ }
+};
+using ElementCountSet = SmallSet<ElementCount, 16, ElementCountComparator>;
+
/// A class that represents two vectorization factors (initialized with 0 by
/// default). One for fixed-width vectorization and one for scalable
/// vectorization. This can be used by the vectorizer to choose from a range of
@@ -261,7 +272,7 @@ class LoopVectorizationPlanner {
const TargetLibraryInfo *TLI;
/// Target Transform Info.
- const TargetTransformInfo *TTI;
+ const TargetTransformInfo &TTI;
/// The legality analysis.
LoopVectorizationLegality *Legal;
@@ -280,12 +291,15 @@ class LoopVectorizationPlanner {
SmallVector<VPlanPtr, 4> VPlans;
+ /// Profitable vector factors.
+ SmallVector<VectorizationFactor, 8> ProfitableVFs;
+
/// A builder used to construct the current plan.
VPBuilder Builder;
public:
LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI,
- const TargetTransformInfo *TTI,
+ const TargetTransformInfo &TTI,
LoopVectorizationLegality *Legal,
LoopVectorizationCostModel &CM,
InterleavedAccessInfo &IAI,
@@ -311,16 +325,22 @@ public:
/// TODO: \p IsEpilogueVectorization is needed to avoid issues due to epilogue
/// vectorization re-using plans for both the main and epilogue vector loops.
/// It should be removed once the re-use issue has been fixed.
- void executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan,
- InnerLoopVectorizer &LB, DominatorTree *DT,
- bool IsEpilogueVectorization);
+ /// \p ExpandedSCEVs is passed during execution of the plan for epilogue loop
+ /// to re-use expansion results generated during main plan execution. Returns
+ /// a mapping of SCEVs to their expanded IR values. Note that this is a
+ /// temporary workaround needed due to the current epilogue handling.
+ DenseMap<const SCEV *, Value *>
+ executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan,
+ InnerLoopVectorizer &LB, DominatorTree *DT,
+ bool IsEpilogueVectorization,
+ DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr);
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void printPlans(raw_ostream &O);
#endif
- /// Look through the existing plans and return true if we have one with all
- /// the vectorization factors in question.
+ /// Look through the existing plans and return true if we have one with
+ /// vectorization factor \p VF.
bool hasPlanWithVF(ElementCount VF) const {
return any_of(VPlans,
[&](const VPlanPtr &Plan) { return Plan->hasVF(VF); });
@@ -333,8 +353,11 @@ public:
getDecisionAndClampRange(const std::function<bool(ElementCount)> &Predicate,
VFRange &Range);
- /// Check if the number of runtime checks exceeds the threshold.
- bool requiresTooManyRuntimeChecks() const;
+ /// \return The most profitable vectorization factor and the cost of that VF
+ /// for vectorizing the epilogue. Returns VectorizationFactor::Disabled if
+ /// epilogue vectorization is not supported for the loop.
+ VectorizationFactor
+ selectEpilogueVectorizationFactor(const ElementCount MaxVF, unsigned IC);
protected:
/// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
@@ -350,9 +373,12 @@ private:
/// Build a VPlan using VPRecipes according to the information gather by
/// Legal. This method is only used for the legacy inner loop vectorizer.
- VPlanPtr buildVPlanWithVPRecipes(
- VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions,
- const MapVector<Instruction *, Instruction *> &SinkAfter);
+ /// \p Range's largest included VF is restricted to the maximum VF the
+ /// returned VPlan is valid for. If no VPlan can be built for the input range,
+ /// set the largest included VF to the maximum VF for which no plan could be
+ /// built.
+ std::optional<VPlanPtr> tryToBuildVPlanWithVPRecipes(
+ VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions);
/// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
/// according to the information gathered by Legal when it checked if it is
@@ -367,6 +393,20 @@ private:
void adjustRecipesForReductions(VPBasicBlock *LatchVPBB, VPlanPtr &Plan,
VPRecipeBuilder &RecipeBuilder,
ElementCount MinVF);
+
+ /// \return The most profitable vectorization factor and the cost of that VF.
+ /// This method checks every VF in \p CandidateVFs.
+ VectorizationFactor
+ selectVectorizationFactor(const ElementCountSet &CandidateVFs);
+
+ /// Returns true if the per-lane cost of VectorizationFactor A is lower than
+ /// that of B.
+ bool isMoreProfitable(const VectorizationFactor &A,
+ const VectorizationFactor &B) const;
+
+ /// Determines if we have the infrastructure to vectorize the loop and its
+ /// epilogue, assuming the main loop is vectorized by \p VF.
+ bool isCandidateForEpilogueVectorization(const ElementCount VF) const;
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a28099d8ba7d..d7e40e8ef978 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -98,6 +98,7 @@
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/DerivedTypes.h"
@@ -120,8 +121,6 @@
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/IR/Verifier.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
@@ -231,6 +230,25 @@ static cl::opt<PreferPredicateTy::Option> PreferPredicateOverEpilogue(
"prefers tail-folding, don't attempt vectorization if "
"tail-folding fails.")));
+static cl::opt<TailFoldingStyle> ForceTailFoldingStyle(
+ "force-tail-folding-style", cl::desc("Force the tail folding style"),
+ cl::init(TailFoldingStyle::None),
+ cl::values(
+ clEnumValN(TailFoldingStyle::None, "none", "Disable tail folding"),
+ clEnumValN(
+ TailFoldingStyle::Data, "data",
+ "Create lane mask for data only, using active.lane.mask intrinsic"),
+ clEnumValN(TailFoldingStyle::DataWithoutLaneMask,
+ "data-without-lane-mask",
+ "Create lane mask with compare/stepvector"),
+ clEnumValN(TailFoldingStyle::DataAndControlFlow, "data-and-control",
+ "Create lane mask using active.lane.mask intrinsic, and use "
+ "it for both data and control flow"),
+ clEnumValN(
+ TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck,
+ "data-and-control-without-rt-check",
+ "Similar to data-and-control, but remove the runtime check")));
+
static cl::opt<bool> MaximizeBandwidth(
"vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden,
cl::desc("Maximize bandwidth when selecting vectorization factor which "
@@ -338,10 +356,12 @@ static cl::opt<bool> PreferPredicatedReductionSelect(
cl::desc(
"Prefer predicating a reduction operation over an after loop select."));
+namespace llvm {
cl::opt<bool> EnableVPlanNativePath(
- "enable-vplan-native-path", cl::init(false), cl::Hidden,
+ "enable-vplan-native-path", cl::Hidden,
cl::desc("Enable VPlan-native vectorization path with "
"support for outer loop vectorization."));
+}
// This flag enables the stress testing of the VPlan H-CFG construction in the
// VPlan-native vectorization path. It must be used in conjuction with
@@ -419,9 +439,42 @@ static std::optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE,
return std::nullopt;
}
+/// Return a vector containing interleaved elements from multiple
+/// smaller input vectors.
+static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
+ const Twine &Name) {
+ unsigned Factor = Vals.size();
+ assert(Factor > 1 && "Tried to interleave invalid number of vectors");
+
+ VectorType *VecTy = cast<VectorType>(Vals[0]->getType());
+#ifndef NDEBUG
+ for (Value *Val : Vals)
+ assert(Val->getType() == VecTy && "Tried to interleave mismatched types");
+#endif
+
+ // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
+ // must use intrinsics to interleave.
+ if (VecTy->isScalableTy()) {
+ VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
+ return Builder.CreateIntrinsic(
+ WideVecTy, Intrinsic::experimental_vector_interleave2, Vals,
+ /*FMFSource=*/nullptr, Name);
+ }
+
+ // Fixed length. Start by concatenating all vectors into a wide vector.
+ Value *WideVec = concatenateVectors(Builder, Vals);
+
+ // Interleave the elements into the wide vector.
+ const unsigned NumElts = VecTy->getElementCount().getFixedValue();
+ return Builder.CreateShuffleVector(
+ WideVec, createInterleaveMask(NumElts, Factor), Name);
+}
+
namespace {
// Forward declare GeneratedRTChecks.
class GeneratedRTChecks;
+
+using SCEV2ValueTy = DenseMap<const SCEV *, Value *>;
} // namespace
namespace llvm {
@@ -477,8 +530,10 @@ public:
/// loop and the start value for the canonical induction, if it is != 0. The
/// latter is the case when vectorizing the epilogue loop. In the case of
/// epilogue vectorization, this function is overriden to handle the more
- /// complex control flow around the loops.
- virtual std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton();
+ /// complex control flow around the loops. \p ExpandedSCEVs is used to
+ /// look up SCEV expansions for expressions needed during skeleton creation.
+ virtual std::pair<BasicBlock *, Value *>
+ createVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs);
/// Fix the vectorized code, taking care of header phi's, live-outs, and more.
void fixVectorizedLoop(VPTransformState &State, VPlan &Plan);
@@ -498,7 +553,7 @@ public:
/// Instr's operands.
void scalarizeInstruction(const Instruction *Instr,
VPReplicateRecipe *RepRecipe,
- const VPIteration &Instance, bool IfPredicateInstr,
+ const VPIteration &Instance,
VPTransformState &State);
/// Construct the vector value of a scalarized value \p V one lane at a time.
@@ -513,7 +568,7 @@ public:
ArrayRef<VPValue *> VPDefs,
VPTransformState &State, VPValue *Addr,
ArrayRef<VPValue *> StoredValues,
- VPValue *BlockInMask = nullptr);
+ VPValue *BlockInMask, bool NeedsMaskForGaps);
/// Fix the non-induction PHIs in \p Plan.
void fixNonInductionPHIs(VPlan &Plan, VPTransformState &State);
@@ -522,28 +577,30 @@ public:
/// able to vectorize with strict in-order reductions for the given RdxDesc.
bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc);
- /// Create a broadcast instruction. This method generates a broadcast
- /// instruction (shuffle) for loop invariant values and for the induction
- /// value. If this is the induction variable then we extend it to N, N+1, ...
- /// this is needed because each iteration in the loop corresponds to a SIMD
- /// element.
- virtual Value *getBroadcastInstrs(Value *V);
-
// Returns the resume value (bc.merge.rdx) for a reduction as
// generated by fixReduction.
PHINode *getReductionResumeValue(const RecurrenceDescriptor &RdxDesc);
/// Create a new phi node for the induction variable \p OrigPhi to resume
/// iteration count in the scalar epilogue, from where the vectorized loop
- /// left off. In cases where the loop skeleton is more complicated (eg.
- /// epilogue vectorization) and the resume values can come from an additional
- /// bypass block, the \p AdditionalBypass pair provides information about the
- /// bypass block and the end value on the edge from bypass to this loop.
+ /// left off. \p Step is the SCEV-expanded induction step to use. In cases
+ /// where the loop skeleton is more complicated (i.e., epilogue vectorization)
+ /// and the resume values can come from an additional bypass block, the \p
+ /// AdditionalBypass pair provides information about the bypass block and the
+ /// end value on the edge from bypass to this loop.
PHINode *createInductionResumeValue(
- PHINode *OrigPhi, const InductionDescriptor &ID,
+ PHINode *OrigPhi, const InductionDescriptor &ID, Value *Step,
ArrayRef<BasicBlock *> BypassBlocks,
std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr});
+ /// Returns the original loop trip count.
+ Value *getTripCount() const { return TripCount; }
+
+ /// Used to set the trip count after ILV's construction and after the
+ /// preheader block has been executed. Note that this always holds the trip
+ /// count of the original loop for both main loop and epilogue vectorization.
+ void setTripCount(Value *TC) { TripCount = TC; }
+
protected:
friend class LoopVectorizationPlanner;
@@ -560,7 +617,7 @@ protected:
void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II,
Value *VectorTripCount, Value *EndValue,
BasicBlock *MiddleBlock, BasicBlock *VectorHeader,
- VPlan &Plan);
+ VPlan &Plan, VPTransformState &State);
/// Handle all cross-iteration phis in the header.
void fixCrossIterationPHIs(VPTransformState &State);
@@ -573,10 +630,6 @@ protected:
/// Create code for the loop exit value of the reduction.
void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State);
- /// Clear NSW/NUW flags from reduction instructions if necessary.
- void clearReductionWrapFlags(VPReductionPHIRecipe *PhiR,
- VPTransformState &State);
-
/// Iteratively sink the scalarized operands of a predicated instruction into
/// the block that was created for it.
void sinkScalarOperands(Instruction *PredInst);
@@ -585,9 +638,6 @@ protected:
/// represented as.
void truncateToMinimalBitwidths(VPTransformState &State);
- /// Returns (and creates if needed) the original loop trip count.
- Value *getOrCreateTripCount(BasicBlock *InsertBlock);
-
/// Returns (and creates if needed) the trip count of the widened loop.
Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock);
@@ -621,6 +671,7 @@ protected:
/// block, the \p AdditionalBypass pair provides information about the bypass
/// block and the end value on the edge from bypass to this loop.
void createInductionResumeValues(
+ const SCEV2ValueTy &ExpandedSCEVs,
std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr});
/// Complete the loop skeleton by adding debug MDs, creating appropriate
@@ -758,9 +809,6 @@ public:
ElementCount::getFixed(1),
ElementCount::getFixed(1), UnrollFactor, LVL, CM,
BFI, PSI, Check) {}
-
-private:
- Value *getBroadcastInstrs(Value *V) override;
};
/// Encapsulate information regarding vectorization of a loop and its epilogue.
@@ -810,15 +858,16 @@ public:
// Override this function to handle the more complex control flow around the
// three loops.
- std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton() final {
- return createEpilogueVectorizedLoopSkeleton();
+ std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton(
+ const SCEV2ValueTy &ExpandedSCEVs) final {
+ return createEpilogueVectorizedLoopSkeleton(ExpandedSCEVs);
}
/// The interface for creating a vectorized skeleton using one of two
/// different strategies, each corresponding to one execution of the vplan
/// as described above.
virtual std::pair<BasicBlock *, Value *>
- createEpilogueVectorizedLoopSkeleton() = 0;
+ createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) = 0;
/// Holds and updates state information required to vectorize the main loop
/// and its epilogue in two separate passes. This setup helps us avoid
@@ -846,7 +895,8 @@ public:
EPI, LVL, CM, BFI, PSI, Check) {}
/// Implements the interface for creating a vectorized skeleton using the
/// *main loop* strategy (ie the first pass of vplan execution).
- std::pair<BasicBlock *, Value *> createEpilogueVectorizedLoopSkeleton() final;
+ std::pair<BasicBlock *, Value *>
+ createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) final;
protected:
/// Emits an iteration count bypass check once for the main loop (when \p
@@ -876,7 +926,8 @@ public:
}
/// Implements the interface for creating a vectorized skeleton using the
/// *epilogue loop* strategy (ie the second pass of vplan execution).
- std::pair<BasicBlock *, Value *> createEpilogueVectorizedLoopSkeleton() final;
+ std::pair<BasicBlock *, Value *>
+ createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) final;
protected:
/// Emits an iteration count bypass check after the main vector loop has
@@ -953,35 +1004,21 @@ namespace llvm {
Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
int64_t Step) {
assert(Ty->isIntegerTy() && "Expected an integer step");
- Constant *StepVal = ConstantInt::get(Ty, Step * VF.getKnownMinValue());
- return VF.isScalable() ? B.CreateVScale(StepVal) : StepVal;
+ return B.CreateElementCount(Ty, VF.multiplyCoefficientBy(Step));
}
/// Return the runtime value for VF.
Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
- Constant *EC = ConstantInt::get(Ty, VF.getKnownMinValue());
- return VF.isScalable() ? B.CreateVScale(EC) : EC;
+ return B.CreateElementCount(Ty, VF);
}
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE) {
+const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
+ Loop *OrigLoop) {
const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
ScalarEvolution &SE = *PSE.getSE();
-
- // The exit count might have the type of i64 while the phi is i32. This can
- // happen if we have an induction variable that is sign extended before the
- // compare. The only way that we get a backedge taken count is that the
- // induction variable was signed and as such will not overflow. In such a case
- // truncation is legal.
- if (SE.getTypeSizeInBits(BackedgeTakenCount->getType()) >
- IdxTy->getPrimitiveSizeInBits())
- BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy);
- BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy);
-
- // Get the total trip count from the count by adding 1.
- return SE.getAddExpr(BackedgeTakenCount,
- SE.getOne(BackedgeTakenCount->getType()));
+ return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
}
static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy,
@@ -1062,11 +1099,17 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes(
continue;
// This recipe contributes to the address computation of a widen
- // load/store. Collect recipe if its underlying instruction has
- // poison-generating flags.
- Instruction *Instr = CurRec->getUnderlyingInstr();
- if (Instr && Instr->hasPoisonGeneratingFlags())
- State.MayGeneratePoisonRecipes.insert(CurRec);
+ // load/store. If the underlying instruction has poison-generating flags,
+ // drop them directly.
+ if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) {
+ RecWithFlags->dropPoisonGeneratingFlags();
+ } else {
+ Instruction *Instr = CurRec->getUnderlyingInstr();
+ (void)Instr;
+ assert((!Instr || !Instr->hasPoisonGeneratingFlags()) &&
+ "found instruction with poison generating flags not covered by "
+ "VPRecipeWithIRFlags");
+ }
// Add new definitions to the worklist.
for (VPValue *operand : CurRec->operands())
@@ -1143,15 +1186,7 @@ enum ScalarEpilogueLowering {
CM_ScalarEpilogueNotAllowedUsePredicate
};
-/// ElementCountComparator creates a total ordering for ElementCount
-/// for the purposes of using it in a set structure.
-struct ElementCountComparator {
- bool operator()(const ElementCount &LHS, const ElementCount &RHS) const {
- return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) <
- std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue());
- }
-};
-using ElementCountSet = SmallSet<ElementCount, 16, ElementCountComparator>;
+using InstructionVFPair = std::pair<Instruction *, ElementCount>;
/// LoopVectorizationCostModel - estimates the expected speedups due to
/// vectorization.
@@ -1184,17 +1219,6 @@ public:
/// otherwise.
bool runtimeChecksRequired();
- /// \return The most profitable vectorization factor and the cost of that VF.
- /// This method checks every VF in \p CandidateVFs. If UserVF is not ZERO
- /// then this vectorization factor will be selected if vectorization is
- /// possible.
- VectorizationFactor
- selectVectorizationFactor(const ElementCountSet &CandidateVFs);
-
- VectorizationFactor
- selectEpilogueVectorizationFactor(const ElementCount MaxVF,
- const LoopVectorizationPlanner &LVP);
-
/// Setup cost-based decisions for user vectorization factor.
/// \return true if the UserVF is a feasible VF to be chosen.
bool selectUserVectorizationFactor(ElementCount UserVF) {
@@ -1278,11 +1302,17 @@ public:
auto Scalars = InstsToScalarize.find(VF);
assert(Scalars != InstsToScalarize.end() &&
"VF not yet analyzed for scalarization profitability");
- return Scalars->second.find(I) != Scalars->second.end();
+ return Scalars->second.contains(I);
}
/// Returns true if \p I is known to be uniform after vectorization.
bool isUniformAfterVectorization(Instruction *I, ElementCount VF) const {
+ // Pseudo probe needs to be duplicated for each unrolled iteration and
+ // vector lane so that profiled loop trip count can be accurately
+ // accumulated instead of being under counted.
+ if (isa<PseudoProbeInst>(I))
+ return false;
+
if (VF.isScalar())
return true;
@@ -1316,7 +1346,7 @@ public:
/// \returns True if instruction \p I can be truncated to a smaller bitwidth
/// for vectorization factor \p VF.
bool canTruncateToMinimalBitwidth(Instruction *I, ElementCount VF) const {
- return VF.isVector() && MinBWs.find(I) != MinBWs.end() &&
+ return VF.isVector() && MinBWs.contains(I) &&
!isProfitableToScalarize(I, VF) &&
!isScalarAfterVectorization(I, VF);
}
@@ -1379,7 +1409,7 @@ public:
InstructionCost getWideningCost(Instruction *I, ElementCount VF) {
assert(VF.isVector() && "Expected VF >=2");
std::pair<Instruction *, ElementCount> InstOnVF = std::make_pair(I, VF);
- assert(WideningDecisions.find(InstOnVF) != WideningDecisions.end() &&
+ assert(WideningDecisions.contains(InstOnVF) &&
"The cost is not calculated");
return WideningDecisions[InstOnVF].second;
}
@@ -1419,7 +1449,7 @@ public:
/// that may be vectorized as interleave, gather-scatter or scalarized.
void collectUniformsAndScalars(ElementCount VF) {
// Do the analysis once.
- if (VF.isScalar() || Uniforms.find(VF) != Uniforms.end())
+ if (VF.isScalar() || Uniforms.contains(VF))
return;
setCostBasedWideningDecision(VF);
collectLoopUniforms(VF);
@@ -1442,8 +1472,7 @@ public:
/// Returns true if the target machine can represent \p V as a masked gather
/// or scatter operation.
- bool isLegalGatherOrScatter(Value *V,
- ElementCount VF = ElementCount::getFixed(1)) {
+ bool isLegalGatherOrScatter(Value *V, ElementCount VF) {
bool LI = isa<LoadInst>(V);
bool SI = isa<StoreInst>(V);
if (!LI && !SI)
@@ -1522,14 +1551,29 @@ public:
/// Returns true if we're required to use a scalar epilogue for at least
/// the final iteration of the original loop.
- bool requiresScalarEpilogue(ElementCount VF) const {
+ bool requiresScalarEpilogue(bool IsVectorizing) const {
if (!isScalarEpilogueAllowed())
return false;
// If we might exit from anywhere but the latch, must run the exiting
// iteration in scalar form.
if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch())
return true;
- return VF.isVector() && InterleaveInfo.requiresScalarEpilogue();
+ return IsVectorizing && InterleaveInfo.requiresScalarEpilogue();
+ }
+
+ /// Returns true if we're required to use a scalar epilogue for at least
+ /// the final iteration of the original loop for all VFs in \p Range.
+ /// A scalar epilogue must either be required for all VFs in \p Range or for
+ /// none.
+ bool requiresScalarEpilogue(VFRange Range) const {
+ auto RequiresScalarEpilogue = [this](ElementCount VF) {
+ return requiresScalarEpilogue(VF.isVector());
+ };
+ bool IsRequired = all_of(Range, RequiresScalarEpilogue);
+ assert(
+ (IsRequired || none_of(Range, RequiresScalarEpilogue)) &&
+ "all VFs in range must agree on whether a scalar epilogue is required");
+ return IsRequired;
}
/// Returns true if a scalar epilogue is not allowed due to optsize or a
@@ -1538,14 +1582,21 @@ public:
return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed;
}
- /// Returns true if all loop blocks should be masked to fold tail loop.
- bool foldTailByMasking() const { return FoldTailByMasking; }
+ /// Returns the TailFoldingStyle that is best for the current loop.
+ TailFoldingStyle
+ getTailFoldingStyle(bool IVUpdateMayOverflow = true) const {
+ if (!CanFoldTailByMasking)
+ return TailFoldingStyle::None;
+
+ if (ForceTailFoldingStyle.getNumOccurrences())
+ return ForceTailFoldingStyle;
+
+ return TTI.getPreferredTailFoldingStyle(IVUpdateMayOverflow);
+ }
- /// Returns true if were tail-folding and want to use the active lane mask
- /// for vector loop control flow.
- bool useActiveLaneMaskForControlFlow() const {
- return FoldTailByMasking &&
- TTI.emitGetActiveLaneMask() == PredicationStyle::DataAndControlFlow;
+ /// Returns true if all loop blocks should be masked to fold tail loop.
+ bool foldTailByMasking() const {
+ return getTailFoldingStyle() != TailFoldingStyle::None;
}
/// Returns true if the instructions in this block requires predication
@@ -1582,12 +1633,8 @@ public:
/// scalarized -
/// i.e. either vector version isn't available, or is too expensive.
InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF,
- bool &NeedToScalarize) const;
-
- /// Returns true if the per-lane cost of VectorizationFactor A is lower than
- /// that of B.
- bool isMoreProfitable(const VectorizationFactor &A,
- const VectorizationFactor &B) const;
+ Function **Variant,
+ bool *NeedsMask = nullptr) const;
/// Invalidates decisions already taken by the cost model.
void invalidateCostModelingDecisions() {
@@ -1596,10 +1643,29 @@ public:
Scalars.clear();
}
- /// Convenience function that returns the value of vscale_range iff
- /// vscale_range.min == vscale_range.max or otherwise returns the value
- /// returned by the corresponding TLI method.
- std::optional<unsigned> getVScaleForTuning() const;
+ /// The vectorization cost is a combination of the cost itself and a boolean
+ /// indicating whether any of the contributing operations will actually
+ /// operate on vector values after type legalization in the backend. If this
+ /// latter value is false, then all operations will be scalarized (i.e. no
+ /// vectorization has actually taken place).
+ using VectorizationCostTy = std::pair<InstructionCost, bool>;
+
+ /// Returns the expected execution cost. The unit of the cost does
+ /// not matter because we use the 'cost' units to compare different
+ /// vector widths. The cost that is returned is *not* normalized by
+ /// the factor width. If \p Invalid is not nullptr, this function
+ /// will add a pair(Instruction*, ElementCount) to \p Invalid for
+ /// each instruction that has an Invalid cost for the given VF.
+ VectorizationCostTy
+ expectedCost(ElementCount VF,
+ SmallVectorImpl<InstructionVFPair> *Invalid = nullptr);
+
+ bool hasPredStores() const { return NumPredStores > 0; }
+
+ /// Returns true if epilogue vectorization is considered profitable, and
+ /// false otherwise.
+ /// \p VF is the vectorization factor chosen for the original loop.
+ bool isEpilogueVectorizationProfitable(const ElementCount VF) const;
private:
unsigned NumPredStores = 0;
@@ -1626,24 +1692,6 @@ private:
/// of elements.
ElementCount getMaxLegalScalableVF(unsigned MaxSafeElements);
- /// The vectorization cost is a combination of the cost itself and a boolean
- /// indicating whether any of the contributing operations will actually
- /// operate on vector values after type legalization in the backend. If this
- /// latter value is false, then all operations will be scalarized (i.e. no
- /// vectorization has actually taken place).
- using VectorizationCostTy = std::pair<InstructionCost, bool>;
-
- /// Returns the expected execution cost. The unit of the cost does
- /// not matter because we use the 'cost' units to compare different
- /// vector widths. The cost that is returned is *not* normalized by
- /// the factor width. If \p Invalid is not nullptr, this function
- /// will add a pair(Instruction*, ElementCount) to \p Invalid for
- /// each instruction that has an Invalid cost for the given VF.
- using InstructionVFPair = std::pair<Instruction *, ElementCount>;
- VectorizationCostTy
- expectedCost(ElementCount VF,
- SmallVectorImpl<InstructionVFPair> *Invalid = nullptr);
-
/// Returns the execution time cost of an instruction for a given vector
/// width. Vector width of one means scalar.
VectorizationCostTy getInstructionCost(Instruction *I, ElementCount VF);
@@ -1715,7 +1763,7 @@ private:
ScalarEpilogueLowering ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
/// All blocks of loop are to be masked to fold tail of scalar iterations.
- bool FoldTailByMasking = false;
+ bool CanFoldTailByMasking = false;
/// A map holding scalar costs for different vectorization factors. The
/// presence of a cost for an instruction in the mapping indicates that the
@@ -1796,8 +1844,7 @@ private:
// the scalars are collected. That should be a safe assumption in most
// cases, because we check if the operands have vectorizable types
// beforehand in LoopVectorizationLegality.
- return Scalars.find(VF) == Scalars.end() ||
- !isScalarAfterVectorization(I, VF);
+ return !Scalars.contains(VF) || !isScalarAfterVectorization(I, VF);
};
/// Returns a range containing only operands needing to be extracted.
@@ -1807,16 +1854,6 @@ private:
Ops, [this, VF](Value *V) { return this->needsExtract(V, VF); }));
}
- /// Determines if we have the infrastructure to vectorize loop \p L and its
- /// epilogue, assuming the main loop is vectorized by \p VF.
- bool isCandidateForEpilogueVectorization(const Loop &L,
- const ElementCount VF) const;
-
- /// Returns true if epilogue vectorization is considered profitable, and
- /// false otherwise.
- /// \p VF is the vectorization factor chosen for the original loop.
- bool isEpilogueVectorizationProfitable(const ElementCount VF) const;
-
public:
/// The loop that we evaluate.
Loop *TheLoop;
@@ -1862,9 +1899,6 @@ public:
/// All element types found in the loop.
SmallPtrSet<Type *, 16> ElementTypesInLoop;
-
- /// Profitable vector factors.
- SmallVector<VectorizationFactor, 8> ProfitableVFs;
};
} // end namespace llvm
@@ -2135,6 +2169,17 @@ public:
};
} // namespace
+static bool useActiveLaneMask(TailFoldingStyle Style) {
+ return Style == TailFoldingStyle::Data ||
+ Style == TailFoldingStyle::DataAndControlFlow ||
+ Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
+}
+
+static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
+ return Style == TailFoldingStyle::DataAndControlFlow ||
+ Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
+}
+
// Return true if \p OuterLp is an outer loop annotated with hints for explicit
// vectorization. The loop needs to be annotated with #pragma omp simd
// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -2202,97 +2247,11 @@ static void collectSupportedLoops(Loop &L, LoopInfo *LI,
collectSupportedLoops(*InnerL, LI, ORE, V);
}
-namespace {
-
-/// The LoopVectorize Pass.
-struct LoopVectorize : public FunctionPass {
- /// Pass identification, replacement for typeid
- static char ID;
-
- LoopVectorizePass Impl;
-
- explicit LoopVectorize(bool InterleaveOnlyWhenForced = false,
- bool VectorizeOnlyWhenForced = false)
- : FunctionPass(ID),
- Impl({InterleaveOnlyWhenForced, VectorizeOnlyWhenForced}) {
- initializeLoopVectorizePass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI();
- auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
- auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
- auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits();
- auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
-
- return Impl
- .runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AC, LAIs, *ORE, PSI)
- .MadeAnyChange;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<BlockFrequencyInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<LoopAccessLegacyAnalysis>();
- AU.addRequired<DemandedBitsWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<InjectTLIMappingsLegacy>();
-
- // We currently do not preserve loopinfo/dominator analyses with outer loop
- // vectorization. Until this is addressed, mark these analyses as preserved
- // only for non-VPlan-native path.
- // TODO: Preserve Loop and Dominator analyses for VPlan-native path.
- if (!EnableVPlanNativePath) {
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- }
-
- AU.addPreserved<BasicAAWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addRequired<ProfileSummaryInfoWrapperPass>();
- }
-};
-
-} // end anonymous namespace
-
//===----------------------------------------------------------------------===//
// Implementation of LoopVectorizationLegality, InnerLoopVectorizer and
// LoopVectorizationCostModel and LoopVectorizationPlanner.
//===----------------------------------------------------------------------===//
-Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) {
- // We need to place the broadcast of invariant variables outside the loop,
- // but only if it's proven safe to do so. Else, broadcast will be inside
- // vector loop body.
- Instruction *Instr = dyn_cast<Instruction>(V);
- bool SafeToHoist = OrigLoop->isLoopInvariant(V) &&
- (!Instr ||
- DT->dominates(Instr->getParent(), LoopVectorPreHeader));
- // Place the code for broadcasting invariant variables in the new preheader.
- IRBuilder<>::InsertPointGuard Guard(Builder);
- if (SafeToHoist)
- Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());
-
- // Broadcast the scalar into all locations in the vector.
- Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast");
-
- return Shuf;
-}
-
/// This function adds
/// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...)
/// to each vector element of Val. The sequence starts at StartIndex.
@@ -2435,21 +2394,6 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step,
}
}
-// Generate code for the induction step. Note that induction steps are
-// required to be loop-invariant
-static Value *CreateStepValue(const SCEV *Step, ScalarEvolution &SE,
- Instruction *InsertBefore,
- Loop *OrigLoop = nullptr) {
- const DataLayout &DL = SE.getDataLayout();
- assert((!OrigLoop || SE.isLoopInvariant(Step, OrigLoop)) &&
- "Induction step should be loop invariant");
- if (auto *E = dyn_cast<SCEVUnknown>(Step))
- return E->getValue();
-
- SCEVExpander Exp(SE, DL, "induction");
- return Exp.expandCodeFor(Step, Step->getType(), InsertBefore);
-}
-
/// Compute the transformed value of Index at offset StartValue using step
/// StepValue.
/// For integer induction, returns StartValue + Index * StepValue.
@@ -2514,9 +2458,7 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
return CreateAdd(StartValue, Offset);
}
case InductionDescriptor::IK_PtrInduction: {
- assert(isa<Constant>(Step) &&
- "Expected constant step for pointer induction");
- return B.CreateGEP(ID.getElementType(), StartValue, CreateMul(Index, Step));
+ return B.CreateGEP(B.getInt8Ty(), StartValue, CreateMul(Index, Step));
}
case InductionDescriptor::IK_FpInduction: {
assert(!isa<VectorType>(Index->getType()) &&
@@ -2538,6 +2480,50 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
llvm_unreachable("invalid enum");
}
+std::optional<unsigned> getMaxVScale(const Function &F,
+ const TargetTransformInfo &TTI) {
+ if (std::optional<unsigned> MaxVScale = TTI.getMaxVScale())
+ return MaxVScale;
+
+ if (F.hasFnAttribute(Attribute::VScaleRange))
+ return F.getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax();
+
+ return std::nullopt;
+}
+
+/// For the given VF and UF and maximum trip count computed for the loop, return
+/// whether the induction variable might overflow in the vectorized loop. If not,
+/// then we know a runtime overflow check always evaluates to false and can be
+/// removed.
+static bool isIndvarOverflowCheckKnownFalse(
+ const LoopVectorizationCostModel *Cost,
+ ElementCount VF, std::optional<unsigned> UF = std::nullopt) {
+ // Always be conservative if we don't know the exact unroll factor.
+ unsigned MaxUF = UF ? *UF : Cost->TTI.getMaxInterleaveFactor(VF);
+
+ Type *IdxTy = Cost->Legal->getWidestInductionType();
+ APInt MaxUIntTripCount = cast<IntegerType>(IdxTy)->getMask();
+
+ // We know the runtime overflow check is known false iff the (max) trip-count
+ // is known and (max) trip-count + (VF * UF) does not overflow in the type of
+ // the vector loop induction variable.
+ if (unsigned TC =
+ Cost->PSE.getSE()->getSmallConstantMaxTripCount(Cost->TheLoop)) {
+ uint64_t MaxVF = VF.getKnownMinValue();
+ if (VF.isScalable()) {
+ std::optional<unsigned> MaxVScale =
+ getMaxVScale(*Cost->TheFunction, Cost->TTI);
+ if (!MaxVScale)
+ return false;
+ MaxVF *= *MaxVScale;
+ }
+
+ return (MaxUIntTripCount - TC).ugt(MaxVF * MaxUF);
+ }
+
+ return false;
+}
+
void InnerLoopVectorizer::packScalarIntoVectorValue(VPValue *Def,
const VPIteration &Instance,
VPTransformState &State) {
@@ -2591,14 +2577,13 @@ static bool useMaskedInterleavedAccesses(const TargetTransformInfo &TTI) {
void InnerLoopVectorizer::vectorizeInterleaveGroup(
const InterleaveGroup<Instruction> *Group, ArrayRef<VPValue *> VPDefs,
VPTransformState &State, VPValue *Addr, ArrayRef<VPValue *> StoredValues,
- VPValue *BlockInMask) {
+ VPValue *BlockInMask, bool NeedsMaskForGaps) {
Instruction *Instr = Group->getInsertPos();
const DataLayout &DL = Instr->getModule()->getDataLayout();
// Prepare for the vector type of the interleaved load/store.
Type *ScalarTy = getLoadStoreType(Instr);
unsigned InterleaveFactor = Group->getFactor();
- assert(!VF.isScalable() && "scalable vectors not yet supported.");
auto *VecTy = VectorType::get(ScalarTy, VF * InterleaveFactor);
// Prepare for the new pointers.
@@ -2609,14 +2594,21 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
assert((!BlockInMask || !Group->isReverse()) &&
"Reversed masked interleave-group not supported.");
+ Value *Idx;
// If the group is reverse, adjust the index to refer to the last vector lane
// instead of the first. We adjust the index from the first vector lane,
// rather than directly getting the pointer for lane VF - 1, because the
// pointer operand of the interleaved access is supposed to be uniform. For
// uniform instructions, we're only required to generate a value for the
// first vector lane in each unroll iteration.
- if (Group->isReverse())
- Index += (VF.getKnownMinValue() - 1) * Group->getFactor();
+ if (Group->isReverse()) {
+ Value *RuntimeVF = getRuntimeVF(Builder, Builder.getInt32Ty(), VF);
+ Idx = Builder.CreateSub(RuntimeVF, Builder.getInt32(1));
+ Idx = Builder.CreateMul(Idx, Builder.getInt32(Group->getFactor()));
+ Idx = Builder.CreateAdd(Idx, Builder.getInt32(Index));
+ Idx = Builder.CreateNeg(Idx);
+ } else
+ Idx = Builder.getInt32(-Index);
for (unsigned Part = 0; Part < UF; Part++) {
Value *AddrPart = State.get(Addr, VPIteration(Part, 0));
@@ -2637,8 +2629,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
bool InBounds = false;
if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts()))
InBounds = gep->isInBounds();
- AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Builder.getInt32(-Index));
- cast<GetElementPtrInst>(AddrPart)->setIsInBounds(InBounds);
+ AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds);
// Cast to the vector pointer type.
unsigned AddressSpace = AddrPart->getType()->getPointerAddressSpace();
@@ -2649,14 +2640,43 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
State.setDebugLocFromInst(Instr);
Value *PoisonVec = PoisonValue::get(VecTy);
- Value *MaskForGaps = nullptr;
- if (Group->requiresScalarEpilogue() && !Cost->isScalarEpilogueAllowed()) {
- MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group);
- assert(MaskForGaps && "Mask for Gaps is required but it is null");
- }
+ auto CreateGroupMask = [this, &BlockInMask, &State, &InterleaveFactor](
+ unsigned Part, Value *MaskForGaps) -> Value * {
+ if (VF.isScalable()) {
+ assert(!MaskForGaps && "Interleaved groups with gaps are not supported.");
+ assert(InterleaveFactor == 2 &&
+ "Unsupported deinterleave factor for scalable vectors");
+ auto *BlockInMaskPart = State.get(BlockInMask, Part);
+ SmallVector<Value *, 2> Ops = {BlockInMaskPart, BlockInMaskPart};
+ auto *MaskTy =
+ VectorType::get(Builder.getInt1Ty(), VF.getKnownMinValue() * 2, true);
+ return Builder.CreateIntrinsic(
+ MaskTy, Intrinsic::experimental_vector_interleave2, Ops,
+ /*FMFSource=*/nullptr, "interleaved.mask");
+ }
+
+ if (!BlockInMask)
+ return MaskForGaps;
+
+ Value *BlockInMaskPart = State.get(BlockInMask, Part);
+ Value *ShuffledMask = Builder.CreateShuffleVector(
+ BlockInMaskPart,
+ createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()),
+ "interleaved.mask");
+ return MaskForGaps ? Builder.CreateBinOp(Instruction::And, ShuffledMask,
+ MaskForGaps)
+ : ShuffledMask;
+ };
// Vectorize the interleaved load group.
if (isa<LoadInst>(Instr)) {
+ Value *MaskForGaps = nullptr;
+ if (NeedsMaskForGaps) {
+ MaskForGaps =
+ createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group);
+ assert(MaskForGaps && "Mask for Gaps is required but it is null");
+ }
+
// For each unroll part, create a wide load for the group.
SmallVector<Value *, 2> NewLoads;
for (unsigned Part = 0; Part < UF; Part++) {
@@ -2664,18 +2684,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
if (BlockInMask || MaskForGaps) {
assert(useMaskedInterleavedAccesses(*TTI) &&
"masked interleaved groups are not allowed.");
- Value *GroupMask = MaskForGaps;
- if (BlockInMask) {
- Value *BlockInMaskPart = State.get(BlockInMask, Part);
- Value *ShuffledMask = Builder.CreateShuffleVector(
- BlockInMaskPart,
- createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()),
- "interleaved.mask");
- GroupMask = MaskForGaps
- ? Builder.CreateBinOp(Instruction::And, ShuffledMask,
- MaskForGaps)
- : ShuffledMask;
- }
+ Value *GroupMask = CreateGroupMask(Part, MaskForGaps);
NewLoad =
Builder.CreateMaskedLoad(VecTy, AddrParts[Part], Group->getAlign(),
GroupMask, PoisonVec, "wide.masked.vec");
@@ -2687,6 +2696,41 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
NewLoads.push_back(NewLoad);
}
+ if (VecTy->isScalableTy()) {
+ assert(InterleaveFactor == 2 &&
+ "Unsupported deinterleave factor for scalable vectors");
+
+ for (unsigned Part = 0; Part < UF; ++Part) {
+ // Scalable vectors cannot use arbitrary shufflevectors (only splats),
+ // so must use intrinsics to deinterleave.
+ Value *DI = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vector_deinterleave2, VecTy, NewLoads[Part],
+ /*FMFSource=*/nullptr, "strided.vec");
+ unsigned J = 0;
+ for (unsigned I = 0; I < InterleaveFactor; ++I) {
+ Instruction *Member = Group->getMember(I);
+
+ if (!Member)
+ continue;
+
+ Value *StridedVec = Builder.CreateExtractValue(DI, I);
+ // If this member has different type, cast the result type.
+ if (Member->getType() != ScalarTy) {
+ VectorType *OtherVTy = VectorType::get(Member->getType(), VF);
+ StridedVec = createBitOrPointerCast(StridedVec, OtherVTy, DL);
+ }
+
+ if (Group->isReverse())
+ StridedVec = Builder.CreateVectorReverse(StridedVec, "reverse");
+
+ State.set(VPDefs[J], StridedVec, Part);
+ ++J;
+ }
+ }
+
+ return;
+ }
+
// For each member in the group, shuffle out the appropriate data from the
// wide loads.
unsigned J = 0;
@@ -2724,7 +2768,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
auto *SubVT = VectorType::get(ScalarTy, VF);
// Vectorize the interleaved store group.
- MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group);
+ Value *MaskForGaps =
+ createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group);
assert((!MaskForGaps || useMaskedInterleavedAccesses(*TTI)) &&
"masked interleaved groups are not allowed.");
assert((!MaskForGaps || !VF.isScalable()) &&
@@ -2759,27 +2804,11 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
StoredVecs.push_back(StoredVec);
}
- // Concatenate all vectors into a wide vector.
- Value *WideVec = concatenateVectors(Builder, StoredVecs);
-
- // Interleave the elements in the wide vector.
- Value *IVec = Builder.CreateShuffleVector(
- WideVec, createInterleaveMask(VF.getKnownMinValue(), InterleaveFactor),
- "interleaved.vec");
-
+ // Interleave all the smaller vectors into one wider vector.
+ Value *IVec = interleaveVectors(Builder, StoredVecs, "interleaved.vec");
Instruction *NewStoreInstr;
if (BlockInMask || MaskForGaps) {
- Value *GroupMask = MaskForGaps;
- if (BlockInMask) {
- Value *BlockInMaskPart = State.get(BlockInMask, Part);
- Value *ShuffledMask = Builder.CreateShuffleVector(
- BlockInMaskPart,
- createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()),
- "interleaved.mask");
- GroupMask = MaskForGaps ? Builder.CreateBinOp(Instruction::And,
- ShuffledMask, MaskForGaps)
- : ShuffledMask;
- }
+ Value *GroupMask = CreateGroupMask(Part, MaskForGaps);
NewStoreInstr = Builder.CreateMaskedStore(IVec, AddrParts[Part],
Group->getAlign(), GroupMask);
} else
@@ -2793,7 +2822,6 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
VPReplicateRecipe *RepRecipe,
const VPIteration &Instance,
- bool IfPredicateInstr,
VPTransformState &State) {
assert(!Instr->getType()->isAggregateType() && "Can't handle vectors");
@@ -2810,14 +2838,7 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
if (!IsVoidRetTy)
Cloned->setName(Instr->getName() + ".cloned");
- // If the scalarized instruction contributes to the address computation of a
- // widen masked load/store which was in a basic block that needed predication
- // and is not predicated after vectorization, we can't propagate
- // poison-generating flags (nuw/nsw, exact, inbounds, etc.). The scalarized
- // instruction could feed a poison value to the base address of the widen
- // load/store.
- if (State.MayGeneratePoisonRecipes.contains(RepRecipe))
- Cloned->dropPoisonGeneratingFlags();
+ RepRecipe->setFlags(Cloned);
if (Instr->getDebugLoc())
State.setDebugLocFromInst(Instr);
@@ -2843,45 +2864,17 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
AC->registerAssumption(II);
// End if-block.
+ bool IfPredicateInstr = RepRecipe->getParent()->getParent()->isReplicator();
if (IfPredicateInstr)
PredicatedInstructions.push_back(Cloned);
}
-Value *InnerLoopVectorizer::getOrCreateTripCount(BasicBlock *InsertBlock) {
- if (TripCount)
- return TripCount;
-
- assert(InsertBlock);
- IRBuilder<> Builder(InsertBlock->getTerminator());
- // Find the loop boundaries.
- Type *IdxTy = Legal->getWidestInductionType();
- assert(IdxTy && "No type for induction");
- const SCEV *ExitCount = createTripCountSCEV(IdxTy, PSE);
-
- const DataLayout &DL = InsertBlock->getModule()->getDataLayout();
-
- // Expand the trip count and place the new instructions in the preheader.
- // Notice that the pre-header does not change, only the loop body.
- SCEVExpander Exp(*PSE.getSE(), DL, "induction");
-
- // Count holds the overall loop count (N).
- TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(),
- InsertBlock->getTerminator());
-
- if (TripCount->getType()->isPointerTy())
- TripCount =
- CastInst::CreatePointerCast(TripCount, IdxTy, "exitcount.ptrcnt.to.int",
- InsertBlock->getTerminator());
-
- return TripCount;
-}
-
Value *
InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) {
if (VectorTripCount)
return VectorTripCount;
- Value *TC = getOrCreateTripCount(InsertBlock);
+ Value *TC = getTripCount();
IRBuilder<> Builder(InsertBlock->getTerminator());
Type *Ty = TC->getType();
@@ -2917,7 +2910,7 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) {
// the step does not evenly divide the trip count, no adjustment is necessary
// since there will already be scalar iterations. Note that the minimum
// iterations check ensures that N >= Step.
- if (Cost->requiresScalarEpilogue(VF)) {
+ if (Cost->requiresScalarEpilogue(VF.isVector())) {
auto *IsZero = Builder.CreateICmpEQ(R, ConstantInt::get(R->getType(), 0));
R = Builder.CreateSelect(IsZero, Step, R);
}
@@ -2930,10 +2923,10 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) {
Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy,
const DataLayout &DL) {
// Verify that V is a vector type with same number of elements as DstVTy.
- auto *DstFVTy = cast<FixedVectorType>(DstVTy);
- unsigned VF = DstFVTy->getNumElements();
- auto *SrcVecTy = cast<FixedVectorType>(V->getType());
- assert((VF == SrcVecTy->getNumElements()) && "Vector dimensions do not match");
+ auto *DstFVTy = cast<VectorType>(DstVTy);
+ auto VF = DstFVTy->getElementCount();
+ auto *SrcVecTy = cast<VectorType>(V->getType());
+ assert(VF == SrcVecTy->getElementCount() && "Vector dimensions do not match");
Type *SrcElemTy = SrcVecTy->getElementType();
Type *DstElemTy = DstFVTy->getElementType();
assert((DL.getTypeSizeInBits(SrcElemTy) == DL.getTypeSizeInBits(DstElemTy)) &&
@@ -2953,13 +2946,13 @@ Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy,
"Only one type should be a floating point type");
Type *IntTy =
IntegerType::getIntNTy(V->getContext(), DL.getTypeSizeInBits(SrcElemTy));
- auto *VecIntTy = FixedVectorType::get(IntTy, VF);
+ auto *VecIntTy = VectorType::get(IntTy, VF);
Value *CastVal = Builder.CreateBitOrPointerCast(V, VecIntTy);
return Builder.CreateBitOrPointerCast(CastVal, DstFVTy);
}
void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
- Value *Count = getOrCreateTripCount(LoopVectorPreHeader);
+ Value *Count = getTripCount();
// Reuse existing vector loop preheader for TC checks.
// Note that new preheader block is generated for vector loop.
BasicBlock *const TCCheckBlock = LoopVectorPreHeader;
@@ -2970,8 +2963,8 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
// vector trip count is zero. This check also covers the case where adding one
// to the backedge-taken count overflowed leading to an incorrect trip count
// of zero. In this case we will also jump to the scalar loop.
- auto P = Cost->requiresScalarEpilogue(VF) ? ICmpInst::ICMP_ULE
- : ICmpInst::ICMP_ULT;
+ auto P = Cost->requiresScalarEpilogue(VF.isVector()) ? ICmpInst::ICMP_ULE
+ : ICmpInst::ICMP_ULT;
// If tail is to be folded, vector loop takes care of all iterations.
Type *CountTy = Count->getType();
@@ -2989,10 +2982,13 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
Intrinsic::umax, MinProfTC, createStepForVF(Builder, CountTy, VF, UF));
};
- if (!Cost->foldTailByMasking())
+ TailFoldingStyle Style = Cost->getTailFoldingStyle();
+ if (Style == TailFoldingStyle::None)
CheckMinIters =
Builder.CreateICmp(P, Count, CreateStep(), "min.iters.check");
- else if (VF.isScalable()) {
+ else if (VF.isScalable() &&
+ !isIndvarOverflowCheckKnownFalse(Cost, VF, UF) &&
+ Style != TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
// vscale is not necessarily a power-of-2, which means we cannot guarantee
// an overflow to zero when updating induction variables and so an
// additional overflow check is required before entering the vector loop.
@@ -3017,7 +3013,7 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
// Update dominator for Bypass & LoopExit (if needed).
DT->changeImmediateDominator(Bypass, TCCheckBlock);
- if (!Cost->requiresScalarEpilogue(VF))
+ if (!Cost->requiresScalarEpilogue(VF.isVector()))
// If there is an epilogue which must run, there's no edge from the
// middle block to exit blocks and thus no need to update the immediate
// dominator of the exit blocks.
@@ -3044,7 +3040,7 @@ BasicBlock *InnerLoopVectorizer::emitSCEVChecks(BasicBlock *Bypass) {
// Update dominator only if this is first RT check.
if (LoopBypassBlocks.empty()) {
DT->changeImmediateDominator(Bypass, SCEVCheckBlock);
- if (!Cost->requiresScalarEpilogue(VF))
+ if (!Cost->requiresScalarEpilogue(VF.isVector()))
// If there is an epilogue which must run, there's no edge from the
// middle block to exit blocks and thus no need to update the immediate
// dominator of the exit blocks.
@@ -3097,7 +3093,7 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
LoopVectorPreHeader = OrigLoop->getLoopPreheader();
assert(LoopVectorPreHeader && "Invalid loop structure");
LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
- assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF)) &&
+ assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
"multiple exit loop without required epilogue?");
LoopMiddleBlock =
@@ -3117,17 +3113,18 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
// branch from the middle block to the loop scalar preheader, and the
// exit block. completeLoopSkeleton will update the condition to use an
// iteration check, if required to decide whether to execute the remainder.
- BranchInst *BrInst = Cost->requiresScalarEpilogue(VF) ?
- BranchInst::Create(LoopScalarPreHeader) :
- BranchInst::Create(LoopExitBlock, LoopScalarPreHeader,
- Builder.getTrue());
+ BranchInst *BrInst =
+ Cost->requiresScalarEpilogue(VF.isVector())
+ ? BranchInst::Create(LoopScalarPreHeader)
+ : BranchInst::Create(LoopExitBlock, LoopScalarPreHeader,
+ Builder.getTrue());
BrInst->setDebugLoc(ScalarLatchTerm->getDebugLoc());
ReplaceInstWithInst(LoopMiddleBlock->getTerminator(), BrInst);
// Update dominator for loop exit. During skeleton creation, only the vector
// pre-header and the middle block are created. The vector loop is entirely
// created during VPlan exection.
- if (!Cost->requiresScalarEpilogue(VF))
+ if (!Cost->requiresScalarEpilogue(VF.isVector()))
// If there is an epilogue which must run, there's no edge from the
// middle block to exit blocks and thus no need to update the immediate
// dominator of the exit blocks.
@@ -3135,7 +3132,7 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
}
PHINode *InnerLoopVectorizer::createInductionResumeValue(
- PHINode *OrigPhi, const InductionDescriptor &II,
+ PHINode *OrigPhi, const InductionDescriptor &II, Value *Step,
ArrayRef<BasicBlock *> BypassBlocks,
std::pair<BasicBlock *, Value *> AdditionalBypass) {
Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader);
@@ -3154,8 +3151,6 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue(
if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp()))
B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags());
- Value *Step =
- CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint());
EndValue =
emitTransformedIndex(B, VectorTripCount, II.getStartValue(), Step, II);
EndValue->setName("ind.end");
@@ -3163,8 +3158,6 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue(
// Compute the end value for the additional bypass (if applicable).
if (AdditionalBypass.first) {
B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt()));
- Value *Step =
- CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint());
EndValueFromAdditionalBypass = emitTransformedIndex(
B, AdditionalBypass.second, II.getStartValue(), Step, II);
EndValueFromAdditionalBypass->setName("ind.end");
@@ -3193,7 +3186,22 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue(
return BCResumeVal;
}
+/// Return the expanded step for \p ID using \p ExpandedSCEVs to look up SCEV
+/// expansion results.
+static Value *getExpandedStep(const InductionDescriptor &ID,
+ const SCEV2ValueTy &ExpandedSCEVs) {
+ const SCEV *Step = ID.getStep();
+ if (auto *C = dyn_cast<SCEVConstant>(Step))
+ return C->getValue();
+ if (auto *U = dyn_cast<SCEVUnknown>(Step))
+ return U->getValue();
+ auto I = ExpandedSCEVs.find(Step);
+ assert(I != ExpandedSCEVs.end() && "SCEV must be expanded at this point");
+ return I->second;
+}
+
void InnerLoopVectorizer::createInductionResumeValues(
+ const SCEV2ValueTy &ExpandedSCEVs,
std::pair<BasicBlock *, Value *> AdditionalBypass) {
assert(((AdditionalBypass.first && AdditionalBypass.second) ||
(!AdditionalBypass.first && !AdditionalBypass.second)) &&
@@ -3209,14 +3217,15 @@ void InnerLoopVectorizer::createInductionResumeValues(
PHINode *OrigPhi = InductionEntry.first;
const InductionDescriptor &II = InductionEntry.second;
PHINode *BCResumeVal = createInductionResumeValue(
- OrigPhi, II, LoopBypassBlocks, AdditionalBypass);
+ OrigPhi, II, getExpandedStep(II, ExpandedSCEVs), LoopBypassBlocks,
+ AdditionalBypass);
OrigPhi->setIncomingValueForBlock(LoopScalarPreHeader, BCResumeVal);
}
}
BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
// The trip counts should be cached by now.
- Value *Count = getOrCreateTripCount(LoopVectorPreHeader);
+ Value *Count = getTripCount();
Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader);
auto *ScalarLatchTerm = OrigLoop->getLoopLatch()->getTerminator();
@@ -3229,7 +3238,8 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
// Thus if tail is to be folded, we know we don't need to run the
// remainder and we can use the previous value for the condition (true).
// 3) Otherwise, construct a runtime check.
- if (!Cost->requiresScalarEpilogue(VF) && !Cost->foldTailByMasking()) {
+ if (!Cost->requiresScalarEpilogue(VF.isVector()) &&
+ !Cost->foldTailByMasking()) {
Instruction *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
Count, VectorTripCount, "cmp.n",
LoopMiddleBlock->getTerminator());
@@ -3250,14 +3260,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
}
std::pair<BasicBlock *, Value *>
-InnerLoopVectorizer::createVectorizedLoopSkeleton() {
+InnerLoopVectorizer::createVectorizedLoopSkeleton(
+ const SCEV2ValueTy &ExpandedSCEVs) {
/*
In this function we generate a new loop. The new loop will contain
the vectorized instructions while the old loop will continue to run the
scalar remainder.
- [ ] <-- loop iteration number check.
- / |
+ [ ] <-- old preheader - loop iteration number check and SCEVs in Plan's
+ / | preheader are expanded here. Eventually all required SCEV
+ / | expansion should happen here.
/ v
| [ ] <-- vector loop bypass (may consist of multiple blocks).
| / |
@@ -3304,7 +3316,7 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() {
emitMemRuntimeChecks(LoopScalarPreHeader);
// Emit phis for the new starting index of the scalar loop.
- createInductionResumeValues();
+ createInductionResumeValues(ExpandedSCEVs);
return {completeLoopSkeleton(), nullptr};
}
@@ -3317,7 +3329,8 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
const InductionDescriptor &II,
Value *VectorTripCount, Value *EndValue,
BasicBlock *MiddleBlock,
- BasicBlock *VectorHeader, VPlan &Plan) {
+ BasicBlock *VectorHeader, VPlan &Plan,
+ VPTransformState &State) {
// There are two kinds of external IV usages - those that use the value
// computed in the last iteration (the PHI) and those that use the penultimate
// value (the value that feeds into the phi from the loop latch).
@@ -3345,7 +3358,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
auto *UI = cast<Instruction>(U);
if (!OrigLoop->contains(UI)) {
assert(isa<PHINode>(UI) && "Expected LCSSA form");
-
IRBuilder<> B(MiddleBlock->getTerminator());
// Fast-math-flags propagate from the original induction instruction.
@@ -3355,8 +3367,11 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
Value *CountMinusOne = B.CreateSub(
VectorTripCount, ConstantInt::get(VectorTripCount->getType(), 1));
CountMinusOne->setName("cmo");
- Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(),
- VectorHeader->getTerminator());
+
+ VPValue *StepVPV = Plan.getSCEVExpansion(II.getStep());
+ assert(StepVPV && "step must have been expanded during VPlan execution");
+ Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue()
+ : State.get(StepVPV, {0, 0});
Value *Escape =
emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, II);
Escape->setName("ind.escape");
@@ -3430,12 +3445,12 @@ static void cse(BasicBlock *BB) {
}
}
-InstructionCost
-LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF,
- bool &NeedToScalarize) const {
+InstructionCost LoopVectorizationCostModel::getVectorCallCost(
+ CallInst *CI, ElementCount VF, Function **Variant, bool *NeedsMask) const {
Function *F = CI->getCalledFunction();
Type *ScalarRetTy = CI->getType();
SmallVector<Type *, 4> Tys, ScalarTys;
+ bool MaskRequired = Legal->isMaskRequired(CI);
for (auto &ArgOp : CI->args())
ScalarTys.push_back(ArgOp->getType());
@@ -3464,18 +3479,39 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF,
// If we can't emit a vector call for this function, then the currently found
// cost is the cost we need to return.
- NeedToScalarize = true;
- VFShape Shape = VFShape::get(*CI, VF, false /*HasGlobalPred*/);
+ InstructionCost MaskCost = 0;
+ VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
+ if (NeedsMask)
+ *NeedsMask = MaskRequired;
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
+ // If we want an unmasked vector function but can't find one matching the VF,
+ // maybe we can find vector function that does use a mask and synthesize
+ // an all-true mask.
+ if (!VecFunc && !MaskRequired) {
+ Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
+ VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
+ // If we found one, add in the cost of creating a mask
+ if (VecFunc) {
+ if (NeedsMask)
+ *NeedsMask = true;
+ MaskCost = TTI.getShuffleCost(
+ TargetTransformInfo::SK_Broadcast,
+ VectorType::get(
+ IntegerType::getInt1Ty(VecFunc->getFunctionType()->getContext()),
+ VF));
+ }
+ }
+ // We don't support masked function calls yet, but we can scalarize a
+ // masked call with branches (unless VF is scalable).
if (!TLI || CI->isNoBuiltin() || !VecFunc)
- return Cost;
+ return VF.isScalable() ? InstructionCost::getInvalid() : Cost;
// If the corresponding vector cost is cheaper, return its cost.
InstructionCost VectorCallCost =
- TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind);
+ TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
if (VectorCallCost < Cost) {
- NeedToScalarize = false;
+ *Variant = VecFunc;
Cost = VectorCallCost;
}
return Cost;
@@ -3675,14 +3711,25 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
// Forget the original basic block.
PSE.getSE()->forgetLoop(OrigLoop);
+ // After vectorization, the exit blocks of the original loop will have
+ // additional predecessors. Invalidate SCEVs for the exit phis in case SE
+ // looked through single-entry phis.
+ SmallVector<BasicBlock *> ExitBlocks;
+ OrigLoop->getExitBlocks(ExitBlocks);
+ for (BasicBlock *Exit : ExitBlocks)
+ for (PHINode &PN : Exit->phis())
+ PSE.getSE()->forgetValue(&PN);
+
VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock();
Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]);
- if (Cost->requiresScalarEpilogue(VF)) {
+ if (Cost->requiresScalarEpilogue(VF.isVector())) {
// No edge from the middle block to the unique exit block has been inserted
// and there is nothing to fix from vector loop; phis should have incoming
// from scalar loop only.
- Plan.clearLiveOuts();
} else {
+ // TODO: Check VPLiveOuts to see if IV users need fixing instead of checking
+ // the cost model.
+
// If we inserted an edge from the middle block to the unique exit block,
// update uses outside the loop (phis) to account for the newly inserted
// edge.
@@ -3692,7 +3739,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
fixupIVUsers(Entry.first, Entry.second,
getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()),
IVEndValues[Entry.first], LoopMiddleBlock,
- VectorLoop->getHeader(), Plan);
+ VectorLoop->getHeader(), Plan, State);
}
// Fix LCSSA phis not already fixed earlier. Extracts may need to be generated
@@ -3799,31 +3846,53 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence(
Value *Incoming = State.get(PreviousDef, UF - 1);
auto *ExtractForScalar = Incoming;
auto *IdxTy = Builder.getInt32Ty();
+ Value *RuntimeVF = nullptr;
if (VF.isVector()) {
auto *One = ConstantInt::get(IdxTy, 1);
Builder.SetInsertPoint(LoopMiddleBlock->getTerminator());
- auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, VF);
+ RuntimeVF = getRuntimeVF(Builder, IdxTy, VF);
auto *LastIdx = Builder.CreateSub(RuntimeVF, One);
- ExtractForScalar = Builder.CreateExtractElement(ExtractForScalar, LastIdx,
- "vector.recur.extract");
- }
- // Extract the second last element in the middle block if the
- // Phi is used outside the loop. We need to extract the phi itself
- // and not the last element (the phi update in the current iteration). This
- // will be the value when jumping to the exit block from the LoopMiddleBlock,
- // when the scalar loop is not run at all.
- Value *ExtractForPhiUsedOutsideLoop = nullptr;
- if (VF.isVector()) {
- auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, VF);
- auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2));
- ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement(
- Incoming, Idx, "vector.recur.extract.for.phi");
- } else if (UF > 1)
- // When loop is unrolled without vectorizing, initialize
- // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled value
- // of `Incoming`. This is analogous to the vectorized case above: extracting
- // the second last element when VF > 1.
- ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2);
+ ExtractForScalar =
+ Builder.CreateExtractElement(Incoming, LastIdx, "vector.recur.extract");
+ }
+
+ auto RecurSplice = cast<VPInstruction>(*PhiR->user_begin());
+ assert(PhiR->getNumUsers() == 1 &&
+ RecurSplice->getOpcode() ==
+ VPInstruction::FirstOrderRecurrenceSplice &&
+ "recurrence phi must have a single user: FirstOrderRecurrenceSplice");
+ SmallVector<VPLiveOut *> LiveOuts;
+ for (VPUser *U : RecurSplice->users())
+ if (auto *LiveOut = dyn_cast<VPLiveOut>(U))
+ LiveOuts.push_back(LiveOut);
+
+ if (!LiveOuts.empty()) {
+ // Extract the second last element in the middle block if the
+ // Phi is used outside the loop. We need to extract the phi itself
+ // and not the last element (the phi update in the current iteration). This
+ // will be the value when jumping to the exit block from the
+ // LoopMiddleBlock, when the scalar loop is not run at all.
+ Value *ExtractForPhiUsedOutsideLoop = nullptr;
+ if (VF.isVector()) {
+ auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2));
+ ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement(
+ Incoming, Idx, "vector.recur.extract.for.phi");
+ } else {
+ assert(UF > 1 && "VF and UF cannot both be 1");
+ // When loop is unrolled without vectorizing, initialize
+ // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled
+ // value of `Incoming`. This is analogous to the vectorized case above:
+ // extracting the second last element when VF > 1.
+ ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2);
+ }
+
+ for (VPLiveOut *LiveOut : LiveOuts) {
+ assert(!Cost->requiresScalarEpilogue(VF.isVector()));
+ PHINode *LCSSAPhi = LiveOut->getPhi();
+ LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock);
+ State.Plan->removeLiveOut(LCSSAPhi);
+ }
+ }
// Fix the initial value of the original recurrence in the scalar loop.
Builder.SetInsertPoint(&*LoopScalarPreHeader->begin());
@@ -3837,22 +3906,6 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence(
Phi->setIncomingValueForBlock(LoopScalarPreHeader, Start);
Phi->setName("scalar.recur");
-
- // Finally, fix users of the recurrence outside the loop. The users will need
- // either the last value of the scalar recurrence or the last value of the
- // vector recurrence we extracted in the middle block. Since the loop is in
- // LCSSA form, we just need to find all the phi nodes for the original scalar
- // recurrence in the exit block, and then add an edge for the middle block.
- // Note that LCSSA does not imply single entry when the original scalar loop
- // had multiple exiting edges (as we always run the last iteration in the
- // scalar epilogue); in that case, there is no edge from middle to exit and
- // and thus no phis which needed updated.
- if (!Cost->requiresScalarEpilogue(VF))
- for (PHINode &LCSSAPhi : LoopExitBlock->phis())
- if (llvm::is_contained(LCSSAPhi.incoming_values(), Phi)) {
- LCSSAPhi.addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock);
- State.Plan->removeLiveOut(&LCSSAPhi);
- }
}
void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
@@ -3872,9 +3925,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
// This is the vector-clone of the value that leaves the loop.
Type *VecTy = State.get(LoopExitInstDef, 0)->getType();
- // Wrap flags are in general invalid after vectorization, clear them.
- clearReductionWrapFlags(PhiR, State);
-
// Before each round, move the insertion point right between
// the PHIs and the values we are going to write.
// This allows us to write both PHINodes and the extractelement
@@ -4036,7 +4086,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
// We know that the loop is in LCSSA form. We need to update the PHI nodes
// in the exit blocks. See comment on analogous loop in
// fixFixedOrderRecurrence for a more complete explaination of the logic.
- if (!Cost->requiresScalarEpilogue(VF))
+ if (!Cost->requiresScalarEpilogue(VF.isVector()))
for (PHINode &LCSSAPhi : LoopExitBlock->phis())
if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) {
LCSSAPhi.addIncoming(ReducedPartRdx, LoopMiddleBlock);
@@ -4054,38 +4104,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
OrigPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst);
}
-void InnerLoopVectorizer::clearReductionWrapFlags(VPReductionPHIRecipe *PhiR,
- VPTransformState &State) {
- const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
- RecurKind RK = RdxDesc.getRecurrenceKind();
- if (RK != RecurKind::Add && RK != RecurKind::Mul)
- return;
-
- SmallVector<VPValue *, 8> Worklist;
- SmallPtrSet<VPValue *, 8> Visited;
- Worklist.push_back(PhiR);
- Visited.insert(PhiR);
-
- while (!Worklist.empty()) {
- VPValue *Cur = Worklist.pop_back_val();
- for (unsigned Part = 0; Part < UF; ++Part) {
- Value *V = State.get(Cur, Part);
- if (!isa<OverflowingBinaryOperator>(V))
- break;
- cast<Instruction>(V)->dropPoisonGeneratingFlags();
- }
-
- for (VPUser *U : Cur->users()) {
- auto *UserRecipe = dyn_cast<VPRecipeBase>(U);
- if (!UserRecipe)
- continue;
- for (VPValue *V : UserRecipe->definedValues())
- if (Visited.insert(V).second)
- Worklist.push_back(V);
- }
- }
-}
-
void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) {
// The basic block and loop containing the predicated instruction.
auto *PredBB = PredInst->getParent();
@@ -4125,10 +4143,11 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) {
auto *I = dyn_cast<Instruction>(Worklist.pop_back_val());
// We can't sink an instruction if it is a phi node, is not in the loop,
- // or may have side effects.
+ // may have side effects or may read from memory.
+ // TODO Could dor more granular checking to allow sinking a load past non-store instructions.
if (!I || isa<PHINode>(I) || !VectorLoop->contains(I) ||
- I->mayHaveSideEffects())
- continue;
+ I->mayHaveSideEffects() || I->mayReadFromMemory())
+ continue;
// If the instruction is already in PredBB, check if we can sink its
// operands. In that case, VPlan's sinkScalarOperands() succeeded in
@@ -4189,7 +4208,7 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) {
// We should not collect Scalars more than once per VF. Right now, this
// function is called from collectUniformsAndScalars(), which already does
// this check. Collecting Scalars for VF=1 does not make any sense.
- assert(VF.isVector() && Scalars.find(VF) == Scalars.end() &&
+ assert(VF.isVector() && !Scalars.contains(VF) &&
"This function should not be visited twice for the same VF");
// This avoids any chances of creating a REPLICATE recipe during planning
@@ -4382,6 +4401,8 @@ bool LoopVectorizationCostModel::isScalarWithPredication(
switch(I->getOpcode()) {
default:
return true;
+ case Instruction::Call:
+ return !VFDatabase::hasMaskedVariant(*(cast<CallInst>(I)), VF);
case Instruction::Load:
case Instruction::Store: {
auto *Ptr = getLoadStorePointerOperand(I);
@@ -4430,10 +4451,10 @@ bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const {
// both speculation safety (which follows from the same argument as loads),
// but also must prove the value being stored is correct. The easiest
// form of the later is to require that all values stored are the same.
- if (Legal->isUniformMemOp(*I) &&
- (isa<LoadInst>(I) ||
- (isa<StoreInst>(I) &&
- TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()))) &&
+ if (Legal->isInvariant(getLoadStorePointerOperand(I)) &&
+ (isa<LoadInst>(I) ||
+ (isa<StoreInst>(I) &&
+ TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()))) &&
!Legal->blockNeedsPredication(I->getParent()))
return false;
return true;
@@ -4445,6 +4466,8 @@ bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const {
// TODO: We can use the loop-preheader as context point here and get
// context sensitive reasoning
return !isSafeToSpeculativelyExecute(I);
+ case Instruction::Call:
+ return Legal->isMaskRequired(I);
}
}
@@ -4502,7 +4525,8 @@ LoopVectorizationCostModel::getDivRemSpeculationCost(Instruction *I,
// second vector operand. One example of this are shifts on x86.
Value *Op2 = I->getOperand(1);
auto Op2Info = TTI.getOperandInfo(Op2);
- if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2))
+ if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue &&
+ Legal->isInvariant(Op2))
Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
SmallVector<const Value *, 4> Operands(I->operand_values());
@@ -4614,7 +4638,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
// already does this check. Collecting Uniforms for VF=1 does not make any
// sense.
- assert(VF.isVector() && Uniforms.find(VF) == Uniforms.end() &&
+ assert(VF.isVector() && !Uniforms.contains(VF) &&
"This function should not be visited twice for the same VF");
// Visit the list of Uniforms. If we'll not find any uniform value, we'll
@@ -4663,10 +4687,18 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
addToWorklistIfAllowed(Cmp);
+ auto PrevVF = VF.divideCoefficientBy(2);
// Return true if all lanes perform the same memory operation, and we can
// thus chose to execute only one.
auto isUniformMemOpUse = [&](Instruction *I) {
- if (!Legal->isUniformMemOp(*I))
+ // If the value was already known to not be uniform for the previous
+ // (smaller VF), it cannot be uniform for the larger VF.
+ if (PrevVF.isVector()) {
+ auto Iter = Uniforms.find(PrevVF);
+ if (Iter != Uniforms.end() && !Iter->second.contains(I))
+ return false;
+ }
+ if (!Legal->isUniformMemOp(*I, VF))
return false;
if (isa<LoadInst>(I))
// Loading the same address always produces the same result - at least
@@ -4689,11 +4721,14 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
WideningDecision == CM_Interleave);
};
-
// Returns true if Ptr is the pointer operand of a memory access instruction
- // I, and I is known to not require scalarization.
+ // I, I is known to not require scalarization, and the pointer is not also
+ // stored.
auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool {
- return getLoadStorePointerOperand(I) == Ptr && isUniformDecision(I, VF);
+ if (isa<StoreInst>(I) && I->getOperand(0) == Ptr)
+ return false;
+ return getLoadStorePointerOperand(I) == Ptr &&
+ (isUniformDecision(I, VF) || Legal->isInvariant(Ptr));
};
// Holds a list of values which are known to have at least one uniform use.
@@ -4739,10 +4774,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
if (isUniformMemOpUse(&I))
addToWorklistIfAllowed(&I);
- if (isUniformDecision(&I, VF)) {
- assert(isVectorizedMemAccessUse(&I, Ptr) && "consistency check");
+ if (isVectorizedMemAccessUse(&I, Ptr))
HasUniformUse.insert(Ptr);
- }
}
// Add to the worklist any operands which have *only* uniform (e.g. lane 0
@@ -4906,12 +4939,11 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
return MaxScalableVF;
// Limit MaxScalableVF by the maximum safe dependence distance.
- std::optional<unsigned> MaxVScale = TTI.getMaxVScale();
- if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange))
- MaxVScale =
- TheFunction->getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax();
- MaxScalableVF =
- ElementCount::getScalable(MaxVScale ? (MaxSafeElements / *MaxVScale) : 0);
+ if (std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI))
+ MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale);
+ else
+ MaxScalableVF = ElementCount::getScalable(0);
+
if (!MaxScalableVF)
reportVectorizationInfo(
"Max legal vector width too small, scalable vectorization "
@@ -4932,7 +4964,7 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
// the memory accesses that is most restrictive (involved in the smallest
// dependence distance).
unsigned MaxSafeElements =
- PowerOf2Floor(Legal->getMaxSafeVectorWidthInBits() / WidestType);
+ llvm::bit_floor(Legal->getMaxSafeVectorWidthInBits() / WidestType);
auto MaxSafeFixedVF = ElementCount::getFixed(MaxSafeElements);
auto MaxSafeScalableVF = getMaxLegalScalableVF(MaxSafeElements);
@@ -5105,16 +5137,26 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
}
FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true);
+
// Avoid tail folding if the trip count is known to be a multiple of any VF
- // we chose.
- // FIXME: The condition below pessimises the case for fixed-width vectors,
- // when scalable VFs are also candidates for vectorization.
- if (MaxFactors.FixedVF.isVector() && !MaxFactors.ScalableVF) {
- ElementCount MaxFixedVF = MaxFactors.FixedVF;
- assert((UserVF.isNonZero() || isPowerOf2_32(MaxFixedVF.getFixedValue())) &&
+ // we choose.
+ std::optional<unsigned> MaxPowerOf2RuntimeVF =
+ MaxFactors.FixedVF.getFixedValue();
+ if (MaxFactors.ScalableVF) {
+ std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI);
+ if (MaxVScale && TTI.isVScaleKnownToBeAPowerOfTwo()) {
+ MaxPowerOf2RuntimeVF = std::max<unsigned>(
+ *MaxPowerOf2RuntimeVF,
+ *MaxVScale * MaxFactors.ScalableVF.getKnownMinValue());
+ } else
+ MaxPowerOf2RuntimeVF = std::nullopt; // Stick with tail-folding for now.
+ }
+
+ if (MaxPowerOf2RuntimeVF && *MaxPowerOf2RuntimeVF > 0) {
+ assert((UserVF.isNonZero() || isPowerOf2_32(*MaxPowerOf2RuntimeVF)) &&
"MaxFixedVF must be a power of 2");
- unsigned MaxVFtimesIC = UserIC ? MaxFixedVF.getFixedValue() * UserIC
- : MaxFixedVF.getFixedValue();
+ unsigned MaxVFtimesIC =
+ UserIC ? *MaxPowerOf2RuntimeVF * UserIC : *MaxPowerOf2RuntimeVF;
ScalarEvolution *SE = PSE.getSE();
const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
const SCEV *ExitCount = SE->getAddExpr(
@@ -5134,7 +5176,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
// by masking.
// FIXME: look for a smaller MaxVF that does divide TC rather than masking.
if (Legal->prepareToFoldTailByMasking()) {
- FoldTailByMasking = true;
+ CanFoldTailByMasking = true;
return MaxFactors;
}
@@ -5187,7 +5229,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
// Ensure MaxVF is a power of 2; the dependence distance bound may not be.
// Note that both WidestRegister and WidestType may not be a powers of 2.
auto MaxVectorElementCount = ElementCount::get(
- PowerOf2Floor(WidestRegister.getKnownMinValue() / WidestType),
+ llvm::bit_floor(WidestRegister.getKnownMinValue() / WidestType),
ComputeScalableMaxVF);
MaxVectorElementCount = MinVF(MaxVectorElementCount, MaxSafeVF);
LLVM_DEBUG(dbgs() << "LV: The Widest register safe to use is: "
@@ -5207,6 +5249,13 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
auto Min = Attr.getVScaleRangeMin();
WidestRegisterMinEC *= Min;
}
+
+ // When a scalar epilogue is required, at least one iteration of the scalar
+ // loop has to execute. Adjust ConstTripCount accordingly to avoid picking a
+ // max VF that results in a dead vector loop.
+ if (ConstTripCount > 0 && requiresScalarEpilogue(true))
+ ConstTripCount -= 1;
+
if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC &&
(!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) {
// If loop trip count (TC) is known at compile time there is no point in
@@ -5214,7 +5263,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
// power of two which doesn't exceed TC.
// If MaxVectorElementCount is scalable, we only fall back on a fixed VF
// when the TC is less than or equal to the known number of lanes.
- auto ClampedConstTripCount = PowerOf2Floor(ConstTripCount);
+ auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount);
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
"exceeding the constant trip count: "
<< ClampedConstTripCount << "\n");
@@ -5228,7 +5277,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 &&
TTI.shouldMaximizeVectorBandwidth(RegKind))) {
auto MaxVectorElementCountMaxBW = ElementCount::get(
- PowerOf2Floor(WidestRegister.getKnownMinValue() / SmallestType),
+ llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType),
ComputeScalableMaxVF);
MaxVectorElementCountMaxBW = MinVF(MaxVectorElementCountMaxBW, MaxSafeVF);
@@ -5273,9 +5322,14 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
return MaxVF;
}
-std::optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const {
- if (TheFunction->hasFnAttribute(Attribute::VScaleRange)) {
- auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange);
+/// Convenience function that returns the value of vscale_range iff
+/// vscale_range.min == vscale_range.max or otherwise returns the value
+/// returned by the corresponding TTI method.
+static std::optional<unsigned>
+getVScaleForTuning(const Loop *L, const TargetTransformInfo &TTI) {
+ const Function *Fn = L->getHeader()->getParent();
+ if (Fn->hasFnAttribute(Attribute::VScaleRange)) {
+ auto Attr = Fn->getFnAttribute(Attribute::VScaleRange);
auto Min = Attr.getVScaleRangeMin();
auto Max = Attr.getVScaleRangeMax();
if (Max && Min == Max)
@@ -5285,31 +5339,39 @@ std::optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const {
return TTI.getVScaleForTuning();
}
-bool LoopVectorizationCostModel::isMoreProfitable(
+bool LoopVectorizationPlanner::isMoreProfitable(
const VectorizationFactor &A, const VectorizationFactor &B) const {
InstructionCost CostA = A.Cost;
InstructionCost CostB = B.Cost;
- unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop);
-
- if (!A.Width.isScalable() && !B.Width.isScalable() && FoldTailByMasking &&
- MaxTripCount) {
- // If we are folding the tail and the trip count is a known (possibly small)
- // constant, the trip count will be rounded up to an integer number of
- // iterations. The total cost will be PerIterationCost*ceil(TripCount/VF),
- // which we compare directly. When not folding the tail, the total cost will
- // be PerIterationCost*floor(TC/VF) + Scalar remainder cost, and so is
- // approximated with the per-lane cost below instead of using the tripcount
- // as here.
- auto RTCostA = CostA * divideCeil(MaxTripCount, A.Width.getFixedValue());
- auto RTCostB = CostB * divideCeil(MaxTripCount, B.Width.getFixedValue());
+ unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(OrigLoop);
+
+ if (!A.Width.isScalable() && !B.Width.isScalable() && MaxTripCount) {
+ // If the trip count is a known (possibly small) constant, the trip count
+ // will be rounded up to an integer number of iterations under
+ // FoldTailByMasking. The total cost in that case will be
+ // VecCost*ceil(TripCount/VF). When not folding the tail, the total
+ // cost will be VecCost*floor(TC/VF) + ScalarCost*(TC%VF). There will be
+ // some extra overheads, but for the purpose of comparing the costs of
+ // different VFs we can use this to compare the total loop-body cost
+ // expected after vectorization.
+ auto GetCostForTC = [MaxTripCount, this](unsigned VF,
+ InstructionCost VectorCost,
+ InstructionCost ScalarCost) {
+ return CM.foldTailByMasking() ? VectorCost * divideCeil(MaxTripCount, VF)
+ : VectorCost * (MaxTripCount / VF) +
+ ScalarCost * (MaxTripCount % VF);
+ };
+ auto RTCostA = GetCostForTC(A.Width.getFixedValue(), CostA, A.ScalarCost);
+ auto RTCostB = GetCostForTC(B.Width.getFixedValue(), CostB, B.ScalarCost);
+
return RTCostA < RTCostB;
}
// Improve estimate for the vector width if it is scalable.
unsigned EstimatedWidthA = A.Width.getKnownMinValue();
unsigned EstimatedWidthB = B.Width.getKnownMinValue();
- if (std::optional<unsigned> VScale = getVScaleForTuning()) {
+ if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI)) {
if (A.Width.isScalable())
EstimatedWidthA *= *VScale;
if (B.Width.isScalable())
@@ -5328,9 +5390,74 @@ bool LoopVectorizationCostModel::isMoreProfitable(
return (CostA * EstimatedWidthB) < (CostB * EstimatedWidthA);
}
-VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
+static void emitInvalidCostRemarks(SmallVector<InstructionVFPair> InvalidCosts,
+ OptimizationRemarkEmitter *ORE,
+ Loop *TheLoop) {
+ if (InvalidCosts.empty())
+ return;
+
+ // Emit a report of VFs with invalid costs in the loop.
+
+ // Group the remarks per instruction, keeping the instruction order from
+ // InvalidCosts.
+ std::map<Instruction *, unsigned> Numbering;
+ unsigned I = 0;
+ for (auto &Pair : InvalidCosts)
+ if (!Numbering.count(Pair.first))
+ Numbering[Pair.first] = I++;
+
+ // Sort the list, first on instruction(number) then on VF.
+ sort(InvalidCosts, [&Numbering](InstructionVFPair &A, InstructionVFPair &B) {
+ if (Numbering[A.first] != Numbering[B.first])
+ return Numbering[A.first] < Numbering[B.first];
+ ElementCountComparator ECC;
+ return ECC(A.second, B.second);
+ });
+
+ // For a list of ordered instruction-vf pairs:
+ // [(load, vf1), (load, vf2), (store, vf1)]
+ // Group the instructions together to emit separate remarks for:
+ // load (vf1, vf2)
+ // store (vf1)
+ auto Tail = ArrayRef<InstructionVFPair>(InvalidCosts);
+ auto Subset = ArrayRef<InstructionVFPair>();
+ do {
+ if (Subset.empty())
+ Subset = Tail.take_front(1);
+
+ Instruction *I = Subset.front().first;
+
+ // If the next instruction is different, or if there are no other pairs,
+ // emit a remark for the collated subset. e.g.
+ // [(load, vf1), (load, vf2))]
+ // to emit:
+ // remark: invalid costs for 'load' at VF=(vf, vf2)
+ if (Subset == Tail || Tail[Subset.size()].first != I) {
+ std::string OutString;
+ raw_string_ostream OS(OutString);
+ assert(!Subset.empty() && "Unexpected empty range");
+ OS << "Instruction with invalid costs prevented vectorization at VF=(";
+ for (const auto &Pair : Subset)
+ OS << (Pair.second == Subset.front().second ? "" : ", ") << Pair.second;
+ OS << "):";
+ if (auto *CI = dyn_cast<CallInst>(I))
+ OS << " call to " << CI->getCalledFunction()->getName();
+ else
+ OS << " " << I->getOpcodeName();
+ OS.flush();
+ reportVectorizationInfo(OutString, "InvalidCost", ORE, TheLoop, I);
+ Tail = Tail.drop_front(Subset.size());
+ Subset = {};
+ } else
+ // Grow the subset by one element
+ Subset = Tail.take_front(Subset.size() + 1);
+ } while (!Tail.empty());
+}
+
+VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor(
const ElementCountSet &VFCandidates) {
- InstructionCost ExpectedCost = expectedCost(ElementCount::getFixed(1)).first;
+ InstructionCost ExpectedCost =
+ CM.expectedCost(ElementCount::getFixed(1)).first;
LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << ExpectedCost << ".\n");
assert(ExpectedCost.isValid() && "Unexpected invalid cost for scalar loop");
assert(VFCandidates.count(ElementCount::getFixed(1)) &&
@@ -5340,7 +5467,7 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
ExpectedCost);
VectorizationFactor ChosenFactor = ScalarCost;
- bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled;
+ bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled;
if (ForceVectorization && VFCandidates.size() > 1) {
// Ignore scalar width, because the user explicitly wants vectorization.
// Initialize cost to max so that VF = 2 is, at least, chosen during cost
@@ -5354,12 +5481,13 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
if (i.isScalar())
continue;
- VectorizationCostTy C = expectedCost(i, &InvalidCosts);
+ LoopVectorizationCostModel::VectorizationCostTy C =
+ CM.expectedCost(i, &InvalidCosts);
VectorizationFactor Candidate(i, C.first, ScalarCost.ScalarCost);
#ifndef NDEBUG
unsigned AssumedMinimumVscale = 1;
- if (std::optional<unsigned> VScale = getVScaleForTuning())
+ if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI))
AssumedMinimumVscale = *VScale;
unsigned Width =
Candidate.Width.isScalable()
@@ -5388,70 +5516,13 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
ChosenFactor = Candidate;
}
- // Emit a report of VFs with invalid costs in the loop.
- if (!InvalidCosts.empty()) {
- // Group the remarks per instruction, keeping the instruction order from
- // InvalidCosts.
- std::map<Instruction *, unsigned> Numbering;
- unsigned I = 0;
- for (auto &Pair : InvalidCosts)
- if (!Numbering.count(Pair.first))
- Numbering[Pair.first] = I++;
-
- // Sort the list, first on instruction(number) then on VF.
- llvm::sort(InvalidCosts,
- [&Numbering](InstructionVFPair &A, InstructionVFPair &B) {
- if (Numbering[A.first] != Numbering[B.first])
- return Numbering[A.first] < Numbering[B.first];
- ElementCountComparator ECC;
- return ECC(A.second, B.second);
- });
-
- // For a list of ordered instruction-vf pairs:
- // [(load, vf1), (load, vf2), (store, vf1)]
- // Group the instructions together to emit separate remarks for:
- // load (vf1, vf2)
- // store (vf1)
- auto Tail = ArrayRef<InstructionVFPair>(InvalidCosts);
- auto Subset = ArrayRef<InstructionVFPair>();
- do {
- if (Subset.empty())
- Subset = Tail.take_front(1);
-
- Instruction *I = Subset.front().first;
-
- // If the next instruction is different, or if there are no other pairs,
- // emit a remark for the collated subset. e.g.
- // [(load, vf1), (load, vf2))]
- // to emit:
- // remark: invalid costs for 'load' at VF=(vf, vf2)
- if (Subset == Tail || Tail[Subset.size()].first != I) {
- std::string OutString;
- raw_string_ostream OS(OutString);
- assert(!Subset.empty() && "Unexpected empty range");
- OS << "Instruction with invalid costs prevented vectorization at VF=(";
- for (const auto &Pair : Subset)
- OS << (Pair.second == Subset.front().second ? "" : ", ")
- << Pair.second;
- OS << "):";
- if (auto *CI = dyn_cast<CallInst>(I))
- OS << " call to " << CI->getCalledFunction()->getName();
- else
- OS << " " << I->getOpcodeName();
- OS.flush();
- reportVectorizationInfo(OutString, "InvalidCost", ORE, TheLoop, I);
- Tail = Tail.drop_front(Subset.size());
- Subset = {};
- } else
- // Grow the subset by one element
- Subset = Tail.take_front(Subset.size() + 1);
- } while (!Tail.empty());
- }
+ emitInvalidCostRemarks(InvalidCosts, ORE, OrigLoop);
- if (!EnableCondStoresVectorization && NumPredStores) {
- reportVectorizationFailure("There are conditional stores.",
+ if (!EnableCondStoresVectorization && CM.hasPredStores()) {
+ reportVectorizationFailure(
+ "There are conditional stores.",
"store that is conditionally executed prevents vectorization",
- "ConditionalStore", ORE, TheLoop);
+ "ConditionalStore", ORE, OrigLoop);
ChosenFactor = ScalarCost;
}
@@ -5463,11 +5534,11 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
return ChosenFactor;
}
-bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization(
- const Loop &L, ElementCount VF) const {
+bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
+ ElementCount VF) const {
// Cross iteration phis such as reductions need special handling and are
// currently unsupported.
- if (any_of(L.getHeader()->phis(),
+ if (any_of(OrigLoop->getHeader()->phis(),
[&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); }))
return false;
@@ -5475,20 +5546,21 @@ bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization(
// currently unsupported.
for (const auto &Entry : Legal->getInductionVars()) {
// Look for uses of the value of the induction at the last iteration.
- Value *PostInc = Entry.first->getIncomingValueForBlock(L.getLoopLatch());
+ Value *PostInc =
+ Entry.first->getIncomingValueForBlock(OrigLoop->getLoopLatch());
for (User *U : PostInc->users())
- if (!L.contains(cast<Instruction>(U)))
+ if (!OrigLoop->contains(cast<Instruction>(U)))
return false;
// Look for uses of penultimate value of the induction.
for (User *U : Entry.first->users())
- if (!L.contains(cast<Instruction>(U)))
+ if (!OrigLoop->contains(cast<Instruction>(U)))
return false;
}
// Epilogue vectorization code has not been auditted to ensure it handles
// non-latch exits properly. It may be fine, but it needs auditted and
// tested.
- if (L.getExitingBlock() != L.getLoopLatch())
+ if (OrigLoop->getExitingBlock() != OrigLoop->getLoopLatch())
return false;
return true;
@@ -5507,62 +5579,59 @@ bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable(
// We also consider epilogue vectorization unprofitable for targets that don't
// consider interleaving beneficial (eg. MVE).
- if (TTI.getMaxInterleaveFactor(VF.getKnownMinValue()) <= 1)
+ if (TTI.getMaxInterleaveFactor(VF) <= 1)
return false;
- // FIXME: We should consider changing the threshold for scalable
- // vectors to take VScaleForTuning into account.
- if (VF.getKnownMinValue() >= EpilogueVectorizationMinVF)
+
+ unsigned Multiplier = 1;
+ if (VF.isScalable())
+ Multiplier = getVScaleForTuning(TheLoop, TTI).value_or(1);
+ if ((Multiplier * VF.getKnownMinValue()) >= EpilogueVectorizationMinVF)
return true;
return false;
}
-VectorizationFactor
-LoopVectorizationCostModel::selectEpilogueVectorizationFactor(
- const ElementCount MainLoopVF, const LoopVectorizationPlanner &LVP) {
+VectorizationFactor LoopVectorizationPlanner::selectEpilogueVectorizationFactor(
+ const ElementCount MainLoopVF, unsigned IC) {
VectorizationFactor Result = VectorizationFactor::Disabled();
if (!EnableEpilogueVectorization) {
- LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is disabled.\n";);
+ LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is disabled.\n");
return Result;
}
- if (!isScalarEpilogueAllowed()) {
- LLVM_DEBUG(
- dbgs() << "LEV: Unable to vectorize epilogue because no epilogue is "
- "allowed.\n";);
+ if (!CM.isScalarEpilogueAllowed()) {
+ LLVM_DEBUG(dbgs() << "LEV: Unable to vectorize epilogue because no "
+ "epilogue is allowed.\n");
return Result;
}
// Not really a cost consideration, but check for unsupported cases here to
// simplify the logic.
- if (!isCandidateForEpilogueVectorization(*TheLoop, MainLoopVF)) {
- LLVM_DEBUG(
- dbgs() << "LEV: Unable to vectorize epilogue because the loop is "
- "not a supported candidate.\n";);
+ if (!isCandidateForEpilogueVectorization(MainLoopVF)) {
+ LLVM_DEBUG(dbgs() << "LEV: Unable to vectorize epilogue because the loop "
+ "is not a supported candidate.\n");
return Result;
}
if (EpilogueVectorizationForceVF > 1) {
- LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n";);
+ LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n");
ElementCount ForcedEC = ElementCount::getFixed(EpilogueVectorizationForceVF);
- if (LVP.hasPlanWithVF(ForcedEC))
+ if (hasPlanWithVF(ForcedEC))
return {ForcedEC, 0, 0};
else {
- LLVM_DEBUG(
- dbgs()
- << "LEV: Epilogue vectorization forced factor is not viable.\n";);
+ LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization forced factor is not "
+ "viable.\n");
return Result;
}
}
- if (TheLoop->getHeader()->getParent()->hasOptSize() ||
- TheLoop->getHeader()->getParent()->hasMinSize()) {
+ if (OrigLoop->getHeader()->getParent()->hasOptSize() ||
+ OrigLoop->getHeader()->getParent()->hasMinSize()) {
LLVM_DEBUG(
- dbgs()
- << "LEV: Epilogue vectorization skipped due to opt for size.\n";);
+ dbgs() << "LEV: Epilogue vectorization skipped due to opt for size.\n");
return Result;
}
- if (!isEpilogueVectorizationProfitable(MainLoopVF)) {
+ if (!CM.isEpilogueVectorizationProfitable(MainLoopVF)) {
LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is not profitable for "
"this loop\n");
return Result;
@@ -5574,21 +5643,48 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor(
ElementCount EstimatedRuntimeVF = MainLoopVF;
if (MainLoopVF.isScalable()) {
EstimatedRuntimeVF = ElementCount::getFixed(MainLoopVF.getKnownMinValue());
- if (std::optional<unsigned> VScale = getVScaleForTuning())
+ if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI))
EstimatedRuntimeVF *= *VScale;
}
- for (auto &NextVF : ProfitableVFs)
- if (((!NextVF.Width.isScalable() && MainLoopVF.isScalable() &&
- ElementCount::isKnownLT(NextVF.Width, EstimatedRuntimeVF)) ||
- ElementCount::isKnownLT(NextVF.Width, MainLoopVF)) &&
- (Result.Width.isScalar() || isMoreProfitable(NextVF, Result)) &&
- LVP.hasPlanWithVF(NextVF.Width))
+ ScalarEvolution &SE = *PSE.getSE();
+ Type *TCType = Legal->getWidestInductionType();
+ const SCEV *RemainingIterations = nullptr;
+ for (auto &NextVF : ProfitableVFs) {
+ // Skip candidate VFs without a corresponding VPlan.
+ if (!hasPlanWithVF(NextVF.Width))
+ continue;
+
+ // Skip candidate VFs with widths >= the estimate runtime VF (scalable
+ // vectors) or the VF of the main loop (fixed vectors).
+ if ((!NextVF.Width.isScalable() && MainLoopVF.isScalable() &&
+ ElementCount::isKnownGE(NextVF.Width, EstimatedRuntimeVF)) ||
+ ElementCount::isKnownGE(NextVF.Width, MainLoopVF))
+ continue;
+
+ // If NextVF is greater than the number of remaining iterations, the
+ // epilogue loop would be dead. Skip such factors.
+ if (!MainLoopVF.isScalable() && !NextVF.Width.isScalable()) {
+ // TODO: extend to support scalable VFs.
+ if (!RemainingIterations) {
+ const SCEV *TC = createTripCountSCEV(TCType, PSE, OrigLoop);
+ RemainingIterations = SE.getURemExpr(
+ TC, SE.getConstant(TCType, MainLoopVF.getKnownMinValue() * IC));
+ }
+ if (SE.isKnownPredicate(
+ CmpInst::ICMP_UGT,
+ SE.getConstant(TCType, NextVF.Width.getKnownMinValue()),
+ RemainingIterations))
+ continue;
+ }
+
+ if (Result.Width.isScalar() || isMoreProfitable(NextVF, Result))
Result = NextVF;
+ }
if (Result != VectorizationFactor::Disabled())
LLVM_DEBUG(dbgs() << "LEV: Vectorizing epilogue loop with VF = "
- << Result.Width << "\n";);
+ << Result.Width << "\n");
return Result;
}
@@ -5688,7 +5784,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
return 1;
// We used the distance for the interleave count.
- if (Legal->getMaxSafeDepDistBytes() != -1U)
+ if (!Legal->isSafeForAnyVectorWidth())
return 1;
auto BestKnownTC = getSmallBestKnownTC(*PSE.getSE(), TheLoop);
@@ -5750,20 +5846,19 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
if (R.LoopInvariantRegs.find(pair.first) != R.LoopInvariantRegs.end())
LoopInvariantRegs = R.LoopInvariantRegs[pair.first];
- unsigned TmpIC = PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs) / MaxLocalUsers);
+ unsigned TmpIC = llvm::bit_floor((TargetNumRegisters - LoopInvariantRegs) /
+ MaxLocalUsers);
// Don't count the induction variable as interleaved.
if (EnableIndVarRegisterHeur) {
- TmpIC =
- PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs - 1) /
- std::max(1U, (MaxLocalUsers - 1)));
+ TmpIC = llvm::bit_floor((TargetNumRegisters - LoopInvariantRegs - 1) /
+ std::max(1U, (MaxLocalUsers - 1)));
}
IC = std::min(IC, TmpIC);
}
// Clamp the interleave ranges to reasonable counts.
- unsigned MaxInterleaveCount =
- TTI.getMaxInterleaveFactor(VF.getKnownMinValue());
+ unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF);
// Check if the user has overridden the max.
if (VF.isScalar()) {
@@ -5834,8 +5929,8 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
// We assume that the cost overhead is 1 and we use the cost model
// to estimate the cost of the loop and interleave until the cost of the
// loop overhead is about 5% of the cost of the loop.
- unsigned SmallIC = std::min(
- IC, (unsigned)PowerOf2Floor(SmallLoopCost / *LoopCost.getValue()));
+ unsigned SmallIC = std::min(IC, (unsigned)llvm::bit_floor<uint64_t>(
+ SmallLoopCost / *LoopCost.getValue()));
// Interleave until store/load ports (estimated by max interleave count) are
// saturated.
@@ -5953,7 +6048,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
// Saves the list of values that are used in the loop but are defined outside
// the loop (not including non-instruction values such as arguments and
// constants).
- SmallPtrSet<Value *, 8> LoopInvariants;
+ SmallSetVector<Instruction *, 8> LoopInvariants;
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
for (Instruction &I : BB->instructionsWithoutDebug()) {
@@ -6079,11 +6174,16 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
for (auto *Inst : LoopInvariants) {
// FIXME: The target might use more than one register for the type
// even in the scalar case.
- unsigned Usage =
- VFs[i].isScalar() ? 1 : GetRegUsage(Inst->getType(), VFs[i]);
+ bool IsScalar = all_of(Inst->users(), [&](User *U) {
+ auto *I = cast<Instruction>(U);
+ return TheLoop != LI->getLoopFor(I->getParent()) ||
+ isScalarAfterVectorization(I, VFs[i]);
+ });
+
+ ElementCount VF = IsScalar ? ElementCount::getFixed(1) : VFs[i];
unsigned ClassID =
- TTI.getRegisterClassForType(VFs[i].isVector(), Inst->getType());
- Invariant[ClassID] += Usage;
+ TTI.getRegisterClassForType(VF.isVector(), Inst->getType());
+ Invariant[ClassID] += GetRegUsage(Inst->getType(), VF);
}
LLVM_DEBUG({
@@ -6134,8 +6234,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(ElementCount VF) {
// instructions to scalarize, there's nothing to do. Collection may already
// have occurred if we have a user-selected VF and are now computing the
// expected cost for interleaving.
- if (VF.isScalar() || VF.isZero() ||
- InstsToScalarize.find(VF) != InstsToScalarize.end())
+ if (VF.isScalar() || VF.isZero() || InstsToScalarize.contains(VF))
return;
// Initialize a mapping for VF in InstsToScalalarize. If we find that it's
@@ -6224,7 +6323,7 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount(
Instruction *I = Worklist.pop_back_val();
// If we've already analyzed the instruction, there's nothing to do.
- if (ScalarCosts.find(I) != ScalarCosts.end())
+ if (ScalarCosts.contains(I))
continue;
// Compute the cost of the vector instruction. Note that this cost already
@@ -6362,11 +6461,6 @@ static const SCEV *getAddressAccessSCEV(
return PSE.getSCEV(Ptr);
}
-static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) {
- return Legal->hasStride(I->getOperand(0)) ||
- Legal->hasStride(I->getOperand(1));
-}
-
InstructionCost
LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I,
ElementCount VF) {
@@ -6460,7 +6554,7 @@ LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I,
InstructionCost
LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I,
ElementCount VF) {
- assert(Legal->isUniformMemOp(*I));
+ assert(Legal->isUniformMemOp(*I, VF));
Type *ValTy = getLoadStoreType(I);
auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF));
@@ -6475,7 +6569,7 @@ LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I,
}
StoreInst *SI = cast<StoreInst>(I);
- bool isLoopInvariantStoreValue = Legal->isUniform(SI->getValueOperand());
+ bool isLoopInvariantStoreValue = Legal->isInvariant(SI->getValueOperand());
return TTI.getAddressComputationCost(ValTy) +
TTI.getMemoryOpCost(Instruction::Store, ValTy, Alignment, AS,
CostKind) +
@@ -6502,11 +6596,6 @@ LoopVectorizationCostModel::getGatherScatterCost(Instruction *I,
InstructionCost
LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
ElementCount VF) {
- // TODO: Once we have support for interleaving with scalable vectors
- // we can calculate the cost properly here.
- if (VF.isScalable())
- return InstructionCost::getInvalid();
-
Type *ValTy = getLoadStoreType(I);
auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF));
unsigned AS = getLoadStoreAddressSpace(I);
@@ -6836,7 +6925,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) {
if (isa<StoreInst>(&I) && isScalarWithPredication(&I, VF))
NumPredStores++;
- if (Legal->isUniformMemOp(I)) {
+ if (Legal->isUniformMemOp(I, VF)) {
auto isLegalToScalarize = [&]() {
if (!VF.isScalable())
// Scalarization of fixed length vectors "just works".
@@ -7134,8 +7223,12 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- // Since we will replace the stride by 1 the multiplication should go away.
- if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal))
+ // If we're speculating on the stride being 1, the multiplication may
+ // fold away. We can generalize this for all operations using the notion
+ // of neutral elements. (TODO)
+ if (I->getOpcode() == Instruction::Mul &&
+ (PSE.getSCEV(I->getOperand(0))->isOne() ||
+ PSE.getSCEV(I->getOperand(1))->isOne()))
return 0;
// Detect reduction patterns
@@ -7146,7 +7239,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
// second vector operand. One example of this are shifts on x86.
Value *Op2 = I->getOperand(1);
auto Op2Info = TTI.getOperandInfo(Op2);
- if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2))
+ if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue &&
+ Legal->isInvariant(Op2))
Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
SmallVector<const Value *, 4> Operands(I->operand_values());
@@ -7304,7 +7398,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
VectorTy =
largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
} else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) {
- SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy);
+ // Leave SrcVecTy unchanged - we only shrink the destination element
+ // type.
VectorTy =
smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
}
@@ -7316,9 +7411,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
if (RecurrenceDescriptor::isFMulAddIntrinsic(I))
if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind))
return *RedCost;
- bool NeedToScalarize;
+ Function *Variant;
CallInst *CI = cast<CallInst>(I);
- InstructionCost CallCost = getVectorCallCost(CI, VF, NeedToScalarize);
+ InstructionCost CallCost = getVectorCallCost(CI, VF, &Variant);
if (getVectorIntrinsicIDForCall(CI, TLI)) {
InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF);
return std::min(CallCost, IntrinsicCost);
@@ -7339,37 +7434,6 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
} // end of switch.
}
-char LoopVectorize::ID = 0;
-
-static const char lv_name[] = "Loop Vectorization";
-
-INITIALIZE_PASS_BEGIN(LoopVectorize, LV_NAME, lv_name, false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
-INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(InjectTLIMappingsLegacy)
-INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false)
-
-namespace llvm {
-
-Pass *createLoopVectorizePass() { return new LoopVectorize(); }
-
-Pass *createLoopVectorizePass(bool InterleaveOnlyWhenForced,
- bool VectorizeOnlyWhenForced) {
- return new LoopVectorize(InterleaveOnlyWhenForced, VectorizeOnlyWhenForced);
-}
-
-} // end namespace llvm
-
void LoopVectorizationCostModel::collectValuesToIgnore() {
// Ignore ephemeral values.
CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore);
@@ -7462,7 +7526,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
// reasonable one.
if (UserVF.isZero()) {
VF = ElementCount::getFixed(determineVPlanVF(
- TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+ TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
.getFixedValue(),
CM));
LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n");
@@ -7497,13 +7561,16 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
std::optional<VectorizationFactor>
LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
assert(OrigLoop->isInnermost() && "Inner loop expected.");
+ CM.collectValuesToIgnore();
+ CM.collectElementTypesForWidening();
+
FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC);
if (!MaxFactors) // Cases that should not to be vectorized nor interleaved.
return std::nullopt;
// Invalidate interleave groups if all blocks of loop will be predicated.
if (CM.blockNeedsPredicationForAnyReason(OrigLoop->getHeader()) &&
- !useMaskedInterleavedAccesses(*TTI)) {
+ !useMaskedInterleavedAccesses(TTI)) {
LLVM_DEBUG(
dbgs()
<< "LV: Invalidate all interleaved groups due to fold-tail by masking "
@@ -7527,6 +7594,12 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n");
CM.collectInLoopReductions();
buildVPlansWithVPRecipes(UserVF, UserVF);
+ if (!hasPlanWithVF(UserVF)) {
+ LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << UserVF
+ << ".\n");
+ return std::nullopt;
+ }
+
LLVM_DEBUG(printPlans(dbgs()));
return {{UserVF, 0, 0}};
} else
@@ -7562,8 +7635,13 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
return VectorizationFactor::Disabled();
// Select the optimal vectorization factor.
- VectorizationFactor VF = CM.selectVectorizationFactor(VFCandidates);
+ VectorizationFactor VF = selectVectorizationFactor(VFCandidates);
assert((VF.Width.isScalar() || VF.ScalarCost > 0) && "when vectorizing, the scalar cost must be non-zero.");
+ if (!hasPlanWithVF(VF.Width)) {
+ LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << VF.Width
+ << ".\n");
+ return std::nullopt;
+ }
return VF;
}
@@ -7614,43 +7692,51 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) {
}
}
-void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF,
- VPlan &BestVPlan,
- InnerLoopVectorizer &ILV,
- DominatorTree *DT,
- bool IsEpilogueVectorization) {
+SCEV2ValueTy LoopVectorizationPlanner::executePlan(
+ ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
+ InnerLoopVectorizer &ILV, DominatorTree *DT, bool IsEpilogueVectorization,
+ DenseMap<const SCEV *, Value *> *ExpandedSCEVs) {
assert(BestVPlan.hasVF(BestVF) &&
"Trying to execute plan with unsupported VF");
assert(BestVPlan.hasUF(BestUF) &&
"Trying to execute plan with unsupported UF");
+ assert(
+ (IsEpilogueVectorization || !ExpandedSCEVs) &&
+ "expanded SCEVs to reuse can only be used during epilogue vectorization");
LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF
<< '\n');
- // Workaround! Compute the trip count of the original loop and cache it
- // before we start modifying the CFG. This code has a systemic problem
- // wherein it tries to run analysis over partially constructed IR; this is
- // wrong, and not simply for SCEV. The trip count of the original loop
- // simply happens to be prone to hitting this in practice. In theory, we
- // can hit the same issue for any SCEV, or ValueTracking query done during
- // mutation. See PR49900.
- ILV.getOrCreateTripCount(OrigLoop->getLoopPreheader());
-
if (!IsEpilogueVectorization)
VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
// Perform the actual loop transformation.
+ VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan};
+
+ // 0. Generate SCEV-dependent code into the preheader, including TripCount,
+ // before making any changes to the CFG.
+ if (!BestVPlan.getPreheader()->empty()) {
+ State.CFG.PrevBB = OrigLoop->getLoopPreheader();
+ State.Builder.SetInsertPoint(OrigLoop->getLoopPreheader()->getTerminator());
+ BestVPlan.getPreheader()->execute(&State);
+ }
+ if (!ILV.getTripCount())
+ ILV.setTripCount(State.get(BestVPlan.getTripCount(), {0, 0}));
+ else
+ assert(IsEpilogueVectorization && "should only re-use the existing trip "
+ "count during epilogue vectorization");
// 1. Set up the skeleton for vectorization, including vector pre-header and
// middle block. The vector loop is created during VPlan execution.
- VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan};
Value *CanonicalIVStartValue;
std::tie(State.CFG.PrevBB, CanonicalIVStartValue) =
- ILV.createVectorizedLoopSkeleton();
+ ILV.createVectorizedLoopSkeleton(ExpandedSCEVs ? *ExpandedSCEVs
+ : State.ExpandedSCEVs);
// Only use noalias metadata when using memory checks guaranteeing no overlap
// across all iterations.
const LoopAccessInfo *LAI = ILV.Legal->getLAI();
+ std::unique_ptr<LoopVersioning> LVer = nullptr;
if (LAI && !LAI->getRuntimePointerChecking()->getChecks().empty() &&
!LAI->getRuntimePointerChecking()->getDiffChecks()) {
@@ -7658,9 +7744,10 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF,
// still use it to add the noalias metadata.
// TODO: Find a better way to re-use LoopVersioning functionality to add
// metadata.
- State.LVer = std::make_unique<LoopVersioning>(
+ LVer = std::make_unique<LoopVersioning>(
*LAI, LAI->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, DT,
PSE.getSE());
+ State.LVer = &*LVer;
State.LVer->prepareNoAliasMetadata();
}
@@ -7677,10 +7764,9 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF,
//===------------------------------------------------===//
// 2. Copy and widen instructions from the old loop into the new loop.
- BestVPlan.prepareToExecute(ILV.getOrCreateTripCount(nullptr),
- ILV.getOrCreateVectorTripCount(nullptr),
- CanonicalIVStartValue, State,
- IsEpilogueVectorization);
+ BestVPlan.prepareToExecute(
+ ILV.getTripCount(), ILV.getOrCreateVectorTripCount(nullptr),
+ CanonicalIVStartValue, State, IsEpilogueVectorization);
BestVPlan.execute(&State);
@@ -7706,13 +7792,18 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF,
LoopVectorizeHints Hints(L, true, *ORE);
Hints.setAlreadyVectorized();
}
- AddRuntimeUnrollDisableMetaData(L);
+ TargetTransformInfo::UnrollingPreferences UP;
+ TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
+ if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
+ AddRuntimeUnrollDisableMetaData(L);
// 3. Fix the vectorized code: take care of header phi's, live-outs,
// predication, updating analyses.
ILV.fixVectorizedLoop(State, BestVPlan);
ILV.printDebugTracesAtEnd();
+
+ return State.ExpandedSCEVs;
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -7725,8 +7816,6 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
}
#endif
-Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; }
-
//===--------------------------------------------------------------------===//
// EpilogueVectorizerMainLoop
//===--------------------------------------------------------------------===//
@@ -7734,7 +7823,8 @@ Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; }
/// This function is partially responsible for generating the control flow
/// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization.
std::pair<BasicBlock *, Value *>
-EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() {
+EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton(
+ const SCEV2ValueTy &ExpandedSCEVs) {
createVectorLoopSkeleton("");
// Generate the code to check the minimum iteration count of the vector
@@ -7795,7 +7885,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
assert(Bypass && "Expected valid bypass basic block.");
ElementCount VFactor = ForEpilogue ? EPI.EpilogueVF : VF;
unsigned UFactor = ForEpilogue ? EPI.EpilogueUF : UF;
- Value *Count = getOrCreateTripCount(LoopVectorPreHeader);
+ Value *Count = getTripCount();
// Reuse existing vector loop preheader for TC checks.
// Note that new preheader block is generated for vector loop.
BasicBlock *const TCCheckBlock = LoopVectorPreHeader;
@@ -7803,8 +7893,10 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
// Generate code to check if the loop's trip count is less than VF * UF of the
// main vector loop.
- auto P = Cost->requiresScalarEpilogue(ForEpilogue ? EPI.EpilogueVF : VF) ?
- ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT;
+ auto P = Cost->requiresScalarEpilogue(ForEpilogue ? EPI.EpilogueVF.isVector()
+ : VF.isVector())
+ ? ICmpInst::ICMP_ULE
+ : ICmpInst::ICMP_ULT;
Value *CheckMinIters = Builder.CreateICmp(
P, Count, createStepForVF(Builder, Count->getType(), VFactor, UFactor),
@@ -7824,7 +7916,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
// Update dominator for Bypass & LoopExit.
DT->changeImmediateDominator(Bypass, TCCheckBlock);
- if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF))
+ if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector()))
// For loops with multiple exits, there's no edge from the middle block
// to exit blocks (as the epilogue must run) and thus no need to update
// the immediate dominator of the exit blocks.
@@ -7852,7 +7944,8 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
/// This function is partially responsible for generating the control flow
/// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization.
std::pair<BasicBlock *, Value *>
-EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() {
+EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton(
+ const SCEV2ValueTy &ExpandedSCEVs) {
createVectorLoopSkeleton("vec.epilog.");
// Now, compare the remaining count and if there aren't enough iterations to
@@ -7891,7 +7984,7 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() {
DT->changeImmediateDominator(LoopScalarPreHeader,
EPI.EpilogueIterationCountCheck);
- if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF))
+ if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector()))
// If there is an epilogue which must run, there's no edge from the
// middle block to exit blocks and thus no need to update the immediate
// dominator of the exit blocks.
@@ -7950,7 +8043,8 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() {
// check, then the resume value for the induction variable comes from
// the trip count of the main vector loop, hence passing the AdditionalBypass
// argument.
- createInductionResumeValues({VecEpilogueIterationCountCheck,
+ createInductionResumeValues(ExpandedSCEVs,
+ {VecEpilogueIterationCountCheck,
EPI.VectorTripCount} /* AdditionalBypass */);
return {completeLoopSkeleton(), EPResumeVal};
@@ -7972,8 +8066,9 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
// Generate code to check if the loop's trip count is less than VF * UF of the
// vector epilogue loop.
- auto P = Cost->requiresScalarEpilogue(EPI.EpilogueVF) ?
- ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT;
+ auto P = Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector())
+ ? ICmpInst::ICMP_ULE
+ : ICmpInst::ICMP_ULT;
Value *CheckMinIters =
Builder.CreateICmp(P, Count,
@@ -8008,8 +8103,7 @@ bool LoopVectorizationPlanner::getDecisionAndClampRange(
assert(!Range.isEmpty() && "Trying to test an empty VF range.");
bool PredicateAtRangeStart = Predicate(Range.Start);
- for (ElementCount TmpVF = Range.Start * 2;
- ElementCount::isKnownLT(TmpVF, Range.End); TmpVF *= 2)
+ for (ElementCount TmpVF : VFRange(Range.Start * 2, Range.End))
if (Predicate(TmpVF) != PredicateAtRangeStart) {
Range.End = TmpVF;
break;
@@ -8025,16 +8119,16 @@ bool LoopVectorizationPlanner::getDecisionAndClampRange(
/// buildVPlan().
void LoopVectorizationPlanner::buildVPlans(ElementCount MinVF,
ElementCount MaxVF) {
- auto MaxVFPlusOne = MaxVF.getWithIncrement(1);
- for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFPlusOne);) {
- VFRange SubRange = {VF, MaxVFPlusOne};
+ auto MaxVFTimes2 = MaxVF * 2;
+ for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) {
+ VFRange SubRange = {VF, MaxVFTimes2};
VPlans.push_back(buildVPlan(SubRange));
VF = SubRange.End;
}
}
VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst,
- VPlanPtr &Plan) {
+ VPlan &Plan) {
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
// Look for cached value.
@@ -8058,7 +8152,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst,
if (OrigLoop->isLoopExiting(Src))
return EdgeMaskCache[Edge] = SrcMask;
- VPValue *EdgeMask = Plan->getOrAddVPValue(BI->getCondition());
+ VPValue *EdgeMask = Plan.getVPValueOrAddLiveIn(BI->getCondition());
assert(EdgeMask && "No Edge Mask found for condition");
if (BI->getSuccessor(0) != Dst)
@@ -8069,7 +8163,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst,
// 'select i1 SrcMask, i1 EdgeMask, i1 false'.
// The select version does not introduce new UB if SrcMask is false and
// EdgeMask is poison. Using 'and' here introduces undefined behavior.
- VPValue *False = Plan->getOrAddVPValue(
+ VPValue *False = Plan.getVPValueOrAddLiveIn(
ConstantInt::getFalse(BI->getCondition()->getType()));
EdgeMask =
Builder.createSelect(SrcMask, EdgeMask, False, BI->getDebugLoc());
@@ -8078,7 +8172,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst,
return EdgeMaskCache[Edge] = EdgeMask;
}
-VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) {
+VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) {
assert(OrigLoop->contains(BB) && "Block is not a part of a loop");
// Look for cached value.
@@ -8098,29 +8192,28 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) {
// If we're using the active lane mask for control flow, then we get the
// mask from the active lane mask PHI that is cached in the VPlan.
- PredicationStyle EmitGetActiveLaneMask = CM.TTI.emitGetActiveLaneMask();
- if (EmitGetActiveLaneMask == PredicationStyle::DataAndControlFlow)
- return BlockMaskCache[BB] = Plan->getActiveLaneMaskPhi();
+ TailFoldingStyle TFStyle = CM.getTailFoldingStyle();
+ if (useActiveLaneMaskForControlFlow(TFStyle))
+ return BlockMaskCache[BB] = Plan.getActiveLaneMaskPhi();
// Introduce the early-exit compare IV <= BTC to form header block mask.
// This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
// constructing the desired canonical IV in the header block as its first
// non-phi instructions.
- VPBasicBlock *HeaderVPBB =
- Plan->getVectorLoopRegion()->getEntryBasicBlock();
+ VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
- auto *IV = new VPWidenCanonicalIVRecipe(Plan->getCanonicalIV());
+ auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
HeaderVPBB->insert(IV, HeaderVPBB->getFirstNonPhi());
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint);
- if (EmitGetActiveLaneMask != PredicationStyle::None) {
- VPValue *TC = Plan->getOrCreateTripCount();
+ if (useActiveLaneMask(TFStyle)) {
+ VPValue *TC = Plan.getTripCount();
BlockMask = Builder.createNaryOp(VPInstruction::ActiveLaneMask, {IV, TC},
nullptr, "active.lane.mask");
} else {
- VPValue *BTC = Plan->getOrCreateBackedgeTakenCount();
+ VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC});
}
return BlockMaskCache[BB] = BlockMask;
@@ -8168,7 +8261,7 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I,
VPValue *Mask = nullptr;
if (Legal->isMaskRequired(I))
- Mask = createBlockInMask(I->getParent(), Plan);
+ Mask = createBlockInMask(I->getParent(), *Plan);
// Determine if the pointer operand of the access is either consecutive or
// reverse consecutive.
@@ -8189,22 +8282,11 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I,
/// Creates a VPWidenIntOrFpInductionRecpipe for \p Phi. If needed, it will also
/// insert a recipe to expand the step for the induction recipe.
-static VPWidenIntOrFpInductionRecipe *createWidenInductionRecipes(
- PHINode *Phi, Instruction *PhiOrTrunc, VPValue *Start,
- const InductionDescriptor &IndDesc, LoopVectorizationCostModel &CM,
- VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop, VFRange &Range) {
- // Returns true if an instruction \p I should be scalarized instead of
- // vectorized for the chosen vectorization factor.
- auto ShouldScalarizeInstruction = [&CM](Instruction *I, ElementCount VF) {
- return CM.isScalarAfterVectorization(I, VF) ||
- CM.isProfitableToScalarize(I, VF);
- };
-
- bool NeedsScalarIVOnly = LoopVectorizationPlanner::getDecisionAndClampRange(
- [&](ElementCount VF) {
- return ShouldScalarizeInstruction(PhiOrTrunc, VF);
- },
- Range);
+static VPWidenIntOrFpInductionRecipe *
+createWidenInductionRecipes(PHINode *Phi, Instruction *PhiOrTrunc,
+ VPValue *Start, const InductionDescriptor &IndDesc,
+ VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop,
+ VFRange &Range) {
assert(IndDesc.getStartValue() ==
Phi->getIncomingValueForBlock(OrigLoop.getLoopPreheader()));
assert(SE.isLoopInvariant(IndDesc.getStep(), &OrigLoop) &&
@@ -8213,12 +8295,10 @@ static VPWidenIntOrFpInductionRecipe *createWidenInductionRecipes(
VPValue *Step =
vputils::getOrCreateVPValueForSCEVExpr(Plan, IndDesc.getStep(), SE);
if (auto *TruncI = dyn_cast<TruncInst>(PhiOrTrunc)) {
- return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, TruncI,
- !NeedsScalarIVOnly);
+ return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, TruncI);
}
assert(isa<PHINode>(PhiOrTrunc) && "must be a phi node here");
- return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc,
- !NeedsScalarIVOnly);
+ return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc);
}
VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI(
@@ -8227,14 +8307,13 @@ VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI(
// Check if this is an integer or fp induction. If so, build the recipe that
// produces its scalar and vector values.
if (auto *II = Legal->getIntOrFpInductionDescriptor(Phi))
- return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, CM, Plan,
+ return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, Plan,
*PSE.getSE(), *OrigLoop, Range);
// Check if this is pointer induction. If so, build the recipe for it.
if (auto *II = Legal->getPointerInductionDescriptor(Phi)) {
VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(Plan, II->getStep(),
*PSE.getSE());
- assert(isa<SCEVConstant>(II->getStep()));
return new VPWidenPointerInductionRecipe(
Phi, Operands[0], Step, *II,
LoopVectorizationPlanner::getDecisionAndClampRange(
@@ -8267,9 +8346,9 @@ VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate(
auto *Phi = cast<PHINode>(I->getOperand(0));
const InductionDescriptor &II = *Legal->getIntOrFpInductionDescriptor(Phi);
- VPValue *Start = Plan.getOrAddVPValue(II.getStartValue());
- return createWidenInductionRecipes(Phi, I, Start, II, CM, Plan,
- *PSE.getSE(), *OrigLoop, Range);
+ VPValue *Start = Plan.getVPValueOrAddLiveIn(II.getStartValue());
+ return createWidenInductionRecipes(Phi, I, Start, II, Plan, *PSE.getSE(),
+ *OrigLoop, Range);
}
return nullptr;
}
@@ -8309,7 +8388,7 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi,
for (unsigned In = 0; In < NumIncoming; In++) {
VPValue *EdgeMask =
- createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), Plan);
+ createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), *Plan);
assert((EdgeMask || NumIncoming == 1) &&
"Multiple predecessors with one having a full mask");
OperandsWithMask.push_back(Operands[In]);
@@ -8321,8 +8400,8 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi,
VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
ArrayRef<VPValue *> Operands,
- VFRange &Range) const {
-
+ VFRange &Range,
+ VPlanPtr &Plan) {
bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange(
[this, CI](ElementCount VF) {
return CM.isScalarWithPredication(CI, VF);
@@ -8339,17 +8418,17 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
ID == Intrinsic::experimental_noalias_scope_decl))
return nullptr;
- ArrayRef<VPValue *> Ops = Operands.take_front(CI->arg_size());
+ SmallVector<VPValue *, 4> Ops(Operands.take_front(CI->arg_size()));
// Is it beneficial to perform intrinsic call compared to lib call?
bool ShouldUseVectorIntrinsic =
ID && LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) -> bool {
- bool NeedToScalarize = false;
+ Function *Variant;
// Is it beneficial to perform intrinsic call compared to lib
// call?
InstructionCost CallCost =
- CM.getVectorCallCost(CI, VF, NeedToScalarize);
+ CM.getVectorCallCost(CI, VF, &Variant);
InstructionCost IntrinsicCost =
CM.getVectorIntrinsicCost(CI, VF);
return IntrinsicCost <= CallCost;
@@ -8358,6 +8437,9 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
if (ShouldUseVectorIntrinsic)
return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID);
+ Function *Variant = nullptr;
+ ElementCount VariantVF;
+ bool NeedsMask = false;
// Is better to call a vectorized version of the function than to to scalarize
// the call?
auto ShouldUseVectorCall = LoopVectorizationPlanner::getDecisionAndClampRange(
@@ -8365,14 +8447,57 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
// The following case may be scalarized depending on the VF.
// The flag shows whether we can use a usual Call for vectorized
// version of the instruction.
- bool NeedToScalarize = false;
- CM.getVectorCallCost(CI, VF, NeedToScalarize);
- return !NeedToScalarize;
+
+ // If we've found a variant at a previous VF, then stop looking. A
+ // vectorized variant of a function expects input in a certain shape
+ // -- basically the number of input registers, the number of lanes
+ // per register, and whether there's a mask required.
+ // We store a pointer to the variant in the VPWidenCallRecipe, so
+ // once we have an appropriate variant it's only valid for that VF.
+ // This will force a different vplan to be generated for each VF that
+ // finds a valid variant.
+ if (Variant)
+ return false;
+ CM.getVectorCallCost(CI, VF, &Variant, &NeedsMask);
+ // If we found a valid vector variant at this VF, then store the VF
+ // in case we need to generate a mask.
+ if (Variant)
+ VariantVF = VF;
+ return Variant != nullptr;
},
Range);
- if (ShouldUseVectorCall)
+ if (ShouldUseVectorCall) {
+ if (NeedsMask) {
+ // We have 2 cases that would require a mask:
+ // 1) The block needs to be predicated, either due to a conditional
+ // in the scalar loop or use of an active lane mask with
+ // tail-folding, and we use the appropriate mask for the block.
+ // 2) No mask is required for the block, but the only available
+ // vector variant at this VF requires a mask, so we synthesize an
+ // all-true mask.
+ VPValue *Mask = nullptr;
+ if (Legal->isMaskRequired(CI))
+ Mask = createBlockInMask(CI->getParent(), *Plan);
+ else
+ Mask = Plan->getVPValueOrAddLiveIn(ConstantInt::getTrue(
+ IntegerType::getInt1Ty(Variant->getFunctionType()->getContext())));
+
+ VFShape Shape = VFShape::get(*CI, VariantVF, /*HasGlobalPred=*/true);
+ unsigned MaskPos = 0;
+
+ for (const VFInfo &Info : VFDatabase::getMappings(*CI))
+ if (Info.Shape == Shape) {
+ assert(Info.isMasked() && "Vector function info shape mismatch");
+ MaskPos = Info.getParamIndexForOptionalMask().value();
+ break;
+ }
+
+ Ops.insert(Ops.begin() + MaskPos, Mask);
+ }
+
return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()),
- Intrinsic::not_intrinsic);
+ Intrinsic::not_intrinsic, Variant);
+ }
return nullptr;
}
@@ -8405,9 +8530,9 @@ VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I,
// div/rem operation itself. Otherwise fall through to general handling below.
if (CM.isPredicatedInst(I)) {
SmallVector<VPValue *> Ops(Operands.begin(), Operands.end());
- VPValue *Mask = createBlockInMask(I->getParent(), Plan);
- VPValue *One =
- Plan->getOrAddExternalDef(ConstantInt::get(I->getType(), 1u, false));
+ VPValue *Mask = createBlockInMask(I->getParent(), *Plan);
+ VPValue *One = Plan->getVPValueOrAddLiveIn(
+ ConstantInt::get(I->getType(), 1u, false));
auto *SafeRHS =
new VPInstruction(Instruction::Select, {Mask, Ops[1], One},
I->getDebugLoc());
@@ -8415,38 +8540,26 @@ VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I,
Ops[1] = SafeRHS;
return new VPWidenRecipe(*I, make_range(Ops.begin(), Ops.end()));
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Instruction::Add:
case Instruction::And:
case Instruction::AShr:
- case Instruction::BitCast:
case Instruction::FAdd:
case Instruction::FCmp:
case Instruction::FDiv:
case Instruction::FMul:
case Instruction::FNeg:
- case Instruction::FPExt:
- case Instruction::FPToSI:
- case Instruction::FPToUI:
- case Instruction::FPTrunc:
case Instruction::FRem:
case Instruction::FSub:
case Instruction::ICmp:
- case Instruction::IntToPtr:
case Instruction::LShr:
case Instruction::Mul:
case Instruction::Or:
- case Instruction::PtrToInt:
case Instruction::Select:
- case Instruction::SExt:
case Instruction::Shl:
- case Instruction::SIToFP:
case Instruction::Sub:
- case Instruction::Trunc:
- case Instruction::UIToFP:
case Instruction::Xor:
- case Instruction::ZExt:
case Instruction::Freeze:
return new VPWidenRecipe(*I, make_range(Operands.begin(), Operands.end()));
};
@@ -8462,9 +8575,9 @@ void VPRecipeBuilder::fixHeaderPhis() {
}
}
-VPBasicBlock *VPRecipeBuilder::handleReplication(
- Instruction *I, VFRange &Range, VPBasicBlock *VPBB,
- VPlanPtr &Plan) {
+VPRecipeOrVPValueTy VPRecipeBuilder::handleReplication(Instruction *I,
+ VFRange &Range,
+ VPlan &Plan) {
bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) { return CM.isUniformAfterVectorization(I, VF); },
Range);
@@ -8501,83 +8614,22 @@ VPBasicBlock *VPRecipeBuilder::handleReplication(
break;
}
}
-
- auto *Recipe = new VPReplicateRecipe(I, Plan->mapToVPValues(I->operands()),
- IsUniform, IsPredicated);
-
- // Find if I uses a predicated instruction. If so, it will use its scalar
- // value. Avoid hoisting the insert-element which packs the scalar value into
- // a vector value, as that happens iff all users use the vector value.
- for (VPValue *Op : Recipe->operands()) {
- auto *PredR =
- dyn_cast_or_null<VPPredInstPHIRecipe>(Op->getDefiningRecipe());
- if (!PredR)
- continue;
- auto *RepR = cast<VPReplicateRecipe>(
- PredR->getOperand(0)->getDefiningRecipe());
- assert(RepR->isPredicated() &&
- "expected Replicate recipe to be predicated");
- RepR->setAlsoPack(false);
- }
-
- // Finalize the recipe for Instr, first if it is not predicated.
+ VPValue *BlockInMask = nullptr;
if (!IsPredicated) {
+ // Finalize the recipe for Instr, first if it is not predicated.
LLVM_DEBUG(dbgs() << "LV: Scalarizing:" << *I << "\n");
- setRecipe(I, Recipe);
- Plan->addVPValue(I, Recipe);
- VPBB->appendRecipe(Recipe);
- return VPBB;
- }
- LLVM_DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n");
-
- VPBlockBase *SingleSucc = VPBB->getSingleSuccessor();
- assert(SingleSucc && "VPBB must have a single successor when handling "
- "predicated replication.");
- VPBlockUtils::disconnectBlocks(VPBB, SingleSucc);
- // Record predicated instructions for above packing optimizations.
- VPBlockBase *Region = createReplicateRegion(Recipe, Plan);
- VPBlockUtils::insertBlockAfter(Region, VPBB);
- auto *RegSucc = new VPBasicBlock();
- VPBlockUtils::insertBlockAfter(RegSucc, Region);
- VPBlockUtils::connectBlocks(RegSucc, SingleSucc);
- return RegSucc;
-}
-
-VPRegionBlock *
-VPRecipeBuilder::createReplicateRegion(VPReplicateRecipe *PredRecipe,
- VPlanPtr &Plan) {
- Instruction *Instr = PredRecipe->getUnderlyingInstr();
- // Instructions marked for predication are replicated and placed under an
- // if-then construct to prevent side-effects.
- // Generate recipes to compute the block mask for this region.
- VPValue *BlockInMask = createBlockInMask(Instr->getParent(), Plan);
-
- // Build the triangular if-then region.
- std::string RegionName = (Twine("pred.") + Instr->getOpcodeName()).str();
- assert(Instr->getParent() && "Predicated instruction not in any basic block");
- auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask);
- auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe);
- auto *PHIRecipe = Instr->getType()->isVoidTy()
- ? nullptr
- : new VPPredInstPHIRecipe(PredRecipe);
- if (PHIRecipe) {
- setRecipe(Instr, PHIRecipe);
- Plan->addVPValue(Instr, PHIRecipe);
} else {
- setRecipe(Instr, PredRecipe);
- Plan->addVPValue(Instr, PredRecipe);
+ LLVM_DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n");
+ // Instructions marked for predication are replicated and a mask operand is
+ // added initially. Masked replicate recipes will later be placed under an
+ // if-then construct to prevent side-effects. Generate recipes to compute
+ // the block mask for this region.
+ BlockInMask = createBlockInMask(I->getParent(), Plan);
}
- auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe);
- auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", PredRecipe);
- VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true);
-
- // Note: first set Entry as region entry and then connect successors starting
- // from it in order, to propagate the "parent" of each VPBasicBlock.
- VPBlockUtils::insertTwoBlocksAfter(Pred, Exiting, Entry);
- VPBlockUtils::connectBlocks(Pred, Exiting);
-
- return Region;
+ auto *Recipe = new VPReplicateRecipe(I, Plan.mapToVPValues(I->operands()),
+ IsUniform, BlockInMask);
+ return toVPRecipeResult(Recipe);
}
VPRecipeOrVPValueTy
@@ -8643,7 +8695,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
return nullptr;
if (auto *CI = dyn_cast<CallInst>(Instr))
- return toVPRecipeResult(tryToWidenCall(CI, Operands, Range));
+ return toVPRecipeResult(tryToWidenCall(CI, Operands, Range, Plan));
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
return toVPRecipeResult(tryToWidenMemory(Instr, Operands, Range, Plan));
@@ -8653,13 +8705,16 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
if (auto GEP = dyn_cast<GetElementPtrInst>(Instr))
return toVPRecipeResult(new VPWidenGEPRecipe(
- GEP, make_range(Operands.begin(), Operands.end()), OrigLoop));
+ GEP, make_range(Operands.begin(), Operands.end())));
if (auto *SI = dyn_cast<SelectInst>(Instr)) {
- bool InvariantCond =
- PSE.getSE()->isLoopInvariant(PSE.getSCEV(SI->getOperand(0)), OrigLoop);
return toVPRecipeResult(new VPWidenSelectRecipe(
- *SI, make_range(Operands.begin(), Operands.end()), InvariantCond));
+ *SI, make_range(Operands.begin(), Operands.end())));
+ }
+
+ if (auto *CI = dyn_cast<CastInst>(Instr)) {
+ return toVPRecipeResult(
+ new VPWidenCastRecipe(CI->getOpcode(), Operands[0], CI->getType(), CI));
}
return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan));
@@ -8677,34 +8732,11 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
auto &ConditionalAssumes = Legal->getConditionalAssumes();
DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end());
- MapVector<Instruction *, Instruction *> &SinkAfter = Legal->getSinkAfter();
- // Dead instructions do not need sinking. Remove them from SinkAfter.
- for (Instruction *I : DeadInstructions)
- SinkAfter.erase(I);
-
- // Cannot sink instructions after dead instructions (there won't be any
- // recipes for them). Instead, find the first non-dead previous instruction.
- for (auto &P : Legal->getSinkAfter()) {
- Instruction *SinkTarget = P.second;
- Instruction *FirstInst = &*SinkTarget->getParent()->begin();
- (void)FirstInst;
- while (DeadInstructions.contains(SinkTarget)) {
- assert(
- SinkTarget != FirstInst &&
- "Must find a live instruction (at least the one feeding the "
- "fixed-order recurrence PHI) before reaching beginning of the block");
- SinkTarget = SinkTarget->getPrevNode();
- assert(SinkTarget != P.first &&
- "sink source equals target, no sinking required");
- }
- P.second = SinkTarget;
- }
-
- auto MaxVFPlusOne = MaxVF.getWithIncrement(1);
- for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFPlusOne);) {
- VFRange SubRange = {VF, MaxVFPlusOne};
- VPlans.push_back(
- buildVPlanWithVPRecipes(SubRange, DeadInstructions, SinkAfter));
+ auto MaxVFTimes2 = MaxVF * 2;
+ for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) {
+ VFRange SubRange = {VF, MaxVFTimes2};
+ if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange, DeadInstructions))
+ VPlans.push_back(std::move(*Plan));
VF = SubRange.End;
}
}
@@ -8712,10 +8744,9 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
// Add the necessary canonical IV and branch recipes required to control the
// loop.
static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
- bool HasNUW,
- bool UseLaneMaskForLoopControlFlow) {
+ TailFoldingStyle Style) {
Value *StartIdx = ConstantInt::get(IdxTy, 0);
- auto *StartV = Plan.getOrAddVPValue(StartIdx);
+ auto *StartV = Plan.getVPValueOrAddLiveIn(StartIdx);
// Add a VPCanonicalIVPHIRecipe starting at 0 to the header.
auto *CanonicalIVPHI = new VPCanonicalIVPHIRecipe(StartV, DL);
@@ -8725,6 +8756,7 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
// Add a CanonicalIVIncrement{NUW} VPInstruction to increment the scalar
// IV by VF * UF.
+ bool HasNUW = Style == TailFoldingStyle::None;
auto *CanonicalIVIncrement =
new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementNUW
: VPInstruction::CanonicalIVIncrement,
@@ -8732,11 +8764,10 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
CanonicalIVPHI->addOperand(CanonicalIVIncrement);
VPBasicBlock *EB = TopRegion->getExitingBasicBlock();
- EB->appendRecipe(CanonicalIVIncrement);
-
- if (UseLaneMaskForLoopControlFlow) {
+ if (useActiveLaneMaskForControlFlow(Style)) {
// Create the active lane mask instruction in the vplan preheader.
- VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock();
+ VPBasicBlock *VecPreheader =
+ cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSinglePredecessor());
// We can't use StartV directly in the ActiveLaneMask VPInstruction, since
// we have to take unrolling into account. Each part needs to start at
@@ -8745,14 +8776,34 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW
: VPInstruction::CanonicalIVIncrementForPart,
{StartV}, DL, "index.part.next");
- Preheader->appendRecipe(CanonicalIVIncrementParts);
+ VecPreheader->appendRecipe(CanonicalIVIncrementParts);
// Create the ActiveLaneMask instruction using the correct start values.
- VPValue *TC = Plan.getOrCreateTripCount();
+ VPValue *TC = Plan.getTripCount();
+
+ VPValue *TripCount, *IncrementValue;
+ if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
+ // When avoiding a runtime check, the active.lane.mask inside the loop
+ // uses a modified trip count and the induction variable increment is
+ // done after the active.lane.mask intrinsic is called.
+ auto *TCMinusVF =
+ new VPInstruction(VPInstruction::CalculateTripCountMinusVF, {TC}, DL);
+ VecPreheader->appendRecipe(TCMinusVF);
+ IncrementValue = CanonicalIVPHI;
+ TripCount = TCMinusVF;
+ } else {
+ // When the loop is guarded by a runtime overflow check for the loop
+ // induction variable increment by VF, we can increment the value before
+ // the get.active.lane mask and use the unmodified tripcount.
+ EB->appendRecipe(CanonicalIVIncrement);
+ IncrementValue = CanonicalIVIncrement;
+ TripCount = TC;
+ }
+
auto *EntryALM = new VPInstruction(VPInstruction::ActiveLaneMask,
{CanonicalIVIncrementParts, TC}, DL,
"active.lane.mask.entry");
- Preheader->appendRecipe(EntryALM);
+ VecPreheader->appendRecipe(EntryALM);
// Now create the ActiveLaneMaskPhi recipe in the main loop using the
// preheader ActiveLaneMask instruction.
@@ -8763,15 +8814,21 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
CanonicalIVIncrementParts =
new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW
: VPInstruction::CanonicalIVIncrementForPart,
- {CanonicalIVIncrement}, DL);
+ {IncrementValue}, DL);
EB->appendRecipe(CanonicalIVIncrementParts);
auto *ALM = new VPInstruction(VPInstruction::ActiveLaneMask,
- {CanonicalIVIncrementParts, TC}, DL,
+ {CanonicalIVIncrementParts, TripCount}, DL,
"active.lane.mask.next");
EB->appendRecipe(ALM);
LaneMaskPhi->addOperand(ALM);
+ if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
+ // Do the increment of the canonical IV after the active.lane.mask, because
+ // that value is still based off %CanonicalIVPHI
+ EB->appendRecipe(CanonicalIVIncrement);
+ }
+
// We have to invert the mask here because a true condition means jumping
// to the exit block.
auto *NotMask = new VPInstruction(VPInstruction::Not, ALM, DL);
@@ -8781,6 +8838,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
new VPInstruction(VPInstruction::BranchOnCond, {NotMask}, DL);
EB->appendRecipe(BranchBack);
} else {
+ EB->appendRecipe(CanonicalIVIncrement);
+
// Add the BranchOnCount VPInstruction to the latch.
VPInstruction *BranchBack = new VPInstruction(
VPInstruction::BranchOnCount,
@@ -8804,14 +8863,13 @@ static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB,
for (PHINode &ExitPhi : ExitBB->phis()) {
Value *IncomingValue =
ExitPhi.getIncomingValueForBlock(ExitingBB);
- VPValue *V = Plan.getOrAddVPValue(IncomingValue, true);
+ VPValue *V = Plan.getVPValueOrAddLiveIn(IncomingValue);
Plan.addLiveOut(&ExitPhi, V);
}
}
-VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
- VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions,
- const MapVector<Instruction *, Instruction *> &SinkAfter) {
+std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
+ VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions) {
SmallPtrSet<const InterleaveGroup<Instruction> *, 1> InterleaveGroups;
@@ -8822,12 +8880,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// process after constructing the initial VPlan.
// ---------------------------------------------------------------------------
- // Mark instructions we'll need to sink later and their targets as
- // ingredients whose recipe we'll need to record.
- for (const auto &Entry : SinkAfter) {
- RecipeBuilder.recordRecipeOf(Entry.first);
- RecipeBuilder.recordRecipeOf(Entry.second);
- }
for (const auto &Reduction : CM.getInLoopReductionChains()) {
PHINode *Phi = Reduction.first;
RecurKind Kind =
@@ -8852,9 +8904,15 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// single VPInterleaveRecipe.
for (InterleaveGroup<Instruction> *IG : IAI.getInterleaveGroups()) {
auto applyIG = [IG, this](ElementCount VF) -> bool {
- return (VF.isVector() && // Query is illegal for VF == 1
- CM.getWideningDecision(IG->getInsertPos(), VF) ==
- LoopVectorizationCostModel::CM_Interleave);
+ bool Result = (VF.isVector() && // Query is illegal for VF == 1
+ CM.getWideningDecision(IG->getInsertPos(), VF) ==
+ LoopVectorizationCostModel::CM_Interleave);
+ // For scalable vectors, the only interleave factor currently supported
+ // is 2 since we require the (de)interleave2 intrinsics instead of
+ // shufflevectors.
+ assert((!Result || !VF.isScalable() || IG->getFactor() == 2) &&
+ "Unsupported interleave factor for scalable vectors");
+ return Result;
};
if (!getDecisionAndClampRange(applyIG, Range))
continue;
@@ -8869,26 +8927,34 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// visit each basic block after having visited its predecessor basic blocks.
// ---------------------------------------------------------------------------
- // Create initial VPlan skeleton, starting with a block for the pre-header,
- // followed by a region for the vector loop, followed by the middle block. The
- // skeleton vector loop region contains a header and latch block.
- VPBasicBlock *Preheader = new VPBasicBlock("vector.ph");
- auto Plan = std::make_unique<VPlan>(Preheader);
-
+ // Create initial VPlan skeleton, having a basic block for the pre-header
+ // which contains SCEV expansions that need to happen before the CFG is
+ // modified; a basic block for the vector pre-header, followed by a region for
+ // the vector loop, followed by the middle basic block. The skeleton vector
+ // loop region contains a header and latch basic blocks.
+ VPlanPtr Plan = VPlan::createInitialVPlan(
+ createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
+ *PSE.getSE());
VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");
VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch");
VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB);
auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop");
- VPBlockUtils::insertBlockAfter(TopRegion, Preheader);
+ VPBlockUtils::insertBlockAfter(TopRegion, Plan->getEntry());
VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
+ // Don't use getDecisionAndClampRange here, because we don't know the UF
+ // so this function is better to be conservative, rather than to split
+ // it up into different VPlans.
+ bool IVUpdateMayOverflow = false;
+ for (ElementCount VF : Range)
+ IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF);
+
Instruction *DLInst =
getDebugLocFromInstOrOperands(Legal->getPrimaryInduction());
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(),
DLInst ? DLInst->getDebugLoc() : DebugLoc(),
- !CM.foldTailByMasking(),
- CM.useActiveLaneMaskForControlFlow());
+ CM.getTailFoldingStyle(IVUpdateMayOverflow));
// Scan the body of the loop in a topological order to visit each basic block
// after having visited its predecessor basic blocks.
@@ -8896,18 +8962,16 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
DFS.perform(LI);
VPBasicBlock *VPBB = HeaderVPBB;
- SmallVector<VPWidenIntOrFpInductionRecipe *> InductionsToMove;
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
// Relevant instructions from basic block BB will be grouped into VPRecipe
// ingredients and fill a new VPBasicBlock.
- unsigned VPBBsForBB = 0;
if (VPBB != HeaderVPBB)
VPBB->setName(BB->getName());
Builder.setInsertPoint(VPBB);
// Introduce each ingredient into VPlan.
// TODO: Model and preserve debug intrinsics in VPlan.
- for (Instruction &I : BB->instructionsWithoutDebug()) {
+ for (Instruction &I : BB->instructionsWithoutDebug(false)) {
Instruction *Instr = &I;
// First filter out irrelevant instructions, to ensure no recipes are
@@ -8918,7 +8982,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
SmallVector<VPValue *, 4> Operands;
auto *Phi = dyn_cast<PHINode>(Instr);
if (Phi && Phi->getParent() == OrigLoop->getHeader()) {
- Operands.push_back(Plan->getOrAddVPValue(
+ Operands.push_back(Plan->getVPValueOrAddLiveIn(
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())));
} else {
auto OpRange = Plan->mapToVPValues(Instr->operands());
@@ -8932,50 +8996,36 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
Legal->isInvariantAddressOfReduction(SI->getPointerOperand()))
continue;
- if (auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe(
- Instr, Operands, Range, VPBB, Plan)) {
- // If Instr can be simplified to an existing VPValue, use it.
- if (RecipeOrValue.is<VPValue *>()) {
- auto *VPV = RecipeOrValue.get<VPValue *>();
- Plan->addVPValue(Instr, VPV);
- // If the re-used value is a recipe, register the recipe for the
- // instruction, in case the recipe for Instr needs to be recorded.
- if (VPRecipeBase *R = VPV->getDefiningRecipe())
- RecipeBuilder.setRecipe(Instr, R);
- continue;
- }
- // Otherwise, add the new recipe.
- VPRecipeBase *Recipe = RecipeOrValue.get<VPRecipeBase *>();
- for (auto *Def : Recipe->definedValues()) {
- auto *UV = Def->getUnderlyingValue();
- Plan->addVPValue(UV, Def);
- }
-
- if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) &&
- HeaderVPBB->getFirstNonPhi() != VPBB->end()) {
- // Keep track of VPWidenIntOrFpInductionRecipes not in the phi section
- // of the header block. That can happen for truncates of induction
- // variables. Those recipes are moved to the phi section of the header
- // block after applying SinkAfter, which relies on the original
- // position of the trunc.
- assert(isa<TruncInst>(Instr));
- InductionsToMove.push_back(
- cast<VPWidenIntOrFpInductionRecipe>(Recipe));
- }
- RecipeBuilder.setRecipe(Instr, Recipe);
- VPBB->appendRecipe(Recipe);
+ auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe(
+ Instr, Operands, Range, VPBB, Plan);
+ if (!RecipeOrValue)
+ RecipeOrValue = RecipeBuilder.handleReplication(Instr, Range, *Plan);
+ // If Instr can be simplified to an existing VPValue, use it.
+ if (isa<VPValue *>(RecipeOrValue)) {
+ auto *VPV = cast<VPValue *>(RecipeOrValue);
+ Plan->addVPValue(Instr, VPV);
+ // If the re-used value is a recipe, register the recipe for the
+ // instruction, in case the recipe for Instr needs to be recorded.
+ if (VPRecipeBase *R = VPV->getDefiningRecipe())
+ RecipeBuilder.setRecipe(Instr, R);
continue;
}
-
- // Otherwise, if all widening options failed, Instruction is to be
- // replicated. This may create a successor for VPBB.
- VPBasicBlock *NextVPBB =
- RecipeBuilder.handleReplication(Instr, Range, VPBB, Plan);
- if (NextVPBB != VPBB) {
- VPBB = NextVPBB;
- VPBB->setName(BB->hasName() ? BB->getName() + "." + Twine(VPBBsForBB++)
- : "");
+ // Otherwise, add the new recipe.
+ VPRecipeBase *Recipe = cast<VPRecipeBase *>(RecipeOrValue);
+ for (auto *Def : Recipe->definedValues()) {
+ auto *UV = Def->getUnderlyingValue();
+ Plan->addVPValue(UV, Def);
}
+
+ RecipeBuilder.setRecipe(Instr, Recipe);
+ if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) &&
+ HeaderVPBB->getFirstNonPhi() != VPBB->end()) {
+ // Move VPWidenIntOrFpInductionRecipes for optimized truncates to the
+ // phi section of HeaderVPBB.
+ assert(isa<TruncInst>(Instr));
+ Recipe->insertBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi());
+ } else
+ VPBB->appendRecipe(Recipe);
}
VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
@@ -8985,7 +9035,12 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// After here, VPBB should not be used.
VPBB = nullptr;
- addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan);
+ if (CM.requiresScalarEpilogue(Range)) {
+ // No edge from the middle block to the unique exit block has been inserted
+ // and there is nothing to fix from vector loop; phis should have incoming
+ // from scalar loop only.
+ } else
+ addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan);
assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) &&
!Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() &&
@@ -8998,116 +9053,10 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
// bring the VPlan to its final state.
// ---------------------------------------------------------------------------
- // Apply Sink-After legal constraints.
- auto GetReplicateRegion = [](VPRecipeBase *R) -> VPRegionBlock * {
- auto *Region = dyn_cast_or_null<VPRegionBlock>(R->getParent()->getParent());
- if (Region && Region->isReplicator()) {
- assert(Region->getNumSuccessors() == 1 &&
- Region->getNumPredecessors() == 1 && "Expected SESE region!");
- assert(R->getParent()->size() == 1 &&
- "A recipe in an original replicator region must be the only "
- "recipe in its block");
- return Region;
- }
- return nullptr;
- };
- for (const auto &Entry : SinkAfter) {
- VPRecipeBase *Sink = RecipeBuilder.getRecipe(Entry.first);
- VPRecipeBase *Target = RecipeBuilder.getRecipe(Entry.second);
-
- auto *TargetRegion = GetReplicateRegion(Target);
- auto *SinkRegion = GetReplicateRegion(Sink);
- if (!SinkRegion) {
- // If the sink source is not a replicate region, sink the recipe directly.
- if (TargetRegion) {
- // The target is in a replication region, make sure to move Sink to
- // the block after it, not into the replication region itself.
- VPBasicBlock *NextBlock =
- cast<VPBasicBlock>(TargetRegion->getSuccessors().front());
- Sink->moveBefore(*NextBlock, NextBlock->getFirstNonPhi());
- } else
- Sink->moveAfter(Target);
- continue;
- }
-
- // The sink source is in a replicate region. Unhook the region from the CFG.
- auto *SinkPred = SinkRegion->getSinglePredecessor();
- auto *SinkSucc = SinkRegion->getSingleSuccessor();
- VPBlockUtils::disconnectBlocks(SinkPred, SinkRegion);
- VPBlockUtils::disconnectBlocks(SinkRegion, SinkSucc);
- VPBlockUtils::connectBlocks(SinkPred, SinkSucc);
-
- if (TargetRegion) {
- // The target recipe is also in a replicate region, move the sink region
- // after the target region.
- auto *TargetSucc = TargetRegion->getSingleSuccessor();
- VPBlockUtils::disconnectBlocks(TargetRegion, TargetSucc);
- VPBlockUtils::connectBlocks(TargetRegion, SinkRegion);
- VPBlockUtils::connectBlocks(SinkRegion, TargetSucc);
- } else {
- // The sink source is in a replicate region, we need to move the whole
- // replicate region, which should only contain a single recipe in the
- // main block.
- auto *SplitBlock =
- Target->getParent()->splitAt(std::next(Target->getIterator()));
-
- auto *SplitPred = SplitBlock->getSinglePredecessor();
-
- VPBlockUtils::disconnectBlocks(SplitPred, SplitBlock);
- VPBlockUtils::connectBlocks(SplitPred, SinkRegion);
- VPBlockUtils::connectBlocks(SinkRegion, SplitBlock);
- }
- }
-
- VPlanTransforms::removeRedundantCanonicalIVs(*Plan);
- VPlanTransforms::removeRedundantInductionCasts(*Plan);
-
- // Now that sink-after is done, move induction recipes for optimized truncates
- // to the phi section of the header block.
- for (VPWidenIntOrFpInductionRecipe *Ind : InductionsToMove)
- Ind->moveBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi());
-
// Adjust the recipes for any inloop reductions.
adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExiting()), Plan,
RecipeBuilder, Range.Start);
- // Introduce a recipe to combine the incoming and previous values of a
- // fixed-order recurrence.
- for (VPRecipeBase &R :
- Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
- auto *RecurPhi = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R);
- if (!RecurPhi)
- continue;
-
- VPRecipeBase *PrevRecipe = &RecurPhi->getBackedgeRecipe();
- // Fixed-order recurrences do not contain cycles, so this loop is guaranteed
- // to terminate.
- while (auto *PrevPhi =
- dyn_cast<VPFirstOrderRecurrencePHIRecipe>(PrevRecipe))
- PrevRecipe = &PrevPhi->getBackedgeRecipe();
- VPBasicBlock *InsertBlock = PrevRecipe->getParent();
- auto *Region = GetReplicateRegion(PrevRecipe);
- if (Region)
- InsertBlock = dyn_cast<VPBasicBlock>(Region->getSingleSuccessor());
- if (!InsertBlock) {
- InsertBlock = new VPBasicBlock(Region->getName() + ".succ");
- VPBlockUtils::insertBlockAfter(InsertBlock, Region);
- }
- if (Region || PrevRecipe->isPhi())
- Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi());
- else
- Builder.setInsertPoint(InsertBlock, std::next(PrevRecipe->getIterator()));
-
- auto *RecurSplice = cast<VPInstruction>(
- Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice,
- {RecurPhi, RecurPhi->getBackedgeValue()}));
-
- RecurPhi->replaceAllUsesWith(RecurSplice);
- // Set the first operand of RecurSplice to RecurPhi again, after replacing
- // all users.
- RecurSplice->setOperand(0, RecurPhi);
- }
-
// Interleave memory: for each Interleave Group we marked earlier as relevant
// for this VPlan, replace the Recipes widening its memory instructions with a
// single VPInterleaveRecipe at its insertion point.
@@ -9122,48 +9071,66 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
StoredValues.push_back(StoreR->getStoredValue());
}
+ bool NeedsMaskForGaps =
+ IG->requiresScalarEpilogue() && !CM.isScalarEpilogueAllowed();
auto *VPIG = new VPInterleaveRecipe(IG, Recipe->getAddr(), StoredValues,
- Recipe->getMask());
+ Recipe->getMask(), NeedsMaskForGaps);
VPIG->insertBefore(Recipe);
unsigned J = 0;
for (unsigned i = 0; i < IG->getFactor(); ++i)
if (Instruction *Member = IG->getMember(i)) {
+ VPRecipeBase *MemberR = RecipeBuilder.getRecipe(Member);
if (!Member->getType()->isVoidTy()) {
- VPValue *OriginalV = Plan->getVPValue(Member);
- Plan->removeVPValueFor(Member);
- Plan->addVPValue(Member, VPIG->getVPValue(J));
+ VPValue *OriginalV = MemberR->getVPSingleValue();
OriginalV->replaceAllUsesWith(VPIG->getVPValue(J));
J++;
}
- RecipeBuilder.getRecipe(Member)->eraseFromParent();
+ MemberR->eraseFromParent();
}
}
- for (ElementCount VF = Range.Start; ElementCount::isKnownLT(VF, Range.End);
- VF *= 2)
+ for (ElementCount VF : Range)
Plan->addVF(VF);
Plan->setName("Initial VPlan");
+ // Replace VPValues for known constant strides guaranteed by predicate scalar
+ // evolution.
+ for (auto [_, Stride] : Legal->getLAI()->getSymbolicStrides()) {
+ auto *StrideV = cast<SCEVUnknown>(Stride)->getValue();
+ auto *ScevStride = dyn_cast<SCEVConstant>(PSE.getSCEV(StrideV));
+ // Only handle constant strides for now.
+ if (!ScevStride)
+ continue;
+ Constant *CI = ConstantInt::get(Stride->getType(), ScevStride->getAPInt());
+
+ auto *ConstVPV = Plan->getVPValueOrAddLiveIn(CI);
+ // The versioned value may not be used in the loop directly, so just add a
+ // new live-in in those cases.
+ Plan->getVPValueOrAddLiveIn(StrideV)->replaceAllUsesWith(ConstVPV);
+ }
+
// From this point onwards, VPlan-to-VPlan transformations may change the plan
// in ways that accessing values using original IR values is incorrect.
Plan->disableValue2VPValue();
+ // Sink users of fixed-order recurrence past the recipe defining the previous
+ // value and introduce FirstOrderRecurrenceSplice VPInstructions.
+ if (!VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder))
+ return std::nullopt;
+
+ VPlanTransforms::removeRedundantCanonicalIVs(*Plan);
+ VPlanTransforms::removeRedundantInductionCasts(*Plan);
+
VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE());
VPlanTransforms::removeDeadRecipes(*Plan);
- bool ShouldSimplify = true;
- while (ShouldSimplify) {
- ShouldSimplify = VPlanTransforms::sinkScalarOperands(*Plan);
- ShouldSimplify |=
- VPlanTransforms::mergeReplicateRegionsIntoSuccessors(*Plan);
- ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(*Plan);
- }
+ VPlanTransforms::createAndOptimizeReplicateRegions(*Plan);
VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan);
VPlanTransforms::mergeBlocksIntoPredecessors(*Plan);
assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid");
- return Plan;
+ return std::make_optional(std::move(Plan));
}
VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
@@ -9175,21 +9142,21 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
// Create new empty VPlan
- auto Plan = std::make_unique<VPlan>();
+ auto Plan = VPlan::createInitialVPlan(
+ createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
+ *PSE.getSE());
// Build hierarchical CFG
VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan);
HCFGBuilder.buildHierarchicalCFG();
- for (ElementCount VF = Range.Start; ElementCount::isKnownLT(VF, Range.End);
- VF *= 2)
+ for (ElementCount VF : Range)
Plan->addVF(VF);
- SmallPtrSet<Instruction *, 1> DeadInstructions;
VPlanTransforms::VPInstructionsToVPRecipes(
- OrigLoop, Plan,
+ Plan,
[this](PHINode *P) { return Legal->getIntOrFpInductionDescriptor(P); },
- DeadInstructions, *PSE.getSE(), *TLI);
+ *PSE.getSE(), *TLI);
// Remove the existing terminator of the exiting block of the top-most region.
// A BranchOnCount will be added instead when adding the canonical IV recipes.
@@ -9198,7 +9165,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
Term->eraseFromParent();
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DebugLoc(),
- true, CM.useActiveLaneMaskForControlFlow());
+ CM.getTailFoldingStyle());
return Plan;
}
@@ -9255,7 +9222,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(WidenRecipe->getParent(),
WidenRecipe->getIterator());
- CondOp = RecipeBuilder.createBlockInMask(R->getParent(), Plan);
+ CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan);
}
if (IsFMulAdd) {
@@ -9270,7 +9237,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
VecOp = FMulRecipe;
}
VPReductionRecipe *RedRecipe =
- new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, TTI);
+ new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI);
WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
Plan->removeVPValueFor(R);
Plan->addVPValue(R, RedRecipe);
@@ -9304,13 +9271,15 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (!PhiR || PhiR->isInLoop())
continue;
VPValue *Cond =
- RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan);
+ RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan);
VPValue *Red = PhiR->getBackedgeValue();
assert(Red->getDefiningRecipe()->getParent() != LatchVPBB &&
"reduction recipe must be defined before latch");
Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR});
}
}
+
+ VPlanTransforms::clearReductionWrapFlags(*Plan);
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -9475,7 +9444,7 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
PartStart, ConstantInt::get(PtrInd->getType(), Lane));
Value *GlobalIdx = State.Builder.CreateAdd(PtrInd, Idx);
- Value *Step = State.get(getOperand(1), VPIteration(0, Part));
+ Value *Step = State.get(getOperand(1), VPIteration(Part, Lane));
Value *SclrGep = emitTransformedIndex(
State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc);
SclrGep->setName("next.gep");
@@ -9485,8 +9454,6 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
return;
}
- assert(isa<SCEVConstant>(IndDesc.getStep()) &&
- "Induction step not a SCEV constant!");
Type *PhiType = IndDesc.getStep()->getType();
// Build a pointer phi
@@ -9506,7 +9473,7 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
Value *NumUnrolledElems =
State.Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF));
Value *InductionGEP = GetElementPtrInst::Create(
- IndDesc.getElementType(), NewPointerPhi,
+ State.Builder.getInt8Ty(), NewPointerPhi,
State.Builder.CreateMul(ScalarStepValue, NumUnrolledElems), "ptr.ind",
InductionLoc);
// Add induction update using an incorrect block temporarily. The phi node
@@ -9529,10 +9496,10 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
StartOffset = State.Builder.CreateAdd(
StartOffset, State.Builder.CreateStepVector(VecPhiType));
- assert(ScalarStepValue == State.get(getOperand(1), VPIteration(0, Part)) &&
+ assert(ScalarStepValue == State.get(getOperand(1), VPIteration(Part, 0)) &&
"scalar step must be the same across all parts");
Value *GEP = State.Builder.CreateGEP(
- IndDesc.getElementType(), NewPointerPhi,
+ State.Builder.getInt8Ty(), NewPointerPhi,
State.Builder.CreateMul(
StartOffset,
State.Builder.CreateVectorSplat(State.VF, ScalarStepValue),
@@ -9584,7 +9551,8 @@ void VPScalarIVStepsRecipe::execute(VPTransformState &State) {
void VPInterleaveRecipe::execute(VPTransformState &State) {
assert(!State.Instance && "Interleave group being replicated.");
State.ILV->vectorizeInterleaveGroup(IG, definedValues(), State, getAddr(),
- getStoredValues(), getMask());
+ getStoredValues(), getMask(),
+ NeedsMaskForGaps);
}
void VPReductionRecipe::execute(VPTransformState &State) {
@@ -9640,10 +9608,9 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
Instruction *UI = getUnderlyingInstr();
if (State.Instance) { // Generate a single instance.
assert(!State.VF.isScalable() && "Can't scalarize a scalable vector");
- State.ILV->scalarizeInstruction(UI, this, *State.Instance,
- IsPredicated, State);
+ State.ILV->scalarizeInstruction(UI, this, *State.Instance, State);
// Insert scalar instance packing it into a vector.
- if (AlsoPack && State.VF.isVector()) {
+ if (State.VF.isVector() && shouldPack()) {
// If we're constructing lane 0, initialize to start from poison.
if (State.Instance->Lane.isFirstLane()) {
assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
@@ -9663,8 +9630,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
all_of(operands(), [](VPValue *Op) {
return Op->isDefinedOutsideVectorRegions();
})) {
- State.ILV->scalarizeInstruction(UI, this, VPIteration(0, 0), IsPredicated,
- State);
+ State.ILV->scalarizeInstruction(UI, this, VPIteration(0, 0), State);
if (user_begin() != user_end()) {
for (unsigned Part = 1; Part < State.UF; ++Part)
State.set(this, State.get(this, VPIteration(0, 0)),
@@ -9676,16 +9642,16 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
// Uniform within VL means we need to generate lane 0 only for each
// unrolled copy.
for (unsigned Part = 0; Part < State.UF; ++Part)
- State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, 0),
- IsPredicated, State);
+ State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, 0), State);
return;
}
- // A store of a loop varying value to a loop invariant address only
- // needs only the last copy of the store.
- if (isa<StoreInst>(UI) && !getOperand(1)->hasDefiningRecipe()) {
+ // A store of a loop varying value to a uniform address only needs the last
+ // copy of the store.
+ if (isa<StoreInst>(UI) &&
+ vputils::isUniformAfterVectorization(getOperand(1))) {
auto Lane = VPLane::getLastLaneForVF(State.VF);
- State.ILV->scalarizeInstruction(UI, this, VPIteration(State.UF - 1, Lane), IsPredicated,
+ State.ILV->scalarizeInstruction(UI, this, VPIteration(State.UF - 1, Lane),
State);
return;
}
@@ -9695,8 +9661,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
const unsigned EndLane = State.VF.getKnownMinValue();
for (unsigned Part = 0; Part < State.UF; ++Part)
for (unsigned Lane = 0; Lane < EndLane; ++Lane)
- State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane),
- IsPredicated, State);
+ State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), State);
}
void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
@@ -9714,7 +9679,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
auto *DataTy = VectorType::get(ScalarDataTy, State.VF);
const Align Alignment = getLoadStoreAlignment(&Ingredient);
- bool CreateGatherScatter = !Consecutive;
+ bool CreateGatherScatter = !isConsecutive();
auto &Builder = State.Builder;
InnerLoopVectorizer::VectorParts BlockInMaskParts(State.UF);
@@ -9725,36 +9690,39 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
const auto CreateVecPtr = [&](unsigned Part, Value *Ptr) -> Value * {
// Calculate the pointer for the specific unroll-part.
- GetElementPtrInst *PartPtr = nullptr;
-
+ Value *PartPtr = nullptr;
+
+ // Use i32 for the gep index type when the value is constant,
+ // or query DataLayout for a more suitable index type otherwise.
+ const DataLayout &DL =
+ Builder.GetInsertBlock()->getModule()->getDataLayout();
+ Type *IndexTy = State.VF.isScalable() && (isReverse() || Part > 0)
+ ? DL.getIndexType(ScalarDataTy->getPointerTo())
+ : Builder.getInt32Ty();
bool InBounds = false;
if (auto *gep = dyn_cast<GetElementPtrInst>(Ptr->stripPointerCasts()))
InBounds = gep->isInBounds();
- if (Reverse) {
+ if (isReverse()) {
// If the address is consecutive but reversed, then the
// wide store needs to start at the last vector element.
// RunTimeVF = VScale * VF.getKnownMinValue()
// For fixed-width VScale is 1, then RunTimeVF = VF.getKnownMinValue()
- Value *RunTimeVF = getRuntimeVF(Builder, Builder.getInt32Ty(), State.VF);
+ Value *RunTimeVF = getRuntimeVF(Builder, IndexTy, State.VF);
// NumElt = -Part * RunTimeVF
- Value *NumElt = Builder.CreateMul(Builder.getInt32(-Part), RunTimeVF);
+ Value *NumElt =
+ Builder.CreateMul(ConstantInt::get(IndexTy, -(int64_t)Part), RunTimeVF);
// LastLane = 1 - RunTimeVF
- Value *LastLane = Builder.CreateSub(Builder.getInt32(1), RunTimeVF);
+ Value *LastLane =
+ Builder.CreateSub(ConstantInt::get(IndexTy, 1), RunTimeVF);
+ PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, NumElt, "", InBounds);
PartPtr =
- cast<GetElementPtrInst>(Builder.CreateGEP(ScalarDataTy, Ptr, NumElt));
- PartPtr->setIsInBounds(InBounds);
- PartPtr = cast<GetElementPtrInst>(
- Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane));
- PartPtr->setIsInBounds(InBounds);
+ Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane, "", InBounds);
if (isMaskRequired) // Reverse of a null all-one mask is a null mask.
BlockInMaskParts[Part] =
Builder.CreateVectorReverse(BlockInMaskParts[Part], "reverse");
} else {
- Value *Increment =
- createStepForVF(Builder, Builder.getInt32Ty(), State.VF, Part);
- PartPtr = cast<GetElementPtrInst>(
- Builder.CreateGEP(ScalarDataTy, Ptr, Increment));
- PartPtr->setIsInBounds(InBounds);
+ Value *Increment = createStepForVF(Builder, IndexTy, State.VF, Part);
+ PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, Increment, "", InBounds);
}
unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace();
@@ -9774,7 +9742,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment,
MaskPart);
} else {
- if (Reverse) {
+ if (isReverse()) {
// If we store to reverse consecutive memory locations, then we need
// to reverse the order of elements in the stored value.
StoredVal = Builder.CreateVectorReverse(StoredVal, "reverse");
@@ -9833,7 +9801,6 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
static ScalarEpilogueLowering getScalarEpilogueLowering(
Function *F, Loop *L, LoopVectorizeHints &Hints, ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI, TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
- AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
LoopVectorizationLegality &LVL, InterleavedAccessInfo *IAI) {
// 1) OptSize takes precedence over all other options, i.e. if this is set,
// don't look at hints or options, and don't request a scalar epilogue.
@@ -9869,7 +9836,8 @@ static ScalarEpilogueLowering getScalarEpilogueLowering(
};
// 4) if the TTI hook indicates this is profitable, request predication.
- if (TTI->preferPredicateOverEpilogue(L, LI, *SE, *AC, TLI, DT, &LVL, IAI))
+ TailFoldingInfo TFI(TLI, &LVL, IAI);
+ if (TTI->preferPredicateOverEpilogue(&TFI))
return CM_ScalarEpilogueNotNeededUsePredicate;
return CM_ScalarEpilogueAllowed;
@@ -9880,9 +9848,29 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) {
if (hasVectorValue(Def, Part))
return Data.PerPartOutput[Def][Part];
+ auto GetBroadcastInstrs = [this, Def](Value *V) {
+ bool SafeToHoist = Def->isDefinedOutsideVectorRegions();
+ if (VF.isScalar())
+ return V;
+ // Place the code for broadcasting invariant variables in the new preheader.
+ IRBuilder<>::InsertPointGuard Guard(Builder);
+ if (SafeToHoist) {
+ BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>(
+ Plan->getVectorLoopRegion()->getSinglePredecessor())];
+ if (LoopVectorPreHeader)
+ Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());
+ }
+
+ // Place the code for broadcasting invariant variables in the new preheader.
+ // Broadcast the scalar into all locations in the vector.
+ Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast");
+
+ return Shuf;
+ };
+
if (!hasScalarValue(Def, {Part, 0})) {
Value *IRV = Def->getLiveInIRValue();
- Value *B = ILV->getBroadcastInstrs(IRV);
+ Value *B = GetBroadcastInstrs(IRV);
set(Def, B, Part);
return B;
}
@@ -9900,9 +9888,11 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) {
unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1;
// Check if there is a scalar value for the selected lane.
if (!hasScalarValue(Def, {Part, LastLane})) {
- // At the moment, VPWidenIntOrFpInductionRecipes and VPScalarIVStepsRecipes can also be uniform.
+ // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and
+ // VPExpandSCEVRecipes can also be uniform.
assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) ||
- isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe())) &&
+ isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) ||
+ isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) &&
"unexpected recipe found to be invariant");
IsUniform = true;
LastLane = 0;
@@ -9927,7 +9917,7 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) {
// State, we will only generate the insertelements once.
Value *VectorValue = nullptr;
if (IsUniform) {
- VectorValue = ILV->getBroadcastInstrs(ScalarValue);
+ VectorValue = GetBroadcastInstrs(ScalarValue);
set(Def, VectorValue, Part);
} else {
// Initialize packing with insertelements to start from undef.
@@ -9962,15 +9952,15 @@ static bool processLoopInVPlanNativePath(
Function *F = L->getHeader()->getParent();
InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI());
- ScalarEpilogueLowering SEL = getScalarEpilogueLowering(
- F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, *LVL, &IAI);
+ ScalarEpilogueLowering SEL =
+ getScalarEpilogueLowering(F, L, Hints, PSI, BFI, TTI, TLI, *LVL, &IAI);
LoopVectorizationCostModel CM(SEL, L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F,
&Hints, IAI);
// Use the planner for outer loop vectorization.
// TODO: CM is not used at this point inside the planner. Turn CM into an
// optional argument if we don't need it in the future.
- LoopVectorizationPlanner LVP(L, LI, TLI, TTI, LVL, CM, IAI, PSE, Hints, ORE);
+ LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, LVL, CM, IAI, PSE, Hints, ORE);
// Get user vectorization factor.
ElementCount UserVF = Hints.getWidth();
@@ -10231,8 +10221,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Check the function attributes and profiles to find out if this function
// should be optimized for size.
- ScalarEpilogueLowering SEL = getScalarEpilogueLowering(
- F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, LVL, &IAI);
+ ScalarEpilogueLowering SEL =
+ getScalarEpilogueLowering(F, L, Hints, PSI, BFI, TTI, TLI, LVL, &IAI);
// Check the loop for a trip count threshold: vectorize loops with a tiny trip
// count by optimizing for size, to minimize overheads.
@@ -10309,11 +10299,9 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Use the cost model.
LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE,
F, &Hints, IAI);
- CM.collectValuesToIgnore();
- CM.collectElementTypesForWidening();
-
// Use the planner for vectorization.
- LoopVectorizationPlanner LVP(L, LI, TLI, TTI, &LVL, CM, IAI, PSE, Hints, ORE);
+ LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, &LVL, CM, IAI, PSE, Hints,
+ ORE);
// Get user vectorization factor and interleave count.
ElementCount UserVF = Hints.getWidth();
@@ -10342,7 +10330,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
bool ForceVectorization =
Hints.getForce() == LoopVectorizeHints::FK_Enabled;
if (!ForceVectorization &&
- !areRuntimeChecksProfitable(Checks, VF, CM.getVScaleForTuning(), L,
+ !areRuntimeChecksProfitable(Checks, VF, getVScaleForTuning(L, *TTI), L,
*PSE.getSE())) {
ORE->emit([&]() {
return OptimizationRemarkAnalysisAliasing(
@@ -10464,7 +10452,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Consider vectorizing the epilogue too if it's profitable.
VectorizationFactor EpilogueVF =
- CM.selectEpilogueVectorizationFactor(VF.Width, LVP);
+ LVP.selectEpilogueVectorizationFactor(VF.Width, IC);
if (EpilogueVF.Width.isVector()) {
// The first pass vectorizes the main loop and creates a scalar epilogue
@@ -10475,8 +10463,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
EPI, &LVL, &CM, BFI, PSI, Checks);
VPlan &BestMainPlan = LVP.getBestPlanFor(EPI.MainLoopVF);
- LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, BestMainPlan, MainILV,
- DT, true);
+ auto ExpandedSCEVs = LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF,
+ BestMainPlan, MainILV, DT, true);
++LoopsVectorized;
// Second pass vectorizes the epilogue and adjusts the control flow
@@ -10492,6 +10480,21 @@ bool LoopVectorizePass::processLoop(Loop *L) {
VPBasicBlock *Header = VectorLoop->getEntryBasicBlock();
Header->setName("vec.epilog.vector.body");
+ // Re-use the trip count and steps expanded for the main loop, as
+ // skeleton creation needs it as a value that dominates both the scalar
+ // and vector epilogue loops
+ // TODO: This is a workaround needed for epilogue vectorization and it
+ // should be removed once induction resume value creation is done
+ // directly in VPlan.
+ EpilogILV.setTripCount(MainILV.getTripCount());
+ for (auto &R : make_early_inc_range(*BestEpiPlan.getPreheader())) {
+ auto *ExpandR = cast<VPExpandSCEVRecipe>(&R);
+ auto *ExpandedVal = BestEpiPlan.getVPValueOrAddLiveIn(
+ ExpandedSCEVs.find(ExpandR->getSCEV())->second);
+ ExpandR->replaceAllUsesWith(ExpandedVal);
+ ExpandR->eraseFromParent();
+ }
+
// Ensure that the start values for any VPWidenIntOrFpInductionRecipe,
// VPWidenPointerInductionRecipe and VPReductionPHIRecipes are updated
// before vectorizing the epilogue loop.
@@ -10520,15 +10523,16 @@ bool LoopVectorizePass::processLoop(Loop *L) {
}
ResumeV = MainILV.createInductionResumeValue(
- IndPhi, *ID, {EPI.MainLoopIterationCountCheck});
+ IndPhi, *ID, getExpandedStep(*ID, ExpandedSCEVs),
+ {EPI.MainLoopIterationCountCheck});
}
assert(ResumeV && "Must have a resume value");
- VPValue *StartVal = BestEpiPlan.getOrAddExternalDef(ResumeV);
+ VPValue *StartVal = BestEpiPlan.getVPValueOrAddLiveIn(ResumeV);
cast<VPHeaderPHIRecipe>(&R)->setStartValue(StartVal);
}
LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV,
- DT, true);
+ DT, true, &ExpandedSCEVs);
++LoopsEpilogueVectorized;
if (!MainILV.areSafetyChecksAdded())
@@ -10581,14 +10585,14 @@ bool LoopVectorizePass::processLoop(Loop *L) {
LoopVectorizeResult LoopVectorizePass::runImpl(
Function &F, ScalarEvolution &SE_, LoopInfo &LI_, TargetTransformInfo &TTI_,
- DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_,
+ DominatorTree &DT_, BlockFrequencyInfo *BFI_, TargetLibraryInfo *TLI_,
DemandedBits &DB_, AssumptionCache &AC_, LoopAccessInfoManager &LAIs_,
OptimizationRemarkEmitter &ORE_, ProfileSummaryInfo *PSI_) {
SE = &SE_;
LI = &LI_;
TTI = &TTI_;
DT = &DT_;
- BFI = &BFI_;
+ BFI = BFI_;
TLI = TLI_;
AC = &AC_;
LAIs = &LAIs_;
@@ -10604,7 +10608,7 @@ LoopVectorizeResult LoopVectorizePass::runImpl(
// vector registers, loop vectorization may still enable scalar
// interleaving.
if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true)) &&
- TTI->getMaxInterleaveFactor(1) < 2)
+ TTI->getMaxInterleaveFactor(ElementCount::getFixed(1)) < 2)
return LoopVectorizeResult(false, false);
bool Changed = false, CFGChanged = false;
@@ -10656,7 +10660,6 @@ PreservedAnalyses LoopVectorizePass::run(Function &F,
auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
auto &DB = AM.getResult<DemandedBitsAnalysis>(F);
@@ -10666,12 +10669,20 @@ PreservedAnalyses LoopVectorizePass::run(Function &F,
auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
ProfileSummaryInfo *PSI =
MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
+ BlockFrequencyInfo *BFI = nullptr;
+ if (PSI && PSI->hasProfileSummary())
+ BFI = &AM.getResult<BlockFrequencyAnalysis>(F);
LoopVectorizeResult Result =
runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AC, LAIs, ORE, PSI);
if (!Result.MadeAnyChange)
return PreservedAnalyses::all();
PreservedAnalyses PA;
+ if (isAssignmentTrackingEnabled(*F.getParent())) {
+ for (auto &BB : F)
+ RemoveRedundantDbgInstrs(&BB);
+ }
+
// We currently do not preserve loopinfo/dominator analyses with outer loop
// vectorization. Until this is addressed, mark these analyses as preserved
// only for non-VPlan-native path.
@@ -10679,6 +10690,11 @@ PreservedAnalyses LoopVectorizePass::run(Function &F,
if (!EnableVPlanNativePath) {
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<ScalarEvolutionAnalysis>();
+
+#ifdef EXPENSIVE_CHECKS
+ SE.verify();
+#endif
}
if (Result.MadeCFGChange) {
@@ -10699,8 +10715,8 @@ void LoopVectorizePass::printPipeline(
static_cast<PassInfoMixin<LoopVectorizePass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
OS << (InterleaveOnlyWhenForced ? "" : "no-") << "interleave-forced-only;";
OS << (VectorizeOnlyWhenForced ? "" : "no-") << "vectorize-forced-only;";
- OS << ">";
+ OS << '>';
}
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e3eb6b1804e7..821a3fa22a85 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -87,7 +87,6 @@
#include "llvm/Transforms/Utils/InjectTLIMappings.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
-#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -126,6 +125,13 @@ static cl::opt<bool> ShouldStartVectorizeHorAtStore(
cl::desc(
"Attempt to vectorize horizontal reductions feeding into a store"));
+// NOTE: If AllowHorRdxIdenityOptimization is true, the optimization will run
+// even if we match a reduction but do not vectorize in the end.
+static cl::opt<bool> AllowHorRdxIdenityOptimization(
+ "slp-optimize-identity-hor-reduction-ops", cl::init(true), cl::Hidden,
+ cl::desc("Allow optimization of original scalar identity operations on "
+ "matched horizontal reductions."));
+
static cl::opt<int>
MaxVectorRegSizeOption("slp-max-reg-size", cl::init(128), cl::Hidden,
cl::desc("Attempt to vectorize for this register size in bits"));
@@ -287,7 +293,7 @@ static bool isCommutative(Instruction *I) {
/// \returns inserting index of InsertElement or InsertValue instruction,
/// using Offset as base offset for index.
static std::optional<unsigned> getInsertIndex(const Value *InsertInst,
- unsigned Offset = 0) {
+ unsigned Offset = 0) {
int Index = Offset;
if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) {
const auto *VT = dyn_cast<FixedVectorType>(IE->getType());
@@ -342,16 +348,16 @@ enum class UseMask {
static SmallBitVector buildUseMask(int VF, ArrayRef<int> Mask,
UseMask MaskArg) {
SmallBitVector UseMask(VF, true);
- for (auto P : enumerate(Mask)) {
- if (P.value() == UndefMaskElem) {
+ for (auto [Idx, Value] : enumerate(Mask)) {
+ if (Value == PoisonMaskElem) {
if (MaskArg == UseMask::UndefsAsMask)
- UseMask.reset(P.index());
+ UseMask.reset(Idx);
continue;
}
- if (MaskArg == UseMask::FirstArg && P.value() < VF)
- UseMask.reset(P.value());
- else if (MaskArg == UseMask::SecondArg && P.value() >= VF)
- UseMask.reset(P.value() - VF);
+ if (MaskArg == UseMask::FirstArg && Value < VF)
+ UseMask.reset(Value);
+ else if (MaskArg == UseMask::SecondArg && Value >= VF)
+ UseMask.reset(Value - VF);
}
return UseMask;
}
@@ -374,9 +380,9 @@ static SmallBitVector isUndefVector(const Value *V,
if (!UseMask.empty()) {
const Value *Base = V;
while (auto *II = dyn_cast<InsertElementInst>(Base)) {
+ Base = II->getOperand(0);
if (isa<T>(II->getOperand(1)))
continue;
- Base = II->getOperand(0);
std::optional<unsigned> Idx = getInsertIndex(II);
if (!Idx)
continue;
@@ -461,7 +467,7 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
Value *Vec2 = nullptr;
enum ShuffleMode { Unknown, Select, Permute };
ShuffleMode CommonShuffleMode = Unknown;
- Mask.assign(VL.size(), UndefMaskElem);
+ Mask.assign(VL.size(), PoisonMaskElem);
for (unsigned I = 0, E = VL.size(); I < E; ++I) {
// Undef can be represented as an undef element in a vector.
if (isa<UndefValue>(VL[I]))
@@ -533,6 +539,117 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
return *EI->idx_begin();
}
+/// Tries to find extractelement instructions with constant indices from fixed
+/// vector type and gather such instructions into a bunch, which highly likely
+/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was
+/// successful, the matched scalars are replaced by poison values in \p VL for
+/// future analysis.
+static std::optional<TTI::ShuffleKind>
+tryToGatherExtractElements(SmallVectorImpl<Value *> &VL,
+ SmallVectorImpl<int> &Mask) {
+ // Scan list of gathered scalars for extractelements that can be represented
+ // as shuffles.
+ MapVector<Value *, SmallVector<int>> VectorOpToIdx;
+ SmallVector<int> UndefVectorExtracts;
+ for (int I = 0, E = VL.size(); I < E; ++I) {
+ auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
+ if (!EI) {
+ if (isa<UndefValue>(VL[I]))
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
+ if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand()))
+ continue;
+ std::optional<unsigned> Idx = getExtractIndex(EI);
+ // Undefined index.
+ if (!Idx) {
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ SmallBitVector ExtractMask(VecTy->getNumElements(), true);
+ ExtractMask.reset(*Idx);
+ if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) {
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ VectorOpToIdx[EI->getVectorOperand()].push_back(I);
+ }
+ // Sort the vector operands by the maximum number of uses in extractelements.
+ MapVector<unsigned, SmallVector<Value *>> VFToVector;
+ for (const auto &Data : VectorOpToIdx)
+ VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
+ .push_back(Data.first);
+ for (auto &Data : VFToVector) {
+ stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
+ return VectorOpToIdx.find(V1)->second.size() >
+ VectorOpToIdx.find(V2)->second.size();
+ });
+ }
+ // Find the best pair of the vectors with the same number of elements or a
+ // single vector.
+ const int UndefSz = UndefVectorExtracts.size();
+ unsigned SingleMax = 0;
+ Value *SingleVec = nullptr;
+ unsigned PairMax = 0;
+ std::pair<Value *, Value *> PairVec(nullptr, nullptr);
+ for (auto &Data : VFToVector) {
+ Value *V1 = Data.second.front();
+ if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
+ SingleMax = VectorOpToIdx[V1].size() + UndefSz;
+ SingleVec = V1;
+ }
+ Value *V2 = nullptr;
+ if (Data.second.size() > 1)
+ V2 = *std::next(Data.second.begin());
+ if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
+ UndefSz) {
+ PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
+ PairVec = std::make_pair(V1, V2);
+ }
+ }
+ if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
+ return std::nullopt;
+ // Check if better to perform a shuffle of 2 vectors or just of a single
+ // vector.
+ SmallVector<Value *> SavedVL(VL.begin(), VL.end());
+ SmallVector<Value *> GatheredExtracts(
+ VL.size(), PoisonValue::get(VL.front()->getType()));
+ if (SingleMax >= PairMax && SingleMax) {
+ for (int Idx : VectorOpToIdx[SingleVec])
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ } else {
+ for (Value *V : {PairVec.first, PairVec.second})
+ for (int Idx : VectorOpToIdx[V])
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ }
+ // Add extracts from undefs too.
+ for (int Idx : UndefVectorExtracts)
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ // Check that gather of extractelements can be represented as just a
+ // shuffle of a single/two vectors the scalars are extracted from.
+ std::optional<TTI::ShuffleKind> Res =
+ isFixedVectorShuffle(GatheredExtracts, Mask);
+ if (!Res) {
+ // TODO: try to check other subsets if possible.
+ // Restore the original VL if attempt was not successful.
+ VL.swap(SavedVL);
+ return std::nullopt;
+ }
+ // Restore unused scalars from mask, if some of the extractelements were not
+ // selected for shuffle.
+ for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) {
+ auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
+ if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) ||
+ !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) ||
+ is_contained(UndefVectorExtracts, I))
+ continue;
+ if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I]))
+ std::swap(VL[I], GatheredExtracts[I]);
+ }
+ return Res;
+}
+
namespace {
/// Main data required for vectorization of instructions.
@@ -829,18 +946,29 @@ static bool isSimple(Instruction *I) {
}
/// Shuffles \p Mask in accordance with the given \p SubMask.
-static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask) {
+/// \param ExtendingManyInputs Supports reshuffling of the mask with not only
+/// one but two input vectors.
+static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask,
+ bool ExtendingManyInputs = false) {
if (SubMask.empty())
return;
+ assert(
+ (!ExtendingManyInputs || SubMask.size() > Mask.size() ||
+ // Check if input scalars were extended to match the size of other node.
+ (SubMask.size() == Mask.size() &&
+ std::all_of(std::next(Mask.begin(), Mask.size() / 2), Mask.end(),
+ [](int Idx) { return Idx == PoisonMaskElem; }))) &&
+ "SubMask with many inputs support must be larger than the mask.");
if (Mask.empty()) {
Mask.append(SubMask.begin(), SubMask.end());
return;
}
- SmallVector<int> NewMask(SubMask.size(), UndefMaskElem);
+ SmallVector<int> NewMask(SubMask.size(), PoisonMaskElem);
int TermValue = std::min(Mask.size(), SubMask.size());
for (int I = 0, E = SubMask.size(); I < E; ++I) {
- if (SubMask[I] >= TermValue || SubMask[I] == UndefMaskElem ||
- Mask[SubMask[I]] >= TermValue)
+ if (SubMask[I] == PoisonMaskElem ||
+ (!ExtendingManyInputs &&
+ (SubMask[I] >= TermValue || Mask[SubMask[I]] >= TermValue)))
continue;
NewMask[I] = Mask[SubMask[I]];
}
@@ -887,7 +1015,7 @@ static void inversePermutation(ArrayRef<unsigned> Indices,
SmallVectorImpl<int> &Mask) {
Mask.clear();
const unsigned E = Indices.size();
- Mask.resize(E, UndefMaskElem);
+ Mask.resize(E, PoisonMaskElem);
for (unsigned I = 0; I < E; ++I)
Mask[Indices[I]] = I;
}
@@ -900,7 +1028,7 @@ static void reorderScalars(SmallVectorImpl<Value *> &Scalars,
UndefValue::get(Scalars.front()->getType()));
Prev.swap(Scalars);
for (unsigned I = 0, E = Prev.size(); I < E; ++I)
- if (Mask[I] != UndefMaskElem)
+ if (Mask[I] != PoisonMaskElem)
Scalars[Mask[I]] = Prev[I];
}
@@ -962,6 +1090,7 @@ namespace slpvectorizer {
class BoUpSLP {
struct TreeEntry;
struct ScheduleData;
+ class ShuffleCostEstimator;
class ShuffleInstructionBuilder;
public:
@@ -1006,8 +1135,12 @@ public:
/// Vectorize the tree but with the list of externally used values \p
/// ExternallyUsedValues. Values in this MapVector can be replaced but the
/// generated extractvalue instructions.
- Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
- Instruction *ReductionRoot = nullptr);
+ /// \param ReplacedExternals containd list of replaced external values
+ /// {scalar, replace} after emitting extractelement for external uses.
+ Value *
+ vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
+ SmallVectorImpl<std::pair<Value *, Value *>> &ReplacedExternals,
+ Instruction *ReductionRoot = nullptr);
/// \returns the cost incurred by unwanted spills and fills, caused by
/// holding live values over call sites.
@@ -1025,24 +1158,18 @@ public:
/// Construct a vectorizable tree that starts at \p Roots.
void buildTree(ArrayRef<Value *> Roots);
- /// Checks if the very first tree node is going to be vectorized.
- bool isVectorizedFirstNode() const {
- return !VectorizableTree.empty() &&
- VectorizableTree.front()->State == TreeEntry::Vectorize;
- }
-
- /// Returns the main instruction for the very first node.
- Instruction *getFirstNodeMainOp() const {
- assert(!VectorizableTree.empty() && "No tree to get the first node from");
- return VectorizableTree.front()->getMainOp();
- }
-
/// Returns whether the root node has in-tree uses.
bool doesRootHaveInTreeUses() const {
return !VectorizableTree.empty() &&
!VectorizableTree.front()->UserTreeIndices.empty();
}
+ /// Return the scalars of the root node.
+ ArrayRef<Value *> getRootNodeScalars() const {
+ assert(!VectorizableTree.empty() && "No graph to get the first node from");
+ return VectorizableTree.front()->Scalars;
+ }
+
/// Builds external uses of the vectorized scalars, i.e. the list of
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
/// ExternallyUsedValues contains additional list of external uses to handle
@@ -1064,6 +1191,8 @@ public:
MinBWs.clear();
InstrElementSize.clear();
UserIgnoreList = nullptr;
+ PostponedGathers.clear();
+ ValueToGatherNodes.clear();
}
unsigned getTreeSize() const { return VectorizableTree.size(); }
@@ -1083,9 +1212,12 @@ public:
/// Gets reordering data for the given tree entry. If the entry is vectorized
/// - just return ReorderIndices, otherwise check if the scalars can be
/// reordered and return the most optimal order.
+ /// \return std::nullopt if ordering is not important, empty order, if
+ /// identity order is important, or the actual order.
/// \param TopToBottom If true, include the order of vectorized stores and
/// insertelement nodes, otherwise skip them.
- std::optional<OrdersType> getReorderingData(const TreeEntry &TE, bool TopToBottom);
+ std::optional<OrdersType> getReorderingData(const TreeEntry &TE,
+ bool TopToBottom);
/// Reorders the current graph to the most profitable order starting from the
/// root node to the leaf nodes. The best order is chosen only from the nodes
@@ -1328,8 +1460,14 @@ public:
ConstantInt *Ex1Idx;
if (match(V1, m_ExtractElt(m_Value(EV1), m_ConstantInt(Ex1Idx)))) {
// Undefs are always profitable for extractelements.
+ // Compiler can easily combine poison and extractelement <non-poison> or
+ // undef and extractelement <poison>. But combining undef +
+ // extractelement <non-poison-but-may-produce-poison> requires some
+ // extra operations.
if (isa<UndefValue>(V2))
- return LookAheadHeuristics::ScoreConsecutiveExtracts;
+ return (isa<PoisonValue>(V2) || isUndefVector(EV1).all())
+ ? LookAheadHeuristics::ScoreConsecutiveExtracts
+ : LookAheadHeuristics::ScoreSameOpcode;
Value *EV2 = nullptr;
ConstantInt *Ex2Idx = nullptr;
if (match(V2,
@@ -1683,9 +1821,10 @@ public:
// Search all operands in Ops[*][Lane] for the one that matches best
// Ops[OpIdx][LastLane] and return its opreand index.
// If no good match can be found, return std::nullopt.
- std::optional<unsigned> getBestOperand(unsigned OpIdx, int Lane, int LastLane,
- ArrayRef<ReorderingMode> ReorderingModes,
- ArrayRef<Value *> MainAltOps) {
+ std::optional<unsigned>
+ getBestOperand(unsigned OpIdx, int Lane, int LastLane,
+ ArrayRef<ReorderingMode> ReorderingModes,
+ ArrayRef<Value *> MainAltOps) {
unsigned NumOperands = getNumOperands();
// The operand of the previous lane at OpIdx.
@@ -2299,7 +2438,8 @@ private:
/// \returns the cost of the vectorizable entry.
InstructionCost getEntryCost(const TreeEntry *E,
- ArrayRef<Value *> VectorizedVals);
+ ArrayRef<Value *> VectorizedVals,
+ SmallPtrSetImpl<Value *> &CheckedExtracts);
/// This is the recursive part of buildTree.
void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth,
@@ -2323,15 +2463,13 @@ private:
/// Create a new vector from a list of scalar values. Produces a sequence
/// which exploits values reused across lanes, and arranges the inserts
/// for ease of later optimization.
- Value *createBuildVector(const TreeEntry *E);
+ template <typename BVTy, typename ResTy, typename... Args>
+ ResTy processBuildVector(const TreeEntry *E, Args &...Params);
- /// \returns the scalarization cost for this type. Scalarization in this
- /// context means the creation of vectors from a group of scalars. If \p
- /// NeedToShuffle is true, need to add a cost of reshuffling some of the
- /// vector elements.
- InstructionCost getGatherCost(FixedVectorType *Ty,
- const APInt &ShuffledIndices,
- bool NeedToShuffle) const;
+ /// Create a new vector from a list of scalar values. Produces a sequence
+ /// which exploits values reused across lanes, and arranges the inserts
+ /// for ease of later optimization.
+ Value *createBuildVector(const TreeEntry *E);
/// Returns the instruction in the bundle, which can be used as a base point
/// for scheduling. Usually it is the last instruction in the bundle, except
@@ -2354,14 +2492,16 @@ private:
/// \returns the scalarization cost for this list of values. Assuming that
/// this subtree gets vectorized, we may need to extract the values from the
/// roots. This method calculates the cost of extracting the values.
- InstructionCost getGatherCost(ArrayRef<Value *> VL) const;
+ /// \param ForPoisonSrc true if initial vector is poison, false otherwise.
+ InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc) const;
/// Set the Builder insert point to one after the last instruction in
/// the bundle
void setInsertPointAfterBundle(const TreeEntry *E);
- /// \returns a vector from a collection of scalars in \p VL.
- Value *gather(ArrayRef<Value *> VL);
+ /// \returns a vector from a collection of scalars in \p VL. if \p Root is not
+ /// specified, the starting vector value is poison.
+ Value *gather(ArrayRef<Value *> VL, Value *Root);
/// \returns whether the VectorizableTree is fully vectorizable and will
/// be beneficial even the tree height is tiny.
@@ -2400,6 +2540,14 @@ private:
using VecTreeTy = SmallVector<std::unique_ptr<TreeEntry>, 8>;
TreeEntry(VecTreeTy &Container) : Container(Container) {}
+ /// \returns Common mask for reorder indices and reused scalars.
+ SmallVector<int> getCommonMask() const {
+ SmallVector<int> Mask;
+ inversePermutation(ReorderIndices, Mask);
+ ::addMask(Mask, ReuseShuffleIndices);
+ return Mask;
+ }
+
/// \returns true if the scalars in VL are equal to this entry.
bool isSame(ArrayRef<Value *> VL) const {
auto &&IsSame = [VL](ArrayRef<Value *> Scalars, ArrayRef<int> Mask) {
@@ -2409,8 +2557,8 @@ private:
std::equal(VL.begin(), VL.end(), Mask.begin(),
[Scalars](Value *V, int Idx) {
return (isa<UndefValue>(V) &&
- Idx == UndefMaskElem) ||
- (Idx != UndefMaskElem && V == Scalars[Idx]);
+ Idx == PoisonMaskElem) ||
+ (Idx != PoisonMaskElem && V == Scalars[Idx]);
});
};
if (!ReorderIndices.empty()) {
@@ -2471,7 +2619,7 @@ private:
ValueList Scalars;
/// The Scalars are vectorized into this value. It is initialized to Null.
- Value *VectorizedValue = nullptr;
+ WeakTrackingVH VectorizedValue = nullptr;
/// Do we need to gather this sequence or vectorize it
/// (either with vector instruction or with scatter/gather
@@ -2684,20 +2832,22 @@ private:
#ifndef NDEBUG
void dumpTreeCosts(const TreeEntry *E, InstructionCost ReuseShuffleCost,
- InstructionCost VecCost,
- InstructionCost ScalarCost) const {
- dbgs() << "SLP: Calculated costs for Tree:\n"; E->dump();
+ InstructionCost VecCost, InstructionCost ScalarCost,
+ StringRef Banner) const {
+ dbgs() << "SLP: " << Banner << ":\n";
+ E->dump();
dbgs() << "SLP: Costs:\n";
dbgs() << "SLP: ReuseShuffleCost = " << ReuseShuffleCost << "\n";
dbgs() << "SLP: VectorCost = " << VecCost << "\n";
dbgs() << "SLP: ScalarCost = " << ScalarCost << "\n";
- dbgs() << "SLP: ReuseShuffleCost + VecCost - ScalarCost = " <<
- ReuseShuffleCost + VecCost - ScalarCost << "\n";
+ dbgs() << "SLP: ReuseShuffleCost + VecCost - ScalarCost = "
+ << ReuseShuffleCost + VecCost - ScalarCost << "\n";
}
#endif
/// Create a new VectorizableTree entry.
- TreeEntry *newTreeEntry(ArrayRef<Value *> VL, std::optional<ScheduleData *> Bundle,
+ TreeEntry *newTreeEntry(ArrayRef<Value *> VL,
+ std::optional<ScheduleData *> Bundle,
const InstructionsState &S,
const EdgeInfo &UserTreeIdx,
ArrayRef<int> ReuseShuffleIndices = std::nullopt,
@@ -2791,8 +2941,14 @@ private:
return ScalarToTreeEntry.lookup(V);
}
+ /// Checks if the specified list of the instructions/values can be vectorized
+ /// and fills required data before actual scheduling of the instructions.
+ TreeEntry::EntryState getScalarsVectorizationState(
+ InstructionsState &S, ArrayRef<Value *> VL, bool IsScatterVectorizeUserTE,
+ OrdersType &CurrentOrder, SmallVectorImpl<Value *> &PointerOps) const;
+
/// Maps a specific scalar to its tree entry.
- SmallDenseMap<Value*, TreeEntry *> ScalarToTreeEntry;
+ SmallDenseMap<Value *, TreeEntry *> ScalarToTreeEntry;
/// Maps a value to the proposed vectorizable size.
SmallDenseMap<Value *, unsigned> InstrElementSize;
@@ -2808,6 +2964,15 @@ private:
/// pre-gather them before.
DenseMap<const TreeEntry *, Instruction *> EntryToLastInstruction;
+ /// List of gather nodes, depending on other gather/vector nodes, which should
+ /// be emitted after the vector instruction emission process to correctly
+ /// handle order of the vector instructions and shuffles.
+ SetVector<const TreeEntry *> PostponedGathers;
+
+ using ValueToGatherNodesMap =
+ DenseMap<Value *, SmallPtrSet<const TreeEntry *, 4>>;
+ ValueToGatherNodesMap ValueToGatherNodes;
+
/// This POD struct describes one external user in the vectorized tree.
struct ExternalUser {
ExternalUser(Value *S, llvm::User *U, int L)
@@ -3235,7 +3400,6 @@ private:
<< "SLP: gets ready (ctl): " << *DepBundle << "\n");
}
}
-
}
}
@@ -3579,7 +3743,7 @@ static void reorderReuses(SmallVectorImpl<int> &Reuses, ArrayRef<int> Mask) {
SmallVector<int> Prev(Reuses.begin(), Reuses.end());
Prev.swap(Reuses);
for (unsigned I = 0, E = Prev.size(); I < E; ++I)
- if (Mask[I] != UndefMaskElem)
+ if (Mask[I] != PoisonMaskElem)
Reuses[Mask[I]] = Prev[I];
}
@@ -3603,7 +3767,7 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) {
}
Order.assign(Mask.size(), Mask.size());
for (unsigned I = 0, E = Mask.size(); I < E; ++I)
- if (MaskOrder[I] != UndefMaskElem)
+ if (MaskOrder[I] != PoisonMaskElem)
Order[MaskOrder[I]] = I;
fixupOrderingIndices(Order);
}
@@ -3653,10 +3817,8 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
return false;
return true;
};
- if (IsIdentityOrder(CurrentOrder)) {
- CurrentOrder.clear();
- return CurrentOrder;
- }
+ if (IsIdentityOrder(CurrentOrder))
+ return OrdersType();
auto *It = CurrentOrder.begin();
for (unsigned I = 0; I < NumScalars;) {
if (UsedPositions.test(I)) {
@@ -3669,7 +3831,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
}
++It;
}
- return CurrentOrder;
+ return std::move(CurrentOrder);
}
return std::nullopt;
}
@@ -3779,9 +3941,9 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
return LoadsState::Gather;
}
-bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
- const DataLayout &DL, ScalarEvolution &SE,
- SmallVectorImpl<unsigned> &SortedIndices) {
+static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
+ const DataLayout &DL, ScalarEvolution &SE,
+ SmallVectorImpl<unsigned> &SortedIndices) {
assert(llvm::all_of(
VL, [](const Value *V) { return V->getType()->isPointerTy(); }) &&
"Expected list of pointer operands.");
@@ -3825,7 +3987,7 @@ bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
return std::get<1>(X) < std::get<1>(Y);
});
int InitialOffset = std::get<1>(Vec[0]);
- AnyConsecutive |= all_of(enumerate(Vec), [InitialOffset](auto &P) {
+ AnyConsecutive |= all_of(enumerate(Vec), [InitialOffset](const auto &P) {
return std::get<1>(P.value()) == int(P.index()) + InitialOffset;
});
}
@@ -3862,7 +4024,7 @@ BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) {
BoUpSLP::OrdersType Order;
if (clusterSortPtrAccesses(Ptrs, ScalarTy, *DL, *SE, Order))
- return Order;
+ return std::move(Order);
return std::nullopt;
}
@@ -3888,31 +4050,35 @@ static bool areTwoInsertFromSameBuildVector(
// Go through the vector operand of insertelement instructions trying to find
// either VU as the original vector for IE2 or V as the original vector for
// IE1.
+ SmallSet<int, 8> ReusedIdx;
+ bool IsReusedIdx = false;
do {
- if (IE2 == VU)
+ if (IE2 == VU && !IE1)
return VU->hasOneUse();
- if (IE1 == V)
+ if (IE1 == V && !IE2)
return V->hasOneUse();
- if (IE1) {
- if ((IE1 != VU && !IE1->hasOneUse()) ||
- getInsertIndex(IE1).value_or(*Idx2) == *Idx2)
+ if (IE1 && IE1 != V) {
+ IsReusedIdx |=
+ !ReusedIdx.insert(getInsertIndex(IE1).value_or(*Idx2)).second;
+ if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx)
IE1 = nullptr;
else
IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1));
}
- if (IE2) {
- if ((IE2 != V && !IE2->hasOneUse()) ||
- getInsertIndex(IE2).value_or(*Idx1) == *Idx1)
+ if (IE2 && IE2 != VU) {
+ IsReusedIdx |=
+ !ReusedIdx.insert(getInsertIndex(IE2).value_or(*Idx1)).second;
+ if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx)
IE2 = nullptr;
else
IE2 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE2));
}
- } while (IE1 || IE2);
+ } while (!IsReusedIdx && (IE1 || IE2));
return false;
}
-std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE,
- bool TopToBottom) {
+std::optional<BoUpSLP::OrdersType>
+BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
// No need to reorder if need to shuffle reuses, still need to shuffle the
// node.
if (!TE.ReuseShuffleIndices.empty()) {
@@ -3936,14 +4102,14 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
std::optional<unsigned> Idx = getExtractIndex(cast<Instruction>(V));
return Idx && *Idx < Sz;
})) {
- SmallVector<int> ReorderMask(Sz, UndefMaskElem);
+ SmallVector<int> ReorderMask(Sz, PoisonMaskElem);
if (TE.ReorderIndices.empty())
std::iota(ReorderMask.begin(), ReorderMask.end(), 0);
else
inversePermutation(TE.ReorderIndices, ReorderMask);
for (unsigned I = 0; I < VF; ++I) {
int &Idx = ReusedMask[I];
- if (Idx == UndefMaskElem)
+ if (Idx == PoisonMaskElem)
continue;
Value *V = TE.Scalars[ReorderMask[Idx]];
std::optional<unsigned> EI = getExtractIndex(cast<Instruction>(V));
@@ -3958,7 +4124,7 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
for (unsigned K = 0; K < VF; K += Sz) {
OrdersType CurrentOrder(TE.ReorderIndices);
SmallVector<int> SubMask{ArrayRef(ReusedMask).slice(K, Sz)};
- if (SubMask.front() == UndefMaskElem)
+ if (SubMask.front() == PoisonMaskElem)
std::iota(SubMask.begin(), SubMask.end(), 0);
reorderOrder(CurrentOrder, SubMask);
transform(CurrentOrder, It, [K](unsigned Pos) { return Pos + K; });
@@ -3966,8 +4132,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
}
if (all_of(enumerate(ResOrder),
[](const auto &Data) { return Data.index() == Data.value(); }))
- return {}; // Use identity order.
- return ResOrder;
+ return std::nullopt; // No need to reorder.
+ return std::move(ResOrder);
}
if (TE.State == TreeEntry::Vectorize &&
(isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) ||
@@ -3976,6 +4142,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
return TE.ReorderIndices;
if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::PHI) {
auto PHICompare = [](llvm::Value *V1, llvm::Value *V2) {
+ if (V1 == V2)
+ return false;
if (!V1->hasOneUse() || !V2->hasOneUse())
return false;
auto *FirstUserOfPhi1 = cast<Instruction>(*V1->user_begin());
@@ -4023,8 +4191,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
for (unsigned Id = 0, Sz = Phis.size(); Id < Sz; ++Id)
ResOrder[Id] = PhiToId[Phis[Id]];
if (IsIdentityOrder(ResOrder))
- return {};
- return ResOrder;
+ return std::nullopt; // No need to reorder.
+ return std::move(ResOrder);
}
if (TE.State == TreeEntry::NeedToGather) {
// TODO: add analysis of other gather nodes with extractelement
@@ -4050,7 +4218,42 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T
if (Reuse || !CurrentOrder.empty()) {
if (!CurrentOrder.empty())
fixupOrderingIndices(CurrentOrder);
- return CurrentOrder;
+ return std::move(CurrentOrder);
+ }
+ }
+ // If the gather node is <undef, v, .., poison> and
+ // insertelement poison, v, 0 [+ permute]
+ // is cheaper than
+ // insertelement poison, v, n - try to reorder.
+ // If rotating the whole graph, exclude the permute cost, the whole graph
+ // might be transformed.
+ int Sz = TE.Scalars.size();
+ if (isSplat(TE.Scalars) && !allConstant(TE.Scalars) &&
+ count_if(TE.Scalars, UndefValue::classof) == Sz - 1) {
+ const auto *It =
+ find_if(TE.Scalars, [](Value *V) { return !isConstant(V); });
+ if (It == TE.Scalars.begin())
+ return OrdersType();
+ auto *Ty = FixedVectorType::get(TE.Scalars.front()->getType(), Sz);
+ if (It != TE.Scalars.end()) {
+ OrdersType Order(Sz, Sz);
+ unsigned Idx = std::distance(TE.Scalars.begin(), It);
+ Order[Idx] = 0;
+ fixupOrderingIndices(Order);
+ SmallVector<int> Mask;
+ inversePermutation(Order, Mask);
+ InstructionCost PermuteCost =
+ TopToBottom
+ ? 0
+ : TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, Mask);
+ InstructionCost InsertFirstCost = TTI->getVectorInstrCost(
+ Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, 0,
+ PoisonValue::get(Ty), *It);
+ InstructionCost InsertIdxCost = TTI->getVectorInstrCost(
+ Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, Idx,
+ PoisonValue::get(Ty), *It);
+ if (InsertFirstCost + PermuteCost < InsertIdxCost)
+ return std::move(Order);
}
}
if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE))
@@ -4260,7 +4463,7 @@ void BoUpSLP::reorderTopToBottom() {
unsigned E = Order.size();
OrdersType CurrentOrder(E, E);
transform(Mask, CurrentOrder.begin(), [E](int Idx) {
- return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx);
+ return Idx == PoisonMaskElem ? E : static_cast<unsigned>(Idx);
});
fixupOrderingIndices(CurrentOrder);
++OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second;
@@ -4285,10 +4488,10 @@ void BoUpSLP::reorderTopToBottom() {
continue;
SmallVector<int> Mask;
inversePermutation(BestOrder, Mask);
- SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem);
+ SmallVector<int> MaskOrder(BestOrder.size(), PoisonMaskElem);
unsigned E = BestOrder.size();
transform(BestOrder, MaskOrder.begin(), [E](unsigned I) {
- return I < E ? static_cast<int>(I) : UndefMaskElem;
+ return I < E ? static_cast<int>(I) : PoisonMaskElem;
});
// Do an actual reordering, if profitable.
for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) {
@@ -4384,7 +4587,7 @@ bool BoUpSLP::canReorderOperands(
}
return false;
}) > 1 &&
- !all_of(UserTE->getOperand(I), isConstant))
+ !allConstant(UserTE->getOperand(I)))
return false;
if (Gather)
GatherOps.push_back(Gather);
@@ -4499,7 +4702,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
unsigned E = Order.size();
OrdersType CurrentOrder(E, E);
transform(Mask, CurrentOrder.begin(), [E](int Idx) {
- return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx);
+ return Idx == PoisonMaskElem ? E : static_cast<unsigned>(Idx);
});
fixupOrderingIndices(CurrentOrder);
OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second +=
@@ -4578,10 +4781,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
VisitedOps.clear();
SmallVector<int> Mask;
inversePermutation(BestOrder, Mask);
- SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem);
+ SmallVector<int> MaskOrder(BestOrder.size(), PoisonMaskElem);
unsigned E = BestOrder.size();
transform(BestOrder, MaskOrder.begin(), [E](unsigned I) {
- return I < E ? static_cast<int>(I) : UndefMaskElem;
+ return I < E ? static_cast<int>(I) : PoisonMaskElem;
});
for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) {
TreeEntry *TE = Op.second;
@@ -4779,7 +4982,7 @@ bool BoUpSLP::canFormVector(const SmallVector<StoreInst *, 4> &StoresVec,
// Check if the stores are consecutive by checking if their difference is 1.
for (unsigned Idx : seq<unsigned>(1, StoreOffsetVec.size()))
- if (StoreOffsetVec[Idx].second != StoreOffsetVec[Idx-1].second + 1)
+ if (StoreOffsetVec[Idx].second != StoreOffsetVec[Idx - 1].second + 1)
return false;
// Calculate the shuffle indices according to their offset against the sorted
@@ -4976,6 +5179,309 @@ static bool isAlternateInstruction(const Instruction *I,
const Instruction *AltOp,
const TargetLibraryInfo &TLI);
+BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
+ InstructionsState &S, ArrayRef<Value *> VL, bool IsScatterVectorizeUserTE,
+ OrdersType &CurrentOrder, SmallVectorImpl<Value *> &PointerOps) const {
+ assert(S.MainOp && "Expected instructions with same/alternate opcodes only.");
+
+ unsigned ShuffleOrOp =
+ S.isAltShuffle() ? (unsigned)Instruction::ShuffleVector : S.getOpcode();
+ auto *VL0 = cast<Instruction>(S.OpValue);
+ switch (ShuffleOrOp) {
+ case Instruction::PHI: {
+ // Check for terminator values (e.g. invoke).
+ for (Value *V : VL)
+ for (Value *Incoming : cast<PHINode>(V)->incoming_values()) {
+ Instruction *Term = dyn_cast<Instruction>(Incoming);
+ if (Term && Term->isTerminator()) {
+ LLVM_DEBUG(dbgs()
+ << "SLP: Need to swizzle PHINodes (terminator use).\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::ExtractValue:
+ case Instruction::ExtractElement: {
+ bool Reuse = canReuseExtract(VL, VL0, CurrentOrder);
+ if (Reuse || !CurrentOrder.empty())
+ return TreeEntry::Vectorize;
+ LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n");
+ return TreeEntry::NeedToGather;
+ }
+ case Instruction::InsertElement: {
+ // Check that we have a buildvector and not a shuffle of 2 or more
+ // different vectors.
+ ValueSet SourceVectors;
+ for (Value *V : VL) {
+ SourceVectors.insert(cast<Instruction>(V)->getOperand(0));
+ assert(getInsertIndex(V) != std::nullopt &&
+ "Non-constant or undef index?");
+ }
+
+ if (count_if(VL, [&SourceVectors](Value *V) {
+ return !SourceVectors.contains(V);
+ }) >= 2) {
+ // Found 2nd source vector - cancel.
+ LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with "
+ "different source vectors.\n");
+ return TreeEntry::NeedToGather;
+ }
+
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::Load: {
+ // Check that a vectorized load would load the same memory as a scalar
+ // load. For example, we don't want to vectorize loads that are smaller
+ // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
+ // treats loading/storing it as an i8 struct. If we vectorize loads/stores
+ // from such a struct, we read/write packed bits disagreeing with the
+ // unvectorized version.
+ switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder,
+ PointerOps)) {
+ case LoadsState::Vectorize:
+ return TreeEntry::Vectorize;
+ case LoadsState::ScatterVectorize:
+ return TreeEntry::ScatterVectorize;
+ case LoadsState::Gather:
+#ifndef NDEBUG
+ Type *ScalarTy = VL0->getType();
+ if (DL->getTypeSizeInBits(ScalarTy) !=
+ DL->getTypeAllocSizeInBits(ScalarTy))
+ LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n");
+ else if (any_of(VL,
+ [](Value *V) { return !cast<LoadInst>(V)->isSimple(); }))
+ LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n");
+ else
+ LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n");
+#endif // NDEBUG
+ return TreeEntry::NeedToGather;
+ }
+ llvm_unreachable("Unexpected state of loads");
+ }
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::FPExt:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ case Instruction::Trunc:
+ case Instruction::FPTrunc:
+ case Instruction::BitCast: {
+ Type *SrcTy = VL0->getOperand(0)->getType();
+ for (Value *V : VL) {
+ Type *Ty = cast<Instruction>(V)->getOperand(0)->getType();
+ if (Ty != SrcTy || !isValidElementType(Ty)) {
+ LLVM_DEBUG(
+ dbgs() << "SLP: Gathering casts with different src types.\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::ICmp:
+ case Instruction::FCmp: {
+ // Check that all of the compares have the same predicate.
+ CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
+ CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0);
+ Type *ComparedTy = VL0->getOperand(0)->getType();
+ for (Value *V : VL) {
+ CmpInst *Cmp = cast<CmpInst>(V);
+ if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) ||
+ Cmp->getOperand(0)->getType() != ComparedTy) {
+ LLVM_DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::Select:
+ case Instruction::FNeg:
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::FDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ return TreeEntry::Vectorize;
+ case Instruction::GetElementPtr: {
+ // We don't combine GEPs with complicated (nested) indexing.
+ for (Value *V : VL) {
+ auto *I = dyn_cast<GetElementPtrInst>(V);
+ if (!I)
+ continue;
+ if (I->getNumOperands() != 2) {
+ LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+
+ // We can't combine several GEPs into one vector if they operate on
+ // different types.
+ Type *Ty0 = cast<GEPOperator>(VL0)->getSourceElementType();
+ for (Value *V : VL) {
+ auto *GEP = dyn_cast<GEPOperator>(V);
+ if (!GEP)
+ continue;
+ Type *CurTy = GEP->getSourceElementType();
+ if (Ty0 != CurTy) {
+ LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+
+ // We don't combine GEPs with non-constant indexes.
+ Type *Ty1 = VL0->getOperand(1)->getType();
+ for (Value *V : VL) {
+ auto *I = dyn_cast<GetElementPtrInst>(V);
+ if (!I)
+ continue;
+ auto *Op = I->getOperand(1);
+ if ((!IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) ||
+ (Op->getType() != Ty1 &&
+ ((IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) ||
+ Op->getType()->getScalarSizeInBits() >
+ DL->getIndexSizeInBits(
+ V->getType()->getPointerAddressSpace())))) {
+ LLVM_DEBUG(
+ dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::Store: {
+ // Check if the stores are consecutive or if we need to swizzle them.
+ llvm::Type *ScalarTy = cast<StoreInst>(VL0)->getValueOperand()->getType();
+ // Avoid types that are padded when being allocated as scalars, while
+ // being packed together in a vector (such as i1).
+ if (DL->getTypeSizeInBits(ScalarTy) !=
+ DL->getTypeAllocSizeInBits(ScalarTy)) {
+ LLVM_DEBUG(dbgs() << "SLP: Gathering stores of non-packed type.\n");
+ return TreeEntry::NeedToGather;
+ }
+ // Make sure all stores in the bundle are simple - we can't vectorize
+ // atomic or volatile stores.
+ for (Value *V : VL) {
+ auto *SI = cast<StoreInst>(V);
+ if (!SI->isSimple()) {
+ LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple stores.\n");
+ return TreeEntry::NeedToGather;
+ }
+ PointerOps.push_back(SI->getPointerOperand());
+ }
+
+ // Check the order of pointer operands.
+ if (llvm::sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, CurrentOrder)) {
+ Value *Ptr0;
+ Value *PtrN;
+ if (CurrentOrder.empty()) {
+ Ptr0 = PointerOps.front();
+ PtrN = PointerOps.back();
+ } else {
+ Ptr0 = PointerOps[CurrentOrder.front()];
+ PtrN = PointerOps[CurrentOrder.back()];
+ }
+ std::optional<int> Dist =
+ getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
+ // Check that the sorted pointer operands are consecutive.
+ if (static_cast<unsigned>(*Dist) == VL.size() - 1)
+ return TreeEntry::Vectorize;
+ }
+
+ LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n");
+ return TreeEntry::NeedToGather;
+ }
+ case Instruction::Call: {
+ // Check if the calls are all to the same vectorizable intrinsic or
+ // library function.
+ CallInst *CI = cast<CallInst>(VL0);
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+
+ VFShape Shape = VFShape::get(
+ *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
+ false /*HasGlobalPred*/);
+ Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
+
+ if (!VecFunc && !isTriviallyVectorizable(ID)) {
+ LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
+ return TreeEntry::NeedToGather;
+ }
+ Function *F = CI->getCalledFunction();
+ unsigned NumArgs = CI->arg_size();
+ SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr);
+ for (unsigned J = 0; J != NumArgs; ++J)
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+ ScalarArgs[J] = CI->getArgOperand(J);
+ for (Value *V : VL) {
+ CallInst *CI2 = dyn_cast<CallInst>(V);
+ if (!CI2 || CI2->getCalledFunction() != F ||
+ getVectorIntrinsicIDForCall(CI2, TLI) != ID ||
+ (VecFunc &&
+ VecFunc != VFDatabase(*CI2).getVectorizedFunction(Shape)) ||
+ !CI->hasIdenticalOperandBundleSchema(*CI2)) {
+ LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V
+ << "\n");
+ return TreeEntry::NeedToGather;
+ }
+ // Some intrinsics have scalar arguments and should be same in order for
+ // them to be vectorized.
+ for (unsigned J = 0; J != NumArgs; ++J) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+ Value *A1J = CI2->getArgOperand(J);
+ if (ScalarArgs[J] != A1J) {
+ LLVM_DEBUG(dbgs()
+ << "SLP: mismatched arguments in call:" << *CI
+ << " argument " << ScalarArgs[J] << "!=" << A1J << "\n");
+ return TreeEntry::NeedToGather;
+ }
+ }
+ }
+ // Verify that the bundle operands are identical between the two calls.
+ if (CI->hasOperandBundles() &&
+ !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(),
+ CI->op_begin() + CI->getBundleOperandsEndIndex(),
+ CI2->op_begin() + CI2->getBundleOperandsStartIndex())) {
+ LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI
+ << "!=" << *V << '\n');
+ return TreeEntry::NeedToGather;
+ }
+ }
+
+ return TreeEntry::Vectorize;
+ }
+ case Instruction::ShuffleVector: {
+ // If this is not an alternate sequence of opcode like add-sub
+ // then do not vectorize this instruction.
+ if (!S.isAltShuffle()) {
+ LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");
+ return TreeEntry::NeedToGather;
+ }
+ return TreeEntry::Vectorize;
+ }
+ default:
+ LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");
+ return TreeEntry::NeedToGather;
+ }
+}
+
void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
const EdgeInfo &UserTreeIdx) {
assert((allConstant(VL) || allSameType(VL)) && "Invalid types!");
@@ -4990,7 +5496,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
for (Value *V : VL) {
if (isConstant(V)) {
ReuseShuffleIndicies.emplace_back(
- isa<UndefValue>(V) ? UndefMaskElem : UniqueValues.size());
+ isa<UndefValue>(V) ? PoisonMaskElem : UniqueValues.size());
UniqueValues.emplace_back(V);
continue;
}
@@ -5010,7 +5516,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
return isa<UndefValue>(V) ||
!isConstant(V);
})) ||
- !llvm::isPowerOf2_32(NumUniqueScalarValues)) {
+ !llvm::has_single_bit<uint32_t>(NumUniqueScalarValues)) {
LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n");
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
return false;
@@ -5257,6 +5763,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (!TryToFindDuplicates(S))
return;
+ // Perform specific checks for each particular instruction kind.
+ OrdersType CurrentOrder;
+ SmallVector<Value *> PointerOps;
+ TreeEntry::EntryState State = getScalarsVectorizationState(
+ S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps);
+ if (State == TreeEntry::NeedToGather) {
+ newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
+ return;
+ }
+
auto &BSRef = BlocksSchedules[BB];
if (!BSRef)
BSRef = std::make_unique<BlockScheduling>(BB);
@@ -5285,20 +5802,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
case Instruction::PHI: {
auto *PH = cast<PHINode>(VL0);
- // Check for terminator values (e.g. invoke).
- for (Value *V : VL)
- for (Value *Incoming : cast<PHINode>(V)->incoming_values()) {
- Instruction *Term = dyn_cast<Instruction>(Incoming);
- if (Term && Term->isTerminator()) {
- LLVM_DEBUG(dbgs()
- << "SLP: Need to swizzle PHINodes (terminator use).\n");
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- return;
- }
- }
-
TreeEntry *TE =
newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n");
@@ -5326,9 +5829,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
case Instruction::ExtractValue:
case Instruction::ExtractElement: {
- OrdersType CurrentOrder;
- bool Reuse = canReuseExtract(VL, VL0, CurrentOrder);
- if (Reuse) {
+ if (CurrentOrder.empty()) {
LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n");
newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
@@ -5339,55 +5840,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
VectorizableTree.back()->setOperand(0, Op0);
return;
}
- if (!CurrentOrder.empty()) {
- LLVM_DEBUG({
- dbgs() << "SLP: Reusing or shuffling of reordered extract sequence "
- "with order";
- for (unsigned Idx : CurrentOrder)
- dbgs() << " " << Idx;
- dbgs() << "\n";
- });
- fixupOrderingIndices(CurrentOrder);
- // Insert new order with initial value 0, if it does not exist,
- // otherwise return the iterator to the existing one.
- newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies, CurrentOrder);
- // This is a special case, as it does not gather, but at the same time
- // we are not extending buildTree_rec() towards the operands.
- ValueList Op0;
- Op0.assign(VL.size(), VL0->getOperand(0));
- VectorizableTree.back()->setOperand(0, Op0);
- return;
- }
- LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n");
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- BS.cancelScheduling(VL, VL0);
+ LLVM_DEBUG({
+ dbgs() << "SLP: Reusing or shuffling of reordered extract sequence "
+ "with order";
+ for (unsigned Idx : CurrentOrder)
+ dbgs() << " " << Idx;
+ dbgs() << "\n";
+ });
+ fixupOrderingIndices(CurrentOrder);
+ // Insert new order with initial value 0, if it does not exist,
+ // otherwise return the iterator to the existing one.
+ newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies, CurrentOrder);
+ // This is a special case, as it does not gather, but at the same time
+ // we are not extending buildTree_rec() towards the operands.
+ ValueList Op0;
+ Op0.assign(VL.size(), VL0->getOperand(0));
+ VectorizableTree.back()->setOperand(0, Op0);
return;
}
case Instruction::InsertElement: {
assert(ReuseShuffleIndicies.empty() && "All inserts should be unique");
- // Check that we have a buildvector and not a shuffle of 2 or more
- // different vectors.
- ValueSet SourceVectors;
- for (Value *V : VL) {
- SourceVectors.insert(cast<Instruction>(V)->getOperand(0));
- assert(getInsertIndex(V) != std::nullopt &&
- "Non-constant or undef index?");
- }
-
- if (count_if(VL, [&SourceVectors](Value *V) {
- return !SourceVectors.contains(V);
- }) >= 2) {
- // Found 2nd source vector - cancel.
- LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with "
- "different source vectors.\n");
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
- BS.cancelScheduling(VL, VL0);
- return;
- }
-
auto OrdCompare = [](const std::pair<int, int> &P1,
const std::pair<int, int> &P2) {
return P1.first > P2.first;
@@ -5430,12 +5904,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// treats loading/storing it as an i8 struct. If we vectorize loads/stores
// from such a struct, we read/write packed bits disagreeing with the
// unvectorized version.
- SmallVector<Value *> PointerOps;
- OrdersType CurrentOrder;
TreeEntry *TE = nullptr;
- switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
- CurrentOrder, PointerOps)) {
- case LoadsState::Vectorize:
+ switch (State) {
+ case TreeEntry::Vectorize:
if (CurrentOrder.empty()) {
// Original loads are consecutive and does not require reordering.
TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
@@ -5450,7 +5921,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
TE->setOperandsInOrder();
break;
- case LoadsState::ScatterVectorize:
+ case TreeEntry::ScatterVectorize:
// Vectorizing non-consecutive loads with `llvm.masked.gather`.
TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S,
UserTreeIdx, ReuseShuffleIndicies);
@@ -5458,23 +5929,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
buildTree_rec(PointerOps, Depth + 1, {TE, 0});
LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n");
break;
- case LoadsState::Gather:
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
-#ifndef NDEBUG
- Type *ScalarTy = VL0->getType();
- if (DL->getTypeSizeInBits(ScalarTy) !=
- DL->getTypeAllocSizeInBits(ScalarTy))
- LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n");
- else if (any_of(VL, [](Value *V) {
- return !cast<LoadInst>(V)->isSimple();
- }))
- LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n");
- else
- LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n");
-#endif // NDEBUG
- break;
+ case TreeEntry::NeedToGather:
+ llvm_unreachable("Unexpected loads state.");
}
return;
}
@@ -5490,18 +5946,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
- Type *SrcTy = VL0->getOperand(0)->getType();
- for (Value *V : VL) {
- Type *Ty = cast<Instruction>(V)->getOperand(0)->getType();
- if (Ty != SrcTy || !isValidElementType(Ty)) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs()
- << "SLP: Gathering casts with different src types.\n");
- return;
- }
- }
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n");
@@ -5521,21 +5965,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
case Instruction::FCmp: {
// Check that all of the compares have the same predicate.
CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
- CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0);
- Type *ComparedTy = VL0->getOperand(0)->getType();
- for (Value *V : VL) {
- CmpInst *Cmp = cast<CmpInst>(V);
- if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) ||
- Cmp->getOperand(0)->getType() != ComparedTy) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs()
- << "SLP: Gathering cmp with different predicate.\n");
- return;
- }
- }
-
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n");
@@ -5544,7 +5973,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (cast<CmpInst>(VL0)->isCommutative()) {
// Commutative predicate - collect + sort operands of the instructions
// so that each side is more likely to have the same opcode.
- assert(P0 == SwapP0 && "Commutative Predicate mismatch");
+ assert(P0 == CmpInst::getSwappedPredicate(P0) &&
+ "Commutative Predicate mismatch");
reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this);
} else {
// Collect operands - commute if it uses the swapped predicate.
@@ -5612,60 +6042,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
return;
}
case Instruction::GetElementPtr: {
- // We don't combine GEPs with complicated (nested) indexing.
- for (Value *V : VL) {
- auto *I = dyn_cast<GetElementPtrInst>(V);
- if (!I)
- continue;
- if (I->getNumOperands() != 2) {
- LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n");
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- return;
- }
- }
-
- // We can't combine several GEPs into one vector if they operate on
- // different types.
- Type *Ty0 = cast<GEPOperator>(VL0)->getSourceElementType();
- for (Value *V : VL) {
- auto *GEP = dyn_cast<GEPOperator>(V);
- if (!GEP)
- continue;
- Type *CurTy = GEP->getSourceElementType();
- if (Ty0 != CurTy) {
- LLVM_DEBUG(dbgs()
- << "SLP: not-vectorizable GEP (different types).\n");
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- return;
- }
- }
-
- // We don't combine GEPs with non-constant indexes.
- Type *Ty1 = VL0->getOperand(1)->getType();
- for (Value *V : VL) {
- auto *I = dyn_cast<GetElementPtrInst>(V);
- if (!I)
- continue;
- auto *Op = I->getOperand(1);
- if ((!IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) ||
- (Op->getType() != Ty1 &&
- ((IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) ||
- Op->getType()->getScalarSizeInBits() >
- DL->getIndexSizeInBits(
- V->getType()->getPointerAddressSpace())))) {
- LLVM_DEBUG(dbgs()
- << "SLP: not-vectorizable GEP (non-constant indexes).\n");
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- return;
- }
- }
-
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n");
@@ -5722,78 +6098,29 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
case Instruction::Store: {
// Check if the stores are consecutive or if we need to swizzle them.
- llvm::Type *ScalarTy = cast<StoreInst>(VL0)->getValueOperand()->getType();
- // Avoid types that are padded when being allocated as scalars, while
- // being packed together in a vector (such as i1).
- if (DL->getTypeSizeInBits(ScalarTy) !=
- DL->getTypeAllocSizeInBits(ScalarTy)) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: Gathering stores of non-packed type.\n");
- return;
- }
- // Make sure all stores in the bundle are simple - we can't vectorize
- // atomic or volatile stores.
- SmallVector<Value *, 4> PointerOps(VL.size());
ValueList Operands(VL.size());
- auto POIter = PointerOps.begin();
- auto OIter = Operands.begin();
+ auto *OIter = Operands.begin();
for (Value *V : VL) {
auto *SI = cast<StoreInst>(V);
- if (!SI->isSimple()) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple stores.\n");
- return;
- }
- *POIter = SI->getPointerOperand();
*OIter = SI->getValueOperand();
- ++POIter;
++OIter;
}
-
- OrdersType CurrentOrder;
- // Check the order of pointer operands.
- if (llvm::sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, CurrentOrder)) {
- Value *Ptr0;
- Value *PtrN;
- if (CurrentOrder.empty()) {
- Ptr0 = PointerOps.front();
- PtrN = PointerOps.back();
- } else {
- Ptr0 = PointerOps[CurrentOrder.front()];
- PtrN = PointerOps[CurrentOrder.back()];
- }
- std::optional<int> Dist =
- getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
- // Check that the sorted pointer operands are consecutive.
- if (static_cast<unsigned>(*Dist) == VL.size() - 1) {
- if (CurrentOrder.empty()) {
- // Original stores are consecutive and does not require reordering.
- TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S,
- UserTreeIdx, ReuseShuffleIndicies);
- TE->setOperandsInOrder();
- buildTree_rec(Operands, Depth + 1, {TE, 0});
- LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n");
- } else {
- fixupOrderingIndices(CurrentOrder);
- TreeEntry *TE =
- newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies, CurrentOrder);
- TE->setOperandsInOrder();
- buildTree_rec(Operands, Depth + 1, {TE, 0});
- LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n");
- }
- return;
- }
+ // Check that the sorted pointer operands are consecutive.
+ if (CurrentOrder.empty()) {
+ // Original stores are consecutive and does not require reordering.
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
+ TE->setOperandsInOrder();
+ buildTree_rec(Operands, Depth + 1, {TE, 0});
+ LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n");
+ } else {
+ fixupOrderingIndices(CurrentOrder);
+ TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies, CurrentOrder);
+ TE->setOperandsInOrder();
+ buildTree_rec(Operands, Depth + 1, {TE, 0});
+ LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n");
}
-
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n");
return;
}
case Instruction::Call: {
@@ -5802,68 +6129,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
CallInst *CI = cast<CallInst>(VL0);
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
- VFShape Shape = VFShape::get(
- *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
- false /*HasGlobalPred*/);
- Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
-
- if (!VecFunc && !isTriviallyVectorizable(ID)) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
- return;
- }
- Function *F = CI->getCalledFunction();
- unsigned NumArgs = CI->arg_size();
- SmallVector<Value*, 4> ScalarArgs(NumArgs, nullptr);
- for (unsigned j = 0; j != NumArgs; ++j)
- if (isVectorIntrinsicWithScalarOpAtArg(ID, j))
- ScalarArgs[j] = CI->getArgOperand(j);
- for (Value *V : VL) {
- CallInst *CI2 = dyn_cast<CallInst>(V);
- if (!CI2 || CI2->getCalledFunction() != F ||
- getVectorIntrinsicIDForCall(CI2, TLI) != ID ||
- (VecFunc &&
- VecFunc != VFDatabase(*CI2).getVectorizedFunction(Shape)) ||
- !CI->hasIdenticalOperandBundleSchema(*CI2)) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V
- << "\n");
- return;
- }
- // Some intrinsics have scalar arguments and should be same in order for
- // them to be vectorized.
- for (unsigned j = 0; j != NumArgs; ++j) {
- if (isVectorIntrinsicWithScalarOpAtArg(ID, j)) {
- Value *A1J = CI2->getArgOperand(j);
- if (ScalarArgs[j] != A1J) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI
- << " argument " << ScalarArgs[j] << "!=" << A1J
- << "\n");
- return;
- }
- }
- }
- // Verify that the bundle operands are identical between the two calls.
- if (CI->hasOperandBundles() &&
- !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(),
- CI->op_begin() + CI->getBundleOperandsEndIndex(),
- CI2->op_begin() + CI2->getBundleOperandsStartIndex())) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:"
- << *CI << "!=" << *V << '\n');
- return;
- }
- }
-
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
TE->setOperandsInOrder();
@@ -5883,15 +6148,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
return;
}
case Instruction::ShuffleVector: {
- // If this is not an alternate sequence of opcode like add-sub
- // then do not vectorize this instruction.
- if (!S.isAltShuffle()) {
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");
- return;
- }
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");
@@ -5949,19 +6205,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
return;
}
default:
- BS.cancelScheduling(VL, VL0);
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
- LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");
- return;
+ break;
}
+ llvm_unreachable("Unexpected vectorization of the instructions.");
}
unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
unsigned N = 1;
Type *EltTy = T;
- while (isa<StructType, ArrayType, VectorType>(EltTy)) {
+ while (isa<StructType, ArrayType, FixedVectorType>(EltTy)) {
if (auto *ST = dyn_cast<StructType>(EltTy)) {
// Check that struct is homogeneous.
for (const auto *Ty : ST->elements())
@@ -5982,7 +6235,8 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
if (!isValidElementType(EltTy))
return 0;
uint64_t VTSize = DL.getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N));
- if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL.getTypeStoreSizeInBits(T))
+ if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize ||
+ VTSize != DL.getTypeStoreSizeInBits(T))
return 0;
return N;
}
@@ -6111,68 +6365,6 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
return {IntrinsicCost, LibCost};
}
-/// Compute the cost of creating a vector of type \p VecTy containing the
-/// extracted values from \p VL.
-static InstructionCost
-computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy,
- TargetTransformInfo::ShuffleKind ShuffleKind,
- ArrayRef<int> Mask, TargetTransformInfo &TTI) {
- unsigned NumOfParts = TTI.getNumberOfParts(VecTy);
-
- if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc || !NumOfParts ||
- VecTy->getNumElements() < NumOfParts)
- return TTI.getShuffleCost(ShuffleKind, VecTy, Mask);
-
- bool AllConsecutive = true;
- unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts;
- unsigned Idx = -1;
- InstructionCost Cost = 0;
-
- // Process extracts in blocks of EltsPerVector to check if the source vector
- // operand can be re-used directly. If not, add the cost of creating a shuffle
- // to extract the values into a vector register.
- SmallVector<int> RegMask(EltsPerVector, UndefMaskElem);
- for (auto *V : VL) {
- ++Idx;
-
- // Reached the start of a new vector registers.
- if (Idx % EltsPerVector == 0) {
- RegMask.assign(EltsPerVector, UndefMaskElem);
- AllConsecutive = true;
- continue;
- }
-
- // Need to exclude undefs from analysis.
- if (isa<UndefValue>(V) || Mask[Idx] == UndefMaskElem)
- continue;
-
- // Check all extracts for a vector register on the target directly
- // extract values in order.
- unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V));
- if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != UndefMaskElem) {
- unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1]));
- AllConsecutive &= PrevIdx + 1 == CurrentIdx &&
- CurrentIdx % EltsPerVector == Idx % EltsPerVector;
- RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector;
- }
-
- if (AllConsecutive)
- continue;
-
- // Skip all indices, except for the last index per vector block.
- if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size())
- continue;
-
- // If we have a series of extracts which are not consecutive and hence
- // cannot re-use the source vector register directly, compute the shuffle
- // cost to extract the vector with EltsPerVector elements.
- Cost += TTI.getShuffleCost(
- TargetTransformInfo::SK_PermuteSingleSrc,
- FixedVectorType::get(VecTy->getElementType(), EltsPerVector), RegMask);
- }
- return Cost;
-}
-
/// Build shuffle mask for shuffle graph entries and lists of main and alternate
/// operations operands.
static void
@@ -6183,7 +6375,7 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices,
SmallVectorImpl<Value *> *OpScalars = nullptr,
SmallVectorImpl<Value *> *AltScalars = nullptr) {
unsigned Sz = VL.size();
- Mask.assign(Sz, UndefMaskElem);
+ Mask.assign(Sz, PoisonMaskElem);
SmallVector<int> OrderMask;
if (!ReorderIndices.empty())
inversePermutation(ReorderIndices, OrderMask);
@@ -6203,9 +6395,9 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices,
}
}
if (!ReusesIndices.empty()) {
- SmallVector<int> NewMask(ReusesIndices.size(), UndefMaskElem);
+ SmallVector<int> NewMask(ReusesIndices.size(), PoisonMaskElem);
transform(ReusesIndices, NewMask.begin(), [&Mask](int Idx) {
- return Idx != UndefMaskElem ? Mask[Idx] : UndefMaskElem;
+ return Idx != PoisonMaskElem ? Mask[Idx] : PoisonMaskElem;
});
Mask.swap(NewMask);
}
@@ -6325,13 +6517,13 @@ protected:
static void combineMasks(unsigned LocalVF, SmallVectorImpl<int> &Mask,
ArrayRef<int> ExtMask) {
unsigned VF = Mask.size();
- SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem);
+ SmallVector<int> NewMask(ExtMask.size(), PoisonMaskElem);
for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) {
- if (ExtMask[I] == UndefMaskElem)
+ if (ExtMask[I] == PoisonMaskElem)
continue;
int MaskedIdx = Mask[ExtMask[I] % VF];
NewMask[I] =
- MaskedIdx == UndefMaskElem ? UndefMaskElem : MaskedIdx % LocalVF;
+ MaskedIdx == PoisonMaskElem ? PoisonMaskElem : MaskedIdx % LocalVF;
}
Mask.swap(NewMask);
}
@@ -6418,11 +6610,12 @@ protected:
if (auto *SVOpTy =
dyn_cast<FixedVectorType>(SV->getOperand(0)->getType()))
LocalVF = SVOpTy->getNumElements();
- SmallVector<int> ExtMask(Mask.size(), UndefMaskElem);
+ SmallVector<int> ExtMask(Mask.size(), PoisonMaskElem);
for (auto [Idx, I] : enumerate(Mask)) {
- if (I == UndefMaskElem)
- continue;
- ExtMask[Idx] = SV->getMaskValue(I);
+ if (I == PoisonMaskElem ||
+ static_cast<unsigned>(I) >= SV->getShuffleMask().size())
+ continue;
+ ExtMask[Idx] = SV->getMaskValue(I);
}
bool IsOp1Undef =
isUndefVector(SV->getOperand(0),
@@ -6435,11 +6628,11 @@ protected:
if (!IsOp1Undef && !IsOp2Undef) {
// Update mask and mark undef elems.
for (int &I : Mask) {
- if (I == UndefMaskElem)
+ if (I == PoisonMaskElem)
continue;
if (SV->getMaskValue(I % SV->getShuffleMask().size()) ==
- UndefMaskElem)
- I = UndefMaskElem;
+ PoisonMaskElem)
+ I = PoisonMaskElem;
}
break;
}
@@ -6453,15 +6646,16 @@ protected:
Op = SV->getOperand(1);
}
if (auto *OpTy = dyn_cast<FixedVectorType>(Op->getType());
- !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute)) {
+ !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute) ||
+ ShuffleVectorInst::isZeroEltSplatMask(Mask)) {
if (IdentityOp) {
V = IdentityOp;
assert(Mask.size() == IdentityMask.size() &&
"Expected masks of same sizes.");
// Clear known poison elements.
for (auto [I, Idx] : enumerate(Mask))
- if (Idx == UndefMaskElem)
- IdentityMask[I] = UndefMaskElem;
+ if (Idx == PoisonMaskElem)
+ IdentityMask[I] = PoisonMaskElem;
Mask.swap(IdentityMask);
auto *Shuffle = dyn_cast<ShuffleVectorInst>(V);
return SinglePermute &&
@@ -6481,10 +6675,12 @@ protected:
/// Smart shuffle instruction emission, walks through shuffles trees and
/// tries to find the best matching vector for the actual shuffle
/// instruction.
- template <typename ShuffleBuilderTy>
- static Value *createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask,
- ShuffleBuilderTy &Builder) {
+ template <typename T, typename ShuffleBuilderTy>
+ static T createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask,
+ ShuffleBuilderTy &Builder) {
assert(V1 && "Expected at least one vector value.");
+ if (V2)
+ Builder.resizeToMatch(V1, V2);
int VF = Mask.size();
if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
VF = FTy->getNumElements();
@@ -6495,8 +6691,8 @@ protected:
Value *Op2 = V2;
int VF =
cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue();
- SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem);
- SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem);
+ SmallVector<int> CombinedMask1(Mask.size(), PoisonMaskElem);
+ SmallVector<int> CombinedMask2(Mask.size(), PoisonMaskElem);
for (int I = 0, E = Mask.size(); I < E; ++I) {
if (Mask[I] < VF)
CombinedMask1[I] = Mask[I];
@@ -6514,9 +6710,9 @@ protected:
// again.
if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1))
if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) {
- SmallVector<int> ExtMask1(Mask.size(), UndefMaskElem);
+ SmallVector<int> ExtMask1(Mask.size(), PoisonMaskElem);
for (auto [Idx, I] : enumerate(CombinedMask1)) {
- if (I == UndefMaskElem)
+ if (I == PoisonMaskElem)
continue;
ExtMask1[Idx] = SV1->getMaskValue(I);
}
@@ -6524,9 +6720,9 @@ protected:
cast<FixedVectorType>(SV1->getOperand(1)->getType())
->getNumElements(),
ExtMask1, UseMask::SecondArg);
- SmallVector<int> ExtMask2(CombinedMask2.size(), UndefMaskElem);
+ SmallVector<int> ExtMask2(CombinedMask2.size(), PoisonMaskElem);
for (auto [Idx, I] : enumerate(CombinedMask2)) {
- if (I == UndefMaskElem)
+ if (I == PoisonMaskElem)
continue;
ExtMask2[Idx] = SV2->getMaskValue(I);
}
@@ -6566,64 +6762,360 @@ protected:
->getElementCount()
.getKnownMinValue());
for (int I = 0, E = Mask.size(); I < E; ++I) {
- if (CombinedMask2[I] != UndefMaskElem) {
- assert(CombinedMask1[I] == UndefMaskElem &&
+ if (CombinedMask2[I] != PoisonMaskElem) {
+ assert(CombinedMask1[I] == PoisonMaskElem &&
"Expected undefined mask element");
CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF);
}
}
+ const int Limit = CombinedMask1.size() * 2;
+ if (Op1 == Op2 && Limit == 2 * VF &&
+ all_of(CombinedMask1, [=](int Idx) { return Idx < Limit; }) &&
+ (ShuffleVectorInst::isIdentityMask(CombinedMask1) ||
+ (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1) &&
+ isa<ShuffleVectorInst>(Op1) &&
+ cast<ShuffleVectorInst>(Op1)->getShuffleMask() ==
+ ArrayRef(CombinedMask1))))
+ return Builder.createIdentity(Op1);
return Builder.createShuffleVector(
Op1, Op1 == Op2 ? PoisonValue::get(Op1->getType()) : Op2,
CombinedMask1);
}
if (isa<PoisonValue>(V1))
- return PoisonValue::get(FixedVectorType::get(
- cast<VectorType>(V1->getType())->getElementType(), Mask.size()));
+ return Builder.createPoison(
+ cast<VectorType>(V1->getType())->getElementType(), Mask.size());
SmallVector<int> NewMask(Mask.begin(), Mask.end());
bool IsIdentity = peekThroughShuffles(V1, NewMask, /*SinglePermute=*/true);
assert(V1 && "Expected non-null value after looking through shuffles.");
if (!IsIdentity)
return Builder.createShuffleVector(V1, NewMask);
- return V1;
+ return Builder.createIdentity(V1);
}
};
} // namespace
-InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
- ArrayRef<Value *> VectorizedVals) {
- ArrayRef<Value *> VL = E->Scalars;
+/// Merges shuffle masks and emits final shuffle instruction, if required. It
+/// supports shuffling of 2 input vectors. It implements lazy shuffles emission,
+/// when the actual shuffle instruction is generated only if this is actually
+/// required. Otherwise, the shuffle instruction emission is delayed till the
+/// end of the process, to reduce the number of emitted instructions and further
+/// analysis/transformations.
+class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
+ bool IsFinalized = false;
+ SmallVector<int> CommonMask;
+ SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
+ const TargetTransformInfo &TTI;
+ InstructionCost Cost = 0;
+ ArrayRef<Value *> VectorizedVals;
+ BoUpSLP &R;
+ SmallPtrSetImpl<Value *> &CheckedExtracts;
+ constexpr static TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) {
+ if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof))
+ return TTI::TCC_Free;
+ auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
+ InstructionCost GatherCost = 0;
+ SmallVector<Value *> Gathers(VL.begin(), VL.end());
+ // Improve gather cost for gather of loads, if we can group some of the
+ // loads into vector loads.
+ InstructionsState S = getSameOpcode(VL, *R.TLI);
+ if (VL.size() > 2 && S.getOpcode() == Instruction::Load &&
+ !S.isAltShuffle() &&
+ !all_of(Gathers, [&](Value *V) { return R.getTreeEntry(V); }) &&
+ !isSplat(Gathers)) {
+ BoUpSLP::ValueSet VectorizedLoads;
+ unsigned StartIdx = 0;
+ unsigned VF = VL.size() / 2;
+ unsigned VectorizedCnt = 0;
+ unsigned ScatterVectorizeCnt = 0;
+ const unsigned Sz = R.DL->getTypeSizeInBits(S.MainOp->getType());
+ for (unsigned MinVF = R.getMinVF(2 * Sz); VF >= MinVF; VF /= 2) {
+ for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End;
+ Cnt += VF) {
+ ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
+ if (!VectorizedLoads.count(Slice.front()) &&
+ !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
+ SmallVector<Value *> PointerOps;
+ OrdersType CurrentOrder;
+ LoadsState LS =
+ canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE,
+ *R.LI, *R.TLI, CurrentOrder, PointerOps);
+ switch (LS) {
+ case LoadsState::Vectorize:
+ case LoadsState::ScatterVectorize:
+ // Mark the vectorized loads so that we don't vectorize them
+ // again.
+ if (LS == LoadsState::Vectorize)
+ ++VectorizedCnt;
+ else
+ ++ScatterVectorizeCnt;
+ VectorizedLoads.insert(Slice.begin(), Slice.end());
+ // If we vectorized initial block, no need to try to vectorize
+ // it again.
+ if (Cnt == StartIdx)
+ StartIdx += VF;
+ break;
+ case LoadsState::Gather:
+ break;
+ }
+ }
+ }
+ // Check if the whole array was vectorized already - exit.
+ if (StartIdx >= VL.size())
+ break;
+ // Found vectorizable parts - exit.
+ if (!VectorizedLoads.empty())
+ break;
+ }
+ if (!VectorizedLoads.empty()) {
+ unsigned NumParts = TTI.getNumberOfParts(VecTy);
+ bool NeedInsertSubvectorAnalysis =
+ !NumParts || (VL.size() / VF) > NumParts;
+ // Get the cost for gathered loads.
+ for (unsigned I = 0, End = VL.size(); I < End; I += VF) {
+ if (VectorizedLoads.contains(VL[I]))
+ continue;
+ GatherCost += getBuildVectorCost(VL.slice(I, VF), Root);
+ }
+ // Exclude potentially vectorized loads from list of gathered
+ // scalars.
+ auto *LI = cast<LoadInst>(S.MainOp);
+ Gathers.assign(Gathers.size(), PoisonValue::get(LI->getType()));
+ // The cost for vectorized loads.
+ InstructionCost ScalarsCost = 0;
+ for (Value *V : VectorizedLoads) {
+ auto *LI = cast<LoadInst>(V);
+ ScalarsCost +=
+ TTI.getMemoryOpCost(Instruction::Load, LI->getType(),
+ LI->getAlign(), LI->getPointerAddressSpace(),
+ CostKind, TTI::OperandValueInfo(), LI);
+ }
+ auto *LoadTy = FixedVectorType::get(LI->getType(), VF);
+ Align Alignment = LI->getAlign();
+ GatherCost +=
+ VectorizedCnt *
+ TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment,
+ LI->getPointerAddressSpace(), CostKind,
+ TTI::OperandValueInfo(), LI);
+ GatherCost += ScatterVectorizeCnt *
+ TTI.getGatherScatterOpCost(
+ Instruction::Load, LoadTy, LI->getPointerOperand(),
+ /*VariableMask=*/false, Alignment, CostKind, LI);
+ if (NeedInsertSubvectorAnalysis) {
+ // Add the cost for the subvectors insert.
+ for (int I = VF, E = VL.size(); I < E; I += VF)
+ GatherCost += TTI.getShuffleCost(TTI::SK_InsertSubvector, VecTy,
+ std::nullopt, CostKind, I, LoadTy);
+ }
+ GatherCost -= ScalarsCost;
+ }
+ } else if (!Root && isSplat(VL)) {
+ // Found the broadcasting of the single scalar, calculate the cost as
+ // the broadcast.
+ const auto *It =
+ find_if(VL, [](Value *V) { return !isa<UndefValue>(V); });
+ assert(It != VL.end() && "Expected at least one non-undef value.");
+ // Add broadcast for non-identity shuffle only.
+ bool NeedShuffle =
+ count(VL, *It) > 1 &&
+ (VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof));
+ InstructionCost InsertCost = TTI.getVectorInstrCost(
+ Instruction::InsertElement, VecTy, CostKind,
+ NeedShuffle ? 0 : std::distance(VL.begin(), It),
+ PoisonValue::get(VecTy), *It);
+ return InsertCost +
+ (NeedShuffle ? TTI.getShuffleCost(
+ TargetTransformInfo::SK_Broadcast, VecTy,
+ /*Mask=*/std::nullopt, CostKind, /*Index=*/0,
+ /*SubTp=*/nullptr, /*Args=*/*It)
+ : TTI::TCC_Free);
+ }
+ return GatherCost +
+ (all_of(Gathers, UndefValue::classof)
+ ? TTI::TCC_Free
+ : R.getGatherCost(Gathers, !Root && VL.equals(Gathers)));
+ };
- Type *ScalarTy = VL[0]->getType();
- if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
- ScalarTy = SI->getValueOperand()->getType();
- else if (CmpInst *CI = dyn_cast<CmpInst>(VL[0]))
- ScalarTy = CI->getOperand(0)->getType();
- else if (auto *IE = dyn_cast<InsertElementInst>(VL[0]))
- ScalarTy = IE->getOperand(1)->getType();
- auto *VecTy = FixedVectorType::get(ScalarTy, VL.size());
- TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ /// Compute the cost of creating a vector of type \p VecTy containing the
+ /// extracted values from \p VL.
+ InstructionCost computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask,
+ TTI::ShuffleKind ShuffleKind) {
+ auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
+ unsigned NumOfParts = TTI.getNumberOfParts(VecTy);
- // If we have computed a smaller type for the expression, update VecTy so
- // that the costs will be accurate.
- if (MinBWs.count(VL[0]))
- VecTy = FixedVectorType::get(
- IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size());
- unsigned EntryVF = E->getVectorFactor();
- auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF);
+ if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc ||
+ !NumOfParts || VecTy->getNumElements() < NumOfParts)
+ return TTI.getShuffleCost(ShuffleKind, VecTy, Mask);
- bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
- // FIXME: it tries to fix a problem with MSVC buildbots.
- TargetTransformInfo *TTI = this->TTI;
- auto AdjustExtractsCost = [=](InstructionCost &Cost) {
+ bool AllConsecutive = true;
+ unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts;
+ unsigned Idx = -1;
+ InstructionCost Cost = 0;
+
+ // Process extracts in blocks of EltsPerVector to check if the source vector
+ // operand can be re-used directly. If not, add the cost of creating a
+ // shuffle to extract the values into a vector register.
+ SmallVector<int> RegMask(EltsPerVector, PoisonMaskElem);
+ for (auto *V : VL) {
+ ++Idx;
+
+ // Reached the start of a new vector registers.
+ if (Idx % EltsPerVector == 0) {
+ RegMask.assign(EltsPerVector, PoisonMaskElem);
+ AllConsecutive = true;
+ continue;
+ }
+
+ // Need to exclude undefs from analysis.
+ if (isa<UndefValue>(V) || Mask[Idx] == PoisonMaskElem)
+ continue;
+
+ // Check all extracts for a vector register on the target directly
+ // extract values in order.
+ unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V));
+ if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != PoisonMaskElem) {
+ unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1]));
+ AllConsecutive &= PrevIdx + 1 == CurrentIdx &&
+ CurrentIdx % EltsPerVector == Idx % EltsPerVector;
+ RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector;
+ }
+
+ if (AllConsecutive)
+ continue;
+
+ // Skip all indices, except for the last index per vector block.
+ if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size())
+ continue;
+
+ // If we have a series of extracts which are not consecutive and hence
+ // cannot re-use the source vector register directly, compute the shuffle
+ // cost to extract the vector with EltsPerVector elements.
+ Cost += TTI.getShuffleCost(
+ TargetTransformInfo::SK_PermuteSingleSrc,
+ FixedVectorType::get(VecTy->getElementType(), EltsPerVector),
+ RegMask);
+ }
+ return Cost;
+ }
+
+ class ShuffleCostBuilder {
+ const TargetTransformInfo &TTI;
+
+ static bool isEmptyOrIdentity(ArrayRef<int> Mask, unsigned VF) {
+ int Limit = 2 * VF;
+ return Mask.empty() ||
+ (VF == Mask.size() &&
+ all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) &&
+ ShuffleVectorInst::isIdentityMask(Mask));
+ }
+
+ public:
+ ShuffleCostBuilder(const TargetTransformInfo &TTI) : TTI(TTI) {}
+ ~ShuffleCostBuilder() = default;
+ InstructionCost createShuffleVector(Value *V1, Value *,
+ ArrayRef<int> Mask) const {
+ // Empty mask or identity mask are free.
+ unsigned VF =
+ cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue();
+ if (isEmptyOrIdentity(Mask, VF))
+ return TTI::TCC_Free;
+ return TTI.getShuffleCost(
+ TTI::SK_PermuteTwoSrc,
+ FixedVectorType::get(
+ cast<VectorType>(V1->getType())->getElementType(), Mask.size()),
+ Mask);
+ }
+ InstructionCost createShuffleVector(Value *V1, ArrayRef<int> Mask) const {
+ // Empty mask or identity mask are free.
+ if (isEmptyOrIdentity(Mask, Mask.size()))
+ return TTI::TCC_Free;
+ return TTI.getShuffleCost(
+ TTI::SK_PermuteSingleSrc,
+ FixedVectorType::get(
+ cast<VectorType>(V1->getType())->getElementType(), Mask.size()),
+ Mask);
+ }
+ InstructionCost createIdentity(Value *) const { return TTI::TCC_Free; }
+ InstructionCost createPoison(Type *Ty, unsigned VF) const {
+ return TTI::TCC_Free;
+ }
+ void resizeToMatch(Value *&, Value *&) const {}
+ };
+
+ /// Smart shuffle instruction emission, walks through shuffles trees and
+ /// tries to find the best matching vector for the actual shuffle
+ /// instruction.
+ InstructionCost
+ createShuffle(const PointerUnion<Value *, const TreeEntry *> &P1,
+ const PointerUnion<Value *, const TreeEntry *> &P2,
+ ArrayRef<int> Mask) {
+ ShuffleCostBuilder Builder(TTI);
+ Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>();
+ unsigned CommonVF = 0;
+ if (!V1) {
+ const TreeEntry *E = P1.get<const TreeEntry *>();
+ unsigned VF = E->getVectorFactor();
+ if (V2) {
+ unsigned V2VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+ if (V2VF != VF && V2VF == E->Scalars.size())
+ VF = E->Scalars.size();
+ } else if (!P2.isNull()) {
+ const TreeEntry *E2 = P2.get<const TreeEntry *>();
+ if (E->Scalars.size() == E2->Scalars.size())
+ CommonVF = VF = E->Scalars.size();
+ } else {
+ // P2 is empty, check that we have same node + reshuffle (if any).
+ if (E->Scalars.size() == Mask.size() && VF != Mask.size()) {
+ VF = E->Scalars.size();
+ SmallVector<int> CommonMask(Mask.begin(), Mask.end());
+ ::addMask(CommonMask, E->getCommonMask());
+ V1 = Constant::getNullValue(
+ FixedVectorType::get(E->Scalars.front()->getType(), VF));
+ return BaseShuffleAnalysis::createShuffle<InstructionCost>(
+ V1, nullptr, CommonMask, Builder);
+ }
+ }
+ V1 = Constant::getNullValue(
+ FixedVectorType::get(E->Scalars.front()->getType(), VF));
+ }
+ if (!V2 && !P2.isNull()) {
+ const TreeEntry *E = P2.get<const TreeEntry *>();
+ unsigned VF = E->getVectorFactor();
+ unsigned V1VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ if (!CommonVF && V1VF == E->Scalars.size())
+ CommonVF = E->Scalars.size();
+ if (CommonVF)
+ VF = CommonVF;
+ V2 = Constant::getNullValue(
+ FixedVectorType::get(E->Scalars.front()->getType(), VF));
+ }
+ return BaseShuffleAnalysis::createShuffle<InstructionCost>(V1, V2, Mask,
+ Builder);
+ }
+
+public:
+ ShuffleCostEstimator(TargetTransformInfo &TTI,
+ ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
+ SmallPtrSetImpl<Value *> &CheckedExtracts)
+ : TTI(TTI), VectorizedVals(VectorizedVals), R(R),
+ CheckedExtracts(CheckedExtracts) {}
+ Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask,
+ TTI::ShuffleKind ShuffleKind) {
+ if (Mask.empty())
+ return nullptr;
+ Value *VecBase = nullptr;
+ ArrayRef<Value *> VL = E->Scalars;
+ auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
// If the resulting type is scalarized, do not adjust the cost.
- unsigned VecNumParts = TTI->getNumberOfParts(VecTy);
+ unsigned VecNumParts = TTI.getNumberOfParts(VecTy);
if (VecNumParts == VecTy->getNumElements())
- return;
+ return nullptr;
DenseMap<Value *, int> ExtractVectorsTys;
- SmallPtrSet<Value *, 4> CheckedExtracts;
- for (auto *V : VL) {
- if (isa<UndefValue>(V))
+ for (auto [I, V] : enumerate(VL)) {
+ // Ignore non-extractelement scalars.
+ if (isa<UndefValue>(V) || (!Mask.empty() && Mask[I] == PoisonMaskElem))
continue;
// If all users of instruction are going to be vectorized and this
// instruction itself is not going to be vectorized, consider this
@@ -6631,17 +7123,18 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
// vectorized tree.
// Also, avoid adjusting the cost for extractelements with multiple uses
// in different graph entries.
- const TreeEntry *VE = getTreeEntry(V);
+ const TreeEntry *VE = R.getTreeEntry(V);
if (!CheckedExtracts.insert(V).second ||
- !areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) ||
+ !R.areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) ||
(VE && VE != E))
continue;
auto *EE = cast<ExtractElementInst>(V);
+ VecBase = EE->getVectorOperand();
std::optional<unsigned> EEIdx = getExtractIndex(EE);
if (!EEIdx)
continue;
unsigned Idx = *EEIdx;
- if (VecNumParts != TTI->getNumberOfParts(EE->getVectorOperandType())) {
+ if (VecNumParts != TTI.getNumberOfParts(EE->getVectorOperandType())) {
auto It =
ExtractVectorsTys.try_emplace(EE->getVectorOperand(), Idx).first;
It->getSecond() = std::min<int>(It->second, Idx);
@@ -6654,18 +7147,17 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
})) {
// Use getExtractWithExtendCost() to calculate the cost of
// extractelement/ext pair.
- Cost -=
- TTI->getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(),
- EE->getVectorOperandType(), Idx);
+ Cost -= TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(),
+ EE->getVectorOperandType(), Idx);
// Add back the cost of s|zext which is subtracted separately.
- Cost += TTI->getCastInstrCost(
+ Cost += TTI.getCastInstrCost(
Ext->getOpcode(), Ext->getType(), EE->getType(),
TTI::getCastContextHint(Ext), CostKind, Ext);
continue;
}
}
- Cost -= TTI->getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind,
- Idx);
+ Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind,
+ Idx);
}
// Add a cost for subvector extracts/inserts if required.
for (const auto &Data : ExtractVectorsTys) {
@@ -6673,34 +7165,148 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
unsigned NumElts = VecTy->getNumElements();
if (Data.second % NumElts == 0)
continue;
- if (TTI->getNumberOfParts(EEVTy) > VecNumParts) {
+ if (TTI.getNumberOfParts(EEVTy) > VecNumParts) {
unsigned Idx = (Data.second / NumElts) * NumElts;
unsigned EENumElts = EEVTy->getNumElements();
+ if (Idx % NumElts == 0)
+ continue;
if (Idx + NumElts <= EENumElts) {
- Cost +=
- TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
- EEVTy, std::nullopt, CostKind, Idx, VecTy);
+ Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
+ EEVTy, std::nullopt, CostKind, Idx, VecTy);
} else {
// Need to round up the subvector type vectorization factor to avoid a
// crash in cost model functions. Make SubVT so that Idx + VF of SubVT
// <= EENumElts.
auto *SubVT =
FixedVectorType::get(VecTy->getElementType(), EENumElts - Idx);
- Cost +=
- TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
- EEVTy, std::nullopt, CostKind, Idx, SubVT);
+ Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
+ EEVTy, std::nullopt, CostKind, Idx, SubVT);
}
} else {
- Cost += TTI->getShuffleCost(TargetTransformInfo::SK_InsertSubvector,
- VecTy, std::nullopt, CostKind, 0, EEVTy);
+ Cost += TTI.getShuffleCost(TargetTransformInfo::SK_InsertSubvector,
+ VecTy, std::nullopt, CostKind, 0, EEVTy);
}
}
- };
+ // Check that gather of extractelements can be represented as just a
+ // shuffle of a single/two vectors the scalars are extracted from.
+ // Found the bunch of extractelement instructions that must be gathered
+ // into a vector and can be represented as a permutation elements in a
+ // single input vector or of 2 input vectors.
+ Cost += computeExtractCost(VL, Mask, ShuffleKind);
+ return VecBase;
+ }
+ void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef<int> Mask) {
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign({E1, E2});
+ }
+ void add(const TreeEntry *E1, ArrayRef<int> Mask) {
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign(1, E1);
+ }
+ /// Adds another one input vector and the mask for the shuffling.
+ void add(Value *V1, ArrayRef<int> Mask) {
+ assert(CommonMask.empty() && InVectors.empty() &&
+ "Expected empty input mask/vectors.");
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign(1, V1);
+ }
+ Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
+ Cost += getBuildVectorCost(VL, Root);
+ if (!Root) {
+ assert(InVectors.empty() && "Unexpected input vectors for buildvector.");
+ // FIXME: Need to find a way to avoid use of getNullValue here.
+ SmallVector<Constant *> Vals;
+ for (Value *V : VL) {
+ if (isa<UndefValue>(V)) {
+ Vals.push_back(cast<Constant>(V));
+ continue;
+ }
+ Vals.push_back(Constant::getNullValue(V->getType()));
+ }
+ return ConstantVector::get(Vals);
+ }
+ return ConstantVector::getSplat(
+ ElementCount::getFixed(VL.size()),
+ Constant::getNullValue(VL.front()->getType()));
+ }
+ /// Finalize emission of the shuffles.
+ InstructionCost
+ finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
+ function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
+ IsFinalized = true;
+ if (Action) {
+ const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
+ if (InVectors.size() == 2) {
+ Cost += createShuffle(Vec, InVectors.back(), CommonMask);
+ InVectors.pop_back();
+ } else {
+ Cost += createShuffle(Vec, nullptr, CommonMask);
+ }
+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
+ if (CommonMask[Idx] != PoisonMaskElem)
+ CommonMask[Idx] = Idx;
+ assert(VF > 0 &&
+ "Expected vector length for the final value before action.");
+ Value *V = Vec.dyn_cast<Value *>();
+ if (!Vec.isNull() && !V)
+ V = Constant::getNullValue(FixedVectorType::get(
+ Vec.get<const TreeEntry *>()->Scalars.front()->getType(),
+ CommonMask.size()));
+ Action(V, CommonMask);
+ }
+ ::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true);
+ if (CommonMask.empty())
+ return Cost;
+ int Limit = CommonMask.size() * 2;
+ if (all_of(CommonMask, [=](int Idx) { return Idx < Limit; }) &&
+ ShuffleVectorInst::isIdentityMask(CommonMask))
+ return Cost;
+ return Cost +
+ createShuffle(InVectors.front(),
+ InVectors.size() == 2 ? InVectors.back() : nullptr,
+ CommonMask);
+ }
+
+ ~ShuffleCostEstimator() {
+ assert((IsFinalized || CommonMask.empty()) &&
+ "Shuffle construction must be finalized.");
+ }
+};
+
+InstructionCost
+BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
+ SmallPtrSetImpl<Value *> &CheckedExtracts) {
+ ArrayRef<Value *> VL = E->Scalars;
+
+ Type *ScalarTy = VL[0]->getType();
+ if (auto *SI = dyn_cast<StoreInst>(VL[0]))
+ ScalarTy = SI->getValueOperand()->getType();
+ else if (auto *CI = dyn_cast<CmpInst>(VL[0]))
+ ScalarTy = CI->getOperand(0)->getType();
+ else if (auto *IE = dyn_cast<InsertElementInst>(VL[0]))
+ ScalarTy = IE->getOperand(1)->getType();
+ auto *VecTy = FixedVectorType::get(ScalarTy, VL.size());
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ // If we have computed a smaller type for the expression, update VecTy so
+ // that the costs will be accurate.
+ if (MinBWs.count(VL[0]))
+ VecTy = FixedVectorType::get(
+ IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size());
+ unsigned EntryVF = E->getVectorFactor();
+ auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF);
+
+ bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
if (E->State == TreeEntry::NeedToGather) {
if (allConstant(VL))
return 0;
if (isa<InsertElementInst>(VL[0]))
return InstructionCost::getInvalid();
+ ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this,
+ CheckedExtracts);
+ unsigned VF = E->getVectorFactor();
+ SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(),
+ E->ReuseShuffleIndices.end());
SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end());
// Build a mask out of the reorder indices and reorder scalars per this
// mask.
@@ -6709,195 +7315,104 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
if (!ReorderMask.empty())
reorderScalars(GatheredScalars, ReorderMask);
SmallVector<int> Mask;
+ SmallVector<int> ExtractMask;
+ std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle;
std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle;
SmallVector<const TreeEntry *> Entries;
+ Type *ScalarTy = GatheredScalars.front()->getType();
+ // Check for gathered extracts.
+ ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask);
+ SmallVector<Value *> IgnoredVals;
+ if (UserIgnoreList)
+ IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end());
+
+ bool Resized = false;
+ if (Value *VecBase = Estimator.adjustExtracts(
+ E, ExtractMask, ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc)))
+ if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
+ if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) {
+ Resized = true;
+ GatheredScalars.append(VF - GatheredScalars.size(),
+ PoisonValue::get(ScalarTy));
+ }
+
// Do not try to look for reshuffled loads for gathered loads (they will be
// handled later), for vectorized scalars, and cases, which are definitely
// not profitable (splats and small gather nodes.)
- if (E->getOpcode() != Instruction::Load || E->isAltShuffle() ||
+ if (ExtractShuffle || E->getOpcode() != Instruction::Load ||
+ E->isAltShuffle() ||
all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) ||
isSplat(E->Scalars) ||
(E->Scalars != GatheredScalars && GatheredScalars.size() <= 2))
GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries);
if (GatherShuffle) {
- // Remove shuffled elements from list of gathers.
- for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
- if (Mask[I] != UndefMaskElem)
- GatheredScalars[I] = PoisonValue::get(ScalarTy);
- }
assert((Entries.size() == 1 || Entries.size() == 2) &&
"Expected shuffle of 1 or 2 entries.");
- InstructionCost GatherCost = 0;
- int Limit = Mask.size() * 2;
- if (all_of(Mask, [=](int Idx) { return Idx < Limit; }) &&
- ShuffleVectorInst::isIdentityMask(Mask)) {
+ if (*GatherShuffle == TTI::SK_PermuteSingleSrc &&
+ Entries.front()->isSame(E->Scalars)) {
// Perfect match in the graph, will reuse the previously vectorized
// node. Cost is 0.
LLVM_DEBUG(
dbgs()
<< "SLP: perfect diamond match for gather bundle that starts with "
<< *VL.front() << ".\n");
- if (NeedToShuffleReuses)
- GatherCost =
- TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
- FinalVecTy, E->ReuseShuffleIndices);
- } else {
- LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size()
- << " entries for bundle that starts with "
- << *VL.front() << ".\n");
- // Detected that instead of gather we can emit a shuffle of single/two
- // previously vectorized nodes. Add the cost of the permutation rather
- // than gather.
- ::addMask(Mask, E->ReuseShuffleIndices);
- GatherCost = TTI->getShuffleCost(*GatherShuffle, FinalVecTy, Mask);
- }
- if (!all_of(GatheredScalars, UndefValue::classof))
- GatherCost += getGatherCost(GatheredScalars);
- return GatherCost;
- }
- if ((E->getOpcode() == Instruction::ExtractElement ||
- all_of(E->Scalars,
- [](Value *V) {
- return isa<ExtractElementInst, UndefValue>(V);
- })) &&
- allSameType(VL)) {
- // Check that gather of extractelements can be represented as just a
- // shuffle of a single/two vectors the scalars are extracted from.
- SmallVector<int> Mask;
- std::optional<TargetTransformInfo::ShuffleKind> ShuffleKind =
- isFixedVectorShuffle(VL, Mask);
- if (ShuffleKind) {
- // Found the bunch of extractelement instructions that must be gathered
- // into a vector and can be represented as a permutation elements in a
- // single input vector or of 2 input vectors.
- InstructionCost Cost =
- computeExtractCost(VL, VecTy, *ShuffleKind, Mask, *TTI);
- AdjustExtractsCost(Cost);
- if (NeedToShuffleReuses)
- Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
- FinalVecTy, E->ReuseShuffleIndices);
- return Cost;
- }
- }
- if (isSplat(VL)) {
- // Found the broadcasting of the single scalar, calculate the cost as the
- // broadcast.
- assert(VecTy == FinalVecTy &&
- "No reused scalars expected for broadcast.");
- const auto *It =
- find_if(VL, [](Value *V) { return !isa<UndefValue>(V); });
- // If all values are undefs - consider cost free.
- if (It == VL.end())
- return TTI::TCC_Free;
- // Add broadcast for non-identity shuffle only.
- bool NeedShuffle =
- VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof);
- InstructionCost InsertCost =
- TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind,
- /*Index=*/0, PoisonValue::get(VecTy), *It);
- return InsertCost + (NeedShuffle
- ? TTI->getShuffleCost(
- TargetTransformInfo::SK_Broadcast, VecTy,
- /*Mask=*/std::nullopt, CostKind,
- /*Index=*/0,
- /*SubTp=*/nullptr, /*Args=*/VL[0])
- : TTI::TCC_Free);
- }
- InstructionCost ReuseShuffleCost = 0;
- if (NeedToShuffleReuses)
- ReuseShuffleCost = TTI->getShuffleCost(
- TTI::SK_PermuteSingleSrc, FinalVecTy, E->ReuseShuffleIndices);
- // Improve gather cost for gather of loads, if we can group some of the
- // loads into vector loads.
- if (VL.size() > 2 && E->getOpcode() == Instruction::Load &&
- !E->isAltShuffle()) {
- BoUpSLP::ValueSet VectorizedLoads;
- unsigned StartIdx = 0;
- unsigned VF = VL.size() / 2;
- unsigned VectorizedCnt = 0;
- unsigned ScatterVectorizeCnt = 0;
- const unsigned Sz = DL->getTypeSizeInBits(E->getMainOp()->getType());
- for (unsigned MinVF = getMinVF(2 * Sz); VF >= MinVF; VF /= 2) {
- for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End;
- Cnt += VF) {
- ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
- if (!VectorizedLoads.count(Slice.front()) &&
- !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
- SmallVector<Value *> PointerOps;
- OrdersType CurrentOrder;
- LoadsState LS =
- canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, *SE, *LI,
- *TLI, CurrentOrder, PointerOps);
- switch (LS) {
- case LoadsState::Vectorize:
- case LoadsState::ScatterVectorize:
- // Mark the vectorized loads so that we don't vectorize them
- // again.
- if (LS == LoadsState::Vectorize)
- ++VectorizedCnt;
- else
- ++ScatterVectorizeCnt;
- VectorizedLoads.insert(Slice.begin(), Slice.end());
- // If we vectorized initial block, no need to try to vectorize it
- // again.
- if (Cnt == StartIdx)
- StartIdx += VF;
- break;
- case LoadsState::Gather:
- break;
- }
+ // Restore the mask for previous partially matched values.
+ for (auto [I, V] : enumerate(E->Scalars)) {
+ if (isa<PoisonValue>(V)) {
+ Mask[I] = PoisonMaskElem;
+ continue;
}
+ if (Mask[I] == PoisonMaskElem)
+ Mask[I] = Entries.front()->findLaneForValue(V);
}
- // Check if the whole array was vectorized already - exit.
- if (StartIdx >= VL.size())
- break;
- // Found vectorizable parts - exit.
- if (!VectorizedLoads.empty())
- break;
+ Estimator.add(Entries.front(), Mask);
+ return Estimator.finalize(E->ReuseShuffleIndices);
}
- if (!VectorizedLoads.empty()) {
- InstructionCost GatherCost = 0;
- unsigned NumParts = TTI->getNumberOfParts(VecTy);
- bool NeedInsertSubvectorAnalysis =
- !NumParts || (VL.size() / VF) > NumParts;
- // Get the cost for gathered loads.
- for (unsigned I = 0, End = VL.size(); I < End; I += VF) {
- if (VectorizedLoads.contains(VL[I]))
- continue;
- GatherCost += getGatherCost(VL.slice(I, VF));
- }
- // The cost for vectorized loads.
- InstructionCost ScalarsCost = 0;
- for (Value *V : VectorizedLoads) {
- auto *LI = cast<LoadInst>(V);
- ScalarsCost +=
- TTI->getMemoryOpCost(Instruction::Load, LI->getType(),
- LI->getAlign(), LI->getPointerAddressSpace(),
- CostKind, TTI::OperandValueInfo(), LI);
- }
- auto *LI = cast<LoadInst>(E->getMainOp());
- auto *LoadTy = FixedVectorType::get(LI->getType(), VF);
- Align Alignment = LI->getAlign();
- GatherCost +=
- VectorizedCnt *
- TTI->getMemoryOpCost(Instruction::Load, LoadTy, Alignment,
- LI->getPointerAddressSpace(), CostKind,
- TTI::OperandValueInfo(), LI);
- GatherCost += ScatterVectorizeCnt *
- TTI->getGatherScatterOpCost(
- Instruction::Load, LoadTy, LI->getPointerOperand(),
- /*VariableMask=*/false, Alignment, CostKind, LI);
- if (NeedInsertSubvectorAnalysis) {
- // Add the cost for the subvectors insert.
- for (int I = VF, E = VL.size(); I < E; I += VF)
- GatherCost +=
- TTI->getShuffleCost(TTI::SK_InsertSubvector, VecTy,
- std::nullopt, CostKind, I, LoadTy);
- }
- return ReuseShuffleCost + GatherCost - ScalarsCost;
+ if (!Resized) {
+ unsigned VF1 = Entries.front()->getVectorFactor();
+ unsigned VF2 = Entries.back()->getVectorFactor();
+ if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF)
+ GatheredScalars.append(VF - GatheredScalars.size(),
+ PoisonValue::get(ScalarTy));
}
+ // Remove shuffled elements from list of gathers.
+ for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
+ if (Mask[I] != PoisonMaskElem)
+ GatheredScalars[I] = PoisonValue::get(ScalarTy);
+ }
+ LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size()
+ << " entries for bundle that starts with "
+ << *VL.front() << ".\n";);
+ if (Entries.size() == 1)
+ Estimator.add(Entries.front(), Mask);
+ else
+ Estimator.add(Entries.front(), Entries.back(), Mask);
+ if (all_of(GatheredScalars, PoisonValue ::classof))
+ return Estimator.finalize(E->ReuseShuffleIndices);
+ return Estimator.finalize(
+ E->ReuseShuffleIndices, E->Scalars.size(),
+ [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
+ Vec = Estimator.gather(GatheredScalars,
+ Constant::getNullValue(FixedVectorType::get(
+ GatheredScalars.front()->getType(),
+ GatheredScalars.size())));
+ });
}
- return ReuseShuffleCost + getGatherCost(VL);
+ if (!all_of(GatheredScalars, PoisonValue::classof)) {
+ auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size());
+ bool SameGathers = VL.equals(Gathers);
+ Value *BV = Estimator.gather(
+ Gathers, SameGathers ? nullptr
+ : Constant::getNullValue(FixedVectorType::get(
+ GatheredScalars.front()->getType(),
+ GatheredScalars.size())));
+ SmallVector<int> ReuseMask(Gathers.size(), PoisonMaskElem);
+ std::iota(ReuseMask.begin(), ReuseMask.end(), 0);
+ Estimator.add(BV, ReuseMask);
+ }
+ if (ExtractShuffle)
+ Estimator.add(E, std::nullopt);
+ return Estimator.finalize(E->ReuseShuffleIndices);
}
InstructionCost CommonCost = 0;
SmallVector<int> Mask;
@@ -6945,48 +7460,89 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
}
InstructionCost VecCost = VectorCost(CommonCost);
- LLVM_DEBUG(
- dumpTreeCosts(E, CommonCost, VecCost - CommonCost, ScalarCost));
- // Disable warnings for `this` and `E` are unused. Required for
- // `dumpTreeCosts`.
- (void)this;
- (void)E;
+ LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost,
+ ScalarCost, "Calculated costs for Tree"));
return VecCost - ScalarCost;
};
// Calculate cost difference from vectorizing set of GEPs.
// Negative value means vectorizing is profitable.
auto GetGEPCostDiff = [=](ArrayRef<Value *> Ptrs, Value *BasePtr) {
- InstructionCost CostSavings = 0;
- for (Value *V : Ptrs) {
- if (V == BasePtr)
- continue;
- auto *Ptr = dyn_cast<GetElementPtrInst>(V);
- // GEPs may contain just addresses without instructions, considered free.
- // GEPs with all constant indices also considered to have zero cost.
- if (!Ptr || Ptr->hasAllConstantIndices())
- continue;
-
- // Here we differentiate two cases: when GEPs represent a regular
- // vectorization tree node (and hence vectorized) and when the set is
- // arguments of a set of loads or stores being vectorized. In the former
- // case all the scalar GEPs will be removed as a result of vectorization.
+ InstructionCost ScalarCost = 0;
+ InstructionCost VecCost = 0;
+ // Here we differentiate two cases: (1) when Ptrs represent a regular
+ // vectorization tree node (as they are pointer arguments of scattered
+ // loads) or (2) when Ptrs are the arguments of loads or stores being
+ // vectorized as plane wide unit-stride load/store since all the
+ // loads/stores are known to be from/to adjacent locations.
+ assert(E->State == TreeEntry::Vectorize &&
+ "Entry state expected to be Vectorize here.");
+ if (isa<LoadInst, StoreInst>(VL0)) {
+ // Case 2: estimate costs for pointer related costs when vectorizing to
+ // a wide load/store.
+ // Scalar cost is estimated as a set of pointers with known relationship
+ // between them.
+ // For vector code we will use BasePtr as argument for the wide load/store
+ // but we also need to account all the instructions which are going to
+ // stay in vectorized code due to uses outside of these scalar
+ // loads/stores.
+ ScalarCost = TTI->getPointersChainCost(
+ Ptrs, BasePtr, TTI::PointersChainInfo::getUnitStride(), ScalarTy,
+ CostKind);
+
+ SmallVector<const Value *> PtrsRetainedInVecCode;
+ for (Value *V : Ptrs) {
+ if (V == BasePtr) {
+ PtrsRetainedInVecCode.push_back(V);
+ continue;
+ }
+ auto *Ptr = dyn_cast<GetElementPtrInst>(V);
+ // For simplicity assume Ptr to stay in vectorized code if it's not a
+ // GEP instruction. We don't care since it's cost considered free.
+ // TODO: We should check for any uses outside of vectorizable tree
+ // rather than just single use.
+ if (!Ptr || !Ptr->hasOneUse())
+ PtrsRetainedInVecCode.push_back(V);
+ }
+
+ if (PtrsRetainedInVecCode.size() == Ptrs.size()) {
+ // If all pointers stay in vectorized code then we don't have
+ // any savings on that.
+ LLVM_DEBUG(dumpTreeCosts(E, 0, ScalarCost, ScalarCost,
+ "Calculated GEPs cost for Tree"));
+ return InstructionCost{TTI::TCC_Free};
+ }
+ VecCost = TTI->getPointersChainCost(
+ PtrsRetainedInVecCode, BasePtr,
+ TTI::PointersChainInfo::getKnownStride(), VecTy, CostKind);
+ } else {
+ // Case 1: Ptrs are the arguments of loads that we are going to transform
+ // into masked gather load intrinsic.
+ // All the scalar GEPs will be removed as a result of vectorization.
// For any external uses of some lanes extract element instructions will
- // be generated (which cost is estimated separately). For the latter case
- // since the set of GEPs itself is not vectorized those used more than
- // once will remain staying in vectorized code as well. So we should not
- // count them as savings.
- if (!Ptr->hasOneUse() && isa<LoadInst, StoreInst>(VL0))
- continue;
-
- // TODO: it is target dependent, so need to implement and then use a TTI
- // interface.
- CostSavings += TTI->getArithmeticInstrCost(Instruction::Add,
- Ptr->getType(), CostKind);
- }
- LLVM_DEBUG(dbgs() << "SLP: Calculated GEPs cost savings or Tree:\n";
- E->dump());
- LLVM_DEBUG(dbgs() << "SLP: GEP cost saving = " << CostSavings << "\n");
- return InstructionCost() - CostSavings;
+ // be generated (which cost is estimated separately).
+ TTI::PointersChainInfo PtrsInfo =
+ all_of(Ptrs,
+ [](const Value *V) {
+ auto *Ptr = dyn_cast<GetElementPtrInst>(V);
+ return Ptr && !Ptr->hasAllConstantIndices();
+ })
+ ? TTI::PointersChainInfo::getUnknownStride()
+ : TTI::PointersChainInfo::getKnownStride();
+
+ ScalarCost = TTI->getPointersChainCost(Ptrs, BasePtr, PtrsInfo, ScalarTy,
+ CostKind);
+ if (auto *BaseGEP = dyn_cast<GEPOperator>(BasePtr)) {
+ SmallVector<const Value *> Indices(BaseGEP->indices());
+ VecCost = TTI->getGEPCost(BaseGEP->getSourceElementType(),
+ BaseGEP->getPointerOperand(), Indices, VecTy,
+ CostKind);
+ }
+ }
+
+ LLVM_DEBUG(dumpTreeCosts(E, 0, VecCost, ScalarCost,
+ "Calculated GEPs cost for Tree"));
+
+ return VecCost - ScalarCost;
};
switch (ShuffleOrOp) {
@@ -7062,7 +7618,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy);
- SmallVector<int> InsertMask(NumElts, UndefMaskElem);
+ SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
unsigned OffsetBeg = *getInsertIndex(VL.front());
unsigned OffsetEnd = OffsetBeg;
InsertMask[OffsetBeg] = 0;
@@ -7099,13 +7655,13 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
SmallVector<int> Mask;
if (!E->ReorderIndices.empty()) {
inversePermutation(E->ReorderIndices, Mask);
- Mask.append(InsertVecSz - Mask.size(), UndefMaskElem);
+ Mask.append(InsertVecSz - Mask.size(), PoisonMaskElem);
} else {
- Mask.assign(VecSz, UndefMaskElem);
+ Mask.assign(VecSz, PoisonMaskElem);
std::iota(Mask.begin(), std::next(Mask.begin(), InsertVecSz), 0);
}
bool IsIdentity = true;
- SmallVector<int> PrevMask(InsertVecSz, UndefMaskElem);
+ SmallVector<int> PrevMask(InsertVecSz, PoisonMaskElem);
Mask.swap(PrevMask);
for (unsigned I = 0; I < NumScalars; ++I) {
unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]);
@@ -7148,14 +7704,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
InsertVecTy);
} else {
for (unsigned I = 0, End = OffsetBeg - Offset; I < End; ++I)
- Mask[I] = InMask.test(I) ? UndefMaskElem : I;
+ Mask[I] = InMask.test(I) ? PoisonMaskElem : I;
for (unsigned I = OffsetBeg - Offset, End = OffsetEnd - Offset;
I <= End; ++I)
- if (Mask[I] != UndefMaskElem)
+ if (Mask[I] != PoisonMaskElem)
Mask[I] = I + VecSz;
for (unsigned I = OffsetEnd + 1 - Offset; I < VecSz; ++I)
Mask[I] =
- ((I >= InMask.size()) || InMask.test(I)) ? UndefMaskElem : I;
+ ((I >= InMask.size()) || InMask.test(I)) ? PoisonMaskElem : I;
Cost += TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, InsertVecTy, Mask);
}
}
@@ -7422,11 +7978,11 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
VecCost +=
TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy, CostKind);
} else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
- VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy,
- Builder.getInt1Ty(),
+ auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size());
+ VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy,
CI0->getPredicate(), CostKind, VL0);
VecCost += TTI->getCmpSelInstrCost(
- E->getOpcode(), ScalarTy, Builder.getInt1Ty(),
+ E->getOpcode(), VecTy, MaskTy,
cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
E->getAltOp());
} else {
@@ -7615,7 +8171,7 @@ InstructionCost BoUpSLP::getSpillCost() const {
unsigned BundleWidth = VectorizableTree.front()->Scalars.size();
InstructionCost Cost = 0;
- SmallPtrSet<Instruction*, 4> LiveValues;
+ SmallPtrSet<Instruction *, 4> LiveValues;
Instruction *PrevInst = nullptr;
// The entries in VectorizableTree are not necessarily ordered by their
@@ -7626,6 +8182,8 @@ InstructionCost BoUpSLP::getSpillCost() const {
// are grouped together. Using dominance ensures a deterministic order.
SmallVector<Instruction *, 16> OrderedScalars;
for (const auto &TEPtr : VectorizableTree) {
+ if (TEPtr->State != TreeEntry::Vectorize)
+ continue;
Instruction *Inst = dyn_cast<Instruction>(TEPtr->Scalars[0]);
if (!Inst)
continue;
@@ -7639,7 +8197,7 @@ InstructionCost BoUpSLP::getSpillCost() const {
assert((NodeA == NodeB) == (NodeA->getDFSNumIn() == NodeB->getDFSNumIn()) &&
"Different nodes should have different DFS numbers");
if (NodeA != NodeB)
- return NodeA->getDFSNumIn() < NodeB->getDFSNumIn();
+ return NodeA->getDFSNumIn() > NodeB->getDFSNumIn();
return B->comesBefore(A);
});
@@ -7698,7 +8256,7 @@ InstructionCost BoUpSLP::getSpillCost() const {
};
// Debug information does not impact spill cost.
- if (isa<CallInst>(&*PrevInstIt) && !NoCallIntrinsic(&*PrevInstIt) &&
+ if (isa<CallBase>(&*PrevInstIt) && !NoCallIntrinsic(&*PrevInstIt) &&
&*PrevInstIt != PrevInst)
NumCalls++;
@@ -7706,7 +8264,7 @@ InstructionCost BoUpSLP::getSpillCost() const {
}
if (NumCalls) {
- SmallVector<Type*, 4> V;
+ SmallVector<Type *, 4> V;
for (auto *II : LiveValues) {
auto *ScalarTy = II->getType();
if (auto *VectorTy = dyn_cast<FixedVectorType>(ScalarTy))
@@ -7797,8 +8355,8 @@ static T *performExtractsShuffleAction(
ResizeAction(ShuffleMask.begin()->first, Mask, /*ForSingleMask=*/false);
SmallBitVector IsBasePoison = isUndefVector<true>(Base, UseMask);
for (unsigned Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) {
- if (Mask[Idx] == UndefMaskElem)
- Mask[Idx] = IsBasePoison.test(Idx) ? UndefMaskElem : Idx;
+ if (Mask[Idx] == PoisonMaskElem)
+ Mask[Idx] = IsBasePoison.test(Idx) ? PoisonMaskElem : Idx;
else
Mask[Idx] = (Res.second ? Idx : Mask[Idx]) + VF;
}
@@ -7827,8 +8385,8 @@ static T *performExtractsShuffleAction(
// can shuffle them directly.
ArrayRef<int> SecMask = VMIt->second;
for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) {
- if (SecMask[I] != UndefMaskElem) {
- assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars.");
+ if (SecMask[I] != PoisonMaskElem) {
+ assert(Mask[I] == PoisonMaskElem && "Multiple uses of scalars.");
Mask[I] = SecMask[I] + Vec1VF;
}
}
@@ -7841,12 +8399,12 @@ static T *performExtractsShuffleAction(
ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false);
ArrayRef<int> SecMask = VMIt->second;
for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) {
- if (Mask[I] != UndefMaskElem) {
- assert(SecMask[I] == UndefMaskElem && "Multiple uses of scalars.");
+ if (Mask[I] != PoisonMaskElem) {
+ assert(SecMask[I] == PoisonMaskElem && "Multiple uses of scalars.");
if (Res1.second)
Mask[I] = I;
- } else if (SecMask[I] != UndefMaskElem) {
- assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars.");
+ } else if (SecMask[I] != PoisonMaskElem) {
+ assert(Mask[I] == PoisonMaskElem && "Multiple uses of scalars.");
Mask[I] = (Res2.second ? I : SecMask[I]) + VF;
}
}
@@ -7863,11 +8421,11 @@ static T *performExtractsShuffleAction(
ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false);
ArrayRef<int> SecMask = VMIt->second;
for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) {
- if (SecMask[I] != UndefMaskElem) {
- assert((Mask[I] == UndefMaskElem || IsBaseNotUndef) &&
+ if (SecMask[I] != PoisonMaskElem) {
+ assert((Mask[I] == PoisonMaskElem || IsBaseNotUndef) &&
"Multiple uses of scalars.");
Mask[I] = (Res.second ? I : SecMask[I]) + VF;
- } else if (Mask[I] != UndefMaskElem) {
+ } else if (Mask[I] != PoisonMaskElem) {
Mask[I] = I;
}
}
@@ -7877,12 +8435,23 @@ static T *performExtractsShuffleAction(
}
InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
+ // Build a map for gathered scalars to the nodes where they are used.
+ ValueToGatherNodes.clear();
+ for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) {
+ if (EntryPtr->State != TreeEntry::NeedToGather)
+ continue;
+ for (Value *V : EntryPtr->Scalars)
+ if (!isConstant(V))
+ ValueToGatherNodes.try_emplace(V).first->getSecond().insert(
+ EntryPtr.get());
+ }
InstructionCost Cost = 0;
LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size "
<< VectorizableTree.size() << ".\n");
unsigned BundleWidth = VectorizableTree[0]->Scalars.size();
+ SmallPtrSet<Value *, 4> CheckedExtracts;
for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) {
TreeEntry &TE = *VectorizableTree[I];
if (TE.State == TreeEntry::NeedToGather) {
@@ -7898,7 +8467,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
}
}
- InstructionCost C = getEntryCost(&TE, VectorizedVals);
+ InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts);
Cost += C;
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
<< " for bundle that starts with " << *TE.Scalars[0]
@@ -7951,7 +8520,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
(void)ShuffleMasks.emplace_back();
SmallVectorImpl<int> &Mask = ShuffleMasks.back()[ScalarTE];
if (Mask.empty())
- Mask.assign(FTy->getNumElements(), UndefMaskElem);
+ Mask.assign(FTy->getNumElements(), PoisonMaskElem);
// Find the insertvector, vectorized in tree, if any.
Value *Base = VU;
while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) {
@@ -7965,7 +8534,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
do {
IEBase = cast<InsertElementInst>(Base);
int Idx = *getInsertIndex(IEBase);
- assert(Mask[Idx] == UndefMaskElem &&
+ assert(Mask[Idx] == PoisonMaskElem &&
"InsertElementInstruction used already.");
Mask[Idx] = Idx;
Base = IEBase->getOperand(0);
@@ -7985,7 +8554,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
int InIdx = *InsertIdx;
SmallVectorImpl<int> &Mask = ShuffleMasks[VecId][ScalarTE];
if (Mask.empty())
- Mask.assign(FTy->getNumElements(), UndefMaskElem);
+ Mask.assign(FTy->getNumElements(), PoisonMaskElem);
Mask[InIdx] = EU.Lane;
DemandedElts[VecId].setBit(InIdx);
continue;
@@ -8024,7 +8593,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
(all_of(Mask,
[VF](int Idx) { return Idx < 2 * static_cast<int>(VF); }) &&
!ShuffleVectorInst::isIdentityMask(Mask)))) {
- SmallVector<int> OrigMask(VecVF, UndefMaskElem);
+ SmallVector<int> OrigMask(VecVF, PoisonMaskElem);
std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)),
OrigMask.begin());
C = TTI->getShuffleCost(
@@ -8110,17 +8679,23 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
// No need to check for the topmost gather node.
if (TE == VectorizableTree.front().get())
return std::nullopt;
- Mask.assign(VL.size(), UndefMaskElem);
+ Mask.assign(VL.size(), PoisonMaskElem);
assert(TE->UserTreeIndices.size() == 1 &&
"Expected only single user of the gather node.");
// TODO: currently checking only for Scalars in the tree entry, need to count
// reused elements too for better cost estimation.
Instruction &UserInst =
getLastInstructionInBundle(TE->UserTreeIndices.front().UserTE);
- auto *PHI = dyn_cast<PHINode>(&UserInst);
- auto *NodeUI = DT->getNode(
- PHI ? PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx)
- : UserInst.getParent());
+ BasicBlock *ParentBB = nullptr;
+ // Main node of PHI entries keeps the correct order of operands/incoming
+ // blocks.
+ if (auto *PHI =
+ dyn_cast<PHINode>(TE->UserTreeIndices.front().UserTE->getMainOp())) {
+ ParentBB = PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx);
+ } else {
+ ParentBB = UserInst.getParent();
+ }
+ auto *NodeUI = DT->getNode(ParentBB);
assert(NodeUI && "Should only process reachable instructions");
SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end());
auto CheckOrdering = [&](Instruction *LastEI) {
@@ -8147,45 +8722,6 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
return false;
return true;
};
- // Build a lists of values to tree entries.
- DenseMap<Value *, SmallPtrSet<const TreeEntry *, 4>> ValueToTEs;
- for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) {
- if (EntryPtr.get() == TE)
- continue;
- if (EntryPtr->State != TreeEntry::NeedToGather)
- continue;
- if (!any_of(EntryPtr->Scalars, [&GatheredScalars](Value *V) {
- return GatheredScalars.contains(V);
- }))
- continue;
- assert(EntryPtr->UserTreeIndices.size() == 1 &&
- "Expected only single user of the gather node.");
- Instruction &EntryUserInst =
- getLastInstructionInBundle(EntryPtr->UserTreeIndices.front().UserTE);
- if (&UserInst == &EntryUserInst) {
- // If 2 gathers are operands of the same entry, compare operands indices,
- // use the earlier one as the base.
- if (TE->UserTreeIndices.front().UserTE ==
- EntryPtr->UserTreeIndices.front().UserTE &&
- TE->UserTreeIndices.front().EdgeIdx <
- EntryPtr->UserTreeIndices.front().EdgeIdx)
- continue;
- }
- // Check if the user node of the TE comes after user node of EntryPtr,
- // otherwise EntryPtr depends on TE.
- auto *EntryPHI = dyn_cast<PHINode>(&EntryUserInst);
- auto *EntryI =
- EntryPHI
- ? EntryPHI
- ->getIncomingBlock(EntryPtr->UserTreeIndices.front().EdgeIdx)
- ->getTerminator()
- : &EntryUserInst;
- if (!CheckOrdering(EntryI))
- continue;
- for (Value *V : EntryPtr->Scalars)
- if (!isConstant(V))
- ValueToTEs.try_emplace(V).first->getSecond().insert(EntryPtr.get());
- }
// Find all tree entries used by the gathered values. If no common entries
// found - not a shuffle.
// Here we build a set of tree nodes for each gathered value and trying to
@@ -8195,16 +8731,58 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
// have a permutation of 2 input vectors.
SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
DenseMap<Value *, int> UsedValuesEntry;
- for (Value *V : TE->Scalars) {
+ for (Value *V : VL) {
if (isConstant(V))
continue;
// Build a list of tree entries where V is used.
SmallPtrSet<const TreeEntry *, 4> VToTEs;
- auto It = ValueToTEs.find(V);
- if (It != ValueToTEs.end())
- VToTEs = It->second;
- if (const TreeEntry *VTE = getTreeEntry(V))
+ for (const TreeEntry *TEPtr : ValueToGatherNodes.find(V)->second) {
+ if (TEPtr == TE)
+ continue;
+ assert(any_of(TEPtr->Scalars,
+ [&](Value *V) { return GatheredScalars.contains(V); }) &&
+ "Must contain at least single gathered value.");
+ assert(TEPtr->UserTreeIndices.size() == 1 &&
+ "Expected only single user of the gather node.");
+ PHINode *EntryPHI =
+ dyn_cast<PHINode>(TEPtr->UserTreeIndices.front().UserTE->getMainOp());
+ Instruction *EntryUserInst =
+ EntryPHI ? nullptr
+ : &getLastInstructionInBundle(
+ TEPtr->UserTreeIndices.front().UserTE);
+ if (&UserInst == EntryUserInst) {
+ assert(!EntryPHI && "Unexpected phi node entry.");
+ // If 2 gathers are operands of the same entry, compare operands
+ // indices, use the earlier one as the base.
+ if (TE->UserTreeIndices.front().UserTE ==
+ TEPtr->UserTreeIndices.front().UserTE &&
+ TE->UserTreeIndices.front().EdgeIdx <
+ TEPtr->UserTreeIndices.front().EdgeIdx)
+ continue;
+ }
+ // Check if the user node of the TE comes after user node of EntryPtr,
+ // otherwise EntryPtr depends on TE.
+ auto *EntryI =
+ EntryPHI
+ ? EntryPHI
+ ->getIncomingBlock(TEPtr->UserTreeIndices.front().EdgeIdx)
+ ->getTerminator()
+ : EntryUserInst;
+ if ((ParentBB != EntryI->getParent() ||
+ TE->UserTreeIndices.front().EdgeIdx <
+ TEPtr->UserTreeIndices.front().EdgeIdx ||
+ TE->UserTreeIndices.front().UserTE !=
+ TEPtr->UserTreeIndices.front().UserTE) &&
+ !CheckOrdering(EntryI))
+ continue;
+ VToTEs.insert(TEPtr);
+ }
+ if (const TreeEntry *VTE = getTreeEntry(V)) {
+ Instruction &EntryUserInst = getLastInstructionInBundle(VTE);
+ if (&EntryUserInst == &UserInst || !CheckOrdering(&EntryUserInst))
+ continue;
VToTEs.insert(VTE);
+ }
if (VToTEs.empty())
continue;
if (UsedTEs.empty()) {
@@ -8260,13 +8838,13 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) {
return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars);
});
- if (It != FirstEntries.end()) {
+ if (It != FirstEntries.end() && (*It)->getVectorFactor() == VL.size()) {
Entries.push_back(*It);
std::iota(Mask.begin(), Mask.end(), 0);
// Clear undef scalars.
for (int I = 0, Sz = VL.size(); I < Sz; ++I)
- if (isa<PoisonValue>(TE->Scalars[I]))
- Mask[I] = UndefMaskElem;
+ if (isa<PoisonValue>(VL[I]))
+ Mask[I] = PoisonMaskElem;
return TargetTransformInfo::SK_PermuteSingleSrc;
}
// No perfect match, just shuffle, so choose the first tree node from the
@@ -8302,10 +8880,18 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
break;
}
}
- // No 2 source vectors with the same vector factor - give up and do regular
- // gather.
- if (Entries.empty())
- return std::nullopt;
+ // No 2 source vectors with the same vector factor - just choose 2 with max
+ // index.
+ if (Entries.empty()) {
+ Entries.push_back(
+ *std::max_element(UsedTEs.front().begin(), UsedTEs.front().end(),
+ [](const TreeEntry *TE1, const TreeEntry *TE2) {
+ return TE1->Idx < TE2->Idx;
+ }));
+ Entries.push_back(SecondEntries.front());
+ VF = std::max(Entries.front()->getVectorFactor(),
+ Entries.back()->getVectorFactor());
+ }
}
bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof);
@@ -8427,19 +9013,8 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
return std::nullopt;
}
-InstructionCost BoUpSLP::getGatherCost(FixedVectorType *Ty,
- const APInt &ShuffledIndices,
- bool NeedToShuffle) const {
- TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- InstructionCost Cost =
- TTI->getScalarizationOverhead(Ty, ~ShuffledIndices, /*Insert*/ true,
- /*Extract*/ false, CostKind);
- if (NeedToShuffle)
- Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty);
- return Cost;
-}
-
-InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const {
+InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL,
+ bool ForPoisonSrc) const {
// Find the type of the operands in VL.
Type *ScalarTy = VL[0]->getType();
if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
@@ -8451,20 +9026,36 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const {
// shuffle candidates.
APInt ShuffledElements = APInt::getZero(VL.size());
DenseSet<Value *> UniqueElements;
- // Iterate in reverse order to consider insert elements with the high cost.
- for (unsigned I = VL.size(); I > 0; --I) {
- unsigned Idx = I - 1;
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost Cost;
+ auto EstimateInsertCost = [&](unsigned I, Value *V) {
+ if (!ForPoisonSrc)
+ Cost +=
+ TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind,
+ I, Constant::getNullValue(VecTy), V);
+ };
+ for (unsigned I = 0, E = VL.size(); I < E; ++I) {
+ Value *V = VL[I];
// No need to shuffle duplicates for constants.
- if (isConstant(VL[Idx])) {
- ShuffledElements.setBit(Idx);
+ if ((ForPoisonSrc && isConstant(V)) || isa<UndefValue>(V)) {
+ ShuffledElements.setBit(I);
continue;
}
- if (!UniqueElements.insert(VL[Idx]).second) {
+ if (!UniqueElements.insert(V).second) {
DuplicateNonConst = true;
- ShuffledElements.setBit(Idx);
+ ShuffledElements.setBit(I);
+ continue;
}
+ EstimateInsertCost(I, V);
}
- return getGatherCost(VecTy, ShuffledElements, DuplicateNonConst);
+ if (ForPoisonSrc)
+ Cost =
+ TTI->getScalarizationOverhead(VecTy, ~ShuffledElements, /*Insert*/ true,
+ /*Extract*/ false, CostKind);
+ if (DuplicateNonConst)
+ Cost +=
+ TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
+ return Cost;
}
// Perform operand reordering on the instructions in VL and return the reordered
@@ -8483,6 +9074,9 @@ void BoUpSLP::reorderInputsAccordingToOpcode(
}
Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
+ auto &Res = EntryToLastInstruction.FindAndConstruct(E);
+ if (Res.second)
+ return *Res.second;
// Get the basic block this bundle is in. All instructions in the bundle
// should be in this block (except for extractelement-like instructions with
// constant indeces).
@@ -8497,7 +9091,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
isVectorLikeInstWithConstOps(I);
}));
- auto &&FindLastInst = [E, Front, this, &BB]() {
+ auto FindLastInst = [&]() {
Instruction *LastInst = Front;
for (Value *V : E->Scalars) {
auto *I = dyn_cast<Instruction>(V);
@@ -8508,9 +9102,11 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
LastInst = I;
continue;
}
- assert(isVectorLikeInstWithConstOps(LastInst) &&
- isVectorLikeInstWithConstOps(I) &&
- "Expected vector-like insts only.");
+ assert(((E->getOpcode() == Instruction::GetElementPtr &&
+ !isa<GetElementPtrInst>(I)) ||
+ (isVectorLikeInstWithConstOps(LastInst) &&
+ isVectorLikeInstWithConstOps(I))) &&
+ "Expected vector-like or non-GEP in GEP node insts only.");
if (!DT->isReachableFromEntry(LastInst->getParent())) {
LastInst = I;
continue;
@@ -8531,7 +9127,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
return LastInst;
};
- auto &&FindFirstInst = [E, Front, this]() {
+ auto FindFirstInst = [&]() {
Instruction *FirstInst = Front;
for (Value *V : E->Scalars) {
auto *I = dyn_cast<Instruction>(V);
@@ -8542,9 +9138,11 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
FirstInst = I;
continue;
}
- assert(isVectorLikeInstWithConstOps(FirstInst) &&
- isVectorLikeInstWithConstOps(I) &&
- "Expected vector-like insts only.");
+ assert(((E->getOpcode() == Instruction::GetElementPtr &&
+ !isa<GetElementPtrInst>(I)) ||
+ (isVectorLikeInstWithConstOps(FirstInst) &&
+ isVectorLikeInstWithConstOps(I))) &&
+ "Expected vector-like or non-GEP in GEP node insts only.");
if (!DT->isReachableFromEntry(FirstInst->getParent())) {
FirstInst = I;
continue;
@@ -8566,22 +9164,23 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
// Set the insert point to the beginning of the basic block if the entry
// should not be scheduled.
- if (E->State != TreeEntry::NeedToGather &&
- (doesNotNeedToSchedule(E->Scalars) ||
+ if (doesNotNeedToSchedule(E->Scalars) ||
+ (E->State != TreeEntry::NeedToGather &&
all_of(E->Scalars, isVectorLikeInstWithConstOps))) {
- Instruction *InsertInst;
- if (all_of(E->Scalars, [](Value *V) {
+ if ((E->getOpcode() == Instruction::GetElementPtr &&
+ any_of(E->Scalars,
+ [](Value *V) {
+ return !isa<GetElementPtrInst>(V) && isa<Instruction>(V);
+ })) ||
+ all_of(E->Scalars, [](Value *V) {
return !isVectorLikeInstWithConstOps(V) && isUsedOutsideBlock(V);
}))
- InsertInst = FindLastInst();
+ Res.second = FindLastInst();
else
- InsertInst = FindFirstInst();
- return *InsertInst;
+ Res.second = FindFirstInst();
+ return *Res.second;
}
- // The last instruction in the bundle in program order.
- Instruction *LastInst = nullptr;
-
// Find the last instruction. The common case should be that BB has been
// scheduled, and the last instruction is VL.back(). So we start with
// VL.back() and iterate over schedule data until we reach the end of the
@@ -8594,7 +9193,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
if (Bundle && Bundle->isPartOfBundle())
for (; Bundle; Bundle = Bundle->NextInBundle)
if (Bundle->OpValue == Bundle->Inst)
- LastInst = Bundle->Inst;
+ Res.second = Bundle->Inst;
}
// LastInst can still be null at this point if there's either not an entry
@@ -8615,15 +9214,15 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) {
// not ideal. However, this should be exceedingly rare since it requires that
// we both exit early from buildTree_rec and that the bundle be out-of-order
// (causing us to iterate all the way to the end of the block).
- if (!LastInst)
- LastInst = FindLastInst();
- assert(LastInst && "Failed to find last instruction in bundle");
- return *LastInst;
+ if (!Res.second)
+ Res.second = FindLastInst();
+ assert(Res.second && "Failed to find last instruction in bundle");
+ return *Res.second;
}
void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) {
auto *Front = E->getMainOp();
- Instruction *LastInst = EntryToLastInstruction.lookup(E);
+ Instruction *LastInst = &getLastInstructionInBundle(E);
assert(LastInst && "Failed to find last instruction in bundle");
// If the instruction is PHI, set the insert point after all the PHIs.
bool IsPHI = isa<PHINode>(LastInst);
@@ -8641,7 +9240,7 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) {
Builder.SetCurrentDebugLocation(Front->getDebugLoc());
}
-Value *BoUpSLP::gather(ArrayRef<Value *> VL) {
+Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) {
// List of instructions/lanes from current block and/or the blocks which are
// part of the current loop. These instructions will be inserted at the end to
// make it possible to optimize loops and hoist invariant instructions out of
@@ -8658,7 +9257,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) {
for (int I = 0, E = VL.size(); I < E; ++I) {
if (auto *Inst = dyn_cast<Instruction>(VL[I]))
if ((CheckPredecessor(Inst->getParent(), Builder.GetInsertBlock()) ||
- getTreeEntry(Inst) || (L && (L->contains(Inst)))) &&
+ getTreeEntry(Inst) ||
+ (L && (!Root || L->isLoopInvariant(Root)) && L->contains(Inst))) &&
PostponedIndices.insert(I).second)
PostponedInsts.emplace_back(Inst, I);
}
@@ -8681,7 +9281,7 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) {
Value *Val0 =
isa<StoreInst>(VL[0]) ? cast<StoreInst>(VL[0])->getValueOperand() : VL[0];
FixedVectorType *VecTy = FixedVectorType::get(Val0->getType(), VL.size());
- Value *Vec = PoisonValue::get(VecTy);
+ Value *Vec = Root ? Root : PoisonValue::get(VecTy);
SmallVector<int> NonConsts;
// Insert constant values at first.
for (int I = 0, E = VL.size(); I < E; ++I) {
@@ -8691,6 +9291,18 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) {
NonConsts.push_back(I);
continue;
}
+ if (Root) {
+ if (!isa<UndefValue>(VL[I])) {
+ NonConsts.push_back(I);
+ continue;
+ }
+ if (isa<PoisonValue>(VL[I]))
+ continue;
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(Root)) {
+ if (SV->getMaskValue(I) == PoisonMaskElem)
+ continue;
+ }
+ }
Vec = CreateInsertElement(Vec, VL[I], I);
}
// Insert non-constant values.
@@ -8789,6 +9401,10 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
}
return Vec;
}
+ Value *createIdentity(Value *V) { return V; }
+ Value *createPoison(Type *Ty, unsigned VF) {
+ return PoisonValue::get(FixedVectorType::get(Ty, VF));
+ }
/// Resizes 2 input vector to match the sizes, if the they are not equal
/// yet. The smallest vector is resized to the size of the larger vector.
void resizeToMatch(Value *&V1, Value *&V2) {
@@ -8798,7 +9414,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
int V2VF = cast<FixedVectorType>(V2->getType())->getNumElements();
int VF = std::max(V1VF, V2VF);
int MinVF = std::min(V1VF, V2VF);
- SmallVector<int> IdentityMask(VF, UndefMaskElem);
+ SmallVector<int> IdentityMask(VF, PoisonMaskElem);
std::iota(IdentityMask.begin(), std::next(IdentityMask.begin(), MinVF),
0);
Value *&Op = MinVF == V1VF ? V1 : V2;
@@ -8821,7 +9437,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
assert(V1 && "Expected at least one vector value.");
ShuffleIRBuilder ShuffleBuilder(Builder, R.GatherShuffleExtractSeq,
R.CSEBlocks);
- return BaseShuffleAnalysis::createShuffle(V1, V2, Mask, ShuffleBuilder);
+ return BaseShuffleAnalysis::createShuffle<Value *>(V1, V2, Mask,
+ ShuffleBuilder);
}
/// Transforms mask \p CommonMask per given \p Mask to make proper set after
@@ -8829,7 +9446,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
static void transformMaskAfterShuffle(MutableArrayRef<int> CommonMask,
ArrayRef<int> Mask) {
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
- if (Mask[Idx] != UndefMaskElem)
+ if (Mask[Idx] != PoisonMaskElem)
CommonMask[Idx] = Idx;
}
@@ -8837,6 +9454,39 @@ public:
ShuffleInstructionBuilder(IRBuilderBase &Builder, BoUpSLP &R)
: Builder(Builder), R(R) {}
+ /// Adjusts extractelements after reusing them.
+ Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask) {
+ Value *VecBase = nullptr;
+ for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
+ int Idx = Mask[I];
+ if (Idx == PoisonMaskElem)
+ continue;
+ auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
+ VecBase = EI->getVectorOperand();
+ // If the only one use is vectorized - can delete the extractelement
+ // itself.
+ if (!EI->hasOneUse() || any_of(EI->users(), [&](User *U) {
+ return !R.ScalarToTreeEntry.count(U);
+ }))
+ continue;
+ R.eraseInstruction(EI);
+ }
+ return VecBase;
+ }
+ /// Checks if the specified entry \p E needs to be delayed because of its
+ /// dependency nodes.
+ Value *needToDelay(const TreeEntry *E, ArrayRef<const TreeEntry *> Deps) {
+ // No need to delay emission if all deps are ready.
+ if (all_of(Deps, [](const TreeEntry *TE) { return TE->VectorizedValue; }))
+ return nullptr;
+ // Postpone gather emission, will be emitted after the end of the
+ // process to keep correct order.
+ auto *VecTy = FixedVectorType::get(E->Scalars.front()->getType(),
+ E->getVectorFactor());
+ return Builder.CreateAlignedLoad(
+ VecTy, PoisonValue::get(PointerType::getUnqual(VecTy->getContext())),
+ MaybeAlign());
+ }
/// Adds 2 input vectors and the mask for their shuffling.
void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors.");
@@ -8849,15 +9499,15 @@ public:
Value *Vec = InVectors.front();
if (InVectors.size() == 2) {
Vec = createShuffle(Vec, InVectors.back(), CommonMask);
- transformMaskAfterShuffle(CommonMask, Mask);
+ transformMaskAfterShuffle(CommonMask, CommonMask);
} else if (cast<FixedVectorType>(Vec->getType())->getNumElements() !=
Mask.size()) {
Vec = createShuffle(Vec, nullptr, CommonMask);
- transformMaskAfterShuffle(CommonMask, Mask);
+ transformMaskAfterShuffle(CommonMask, CommonMask);
}
V1 = createShuffle(V1, V2, Mask);
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
- if (Mask[Idx] != UndefMaskElem)
+ if (Mask[Idx] != PoisonMaskElem)
CommonMask[Idx] = Idx + Sz;
InVectors.front() = Vec;
if (InVectors.size() == 2)
@@ -8870,7 +9520,7 @@ public:
if (InVectors.empty()) {
if (!isa<FixedVectorType>(V1->getType())) {
V1 = createShuffle(V1, nullptr, CommonMask);
- CommonMask.assign(Mask.size(), UndefMaskElem);
+ CommonMask.assign(Mask.size(), PoisonMaskElem);
transformMaskAfterShuffle(CommonMask, Mask);
}
InVectors.push_back(V1);
@@ -8892,7 +9542,7 @@ public:
transformMaskAfterShuffle(CommonMask, CommonMask);
}
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
- if (CommonMask[Idx] == UndefMaskElem && Mask[Idx] != UndefMaskElem)
+ if (CommonMask[Idx] == PoisonMaskElem && Mask[Idx] != PoisonMaskElem)
CommonMask[Idx] =
V->getType() != V1->getType()
? Idx + Sz
@@ -8910,7 +9560,7 @@ public:
// Check if second vector is required if the used elements are already
// used from the first one.
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
- if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem) {
+ if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem) {
InVectors.push_back(V1);
break;
}
@@ -8919,7 +9569,7 @@ public:
if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
VF = FTy->getNumElements();
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
- if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem)
+ if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
}
/// Adds another one input vector and the mask for the shuffling.
@@ -8928,17 +9578,46 @@ public:
inversePermutation(Order, NewMask);
add(V1, NewMask);
}
+ Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
+ return R.gather(VL, Root);
+ }
+ Value *createFreeze(Value *V) { return Builder.CreateFreeze(V); }
/// Finalize emission of the shuffles.
+ /// \param Action the action (if any) to be performed before final applying of
+ /// the \p ExtMask mask.
Value *
- finalize(ArrayRef<int> ExtMask = std::nullopt) {
+ finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
+ function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
IsFinalized = true;
+ if (Action) {
+ Value *Vec = InVectors.front();
+ if (InVectors.size() == 2) {
+ Vec = createShuffle(Vec, InVectors.back(), CommonMask);
+ InVectors.pop_back();
+ } else {
+ Vec = createShuffle(Vec, nullptr, CommonMask);
+ }
+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
+ if (CommonMask[Idx] != PoisonMaskElem)
+ CommonMask[Idx] = Idx;
+ assert(VF > 0 &&
+ "Expected vector length for the final value before action.");
+ unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements();
+ if (VecVF < VF) {
+ SmallVector<int> ResizeMask(VF, PoisonMaskElem);
+ std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0);
+ Vec = createShuffle(Vec, nullptr, ResizeMask);
+ }
+ Action(Vec, CommonMask);
+ InVectors.front() = Vec;
+ }
if (!ExtMask.empty()) {
if (CommonMask.empty()) {
CommonMask.assign(ExtMask.begin(), ExtMask.end());
} else {
- SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem);
+ SmallVector<int> NewMask(ExtMask.size(), PoisonMaskElem);
for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) {
- if (ExtMask[I] == UndefMaskElem)
+ if (ExtMask[I] == PoisonMaskElem)
continue;
NewMask[I] = CommonMask[ExtMask[I]];
}
@@ -9009,18 +9688,18 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
// ... (use %2)
// %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0}
// br %block
- SmallVector<int> UniqueIdxs(VF, UndefMaskElem);
+ SmallVector<int> UniqueIdxs(VF, PoisonMaskElem);
SmallSet<int, 4> UsedIdxs;
int Pos = 0;
for (int Idx : VE->ReuseShuffleIndices) {
- if (Idx != static_cast<int>(VF) && Idx != UndefMaskElem &&
+ if (Idx != static_cast<int>(VF) && Idx != PoisonMaskElem &&
UsedIdxs.insert(Idx).second)
UniqueIdxs[Idx] = Pos;
++Pos;
}
assert(VF >= UsedIdxs.size() && "Expected vectorization factor "
"less than original vector size.");
- UniqueIdxs.append(VF - UsedIdxs.size(), UndefMaskElem);
+ UniqueIdxs.append(VF - UsedIdxs.size(), PoisonMaskElem);
V = FinalShuffle(V, UniqueIdxs);
} else {
assert(VF < cast<FixedVectorType>(V->getType())->getNumElements() &&
@@ -9031,6 +9710,21 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
V = FinalShuffle(V, UniformMask);
}
}
+ // Need to update the operand gather node, if actually the operand is not a
+ // vectorized node, but the buildvector/gather node, which matches one of
+ // the vectorized nodes.
+ if (find_if(VE->UserTreeIndices, [&](const EdgeInfo &EI) {
+ return EI.UserTE == E && EI.EdgeIdx == NodeIdx;
+ }) == VE->UserTreeIndices.end()) {
+ auto *It = find_if(
+ VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
+ return TE->State == TreeEntry::NeedToGather &&
+ TE->UserTreeIndices.front().UserTE == E &&
+ TE->UserTreeIndices.front().EdgeIdx == NodeIdx;
+ });
+ assert(It != VectorizableTree.end() && "Expected gather node operand.");
+ (*It)->VectorizedValue = V;
+ }
return V;
}
}
@@ -9049,108 +9743,370 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
IRBuilder<>::InsertPointGuard Guard(Builder);
if (E->getOpcode() != Instruction::InsertElement &&
E->getOpcode() != Instruction::PHI) {
- Instruction *LastInst = EntryToLastInstruction.lookup(E);
+ Instruction *LastInst = &getLastInstructionInBundle(E);
assert(LastInst && "Failed to find last instruction in bundle");
Builder.SetInsertPoint(LastInst);
}
return vectorizeTree(I->get());
}
-Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
+template <typename BVTy, typename ResTy, typename... Args>
+ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
assert(E->State == TreeEntry::NeedToGather && "Expected gather node.");
unsigned VF = E->getVectorFactor();
- ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
- SmallVector<Value *> Gathered(
- VF, PoisonValue::get(E->Scalars.front()->getType()));
bool NeedFreeze = false;
- SmallVector<Value *> VL(E->Scalars.begin(), E->Scalars.end());
- // Build a mask out of the redorder indices and reorder scalars per this mask.
+ SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(),
+ E->ReuseShuffleIndices.end());
+ SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end());
+ // Build a mask out of the reorder indices and reorder scalars per this
+ // mask.
SmallVector<int> ReorderMask;
inversePermutation(E->ReorderIndices, ReorderMask);
if (!ReorderMask.empty())
- reorderScalars(VL, ReorderMask);
- SmallVector<int> ReuseMask(VF, UndefMaskElem);
- if (!allConstant(VL)) {
+ reorderScalars(GatheredScalars, ReorderMask);
+ auto FindReusedSplat = [&](SmallVectorImpl<int> &Mask) {
+ if (!isSplat(E->Scalars) || none_of(E->Scalars, [](Value *V) {
+ return isa<UndefValue>(V) && !isa<PoisonValue>(V);
+ }))
+ return false;
+ TreeEntry *UserTE = E->UserTreeIndices.back().UserTE;
+ unsigned EdgeIdx = E->UserTreeIndices.back().EdgeIdx;
+ if (UserTE->getNumOperands() != 2)
+ return false;
+ auto *It =
+ find_if(VectorizableTree, [=](const std::unique_ptr<TreeEntry> &TE) {
+ return find_if(TE->UserTreeIndices, [=](const EdgeInfo &EI) {
+ return EI.UserTE == UserTE && EI.EdgeIdx != EdgeIdx;
+ }) != TE->UserTreeIndices.end();
+ });
+ if (It == VectorizableTree.end())
+ return false;
+ unsigned I =
+ *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; });
+ int Sz = Mask.size();
+ if (all_of(Mask, [Sz](int Idx) { return Idx < 2 * Sz; }) &&
+ ShuffleVectorInst::isIdentityMask(Mask))
+ std::iota(Mask.begin(), Mask.end(), 0);
+ else
+ std::fill(Mask.begin(), Mask.end(), I);
+ return true;
+ };
+ BVTy ShuffleBuilder(Params...);
+ ResTy Res = ResTy();
+ SmallVector<int> Mask;
+ SmallVector<int> ExtractMask;
+ std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle;
+ std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle;
+ SmallVector<const TreeEntry *> Entries;
+ Type *ScalarTy = GatheredScalars.front()->getType();
+ if (!all_of(GatheredScalars, UndefValue::classof)) {
+ // Check for gathered extracts.
+ ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask);
+ SmallVector<Value *> IgnoredVals;
+ if (UserIgnoreList)
+ IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end());
+ bool Resized = false;
+ if (Value *VecBase = ShuffleBuilder.adjustExtracts(E, ExtractMask))
+ if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
+ if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) {
+ Resized = true;
+ GatheredScalars.append(VF - GatheredScalars.size(),
+ PoisonValue::get(ScalarTy));
+ }
+ // Gather extracts after we check for full matched gathers only.
+ if (ExtractShuffle || E->getOpcode() != Instruction::Load ||
+ E->isAltShuffle() ||
+ all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) ||
+ isSplat(E->Scalars) ||
+ (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) {
+ GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries);
+ }
+ if (GatherShuffle) {
+ if (Value *Delayed = ShuffleBuilder.needToDelay(E, Entries)) {
+ // Delay emission of gathers which are not ready yet.
+ PostponedGathers.insert(E);
+ // Postpone gather emission, will be emitted after the end of the
+ // process to keep correct order.
+ return Delayed;
+ }
+ assert((Entries.size() == 1 || Entries.size() == 2) &&
+ "Expected shuffle of 1 or 2 entries.");
+ if (*GatherShuffle == TTI::SK_PermuteSingleSrc &&
+ Entries.front()->isSame(E->Scalars)) {
+ // Perfect match in the graph, will reuse the previously vectorized
+ // node. Cost is 0.
+ LLVM_DEBUG(
+ dbgs()
+ << "SLP: perfect diamond match for gather bundle that starts with "
+ << *E->Scalars.front() << ".\n");
+ // Restore the mask for previous partially matched values.
+ if (Entries.front()->ReorderIndices.empty() &&
+ ((Entries.front()->ReuseShuffleIndices.empty() &&
+ E->Scalars.size() == Entries.front()->Scalars.size()) ||
+ (E->Scalars.size() ==
+ Entries.front()->ReuseShuffleIndices.size()))) {
+ std::iota(Mask.begin(), Mask.end(), 0);
+ } else {
+ for (auto [I, V] : enumerate(E->Scalars)) {
+ if (isa<PoisonValue>(V)) {
+ Mask[I] = PoisonMaskElem;
+ continue;
+ }
+ Mask[I] = Entries.front()->findLaneForValue(V);
+ }
+ }
+ ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask);
+ Res = ShuffleBuilder.finalize(E->getCommonMask());
+ return Res;
+ }
+ if (!Resized) {
+ unsigned VF1 = Entries.front()->getVectorFactor();
+ unsigned VF2 = Entries.back()->getVectorFactor();
+ if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF)
+ GatheredScalars.append(VF - GatheredScalars.size(),
+ PoisonValue::get(ScalarTy));
+ }
+ // Remove shuffled elements from list of gathers.
+ for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
+ if (Mask[I] != PoisonMaskElem)
+ GatheredScalars[I] = PoisonValue::get(ScalarTy);
+ }
+ }
+ }
+ auto TryPackScalars = [&](SmallVectorImpl<Value *> &Scalars,
+ SmallVectorImpl<int> &ReuseMask,
+ bool IsRootPoison) {
// For splats with can emit broadcasts instead of gathers, so try to find
// such sequences.
- bool IsSplat = isSplat(VL) && (VL.size() > 2 || VL.front() == VL.back());
+ bool IsSplat = IsRootPoison && isSplat(Scalars) &&
+ (Scalars.size() > 2 || Scalars.front() == Scalars.back());
+ Scalars.append(VF - Scalars.size(), PoisonValue::get(ScalarTy));
SmallVector<int> UndefPos;
DenseMap<Value *, unsigned> UniquePositions;
// Gather unique non-const values and all constant values.
// For repeated values, just shuffle them.
- for (auto [I, V] : enumerate(VL)) {
+ int NumNonConsts = 0;
+ int SinglePos = 0;
+ for (auto [I, V] : enumerate(Scalars)) {
if (isa<UndefValue>(V)) {
if (!isa<PoisonValue>(V)) {
- Gathered[I] = V;
ReuseMask[I] = I;
UndefPos.push_back(I);
}
continue;
}
if (isConstant(V)) {
- Gathered[I] = V;
ReuseMask[I] = I;
continue;
}
+ ++NumNonConsts;
+ SinglePos = I;
+ Value *OrigV = V;
+ Scalars[I] = PoisonValue::get(ScalarTy);
if (IsSplat) {
- Gathered.front() = V;
+ Scalars.front() = OrigV;
ReuseMask[I] = 0;
} else {
- const auto Res = UniquePositions.try_emplace(V, I);
- Gathered[Res.first->second] = V;
+ const auto Res = UniquePositions.try_emplace(OrigV, I);
+ Scalars[Res.first->second] = OrigV;
ReuseMask[I] = Res.first->second;
}
}
- if (!UndefPos.empty() && IsSplat) {
+ if (NumNonConsts == 1) {
+ // Restore single insert element.
+ if (IsSplat) {
+ ReuseMask.assign(VF, PoisonMaskElem);
+ std::swap(Scalars.front(), Scalars[SinglePos]);
+ if (!UndefPos.empty() && UndefPos.front() == 0)
+ Scalars.front() = UndefValue::get(ScalarTy);
+ }
+ ReuseMask[SinglePos] = SinglePos;
+ } else if (!UndefPos.empty() && IsSplat) {
// For undef values, try to replace them with the simple broadcast.
// We can do it if the broadcasted value is guaranteed to be
// non-poisonous, or by freezing the incoming scalar value first.
- auto *It = find_if(Gathered, [this, E](Value *V) {
+ auto *It = find_if(Scalars, [this, E](Value *V) {
return !isa<UndefValue>(V) &&
(getTreeEntry(V) || isGuaranteedNotToBePoison(V) ||
- any_of(V->uses(), [E](const Use &U) {
- // Check if the value already used in the same operation in
- // one of the nodes already.
- return E->UserTreeIndices.size() == 1 &&
- is_contained(
- E->UserTreeIndices.front().UserTE->Scalars,
- U.getUser()) &&
- E->UserTreeIndices.front().EdgeIdx != U.getOperandNo();
- }));
+ (E->UserTreeIndices.size() == 1 &&
+ any_of(V->uses(), [E](const Use &U) {
+ // Check if the value already used in the same operation in
+ // one of the nodes already.
+ return E->UserTreeIndices.front().EdgeIdx !=
+ U.getOperandNo() &&
+ is_contained(
+ E->UserTreeIndices.front().UserTE->Scalars,
+ U.getUser());
+ })));
});
- if (It != Gathered.end()) {
+ if (It != Scalars.end()) {
// Replace undefs by the non-poisoned scalars and emit broadcast.
- int Pos = std::distance(Gathered.begin(), It);
+ int Pos = std::distance(Scalars.begin(), It);
for_each(UndefPos, [&](int I) {
// Set the undef position to the non-poisoned scalar.
ReuseMask[I] = Pos;
- // Replace the undef by the poison, in the mask it is replaced by non-poisoned scalar already.
+ // Replace the undef by the poison, in the mask it is replaced by
+ // non-poisoned scalar already.
if (I != Pos)
- Gathered[I] = PoisonValue::get(Gathered[I]->getType());
+ Scalars[I] = PoisonValue::get(ScalarTy);
});
} else {
// Replace undefs by the poisons, emit broadcast and then emit
// freeze.
for_each(UndefPos, [&](int I) {
- ReuseMask[I] = UndefMaskElem;
- if (isa<UndefValue>(Gathered[I]))
- Gathered[I] = PoisonValue::get(Gathered[I]->getType());
+ ReuseMask[I] = PoisonMaskElem;
+ if (isa<UndefValue>(Scalars[I]))
+ Scalars[I] = PoisonValue::get(ScalarTy);
});
NeedFreeze = true;
}
}
+ };
+ if (ExtractShuffle || GatherShuffle) {
+ bool IsNonPoisoned = true;
+ bool IsUsedInExpr = false;
+ Value *Vec1 = nullptr;
+ if (ExtractShuffle) {
+ // Gather of extractelements can be represented as just a shuffle of
+ // a single/two vectors the scalars are extracted from.
+ // Find input vectors.
+ Value *Vec2 = nullptr;
+ for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
+ if (ExtractMask[I] == PoisonMaskElem ||
+ (!Mask.empty() && Mask[I] != PoisonMaskElem)) {
+ ExtractMask[I] = PoisonMaskElem;
+ continue;
+ }
+ if (isa<UndefValue>(E->Scalars[I]))
+ continue;
+ auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
+ if (!Vec1) {
+ Vec1 = EI->getVectorOperand();
+ } else if (Vec1 != EI->getVectorOperand()) {
+ assert((!Vec2 || Vec2 == EI->getVectorOperand()) &&
+ "Expected only 1 or 2 vectors shuffle.");
+ Vec2 = EI->getVectorOperand();
+ }
+ }
+ if (Vec2) {
+ IsNonPoisoned &=
+ isGuaranteedNotToBePoison(Vec1) && isGuaranteedNotToBePoison(Vec2);
+ ShuffleBuilder.add(Vec1, Vec2, ExtractMask);
+ } else if (Vec1) {
+ IsUsedInExpr = FindReusedSplat(ExtractMask);
+ ShuffleBuilder.add(Vec1, ExtractMask);
+ IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1);
+ } else {
+ ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get(
+ ScalarTy, GatheredScalars.size())),
+ ExtractMask);
+ }
+ }
+ if (GatherShuffle) {
+ if (Entries.size() == 1) {
+ IsUsedInExpr = FindReusedSplat(Mask);
+ ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask);
+ IsNonPoisoned &=
+ isGuaranteedNotToBePoison(Entries.front()->VectorizedValue);
+ } else {
+ ShuffleBuilder.add(Entries.front()->VectorizedValue,
+ Entries.back()->VectorizedValue, Mask);
+ IsNonPoisoned &=
+ isGuaranteedNotToBePoison(Entries.front()->VectorizedValue) &&
+ isGuaranteedNotToBePoison(Entries.back()->VectorizedValue);
+ }
+ }
+ // Try to figure out best way to combine values: build a shuffle and insert
+ // elements or just build several shuffles.
+ // Insert non-constant scalars.
+ SmallVector<Value *> NonConstants(GatheredScalars);
+ int EMSz = ExtractMask.size();
+ int MSz = Mask.size();
+ // Try to build constant vector and shuffle with it only if currently we
+ // have a single permutation and more than 1 scalar constants.
+ bool IsSingleShuffle = !ExtractShuffle || !GatherShuffle;
+ bool IsIdentityShuffle =
+ (ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc) ==
+ TTI::SK_PermuteSingleSrc &&
+ none_of(ExtractMask, [&](int I) { return I >= EMSz; }) &&
+ ShuffleVectorInst::isIdentityMask(ExtractMask)) ||
+ (GatherShuffle.value_or(TTI::SK_PermuteTwoSrc) ==
+ TTI::SK_PermuteSingleSrc &&
+ none_of(Mask, [&](int I) { return I >= MSz; }) &&
+ ShuffleVectorInst::isIdentityMask(Mask));
+ bool EnoughConstsForShuffle =
+ IsSingleShuffle &&
+ (none_of(GatheredScalars,
+ [](Value *V) {
+ return isa<UndefValue>(V) && !isa<PoisonValue>(V);
+ }) ||
+ any_of(GatheredScalars,
+ [](Value *V) {
+ return isa<Constant>(V) && !isa<UndefValue>(V);
+ })) &&
+ (!IsIdentityShuffle ||
+ (GatheredScalars.size() == 2 &&
+ any_of(GatheredScalars,
+ [](Value *V) { return !isa<UndefValue>(V); })) ||
+ count_if(GatheredScalars, [](Value *V) {
+ return isa<Constant>(V) && !isa<PoisonValue>(V);
+ }) > 1);
+ // NonConstants array contains just non-constant values, GatheredScalars
+ // contains only constant to build final vector and then shuffle.
+ for (int I = 0, Sz = GatheredScalars.size(); I < Sz; ++I) {
+ if (EnoughConstsForShuffle && isa<Constant>(GatheredScalars[I]))
+ NonConstants[I] = PoisonValue::get(ScalarTy);
+ else
+ GatheredScalars[I] = PoisonValue::get(ScalarTy);
+ }
+ // Generate constants for final shuffle and build a mask for them.
+ if (!all_of(GatheredScalars, PoisonValue::classof)) {
+ SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem);
+ TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true);
+ Value *BV = ShuffleBuilder.gather(GatheredScalars);
+ ShuffleBuilder.add(BV, BVMask);
+ }
+ if (all_of(NonConstants, [=](Value *V) {
+ return isa<PoisonValue>(V) ||
+ (IsSingleShuffle && ((IsIdentityShuffle &&
+ IsNonPoisoned) || IsUsedInExpr) && isa<UndefValue>(V));
+ }))
+ Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
+ else
+ Res = ShuffleBuilder.finalize(
+ E->ReuseShuffleIndices, E->Scalars.size(),
+ [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
+ TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
+ Vec = ShuffleBuilder.gather(NonConstants, Vec);
+ });
+ } else if (!allConstant(GatheredScalars)) {
+ // Gather unique scalars and all constants.
+ SmallVector<int> ReuseMask(GatheredScalars.size(), PoisonMaskElem);
+ TryPackScalars(GatheredScalars, ReuseMask, /*IsRootPoison=*/true);
+ Value *BV = ShuffleBuilder.gather(GatheredScalars);
+ ShuffleBuilder.add(BV, ReuseMask);
+ Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
} else {
- ReuseMask.clear();
- copy(VL, Gathered.begin());
+ // Gather all constants.
+ SmallVector<int> Mask(E->Scalars.size(), PoisonMaskElem);
+ for (auto [I, V] : enumerate(E->Scalars)) {
+ if (!isa<PoisonValue>(V))
+ Mask[I] = I;
+ }
+ Value *BV = ShuffleBuilder.gather(E->Scalars);
+ ShuffleBuilder.add(BV, Mask);
+ Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
}
- // Gather unique scalars and all constants.
- Value *Vec = gather(Gathered);
- ShuffleBuilder.add(Vec, ReuseMask);
- Vec = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
+
if (NeedFreeze)
- Vec = Builder.CreateFreeze(Vec);
- return Vec;
+ Res = ShuffleBuilder.createFreeze(Res);
+ return Res;
+}
+
+Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
+ return processBuildVector<ShuffleInstructionBuilder, Value *>(E, Builder,
+ *this);
}
Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
@@ -9161,10 +10117,17 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return E->VectorizedValue;
}
+ if (E->State == TreeEntry::NeedToGather) {
+ if (E->getMainOp() && E->Idx == 0)
+ setInsertPointAfterBundle(E);
+ Value *Vec = createBuildVector(E);
+ E->VectorizedValue = Vec;
+ return Vec;
+ }
+
auto FinalShuffle = [&](Value *V, const TreeEntry *E) {
ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
- if (E->State != TreeEntry::NeedToGather &&
- E->getOpcode() == Instruction::Store) {
+ if (E->getOpcode() == Instruction::Store) {
ArrayRef<int> Mask =
ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()),
E->ReorderIndices.size());
@@ -9175,45 +10138,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return ShuffleBuilder.finalize(E->ReuseShuffleIndices);
};
- if (E->State == TreeEntry::NeedToGather) {
- if (E->Idx > 0) {
- // We are in the middle of a vectorizable chain. We need to gather the
- // scalars from the users.
- Value *Vec = createBuildVector(E);
- E->VectorizedValue = Vec;
- return Vec;
- }
- if (E->getMainOp())
- setInsertPointAfterBundle(E);
- SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end());
- // Build a mask out of the reorder indices and reorder scalars per this
- // mask.
- SmallVector<int> ReorderMask;
- inversePermutation(E->ReorderIndices, ReorderMask);
- if (!ReorderMask.empty())
- reorderScalars(GatheredScalars, ReorderMask);
- Value *Vec;
- SmallVector<int> Mask;
- SmallVector<const TreeEntry *> Entries;
- std::optional<TargetTransformInfo::ShuffleKind> Shuffle =
- isGatherShuffledEntry(E, GatheredScalars, Mask, Entries);
- if (Shuffle) {
- assert((Entries.size() == 1 || Entries.size() == 2) &&
- "Expected shuffle of 1 or 2 entries.");
- Vec = Builder.CreateShuffleVector(Entries.front()->VectorizedValue,
- Entries.back()->VectorizedValue, Mask);
- if (auto *I = dyn_cast<Instruction>(Vec)) {
- GatherShuffleExtractSeq.insert(I);
- CSEBlocks.insert(I->getParent());
- }
- } else {
- Vec = gather(E->Scalars);
- }
- Vec = FinalShuffle(Vec, E);
- E->VectorizedValue = Vec;
- return Vec;
- }
-
assert((E->State == TreeEntry::Vectorize ||
E->State == TreeEntry::ScatterVectorize) &&
"Unhandled state");
@@ -9248,7 +10172,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
// PHINodes may have multiple entries from the same block. We want to
// visit every block once.
- SmallPtrSet<BasicBlock*, 4> VisitedBBs;
+ SmallPtrSet<BasicBlock *, 4> VisitedBBs;
for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
ValueList Operands;
@@ -9314,14 +10238,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
SmallVector<int> Mask;
if (!E->ReorderIndices.empty()) {
inversePermutation(E->ReorderIndices, Mask);
- Mask.append(NumElts - NumScalars, UndefMaskElem);
+ Mask.append(NumElts - NumScalars, PoisonMaskElem);
} else {
- Mask.assign(NumElts, UndefMaskElem);
+ Mask.assign(NumElts, PoisonMaskElem);
std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0);
}
// Create InsertVector shuffle if necessary
bool IsIdentity = true;
- SmallVector<int> PrevMask(NumElts, UndefMaskElem);
+ SmallVector<int> PrevMask(NumElts, PoisonMaskElem);
Mask.swap(PrevMask);
for (unsigned I = 0; I < NumScalars; ++I) {
Value *Scalar = E->Scalars[PrevMask[I]];
@@ -9337,9 +10261,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
}
- SmallVector<int> InsertMask(NumElts, UndefMaskElem);
+ SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
for (unsigned I = 0; I < NumElts; I++) {
- if (Mask[I] != UndefMaskElem)
+ if (Mask[I] != PoisonMaskElem)
InsertMask[Offset + I] = I;
}
SmallBitVector UseMask =
@@ -9354,10 +10278,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
isUndefVector<true>(FirstInsert->getOperand(0), UseMask);
if (!IsFirstPoison.all()) {
for (unsigned I = 0; I < NumElts; I++) {
- if (InsertMask[I] == UndefMaskElem && !IsFirstPoison.test(I))
+ if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I))
InsertMask[I] = I + NumElts;
}
- }
+ }
V = Builder.CreateShuffleVector(
V,
IsFirstPoison.all() ? PoisonValue::get(V->getType())
@@ -9372,8 +10296,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
SmallBitVector IsFirstPoison =
isUndefVector<true>(FirstInsert->getOperand(0), UseMask);
for (unsigned I = 0; I < NumElts; I++) {
- if (InsertMask[I] == UndefMaskElem)
- InsertMask[I] = IsFirstPoison.test(I) ? UndefMaskElem : I;
+ if (InsertMask[I] == PoisonMaskElem)
+ InsertMask[I] = IsFirstPoison.test(I) ? PoisonMaskElem : I;
else
InsertMask[I] += NumElts;
}
@@ -9544,20 +10468,17 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
LoadInst *LI = cast<LoadInst>(VL0);
Instruction *NewLI;
- unsigned AS = LI->getPointerAddressSpace();
Value *PO = LI->getPointerOperand();
if (E->State == TreeEntry::Vectorize) {
- Value *VecPtr = Builder.CreateBitCast(PO, VecTy->getPointerTo(AS));
- NewLI = Builder.CreateAlignedLoad(VecTy, VecPtr, LI->getAlign());
+ NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());
- // The pointer operand uses an in-tree scalar so we add the new BitCast
- // or LoadInst to ExternalUses list to make sure that an extract will
+ // The pointer operand uses an in-tree scalar so we add the new
+ // LoadInst to ExternalUses list to make sure that an extract will
// be generated in the future.
if (TreeEntry *Entry = getTreeEntry(PO)) {
// Find which lane we need to extract.
unsigned FoundLane = Entry->findLaneForValue(PO);
- ExternalUses.emplace_back(
- PO, PO != VecPtr ? cast<User>(VecPtr) : NewLI, FoundLane);
+ ExternalUses.emplace_back(PO, NewLI, FoundLane);
}
} else {
assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state");
@@ -9653,7 +10574,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
CallInst *CI = cast<CallInst>(VL0);
setInsertPointAfterBundle(E);
- Intrinsic::ID IID = Intrinsic::not_intrinsic;
+ Intrinsic::ID IID = Intrinsic::not_intrinsic;
if (Function *FI = CI->getCalledFunction())
IID = FI->getIntrinsicID();
@@ -9665,8 +10586,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Value *ScalarArg = nullptr;
std::vector<Value *> OpVecs;
- SmallVector<Type *, 2> TysForDecl =
- {FixedVectorType::get(CI->getType(), E->Scalars.size())};
+ SmallVector<Type *, 2> TysForDecl;
+ // Add return type if intrinsic is overloaded on it.
+ if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+ TysForDecl.push_back(
+ FixedVectorType::get(CI->getType(), E->Scalars.size()));
for (int j = 0, e = CI->arg_size(); j < e; ++j) {
ValueList OpVL;
// Some intrinsics have scalar arguments. This argument should not be
@@ -9808,14 +10732,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
return V;
}
default:
- llvm_unreachable("unknown inst");
+ llvm_unreachable("unknown inst");
}
return nullptr;
}
Value *BoUpSLP::vectorizeTree() {
ExtraValueToDebugLocsMap ExternallyUsedValues;
- return vectorizeTree(ExternallyUsedValues);
+ SmallVector<std::pair<Value *, Value *>> ReplacedExternals;
+ return vectorizeTree(ExternallyUsedValues, ReplacedExternals);
}
namespace {
@@ -9829,28 +10754,51 @@ struct ShuffledInsertData {
};
} // namespace
-Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
- Instruction *ReductionRoot) {
+Value *BoUpSLP::vectorizeTree(
+ const ExtraValueToDebugLocsMap &ExternallyUsedValues,
+ SmallVectorImpl<std::pair<Value *, Value *>> &ReplacedExternals,
+ Instruction *ReductionRoot) {
// All blocks must be scheduled before any instructions are inserted.
for (auto &BSIter : BlocksSchedules) {
scheduleBlock(BSIter.second.get());
}
-
- // Pre-gather last instructions.
- for (const std::unique_ptr<TreeEntry> &E : VectorizableTree) {
- if ((E->State == TreeEntry::NeedToGather &&
- (!E->getMainOp() || E->Idx > 0)) ||
- (E->State != TreeEntry::NeedToGather &&
- E->getOpcode() == Instruction::ExtractValue) ||
- E->getOpcode() == Instruction::InsertElement)
- continue;
- Instruction *LastInst = &getLastInstructionInBundle(E.get());
- EntryToLastInstruction.try_emplace(E.get(), LastInst);
- }
+ // Clean Entry-to-LastInstruction table. It can be affected after scheduling,
+ // need to rebuild it.
+ EntryToLastInstruction.clear();
Builder.SetInsertPoint(ReductionRoot ? ReductionRoot
: &F->getEntryBlock().front());
auto *VectorRoot = vectorizeTree(VectorizableTree[0].get());
+ // Run through the list of postponed gathers and emit them, replacing the temp
+ // emitted allocas with actual vector instructions.
+ ArrayRef<const TreeEntry *> PostponedNodes = PostponedGathers.getArrayRef();
+ DenseMap<Value *, SmallVector<TreeEntry *>> PostponedValues;
+ for (const TreeEntry *E : PostponedNodes) {
+ auto *TE = const_cast<TreeEntry *>(E);
+ if (auto *VecTE = getTreeEntry(TE->Scalars.front()))
+ if (VecTE->isSame(TE->UserTreeIndices.front().UserTE->getOperand(
+ TE->UserTreeIndices.front().EdgeIdx)))
+ // Found gather node which is absolutely the same as one of the
+ // vectorized nodes. It may happen after reordering.
+ continue;
+ auto *PrevVec = cast<Instruction>(TE->VectorizedValue);
+ TE->VectorizedValue = nullptr;
+ auto *UserI =
+ cast<Instruction>(TE->UserTreeIndices.front().UserTE->VectorizedValue);
+ Builder.SetInsertPoint(PrevVec);
+ Builder.SetCurrentDebugLocation(UserI->getDebugLoc());
+ Value *Vec = vectorizeTree(TE);
+ PrevVec->replaceAllUsesWith(Vec);
+ PostponedValues.try_emplace(Vec).first->second.push_back(TE);
+ // Replace the stub vector node, if it was used before for one of the
+ // buildvector nodes already.
+ auto It = PostponedValues.find(PrevVec);
+ if (It != PostponedValues.end()) {
+ for (TreeEntry *VTE : It->getSecond())
+ VTE->VectorizedValue = Vec;
+ }
+ eraseInstruction(PrevVec);
+ }
// If the vectorized tree can be rewritten in a smaller type, we truncate the
// vectorized root. InstCombine will then rewrite the entire expression. We
@@ -9968,14 +10916,9 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
Builder.SetInsertPoint(&F->getEntryBlock().front());
}
Value *NewInst = ExtractAndExtendIfNeeded(Vec);
- auto &NewInstLocs = ExternallyUsedValues[NewInst];
- auto It = ExternallyUsedValues.find(Scalar);
- assert(It != ExternallyUsedValues.end() &&
- "Externally used scalar is not found in ExternallyUsedValues");
- NewInstLocs.append(It->second);
- ExternallyUsedValues.erase(Scalar);
// Required to update internally referenced instructions.
Scalar->replaceAllUsesWith(NewInst);
+ ReplacedExternals.emplace_back(Scalar, NewInst);
continue;
}
@@ -10004,7 +10947,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
ShuffledInserts.size() - 1);
SmallVectorImpl<int> &Mask = It->ValueMasks[Vec];
if (Mask.empty())
- Mask.assign(FTy->getNumElements(), UndefMaskElem);
+ Mask.assign(FTy->getNumElements(), PoisonMaskElem);
// Find the insertvector, vectorized in tree, if any.
Value *Base = VU;
while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) {
@@ -10017,7 +10960,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
do {
IEBase = cast<InsertElementInst>(Base);
int IEIdx = *getInsertIndex(IEBase);
- assert(Mask[Idx] == UndefMaskElem &&
+ assert(Mask[Idx] == PoisonMaskElem &&
"InsertElementInstruction used already.");
Mask[IEIdx] = IEIdx;
Base = IEBase->getOperand(0);
@@ -10035,7 +10978,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
}
SmallVectorImpl<int> &Mask = It->ValueMasks[Vec];
if (Mask.empty())
- Mask.assign(FTy->getNumElements(), UndefMaskElem);
+ Mask.assign(FTy->getNumElements(), PoisonMaskElem);
Mask[Idx] = ExternalUse.Lane;
It->InsertElements.push_back(cast<InsertElementInst>(User));
continue;
@@ -10077,8 +11020,8 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
}
auto CreateShuffle = [&](Value *V1, Value *V2, ArrayRef<int> Mask) {
- SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem);
- SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem);
+ SmallVector<int> CombinedMask1(Mask.size(), PoisonMaskElem);
+ SmallVector<int> CombinedMask2(Mask.size(), PoisonMaskElem);
int VF = cast<FixedVectorType>(V1->getType())->getNumElements();
for (int I = 0, E = Mask.size(); I < E; ++I) {
if (Mask[I] < VF)
@@ -10103,9 +11046,9 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues,
return std::make_pair(Vec, true);
}
if (!ForSingleMask) {
- SmallVector<int> ResizeMask(VF, UndefMaskElem);
+ SmallVector<int> ResizeMask(VF, PoisonMaskElem);
for (unsigned I = 0; I < VF; ++I) {
- if (Mask[I] != UndefMaskElem)
+ if (Mask[I] != PoisonMaskElem)
ResizeMask[Mask[I]] = Mask[I];
}
Vec = CreateShuffle(Vec, nullptr, ResizeMask);
@@ -10308,14 +11251,14 @@ void BoUpSLP::optimizeGatherSequence() {
// registers.
unsigned LastUndefsCnt = 0;
for (int I = 0, E = NewMask.size(); I < E; ++I) {
- if (SM1[I] == UndefMaskElem)
+ if (SM1[I] == PoisonMaskElem)
++LastUndefsCnt;
else
LastUndefsCnt = 0;
- if (NewMask[I] != UndefMaskElem && SM1[I] != UndefMaskElem &&
+ if (NewMask[I] != PoisonMaskElem && SM1[I] != PoisonMaskElem &&
NewMask[I] != SM1[I])
return false;
- if (NewMask[I] == UndefMaskElem)
+ if (NewMask[I] == PoisonMaskElem)
NewMask[I] = SM1[I];
}
// Check if the last undefs actually change the final number of used vector
@@ -10590,11 +11533,20 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V,
}
// Search up and down at the same time, because we don't know if the new
// instruction is above or below the existing scheduling region.
+ // Ignore debug info (and other "AssumeLike" intrinsics) so that's not counted
+ // against the budget. Otherwise debug info could affect codegen.
BasicBlock::reverse_iterator UpIter =
++ScheduleStart->getIterator().getReverse();
BasicBlock::reverse_iterator UpperEnd = BB->rend();
BasicBlock::iterator DownIter = ScheduleEnd->getIterator();
BasicBlock::iterator LowerEnd = BB->end();
+ auto IsAssumeLikeIntr = [](const Instruction &I) {
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ return II->isAssumeLikeIntrinsic();
+ return false;
+ };
+ UpIter = std::find_if_not(UpIter, UpperEnd, IsAssumeLikeIntr);
+ DownIter = std::find_if_not(DownIter, LowerEnd, IsAssumeLikeIntr);
while (UpIter != UpperEnd && DownIter != LowerEnd && &*UpIter != I &&
&*DownIter != I) {
if (++ScheduleRegionSize > ScheduleRegionSizeLimit) {
@@ -10604,6 +11556,9 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V,
++UpIter;
++DownIter;
+
+ UpIter = std::find_if_not(UpIter, UpperEnd, IsAssumeLikeIntr);
+ DownIter = std::find_if_not(DownIter, LowerEnd, IsAssumeLikeIntr);
}
if (DownIter == LowerEnd || (UpIter != UpperEnd && &*UpIter == I)) {
assert(I->getParent() == ScheduleStart->getParent() &&
@@ -10804,7 +11759,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
unsigned numAliased = 0;
unsigned DistToSrc = 1;
- for ( ; DepDest; DepDest = DepDest->NextLoadStore) {
+ for (; DepDest; DepDest = DepDest->NextLoadStore) {
assert(isInSchedulingRegion(DepDest));
// We have two limits to reduce the complexity:
@@ -11163,8 +12118,8 @@ void BoUpSLP::computeMinimumValueSizes() {
// we can truncate the roots to this narrower type.
for (auto *Root : TreeRoot) {
auto Mask = DB->getDemandedBits(cast<Instruction>(Root));
- MaxBitWidth = std::max<unsigned>(
- Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth);
+ MaxBitWidth = std::max<unsigned>(Mask.getBitWidth() - Mask.countl_zero(),
+ MaxBitWidth);
}
// True if the roots can be zero-extended back to their original type, rather
@@ -11223,8 +12178,7 @@ void BoUpSLP::computeMinimumValueSizes() {
}
// Round MaxBitWidth up to the next power-of-two.
- if (!isPowerOf2_64(MaxBitWidth))
- MaxBitWidth = NextPowerOf2(MaxBitWidth);
+ MaxBitWidth = llvm::bit_ceil(MaxBitWidth);
// If the maximum bit width we compute is less than the with of the roots'
// type, we can proceed with the narrowing. Otherwise, do nothing.
@@ -11242,60 +12196,6 @@ void BoUpSLP::computeMinimumValueSizes() {
MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive);
}
-namespace {
-
-/// The SLPVectorizer Pass.
-struct SLPVectorizer : public FunctionPass {
- SLPVectorizerPass Impl;
-
- /// Pass identification, replacement for typeid
- static char ID;
-
- explicit SLPVectorizer() : FunctionPass(ID) {
- initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
- }
-
- bool doInitialization(Module &M) override { return false; }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
-
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
- auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;
- auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits();
- auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
-
- return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- FunctionPass::getAnalysisUsage(AU);
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<ScalarEvolutionWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<DemandedBitsWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<InjectTLIMappingsLegacy>();
- AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.setPreservesCFG();
- }
-};
-
-} // end anonymous namespace
-
PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) {
auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
@@ -11536,7 +12436,7 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores,
unsigned MaxVecRegSize = R.getMaxVecRegSize();
unsigned EltSize = R.getVectorElementSize(Operands[0]);
- unsigned MaxElts = llvm::PowerOf2Floor(MaxVecRegSize / EltSize);
+ unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize);
unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store),
MaxElts);
@@ -11618,17 +12518,8 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) {
}
}
-bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) {
- if (!A || !B)
- return false;
- if (isa<InsertElementInst>(A) || isa<InsertElementInst>(B))
- return false;
- Value *VL[] = {A, B};
- return tryToVectorizeList(VL, R);
-}
-
bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
- bool LimitForRegisterSize) {
+ bool MaxVFOnly) {
if (VL.size() < 2)
return false;
@@ -11663,7 +12554,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
unsigned Sz = R.getVectorElementSize(I0);
unsigned MinVF = R.getMinVF(Sz);
- unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF);
+ unsigned MaxVF = std::max<unsigned>(llvm::bit_floor(VL.size()), MinVF);
MaxVF = std::min(R.getMaximumVF(Sz, S.getOpcode()), MaxVF);
if (MaxVF < 2) {
R.getORE()->emit([&]() {
@@ -11690,21 +12581,17 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
if (TTI->getNumberOfParts(VecTy) == VF)
continue;
for (unsigned I = NextInst; I < MaxInst; ++I) {
- unsigned OpsWidth = 0;
+ unsigned ActualVF = std::min(MaxInst - I, VF);
- if (I + VF > MaxInst)
- OpsWidth = MaxInst - I;
- else
- OpsWidth = VF;
-
- if (!isPowerOf2_32(OpsWidth))
+ if (!isPowerOf2_32(ActualVF))
continue;
- if ((LimitForRegisterSize && OpsWidth < MaxVF) ||
- (VF > MinVF && OpsWidth <= VF / 2) || (VF == MinVF && OpsWidth < 2))
+ if (MaxVFOnly && ActualVF < MaxVF)
+ break;
+ if ((VF > MinVF && ActualVF <= VF / 2) || (VF == MinVF && ActualVF < 2))
break;
- ArrayRef<Value *> Ops = VL.slice(I, OpsWidth);
+ ArrayRef<Value *> Ops = VL.slice(I, ActualVF);
// Check that a previous iteration of this loop did not delete the Value.
if (llvm::any_of(Ops, [&R](Value *V) {
auto *I = dyn_cast<Instruction>(V);
@@ -11712,7 +12599,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
}))
continue;
- LLVM_DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations "
+ LLVM_DEBUG(dbgs() << "SLP: Analyzing " << ActualVF << " operations "
<< "\n");
R.buildTree(Ops);
@@ -11730,7 +12617,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
MinCost = std::min(MinCost, Cost);
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
- << " for VF=" << OpsWidth << "\n");
+ << " for VF=" << ActualVF << "\n");
if (Cost < -SLPCostThreshold) {
LLVM_DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n");
R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList",
@@ -11806,14 +12693,14 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
}
if (Candidates.size() == 1)
- return tryToVectorizePair(Op0, Op1, R);
+ return tryToVectorizeList({Op0, Op1}, R);
// We have multiple options. Try to pick the single best.
std::optional<int> BestCandidate = R.findBestRootPair(Candidates);
if (!BestCandidate)
return false;
- return tryToVectorizePair(Candidates[*BestCandidate].first,
- Candidates[*BestCandidate].second, R);
+ return tryToVectorizeList(
+ {Candidates[*BestCandidate].first, Candidates[*BestCandidate].second}, R);
}
namespace {
@@ -11857,6 +12744,9 @@ class HorizontalReduction {
WeakTrackingVH ReductionRoot;
/// The type of reduction operation.
RecurKind RdxKind;
+ /// Checks if the optimization of original scalar identity operations on
+ /// matched horizontal reductions is enabled and allowed.
+ bool IsSupportedHorRdxIdentityOp = false;
static bool isCmpSelMinMax(Instruction *I) {
return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) &&
@@ -11888,6 +12778,9 @@ class HorizontalReduction {
return I->getFastMathFlags().noNaNs();
}
+ if (Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimum)
+ return true;
+
return I->isAssociative();
}
@@ -11905,6 +12798,7 @@ class HorizontalReduction {
static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS,
Value *RHS, const Twine &Name, bool UseSelect) {
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+ bool IsConstant = isConstant(LHS) && isConstant(RHS);
switch (Kind) {
case RecurKind::Or:
if (UseSelect &&
@@ -11926,29 +12820,49 @@ class HorizontalReduction {
return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
Name);
case RecurKind::FMax:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ maxnum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
case RecurKind::FMin:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ minnum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
+ case RecurKind::FMaximum:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ maximum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
+ return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
+ case RecurKind::FMinimum:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ minimum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
+ return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
case RecurKind::SMax:
- if (UseSelect) {
+ if (IsConstant || UseSelect) {
Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
}
return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS);
case RecurKind::SMin:
- if (UseSelect) {
+ if (IsConstant || UseSelect) {
Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
}
return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS);
case RecurKind::UMax:
- if (UseSelect) {
+ if (IsConstant || UseSelect) {
Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
}
return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS);
case RecurKind::UMin:
- if (UseSelect) {
+ if (IsConstant || UseSelect) {
Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
}
@@ -11984,6 +12898,7 @@ class HorizontalReduction {
return Op;
}
+public:
static RecurKind getRdxKind(Value *V) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
@@ -12010,6 +12925,10 @@ class HorizontalReduction {
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
return RecurKind::FMin;
+ if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value())))
+ return RecurKind::FMaximum;
+ if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value())))
+ return RecurKind::FMinimum;
// This matches either cmp+select or intrinsics. SLP is expected to handle
// either form.
// TODO: If we are canonicalizing to intrinsics, we can remove several
@@ -12086,6 +13005,7 @@ class HorizontalReduction {
return isCmpSelMinMax(I) ? 1 : 0;
}
+private:
/// Total number of operands in the reduction operation.
static unsigned getNumberOfOperands(Instruction *I) {
return isCmpSelMinMax(I) ? 3 : 2;
@@ -12134,17 +13054,6 @@ class HorizontalReduction {
}
}
- static Value *getLHS(RecurKind Kind, Instruction *I) {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex(I));
- }
- static Value *getRHS(RecurKind Kind, Instruction *I) {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex(I) + 1);
- }
-
static bool isGoodForReduction(ArrayRef<Value *> Data) {
int Sz = Data.size();
auto *I = dyn_cast<Instruction>(Data.front());
@@ -12156,65 +13065,39 @@ public:
HorizontalReduction() = default;
/// Try to find a reduction tree.
- bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst,
+ bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
ScalarEvolution &SE, const DataLayout &DL,
const TargetLibraryInfo &TLI) {
- assert((!Phi || is_contained(Phi->operands(), Inst)) &&
- "Phi needs to use the binary operator");
- assert((isa<BinaryOperator>(Inst) || isa<SelectInst>(Inst) ||
- isa<IntrinsicInst>(Inst)) &&
- "Expected binop, select, or intrinsic for reduction matching");
- RdxKind = getRdxKind(Inst);
-
- // We could have a initial reductions that is not an add.
- // r *= v1 + v2 + v3 + v4
- // In such a case start looking for a tree rooted in the first '+'.
- if (Phi) {
- if (getLHS(RdxKind, Inst) == Phi) {
- Phi = nullptr;
- Inst = dyn_cast<Instruction>(getRHS(RdxKind, Inst));
- if (!Inst)
- return false;
- RdxKind = getRdxKind(Inst);
- } else if (getRHS(RdxKind, Inst) == Phi) {
- Phi = nullptr;
- Inst = dyn_cast<Instruction>(getLHS(RdxKind, Inst));
- if (!Inst)
- return false;
- RdxKind = getRdxKind(Inst);
- }
- }
-
- if (!isVectorizable(RdxKind, Inst))
+ RdxKind = HorizontalReduction::getRdxKind(Root);
+ if (!isVectorizable(RdxKind, Root))
return false;
// Analyze "regular" integer/FP types for reductions - no target-specific
// types or pointers.
- Type *Ty = Inst->getType();
+ Type *Ty = Root->getType();
if (!isValidElementType(Ty) || Ty->isPointerTy())
return false;
// Though the ultimate reduction may have multiple uses, its condition must
// have only single use.
- if (auto *Sel = dyn_cast<SelectInst>(Inst))
+ if (auto *Sel = dyn_cast<SelectInst>(Root))
if (!Sel->getCondition()->hasOneUse())
return false;
- ReductionRoot = Inst;
+ ReductionRoot = Root;
// Iterate through all the operands of the possible reduction tree and
// gather all the reduced values, sorting them by their value id.
- BasicBlock *BB = Inst->getParent();
- bool IsCmpSelMinMax = isCmpSelMinMax(Inst);
- SmallVector<Instruction *> Worklist(1, Inst);
+ BasicBlock *BB = Root->getParent();
+ bool IsCmpSelMinMax = isCmpSelMinMax(Root);
+ SmallVector<Instruction *> Worklist(1, Root);
// Checks if the operands of the \p TreeN instruction are also reduction
// operations or should be treated as reduced values or an extra argument,
// which is not part of the reduction.
- auto &&CheckOperands = [this, IsCmpSelMinMax,
- BB](Instruction *TreeN,
- SmallVectorImpl<Value *> &ExtraArgs,
- SmallVectorImpl<Value *> &PossibleReducedVals,
- SmallVectorImpl<Instruction *> &ReductionOps) {
+ auto CheckOperands = [&](Instruction *TreeN,
+ SmallVectorImpl<Value *> &ExtraArgs,
+ SmallVectorImpl<Value *> &PossibleReducedVals,
+ SmallVectorImpl<Instruction *> &ReductionOps) {
for (int I = getFirstOperandIndex(TreeN),
End = getNumberOfOperands(TreeN);
I < End; ++I) {
@@ -12229,10 +13112,14 @@ public:
}
// If the edge is not an instruction, or it is different from the main
// reduction opcode or has too many uses - possible reduced value.
+ // Also, do not try to reduce const values, if the operation is not
+ // foldable.
if (!EdgeInst || getRdxKind(EdgeInst) != RdxKind ||
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
- !isVectorizable(getRdxKind(EdgeInst), EdgeInst)) {
+ !isVectorizable(RdxKind, EdgeInst) ||
+ (R.isAnalyzedReductionRoot(EdgeInst) &&
+ all_of(EdgeInst->operands(), Constant::classof))) {
PossibleReducedVals.push_back(EdgeVal);
continue;
}
@@ -12246,10 +13133,43 @@ public:
// instructions (grouping them by the predicate).
MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>>
PossibleReducedVals;
- initReductionOps(Inst);
+ initReductionOps(Root);
DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap;
SmallSet<size_t, 2> LoadKeyUsed;
SmallPtrSet<Value *, 4> DoNotReverseVals;
+
+ auto GenerateLoadsSubkey = [&](size_t Key, LoadInst *LI) {
+ Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
+ if (LoadKeyUsed.contains(Key)) {
+ auto LIt = LoadsMap.find(Ptr);
+ if (LIt != LoadsMap.end()) {
+ for (LoadInst *RLI : LIt->second) {
+ if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(),
+ LI->getType(), LI->getPointerOperand(), DL, SE,
+ /*StrictCheck=*/true))
+ return hash_value(RLI->getPointerOperand());
+ }
+ for (LoadInst *RLI : LIt->second) {
+ if (arePointersCompatible(RLI->getPointerOperand(),
+ LI->getPointerOperand(), TLI)) {
+ hash_code SubKey = hash_value(RLI->getPointerOperand());
+ DoNotReverseVals.insert(RLI);
+ return SubKey;
+ }
+ }
+ if (LIt->second.size() > 2) {
+ hash_code SubKey =
+ hash_value(LIt->second.back()->getPointerOperand());
+ DoNotReverseVals.insert(LIt->second.back());
+ return SubKey;
+ }
+ }
+ }
+ LoadKeyUsed.insert(Key);
+ LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
+ return hash_value(LI->getPointerOperand());
+ };
+
while (!Worklist.empty()) {
Instruction *TreeN = Worklist.pop_back_val();
SmallVector<Value *> Args;
@@ -12269,41 +13189,8 @@ public:
// results.
for (Value *V : PossibleRedVals) {
size_t Key, Idx;
- std::tie(Key, Idx) = generateKeySubkey(
- V, &TLI,
- [&](size_t Key, LoadInst *LI) {
- Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
- if (LoadKeyUsed.contains(Key)) {
- auto LIt = LoadsMap.find(Ptr);
- if (LIt != LoadsMap.end()) {
- for (LoadInst *RLI: LIt->second) {
- if (getPointersDiff(
- RLI->getType(), RLI->getPointerOperand(),
- LI->getType(), LI->getPointerOperand(), DL, SE,
- /*StrictCheck=*/true))
- return hash_value(RLI->getPointerOperand());
- }
- for (LoadInst *RLI : LIt->second) {
- if (arePointersCompatible(RLI->getPointerOperand(),
- LI->getPointerOperand(), TLI)) {
- hash_code SubKey = hash_value(RLI->getPointerOperand());
- DoNotReverseVals.insert(RLI);
- return SubKey;
- }
- }
- if (LIt->second.size() > 2) {
- hash_code SubKey =
- hash_value(LIt->second.back()->getPointerOperand());
- DoNotReverseVals.insert(LIt->second.back());
- return SubKey;
- }
- }
- }
- LoadKeyUsed.insert(Key);
- LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
- return hash_value(LI->getPointerOperand());
- },
- /*AllowAlternate=*/false);
+ std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
+ /*AllowAlternate=*/false);
++PossibleReducedVals[Key][Idx]
.insert(std::make_pair(V, 0))
.first->second;
@@ -12312,40 +13199,8 @@ public:
PossibleReductionOps.rend());
} else {
size_t Key, Idx;
- std::tie(Key, Idx) = generateKeySubkey(
- TreeN, &TLI,
- [&](size_t Key, LoadInst *LI) {
- Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
- if (LoadKeyUsed.contains(Key)) {
- auto LIt = LoadsMap.find(Ptr);
- if (LIt != LoadsMap.end()) {
- for (LoadInst *RLI: LIt->second) {
- if (getPointersDiff(RLI->getType(),
- RLI->getPointerOperand(), LI->getType(),
- LI->getPointerOperand(), DL, SE,
- /*StrictCheck=*/true))
- return hash_value(RLI->getPointerOperand());
- }
- for (LoadInst *RLI : LIt->second) {
- if (arePointersCompatible(RLI->getPointerOperand(),
- LI->getPointerOperand(), TLI)) {
- hash_code SubKey = hash_value(RLI->getPointerOperand());
- DoNotReverseVals.insert(RLI);
- return SubKey;
- }
- }
- if (LIt->second.size() > 2) {
- hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand());
- DoNotReverseVals.insert(LIt->second.back());
- return SubKey;
- }
- }
- }
- LoadKeyUsed.insert(Key);
- LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
- return hash_value(LI->getPointerOperand());
- },
- /*AllowAlternate=*/false);
+ std::tie(Key, Idx) = generateKeySubkey(TreeN, &TLI, GenerateLoadsSubkey,
+ /*AllowAlternate=*/false);
++PossibleReducedVals[Key][Idx]
.insert(std::make_pair(TreeN, 0))
.first->second;
@@ -12407,14 +13262,18 @@ public:
// If there are a sufficient number of reduction values, reduce
// to a nearby power-of-2. We can safely generate oversized
// vectors and rely on the backend to split them to legal sizes.
- size_t NumReducedVals =
+ unsigned NumReducedVals =
std::accumulate(ReducedVals.begin(), ReducedVals.end(), 0,
- [](size_t Num, ArrayRef<Value *> Vals) {
+ [](unsigned Num, ArrayRef<Value *> Vals) -> unsigned {
if (!isGoodForReduction(Vals))
return Num;
return Num + Vals.size();
});
- if (NumReducedVals < ReductionLimit) {
+ if (NumReducedVals < ReductionLimit &&
+ (!AllowHorRdxIdenityOptimization ||
+ all_of(ReducedVals, [](ArrayRef<Value *> RedV) {
+ return RedV.size() < 2 || !allConstant(RedV) || !isSplat(RedV);
+ }))) {
for (ReductionOpsType &RdxOps : ReductionOps)
for (Value *RdxOp : RdxOps)
V.analyzedReductionRoot(cast<Instruction>(RdxOp));
@@ -12428,6 +13287,7 @@ public:
DenseMap<Value *, WeakTrackingVH> TrackedVals(
ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size());
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
+ SmallVector<std::pair<Value *, Value *>> ReplacedExternals;
ExternallyUsedValues.reserve(ExtraArgs.size() + 1);
// The same extra argument may be used several times, so log each attempt
// to use it.
@@ -12448,6 +13308,18 @@ public:
return cast<Instruction>(ScalarCond);
};
+ // Return new VectorizedTree, based on previous value.
+ auto GetNewVectorizedTree = [&](Value *VectorizedTree, Value *Res) {
+ if (VectorizedTree) {
+ // Update the final value in the reduction.
+ Builder.SetCurrentDebugLocation(
+ cast<Instruction>(ReductionOps.front().front())->getDebugLoc());
+ return createOp(Builder, RdxKind, VectorizedTree, Res, "op.rdx",
+ ReductionOps);
+ }
+ // Initialize the final value in the reduction.
+ return Res;
+ };
// The reduction root is used as the insertion point for new instructions,
// so set it as externally used to prevent it from being deleted.
ExternallyUsedValues[ReductionRoot];
@@ -12459,6 +13331,12 @@ public:
continue;
IgnoreList.insert(RdxOp);
}
+ // Intersect the fast-math-flags from all reduction operations.
+ FastMathFlags RdxFMF;
+ RdxFMF.set();
+ for (Value *U : IgnoreList)
+ if (auto *FPMO = dyn_cast<FPMathOperator>(U))
+ RdxFMF &= FPMO->getFastMathFlags();
bool IsCmpSelMinMax = isCmpSelMinMax(cast<Instruction>(ReductionRoot));
// Need to track reduced vals, they may be changed during vectorization of
@@ -12519,16 +13397,82 @@ public:
}
}
}
+
+ // Emit code for constant values.
+ if (AllowHorRdxIdenityOptimization && Candidates.size() > 1 &&
+ allConstant(Candidates)) {
+ Value *Res = Candidates.front();
+ ++VectorizedVals.try_emplace(Candidates.front(), 0).first->getSecond();
+ for (Value *VC : ArrayRef(Candidates).drop_front()) {
+ Res = createOp(Builder, RdxKind, Res, VC, "const.rdx", ReductionOps);
+ ++VectorizedVals.try_emplace(VC, 0).first->getSecond();
+ if (auto *ResI = dyn_cast<Instruction>(Res))
+ V.analyzedReductionRoot(ResI);
+ }
+ VectorizedTree = GetNewVectorizedTree(VectorizedTree, Res);
+ continue;
+ }
+
unsigned NumReducedVals = Candidates.size();
- if (NumReducedVals < ReductionLimit)
+ if (NumReducedVals < ReductionLimit &&
+ (NumReducedVals < 2 || !AllowHorRdxIdenityOptimization ||
+ !isSplat(Candidates)))
continue;
+ // Check if we support repeated scalar values processing (optimization of
+ // original scalar identity operations on matched horizontal reductions).
+ IsSupportedHorRdxIdentityOp =
+ AllowHorRdxIdenityOptimization && RdxKind != RecurKind::Mul &&
+ RdxKind != RecurKind::FMul && RdxKind != RecurKind::FMulAdd;
+ // Gather same values.
+ MapVector<Value *, unsigned> SameValuesCounter;
+ if (IsSupportedHorRdxIdentityOp)
+ for (Value *V : Candidates)
+ ++SameValuesCounter.insert(std::make_pair(V, 0)).first->second;
+ // Used to check if the reduced values used same number of times. In this
+ // case the compiler may produce better code. E.g. if reduced values are
+ // aabbccdd (8 x values), then the first node of the tree will have a node
+ // for 4 x abcd + shuffle <4 x abcd>, <0, 0, 1, 1, 2, 2, 3, 3>.
+ // Plus, the final reduction will be performed on <8 x aabbccdd>.
+ // Instead compiler may build <4 x abcd> tree immediately, + reduction (4
+ // x abcd) * 2.
+ // Currently it only handles add/fadd/xor. and/or/min/max do not require
+ // this analysis, other operations may require an extra estimation of
+ // the profitability.
+ bool SameScaleFactor = false;
+ bool OptReusedScalars = IsSupportedHorRdxIdentityOp &&
+ SameValuesCounter.size() != Candidates.size();
+ if (OptReusedScalars) {
+ SameScaleFactor =
+ (RdxKind == RecurKind::Add || RdxKind == RecurKind::FAdd ||
+ RdxKind == RecurKind::Xor) &&
+ all_of(drop_begin(SameValuesCounter),
+ [&SameValuesCounter](const std::pair<Value *, unsigned> &P) {
+ return P.second == SameValuesCounter.front().second;
+ });
+ Candidates.resize(SameValuesCounter.size());
+ transform(SameValuesCounter, Candidates.begin(),
+ [](const auto &P) { return P.first; });
+ NumReducedVals = Candidates.size();
+ // Have a reduction of the same element.
+ if (NumReducedVals == 1) {
+ Value *OrigV = TrackedToOrig.find(Candidates.front())->second;
+ unsigned Cnt = SameValuesCounter.lookup(OrigV);
+ Value *RedVal =
+ emitScaleForReusedOps(Candidates.front(), Builder, Cnt);
+ VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal);
+ VectorizedVals.try_emplace(OrigV, Cnt);
+ continue;
+ }
+ }
+
unsigned MaxVecRegSize = V.getMaxVecRegSize();
unsigned EltSize = V.getVectorElementSize(Candidates[0]);
- unsigned MaxElts = RegMaxNumber * PowerOf2Floor(MaxVecRegSize / EltSize);
+ unsigned MaxElts =
+ RegMaxNumber * llvm::bit_floor(MaxVecRegSize / EltSize);
unsigned ReduxWidth = std::min<unsigned>(
- PowerOf2Floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts));
+ llvm::bit_floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts));
unsigned Start = 0;
unsigned Pos = Start;
// Restarts vectorization attempt with lower vector factor.
@@ -12551,6 +13495,7 @@ public:
ReduxWidth /= 2;
return IsAnyRedOpGathered;
};
+ bool AnyVectorized = false;
while (Pos < NumReducedVals - ReduxWidth + 1 &&
ReduxWidth >= ReductionLimit) {
// Dependency in tree of the reduction ops - drop this attempt, try
@@ -12603,34 +13548,24 @@ public:
LocalExternallyUsedValues[TrackedVals[V]];
});
}
- // Number of uses of the candidates in the vector of values.
- SmallDenseMap<Value *, unsigned> NumUses(Candidates.size());
- for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) {
- Value *V = Candidates[Cnt];
- ++NumUses.try_emplace(V, 0).first->getSecond();
- }
- for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) {
- Value *V = Candidates[Cnt];
- ++NumUses.try_emplace(V, 0).first->getSecond();
+ if (!IsSupportedHorRdxIdentityOp) {
+ // Number of uses of the candidates in the vector of values.
+ assert(SameValuesCounter.empty() &&
+ "Reused values counter map is not empty");
+ for (unsigned Cnt = 0; Cnt < NumReducedVals; ++Cnt) {
+ if (Cnt >= Pos && Cnt < Pos + ReduxWidth)
+ continue;
+ Value *V = Candidates[Cnt];
+ Value *OrigV = TrackedToOrig.find(V)->second;
+ ++SameValuesCounter[OrigV];
+ }
}
SmallPtrSet<Value *, 4> VLScalars(VL.begin(), VL.end());
// Gather externally used values.
SmallPtrSet<Value *, 4> Visited;
- for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) {
- Value *RdxVal = Candidates[Cnt];
- if (!Visited.insert(RdxVal).second)
+ for (unsigned Cnt = 0; Cnt < NumReducedVals; ++Cnt) {
+ if (Cnt >= Pos && Cnt < Pos + ReduxWidth)
continue;
- // Check if the scalar was vectorized as part of the vectorization
- // tree but not the top node.
- if (!VLScalars.contains(RdxVal) && V.isVectorized(RdxVal)) {
- LocalExternallyUsedValues[RdxVal];
- continue;
- }
- unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal];
- if (NumOps != ReducedValsToOps.find(RdxVal)->second.size())
- LocalExternallyUsedValues[RdxVal];
- }
- for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) {
Value *RdxVal = Candidates[Cnt];
if (!Visited.insert(RdxVal).second)
continue;
@@ -12640,42 +13575,34 @@ public:
LocalExternallyUsedValues[RdxVal];
continue;
}
- unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal];
- if (NumOps != ReducedValsToOps.find(RdxVal)->second.size())
+ Value *OrigV = TrackedToOrig.find(RdxVal)->second;
+ unsigned NumOps =
+ VectorizedVals.lookup(RdxVal) + SameValuesCounter[OrigV];
+ if (NumOps != ReducedValsToOps.find(OrigV)->second.size())
LocalExternallyUsedValues[RdxVal];
}
+ // Do not need the list of reused scalars in regular mode anymore.
+ if (!IsSupportedHorRdxIdentityOp)
+ SameValuesCounter.clear();
for (Value *RdxVal : VL)
if (RequiredExtract.contains(RdxVal))
LocalExternallyUsedValues[RdxVal];
+ // Update LocalExternallyUsedValues for the scalar, replaced by
+ // extractelement instructions.
+ for (const std::pair<Value *, Value *> &Pair : ReplacedExternals) {
+ auto It = ExternallyUsedValues.find(Pair.first);
+ if (It == ExternallyUsedValues.end())
+ continue;
+ LocalExternallyUsedValues[Pair.second].append(It->second);
+ }
V.buildExternalUses(LocalExternallyUsedValues);
V.computeMinimumValueSizes();
- // Intersect the fast-math-flags from all reduction operations.
- FastMathFlags RdxFMF;
- RdxFMF.set();
- for (Value *U : IgnoreList)
- if (auto *FPMO = dyn_cast<FPMathOperator>(U))
- RdxFMF &= FPMO->getFastMathFlags();
// Estimate cost.
InstructionCost TreeCost = V.getTreeCost(VL);
InstructionCost ReductionCost =
- getReductionCost(TTI, VL, ReduxWidth, RdxFMF);
- if (V.isVectorizedFirstNode() && isa<LoadInst>(VL.front())) {
- Instruction *MainOp = V.getFirstNodeMainOp();
- for (Value *V : VL) {
- auto *VI = dyn_cast<LoadInst>(V);
- // Add the costs of scalar GEP pointers, to be removed from the
- // code.
- if (!VI || VI == MainOp)
- continue;
- auto *Ptr = dyn_cast<GetElementPtrInst>(VI->getPointerOperand());
- if (!Ptr || !Ptr->hasOneUse() || Ptr->hasAllConstantIndices())
- continue;
- TreeCost -= TTI->getArithmeticInstrCost(
- Instruction::Add, Ptr->getType(), TTI::TCK_RecipThroughput);
- }
- }
+ getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n");
if (!Cost.isValid())
@@ -12716,8 +13643,8 @@ public:
InsertPt = GetCmpForMinMaxReduction(RdxRootInst);
// Vectorize a tree.
- Value *VectorizedRoot =
- V.vectorizeTree(LocalExternallyUsedValues, InsertPt);
+ Value *VectorizedRoot = V.vectorizeTree(LocalExternallyUsedValues,
+ ReplacedExternals, InsertPt);
Builder.SetInsertPoint(InsertPt);
@@ -12727,29 +13654,48 @@ public:
if (isBoolLogicOp(RdxRootInst))
VectorizedRoot = Builder.CreateFreeze(VectorizedRoot);
+ // Emit code to correctly handle reused reduced values, if required.
+ if (OptReusedScalars && !SameScaleFactor) {
+ VectorizedRoot =
+ emitReusedOps(VectorizedRoot, Builder, V.getRootNodeScalars(),
+ SameValuesCounter, TrackedToOrig);
+ }
+
Value *ReducedSubTree =
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
- if (!VectorizedTree) {
- // Initialize the final value in the reduction.
- VectorizedTree = ReducedSubTree;
- } else {
- // Update the final value in the reduction.
- Builder.SetCurrentDebugLocation(
- cast<Instruction>(ReductionOps.front().front())->getDebugLoc());
- VectorizedTree = createOp(Builder, RdxKind, VectorizedTree,
- ReducedSubTree, "op.rdx", ReductionOps);
- }
+ // Improved analysis for add/fadd/xor reductions with same scale factor
+ // for all operands of reductions. We can emit scalar ops for them
+ // instead.
+ if (OptReusedScalars && SameScaleFactor)
+ ReducedSubTree = emitScaleForReusedOps(
+ ReducedSubTree, Builder, SameValuesCounter.front().second);
+
+ VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
// Count vectorized reduced values to exclude them from final reduction.
for (Value *RdxVal : VL) {
- ++VectorizedVals.try_emplace(TrackedToOrig.find(RdxVal)->second, 0)
- .first->getSecond();
+ Value *OrigV = TrackedToOrig.find(RdxVal)->second;
+ if (IsSupportedHorRdxIdentityOp) {
+ VectorizedVals.try_emplace(OrigV, SameValuesCounter[RdxVal]);
+ continue;
+ }
+ ++VectorizedVals.try_emplace(OrigV, 0).first->getSecond();
if (!V.isVectorized(RdxVal))
RequiredExtract.insert(RdxVal);
}
Pos += ReduxWidth;
Start = Pos;
- ReduxWidth = PowerOf2Floor(NumReducedVals - Pos);
+ ReduxWidth = llvm::bit_floor(NumReducedVals - Pos);
+ AnyVectorized = true;
+ }
+ if (OptReusedScalars && !AnyVectorized) {
+ for (const std::pair<Value *, unsigned> &P : SameValuesCounter) {
+ Value *RedVal = emitScaleForReusedOps(P.first, Builder, P.second);
+ VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal);
+ Value *OrigV = TrackedToOrig.find(P.first)->second;
+ VectorizedVals.try_emplace(OrigV, P.second);
+ }
+ continue;
}
}
if (VectorizedTree) {
@@ -12757,7 +13703,7 @@ public:
// possible problem with poison propagation. If not possible to reorder
// (both operands are originally RHS), emit an extra freeze instruction
// for the LHS operand.
- //I.e., if we have original code like this:
+ // I.e., if we have original code like this:
// RedOp1 = select i1 ?, i1 LHS, i1 false
// RedOp2 = select i1 RHS, i1 ?, i1 false
@@ -12892,7 +13838,8 @@ private:
/// Calculate the cost of a reduction.
InstructionCost getReductionCost(TargetTransformInfo *TTI,
ArrayRef<Value *> ReducedVals,
- unsigned ReduxWidth, FastMathFlags FMF) {
+ bool IsCmpSelMinMax, unsigned ReduxWidth,
+ FastMathFlags FMF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Value *FirstReducedVal = ReducedVals.front();
Type *ScalarTy = FirstReducedVal->getType();
@@ -12900,7 +13847,36 @@ private:
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
// the reduction value can be calculated at the compile time.
- bool AllConsts = all_of(ReducedVals, isConstant);
+ bool AllConsts = allConstant(ReducedVals);
+ auto EvaluateScalarCost = [&](function_ref<InstructionCost()> GenCostFn) {
+ InstructionCost Cost = 0;
+ // Scalar cost is repeated for N-1 elements.
+ int Cnt = ReducedVals.size();
+ for (Value *RdxVal : ReducedVals) {
+ if (Cnt == 1)
+ break;
+ --Cnt;
+ if (RdxVal->hasNUsesOrMore(IsCmpSelMinMax ? 3 : 2)) {
+ Cost += GenCostFn();
+ continue;
+ }
+ InstructionCost ScalarCost = 0;
+ for (User *U : RdxVal->users()) {
+ auto *RdxOp = cast<Instruction>(U);
+ if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
+ ScalarCost += TTI->getInstructionCost(RdxOp, CostKind);
+ continue;
+ }
+ ScalarCost = InstructionCost::getInvalid();
+ break;
+ }
+ if (ScalarCost.isValid())
+ Cost += ScalarCost;
+ else
+ Cost += GenCostFn();
+ }
+ return Cost;
+ };
switch (RdxKind) {
case RecurKind::Add:
case RecurKind::Mul:
@@ -12913,52 +13889,32 @@ private:
if (!AllConsts)
VectorCost =
TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind);
- ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
+ ScalarCost = EvaluateScalarCost([&]() {
+ return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
+ });
break;
}
case RecurKind::FMax:
- case RecurKind::FMin: {
- auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy);
- if (!AllConsts) {
- auto *VecCondTy =
- cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
- VectorCost =
- TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
- /*IsUnsigned=*/false, CostKind);
- }
- CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
- ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy,
- SclCondTy, RdxPred, CostKind) +
- TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
- SclCondTy, RdxPred, CostKind);
- break;
- }
+ case RecurKind::FMin:
+ case RecurKind::FMaximum:
+ case RecurKind::FMinimum:
case RecurKind::SMax:
case RecurKind::SMin:
case RecurKind::UMax:
case RecurKind::UMin: {
- auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy);
- if (!AllConsts) {
- auto *VecCondTy =
- cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
- bool IsUnsigned =
- RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin;
- VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
- IsUnsigned, CostKind);
- }
- CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
- ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
- SclCondTy, RdxPred, CostKind) +
- TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
- SclCondTy, RdxPred, CostKind);
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+ if (!AllConsts)
+ VectorCost = TTI->getMinMaxReductionCost(Id, VectorTy, FMF, CostKind);
+ ScalarCost = EvaluateScalarCost([&]() {
+ IntrinsicCostAttributes ICA(Id, ScalarTy, {ScalarTy, ScalarTy}, FMF);
+ return TTI->getIntrinsicInstrCost(ICA, CostKind);
+ });
break;
}
default:
llvm_unreachable("Expected arithmetic or min/max reduction operation");
}
- // Scalar cost is repeated for N-1 elements.
- ScalarCost *= (ReduxWidth - 1);
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
<< " for reduction that starts with " << *FirstReducedVal
<< " (It is a splitting reduction)\n");
@@ -12977,8 +13933,148 @@ private:
++NumVectorInstructions;
return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind);
}
-};
+ /// Emits optimized code for unique scalar value reused \p Cnt times.
+ Value *emitScaleForReusedOps(Value *VectorizedValue, IRBuilderBase &Builder,
+ unsigned Cnt) {
+ assert(IsSupportedHorRdxIdentityOp &&
+ "The optimization of matched scalar identity horizontal reductions "
+ "must be supported.");
+ switch (RdxKind) {
+ case RecurKind::Add: {
+ // res = mul vv, n
+ Value *Scale = ConstantInt::get(VectorizedValue->getType(), Cnt);
+ LLVM_DEBUG(dbgs() << "SLP: Add (to-mul) " << Cnt << "of "
+ << VectorizedValue << ". (HorRdx)\n");
+ return Builder.CreateMul(VectorizedValue, Scale);
+ }
+ case RecurKind::Xor: {
+ // res = n % 2 ? 0 : vv
+ LLVM_DEBUG(dbgs() << "SLP: Xor " << Cnt << "of " << VectorizedValue
+ << ". (HorRdx)\n");
+ if (Cnt % 2 == 0)
+ return Constant::getNullValue(VectorizedValue->getType());
+ return VectorizedValue;
+ }
+ case RecurKind::FAdd: {
+ // res = fmul v, n
+ Value *Scale = ConstantFP::get(VectorizedValue->getType(), Cnt);
+ LLVM_DEBUG(dbgs() << "SLP: FAdd (to-fmul) " << Cnt << "of "
+ << VectorizedValue << ". (HorRdx)\n");
+ return Builder.CreateFMul(VectorizedValue, Scale);
+ }
+ case RecurKind::And:
+ case RecurKind::Or:
+ case RecurKind::SMax:
+ case RecurKind::SMin:
+ case RecurKind::UMax:
+ case RecurKind::UMin:
+ case RecurKind::FMax:
+ case RecurKind::FMin:
+ case RecurKind::FMaximum:
+ case RecurKind::FMinimum:
+ // res = vv
+ return VectorizedValue;
+ case RecurKind::Mul:
+ case RecurKind::FMul:
+ case RecurKind::FMulAdd:
+ case RecurKind::SelectICmp:
+ case RecurKind::SelectFCmp:
+ case RecurKind::None:
+ llvm_unreachable("Unexpected reduction kind for repeated scalar.");
+ }
+ return nullptr;
+ }
+
+ /// Emits actual operation for the scalar identity values, found during
+ /// horizontal reduction analysis.
+ Value *emitReusedOps(Value *VectorizedValue, IRBuilderBase &Builder,
+ ArrayRef<Value *> VL,
+ const MapVector<Value *, unsigned> &SameValuesCounter,
+ const DenseMap<Value *, Value *> &TrackedToOrig) {
+ assert(IsSupportedHorRdxIdentityOp &&
+ "The optimization of matched scalar identity horizontal reductions "
+ "must be supported.");
+ switch (RdxKind) {
+ case RecurKind::Add: {
+ // root = mul prev_root, <1, 1, n, 1>
+ SmallVector<Constant *> Vals;
+ for (Value *V : VL) {
+ unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+ Vals.push_back(ConstantInt::get(V->getType(), Cnt, /*IsSigned=*/false));
+ }
+ auto *Scale = ConstantVector::get(Vals);
+ LLVM_DEBUG(dbgs() << "SLP: Add (to-mul) " << Scale << "of "
+ << VectorizedValue << ". (HorRdx)\n");
+ return Builder.CreateMul(VectorizedValue, Scale);
+ }
+ case RecurKind::And:
+ case RecurKind::Or:
+ // No need for multiple or/and(s).
+ LLVM_DEBUG(dbgs() << "SLP: And/or of same " << VectorizedValue
+ << ". (HorRdx)\n");
+ return VectorizedValue;
+ case RecurKind::SMax:
+ case RecurKind::SMin:
+ case RecurKind::UMax:
+ case RecurKind::UMin:
+ case RecurKind::FMax:
+ case RecurKind::FMin:
+ case RecurKind::FMaximum:
+ case RecurKind::FMinimum:
+ // No need for multiple min/max(s) of the same value.
+ LLVM_DEBUG(dbgs() << "SLP: Max/min of same " << VectorizedValue
+ << ". (HorRdx)\n");
+ return VectorizedValue;
+ case RecurKind::Xor: {
+ // Replace values with even number of repeats with 0, since
+ // x xor x = 0.
+ // root = shuffle prev_root, zeroinitalizer, <0, 1, 2, vf, 4, vf, 5, 6,
+ // 7>, if elements 4th and 6th elements have even number of repeats.
+ SmallVector<int> Mask(
+ cast<FixedVectorType>(VectorizedValue->getType())->getNumElements(),
+ PoisonMaskElem);
+ std::iota(Mask.begin(), Mask.end(), 0);
+ bool NeedShuffle = false;
+ for (unsigned I = 0, VF = VL.size(); I < VF; ++I) {
+ Value *V = VL[I];
+ unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+ if (Cnt % 2 == 0) {
+ Mask[I] = VF;
+ NeedShuffle = true;
+ }
+ }
+ LLVM_DEBUG(dbgs() << "SLP: Xor <"; for (int I
+ : Mask) dbgs()
+ << I << " ";
+ dbgs() << "> of " << VectorizedValue << ". (HorRdx)\n");
+ if (NeedShuffle)
+ VectorizedValue = Builder.CreateShuffleVector(
+ VectorizedValue,
+ ConstantVector::getNullValue(VectorizedValue->getType()), Mask);
+ return VectorizedValue;
+ }
+ case RecurKind::FAdd: {
+ // root = fmul prev_root, <1.0, 1.0, n.0, 1.0>
+ SmallVector<Constant *> Vals;
+ for (Value *V : VL) {
+ unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+ Vals.push_back(ConstantFP::get(V->getType(), Cnt));
+ }
+ auto *Scale = ConstantVector::get(Vals);
+ return Builder.CreateFMul(VectorizedValue, Scale);
+ }
+ case RecurKind::Mul:
+ case RecurKind::FMul:
+ case RecurKind::FMulAdd:
+ case RecurKind::SelectICmp:
+ case RecurKind::SelectFCmp:
+ case RecurKind::None:
+ llvm_unreachable("Unexpected reduction kind for reused scalars.");
+ }
+ return nullptr;
+ }
+};
} // end anonymous namespace
static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) {
@@ -13075,15 +14171,15 @@ static bool findBuildAggregate(Instruction *LastInsertInst,
return false;
}
-/// Try and get a reduction value from a phi node.
+/// Try and get a reduction instruction from a phi node.
///
/// Given a phi node \p P in a block \p ParentBB, consider possible reductions
/// if they come from either \p ParentBB or a containing loop latch.
///
/// \returns A candidate reduction value if possible, or \code nullptr \endcode
/// if not possible.
-static Value *getReductionValue(const DominatorTree *DT, PHINode *P,
- BasicBlock *ParentBB, LoopInfo *LI) {
+static Instruction *getReductionInstr(const DominatorTree *DT, PHINode *P,
+ BasicBlock *ParentBB, LoopInfo *LI) {
// There are situations where the reduction value is not dominated by the
// reduction phi. Vectorizing such cases has been reported to cause
// miscompiles. See PR25787.
@@ -13092,13 +14188,13 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,
DT->dominates(P->getParent(), cast<Instruction>(R)->getParent());
};
- Value *Rdx = nullptr;
+ Instruction *Rdx = nullptr;
// Return the incoming value if it comes from the same BB as the phi node.
if (P->getIncomingBlock(0) == ParentBB) {
- Rdx = P->getIncomingValue(0);
+ Rdx = dyn_cast<Instruction>(P->getIncomingValue(0));
} else if (P->getIncomingBlock(1) == ParentBB) {
- Rdx = P->getIncomingValue(1);
+ Rdx = dyn_cast<Instruction>(P->getIncomingValue(1));
}
if (Rdx && DominatedReduxValue(Rdx))
@@ -13115,9 +14211,9 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,
// There is a loop latch, return the incoming value if it comes from
// that. This reduction pattern occasionally turns up.
if (P->getIncomingBlock(0) == BBLatch) {
- Rdx = P->getIncomingValue(0);
+ Rdx = dyn_cast<Instruction>(P->getIncomingValue(0));
} else if (P->getIncomingBlock(1) == BBLatch) {
- Rdx = P->getIncomingValue(1);
+ Rdx = dyn_cast<Instruction>(P->getIncomingValue(1));
}
if (Rdx && DominatedReduxValue(Rdx))
@@ -13133,6 +14229,10 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) {
return true;
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(V0), m_Value(V1))))
return true;
+ if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(V0), m_Value(V1))))
+ return true;
+ if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(V0), m_Value(V1))))
+ return true;
if (match(I, m_Intrinsic<Intrinsic::smax>(m_Value(V0), m_Value(V1))))
return true;
if (match(I, m_Intrinsic<Intrinsic::smin>(m_Value(V0), m_Value(V1))))
@@ -13144,21 +14244,63 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) {
return false;
}
+/// We could have an initial reduction that is not an add.
+/// r *= v1 + v2 + v3 + v4
+/// In such a case start looking for a tree rooted in the first '+'.
+/// \Returns the new root if found, which may be nullptr if not an instruction.
+static Instruction *tryGetSecondaryReductionRoot(PHINode *Phi,
+ Instruction *Root) {
+ assert((isa<BinaryOperator>(Root) || isa<SelectInst>(Root) ||
+ isa<IntrinsicInst>(Root)) &&
+ "Expected binop, select, or intrinsic for reduction matching");
+ Value *LHS =
+ Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root));
+ Value *RHS =
+ Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root) + 1);
+ if (LHS == Phi)
+ return dyn_cast<Instruction>(RHS);
+ if (RHS == Phi)
+ return dyn_cast<Instruction>(LHS);
+ return nullptr;
+}
+
+/// \p Returns the first operand of \p I that does not match \p Phi. If
+/// operand is not an instruction it returns nullptr.
+static Instruction *getNonPhiOperand(Instruction *I, PHINode *Phi) {
+ Value *Op0 = nullptr;
+ Value *Op1 = nullptr;
+ if (!matchRdxBop(I, Op0, Op1))
+ return nullptr;
+ return dyn_cast<Instruction>(Op0 == Phi ? Op1 : Op0);
+}
+
+/// \Returns true if \p I is a candidate instruction for reduction vectorization.
+static bool isReductionCandidate(Instruction *I) {
+ bool IsSelect = match(I, m_Select(m_Value(), m_Value(), m_Value()));
+ Value *B0 = nullptr, *B1 = nullptr;
+ bool IsBinop = matchRdxBop(I, B0, B1);
+ return IsBinop || IsSelect;
+}
+
bool SLPVectorizerPass::vectorizeHorReduction(
- PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI,
+ PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI,
SmallVectorImpl<WeakTrackingVH> &PostponedInsts) {
if (!ShouldVectorizeHor)
return false;
+ bool TryOperandsAsNewSeeds = P && isa<BinaryOperator>(Root);
- auto *Root = dyn_cast_or_null<Instruction>(V);
- if (!Root)
+ if (Root->getParent() != BB || isa<PHINode>(Root))
return false;
- if (!isa<BinaryOperator>(Root))
- P = nullptr;
+ // If we can find a secondary reduction root, use that instead.
+ auto SelectRoot = [&]() {
+ if (TryOperandsAsNewSeeds && isReductionCandidate(Root) &&
+ HorizontalReduction::getRdxKind(Root) != RecurKind::None)
+ if (Instruction *NewRoot = tryGetSecondaryReductionRoot(P, Root))
+ return NewRoot;
+ return Root;
+ };
- if (Root->getParent() != BB || isa<PHINode>(Root))
- return false;
// Start analysis starting from Root instruction. If horizontal reduction is
// found, try to vectorize it. If it is not a horizontal reduction or
// vectorization is not possible or not effective, and currently analyzed
@@ -13171,22 +14313,32 @@ bool SLPVectorizerPass::vectorizeHorReduction(
// If a horizintal reduction was not matched or vectorized we collect
// instructions for possible later attempts for vectorization.
std::queue<std::pair<Instruction *, unsigned>> Stack;
- Stack.emplace(Root, 0);
+ Stack.emplace(SelectRoot(), 0);
SmallPtrSet<Value *, 8> VisitedInstrs;
bool Res = false;
- auto &&TryToReduce = [this, TTI, &P, &R](Instruction *Inst, Value *&B0,
- Value *&B1) -> Value * {
+ auto &&TryToReduce = [this, TTI, &R](Instruction *Inst) -> Value * {
if (R.isAnalyzedReductionRoot(Inst))
return nullptr;
- bool IsBinop = matchRdxBop(Inst, B0, B1);
- bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value()));
- if (IsBinop || IsSelect) {
- HorizontalReduction HorRdx;
- if (HorRdx.matchAssociativeReduction(P, Inst, *SE, *DL, *TLI))
- return HorRdx.tryToReduce(R, TTI, *TLI);
+ if (!isReductionCandidate(Inst))
+ return nullptr;
+ HorizontalReduction HorRdx;
+ if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
+ return nullptr;
+ return HorRdx.tryToReduce(R, TTI, *TLI);
+ };
+ auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) {
+ if (TryOperandsAsNewSeeds && FutureSeed == Root) {
+ FutureSeed = getNonPhiOperand(Root, P);
+ if (!FutureSeed)
+ return false;
}
- return nullptr;
+ // Do not collect CmpInst or InsertElementInst/InsertValueInst as their
+ // analysis is done separately.
+ if (!isa<CmpInst, InsertElementInst, InsertValueInst>(FutureSeed))
+ PostponedInsts.push_back(FutureSeed);
+ return true;
};
+
while (!Stack.empty()) {
Instruction *Inst;
unsigned Level;
@@ -13197,37 +14349,19 @@ bool SLPVectorizerPass::vectorizeHorReduction(
// iteration while stack was populated before that happened.
if (R.isDeleted(Inst))
continue;
- Value *B0 = nullptr, *B1 = nullptr;
- if (Value *V = TryToReduce(Inst, B0, B1)) {
+ if (Value *VectorizedV = TryToReduce(Inst)) {
Res = true;
- // Set P to nullptr to avoid re-analysis of phi node in
- // matchAssociativeReduction function unless this is the root node.
- P = nullptr;
- if (auto *I = dyn_cast<Instruction>(V)) {
+ if (auto *I = dyn_cast<Instruction>(VectorizedV)) {
// Try to find another reduction.
Stack.emplace(I, Level);
continue;
}
} else {
- bool IsBinop = B0 && B1;
- if (P && IsBinop) {
- Inst = dyn_cast<Instruction>(B0);
- if (Inst == P)
- Inst = dyn_cast<Instruction>(B1);
- if (!Inst) {
- // Set P to nullptr to avoid re-analysis of phi node in
- // matchAssociativeReduction function unless this is the root node.
- P = nullptr;
- continue;
- }
+ // We could not vectorize `Inst` so try to use it as a future seed.
+ if (!TryAppendToPostponedInsts(Inst)) {
+ assert(Stack.empty() && "Expected empty stack");
+ break;
}
- // Set P to nullptr to avoid re-analysis of phi node in
- // matchAssociativeReduction function unless this is the root node.
- P = nullptr;
- // Do not collect CmpInst or InsertElementInst/InsertValueInst as their
- // analysis is done separately.
- if (!isa<CmpInst, InsertElementInst, InsertValueInst>(Inst))
- PostponedInsts.push_back(Inst);
}
// Try to vectorize operands.
@@ -13246,11 +14380,11 @@ bool SLPVectorizerPass::vectorizeHorReduction(
return Res;
}
-bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V,
+bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Instruction *Root,
BasicBlock *BB, BoUpSLP &R,
TargetTransformInfo *TTI) {
SmallVector<WeakTrackingVH> PostponedInsts;
- bool Res = vectorizeHorReduction(P, V, BB, R, TTI, PostponedInsts);
+ bool Res = vectorizeHorReduction(P, Root, BB, R, TTI, PostponedInsts);
Res |= tryToVectorize(PostponedInsts, R);
return Res;
}
@@ -13297,13 +14431,11 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI,
}
template <typename T>
-static bool
-tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming,
- function_ref<unsigned(T *)> Limit,
- function_ref<bool(T *, T *)> Comparator,
- function_ref<bool(T *, T *)> AreCompatible,
- function_ref<bool(ArrayRef<T *>, bool)> TryToVectorizeHelper,
- bool LimitForRegisterSize) {
+static bool tryToVectorizeSequence(
+ SmallVectorImpl<T *> &Incoming, function_ref<bool(T *, T *)> Comparator,
+ function_ref<bool(T *, T *)> AreCompatible,
+ function_ref<bool(ArrayRef<T *>, bool)> TryToVectorizeHelper,
+ bool MaxVFOnly, BoUpSLP &R) {
bool Changed = false;
// Sort by type, parent, operands.
stable_sort(Incoming, Comparator);
@@ -13331,21 +14463,29 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming,
// same/alternate ops only, this may result in some extra final
// vectorization.
if (NumElts > 1 &&
- TryToVectorizeHelper(ArrayRef(IncIt, NumElts), LimitForRegisterSize)) {
+ TryToVectorizeHelper(ArrayRef(IncIt, NumElts), MaxVFOnly)) {
// Success start over because instructions might have been changed.
Changed = true;
- } else if (NumElts < Limit(*IncIt) &&
- (Candidates.empty() ||
- Candidates.front()->getType() == (*IncIt)->getType())) {
- Candidates.append(IncIt, std::next(IncIt, NumElts));
+ } else {
+ /// \Returns the minimum number of elements that we will attempt to
+ /// vectorize.
+ auto GetMinNumElements = [&R](Value *V) {
+ unsigned EltSize = R.getVectorElementSize(V);
+ return std::max(2U, R.getMaxVecRegSize() / EltSize);
+ };
+ if (NumElts < GetMinNumElements(*IncIt) &&
+ (Candidates.empty() ||
+ Candidates.front()->getType() == (*IncIt)->getType())) {
+ Candidates.append(IncIt, std::next(IncIt, NumElts));
+ }
}
// Final attempt to vectorize instructions with the same types.
if (Candidates.size() > 1 &&
(SameTypeIt == E || (*SameTypeIt)->getType() != (*IncIt)->getType())) {
- if (TryToVectorizeHelper(Candidates, /*LimitForRegisterSize=*/false)) {
+ if (TryToVectorizeHelper(Candidates, /*MaxVFOnly=*/false)) {
// Success start over because instructions might have been changed.
Changed = true;
- } else if (LimitForRegisterSize) {
+ } else if (MaxVFOnly) {
// Try to vectorize using small vectors.
for (auto *It = Candidates.begin(), *End = Candidates.end();
It != End;) {
@@ -13353,9 +14493,8 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming,
while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It))
++SameTypeIt;
unsigned NumElts = (SameTypeIt - It);
- if (NumElts > 1 &&
- TryToVectorizeHelper(ArrayRef(It, NumElts),
- /*LimitForRegisterSize=*/false))
+ if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(It, NumElts),
+ /*MaxVFOnly=*/false))
Changed = true;
It = SameTypeIt;
}
@@ -13378,11 +14517,12 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming,
/// of the second cmp instruction.
template <bool IsCompatibility>
static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI,
- function_ref<bool(Instruction *)> IsDeleted) {
+ const DominatorTree &DT) {
+ assert(isValidElementType(V->getType()) &&
+ isValidElementType(V2->getType()) &&
+ "Expected valid element types only.");
auto *CI1 = cast<CmpInst>(V);
auto *CI2 = cast<CmpInst>(V2);
- if (IsDeleted(CI2) || !isValidElementType(CI2->getType()))
- return false;
if (CI1->getOperand(0)->getType()->getTypeID() <
CI2->getOperand(0)->getType()->getTypeID())
return !IsCompatibility;
@@ -13411,31 +14551,102 @@ static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI,
return false;
if (auto *I1 = dyn_cast<Instruction>(Op1))
if (auto *I2 = dyn_cast<Instruction>(Op2)) {
- if (I1->getParent() != I2->getParent())
- return false;
+ if (IsCompatibility) {
+ if (I1->getParent() != I2->getParent())
+ return false;
+ } else {
+ // Try to compare nodes with same parent.
+ DomTreeNodeBase<BasicBlock> *NodeI1 = DT.getNode(I1->getParent());
+ DomTreeNodeBase<BasicBlock> *NodeI2 = DT.getNode(I2->getParent());
+ if (!NodeI1)
+ return NodeI2 != nullptr;
+ if (!NodeI2)
+ return false;
+ assert((NodeI1 == NodeI2) ==
+ (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) &&
+ "Different nodes should have different DFS numbers");
+ if (NodeI1 != NodeI2)
+ return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn();
+ }
InstructionsState S = getSameOpcode({I1, I2}, TLI);
- if (S.getOpcode())
+ if (S.getOpcode() && (IsCompatibility || !S.isAltShuffle()))
continue;
- return false;
+ return !IsCompatibility && I1->getOpcode() < I2->getOpcode();
}
}
return IsCompatibility;
}
-bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions,
- BasicBlock *BB, BoUpSLP &R,
- bool AtTerminator) {
+template <typename ItT>
+bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts,
+ BasicBlock *BB, BoUpSLP &R) {
+ bool Changed = false;
+ // Try to find reductions first.
+ for (CmpInst *I : CmpInsts) {
+ if (R.isDeleted(I))
+ continue;
+ for (Value *Op : I->operands())
+ if (auto *RootOp = dyn_cast<Instruction>(Op))
+ Changed |= vectorizeRootInstruction(nullptr, RootOp, BB, R, TTI);
+ }
+ // Try to vectorize operands as vector bundles.
+ for (CmpInst *I : CmpInsts) {
+ if (R.isDeleted(I))
+ continue;
+ Changed |= tryToVectorize(I, R);
+ }
+ // Try to vectorize list of compares.
+ // Sort by type, compare predicate, etc.
+ auto CompareSorter = [&](Value *V, Value *V2) {
+ if (V == V2)
+ return false;
+ return compareCmp<false>(V, V2, *TLI, *DT);
+ };
+
+ auto AreCompatibleCompares = [&](Value *V1, Value *V2) {
+ if (V1 == V2)
+ return true;
+ return compareCmp<true>(V1, V2, *TLI, *DT);
+ };
+
+ SmallVector<Value *> Vals;
+ for (Instruction *V : CmpInsts)
+ if (!R.isDeleted(V) && isValidElementType(V->getType()))
+ Vals.push_back(V);
+ if (Vals.size() <= 1)
+ return Changed;
+ Changed |= tryToVectorizeSequence<Value>(
+ Vals, CompareSorter, AreCompatibleCompares,
+ [this, &R](ArrayRef<Value *> Candidates, bool MaxVFOnly) {
+ // Exclude possible reductions from other blocks.
+ bool ArePossiblyReducedInOtherBlock = any_of(Candidates, [](Value *V) {
+ return any_of(V->users(), [V](User *U) {
+ auto *Select = dyn_cast<SelectInst>(U);
+ return Select &&
+ Select->getParent() != cast<Instruction>(V)->getParent();
+ });
+ });
+ if (ArePossiblyReducedInOtherBlock)
+ return false;
+ return tryToVectorizeList(Candidates, R, MaxVFOnly);
+ },
+ /*MaxVFOnly=*/true, R);
+ return Changed;
+}
+
+bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions,
+ BasicBlock *BB, BoUpSLP &R) {
+ assert(all_of(Instructions,
+ [](auto *I) {
+ return isa<InsertElementInst, InsertValueInst>(I);
+ }) &&
+ "This function only accepts Insert instructions");
bool OpsChanged = false;
- SmallVector<Instruction *, 4> PostponedCmps;
SmallVector<WeakTrackingVH> PostponedInsts;
// pass1 - try to vectorize reductions only
for (auto *I : reverse(Instructions)) {
if (R.isDeleted(I))
continue;
- if (isa<CmpInst>(I)) {
- PostponedCmps.push_back(I);
- continue;
- }
OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts);
}
// pass2 - try to match and vectorize a buildvector sequence.
@@ -13451,63 +14662,7 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions,
// Now try to vectorize postponed instructions.
OpsChanged |= tryToVectorize(PostponedInsts, R);
- if (AtTerminator) {
- // Try to find reductions first.
- for (Instruction *I : PostponedCmps) {
- if (R.isDeleted(I))
- continue;
- for (Value *Op : I->operands())
- OpsChanged |= vectorizeRootInstruction(nullptr, Op, BB, R, TTI);
- }
- // Try to vectorize operands as vector bundles.
- for (Instruction *I : PostponedCmps) {
- if (R.isDeleted(I))
- continue;
- OpsChanged |= tryToVectorize(I, R);
- }
- // Try to vectorize list of compares.
- // Sort by type, compare predicate, etc.
- auto CompareSorter = [&](Value *V, Value *V2) {
- return compareCmp<false>(V, V2, *TLI,
- [&R](Instruction *I) { return R.isDeleted(I); });
- };
-
- auto AreCompatibleCompares = [&](Value *V1, Value *V2) {
- if (V1 == V2)
- return true;
- return compareCmp<true>(V1, V2, *TLI,
- [&R](Instruction *I) { return R.isDeleted(I); });
- };
- auto Limit = [&R](Value *V) {
- unsigned EltSize = R.getVectorElementSize(V);
- return std::max(2U, R.getMaxVecRegSize() / EltSize);
- };
-
- SmallVector<Value *> Vals(PostponedCmps.begin(), PostponedCmps.end());
- OpsChanged |= tryToVectorizeSequence<Value>(
- Vals, Limit, CompareSorter, AreCompatibleCompares,
- [this, &R](ArrayRef<Value *> Candidates, bool LimitForRegisterSize) {
- // Exclude possible reductions from other blocks.
- bool ArePossiblyReducedInOtherBlock =
- any_of(Candidates, [](Value *V) {
- return any_of(V->users(), [V](User *U) {
- return isa<SelectInst>(U) &&
- cast<SelectInst>(U)->getParent() !=
- cast<Instruction>(V)->getParent();
- });
- });
- if (ArePossiblyReducedInOtherBlock)
- return false;
- return tryToVectorizeList(Candidates, R, LimitForRegisterSize);
- },
- /*LimitForRegisterSize=*/true);
- Instructions.clear();
- } else {
- Instructions.clear();
- // Insert in reverse order since the PostponedCmps vector was filled in
- // reverse order.
- Instructions.insert(PostponedCmps.rbegin(), PostponedCmps.rend());
- }
+ Instructions.clear();
return OpsChanged;
}
@@ -13603,10 +14758,6 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
}
return true;
};
- auto Limit = [&R](Value *V) {
- unsigned EltSize = R.getVectorElementSize(V);
- return std::max(2U, R.getMaxVecRegSize() / EltSize);
- };
bool HaveVectorizedPhiNodes = false;
do {
@@ -13648,19 +14799,44 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
}
HaveVectorizedPhiNodes = tryToVectorizeSequence<Value>(
- Incoming, Limit, PHICompare, AreCompatiblePHIs,
- [this, &R](ArrayRef<Value *> Candidates, bool LimitForRegisterSize) {
- return tryToVectorizeList(Candidates, R, LimitForRegisterSize);
+ Incoming, PHICompare, AreCompatiblePHIs,
+ [this, &R](ArrayRef<Value *> Candidates, bool MaxVFOnly) {
+ return tryToVectorizeList(Candidates, R, MaxVFOnly);
},
- /*LimitForRegisterSize=*/true);
+ /*MaxVFOnly=*/true, R);
Changed |= HaveVectorizedPhiNodes;
VisitedInstrs.insert(Incoming.begin(), Incoming.end());
} while (HaveVectorizedPhiNodes);
VisitedInstrs.clear();
- InstSetVector PostProcessInstructions;
- SmallDenseSet<Instruction *, 4> KeyNodes;
+ InstSetVector PostProcessInserts;
+ SmallSetVector<CmpInst *, 8> PostProcessCmps;
+ // Vectorizes Inserts in `PostProcessInserts` and if `VecctorizeCmps` is true
+ // also vectorizes `PostProcessCmps`.
+ auto VectorizeInsertsAndCmps = [&](bool VectorizeCmps) {
+ bool Changed = vectorizeInserts(PostProcessInserts, BB, R);
+ if (VectorizeCmps) {
+ Changed |= vectorizeCmpInsts(reverse(PostProcessCmps), BB, R);
+ PostProcessCmps.clear();
+ }
+ PostProcessInserts.clear();
+ return Changed;
+ };
+ // Returns true if `I` is in `PostProcessInserts` or `PostProcessCmps`.
+ auto IsInPostProcessInstrs = [&](Instruction *I) {
+ if (auto *Cmp = dyn_cast<CmpInst>(I))
+ return PostProcessCmps.contains(Cmp);
+ return isa<InsertElementInst, InsertValueInst>(I) &&
+ PostProcessInserts.contains(I);
+ };
+ // Returns true if `I` is an instruction without users, like terminator, or
+ // function call with ignored return value, store. Ignore unused instructions
+ // (basing on instruction type, except for CallInst and InvokeInst).
+ auto HasNoUsers = [](Instruction *I) {
+ return I->use_empty() &&
+ (I->getType()->isVoidTy() || isa<CallInst, InvokeInst>(I));
+ };
for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
// Skip instructions with scalable type. The num of elements is unknown at
// compile-time for scalable type.
@@ -13672,9 +14848,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
// We may go through BB multiple times so skip the one we have checked.
if (!VisitedInstrs.insert(&*it).second) {
- if (it->use_empty() && KeyNodes.contains(&*it) &&
- vectorizeSimpleInstructions(PostProcessInstructions, BB, R,
- it->isTerminator())) {
+ if (HasNoUsers(&*it) &&
+ VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) {
// We would like to start over since some instructions are deleted
// and the iterator may become invalid value.
Changed = true;
@@ -13692,8 +14867,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
// Check that the PHI is a reduction PHI.
if (P->getNumIncomingValues() == 2) {
// Try to match and vectorize a horizontal reduction.
- if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R,
- TTI)) {
+ Instruction *Root = getReductionInstr(DT, P, BB, LI);
+ if (Root && vectorizeRootInstruction(P, Root, BB, R, TTI)) {
Changed = true;
it = BB->begin();
e = BB->end();
@@ -13714,19 +14889,14 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
// Postponed instructions should not be vectorized here, delay their
// vectorization.
if (auto *PI = dyn_cast<Instruction>(P->getIncomingValue(I));
- PI && !PostProcessInstructions.contains(PI))
- Changed |= vectorizeRootInstruction(nullptr, P->getIncomingValue(I),
+ PI && !IsInPostProcessInstrs(PI))
+ Changed |= vectorizeRootInstruction(nullptr, PI,
P->getIncomingBlock(I), R, TTI);
}
continue;
}
- // Ran into an instruction without users, like terminator, or function call
- // with ignored return value, store. Ignore unused instructions (basing on
- // instruction type, except for CallInst and InvokeInst).
- if (it->use_empty() &&
- (it->getType()->isVoidTy() || isa<CallInst, InvokeInst>(it))) {
- KeyNodes.insert(&*it);
+ if (HasNoUsers(&*it)) {
bool OpsChanged = false;
auto *SI = dyn_cast<StoreInst>(it);
bool TryToVectorizeRoot = ShouldStartVectorizeHorAtStore || !SI;
@@ -13746,16 +14916,16 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
// Postponed instructions should not be vectorized here, delay their
// vectorization.
if (auto *VI = dyn_cast<Instruction>(V);
- VI && !PostProcessInstructions.contains(VI))
+ VI && !IsInPostProcessInstrs(VI))
// Try to match and vectorize a horizontal reduction.
- OpsChanged |= vectorizeRootInstruction(nullptr, V, BB, R, TTI);
+ OpsChanged |= vectorizeRootInstruction(nullptr, VI, BB, R, TTI);
}
}
// Start vectorization of post-process list of instructions from the
// top-tree instructions to try to vectorize as many instructions as
// possible.
- OpsChanged |= vectorizeSimpleInstructions(PostProcessInstructions, BB, R,
- it->isTerminator());
+ OpsChanged |=
+ VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator());
if (OpsChanged) {
// We would like to start over since some instructions are deleted
// and the iterator may become invalid value.
@@ -13766,8 +14936,10 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
}
}
- if (isa<CmpInst, InsertElementInst, InsertValueInst>(it))
- PostProcessInstructions.insert(&*it);
+ if (isa<InsertElementInst, InsertValueInst>(it))
+ PostProcessInserts.insert(&*it);
+ else if (isa<CmpInst>(it))
+ PostProcessCmps.insert(cast<CmpInst>(&*it));
}
return Changed;
@@ -13928,10 +15100,6 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) {
return V1->getValueOperand()->getValueID() ==
V2->getValueOperand()->getValueID();
};
- auto Limit = [&R, this](StoreInst *SI) {
- unsigned EltSize = DL->getTypeSizeInBits(SI->getValueOperand()->getType());
- return R.getMinVF(EltSize);
- };
// Attempt to sort and vectorize each of the store-groups.
for (auto &Pair : Stores) {
@@ -13945,28 +15113,11 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) {
continue;
Changed |= tryToVectorizeSequence<StoreInst>(
- Pair.second, Limit, StoreSorter, AreCompatibleStores,
+ Pair.second, StoreSorter, AreCompatibleStores,
[this, &R](ArrayRef<StoreInst *> Candidates, bool) {
return vectorizeStores(Candidates, R);
},
- /*LimitForRegisterSize=*/false);
+ /*MaxVFOnly=*/false, R);
}
return Changed;
}
-
-char SLPVectorizer::ID = 0;
-
-static const char lv_name[] = "SLP Vectorizer";
-
-INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(InjectTLIMappingsLegacy)
-INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
-
-Pass *llvm::createSLPVectorizerPass() { return new SLPVectorizer(); }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 733d2e1c667b..1271d1424c03 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -95,7 +95,7 @@ class VPRecipeBuilder {
/// return a new VPWidenCallRecipe. Range.End may be decreased to ensure same
/// decision from \p Range.Start to \p Range.End.
VPWidenCallRecipe *tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands,
- VFRange &Range) const;
+ VFRange &Range, VPlanPtr &Plan);
/// Check if \p I has an opcode that can be widened and return a VPWidenRecipe
/// if it can. The function should only be called if the cost-model indicates
@@ -136,11 +136,11 @@ public:
/// A helper function that computes the predicate of the block BB, assuming
/// that the header block of the loop is set to True. It returns the *entry*
/// mask for the block BB.
- VPValue *createBlockInMask(BasicBlock *BB, VPlanPtr &Plan);
+ VPValue *createBlockInMask(BasicBlock *BB, VPlan &Plan);
/// A helper function that computes the predicate of the edge between SRC
/// and DST.
- VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlanPtr &Plan);
+ VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlan &Plan);
/// Mark given ingredient for recording its recipe once one is created for
/// it.
@@ -159,19 +159,11 @@ public:
return Ingredient2Recipe[I];
}
- /// Create a replicating region for \p PredRecipe.
- VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
- VPlanPtr &Plan);
-
- /// Build a VPReplicationRecipe for \p I and enclose it within a Region if it
- /// is predicated. \return \p VPBB augmented with this new recipe if \p I is
- /// not predicated, otherwise \return a new VPBasicBlock that succeeds the new
- /// Region. Update the packing decision of predicated instructions if they
- /// feed \p I. Range.End may be decreased to ensure same recipe behavior from
- /// \p Range.Start to \p Range.End.
- VPBasicBlock *handleReplication(
- Instruction *I, VFRange &Range, VPBasicBlock *VPBB,
- VPlanPtr &Plan);
+ /// Build a VPReplicationRecipe for \p I. If it is predicated, add the mask as
+ /// last operand. Range.End may be decreased to ensure same recipe behavior
+ /// from \p Range.Start to \p Range.End.
+ VPRecipeOrVPValueTy handleReplication(Instruction *I, VFRange &Range,
+ VPlan &Plan);
/// Add the incoming values from the backedge to reduction & first-order
/// recurrence cross-iteration phis.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index d554f438c804..e81b88fd8099 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -23,6 +23,7 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/BasicBlock.h"
@@ -46,7 +47,10 @@
#include <vector>
using namespace llvm;
+
+namespace llvm {
extern cl::opt<bool> EnableVPlanNativePath;
+}
#define DEBUG_TYPE "vplan"
@@ -160,8 +164,9 @@ VPBasicBlock *VPBlockBase::getEntryBasicBlock() {
}
void VPBlockBase::setPlan(VPlan *ParentPlan) {
- assert(ParentPlan->getEntry() == this &&
- "Can only set plan on its entry block.");
+ assert(
+ (ParentPlan->getEntry() == this || ParentPlan->getPreheader() == this) &&
+ "Can only set plan on its entry or preheader block.");
Plan = ParentPlan;
}
@@ -209,7 +214,7 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
}
Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) {
- if (!Def->hasDefiningRecipe())
+ if (Def->isLiveIn())
return Def->getLiveInIRValue();
if (hasScalarValue(Def, Instance)) {
@@ -243,11 +248,19 @@ void VPTransformState::addNewMetadata(Instruction *To,
}
void VPTransformState::addMetadata(Instruction *To, Instruction *From) {
+ // No source instruction to transfer metadata from?
+ if (!From)
+ return;
+
propagateMetadata(To, From);
addNewMetadata(To, From);
}
void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) {
+ // No source instruction to transfer metadata from?
+ if (!From)
+ return;
+
for (Value *V : To) {
if (Instruction *I = dyn_cast<Instruction>(V))
addMetadata(I, From);
@@ -265,7 +278,7 @@ void VPTransformState::setDebugLocFromInst(const Value *V) {
// When a FSDiscriminator is enabled, we don't need to add the multiply
// factors to the discriminators.
if (DIL && Inst->getFunction()->shouldEmitDebugInfoForProfiling() &&
- !isa<DbgInfoIntrinsic>(Inst) && !EnableFSDiscriminator) {
+ !Inst->isDebugOrPseudoInst() && !EnableFSDiscriminator) {
// FIXME: For scalable vectors, assume vscale=1.
auto NewDIL =
DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue());
@@ -577,7 +590,9 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent,
#endif
VPlan::~VPlan() {
- clearLiveOuts();
+ for (auto &KV : LiveOuts)
+ delete KV.second;
+ LiveOuts.clear();
if (Entry) {
VPValue DummyValue;
@@ -585,15 +600,23 @@ VPlan::~VPlan() {
Block->dropAllReferences(&DummyValue);
VPBlockBase::deleteCFG(Entry);
+
+ Preheader->dropAllReferences(&DummyValue);
+ delete Preheader;
}
- for (VPValue *VPV : VPValuesToFree)
+ for (VPValue *VPV : VPLiveInsToFree)
delete VPV;
- if (TripCount)
- delete TripCount;
if (BackedgeTakenCount)
delete BackedgeTakenCount;
- for (auto &P : VPExternalDefs)
- delete P.second;
+}
+
+VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) {
+ VPBasicBlock *Preheader = new VPBasicBlock("ph");
+ VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph");
+ auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader);
+ Plan->TripCount =
+ vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE);
+ return Plan;
}
VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() {
@@ -609,13 +632,6 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
Value *CanonicalIVStartValue,
VPTransformState &State,
bool IsEpilogueVectorization) {
-
- // Check if the trip count is needed, and if so build it.
- if (TripCount && TripCount->getNumUsers()) {
- for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)
- State.set(TripCount, TripCountV, Part);
- }
-
// Check if the backedge taken count is needed, and if so build it.
if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) {
IRBuilder<> Builder(State.CFG.PrevBB->getTerminator());
@@ -636,7 +652,7 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
// needs to be changed from zero to the value after the main vector loop.
// FIXME: Improve modeling for canonical IV start values in the epilogue loop.
if (CanonicalIVStartValue) {
- VPValue *VPV = getOrAddExternalDef(CanonicalIVStartValue);
+ VPValue *VPV = getVPValueOrAddLiveIn(CanonicalIVStartValue);
auto *IV = getCanonicalIV();
assert(all_of(IV->users(),
[](const VPUser *U) {
@@ -650,8 +666,7 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
VPInstruction::CanonicalIVIncrementNUW;
}) &&
"the canonical IV should only be used by its increments or "
- "ScalarIVSteps when "
- "resetting the start value");
+ "ScalarIVSteps when resetting the start value");
IV->setOperand(0, VPV);
}
}
@@ -748,13 +763,25 @@ void VPlan::print(raw_ostream &O) const {
if (VectorTripCount.getNumUsers() > 0) {
O << "\nLive-in ";
VectorTripCount.printAsOperand(O, SlotTracker);
- O << " = vector-trip-count\n";
+ O << " = vector-trip-count";
}
if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) {
O << "\nLive-in ";
BackedgeTakenCount->printAsOperand(O, SlotTracker);
- O << " = backedge-taken count\n";
+ O << " = backedge-taken count";
+ }
+
+ O << "\n";
+ if (TripCount->isLiveIn())
+ O << "Live-in ";
+ TripCount->printAsOperand(O, SlotTracker);
+ O << " = original trip-count";
+ O << "\n";
+
+ if (!getPreheader()->empty()) {
+ O << "\n";
+ getPreheader()->print(O, "", SlotTracker);
}
for (const VPBlockBase *Block : vp_depth_first_shallow(getEntry())) {
@@ -765,11 +792,7 @@ void VPlan::print(raw_ostream &O) const {
if (!LiveOuts.empty())
O << "\n";
for (const auto &KV : LiveOuts) {
- O << "Live-out ";
- KV.second->getPhi()->printAsOperand(O);
- O << " = ";
- KV.second->getOperand(0)->printAsOperand(O, SlotTracker);
- O << "\n";
+ KV.second->print(O, SlotTracker);
}
O << "}\n";
@@ -882,6 +905,8 @@ void VPlanPrinter::dump() {
OS << "edge [fontname=Courier, fontsize=30]\n";
OS << "compound=true\n";
+ dumpBlock(Plan.getPreheader());
+
for (const VPBlockBase *Block : vp_depth_first_shallow(Plan.getEntry()))
dumpBlock(Block);
@@ -1086,26 +1111,27 @@ VPInterleavedAccessInfo::VPInterleavedAccessInfo(VPlan &Plan,
}
void VPSlotTracker::assignSlot(const VPValue *V) {
- assert(Slots.find(V) == Slots.end() && "VPValue already has a slot!");
+ assert(!Slots.contains(V) && "VPValue already has a slot!");
Slots[V] = NextSlot++;
}
void VPSlotTracker::assignSlots(const VPlan &Plan) {
-
- for (const auto &P : Plan.VPExternalDefs)
- assignSlot(P.second);
-
assignSlot(&Plan.VectorTripCount);
if (Plan.BackedgeTakenCount)
assignSlot(Plan.BackedgeTakenCount);
+ assignSlots(Plan.getPreheader());
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<const VPBlockBase *>>
RPOT(VPBlockDeepTraversalWrapper<const VPBlockBase *>(Plan.getEntry()));
for (const VPBasicBlock *VPBB :
VPBlockUtils::blocksOnly<const VPBasicBlock>(RPOT))
- for (const VPRecipeBase &Recipe : *VPBB)
- for (VPValue *Def : Recipe.definedValues())
- assignSlot(Def);
+ assignSlots(VPBB);
+}
+
+void VPSlotTracker::assignSlots(const VPBasicBlock *VPBB) {
+ for (const VPRecipeBase &Recipe : *VPBB)
+ for (VPValue *Def : Recipe.definedValues())
+ assignSlot(Def);
}
bool vputils::onlyFirstLaneUsed(VPValue *Def) {
@@ -1115,13 +1141,17 @@ bool vputils::onlyFirstLaneUsed(VPValue *Def) {
VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
ScalarEvolution &SE) {
+ if (auto *Expanded = Plan.getSCEVExpansion(Expr))
+ return Expanded;
+ VPValue *Expanded = nullptr;
if (auto *E = dyn_cast<SCEVConstant>(Expr))
- return Plan.getOrAddExternalDef(E->getValue());
- if (auto *E = dyn_cast<SCEVUnknown>(Expr))
- return Plan.getOrAddExternalDef(E->getValue());
-
- VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock();
- VPExpandSCEVRecipe *Step = new VPExpandSCEVRecipe(Expr, SE);
- Preheader->appendRecipe(Step);
- return Step;
+ Expanded = Plan.getVPValueOrAddLiveIn(E->getValue());
+ else if (auto *E = dyn_cast<SCEVUnknown>(Expr))
+ Expanded = Plan.getVPValueOrAddLiveIn(E->getValue());
+ else {
+ Expanded = new VPExpandSCEVRecipe(Expr, SE);
+ Plan.getPreheader()->appendRecipe(Expanded->getDefiningRecipe());
+ }
+ Plan.addSCEVExpansion(Expr, Expanded);
+ return Expanded;
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 986faaf99664..73313465adea 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -25,7 +25,6 @@
#include "VPlanValue.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -33,11 +32,12 @@
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h"
+#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/FMF.h"
-#include "llvm/Transforms/Utils/LoopVersioning.h"
+#include "llvm/IR/Operator.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
@@ -47,11 +47,9 @@ namespace llvm {
class BasicBlock;
class DominatorTree;
-class InductionDescriptor;
class InnerLoopVectorizer;
class IRBuilderBase;
class LoopInfo;
-class PredicateScalarEvolution;
class raw_ostream;
class RecurrenceDescriptor;
class SCEV;
@@ -62,6 +60,7 @@ class VPlan;
class VPReplicateRecipe;
class VPlanSlp;
class Value;
+class LoopVersioning;
namespace Intrinsic {
typedef unsigned ID;
@@ -76,16 +75,17 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF);
Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
int64_t Step);
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE);
+const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
+ Loop *CurLoop = nullptr);
/// A range of powers-of-2 vectorization factors with fixed start and
/// adjustable end. The range includes start and excludes end, e.g.,:
-/// [1, 9) = {1, 2, 4, 8}
+/// [1, 16) = {1, 2, 4, 8}
struct VFRange {
// A power of 2.
const ElementCount Start;
- // Need not be a power of 2. If End <= Start range is empty.
+ // A power of 2. If End <= Start range is empty.
ElementCount End;
bool isEmpty() const {
@@ -98,6 +98,33 @@ struct VFRange {
"Both Start and End should have the same scalable flag");
assert(isPowerOf2_32(Start.getKnownMinValue()) &&
"Expected Start to be a power of 2");
+ assert(isPowerOf2_32(End.getKnownMinValue()) &&
+ "Expected End to be a power of 2");
+ }
+
+ /// Iterator to iterate over vectorization factors in a VFRange.
+ class iterator
+ : public iterator_facade_base<iterator, std::forward_iterator_tag,
+ ElementCount> {
+ ElementCount VF;
+
+ public:
+ iterator(ElementCount VF) : VF(VF) {}
+
+ bool operator==(const iterator &Other) const { return VF == Other.VF; }
+
+ ElementCount operator*() const { return VF; }
+
+ iterator &operator++() {
+ VF *= 2;
+ return *this;
+ }
+ };
+
+ iterator begin() { return iterator(Start); }
+ iterator end() {
+ assert(isPowerOf2_32(End.getKnownMinValue()));
+ return iterator(End);
}
};
@@ -248,7 +275,7 @@ struct VPTransformState {
}
bool hasAnyVectorValue(VPValue *Def) const {
- return Data.PerPartOutput.find(Def) != Data.PerPartOutput.end();
+ return Data.PerPartOutput.contains(Def);
}
bool hasScalarValue(VPValue *Def, VPIteration Instance) {
@@ -370,10 +397,6 @@ struct VPTransformState {
/// Pointer to the VPlan code is generated for.
VPlan *Plan;
- /// Holds recipes that may generate a poison value that is used after
- /// vectorization, even when their operands are not poison.
- SmallPtrSet<VPRecipeBase *, 16> MayGeneratePoisonRecipes;
-
/// The loop object for the current parent region, or nullptr.
Loop *CurrentVectorLoop = nullptr;
@@ -382,7 +405,11 @@ struct VPTransformState {
///
/// This is currently only used to add no-alias metadata based on the
/// memchecks. The actually versioning is performed manually.
- std::unique_ptr<LoopVersioning> LVer;
+ LoopVersioning *LVer = nullptr;
+
+ /// Map SCEVs to their expanded values. Populated when executing
+ /// VPExpandSCEVRecipes.
+ DenseMap<const SCEV *, Value *> ExpandedSCEVs;
};
/// VPBlockBase is the building block of the Hierarchical Control-Flow Graph.
@@ -639,6 +666,10 @@ public:
VPLiveOut(PHINode *Phi, VPValue *Op)
: VPUser({Op}, VPUser::VPUserID::LiveOut), Phi(Phi) {}
+ static inline bool classof(const VPUser *U) {
+ return U->getVPUserID() == VPUser::VPUserID::LiveOut;
+ }
+
/// Fixup the wrapped LCSSA phi node in the unique exit block. This simply
/// means we need to add the appropriate incoming value from the middle
/// block as exiting edges from the scalar epilogue loop (if present) are
@@ -654,6 +685,11 @@ public:
}
PHINode *getPhi() const { return Phi; }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the VPLiveOut to \p O.
+ void print(raw_ostream &O, VPSlotTracker &SlotTracker) const;
+#endif
};
/// VPRecipeBase is a base class modeling a sequence of one or more output IR
@@ -790,6 +826,7 @@ public:
SLPLoad,
SLPStore,
ActiveLaneMask,
+ CalculateTripCountMinusVF,
CanonicalIVIncrement,
CanonicalIVIncrementNUW,
// The next two are similar to the above, but instead increment the
@@ -810,8 +847,10 @@ private:
const std::string Name;
/// Utility method serving execute(): generates a single instance of the
- /// modeled instruction.
- void generateInstruction(VPTransformState &State, unsigned Part);
+ /// modeled instruction. \returns the generated value for \p Part.
+ /// In some cases an existing value is returned rather than a generated
+ /// one.
+ Value *generateInstruction(VPTransformState &State, unsigned Part);
protected:
void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); }
@@ -892,6 +931,7 @@ public:
default:
return false;
case VPInstruction::ActiveLaneMask:
+ case VPInstruction::CalculateTripCountMinusVF:
case VPInstruction::CanonicalIVIncrement:
case VPInstruction::CanonicalIVIncrementNUW:
case VPInstruction::CanonicalIVIncrementForPart:
@@ -903,14 +943,169 @@ public:
}
};
+/// Class to record LLVM IR flag for a recipe along with it.
+class VPRecipeWithIRFlags : public VPRecipeBase {
+ enum class OperationType : unsigned char {
+ OverflowingBinOp,
+ PossiblyExactOp,
+ GEPOp,
+ FPMathOp,
+ Other
+ };
+ struct WrapFlagsTy {
+ char HasNUW : 1;
+ char HasNSW : 1;
+ };
+ struct ExactFlagsTy {
+ char IsExact : 1;
+ };
+ struct GEPFlagsTy {
+ char IsInBounds : 1;
+ };
+ struct FastMathFlagsTy {
+ char AllowReassoc : 1;
+ char NoNaNs : 1;
+ char NoInfs : 1;
+ char NoSignedZeros : 1;
+ char AllowReciprocal : 1;
+ char AllowContract : 1;
+ char ApproxFunc : 1;
+ };
+
+ OperationType OpType;
+
+ union {
+ WrapFlagsTy WrapFlags;
+ ExactFlagsTy ExactFlags;
+ GEPFlagsTy GEPFlags;
+ FastMathFlagsTy FMFs;
+ unsigned char AllFlags;
+ };
+
+public:
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands)
+ : VPRecipeBase(SC, Operands) {
+ OpType = OperationType::Other;
+ AllFlags = 0;
+ }
+
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands,
+ Instruction &I)
+ : VPRecipeWithIRFlags(SC, Operands) {
+ if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
+ OpType = OperationType::OverflowingBinOp;
+ WrapFlags.HasNUW = Op->hasNoUnsignedWrap();
+ WrapFlags.HasNSW = Op->hasNoSignedWrap();
+ } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
+ OpType = OperationType::PossiblyExactOp;
+ ExactFlags.IsExact = Op->isExact();
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
+ OpType = OperationType::GEPOp;
+ GEPFlags.IsInBounds = GEP->isInBounds();
+ } else if (auto *Op = dyn_cast<FPMathOperator>(&I)) {
+ OpType = OperationType::FPMathOp;
+ FastMathFlags FMF = Op->getFastMathFlags();
+ FMFs.AllowReassoc = FMF.allowReassoc();
+ FMFs.NoNaNs = FMF.noNaNs();
+ FMFs.NoInfs = FMF.noInfs();
+ FMFs.NoSignedZeros = FMF.noSignedZeros();
+ FMFs.AllowReciprocal = FMF.allowReciprocal();
+ FMFs.AllowContract = FMF.allowContract();
+ FMFs.ApproxFunc = FMF.approxFunc();
+ }
+ }
+
+ static inline bool classof(const VPRecipeBase *R) {
+ return R->getVPDefID() == VPRecipeBase::VPWidenSC ||
+ R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
+ R->getVPDefID() == VPRecipeBase::VPReplicateSC;
+ }
+
+ /// Drop all poison-generating flags.
+ void dropPoisonGeneratingFlags() {
+ // NOTE: This needs to be kept in-sync with
+ // Instruction::dropPoisonGeneratingFlags.
+ switch (OpType) {
+ case OperationType::OverflowingBinOp:
+ WrapFlags.HasNUW = false;
+ WrapFlags.HasNSW = false;
+ break;
+ case OperationType::PossiblyExactOp:
+ ExactFlags.IsExact = false;
+ break;
+ case OperationType::GEPOp:
+ GEPFlags.IsInBounds = false;
+ break;
+ case OperationType::FPMathOp:
+ FMFs.NoNaNs = false;
+ FMFs.NoInfs = false;
+ break;
+ case OperationType::Other:
+ break;
+ }
+ }
+
+ /// Set the IR flags for \p I.
+ void setFlags(Instruction *I) const {
+ switch (OpType) {
+ case OperationType::OverflowingBinOp:
+ I->setHasNoUnsignedWrap(WrapFlags.HasNUW);
+ I->setHasNoSignedWrap(WrapFlags.HasNSW);
+ break;
+ case OperationType::PossiblyExactOp:
+ I->setIsExact(ExactFlags.IsExact);
+ break;
+ case OperationType::GEPOp:
+ cast<GetElementPtrInst>(I)->setIsInBounds(GEPFlags.IsInBounds);
+ break;
+ case OperationType::FPMathOp:
+ I->setHasAllowReassoc(FMFs.AllowReassoc);
+ I->setHasNoNaNs(FMFs.NoNaNs);
+ I->setHasNoInfs(FMFs.NoInfs);
+ I->setHasNoSignedZeros(FMFs.NoSignedZeros);
+ I->setHasAllowReciprocal(FMFs.AllowReciprocal);
+ I->setHasAllowContract(FMFs.AllowContract);
+ I->setHasApproxFunc(FMFs.ApproxFunc);
+ break;
+ case OperationType::Other:
+ break;
+ }
+ }
+
+ bool isInBounds() const {
+ assert(OpType == OperationType::GEPOp &&
+ "recipe doesn't have inbounds flag");
+ return GEPFlags.IsInBounds;
+ }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ FastMathFlags getFastMathFlags() const {
+ FastMathFlags Res;
+ Res.setAllowReassoc(FMFs.AllowReassoc);
+ Res.setNoNaNs(FMFs.NoNaNs);
+ Res.setNoInfs(FMFs.NoInfs);
+ Res.setNoSignedZeros(FMFs.NoSignedZeros);
+ Res.setAllowReciprocal(FMFs.AllowReciprocal);
+ Res.setAllowContract(FMFs.AllowContract);
+ Res.setApproxFunc(FMFs.ApproxFunc);
+ return Res;
+ }
+
+ void printFlags(raw_ostream &O) const;
+#endif
+};
+
/// VPWidenRecipe is a recipe for producing a copy of vector type its
/// ingredient. This recipe covers most of the traditional vectorization cases
/// where each ingredient transforms into a vectorized version of itself.
-class VPWidenRecipe : public VPRecipeBase, public VPValue {
+class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue {
+
public:
template <typename IterT>
VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
- : VPRecipeBase(VPDef::VPWidenSC, Operands), VPValue(this, &I) {}
+ : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I) {}
~VPWidenRecipe() override = default;
@@ -926,18 +1121,62 @@ public:
#endif
};
+/// VPWidenCastRecipe is a recipe to create vector cast instructions.
+class VPWidenCastRecipe : public VPRecipeBase, public VPValue {
+ /// Cast instruction opcode.
+ Instruction::CastOps Opcode;
+
+ /// Result type for the cast.
+ Type *ResultTy;
+
+public:
+ VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+ CastInst *UI = nullptr)
+ : VPRecipeBase(VPDef::VPWidenCastSC, Op), VPValue(this, UI),
+ Opcode(Opcode), ResultTy(ResultTy) {
+ assert((!UI || UI->getOpcode() == Opcode) &&
+ "opcode of underlying cast doesn't match");
+ assert((!UI || UI->getType() == ResultTy) &&
+ "result type of underlying cast doesn't match");
+ }
+
+ ~VPWidenCastRecipe() override = default;
+
+ VP_CLASSOF_IMPL(VPDef::VPWidenCastSC)
+
+ /// Produce widened copies of the cast.
+ void execute(VPTransformState &State) override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the recipe.
+ void print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const override;
+#endif
+
+ Instruction::CastOps getOpcode() const { return Opcode; }
+
+ /// Returns the result type of the cast.
+ Type *getResultType() const { return ResultTy; }
+};
+
/// A recipe for widening Call instructions.
class VPWidenCallRecipe : public VPRecipeBase, public VPValue {
/// ID of the vector intrinsic to call when widening the call. If set the
/// Intrinsic::not_intrinsic, a library call will be used instead.
Intrinsic::ID VectorIntrinsicID;
+ /// If this recipe represents a library call, Variant stores a pointer to
+ /// the chosen function. There is a 1:1 mapping between a given VF and the
+ /// chosen vectorized variant, so there will be a different vplan for each
+ /// VF with a valid variant.
+ Function *Variant;
public:
template <typename IterT>
VPWidenCallRecipe(CallInst &I, iterator_range<IterT> CallArguments,
- Intrinsic::ID VectorIntrinsicID)
+ Intrinsic::ID VectorIntrinsicID,
+ Function *Variant = nullptr)
: VPRecipeBase(VPDef::VPWidenCallSC, CallArguments), VPValue(this, &I),
- VectorIntrinsicID(VectorIntrinsicID) {}
+ VectorIntrinsicID(VectorIntrinsicID), Variant(Variant) {}
~VPWidenCallRecipe() override = default;
@@ -954,17 +1193,10 @@ public:
};
/// A recipe for widening select instructions.
-class VPWidenSelectRecipe : public VPRecipeBase, public VPValue {
-
- /// Is the condition of the select loop invariant?
- bool InvariantCond;
-
-public:
+struct VPWidenSelectRecipe : public VPRecipeBase, public VPValue {
template <typename IterT>
- VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands,
- bool InvariantCond)
- : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I),
- InvariantCond(InvariantCond) {}
+ VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands)
+ : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I) {}
~VPWidenSelectRecipe() override = default;
@@ -978,29 +1210,38 @@ public:
void print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const override;
#endif
+
+ VPValue *getCond() const {
+ return getOperand(0);
+ }
+
+ bool isInvariantCond() const {
+ return getCond()->isDefinedOutsideVectorRegions();
+ }
};
/// A recipe for handling GEP instructions.
-class VPWidenGEPRecipe : public VPRecipeBase, public VPValue {
- bool IsPtrLoopInvariant;
- SmallBitVector IsIndexLoopInvariant;
+class VPWidenGEPRecipe : public VPRecipeWithIRFlags, public VPValue {
+ bool isPointerLoopInvariant() const {
+ return getOperand(0)->isDefinedOutsideVectorRegions();
+ }
+
+ bool isIndexLoopInvariant(unsigned I) const {
+ return getOperand(I + 1)->isDefinedOutsideVectorRegions();
+ }
+
+ bool areAllOperandsInvariant() const {
+ return all_of(operands(), [](VPValue *Op) {
+ return Op->isDefinedOutsideVectorRegions();
+ });
+ }
public:
template <typename IterT>
VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands)
- : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP),
- IsIndexLoopInvariant(GEP->getNumIndices(), false) {}
+ : VPRecipeWithIRFlags(VPDef::VPWidenGEPSC, Operands, *GEP),
+ VPValue(this, GEP) {}
- template <typename IterT>
- VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands,
- Loop *OrigLoop)
- : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP),
- IsIndexLoopInvariant(GEP->getNumIndices(), false) {
- IsPtrLoopInvariant = OrigLoop->isLoopInvariant(GEP->getPointerOperand());
- for (auto Index : enumerate(GEP->indices()))
- IsIndexLoopInvariant[Index.index()] =
- OrigLoop->isLoopInvariant(Index.value().get());
- }
~VPWidenGEPRecipe() override = default;
VP_CLASSOF_IMPL(VPDef::VPWidenGEPSC)
@@ -1015,78 +1256,6 @@ public:
#endif
};
-/// A recipe for handling phi nodes of integer and floating-point inductions,
-/// producing their vector values.
-class VPWidenIntOrFpInductionRecipe : public VPRecipeBase, public VPValue {
- PHINode *IV;
- const InductionDescriptor &IndDesc;
- bool NeedsVectorIV;
-
-public:
- VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step,
- const InductionDescriptor &IndDesc,
- bool NeedsVectorIV)
- : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}),
- VPValue(this, IV), IV(IV), IndDesc(IndDesc),
- NeedsVectorIV(NeedsVectorIV) {}
-
- VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step,
- const InductionDescriptor &IndDesc,
- TruncInst *Trunc, bool NeedsVectorIV)
- : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}),
- VPValue(this, Trunc), IV(IV), IndDesc(IndDesc),
- NeedsVectorIV(NeedsVectorIV) {}
-
- ~VPWidenIntOrFpInductionRecipe() override = default;
-
- VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC)
-
- /// Generate the vectorized and scalarized versions of the phi node as
- /// needed by their users.
- void execute(VPTransformState &State) override;
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- /// Print the recipe.
- void print(raw_ostream &O, const Twine &Indent,
- VPSlotTracker &SlotTracker) const override;
-#endif
-
- /// Returns the start value of the induction.
- VPValue *getStartValue() { return getOperand(0); }
- const VPValue *getStartValue() const { return getOperand(0); }
-
- /// Returns the step value of the induction.
- VPValue *getStepValue() { return getOperand(1); }
- const VPValue *getStepValue() const { return getOperand(1); }
-
- /// Returns the first defined value as TruncInst, if it is one or nullptr
- /// otherwise.
- TruncInst *getTruncInst() {
- return dyn_cast_or_null<TruncInst>(getVPValue(0)->getUnderlyingValue());
- }
- const TruncInst *getTruncInst() const {
- return dyn_cast_or_null<TruncInst>(getVPValue(0)->getUnderlyingValue());
- }
-
- PHINode *getPHINode() { return IV; }
-
- /// Returns the induction descriptor for the recipe.
- const InductionDescriptor &getInductionDescriptor() const { return IndDesc; }
-
- /// Returns true if the induction is canonical, i.e. starting at 0 and
- /// incremented by UF * VF (= the original IV is incremented by 1).
- bool isCanonical() const;
-
- /// Returns the scalar type of the induction.
- const Type *getScalarType() const {
- const TruncInst *TruncI = getTruncInst();
- return TruncI ? TruncI->getType() : IV->getType();
- }
-
- /// Returns true if a vector phi needs to be created for the induction.
- bool needsVectorIV() const { return NeedsVectorIV; }
-};
-
/// A pure virtual base class for all recipes modeling header phis, including
/// phis for first order recurrences, pointer inductions and reductions. The
/// start value is the first operand of the recipe and the incoming value from
@@ -1112,9 +1281,9 @@ public:
/// per-lane based on the canonical induction.
class VPHeaderPHIRecipe : public VPRecipeBase, public VPValue {
protected:
- VPHeaderPHIRecipe(unsigned char VPDefID, PHINode *Phi,
+ VPHeaderPHIRecipe(unsigned char VPDefID, Instruction *UnderlyingInstr,
VPValue *Start = nullptr)
- : VPRecipeBase(VPDefID, {}), VPValue(this, Phi) {
+ : VPRecipeBase(VPDefID, {}), VPValue(this, UnderlyingInstr) {
if (Start)
addOperand(Start);
}
@@ -1125,12 +1294,12 @@ public:
/// Method to support type inquiry through isa, cast, and dyn_cast.
static inline bool classof(const VPRecipeBase *B) {
return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC &&
- B->getVPDefID() <= VPDef::VPLastPHISC;
+ B->getVPDefID() <= VPDef::VPLastHeaderPHISC;
}
static inline bool classof(const VPValue *V) {
auto *B = V->getDefiningRecipe();
return B && B->getVPDefID() >= VPRecipeBase::VPFirstHeaderPHISC &&
- B->getVPDefID() <= VPRecipeBase::VPLastPHISC;
+ B->getVPDefID() <= VPRecipeBase::VPLastHeaderPHISC;
}
/// Generate the phi nodes.
@@ -1154,17 +1323,92 @@ public:
void setStartValue(VPValue *V) { setOperand(0, V); }
/// Returns the incoming value from the loop backedge.
- VPValue *getBackedgeValue() {
+ virtual VPValue *getBackedgeValue() {
return getOperand(1);
}
/// Returns the backedge value as a recipe. The backedge value is guaranteed
/// to be a recipe.
- VPRecipeBase &getBackedgeRecipe() {
+ virtual VPRecipeBase &getBackedgeRecipe() {
return *getBackedgeValue()->getDefiningRecipe();
}
};
+/// A recipe for handling phi nodes of integer and floating-point inductions,
+/// producing their vector values.
+class VPWidenIntOrFpInductionRecipe : public VPHeaderPHIRecipe {
+ PHINode *IV;
+ TruncInst *Trunc;
+ const InductionDescriptor &IndDesc;
+
+public:
+ VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step,
+ const InductionDescriptor &IndDesc)
+ : VPHeaderPHIRecipe(VPDef::VPWidenIntOrFpInductionSC, IV, Start), IV(IV),
+ Trunc(nullptr), IndDesc(IndDesc) {
+ addOperand(Step);
+ }
+
+ VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step,
+ const InductionDescriptor &IndDesc,
+ TruncInst *Trunc)
+ : VPHeaderPHIRecipe(VPDef::VPWidenIntOrFpInductionSC, Trunc, Start),
+ IV(IV), Trunc(Trunc), IndDesc(IndDesc) {
+ addOperand(Step);
+ }
+
+ ~VPWidenIntOrFpInductionRecipe() override = default;
+
+ VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC)
+
+ /// Generate the vectorized and scalarized versions of the phi node as
+ /// needed by their users.
+ void execute(VPTransformState &State) override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the recipe.
+ void print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const override;
+#endif
+
+ VPValue *getBackedgeValue() override {
+ // TODO: All operands of base recipe must exist and be at same index in
+ // derived recipe.
+ llvm_unreachable(
+ "VPWidenIntOrFpInductionRecipe generates its own backedge value");
+ }
+
+ VPRecipeBase &getBackedgeRecipe() override {
+ // TODO: All operands of base recipe must exist and be at same index in
+ // derived recipe.
+ llvm_unreachable(
+ "VPWidenIntOrFpInductionRecipe generates its own backedge value");
+ }
+
+ /// Returns the step value of the induction.
+ VPValue *getStepValue() { return getOperand(1); }
+ const VPValue *getStepValue() const { return getOperand(1); }
+
+ /// Returns the first defined value as TruncInst, if it is one or nullptr
+ /// otherwise.
+ TruncInst *getTruncInst() { return Trunc; }
+ const TruncInst *getTruncInst() const { return Trunc; }
+
+ PHINode *getPHINode() { return IV; }
+
+ /// Returns the induction descriptor for the recipe.
+ const InductionDescriptor &getInductionDescriptor() const { return IndDesc; }
+
+ /// Returns true if the induction is canonical, i.e. starting at 0 and
+ /// incremented by UF * VF (= the original IV is incremented by 1).
+ bool isCanonical() const;
+
+ /// Returns the scalar type of the induction.
+ const Type *getScalarType() const {
+ return Trunc ? Trunc->getType() : IV->getType();
+ }
+};
+
class VPWidenPointerInductionRecipe : public VPHeaderPHIRecipe {
const InductionDescriptor &IndDesc;
@@ -1374,12 +1618,20 @@ public:
class VPInterleaveRecipe : public VPRecipeBase {
const InterleaveGroup<Instruction> *IG;
+ /// Indicates if the interleave group is in a conditional block and requires a
+ /// mask.
bool HasMask = false;
+ /// Indicates if gaps between members of the group need to be masked out or if
+ /// unusued gaps can be loaded speculatively.
+ bool NeedsMaskForGaps = false;
+
public:
VPInterleaveRecipe(const InterleaveGroup<Instruction> *IG, VPValue *Addr,
- ArrayRef<VPValue *> StoredValues, VPValue *Mask)
- : VPRecipeBase(VPDef::VPInterleaveSC, {Addr}), IG(IG) {
+ ArrayRef<VPValue *> StoredValues, VPValue *Mask,
+ bool NeedsMaskForGaps)
+ : VPRecipeBase(VPDef::VPInterleaveSC, {Addr}), IG(IG),
+ NeedsMaskForGaps(NeedsMaskForGaps) {
for (unsigned i = 0; i < IG->getFactor(); ++i)
if (Instruction *I = IG->getMember(i)) {
if (I->getType()->isVoidTy())
@@ -1490,28 +1742,21 @@ public:
/// copies of the original scalar type, one per lane, instead of producing a
/// single copy of widened type for all lanes. If the instruction is known to be
/// uniform only one copy, per lane zero, will be generated.
-class VPReplicateRecipe : public VPRecipeBase, public VPValue {
+class VPReplicateRecipe : public VPRecipeWithIRFlags, public VPValue {
/// Indicator if only a single replica per lane is needed.
bool IsUniform;
/// Indicator if the replicas are also predicated.
bool IsPredicated;
- /// Indicator if the scalar values should also be packed into a vector.
- bool AlsoPack;
-
public:
template <typename IterT>
VPReplicateRecipe(Instruction *I, iterator_range<IterT> Operands,
- bool IsUniform, bool IsPredicated = false)
- : VPRecipeBase(VPDef::VPReplicateSC, Operands), VPValue(this, I),
- IsUniform(IsUniform), IsPredicated(IsPredicated) {
- // Retain the previous behavior of predicateInstructions(), where an
- // insert-element of a predicated instruction got hoisted into the
- // predicated basic block iff it was its only user. This is achieved by
- // having predicated instructions also pack their values into a vector by
- // default unless they have a replicated user which uses their scalar value.
- AlsoPack = IsPredicated && !I->use_empty();
+ bool IsUniform, VPValue *Mask = nullptr)
+ : VPRecipeWithIRFlags(VPDef::VPReplicateSC, Operands, *I),
+ VPValue(this, I), IsUniform(IsUniform), IsPredicated(Mask) {
+ if (Mask)
+ addOperand(Mask);
}
~VPReplicateRecipe() override = default;
@@ -1523,8 +1768,6 @@ public:
/// the \p State.
void execute(VPTransformState &State) override;
- void setAlsoPack(bool Pack) { AlsoPack = Pack; }
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
@@ -1533,8 +1776,6 @@ public:
bool isUniform() const { return IsUniform; }
- bool isPacked() const { return AlsoPack; }
-
bool isPredicated() const { return IsPredicated; }
/// Returns true if the recipe only uses the first lane of operand \p Op.
@@ -1550,6 +1791,17 @@ public:
"Op must be an operand of the recipe");
return true;
}
+
+ /// Returns true if the recipe is used by a widened recipe via an intervening
+ /// VPPredInstPHIRecipe. In this case, the scalar values should also be packed
+ /// in a vector.
+ bool shouldPack() const;
+
+ /// Return the mask of a predicated VPReplicateRecipe.
+ VPValue *getMask() {
+ assert(isPredicated() && "Trying to get the mask of a unpredicated recipe");
+ return getOperand(getNumOperands() - 1);
+ }
};
/// A recipe for generating conditional branches on the bits of a mask.
@@ -1791,9 +2043,11 @@ public:
return true;
}
- /// Check if the induction described by \p ID is canonical, i.e. has the same
- /// start, step (of 1), and type as the canonical IV.
- bool isCanonical(const InductionDescriptor &ID, Type *Ty) const;
+ /// Check if the induction described by \p Kind, /p Start and \p Step is
+ /// canonical, i.e. has the same start, step (of 1), and type as the
+ /// canonical IV.
+ bool isCanonical(InductionDescriptor::InductionKind Kind, VPValue *Start,
+ VPValue *Step, Type *Ty) const;
};
/// A recipe for generating the active lane mask for the vector loop that is
@@ -2156,13 +2410,19 @@ public:
/// to produce efficient output IR, including which branches, basic-blocks and
/// output IR instructions to generate, and their cost. VPlan holds a
/// Hierarchical-CFG of VPBasicBlocks and VPRegionBlocks rooted at an Entry
-/// VPBlock.
+/// VPBasicBlock.
class VPlan {
friend class VPlanPrinter;
friend class VPSlotTracker;
- /// Hold the single entry to the Hierarchical CFG of the VPlan.
- VPBlockBase *Entry;
+ /// Hold the single entry to the Hierarchical CFG of the VPlan, i.e. the
+ /// preheader of the vector loop.
+ VPBasicBlock *Entry;
+
+ /// VPBasicBlock corresponding to the original preheader. Used to place
+ /// VPExpandSCEV recipes for expressions used during skeleton creation and the
+ /// rest of VPlan execution.
+ VPBasicBlock *Preheader;
/// Holds the VFs applicable to this VPlan.
SmallSetVector<ElementCount, 2> VFs;
@@ -2174,10 +2434,6 @@ class VPlan {
/// Holds the name of the VPlan, for printing.
std::string Name;
- /// Holds all the external definitions created for this VPlan. External
- /// definitions must be immutable and hold a pointer to their underlying IR.
- DenseMap<Value *, VPValue *> VPExternalDefs;
-
/// Represents the trip count of the original loop, for folding
/// the tail.
VPValue *TripCount = nullptr;
@@ -2193,9 +2449,9 @@ class VPlan {
/// VPlan.
Value2VPValueTy Value2VPValue;
- /// Contains all VPValues that been allocated by addVPValue directly and need
- /// to be free when the plan's destructor is called.
- SmallVector<VPValue *, 16> VPValuesToFree;
+ /// Contains all the external definitions created for this VPlan. External
+ /// definitions are VPValues that hold a pointer to their underlying IR.
+ SmallVector<VPValue *, 16> VPLiveInsToFree;
/// Indicates whether it is safe use the Value2VPValue mapping or if the
/// mapping cannot be used any longer, because it is stale.
@@ -2204,14 +2460,41 @@ class VPlan {
/// Values used outside the plan.
MapVector<PHINode *, VPLiveOut *> LiveOuts;
+ /// Mapping from SCEVs to the VPValues representing their expansions.
+ /// NOTE: This mapping is temporary and will be removed once all users have
+ /// been modeled in VPlan directly.
+ DenseMap<const SCEV *, VPValue *> SCEVToExpansion;
+
public:
- VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) {
- if (Entry)
- Entry->setPlan(this);
+ /// Construct a VPlan with original preheader \p Preheader, trip count \p TC
+ /// and \p Entry to the plan. At the moment, \p Preheader and \p Entry need to
+ /// be disconnected, as the bypass blocks between them are not yet modeled in
+ /// VPlan.
+ VPlan(VPBasicBlock *Preheader, VPValue *TC, VPBasicBlock *Entry)
+ : VPlan(Preheader, Entry) {
+ TripCount = TC;
+ }
+
+ /// Construct a VPlan with original preheader \p Preheader and \p Entry to
+ /// the plan. At the moment, \p Preheader and \p Entry need to be
+ /// disconnected, as the bypass blocks between them are not yet modeled in
+ /// VPlan.
+ VPlan(VPBasicBlock *Preheader, VPBasicBlock *Entry)
+ : Entry(Entry), Preheader(Preheader) {
+ Entry->setPlan(this);
+ Preheader->setPlan(this);
+ assert(Preheader->getNumSuccessors() == 0 &&
+ Preheader->getNumPredecessors() == 0 &&
+ "preheader must be disconnected");
}
~VPlan();
+ /// Create an initial VPlan with preheader and entry blocks. Creates a
+ /// VPExpandSCEVRecipe for \p TripCount and uses it as plan's trip count.
+ static VPlanPtr createInitialVPlan(const SCEV *TripCount,
+ ScalarEvolution &PSE);
+
/// Prepare the plan for execution, setting up the required live-in values.
void prepareToExecute(Value *TripCount, Value *VectorTripCount,
Value *CanonicalIVStartValue, VPTransformState &State,
@@ -2220,19 +2503,12 @@ public:
/// Generate the IR code for this VPlan.
void execute(VPTransformState *State);
- VPBlockBase *getEntry() { return Entry; }
- const VPBlockBase *getEntry() const { return Entry; }
-
- VPBlockBase *setEntry(VPBlockBase *Block) {
- Entry = Block;
- Block->setPlan(this);
- return Entry;
- }
+ VPBasicBlock *getEntry() { return Entry; }
+ const VPBasicBlock *getEntry() const { return Entry; }
/// The trip count of the original loop.
- VPValue *getOrCreateTripCount() {
- if (!TripCount)
- TripCount = new VPValue();
+ VPValue *getTripCount() const {
+ assert(TripCount && "trip count needs to be set before accessing it");
return TripCount;
}
@@ -2275,50 +2551,35 @@ public:
void setName(const Twine &newName) { Name = newName.str(); }
- /// Get the existing or add a new external definition for \p V.
- VPValue *getOrAddExternalDef(Value *V) {
- auto I = VPExternalDefs.insert({V, nullptr});
- if (I.second)
- I.first->second = new VPValue(V);
- return I.first->second;
- }
-
- void addVPValue(Value *V) {
- assert(Value2VPValueEnabled &&
- "IR value to VPValue mapping may be out of date!");
- assert(V && "Trying to add a null Value to VPlan");
- assert(!Value2VPValue.count(V) && "Value already exists in VPlan");
- VPValue *VPV = new VPValue(V);
- Value2VPValue[V] = VPV;
- VPValuesToFree.push_back(VPV);
- }
-
void addVPValue(Value *V, VPValue *VPV) {
- assert(Value2VPValueEnabled && "Value2VPValue mapping may be out of date!");
+ assert((Value2VPValueEnabled || VPV->isLiveIn()) &&
+ "Value2VPValue mapping may be out of date!");
assert(V && "Trying to add a null Value to VPlan");
assert(!Value2VPValue.count(V) && "Value already exists in VPlan");
Value2VPValue[V] = VPV;
}
/// Returns the VPValue for \p V. \p OverrideAllowed can be used to disable
- /// checking whether it is safe to query VPValues using IR Values.
+ /// /// checking whether it is safe to query VPValues using IR Values.
VPValue *getVPValue(Value *V, bool OverrideAllowed = false) {
- assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) &&
- "Value2VPValue mapping may be out of date!");
assert(V && "Trying to get the VPValue of a null Value");
assert(Value2VPValue.count(V) && "Value does not exist in VPlan");
+ assert((Value2VPValueEnabled || OverrideAllowed ||
+ Value2VPValue[V]->isLiveIn()) &&
+ "Value2VPValue mapping may be out of date!");
return Value2VPValue[V];
}
- /// Gets the VPValue or adds a new one (if none exists yet) for \p V. \p
- /// OverrideAllowed can be used to disable checking whether it is safe to
- /// query VPValues using IR Values.
- VPValue *getOrAddVPValue(Value *V, bool OverrideAllowed = false) {
- assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) &&
- "Value2VPValue mapping may be out of date!");
+ /// Gets the VPValue for \p V or adds a new live-in (if none exists yet) for
+ /// \p V.
+ VPValue *getVPValueOrAddLiveIn(Value *V) {
assert(V && "Trying to get or add the VPValue of a null Value");
- if (!Value2VPValue.count(V))
- addVPValue(V);
+ if (!Value2VPValue.count(V)) {
+ VPValue *VPV = new VPValue(V);
+ VPLiveInsToFree.push_back(VPV);
+ addVPValue(V, VPV);
+ }
+
return getVPValue(V);
}
@@ -2344,7 +2605,7 @@ public:
iterator_range<mapped_iterator<Use *, std::function<VPValue *(Value *)>>>
mapToVPValues(User::op_range Operands) {
std::function<VPValue *(Value *)> Fn = [this](Value *Op) {
- return getOrAddVPValue(Op);
+ return getVPValueOrAddLiveIn(Op);
};
return map_range(Operands, Fn);
}
@@ -2373,12 +2634,6 @@ public:
void addLiveOut(PHINode *PN, VPValue *V);
- void clearLiveOuts() {
- for (auto &KV : LiveOuts)
- delete KV.second;
- LiveOuts.clear();
- }
-
void removeLiveOut(PHINode *PN) {
delete LiveOuts[PN];
LiveOuts.erase(PN);
@@ -2388,6 +2643,19 @@ public:
return LiveOuts;
}
+ VPValue *getSCEVExpansion(const SCEV *S) const {
+ return SCEVToExpansion.lookup(S);
+ }
+
+ void addSCEVExpansion(const SCEV *S, VPValue *V) {
+ assert(!SCEVToExpansion.contains(S) && "SCEV already expanded");
+ SCEVToExpansion[S] = V;
+ }
+
+ /// \return The block corresponding to the original preheader.
+ VPBasicBlock *getPreheader() { return Preheader; }
+ const VPBasicBlock *getPreheader() const { return Preheader; }
+
private:
/// Add to the given dominator tree the header block and every new basic block
/// that was created between it and the latch block, inclusive.
@@ -2709,6 +2977,8 @@ inline bool isUniformAfterVectorization(VPValue *VPV) {
assert(Def && "Must have definition for value defined inside vector region");
if (auto Rep = dyn_cast<VPReplicateRecipe>(Def))
return Rep->isUniform();
+ if (auto *GEP = dyn_cast<VPWidenGEPRecipe>(Def))
+ return all_of(GEP->operands(), isUniformAfterVectorization);
return false;
}
} // end namespace vputils
diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h
index f790f7e73e11..89e2e7514dac 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h
@@ -13,6 +13,7 @@
#define LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H
#include "VPlan.h"
+#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/SmallVector.h"
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
index 952ce72e36c1..f6e3a2a16db8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
@@ -73,9 +73,8 @@ public:
PlainCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P)
: TheLoop(Lp), LI(LI), Plan(P) {}
- /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected
- /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG.
- VPBasicBlock *buildPlainCFG();
+ /// Build plain CFG for TheLoop and connects it to Plan's entry.
+ void buildPlainCFG();
};
} // anonymous namespace
@@ -196,7 +195,7 @@ VPValue *PlainCFGBuilder::getOrCreateVPOperand(Value *IRVal) {
// A and B: Create VPValue and add it to the pool of external definitions and
// to the Value->VPValue map.
- VPValue *NewVPVal = Plan.getOrAddExternalDef(IRVal);
+ VPValue *NewVPVal = Plan.getVPValueOrAddLiveIn(IRVal);
IRDef2VPValue[IRVal] = NewVPVal;
return NewVPVal;
}
@@ -254,7 +253,7 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
}
// Main interface to build the plain CFG.
-VPBasicBlock *PlainCFGBuilder::buildPlainCFG() {
+void PlainCFGBuilder::buildPlainCFG() {
// 1. Scan the body of the loop in a topological order to visit each basic
// block after having visited its predecessor basic blocks. Create a VPBB for
// each BB and link it to its successor and predecessor VPBBs. Note that
@@ -267,12 +266,13 @@ VPBasicBlock *PlainCFGBuilder::buildPlainCFG() {
BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader();
assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) &&
"Unexpected loop preheader");
- VPBasicBlock *ThePreheaderVPBB = getOrCreateVPBB(ThePreheaderBB);
+ VPBasicBlock *ThePreheaderVPBB = Plan.getEntry();
+ BB2VPBB[ThePreheaderBB] = ThePreheaderVPBB;
ThePreheaderVPBB->setName("vector.ph");
for (auto &I : *ThePreheaderBB) {
if (I.getType()->isVoidTy())
continue;
- IRDef2VPValue[&I] = Plan.getOrAddExternalDef(&I);
+ IRDef2VPValue[&I] = Plan.getVPValueOrAddLiveIn(&I);
}
// Create empty VPBB for Loop H so that we can link PH->H.
VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader());
@@ -371,20 +371,17 @@ VPBasicBlock *PlainCFGBuilder::buildPlainCFG() {
// have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding
// VPlan operands.
fixPhiNodes();
-
- return ThePreheaderVPBB;
}
-VPBasicBlock *VPlanHCFGBuilder::buildPlainCFG() {
+void VPlanHCFGBuilder::buildPlainCFG() {
PlainCFGBuilder PCFGBuilder(TheLoop, LI, Plan);
- return PCFGBuilder.buildPlainCFG();
+ PCFGBuilder.buildPlainCFG();
}
// Public interface to build a H-CFG.
void VPlanHCFGBuilder::buildHierarchicalCFG() {
- // Build Top Region enclosing the plain CFG and set it as VPlan entry.
- VPBasicBlock *EntryVPBB = buildPlainCFG();
- Plan.setEntry(EntryVPBB);
+ // Build Top Region enclosing the plain CFG.
+ buildPlainCFG();
LLVM_DEBUG(Plan.setName("HCFGBuilder: Plain CFG\n"); dbgs() << Plan);
VPRegionBlock *TopRegion = Plan.getVectorLoopRegion();
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
index 2d52990af268..299ae36155cb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h
@@ -57,9 +57,8 @@ private:
// are introduced.
VPDominatorTree VPDomTree;
- /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected
- /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG.
- VPBasicBlock *buildPlainCFG();
+ /// Build plain CFG for TheLoop and connects it to Plan's entry.
+ void buildPlainCFG();
public:
VPlanHCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 4e9be35001ad..26c309eed800 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -34,7 +34,9 @@ using namespace llvm;
using VectorParts = SmallVector<Value *, 2>;
+namespace llvm {
extern cl::opt<bool> EnableVPlanNativePath;
+}
#define LV_NAME "loop-vectorize"
#define DEBUG_TYPE LV_NAME
@@ -50,14 +52,16 @@ bool VPRecipeBase::mayWriteToMemory() const {
->mayWriteToMemory();
case VPBranchOnMaskSC:
case VPScalarIVStepsSC:
+ case VPPredInstPHISC:
return false;
- case VPWidenIntOrFpInductionSC:
+ case VPBlendSC:
+ case VPReductionSC:
case VPWidenCanonicalIVSC:
+ case VPWidenCastSC:
+ case VPWidenGEPSC:
+ case VPWidenIntOrFpInductionSC:
case VPWidenPHISC:
- case VPBlendSC:
case VPWidenSC:
- case VPWidenGEPSC:
- case VPReductionSC:
case VPWidenSelectSC: {
const Instruction *I =
dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue());
@@ -82,14 +86,16 @@ bool VPRecipeBase::mayReadFromMemory() const {
->mayReadFromMemory();
case VPBranchOnMaskSC:
case VPScalarIVStepsSC:
+ case VPPredInstPHISC:
return false;
- case VPWidenIntOrFpInductionSC:
+ case VPBlendSC:
+ case VPReductionSC:
case VPWidenCanonicalIVSC:
+ case VPWidenCastSC:
+ case VPWidenGEPSC:
+ case VPWidenIntOrFpInductionSC:
case VPWidenPHISC:
- case VPBlendSC:
case VPWidenSC:
- case VPWidenGEPSC:
- case VPReductionSC:
case VPWidenSelectSC: {
const Instruction *I =
dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue());
@@ -108,16 +114,20 @@ bool VPRecipeBase::mayHaveSideEffects() const {
case VPDerivedIVSC:
case VPPredInstPHISC:
return false;
- case VPWidenIntOrFpInductionSC:
- case VPWidenPointerInductionSC:
+ case VPWidenCallSC:
+ return cast<Instruction>(getVPSingleValue()->getUnderlyingValue())
+ ->mayHaveSideEffects();
+ case VPBlendSC:
+ case VPReductionSC:
+ case VPScalarIVStepsSC:
case VPWidenCanonicalIVSC:
+ case VPWidenCastSC:
+ case VPWidenGEPSC:
+ case VPWidenIntOrFpInductionSC:
case VPWidenPHISC:
- case VPBlendSC:
+ case VPWidenPointerInductionSC:
case VPWidenSC:
- case VPWidenGEPSC:
- case VPReductionSC:
- case VPWidenSelectSC:
- case VPScalarIVStepsSC: {
+ case VPWidenSelectSC: {
const Instruction *I =
dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue());
(void)I;
@@ -125,6 +135,13 @@ bool VPRecipeBase::mayHaveSideEffects() const {
"underlying instruction has side-effects");
return false;
}
+ case VPWidenMemoryInstructionSC:
+ assert(cast<VPWidenMemoryInstructionRecipe>(this)
+ ->getIngredient()
+ .mayHaveSideEffects() == mayWriteToMemory() &&
+ "mayHaveSideffects result for ingredient differs from this "
+ "implementation");
+ return mayWriteToMemory();
case VPReplicateSC: {
auto *R = cast<VPReplicateRecipe>(this);
return R->getUnderlyingInstr()->mayHaveSideEffects();
@@ -143,6 +160,16 @@ void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) {
State.Builder.GetInsertBlock());
}
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPLiveOut::print(raw_ostream &O, VPSlotTracker &SlotTracker) const {
+ O << "Live-out ";
+ getPhi()->printAsOperand(O);
+ O << " = ";
+ getOperand(0)->printAsOperand(O, SlotTracker);
+ O << "\n";
+}
+#endif
+
void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) {
assert(!Parent && "Recipe already in some VPBasicBlock");
assert(InsertPos->getParent() &&
@@ -189,55 +216,44 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
insertBefore(BB, I);
}
-void VPInstruction::generateInstruction(VPTransformState &State,
- unsigned Part) {
+Value *VPInstruction::generateInstruction(VPTransformState &State,
+ unsigned Part) {
IRBuilderBase &Builder = State.Builder;
Builder.SetCurrentDebugLocation(DL);
if (Instruction::isBinaryOp(getOpcode())) {
Value *A = State.get(getOperand(0), Part);
Value *B = State.get(getOperand(1), Part);
- Value *V =
- Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
- State.set(this, V, Part);
- return;
+ return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
}
switch (getOpcode()) {
case VPInstruction::Not: {
Value *A = State.get(getOperand(0), Part);
- Value *V = Builder.CreateNot(A, Name);
- State.set(this, V, Part);
- break;
+ return Builder.CreateNot(A, Name);
}
case VPInstruction::ICmpULE: {
Value *IV = State.get(getOperand(0), Part);
Value *TC = State.get(getOperand(1), Part);
- Value *V = Builder.CreateICmpULE(IV, TC, Name);
- State.set(this, V, Part);
- break;
+ return Builder.CreateICmpULE(IV, TC, Name);
}
case Instruction::Select: {
Value *Cond = State.get(getOperand(0), Part);
Value *Op1 = State.get(getOperand(1), Part);
Value *Op2 = State.get(getOperand(2), Part);
- Value *V = Builder.CreateSelect(Cond, Op1, Op2, Name);
- State.set(this, V, Part);
- break;
+ return Builder.CreateSelect(Cond, Op1, Op2, Name);
}
case VPInstruction::ActiveLaneMask: {
// Get first lane of vector induction variable.
Value *VIVElem0 = State.get(getOperand(0), VPIteration(Part, 0));
// Get the original loop tripcount.
- Value *ScalarTC = State.get(getOperand(1), Part);
+ Value *ScalarTC = State.get(getOperand(1), VPIteration(Part, 0));
auto *Int1Ty = Type::getInt1Ty(Builder.getContext());
auto *PredTy = VectorType::get(Int1Ty, State.VF);
- Instruction *Call = Builder.CreateIntrinsic(
- Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()},
- {VIVElem0, ScalarTC}, nullptr, Name);
- State.set(this, Call, Part);
- break;
+ return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+ {PredTy, ScalarTC->getType()},
+ {VIVElem0, ScalarTC}, nullptr, Name);
}
case VPInstruction::FirstOrderRecurrenceSplice: {
// Generate code to combine the previous and current values in vector v3.
@@ -255,18 +271,22 @@ void VPInstruction::generateInstruction(VPTransformState &State,
// For the first part, use the recurrence phi (v1), otherwise v2.
auto *V1 = State.get(getOperand(0), 0);
Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1);
- if (!PartMinus1->getType()->isVectorTy()) {
- State.set(this, PartMinus1, Part);
- } else {
- Value *V2 = State.get(getOperand(1), Part);
- State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1, Name),
- Part);
- }
- break;
+ if (!PartMinus1->getType()->isVectorTy())
+ return PartMinus1;
+ Value *V2 = State.get(getOperand(1), Part);
+ return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name);
+ }
+ case VPInstruction::CalculateTripCountMinusVF: {
+ Value *ScalarTC = State.get(getOperand(0), {0, 0});
+ Value *Step =
+ createStepForVF(Builder, ScalarTC->getType(), State.VF, State.UF);
+ Value *Sub = Builder.CreateSub(ScalarTC, Step);
+ Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step);
+ Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
+ return Builder.CreateSelect(Cmp, Sub, Zero);
}
case VPInstruction::CanonicalIVIncrement:
case VPInstruction::CanonicalIVIncrementNUW: {
- Value *Next = nullptr;
if (Part == 0) {
bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW;
auto *Phi = State.get(getOperand(0), 0);
@@ -274,34 +294,26 @@ void VPInstruction::generateInstruction(VPTransformState &State,
// elements) times the unroll factor (num of SIMD instructions).
Value *Step =
createStepForVF(Builder, Phi->getType(), State.VF, State.UF);
- Next = Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
- } else {
- Next = State.get(this, 0);
+ return Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
}
-
- State.set(this, Next, Part);
- break;
+ return State.get(this, 0);
}
case VPInstruction::CanonicalIVIncrementForPart:
case VPInstruction::CanonicalIVIncrementForPartNUW: {
bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW;
auto *IV = State.get(getOperand(0), VPIteration(0, 0));
- if (Part == 0) {
- State.set(this, IV, Part);
- break;
- }
+ if (Part == 0)
+ return IV;
// The canonical IV is incremented by the vectorization factor (num of SIMD
// elements) times the unroll part.
Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part);
- Value *Next = Builder.CreateAdd(IV, Step, Name, IsNUW, false);
- State.set(this, Next, Part);
- break;
+ return Builder.CreateAdd(IV, Step, Name, IsNUW, false);
}
case VPInstruction::BranchOnCond: {
if (Part != 0)
- break;
+ return nullptr;
Value *Cond = State.get(getOperand(0), VPIteration(Part, 0));
VPRegionBlock *ParentRegion = getParent()->getParent();
@@ -318,11 +330,11 @@ void VPInstruction::generateInstruction(VPTransformState &State,
CondBr->setSuccessor(0, nullptr);
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
- break;
+ return CondBr;
}
case VPInstruction::BranchOnCount: {
if (Part != 0)
- break;
+ return nullptr;
// First create the compare.
Value *IV = State.get(getOperand(0), Part);
Value *TC = State.get(getOperand(1), Part);
@@ -342,7 +354,7 @@ void VPInstruction::generateInstruction(VPTransformState &State,
State.CFG.VPBB2IRBB[Header]);
CondBr->setSuccessor(0, nullptr);
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
- break;
+ return CondBr;
}
default:
llvm_unreachable("Unsupported opcode for instruction");
@@ -353,8 +365,13 @@ void VPInstruction::execute(VPTransformState &State) {
assert(!State.Instance && "VPInstruction executing an Instance");
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
State.Builder.setFastMathFlags(FMF);
- for (unsigned Part = 0; Part < State.UF; ++Part)
- generateInstruction(State, Part);
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ Value *GeneratedValue = generateInstruction(State, Part);
+ if (!hasResult())
+ continue;
+ assert(GeneratedValue && "generateInstruction must produce a value");
+ State.set(this, GeneratedValue, Part);
+ }
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -400,6 +417,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::BranchOnCond:
O << "branch-on-cond";
break;
+ case VPInstruction::CalculateTripCountMinusVF:
+ O << "TC > VF ? TC - VF : 0";
+ break;
case VPInstruction::CanonicalIVIncrementForPart:
O << "VF * Part + ";
break;
@@ -438,18 +458,19 @@ void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
}
void VPWidenCallRecipe::execute(VPTransformState &State) {
+ assert(State.VF.isVector() && "not widening");
auto &CI = *cast<CallInst>(getUnderlyingInstr());
assert(!isa<DbgInfoIntrinsic>(CI) &&
"DbgInfoIntrinsic should have been dropped during VPlan construction");
State.setDebugLocFromInst(&CI);
- SmallVector<Type *, 4> Tys;
- for (Value *ArgOperand : CI.args())
- Tys.push_back(
- ToVectorTy(ArgOperand->getType(), State.VF.getKnownMinValue()));
-
for (unsigned Part = 0; Part < State.UF; ++Part) {
- SmallVector<Type *, 2> TysForDecl = {CI.getType()};
+ SmallVector<Type *, 2> TysForDecl;
+ // Add return type if intrinsic is overloaded on it.
+ if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) {
+ TysForDecl.push_back(
+ VectorType::get(CI.getType()->getScalarType(), State.VF));
+ }
SmallVector<Value *, 4> Args;
for (const auto &I : enumerate(operands())) {
// Some intrinsics have a scalar argument - don't replace it with a
@@ -468,21 +489,16 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
Function *VectorF;
if (VectorIntrinsicID != Intrinsic::not_intrinsic) {
// Use vector version of the intrinsic.
- if (State.VF.isVector())
- TysForDecl[0] =
- VectorType::get(CI.getType()->getScalarType(), State.VF);
Module *M = State.Builder.GetInsertBlock()->getModule();
VectorF = Intrinsic::getDeclaration(M, VectorIntrinsicID, TysForDecl);
assert(VectorF && "Can't retrieve vector intrinsic.");
} else {
- // Use vector version of the function call.
- const VFShape Shape = VFShape::get(CI, State.VF, false /*HasGlobalPred*/);
#ifndef NDEBUG
- assert(VFDatabase(CI).getVectorizedFunction(Shape) != nullptr &&
- "Can't create vector function.");
+ assert(Variant != nullptr && "Can't create vector function.");
#endif
- VectorF = VFDatabase(CI).getVectorizedFunction(Shape);
+ VectorF = Variant;
}
+
SmallVector<OperandBundleDef, 1> OpBundles;
CI.getOperandBundlesAsDefs(OpBundles);
CallInst *V = State.Builder.CreateCall(VectorF, Args, OpBundles);
@@ -514,8 +530,12 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent,
if (VectorIntrinsicID)
O << " (using vector intrinsic)";
- else
- O << " (using library function)";
+ else {
+ O << " (using library function";
+ if (Variant->hasName())
+ O << ": " << Variant->getName();
+ O << ")";
+ }
}
void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent,
@@ -528,7 +548,7 @@ void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent,
getOperand(1)->printAsOperand(O, SlotTracker);
O << ", ";
getOperand(2)->printAsOperand(O, SlotTracker);
- O << (InvariantCond ? " (condition is loop invariant)" : "");
+ O << (isInvariantCond() ? " (condition is loop invariant)" : "");
}
#endif
@@ -541,10 +561,10 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
// We have to take the 'vectorized' value and pick the first lane.
// Instcombine will make this a no-op.
auto *InvarCond =
- InvariantCond ? State.get(getOperand(0), VPIteration(0, 0)) : nullptr;
+ isInvariantCond() ? State.get(getCond(), VPIteration(0, 0)) : nullptr;
for (unsigned Part = 0; Part < State.UF; ++Part) {
- Value *Cond = InvarCond ? InvarCond : State.get(getOperand(0), Part);
+ Value *Cond = InvarCond ? InvarCond : State.get(getCond(), Part);
Value *Op0 = State.get(getOperand(1), Part);
Value *Op1 = State.get(getOperand(2), Part);
Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1);
@@ -553,6 +573,33 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
}
}
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
+ switch (OpType) {
+ case OperationType::PossiblyExactOp:
+ if (ExactFlags.IsExact)
+ O << " exact";
+ break;
+ case OperationType::OverflowingBinOp:
+ if (WrapFlags.HasNUW)
+ O << " nuw";
+ if (WrapFlags.HasNSW)
+ O << " nsw";
+ break;
+ case OperationType::FPMathOp:
+ getFastMathFlags().print(O);
+ break;
+ case OperationType::GEPOp:
+ if (GEPFlags.IsInBounds)
+ O << " inbounds";
+ break;
+ case OperationType::Other:
+ break;
+ }
+ O << " ";
+}
+#endif
+
void VPWidenRecipe::execute(VPTransformState &State) {
auto &I = *cast<Instruction>(getUnderlyingValue());
auto &Builder = State.Builder;
@@ -592,17 +639,8 @@ void VPWidenRecipe::execute(VPTransformState &State) {
Value *V = Builder.CreateNAryOp(I.getOpcode(), Ops);
- if (auto *VecOp = dyn_cast<Instruction>(V)) {
- VecOp->copyIRFlags(&I);
-
- // If the instruction is vectorized and was in a basic block that needed
- // predication, we can't propagate poison-generating flags (nuw/nsw,
- // exact, etc.). The control flow has been linearized and the
- // instruction is no longer guarded by the predicate, which could make
- // the flag properties to no longer hold.
- if (State.MayGeneratePoisonRecipes.contains(this))
- VecOp->dropPoisonGeneratingFlags();
- }
+ if (auto *VecOp = dyn_cast<Instruction>(V))
+ setFlags(VecOp);
// Use this vector value for all users of the original instruction.
State.set(this, V, Part);
@@ -646,35 +684,6 @@ void VPWidenRecipe::execute(VPTransformState &State) {
break;
}
-
- case Instruction::ZExt:
- case Instruction::SExt:
- case Instruction::FPToUI:
- case Instruction::FPToSI:
- case Instruction::FPExt:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
- case Instruction::SIToFP:
- case Instruction::UIToFP:
- case Instruction::Trunc:
- case Instruction::FPTrunc:
- case Instruction::BitCast: {
- auto *CI = cast<CastInst>(&I);
- State.setDebugLocFromInst(CI);
-
- /// Vectorize casts.
- Type *DestTy = (State.VF.isScalar())
- ? CI->getType()
- : VectorType::get(CI->getType(), State.VF);
-
- for (unsigned Part = 0; Part < State.UF; ++Part) {
- Value *A = State.get(getOperand(0), Part);
- Value *Cast = Builder.CreateCast(CI->getOpcode(), A, DestTy);
- State.set(this, Cast, Part);
- State.addMetadata(Cast, &I);
- }
- break;
- }
default:
// This instruction is not vectorized by simple widening.
LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I);
@@ -687,10 +696,39 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
O << Indent << "WIDEN ";
printAsOperand(O, SlotTracker);
const Instruction *UI = getUnderlyingInstr();
- O << " = " << UI->getOpcodeName() << " ";
+ O << " = " << UI->getOpcodeName();
+ printFlags(O);
if (auto *Cmp = dyn_cast<CmpInst>(UI))
- O << CmpInst::getPredicateName(Cmp->getPredicate()) << " ";
+ O << Cmp->getPredicate() << " ";
+ printOperands(O, SlotTracker);
+}
+#endif
+
+void VPWidenCastRecipe::execute(VPTransformState &State) {
+ auto *I = cast_or_null<Instruction>(getUnderlyingValue());
+ if (I)
+ State.setDebugLocFromInst(I);
+ auto &Builder = State.Builder;
+ /// Vectorize casts.
+ assert(State.VF.isVector() && "Not vectorizing?");
+ Type *DestTy = VectorType::get(getResultType(), State.VF);
+
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ Value *A = State.get(getOperand(0), Part);
+ Value *Cast = Builder.CreateCast(Instruction::CastOps(Opcode), A, DestTy);
+ State.set(this, Cast, Part);
+ State.addMetadata(Cast, I);
+ }
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "WIDEN-CAST ";
+ printAsOperand(O, SlotTracker);
+ O << " = " << Instruction::getOpcodeName(Opcode) << " ";
printOperands(O, SlotTracker);
+ O << " to " << *getResultType();
}
void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent,
@@ -710,8 +748,13 @@ void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent,
#endif
bool VPWidenIntOrFpInductionRecipe::isCanonical() const {
+ // The step may be defined by a recipe in the preheader (e.g. if it requires
+ // SCEV expansion), but for the canonical induction the step is required to be
+ // 1, which is represented as live-in.
+ if (getStepValue()->getDefiningRecipe())
+ return false;
+ auto *StepC = dyn_cast<ConstantInt>(getStepValue()->getLiveInIRValue());
auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue());
- auto *StepC = dyn_cast<SCEVConstant>(getInductionDescriptor().getStep());
return StartC && StartC->isZero() && StepC && StepC->isOne();
}
@@ -743,6 +786,7 @@ void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent,
#endif
void VPWidenGEPRecipe::execute(VPTransformState &State) {
+ assert(State.VF.isVector() && "not widening");
auto *GEP = cast<GetElementPtrInst>(getUnderlyingInstr());
// Construct a vector GEP by widening the operands of the scalar GEP as
// necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP
@@ -750,7 +794,7 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) {
// is vector-typed. Thus, to keep the representation compact, we only use
// vector-typed operands for loop-varying values.
- if (State.VF.isVector() && IsPtrLoopInvariant && IsIndexLoopInvariant.all()) {
+ if (areAllOperandsInvariant()) {
// If we are vectorizing, but the GEP has only loop-invariant operands,
// the GEP we build (by only using vector-typed operands for
// loop-varying values) would be a scalar pointer. Thus, to ensure we
@@ -763,9 +807,15 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) {
// required. We would add the scalarization decision to
// collectLoopScalars() and teach getVectorValue() to broadcast
// the lane-zero scalar value.
- auto *Clone = State.Builder.Insert(GEP->clone());
+ SmallVector<Value *> Ops;
+ for (unsigned I = 0, E = getNumOperands(); I != E; I++)
+ Ops.push_back(State.get(getOperand(I), VPIteration(0, 0)));
+
+ auto *NewGEP =
+ State.Builder.CreateGEP(GEP->getSourceElementType(), Ops[0],
+ ArrayRef(Ops).drop_front(), "", isInBounds());
for (unsigned Part = 0; Part < State.UF; ++Part) {
- Value *EntryPart = State.Builder.CreateVectorSplat(State.VF, Clone);
+ Value *EntryPart = State.Builder.CreateVectorSplat(State.VF, NewGEP);
State.set(this, EntryPart, Part);
State.addMetadata(EntryPart, GEP);
}
@@ -780,7 +830,7 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) {
for (unsigned Part = 0; Part < State.UF; ++Part) {
// The pointer operand of the new GEP. If it's loop-invariant, we
// won't broadcast it.
- auto *Ptr = IsPtrLoopInvariant
+ auto *Ptr = isPointerLoopInvariant()
? State.get(getOperand(0), VPIteration(0, 0))
: State.get(getOperand(0), Part);
@@ -789,24 +839,16 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) {
SmallVector<Value *, 4> Indices;
for (unsigned I = 1, E = getNumOperands(); I < E; I++) {
VPValue *Operand = getOperand(I);
- if (IsIndexLoopInvariant[I - 1])
+ if (isIndexLoopInvariant(I - 1))
Indices.push_back(State.get(Operand, VPIteration(0, 0)));
else
Indices.push_back(State.get(Operand, Part));
}
- // If the GEP instruction is vectorized and was in a basic block that
- // needed predication, we can't propagate the poison-generating 'inbounds'
- // flag. The control flow has been linearized and the GEP is no longer
- // guarded by the predicate, which could make the 'inbounds' properties to
- // no longer hold.
- bool IsInBounds =
- GEP->isInBounds() && State.MayGeneratePoisonRecipes.count(this) == 0;
-
// Create the new GEP. Note that this GEP may be a scalar if VF == 1,
// but it should be a vector, otherwise.
auto *NewGEP = State.Builder.CreateGEP(GEP->getSourceElementType(), Ptr,
- Indices, "", IsInBounds);
+ Indices, "", isInBounds());
assert((State.VF.isScalar() || NewGEP->getType()->isVectorTy()) &&
"NewGEP is not a pointer vector");
State.set(this, NewGEP, Part);
@@ -819,14 +861,14 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) {
void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "WIDEN-GEP ";
- O << (IsPtrLoopInvariant ? "Inv" : "Var");
- size_t IndicesNumber = IsIndexLoopInvariant.size();
- for (size_t I = 0; I < IndicesNumber; ++I)
- O << "[" << (IsIndexLoopInvariant[I] ? "Inv" : "Var") << "]";
+ O << (isPointerLoopInvariant() ? "Inv" : "Var");
+ for (size_t I = 0; I < getNumOperands() - 1; ++I)
+ O << "[" << (isIndexLoopInvariant(I) ? "Inv" : "Var") << "]";
O << " ";
printAsOperand(O, SlotTracker);
- O << " = getelementptr ";
+ O << " = getelementptr";
+ printFlags(O);
printOperands(O, SlotTracker);
}
#endif
@@ -911,7 +953,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
}
+#endif
+
+bool VPReplicateRecipe::shouldPack() const {
+ // Find if the recipe is used by a widened recipe via an intervening
+ // VPPredInstPHIRecipe. In this case, also pack the scalar values in a vector.
+ return any_of(users(), [](const VPUser *U) {
+ if (auto *PredR = dyn_cast<VPPredInstPHIRecipe>(U))
+ return any_of(PredR->users(), [PredR](const VPUser *U) {
+ return !U->usesScalars(PredR);
+ });
+ return false;
+ });
+}
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << (IsUniform ? "CLONE " : "REPLICATE ");
@@ -921,18 +977,21 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
}
if (auto *CB = dyn_cast<CallBase>(getUnderlyingInstr())) {
- O << "call @" << CB->getCalledFunction()->getName() << "(";
+ O << "call";
+ printFlags(O);
+ O << "@" << CB->getCalledFunction()->getName() << "(";
interleaveComma(make_range(op_begin(), op_begin() + (getNumOperands() - 1)),
O, [&O, &SlotTracker](VPValue *Op) {
Op->printAsOperand(O, SlotTracker);
});
O << ")";
} else {
- O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode()) << " ";
+ O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode());
+ printFlags(O);
printOperands(O, SlotTracker);
}
- if (AlsoPack)
+ if (shouldPack())
O << " (S->V)";
}
#endif
@@ -1053,20 +1112,22 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif
-bool VPCanonicalIVPHIRecipe::isCanonical(const InductionDescriptor &ID,
- Type *Ty) const {
- if (Ty != getScalarType())
+bool VPCanonicalIVPHIRecipe::isCanonical(
+ InductionDescriptor::InductionKind Kind, VPValue *Start, VPValue *Step,
+ Type *Ty) const {
+ // The types must match and it must be an integer induction.
+ if (Ty != getScalarType() || Kind != InductionDescriptor::IK_IntInduction)
return false;
- // The start value of ID must match the start value of this canonical
- // induction.
- if (getStartValue()->getLiveInIRValue() != ID.getStartValue())
+ // Start must match the start value of this canonical induction.
+ if (Start != getStartValue())
return false;
- ConstantInt *Step = ID.getConstIntStepValue();
- // ID must also be incremented by one. IK_IntInduction always increment the
- // induction by Step, but the binary op may not be set.
- return ID.getKind() == InductionDescriptor::IK_IntInduction && Step &&
- Step->isOne();
+ // If the step is defined by a recipe, it is not a ConstantInt.
+ if (Step->getDefiningRecipe())
+ return false;
+
+ ConstantInt *StepC = dyn_cast<ConstantInt>(Step->getLiveInIRValue());
+ return StepC && StepC->isOne();
}
bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(ElementCount VF) {
@@ -1092,9 +1153,11 @@ void VPExpandSCEVRecipe::execute(VPTransformState &State) {
Value *Res = Exp.expandCodeFor(Expr, Expr->getType(),
&*State.Builder.GetInsertPoint());
-
+ assert(!State.ExpandedSCEVs.contains(Expr) &&
+ "Same SCEV expanded multiple times");
+ State.ExpandedSCEVs[Expr] = Res;
for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)
- State.set(this, Res, Part);
+ State.set(this, Res, {Part, 0});
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index cbf111b00e3d..83bfdfd09d19 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//
#include "VPlanTransforms.h"
+#include "VPlanDominatorTree.h"
+#include "VPRecipeBuilder.h"
#include "VPlanCFG.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
@@ -22,11 +24,10 @@
using namespace llvm;
void VPlanTransforms::VPInstructionsToVPRecipes(
- Loop *OrigLoop, VPlanPtr &Plan,
+ VPlanPtr &Plan,
function_ref<const InductionDescriptor *(PHINode *)>
GetIntOrFpInductionDescriptor,
- SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE,
- const TargetLibraryInfo &TLI) {
+ ScalarEvolution &SE, const TargetLibraryInfo &TLI) {
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
Plan->getEntry());
@@ -39,22 +40,15 @@ void VPlanTransforms::VPInstructionsToVPRecipes(
VPValue *VPV = Ingredient.getVPSingleValue();
Instruction *Inst = cast<Instruction>(VPV->getUnderlyingValue());
- if (DeadInstructions.count(Inst)) {
- VPValue DummyValue;
- VPV->replaceAllUsesWith(&DummyValue);
- Ingredient.eraseFromParent();
- continue;
- }
VPRecipeBase *NewRecipe = nullptr;
if (auto *VPPhi = dyn_cast<VPWidenPHIRecipe>(&Ingredient)) {
auto *Phi = cast<PHINode>(VPPhi->getUnderlyingValue());
if (const auto *II = GetIntOrFpInductionDescriptor(Phi)) {
- VPValue *Start = Plan->getOrAddVPValue(II->getStartValue());
+ VPValue *Start = Plan->getVPValueOrAddLiveIn(II->getStartValue());
VPValue *Step =
vputils::getOrCreateVPValueForSCEVExpr(*Plan, II->getStep(), SE);
- NewRecipe =
- new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II, true);
+ NewRecipe = new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II);
} else {
Plan->addVPValue(Phi, VPPhi);
continue;
@@ -66,28 +60,25 @@ void VPlanTransforms::VPInstructionsToVPRecipes(
// Create VPWidenMemoryInstructionRecipe for loads and stores.
if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) {
NewRecipe = new VPWidenMemoryInstructionRecipe(
- *Load, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)),
- nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/);
+ *Load, Ingredient.getOperand(0), nullptr /*Mask*/,
+ false /*Consecutive*/, false /*Reverse*/);
} else if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) {
NewRecipe = new VPWidenMemoryInstructionRecipe(
- *Store, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)),
- Plan->getOrAddVPValue(Store->getValueOperand()), nullptr /*Mask*/,
- false /*Consecutive*/, false /*Reverse*/);
+ *Store, Ingredient.getOperand(1), Ingredient.getOperand(0),
+ nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/);
} else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) {
- NewRecipe = new VPWidenGEPRecipe(
- GEP, Plan->mapToVPValues(GEP->operands()), OrigLoop);
+ NewRecipe = new VPWidenGEPRecipe(GEP, Ingredient.operands());
} else if (CallInst *CI = dyn_cast<CallInst>(Inst)) {
NewRecipe =
- new VPWidenCallRecipe(*CI, Plan->mapToVPValues(CI->args()),
+ new VPWidenCallRecipe(*CI, drop_end(Ingredient.operands()),
getVectorIntrinsicIDForCall(CI, &TLI));
} else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) {
- bool InvariantCond =
- SE.isLoopInvariant(SE.getSCEV(SI->getOperand(0)), OrigLoop);
- NewRecipe = new VPWidenSelectRecipe(
- *SI, Plan->mapToVPValues(SI->operands()), InvariantCond);
+ NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands());
+ } else if (auto *CI = dyn_cast<CastInst>(Inst)) {
+ NewRecipe = new VPWidenCastRecipe(
+ CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), CI);
} else {
- NewRecipe =
- new VPWidenRecipe(*Inst, Plan->mapToVPValues(Inst->operands()));
+ NewRecipe = new VPWidenRecipe(*Inst, Ingredient.operands());
}
}
@@ -98,15 +89,11 @@ void VPlanTransforms::VPInstructionsToVPRecipes(
assert(NewRecipe->getNumDefinedValues() == 0 &&
"Only recpies with zero or one defined values expected");
Ingredient.eraseFromParent();
- Plan->removeVPValueFor(Inst);
- for (auto *Def : NewRecipe->definedValues()) {
- Plan->addVPValue(Inst, Def);
- }
}
}
}
-bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) {
+static bool sinkScalarOperands(VPlan &Plan) {
auto Iter = vp_depth_first_deep(Plan.getEntry());
bool Changed = false;
// First, collect the operands of all recipes in replicate blocks as seeds for
@@ -167,8 +154,7 @@ bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) {
continue;
Instruction *I = cast<Instruction>(
cast<VPReplicateRecipe>(SinkCandidate)->getUnderlyingValue());
- auto *Clone =
- new VPReplicateRecipe(I, SinkCandidate->operands(), true, false);
+ auto *Clone = new VPReplicateRecipe(I, SinkCandidate->operands(), true);
// TODO: add ".cloned" suffix to name of Clone's VPValue.
Clone->insertBefore(SinkCandidate);
@@ -224,7 +210,10 @@ static VPBasicBlock *getPredicatedThenBlock(VPRegionBlock *R) {
return nullptr;
}
-bool VPlanTransforms::mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
+// Merge replicate regions in their successor region, if a replicate region
+// is connected to a successor replicate region with the same predicate by a
+// single, empty VPBasicBlock.
+static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
SetVector<VPRegionBlock *> DeletedRegions;
// Collect replicate regions followed by an empty block, followed by another
@@ -312,6 +301,81 @@ bool VPlanTransforms::mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
return !DeletedRegions.empty();
}
+static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
+ VPlan &Plan) {
+ Instruction *Instr = PredRecipe->getUnderlyingInstr();
+ // Build the triangular if-then region.
+ std::string RegionName = (Twine("pred.") + Instr->getOpcodeName()).str();
+ assert(Instr->getParent() && "Predicated instruction not in any basic block");
+ auto *BlockInMask = PredRecipe->getMask();
+ auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask);
+ auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe);
+
+ // Replace predicated replicate recipe with a replicate recipe without a
+ // mask but in the replicate region.
+ auto *RecipeWithoutMask = new VPReplicateRecipe(
+ PredRecipe->getUnderlyingInstr(),
+ make_range(PredRecipe->op_begin(), std::prev(PredRecipe->op_end())),
+ PredRecipe->isUniform());
+ auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask);
+
+ VPPredInstPHIRecipe *PHIRecipe = nullptr;
+ if (PredRecipe->getNumUsers() != 0) {
+ PHIRecipe = new VPPredInstPHIRecipe(RecipeWithoutMask);
+ PredRecipe->replaceAllUsesWith(PHIRecipe);
+ PHIRecipe->setOperand(0, RecipeWithoutMask);
+ }
+ PredRecipe->eraseFromParent();
+ auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe);
+ VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true);
+
+ // Note: first set Entry as region entry and then connect successors starting
+ // from it in order, to propagate the "parent" of each VPBasicBlock.
+ VPBlockUtils::insertTwoBlocksAfter(Pred, Exiting, Entry);
+ VPBlockUtils::connectBlocks(Pred, Exiting);
+
+ return Region;
+}
+
+static void addReplicateRegions(VPlan &Plan) {
+ SmallVector<VPReplicateRecipe *> WorkList;
+ for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+ vp_depth_first_deep(Plan.getEntry()))) {
+ for (VPRecipeBase &R : *VPBB)
+ if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R)) {
+ if (RepR->isPredicated())
+ WorkList.push_back(RepR);
+ }
+ }
+
+ unsigned BBNum = 0;
+ for (VPReplicateRecipe *RepR : WorkList) {
+ VPBasicBlock *CurrentBlock = RepR->getParent();
+ VPBasicBlock *SplitBlock = CurrentBlock->splitAt(RepR->getIterator());
+
+ BasicBlock *OrigBB = RepR->getUnderlyingInstr()->getParent();
+ SplitBlock->setName(
+ OrigBB->hasName() ? OrigBB->getName() + "." + Twine(BBNum++) : "");
+ // Record predicated instructions for above packing optimizations.
+ VPBlockBase *Region = createReplicateRegion(RepR, Plan);
+ Region->setParent(CurrentBlock->getParent());
+ VPBlockUtils::disconnectBlocks(CurrentBlock, SplitBlock);
+ VPBlockUtils::connectBlocks(CurrentBlock, Region);
+ VPBlockUtils::connectBlocks(Region, SplitBlock);
+ }
+}
+
+void VPlanTransforms::createAndOptimizeReplicateRegions(VPlan &Plan) {
+ // Convert masked VPReplicateRecipes to if-then region blocks.
+ addReplicateRegions(Plan);
+
+ bool ShouldSimplify = true;
+ while (ShouldSimplify) {
+ ShouldSimplify = sinkScalarOperands(Plan);
+ ShouldSimplify |= mergeReplicateRegionsIntoSuccessors(Plan);
+ ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(Plan);
+ }
+}
bool VPlanTransforms::mergeBlocksIntoPredecessors(VPlan &Plan) {
SmallVector<VPBasicBlock *> WorkList;
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
@@ -395,7 +459,10 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) {
// everything WidenNewIV's users need. That is, WidenOriginalIV will
// generate a vector phi or all users of WidenNewIV demand the first lane
// only.
- if (WidenOriginalIV->needsVectorIV() ||
+ if (any_of(WidenOriginalIV->users(),
+ [WidenOriginalIV](VPUser *U) {
+ return !U->usesScalars(WidenOriginalIV);
+ }) ||
vputils::onlyFirstLaneUsed(WidenNewIV)) {
WidenNewIV->replaceAllUsesWith(WidenOriginalIV);
WidenNewIV->eraseFromParent();
@@ -440,10 +507,10 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) {
if (Instruction *TruncI = WideIV->getTruncInst())
ResultTy = TruncI->getType();
const InductionDescriptor &ID = WideIV->getInductionDescriptor();
- VPValue *Step =
- vputils::getOrCreateVPValueForSCEVExpr(Plan, ID.getStep(), SE);
+ VPValue *Step = WideIV->getStepValue();
VPValue *BaseIV = CanonicalIV;
- if (!CanonicalIV->isCanonical(ID, ResultTy)) {
+ if (!CanonicalIV->isCanonical(ID.getKind(), WideIV->getStartValue(), Step,
+ ResultTy)) {
BaseIV = new VPDerivedIVRecipe(ID, WideIV->getStartValue(), CanonicalIV,
Step, ResultTy);
HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP);
@@ -522,9 +589,9 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
return;
LLVMContext &Ctx = SE.getContext();
- auto *BOC =
- new VPInstruction(VPInstruction::BranchOnCond,
- {Plan.getOrAddExternalDef(ConstantInt::getTrue(Ctx))});
+ auto *BOC = new VPInstruction(
+ VPInstruction::BranchOnCond,
+ {Plan.getVPValueOrAddLiveIn(ConstantInt::getTrue(Ctx))});
Term->eraseFromParent();
ExitingVPBB->appendRecipe(BOC);
Plan.setVF(BestVF);
@@ -533,3 +600,181 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
// 1. Replace inductions with constants.
// 2. Replace vector loop region with VPBasicBlock.
}
+
+#ifndef NDEBUG
+static VPRegionBlock *GetReplicateRegion(VPRecipeBase *R) {
+ auto *Region = dyn_cast_or_null<VPRegionBlock>(R->getParent()->getParent());
+ if (Region && Region->isReplicator()) {
+ assert(Region->getNumSuccessors() == 1 &&
+ Region->getNumPredecessors() == 1 && "Expected SESE region!");
+ assert(R->getParent()->size() == 1 &&
+ "A recipe in an original replicator region must be the only "
+ "recipe in its block");
+ return Region;
+ }
+ return nullptr;
+}
+#endif
+
+static bool properlyDominates(const VPRecipeBase *A, const VPRecipeBase *B,
+ VPDominatorTree &VPDT) {
+ if (A == B)
+ return false;
+
+ auto LocalComesBefore = [](const VPRecipeBase *A, const VPRecipeBase *B) {
+ for (auto &R : *A->getParent()) {
+ if (&R == A)
+ return true;
+ if (&R == B)
+ return false;
+ }
+ llvm_unreachable("recipe not found");
+ };
+ const VPBlockBase *ParentA = A->getParent();
+ const VPBlockBase *ParentB = B->getParent();
+ if (ParentA == ParentB)
+ return LocalComesBefore(A, B);
+
+ assert(!GetReplicateRegion(const_cast<VPRecipeBase *>(A)) &&
+ "No replicate regions expected at this point");
+ assert(!GetReplicateRegion(const_cast<VPRecipeBase *>(B)) &&
+ "No replicate regions expected at this point");
+ return VPDT.properlyDominates(ParentA, ParentB);
+}
+
+/// Sink users of \p FOR after the recipe defining the previous value \p
+/// Previous of the recurrence. \returns true if all users of \p FOR could be
+/// re-arranged as needed or false if it is not possible.
+static bool
+sinkRecurrenceUsersAfterPrevious(VPFirstOrderRecurrencePHIRecipe *FOR,
+ VPRecipeBase *Previous,
+ VPDominatorTree &VPDT) {
+ // Collect recipes that need sinking.
+ SmallVector<VPRecipeBase *> WorkList;
+ SmallPtrSet<VPRecipeBase *, 8> Seen;
+ Seen.insert(Previous);
+ auto TryToPushSinkCandidate = [&](VPRecipeBase *SinkCandidate) {
+ // The previous value must not depend on the users of the recurrence phi. In
+ // that case, FOR is not a fixed order recurrence.
+ if (SinkCandidate == Previous)
+ return false;
+
+ if (isa<VPHeaderPHIRecipe>(SinkCandidate) ||
+ !Seen.insert(SinkCandidate).second ||
+ properlyDominates(Previous, SinkCandidate, VPDT))
+ return true;
+
+ if (SinkCandidate->mayHaveSideEffects())
+ return false;
+
+ WorkList.push_back(SinkCandidate);
+ return true;
+ };
+
+ // Recursively sink users of FOR after Previous.
+ WorkList.push_back(FOR);
+ for (unsigned I = 0; I != WorkList.size(); ++I) {
+ VPRecipeBase *Current = WorkList[I];
+ assert(Current->getNumDefinedValues() == 1 &&
+ "only recipes with a single defined value expected");
+
+ for (VPUser *User : Current->getVPSingleValue()->users()) {
+ if (auto *R = dyn_cast<VPRecipeBase>(User))
+ if (!TryToPushSinkCandidate(R))
+ return false;
+ }
+ }
+
+ // Keep recipes to sink ordered by dominance so earlier instructions are
+ // processed first.
+ sort(WorkList, [&VPDT](const VPRecipeBase *A, const VPRecipeBase *B) {
+ return properlyDominates(A, B, VPDT);
+ });
+
+ for (VPRecipeBase *SinkCandidate : WorkList) {
+ if (SinkCandidate == FOR)
+ continue;
+
+ SinkCandidate->moveAfter(Previous);
+ Previous = SinkCandidate;
+ }
+ return true;
+}
+
+bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
+ VPBuilder &Builder) {
+ VPDominatorTree VPDT;
+ VPDT.recalculate(Plan);
+
+ SmallVector<VPFirstOrderRecurrencePHIRecipe *> RecurrencePhis;
+ for (VPRecipeBase &R :
+ Plan.getVectorLoopRegion()->getEntry()->getEntryBasicBlock()->phis())
+ if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R))
+ RecurrencePhis.push_back(FOR);
+
+ for (VPFirstOrderRecurrencePHIRecipe *FOR : RecurrencePhis) {
+ SmallPtrSet<VPFirstOrderRecurrencePHIRecipe *, 4> SeenPhis;
+ VPRecipeBase *Previous = FOR->getBackedgeValue()->getDefiningRecipe();
+ // Fixed-order recurrences do not contain cycles, so this loop is guaranteed
+ // to terminate.
+ while (auto *PrevPhi =
+ dyn_cast_or_null<VPFirstOrderRecurrencePHIRecipe>(Previous)) {
+ assert(PrevPhi->getParent() == FOR->getParent());
+ assert(SeenPhis.insert(PrevPhi).second);
+ Previous = PrevPhi->getBackedgeValue()->getDefiningRecipe();
+ }
+
+ if (!sinkRecurrenceUsersAfterPrevious(FOR, Previous, VPDT))
+ return false;
+
+ // Introduce a recipe to combine the incoming and previous values of a
+ // fixed-order recurrence.
+ VPBasicBlock *InsertBlock = Previous->getParent();
+ if (isa<VPHeaderPHIRecipe>(Previous))
+ Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi());
+ else
+ Builder.setInsertPoint(InsertBlock, std::next(Previous->getIterator()));
+
+ auto *RecurSplice = cast<VPInstruction>(
+ Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice,
+ {FOR, FOR->getBackedgeValue()}));
+
+ FOR->replaceAllUsesWith(RecurSplice);
+ // Set the first operand of RecurSplice to FOR again, after replacing
+ // all users.
+ RecurSplice->setOperand(0, FOR);
+ }
+ return true;
+}
+
+void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
+ for (VPRecipeBase &R :
+ Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
+ auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
+ if (!PhiR)
+ continue;
+ const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+ RecurKind RK = RdxDesc.getRecurrenceKind();
+ if (RK != RecurKind::Add && RK != RecurKind::Mul)
+ continue;
+
+ SmallSetVector<VPValue *, 8> Worklist;
+ Worklist.insert(PhiR);
+
+ for (unsigned I = 0; I != Worklist.size(); ++I) {
+ VPValue *Cur = Worklist[I];
+ if (auto *RecWithFlags =
+ dyn_cast<VPRecipeWithIRFlags>(Cur->getDefiningRecipe())) {
+ RecWithFlags->dropPoisonGeneratingFlags();
+ }
+
+ for (VPUser *U : Cur->users()) {
+ auto *UserRecipe = dyn_cast<VPRecipeBase>(U);
+ if (!UserRecipe)
+ continue;
+ for (VPValue *V : UserRecipe->definedValues())
+ Worklist.insert(V);
+ }
+ }
+ }
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index be0d8e76d809..3eccf6e9600d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -25,23 +25,23 @@ class ScalarEvolution;
class Loop;
class PredicatedScalarEvolution;
class TargetLibraryInfo;
+class VPBuilder;
+class VPRecipeBuilder;
struct VPlanTransforms {
/// Replaces the VPInstructions in \p Plan with corresponding
/// widen recipes.
static void
- VPInstructionsToVPRecipes(Loop *OrigLoop, VPlanPtr &Plan,
+ VPInstructionsToVPRecipes(VPlanPtr &Plan,
function_ref<const InductionDescriptor *(PHINode *)>
GetIntOrFpInductionDescriptor,
- SmallPtrSetImpl<Instruction *> &DeadInstructions,
ScalarEvolution &SE, const TargetLibraryInfo &TLI);
- static bool sinkScalarOperands(VPlan &Plan);
-
- /// Merge replicate regions in their successor region, if a replicate region
- /// is connected to a successor replicate region with the same predicate by a
- /// single, empty VPBasicBlock.
- static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan);
+ /// Wrap predicated VPReplicateRecipes with a mask operand in an if-then
+ /// region block and remove the mask operand. Optimize the created regions by
+ /// iteratively sinking scalar operands into the region, followed by merging
+ /// regions until no improvements are remaining.
+ static void createAndOptimizeReplicateRegions(VPlan &Plan);
/// Remove redundant VPBasicBlocks by merging them into their predecessor if
/// the predecessor has a single successor.
@@ -71,6 +71,19 @@ struct VPlanTransforms {
/// them with already existing recipes expanding the same SCEV expression.
static void removeRedundantExpandSCEVRecipes(VPlan &Plan);
+ /// Sink users of fixed-order recurrences after the recipe defining their
+ /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions
+ /// to combine the value from the recurrence phis and previous values. The
+ /// current implementation assumes all users can be sunk after the previous
+ /// value, which is enforced by earlier legality checks.
+ /// \returns true if all users of fixed-order recurrences could be re-arranged
+ /// as needed or false if it is not possible. In the latter case, \p Plan is
+ /// not valid.
+ static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder);
+
+ /// Clear NSW/NUW flags from reduction instructions if necessary.
+ static void clearReductionWrapFlags(VPlan &Plan);
+
/// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the
/// resulting plan to \p BestVF and \p BestUF.
static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 62ec65cbfe5d..ac110bb3b0ef 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -171,16 +171,19 @@ public:
/// Returns true if this VPValue is defined by a recipe.
bool hasDefiningRecipe() const { return getDefiningRecipe(); }
+ /// Returns true if this VPValue is a live-in, i.e. defined outside the VPlan.
+ bool isLiveIn() const { return !hasDefiningRecipe(); }
+
/// Returns the underlying IR value, if this VPValue is defined outside the
/// scope of VPlan. Returns nullptr if the VPValue is defined by a VPDef
/// inside a VPlan.
Value *getLiveInIRValue() {
- assert(!hasDefiningRecipe() &&
+ assert(isLiveIn() &&
"VPValue is not a live-in; it is defined by a VPDef inside a VPlan");
return getUnderlyingValue();
}
const Value *getLiveInIRValue() const {
- assert(!hasDefiningRecipe() &&
+ assert(isLiveIn() &&
"VPValue is not a live-in; it is defined by a VPDef inside a VPlan");
return getUnderlyingValue();
}
@@ -342,15 +345,16 @@ public:
VPScalarIVStepsSC,
VPWidenCallSC,
VPWidenCanonicalIVSC,
+ VPWidenCastSC,
VPWidenGEPSC,
VPWidenMemoryInstructionSC,
VPWidenSC,
VPWidenSelectSC,
-
- // Phi-like recipes. Need to be kept together.
+ // START: Phi-like recipes. Need to be kept together.
VPBlendSC,
VPPredInstPHISC,
- // Header-phi recipes. Need to be kept together.
+ // START: SubclassID for recipes that inherit VPHeaderPHIRecipe.
+ // VPHeaderPHIRecipe need to be kept together.
VPCanonicalIVPHISC,
VPActiveLaneMaskPHISC,
VPFirstOrderRecurrencePHISC,
@@ -358,8 +362,11 @@ public:
VPWidenIntOrFpInductionSC,
VPWidenPointerInductionSC,
VPReductionPHISC,
+ // END: SubclassID for recipes that inherit VPHeaderPHIRecipe
+ // END: Phi-like recipes
VPFirstPHISC = VPBlendSC,
VPFirstHeaderPHISC = VPCanonicalIVPHISC,
+ VPLastHeaderPHISC = VPReductionPHISC,
VPLastPHISC = VPReductionPHISC,
};
@@ -434,6 +441,7 @@ class VPSlotTracker {
void assignSlot(const VPValue *V);
void assignSlots(const VPlan &Plan);
+ void assignSlots(const VPBasicBlock *VPBB);
public:
VPSlotTracker(const VPlan *Plan = nullptr) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
index 18125cebed33..d6b81543dbc9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
@@ -15,6 +15,7 @@
#include "VPlanVerifier.h"
#include "VPlan.h"
#include "VPlanCFG.h"
+#include "VPlanDominatorTree.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/Support/CommandLine.h"
@@ -189,9 +190,8 @@ static bool verifyPhiRecipes(const VPBasicBlock *VPBB) {
return true;
}
-static bool
-verifyVPBasicBlock(const VPBasicBlock *VPBB,
- DenseMap<const VPBlockBase *, unsigned> &BlockNumbering) {
+static bool verifyVPBasicBlock(const VPBasicBlock *VPBB,
+ VPDominatorTree &VPDT) {
if (!verifyPhiRecipes(VPBB))
return false;
@@ -206,7 +206,8 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB,
for (const VPValue *V : R.definedValues()) {
for (const VPUser *U : V->users()) {
auto *UI = dyn_cast<VPRecipeBase>(U);
- if (!UI || isa<VPHeaderPHIRecipe>(UI))
+ // TODO: check dominance of incoming values for phis properly.
+ if (!UI || isa<VPHeaderPHIRecipe>(UI) || isa<VPPredInstPHIRecipe>(UI))
continue;
// If the user is in the same block, check it comes after R in the
@@ -219,27 +220,7 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB,
continue;
}
- // Skip blocks outside any region for now and blocks outside
- // replicate-regions.
- auto *ParentR = VPBB->getParent();
- if (!ParentR || !ParentR->isReplicator())
- continue;
-
- // For replicators, verify that VPPRedInstPHIRecipe defs are only used
- // in subsequent blocks.
- if (isa<VPPredInstPHIRecipe>(&R)) {
- auto I = BlockNumbering.find(UI->getParent());
- unsigned BlockNumber = I == BlockNumbering.end() ? std::numeric_limits<unsigned>::max() : I->second;
- if (BlockNumber < BlockNumbering[ParentR]) {
- errs() << "Use before def!\n";
- return false;
- }
- continue;
- }
-
- // All non-VPPredInstPHIRecipe recipes in the block must be used in
- // the replicate region only.
- if (UI->getParent()->getParent() != ParentR) {
+ if (!VPDT.dominates(VPBB, UI->getParent())) {
errs() << "Use before def!\n";
return false;
}
@@ -250,15 +231,13 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB,
}
bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) {
- DenseMap<const VPBlockBase *, unsigned> BlockNumbering;
- unsigned Cnt = 0;
+ VPDominatorTree VPDT;
+ VPDT.recalculate(const_cast<VPlan &>(Plan));
+
auto Iter = vp_depth_first_deep(Plan.getEntry());
- for (const VPBlockBase *VPB : Iter) {
- BlockNumbering[VPB] = Cnt++;
- auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
- if (!VPBB)
- continue;
- if (!verifyVPBasicBlock(VPBB, BlockNumbering))
+ for (const VPBasicBlock *VPBB :
+ VPBlockUtils::blocksOnly<const VPBasicBlock>(Iter)) {
+ if (!verifyVPBasicBlock(VPBB, VPDT))
return false;
}
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 2e489757ebc1..13464c9d3496 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -25,11 +25,8 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Vectorize.h"
#include <numeric>
#define DEBUG_TYPE "vector-combine"
@@ -247,7 +244,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
// still need a shuffle to change the vector size.
auto *Ty = cast<FixedVectorType>(I.getType());
unsigned OutputNumElts = Ty->getNumElements();
- SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem);
+ SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
Mask[0] = OffsetEltIndex;
if (OffsetEltIndex)
@@ -460,9 +457,9 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
// If we are extracting from 2 different indexes, then one operand must be
// shuffled before performing the vector operation. The shuffle mask is
- // undefined except for 1 lane that is being translated to the remaining
+ // poison except for 1 lane that is being translated to the remaining
// extraction lane. Therefore, it is a splat shuffle. Ex:
- // ShufMask = { undef, undef, 0, undef }
+ // ShufMask = { poison, poison, 0, poison }
// TODO: The cost model has an option for a "broadcast" shuffle
// (splat-from-element-0), but no option for a more general splat.
NewCost +=
@@ -479,11 +476,11 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
/// to a new element location.
static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
unsigned NewIndex, IRBuilder<> &Builder) {
- // The shuffle mask is undefined except for 1 lane that is being translated
+ // The shuffle mask is poison except for 1 lane that is being translated
// to the new element index. Example for OldIndex == 2 and NewIndex == 0:
- // ShufMask = { 2, undef, undef, undef }
+ // ShufMask = { 2, poison, poison, poison }
auto *VecTy = cast<FixedVectorType>(Vec->getType());
- SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
+ SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
ShufMask[NewIndex] = OldIndex;
return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
}
@@ -917,7 +914,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
InstructionCost NewCost = TTI.getCmpSelInstrCost(
CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
- SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
+ SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
ShufMask[CheapIndex] = ExpensiveIndex;
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
ShufMask);
@@ -932,7 +929,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
// Create a vector constant from the 2 scalar constants.
SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
- UndefValue::get(VecTy->getElementType()));
+ PoisonValue::get(VecTy->getElementType()));
CmpC[Index0] = C0;
CmpC[Index1] = C1;
Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
@@ -1565,7 +1562,7 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
// Calculate our ReconstructMasks from the OrigReconstructMasks and the
// modified order of the input shuffles.
SmallVector<SmallVector<int>> ReconstructMasks;
- for (auto Mask : OrigReconstructMasks) {
+ for (const auto &Mask : OrigReconstructMasks) {
SmallVector<int> ReconstructMask;
for (int M : Mask) {
auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
@@ -1596,12 +1593,12 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
}
while (V1A.size() < NumElts) {
- V1A.push_back(UndefMaskElem);
- V1B.push_back(UndefMaskElem);
+ V1A.push_back(PoisonMaskElem);
+ V1B.push_back(PoisonMaskElem);
}
while (V2A.size() < NumElts) {
- V2A.push_back(UndefMaskElem);
- V2B.push_back(UndefMaskElem);
+ V2A.push_back(PoisonMaskElem);
+ V2B.push_back(PoisonMaskElem);
}
auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
@@ -1660,16 +1657,16 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return SSV->getOperand(Op);
return SV->getOperand(Op);
};
- Builder.SetInsertPoint(SVI0A->getNextNode());
+ Builder.SetInsertPoint(SVI0A->getInsertionPointAfterDef());
Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
GetShuffleOperand(SVI0A, 1), V1A);
- Builder.SetInsertPoint(SVI0B->getNextNode());
+ Builder.SetInsertPoint(SVI0B->getInsertionPointAfterDef());
Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
GetShuffleOperand(SVI0B, 1), V1B);
- Builder.SetInsertPoint(SVI1A->getNextNode());
+ Builder.SetInsertPoint(SVI1A->getInsertionPointAfterDef());
Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
GetShuffleOperand(SVI1A, 1), V2A);
- Builder.SetInsertPoint(SVI1B->getNextNode());
+ Builder.SetInsertPoint(SVI1B->getInsertionPointAfterDef());
Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
GetShuffleOperand(SVI1B, 1), V2B);
Builder.SetInsertPoint(Op0);
@@ -1811,54 +1808,6 @@ bool VectorCombine::run() {
return MadeChange;
}
-// Pass manager boilerplate below here.
-
-namespace {
-class VectorCombineLegacyPass : public FunctionPass {
-public:
- static char ID;
- VectorCombineLegacyPass() : FunctionPass(ID) {
- initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.setPreservesCFG();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<BasicAAWrapperPass>();
- FunctionPass::getAnalysisUsage(AU);
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- VectorCombine Combiner(F, TTI, DT, AA, AC, false);
- return Combiner.run();
- }
-};
-} // namespace
-
-char VectorCombineLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
- "Optimize scalar/vector ops", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
- "Optimize scalar/vector ops", false, false)
-Pass *llvm::createVectorCombinePass() {
- return new VectorCombineLegacyPass();
-}
-
PreservedAnalyses VectorCombinePass::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &AC = FAM.getResult<AssumptionAnalysis>(F);
diff --git a/llvm/lib/Transforms/Vectorize/Vectorize.cpp b/llvm/lib/Transforms/Vectorize/Vectorize.cpp
index 208e5eeea864..2f5048d2a664 100644
--- a/llvm/lib/Transforms/Vectorize/Vectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/Vectorize.cpp
@@ -12,10 +12,6 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Vectorize.h"
-#include "llvm-c/Initialization.h"
-#include "llvm-c/Transforms/Vectorize.h"
-#include "llvm/IR/LegacyPassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/PassRegistry.h"
@@ -23,20 +19,5 @@ using namespace llvm;
/// Initialize all passes linked into the Vectorization library.
void llvm::initializeVectorization(PassRegistry &Registry) {
- initializeLoopVectorizePass(Registry);
- initializeSLPVectorizerPass(Registry);
initializeLoadStoreVectorizerLegacyPassPass(Registry);
- initializeVectorCombineLegacyPassPass(Registry);
-}
-
-void LLVMInitializeVectorization(LLVMPassRegistryRef R) {
- initializeVectorization(*unwrap(R));
-}
-
-void LLVMAddLoopVectorizePass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createLoopVectorizePass());
-}
-
-void LLVMAddSLPVectorizePass(LLVMPassManagerRef PM) {
- unwrap(PM)->add(createSLPVectorizerPass());
}