aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
commitb1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch)
tree7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms
parent7fa27ce4a07f19b07799a767fc29416f3b625afb (diff)
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp242
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp2
-rw-r--r--llvm/lib/Transforms/CFGuard/CFGuard.cpp13
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroCleanup.cpp8
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroElide.cpp85
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroFrame.cpp148
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroInstr.h48
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroInternal.h5
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroSplit.cpp114
-rw-r--r--llvm/lib/Transforms/Coroutines/Coroutines.cpp22
-rw-r--r--llvm/lib/Transforms/HipStdPar/HipStdPar.cpp312
-rw-r--r--llvm/lib/Transforms/IPO/ArgumentPromotion.cpp23
-rw-r--r--llvm/lib/Transforms/IPO/Attributor.cpp396
-rw-r--r--llvm/lib/Transforms/IPO/AttributorAttributes.cpp1675
-rw-r--r--llvm/lib/Transforms/IPO/CrossDSOCFI.cpp8
-rw-r--r--llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp4
-rw-r--r--llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp16
-rw-r--r--llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp105
-rw-r--r--llvm/lib/Transforms/IPO/FunctionAttrs.cpp154
-rw-r--r--llvm/lib/Transforms/IPO/FunctionImport.cpp132
-rw-r--r--llvm/lib/Transforms/IPO/FunctionSpecialization.cpp423
-rw-r--r--llvm/lib/Transforms/IPO/GlobalOpt.cpp90
-rw-r--r--llvm/lib/Transforms/IPO/HotColdSplitting.cpp76
-rw-r--r--llvm/lib/Transforms/IPO/IROutliner.cpp10
-rw-r--r--llvm/lib/Transforms/IPO/Inliner.cpp2
-rw-r--r--llvm/lib/Transforms/IPO/LowerTypeTests.cpp112
-rw-r--r--llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp46
-rw-r--r--llvm/lib/Transforms/IPO/MergeFunctions.cpp62
-rw-r--r--llvm/lib/Transforms/IPO/OpenMPOpt.cpp924
-rw-r--r--llvm/lib/Transforms/IPO/PartialInlining.cpp15
-rw-r--r--llvm/lib/Transforms/IPO/SCCP.cpp19
-rw-r--r--llvm/lib/Transforms/IPO/SampleContextTracker.cpp72
-rw-r--r--llvm/lib/Transforms/IPO/SampleProfile.cpp572
-rw-r--r--llvm/lib/Transforms/IPO/SampleProfileProbe.cpp23
-rw-r--r--llvm/lib/Transforms/IPO/StripSymbols.cpp17
-rw-r--r--llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp2
-rw-r--r--llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp6
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp129
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp274
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp593
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp236
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp140
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp1087
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h90
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp189
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp454
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp77
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp133
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp354
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp163
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp246
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp26
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp619
-rw-r--r--llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp212
-rw-r--r--llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp25
-rw-r--r--llvm/lib/Transforms/Instrumentation/CGProfile.cpp2
-rw-r--r--llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp14
-rw-r--r--llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp110
-rw-r--r--llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp57
-rw-r--r--llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp326
-rw-r--r--llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp6
-rw-r--r--llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp432
-rw-r--r--llvm/lib/Transforms/Instrumentation/Instrumentation.cpp7
-rw-r--r--llvm/lib/Transforms/Instrumentation/MemProfiler.cpp85
-rw-r--r--llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp625
-rw-r--r--llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp82
-rw-r--r--llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp2
-rw-r--r--llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp25
-rw-r--r--llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp99
-rw-r--r--llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp86
-rw-r--r--llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h1
-rw-r--r--llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp6
-rw-r--r--llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/ADCE.cpp10
-rw-r--r--llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp25
-rw-r--r--llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp12
-rw-r--r--llvm/lib/Transforms/Scalar/ConstantHoisting.cpp4
-rw-r--r--llvm/lib/Transforms/Scalar/ConstraintElimination.cpp502
-rw-r--r--llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp128
-rw-r--r--llvm/lib/Transforms/Scalar/DCE.cpp33
-rw-r--r--llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp30
-rw-r--r--llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp76
-rw-r--r--llvm/lib/Transforms/Scalar/EarlyCSE.cpp189
-rw-r--r--llvm/lib/Transforms/Scalar/GVN.cpp78
-rw-r--r--llvm/lib/Transforms/Scalar/GVNSink.cpp5
-rw-r--r--llvm/lib/Transforms/Scalar/GuardWidening.cpp397
-rw-r--r--llvm/lib/Transforms/Scalar/IndVarSimplify.cpp20
-rw-r--r--llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp1138
-rw-r--r--llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp83
-rw-r--r--llvm/lib/Transforms/Scalar/InferAlignment.cpp91
-rw-r--r--llvm/lib/Transforms/Scalar/JumpThreading.cpp144
-rw-r--r--llvm/lib/Transforms/Scalar/LICM.cpp156
-rw-r--r--llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp3
-rw-r--r--llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/LoopDistribute.cpp6
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFlatten.cpp5
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFuse.cpp25
-rw-r--r--llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp40
-rw-r--r--llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp57
-rw-r--r--llvm/lib/Transforms/Scalar/LoopInterchange.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp7
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPassManager.cpp10
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp173
-rw-r--r--llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp51
-rw-r--r--llvm/lib/Transforms/Scalar/LoopSink.cpp67
-rw-r--r--llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp59
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp45
-rw-r--r--llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp93
-rw-r--r--llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp36
-rw-r--r--llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp28
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp46
-rw-r--r--llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp27
-rw-r--r--llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp20
-rw-r--r--llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp282
-rw-r--r--llvm/lib/Transforms/Scalar/MergeICmps.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp53
-rw-r--r--llvm/lib/Transforms/Scalar/NaryReassociate.cpp8
-rw-r--r--llvm/lib/Transforms/Scalar/NewGVN.cpp198
-rw-r--r--llvm/lib/Transforms/Scalar/Reassociate.cpp54
-rw-r--r--llvm/lib/Transforms/Scalar/Reg2Mem.cpp35
-rw-r--r--llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp133
-rw-r--r--llvm/lib/Transforms/Scalar/SCCP.cpp4
-rw-r--r--llvm/lib/Transforms/Scalar/SROA.cpp354
-rw-r--r--llvm/lib/Transforms/Scalar/Scalar.cpp15
-rw-r--r--llvm/lib/Transforms/Scalar/Scalarizer.cpp63
-rw-r--r--llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp23
-rw-r--r--llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp306
-rw-r--r--llvm/lib/Transforms/Scalar/Sink.cpp12
-rw-r--r--llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp12
-rw-r--r--llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/StructurizeCFG.cpp5
-rw-r--r--llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp1
-rw-r--r--llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp35
-rw-r--r--llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp12
-rw-r--r--llvm/lib/Transforms/Utils/AddDiscriminators.cpp3
-rw-r--r--llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp35
-rw-r--r--llvm/lib/Transforms/Utils/BasicBlockUtils.cpp143
-rw-r--r--llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp10
-rw-r--r--llvm/lib/Transforms/Utils/BuildLibCalls.cpp171
-rw-r--r--llvm/lib/Transforms/Utils/CallPromotionUtils.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/CloneFunction.cpp66
-rw-r--r--llvm/lib/Transforms/Utils/CloneModule.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/CodeExtractor.cpp122
-rw-r--r--llvm/lib/Transforms/Utils/CodeLayout.cpp878
-rw-r--r--llvm/lib/Transforms/Utils/CodeMoverUtils.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/CtorUtils.cpp9
-rw-r--r--llvm/lib/Transforms/Utils/DXILUpgrade.cpp36
-rw-r--r--llvm/lib/Transforms/Utils/Debugify.cpp73
-rw-r--r--llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp8
-rw-r--r--llvm/lib/Transforms/Utils/EscapeEnumerator.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/FixIrreducible.cpp5
-rw-r--r--llvm/lib/Transforms/Utils/FunctionComparator.cpp94
-rw-r--r--llvm/lib/Transforms/Utils/InjectTLIMappings.cpp12
-rw-r--r--llvm/lib/Transforms/Utils/InlineFunction.cpp303
-rw-r--r--llvm/lib/Transforms/Utils/LCSSA.cpp21
-rw-r--r--llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/Local.cpp743
-rw-r--r--llvm/lib/Transforms/Utils/LoopConstrainer.cpp904
-rw-r--r--llvm/lib/Transforms/Utils/LoopPeel.cpp47
-rw-r--r--llvm/lib/Transforms/Utils/LoopRotationUtils.cpp229
-rw-r--r--llvm/lib/Transforms/Utils/LoopSimplify.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnroll.cpp3
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp125
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp210
-rw-r--r--llvm/lib/Transforms/Utils/LoopVersioning.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp13
-rw-r--r--llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp52
-rw-r--r--llvm/lib/Transforms/Utils/MetaRenamer.cpp8
-rw-r--r--llvm/lib/Transforms/Utils/ModuleUtils.cpp17
-rw-r--r--llvm/lib/Transforms/Utils/MoveAutoInit.cpp5
-rw-r--r--llvm/lib/Transforms/Utils/PredicateInfo.cpp33
-rw-r--r--llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp92
-rw-r--r--llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp7
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp78
-rw-r--r--llvm/lib/Transforms/Utils/SSAUpdater.cpp37
-rw-r--r--llvm/lib/Transforms/Utils/SampleProfileInference.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/SanitizerStats.cpp29
-rw-r--r--llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp390
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp627
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyIndVar.cpp187
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp55
-rw-r--r--llvm/lib/Transforms/Utils/StripGCRelocates.cpp20
-rw-r--r--llvm/lib/Transforms/Utils/SymbolRewriter.cpp2
-rw-r--r--llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp34
-rw-r--r--llvm/lib/Transforms/Utils/UnifyLoopExits.cpp5
-rw-r--r--llvm/lib/Transforms/Utils/Utils.cpp4
-rw-r--r--llvm/lib/Transforms/Utils/ValueMapper.cpp46
-rw-r--r--llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp10
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp38
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h62
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp2030
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp4216
-rw-r--r--llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h7
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.cpp217
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h585
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp237
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanAnalysis.h64
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp257
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp571
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp479
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h62
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanValue.h17
-rw-r--r--llvm/lib/Transforms/Vectorize/VectorCombine.cpp270
206 files changed, 20790 insertions, 12882 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 34c8a380448e..d09ac1c099c1 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -19,7 +19,6 @@
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
-#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -29,7 +28,6 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -373,7 +371,7 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
InstructionCost SatCost = TTI.getIntrinsicInstrCost(
IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
TTI::TCK_RecipThroughput);
- SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy,
+ SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy,
TTI::CastContextHint::None,
TTI::TCK_RecipThroughput);
@@ -398,6 +396,54 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
return true;
}
+/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
+/// pessimistic codegen that has to account for setting errno and can enable
+/// vectorization.
+static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI,
+ TargetLibraryInfo &TLI, AssumptionCache &AC,
+ DominatorTree &DT) {
+ // Match a call to sqrt mathlib function.
+ auto *Call = dyn_cast<CallInst>(&I);
+ if (!Call)
+ return false;
+
+ Module *M = Call->getModule();
+ LibFunc Func;
+ if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
+ return false;
+
+ if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
+ return false;
+
+ // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
+ // (because NNAN or the operand arg must not be less than -0.0) and (2) we
+ // would not end up lowering to a libcall anyway (which could change the value
+ // of errno), then:
+ // (1) errno won't be set.
+ // (2) it is safe to convert this to an intrinsic call.
+ Type *Ty = Call->getType();
+ Value *Arg = Call->getArgOperand(0);
+ if (TTI.haveFastSqrt(Ty) &&
+ (Call->hasNoNaNs() ||
+ cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, &I,
+ &DT))) {
+ IRBuilder<> Builder(&I);
+ IRBuilderBase::FastMathFlagGuard Guard(Builder);
+ Builder.setFastMathFlags(Call->getFastMathFlags());
+
+ Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
+ Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
+ I.replaceAllUsesWith(NewSqrt);
+
+ // Explicitly erase the old call because a call with side effects is not
+ // trivially dead.
+ I.eraseFromParent();
+ return true;
+ }
+
+ return false;
+}
+
// Check if this array of constants represents a cttz table.
// Iterate over the elements from \p Table by trying to find/match all
// the numbers from 0 to \p InputBits that should represent cttz results.
@@ -447,7 +493,8 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
// %shr = lshr i32 %mul, 27
// %idxprom = zext i32 %shr to i64
// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
-// i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
+// i64 %idxprom
+// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
//
// CASE 2:
// %sub = sub i32 0, %x
@@ -455,8 +502,9 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
// %mul = mul i32 %and, 72416175
// %shr = lshr i32 %mul, 26
// %idxprom = zext i32 %shr to i64
-// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64
-// 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
+// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
+// i64 0, i64 %idxprom
+// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
//
// CASE 3:
// %sub = sub i32 0, %x
@@ -464,16 +512,18 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
// %mul = mul i32 %and, 81224991
// %shr = lshr i32 %mul, 27
// %idxprom = zext i32 %shr to i64
-// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64
-// 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
+// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
+// i64 0, i64 %idxprom
+// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
//
// CASE 4:
// %sub = sub i64 0, %x
// %and = and i64 %sub, %x
// %mul = mul i64 %and, 283881067100198605
// %shr = lshr i64 %mul, 58
-// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64
-// %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
+// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
+// i64 %shr
+// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
//
// All this can be lowered to @llvm.cttz.i32/64 intrinsic.
static bool tryToRecognizeTableBasedCttz(Instruction &I) {
@@ -656,7 +706,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
make_range(Start->getIterator(), End->getIterator())) {
if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
return false;
- if (++NumScanned > MaxInstrsToScan)
+
+ // Ignore debug info so that's not counted against MaxInstrsToScan.
+ // Otherwise debug info could affect codegen.
+ if (!isa<DbgInfoIntrinsic>(Inst) && ++NumScanned > MaxInstrsToScan)
return false;
}
@@ -869,159 +922,13 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
return true;
}
-/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
-/// pessimistic codegen that has to account for setting errno and can enable
-/// vectorization.
-static bool foldSqrt(CallInst *Call, TargetTransformInfo &TTI,
- TargetLibraryInfo &TLI, AssumptionCache &AC,
- DominatorTree &DT) {
- Module *M = Call->getModule();
-
- // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
- // (because NNAN or the operand arg must not be less than -0.0) and (2) we
- // would not end up lowering to a libcall anyway (which could change the value
- // of errno), then:
- // (1) errno won't be set.
- // (2) it is safe to convert this to an intrinsic call.
- Type *Ty = Call->getType();
- Value *Arg = Call->getArgOperand(0);
- if (TTI.haveFastSqrt(Ty) &&
- (Call->hasNoNaNs() ||
- cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, Call,
- &DT))) {
- IRBuilder<> Builder(Call);
- IRBuilderBase::FastMathFlagGuard Guard(Builder);
- Builder.setFastMathFlags(Call->getFastMathFlags());
-
- Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
- Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
- Call->replaceAllUsesWith(NewSqrt);
-
- // Explicitly erase the old call because a call with side effects is not
- // trivially dead.
- Call->eraseFromParent();
- return true;
- }
-
- return false;
-}
-
-/// Try to expand strcmp(P, "x") calls.
-static bool expandStrcmp(CallInst *CI, DominatorTree &DT, bool &MadeCFGChange) {
- Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1);
-
- // Trivial cases are optimized during inst combine
- if (Str1P == Str2P)
- return false;
-
- StringRef Str1, Str2;
- bool HasStr1 = getConstantStringInfo(Str1P, Str1);
- bool HasStr2 = getConstantStringInfo(Str2P, Str2);
-
- Value *NonConstantP = nullptr;
- StringRef ConstantStr;
-
- if (!HasStr1 && HasStr2 && Str2.size() == 1) {
- NonConstantP = Str1P;
- ConstantStr = Str2;
- } else if (!HasStr2 && HasStr1 && Str1.size() == 1) {
- NonConstantP = Str2P;
- ConstantStr = Str1;
- } else {
- return false;
- }
-
- // Check if strcmp result is only used in a comparison with zero
- if (!isOnlyUsedInZeroComparison(CI))
- return false;
-
- // For strcmp(P, "x") do the following transformation:
- //
- // (before)
- // dst = strcmp(P, "x")
- //
- // (after)
- // v0 = P[0] - 'x'
- // [if v0 == 0]
- // v1 = P[1]
- // dst = phi(v0, v1)
- //
-
- IRBuilder<> B(CI->getParent());
- DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
-
- Type *RetType = CI->getType();
-
- B.SetInsertPoint(CI);
- BasicBlock *InitialBB = B.GetInsertBlock();
- Value *Str1FirstCharacterValue =
- B.CreateZExt(B.CreateLoad(B.getInt8Ty(), NonConstantP), RetType);
- Value *Str2FirstCharacterValue =
- ConstantInt::get(RetType, static_cast<unsigned char>(ConstantStr[0]));
- Value *FirstCharacterSub =
- B.CreateNSWSub(Str1FirstCharacterValue, Str2FirstCharacterValue);
- Value *IsFirstCharacterSubZero =
- B.CreateICmpEQ(FirstCharacterSub, ConstantInt::get(RetType, 0));
- Instruction *IsFirstCharacterSubZeroBBTerminator = SplitBlockAndInsertIfThen(
- IsFirstCharacterSubZero, CI, /*Unreachable*/ false,
- /*BranchWeights*/ nullptr, &DTU);
-
- B.SetInsertPoint(IsFirstCharacterSubZeroBBTerminator);
- B.GetInsertBlock()->setName("strcmp_expand_sub_is_zero");
- BasicBlock *IsFirstCharacterSubZeroBB = B.GetInsertBlock();
- Value *Str1SecondCharacterValue = B.CreateZExt(
- B.CreateLoad(B.getInt8Ty(), B.CreateConstInBoundsGEP1_64(
- B.getInt8Ty(), NonConstantP, 1)),
- RetType);
-
- B.SetInsertPoint(CI);
- B.GetInsertBlock()->setName("strcmp_expand_sub_join");
-
- PHINode *Result = B.CreatePHI(RetType, 2);
- Result->addIncoming(FirstCharacterSub, InitialBB);
- Result->addIncoming(Str1SecondCharacterValue, IsFirstCharacterSubZeroBB);
-
- CI->replaceAllUsesWith(Result);
- CI->eraseFromParent();
-
- MadeCFGChange = true;
-
- return true;
-}
-
-static bool foldLibraryCalls(Instruction &I, TargetTransformInfo &TTI,
- TargetLibraryInfo &TLI, DominatorTree &DT,
- AssumptionCache &AC, bool &MadeCFGChange) {
- CallInst *CI = dyn_cast<CallInst>(&I);
- if (!CI)
- return false;
-
- LibFunc Func;
- Module *M = I.getModule();
- if (!TLI.getLibFunc(*CI, Func) || !isLibFuncEmittable(M, &TLI, Func))
- return false;
-
- switch (Func) {
- case LibFunc_sqrt:
- case LibFunc_sqrtf:
- case LibFunc_sqrtl:
- return foldSqrt(CI, TTI, TLI, AC, DT);
- case LibFunc_strcmp:
- return expandStrcmp(CI, DT, MadeCFGChange);
- default:
- break;
- }
-
- return false;
-}
-
/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
TargetTransformInfo &TTI,
TargetLibraryInfo &TLI, AliasAnalysis &AA,
- AssumptionCache &AC, bool &MadeCFGChange) {
+ AssumptionCache &AC) {
bool MadeChange = false;
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
@@ -1046,7 +953,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
- MadeChange |= foldLibraryCalls(I, TTI, TLI, DT, AC, MadeCFGChange);
+ MadeChange |= foldSqrt(I, TTI, TLI, AC, DT);
}
}
@@ -1062,12 +969,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
/// handled in the callers of this function.
static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
TargetLibraryInfo &TLI, DominatorTree &DT,
- AliasAnalysis &AA, bool &ChangedCFG) {
+ AliasAnalysis &AA) {
bool MadeChange = false;
const DataLayout &DL = F.getParent()->getDataLayout();
TruncInstCombine TIC(AC, TLI, DL, DT);
MadeChange |= TIC.run(F);
- MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, ChangedCFG);
+ MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC);
return MadeChange;
}
@@ -1078,21 +985,12 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
-
- bool MadeCFGChange = false;
-
- if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
+ if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
}
-
// Mark all the analyses that instcombine updates as preserved.
PreservedAnalyses PA;
-
- if (MadeCFGChange)
- PA.preserve<DominatorTreeAnalysis>();
- else
- PA.preserveSet<CFGAnalyses>();
-
+ PA.preserveSet<CFGAnalyses>();
return PA;
}
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 6c62e84077ac..4d9050be5c55 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -366,7 +366,7 @@ static Type *getReducedType(Value *V, Type *Ty) {
Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
Type *Ty = getReducedType(V, SclTy);
if (auto *C = dyn_cast<Constant>(V)) {
- C = ConstantExpr::getIntegerCast(C, Ty, false);
+ C = ConstantExpr::getTrunc(C, Ty);
// If we got a constantexpr back, try to simplify it with DL info.
return ConstantFoldConstant(C, DL, &TLI);
}
diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
index bf823ac55497..387734358775 100644
--- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp
+++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
@@ -177,8 +177,7 @@ void CFGuard::insertCFGuardCheck(CallBase *CB) {
// Create new call instruction. The CFGuard check should always be a call,
// even if the original CallBase is an Invoke or CallBr instruction.
CallInst *GuardCheck =
- B.CreateCall(GuardFnType, GuardCheckLoad,
- {B.CreateBitCast(CalledOperand, B.getInt8PtrTy())}, Bundles);
+ B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand}, Bundles);
// Ensure that the first argument is passed in the correct register
// (e.g. ECX on 32-bit X86 targets).
@@ -196,11 +195,6 @@ void CFGuard::insertCFGuardDispatch(CallBase *CB) {
Value *CalledOperand = CB->getCalledOperand();
Type *CalledOperandType = CalledOperand->getType();
- // Cast the guard dispatch global to the type of the called operand.
- PointerType *PTy = PointerType::get(CalledOperandType, 0);
- if (GuardFnGlobal->getType() != PTy)
- GuardFnGlobal = ConstantExpr::getBitCast(GuardFnGlobal, PTy);
-
// Load the global as a pointer to a function of the same type.
LoadInst *GuardDispatchLoad = B.CreateLoad(CalledOperandType, GuardFnGlobal);
@@ -236,8 +230,9 @@ bool CFGuard::doInitialization(Module &M) {
return false;
// Set up prototypes for the guard check and dispatch functions.
- GuardFnType = FunctionType::get(Type::getVoidTy(M.getContext()),
- {Type::getInt8PtrTy(M.getContext())}, false);
+ GuardFnType =
+ FunctionType::get(Type::getVoidTy(M.getContext()),
+ {PointerType::getUnqual(M.getContext())}, false);
GuardFnPtrType = PointerType::get(GuardFnType, 0);
// Get or insert the guard check or dispatch global symbols.
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index 29978bef661c..3e3825fcd50e 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -29,15 +29,13 @@ struct Lowerer : coro::LowererBase {
static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) {
Builder.SetInsertPoint(SubFn);
- Value *FrameRaw = SubFn->getFrame();
+ Value *FramePtr = SubFn->getFrame();
int Index = SubFn->getIndex();
- auto *FrameTy = StructType::get(
- SubFn->getContext(), {Builder.getInt8PtrTy(), Builder.getInt8PtrTy()});
- PointerType *FramePtrTy = FrameTy->getPointerTo();
+ auto *FrameTy = StructType::get(SubFn->getContext(),
+ {Builder.getPtrTy(), Builder.getPtrTy()});
Builder.SetInsertPoint(SubFn);
- auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy);
auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index);
auto *Load = Builder.CreateLoad(FrameTy->getElementType(Index), Gep);
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index d78ab1c1ea28..2f4083028ae0 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -165,7 +165,7 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
Frame->setAlignment(FrameAlign);
auto *FrameVoidPtr =
- new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);
+ new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt);
for (auto *CB : CoroBegins) {
CB->replaceAllUsesWith(FrameVoidPtr);
@@ -194,12 +194,49 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
for (auto *DA : It->second)
Visited.insert(DA->getParent());
+ SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
+ for (auto *U : CB->users()) {
+ // The use from coroutine intrinsics are not a problem.
+ if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U))
+ continue;
+
+ // Think all other usages may be an escaping candidate conservatively.
+ //
+ // Note that the major user of switch ABI coroutine (the C++) will store
+ // resume.fn, destroy.fn and the index to the coroutine frame immediately.
+ // So the parent of the coro.begin in C++ will be always escaping.
+ // Then we can't get any performance benefits for C++ by improving the
+ // precision of the method.
+ //
+ // The reason why we still judge it is we want to make LLVM Coroutine in
+ // switch ABIs to be self contained as much as possible instead of a
+ // by-product of C++20 Coroutines.
+ EscapingBBs.insert(cast<Instruction>(U)->getParent());
+ }
+
+ bool PotentiallyEscaped = false;
+
do {
const auto *BB = Worklist.pop_back_val();
if (!Visited.insert(BB).second)
continue;
- if (TIs.count(BB))
- return true;
+
+ // A Path insensitive marker to test whether the coro.begin escapes.
+ // It is intentional to make it path insensitive while it may not be
+ // precise since we don't want the process to be too slow.
+ PotentiallyEscaped |= EscapingBBs.count(BB);
+
+ if (TIs.count(BB)) {
+ if (isa<ReturnInst>(BB->getTerminator()) || PotentiallyEscaped)
+ return true;
+
+ // If the function ends with the exceptional terminator, the memory used
+ // by the coroutine frame can be released by stack unwinding
+ // automatically. So we can think the coro.begin doesn't escape if it
+ // exits the function by exceptional terminator.
+
+ continue;
+ }
// Conservatively say that there is potentially a path.
if (!--Limit)
@@ -236,36 +273,36 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
// memory location storing that value and not the virtual register.
SmallPtrSet<BasicBlock *, 8> Terminators;
- // First gather all of the non-exceptional terminators for the function.
+ // First gather all of the terminators for the function.
// Consider the final coro.suspend as the real terminator when the current
// function is a coroutine.
- for (BasicBlock &B : *F) {
- auto *TI = B.getTerminator();
- if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
- !isa<UnreachableInst>(TI))
- Terminators.insert(&B);
- }
+ for (BasicBlock &B : *F) {
+ auto *TI = B.getTerminator();
+
+ if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
+ continue;
+
+ Terminators.insert(&B);
+ }
// Filter out the coro.destroy that lie along exceptional paths.
SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
for (const auto &It : DestroyAddr) {
- // If there is any coro.destroy dominates all of the terminators for the
- // coro.begin, we could know the corresponding coro.begin wouldn't escape.
- for (Instruction *DA : It.second) {
- if (llvm::all_of(Terminators, [&](auto *TI) {
- return DT.dominates(DA, TI->getTerminator());
- })) {
- ReferencedCoroBegins.insert(It.first);
- break;
- }
- }
-
- // Whether there is any paths from coro.begin to Terminators which not pass
- // through any of the coro.destroys.
+ // If every terminators is dominated by coro.destroy, we could know the
+ // corresponding coro.begin wouldn't escape.
+ //
+ // Otherwise hasEscapePath would decide whether there is any paths from
+ // coro.begin to Terminators which not pass through any of the
+ // coro.destroys.
//
// hasEscapePath is relatively slow, so we avoid to run it as much as
// possible.
- if (!ReferencedCoroBegins.count(It.first) &&
+ if (llvm::all_of(Terminators,
+ [&](auto *TI) {
+ return llvm::any_of(It.second, [&](auto *DA) {
+ return DT.dominates(DA, TI->getTerminator());
+ });
+ }) ||
!hasEscapePath(It.first, Terminators))
ReferencedCoroBegins.insert(It.first);
}
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 1f373270f951..1134b20880f1 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -63,7 +63,7 @@ public:
llvm::sort(V);
}
- size_t blockToIndex(BasicBlock *BB) const {
+ size_t blockToIndex(BasicBlock const *BB) const {
auto *I = llvm::lower_bound(V, BB);
assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
return I - V.begin();
@@ -112,10 +112,11 @@ class SuspendCrossingInfo {
}
/// Compute the BlockData for the current function in one iteration.
- /// Returns whether the BlockData changes in this iteration.
/// Initialize - Whether this is the first iteration, we can optimize
/// the initial case a little bit by manual loop switch.
- template <bool Initialize = false> bool computeBlockData();
+ /// Returns whether the BlockData changes in this iteration.
+ template <bool Initialize = false>
+ bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
public:
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -223,12 +224,14 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
}
#endif
-template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
- const size_t N = Mapping.size();
+template <bool Initialize>
+bool SuspendCrossingInfo::computeBlockData(
+ const ReversePostOrderTraversal<Function *> &RPOT) {
bool Changed = false;
- for (size_t I = 0; I < N; ++I) {
- auto &B = Block[I];
+ for (const BasicBlock *BB : RPOT) {
+ auto BBNo = Mapping.blockToIndex(BB);
+ auto &B = Block[BBNo];
// We don't need to count the predecessors when initialization.
if constexpr (!Initialize)
@@ -261,7 +264,7 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
}
if (B.Suspend) {
- // If block S is a suspend block, it should kill all of the blocks it
+ // If block B is a suspend block, it should kill all of the blocks it
// consumes.
B.Kills |= B.Consumes;
} else if (B.End) {
@@ -273,8 +276,8 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
} else {
// This is reached when B block it not Suspend nor coro.end and it
// need to make sure that it is not in the kill set.
- B.KillLoop |= B.Kills[I];
- B.Kills.reset(I);
+ B.KillLoop |= B.Kills[BBNo];
+ B.Kills.reset(BBNo);
}
if constexpr (!Initialize) {
@@ -283,9 +286,6 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
}
}
- if constexpr (Initialize)
- return true;
-
return Changed;
}
@@ -325,9 +325,11 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
markSuspendBlock(Save);
}
- computeBlockData</*Initialize=*/true>();
-
- while (computeBlockData())
+ // It is considered to be faster to use RPO traversal for forward-edges
+ // dataflow analysis.
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ computeBlockData</*Initialize=*/true>(RPOT);
+ while (computeBlockData</*Initialize*/ false>(RPOT))
;
LLVM_DEBUG(dump());
@@ -1073,7 +1075,7 @@ static DIType *solveDIType(DIBuilder &Builder, Type *Ty,
RetType = CharSizeType;
else {
if (Size % 8 != 0)
- Size = TypeSize::Fixed(Size + 8 - (Size % 8));
+ Size = TypeSize::getFixed(Size + 8 - (Size % 8));
RetType = Builder.createArrayType(
Size, Layout.getPrefTypeAlign(Ty).value(), CharSizeType,
@@ -1290,10 +1292,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
std::optional<FieldIDType> SwitchIndexFieldId;
if (Shape.ABI == coro::ABI::Switch) {
- auto *FramePtrTy = FrameTy->getPointerTo();
- auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
- /*IsVarArg=*/false);
- auto *FnPtrTy = FnTy->getPointerTo();
+ auto *FnPtrTy = PointerType::getUnqual(C);
// Add header fields for the resume and destroy functions.
// We can rely on these being perfectly packed.
@@ -1680,15 +1679,6 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
return CleanupRet;
}
-static void createFramePtr(coro::Shape &Shape) {
- auto *CB = Shape.CoroBegin;
- IRBuilder<> Builder(CB->getNextNode());
- StructType *FrameTy = Shape.FrameTy;
- PointerType *FramePtrTy = FrameTy->getPointerTo();
- Shape.FramePtr =
- cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr"));
-}
-
// Replace all alloca and SSA values that are accessed across suspend points
// with GetElementPointer from coroutine frame + loads and stores. Create an
// AllocaSpillBB that will become the new entry block for the resume parts of
@@ -1700,7 +1690,6 @@ static void createFramePtr(coro::Shape &Shape) {
// becomes:
//
// %hdl = coro.begin(...)
-// %FramePtr = bitcast i8* hdl to %f.frame*
// br label %AllocaSpillBB
//
// AllocaSpillBB:
@@ -1764,8 +1753,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
// Note: If we change the strategy dealing with alignment, we need to refine
// this casting.
if (GEP->getType() != Orig->getType())
- return Builder.CreateBitCast(GEP, Orig->getType(),
- Orig->getName() + Twine(".cast"));
+ return Builder.CreateAddrSpaceCast(GEP, Orig->getType(),
+ Orig->getName() + Twine(".cast"));
}
return GEP;
};
@@ -1775,13 +1764,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
auto SpillAlignment = Align(FrameData.getAlign(Def));
// Create a store instruction storing the value into the
// coroutine frame.
- Instruction *InsertPt = nullptr;
+ BasicBlock::iterator InsertPt;
Type *ByValTy = nullptr;
if (auto *Arg = dyn_cast<Argument>(Def)) {
// For arguments, we will place the store instruction right after
- // the coroutine frame pointer instruction, i.e. bitcast of
- // coro.begin from i8* to %f.frame*.
- InsertPt = Shape.getInsertPtAfterFramePtr();
+ // the coroutine frame pointer instruction, i.e. coro.begin.
+ InsertPt = Shape.getInsertPtAfterFramePtr()->getIterator();
// If we're spilling an Argument, make sure we clear 'nocapture'
// from the coroutine function.
@@ -1792,35 +1780,35 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
} else if (auto *CSI = dyn_cast<AnyCoroSuspendInst>(Def)) {
// Don't spill immediately after a suspend; splitting assumes
// that the suspend will be followed by a branch.
- InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHI();
+ InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHIIt();
} else {
auto *I = cast<Instruction>(Def);
if (!DT.dominates(CB, I)) {
// If it is not dominated by CoroBegin, then spill should be
// inserted immediately after CoroFrame is computed.
- InsertPt = Shape.getInsertPtAfterFramePtr();
+ InsertPt = Shape.getInsertPtAfterFramePtr()->getIterator();
} else if (auto *II = dyn_cast<InvokeInst>(I)) {
// If we are spilling the result of the invoke instruction, split
// the normal edge and insert the spill in the new block.
auto *NewBB = SplitEdge(II->getParent(), II->getNormalDest());
- InsertPt = NewBB->getTerminator();
+ InsertPt = NewBB->getTerminator()->getIterator();
} else if (isa<PHINode>(I)) {
// Skip the PHINodes and EH pads instructions.
BasicBlock *DefBlock = I->getParent();
if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
- InsertPt = splitBeforeCatchSwitch(CSI);
+ InsertPt = splitBeforeCatchSwitch(CSI)->getIterator();
else
- InsertPt = &*DefBlock->getFirstInsertionPt();
+ InsertPt = DefBlock->getFirstInsertionPt();
} else {
assert(!I->isTerminator() && "unexpected terminator");
// For all other values, the spill is placed immediately after
// the definition.
- InsertPt = I->getNextNode();
+ InsertPt = I->getNextNode()->getIterator();
}
}
auto Index = FrameData.getFieldIndex(Def);
- Builder.SetInsertPoint(InsertPt);
+ Builder.SetInsertPoint(InsertPt->getParent(), InsertPt);
auto *G = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, Index, Def->getName() + Twine(".spill.addr"));
if (ByValTy) {
@@ -1840,7 +1828,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
// reference provided with the frame GEP.
if (CurrentBlock != U->getParent()) {
CurrentBlock = U->getParent();
- Builder.SetInsertPoint(&*CurrentBlock->getFirstInsertionPt());
+ Builder.SetInsertPoint(CurrentBlock,
+ CurrentBlock->getFirstInsertionPt());
auto *GEP = GetFramePointer(E.first);
GEP->setName(E.first->getName() + Twine(".reload.addr"));
@@ -1863,6 +1852,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
if (LdInst->getPointerOperandType() != LdInst->getType())
break;
CurDef = LdInst->getPointerOperand();
+ if (!isa<AllocaInst, LoadInst>(CurDef))
+ break;
DIs = FindDbgDeclareUses(CurDef);
}
}
@@ -1878,7 +1869,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
&*Builder.GetInsertPoint());
// This dbg.declare is for the main function entry point. It
// will be deleted in all coro-split functions.
- coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame);
+ coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame,
+ false /*UseEntryValue*/);
}
}
@@ -1911,7 +1903,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
if (Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce ||
Shape.ABI == coro::ABI::Async) {
// If we found any allocas, replace all of their remaining uses with Geps.
- Builder.SetInsertPoint(&SpillBlock->front());
+ Builder.SetInsertPoint(SpillBlock, SpillBlock->begin());
for (const auto &P : FrameData.Allocas) {
AllocaInst *Alloca = P.Alloca;
auto *G = GetFramePointer(Alloca);
@@ -1930,7 +1922,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
// dbg.declares and dbg.values with the reload from the frame.
// Note: We cannot replace the alloca with GEP instructions indiscriminately,
// as some of the uses may not be dominated by CoroBegin.
- Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front());
+ Builder.SetInsertPoint(Shape.AllocaSpillBlock,
+ Shape.AllocaSpillBlock->begin());
SmallVector<Instruction *, 4> UsersToUpdate;
for (const auto &A : FrameData.Allocas) {
AllocaInst *Alloca = A.Alloca;
@@ -1980,16 +1973,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
// to the pointer in the frame.
for (const auto &Alias : A.Aliases) {
auto *FramePtr = GetFramePointer(Alloca);
- auto *FramePtrRaw =
- Builder.CreateBitCast(FramePtr, Type::getInt8PtrTy(C));
auto &Value = *Alias.second;
auto ITy = IntegerType::get(C, Value.getBitWidth());
- auto *AliasPtr = Builder.CreateGEP(Type::getInt8Ty(C), FramePtrRaw,
+ auto *AliasPtr = Builder.CreateGEP(Type::getInt8Ty(C), FramePtr,
ConstantInt::get(ITy, Value));
- auto *AliasPtrTyped =
- Builder.CreateBitCast(AliasPtr, Alias.first->getType());
Alias.first->replaceUsesWithIf(
- AliasPtrTyped, [&](Use &U) { return DT.dominates(CB, U); });
+ AliasPtr, [&](Use &U) { return DT.dominates(CB, U); });
}
}
@@ -2046,8 +2035,8 @@ static void movePHIValuesToInsertedBlock(BasicBlock *SuccBB,
int Index = PN->getBasicBlockIndex(InsertedBB);
Value *V = PN->getIncomingValue(Index);
PHINode *InputV = PHINode::Create(
- V->getType(), 1, V->getName() + Twine(".") + SuccBB->getName(),
- &InsertedBB->front());
+ V->getType(), 1, V->getName() + Twine(".") + SuccBB->getName());
+ InputV->insertBefore(InsertedBB->begin());
InputV->addIncoming(V, PredBB);
PN->setIncomingValue(Index, InputV);
PN = dyn_cast<PHINode>(PN->getNextNode());
@@ -2193,7 +2182,8 @@ static void rewritePHIs(BasicBlock &BB) {
// ehAwareSplitEdge will clone the LandingPad in all the edge blocks.
// We replace the original landing pad with a PHINode that will collect the
// results from all of them.
- ReplPHI = PHINode::Create(LandingPad->getType(), 1, "", LandingPad);
+ ReplPHI = PHINode::Create(LandingPad->getType(), 1, "");
+ ReplPHI->insertBefore(LandingPad->getIterator());
ReplPHI->takeName(LandingPad);
LandingPad->replaceAllUsesWith(ReplPHI);
// We will erase the original landing pad at the end of this function after
@@ -2428,15 +2418,13 @@ static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) {
static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas,
SmallVectorImpl<Instruction*> &DeadInsts) {
for (auto *AI : LocalAllocas) {
- auto M = AI->getModule();
IRBuilder<> Builder(AI);
// Save the stack depth. Try to avoid doing this if the stackrestore
// is going to immediately precede a return or something.
Value *StackSave = nullptr;
if (localAllocaNeedsStackSave(AI))
- StackSave = Builder.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::stacksave));
+ StackSave = Builder.CreateStackSave();
// Allocate memory.
auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize());
@@ -2454,9 +2442,7 @@ static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas,
auto FI = cast<CoroAllocaFreeInst>(U);
if (StackSave) {
Builder.SetInsertPoint(FI);
- Builder.CreateCall(
- Intrinsic::getDeclaration(M, Intrinsic::stackrestore),
- StackSave);
+ Builder.CreateStackRestore(StackSave);
}
}
DeadInsts.push_back(cast<Instruction>(U));
@@ -2498,7 +2484,7 @@ static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy,
coro::Shape &Shape) {
// Make a fake function pointer as a sort of intrinsic.
auto FnTy = FunctionType::get(ValueTy, {}, false);
- auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+ auto Fn = ConstantPointerNull::get(Builder.getPtrTy());
auto Call = Builder.CreateCall(FnTy, Fn, {});
Shape.SwiftErrorOps.push_back(Call);
@@ -2512,9 +2498,9 @@ static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy,
static Value *emitSetSwiftErrorValue(IRBuilder<> &Builder, Value *V,
coro::Shape &Shape) {
// Make a fake function pointer as a sort of intrinsic.
- auto FnTy = FunctionType::get(V->getType()->getPointerTo(),
+ auto FnTy = FunctionType::get(Builder.getPtrTy(),
{V->getType()}, false);
- auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+ auto Fn = ConstantPointerNull::get(Builder.getPtrTy());
auto Call = Builder.CreateCall(FnTy, Fn, { V });
Shape.SwiftErrorOps.push_back(Call);
@@ -2765,17 +2751,8 @@ static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape,
// Sink lifetime.start markers to dominate block when they are
// only used outside the region.
if (Valid && Lifetimes.size() != 0) {
- // May be AI itself, when the type of AI is i8*
- auto *NewBitCast = [&](AllocaInst *AI) -> Value* {
- if (isa<AllocaInst>(Lifetimes[0]->getOperand(1)))
- return AI;
- auto *Int8PtrTy = Type::getInt8PtrTy(F.getContext());
- return CastInst::Create(Instruction::BitCast, AI, Int8PtrTy, "",
- DomBB->getTerminator());
- }(AI);
-
auto *NewLifetime = Lifetimes[0]->clone();
- NewLifetime->replaceUsesOfWith(NewLifetime->getOperand(1), NewBitCast);
+ NewLifetime->replaceUsesOfWith(NewLifetime->getOperand(1), AI);
NewLifetime->insertBefore(DomBB->getTerminator());
// All the outsided lifetime.start markers are no longer necessary.
@@ -2800,6 +2777,11 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape,
if (AI == Shape.SwitchLowering.PromiseAlloca)
return;
+ // The __coro_gro alloca should outlive the promise, make sure we
+ // keep it outside the frame.
+ if (AI->hasMetadata(LLVMContext::MD_coro_outside_frame))
+ return;
+
// The code that uses lifetime.start intrinsic does not work for functions
// with loops without exit. Disable it on ABIs we know to generate such
// code.
@@ -2818,7 +2800,7 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape,
void coro::salvageDebugInfo(
SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap,
- DbgVariableIntrinsic *DVI, bool OptimizeFrame) {
+ DbgVariableIntrinsic *DVI, bool OptimizeFrame, bool UseEntryValue) {
Function *F = DVI->getFunction();
IRBuilder<> Builder(F->getContext());
auto InsertPt = F->getEntryBlock().getFirstInsertionPt();
@@ -2870,7 +2852,9 @@ void coro::salvageDebugInfo(
// Swift async arguments are described by an entry value of the ABI-defined
// register containing the coroutine context.
- if (IsSwiftAsyncArg && !Expr->isEntryValue())
+ // Entry values in variadic expressions are not supported.
+ if (IsSwiftAsyncArg && UseEntryValue && !Expr->isEntryValue() &&
+ Expr->isSingleLocationExpression())
Expr = DIExpression::prepend(Expr, DIExpression::EntryValue);
// If the coroutine frame is an Argument, store it in an alloca to improve
@@ -2902,13 +2886,13 @@ void coro::salvageDebugInfo(
// dbg.value since it does not have the same function wide guarantees that
// dbg.declare does.
if (isa<DbgDeclareInst>(DVI)) {
- Instruction *InsertPt = nullptr;
+ std::optional<BasicBlock::iterator> InsertPt;
if (auto *I = dyn_cast<Instruction>(Storage))
InsertPt = I->getInsertionPointAfterDef();
else if (isa<Argument>(Storage))
- InsertPt = &*F->getEntryBlock().begin();
+ InsertPt = F->getEntryBlock().begin();
if (InsertPt)
- DVI->moveBefore(InsertPt);
+ DVI->moveBefore(*(*InsertPt)->getParent(), *InsertPt);
}
}
@@ -3110,7 +3094,7 @@ void coro::buildCoroutineFrame(
Shape.ABI == coro::ABI::Async)
sinkSpillUsesAfterCoroBegin(F, FrameData, Shape.CoroBegin);
Shape.FrameTy = buildFrameType(F, Shape, FrameData);
- createFramePtr(Shape);
+ Shape.FramePtr = Shape.CoroBegin;
// For now, this works for C++ programs only.
buildFrameDebugInfo(F, Shape, FrameData);
insertSpills(FrameData, Shape);
diff --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h
index 014938c15a0a..f01aa58eb899 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h
@@ -123,8 +123,8 @@ public:
void clearPromise() {
Value *Arg = getArgOperand(PromiseArg);
- setArgOperand(PromiseArg,
- ConstantPointerNull::get(Type::getInt8PtrTy(getContext())));
+ setArgOperand(PromiseArg, ConstantPointerNull::get(
+ PointerType::getUnqual(getContext())));
if (isa<AllocaInst>(Arg))
return;
assert((isa<BitCastInst>(Arg) || isa<GetElementPtrInst>(Arg)) &&
@@ -185,9 +185,7 @@ public:
void setCoroutineSelf() {
assert(isa<ConstantPointerNull>(getArgOperand(CoroutineArg)) &&
"Coroutine argument is already assigned");
- auto *const Int8PtrTy = Type::getInt8PtrTy(getContext());
- setArgOperand(CoroutineArg,
- ConstantExpr::getBitCast(getFunction(), Int8PtrTy));
+ setArgOperand(CoroutineArg, getFunction());
}
// Methods to support type inquiry through isa, cast, and dyn_cast:
@@ -611,8 +609,37 @@ public:
}
};
+/// This represents the llvm.end.results instruction.
+class LLVM_LIBRARY_VISIBILITY CoroEndResults : public IntrinsicInst {
+public:
+ op_iterator retval_begin() { return arg_begin(); }
+ const_op_iterator retval_begin() const { return arg_begin(); }
+
+ op_iterator retval_end() { return arg_end(); }
+ const_op_iterator retval_end() const { return arg_end(); }
+
+ iterator_range<op_iterator> return_values() {
+ return make_range(retval_begin(), retval_end());
+ }
+ iterator_range<const_op_iterator> return_values() const {
+ return make_range(retval_begin(), retval_end());
+ }
+
+ unsigned numReturns() const {
+ return std::distance(retval_begin(), retval_end());
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_end_results;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
class LLVM_LIBRARY_VISIBILITY AnyCoroEndInst : public IntrinsicInst {
- enum { FrameArg, UnwindArg };
+ enum { FrameArg, UnwindArg, TokenArg };
public:
bool isFallthrough() const { return !isUnwind(); }
@@ -620,6 +647,15 @@ public:
return cast<Constant>(getArgOperand(UnwindArg))->isOneValue();
}
+ bool hasResults() const {
+ return !isa<ConstantTokenNone>(getArgOperand(TokenArg));
+ }
+
+ CoroEndResults *getResults() const {
+ assert(hasResults());
+ return cast<CoroEndResults>(getArgOperand(TokenArg));
+ }
+
// Methods to support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
auto ID = I->getIntrinsicID();
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index 067fb6bba47e..0856c4925cc5 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -32,7 +32,7 @@ void replaceCoroFree(CoroIdInst *CoroId, bool Elide);
/// OptimizeFrame is false.
void salvageDebugInfo(
SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap,
- DbgVariableIntrinsic *DVI, bool OptimizeFrame);
+ DbgVariableIntrinsic *DVI, bool OptimizeFrame, bool IsEntryPoint);
// Keeps data and helper functions for lowering coroutine intrinsics.
struct LowererBase {
@@ -185,7 +185,8 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
switch (ABI) {
case coro::ABI::Switch:
return FunctionType::get(Type::getVoidTy(FrameTy->getContext()),
- FrameTy->getPointerTo(), /*IsVarArg*/false);
+ PointerType::getUnqual(FrameTy->getContext()),
+ /*IsVarArg=*/false);
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
return RetconLowering.ResumePrototype->getFunctionType();
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 39e909bf3316..244580f503d5 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -234,6 +234,8 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End,
switch (Shape.ABI) {
// The cloned functions in switch-lowering always return void.
case coro::ABI::Switch:
+ assert(!cast<CoroEndInst>(End)->hasResults() &&
+ "switch coroutine should not return any values");
// coro.end doesn't immediately end the coroutine in the main function
// in this lowering, because we need to deallocate the coroutine.
if (!InResume)
@@ -251,14 +253,45 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End,
// In unique continuation lowering, the continuations always return void.
// But we may have implicitly allocated storage.
- case coro::ABI::RetconOnce:
+ case coro::ABI::RetconOnce: {
maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
- Builder.CreateRetVoid();
+ auto *CoroEnd = cast<CoroEndInst>(End);
+ auto *RetTy = Shape.getResumeFunctionType()->getReturnType();
+
+ if (!CoroEnd->hasResults()) {
+ assert(RetTy->isVoidTy());
+ Builder.CreateRetVoid();
+ break;
+ }
+
+ auto *CoroResults = CoroEnd->getResults();
+ unsigned NumReturns = CoroResults->numReturns();
+
+ if (auto *RetStructTy = dyn_cast<StructType>(RetTy)) {
+ assert(RetStructTy->getNumElements() == NumReturns &&
+ "numbers of returns should match resume function singature");
+ Value *ReturnValue = UndefValue::get(RetStructTy);
+ unsigned Idx = 0;
+ for (Value *RetValEl : CoroResults->return_values())
+ ReturnValue = Builder.CreateInsertValue(ReturnValue, RetValEl, Idx++);
+ Builder.CreateRet(ReturnValue);
+ } else if (NumReturns == 0) {
+ assert(RetTy->isVoidTy());
+ Builder.CreateRetVoid();
+ } else {
+ assert(NumReturns == 1);
+ Builder.CreateRet(*CoroResults->retval_begin());
+ }
+ CoroResults->replaceAllUsesWith(ConstantTokenNone::get(CoroResults->getContext()));
+ CoroResults->eraseFromParent();
break;
+ }
// In non-unique continuation lowering, we signal completion by returning
// a null continuation.
case coro::ABI::Retcon: {
+ assert(!cast<CoroEndInst>(End)->hasResults() &&
+ "retcon coroutine should not return any values");
maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
auto RetTy = Shape.getResumeFunctionType()->getReturnType();
auto RetStructTy = dyn_cast<StructType>(RetTy);
@@ -457,7 +490,8 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
Switch->addCase(IndexVal, ResumeBB);
cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
- auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
+ auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
+ PN->insertBefore(LandingBB->begin());
S->replaceAllUsesWith(PN);
PN->addIncoming(Builder.getInt8(-1), SuspendBB);
PN->addIncoming(S, ResumeBB);
@@ -495,13 +529,20 @@ void CoroCloner::handleFinalSuspend() {
BasicBlock *OldSwitchBB = Switch->getParent();
auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
Builder.SetInsertPoint(OldSwitchBB->getTerminator());
- auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr,
- coro::Shape::SwitchFieldIndex::Resume,
- "ResumeFn.addr");
- auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(),
- GepIndex);
- auto *Cond = Builder.CreateIsNull(Load);
- Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
+
+ if (NewF->isCoroOnlyDestroyWhenComplete()) {
+ // When the coroutine can only be destroyed when complete, we don't need
+ // to generate code for other cases.
+ Builder.CreateBr(ResumeBB);
+ } else {
+ auto *GepIndex = Builder.CreateStructGEP(
+ Shape.FrameTy, NewFramePtr, coro::Shape::SwitchFieldIndex::Resume,
+ "ResumeFn.addr");
+ auto *Load =
+ Builder.CreateLoad(Shape.getSwitchResumePointerType(), GepIndex);
+ auto *Cond = Builder.CreateIsNull(Load);
+ Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
+ }
OldSwitchBB->getTerminator()->eraseFromParent();
}
}
@@ -701,8 +742,13 @@ void CoroCloner::salvageDebugInfo() {
SmallVector<DbgVariableIntrinsic *, 8> Worklist =
collectDbgVariableIntrinsics(*NewF);
SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
+
+ // Only 64-bit ABIs have a register we can refer to with the entry value.
+ bool UseEntryValue =
+ llvm::Triple(OrigF.getParent()->getTargetTriple()).isArch64Bit();
for (DbgVariableIntrinsic *DVI : Worklist)
- coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame);
+ coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame,
+ UseEntryValue);
// Remove all salvaged dbg.declare intrinsics that became
// either unreachable or stale due to the CoroSplit transformation.
@@ -811,7 +857,6 @@ Value *CoroCloner::deriveNewFramePointer() {
auto *ActiveAsyncSuspend = cast<CoroSuspendAsyncInst>(ActiveSuspend);
auto ContextIdx = ActiveAsyncSuspend->getStorageArgumentIndex() & 0xff;
auto *CalleeContext = NewF->getArg(ContextIdx);
- auto *FramePtrTy = Shape.FrameTy->getPointerTo();
auto *ProjectionFunc =
ActiveAsyncSuspend->getAsyncContextProjectionFunction();
auto DbgLoc =
@@ -831,22 +876,20 @@ Value *CoroCloner::deriveNewFramePointer() {
auto InlineRes = InlineFunction(*CallerContext, InlineInfo);
assert(InlineRes.isSuccess());
(void)InlineRes;
- return Builder.CreateBitCast(FramePtrAddr, FramePtrTy);
+ return FramePtrAddr;
}
// In continuation-lowering, the argument is the opaque storage.
case coro::ABI::Retcon:
case coro::ABI::RetconOnce: {
Argument *NewStorage = &*NewF->arg_begin();
- auto FramePtrTy = Shape.FrameTy->getPointerTo();
+ auto FramePtrTy = PointerType::getUnqual(Shape.FrameTy->getContext());
// If the storage is inline, just bitcast to the storage to the frame type.
if (Shape.RetconLowering.IsFrameInlineInStorage)
- return Builder.CreateBitCast(NewStorage, FramePtrTy);
+ return NewStorage;
// Otherwise, load the real frame from the opaque storage.
- auto FramePtrPtr =
- Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo());
- return Builder.CreateLoad(FramePtrTy, FramePtrPtr);
+ return Builder.CreateLoad(FramePtrTy, NewStorage);
}
}
llvm_unreachable("bad ABI");
@@ -940,9 +983,22 @@ void CoroCloner::create() {
// abstract specification, since the DWARF backend expects the
// abstract specification to contain the linkage name and asserts
// that they are identical.
- if (!SP->getDeclaration() && SP->getUnit() &&
- SP->getUnit()->getSourceLanguage() == dwarf::DW_LANG_Swift)
+ if (SP->getUnit() &&
+ SP->getUnit()->getSourceLanguage() == dwarf::DW_LANG_Swift) {
SP->replaceLinkageName(MDString::get(Context, NewF->getName()));
+ if (auto *Decl = SP->getDeclaration()) {
+ auto *NewDecl = DISubprogram::get(
+ Decl->getContext(), Decl->getScope(), Decl->getName(),
+ NewF->getName(), Decl->getFile(), Decl->getLine(), Decl->getType(),
+ Decl->getScopeLine(), Decl->getContainingType(),
+ Decl->getVirtualIndex(), Decl->getThisAdjustment(),
+ Decl->getFlags(), Decl->getSPFlags(), Decl->getUnit(),
+ Decl->getTemplateParams(), nullptr, Decl->getRetainedNodes(),
+ Decl->getThrownTypes(), Decl->getAnnotations(),
+ Decl->getTargetFuncName());
+ SP->replaceDeclaration(NewDecl);
+ }
+ }
}
NewF->setLinkage(savedLinkage);
@@ -1047,7 +1103,7 @@ void CoroCloner::create() {
// Remap vFrame pointer.
auto *NewVFrame = Builder.CreateBitCast(
- NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
+ NewFramePtr, PointerType::getUnqual(Builder.getContext()), "vFrame");
Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
if (OldVFrame != NewVFrame)
OldVFrame->replaceAllUsesWith(NewVFrame);
@@ -1178,7 +1234,7 @@ static void setCoroInfo(Function &F, coro::Shape &Shape,
// Update coro.begin instruction to refer to this constant.
LLVMContext &C = F.getContext();
- auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
+ auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
Shape.getSwitchCoroId()->setInfo(BC);
}
@@ -1425,10 +1481,9 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) {
IRBuilder<> Builder(AllocInst);
auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
Frame->setAlignment(Shape.FrameAlign);
- auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
AllocInst->replaceAllUsesWith(Builder.getFalse());
AllocInst->eraseFromParent();
- CoroBegin->replaceAllUsesWith(VFrame);
+ CoroBegin->replaceAllUsesWith(Frame);
} else {
CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
}
@@ -1658,7 +1713,7 @@ static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend,
Value *Continuation) {
auto *ResumeIntrinsic = Suspend->getResumeFunction();
auto &Context = Suspend->getParent()->getParent()->getContext();
- auto *Int8PtrTy = Type::getInt8PtrTy(Context);
+ auto *Int8PtrTy = PointerType::getUnqual(Context);
IRBuilder<> Builder(ResumeIntrinsic);
auto *Val = Builder.CreateBitOrPointerCast(Continuation, Int8PtrTy);
@@ -1711,7 +1766,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
F.removeRetAttr(Attribute::NonNull);
auto &Context = F.getContext();
- auto *Int8PtrTy = Type::getInt8PtrTy(Context);
+ auto *Int8PtrTy = PointerType::getUnqual(Context);
auto *Id = cast<CoroIdAsyncInst>(Shape.CoroBegin->getId());
IRBuilder<> Builder(Id);
@@ -1829,9 +1884,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
// Stash the allocated frame pointer in the continuation storage.
- auto Dest = Builder.CreateBitCast(Id->getStorage(),
- RawFramePtr->getType()->getPointerTo());
- Builder.CreateStore(RawFramePtr, Dest);
+ Builder.CreateStore(RawFramePtr, Id->getStorage());
}
// Map all uses of llvm.coro.begin to the allocated frame pointer.
@@ -1987,7 +2040,8 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
// coroutine funclets.
SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
for (auto *DDI : collectDbgVariableIntrinsics(F))
- coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame);
+ coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame,
+ false /*UseEntryValue*/);
return Shape;
}
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index cde74c5e693b..eef5543bae24 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -37,16 +37,15 @@ using namespace llvm;
// Construct the lowerer base class and initialize its members.
coro::LowererBase::LowererBase(Module &M)
: TheModule(M), Context(M.getContext()),
- Int8Ptr(Type::getInt8PtrTy(Context)),
+ Int8Ptr(PointerType::get(Context, 0)),
ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
/*isVarArg=*/false)),
NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
-// Creates a sequence of instructions to obtain a resume function address using
-// llvm.coro.subfn.addr. It generates the following sequence:
+// Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
+// It generates the following:
//
-// call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
-// bitcast i8* %2 to void(i8*)*
+// call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
Instruction *InsertPt) {
@@ -56,11 +55,7 @@ Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
assert(Index >= CoroSubFnInst::IndexFirst &&
Index < CoroSubFnInst::IndexLast &&
"makeSubFnCall: Index value out of range");
- auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
-
- auto *Bitcast =
- new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
- return Bitcast;
+ return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
}
// NOTE: Must be sorted!
@@ -137,8 +132,9 @@ void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
return;
Value *Replacement =
- Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
- : CoroFrees.front()->getFrame();
+ Elide
+ ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
+ : CoroFrees.front()->getFrame();
for (CoroFreeInst *CF : CoroFrees) {
CF->replaceAllUsesWith(Replacement);
@@ -267,7 +263,7 @@ void coro::Shape::buildFrom(Function &F) {
if (!CoroBegin) {
// Replace coro.frame which are supposed to be lowered to the result of
// coro.begin with undef.
- auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
+ auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(Undef);
CF->eraseFromParent();
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
new file mode 100644
index 000000000000..fb7cba9edbdb
--- /dev/null
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -0,0 +1,312 @@
+//===----- HipStdPar.cpp - HIP C++ Standard Parallelism Support Passes ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file implements two passes that enable HIP C++ Standard Parallelism
+// Support:
+//
+// 1. AcceleratorCodeSelection (required): Given that only algorithms are
+// accelerated, and that the accelerated implementation exists in the form of
+// a compute kernel, we assume that only the kernel, and all functions
+// reachable from it, constitute code that the user expects the accelerator
+// to execute. Thus, we identify the set of all functions reachable from
+// kernels, and then remove all unreachable ones. This last part is necessary
+// because it is possible for code that the user did not expect to execute on
+// an accelerator to contain constructs that cannot be handled by the target
+// BE, which cannot be provably demonstrated to be dead code in general, and
+// thus can lead to mis-compilation. The degenerate case of this is when a
+// Module contains no kernels (the parent TU had no algorithm invocations fit
+// for acceleration), which we handle by completely emptying said module.
+// **NOTE**: The above does not handle indirectly reachable functions i.e.
+// it is possible to obtain a case where the target of an indirect
+// call is otherwise unreachable and thus is removed; this
+// restriction is aligned with the current `-hipstdpar` limitations
+// and will be relaxed in the future.
+//
+// 2. AllocationInterposition (required only when on-demand paging is
+// unsupported): Some accelerators or operating systems might not support
+// transparent on-demand paging. Thus, they would only be able to access
+// memory that is allocated by an accelerator-aware mechanism. For such cases
+// the user can opt into enabling allocation / deallocation interposition,
+// whereby we replace calls to known allocation / deallocation functions with
+// calls to runtime implemented equivalents that forward the requests to
+// accelerator-aware interfaces. We also support freeing system allocated
+// memory that ends up in one of the runtime equivalents, since this can
+// happen if e.g. a library that was compiled without interposition returns
+// an allocation that can be validly passed to `free`.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/HipStdPar/HipStdPar.h"
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#include <cassert>
+#include <string>
+#include <utility>
+
+using namespace llvm;
+
+template<typename T>
+static inline void eraseFromModule(T &ToErase) {
+ ToErase.replaceAllUsesWith(PoisonValue::get(ToErase.getType()));
+ ToErase.eraseFromParent();
+}
+
+static inline bool checkIfSupported(GlobalVariable &G) {
+ if (!G.isThreadLocal())
+ return true;
+
+ G.dropDroppableUses();
+
+ if (!G.isConstantUsed())
+ return true;
+
+ std::string W;
+ raw_string_ostream OS(W);
+
+ OS << "Accelerator does not support the thread_local variable "
+ << G.getName();
+
+ Instruction *I = nullptr;
+ SmallVector<User *> Tmp(G.user_begin(), G.user_end());
+ SmallPtrSet<User *, 5> Visited;
+ do {
+ auto U = std::move(Tmp.back());
+ Tmp.pop_back();
+
+ if (Visited.contains(U))
+ continue;
+
+ if (isa<Instruction>(U))
+ I = cast<Instruction>(U);
+ else
+ Tmp.insert(Tmp.end(), U->user_begin(), U->user_end());
+
+ Visited.insert(U);
+ } while (!I && !Tmp.empty());
+
+ assert(I && "thread_local global should have at least one non-constant use.");
+
+ G.getContext().diagnose(
+ DiagnosticInfoUnsupported(*I->getParent()->getParent(), W,
+ I->getDebugLoc(), DS_Error));
+
+ return false;
+}
+
+static inline void clearModule(Module &M) { // TODO: simplify.
+ while (!M.functions().empty())
+ eraseFromModule(*M.begin());
+ while (!M.globals().empty())
+ eraseFromModule(*M.globals().begin());
+ while (!M.aliases().empty())
+ eraseFromModule(*M.aliases().begin());
+ while (!M.ifuncs().empty())
+ eraseFromModule(*M.ifuncs().begin());
+}
+
+static inline void maybeHandleGlobals(Module &M) {
+ unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace();
+ for (auto &&G : M.globals()) { // TODO: should we handle these in the FE?
+ if (!checkIfSupported(G))
+ return clearModule(M);
+
+ if (G.isThreadLocal())
+ continue;
+ if (G.isConstant())
+ continue;
+ if (G.getAddressSpace() != GlobAS)
+ continue;
+ if (G.getLinkage() != GlobalVariable::ExternalLinkage)
+ continue;
+
+ G.setLinkage(GlobalVariable::ExternalWeakLinkage);
+ G.setExternallyInitialized(true);
+ }
+}
+
+template<unsigned N>
+static inline void removeUnreachableFunctions(
+ const SmallPtrSet<const Function *, N>& Reachable, Module &M) {
+ removeFromUsedLists(M, [&](Constant *C) {
+ if (auto F = dyn_cast<Function>(C))
+ return !Reachable.contains(F);
+
+ return false;
+ });
+
+ SmallVector<std::reference_wrapper<Function>> ToRemove;
+ copy_if(M, std::back_inserter(ToRemove), [&](auto &&F) {
+ return !F.isIntrinsic() && !Reachable.contains(&F);
+ });
+
+ for_each(ToRemove, eraseFromModule<Function>);
+}
+
+static inline bool isAcceleratorExecutionRoot(const Function *F) {
+ if (!F)
+ return false;
+
+ return F->getCallingConv() == CallingConv::AMDGPU_KERNEL;
+}
+
+static inline bool checkIfSupported(const Function *F, const CallBase *CB) {
+ const auto Dx = F->getName().rfind("__hipstdpar_unsupported");
+
+ if (Dx == StringRef::npos)
+ return true;
+
+ const auto N = F->getName().substr(0, Dx);
+
+ std::string W;
+ raw_string_ostream OS(W);
+
+ if (N == "__ASM")
+ OS << "Accelerator does not support the ASM block:\n"
+ << cast<ConstantDataArray>(CB->getArgOperand(0))->getAsCString();
+ else
+ OS << "Accelerator does not support the " << N << " function.";
+
+ auto Caller = CB->getParent()->getParent();
+
+ Caller->getContext().diagnose(
+ DiagnosticInfoUnsupported(*Caller, W, CB->getDebugLoc(), DS_Error));
+
+ return false;
+}
+
+PreservedAnalyses
+ HipStdParAcceleratorCodeSelectionPass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ auto &CGA = MAM.getResult<CallGraphAnalysis>(M);
+
+ SmallPtrSet<const Function *, 32> Reachable;
+ for (auto &&CGN : CGA) {
+ if (!isAcceleratorExecutionRoot(CGN.first))
+ continue;
+
+ Reachable.insert(CGN.first);
+
+ SmallVector<const Function *> Tmp({CGN.first});
+ do {
+ auto F = std::move(Tmp.back());
+ Tmp.pop_back();
+
+ for (auto &&N : *CGA[F]) {
+ if (!N.second)
+ continue;
+ if (!N.second->getFunction())
+ continue;
+ if (Reachable.contains(N.second->getFunction()))
+ continue;
+
+ if (!checkIfSupported(N.second->getFunction(),
+ dyn_cast<CallBase>(*N.first)))
+ return PreservedAnalyses::none();
+
+ Reachable.insert(N.second->getFunction());
+ Tmp.push_back(N.second->getFunction());
+ }
+ } while (!std::empty(Tmp));
+ }
+
+ if (std::empty(Reachable))
+ clearModule(M);
+ else
+ removeUnreachableFunctions(Reachable, M);
+
+ maybeHandleGlobals(M);
+
+ return PreservedAnalyses::none();
+}
+
+static constexpr std::pair<StringLiteral, StringLiteral> ReplaceMap[]{
+ {"aligned_alloc", "__hipstdpar_aligned_alloc"},
+ {"calloc", "__hipstdpar_calloc"},
+ {"free", "__hipstdpar_free"},
+ {"malloc", "__hipstdpar_malloc"},
+ {"memalign", "__hipstdpar_aligned_alloc"},
+ {"posix_memalign", "__hipstdpar_posix_aligned_alloc"},
+ {"realloc", "__hipstdpar_realloc"},
+ {"reallocarray", "__hipstdpar_realloc_array"},
+ {"_ZdaPv", "__hipstdpar_operator_delete"},
+ {"_ZdaPvm", "__hipstdpar_operator_delete_sized"},
+ {"_ZdaPvSt11align_val_t", "__hipstdpar_operator_delete_aligned"},
+ {"_ZdaPvmSt11align_val_t", "__hipstdpar_operator_delete_aligned_sized"},
+ {"_ZdlPv", "__hipstdpar_operator_delete"},
+ {"_ZdlPvm", "__hipstdpar_operator_delete_sized"},
+ {"_ZdlPvSt11align_val_t", "__hipstdpar_operator_delete_aligned"},
+ {"_ZdlPvmSt11align_val_t", "__hipstdpar_operator_delete_aligned_sized"},
+ {"_Znam", "__hipstdpar_operator_new"},
+ {"_ZnamRKSt9nothrow_t", "__hipstdpar_operator_new_nothrow"},
+ {"_ZnamSt11align_val_t", "__hipstdpar_operator_new_aligned"},
+ {"_ZnamSt11align_val_tRKSt9nothrow_t",
+ "__hipstdpar_operator_new_aligned_nothrow"},
+
+ {"_Znwm", "__hipstdpar_operator_new"},
+ {"_ZnwmRKSt9nothrow_t", "__hipstdpar_operator_new_nothrow"},
+ {"_ZnwmSt11align_val_t", "__hipstdpar_operator_new_aligned"},
+ {"_ZnwmSt11align_val_tRKSt9nothrow_t",
+ "__hipstdpar_operator_new_aligned_nothrow"},
+ {"__builtin_calloc", "__hipstdpar_calloc"},
+ {"__builtin_free", "__hipstdpar_free"},
+ {"__builtin_malloc", "__hipstdpar_malloc"},
+ {"__builtin_operator_delete", "__hipstdpar_operator_delete"},
+ {"__builtin_operator_new", "__hipstdpar_operator_new"},
+ {"__builtin_realloc", "__hipstdpar_realloc"},
+ {"__libc_calloc", "__hipstdpar_calloc"},
+ {"__libc_free", "__hipstdpar_free"},
+ {"__libc_malloc", "__hipstdpar_malloc"},
+ {"__libc_memalign", "__hipstdpar_aligned_alloc"},
+ {"__libc_realloc", "__hipstdpar_realloc"}
+};
+
+PreservedAnalyses
+HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
+ SmallDenseMap<StringRef, StringRef> AllocReplacements(std::cbegin(ReplaceMap),
+ std::cend(ReplaceMap));
+
+ for (auto &&F : M) {
+ if (!F.hasName())
+ continue;
+ if (!AllocReplacements.contains(F.getName()))
+ continue;
+
+ if (auto R = M.getFunction(AllocReplacements[F.getName()])) {
+ F.replaceAllUsesWith(R);
+ } else {
+ std::string W;
+ raw_string_ostream OS(W);
+
+ OS << "cannot be interposed, missing: " << AllocReplacements[F.getName()]
+ << ". Tried to run the allocation interposition pass without the "
+ << "replacement functions available.";
+
+ F.getContext().diagnose(DiagnosticInfoUnsupported(F, W,
+ F.getSubprogram(),
+ DS_Warning));
+ }
+ }
+
+ if (auto F = M.getFunction("__hipstdpar_hidden_free")) {
+ auto LibcFree = M.getOrInsertFunction("__libc_free", F->getFunctionType(),
+ F->getAttributes());
+ F->replaceAllUsesWith(LibcFree.getCallee());
+
+ eraseFromModule(*F);
+ }
+
+ return PreservedAnalyses::none();
+}
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 824da6395f2e..fb3fa8d23daa 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -121,19 +121,24 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
// that we are *not* promoting. For the ones that we do promote, the parameter
// attributes are lost
SmallVector<AttributeSet, 8> ArgAttrVec;
+ // Mapping from old to new argument indices. -1 for promoted or removed
+ // arguments.
+ SmallVector<unsigned> NewArgIndices;
AttributeList PAL = F->getAttributes();
// First, determine the new argument list
- unsigned ArgNo = 0;
+ unsigned ArgNo = 0, NewArgNo = 0;
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
++I, ++ArgNo) {
if (!ArgsToPromote.count(&*I)) {
// Unchanged argument
Params.push_back(I->getType());
ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
+ NewArgIndices.push_back(NewArgNo++);
} else if (I->use_empty()) {
// Dead argument (which are always marked as promotable)
++NumArgumentsDead;
+ NewArgIndices.push_back((unsigned)-1);
} else {
const auto &ArgParts = ArgsToPromote.find(&*I)->second;
for (const auto &Pair : ArgParts) {
@@ -141,6 +146,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
ArgAttrVec.push_back(AttributeSet());
}
++NumArgumentsPromoted;
+ NewArgIndices.push_back((unsigned)-1);
+ NewArgNo += ArgParts.size();
}
}
@@ -154,6 +161,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
F->getName());
NF->copyAttributesFrom(F);
NF->copyMetadata(F, 0);
+ NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
// The new function will have the !dbg metadata copied from the original
// function. The original function may not be deleted, and dbg metadata need
@@ -173,6 +181,19 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,
// the function.
NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
PAL.getRetAttrs(), ArgAttrVec));
+
+ // Remap argument indices in allocsize attribute.
+ if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) {
+ unsigned Arg1 = NewArgIndices[AllocSize->first];
+ assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument");
+ std::optional<unsigned> Arg2;
+ if (AllocSize->second) {
+ Arg2 = NewArgIndices[*AllocSize->second];
+ assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument");
+ }
+ NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2));
+ }
+
AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
ArgAttrVec.clear();
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 847d07a49dee..d8e290cbc8a4 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CallGraph.h"
@@ -50,6 +51,7 @@
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cstdint>
+#include <memory>
#ifdef EXPENSIVE_CHECKS
#include "llvm/IR/Verifier.h"
@@ -93,6 +95,13 @@ static cl::opt<unsigned>
cl::desc("Maximal number of fixpoint iterations."),
cl::init(32));
+static cl::opt<unsigned>
+ MaxSpecializationPerCB("attributor-max-specializations-per-call-base",
+ cl::Hidden,
+ cl::desc("Maximal number of callees specialized for "
+ "a call base"),
+ cl::init(UINT32_MAX));
+
static cl::opt<unsigned, true> MaxInitializationChainLengthX(
"attributor-max-initialization-chain-length", cl::Hidden,
cl::desc(
@@ -166,6 +175,10 @@ static cl::opt<bool> SimplifyAllLoads("attributor-simplify-all-loads",
cl::desc("Try to simplify all loads."),
cl::init(true));
+static cl::opt<bool> CloseWorldAssumption(
+ "attributor-assume-closed-world", cl::Hidden,
+ cl::desc("Should a closed world be assumed, or not. Default if not set."));
+
/// Logic operators for the change status enum class.
///
///{
@@ -226,10 +239,10 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA,
return InstanceInfoAA && InstanceInfoAA->isAssumedUniqueForAnalysis();
}
-Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty,
- const TargetLibraryInfo *TLI,
- const DataLayout &DL,
- AA::RangeTy *RangePtr) {
+Constant *
+AA::getInitialValueForObj(Attributor &A, const AbstractAttribute &QueryingAA,
+ Value &Obj, Type &Ty, const TargetLibraryInfo *TLI,
+ const DataLayout &DL, AA::RangeTy *RangePtr) {
if (isa<AllocaInst>(Obj))
return UndefValue::get(&Ty);
if (Constant *Init = getInitialValueOfAllocation(&Obj, TLI, &Ty))
@@ -242,12 +255,13 @@ Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty,
Constant *Initializer = nullptr;
if (A.hasGlobalVariableSimplificationCallback(*GV)) {
auto AssumedGV = A.getAssumedInitializerFromCallBack(
- *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation);
+ *GV, &QueryingAA, UsedAssumedInformation);
Initializer = *AssumedGV;
if (!Initializer)
return nullptr;
} else {
- if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer()))
+ if (!GV->hasLocalLinkage() &&
+ (GV->isInterposable() || !(GV->isConstant() && GV->hasInitializer())))
return nullptr;
if (!GV->hasInitializer())
return UndefValue::get(&Ty);
@@ -316,7 +330,7 @@ Value *AA::getWithType(Value &V, Type &Ty) {
if (C->getType()->isIntegerTy() && Ty.isIntegerTy())
return ConstantExpr::getTrunc(C, &Ty, /* OnlyIfReduced */ true);
if (C->getType()->isFloatingPointTy() && Ty.isFloatingPointTy())
- return ConstantExpr::getFPTrunc(C, &Ty, /* OnlyIfReduced */ true);
+ return ConstantFoldCastInstruction(Instruction::FPTrunc, C, &Ty);
}
}
return nullptr;
@@ -350,7 +364,7 @@ AA::combineOptionalValuesInAAValueLatice(const std::optional<Value *> &A,
template <bool IsLoad, typename Ty>
static bool getPotentialCopiesOfMemoryValue(
Attributor &A, Ty &I, SmallSetVector<Value *, 4> &PotentialCopies,
- SmallSetVector<Instruction *, 4> &PotentialValueOrigins,
+ SmallSetVector<Instruction *, 4> *PotentialValueOrigins,
const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation,
bool OnlyExact) {
LLVM_DEBUG(dbgs() << "Trying to determine the potential copies of " << I
@@ -361,8 +375,8 @@ static bool getPotentialCopiesOfMemoryValue(
// sure that we can find all of them. If we abort we want to avoid spurious
// dependences and potential copies in the provided container.
SmallVector<const AAPointerInfo *> PIs;
- SmallVector<Value *> NewCopies;
- SmallVector<Instruction *> NewCopyOrigins;
+ SmallSetVector<Value *, 8> NewCopies;
+ SmallSetVector<Instruction *, 8> NewCopyOrigins;
const auto *TLI =
A.getInfoCache().getTargetLibraryInfoForFunction(*I.getFunction());
@@ -425,6 +439,30 @@ static bool getPotentialCopiesOfMemoryValue(
return AdjV;
};
+ auto SkipCB = [&](const AAPointerInfo::Access &Acc) {
+ if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead()))
+ return true;
+ if (IsLoad) {
+ if (Acc.isWrittenValueYetUndetermined())
+ return true;
+ if (PotentialValueOrigins && !isa<AssumeInst>(Acc.getRemoteInst()))
+ return false;
+ if (!Acc.isWrittenValueUnknown())
+ if (Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue()))
+ if (NewCopies.count(V)) {
+ NewCopyOrigins.insert(Acc.getRemoteInst());
+ return true;
+ }
+ if (auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst()))
+ if (Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand()))
+ if (NewCopies.count(V)) {
+ NewCopyOrigins.insert(Acc.getRemoteInst());
+ return true;
+ }
+ }
+ return false;
+ };
+
auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) {
if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead()))
return true;
@@ -449,8 +487,9 @@ static bool getPotentialCopiesOfMemoryValue(
Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue());
if (!V)
return false;
- NewCopies.push_back(V);
- NewCopyOrigins.push_back(Acc.getRemoteInst());
+ NewCopies.insert(V);
+ if (PotentialValueOrigins)
+ NewCopyOrigins.insert(Acc.getRemoteInst());
return true;
}
auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst());
@@ -463,8 +502,9 @@ static bool getPotentialCopiesOfMemoryValue(
Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand());
if (!V)
return false;
- NewCopies.push_back(V);
- NewCopyOrigins.push_back(SI);
+ NewCopies.insert(V);
+ if (PotentialValueOrigins)
+ NewCopyOrigins.insert(SI);
} else {
assert(isa<StoreInst>(I) && "Expected load or store instruction only!");
auto *LI = dyn_cast<LoadInst>(Acc.getRemoteInst());
@@ -474,7 +514,7 @@ static bool getPotentialCopiesOfMemoryValue(
<< *Acc.getRemoteInst() << "\n";);
return false;
}
- NewCopies.push_back(Acc.getRemoteInst());
+ NewCopies.insert(Acc.getRemoteInst());
}
return true;
};
@@ -486,11 +526,11 @@ static bool getPotentialCopiesOfMemoryValue(
AA::RangeTy Range;
auto *PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj),
DepClassTy::NONE);
- if (!PI ||
- !PI->forallInterferingAccesses(A, QueryingAA, I,
- /* FindInterferingWrites */ IsLoad,
- /* FindInterferingReads */ !IsLoad,
- CheckAccess, HasBeenWrittenTo, Range)) {
+ if (!PI || !PI->forallInterferingAccesses(
+ A, QueryingAA, I,
+ /* FindInterferingWrites */ IsLoad,
+ /* FindInterferingReads */ !IsLoad, CheckAccess,
+ HasBeenWrittenTo, Range, SkipCB)) {
LLVM_DEBUG(
dbgs()
<< "Failed to verify all interfering accesses for underlying object: "
@@ -500,8 +540,8 @@ static bool getPotentialCopiesOfMemoryValue(
if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) {
const DataLayout &DL = A.getDataLayout();
- Value *InitialValue =
- AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range);
+ Value *InitialValue = AA::getInitialValueForObj(
+ A, QueryingAA, Obj, *I.getType(), TLI, DL, &Range);
if (!InitialValue) {
LLVM_DEBUG(dbgs() << "Could not determine required initial value of "
"underlying object, abort!\n");
@@ -514,8 +554,9 @@ static bool getPotentialCopiesOfMemoryValue(
return false;
}
- NewCopies.push_back(InitialValue);
- NewCopyOrigins.push_back(nullptr);
+ NewCopies.insert(InitialValue);
+ if (PotentialValueOrigins)
+ NewCopyOrigins.insert(nullptr);
}
PIs.push_back(PI);
@@ -540,7 +581,8 @@ static bool getPotentialCopiesOfMemoryValue(
A.recordDependence(*PI, QueryingAA, DepClassTy::OPTIONAL);
}
PotentialCopies.insert(NewCopies.begin(), NewCopies.end());
- PotentialValueOrigins.insert(NewCopyOrigins.begin(), NewCopyOrigins.end());
+ if (PotentialValueOrigins)
+ PotentialValueOrigins->insert(NewCopyOrigins.begin(), NewCopyOrigins.end());
return true;
}
@@ -551,7 +593,7 @@ bool AA::getPotentiallyLoadedValues(
const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation,
bool OnlyExact) {
return getPotentialCopiesOfMemoryValue</* IsLoad */ true>(
- A, LI, PotentialValues, PotentialValueOrigins, QueryingAA,
+ A, LI, PotentialValues, &PotentialValueOrigins, QueryingAA,
UsedAssumedInformation, OnlyExact);
}
@@ -559,10 +601,9 @@ bool AA::getPotentialCopiesOfStoredValue(
Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies,
const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation,
bool OnlyExact) {
- SmallSetVector<Instruction *, 4> PotentialValueOrigins;
return getPotentialCopiesOfMemoryValue</* IsLoad */ false>(
- A, SI, PotentialCopies, PotentialValueOrigins, QueryingAA,
- UsedAssumedInformation, OnlyExact);
+ A, SI, PotentialCopies, nullptr, QueryingAA, UsedAssumedInformation,
+ OnlyExact);
}
static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP,
@@ -723,7 +764,7 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI,
// Check if we can reach returns.
bool UsedAssumedInformation = false;
- if (A.checkForAllInstructions(ReturnInstCB, FromFn, QueryingAA,
+ if (A.checkForAllInstructions(ReturnInstCB, FromFn, &QueryingAA,
{Instruction::Ret}, UsedAssumedInformation)) {
LLVM_DEBUG(dbgs() << "[AA] No return is reachable, done\n");
continue;
@@ -1021,6 +1062,23 @@ ChangeStatus AbstractAttribute::update(Attributor &A) {
return HasChanged;
}
+Attributor::Attributor(SetVector<Function *> &Functions,
+ InformationCache &InfoCache,
+ AttributorConfig Configuration)
+ : Allocator(InfoCache.Allocator), Functions(Functions),
+ InfoCache(InfoCache), Configuration(Configuration) {
+ if (!isClosedWorldModule())
+ return;
+ for (Function *Fn : Functions)
+ if (Fn->hasAddressTaken(/*PutOffender=*/nullptr,
+ /*IgnoreCallbackUses=*/false,
+ /*IgnoreAssumeLikeCalls=*/true,
+ /*IgnoreLLVMUsed=*/true,
+ /*IgnoreARCAttachedCall=*/false,
+ /*IgnoreCastedDirectCall=*/true))
+ InfoCache.IndirectlyCallableFunctions.push_back(Fn);
+}
+
bool Attributor::getAttrsFromAssumes(const IRPosition &IRP,
Attribute::AttrKind AK,
SmallVectorImpl<Attribute> &Attrs) {
@@ -1053,8 +1111,7 @@ bool Attributor::getAttrsFromAssumes(const IRPosition &IRP,
template <typename DescTy>
ChangeStatus
-Attributor::updateAttrMap(const IRPosition &IRP,
- const ArrayRef<DescTy> &AttrDescs,
+Attributor::updateAttrMap(const IRPosition &IRP, ArrayRef<DescTy> AttrDescs,
function_ref<bool(const DescTy &, AttributeSet,
AttributeMask &, AttrBuilder &)>
CB) {
@@ -1161,9 +1218,8 @@ void Attributor::getAttrs(const IRPosition &IRP,
getAttrsFromAssumes(IRP, AK, Attrs);
}
-ChangeStatus
-Attributor::removeAttrs(const IRPosition &IRP,
- const ArrayRef<Attribute::AttrKind> &AttrKinds) {
+ChangeStatus Attributor::removeAttrs(const IRPosition &IRP,
+ ArrayRef<Attribute::AttrKind> AttrKinds) {
auto RemoveAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet,
AttributeMask &AM, AttrBuilder &) {
if (!AttrSet.hasAttribute(Kind))
@@ -1174,8 +1230,21 @@ Attributor::removeAttrs(const IRPosition &IRP,
return updateAttrMap<Attribute::AttrKind>(IRP, AttrKinds, RemoveAttrCB);
}
+ChangeStatus Attributor::removeAttrs(const IRPosition &IRP,
+ ArrayRef<StringRef> Attrs) {
+ auto RemoveAttrCB = [&](StringRef Attr, AttributeSet AttrSet,
+ AttributeMask &AM, AttrBuilder &) -> bool {
+ if (!AttrSet.hasAttribute(Attr))
+ return false;
+ AM.addAttribute(Attr);
+ return true;
+ };
+
+ return updateAttrMap<StringRef>(IRP, Attrs, RemoveAttrCB);
+}
+
ChangeStatus Attributor::manifestAttrs(const IRPosition &IRP,
- const ArrayRef<Attribute> &Attrs,
+ ArrayRef<Attribute> Attrs,
bool ForceReplace) {
LLVMContext &Ctx = IRP.getAnchorValue().getContext();
auto AddAttrCB = [&](const Attribute &Attr, AttributeSet AttrSet,
@@ -1665,6 +1734,21 @@ bool Attributor::isAssumedDead(const BasicBlock &BB,
return false;
}
+bool Attributor::checkForAllCallees(
+ function_ref<bool(ArrayRef<const Function *>)> Pred,
+ const AbstractAttribute &QueryingAA, const CallBase &CB) {
+ if (const Function *Callee = dyn_cast<Function>(CB.getCalledOperand()))
+ return Pred(Callee);
+
+ const auto *CallEdgesAA = getAAFor<AACallEdges>(
+ QueryingAA, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
+ if (!CallEdgesAA || CallEdgesAA->hasUnknownCallee())
+ return false;
+
+ const auto &Callees = CallEdgesAA->getOptimisticEdges();
+ return Pred(Callees.getArrayRef());
+}
+
bool Attributor::checkForAllUses(
function_ref<bool(const Use &, bool &)> Pred,
const AbstractAttribute &QueryingAA, const Value &V,
@@ -1938,7 +2022,7 @@ bool Attributor::checkForAllReturnedValues(function_ref<bool(Value &)> Pred,
static bool checkForAllInstructionsImpl(
Attributor *A, InformationCache::OpcodeInstMapTy &OpcodeInstMap,
function_ref<bool(Instruction &)> Pred, const AbstractAttribute *QueryingAA,
- const AAIsDead *LivenessAA, const ArrayRef<unsigned> &Opcodes,
+ const AAIsDead *LivenessAA, ArrayRef<unsigned> Opcodes,
bool &UsedAssumedInformation, bool CheckBBLivenessOnly = false,
bool CheckPotentiallyDead = false) {
for (unsigned Opcode : Opcodes) {
@@ -1967,8 +2051,8 @@ static bool checkForAllInstructionsImpl(
bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
const Function *Fn,
- const AbstractAttribute &QueryingAA,
- const ArrayRef<unsigned> &Opcodes,
+ const AbstractAttribute *QueryingAA,
+ ArrayRef<unsigned> Opcodes,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly,
bool CheckPotentiallyDead) {
@@ -1978,12 +2062,12 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
const IRPosition &QueryIRP = IRPosition::function(*Fn);
const auto *LivenessAA =
- CheckPotentiallyDead
- ? nullptr
- : (getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE));
+ CheckPotentiallyDead && QueryingAA
+ ? (getAAFor<AAIsDead>(*QueryingAA, QueryIRP, DepClassTy::NONE))
+ : nullptr;
auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn);
- if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA,
+ if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, QueryingAA,
LivenessAA, Opcodes, UsedAssumedInformation,
CheckBBLivenessOnly, CheckPotentiallyDead))
return false;
@@ -1993,13 +2077,13 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred,
const AbstractAttribute &QueryingAA,
- const ArrayRef<unsigned> &Opcodes,
+ ArrayRef<unsigned> Opcodes,
bool &UsedAssumedInformation,
bool CheckBBLivenessOnly,
bool CheckPotentiallyDead) {
const IRPosition &IRP = QueryingAA.getIRPosition();
const Function *AssociatedFunction = IRP.getAssociatedFunction();
- return checkForAllInstructions(Pred, AssociatedFunction, QueryingAA, Opcodes,
+ return checkForAllInstructions(Pred, AssociatedFunction, &QueryingAA, Opcodes,
UsedAssumedInformation, CheckBBLivenessOnly,
CheckPotentiallyDead);
}
@@ -2964,6 +3048,18 @@ ChangeStatus Attributor::rewriteFunctionSignatures(
NewArgumentAttributes));
AttributeFuncs::updateMinLegalVectorWidthAttr(*NewFn, LargestVectorWidth);
+ // Remove argmem from the memory effects if we have no more pointer
+ // arguments, or they are readnone.
+ MemoryEffects ME = NewFn->getMemoryEffects();
+ int ArgNo = -1;
+ if (ME.doesAccessArgPointees() && all_of(NewArgumentTypes, [&](Type *T) {
+ ++ArgNo;
+ return !T->isPtrOrPtrVectorTy() ||
+ NewFn->hasParamAttribute(ArgNo, Attribute::ReadNone);
+ })) {
+ NewFn->setMemoryEffects(ME - MemoryEffects::argMemOnly());
+ }
+
// Since we have now created the new function, splice the body of the old
// function right into the new function, leaving the old rotting hulk of the
// function empty.
@@ -3203,6 +3299,12 @@ InformationCache::FunctionInfo::~FunctionInfo() {
It.getSecond()->~InstructionVectorTy();
}
+const ArrayRef<Function *>
+InformationCache::getIndirectlyCallableFunctions(Attributor &A) const {
+ assert(A.isClosedWorldModule() && "Cannot see all indirect callees!");
+ return IndirectlyCallableFunctions;
+}
+
void Attributor::recordDependence(const AbstractAttribute &FromAA,
const AbstractAttribute &ToAA,
DepClassTy DepClass) {
@@ -3236,9 +3338,10 @@ void Attributor::checkAndQueryIRAttr(const IRPosition &IRP,
AttributeSet Attrs) {
bool IsKnown;
if (!Attrs.hasAttribute(AK))
- if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE,
- IsKnown))
- getOrCreateAAFor<AAType>(IRP);
+ if (!Configuration.Allowed || Configuration.Allowed->count(&AAType::ID))
+ if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE,
+ IsKnown))
+ getOrCreateAAFor<AAType>(IRP);
}
void Attributor::identifyDefaultAbstractAttributes(Function &F) {
@@ -3285,6 +3388,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
// Every function might be "will-return".
checkAndQueryIRAttr<Attribute::WillReturn, AAWillReturn>(FPos, FnAttrs);
+ // Every function might be marked "nosync"
+ checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs);
+
// Everything that is visible from the outside (=function, argument, return
// positions), cannot be changed if the function is not IPO amendable. We can
// however analyse the code inside.
@@ -3293,9 +3399,6 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
// Every function can be nounwind.
checkAndQueryIRAttr<Attribute::NoUnwind, AANoUnwind>(FPos, FnAttrs);
- // Every function might be marked "nosync"
- checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs);
-
// Every function might be "no-return".
checkAndQueryIRAttr<Attribute::NoReturn, AANoReturn>(FPos, FnAttrs);
@@ -3315,6 +3418,14 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
// Every function can track active assumptions.
getOrCreateAAFor<AAAssumptionInfo>(FPos);
+ // If we're not using a dynamic mode for float, there's nothing worthwhile
+ // to infer. This misses the edge case denormal-fp-math="dynamic" and
+ // denormal-fp-math-f32=something, but that likely has no real world use.
+ DenormalMode Mode = F.getDenormalMode(APFloat::IEEEsingle());
+ if (Mode.Input == DenormalMode::Dynamic ||
+ Mode.Output == DenormalMode::Dynamic)
+ getOrCreateAAFor<AADenormalFPMath>(FPos);
+
// Return attributes are only appropriate if the return type is non void.
Type *ReturnType = F.getReturnType();
if (!ReturnType->isVoidTy()) {
@@ -3420,8 +3531,10 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
// TODO: Even if the callee is not known now we might be able to simplify
// the call/callee.
- if (!Callee)
+ if (!Callee) {
+ getOrCreateAAFor<AAIndirectCallInfo>(CBFnPos);
return true;
+ }
// Every call site can track active assumptions.
getOrCreateAAFor<AAAssumptionInfo>(CBFnPos);
@@ -3498,14 +3611,13 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
};
auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F);
- bool Success;
+ [[maybe_unused]] bool Success;
bool UsedAssumedInformation = false;
Success = checkForAllInstructionsImpl(
nullptr, OpcodeInstMap, CallSitePred, nullptr, nullptr,
{(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr,
(unsigned)Instruction::Call},
UsedAssumedInformation);
- (void)Success;
assert(Success && "Expected the check call to be successful!");
auto LoadStorePred = [&](Instruction &I) -> bool {
@@ -3531,10 +3643,26 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
nullptr, OpcodeInstMap, LoadStorePred, nullptr, nullptr,
{(unsigned)Instruction::Load, (unsigned)Instruction::Store},
UsedAssumedInformation);
- (void)Success;
+ assert(Success && "Expected the check call to be successful!");
+
+ // AllocaInstPredicate
+ auto AAAllocationInfoPred = [&](Instruction &I) -> bool {
+ getOrCreateAAFor<AAAllocationInfo>(IRPosition::value(I));
+ return true;
+ };
+
+ Success = checkForAllInstructionsImpl(
+ nullptr, OpcodeInstMap, AAAllocationInfoPred, nullptr, nullptr,
+ {(unsigned)Instruction::Alloca}, UsedAssumedInformation);
assert(Success && "Expected the check call to be successful!");
}
+bool Attributor::isClosedWorldModule() const {
+ if (CloseWorldAssumption.getNumOccurrences())
+ return CloseWorldAssumption;
+ return isModulePass() && Configuration.IsClosedWorldModule;
+}
+
/// Helpers to ease debugging through output streams and print calls.
///
///{
@@ -3696,6 +3824,26 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache,
AttributorConfig AC(CGUpdater);
AC.IsModulePass = IsModulePass;
AC.DeleteFns = DeleteFns;
+
+ /// Tracking callback for specialization of indirect calls.
+ DenseMap<CallBase *, std::unique_ptr<SmallPtrSet<Function *, 8>>>
+ IndirectCalleeTrackingMap;
+ if (MaxSpecializationPerCB.getNumOccurrences()) {
+ AC.IndirectCalleeSpecializationCallback =
+ [&](Attributor &, const AbstractAttribute &AA, CallBase &CB,
+ Function &Callee) {
+ if (MaxSpecializationPerCB == 0)
+ return false;
+ auto &Set = IndirectCalleeTrackingMap[&CB];
+ if (!Set)
+ Set = std::make_unique<SmallPtrSet<Function *, 8>>();
+ if (Set->size() >= MaxSpecializationPerCB)
+ return Set->contains(&Callee);
+ Set->insert(&Callee);
+ return true;
+ };
+ }
+
Attributor A(Functions, InfoCache, AC);
// Create shallow wrappers for all functions that are not IPO amendable
@@ -3759,6 +3907,88 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache,
return Changed == ChangeStatus::CHANGED;
}
+static bool runAttributorLightOnFunctions(InformationCache &InfoCache,
+ SetVector<Function *> &Functions,
+ AnalysisGetter &AG,
+ CallGraphUpdater &CGUpdater,
+ FunctionAnalysisManager &FAM,
+ bool IsModulePass) {
+ if (Functions.empty())
+ return false;
+
+ LLVM_DEBUG({
+ dbgs() << "[AttributorLight] Run on module with " << Functions.size()
+ << " functions:\n";
+ for (Function *Fn : Functions)
+ dbgs() << " - " << Fn->getName() << "\n";
+ });
+
+ // Create an Attributor and initially empty information cache that is filled
+ // while we identify default attribute opportunities.
+ AttributorConfig AC(CGUpdater);
+ AC.IsModulePass = IsModulePass;
+ AC.DeleteFns = false;
+ DenseSet<const char *> Allowed(
+ {&AAWillReturn::ID, &AANoUnwind::ID, &AANoRecurse::ID, &AANoSync::ID,
+ &AANoFree::ID, &AANoReturn::ID, &AAMemoryLocation::ID,
+ &AAMemoryBehavior::ID, &AAUnderlyingObjects::ID, &AANoCapture::ID,
+ &AAInterFnReachability::ID, &AAIntraFnReachability::ID, &AACallEdges::ID,
+ &AANoFPClass::ID, &AAMustProgress::ID, &AANonNull::ID});
+ AC.Allowed = &Allowed;
+ AC.UseLiveness = false;
+
+ Attributor A(Functions, InfoCache, AC);
+
+ for (Function *F : Functions) {
+ if (F->hasExactDefinition())
+ NumFnWithExactDefinition++;
+ else
+ NumFnWithoutExactDefinition++;
+
+ // We look at internal functions only on-demand but if any use is not a
+ // direct call or outside the current set of analyzed functions, we have
+ // to do it eagerly.
+ if (F->hasLocalLinkage()) {
+ if (llvm::all_of(F->uses(), [&Functions](const Use &U) {
+ const auto *CB = dyn_cast<CallBase>(U.getUser());
+ return CB && CB->isCallee(&U) &&
+ Functions.count(const_cast<Function *>(CB->getCaller()));
+ }))
+ continue;
+ }
+
+ // Populate the Attributor with abstract attribute opportunities in the
+ // function and the information cache with IR information.
+ A.identifyDefaultAbstractAttributes(*F);
+ }
+
+ ChangeStatus Changed = A.run();
+
+ if (Changed == ChangeStatus::CHANGED) {
+ // Invalidate analyses for modified functions so that we don't have to
+ // invalidate all analyses for all functions in this SCC.
+ PreservedAnalyses FuncPA;
+ // We haven't changed the CFG for modified functions.
+ FuncPA.preserveSet<CFGAnalyses>();
+ for (Function *Changed : A.getModifiedFunctions()) {
+ FAM.invalidate(*Changed, FuncPA);
+ // Also invalidate any direct callers of changed functions since analyses
+ // may care about attributes of direct callees. For example, MemorySSA
+ // cares about whether or not a call's callee modifies memory and queries
+ // that through function attributes.
+ for (auto *U : Changed->users()) {
+ if (auto *Call = dyn_cast<CallBase>(U)) {
+ if (Call->getCalledFunction() == Changed)
+ FAM.invalidate(*Call->getFunction(), FuncPA);
+ }
+ }
+ }
+ }
+ LLVM_DEBUG(dbgs() << "[Attributor] Done with " << Functions.size()
+ << " functions, result: " << Changed << ".\n");
+ return Changed == ChangeStatus::CHANGED;
+}
+
void AADepGraph::viewGraph() { llvm::ViewGraph(this, "Dependency Graph"); }
void AADepGraph::dumpGraph() {
@@ -3839,6 +4069,62 @@ PreservedAnalyses AttributorCGSCCPass::run(LazyCallGraph::SCC &C,
return PreservedAnalyses::all();
}
+PreservedAnalyses AttributorLightPass::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ AnalysisGetter AG(FAM, /* CachedOnly */ true);
+
+ SetVector<Function *> Functions;
+ for (Function &F : M)
+ Functions.insert(&F);
+
+ CallGraphUpdater CGUpdater;
+ BumpPtrAllocator Allocator;
+ InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr);
+ if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM,
+ /* IsModulePass */ true)) {
+ PreservedAnalyses PA;
+ // We have not added or removed functions.
+ PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+ // We already invalidated all relevant function analyses above.
+ PA.preserveSet<AllAnalysesOn<Function>>();
+ return PA;
+ }
+ return PreservedAnalyses::all();
+}
+
+PreservedAnalyses AttributorLightCGSCCPass::run(LazyCallGraph::SCC &C,
+ CGSCCAnalysisManager &AM,
+ LazyCallGraph &CG,
+ CGSCCUpdateResult &UR) {
+ FunctionAnalysisManager &FAM =
+ AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
+ AnalysisGetter AG(FAM);
+
+ SetVector<Function *> Functions;
+ for (LazyCallGraph::Node &N : C)
+ Functions.insert(&N.getFunction());
+
+ if (Functions.empty())
+ return PreservedAnalyses::all();
+
+ Module &M = *Functions.back()->getParent();
+ CallGraphUpdater CGUpdater;
+ CGUpdater.initialize(CG, C, AM, UR);
+ BumpPtrAllocator Allocator;
+ InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions);
+ if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM,
+ /* IsModulePass */ false)) {
+ PreservedAnalyses PA;
+ // We have not added or removed functions.
+ PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+ // We already invalidated all relevant function analyses above.
+ PA.preserveSet<AllAnalysesOn<Function>>();
+ return PA;
+ }
+ return PreservedAnalyses::all();
+}
namespace llvm {
template <> struct GraphTraits<AADepGraphNode *> {
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 3a9a89d61355..889ebd7438bd 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -55,6 +55,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
@@ -64,12 +65,16 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <cassert>
#include <numeric>
#include <optional>
+#include <string>
using namespace llvm;
@@ -188,6 +193,10 @@ PIPE_OPERATOR(AAPointerInfo)
PIPE_OPERATOR(AAAssumptionInfo)
PIPE_OPERATOR(AAUnderlyingObjects)
PIPE_OPERATOR(AAAddressSpace)
+PIPE_OPERATOR(AAAllocationInfo)
+PIPE_OPERATOR(AAIndirectCallInfo)
+PIPE_OPERATOR(AAGlobalValueInfo)
+PIPE_OPERATOR(AADenormalFPMath)
#undef PIPE_OPERATOR
@@ -313,7 +322,6 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr,
// If an offset is left we use byte-wise adjustment.
if (IntOffset != 0) {
- Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy());
Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(IntOffset),
GEPName + ".b" + Twine(IntOffset.getZExtValue()));
}
@@ -377,7 +385,7 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA,
/// Clamp the information known for all returned values of a function
/// (identified by \p QueryingAA) into \p S.
template <typename AAType, typename StateType = typename AAType::StateType,
- Attribute::AttrKind IRAttributeKind = Attribute::None,
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind,
bool RecurseForSelectAndPHI = true>
static void clampReturnedValueStates(
Attributor &A, const AAType &QueryingAA, StateType &S,
@@ -400,7 +408,7 @@ static void clampReturnedValueStates(
auto CheckReturnValue = [&](Value &RV) -> bool {
const IRPosition &RVPos = IRPosition::value(RV, CBContext);
// If possible, use the hasAssumedIRAttr interface.
- if (IRAttributeKind != Attribute::None) {
+ if (Attribute::isEnumAttrKind(IRAttributeKind)) {
bool IsKnown;
return AA::hasAssumedIRAttr<IRAttributeKind>(
A, &QueryingAA, RVPos, DepClassTy::REQUIRED, IsKnown);
@@ -434,7 +442,7 @@ namespace {
template <typename AAType, typename BaseType,
typename StateType = typename BaseType::StateType,
bool PropagateCallBaseContext = false,
- Attribute::AttrKind IRAttributeKind = Attribute::None,
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind,
bool RecurseForSelectAndPHI = true>
struct AAReturnedFromReturnedValues : public BaseType {
AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A)
@@ -455,7 +463,7 @@ struct AAReturnedFromReturnedValues : public BaseType {
/// Clamp the information known at all call sites for a given argument
/// (identified by \p QueryingAA) into \p S.
template <typename AAType, typename StateType = typename AAType::StateType,
- Attribute::AttrKind IRAttributeKind = Attribute::None>
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind>
static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
StateType &S) {
LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for "
@@ -480,7 +488,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
return false;
// If possible, use the hasAssumedIRAttr interface.
- if (IRAttributeKind != Attribute::None) {
+ if (Attribute::isEnumAttrKind(IRAttributeKind)) {
bool IsKnown;
return AA::hasAssumedIRAttr<IRAttributeKind>(
A, &QueryingAA, ACSArgPos, DepClassTy::REQUIRED, IsKnown);
@@ -514,7 +522,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
/// context.
template <typename AAType, typename BaseType,
typename StateType = typename AAType::StateType,
- Attribute::AttrKind IRAttributeKind = Attribute::None>
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind>
bool getArgumentStateFromCallBaseContext(Attributor &A,
BaseType &QueryingAttribute,
IRPosition &Pos, StateType &State) {
@@ -529,7 +537,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A,
const IRPosition CBArgPos = IRPosition::callsite_argument(*CBContext, ArgNo);
// If possible, use the hasAssumedIRAttr interface.
- if (IRAttributeKind != Attribute::None) {
+ if (Attribute::isEnumAttrKind(IRAttributeKind)) {
bool IsKnown;
return AA::hasAssumedIRAttr<IRAttributeKind>(
A, &QueryingAttribute, CBArgPos, DepClassTy::REQUIRED, IsKnown);
@@ -555,7 +563,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A,
template <typename AAType, typename BaseType,
typename StateType = typename AAType::StateType,
bool BridgeCallBaseContext = false,
- Attribute::AttrKind IRAttributeKind = Attribute::None>
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind>
struct AAArgumentFromCallSiteArguments : public BaseType {
AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A)
: BaseType(IRP, A) {}
@@ -585,45 +593,55 @@ struct AAArgumentFromCallSiteArguments : public BaseType {
template <typename AAType, typename BaseType,
typename StateType = typename BaseType::StateType,
bool IntroduceCallBaseContext = false,
- Attribute::AttrKind IRAttributeKind = Attribute::None>
-struct AACallSiteReturnedFromReturned : public BaseType {
- AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A)
- : BaseType(IRP, A) {}
+ Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind>
+struct AACalleeToCallSite : public BaseType {
+ AACalleeToCallSite(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {}
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
- assert(this->getIRPosition().getPositionKind() ==
- IRPosition::IRP_CALL_SITE_RETURNED &&
- "Can only wrap function returned positions for call site returned "
- "positions!");
+ auto IRPKind = this->getIRPosition().getPositionKind();
+ assert((IRPKind == IRPosition::IRP_CALL_SITE_RETURNED ||
+ IRPKind == IRPosition::IRP_CALL_SITE) &&
+ "Can only wrap function returned positions for call site "
+ "returned positions!");
auto &S = this->getState();
- const Function *AssociatedFunction =
- this->getIRPosition().getAssociatedFunction();
- if (!AssociatedFunction)
- return S.indicatePessimisticFixpoint();
-
- CallBase &CBContext = cast<CallBase>(this->getAnchorValue());
+ CallBase &CB = cast<CallBase>(this->getAnchorValue());
if (IntroduceCallBaseContext)
- LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:"
- << CBContext << "\n");
-
- IRPosition FnPos = IRPosition::returned(
- *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr);
+ LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CB
+ << "\n");
- // If possible, use the hasAssumedIRAttr interface.
- if (IRAttributeKind != Attribute::None) {
- bool IsKnown;
- if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos,
- DepClassTy::REQUIRED, IsKnown))
- return S.indicatePessimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
+ auto CalleePred = [&](ArrayRef<const Function *> Callees) {
+ for (const Function *Callee : Callees) {
+ IRPosition FnPos =
+ IRPKind == llvm::IRPosition::IRP_CALL_SITE_RETURNED
+ ? IRPosition::returned(*Callee,
+ IntroduceCallBaseContext ? &CB : nullptr)
+ : IRPosition::function(
+ *Callee, IntroduceCallBaseContext ? &CB : nullptr);
+ // If possible, use the hasAssumedIRAttr interface.
+ if (Attribute::isEnumAttrKind(IRAttributeKind)) {
+ bool IsKnown;
+ if (!AA::hasAssumedIRAttr<IRAttributeKind>(
+ A, this, FnPos, DepClassTy::REQUIRED, IsKnown))
+ return false;
+ continue;
+ }
- const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
- if (!AA)
+ const AAType *AA =
+ A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!AA)
+ return false;
+ Changed |= clampStateAndIndicateChange(S, AA->getState());
+ if (S.isAtFixpoint())
+ return S.isValidState();
+ }
+ return true;
+ };
+ if (!A.checkForAllCallees(CalleePred, *this, CB))
return S.indicatePessimisticFixpoint();
- return clampStateAndIndicateChange(S, AA->getState());
+ return Changed;
}
};
@@ -865,11 +883,9 @@ struct AA::PointerInfo::State : public AbstractState {
AAPointerInfo::AccessKind Kind, Type *Ty,
Instruction *RemoteI = nullptr);
- using OffsetBinsTy = DenseMap<RangeTy, SmallSet<unsigned, 4>>;
-
- using const_bin_iterator = OffsetBinsTy::const_iterator;
- const_bin_iterator begin() const { return OffsetBins.begin(); }
- const_bin_iterator end() const { return OffsetBins.end(); }
+ AAPointerInfo::const_bin_iterator begin() const { return OffsetBins.begin(); }
+ AAPointerInfo::const_bin_iterator end() const { return OffsetBins.end(); }
+ int64_t numOffsetBins() const { return OffsetBins.size(); }
const AAPointerInfo::Access &getAccess(unsigned Index) const {
return AccessList[Index];
@@ -889,7 +905,7 @@ protected:
// are all combined into a single Access object. This may result in loss of
// information in RangeTy in the Access object.
SmallVector<AAPointerInfo::Access> AccessList;
- OffsetBinsTy OffsetBins;
+ AAPointerInfo::OffsetBinsTy OffsetBins;
DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;
/// See AAPointerInfo::forallInterferingAccesses.
@@ -1093,6 +1109,12 @@ struct AAPointerInfoImpl
return AAPointerInfo::manifest(A);
}
+ virtual const_bin_iterator begin() const override { return State::begin(); }
+ virtual const_bin_iterator end() const override { return State::end(); }
+ virtual int64_t numOffsetBins() const override {
+ return State::numOffsetBins();
+ }
+
bool forallInterferingAccesses(
AA::RangeTy Range,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB)
@@ -1104,7 +1126,8 @@ struct AAPointerInfoImpl
Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I,
bool FindInterferingWrites, bool FindInterferingReads,
function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo,
- AA::RangeTy &Range) const override {
+ AA::RangeTy &Range,
+ function_ref<bool(const Access &)> SkipCB) const override {
HasBeenWrittenTo = false;
SmallPtrSet<const Access *, 8> DominatingWrites;
@@ -1183,6 +1206,11 @@ struct AAPointerInfoImpl
A, this, IRPosition::function(Scope), DepClassTy::OPTIONAL,
IsKnownNoRecurse);
+ // TODO: Use reaching kernels from AAKernelInfo (or move it to
+ // AAExecutionDomain) such that we allow scopes other than kernels as long
+ // as the reaching kernels are disjoint.
+ bool InstInKernel = Scope.hasFnAttribute("kernel");
+ bool ObjHasKernelLifetime = false;
const bool UseDominanceReasoning =
FindInterferingWrites && IsKnownNoRecurse;
const DominatorTree *DT =
@@ -1215,6 +1243,7 @@ struct AAPointerInfoImpl
// If the alloca containing function is not recursive the alloca
// must be dead in the callee.
const Function *AIFn = AI->getFunction();
+ ObjHasKernelLifetime = AIFn->hasFnAttribute("kernel");
bool IsKnownNoRecurse;
if (AA::hasAssumedIRAttr<Attribute::NoRecurse>(
A, this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL,
@@ -1224,7 +1253,8 @@ struct AAPointerInfoImpl
} else if (auto *GV = dyn_cast<GlobalValue>(&getAssociatedValue())) {
// If the global has kernel lifetime we can stop if we reach a kernel
// as it is "dead" in the (unknown) callees.
- if (HasKernelLifetime(GV, *GV->getParent()))
+ ObjHasKernelLifetime = HasKernelLifetime(GV, *GV->getParent());
+ if (ObjHasKernelLifetime)
IsLiveInCalleeCB = [](const Function &Fn) {
return !Fn.hasFnAttribute("kernel");
};
@@ -1235,6 +1265,15 @@ struct AAPointerInfoImpl
AA::InstExclusionSetTy ExclusionSet;
auto AccessCB = [&](const Access &Acc, bool Exact) {
+ Function *AccScope = Acc.getRemoteInst()->getFunction();
+ bool AccInSameScope = AccScope == &Scope;
+
+ // If the object has kernel lifetime we can ignore accesses only reachable
+ // by other kernels. For now we only skip accesses *in* other kernels.
+ if (InstInKernel && ObjHasKernelLifetime && !AccInSameScope &&
+ AccScope->hasFnAttribute("kernel"))
+ return true;
+
if (Exact && Acc.isMustAccess() && Acc.getRemoteInst() != &I) {
if (Acc.isWrite() || (isa<LoadInst>(I) && Acc.isWriteOrAssumption()))
ExclusionSet.insert(Acc.getRemoteInst());
@@ -1245,8 +1284,7 @@ struct AAPointerInfoImpl
return true;
bool Dominates = FindInterferingWrites && DT && Exact &&
- Acc.isMustAccess() &&
- (Acc.getRemoteInst()->getFunction() == &Scope) &&
+ Acc.isMustAccess() && AccInSameScope &&
DT->dominates(Acc.getRemoteInst(), &I);
if (Dominates)
DominatingWrites.insert(&Acc);
@@ -1276,6 +1314,8 @@ struct AAPointerInfoImpl
// Helper to determine if we can skip a specific write access.
auto CanSkipAccess = [&](const Access &Acc, bool Exact) {
+ if (SkipCB && SkipCB(Acc))
+ return true;
if (!CanIgnoreThreading(Acc))
return false;
@@ -1817,9 +1857,14 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Assumption found "
<< *Assumption.second << ": " << *LoadI
<< " == " << *Assumption.first << "\n");
-
+ bool UsedAssumedInformation = false;
+ std::optional<Value *> Content = nullptr;
+ if (Assumption.first)
+ Content =
+ A.getAssumedSimplified(*Assumption.first, *this,
+ UsedAssumedInformation, AA::Interprocedural);
return handleAccess(
- A, *Assumption.second, Assumption.first, AccessKind::AK_ASSUMPTION,
+ A, *Assumption.second, Content, AccessKind::AK_ASSUMPTION,
OffsetInfoMap[CurPtr].Offsets, Changed, *LoadI->getType());
}
@@ -2083,24 +2128,10 @@ struct AANoUnwindFunction final : public AANoUnwindImpl {
};
/// NoUnwind attribute deduction for a call sites.
-struct AANoUnwindCallSite final : AANoUnwindImpl {
+struct AANoUnwindCallSite final
+ : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl> {
AANoUnwindCallSite(const IRPosition &IRP, Attributor &A)
- : AANoUnwindImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnownNoUnwind;
- if (AA::hasAssumedIRAttr<Attribute::NoUnwind>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoUnwind))
- return ChangeStatus::UNCHANGED;
- return indicatePessimisticFixpoint();
- }
+ : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); }
@@ -2200,8 +2231,15 @@ ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) {
if (I.mayReadOrWriteMemory())
return true;
+ bool IsKnown;
+ CallBase &CB = cast<CallBase>(I);
+ if (AA::hasAssumedIRAttr<Attribute::NoSync>(
+ A, this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL,
+ IsKnown))
+ return true;
+
// non-convergent and readnone imply nosync.
- return !cast<CallBase>(I).isConvergent();
+ return !CB.isConvergent();
};
bool UsedAssumedInformation = false;
@@ -2223,24 +2261,9 @@ struct AANoSyncFunction final : public AANoSyncImpl {
};
/// NoSync attribute deduction for a call sites.
-struct AANoSyncCallSite final : AANoSyncImpl {
+struct AANoSyncCallSite final : AACalleeToCallSite<AANoSync, AANoSyncImpl> {
AANoSyncCallSite(const IRPosition &IRP, Attributor &A)
- : AANoSyncImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnownNoSycn;
- if (AA::hasAssumedIRAttr<Attribute::NoSync>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoSycn))
- return ChangeStatus::UNCHANGED;
- return indicatePessimisticFixpoint();
- }
+ : AACalleeToCallSite<AANoSync, AANoSyncImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); }
@@ -2292,24 +2315,9 @@ struct AANoFreeFunction final : public AANoFreeImpl {
};
/// NoFree attribute deduction for a call sites.
-struct AANoFreeCallSite final : AANoFreeImpl {
+struct AANoFreeCallSite final : AACalleeToCallSite<AANoFree, AANoFreeImpl> {
AANoFreeCallSite(const IRPosition &IRP, Attributor &A)
- : AANoFreeImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnown;
- if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, FnPos,
- DepClassTy::REQUIRED, IsKnown))
- return ChangeStatus::UNCHANGED;
- return indicatePessimisticFixpoint();
- }
+ : AACalleeToCallSite<AANoFree, AANoFreeImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nofree); }
@@ -2450,9 +2458,6 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP,
if (A.hasAttr(IRP, AttrKinds, IgnoreSubsumingPositions, Attribute::NonNull))
return true;
- if (IRP.getPositionKind() == IRP_RETURNED)
- return false;
-
DominatorTree *DT = nullptr;
AssumptionCache *AC = nullptr;
InformationCache &InfoCache = A.getInfoCache();
@@ -2463,9 +2468,27 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP,
}
}
- if (!isKnownNonZero(&IRP.getAssociatedValue(), A.getDataLayout(), 0, AC,
- IRP.getCtxI(), DT))
+ SmallVector<AA::ValueAndContext> Worklist;
+ if (IRP.getPositionKind() != IRP_RETURNED) {
+ Worklist.push_back({IRP.getAssociatedValue(), IRP.getCtxI()});
+ } else {
+ bool UsedAssumedInformation = false;
+ if (!A.checkForAllInstructions(
+ [&](Instruction &I) {
+ Worklist.push_back({*cast<ReturnInst>(I).getReturnValue(), &I});
+ return true;
+ },
+ IRP.getAssociatedFunction(), nullptr, {Instruction::Ret},
+ UsedAssumedInformation))
+ return false;
+ }
+
+ if (llvm::any_of(Worklist, [&](AA::ValueAndContext VAC) {
+ return !isKnownNonZero(VAC.getValue(), A.getDataLayout(), 0, AC,
+ VAC.getCtxI(), DT);
+ }))
return false;
+
A.manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(),
Attribute::NonNull)});
return true;
@@ -2529,7 +2552,8 @@ static int64_t getKnownNonNullAndDerefBytesForUse(
}
std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I);
- if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || I->isVolatile())
+ if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() ||
+ Loc->Size.isScalable() || I->isVolatile())
return 0;
int64_t Offset;
@@ -2610,6 +2634,23 @@ struct AANonNullFloating : public AANonNullImpl {
Values.size() != 1 || Values.front().getValue() != AssociatedValue;
if (!Stripped) {
+ bool IsKnown;
+ if (auto *PHI = dyn_cast<PHINode>(AssociatedValue))
+ if (llvm::all_of(PHI->incoming_values(), [&](Value *Op) {
+ return AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::value(*Op), DepClassTy::OPTIONAL,
+ IsKnown);
+ }))
+ return ChangeStatus::UNCHANGED;
+ if (auto *Select = dyn_cast<SelectInst>(AssociatedValue))
+ if (AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::value(*Select->getFalseValue()),
+ DepClassTy::OPTIONAL, IsKnown) &&
+ AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::value(*Select->getTrueValue()),
+ DepClassTy::OPTIONAL, IsKnown))
+ return ChangeStatus::UNCHANGED;
+
// If we haven't stripped anything we might still be able to use a
// different AA, but only if the IRP changes. Effectively when we
// interpret this not as a call site value but as a floating/argument
@@ -2634,10 +2675,11 @@ struct AANonNullFloating : public AANonNullImpl {
/// NonNull attribute for function return value.
struct AANonNullReturned final
: AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType,
- false, AANonNull::IRAttributeKind> {
+ false, AANonNull::IRAttributeKind, false> {
AANonNullReturned(const IRPosition &IRP, Attributor &A)
: AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType,
- false, Attribute::NonNull>(IRP, A) {}
+ false, Attribute::NonNull, false>(IRP, A) {
+ }
/// See AbstractAttribute::getAsStr().
const std::string getAsStr(Attributor *A) const override {
@@ -2650,13 +2692,9 @@ struct AANonNullReturned final
/// NonNull attribute for function argument.
struct AANonNullArgument final
- : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl,
- AANonNull::StateType, false,
- AANonNull::IRAttributeKind> {
+ : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl> {
AANonNullArgument(const IRPosition &IRP, Attributor &A)
- : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl,
- AANonNull::StateType, false,
- AANonNull::IRAttributeKind>(IRP, A) {}
+ : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) }
@@ -2672,13 +2710,9 @@ struct AANonNullCallSiteArgument final : AANonNullFloating {
/// NonNull attribute for a call site return position.
struct AANonNullCallSiteReturned final
- : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl,
- AANonNull::StateType, false,
- AANonNull::IRAttributeKind> {
+ : AACalleeToCallSite<AANonNull, AANonNullImpl> {
AANonNullCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl,
- AANonNull::StateType, false,
- AANonNull::IRAttributeKind>(IRP, A) {}
+ : AACalleeToCallSite<AANonNull, AANonNullImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) }
@@ -2830,24 +2864,10 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
};
/// NoRecurse attribute deduction for a call sites.
-struct AANoRecurseCallSite final : AANoRecurseImpl {
+struct AANoRecurseCallSite final
+ : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl> {
AANoRecurseCallSite(const IRPosition &IRP, Attributor &A)
- : AANoRecurseImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnownNoRecurse;
- if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoRecurse))
- return indicatePessimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
+ : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); }
@@ -3355,26 +3375,17 @@ struct AAWillReturnFunction final : AAWillReturnImpl {
};
/// WillReturn attribute deduction for a call sites.
-struct AAWillReturnCallSite final : AAWillReturnImpl {
+struct AAWillReturnCallSite final
+ : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl> {
AAWillReturnCallSite(const IRPosition &IRP, Attributor &A)
- : AAWillReturnImpl(IRP, A) {}
+ : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl>(IRP, A) {}
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false))
return ChangeStatus::UNCHANGED;
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnown;
- if (AA::hasAssumedIRAttr<Attribute::WillReturn>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnown))
- return ChangeStatus::UNCHANGED;
- return indicatePessimisticFixpoint();
+ return AACalleeToCallSite::updateImpl(A);
}
/// See AbstractAttribute::trackStatistics()
@@ -3402,6 +3413,18 @@ template <typename ToTy> struct ReachabilityQueryInfo {
/// and remember if it worked:
Reachable Result = Reachable::No;
+ /// Precomputed hash for this RQI.
+ unsigned Hash = 0;
+
+ unsigned computeHashValue() const {
+ assert(Hash == 0 && "Computed hash twice!");
+ using InstSetDMI = DenseMapInfo<const AA::InstExclusionSetTy *>;
+ using PairDMI = DenseMapInfo<std::pair<const Instruction *, const ToTy *>>;
+ return const_cast<ReachabilityQueryInfo<ToTy> *>(this)->Hash =
+ detail::combineHashValue(PairDMI ::getHashValue({From, To}),
+ InstSetDMI::getHashValue(ExclusionSet));
+ }
+
ReachabilityQueryInfo(const Instruction *From, const ToTy *To)
: From(From), To(To) {}
@@ -3435,9 +3458,7 @@ template <typename ToTy> struct DenseMapInfo<ReachabilityQueryInfo<ToTy> *> {
return &TombstoneKey;
}
static unsigned getHashValue(const ReachabilityQueryInfo<ToTy> *RQI) {
- unsigned H = PairDMI ::getHashValue({RQI->From, RQI->To});
- H += InstSetDMI::getHashValue(RQI->ExclusionSet);
- return H;
+ return RQI->Hash ? RQI->Hash : RQI->computeHashValue();
}
static bool isEqual(const ReachabilityQueryInfo<ToTy> *LHS,
const ReachabilityQueryInfo<ToTy> *RHS) {
@@ -3480,24 +3501,24 @@ struct CachedReachabilityAA : public BaseTy {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;
- InUpdate = true;
for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) {
RQITy *RQI = QueryVector[u];
- if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI))
+ if (RQI->Result == RQITy::Reachable::No &&
+ isReachableImpl(A, *RQI, /*IsTemporaryRQI=*/false))
Changed = ChangeStatus::CHANGED;
}
- InUpdate = false;
return Changed;
}
- virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0;
+ virtual bool isReachableImpl(Attributor &A, RQITy &RQI,
+ bool IsTemporaryRQI) = 0;
bool rememberResult(Attributor &A, typename RQITy::Reachable Result,
- RQITy &RQI, bool UsedExclusionSet) {
+ RQITy &RQI, bool UsedExclusionSet, bool IsTemporaryRQI) {
RQI.Result = Result;
// Remove the temporary RQI from the cache.
- if (!InUpdate)
+ if (IsTemporaryRQI)
QueryCache.erase(&RQI);
// Insert a plain RQI (w/o exclusion set) if that makes sense. Two options:
@@ -3515,7 +3536,7 @@ struct CachedReachabilityAA : public BaseTy {
}
// Check if we need to insert a new permanent RQI with the exclusion set.
- if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
+ if (IsTemporaryRQI && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) &&
"Did not expect empty set!");
RQITy *RQIPtr = new (A.Allocator)
@@ -3527,7 +3548,7 @@ struct CachedReachabilityAA : public BaseTy {
QueryCache.insert(RQIPtr);
}
- if (Result == RQITy::Reachable::No && !InUpdate)
+ if (Result == RQITy::Reachable::No && IsTemporaryRQI)
A.registerForUpdate(*this);
return Result == RQITy::Reachable::Yes;
}
@@ -3568,7 +3589,6 @@ struct CachedReachabilityAA : public BaseTy {
}
private:
- bool InUpdate = false;
SmallVector<RQITy *> QueryVector;
DenseSet<RQITy *> QueryCache;
};
@@ -3577,7 +3597,10 @@ struct AAIntraFnReachabilityFunction final
: public CachedReachabilityAA<AAIntraFnReachability, Instruction> {
using Base = CachedReachabilityAA<AAIntraFnReachability, Instruction>;
AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A)
- : Base(IRP, A) {}
+ : Base(IRP, A) {
+ DT = A.getInfoCache().getAnalysisResultForFunction<DominatorTreeAnalysis>(
+ *IRP.getAssociatedFunction());
+ }
bool isAssumedReachable(
Attributor &A, const Instruction &From, const Instruction &To,
@@ -3589,7 +3612,8 @@ struct AAIntraFnReachabilityFunction final
RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
- return NonConstThis->isReachableImpl(A, StackRQI);
+ return NonConstThis->isReachableImpl(A, StackRQI,
+ /*IsTemporaryRQI=*/true);
return Result == RQITy::Reachable::Yes;
}
@@ -3598,16 +3622,24 @@ struct AAIntraFnReachabilityFunction final
// of them changed.
auto *LivenessAA =
A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
- if (LivenessAA && llvm::all_of(DeadEdges, [&](const auto &DeadEdge) {
- return LivenessAA->isEdgeDead(DeadEdge.first, DeadEdge.second);
+ if (LivenessAA &&
+ llvm::all_of(DeadEdges,
+ [&](const auto &DeadEdge) {
+ return LivenessAA->isEdgeDead(DeadEdge.first,
+ DeadEdge.second);
+ }) &&
+ llvm::all_of(DeadBlocks, [&](const BasicBlock *BB) {
+ return LivenessAA->isAssumedDead(BB);
})) {
return ChangeStatus::UNCHANGED;
}
DeadEdges.clear();
+ DeadBlocks.clear();
return Base::updateImpl(A);
}
- bool isReachableImpl(Attributor &A, RQITy &RQI) override {
+ bool isReachableImpl(Attributor &A, RQITy &RQI,
+ bool IsTemporaryRQI) override {
const Instruction *Origin = RQI.From;
bool UsedExclusionSet = false;
@@ -3633,31 +3665,41 @@ struct AAIntraFnReachabilityFunction final
// possible.
if (FromBB == ToBB &&
WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
// Check if reaching the ToBB block is sufficient or if even that would not
// ensure reaching the target. In the latter case we are done.
if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
+ const Function *Fn = FromBB->getParent();
SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks;
if (RQI.ExclusionSet)
for (auto *I : *RQI.ExclusionSet)
- ExclusionBlocks.insert(I->getParent());
+ if (I->getFunction() == Fn)
+ ExclusionBlocks.insert(I->getParent());
// Check if we make it out of the FromBB block at all.
if (ExclusionBlocks.count(FromBB) &&
!WillReachInBlock(*RQI.From, *FromBB->getTerminator(),
RQI.ExclusionSet))
- return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::No, RQI, true, IsTemporaryRQI);
+
+ auto *LivenessAA =
+ A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (LivenessAA && LivenessAA->isAssumedDead(ToBB)) {
+ DeadBlocks.insert(ToBB);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
+ }
SmallPtrSet<const BasicBlock *, 16> Visited;
SmallVector<const BasicBlock *, 16> Worklist;
Worklist.push_back(FromBB);
DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> LocalDeadEdges;
- auto *LivenessAA =
- A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
while (!Worklist.empty()) {
const BasicBlock *BB = Worklist.pop_back_val();
if (!Visited.insert(BB).second)
@@ -3669,8 +3711,12 @@ struct AAIntraFnReachabilityFunction final
}
// We checked before if we just need to reach the ToBB block.
if (SuccBB == ToBB)
- return rememberResult(A, RQITy::Reachable::Yes, RQI,
- UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
+ if (DT && ExclusionBlocks.empty() && DT->dominates(BB, ToBB))
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
+
if (ExclusionBlocks.count(SuccBB)) {
UsedExclusionSet = true;
continue;
@@ -3680,16 +3726,24 @@ struct AAIntraFnReachabilityFunction final
}
DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end());
- return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {}
private:
+ // Set of assumed dead blocks we used in the last query. If any changes we
+ // update the state.
+ DenseSet<const BasicBlock *> DeadBlocks;
+
// Set of assumed dead edges we used in the last query. If any changes we
// update the state.
DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> DeadEdges;
+
+ /// The dominator tree of the function to short-circuit reasoning.
+ const DominatorTree *DT = nullptr;
};
} // namespace
@@ -3754,12 +3808,8 @@ struct AANoAliasFloating final : AANoAliasImpl {
/// NoAlias attribute for an argument.
struct AANoAliasArgument final
- : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl,
- AANoAlias::StateType, false,
- Attribute::NoAlias> {
- using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl,
- AANoAlias::StateType, false,
- Attribute::NoAlias>;
+ : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> {
+ using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>;
AANoAliasArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {}
/// See AbstractAttribute::update(...).
@@ -4027,24 +4077,10 @@ struct AANoAliasReturned final : AANoAliasImpl {
};
/// NoAlias attribute deduction for a call site return value.
-struct AANoAliasCallSiteReturned final : AANoAliasImpl {
+struct AANoAliasCallSiteReturned final
+ : AACalleeToCallSite<AANoAlias, AANoAliasImpl> {
AANoAliasCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AANoAliasImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::returned(*F);
- bool IsKnownNoAlias;
- if (!AA::hasAssumedIRAttr<Attribute::NoAlias>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoAlias))
- return indicatePessimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
+ : AACalleeToCallSite<AANoAlias, AANoAliasImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); }
@@ -4696,23 +4732,53 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI,
AbstractAttribute &AA,
SmallVectorImpl<const Instruction *> &AliveSuccessors) {
bool UsedAssumedInformation = false;
- std::optional<Constant *> C =
- A.getAssumedConstant(*SI.getCondition(), AA, UsedAssumedInformation);
- if (!C || isa_and_nonnull<UndefValue>(*C)) {
- // No value yet, assume all edges are dead.
- } else if (isa_and_nonnull<ConstantInt>(*C)) {
- for (const auto &CaseIt : SI.cases()) {
- if (CaseIt.getCaseValue() == *C) {
- AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front());
- return UsedAssumedInformation;
- }
- }
- AliveSuccessors.push_back(&SI.getDefaultDest()->front());
+ SmallVector<AA::ValueAndContext> Values;
+ if (!A.getAssumedSimplifiedValues(IRPosition::value(*SI.getCondition()), &AA,
+ Values, AA::AnyScope,
+ UsedAssumedInformation)) {
+ // Something went wrong, assume all successors are live.
+ for (const BasicBlock *SuccBB : successors(SI.getParent()))
+ AliveSuccessors.push_back(&SuccBB->front());
+ return false;
+ }
+
+ if (Values.empty() ||
+ (Values.size() == 1 &&
+ isa_and_nonnull<UndefValue>(Values.front().getValue()))) {
+ // No valid value yet, assume all edges are dead.
return UsedAssumedInformation;
- } else {
+ }
+
+ Type &Ty = *SI.getCondition()->getType();
+ SmallPtrSet<ConstantInt *, 8> Constants;
+ auto CheckForConstantInt = [&](Value *V) {
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(AA::getWithType(*V, Ty))) {
+ Constants.insert(CI);
+ return true;
+ }
+ return false;
+ };
+
+ if (!all_of(Values, [&](AA::ValueAndContext &VAC) {
+ return CheckForConstantInt(VAC.getValue());
+ })) {
for (const BasicBlock *SuccBB : successors(SI.getParent()))
AliveSuccessors.push_back(&SuccBB->front());
+ return UsedAssumedInformation;
+ }
+
+ unsigned MatchedCases = 0;
+ for (const auto &CaseIt : SI.cases()) {
+ if (Constants.count(CaseIt.getCaseValue())) {
+ ++MatchedCases;
+ AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front());
+ }
}
+
+ // If all potential values have been matched, we will not visit the default
+ // case.
+ if (MatchedCases < Constants.size())
+ AliveSuccessors.push_back(&SI.getDefaultDest()->front());
return UsedAssumedInformation;
}
@@ -5103,9 +5169,8 @@ struct AADereferenceableCallSiteArgument final : AADereferenceableFloating {
/// Dereferenceable attribute deduction for a call site return value.
struct AADereferenceableCallSiteReturned final
- : AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl> {
- using Base =
- AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl>;
+ : AACalleeToCallSite<AADereferenceable, AADereferenceableImpl> {
+ using Base = AACalleeToCallSite<AADereferenceable, AADereferenceableImpl>;
AADereferenceableCallSiteReturned(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
@@ -5400,8 +5465,8 @@ struct AAAlignCallSiteArgument final : AAAlignFloating {
/// Align attribute deduction for a call site return value.
struct AAAlignCallSiteReturned final
- : AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl> {
- using Base = AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl>;
+ : AACalleeToCallSite<AAAlign, AAAlignImpl> {
+ using Base = AACalleeToCallSite<AAAlign, AAAlignImpl>;
AAAlignCallSiteReturned(const IRPosition &IRP, Attributor &A)
: Base(IRP, A) {}
@@ -5449,24 +5514,10 @@ struct AANoReturnFunction final : AANoReturnImpl {
};
/// NoReturn attribute deduction for a call sites.
-struct AANoReturnCallSite final : AANoReturnImpl {
+struct AANoReturnCallSite final
+ : AACalleeToCallSite<AANoReturn, AANoReturnImpl> {
AANoReturnCallSite(const IRPosition &IRP, Attributor &A)
- : AANoReturnImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- bool IsKnownNoReturn;
- if (!AA::hasAssumedIRAttr<Attribute::NoReturn>(
- A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoReturn))
- return indicatePessimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
+ : AACalleeToCallSite<AANoReturn, AANoReturnImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); }
@@ -5805,8 +5856,8 @@ struct AANoCaptureImpl : public AANoCapture {
// For stores we already checked if we can follow them, if they make it
// here we give up.
if (isa<StoreInst>(UInst))
- return isCapturedIn(State, /* Memory */ true, /* Integer */ false,
- /* Return */ false);
+ return isCapturedIn(State, /* Memory */ true, /* Integer */ true,
+ /* Return */ true);
// Explicitly catch return instructions.
if (isa<ReturnInst>(UInst)) {
@@ -6476,7 +6527,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl {
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
- return indicatePessimisticFixpoint();
+ return indicatePessimisticFixpoint();
}
void trackStatistics() const override {
@@ -6937,13 +6988,17 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) {
<< **DI->PotentialAllocationCalls.begin() << "\n");
return false;
}
- Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode();
- if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) {
- LLVM_DEBUG(
- dbgs()
- << "[H2S] unique free call might not be executed with the allocation "
- << *UniqueFree << "\n");
- return false;
+
+ // __kmpc_alloc_shared and __kmpc_alloc_free are by construction matched.
+ if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared) {
+ Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode();
+ if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) {
+ LLVM_DEBUG(
+ dbgs()
+ << "[H2S] unique free call might not be executed with the allocation "
+ << *UniqueFree << "\n");
+ return false;
+ }
}
return true;
};
@@ -7796,6 +7851,9 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior {
// Clear existing attributes.
A.removeAttrs(IRP, AttrKinds);
+ // Clear conflicting writable attribute.
+ if (isAssumedReadOnly())
+ A.removeAttrs(IRP, Attribute::Writable);
// Use the generic manifest method.
return IRAttribute::manifest(A);
@@ -7983,6 +8041,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl {
ME = MemoryEffects::writeOnly();
A.removeAttrs(getIRPosition(), AttrKinds);
+ // Clear conflicting writable attribute.
+ if (ME.onlyReadsMemory())
+ for (Argument &Arg : F.args())
+ A.removeAttrs(IRPosition::argument(Arg), Attribute::Writable);
return A.manifestAttrs(getIRPosition(),
Attribute::getWithMemoryEffects(F.getContext(), ME));
}
@@ -7999,24 +8061,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl {
};
/// AAMemoryBehavior attribute for call sites.
-struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
+struct AAMemoryBehaviorCallSite final
+ : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl> {
AAMemoryBehaviorCallSite(const IRPosition &IRP, Attributor &A)
- : AAMemoryBehaviorImpl(IRP, A) {}
-
- /// See AbstractAttribute::updateImpl(...).
- ChangeStatus updateImpl(Attributor &A) override {
- // TODO: Once we have call site specific value information we can provide
- // call site specific liveness liveness information and then it makes
- // sense to specialize attributes for call sites arguments instead of
- // redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
- const IRPosition &FnPos = IRPosition::function(*F);
- auto *FnAA =
- A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::REQUIRED);
- if (!FnAA)
- return indicatePessimisticFixpoint();
- return clampStateAndIndicateChange(getState(), FnAA->getState());
- }
+ : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl>(IRP, A) {}
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
@@ -8031,6 +8079,11 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl {
ME = MemoryEffects::writeOnly();
A.removeAttrs(getIRPosition(), AttrKinds);
+ // Clear conflicting writable attribute.
+ if (ME.onlyReadsMemory())
+ for (Use &U : CB.args())
+ A.removeAttrs(IRPosition::callsite_argument(CB, U.getOperandNo()),
+ Attribute::Writable);
return A.manifestAttrs(
getIRPosition(), Attribute::getWithMemoryEffects(CB.getContext(), ME));
}
@@ -8821,6 +8874,108 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl {
};
} // namespace
+/// ------------------ denormal-fp-math Attribute -------------------------
+
+namespace {
+struct AADenormalFPMathImpl : public AADenormalFPMath {
+ AADenormalFPMathImpl(const IRPosition &IRP, Attributor &A)
+ : AADenormalFPMath(IRP, A) {}
+
+ const std::string getAsStr(Attributor *A) const override {
+ std::string Str("AADenormalFPMath[");
+ raw_string_ostream OS(Str);
+
+ DenormalState Known = getKnown();
+ if (Known.Mode.isValid())
+ OS << "denormal-fp-math=" << Known.Mode;
+ else
+ OS << "invalid";
+
+ if (Known.ModeF32.isValid())
+ OS << " denormal-fp-math-f32=" << Known.ModeF32;
+ OS << ']';
+ return OS.str();
+ }
+};
+
+struct AADenormalFPMathFunction final : AADenormalFPMathImpl {
+ AADenormalFPMathFunction(const IRPosition &IRP, Attributor &A)
+ : AADenormalFPMathImpl(IRP, A) {}
+
+ void initialize(Attributor &A) override {
+ const Function *F = getAnchorScope();
+ DenormalMode Mode = F->getDenormalModeRaw();
+ DenormalMode ModeF32 = F->getDenormalModeF32Raw();
+
+ // TODO: Handling this here prevents handling the case where a callee has a
+ // fixed denormal-fp-math with dynamic denormal-fp-math-f32, but called from
+ // a function with a fully fixed mode.
+ if (ModeF32 == DenormalMode::getInvalid())
+ ModeF32 = Mode;
+ Known = DenormalState{Mode, ModeF32};
+ if (isModeFixed())
+ indicateFixpoint();
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+ auto CheckCallSite = [=, &Change, &A](AbstractCallSite CS) {
+ Function *Caller = CS.getInstruction()->getFunction();
+ LLVM_DEBUG(dbgs() << "[AADenormalFPMath] Call " << Caller->getName()
+ << "->" << getAssociatedFunction()->getName() << '\n');
+
+ const auto *CallerInfo = A.getAAFor<AADenormalFPMath>(
+ *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
+ if (!CallerInfo)
+ return false;
+
+ Change = Change | clampStateAndIndicateChange(this->getState(),
+ CallerInfo->getState());
+ return true;
+ };
+
+ bool AllCallSitesKnown = true;
+ if (!A.checkForAllCallSites(CheckCallSite, *this, true, AllCallSitesKnown))
+ return indicatePessimisticFixpoint();
+
+ if (Change == ChangeStatus::CHANGED && isModeFixed())
+ indicateFixpoint();
+ return Change;
+ }
+
+ ChangeStatus manifest(Attributor &A) override {
+ LLVMContext &Ctx = getAssociatedFunction()->getContext();
+
+ SmallVector<Attribute, 2> AttrToAdd;
+ SmallVector<StringRef, 2> AttrToRemove;
+ if (Known.Mode == DenormalMode::getDefault()) {
+ AttrToRemove.push_back("denormal-fp-math");
+ } else {
+ AttrToAdd.push_back(
+ Attribute::get(Ctx, "denormal-fp-math", Known.Mode.str()));
+ }
+
+ if (Known.ModeF32 != Known.Mode) {
+ AttrToAdd.push_back(
+ Attribute::get(Ctx, "denormal-fp-math-f32", Known.ModeF32.str()));
+ } else {
+ AttrToRemove.push_back("denormal-fp-math-f32");
+ }
+
+ auto &IRP = getIRPosition();
+
+ // TODO: There should be a combined add and remove API.
+ return A.removeAttrs(IRP, AttrToRemove) |
+ A.manifestAttrs(IRP, AttrToAdd, /*ForceReplace=*/true);
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FN_ATTR(denormal_fp_math)
+ }
+};
+} // namespace
+
/// ------------------ Value Constant Range Attribute -------------------------
namespace {
@@ -9427,17 +9582,13 @@ struct AAValueConstantRangeCallSite : AAValueConstantRangeFunction {
};
struct AAValueConstantRangeCallSiteReturned
- : AACallSiteReturnedFromReturned<AAValueConstantRange,
- AAValueConstantRangeImpl,
- AAValueConstantRangeImpl::StateType,
- /* IntroduceCallBaseContext */ true> {
+ : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl,
+ AAValueConstantRangeImpl::StateType,
+ /* IntroduceCallBaseContext */ true> {
AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AAValueConstantRange,
- AAValueConstantRangeImpl,
- AAValueConstantRangeImpl::StateType,
- /* IntroduceCallBaseContext */ true>(IRP,
- A) {
- }
+ : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl,
+ AAValueConstantRangeImpl::StateType,
+ /* IntroduceCallBaseContext */ true>(IRP, A) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
@@ -9956,12 +10107,12 @@ struct AAPotentialConstantValuesCallSite : AAPotentialConstantValuesFunction {
};
struct AAPotentialConstantValuesCallSiteReturned
- : AACallSiteReturnedFromReturned<AAPotentialConstantValues,
- AAPotentialConstantValuesImpl> {
+ : AACalleeToCallSite<AAPotentialConstantValues,
+ AAPotentialConstantValuesImpl> {
AAPotentialConstantValuesCallSiteReturned(const IRPosition &IRP,
Attributor &A)
- : AACallSiteReturnedFromReturned<AAPotentialConstantValues,
- AAPotentialConstantValuesImpl>(IRP, A) {}
+ : AACalleeToCallSite<AAPotentialConstantValues,
+ AAPotentialConstantValuesImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
@@ -10101,7 +10252,8 @@ struct AANoUndefFloating : public AANoUndefImpl {
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
AANoUndefImpl::initialize(A);
- if (!getState().isAtFixpoint())
+ if (!getState().isAtFixpoint() && getAnchorScope() &&
+ !getAnchorScope()->isDeclaration())
if (Instruction *CtxI = getCtxI())
followUsesInMBEC(*this, A, getState(), *CtxI);
}
@@ -10148,26 +10300,18 @@ struct AANoUndefFloating : public AANoUndefImpl {
};
struct AANoUndefReturned final
- : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef> {
+ : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> {
AANoUndefReturned(const IRPosition &IRP, Attributor &A)
- : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef>(IRP, A) {}
+ : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) }
};
struct AANoUndefArgument final
- : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef> {
+ : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> {
AANoUndefArgument(const IRPosition &IRP, Attributor &A)
- : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef>(IRP, A) {}
+ : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) }
@@ -10182,13 +10326,9 @@ struct AANoUndefCallSiteArgument final : AANoUndefFloating {
};
struct AANoUndefCallSiteReturned final
- : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef> {
+ : AACalleeToCallSite<AANoUndef, AANoUndefImpl> {
AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl,
- AANoUndef::StateType, false,
- Attribute::NoUndef>(IRP, A) {}
+ : AACalleeToCallSite<AANoUndef, AANoUndefImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) }
@@ -10212,7 +10352,6 @@ struct AANoFPClassImpl : AANoFPClass {
A.getAttrs(getIRPosition(), {Attribute::NoFPClass}, Attrs, false);
for (const auto &Attr : Attrs) {
addKnownBits(Attr.getNoFPClass());
- return;
}
const DataLayout &DL = A.getDataLayout();
@@ -10248,8 +10387,22 @@ struct AANoFPClassImpl : AANoFPClass {
/*Depth=*/0, TLI, AC, I, DT);
State.addKnownBits(~KnownFPClass.KnownFPClasses);
- bool TrackUse = false;
- return TrackUse;
+ if (auto *CI = dyn_cast<CallInst>(UseV)) {
+ // Special case FP intrinsic with struct return type.
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::frexp:
+ return true;
+ case Intrinsic::not_intrinsic:
+ // TODO: Could recognize math libcalls
+ return false;
+ default:
+ break;
+ }
+ }
+
+ if (!UseV->getType()->isFPOrFPVectorTy())
+ return false;
+ return !isa<LoadInst, AtomicRMWInst>(UseV);
}
const std::string getAsStr(Attributor *A) const override {
@@ -10339,9 +10492,9 @@ struct AANoFPClassCallSiteArgument final : AANoFPClassFloating {
};
struct AANoFPClassCallSiteReturned final
- : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl> {
+ : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl> {
AANoFPClassCallSiteReturned(const IRPosition &IRP, Attributor &A)
- : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl>(IRP, A) {}
+ : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl>(IRP, A) {}
/// See AbstractAttribute::trackStatistics()
void trackStatistics() const override {
@@ -10446,15 +10599,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
return Change;
}
- // Process callee metadata if available.
- if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) {
- for (const auto &Op : MD->operands()) {
- Function *Callee = mdconst::dyn_extract_or_null<Function>(Op);
- if (Callee)
- addCalledFunction(Callee, Change);
- }
- return Change;
- }
+ if (CB->isIndirectCall())
+ if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>(
+ *this, getIRPosition(), DepClassTy::OPTIONAL))
+ if (IndirectCallAA->foreachCallee(
+ [&](Function *Fn) { return VisitValue(*Fn, CB); }))
+ return Change;
// The most simple case.
ProcessCalledOperand(CB->getCalledOperand(), CB);
@@ -10519,28 +10669,26 @@ struct AAInterFnReachabilityFunction
bool instructionCanReach(
Attributor &A, const Instruction &From, const Function &To,
- const AA::InstExclusionSetTy *ExclusionSet,
- SmallPtrSet<const Function *, 16> *Visited) const override {
+ const AA::InstExclusionSetTy *ExclusionSet) const override {
assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!");
auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this);
RQITy StackRQI(A, From, To, ExclusionSet, false);
typename RQITy::Reachable Result;
if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
- return NonConstThis->isReachableImpl(A, StackRQI);
+ return NonConstThis->isReachableImpl(A, StackRQI,
+ /*IsTemporaryRQI=*/true);
return Result == RQITy::Reachable::Yes;
}
- bool isReachableImpl(Attributor &A, RQITy &RQI) override {
- return isReachableImpl(A, RQI, nullptr);
- }
-
bool isReachableImpl(Attributor &A, RQITy &RQI,
- SmallPtrSet<const Function *, 16> *Visited) {
-
- SmallPtrSet<const Function *, 16> LocalVisited;
- if (!Visited)
- Visited = &LocalVisited;
+ bool IsTemporaryRQI) override {
+ const Instruction *EntryI =
+ &RQI.From->getFunction()->getEntryBlock().front();
+ if (EntryI != RQI.From &&
+ !instructionCanReach(A, *EntryI, *RQI.To, nullptr))
+ return rememberResult(A, RQITy::Reachable::No, RQI, false,
+ IsTemporaryRQI);
auto CheckReachableCallBase = [&](CallBase *CB) {
auto *CBEdges = A.getAAFor<AACallEdges>(
@@ -10554,8 +10702,7 @@ struct AAInterFnReachabilityFunction
for (Function *Fn : CBEdges->getOptimisticEdges()) {
if (Fn == RQI.To)
return false;
- if (!Visited->insert(Fn).second)
- continue;
+
if (Fn->isDeclaration()) {
if (Fn->hasFnAttribute(Attribute::NoCallback))
continue;
@@ -10563,15 +10710,20 @@ struct AAInterFnReachabilityFunction
return false;
}
- const AAInterFnReachability *InterFnReachability = this;
- if (Fn != getAnchorScope())
- InterFnReachability = A.getAAFor<AAInterFnReachability>(
- *this, IRPosition::function(*Fn), DepClassTy::OPTIONAL);
+ if (Fn == getAnchorScope()) {
+ if (EntryI == RQI.From)
+ continue;
+ return false;
+ }
+
+ const AAInterFnReachability *InterFnReachability =
+ A.getAAFor<AAInterFnReachability>(*this, IRPosition::function(*Fn),
+ DepClassTy::OPTIONAL);
const Instruction &FnFirstInst = Fn->getEntryBlock().front();
if (!InterFnReachability ||
InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To,
- RQI.ExclusionSet, Visited))
+ RQI.ExclusionSet))
return false;
}
return true;
@@ -10583,10 +10735,12 @@ struct AAInterFnReachabilityFunction
// Determine call like instructions that we can reach from the inst.
auto CheckCallBase = [&](Instruction &CBInst) {
- if (!IntraFnReachability || !IntraFnReachability->isAssumedReachable(
- A, *RQI.From, CBInst, RQI.ExclusionSet))
+ // There are usually less nodes in the call graph, check inter function
+ // reachability first.
+ if (CheckReachableCallBase(cast<CallBase>(&CBInst)))
return true;
- return CheckReachableCallBase(cast<CallBase>(&CBInst));
+ return IntraFnReachability && !IntraFnReachability->isAssumedReachable(
+ A, *RQI.From, CBInst, RQI.ExclusionSet);
};
bool UsedExclusionSet = /* conservative */ true;
@@ -10594,16 +10748,14 @@ struct AAInterFnReachabilityFunction
if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this,
UsedAssumedInformation,
/* CheckBBLivenessOnly */ true))
- return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
- return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+ return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+ IsTemporaryRQI);
}
void trackStatistics() const override {}
-
-private:
- SmallVector<RQITy *> QueryVector;
- DenseSet<RQITy *> QueryCache;
};
} // namespace
@@ -10880,64 +11032,104 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
// Simplify the operands first.
bool UsedAssumedInformation = false;
- const auto &SimplifiedLHS = A.getAssumedSimplified(
- IRPosition::value(*LHS, getCallBaseContext()), *this,
- UsedAssumedInformation, AA::Intraprocedural);
- if (!SimplifiedLHS.has_value())
+ SmallVector<AA::ValueAndContext> LHSValues, RHSValues;
+ auto GetSimplifiedValues = [&](Value &V,
+ SmallVector<AA::ValueAndContext> &Values) {
+ if (!A.getAssumedSimplifiedValues(
+ IRPosition::value(V, getCallBaseContext()), this, Values,
+ AA::Intraprocedural, UsedAssumedInformation)) {
+ Values.clear();
+ Values.push_back(AA::ValueAndContext{V, II.I.getCtxI()});
+ }
+ return Values.empty();
+ };
+ if (GetSimplifiedValues(*LHS, LHSValues))
return true;
- if (!*SimplifiedLHS)
- return false;
- LHS = *SimplifiedLHS;
-
- const auto &SimplifiedRHS = A.getAssumedSimplified(
- IRPosition::value(*RHS, getCallBaseContext()), *this,
- UsedAssumedInformation, AA::Intraprocedural);
- if (!SimplifiedRHS.has_value())
+ if (GetSimplifiedValues(*RHS, RHSValues))
return true;
- if (!*SimplifiedRHS)
- return false;
- RHS = *SimplifiedRHS;
LLVMContext &Ctx = LHS->getContext();
- // Handle the trivial case first in which we don't even need to think about
- // null or non-null.
- if (LHS == RHS &&
- (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) {
- Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx),
- CmpInst::isTrueWhenEqual(Pred));
- addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S,
- getAnchorScope());
- return true;
- }
- // From now on we only handle equalities (==, !=).
- if (!CmpInst::isEquality(Pred))
- return false;
+ InformationCache &InfoCache = A.getInfoCache();
+ Instruction *CmpI = dyn_cast<Instruction>(&Cmp);
+ Function *F = CmpI ? CmpI->getFunction() : nullptr;
+ const auto *DT =
+ F ? InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F)
+ : nullptr;
+ const auto *TLI =
+ F ? A.getInfoCache().getTargetLibraryInfoForFunction(*F) : nullptr;
+ auto *AC =
+ F ? InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F)
+ : nullptr;
- bool LHSIsNull = isa<ConstantPointerNull>(LHS);
- bool RHSIsNull = isa<ConstantPointerNull>(RHS);
- if (!LHSIsNull && !RHSIsNull)
- return false;
+ const DataLayout &DL = A.getDataLayout();
+ SimplifyQuery Q(DL, TLI, DT, AC, CmpI);
- // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the
- // non-nullptr operand and if we assume it's non-null we can conclude the
- // result of the comparison.
- assert((LHSIsNull || RHSIsNull) &&
- "Expected nullptr versus non-nullptr comparison at this point");
+ auto CheckPair = [&](Value &LHSV, Value &RHSV) {
+ if (isa<UndefValue>(LHSV) || isa<UndefValue>(RHSV)) {
+ addValue(A, getState(), *UndefValue::get(Cmp.getType()),
+ /* CtxI */ nullptr, II.S, getAnchorScope());
+ return true;
+ }
- // The index is the operand that we assume is not null.
- unsigned PtrIdx = LHSIsNull;
- bool IsKnownNonNull;
- bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
- A, this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED,
- IsKnownNonNull);
- if (!IsAssumedNonNull)
- return false;
+ // Handle the trivial case first in which we don't even need to think
+ // about null or non-null.
+ if (&LHSV == &RHSV &&
+ (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) {
+ Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx),
+ CmpInst::isTrueWhenEqual(Pred));
+ addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S,
+ getAnchorScope());
+ return true;
+ }
+
+ auto *TypedLHS = AA::getWithType(LHSV, *LHS->getType());
+ auto *TypedRHS = AA::getWithType(RHSV, *RHS->getType());
+ if (TypedLHS && TypedRHS) {
+ Value *NewV = simplifyCmpInst(Pred, TypedLHS, TypedRHS, Q);
+ if (NewV && NewV != &Cmp) {
+ addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S,
+ getAnchorScope());
+ return true;
+ }
+ }
+
+ // From now on we only handle equalities (==, !=).
+ if (!CmpInst::isEquality(Pred))
+ return false;
+
+ bool LHSIsNull = isa<ConstantPointerNull>(LHSV);
+ bool RHSIsNull = isa<ConstantPointerNull>(RHSV);
+ if (!LHSIsNull && !RHSIsNull)
+ return false;
+
+ // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the
+ // non-nullptr operand and if we assume it's non-null we can conclude the
+ // result of the comparison.
+ assert((LHSIsNull || RHSIsNull) &&
+ "Expected nullptr versus non-nullptr comparison at this point");
+
+ // The index is the operand that we assume is not null.
+ unsigned PtrIdx = LHSIsNull;
+ bool IsKnownNonNull;
+ bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>(
+ A, this, IRPosition::value(*(PtrIdx ? &RHSV : &LHSV)),
+ DepClassTy::REQUIRED, IsKnownNonNull);
+ if (!IsAssumedNonNull)
+ return false;
+
+ // The new value depends on the predicate, true for != and false for ==.
+ Constant *NewV =
+ ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE);
+ addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S,
+ getAnchorScope());
+ return true;
+ };
- // The new value depends on the predicate, true for != and false for ==.
- Constant *NewV =
- ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE);
- addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, getAnchorScope());
+ for (auto &LHSValue : LHSValues)
+ for (auto &RHSValue : RHSValues)
+ if (!CheckPair(*LHSValue.getValue(), *RHSValue.getValue()))
+ return false;
return true;
}
@@ -11152,9 +11344,8 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
SmallVectorImpl<ItemInfo> &Worklist,
SmallMapVector<const Function *, LivenessInfo, 4> &LivenessAAs) {
if (auto *CI = dyn_cast<CmpInst>(&I))
- if (handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1),
- CI->getPredicate(), II, Worklist))
- return true;
+ return handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1),
+ CI->getPredicate(), II, Worklist);
switch (I.getOpcode()) {
case Instruction::Select:
@@ -11272,12 +11463,12 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl {
ChangeStatus updateImpl(Attributor &A) override {
auto AssumedBefore = getAssumed();
- unsigned CSArgNo = getCallSiteArgNo();
+ unsigned ArgNo = getCalleeArgNo();
bool UsedAssumedInformation = false;
SmallVector<AA::ValueAndContext> Values;
auto CallSitePred = [&](AbstractCallSite ACS) {
- const auto CSArgIRP = IRPosition::callsite_argument(ACS, CSArgNo);
+ const auto CSArgIRP = IRPosition::callsite_argument(ACS, ArgNo);
if (CSArgIRP.getPositionKind() == IRP_INVALID)
return false;
@@ -11889,6 +12080,455 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl {
};
} // namespace
+/// ------------------------ Global Value Info -------------------------------
+namespace {
+struct AAGlobalValueInfoFloating : public AAGlobalValueInfo {
+ AAGlobalValueInfoFloating(const IRPosition &IRP, Attributor &A)
+ : AAGlobalValueInfo(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {}
+
+ bool checkUse(Attributor &A, const Use &U, bool &Follow,
+ SmallVectorImpl<const Value *> &Worklist) {
+ Instruction *UInst = dyn_cast<Instruction>(U.getUser());
+ if (!UInst) {
+ Follow = true;
+ return true;
+ }
+
+ LLVM_DEBUG(dbgs() << "[AAGlobalValueInfo] Check use: " << *U.get() << " in "
+ << *UInst << "\n");
+
+ if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
+ int Idx = &Cmp->getOperandUse(0) == &U;
+ if (isa<Constant>(Cmp->getOperand(Idx)))
+ return true;
+ return U == &getAnchorValue();
+ }
+
+ // Explicitly catch return instructions.
+ if (isa<ReturnInst>(UInst)) {
+ auto CallSitePred = [&](AbstractCallSite ACS) {
+ Worklist.push_back(ACS.getInstruction());
+ return true;
+ };
+ bool UsedAssumedInformation = false;
+ // TODO: We should traverse the uses or add a "non-call-site" CB.
+ if (!A.checkForAllCallSites(CallSitePred, *UInst->getFunction(),
+ /*RequireAllCallSites=*/true, this,
+ UsedAssumedInformation))
+ return false;
+ return true;
+ }
+
+ // For now we only use special logic for call sites. However, the tracker
+ // itself knows about a lot of other non-capturing cases already.
+ auto *CB = dyn_cast<CallBase>(UInst);
+ if (!CB)
+ return false;
+ // Direct calls are OK uses.
+ if (CB->isCallee(&U))
+ return true;
+ // Non-argument uses are scary.
+ if (!CB->isArgOperand(&U))
+ return false;
+ // TODO: Iterate callees.
+ auto *Fn = dyn_cast<Function>(CB->getCalledOperand());
+ if (!Fn || !A.isFunctionIPOAmendable(*Fn))
+ return false;
+
+ unsigned ArgNo = CB->getArgOperandNo(&U);
+ Worklist.push_back(Fn->getArg(ArgNo));
+ return true;
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ unsigned NumUsesBefore = Uses.size();
+
+ SmallPtrSet<const Value *, 8> Visited;
+ SmallVector<const Value *> Worklist;
+ Worklist.push_back(&getAnchorValue());
+
+ auto UsePred = [&](const Use &U, bool &Follow) -> bool {
+ Uses.insert(&U);
+ switch (DetermineUseCaptureKind(U, nullptr)) {
+ case UseCaptureKind::NO_CAPTURE:
+ return checkUse(A, U, Follow, Worklist);
+ case UseCaptureKind::MAY_CAPTURE:
+ return checkUse(A, U, Follow, Worklist);
+ case UseCaptureKind::PASSTHROUGH:
+ Follow = true;
+ return true;
+ }
+ return true;
+ };
+ auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) {
+ Uses.insert(&OldU);
+ return true;
+ };
+
+ while (!Worklist.empty()) {
+ const Value *V = Worklist.pop_back_val();
+ if (!Visited.insert(V).second)
+ continue;
+ if (!A.checkForAllUses(UsePred, *this, *V,
+ /* CheckBBLivenessOnly */ true,
+ DepClassTy::OPTIONAL,
+ /* IgnoreDroppableUses */ true, EquivalentUseCB)) {
+ return indicatePessimisticFixpoint();
+ }
+ }
+
+ return Uses.size() == NumUsesBefore ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
+
+ bool isPotentialUse(const Use &U) const override {
+ return !isValidState() || Uses.contains(&U);
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr(Attributor *A) const override {
+ return "[" + std::to_string(Uses.size()) + " uses]";
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(GlobalValuesTracked);
+ }
+
+private:
+ /// Set of (transitive) uses of this GlobalValue.
+ SmallPtrSet<const Use *, 8> Uses;
+};
+} // namespace
+
+/// ------------------------ Indirect Call Info -------------------------------
+namespace {
+struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo {
+ AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A)
+ : AAIndirectCallInfo(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees);
+ if (!MD && !A.isClosedWorldModule())
+ return;
+
+ if (MD) {
+ for (const auto &Op : MD->operands())
+ if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op))
+ PotentialCallees.insert(Callee);
+ } else if (A.isClosedWorldModule()) {
+ ArrayRef<Function *> IndirectlyCallableFunctions =
+ A.getInfoCache().getIndirectlyCallableFunctions(A);
+ PotentialCallees.insert(IndirectlyCallableFunctions.begin(),
+ IndirectlyCallableFunctions.end());
+ }
+
+ if (PotentialCallees.empty())
+ indicateOptimisticFixpoint();
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ CallBase *CB = cast<CallBase>(getCtxI());
+ const Use &CalleeUse = CB->getCalledOperandUse();
+ Value *FP = CB->getCalledOperand();
+
+ SmallSetVector<Function *, 4> AssumedCalleesNow;
+ bool AllCalleesKnownNow = AllCalleesKnown;
+
+ auto CheckPotentialCalleeUse = [&](Function &PotentialCallee,
+ bool &UsedAssumedInformation) {
+ const auto *GIAA = A.getAAFor<AAGlobalValueInfo>(
+ *this, IRPosition::value(PotentialCallee), DepClassTy::OPTIONAL);
+ if (!GIAA || GIAA->isPotentialUse(CalleeUse))
+ return true;
+ UsedAssumedInformation = !GIAA->isAtFixpoint();
+ return false;
+ };
+
+ auto AddPotentialCallees = [&]() {
+ for (auto *PotentialCallee : PotentialCallees) {
+ bool UsedAssumedInformation = false;
+ if (CheckPotentialCalleeUse(*PotentialCallee, UsedAssumedInformation))
+ AssumedCalleesNow.insert(PotentialCallee);
+ }
+ };
+
+ // Use simplification to find potential callees, if !callees was present,
+ // fallback to that set if necessary.
+ bool UsedAssumedInformation = false;
+ SmallVector<AA::ValueAndContext> Values;
+ if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values,
+ AA::ValueScope::AnyScope,
+ UsedAssumedInformation)) {
+ if (PotentialCallees.empty())
+ return indicatePessimisticFixpoint();
+ AddPotentialCallees();
+ }
+
+ // Try to find a reason for \p Fn not to be a potential callee. If none was
+ // found, add it to the assumed callees set.
+ auto CheckPotentialCallee = [&](Function &Fn) {
+ if (!PotentialCallees.empty() && !PotentialCallees.count(&Fn))
+ return false;
+
+ auto &CachedResult = FilterResults[&Fn];
+ if (CachedResult.has_value())
+ return CachedResult.value();
+
+ bool UsedAssumedInformation = false;
+ if (!CheckPotentialCalleeUse(Fn, UsedAssumedInformation)) {
+ if (!UsedAssumedInformation)
+ CachedResult = false;
+ return false;
+ }
+
+ int NumFnArgs = Fn.arg_size();
+ int NumCBArgs = CB->arg_size();
+
+ // Check if any excess argument (which we fill up with poison) is known to
+ // be UB on undef.
+ for (int I = NumCBArgs; I < NumFnArgs; ++I) {
+ bool IsKnown = false;
+ if (AA::hasAssumedIRAttr<Attribute::NoUndef>(
+ A, this, IRPosition::argument(*Fn.getArg(I)),
+ DepClassTy::OPTIONAL, IsKnown)) {
+ if (IsKnown)
+ CachedResult = false;
+ return false;
+ }
+ }
+
+ CachedResult = true;
+ return true;
+ };
+
+ // Check simplification result, prune known UB callees, also restrict it to
+ // the !callees set, if present.
+ for (auto &VAC : Values) {
+ if (isa<UndefValue>(VAC.getValue()))
+ continue;
+ if (isa<ConstantPointerNull>(VAC.getValue()) &&
+ VAC.getValue()->getType()->getPointerAddressSpace() == 0)
+ continue;
+ // TODO: Check for known UB, e.g., poison + noundef.
+ if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) {
+ if (CheckPotentialCallee(*VACFn))
+ AssumedCalleesNow.insert(VACFn);
+ continue;
+ }
+ if (!PotentialCallees.empty()) {
+ AddPotentialCallees();
+ break;
+ }
+ AllCalleesKnownNow = false;
+ }
+
+ if (AssumedCalleesNow == AssumedCallees &&
+ AllCalleesKnown == AllCalleesKnownNow)
+ return ChangeStatus::UNCHANGED;
+
+ std::swap(AssumedCallees, AssumedCalleesNow);
+ AllCalleesKnown = AllCalleesKnownNow;
+ return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+ // If we can't specialize at all, give up now.
+ if (!AllCalleesKnown && AssumedCallees.empty())
+ return ChangeStatus::UNCHANGED;
+
+ CallBase *CB = cast<CallBase>(getCtxI());
+ bool UsedAssumedInformation = false;
+ if (A.isAssumedDead(*CB, this, /*LivenessAA=*/nullptr,
+ UsedAssumedInformation))
+ return ChangeStatus::UNCHANGED;
+
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
+ Value *FP = CB->getCalledOperand();
+ if (FP->getType()->getPointerAddressSpace())
+ FP = new AddrSpaceCastInst(FP, PointerType::get(FP->getType(), 0),
+ FP->getName() + ".as0", CB);
+
+ bool CBIsVoid = CB->getType()->isVoidTy();
+ Instruction *IP = CB;
+ FunctionType *CSFT = CB->getFunctionType();
+ SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end());
+
+ // If we know all callees and there are none, the call site is (effectively)
+ // dead (or UB).
+ if (AssumedCallees.empty()) {
+ assert(AllCalleesKnown &&
+ "Expected all callees to be known if there are none.");
+ A.changeToUnreachableAfterManifest(CB);
+ return ChangeStatus::CHANGED;
+ }
+
+ // Special handling for the single callee case.
+ if (AllCalleesKnown && AssumedCallees.size() == 1) {
+ auto *NewCallee = AssumedCallees.front();
+ if (isLegalToPromote(*CB, NewCallee)) {
+ promoteCall(*CB, NewCallee, nullptr);
+ return ChangeStatus::CHANGED;
+ }
+ Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee),
+ CSArgs, CB->getName(), CB);
+ if (!CBIsVoid)
+ A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall);
+ A.deleteAfterManifest(*CB);
+ return ChangeStatus::CHANGED;
+ }
+
+ // For each potential value we create a conditional
+ //
+ // ```
+ // if (ptr == value) value(args);
+ // else ...
+ // ```
+ //
+ bool SpecializedForAnyCallees = false;
+ bool SpecializedForAllCallees = AllCalleesKnown;
+ ICmpInst *LastCmp = nullptr;
+ SmallVector<Function *, 8> SkippedAssumedCallees;
+ SmallVector<std::pair<CallInst *, Instruction *>> NewCalls;
+ for (Function *NewCallee : AssumedCallees) {
+ if (!A.shouldSpecializeCallSiteForCallee(*this, *CB, *NewCallee)) {
+ SkippedAssumedCallees.push_back(NewCallee);
+ SpecializedForAllCallees = false;
+ continue;
+ }
+ SpecializedForAnyCallees = true;
+
+ LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee);
+ Instruction *ThenTI =
+ SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false);
+ BasicBlock *CBBB = CB->getParent();
+ A.registerManifestAddedBasicBlock(*ThenTI->getParent());
+ A.registerManifestAddedBasicBlock(*CBBB);
+ auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode());
+ BasicBlock *ElseBB;
+ if (IP == CB) {
+ ElseBB = BasicBlock::Create(ThenTI->getContext(), "",
+ ThenTI->getFunction(), CBBB);
+ A.registerManifestAddedBasicBlock(*ElseBB);
+ IP = BranchInst::Create(CBBB, ElseBB);
+ SplitTI->replaceUsesOfWith(CBBB, ElseBB);
+ } else {
+ ElseBB = IP->getParent();
+ ThenTI->replaceUsesOfWith(ElseBB, CBBB);
+ }
+ CastInst *RetBC = nullptr;
+ CallInst *NewCall = nullptr;
+ if (isLegalToPromote(*CB, NewCallee)) {
+ auto *CBClone = cast<CallBase>(CB->clone());
+ CBClone->insertBefore(ThenTI);
+ NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC));
+ } else {
+ NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs,
+ CB->getName(), ThenTI);
+ }
+ NewCalls.push_back({NewCall, RetBC});
+ }
+
+ auto AttachCalleeMetadata = [&](CallBase &IndirectCB) {
+ if (!AllCalleesKnown)
+ return ChangeStatus::UNCHANGED;
+ MDBuilder MDB(IndirectCB.getContext());
+ MDNode *Callees = MDB.createCallees(SkippedAssumedCallees);
+ IndirectCB.setMetadata(LLVMContext::MD_callees, Callees);
+ return ChangeStatus::CHANGED;
+ };
+
+ if (!SpecializedForAnyCallees)
+ return AttachCalleeMetadata(*CB);
+
+ // Check if we need the fallback indirect call still.
+ if (SpecializedForAllCallees) {
+ LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext()));
+ LastCmp->eraseFromParent();
+ new UnreachableInst(IP->getContext(), IP);
+ IP->eraseFromParent();
+ } else {
+ auto *CBClone = cast<CallInst>(CB->clone());
+ CBClone->setName(CB->getName());
+ CBClone->insertBefore(IP);
+ NewCalls.push_back({CBClone, nullptr});
+ AttachCalleeMetadata(*CBClone);
+ }
+
+ // Check if we need a PHI to merge the results.
+ if (!CBIsVoid) {
+ auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(),
+ CB->getName() + ".phi",
+ &*CB->getParent()->getFirstInsertionPt());
+ for (auto &It : NewCalls) {
+ CallBase *NewCall = It.first;
+ Instruction *CallRet = It.second ? It.second : It.first;
+ if (CallRet->getType() == CB->getType())
+ PHI->addIncoming(CallRet, CallRet->getParent());
+ else if (NewCall->getType()->isVoidTy())
+ PHI->addIncoming(PoisonValue::get(CB->getType()),
+ NewCall->getParent());
+ else
+ llvm_unreachable("Call return should match or be void!");
+ }
+ A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI);
+ }
+
+ A.deleteAfterManifest(*CB);
+ Changed = ChangeStatus::CHANGED;
+
+ return Changed;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr(Attributor *A) const override {
+ return std::string(AllCalleesKnown ? "eliminate" : "specialize") +
+ " indirect call site with " + std::to_string(AssumedCallees.size()) +
+ " functions";
+ }
+
+ void trackStatistics() const override {
+ if (AllCalleesKnown) {
+ STATS_DECLTRACK(
+ Eliminated, CallSites,
+ "Number of indirect call sites eliminated via specialization")
+ } else {
+ STATS_DECLTRACK(Specialized, CallSites,
+ "Number of indirect call sites specialized")
+ }
+ }
+
+ bool foreachCallee(function_ref<bool(Function *)> CB) const override {
+ return isValidState() && AllCalleesKnown && all_of(AssumedCallees, CB);
+ }
+
+private:
+ /// Map to remember filter results.
+ DenseMap<Function *, std::optional<bool>> FilterResults;
+
+ /// If the !callee metadata was present, this set will contain all potential
+ /// callees (superset).
+ SmallSetVector<Function *, 4> PotentialCallees;
+
+ /// This set contains all currently assumed calllees, which might grow over
+ /// time.
+ SmallSetVector<Function *, 4> AssumedCallees;
+
+ /// Flag to indicate if all possible callees are in the AssumedCallees set or
+ /// if there could be others.
+ bool AllCalleesKnown = true;
+};
+} // namespace
+
/// ------------------------ Address Space ------------------------------------
namespace {
struct AAAddressSpaceImpl : public AAAddressSpace {
@@ -11961,8 +12601,13 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
// CGSCC if the AA is run on CGSCC instead of the entire module.
if (!A.isRunOn(Inst->getFunction()))
return true;
- if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst))
+ if (isa<LoadInst>(Inst))
MakeChange(Inst, const_cast<Use &>(U));
+ if (isa<StoreInst>(Inst)) {
+ // We only make changes if the use is the pointer operand.
+ if (U.getOperandNo() == 1)
+ MakeChange(Inst, const_cast<Use &>(U));
+ }
return true;
};
@@ -12064,6 +12709,224 @@ struct AAAddressSpaceCallSiteArgument final : AAAddressSpaceImpl {
};
} // namespace
+/// ----------- Allocation Info ----------
+namespace {
+struct AAAllocationInfoImpl : public AAAllocationInfo {
+ AAAllocationInfoImpl(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfo(IRP, A) {}
+
+ std::optional<TypeSize> getAllocatedSize() const override {
+ assert(isValidState() && "the AA is invalid");
+ return AssumedAllocatedSize;
+ }
+
+ std::optional<TypeSize> findInitialAllocationSize(Instruction *I,
+ const DataLayout &DL) {
+
+ // TODO: implement case for malloc like instructions
+ switch (I->getOpcode()) {
+ case Instruction::Alloca: {
+ AllocaInst *AI = cast<AllocaInst>(I);
+ return AI->getAllocationSize(DL);
+ }
+ default:
+ return std::nullopt;
+ }
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+
+ const IRPosition &IRP = getIRPosition();
+ Instruction *I = IRP.getCtxI();
+
+ // TODO: update check for malloc like calls
+ if (!isa<AllocaInst>(I))
+ return indicatePessimisticFixpoint();
+
+ bool IsKnownNoCapture;
+ if (!AA::hasAssumedIRAttr<Attribute::NoCapture>(
+ A, this, IRP, DepClassTy::OPTIONAL, IsKnownNoCapture))
+ return indicatePessimisticFixpoint();
+
+ const AAPointerInfo *PI =
+ A.getOrCreateAAFor<AAPointerInfo>(IRP, *this, DepClassTy::REQUIRED);
+
+ if (!PI)
+ return indicatePessimisticFixpoint();
+
+ if (!PI->getState().isValidState())
+ return indicatePessimisticFixpoint();
+
+ const DataLayout &DL = A.getDataLayout();
+ const auto AllocationSize = findInitialAllocationSize(I, DL);
+
+ // If allocation size is nullopt, we give up.
+ if (!AllocationSize)
+ return indicatePessimisticFixpoint();
+
+ // For zero sized allocations, we give up.
+ // Since we can't reduce further
+ if (*AllocationSize == 0)
+ return indicatePessimisticFixpoint();
+
+ int64_t BinSize = PI->numOffsetBins();
+
+ // TODO: implement for multiple bins
+ if (BinSize > 1)
+ return indicatePessimisticFixpoint();
+
+ if (BinSize == 0) {
+ auto NewAllocationSize = std::optional<TypeSize>(TypeSize(0, false));
+ if (!changeAllocationSize(NewAllocationSize))
+ return ChangeStatus::UNCHANGED;
+ return ChangeStatus::CHANGED;
+ }
+
+ // TODO: refactor this to be part of multiple bin case
+ const auto &It = PI->begin();
+
+ // TODO: handle if Offset is not zero
+ if (It->first.Offset != 0)
+ return indicatePessimisticFixpoint();
+
+ uint64_t SizeOfBin = It->first.Offset + It->first.Size;
+
+ if (SizeOfBin >= *AllocationSize)
+ return indicatePessimisticFixpoint();
+
+ auto NewAllocationSize =
+ std::optional<TypeSize>(TypeSize(SizeOfBin * 8, false));
+
+ if (!changeAllocationSize(NewAllocationSize))
+ return ChangeStatus::UNCHANGED;
+
+ return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractAttribute::manifest(...).
+ ChangeStatus manifest(Attributor &A) override {
+
+ assert(isValidState() &&
+ "Manifest should only be called if the state is valid.");
+
+ Instruction *I = getIRPosition().getCtxI();
+
+ auto FixedAllocatedSizeInBits = getAllocatedSize()->getFixedValue();
+
+ unsigned long NumBytesToAllocate = (FixedAllocatedSizeInBits + 7) / 8;
+
+ switch (I->getOpcode()) {
+ // TODO: add case for malloc like calls
+ case Instruction::Alloca: {
+
+ AllocaInst *AI = cast<AllocaInst>(I);
+
+ Type *CharType = Type::getInt8Ty(I->getContext());
+
+ auto *NumBytesToValue =
+ ConstantInt::get(I->getContext(), APInt(32, NumBytesToAllocate));
+
+ AllocaInst *NewAllocaInst =
+ new AllocaInst(CharType, AI->getAddressSpace(), NumBytesToValue,
+ AI->getAlign(), AI->getName(), AI->getNextNode());
+
+ if (A.changeAfterManifest(IRPosition::inst(*AI), *NewAllocaInst))
+ return ChangeStatus::CHANGED;
+
+ break;
+ }
+ default:
+ break;
+ }
+
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// See AbstractAttribute::getAsStr().
+ const std::string getAsStr(Attributor *A) const override {
+ if (!isValidState())
+ return "allocationinfo(<invalid>)";
+ return "allocationinfo(" +
+ (AssumedAllocatedSize == HasNoAllocationSize
+ ? "none"
+ : std::to_string(AssumedAllocatedSize->getFixedValue())) +
+ ")";
+ }
+
+private:
+ std::optional<TypeSize> AssumedAllocatedSize = HasNoAllocationSize;
+
+ // Maintain the computed allocation size of the object.
+ // Returns (bool) weather the size of the allocation was modified or not.
+ bool changeAllocationSize(std::optional<TypeSize> Size) {
+ if (AssumedAllocatedSize == HasNoAllocationSize ||
+ AssumedAllocatedSize != Size) {
+ AssumedAllocatedSize = Size;
+ return true;
+ }
+ return false;
+ }
+};
+
+struct AAAllocationInfoFloating : AAAllocationInfoImpl {
+ AAAllocationInfoFloating(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfoImpl(IRP, A) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FLOATING_ATTR(allocationinfo);
+ }
+};
+
+struct AAAllocationInfoReturned : AAAllocationInfoImpl {
+ AAAllocationInfoReturned(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfoImpl(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+ // TODO: we don't rewrite function argument for now because it will need to
+ // rewrite the function signature and all call sites
+ (void)indicatePessimisticFixpoint();
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_FNRET_ATTR(allocationinfo);
+ }
+};
+
+struct AAAllocationInfoCallSiteReturned : AAAllocationInfoImpl {
+ AAAllocationInfoCallSiteReturned(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfoImpl(IRP, A) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSRET_ATTR(allocationinfo);
+ }
+};
+
+struct AAAllocationInfoArgument : AAAllocationInfoImpl {
+ AAAllocationInfoArgument(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfoImpl(IRP, A) {}
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_ARG_ATTR(allocationinfo);
+ }
+};
+
+struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl {
+ AAAllocationInfoCallSiteArgument(const IRPosition &IRP, Attributor &A)
+ : AAAllocationInfoImpl(IRP, A) {}
+
+ /// See AbstractAttribute::initialize(...).
+ void initialize(Attributor &A) override {
+
+ (void)indicatePessimisticFixpoint();
+ }
+
+ void trackStatistics() const override {
+ STATS_DECLTRACK_CSARG_ATTR(allocationinfo);
+ }
+};
+} // namespace
+
const char AANoUnwind::ID = 0;
const char AANoSync::ID = 0;
const char AANoFree::ID = 0;
@@ -12097,6 +12960,10 @@ const char AAPointerInfo::ID = 0;
const char AAAssumptionInfo::ID = 0;
const char AAUnderlyingObjects::ID = 0;
const char AAAddressSpace::ID = 0;
+const char AAAllocationInfo::ID = 0;
+const char AAIndirectCallInfo::ID = 0;
+const char AAGlobalValueInfo::ID = 0;
+const char AADenormalFPMath::ID = 0;
// Macro magic to create the static generator function for attributes that
// follow the naming scheme.
@@ -12143,6 +13010,18 @@ const char AAAddressSpace::ID = 0;
return *AA; \
}
+#define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
+ CLASS *AA = nullptr; \
+ switch (IRP.getPositionKind()) { \
+ SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \
+ default: \
+ llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \
+ " position!"); \
+ } \
+ return *AA; \
+ }
+
#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
CLASS *AA = nullptr; \
@@ -12215,17 +13094,24 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFPClass)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo)
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAddressSpace)
+CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAllocationInfo)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects)
+CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite,
+ AAIndirectCallInfo)
+CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_FLOAT, Floating,
+ AAGlobalValueInfo)
+
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability)
+CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADenormalFPMath)
CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
@@ -12234,5 +13120,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION
#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION
#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION
+#undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION
#undef SWITCH_PK_CREATE
#undef SWITCH_PK_INV
diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
index 93d15f59a036..5cc8258a495a 100644
--- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
+++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
@@ -85,7 +85,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
LLVMContext &Ctx = M.getContext();
FunctionCallee C = M.getOrInsertFunction(
"__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),
- Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx));
+ PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx));
Function *F = cast<Function>(C.getCallee());
// Take over the existing function. The frontend emits a weak stub so that the
// linker knows about the symbol; this pass replaces the function body.
@@ -110,9 +110,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) {
BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F);
IRBuilder<> IRBFail(TrapBB);
- FunctionCallee CFICheckFailFn =
- M.getOrInsertFunction("__cfi_check_fail", Type::getVoidTy(Ctx),
- Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx));
+ FunctionCallee CFICheckFailFn = M.getOrInsertFunction(
+ "__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx),
+ PointerType::getUnqual(Ctx));
IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});
IRBFail.CreateBr(ExitBB);
diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
index 01834015f3fd..4f65748c19e6 100644
--- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
+++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
@@ -174,6 +174,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) {
NF->setComdat(F.getComdat());
F.getParent()->getFunctionList().insert(F.getIterator(), NF);
NF->takeName(&F);
+ NF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat;
// Loop over all the callers of the function, transforming the call sites
// to pass in a smaller number of arguments into the new function.
@@ -248,7 +249,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) {
NF->addMetadata(KindID, *Node);
// Fix up any BlockAddresses that refer to the function.
- F.replaceAllUsesWith(ConstantExpr::getBitCast(NF, F.getType()));
+ F.replaceAllUsesWith(NF);
// Delete the bitcast that we just created, so that NF does not
// appear to be address-taken.
NF->removeDeadConstantUsers();
@@ -877,6 +878,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
// it again.
F->getParent()->getFunctionList().insert(F->getIterator(), NF);
NF->takeName(F);
+ NF->IsNewDbgInfoFormat = F->IsNewDbgInfoFormat;
// Loop over all the callers of the function, transforming the call sites to
// pass in a smaller number of arguments into the new function.
diff --git a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp
index fa56a5b564ae..48ef0772e800 100644
--- a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp
+++ b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp
@@ -7,8 +7,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/EmbedBitcodePass.h"
-#include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/Bitcode/BitcodeWriterPass.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
@@ -16,10 +14,8 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h"
-#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
-#include <memory>
#include <string>
using namespace llvm;
@@ -34,19 +30,9 @@ PreservedAnalyses EmbedBitcodePass::run(Module &M, ModuleAnalysisManager &AM) {
report_fatal_error(
"EmbedBitcode pass currently only supports ELF object format",
/*gen_crash_diag=*/false);
-
- std::unique_ptr<Module> NewModule = CloneModule(M);
- MPM.run(*NewModule, AM);
-
std::string Data;
raw_string_ostream OS(Data);
- if (IsThinLTO)
- ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(*NewModule, AM);
- else
- BitcodeWriterPass(OS, /*ShouldPreserveUseListOrder=*/false, EmitLTOSummary)
- .run(*NewModule, AM);
-
+ ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(M, AM);
embedBufferInModule(M, MemoryBufferRef(Data, "ModuleData"), ".llvm.lto");
-
return PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
index 74931e1032d1..9cf4e448c9b6 100644
--- a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp
@@ -11,38 +11,57 @@
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
#define DEBUG_TYPE "forceattrs"
-static cl::list<std::string>
- ForceAttributes("force-attribute", cl::Hidden,
- cl::desc("Add an attribute to a function. This should be a "
- "pair of 'function-name:attribute-name', for "
- "example -force-attribute=foo:noinline. This "
- "option can be specified multiple times."));
+static cl::list<std::string> ForceAttributes(
+ "force-attribute", cl::Hidden,
+ cl::desc(
+ "Add an attribute to a function. This can be a "
+ "pair of 'function-name:attribute-name', to apply an attribute to a "
+ "specific function. For "
+ "example -force-attribute=foo:noinline. Specifying only an attribute "
+ "will apply the attribute to every function in the module. This "
+ "option can be specified multiple times."));
static cl::list<std::string> ForceRemoveAttributes(
"force-remove-attribute", cl::Hidden,
- cl::desc("Remove an attribute from a function. This should be a "
- "pair of 'function-name:attribute-name', for "
- "example -force-remove-attribute=foo:noinline. This "
+ cl::desc("Remove an attribute from a function. This can be a "
+ "pair of 'function-name:attribute-name' to remove an attribute "
+ "from a specific function. For "
+ "example -force-remove-attribute=foo:noinline. Specifying only an "
+ "attribute will remove the attribute from all functions in the "
+ "module. This "
"option can be specified multiple times."));
+static cl::opt<std::string> CSVFilePath(
+ "forceattrs-csv-path", cl::Hidden,
+ cl::desc(
+ "Path to CSV file containing lines of function names and attributes to "
+ "add to them in the form of `f1,attr1` or `f2,attr2=str`."));
+
/// If F has any forced attributes given on the command line, add them.
/// If F has any forced remove attributes given on the command line, remove
/// them. When both force and force-remove are given to a function, the latter
/// takes precedence.
static void forceAttributes(Function &F) {
auto ParseFunctionAndAttr = [&](StringRef S) {
- auto Kind = Attribute::None;
- auto KV = StringRef(S).split(':');
- if (KV.first != F.getName())
- return Kind;
- Kind = Attribute::getAttrKindFromName(KV.second);
+ StringRef AttributeText;
+ if (S.contains(':')) {
+ auto KV = StringRef(S).split(':');
+ if (KV.first != F.getName())
+ return Attribute::None;
+ AttributeText = KV.second;
+ } else {
+ AttributeText = S;
+ }
+ auto Kind = Attribute::getAttrKindFromName(AttributeText);
if (Kind == Attribute::None || !Attribute::canUseAsFnAttr(Kind)) {
- LLVM_DEBUG(dbgs() << "ForcedAttribute: " << KV.second
+ LLVM_DEBUG(dbgs() << "ForcedAttribute: " << AttributeText
<< " unknown or not a function attribute!\n");
}
return Kind;
@@ -69,12 +88,52 @@ static bool hasForceAttributes() {
PreservedAnalyses ForceFunctionAttrsPass::run(Module &M,
ModuleAnalysisManager &) {
- if (!hasForceAttributes())
- return PreservedAnalyses::all();
-
- for (Function &F : M.functions())
- forceAttributes(F);
-
- // Just conservatively invalidate analyses, this isn't likely to be important.
- return PreservedAnalyses::none();
+ bool Changed = false;
+ if (!CSVFilePath.empty()) {
+ auto BufferOrError = MemoryBuffer::getFileOrSTDIN(CSVFilePath);
+ if (!BufferOrError)
+ report_fatal_error("Cannot open CSV file.");
+ StringRef Buffer = BufferOrError.get()->getBuffer();
+ auto MemoryBuffer = MemoryBuffer::getMemBuffer(Buffer);
+ line_iterator It(*MemoryBuffer);
+ for (; !It.is_at_end(); ++It) {
+ auto SplitPair = It->split(',');
+ if (SplitPair.second.empty())
+ continue;
+ Function *Func = M.getFunction(SplitPair.first);
+ if (Func) {
+ if (Func->isDeclaration())
+ continue;
+ auto SecondSplitPair = SplitPair.second.split('=');
+ if (!SecondSplitPair.second.empty()) {
+ Func->addFnAttr(SecondSplitPair.first, SecondSplitPair.second);
+ Changed = true;
+ } else {
+ auto AttrKind = Attribute::getAttrKindFromName(SplitPair.second);
+ if (AttrKind != Attribute::None &&
+ Attribute::canUseAsFnAttr(AttrKind)) {
+ // TODO: There could be string attributes without a value, we should
+ // support those, too.
+ Func->addFnAttr(AttrKind);
+ Changed = true;
+ } else
+ errs() << "Cannot add " << SplitPair.second
+ << " as an attribute name.\n";
+ }
+ } else {
+ errs() << "Function in CSV file at line " << It.line_number()
+ << " does not exist.\n";
+ // TODO: `report_fatal_error at end of pass for missing functions.
+ continue;
+ }
+ }
+ }
+ if (hasForceAttributes()) {
+ for (Function &F : M.functions())
+ forceAttributes(F);
+ Changed = true;
+ }
+ // Just conservatively invalidate analyses if we've made any changes, this
+ // isn't likely to be important.
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 34299f9dbb23..7c277518b21d 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -110,6 +110,39 @@ using SCCNodeSet = SmallSetVector<Function *, 8>;
} // end anonymous namespace
+static void addLocAccess(MemoryEffects &ME, const MemoryLocation &Loc,
+ ModRefInfo MR, AAResults &AAR) {
+ // Ignore accesses to known-invariant or local memory.
+ MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true);
+ if (isNoModRef(MR))
+ return;
+
+ const Value *UO = getUnderlyingObject(Loc.Ptr);
+ assert(!isa<AllocaInst>(UO) &&
+ "Should have been handled by getModRefInfoMask()");
+ if (isa<Argument>(UO)) {
+ ME |= MemoryEffects::argMemOnly(MR);
+ return;
+ }
+
+ // If it's not an identified object, it might be an argument.
+ if (!isIdentifiedObject(UO))
+ ME |= MemoryEffects::argMemOnly(MR);
+ ME |= MemoryEffects(IRMemLocation::Other, MR);
+}
+
+static void addArgLocs(MemoryEffects &ME, const CallBase *Call,
+ ModRefInfo ArgMR, AAResults &AAR) {
+ for (const Value *Arg : Call->args()) {
+ if (!Arg->getType()->isPtrOrPtrVectorTy())
+ continue;
+
+ addLocAccess(ME,
+ MemoryLocation::getBeforeOrAfter(Arg, Call->getAAMetadata()),
+ ArgMR, AAR);
+ }
+}
+
/// Returns the memory access attribute for function F using AAR for AA results,
/// where SCCNodes is the current SCC.
///
@@ -118,54 +151,48 @@ using SCCNodeSet = SmallSetVector<Function *, 8>;
/// result will be based only on AA results for the function declaration; it
/// will be assumed that some other (perhaps less optimized) version of the
/// function may be selected at link time.
-static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody,
- AAResults &AAR,
- const SCCNodeSet &SCCNodes) {
+///
+/// The return value is split into two parts: Memory effects that always apply,
+/// and additional memory effects that apply if any of the functions in the SCC
+/// can access argmem.
+static std::pair<MemoryEffects, MemoryEffects>
+checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR,
+ const SCCNodeSet &SCCNodes) {
MemoryEffects OrigME = AAR.getMemoryEffects(&F);
if (OrigME.doesNotAccessMemory())
// Already perfect!
- return OrigME;
+ return {OrigME, MemoryEffects::none()};
if (!ThisBody)
- return OrigME;
+ return {OrigME, MemoryEffects::none()};
MemoryEffects ME = MemoryEffects::none();
+ // Additional locations accessed if the SCC accesses argmem.
+ MemoryEffects RecursiveArgME = MemoryEffects::none();
+
// Inalloca and preallocated arguments are always clobbered by the call.
if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) ||
F.getAttributes().hasAttrSomewhere(Attribute::Preallocated))
ME |= MemoryEffects::argMemOnly(ModRefInfo::ModRef);
- auto AddLocAccess = [&](const MemoryLocation &Loc, ModRefInfo MR) {
- // Ignore accesses to known-invariant or local memory.
- MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true);
- if (isNoModRef(MR))
- return;
-
- const Value *UO = getUnderlyingObject(Loc.Ptr);
- assert(!isa<AllocaInst>(UO) &&
- "Should have been handled by getModRefInfoMask()");
- if (isa<Argument>(UO)) {
- ME |= MemoryEffects::argMemOnly(MR);
- return;
- }
-
- // If it's not an identified object, it might be an argument.
- if (!isIdentifiedObject(UO))
- ME |= MemoryEffects::argMemOnly(MR);
- ME |= MemoryEffects(IRMemLocation::Other, MR);
- };
// Scan the function body for instructions that may read or write memory.
for (Instruction &I : instructions(F)) {
// Some instructions can be ignored even if they read or write memory.
// Detect these now, skipping to the next instruction if one is found.
if (auto *Call = dyn_cast<CallBase>(&I)) {
- // Ignore calls to functions in the same SCC, as long as the call sites
- // don't have operand bundles. Calls with operand bundles are allowed to
- // have memory effects not described by the memory effects of the call
- // target.
+ // We can optimistically ignore calls to functions in the same SCC, with
+ // two caveats:
+ // * Calls with operand bundles may have additional effects.
+ // * Argument memory accesses may imply additional effects depending on
+ // what the argument location is.
if (!Call->hasOperandBundles() && Call->getCalledFunction() &&
- SCCNodes.count(Call->getCalledFunction()))
+ SCCNodes.count(Call->getCalledFunction())) {
+ // Keep track of which additional locations are accessed if the SCC
+ // turns out to access argmem.
+ addArgLocs(RecursiveArgME, Call, ModRefInfo::ModRef, AAR);
continue;
+ }
+
MemoryEffects CallME = AAR.getMemoryEffects(Call);
// If the call doesn't access memory, we're done.
@@ -190,15 +217,8 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody,
// Check whether all pointer arguments point to local memory, and
// ignore calls that only access local memory.
ModRefInfo ArgMR = CallME.getModRef(IRMemLocation::ArgMem);
- if (ArgMR != ModRefInfo::NoModRef) {
- for (const Use &U : Call->args()) {
- const Value *Arg = U;
- if (!Arg->getType()->isPtrOrPtrVectorTy())
- continue;
-
- AddLocAccess(MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()), ArgMR);
- }
- }
+ if (ArgMR != ModRefInfo::NoModRef)
+ addArgLocs(ME, Call, ArgMR, AAR);
continue;
}
@@ -222,15 +242,15 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody,
if (I.isVolatile())
ME |= MemoryEffects::inaccessibleMemOnly(MR);
- AddLocAccess(*Loc, MR);
+ addLocAccess(ME, *Loc, MR, AAR);
}
- return OrigME & ME;
+ return {OrigME & ME, RecursiveArgME};
}
MemoryEffects llvm::computeFunctionBodyMemoryAccess(Function &F,
AAResults &AAR) {
- return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {});
+ return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}).first;
}
/// Deduce readonly/readnone/writeonly attributes for the SCC.
@@ -238,24 +258,37 @@ template <typename AARGetterT>
static void addMemoryAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter,
SmallSet<Function *, 8> &Changed) {
MemoryEffects ME = MemoryEffects::none();
+ MemoryEffects RecursiveArgME = MemoryEffects::none();
for (Function *F : SCCNodes) {
// Call the callable parameter to look up AA results for this function.
AAResults &AAR = AARGetter(*F);
// Non-exact function definitions may not be selected at link time, and an
// alternative version that writes to memory may be selected. See the
// comment on GlobalValue::isDefinitionExact for more details.
- ME |= checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes);
+ auto [FnME, FnRecursiveArgME] =
+ checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes);
+ ME |= FnME;
+ RecursiveArgME |= FnRecursiveArgME;
// Reached bottom of the lattice, we will not be able to improve the result.
if (ME == MemoryEffects::unknown())
return;
}
+ // If the SCC accesses argmem, add recursive accesses resulting from that.
+ ModRefInfo ArgMR = ME.getModRef(IRMemLocation::ArgMem);
+ if (ArgMR != ModRefInfo::NoModRef)
+ ME |= RecursiveArgME & MemoryEffects(ArgMR);
+
for (Function *F : SCCNodes) {
MemoryEffects OldME = F->getMemoryEffects();
MemoryEffects NewME = ME & OldME;
if (NewME != OldME) {
++NumMemoryAttr;
F->setMemoryEffects(NewME);
+ // Remove conflicting writable attributes.
+ if (!isModSet(NewME.getModRef(IRMemLocation::ArgMem)))
+ for (Argument &A : F->args())
+ A.removeAttr(Attribute::Writable);
Changed.insert(F);
}
}
@@ -625,7 +658,15 @@ determinePointerAccessAttrs(Argument *A,
// must be a data operand (e.g. argument or operand bundle)
const unsigned UseIndex = CB.getDataOperandNo(U);
- if (!CB.doesNotCapture(UseIndex)) {
+ // Some intrinsics (for instance ptrmask) do not capture their results,
+ // but return results thas alias their pointer argument, and thus should
+ // be handled like GEP or addrspacecast above.
+ if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
+ &CB, /*MustPreserveNullness=*/false)) {
+ for (Use &UU : CB.uses())
+ if (Visited.insert(&UU).second)
+ Worklist.push_back(&UU);
+ } else if (!CB.doesNotCapture(UseIndex)) {
if (!CB.onlyReadsMemory())
// If the callee can save a copy into other memory, then simply
// scanning uses of the call is insufficient. We have no way
@@ -639,7 +680,8 @@ determinePointerAccessAttrs(Argument *A,
Worklist.push_back(&UU);
}
- if (CB.doesNotAccessMemory())
+ ModRefInfo ArgMR = CB.getMemoryEffects().getModRef(IRMemLocation::ArgMem);
+ if (isNoModRef(ArgMR))
continue;
if (Function *F = CB.getCalledFunction())
@@ -654,9 +696,9 @@ determinePointerAccessAttrs(Argument *A,
// invokes with operand bundles.
if (CB.doesNotAccessMemory(UseIndex)) {
/* nop */
- } else if (CB.onlyReadsMemory() || CB.onlyReadsMemory(UseIndex)) {
+ } else if (!isModSet(ArgMR) || CB.onlyReadsMemory(UseIndex)) {
IsRead = true;
- } else if (CB.hasFnAttr(Attribute::WriteOnly) ||
+ } else if (!isRefSet(ArgMR) ||
CB.dataOperandHasImpliedAttr(UseIndex, Attribute::WriteOnly)) {
IsWrite = true;
} else {
@@ -810,6 +852,9 @@ static bool addAccessAttr(Argument *A, Attribute::AttrKind R) {
A->removeAttr(Attribute::WriteOnly);
A->removeAttr(Attribute::ReadOnly);
A->removeAttr(Attribute::ReadNone);
+ // Remove conflicting writable attribute.
+ if (R == Attribute::ReadNone || R == Attribute::ReadOnly)
+ A->removeAttr(Attribute::Writable);
A->addAttr(R);
if (R == Attribute::ReadOnly)
++NumReadOnlyArg;
@@ -1720,7 +1765,8 @@ static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) {
template <typename AARGetterT>
static SmallSet<Function *, 8>
-deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) {
+deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter,
+ bool ArgAttrsOnly) {
SCCNodesResult Nodes = createSCCNodeSet(Functions);
// Bail if the SCC only contains optnone functions.
@@ -1728,6 +1774,10 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) {
return {};
SmallSet<Function *, 8> Changed;
+ if (ArgAttrsOnly) {
+ addArgumentAttrs(Nodes.SCCNodes, Changed);
+ return Changed;
+ }
addArgumentReturnedAttrs(Nodes.SCCNodes, Changed);
addMemoryAttrs(Nodes.SCCNodes, AARGetter, Changed);
@@ -1762,10 +1812,13 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C,
LazyCallGraph &CG,
CGSCCUpdateResult &) {
// Skip non-recursive functions if requested.
+ // Only infer argument attributes for non-recursive functions, because
+ // it can affect optimization behavior in conjunction with noalias.
+ bool ArgAttrsOnly = false;
if (C.size() == 1 && SkipNonRecursive) {
LazyCallGraph::Node &N = *C.begin();
if (!N->lookup(N))
- return PreservedAnalyses::all();
+ ArgAttrsOnly = true;
}
FunctionAnalysisManager &FAM =
@@ -1782,7 +1835,8 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C,
Functions.push_back(&N.getFunction());
}
- auto ChangedFunctions = deriveAttrsInPostOrder(Functions, AARGetter);
+ auto ChangedFunctions =
+ deriveAttrsInPostOrder(Functions, AARGetter, ArgAttrsOnly);
if (ChangedFunctions.empty())
return PreservedAnalyses::all();
@@ -1818,7 +1872,7 @@ void PostOrderFunctionAttrsPass::printPipeline(
static_cast<PassInfoMixin<PostOrderFunctionAttrsPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
if (SkipNonRecursive)
- OS << "<skip-non-recursive>";
+ OS << "<skip-non-recursive-function-attrs>";
}
template <typename AARGetterT>
diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp
index f635b14cd2a9..9c546b531dff 100644
--- a/llvm/lib/Transforms/IPO/FunctionImport.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp
@@ -16,7 +16,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/IR/AutoUpgrade.h"
@@ -272,7 +271,7 @@ class GlobalsImporter final {
function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
IsPrevailing;
FunctionImporter::ImportMapTy &ImportList;
- StringMap<FunctionImporter::ExportSetTy> *const ExportLists;
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists;
bool shouldImportGlobal(const ValueInfo &VI) {
const auto &GVS = DefinedGVSummaries.find(VI.getGUID());
@@ -357,7 +356,7 @@ public:
function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
IsPrevailing,
FunctionImporter::ImportMapTy &ImportList,
- StringMap<FunctionImporter::ExportSetTy> *ExportLists)
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists)
: Index(Index), DefinedGVSummaries(DefinedGVSummaries),
IsPrevailing(IsPrevailing), ImportList(ImportList),
ExportLists(ExportLists) {}
@@ -370,6 +369,29 @@ public:
}
};
+/// Determine the list of imports and exports for each module.
+class ModuleImportsManager final {
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ IsPrevailing;
+ const ModuleSummaryIndex &Index;
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists;
+
+public:
+ ModuleImportsManager(
+ function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+ IsPrevailing,
+ const ModuleSummaryIndex &Index,
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists = nullptr)
+ : IsPrevailing(IsPrevailing), Index(Index), ExportLists(ExportLists) {}
+
+ /// Given the list of globals defined in a module, compute the list of imports
+ /// as well as the list of "exports", i.e. the list of symbols referenced from
+ /// another module (that may require promotion).
+ void computeImportForModule(const GVSummaryMapTy &DefinedGVSummaries,
+ StringRef ModName,
+ FunctionImporter::ImportMapTy &ImportList);
+};
+
static const char *
getFailureName(FunctionImporter::ImportFailureReason Reason) {
switch (Reason) {
@@ -403,7 +425,7 @@ static void computeImportForFunction(
isPrevailing,
SmallVectorImpl<EdgeInfo> &Worklist, GlobalsImporter &GVImporter,
FunctionImporter::ImportMapTy &ImportList,
- StringMap<FunctionImporter::ExportSetTy> *ExportLists,
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists,
FunctionImporter::ImportThresholdsTy &ImportThresholds) {
GVImporter.onImportingSummary(Summary);
static int ImportCount = 0;
@@ -482,7 +504,7 @@ static void computeImportForFunction(
continue;
}
- FunctionImporter::ImportFailureReason Reason;
+ FunctionImporter::ImportFailureReason Reason{};
CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold,
Summary.modulePath(), Reason);
if (!CalleeSummary) {
@@ -567,20 +589,13 @@ static void computeImportForFunction(
}
}
-/// Given the list of globals defined in a module, compute the list of imports
-/// as well as the list of "exports", i.e. the list of symbols referenced from
-/// another module (that may require promotion).
-static void ComputeImportForModule(
- const GVSummaryMapTy &DefinedGVSummaries,
- function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
- isPrevailing,
- const ModuleSummaryIndex &Index, StringRef ModName,
- FunctionImporter::ImportMapTy &ImportList,
- StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) {
+void ModuleImportsManager::computeImportForModule(
+ const GVSummaryMapTy &DefinedGVSummaries, StringRef ModName,
+ FunctionImporter::ImportMapTy &ImportList) {
// Worklist contains the list of function imported in this module, for which
// we will analyse the callees and may import further down the callgraph.
SmallVector<EdgeInfo, 128> Worklist;
- GlobalsImporter GVI(Index, DefinedGVSummaries, isPrevailing, ImportList,
+ GlobalsImporter GVI(Index, DefinedGVSummaries, IsPrevailing, ImportList,
ExportLists);
FunctionImporter::ImportThresholdsTy ImportThresholds;
@@ -603,7 +618,7 @@ static void ComputeImportForModule(
continue;
LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n");
computeImportForFunction(*FuncSummary, Index, ImportInstrLimit,
- DefinedGVSummaries, isPrevailing, Worklist, GVI,
+ DefinedGVSummaries, IsPrevailing, Worklist, GVI,
ImportList, ExportLists, ImportThresholds);
}
@@ -615,7 +630,7 @@ static void ComputeImportForModule(
if (auto *FS = dyn_cast<FunctionSummary>(Summary))
computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries,
- isPrevailing, Worklist, GVI, ImportList,
+ IsPrevailing, Worklist, GVI, ImportList,
ExportLists, ImportThresholds);
}
@@ -671,10 +686,10 @@ static unsigned numGlobalVarSummaries(const ModuleSummaryIndex &Index,
#endif
#ifndef NDEBUG
-static bool
-checkVariableImport(const ModuleSummaryIndex &Index,
- StringMap<FunctionImporter::ImportMapTy> &ImportLists,
- StringMap<FunctionImporter::ExportSetTy> &ExportLists) {
+static bool checkVariableImport(
+ const ModuleSummaryIndex &Index,
+ DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists,
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) {
DenseSet<GlobalValue::GUID> FlattenedImports;
@@ -702,7 +717,7 @@ checkVariableImport(const ModuleSummaryIndex &Index,
for (auto &ExportPerModule : ExportLists)
for (auto &VI : ExportPerModule.second)
if (!FlattenedImports.count(VI.getGUID()) &&
- IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first(), VI))
+ IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first, VI))
return false;
return true;
@@ -712,19 +727,19 @@ checkVariableImport(const ModuleSummaryIndex &Index,
/// Compute all the import and export for every module using the Index.
void llvm::ComputeCrossModuleImport(
const ModuleSummaryIndex &Index,
- const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries,
+ const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries,
function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
isPrevailing,
- StringMap<FunctionImporter::ImportMapTy> &ImportLists,
- StringMap<FunctionImporter::ExportSetTy> &ExportLists) {
+ DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists,
+ DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) {
+ ModuleImportsManager MIS(isPrevailing, Index, &ExportLists);
// For each module that has function defined, compute the import/export lists.
for (const auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) {
- auto &ImportList = ImportLists[DefinedGVSummaries.first()];
+ auto &ImportList = ImportLists[DefinedGVSummaries.first];
LLVM_DEBUG(dbgs() << "Computing import for Module '"
- << DefinedGVSummaries.first() << "'\n");
- ComputeImportForModule(DefinedGVSummaries.second, isPrevailing, Index,
- DefinedGVSummaries.first(), ImportList,
- &ExportLists);
+ << DefinedGVSummaries.first << "'\n");
+ MIS.computeImportForModule(DefinedGVSummaries.second,
+ DefinedGVSummaries.first, ImportList);
}
// When computing imports we only added the variables and functions being
@@ -735,7 +750,7 @@ void llvm::ComputeCrossModuleImport(
for (auto &ELI : ExportLists) {
FunctionImporter::ExportSetTy NewExports;
const auto &DefinedGVSummaries =
- ModuleToDefinedGVSummaries.lookup(ELI.first());
+ ModuleToDefinedGVSummaries.lookup(ELI.first);
for (auto &EI : ELI.second) {
// Find the copy defined in the exporting module so that we can mark the
// values it references in that specific definition as exported.
@@ -783,7 +798,7 @@ void llvm::ComputeCrossModuleImport(
LLVM_DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size()
<< " modules:\n");
for (auto &ModuleImports : ImportLists) {
- auto ModName = ModuleImports.first();
+ auto ModName = ModuleImports.first;
auto &Exports = ExportLists[ModName];
unsigned NumGVS = numGlobalVarSummaries(Index, Exports);
LLVM_DEBUG(dbgs() << "* Module " << ModName << " exports "
@@ -791,7 +806,7 @@ void llvm::ComputeCrossModuleImport(
<< " vars. Imports from " << ModuleImports.second.size()
<< " modules.\n");
for (auto &Src : ModuleImports.second) {
- auto SrcModName = Src.first();
+ auto SrcModName = Src.first;
unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second);
LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod
<< " functions imported from " << SrcModName << "\n");
@@ -809,7 +824,7 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index,
LLVM_DEBUG(dbgs() << "* Module " << ModulePath << " imports from "
<< ImportList.size() << " modules.\n");
for (auto &Src : ImportList) {
- auto SrcModName = Src.first();
+ auto SrcModName = Src.first;
unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second);
LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod
<< " functions imported from " << SrcModName << "\n");
@@ -819,8 +834,15 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index,
}
#endif
-/// Compute all the imports for the given module in the Index.
-void llvm::ComputeCrossModuleImportForModule(
+/// Compute all the imports for the given module using the Index.
+///
+/// \p isPrevailing is a callback that will be called with a global value's GUID
+/// and summary and should return whether the module corresponding to the
+/// summary contains the linker-prevailing copy of that value.
+///
+/// \p ImportList will be populated with a map that can be passed to
+/// FunctionImporter::importFunctions() above (see description there).
+static void ComputeCrossModuleImportForModuleForTest(
StringRef ModulePath,
function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
isPrevailing,
@@ -833,17 +855,20 @@ void llvm::ComputeCrossModuleImportForModule(
// Compute the import list for this module.
LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n");
- ComputeImportForModule(FunctionSummaryMap, isPrevailing, Index, ModulePath,
- ImportList);
+ ModuleImportsManager MIS(isPrevailing, Index);
+ MIS.computeImportForModule(FunctionSummaryMap, ModulePath, ImportList);
#ifndef NDEBUG
dumpImportListForModule(Index, ModulePath, ImportList);
#endif
}
-// Mark all external summaries in Index for import into the given module.
-// Used for distributed builds using a distributed index.
-void llvm::ComputeCrossModuleImportForModuleFromIndex(
+/// Mark all external summaries in \p Index for import into the given module.
+/// Used for testing the case of distributed builds using a distributed index.
+///
+/// \p ImportList will be populated with a map that can be passed to
+/// FunctionImporter::importFunctions() above (see description there).
+static void ComputeCrossModuleImportForModuleFromIndexForTest(
StringRef ModulePath, const ModuleSummaryIndex &Index,
FunctionImporter::ImportMapTy &ImportList) {
for (const auto &GlobalList : Index) {
@@ -1041,7 +1066,7 @@ void llvm::computeDeadSymbolsWithConstProp(
/// \p ModulePath.
void llvm::gatherImportedSummariesForModule(
StringRef ModulePath,
- const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries,
+ const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries,
const FunctionImporter::ImportMapTy &ImportList,
std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) {
// Include all summaries from the importing module.
@@ -1049,10 +1074,9 @@ void llvm::gatherImportedSummariesForModule(
ModuleToDefinedGVSummaries.lookup(ModulePath);
// Include summaries for imports.
for (const auto &ILI : ImportList) {
- auto &SummariesForIndex =
- ModuleToSummariesForIndex[std::string(ILI.first())];
+ auto &SummariesForIndex = ModuleToSummariesForIndex[std::string(ILI.first)];
const auto &DefinedGVSummaries =
- ModuleToDefinedGVSummaries.lookup(ILI.first());
+ ModuleToDefinedGVSummaries.lookup(ILI.first);
for (const auto &GI : ILI.second) {
const auto &DS = DefinedGVSummaries.find(GI);
assert(DS != DefinedGVSummaries.end() &&
@@ -1298,7 +1322,7 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) {
// ensure all uses of alias instead use the new clone (casted if necessary).
NewFn->setLinkage(GA->getLinkage());
NewFn->setVisibility(GA->getVisibility());
- GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewFn, GA->getType()));
+ GA->replaceAllUsesWith(NewFn);
NewFn->takeName(GA);
return NewFn;
}
@@ -1327,7 +1351,7 @@ Expected<bool> FunctionImporter::importFunctions(
// Do the actual import of functions now, one Module at a time
std::set<StringRef> ModuleNameOrderedList;
for (const auto &FunctionsToImportPerModule : ImportList) {
- ModuleNameOrderedList.insert(FunctionsToImportPerModule.first());
+ ModuleNameOrderedList.insert(FunctionsToImportPerModule.first);
}
for (const auto &Name : ModuleNameOrderedList) {
// Get the module for the import
@@ -1461,7 +1485,7 @@ Expected<bool> FunctionImporter::importFunctions(
return ImportedCount;
}
-static bool doImportingForModule(
+static bool doImportingForModuleForTest(
Module &M, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
isPrevailing) {
if (SummaryFile.empty())
@@ -1481,11 +1505,11 @@ static bool doImportingForModule(
// when testing distributed backend handling via the opt tool, when
// we have distributed indexes containing exactly the summaries to import.
if (ImportAllIndex)
- ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index,
- ImportList);
+ ComputeCrossModuleImportForModuleFromIndexForTest(M.getModuleIdentifier(),
+ *Index, ImportList);
else
- ComputeCrossModuleImportForModule(M.getModuleIdentifier(), isPrevailing,
- *Index, ImportList);
+ ComputeCrossModuleImportForModuleForTest(M.getModuleIdentifier(),
+ isPrevailing, *Index, ImportList);
// Conservatively mark all internal values as promoted. This interface is
// only used when doing importing via the function importing pass. The pass
@@ -1533,7 +1557,7 @@ PreservedAnalyses FunctionImportPass::run(Module &M,
auto isPrevailing = [](GlobalValue::GUID, const GlobalValueSummary *) {
return true;
};
- if (!doImportingForModule(M, isPrevailing))
+ if (!doImportingForModuleForTest(M, isPrevailing))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 3d6c501e4596..a4c12006ee24 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -5,45 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-//
-// This specialises functions with constant parameters. Constant parameters
-// like function pointers and constant globals are propagated to the callee by
-// specializing the function. The main benefit of this pass at the moment is
-// that indirect calls are transformed into direct calls, which provides inline
-// opportunities that the inliner would not have been able to achieve. That's
-// why function specialisation is run before the inliner in the optimisation
-// pipeline; that is by design. Otherwise, we would only benefit from constant
-// passing, which is a valid use-case too, but hasn't been explored much in
-// terms of performance uplifts, cost-model and compile-time impact.
-//
-// Current limitations:
-// - It does not yet handle integer ranges. We do support "literal constants",
-// but that's off by default under an option.
-// - The cost-model could be further looked into (it mainly focuses on inlining
-// benefits),
-//
-// Ideas:
-// - With a function specialization attribute for arguments, we could have
-// a direct way to steer function specialization, avoiding the cost-model,
-// and thus control compile-times / code-size.
-//
-// Todos:
-// - Specializing recursive functions relies on running the transformation a
-// number of times, which is controlled by option
-// `func-specialization-max-iters`. Thus, increasing this value and the
-// number of iterations, will linearly increase the number of times recursive
-// functions get specialized, see also the discussion in
-// https://reviews.llvm.org/D106426 for details. Perhaps there is a
-// compile-time friendlier way to control/limit the number of specialisations
-// for recursive functions.
-// - Don't transform the function if function specialization does not trigger;
-// the SCCPSolver may make IR changes.
-//
-// References:
-// - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable
-// it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q
-//
-//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/ADT/Statistic.h"
@@ -78,16 +39,47 @@ static cl::opt<unsigned> MaxClones(
"The maximum number of clones allowed for a single function "
"specialization"));
+static cl::opt<unsigned>
+ MaxDiscoveryIterations("funcspec-max-discovery-iterations", cl::init(100),
+ cl::Hidden,
+ cl::desc("The maximum number of iterations allowed "
+ "when searching for transitive "
+ "phis"));
+
static cl::opt<unsigned> MaxIncomingPhiValues(
- "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
- "The maximum number of incoming values a PHI node can have to be "
- "considered during the specialization bonus estimation"));
+ "funcspec-max-incoming-phi-values", cl::init(8), cl::Hidden,
+ cl::desc("The maximum number of incoming values a PHI node can have to be "
+ "considered during the specialization bonus estimation"));
+
+static cl::opt<unsigned> MaxBlockPredecessors(
+ "funcspec-max-block-predecessors", cl::init(2), cl::Hidden, cl::desc(
+ "The maximum number of predecessors a basic block can have to be "
+ "considered during the estimation of dead code"));
static cl::opt<unsigned> MinFunctionSize(
- "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
+ "funcspec-min-function-size", cl::init(300), cl::Hidden, cl::desc(
"Don't specialize functions that have less than this number of "
"instructions"));
+static cl::opt<unsigned> MaxCodeSizeGrowth(
+ "funcspec-max-codesize-growth", cl::init(3), cl::Hidden, cl::desc(
+ "Maximum codesize growth allowed per function"));
+
+static cl::opt<unsigned> MinCodeSizeSavings(
+ "funcspec-min-codesize-savings", cl::init(20), cl::Hidden, cl::desc(
+ "Reject specializations whose codesize savings are less than this"
+ "much percent of the original function size"));
+
+static cl::opt<unsigned> MinLatencySavings(
+ "funcspec-min-latency-savings", cl::init(40), cl::Hidden,
+ cl::desc("Reject specializations whose latency savings are less than this"
+ "much percent of the original function size"));
+
+static cl::opt<unsigned> MinInliningBonus(
+ "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, cl::desc(
+ "Reject specializations whose inlining bonus is less than this"
+ "much percent of the original function size"));
+
static cl::opt<bool> SpecializeOnAddress(
"funcspec-on-address", cl::init(false), cl::Hidden, cl::desc(
"Enable function specialization on the address of global values"));
@@ -101,32 +93,32 @@ static cl::opt<bool> SpecializeLiteralConstant(
"Enable specialization of functions that take a literal constant as an "
"argument"));
-// Estimates the instruction cost of all the basic blocks in \p WorkList.
-// The successors of such blocks are added to the list as long as they are
-// executable and they have a unique predecessor. \p WorkList represents
-// the basic blocks of a specialization which become dead once we replace
-// instructions that are known to be constants. The aim here is to estimate
-// the combination of size and latency savings in comparison to the non
-// specialized version of the function.
-static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
- DenseSet<BasicBlock *> &DeadBlocks,
- ConstMap &KnownConstants, SCCPSolver &Solver,
- BlockFrequencyInfo &BFI,
- TargetTransformInfo &TTI) {
- Cost Bonus = 0;
+bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ,
+ DenseSet<BasicBlock *> &DeadBlocks) {
+ unsigned I = 0;
+ return all_of(predecessors(Succ),
+ [&I, BB, Succ, &DeadBlocks] (BasicBlock *Pred) {
+ return I++ < MaxBlockPredecessors &&
+ (Pred == BB || Pred == Succ || DeadBlocks.contains(Pred));
+ });
+}
+// Estimates the codesize savings due to dead code after constant propagation.
+// \p WorkList represents the basic blocks of a specialization which will
+// eventually become dead once we replace instructions that are known to be
+// constants. The successors of such blocks are added to the list as long as
+// the \p Solver found they were executable prior to specialization, and only
+// if all their predecessors are dead.
+Cost InstCostVisitor::estimateBasicBlocks(
+ SmallVectorImpl<BasicBlock *> &WorkList) {
+ Cost CodeSize = 0;
// Accumulate the instruction cost of each basic block weighted by frequency.
while (!WorkList.empty()) {
BasicBlock *BB = WorkList.pop_back_val();
- uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
- BFI.getEntryFreq();
- if (!Weight)
- continue;
-
- // These blocks are considered dead as far as the InstCostVisitor is
- // concerned. They haven't been proven dead yet by the Solver, but
- // may become if we propagate the constant specialization arguments.
+ // These blocks are considered dead as far as the InstCostVisitor
+ // is concerned. They haven't been proven dead yet by the Solver,
+ // but may become if we propagate the specialization arguments.
if (!DeadBlocks.insert(BB).second)
continue;
@@ -139,74 +131,100 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
if (KnownConstants.contains(&I))
continue;
- Bonus += Weight *
- TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+ Cost C = TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
- LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
- << " after user " << I << "\n");
+ LLVM_DEBUG(dbgs() << "FnSpecialization: CodeSize " << C
+ << " for user " << I << "\n");
+ CodeSize += C;
}
// Keep adding dead successors to the list as long as they are
- // executable and they have a unique predecessor.
+ // executable and only reachable from dead blocks.
for (BasicBlock *SuccBB : successors(BB))
- if (Solver.isBlockExecutable(SuccBB) &&
- SuccBB->getUniquePredecessor() == BB)
+ if (isBlockExecutable(SuccBB) &&
+ canEliminateSuccessor(BB, SuccBB, DeadBlocks))
WorkList.push_back(SuccBB);
}
- return Bonus;
+ return CodeSize;
}
static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
if (auto *C = dyn_cast<Constant>(V))
return C;
- if (auto It = KnownConstants.find(V); It != KnownConstants.end())
- return It->second;
- return nullptr;
+ return KnownConstants.lookup(V);
}
-Cost InstCostVisitor::getBonusFromPendingPHIs() {
- Cost Bonus = 0;
+Bonus InstCostVisitor::getBonusFromPendingPHIs() {
+ Bonus B;
while (!PendingPHIs.empty()) {
Instruction *Phi = PendingPHIs.pop_back_val();
- Bonus += getUserBonus(Phi);
+ // The pending PHIs could have been proven dead by now.
+ if (isBlockExecutable(Phi->getParent()))
+ B += getUserBonus(Phi);
}
- return Bonus;
+ return B;
+}
+
+/// Compute a bonus for replacing argument \p A with constant \p C.
+Bonus InstCostVisitor::getSpecializationBonus(Argument *A, Constant *C) {
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
+ << C->getNameOrAsOperand() << "\n");
+ Bonus B;
+ for (auto *U : A->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ if (isBlockExecutable(UI->getParent()))
+ B += getUserBonus(UI, A, C);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated bonus {CodeSize = "
+ << B.CodeSize << ", Latency = " << B.Latency
+ << "} for argument " << *A << "\n");
+ return B;
}
-Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+ // We have already propagated a constant for this user.
+ if (KnownConstants.contains(User))
+ return {0, 0};
+
// Cache the iterator before visiting.
LastVisited = Use ? KnownConstants.insert({Use, C}).first
: KnownConstants.end();
- if (auto *I = dyn_cast<SwitchInst>(User))
- return estimateSwitchInst(*I);
-
- if (auto *I = dyn_cast<BranchInst>(User))
- return estimateBranchInst(*I);
-
- C = visit(*User);
- if (!C)
- return 0;
+ Cost CodeSize = 0;
+ if (auto *I = dyn_cast<SwitchInst>(User)) {
+ CodeSize = estimateSwitchInst(*I);
+ } else if (auto *I = dyn_cast<BranchInst>(User)) {
+ CodeSize = estimateBranchInst(*I);
+ } else {
+ C = visit(*User);
+ if (!C)
+ return {0, 0};
+ }
+ // Even though it doesn't make sense to bind switch and branch instructions
+ // with a constant, unlike any other instruction type, it prevents estimating
+ // their bonus multiple times.
KnownConstants.insert({User, C});
+ CodeSize += TTI.getInstructionCost(User, TargetTransformInfo::TCK_CodeSize);
+
uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
- BFI.getEntryFreq();
- if (!Weight)
- return 0;
+ BFI.getEntryFreq().getFrequency();
- Cost Bonus = Weight *
- TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
+ Cost Latency = Weight *
+ TTI.getInstructionCost(User, TargetTransformInfo::TCK_Latency);
- LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
- << " for user " << *User << "\n");
+ LLVM_DEBUG(dbgs() << "FnSpecialization: {CodeSize = " << CodeSize
+ << ", Latency = " << Latency << "} for user "
+ << *User << "\n");
+ Bonus B(CodeSize, Latency);
for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
- if (UI != User && Solver.isBlockExecutable(UI->getParent()))
- Bonus += getUserBonus(UI, User, C);
+ if (UI != User && isBlockExecutable(UI->getParent()))
+ B += getUserBonus(UI, User, C);
- return Bonus;
+ return B;
}
Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
@@ -226,14 +244,12 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
SmallVector<BasicBlock *> WorkList;
for (const auto &Case : I.cases()) {
BasicBlock *BB = Case.getCaseSuccessor();
- if (BB == Succ || !Solver.isBlockExecutable(BB) ||
- BB->getUniquePredecessor() != I.getParent())
- continue;
- WorkList.push_back(BB);
+ if (BB != Succ && isBlockExecutable(BB) &&
+ canEliminateSuccessor(I.getParent(), BB, DeadBlocks))
+ WorkList.push_back(BB);
}
- return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
- TTI);
+ return estimateBasicBlocks(WorkList);
}
Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
@@ -246,12 +262,55 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
// Initialize the worklist with the dead successor as long as
// it is executable and has a unique predecessor.
SmallVector<BasicBlock *> WorkList;
- if (Solver.isBlockExecutable(Succ) &&
- Succ->getUniquePredecessor() == I.getParent())
+ if (isBlockExecutable(Succ) &&
+ canEliminateSuccessor(I.getParent(), Succ, DeadBlocks))
WorkList.push_back(Succ);
- return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
- TTI);
+ return estimateBasicBlocks(WorkList);
+}
+
+bool InstCostVisitor::discoverTransitivelyIncomingValues(
+ Constant *Const, PHINode *Root, DenseSet<PHINode *> &TransitivePHIs) {
+
+ SmallVector<PHINode *, 64> WorkList;
+ WorkList.push_back(Root);
+ unsigned Iter = 0;
+
+ while (!WorkList.empty()) {
+ PHINode *PN = WorkList.pop_back_val();
+
+ if (++Iter > MaxDiscoveryIterations ||
+ PN->getNumIncomingValues() > MaxIncomingPhiValues)
+ return false;
+
+ if (!TransitivePHIs.insert(PN).second)
+ continue;
+
+ for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
+ Value *V = PN->getIncomingValue(I);
+
+ // Disregard self-references and dead incoming values.
+ if (auto *Inst = dyn_cast<Instruction>(V))
+ if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
+ continue;
+
+ if (Constant *C = findConstantFor(V, KnownConstants)) {
+ // Not all incoming values are the same constant. Bail immediately.
+ if (C != Const)
+ return false;
+ continue;
+ }
+
+ if (auto *Phi = dyn_cast<PHINode>(V)) {
+ WorkList.push_back(Phi);
+ continue;
+ }
+
+ // We can't reason about anything else.
+ return false;
+ }
+ }
+ return true;
}
Constant *InstCostVisitor::visitPHINode(PHINode &I) {
@@ -260,23 +319,52 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
bool Inserted = VisitedPHIs.insert(&I).second;
Constant *Const = nullptr;
+ bool HaveSeenIncomingPHI = false;
for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
Value *V = I.getIncomingValue(Idx);
+
+ // Disregard self-references and dead incoming values.
if (auto *Inst = dyn_cast<Instruction>(V))
if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
continue;
- Constant *C = findConstantFor(V, KnownConstants);
- if (!C) {
- if (Inserted)
- PendingPHIs.push_back(&I);
- return nullptr;
+
+ if (Constant *C = findConstantFor(V, KnownConstants)) {
+ if (!Const)
+ Const = C;
+ // Not all incoming values are the same constant. Bail immediately.
+ if (C != Const)
+ return nullptr;
+ continue;
}
- if (!Const)
- Const = C;
- else if (C != Const)
+
+ if (Inserted) {
+ // First time we are seeing this phi. We will retry later, after
+ // all the constant arguments have been propagated. Bail for now.
+ PendingPHIs.push_back(&I);
return nullptr;
+ }
+
+ if (isa<PHINode>(V)) {
+ // Perhaps it is a Transitive Phi. We will confirm later.
+ HaveSeenIncomingPHI = true;
+ continue;
+ }
+
+ // We can't reason about anything else.
+ return nullptr;
}
+
+ if (!Const)
+ return nullptr;
+
+ if (!HaveSeenIncomingPHI)
+ return Const;
+
+ DenseSet<PHINode *> TransitivePHIs;
+ if (!discoverTransitivelyIncomingValues(Const, &I, TransitivePHIs))
+ return nullptr;
+
return Const;
}
@@ -479,10 +567,7 @@ void FunctionSpecializer::promoteConstantStackValues(Function *F) {
Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
GlobalValue::InternalLinkage, ConstVal,
- "funcspec.arg");
- if (ArgOpType != ConstVal->getType())
- GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
-
+ "specialized.arg." + Twine(++NGlobals));
Call->setArgOperand(Idx, GV);
}
}
@@ -572,13 +657,18 @@ bool FunctionSpecializer::run() {
if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant)
continue;
+ int64_t Sz = *Metrics.NumInsts.getValue();
+ assert(Sz > 0 && "CodeSize should be positive");
+ // It is safe to down cast from int64_t, NumInsts is always positive.
+ unsigned FuncSize = static_cast<unsigned>(Sz);
+
LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
- << F.getName() << " is " << Metrics.NumInsts << "\n");
+ << F.getName() << " is " << FuncSize << "\n");
if (Inserted && Metrics.isRecursive)
promoteConstantStackValues(&F);
- if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) {
+ if (!findSpecializations(&F, FuncSize, AllSpecs, SM)) {
LLVM_DEBUG(
dbgs() << "FnSpecialization: No possible specializations found for "
<< F.getName() << "\n");
@@ -706,14 +796,15 @@ void FunctionSpecializer::removeDeadFunctions() {
/// Clone the function \p F and remove the ssa_copy intrinsics added by
/// the SCCPSolver in the cloned version.
-static Function *cloneCandidateFunction(Function *F) {
+static Function *cloneCandidateFunction(Function *F, unsigned NSpecs) {
ValueToValueMapTy Mappings;
Function *Clone = CloneFunction(F, Mappings);
+ Clone->setName(F->getName() + ".specialized." + Twine(NSpecs));
removeSSACopy(*Clone);
return Clone;
}
-bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
+bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
SmallVectorImpl<Spec> &AllSpecs,
SpecMap &SM) {
// A mapping from a specialisation signature to the index of the respective
@@ -779,20 +870,48 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
- Cost Score = 0;
+ Bonus B;
+ unsigned Score = 0;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
- for (ArgInfo &A : S.Args)
- Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
- Score += Visitor.getBonusFromPendingPHIs();
+ for (ArgInfo &A : S.Args) {
+ B += Visitor.getSpecializationBonus(A.Formal, A.Actual);
+ Score += getInliningBonus(A.Formal, A.Actual);
+ }
+ B += Visitor.getBonusFromPendingPHIs();
+
- LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
- << Score << "\n");
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = "
+ << B.CodeSize << ", Latency = " << B.Latency
+ << ", Inlining = " << Score << "}\n");
+
+ FunctionGrowth[F] += FuncSize - B.CodeSize;
+
+ auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize,
+ unsigned FuncGrowth) -> bool {
+ // No check required.
+ if (ForceSpecialization)
+ return true;
+ // Minimum inlining bonus.
+ if (Score > MinInliningBonus * FuncSize / 100)
+ return true;
+ // Minimum codesize savings.
+ if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100)
+ return false;
+ // Minimum latency savings.
+ if (B.Latency < MinLatencySavings * FuncSize / 100)
+ return false;
+ // Maximum codesize growth.
+ if (FuncGrowth / FuncSize > MaxCodeSizeGrowth)
+ return false;
+ return true;
+ };
// Discard unprofitable specialisations.
- if (!ForceSpecialization && Score <= SpecCost)
+ if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F]))
continue;
// Create a new specialisation entry.
+ Score += std::max(B.CodeSize, B.Latency);
auto &Spec = AllSpecs.emplace_back(F, S, Score);
if (CS.getFunction() != F)
Spec.CallSites.push_back(&CS);
@@ -838,7 +957,7 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) {
Function *FunctionSpecializer::createSpecialization(Function *F,
const SpecSig &S) {
- Function *Clone = cloneCandidateFunction(F);
+ Function *Clone = cloneCandidateFunction(F, Specializations.size() + 1);
// The original function does not neccessarily have internal linkage, but the
// clone must.
@@ -859,30 +978,14 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
return Clone;
}
-/// Compute a bonus for replacing argument \p A with constant \p C.
-Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
- InstCostVisitor &Visitor) {
- LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
- << C->getNameOrAsOperand() << "\n");
-
- Cost TotalCost = 0;
- for (auto *U : A->users())
- if (auto *UI = dyn_cast<Instruction>(U))
- if (Solver.isBlockExecutable(UI->getParent()))
- TotalCost += Visitor.getUserBonus(UI, A, C);
-
- LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
- << TotalCost << " for argument " << *A << "\n");
-
- // The below heuristic is only concerned with exposing inlining
- // opportunities via indirect call promotion. If the argument is not a
- // (potentially casted) function pointer, give up.
- //
- // TODO: Perhaps we should consider checking such inlining opportunities
- // while traversing the users of the specialization arguments ?
+/// Compute the inlining bonus for replacing argument \p A with constant \p C.
+/// The below heuristic is only concerned with exposing inlining
+/// opportunities via indirect call promotion. If the argument is not a
+/// (potentially casted) function pointer, give up.
+unsigned FunctionSpecializer::getInliningBonus(Argument *A, Constant *C) {
Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts());
if (!CalledFunction)
- return TotalCost;
+ return 0;
// Get TTI for the called function (used for the inline cost).
auto &CalleeTTI = (GetTTI)(*CalledFunction);
@@ -892,7 +995,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
// calls to be promoted to direct calls. If the indirect call promotion
// would likely enable the called function to be inlined, specializing is a
// good idea.
- int Bonus = 0;
+ int InliningBonus = 0;
for (User *U : A->users()) {
if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
continue;
@@ -919,15 +1022,15 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
// We clamp the bonus for this call to be between zero and the default
// threshold.
if (IC.isAlways())
- Bonus += Params.DefaultThreshold;
+ InliningBonus += Params.DefaultThreshold;
else if (IC.isVariable() && IC.getCostDelta() > 0)
- Bonus += IC.getCostDelta();
+ InliningBonus += IC.getCostDelta();
- LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << InliningBonus
<< " for user " << *U << "\n");
}
- return TotalCost + Bonus;
+ return InliningBonus > 0 ? static_cast<unsigned>(InliningBonus) : 0;
}
/// Determine if it is possible to specialise the function for constant values
diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
index 1ccc523ead8a..951372adcfa9 100644
--- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
@@ -17,7 +17,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/iterator_range.h"
@@ -390,7 +389,7 @@ static bool collectSRATypes(DenseMap<uint64_t, GlobalPart> &Parts,
}
// Scalable types not currently supported.
- if (isa<ScalableVectorType>(Ty))
+ if (Ty->isScalableTy())
return false;
auto IsStored = [](Value *V, Constant *Initializer) {
@@ -930,25 +929,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI,
}
// Update users of the allocation to use the new global instead.
- BitCastInst *TheBC = nullptr;
- while (!CI->use_empty()) {
- Instruction *User = cast<Instruction>(CI->user_back());
- if (BitCastInst *BCI = dyn_cast<BitCastInst>(User)) {
- if (BCI->getType() == NewGV->getType()) {
- BCI->replaceAllUsesWith(NewGV);
- BCI->eraseFromParent();
- } else {
- BCI->setOperand(0, NewGV);
- }
- } else {
- if (!TheBC)
- TheBC = new BitCastInst(NewGV, CI->getType(), "newgv", CI);
- User->replaceUsesOfWith(CI, TheBC);
- }
- }
-
- SmallSetVector<Constant *, 1> RepValues;
- RepValues.insert(NewGV);
+ CI->replaceAllUsesWith(NewGV);
// If there is a comparison against null, we will insert a global bool to
// keep track of whether the global was initialized yet or not.
@@ -980,9 +961,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI,
Use &LoadUse = *LI->use_begin();
ICmpInst *ICI = dyn_cast<ICmpInst>(LoadUse.getUser());
if (!ICI) {
- auto *CE = ConstantExpr::getBitCast(NewGV, LI->getType());
- RepValues.insert(CE);
- LoadUse.set(CE);
+ LoadUse.set(NewGV);
continue;
}
@@ -1028,8 +1007,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI,
// To further other optimizations, loop over all users of NewGV and try to
// constant prop them. This will promote GEP instructions with constant
// indices into GEP constant-exprs, which will allow global-opt to hack on it.
- for (auto *CE : RepValues)
- ConstantPropUsersOf(CE, DL, TLI);
+ ConstantPropUsersOf(NewGV, DL, TLI);
return NewGV;
}
@@ -1474,7 +1452,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
if (!GS.HasMultipleAccessingFunctions &&
GS.AccessingFunction &&
GV->getValueType()->isSingleValueType() &&
- GV->getType()->getAddressSpace() == 0 &&
+ GV->getType()->getAddressSpace() == DL.getAllocaAddrSpace() &&
!GV->isExternallyInitialized() &&
GS.AccessingFunction->doesNotRecurse() &&
isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV,
@@ -1584,7 +1562,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
GV->getAddressSpace());
NGV->takeName(GV);
NGV->copyAttributesFrom(GV);
- GV->replaceAllUsesWith(ConstantExpr::getBitCast(NGV, GV->getType()));
+ GV->replaceAllUsesWith(NGV);
GV->eraseFromParent();
GV = NGV;
}
@@ -1635,7 +1613,7 @@ processGlobal(GlobalValue &GV,
function_ref<TargetTransformInfo &(Function &)> GetTTI,
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
function_ref<DominatorTree &(Function &)> LookupDomTree) {
- if (GV.getName().startswith("llvm."))
+ if (GV.getName().starts_with("llvm."))
return false;
GlobalStatus GS;
@@ -1701,13 +1679,16 @@ static void RemoveAttribute(Function *F, Attribute::AttrKind A) {
/// idea here is that we don't want to mess with the convention if the user
/// explicitly requested something with performance implications like coldcc,
/// GHC, or anyregcc.
-static bool hasChangeableCC(Function *F) {
+static bool hasChangeableCCImpl(Function *F) {
CallingConv::ID CC = F->getCallingConv();
// FIXME: Is it worth transforming x86_stdcallcc and x86_fastcallcc?
if (CC != CallingConv::C && CC != CallingConv::X86_ThisCall)
return false;
+ if (F->isVarArg())
+ return false;
+
// FIXME: Change CC for the whole chain of musttail calls when possible.
//
// Can't change CC of the function that either has musttail calls, or is a
@@ -1727,7 +1708,16 @@ static bool hasChangeableCC(Function *F) {
if (BB.getTerminatingMustTailCall())
return false;
- return true;
+ return !F->hasAddressTaken();
+}
+
+using ChangeableCCCacheTy = SmallDenseMap<Function *, bool, 8>;
+static bool hasChangeableCC(Function *F,
+ ChangeableCCCacheTy &ChangeableCCCache) {
+ auto Res = ChangeableCCCache.try_emplace(F, false);
+ if (Res.second)
+ Res.first->second = hasChangeableCCImpl(F);
+ return Res.first->second;
}
/// Return true if the block containing the call site has a BlockFrequency of
@@ -1781,7 +1771,8 @@ static void changeCallSitesToColdCC(Function *F) {
// coldcc calling convention.
static bool
hasOnlyColdCalls(Function &F,
- function_ref<BlockFrequencyInfo &(Function &)> GetBFI) {
+ function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
+ ChangeableCCCacheTy &ChangeableCCCache) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
@@ -1800,8 +1791,7 @@ hasOnlyColdCalls(Function &F,
if (!CalledFn->hasLocalLinkage())
return false;
// Check if it's valid to use coldcc calling convention.
- if (!hasChangeableCC(CalledFn) || CalledFn->isVarArg() ||
- CalledFn->hasAddressTaken())
+ if (!hasChangeableCC(CalledFn, ChangeableCCCache))
return false;
BlockFrequencyInfo &CallerBFI = GetBFI(F);
if (!isColdCallSite(*CI, CallerBFI))
@@ -1873,12 +1863,9 @@ static void RemovePreallocated(Function *F) {
CB->eraseFromParent();
Builder.SetInsertPoint(PreallocatedSetup);
- auto *StackSave =
- Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stacksave));
-
+ auto *StackSave = Builder.CreateStackSave();
Builder.SetInsertPoint(NewCB->getNextNonDebugInstruction());
- Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stackrestore),
- StackSave);
+ Builder.CreateStackRestore(StackSave);
// Replace @llvm.call.preallocated.arg() with alloca.
// Cannot modify users() while iterating over it, so make a copy.
@@ -1905,10 +1892,8 @@ static void RemovePreallocated(Function *F) {
Builder.SetInsertPoint(InsertBefore);
auto *Alloca =
Builder.CreateAlloca(ArgType, AddressSpace, nullptr, "paarg");
- auto *BitCast = Builder.CreateBitCast(
- Alloca, Type::getInt8PtrTy(M->getContext()), UseCall->getName());
- ArgAllocas[AllocArgIndex] = BitCast;
- AllocaReplacement = BitCast;
+ ArgAllocas[AllocArgIndex] = Alloca;
+ AllocaReplacement = Alloca;
}
UseCall->replaceAllUsesWith(AllocaReplacement);
@@ -1931,9 +1916,10 @@ OptimizeFunctions(Module &M,
bool Changed = false;
+ ChangeableCCCacheTy ChangeableCCCache;
std::vector<Function *> AllCallsCold;
for (Function &F : llvm::make_early_inc_range(M))
- if (hasOnlyColdCalls(F, GetBFI))
+ if (hasOnlyColdCalls(F, GetBFI, ChangeableCCCache))
AllCallsCold.push_back(&F);
// Optimize functions.
@@ -1995,7 +1981,7 @@ OptimizeFunctions(Module &M,
continue;
}
- if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) {
+ if (hasChangeableCC(&F, ChangeableCCCache)) {
NumInternalFunc++;
TargetTransformInfo &TTI = GetTTI(F);
// Change the calling convention to coldcc if either stress testing is
@@ -2005,6 +1991,7 @@ OptimizeFunctions(Module &M,
if (EnableColdCCStressTest ||
(TTI.useColdCCForColdCall(F) &&
isValidCandidateForColdCC(F, GetBFI, AllCallsCold))) {
+ ChangeableCCCache.erase(&F);
F.setCallingConv(CallingConv::Cold);
changeCallSitesToColdCC(&F);
Changed = true;
@@ -2012,7 +1999,7 @@ OptimizeFunctions(Module &M,
}
}
- if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) {
+ if (hasChangeableCC(&F, ChangeableCCCache)) {
// If this function has a calling convention worth changing, is not a
// varargs function, and is only called directly, promote it to use the
// Fast calling convention.
@@ -2117,19 +2104,18 @@ static void setUsedInitializer(GlobalVariable &V,
const auto *VEPT = cast<PointerType>(VAT->getArrayElementType());
// Type of pointer to the array of pointers.
- PointerType *Int8PtrTy =
- Type::getInt8PtrTy(V.getContext(), VEPT->getAddressSpace());
+ PointerType *PtrTy =
+ PointerType::get(V.getContext(), VEPT->getAddressSpace());
SmallVector<Constant *, 8> UsedArray;
for (GlobalValue *GV : Init) {
- Constant *Cast =
- ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy);
+ Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, PtrTy);
UsedArray.push_back(Cast);
}
// Sort to get deterministic order.
array_pod_sort(UsedArray.begin(), UsedArray.end(), compareNames);
- ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size());
+ ArrayType *ATy = ArrayType::get(PtrTy, UsedArray.size());
Module *M = V.getParent();
V.removeFromParent();
@@ -2299,7 +2285,7 @@ OptimizeGlobalAliases(Module &M,
if (!hasUsesToReplace(J, Used, RenameTarget))
continue;
- J.replaceAllUsesWith(ConstantExpr::getBitCast(Aliasee, J.getType()));
+ J.replaceAllUsesWith(Aliasee);
++NumAliasesResolved;
Changed = true;
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 599ace9ca79f..fabb3c5fb921 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -44,6 +44,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/CommandLine.h"
@@ -86,6 +87,11 @@ static cl::opt<int> MaxParametersForSplit(
"hotcoldsplit-max-params", cl::init(4), cl::Hidden,
cl::desc("Maximum number of parameters for a split function"));
+static cl::opt<int> ColdBranchProbDenom(
+ "hotcoldsplit-cold-probability-denom", cl::init(100), cl::Hidden,
+ cl::desc("Divisor of cold branch probability."
+ "BranchProbability = 1/ColdBranchProbDenom"));
+
namespace {
// Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify
// this function unless you modify the MBB version as well.
@@ -102,6 +108,32 @@ bool blockEndsInUnreachable(const BasicBlock &BB) {
return !(isa<ReturnInst>(I) || isa<IndirectBrInst>(I));
}
+void analyzeProfMetadata(BasicBlock *BB,
+ BranchProbability ColdProbThresh,
+ SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks) {
+ // TODO: Handle branches with > 2 successors.
+ BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());
+ if (!CondBr)
+ return;
+
+ uint64_t TrueWt, FalseWt;
+ if (!extractBranchWeights(*CondBr, TrueWt, FalseWt))
+ return;
+
+ auto SumWt = TrueWt + FalseWt;
+ if (SumWt == 0)
+ return;
+
+ auto TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
+ auto FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
+
+ if (TrueProb <= ColdProbThresh)
+ AnnotatedColdBlocks.insert(CondBr->getSuccessor(0));
+
+ if (FalseProb <= ColdProbThresh)
+ AnnotatedColdBlocks.insert(CondBr->getSuccessor(1));
+}
+
bool unlikelyExecuted(BasicBlock &BB) {
// Exception handling blocks are unlikely executed.
if (BB.isEHPad() || isa<ResumeInst>(BB.getTerminator()))
@@ -183,6 +215,34 @@ bool HotColdSplitting::isFunctionCold(const Function &F) const {
return false;
}
+bool HotColdSplitting::isBasicBlockCold(BasicBlock *BB,
+ BranchProbability ColdProbThresh,
+ SmallPtrSetImpl<BasicBlock *> &ColdBlocks,
+ SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks,
+ BlockFrequencyInfo *BFI) const {
+ // This block is already part of some outlining region.
+ if (ColdBlocks.count(BB))
+ return true;
+
+ if (BFI) {
+ if (PSI->isColdBlock(BB, BFI))
+ return true;
+ } else {
+ // Find cold blocks of successors of BB during a reverse postorder traversal.
+ analyzeProfMetadata(BB, ColdProbThresh, AnnotatedColdBlocks);
+
+ // A statically cold BB would be known before it is visited
+ // because the prof-data of incoming edges are 'analyzed' as part of RPOT.
+ if (AnnotatedColdBlocks.count(BB))
+ return true;
+ }
+
+ if (EnableStaticAnalysis && unlikelyExecuted(*BB))
+ return true;
+
+ return false;
+}
+
// Returns false if the function should not be considered for hot-cold split
// optimization.
bool HotColdSplitting::shouldOutlineFrom(const Function &F) const {
@@ -565,6 +625,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
// The set of cold blocks.
SmallPtrSet<BasicBlock *, 4> ColdBlocks;
+ // Set of cold blocks obtained with RPOT.
+ SmallPtrSet<BasicBlock *, 4> AnnotatedColdBlocks;
+
// The worklist of non-intersecting regions left to outline.
SmallVector<OutliningRegion, 2> OutliningWorklist;
@@ -587,16 +650,15 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
TargetTransformInfo &TTI = GetTTI(F);
OptimizationRemarkEmitter &ORE = (*GetORE)(F);
AssumptionCache *AC = LookupAC(F);
+ auto ColdProbThresh = TTI.getPredictableBranchThreshold().getCompl();
+
+ if (ColdBranchProbDenom.getNumOccurrences())
+ ColdProbThresh = BranchProbability(1, ColdBranchProbDenom.getValue());
// Find all cold regions.
for (BasicBlock *BB : RPOT) {
- // This block is already part of some outlining region.
- if (ColdBlocks.count(BB))
- continue;
-
- bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) ||
- (EnableStaticAnalysis && unlikelyExecuted(*BB));
- if (!Cold)
+ if (!isBasicBlockCold(BB, ColdProbThresh, ColdBlocks, AnnotatedColdBlocks,
+ BFI))
continue;
LLVM_DEBUG({
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp
index e258299c6a4c..a6e19df7c5f1 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -155,7 +155,7 @@ struct OutlinableGroup {
/// \param TargetBB - the BasicBlock to put Instruction into.
static void moveBBContents(BasicBlock &SourceBB, BasicBlock &TargetBB) {
for (Instruction &I : llvm::make_early_inc_range(SourceBB))
- I.moveBefore(TargetBB, TargetBB.end());
+ I.moveBeforePreserving(TargetBB, TargetBB.end());
}
/// A function to sort the keys of \p Map, which must be a mapping of constant
@@ -198,7 +198,7 @@ Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other,
BasicBlock *
OutlinableRegion::findCorrespondingBlockIn(const OutlinableRegion &Other,
BasicBlock *BB) {
- Instruction *FirstNonPHI = BB->getFirstNonPHI();
+ Instruction *FirstNonPHI = BB->getFirstNonPHIOrDbg();
assert(FirstNonPHI && "block is empty?");
Value *CorrespondingVal = findCorrespondingValueIn(Other, FirstNonPHI);
if (!CorrespondingVal)
@@ -557,7 +557,7 @@ collectRegionsConstants(OutlinableRegion &Region,
// Iterate over the operands in an instruction. If the global value number,
// assigned by the IRSimilarityCandidate, has been seen before, we check if
- // the the number has been found to be not the same value in each instance.
+ // the number has been found to be not the same value in each instance.
for (Value *V : ID.OperVals) {
std::optional<unsigned> GVNOpt = C.getGVN(V);
assert(GVNOpt && "Expected a GVN for operand?");
@@ -766,7 +766,7 @@ static void moveFunctionData(Function &Old, Function &New,
}
}
-/// Find the the constants that will need to be lifted into arguments
+/// Find the constants that will need to be lifted into arguments
/// as they are not the same in each instance of the region.
///
/// \param [in] C - The IRSimilarityCandidate containing the region we are
@@ -1346,7 +1346,7 @@ findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region,
// the output, so we add a pointer type to the argument types of the overall
// function to handle this output and create a mapping to it.
if (!TypeFound) {
- Group.ArgumentTypes.push_back(Output->getType()->getPointerTo(
+ Group.ArgumentTypes.push_back(PointerType::get(Output->getContext(),
M.getDataLayout().getAllocaAddrSpace()));
// Mark the new pointer type as the last value in the aggregate argument
// list.
diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp
index 3e00aebce372..a9747aebf67b 100644
--- a/llvm/lib/Transforms/IPO/Inliner.cpp
+++ b/llvm/lib/Transforms/IPO/Inliner.cpp
@@ -13,7 +13,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/Inliner.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/PriorityWorklist.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
@@ -63,7 +62,6 @@
#include <cassert>
#include <functional>
#include <utility>
-#include <vector>
using namespace llvm;
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index 9b4b3efd7283..733f290b1bc9 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -381,8 +381,7 @@ struct ScopedSaveAliaseesAndUsed {
appendToCompilerUsed(M, CompilerUsed);
for (auto P : FunctionAliases)
- P.first->setAliasee(
- ConstantExpr::getBitCast(P.second, P.first->getType()));
+ P.first->setAliasee(P.second);
for (auto P : ResolverIFuncs) {
// This does not preserve pointer casts that may have been stripped by the
@@ -411,16 +410,19 @@ class LowerTypeTestsModule {
// selectJumpTableArmEncoding may decide to use Thumb in either case.
bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false;
+ // Cache variable used by hasBranchTargetEnforcement().
+ int HasBranchTargetEnforcement = -1;
+
// The jump table type we ended up deciding on. (Usually the same as
// Arch, except that 'arm' and 'thumb' are often interchangeable.)
Triple::ArchType JumpTableArch = Triple::UnknownArch;
IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
- PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
+ PointerType *Int8PtrTy = PointerType::getUnqual(M.getContext());
ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0);
IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
- PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty);
+ PointerType *Int32PtrTy = PointerType::getUnqual(M.getContext());
IntegerType *Int64Ty = Type::getInt64Ty(M.getContext());
IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0);
@@ -492,6 +494,7 @@ class LowerTypeTestsModule {
ArrayRef<GlobalTypeMember *> Globals);
Triple::ArchType
selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions);
+ bool hasBranchTargetEnforcement();
unsigned getJumpTableEntrySize();
Type *getJumpTableEntryType();
void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
@@ -755,9 +758,9 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
// also conveniently gives us a bit offset to use during the load from
// the bitset.
Value *OffsetSHR =
- B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy));
+ B.CreateLShr(PtrOffset, B.CreateZExt(TIL.AlignLog2, IntPtrTy));
Value *OffsetSHL = B.CreateShl(
- PtrOffset, ConstantExpr::getZExt(
+ PtrOffset, B.CreateZExt(
ConstantExpr::getSub(
ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)),
TIL.AlignLog2),
@@ -962,7 +965,6 @@ LowerTypeTestsModule::importTypeId(StringRef TypeId) {
Int8Arr0Ty);
if (auto *GV = dyn_cast<GlobalVariable>(C))
GV->setVisibility(GlobalValue::HiddenVisibility);
- C = ConstantExpr::getBitCast(C, Int8PtrTy);
return C;
};
@@ -1100,15 +1102,13 @@ void LowerTypeTestsModule::importFunction(
replaceCfiUses(F, FDecl, isJumpTableCanonical);
// Set visibility late because it's used in replaceCfiUses() to determine
- // whether uses need to to be replaced.
+ // whether uses need to be replaced.
F->setVisibility(Visibility);
}
void LowerTypeTestsModule::lowerTypeTestCalls(
ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
- CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy);
-
// For each type identifier in this disjoint set...
for (Metadata *TypeId : TypeIds) {
// Build the bitset.
@@ -1196,6 +1196,20 @@ static const unsigned kARMJumpTableEntrySize = 4;
static const unsigned kARMBTIJumpTableEntrySize = 8;
static const unsigned kARMv6MJumpTableEntrySize = 16;
static const unsigned kRISCVJumpTableEntrySize = 8;
+static const unsigned kLOONGARCH64JumpTableEntrySize = 8;
+
+bool LowerTypeTestsModule::hasBranchTargetEnforcement() {
+ if (HasBranchTargetEnforcement == -1) {
+ // First time this query has been called. Find out the answer by checking
+ // the module flags.
+ if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
+ M.getModuleFlag("branch-target-enforcement")))
+ HasBranchTargetEnforcement = (BTE->getZExtValue() != 0);
+ else
+ HasBranchTargetEnforcement = 0;
+ }
+ return HasBranchTargetEnforcement;
+}
unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
switch (JumpTableArch) {
@@ -1209,19 +1223,22 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
case Triple::arm:
return kARMJumpTableEntrySize;
case Triple::thumb:
- if (CanUseThumbBWJumpTable)
+ if (CanUseThumbBWJumpTable) {
+ if (hasBranchTargetEnforcement())
+ return kARMBTIJumpTableEntrySize;
return kARMJumpTableEntrySize;
- else
+ } else {
return kARMv6MJumpTableEntrySize;
+ }
case Triple::aarch64:
- if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
- M.getModuleFlag("branch-target-enforcement")))
- if (BTE->getZExtValue())
- return kARMBTIJumpTableEntrySize;
+ if (hasBranchTargetEnforcement())
+ return kARMBTIJumpTableEntrySize;
return kARMJumpTableEntrySize;
case Triple::riscv32:
case Triple::riscv64:
return kRISCVJumpTableEntrySize;
+ case Triple::loongarch64:
+ return kLOONGARCH64JumpTableEntrySize;
default:
report_fatal_error("Unsupported architecture for jump tables");
}
@@ -1251,10 +1268,8 @@ void LowerTypeTestsModule::createJumpTableEntry(
} else if (JumpTableArch == Triple::arm) {
AsmOS << "b $" << ArgIndex << "\n";
} else if (JumpTableArch == Triple::aarch64) {
- if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
- Dest->getParent()->getModuleFlag("branch-target-enforcement")))
- if (BTE->getZExtValue())
- AsmOS << "bti c\n";
+ if (hasBranchTargetEnforcement())
+ AsmOS << "bti c\n";
AsmOS << "b $" << ArgIndex << "\n";
} else if (JumpTableArch == Triple::thumb) {
if (!CanUseThumbBWJumpTable) {
@@ -1281,11 +1296,16 @@ void LowerTypeTestsModule::createJumpTableEntry(
<< ".balign 4\n"
<< "1: .word $" << ArgIndex << " - (0b + 4)\n";
} else {
+ if (hasBranchTargetEnforcement())
+ AsmOS << "bti\n";
AsmOS << "b.w $" << ArgIndex << "\n";
}
} else if (JumpTableArch == Triple::riscv32 ||
JumpTableArch == Triple::riscv64) {
AsmOS << "tail $" << ArgIndex << "@plt\n";
+ } else if (JumpTableArch == Triple::loongarch64) {
+ AsmOS << "pcalau12i $$t0, %pc_hi20($" << ArgIndex << ")\n"
+ << "jirl $$r0, $$t0, %pc_lo12($" << ArgIndex << ")\n";
} else {
report_fatal_error("Unsupported architecture for jump tables");
}
@@ -1304,7 +1324,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctions(
ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm ||
Arch == Triple::thumb || Arch == Triple::aarch64 ||
- Arch == Triple::riscv32 || Arch == Triple::riscv64)
+ Arch == Triple::riscv32 || Arch == Triple::riscv64 ||
+ Arch == Triple::loongarch64)
buildBitSetsFromFunctionsNative(TypeIds, Functions);
else if (Arch == Triple::wasm32 || Arch == Triple::wasm64)
buildBitSetsFromFunctionsWASM(TypeIds, Functions);
@@ -1446,9 +1467,19 @@ void LowerTypeTestsModule::createJumpTable(
SmallVector<Value *, 16> AsmArgs;
AsmArgs.reserve(Functions.size() * 2);
- for (GlobalTypeMember *GTM : Functions)
+ // Check if all entries have the NoUnwind attribute.
+ // If all entries have it, we can safely mark the
+ // cfi.jumptable as NoUnwind, otherwise, direct calls
+ // to the jump table will not handle exceptions properly
+ bool areAllEntriesNounwind = true;
+ for (GlobalTypeMember *GTM : Functions) {
+ if (!llvm::cast<llvm::Function>(GTM->getGlobal())
+ ->hasFnAttribute(llvm::Attribute::NoUnwind)) {
+ areAllEntriesNounwind = false;
+ }
createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs,
cast<Function>(GTM->getGlobal()));
+ }
// Align the whole table by entry size.
F->setAlignment(Align(getJumpTableEntrySize()));
@@ -1461,17 +1492,23 @@ void LowerTypeTestsModule::createJumpTable(
if (JumpTableArch == Triple::arm)
F->addFnAttr("target-features", "-thumb-mode");
if (JumpTableArch == Triple::thumb) {
- F->addFnAttr("target-features", "+thumb-mode");
- if (CanUseThumbBWJumpTable) {
- // Thumb jump table assembly needs Thumb2. The following attribute is
- // added by Clang for -march=armv7.
- F->addFnAttr("target-cpu", "cortex-a8");
+ if (hasBranchTargetEnforcement()) {
+ // If we're generating a Thumb jump table with BTI, add a target-features
+ // setting to ensure BTI can be assembled.
+ F->addFnAttr("target-features", "+thumb-mode,+pacbti");
+ } else {
+ F->addFnAttr("target-features", "+thumb-mode");
+ if (CanUseThumbBWJumpTable) {
+ // Thumb jump table assembly needs Thumb2. The following attribute is
+ // added by Clang for -march=armv7.
+ F->addFnAttr("target-cpu", "cortex-a8");
+ }
}
}
// When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI
// for the function to avoid double BTI. This is a no-op without
// -mbranch-protection=.
- if (JumpTableArch == Triple::aarch64) {
+ if (JumpTableArch == Triple::aarch64 || JumpTableArch == Triple::thumb) {
F->addFnAttr("branch-target-enforcement", "false");
F->addFnAttr("sign-return-address", "none");
}
@@ -1485,8 +1522,13 @@ void LowerTypeTestsModule::createJumpTable(
// -fcf-protection=.
if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64)
F->addFnAttr(Attribute::NoCfCheck);
- // Make sure we don't emit .eh_frame for this function.
- F->addFnAttr(Attribute::NoUnwind);
+
+ // Make sure we don't emit .eh_frame for this function if it isn't needed.
+ if (areAllEntriesNounwind)
+ F->addFnAttr(Attribute::NoUnwind);
+
+ // Make sure we do not inline any calls to the cfi.jumptable.
+ F->addFnAttr(Attribute::NoInline);
BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
IRBuilder<> IRB(BB);
@@ -1618,12 +1660,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
Function *F = cast<Function>(Functions[I]->getGlobal());
bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical();
- Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
- ConstantExpr::getInBoundsGetElementPtr(
- JumpTableType, JumpTable,
- ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
- ConstantInt::get(IntPtrTy, I)}),
- F->getType());
+ Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr(
+ JumpTableType, JumpTable,
+ ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
+ ConstantInt::get(IntPtrTy, I)});
const bool IsExported = Functions[I]->isExported();
if (!IsJumpTableCanonical) {
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index f835fb26fcb8..70a3f3067d9d 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -104,11 +104,13 @@ static cl::opt<std::string> MemProfImportSummary(
cl::desc("Import summary to use for testing the ThinLTO backend via opt"),
cl::Hidden);
+namespace llvm {
// Indicate we are linking with an allocator that supports hot/cold operator
// new interfaces.
cl::opt<bool> SupportsHotColdNew(
"supports-hot-cold-new", cl::init(false), cl::Hidden,
cl::desc("Linking with hot/cold operator new interfaces"));
+} // namespace llvm
namespace {
/// CRTP base for graphs built from either IR or ThinLTO summary index.
@@ -791,11 +793,10 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
template <typename DerivedCCG, typename FuncTy, typename CallTy>
void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
eraseCalleeEdge(const ContextEdge *Edge) {
- auto EI =
- std::find_if(CalleeEdges.begin(), CalleeEdges.end(),
- [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) {
- return CalleeEdge.get() == Edge;
- });
+ auto EI = llvm::find_if(
+ CalleeEdges, [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) {
+ return CalleeEdge.get() == Edge;
+ });
assert(EI != CalleeEdges.end());
CalleeEdges.erase(EI);
}
@@ -803,11 +804,10 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
template <typename DerivedCCG, typename FuncTy, typename CallTy>
void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::
eraseCallerEdge(const ContextEdge *Edge) {
- auto EI =
- std::find_if(CallerEdges.begin(), CallerEdges.end(),
- [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) {
- return CallerEdge.get() == Edge;
- });
+ auto EI = llvm::find_if(
+ CallerEdges, [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) {
+ return CallerEdge.get() == Edge;
+ });
assert(EI != CallerEdges.end());
CallerEdges.erase(EI);
}
@@ -2093,8 +2093,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones(
for (auto &Edge : CallerEdges) {
// Skip any that have been removed by an earlier recursive call.
if (Edge->Callee == nullptr && Edge->Caller == nullptr) {
- assert(!std::count(Node->CallerEdges.begin(), Node->CallerEdges.end(),
- Edge));
+ assert(!llvm::count(Node->CallerEdges, Edge));
continue;
}
// Ignore any caller we previously visited via another edge.
@@ -2985,6 +2984,21 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
if (!mayHaveMemprofSummary(CB))
continue;
+ auto *CalledValue = CB->getCalledOperand();
+ auto *CalledFunction = CB->getCalledFunction();
+ if (CalledValue && !CalledFunction) {
+ CalledValue = CalledValue->stripPointerCasts();
+ // Stripping pointer casts can reveal a called function.
+ CalledFunction = dyn_cast<Function>(CalledValue);
+ }
+ // Check if this is an alias to a function. If so, get the
+ // called aliasee for the checks below.
+ if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) {
+ assert(!CalledFunction &&
+ "Expected null called function in callsite for alias");
+ CalledFunction = dyn_cast<Function>(GA->getAliaseeObject());
+ }
+
CallStack<MDNode, MDNode::op_iterator> CallsiteContext(
I.getMetadata(LLVMContext::MD_callsite));
auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof);
@@ -3116,13 +3130,13 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size());
// Should have skipped indirect calls via mayHaveMemprofSummary.
- assert(CB->getCalledFunction());
- assert(!IsMemProfClone(*CB->getCalledFunction()));
+ assert(CalledFunction);
+ assert(!IsMemProfClone(*CalledFunction));
// Update the calls per the summary info.
// Save orig name since it gets updated in the first iteration
// below.
- auto CalleeOrigName = CB->getCalledFunction()->getName();
+ auto CalleeOrigName = CalledFunction->getName();
for (unsigned J = 0; J < StackNode.Clones.size(); J++) {
// Do nothing if this version calls the original version of its
// callee.
@@ -3130,7 +3144,7 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
continue;
auto NewF = M.getOrInsertFunction(
getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]),
- CB->getCalledFunction()->getFunctionType());
+ CalledFunction->getFunctionType());
CallBase *CBClone;
// Copy 0 is the original function.
if (!J)
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index feda5d6459cb..c8c011d94e4a 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -107,6 +107,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/StructuralHash.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
@@ -171,15 +172,14 @@ namespace {
class FunctionNode {
mutable AssertingVH<Function> F;
- FunctionComparator::FunctionHash Hash;
+ IRHash Hash;
public:
// Note the hash is recalculated potentially multiple times, but it is cheap.
- FunctionNode(Function *F)
- : F(F), Hash(FunctionComparator::functionHash(*F)) {}
+ FunctionNode(Function *F) : F(F), Hash(StructuralHash(*F)) {}
Function *getFunc() const { return F; }
- FunctionComparator::FunctionHash getHash() const { return Hash; }
+ IRHash getHash() const { return Hash; }
/// Replace the reference to the function F by the function G, assuming their
/// implementations are equal.
@@ -375,9 +375,32 @@ bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
}
#endif
+/// Check whether \p F has an intrinsic which references
+/// distinct metadata as an operand. The most common
+/// instance of this would be CFI checks for function-local types.
+static bool hasDistinctMetadataIntrinsic(const Function &F) {
+ for (const BasicBlock &BB : F) {
+ for (const Instruction &I : BB.instructionsWithoutDebug()) {
+ if (!isa<IntrinsicInst>(&I))
+ continue;
+
+ for (Value *Op : I.operands()) {
+ auto *MDL = dyn_cast<MetadataAsValue>(Op);
+ if (!MDL)
+ continue;
+ if (MDNode *N = dyn_cast<MDNode>(MDL->getMetadata()))
+ if (N->isDistinct())
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
/// Check whether \p F is eligible for function merging.
static bool isEligibleForMerging(Function &F) {
- return !F.isDeclaration() && !F.hasAvailableExternallyLinkage();
+ return !F.isDeclaration() && !F.hasAvailableExternallyLinkage() &&
+ !hasDistinctMetadataIntrinsic(F);
}
bool MergeFunctions::runOnModule(Module &M) {
@@ -390,11 +413,10 @@ bool MergeFunctions::runOnModule(Module &M) {
// All functions in the module, ordered by hash. Functions with a unique
// hash value are easily eliminated.
- std::vector<std::pair<FunctionComparator::FunctionHash, Function *>>
- HashedFuncs;
+ std::vector<std::pair<IRHash, Function *>> HashedFuncs;
for (Function &Func : M) {
if (isEligibleForMerging(Func)) {
- HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func});
+ HashedFuncs.push_back({StructuralHash(Func), &Func});
}
}
@@ -441,7 +463,6 @@ bool MergeFunctions::runOnModule(Module &M) {
// Replace direct callers of Old with New.
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
- Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType());
for (Use &U : llvm::make_early_inc_range(Old->uses())) {
CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (CB && CB->isCallee(&U)) {
@@ -450,7 +471,7 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
// type congruences in byval(), in which case we need to keep the byval
// type of the call-site, not the callee function.
remove(CB->getFunction());
- U.set(BitcastNew);
+ U.set(New);
}
}
}
@@ -632,7 +653,7 @@ static bool canCreateThunkFor(Function *F) {
// Don't merge tiny functions using a thunk, since it can just end up
// making the function larger.
if (F->size() == 1) {
- if (F->front().size() <= 2) {
+ if (F->front().sizeWithoutDebug() < 2) {
LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName()
<< " is too small to bother creating a thunk for\n");
return false;
@@ -641,6 +662,13 @@ static bool canCreateThunkFor(Function *F) {
return true;
}
+/// Copy metadata from one function to another.
+static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) {
+ if (MDNode *MD = From->getMetadata(Key)) {
+ To->setMetadata(Key, MD);
+ }
+}
+
// Replace G with a simple tail call to bitcast(F). Also (unless
// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
// delete G. Under MergeFunctionsPDI, we use G itself for creating
@@ -719,6 +747,9 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
} else {
NewG->copyAttributesFrom(G);
NewG->takeName(G);
+ // Ensure CFI type metadata is propagated to the new function.
+ copyMetadataIfPresent(G, NewG, "type");
+ copyMetadataIfPresent(G, NewG, "kcfi_type");
removeUsers(G);
G->replaceAllUsesWith(NewG);
G->eraseFromParent();
@@ -741,10 +772,9 @@ static bool canCreateAliasFor(Function *F) {
// Replace G with an alias to F (deleting function G)
void MergeFunctions::writeAlias(Function *F, Function *G) {
- Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
PointerType *PtrType = G->getType();
auto *GA = GlobalAlias::create(G->getValueType(), PtrType->getAddressSpace(),
- G->getLinkage(), "", BitcastF, G->getParent());
+ G->getLinkage(), "", F, G->getParent());
const MaybeAlign FAlign = F->getAlign();
const MaybeAlign GAlign = G->getAlign();
@@ -795,6 +825,9 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
F->getAddressSpace(), "", F->getParent());
NewF->copyAttributesFrom(F);
NewF->takeName(F);
+ // Ensure CFI type metadata is propagated to the new function.
+ copyMetadataIfPresent(F, NewF, "type");
+ copyMetadataIfPresent(F, NewF, "kcfi_type");
removeUsers(F);
F->replaceAllUsesWith(NewF);
@@ -825,9 +858,8 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
// to replace a key in ValueMap<GlobalValue *> with a non-global.
GlobalNumbers.erase(G);
// If G's address is not significant, replace it entirely.
- Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
removeUsers(G);
- G->replaceAllUsesWith(BitcastF);
+ G->replaceAllUsesWith(F);
} else {
// Redirect direct callers of G to F. (See note on MergeFunctionsPDI
// above).
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 588f3901e3cb..b2665161c090 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -33,6 +33,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/Assumptions.h"
#include "llvm/IR/BasicBlock.h"
@@ -42,6 +43,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -156,6 +158,8 @@ STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
"Number of OpenMP runtime function uses identified");
STATISTIC(NumOpenMPTargetRegionKernels,
"Number of OpenMP target region entry points (=kernels) identified");
+STATISTIC(NumNonOpenMPTargetRegionKernels,
+ "Number of non-OpenMP target region kernels identified");
STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
"Number of OpenMP target region entry points (=kernels) executed in "
"SPMD-mode instead of generic-mode");
@@ -181,6 +185,92 @@ STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
static constexpr auto TAG = "[" DEBUG_TYPE "]";
#endif
+namespace KernelInfo {
+
+// struct ConfigurationEnvironmentTy {
+// uint8_t UseGenericStateMachine;
+// uint8_t MayUseNestedParallelism;
+// llvm::omp::OMPTgtExecModeFlags ExecMode;
+// int32_t MinThreads;
+// int32_t MaxThreads;
+// int32_t MinTeams;
+// int32_t MaxTeams;
+// };
+
+// struct DynamicEnvironmentTy {
+// uint16_t DebugIndentionLevel;
+// };
+
+// struct KernelEnvironmentTy {
+// ConfigurationEnvironmentTy Configuration;
+// IdentTy *Ident;
+// DynamicEnvironmentTy *DynamicEnv;
+// };
+
+#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
+ constexpr const unsigned MEMBER##Idx = IDX;
+
+KERNEL_ENVIRONMENT_IDX(Configuration, 0)
+KERNEL_ENVIRONMENT_IDX(Ident, 1)
+
+#undef KERNEL_ENVIRONMENT_IDX
+
+#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
+ constexpr const unsigned MEMBER##Idx = IDX;
+
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
+KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
+
+#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
+
+#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
+ RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
+ return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
+ }
+
+KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
+KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
+
+#undef KERNEL_ENVIRONMENT_GETTER
+
+#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
+ ConstantInt *get##MEMBER##FromKernelEnvironment( \
+ ConstantStruct *KernelEnvC) { \
+ ConstantStruct *ConfigC = \
+ getConfigurationFromKernelEnvironment(KernelEnvC); \
+ return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
+ }
+
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
+KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
+
+#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
+
+GlobalVariable *
+getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
+ constexpr const int InitKernelEnvironmentArgNo = 0;
+ return cast<GlobalVariable>(
+ KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
+ ->stripPointerCasts());
+}
+
+ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
+ GlobalVariable *KernelEnvGV =
+ getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
+ return cast<ConstantStruct>(KernelEnvGV->getInitializer());
+}
+} // namespace KernelInfo
+
namespace {
struct AAHeapToShared;
@@ -196,6 +286,7 @@ struct OMPInformationCache : public InformationCache {
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
OpenMPPostLink(OpenMPPostLink) {
+ OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
OMPBuilder.initialize();
initializeRuntimeFunctions(M);
initializeInternalControlVars();
@@ -531,7 +622,7 @@ struct OMPInformationCache : public InformationCache {
for (Function &F : M) {
for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
if (F.hasFnAttribute(Attribute::NoInline) &&
- F.getName().startswith(Prefix) &&
+ F.getName().starts_with(Prefix) &&
!F.hasFnAttribute(Attribute::OptimizeNone))
F.removeFnAttr(Attribute::NoInline);
}
@@ -595,7 +686,7 @@ struct KernelInfoState : AbstractState {
/// The parallel regions (identified by the outlined parallel functions) that
/// can be reached from the associated function.
- BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
+ BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
ReachedKnownParallelRegions;
/// State to track what parallel region we might reach.
@@ -610,6 +701,10 @@ struct KernelInfoState : AbstractState {
/// one we abort as the kernel is malformed.
CallBase *KernelInitCB = nullptr;
+ /// The constant kernel environement as taken from and passed to
+ /// __kmpc_target_init.
+ ConstantStruct *KernelEnvC = nullptr;
+
/// The __kmpc_target_deinit call in this kernel, if any. If we find more than
/// one we abort as the kernel is malformed.
CallBase *KernelDeinitCB = nullptr;
@@ -651,6 +746,7 @@ struct KernelInfoState : AbstractState {
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
ReachedKnownParallelRegions.indicatePessimisticFixpoint();
ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+ NestedParallelism = true;
return ChangeStatus::CHANGED;
}
@@ -680,6 +776,8 @@ struct KernelInfoState : AbstractState {
return false;
if (ParallelLevels != RHS.ParallelLevels)
return false;
+ if (NestedParallelism != RHS.NestedParallelism)
+ return false;
return true;
}
@@ -714,6 +812,12 @@ struct KernelInfoState : AbstractState {
"assumptions.");
KernelDeinitCB = KIS.KernelDeinitCB;
}
+ if (KIS.KernelEnvC) {
+ if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelEnvC = KIS.KernelEnvC;
+ }
SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
@@ -875,6 +979,9 @@ struct OpenMPOpt {
}
}
+ if (OMPInfoCache.OpenMPPostLink)
+ Changed |= removeRuntimeSymbols();
+
return Changed;
}
@@ -903,7 +1010,7 @@ struct OpenMPOpt {
/// Print OpenMP GPU kernels for testing.
void printKernels() const {
for (Function *F : SCC) {
- if (!omp::isKernel(*F))
+ if (!omp::isOpenMPKernel(*F))
continue;
auto Remark = [&](OptimizationRemarkAnalysis ORA) {
@@ -1404,6 +1511,37 @@ private:
return Changed;
}
+ /// Tries to remove known runtime symbols that are optional from the module.
+ bool removeRuntimeSymbols() {
+ // The RPC client symbol is defined in `libc` and indicates that something
+ // required an RPC server. If its users were all optimized out then we can
+ // safely remove it.
+ // TODO: This should be somewhere more common in the future.
+ if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
+ if (!GV->getType()->isPointerTy())
+ return false;
+
+ Constant *C = GV->getInitializer();
+ if (!C)
+ return false;
+
+ // Check to see if the only user of the RPC client is the external handle.
+ GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
+ if (!Client || Client->getNumUses() > 1 ||
+ Client->user_back() != GV->getInitializer())
+ return false;
+
+ Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
+ Client->eraseFromParent();
+
+ GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
+ GV->eraseFromParent();
+
+ return true;
+ }
+ return false;
+ }
+
/// Tries to hide the latency of runtime calls that involve host to
/// device memory transfers by splitting them into their "issue" and "wait"
/// versions. The "issue" is moved upwards as much as possible. The "wait" is
@@ -1858,7 +1996,7 @@ private:
Function *F = I->getParent()->getParent();
auto &ORE = OREGetter(F);
- if (RemarkName.startswith("OMP"))
+ if (RemarkName.starts_with("OMP"))
ORE.emit([&]() {
return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
<< " [" << RemarkName << "]";
@@ -1874,7 +2012,7 @@ private:
RemarkCallBack &&RemarkCB) const {
auto &ORE = OREGetter(F);
- if (RemarkName.startswith("OMP"))
+ if (RemarkName.starts_with("OMP"))
ORE.emit([&]() {
return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
<< " [" << RemarkName << "]";
@@ -1944,7 +2082,7 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
// TODO: We should use an AA to create an (optimistic and callback
// call-aware) call graph. For now we stick to simple patterns that
// are less powerful, basically the worst fixpoint.
- if (isKernel(F)) {
+ if (isOpenMPKernel(F)) {
CachedKernel = Kernel(&F);
return *CachedKernel;
}
@@ -2535,6 +2673,17 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker {
}
};
+/// Determines if \p BB exits the function unconditionally itself or reaches a
+/// block that does through only unique successors.
+static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
+ if (succ_empty(BB))
+ return true;
+ const BasicBlock *const Successor = BB->getUniqueSuccessor();
+ if (!Successor)
+ return false;
+ return hasFunctionEndAsUniqueSuccessor(Successor);
+}
+
struct AAExecutionDomainFunction : public AAExecutionDomain {
AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
: AAExecutionDomain(IRP, A) {}
@@ -2587,18 +2736,22 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
if (!ED.IsReachedFromAlignedBarrierOnly ||
ED.EncounteredNonLocalSideEffect)
return;
+ if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
+ return;
- // We can remove this barrier, if it is one, or all aligned barriers
- // reaching the kernel end. In the latter case we can transitively work
- // our way back until we find a barrier that guards a side-effect if we
- // are dealing with the kernel end here.
+ // We can remove this barrier, if it is one, or aligned barriers reaching
+ // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
+ // end should only be removed if the kernel end is their unique successor;
+ // otherwise, they may have side-effects that aren't accounted for in the
+ // kernel end in their other successors. If those barriers have other
+ // barriers reaching them, those can be transitively removed as well as
+ // long as the kernel end is also their unique successor.
if (CB) {
DeletedBarriers.insert(CB);
A.deleteAfterManifest(*CB);
++NumBarriersEliminated;
Changed = ChangeStatus::CHANGED;
} else if (!ED.AlignedBarriers.empty()) {
- NumBarriersEliminated += ED.AlignedBarriers.size();
Changed = ChangeStatus::CHANGED;
SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
ED.AlignedBarriers.end());
@@ -2609,7 +2762,10 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
continue;
if (LastCB->getFunction() != getAnchorScope())
continue;
+ if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
+ continue;
if (!DeletedBarriers.count(LastCB)) {
+ ++NumBarriersEliminated;
A.deleteAfterManifest(*LastCB);
continue;
}
@@ -2633,7 +2789,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
HandleAlignedBarrier(CB);
// Handle the "kernel end barrier" for kernels too.
- if (omp::isKernel(*getAnchorScope()))
+ if (omp::isOpenMPKernel(*getAnchorScope()))
HandleAlignedBarrier(nullptr);
return Changed;
@@ -2779,9 +2935,11 @@ struct AAExecutionDomainFunction : public AAExecutionDomain {
CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
if (!CB)
return false;
- const int InitModeArgNo = 1;
- auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
- return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
+ ConstantStruct *KernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(CB);
+ ConstantInt *ExecModeC =
+ KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
+ return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
}
if (C->isZero()) {
@@ -2884,11 +3042,11 @@ bool AAExecutionDomainFunction::handleCallees(Attributor &A,
} else {
// We could not find all predecessors, so this is either a kernel or a
// function with external linkage (or with some other weird uses).
- if (omp::isKernel(*getAnchorScope())) {
+ if (omp::isOpenMPKernel(*getAnchorScope())) {
EntryBBED.IsExecutedByInitialThreadOnly = false;
EntryBBED.IsReachedFromAlignedBarrierOnly = true;
EntryBBED.EncounteredNonLocalSideEffect = false;
- ExitED.IsReachingAlignedBarrierOnly = true;
+ ExitED.IsReachingAlignedBarrierOnly = false;
} else {
EntryBBED.IsExecutedByInitialThreadOnly = false;
EntryBBED.IsReachedFromAlignedBarrierOnly = false;
@@ -2938,7 +3096,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
Function *F = getAnchorScope();
BasicBlock &EntryBB = F->getEntryBlock();
- bool IsKernel = omp::isKernel(*F);
+ bool IsKernel = omp::isOpenMPKernel(*F);
SmallVector<Instruction *> SyncInstWorklist;
for (auto &RIt : *RPOT) {
@@ -3063,7 +3221,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
if (EDAA && EDAA->getState().isValidState()) {
const auto &CalleeED = EDAA->getFunctionExecutionDomain();
ED.IsReachedFromAlignedBarrierOnly =
- CalleeED.IsReachedFromAlignedBarrierOnly;
+ CalleeED.IsReachedFromAlignedBarrierOnly;
AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
ED.EncounteredNonLocalSideEffect |=
@@ -3442,6 +3600,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+ /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
+ /// unknown callees.
+ static bool requiresCalleeForCallBase() { return false; }
+
/// Statistics are tracked as part of manifest for now.
void trackStatistics() const override {}
@@ -3468,7 +3630,8 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
", #ParLevels: " +
(ParallelLevels.isValidState()
? std::to_string(ParallelLevels.size())
- : "<invalid>");
+ : "<invalid>") +
+ ", NestedPar: " + (NestedParallelism ? "yes" : "no");
}
/// Create an abstract attribute biew for the position \p IRP.
@@ -3500,6 +3663,33 @@ struct AAKernelInfoFunction : AAKernelInfo {
return GuardedInstructions;
}
+ void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
+ Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
+ KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
+ assert(NewKernelEnvC && "Failed to create new kernel environment");
+ KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
+ }
+
+#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
+ void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
+ ConstantStruct *ConfigC = \
+ KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
+ Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
+ ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
+ assert(NewConfigC && "Failed to create new configuration environment"); \
+ setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
+ }
+
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
+ KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
+
+#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
+
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
// This is a high-level transform that might change the constant arguments
@@ -3548,61 +3738,73 @@ struct AAKernelInfoFunction : AAKernelInfo {
ReachingKernelEntries.insert(Fn);
IsKernelEntry = true;
- // For kernels we might need to initialize/finalize the IsSPMD state and
- // we need to register a simplification callback so that the Attributor
- // knows the constant arguments to __kmpc_target_init and
- // __kmpc_target_deinit might actually change.
-
- Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
- [&](const IRPosition &IRP, const AbstractAttribute *AA,
- bool &UsedAssumedInformation) -> std::optional<Value *> {
- return nullptr;
- };
+ KernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
+ GlobalVariable *KernelEnvGV =
+ KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
- Attributor::SimplifictionCallbackTy ModeSimplifyCB =
- [&](const IRPosition &IRP, const AbstractAttribute *AA,
- bool &UsedAssumedInformation) -> std::optional<Value *> {
- // IRP represents the "SPMDCompatibilityTracker" argument of an
- // __kmpc_target_init or
- // __kmpc_target_deinit call. We will answer this one with the internal
- // state.
- if (!SPMDCompatibilityTracker.isValidState())
- return nullptr;
- if (!SPMDCompatibilityTracker.isAtFixpoint()) {
- if (AA)
- A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
+ Attributor::GlobalVariableSimplifictionCallbackTy
+ KernelConfigurationSimplifyCB =
+ [&](const GlobalVariable &GV, const AbstractAttribute *AA,
+ bool &UsedAssumedInformation) -> std::optional<Constant *> {
+ if (!isAtFixpoint()) {
+ if (!AA)
+ return nullptr;
UsedAssumedInformation = true;
- } else {
- UsedAssumedInformation = false;
+ A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
}
- auto *Val = ConstantInt::getSigned(
- IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
- SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
- : OMP_TGT_EXEC_MODE_GENERIC);
- return Val;
+ return KernelEnvC;
};
- constexpr const int InitModeArgNo = 1;
- constexpr const int DeinitModeArgNo = 1;
- constexpr const int InitUseStateMachineArgNo = 2;
- A.registerSimplificationCallback(
- IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
- StateMachineSimplifyCB);
- A.registerSimplificationCallback(
- IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
- ModeSimplifyCB);
- A.registerSimplificationCallback(
- IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
- ModeSimplifyCB);
+ A.registerGlobalVariableSimplificationCallback(
+ *KernelEnvGV, KernelConfigurationSimplifyCB);
// Check if we know we are in SPMD-mode already.
- ConstantInt *ModeArg =
- dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
- if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
+ ConstantInt *ExecModeC =
+ KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
+ ConstantInt *AssumedExecModeC = ConstantInt::get(
+ ExecModeC->getType(),
+ ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
+ if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
SPMDCompatibilityTracker.indicateOptimisticFixpoint();
- // This is a generic region but SPMDization is disabled so stop tracking.
else if (DisableOpenMPOptSPMDization)
+ // This is a generic region but SPMDization is disabled so stop
+ // tracking.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ else
+ setExecModeOfKernelEnvironment(AssumedExecModeC);
+
+ const Triple T(Fn->getParent()->getTargetTriple());
+ auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
+ auto [MinThreads, MaxThreads] =
+ OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
+ if (MinThreads)
+ setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
+ if (MaxThreads)
+ setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
+ auto [MinTeams, MaxTeams] =
+ OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
+ if (MinTeams)
+ setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
+ if (MaxTeams)
+ setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
+
+ ConstantInt *MayUseNestedParallelismC =
+ KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
+ ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
+ MayUseNestedParallelismC->getType(), NestedParallelism);
+ setMayUseNestedParallelismOfKernelEnvironment(
+ AssumedMayUseNestedParallelismC);
+
+ if (!DisableOpenMPOptStateMachineRewrite) {
+ ConstantInt *UseGenericStateMachineC =
+ KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
+ KernelEnvC);
+ ConstantInt *AssumedUseGenericStateMachineC =
+ ConstantInt::get(UseGenericStateMachineC->getType(), false);
+ setUseGenericStateMachineOfKernelEnvironment(
+ AssumedUseGenericStateMachineC);
+ }
// Register virtual uses of functions we might need to preserve.
auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
@@ -3703,22 +3905,32 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!KernelInitCB || !KernelDeinitCB)
return ChangeStatus::UNCHANGED;
- /// Insert nested Parallelism global variable
- Function *Kernel = getAnchorScope();
- Module &M = *Kernel->getParent();
- Type *Int8Ty = Type::getInt8Ty(M.getContext());
- auto *GV = new GlobalVariable(
- M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage,
- ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0),
- Kernel->getName() + "_nested_parallelism");
- GV->setVisibility(GlobalValue::HiddenVisibility);
-
- // If we can we change the execution mode to SPMD-mode otherwise we build a
- // custom state machine.
ChangeStatus Changed = ChangeStatus::UNCHANGED;
+
+ bool HasBuiltStateMachine = true;
if (!changeToSPMDMode(A, Changed)) {
if (!KernelInitCB->getCalledFunction()->isDeclaration())
- return buildCustomStateMachine(A);
+ HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
+ else
+ HasBuiltStateMachine = false;
+ }
+
+ // We need to reset KernelEnvC if specific rewriting is not done.
+ ConstantStruct *ExistingKernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
+ ConstantInt *OldUseGenericStateMachineVal =
+ KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
+ ExistingKernelEnvC);
+ if (!HasBuiltStateMachine)
+ setUseGenericStateMachineOfKernelEnvironment(
+ OldUseGenericStateMachineVal);
+
+ // At last, update the KernelEnvc
+ GlobalVariable *KernelEnvGV =
+ KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
+ if (KernelEnvGV->getInitializer() != KernelEnvC) {
+ KernelEnvGV->setInitializer(KernelEnvC);
+ Changed = ChangeStatus::CHANGED;
}
return Changed;
@@ -3788,14 +4000,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
// Find escaping outputs from the guarded region to outside users and
// broadcast their values to them.
for (Instruction &I : *RegionStartBB) {
- SmallPtrSet<Instruction *, 4> OutsideUsers;
- for (User *Usr : I.users()) {
- Instruction &UsrI = *cast<Instruction>(Usr);
+ SmallVector<Use *, 4> OutsideUses;
+ for (Use &U : I.uses()) {
+ Instruction &UsrI = *cast<Instruction>(U.getUser());
if (UsrI.getParent() != RegionStartBB)
- OutsideUsers.insert(&UsrI);
+ OutsideUses.push_back(&U);
}
- if (OutsideUsers.empty())
+ if (OutsideUses.empty())
continue;
HasBroadcastValues = true;
@@ -3818,8 +4030,8 @@ struct AAKernelInfoFunction : AAKernelInfo {
RegionBarrierBB->getTerminator());
// Emit a load instruction and replace uses of the output value.
- for (Instruction *UsrI : OutsideUsers)
- UsrI->replaceUsesOfWith(&I, LoadI);
+ for (Use *U : OutsideUses)
+ A.changeUseAfterManifest(*U, *LoadI);
}
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
@@ -4043,19 +4255,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
auto *CB = cast<CallBase>(Kernel->user_back());
Kernel = CB->getCaller();
}
- assert(omp::isKernel(*Kernel) && "Expected kernel function!");
+ assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
// Check if the kernel is already in SPMD mode, if so, return success.
- GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
- (Kernel->getName() + "_exec_mode").str());
- assert(ExecMode && "Kernel without exec mode?");
- assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
-
- // Set the global exec mode flag to indicate SPMD-Generic mode.
- assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
- "ExecMode is not an integer!");
- const int8_t ExecModeVal =
- cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
+ ConstantStruct *ExistingKernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
+ auto *ExecModeC =
+ KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
+ const int8_t ExecModeVal = ExecModeC->getSExtValue();
if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
return true;
@@ -4073,27 +4280,8 @@ struct AAKernelInfoFunction : AAKernelInfo {
// kernel is executed in.
assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
"Initially non-SPMD kernel has SPMD exec mode!");
- ExecMode->setInitializer(
- ConstantInt::get(ExecMode->getInitializer()->getType(),
- ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
-
- // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
- const int InitModeArgNo = 1;
- const int DeinitModeArgNo = 1;
- const int InitUseStateMachineArgNo = 2;
-
- auto &Ctx = getAnchorValue().getContext();
- A.changeUseAfterManifest(
- KernelInitCB->getArgOperandUse(InitModeArgNo),
- *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
- OMP_TGT_EXEC_MODE_SPMD));
- A.changeUseAfterManifest(
- KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
- *ConstantInt::getBool(Ctx, false));
- A.changeUseAfterManifest(
- KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
- *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
- OMP_TGT_EXEC_MODE_SPMD));
+ setExecModeOfKernelEnvironment(ConstantInt::get(
+ ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
++NumOpenMPTargetRegionKernelsSPMD;
@@ -4104,46 +4292,47 @@ struct AAKernelInfoFunction : AAKernelInfo {
return true;
};
- ChangeStatus buildCustomStateMachine(Attributor &A) {
+ bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
// If we have disabled state machine rewrites, don't make a custom one
if (DisableOpenMPOptStateMachineRewrite)
- return ChangeStatus::UNCHANGED;
+ return false;
// Don't rewrite the state machine if we are not in a valid state.
if (!ReachedKnownParallelRegions.isValidState())
- return ChangeStatus::UNCHANGED;
+ return false;
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
if (!OMPInfoCache.runtimeFnsAvailable(
{OMPRTL___kmpc_get_hardware_num_threads_in_block,
OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
- return ChangeStatus::UNCHANGED;
+ return false;
- const int InitModeArgNo = 1;
- const int InitUseStateMachineArgNo = 2;
+ ConstantStruct *ExistingKernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
// Check if the current configuration is non-SPMD and generic state machine.
// If we already have SPMD mode or a custom state machine we do not need to
// go any further. If it is anything but a constant something is weird and
// we give up.
- ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
- KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
- ConstantInt *Mode =
- dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
+ ConstantInt *UseStateMachineC =
+ KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
+ ExistingKernelEnvC);
+ ConstantInt *ModeC =
+ KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
// If we are stuck with generic mode, try to create a custom device (=GPU)
// state machine which is specialized for the parallel regions that are
// reachable by the kernel.
- if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
- (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
- return ChangeStatus::UNCHANGED;
+ if (UseStateMachineC->isZero() ||
+ (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
+ return false;
+
+ Changed = ChangeStatus::CHANGED;
// If not SPMD mode, indicate we use a custom state machine now.
- auto &Ctx = getAnchorValue().getContext();
- auto *FalseVal = ConstantInt::getBool(Ctx, false);
- A.changeUseAfterManifest(
- KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
+ setUseGenericStateMachineOfKernelEnvironment(
+ ConstantInt::get(UseStateMachineC->getType(), false));
// If we don't actually need a state machine we are done here. This can
// happen if there simply are no parallel regions. In the resulting kernel
@@ -4157,7 +4346,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
};
A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
- return ChangeStatus::CHANGED;
+ return true;
}
// Keep track in the statistics of our new shiny custom state machine.
@@ -4222,6 +4411,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
// UserCodeEntryBB: // user code
// __kmpc_target_deinit(...)
//
+ auto &Ctx = getAnchorValue().getContext();
Function *Kernel = getAssociatedFunction();
assert(Kernel && "Expected an associated function!");
@@ -4292,7 +4482,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
// Create local storage for the work function pointer.
const DataLayout &DL = M.getDataLayout();
- Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
+ Type *VoidPtrTy = PointerType::getUnqual(Ctx);
Instruction *WorkFnAI =
new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
"worker.work_fn.addr", &Kernel->getEntryBlock().front());
@@ -4304,7 +4494,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
StateMachineBeginBB->end()),
DLoc));
- Value *Ident = KernelInitCB->getArgOperand(0);
+ Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
Value *GTid = KernelInitCB;
FunctionCallee BarrierFn =
@@ -4337,9 +4527,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
FunctionType *ParallelRegionFnTy = FunctionType::get(
Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
- Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
- WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
- StateMachineBeginBB);
Instruction *IsDone =
ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
@@ -4358,11 +4545,15 @@ struct AAKernelInfoFunction : AAKernelInfo {
Value *ZeroArg =
Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
+ const unsigned int WrapperFunctionArgNo = 6;
+
// Now that we have most of the CFG skeleton it is time for the if-cascade
// that checks the function pointer we got from the runtime against the
// parallel regions we expect, if there are any.
for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
- auto *ParallelRegion = ReachedKnownParallelRegions[I];
+ auto *CB = ReachedKnownParallelRegions[I];
+ auto *ParallelRegion = dyn_cast<Function>(
+ CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
BasicBlock *PRExecuteBB = BasicBlock::Create(
Ctx, "worker_state_machine.parallel_region.execute", Kernel,
StateMachineEndParallelBB);
@@ -4374,13 +4565,15 @@ struct AAKernelInfoFunction : AAKernelInfo {
BasicBlock *PRNextBB =
BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
Kernel, StateMachineEndParallelBB);
+ A.registerManifestAddedBasicBlock(*PRExecuteBB);
+ A.registerManifestAddedBasicBlock(*PRNextBB);
// Check if we need to compare the pointer at all or if we can just
// call the parallel region function.
Value *IsPR;
if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
Instruction *CmpI = ICmpInst::Create(
- ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
+ ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
"worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
CmpI->setDebugLoc(DLoc);
IsPR = CmpI;
@@ -4400,7 +4593,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!ReachedUnknownParallelRegions.empty()) {
StateMachineIfCascadeCurrentBB->setName(
"worker_state_machine.parallel_region.fallback.execute");
- CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
+ CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
StateMachineIfCascadeCurrentBB)
->setDebugLoc(DLoc);
}
@@ -4423,7 +4616,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
->setDebugLoc(DLoc);
- return ChangeStatus::CHANGED;
+ return true;
}
/// Fixpoint iteration update function. Will be called every time a dependence
@@ -4431,6 +4624,46 @@ struct AAKernelInfoFunction : AAKernelInfo {
ChangeStatus updateImpl(Attributor &A) override {
KernelInfoState StateBefore = getState();
+ // When we leave this function this RAII will make sure the member
+ // KernelEnvC is updated properly depending on the state. That member is
+ // used for simplification of values and needs to be up to date at all
+ // times.
+ struct UpdateKernelEnvCRAII {
+ AAKernelInfoFunction &AA;
+
+ UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
+
+ ~UpdateKernelEnvCRAII() {
+ if (!AA.KernelEnvC)
+ return;
+
+ ConstantStruct *ExistingKernelEnvC =
+ KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);
+
+ if (!AA.isValidState()) {
+ AA.KernelEnvC = ExistingKernelEnvC;
+ return;
+ }
+
+ if (!AA.ReachedKnownParallelRegions.isValidState())
+ AA.setUseGenericStateMachineOfKernelEnvironment(
+ KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
+ ExistingKernelEnvC));
+
+ if (!AA.SPMDCompatibilityTracker.isValidState())
+ AA.setExecModeOfKernelEnvironment(
+ KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
+
+ ConstantInt *MayUseNestedParallelismC =
+ KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
+ AA.KernelEnvC);
+ ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
+ MayUseNestedParallelismC->getType(), AA.NestedParallelism);
+ AA.setMayUseNestedParallelismOfKernelEnvironment(
+ NewMayUseNestedParallelismC);
+ }
+ } RAII(*this);
+
// Callback to check a read/write instruction.
auto CheckRWInst = [&](Instruction &I) {
// We handle calls later.
@@ -4634,15 +4867,13 @@ struct AAKernelInfoCallSite : AAKernelInfo {
AAKernelInfo::initialize(A);
CallBase &CB = cast<CallBase>(getAssociatedValue());
- Function *Callee = getAssociatedFunction();
-
auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
// Check for SPMD-mode assumptions.
if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
- SPMDCompatibilityTracker.indicateOptimisticFixpoint();
indicateOptimisticFixpoint();
+ return;
}
// First weed out calls we do not care about, that is readonly/readnone
@@ -4657,124 +4888,156 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// we will handle them explicitly in the switch below. If it is not, we
// will use an AAKernelInfo object on the callee to gather information and
// merge that into the current state. The latter happens in the updateImpl.
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- // Unknown caller or declarations are not analyzable, we give up.
- if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
+ auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ // Unknown caller or declarations are not analyzable, we give up.
+ if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
- // Unknown callees might contain parallel regions, except if they have
- // an appropriate assumption attached.
- if (!AssumptionAA ||
- !(AssumptionAA->hasAssumption("omp_no_openmp") ||
- AssumptionAA->hasAssumption("omp_no_parallelism")))
- ReachedUnknownParallelRegions.insert(&CB);
+ // Unknown callees might contain parallel regions, except if they have
+ // an appropriate assumption attached.
+ if (!AssumptionAA ||
+ !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+ AssumptionAA->hasAssumption("omp_no_parallelism")))
+ ReachedUnknownParallelRegions.insert(&CB);
- // If SPMDCompatibilityTracker is not fixed, we need to give up on the
- // idea we can run something unknown in SPMD-mode.
- if (!SPMDCompatibilityTracker.isAtFixpoint()) {
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- }
+ // If SPMDCompatibilityTracker is not fixed, we need to give up on the
+ // idea we can run something unknown in SPMD-mode.
+ if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ }
- // We have updated the state for this unknown call properly, there won't
- // be any change so we indicate a fixpoint.
- indicateOptimisticFixpoint();
+ // We have updated the state for this unknown call properly, there
+ // won't be any change so we indicate a fixpoint.
+ indicateOptimisticFixpoint();
+ }
+ // If the callee is known and can be used in IPO, we will update the
+ // state based on the callee state in updateImpl.
+ return;
+ }
+ if (NumCallees > 1) {
+ indicatePessimisticFixpoint();
+ return;
}
- // If the callee is known and can be used in IPO, we will update the state
- // based on the callee state in updateImpl.
- return;
- }
- const unsigned int WrapperFunctionArgNo = 6;
- RuntimeFunction RF = It->getSecond();
- switch (RF) {
- // All the functions we know are compatible with SPMD mode.
- case OMPRTL___kmpc_is_spmd_exec_mode:
- case OMPRTL___kmpc_distribute_static_fini:
- case OMPRTL___kmpc_for_static_fini:
- case OMPRTL___kmpc_global_thread_num:
- case OMPRTL___kmpc_get_hardware_num_threads_in_block:
- case OMPRTL___kmpc_get_hardware_num_blocks:
- case OMPRTL___kmpc_single:
- case OMPRTL___kmpc_end_single:
- case OMPRTL___kmpc_master:
- case OMPRTL___kmpc_end_master:
- case OMPRTL___kmpc_barrier:
- case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_end_reduce_nowait:
- break;
- case OMPRTL___kmpc_distribute_static_init_4:
- case OMPRTL___kmpc_distribute_static_init_4u:
- case OMPRTL___kmpc_distribute_static_init_8:
- case OMPRTL___kmpc_distribute_static_init_8u:
- case OMPRTL___kmpc_for_static_init_4:
- case OMPRTL___kmpc_for_static_init_4u:
- case OMPRTL___kmpc_for_static_init_8:
- case OMPRTL___kmpc_for_static_init_8u: {
- // Check the schedule and allow static schedule in SPMD mode.
- unsigned ScheduleArgOpNo = 2;
- auto *ScheduleTypeCI =
- dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
- unsigned ScheduleTypeVal =
- ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
- switch (OMPScheduleType(ScheduleTypeVal)) {
- case OMPScheduleType::UnorderedStatic:
- case OMPScheduleType::UnorderedStaticChunked:
- case OMPScheduleType::OrderedDistribute:
- case OMPScheduleType::OrderedDistributeChunked:
+ RuntimeFunction RF = It->getSecond();
+ switch (RF) {
+ // All the functions we know are compatible with SPMD mode.
+ case OMPRTL___kmpc_is_spmd_exec_mode:
+ case OMPRTL___kmpc_distribute_static_fini:
+ case OMPRTL___kmpc_for_static_fini:
+ case OMPRTL___kmpc_global_thread_num:
+ case OMPRTL___kmpc_get_hardware_num_threads_in_block:
+ case OMPRTL___kmpc_get_hardware_num_blocks:
+ case OMPRTL___kmpc_single:
+ case OMPRTL___kmpc_end_single:
+ case OMPRTL___kmpc_master:
+ case OMPRTL___kmpc_end_master:
+ case OMPRTL___kmpc_barrier:
+ case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
+ case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
+ case OMPRTL___kmpc_error:
+ case OMPRTL___kmpc_flush:
+ case OMPRTL___kmpc_get_hardware_thread_id_in_block:
+ case OMPRTL___kmpc_get_warp_size:
+ case OMPRTL_omp_get_thread_num:
+ case OMPRTL_omp_get_num_threads:
+ case OMPRTL_omp_get_max_threads:
+ case OMPRTL_omp_in_parallel:
+ case OMPRTL_omp_get_dynamic:
+ case OMPRTL_omp_get_cancellation:
+ case OMPRTL_omp_get_nested:
+ case OMPRTL_omp_get_schedule:
+ case OMPRTL_omp_get_thread_limit:
+ case OMPRTL_omp_get_supported_active_levels:
+ case OMPRTL_omp_get_max_active_levels:
+ case OMPRTL_omp_get_level:
+ case OMPRTL_omp_get_ancestor_thread_num:
+ case OMPRTL_omp_get_team_size:
+ case OMPRTL_omp_get_active_level:
+ case OMPRTL_omp_in_final:
+ case OMPRTL_omp_get_proc_bind:
+ case OMPRTL_omp_get_num_places:
+ case OMPRTL_omp_get_num_procs:
+ case OMPRTL_omp_get_place_proc_ids:
+ case OMPRTL_omp_get_place_num:
+ case OMPRTL_omp_get_partition_num_places:
+ case OMPRTL_omp_get_partition_place_nums:
+ case OMPRTL_omp_get_wtime:
break;
- default:
+ case OMPRTL___kmpc_distribute_static_init_4:
+ case OMPRTL___kmpc_distribute_static_init_4u:
+ case OMPRTL___kmpc_distribute_static_init_8:
+ case OMPRTL___kmpc_distribute_static_init_8u:
+ case OMPRTL___kmpc_for_static_init_4:
+ case OMPRTL___kmpc_for_static_init_4u:
+ case OMPRTL___kmpc_for_static_init_8:
+ case OMPRTL___kmpc_for_static_init_8u: {
+ // Check the schedule and allow static schedule in SPMD mode.
+ unsigned ScheduleArgOpNo = 2;
+ auto *ScheduleTypeCI =
+ dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
+ unsigned ScheduleTypeVal =
+ ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
+ switch (OMPScheduleType(ScheduleTypeVal)) {
+ case OMPScheduleType::UnorderedStatic:
+ case OMPScheduleType::UnorderedStaticChunked:
+ case OMPScheduleType::OrderedDistribute:
+ case OMPScheduleType::OrderedDistributeChunked:
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ };
+ } break;
+ case OMPRTL___kmpc_target_init:
+ KernelInitCB = &CB;
+ break;
+ case OMPRTL___kmpc_target_deinit:
+ KernelDeinitCB = &CB;
+ break;
+ case OMPRTL___kmpc_parallel_51:
+ if (!handleParallel51(A, CB))
+ indicatePessimisticFixpoint();
+ return;
+ case OMPRTL___kmpc_omp_task:
+ // We do not look into tasks right now, just give up.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
+ ReachedUnknownParallelRegions.insert(&CB);
break;
- };
- } break;
- case OMPRTL___kmpc_target_init:
- KernelInitCB = &CB;
- break;
- case OMPRTL___kmpc_target_deinit:
- KernelDeinitCB = &CB;
- break;
- case OMPRTL___kmpc_parallel_51:
- if (auto *ParallelRegion = dyn_cast<Function>(
- CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
- ReachedKnownParallelRegions.insert(ParallelRegion);
- /// Check nested parallelism
- auto *FnAA = A.getAAFor<AAKernelInfo>(
- *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
- NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
- !FnAA->ReachedKnownParallelRegions.empty() ||
- !FnAA->ReachedUnknownParallelRegions.empty();
+ case OMPRTL___kmpc_alloc_shared:
+ case OMPRTL___kmpc_free_shared:
+ // Return without setting a fixpoint, to be resolved in updateImpl.
+ return;
+ default:
+ // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
+ // generally. However, they do not hide parallel regions.
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
break;
}
- // The condition above should usually get the parallel region function
- // pointer and record it. In the off chance it doesn't we assume the
- // worst.
- ReachedUnknownParallelRegions.insert(&CB);
- break;
- case OMPRTL___kmpc_omp_task:
- // We do not look into tasks right now, just give up.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- ReachedUnknownParallelRegions.insert(&CB);
- break;
- case OMPRTL___kmpc_alloc_shared:
- case OMPRTL___kmpc_free_shared:
- // Return without setting a fixpoint, to be resolved in updateImpl.
+ // All other OpenMP runtime calls will not reach parallel regions so they
+ // can be safely ignored for now. Since it is a known OpenMP runtime call
+ // we have now modeled all effects and there is no need for any update.
+ indicateOptimisticFixpoint();
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ CheckCallee(getAssociatedFunction(), 1);
return;
- default:
- // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
- // generally. However, they do not hide parallel regions.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- break;
}
- // All other OpenMP runtime calls will not reach parallel regions so they
- // can be safely ignored for now. Since it is a known OpenMP runtime call we
- // have now modeled all effects and there is no need for any update.
- indicateOptimisticFixpoint();
+ const auto &OptimisticEdges = AACE->getOptimisticEdges();
+ for (auto *Callee : OptimisticEdges) {
+ CheckCallee(Callee, OptimisticEdges.size());
+ if (isAtFixpoint())
+ break;
+ }
}
ChangeStatus updateImpl(Attributor &A) override {
@@ -4782,62 +5045,115 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// call site specific liveness information and then it makes
// sense to specialize attributes for call sites arguments instead of
// redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
-
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+ KernelInfoState StateBefore = getState();
- // If F is not a runtime function, propagate the AAKernelInfo of the callee.
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- const IRPosition &FnPos = IRPosition::function(*F);
- auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
- if (!FnAA)
+ auto CheckCallee = [&](Function *F, int NumCallees) {
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+
+ // If F is not a runtime function, propagate the AAKernelInfo of the
+ // callee.
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto *FnAA =
+ A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
+ if (getState() == FnAA->getState())
+ return ChangeStatus::UNCHANGED;
+ getState() = FnAA->getState();
+ return ChangeStatus::CHANGED;
+ }
+ if (NumCallees > 1)
return indicatePessimisticFixpoint();
- if (getState() == FnAA->getState())
- return ChangeStatus::UNCHANGED;
- getState() = FnAA->getState();
- return ChangeStatus::CHANGED;
- }
- // F is a runtime function that allocates or frees memory, check
- // AAHeapToStack and AAHeapToShared.
- KernelInfoState StateBefore = getState();
- assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
- It->getSecond() == OMPRTL___kmpc_free_shared) &&
- "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+ CallBase &CB = cast<CallBase>(getAssociatedValue());
+ if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
+ if (!handleParallel51(A, CB))
+ return indicatePessimisticFixpoint();
+ return StateBefore == getState() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
- CallBase &CB = cast<CallBase>(getAssociatedValue());
+ // F is a runtime function that allocates or frees memory, check
+ // AAHeapToStack and AAHeapToShared.
+ assert(
+ (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
+ It->getSecond() == OMPRTL___kmpc_free_shared) &&
+ "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
- auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
- auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+ auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+ auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
- RuntimeFunction RF = It->getSecond();
+ RuntimeFunction RF = It->getSecond();
- switch (RF) {
- // If neither HeapToStack nor HeapToShared assume the call is removed,
- // assume SPMD incompatibility.
- case OMPRTL___kmpc_alloc_shared:
- if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
- (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
- SPMDCompatibilityTracker.insert(&CB);
- break;
- case OMPRTL___kmpc_free_shared:
- if ((!HeapToStackAA ||
- !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
- (!HeapToSharedAA ||
- !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ switch (RF) {
+ // If neither HeapToStack nor HeapToShared assume the call is removed,
+ // assume SPMD incompatibility.
+ case OMPRTL___kmpc_alloc_shared:
+ if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
+ (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ case OMPRTL___kmpc_free_shared:
+ if ((!HeapToStackAA ||
+ !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
+ (!HeapToSharedAA ||
+ !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
- break;
- default:
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
+ }
+ return ChangeStatus::CHANGED;
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ if (Function *F = getAssociatedFunction())
+ CheckCallee(F, /*NumCallees=*/1);
+ } else {
+ const auto &OptimisticEdges = AACE->getOptimisticEdges();
+ for (auto *Callee : OptimisticEdges) {
+ CheckCallee(Callee, OptimisticEdges.size());
+ if (isAtFixpoint())
+ break;
+ }
}
return StateBefore == getState() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}
+
+ /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
+ /// handled, if a problem occurred, false is returned.
+ bool handleParallel51(Attributor &A, CallBase &CB) {
+ const unsigned int NonWrapperFunctionArgNo = 5;
+ const unsigned int WrapperFunctionArgNo = 6;
+ auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
+ ? NonWrapperFunctionArgNo
+ : WrapperFunctionArgNo;
+
+ auto *ParallelRegion = dyn_cast<Function>(
+ CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
+ if (!ParallelRegion)
+ return false;
+
+ ReachedKnownParallelRegions.insert(&CB);
+ /// Check nested parallelism
+ auto *FnAA = A.getAAFor<AAKernelInfo>(
+ *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
+ NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
+ !FnAA->ReachedKnownParallelRegions.empty() ||
+ !FnAA->ReachedKnownParallelRegions.isValidState() ||
+ !FnAA->ReachedUnknownParallelRegions.isValidState() ||
+ !FnAA->ReachedUnknownParallelRegions.empty();
+ return true;
+ }
};
struct AAFoldRuntimeCall
@@ -5251,6 +5567,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
UsedAssumedInformation, AA::Interprocedural);
continue;
}
+ if (auto *CI = dyn_cast<CallBase>(&I)) {
+ if (CI->isIndirectCall())
+ A.getOrCreateAAFor<AAIndirectCallInfo>(
+ IRPosition::callsite_function(*CI));
+ }
if (auto *SI = dyn_cast<StoreInst>(&I)) {
A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
continue;
@@ -5569,7 +5890,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
return PreservedAnalyses::all();
}
-bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); }
+bool llvm::omp::isOpenMPKernel(Function &Fn) {
+ return Fn.hasFnAttribute("kernel");
+}
KernelSet llvm::omp::getDeviceKernels(Module &M) {
// TODO: Create a more cross-platform way of determining device kernels.
@@ -5591,10 +5914,13 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) {
if (!KernelFn)
continue;
- assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation");
- ++NumOpenMPTargetRegionKernels;
-
- Kernels.insert(KernelFn);
+ // We are only interested in OpenMP target regions. Others, such as kernels
+ // generated by CUDA but linked together, are not interesting to this pass.
+ if (isOpenMPKernel(*KernelFn)) {
+ ++NumOpenMPTargetRegionKernels;
+ Kernels.insert(KernelFn);
+ } else
+ ++NumNonOpenMPTargetRegionKernels;
}
return Kernels;
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index b88ba2dec24b..aa4f205ec5bd 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -161,7 +161,7 @@ struct FunctionOutliningInfo {
// The dominating block of the region to be outlined.
BasicBlock *NonReturnBlock = nullptr;
- // The set of blocks in Entries that that are predecessors to ReturnBlock
+ // The set of blocks in Entries that are predecessors to ReturnBlock
SmallVector<BasicBlock *, 4> ReturnBlockPreds;
};
@@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
const DataLayout &DL = Caller->getParent()->getDataLayout();
// The savings of eliminating the call:
- int NonWeightedSavings = getCallsiteCost(CB, DL);
+ int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
BlockFrequency NormWeightedSavings(NonWeightedSavings);
// Weighted saving is smaller than weighted cost, return false
@@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
}
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
- InlineCost += getCallsiteCost(*CI, DL);
+ InlineCost += getCallsiteCost(*TTI, *CI, DL);
continue;
}
if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
- InlineCost += getCallsiteCost(*II, DL);
+ InlineCost += getCallsiteCost(*TTI, *II, DL);
continue;
}
@@ -1042,7 +1042,7 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const {
ClonedOI->ReturnBlock = ClonedOI->ReturnBlock->splitBasicBlock(
ClonedOI->ReturnBlock->getFirstNonPHI()->getIterator());
BasicBlock::iterator I = PreReturn->begin();
- Instruction *Ins = &ClonedOI->ReturnBlock->front();
+ BasicBlock::iterator Ins = ClonedOI->ReturnBlock->begin();
SmallVector<Instruction *, 4> DeadPhis;
while (I != PreReturn->end()) {
PHINode *OldPhi = dyn_cast<PHINode>(I);
@@ -1050,9 +1050,10 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const {
break;
PHINode *RetPhi =
- PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, "", Ins);
+ PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, "");
+ RetPhi->insertBefore(Ins);
OldPhi->replaceAllUsesWith(RetPhi);
- Ins = ClonedOI->ReturnBlock->getFirstNonPHI();
+ Ins = ClonedOI->ReturnBlock->getFirstNonPHIIt();
RetPhi->addIncoming(&*I, PreReturn);
for (BasicBlock *E : ClonedOI->ReturnBlockPreds) {
diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp
index e2e6364df906..b1f9b827dcba 100644
--- a/llvm/lib/Transforms/IPO/SCCP.cpp
+++ b/llvm/lib/Transforms/IPO/SCCP.cpp
@@ -22,6 +22,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ModRef.h"
@@ -43,7 +44,7 @@ STATISTIC(NumInstReplaced,
"Number of instructions replaced with (simpler) instruction");
static cl::opt<unsigned> FuncSpecMaxIters(
- "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc(
+ "funcspec-max-iters", cl::init(10), cl::Hidden, cl::desc(
"The maximum number of iterations function specialization is run"));
static void findReturnsToZap(Function &F,
@@ -235,11 +236,11 @@ static bool runIPSCCP(
// nodes in executable blocks we found values for. The function's entry
// block is not part of BlocksToErase, so we have to handle it separately.
for (BasicBlock *BB : BlocksToErase) {
- NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(),
+ NumInstRemoved += changeToUnreachable(BB->getFirstNonPHIOrDbg(),
/*PreserveLCSSA=*/false, &DTU);
}
if (!Solver.isBlockExecutable(&F.front()))
- NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(),
+ NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHIOrDbg(),
/*PreserveLCSSA=*/false, &DTU);
BasicBlock *NewUnreachableBB = nullptr;
@@ -371,6 +372,18 @@ static bool runIPSCCP(
StoreInst *SI = cast<StoreInst>(GV->user_back());
SI->eraseFromParent();
}
+
+ // Try to create a debug constant expression for the global variable
+ // initializer value.
+ SmallVector<DIGlobalVariableExpression *, 1> GVEs;
+ GV->getDebugInfo(GVEs);
+ if (GVEs.size() == 1) {
+ DIBuilder DIB(M);
+ if (DIExpression *InitExpr = getExpressionForConstant(
+ DIB, *GV->getInitializer(), *GV->getValueType()))
+ GVEs[0]->replaceOperandWith(1, InitExpr);
+ }
+
MadeChanges = true;
M.eraseGlobalVariable(GV);
++NumGlobalConst;
diff --git a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
index 3ddf5fe20edb..f7a54d428f20 100644
--- a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
+++ b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
@@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/SampleContextTracker.h"
-#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/InstrTypes.h"
@@ -29,7 +28,7 @@ using namespace sampleprof;
namespace llvm {
ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite,
- StringRef CalleeName) {
+ FunctionId CalleeName) {
if (CalleeName.empty())
return getHottestChildContext(CallSite);
@@ -104,7 +103,7 @@ SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent,
}
void ContextTrieNode::removeChildContext(const LineLocation &CallSite,
- StringRef CalleeName) {
+ FunctionId CalleeName) {
uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite);
// Note this essentially calls dtor and destroys that child context
AllChildContext.erase(Hash);
@@ -114,7 +113,7 @@ std::map<uint64_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() {
return AllChildContext;
}
-StringRef ContextTrieNode::getFuncName() const { return FuncName; }
+FunctionId ContextTrieNode::getFuncName() const { return FuncName; }
FunctionSamples *ContextTrieNode::getFunctionSamples() const {
return FuncSamples;
@@ -178,7 +177,7 @@ void ContextTrieNode::dumpTree() {
}
ContextTrieNode *ContextTrieNode::getOrCreateChildContext(
- const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) {
+ const LineLocation &CallSite, FunctionId CalleeName, bool AllowCreate) {
uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite);
auto It = AllChildContext.find(Hash);
if (It != AllChildContext.end()) {
@@ -201,7 +200,7 @@ SampleContextTracker::SampleContextTracker(
: GUIDToFuncNameMap(GUIDToFuncNameMap) {
for (auto &FuncSample : Profiles) {
FunctionSamples *FSamples = &FuncSample.second;
- SampleContext Context = FuncSample.first;
+ SampleContext Context = FuncSample.second.getContext();
LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString()
<< "\n");
ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
@@ -232,14 +231,12 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
return nullptr;
CalleeName = FunctionSamples::getCanonicalFnName(CalleeName);
- // Convert real function names to MD5 names, if the input profile is
- // MD5-based.
- std::string FGUID;
- CalleeName = getRepInFormat(CalleeName, FunctionSamples::UseMD5, FGUID);
+
+ FunctionId FName = getRepInFormat(CalleeName);
// For indirect call, CalleeName will be empty, in which case the context
// profile for callee with largest total samples will be returned.
- ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName);
+ ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, FName);
if (CalleeContext) {
FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
LLVM_DEBUG(if (FSamples) {
@@ -305,27 +302,23 @@ SampleContextTracker::getContextSamplesFor(const SampleContext &Context) {
SampleContextTracker::ContextSamplesTy &
SampleContextTracker::getAllContextSamplesFor(const Function &Func) {
StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
- return FuncToCtxtProfiles[CanonName];
+ return FuncToCtxtProfiles[getRepInFormat(CanonName)];
}
SampleContextTracker::ContextSamplesTy &
SampleContextTracker::getAllContextSamplesFor(StringRef Name) {
- return FuncToCtxtProfiles[Name];
+ return FuncToCtxtProfiles[getRepInFormat(Name)];
}
FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func,
bool MergeContext) {
StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
- return getBaseSamplesFor(CanonName, MergeContext);
+ return getBaseSamplesFor(getRepInFormat(CanonName), MergeContext);
}
-FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
+FunctionSamples *SampleContextTracker::getBaseSamplesFor(FunctionId Name,
bool MergeContext) {
LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n");
- // Convert real function names to MD5 names, if the input profile is
- // MD5-based.
- std::string FGUID;
- Name = getRepInFormat(Name, FunctionSamples::UseMD5, FGUID);
// Base profile is top-level node (child of root node), so try to retrieve
// existing top-level node for given function first. If it exists, it could be
@@ -373,7 +366,7 @@ void SampleContextTracker::markContextSamplesInlined(
ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; }
void SampleContextTracker::promoteMergeContextSamplesTree(
- const Instruction &Inst, StringRef CalleeName) {
+ const Instruction &Inst, FunctionId CalleeName) {
LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
<< Inst << "\n");
// Get the caller context for the call instruction, we don't use callee
@@ -458,9 +451,9 @@ void SampleContextTracker::dump() { RootContext.dumpTree(); }
StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const {
if (!FunctionSamples::UseMD5)
- return Node->getFuncName();
+ return Node->getFuncName().stringRef();
assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be populated first");
- return GUIDToFuncNameMap->lookup(std::stoull(Node->getFuncName().data()));
+ return GUIDToFuncNameMap->lookup(Node->getFuncName().getHashCode());
}
ContextTrieNode *
@@ -470,7 +463,7 @@ SampleContextTracker::getContextFor(const SampleContext &Context) {
ContextTrieNode *
SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
- StringRef CalleeName) {
+ FunctionId CalleeName) {
assert(DIL && "Expect non-null location");
ContextTrieNode *CallContext = getContextFor(DIL);
@@ -485,7 +478,7 @@ SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
assert(DIL && "Expect non-null location");
- SmallVector<std::pair<LineLocation, StringRef>, 10> S;
+ SmallVector<std::pair<LineLocation, FunctionId>, 10> S;
// Use C++ linkage name if possible.
const DILocation *PrevDIL = DIL;
@@ -494,7 +487,8 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
if (Name.empty())
Name = PrevDIL->getScope()->getSubprogram()->getName();
S.push_back(
- std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name));
+ std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL),
+ getRepInFormat(Name)));
PrevDIL = DIL;
}
@@ -503,24 +497,14 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName();
if (RootName.empty())
RootName = PrevDIL->getScope()->getSubprogram()->getName();
- S.push_back(std::make_pair(LineLocation(0, 0), RootName));
-
- // Convert real function names to MD5 names, if the input profile is
- // MD5-based.
- std::list<std::string> MD5Names;
- if (FunctionSamples::UseMD5) {
- for (auto &Location : S) {
- MD5Names.emplace_back();
- getRepInFormat(Location.second, FunctionSamples::UseMD5, MD5Names.back());
- Location.second = MD5Names.back();
- }
- }
+ S.push_back(std::make_pair(LineLocation(0, 0),
+ getRepInFormat(RootName)));
ContextTrieNode *ContextNode = &RootContext;
int I = S.size();
while (--I >= 0 && ContextNode) {
LineLocation &CallSite = S[I].first;
- StringRef CalleeName = S[I].second;
+ FunctionId CalleeName = S[I].second;
ContextNode = ContextNode->getChildContext(CallSite, CalleeName);
}
@@ -540,10 +524,10 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
// Create child node at parent line/disc location
if (AllowCreate) {
ContextNode =
- ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.FuncName);
+ ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.Func);
} else {
ContextNode =
- ContextNode->getChildContext(CallSiteLoc, Callsite.FuncName);
+ ContextNode->getChildContext(CallSiteLoc, Callsite.Func);
}
CallSiteLoc = Callsite.Location;
}
@@ -553,12 +537,14 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
return ContextNode;
}
-ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
+ContextTrieNode *
+SampleContextTracker::getTopLevelContextNode(FunctionId FName) {
assert(!FName.empty() && "Top level node query must provide valid name");
return RootContext.getChildContext(LineLocation(0, 0), FName);
}
-ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
+ContextTrieNode &
+SampleContextTracker::addTopLevelContextNode(FunctionId FName) {
assert(!getTopLevelContextNode(FName) && "Node to add must not exist");
return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName);
}
@@ -638,7 +624,7 @@ void SampleContextTracker::createContextLessProfileMap(
FunctionSamples *FProfile = Node->getFunctionSamples();
// Profile's context can be empty, use ContextNode's func name.
if (FProfile)
- ContextLessProfiles[Node->getFuncName()].merge(*FProfile);
+ ContextLessProfiles.Create(Node->getFuncName()).merge(*FProfile);
}
}
} // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index a53baecd4776..6c6f0a0eca72 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/ProfileData/InstrProf.h"
@@ -142,11 +143,6 @@ static cl::opt<bool> PersistProfileStaleness(
cl::desc("Compute stale profile statistical metrics and write it into the "
"native object file(.llvm_stats section)."));
-static cl::opt<bool> FlattenProfileForMatching(
- "flatten-profile-for-matching", cl::Hidden, cl::init(true),
- cl::desc(
- "Use flattened profile for stale profile detection and matching."));
-
static cl::opt<bool> ProfileSampleAccurate(
"profile-sample-accurate", cl::Hidden, cl::init(false),
cl::desc("If the sample profile is accurate, we will mark all un-sampled "
@@ -429,7 +425,7 @@ struct CandidateComparer {
return LCS->getBodySamples().size() > RCS->getBodySamples().size();
// Tie breaker using GUID so we have stable/deterministic inlining order
- return LCS->getGUID(LCS->getName()) < RCS->getGUID(RCS->getName());
+ return LCS->getGUID() < RCS->getGUID();
}
};
@@ -458,32 +454,44 @@ class SampleProfileMatcher {
uint64_t MismatchedFuncHashSamples = 0;
uint64_t TotalFuncHashSamples = 0;
+ // A dummy name for unknown indirect callee, used to differentiate from a
+ // non-call instruction that also has an empty callee name.
+ static constexpr const char *UnknownIndirectCallee =
+ "unknown.indirect.callee";
+
public:
SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
const PseudoProbeManager *ProbeManager)
- : M(M), Reader(Reader), ProbeManager(ProbeManager) {
- if (FlattenProfileForMatching) {
- ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
- FunctionSamples::ProfileIsCS);
- }
- }
+ : M(M), Reader(Reader), ProbeManager(ProbeManager){};
void runOnModule();
private:
FunctionSamples *getFlattenedSamplesFor(const Function &F) {
StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
- auto It = FlattenedProfiles.find(CanonFName);
+ auto It = FlattenedProfiles.find(FunctionId(CanonFName));
if (It != FlattenedProfiles.end())
return &It->second;
return nullptr;
}
- void runOnFunction(const Function &F, const FunctionSamples &FS);
+ void runOnFunction(const Function &F);
+ void findIRAnchors(const Function &F,
+ std::map<LineLocation, StringRef> &IRAnchors);
+ void findProfileAnchors(
+ const FunctionSamples &FS,
+ std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors);
+ void countMismatchedSamples(const FunctionSamples &FS);
void countProfileMismatches(
+ const Function &F, const FunctionSamples &FS,
+ const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors);
+ void countProfileCallsiteMismatches(
const FunctionSamples &FS,
- const std::unordered_set<LineLocation, LineLocationHash>
- &MatchedCallsiteLocs,
+ const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors,
uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
-
LocToLocMap &getIRToProfileLocationMap(const Function &F) {
auto Ret = FuncMappings.try_emplace(
FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -491,12 +499,10 @@ private:
}
void distributeIRToProfileLocationMap();
void distributeIRToProfileLocationMap(FunctionSamples &FS);
- void populateProfileCallsites(
- const FunctionSamples &FS,
- StringMap<std::set<LineLocation>> &CalleeToCallsitesMap);
void runStaleProfileMatching(
- const std::map<LineLocation, StringRef> &IRLocations,
- StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+ const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors,
LocToLocMap &IRToProfileLocationMap);
};
@@ -538,7 +544,6 @@ protected:
findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const;
void findExternalInlineCandidate(CallBase *CB, const FunctionSamples *Samples,
DenseSet<GlobalValue::GUID> &InlinedGUIDs,
- const StringMap<Function *> &SymbolMap,
uint64_t Threshold);
// Attempt to promote indirect call and also inline the promoted call
bool tryPromoteAndInlineCandidate(
@@ -573,7 +578,7 @@ protected:
/// the function name. If the function name contains suffix, additional
/// entry is added to map from the stripped name to the function if there
/// is one-to-one mapping.
- StringMap<Function *> SymbolMap;
+ HashKeyMap<std::unordered_map, FunctionId, Function *> SymbolMap;
std::function<AssumptionCache &(Function &)> GetAC;
std::function<TargetTransformInfo &(Function &)> GetTTI;
@@ -615,6 +620,11 @@ protected:
// All the Names used in FunctionSamples including outline function
// names, inline instance names and call target names.
StringSet<> NamesInProfile;
+ // MD5 version of NamesInProfile. Either NamesInProfile or GUIDsInProfile is
+ // populated, depends on whether the profile uses MD5. Because the name table
+ // generally contains several magnitude more entries than the number of
+ // functions, we do not want to convert all names from one form to another.
+ llvm::DenseSet<uint64_t> GUIDsInProfile;
// For symbol in profile symbol list, whether to regard their profiles
// to be accurate. It is mainly decided by existance of profile symbol
@@ -759,8 +769,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples(
assert(L && R && "Expect non-null FunctionSamples");
if (L->getHeadSamplesEstimate() != R->getHeadSamplesEstimate())
return L->getHeadSamplesEstimate() > R->getHeadSamplesEstimate();
- return FunctionSamples::getGUID(L->getName()) <
- FunctionSamples::getGUID(R->getName());
+ return L->getGUID() < R->getGUID();
};
if (FunctionSamples::ProfileIsCS) {
@@ -970,13 +979,13 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate(
// This prevents allocating an array of zero length in callees below.
if (MaxNumPromotions == 0)
return false;
- auto CalleeFunctionName = Candidate.CalleeSamples->getFuncName();
+ auto CalleeFunctionName = Candidate.CalleeSamples->getFunction();
auto R = SymbolMap.find(CalleeFunctionName);
- if (R == SymbolMap.end() || !R->getValue())
+ if (R == SymbolMap.end() || !R->second)
return false;
auto &CI = *Candidate.CallInstr;
- if (!doesHistoryAllowICP(CI, R->getValue()->getName()))
+ if (!doesHistoryAllowICP(CI, R->second->getName()))
return false;
const char *Reason = "Callee function not available";
@@ -986,17 +995,17 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate(
// clone the caller first, and inline the cloned caller if it is
// recursive. As llvm does not inline recursive calls, we will
// simply ignore it instead of handling it explicitly.
- if (!R->getValue()->isDeclaration() && R->getValue()->getSubprogram() &&
- R->getValue()->hasFnAttribute("use-sample-profile") &&
- R->getValue() != &F && isLegalToPromote(CI, R->getValue(), &Reason)) {
+ if (!R->second->isDeclaration() && R->second->getSubprogram() &&
+ R->second->hasFnAttribute("use-sample-profile") &&
+ R->second != &F && isLegalToPromote(CI, R->second, &Reason)) {
// For promoted target, set its value with NOMORE_ICP_MAGICNUM count
// in the value profile metadata so the target won't be promoted again.
SmallVector<InstrProfValueData, 1> SortedCallTargets = {InstrProfValueData{
- Function::getGUID(R->getValue()->getName()), NOMORE_ICP_MAGICNUM}};
+ Function::getGUID(R->second->getName()), NOMORE_ICP_MAGICNUM}};
updateIDTMetaData(CI, SortedCallTargets, 0);
auto *DI = &pgo::promoteIndirectCall(
- CI, R->getValue(), Candidate.CallsiteCount, Sum, false, ORE);
+ CI, R->second, Candidate.CallsiteCount, Sum, false, ORE);
if (DI) {
Sum -= Candidate.CallsiteCount;
// Do not prorate the indirect callsite distribution since the original
@@ -1025,7 +1034,8 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate(
}
} else {
LLVM_DEBUG(dbgs() << "\nFailed to promote indirect call to "
- << Candidate.CalleeSamples->getFuncName() << " because "
+ << FunctionSamples::getCanonicalFnName(
+ Candidate.CallInstr->getName())<< " because "
<< Reason << "\n");
}
return false;
@@ -1070,8 +1080,7 @@ void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates(
void SampleProfileLoader::findExternalInlineCandidate(
CallBase *CB, const FunctionSamples *Samples,
- DenseSet<GlobalValue::GUID> &InlinedGUIDs,
- const StringMap<Function *> &SymbolMap, uint64_t Threshold) {
+ DenseSet<GlobalValue::GUID> &InlinedGUIDs, uint64_t Threshold) {
// If ExternalInlineAdvisor(ReplayInlineAdvisor) wants to inline an external
// function make sure it's imported
@@ -1080,7 +1089,7 @@ void SampleProfileLoader::findExternalInlineCandidate(
// just add the direct GUID and move on
if (!Samples) {
InlinedGUIDs.insert(
- FunctionSamples::getGUID(CB->getCalledFunction()->getName()));
+ Function::getGUID(CB->getCalledFunction()->getName()));
return;
}
// Otherwise, drop the threshold to import everything that we can
@@ -1121,22 +1130,20 @@ void SampleProfileLoader::findExternalInlineCandidate(
CalleeSample->getContext().hasAttribute(ContextShouldBeInlined);
if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold)
continue;
-
- StringRef Name = CalleeSample->getFuncName();
- Function *Func = SymbolMap.lookup(Name);
+
+ Function *Func = SymbolMap.lookup(CalleeSample->getFunction());
// Add to the import list only when it's defined out of module.
if (!Func || Func->isDeclaration())
- InlinedGUIDs.insert(FunctionSamples::getGUID(CalleeSample->getName()));
+ InlinedGUIDs.insert(CalleeSample->getGUID());
// Import hot CallTargets, which may not be available in IR because full
// profile annotation cannot be done until backend compilation in ThinLTO.
for (const auto &BS : CalleeSample->getBodySamples())
for (const auto &TS : BS.second.getCallTargets())
- if (TS.getValue() > Threshold) {
- StringRef CalleeName = CalleeSample->getFuncName(TS.getKey());
- const Function *Callee = SymbolMap.lookup(CalleeName);
+ if (TS.second > Threshold) {
+ const Function *Callee = SymbolMap.lookup(TS.first);
if (!Callee || Callee->isDeclaration())
- InlinedGUIDs.insert(FunctionSamples::getGUID(TS.getKey()));
+ InlinedGUIDs.insert(TS.first.getHashCode());
}
// Import hot child context profile associted with callees. Note that this
@@ -1234,7 +1241,7 @@ bool SampleProfileLoader::inlineHotFunctions(
for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) {
uint64_t SumOrigin = Sum;
if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) {
- findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap,
+ findExternalInlineCandidate(I, FS, InlinedGUIDs,
PSI->getOrCompHotCountThreshold());
continue;
}
@@ -1255,7 +1262,7 @@ bool SampleProfileLoader::inlineHotFunctions(
}
} else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) {
findExternalInlineCandidate(I, findCalleeFunctionSamples(*I),
- InlinedGUIDs, SymbolMap,
+ InlinedGUIDs,
PSI->getOrCompHotCountThreshold());
}
}
@@ -1504,7 +1511,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority(
for (const auto *FS : CalleeSamples) {
// TODO: Consider disable pre-lTO ICP for MonoLTO as well
if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) {
- findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap,
+ findExternalInlineCandidate(I, FS, InlinedGUIDs,
PSI->getOrCompHotCountThreshold());
continue;
}
@@ -1557,7 +1564,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority(
}
} else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) {
findExternalInlineCandidate(I, findCalleeFunctionSamples(*I),
- InlinedGUIDs, SymbolMap,
+ InlinedGUIDs,
PSI->getOrCompHotCountThreshold());
}
}
@@ -1619,7 +1626,12 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples(
// Note that we have to do the merge right after processing function.
// This allows OutlineFS's profile to be used for annotation during
// top-down processing of functions' annotation.
- FunctionSamples *OutlineFS = Reader->getOrCreateSamplesFor(*Callee);
+ FunctionSamples *OutlineFS = Reader->getSamplesFor(*Callee);
+ // If outlined function does not exist in the profile, add it to a
+ // separate map so that it does not rehash the original profile.
+ if (!OutlineFS)
+ OutlineFS = &OutlineFunctionSamples[
+ FunctionId(FunctionSamples::getCanonicalFnName(Callee->getName()))];
OutlineFS->merge(*FS, 1);
// Set outlined profile to be synthetic to not bias the inliner.
OutlineFS->SetContextSynthetic();
@@ -1638,7 +1650,7 @@ GetSortedValueDataFromCallTargets(const SampleRecord::CallTargetMap &M) {
SmallVector<InstrProfValueData, 2> R;
for (const auto &I : SampleRecord::SortCallTargets(M)) {
R.emplace_back(
- InstrProfValueData{FunctionSamples::getGUID(I.first), I.second});
+ InstrProfValueData{I.first.getHashCode(), I.second});
}
return R;
}
@@ -1699,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
- I.setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(
- {static_cast<uint32_t>(BlockWeights[BB])}));
+ setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1709,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
// clear it for cold code.
for (auto &I : *BB) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
- if (cast<CallBase>(I).isIndirectCall())
+ if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
- else
- I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+ } else {
+ setBranchWeights(I, {uint32_t(0)});
+ }
}
}
}
@@ -1792,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "
@@ -1865,7 +1876,8 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) {
for (Function &F : M) {
if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
continue;
- ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(F));
+ ProfiledCG->addProfiledFunction(
+ getRepInFormat(FunctionSamples::getCanonicalFnName(F)));
}
return ProfiledCG;
@@ -1913,7 +1925,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
// on the profile to favor more inlining. This is only a problem with CS
// profile.
// 3. Transitive indirect call edges due to inlining. When a callee function
- // (say B) is inlined into into a caller function (say A) in LTO prelink,
+ // (say B) is inlined into a caller function (say A) in LTO prelink,
// every call edge originated from the callee B will be transferred to
// the caller A. If any transferred edge (say A->C) is indirect, the
// original profiled indirect edge B->C, even if considered, would not
@@ -2016,8 +2028,16 @@ bool SampleProfileLoader::doInitialization(Module &M,
ProfileAccurateForSymsInList && PSL && !ProfileSampleAccurate;
if (ProfAccForSymsInList) {
NamesInProfile.clear();
- if (auto NameTable = Reader->getNameTable())
- NamesInProfile.insert(NameTable->begin(), NameTable->end());
+ GUIDsInProfile.clear();
+ if (auto NameTable = Reader->getNameTable()) {
+ if (FunctionSamples::UseMD5) {
+ for (auto Name : *NameTable)
+ GUIDsInProfile.insert(Name.getHashCode());
+ } else {
+ for (auto Name : *NameTable)
+ NamesInProfile.insert(Name.stringRef());
+ }
+ }
CoverageTracker.setProfAccForSymsInList(true);
}
@@ -2103,77 +2123,200 @@ bool SampleProfileLoader::doInitialization(Module &M,
return true;
}
-void SampleProfileMatcher::countProfileMismatches(
- const FunctionSamples &FS,
- const std::unordered_set<LineLocation, LineLocationHash>
- &MatchedCallsiteLocs,
- uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
+void SampleProfileMatcher::findIRAnchors(
+ const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
+ // For inlined code, recover the original callsite and callee by finding the
+ // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the
+ // top-level frame is "main:1", the callsite is "1" and the callee is "foo".
+ auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) {
+ assert((DIL && DIL->getInlinedAt()) && "No inlined callsite");
+ const DILocation *PrevDIL = nullptr;
+ do {
+ PrevDIL = DIL;
+ DIL = DIL->getInlinedAt();
+ } while (DIL->getInlinedAt());
- auto isInvalidLineOffset = [](uint32_t LineOffset) {
- return LineOffset & 0x8000;
+ LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+ StringRef CalleeName = PrevDIL->getSubprogramLinkageName();
+ return std::make_pair(Callsite, CalleeName);
};
- // Check if there are any callsites in the profile that does not match to any
- // IR callsites, those callsite samples will be discarded.
- for (auto &I : FS.getBodySamples()) {
- const LineLocation &Loc = I.first;
- if (isInvalidLineOffset(Loc.LineOffset))
- continue;
+ auto GetCanonicalCalleeName = [](const CallBase *CB) {
+ StringRef CalleeName = UnknownIndirectCallee;
+ if (Function *Callee = CB->getCalledFunction())
+ CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
+ return CalleeName;
+ };
+
+ // Extract profile matching anchors in the IR.
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ DILocation *DIL = I.getDebugLoc();
+ if (!DIL)
+ continue;
+
+ if (FunctionSamples::ProfileIsProbeBased) {
+ if (auto Probe = extractProbe(I)) {
+ // Flatten inlined IR for the matching.
+ if (DIL->getInlinedAt()) {
+ IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
+ } else {
+ // Use empty StringRef for basic block probe.
+ StringRef CalleeName;
+ if (const auto *CB = dyn_cast<CallBase>(&I)) {
+ // Skip the probe inst whose callee name is "llvm.pseudoprobe".
+ if (!isa<IntrinsicInst>(&I))
+ CalleeName = GetCanonicalCalleeName(CB);
+ }
+ IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName);
+ }
+ }
+ } else {
+ // TODO: For line-number based profile(AutoFDO), currently only support
+ // find callsite anchors. In future, we need to parse all the non-call
+ // instructions to extract the line locations for profile matching.
+ if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
+ continue;
- uint64_t Count = I.second.getSamples();
- if (!I.second.getCallTargets().empty()) {
- TotalCallsiteSamples += Count;
- FuncProfiledCallsites++;
- if (!MatchedCallsiteLocs.count(Loc)) {
- MismatchedCallsiteSamples += Count;
- FuncMismatchedCallsites++;
+ if (DIL->getInlinedAt()) {
+ IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
+ } else {
+ LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+ StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I));
+ IRAnchors.emplace(Callsite, CalleeName);
+ }
}
}
}
+}
- for (auto &I : FS.getCallsiteSamples()) {
- const LineLocation &Loc = I.first;
- if (isInvalidLineOffset(Loc.LineOffset))
- continue;
+void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) {
+ const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
+ // Skip the function that is external or renamed.
+ if (!FuncDesc)
+ return;
+
+ if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+ MismatchedFuncHashSamples += FS.getTotalSamples();
+ return;
+ }
+ for (const auto &I : FS.getCallsiteSamples())
+ for (const auto &CS : I.second)
+ countMismatchedSamples(CS.second);
+}
+
+void SampleProfileMatcher::countProfileMismatches(
+ const Function &F, const FunctionSamples &FS,
+ const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors) {
+ [[maybe_unused]] bool IsFuncHashMismatch = false;
+ if (FunctionSamples::ProfileIsProbeBased) {
+ TotalFuncHashSamples += FS.getTotalSamples();
+ TotalProfiledFunc++;
+ const auto *FuncDesc = ProbeManager->getDesc(F);
+ if (FuncDesc) {
+ if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+ NumMismatchedFuncHash++;
+ IsFuncHashMismatch = true;
+ }
+ countMismatchedSamples(FS);
+ }
+ }
+
+ uint64_t FuncMismatchedCallsites = 0;
+ uint64_t FuncProfiledCallsites = 0;
+ countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors,
+ FuncMismatchedCallsites,
+ FuncProfiledCallsites);
+ TotalProfiledCallsites += FuncProfiledCallsites;
+ NumMismatchedCallsites += FuncMismatchedCallsites;
+ LLVM_DEBUG({
+ if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
+ FuncMismatchedCallsites)
+ dbgs() << "Function checksum is matched but there are "
+ << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
+ << " mismatched callsites.\n";
+ });
+}
+
+void SampleProfileMatcher::countProfileCallsiteMismatches(
+ const FunctionSamples &FS,
+ const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors,
+ uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
+
+ // Check if there are any callsites in the profile that does not match to any
+ // IR callsites, those callsite samples will be discarded.
+ for (const auto &I : ProfileAnchors) {
+ const auto &Loc = I.first;
+ const auto &Callees = I.second;
+ assert(!Callees.empty() && "Callees should not be empty");
+
+ StringRef IRCalleeName;
+ const auto &IR = IRAnchors.find(Loc);
+ if (IR != IRAnchors.end())
+ IRCalleeName = IR->second;
- uint64_t Count = 0;
- for (auto &FM : I.second) {
- Count += FM.second.getHeadSamplesEstimate();
+ // Compute number of samples in the original profile.
+ uint64_t CallsiteSamples = 0;
+ auto CTM = FS.findCallTargetMapAt(Loc);
+ if (CTM) {
+ for (const auto &I : CTM.get())
+ CallsiteSamples += I.second;
}
- TotalCallsiteSamples += Count;
+ const auto *FSMap = FS.findFunctionSamplesMapAt(Loc);
+ if (FSMap) {
+ for (const auto &I : *FSMap)
+ CallsiteSamples += I.second.getTotalSamples();
+ }
+
+ bool CallsiteIsMatched = false;
+ // Since indirect call does not have CalleeName, check conservatively if
+ // callsite in the profile is a callsite location. This is to reduce num of
+ // false positive since otherwise all the indirect call samples will be
+ // reported as mismatching.
+ if (IRCalleeName == UnknownIndirectCallee)
+ CallsiteIsMatched = true;
+ else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
+ CallsiteIsMatched = true;
+
FuncProfiledCallsites++;
- if (!MatchedCallsiteLocs.count(Loc)) {
- MismatchedCallsiteSamples += Count;
+ TotalCallsiteSamples += CallsiteSamples;
+ if (!CallsiteIsMatched) {
FuncMismatchedCallsites++;
+ MismatchedCallsiteSamples += CallsiteSamples;
}
}
}
-// Populate the anchors(direct callee name) from profile.
-void SampleProfileMatcher::populateProfileCallsites(
- const FunctionSamples &FS,
- StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) {
+void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS,
+ std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
+ auto isInvalidLineOffset = [](uint32_t LineOffset) {
+ return LineOffset & 0x8000;
+ };
+
for (const auto &I : FS.getBodySamples()) {
- const auto &Loc = I.first;
- const auto &CTM = I.second.getCallTargets();
- // Filter out possible indirect calls, use direct callee name as anchor.
- if (CTM.size() == 1) {
- StringRef CalleeName = CTM.begin()->first();
- const auto &Candidates = CalleeToCallsitesMap.try_emplace(
- CalleeName, std::set<LineLocation>());
- Candidates.first->second.insert(Loc);
+ const LineLocation &Loc = I.first;
+ if (isInvalidLineOffset(Loc.LineOffset))
+ continue;
+ for (const auto &I : I.second.getCallTargets()) {
+ auto Ret = ProfileAnchors.try_emplace(Loc,
+ std::unordered_set<FunctionId>());
+ Ret.first->second.insert(I.first);
}
}
for (const auto &I : FS.getCallsiteSamples()) {
const LineLocation &Loc = I.first;
+ if (isInvalidLineOffset(Loc.LineOffset))
+ continue;
const auto &CalleeMap = I.second;
- // Filter out possible indirect calls, use direct callee name as anchor.
- if (CalleeMap.size() == 1) {
- StringRef CalleeName = CalleeMap.begin()->first;
- const auto &Candidates = CalleeToCallsitesMap.try_emplace(
- CalleeName, std::set<LineLocation>());
- Candidates.first->second.insert(Loc);
+ for (const auto &I : CalleeMap) {
+ auto Ret = ProfileAnchors.try_emplace(Loc,
+ std::unordered_set<FunctionId>());
+ Ret.first->second.insert(I.first);
}
}
}
@@ -2196,12 +2339,30 @@ void SampleProfileMatcher::populateProfileCallsites(
// [1, 2, 3(foo), 4, 7, 8(bar), 9]
// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
void SampleProfileMatcher::runStaleProfileMatching(
- const std::map<LineLocation, StringRef> &IRLocations,
- StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+ const Function &F,
+ const std::map<LineLocation, StringRef> &IRAnchors,
+ const std::map<LineLocation, std::unordered_set<FunctionId>>
+ &ProfileAnchors,
LocToLocMap &IRToProfileLocationMap) {
+ LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
+ << "\n");
assert(IRToProfileLocationMap.empty() &&
"Run stale profile matching only once per function");
+ std::unordered_map<FunctionId, std::set<LineLocation>>
+ CalleeToCallsitesMap;
+ for (const auto &I : ProfileAnchors) {
+ const auto &Loc = I.first;
+ const auto &Callees = I.second;
+ // Filter out possible indirect calls, use direct callee name as anchor.
+ if (Callees.size() == 1) {
+ FunctionId CalleeName = *Callees.begin();
+ const auto &Candidates = CalleeToCallsitesMap.try_emplace(
+ CalleeName, std::set<LineLocation>());
+ Candidates.first->second.insert(Loc);
+ }
+ }
+
auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) {
// Skip the unchanged location mapping to save memory.
if (From != To)
@@ -2212,18 +2373,19 @@ void SampleProfileMatcher::runStaleProfileMatching(
int32_t LocationDelta = 0;
SmallVector<LineLocation> LastMatchedNonAnchors;
- for (const auto &IR : IRLocations) {
+ for (const auto &IR : IRAnchors) {
const auto &Loc = IR.first;
- StringRef CalleeName = IR.second;
+ auto CalleeName = IR.second;
bool IsMatchedAnchor = false;
// Match the anchor location in lexical order.
if (!CalleeName.empty()) {
- auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName);
- if (ProfileAnchors != CalleeToCallsitesMap.end() &&
- !ProfileAnchors->second.empty()) {
- auto CI = ProfileAnchors->second.begin();
+ auto CandidateAnchors = CalleeToCallsitesMap.find(
+ getRepInFormat(CalleeName));
+ if (CandidateAnchors != CalleeToCallsitesMap.end() &&
+ !CandidateAnchors->second.empty()) {
+ auto CI = CandidateAnchors->second.begin();
const auto Candidate = *CI;
- ProfileAnchors->second.erase(CI);
+ CandidateAnchors->second.erase(CI);
InsertMatching(Loc, Candidate);
LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName
<< " is matched from " << Loc << " to " << Candidate
@@ -2261,122 +2423,56 @@ void SampleProfileMatcher::runStaleProfileMatching(
}
}
-void SampleProfileMatcher::runOnFunction(const Function &F,
- const FunctionSamples &FS) {
- bool IsFuncHashMismatch = false;
- if (FunctionSamples::ProfileIsProbeBased) {
- uint64_t Count = FS.getTotalSamples();
- TotalFuncHashSamples += Count;
- TotalProfiledFunc++;
- if (!ProbeManager->profileIsValid(F, FS)) {
- MismatchedFuncHashSamples += Count;
- NumMismatchedFuncHash++;
- IsFuncHashMismatch = true;
- }
- }
-
- std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs;
- // The value of the map is the name of direct callsite and use empty StringRef
- // for non-direct-call site.
- std::map<LineLocation, StringRef> IRLocations;
-
- // Extract profile matching anchors and profile mismatch metrics in the IR.
- for (auto &BB : F) {
- for (auto &I : BB) {
- // TODO: Support line-number based location(AutoFDO).
- if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) {
- if (std::optional<PseudoProbe> Probe = extractProbe(I))
- IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef());
- }
-
- if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
- continue;
-
- const auto *CB = dyn_cast<CallBase>(&I);
- if (auto &DLoc = I.getDebugLoc()) {
- LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc);
-
- StringRef CalleeName;
- if (Function *Callee = CB->getCalledFunction())
- CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
-
- // Force to overwrite the callee name in case any non-call location was
- // written before.
- auto R = IRLocations.emplace(IRCallsite, CalleeName);
- R.first->second = CalleeName;
- assert((!FunctionSamples::ProfileIsProbeBased || R.second ||
- R.first->second == CalleeName) &&
- "Overwrite non-call or different callee name location for "
- "pseudo probe callsite");
+void SampleProfileMatcher::runOnFunction(const Function &F) {
+ // We need to use flattened function samples for matching.
+ // Unlike IR, which includes all callsites from the source code, the callsites
+ // in profile only show up when they are hit by samples, i,e. the profile
+ // callsites in one context may differ from those in another context. To get
+ // the maximum number of callsites, we merge the function profiles from all
+ // contexts, aka, the flattened profile to find profile anchors.
+ const auto *FSFlattened = getFlattenedSamplesFor(F);
+ if (!FSFlattened)
+ return;
- // Go through all the callsites on the IR and flag the callsite if the
- // target name is the same as the one in the profile.
- const auto CTM = FS.findCallTargetMapAt(IRCallsite);
- const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite);
-
- // Indirect call case.
- if (CalleeName.empty()) {
- // Since indirect call does not have the CalleeName, check
- // conservatively if callsite in the profile is a callsite location.
- // This is to avoid nums of false positive since otherwise all the
- // indirect call samples will be reported as mismatching.
- if ((CTM && !CTM->empty()) || (CallsiteFS && !CallsiteFS->empty()))
- MatchedCallsiteLocs.insert(IRCallsite);
- } else {
- // Check if the call target name is matched for direct call case.
- if ((CTM && CTM->count(CalleeName)) ||
- (CallsiteFS && CallsiteFS->count(CalleeName)))
- MatchedCallsiteLocs.insert(IRCallsite);
- }
- }
- }
- }
+ // Anchors for IR. It's a map from IR location to callee name, callee name is
+ // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
+ // for unknown indrect callee name.
+ std::map<LineLocation, StringRef> IRAnchors;
+ findIRAnchors(F, IRAnchors);
+ // Anchors for profile. It's a map from callsite location to a set of callee
+ // name.
+ std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
+ findProfileAnchors(*FSFlattened, ProfileAnchors);
// Detect profile mismatch for profile staleness metrics report.
- if (ReportProfileStaleness || PersistProfileStaleness) {
- uint64_t FuncMismatchedCallsites = 0;
- uint64_t FuncProfiledCallsites = 0;
- countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites,
- FuncProfiledCallsites);
- TotalProfiledCallsites += FuncProfiledCallsites;
- NumMismatchedCallsites += FuncMismatchedCallsites;
- LLVM_DEBUG({
- if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
- FuncMismatchedCallsites)
- dbgs() << "Function checksum is matched but there are "
- << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
- << " mismatched callsites.\n";
- });
+ // Skip reporting the metrics for imported functions.
+ if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
+ (ReportProfileStaleness || PersistProfileStaleness)) {
+ // Use top-level nested FS for counting profile mismatch metrics since
+ // currently once a callsite is mismatched, all its children profiles are
+ // dropped.
+ if (const auto *FS = Reader.getSamplesFor(F))
+ countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors);
}
- if (IsFuncHashMismatch && SalvageStaleProfile) {
- LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
- << "\n");
-
- StringMap<std::set<LineLocation>> CalleeToCallsitesMap;
- populateProfileCallsites(FS, CalleeToCallsitesMap);
-
+ // Run profile matching for checksum mismatched profile, currently only
+ // support for pseudo-probe.
+ if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
+ !ProbeManager->profileIsValid(F, *FSFlattened)) {
// The matching result will be saved to IRToProfileLocationMap, create a new
// map for each function.
- auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
-
- runStaleProfileMatching(IRLocations, CalleeToCallsitesMap,
- IRToProfileLocationMap);
+ runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
+ getIRToProfileLocationMap(F));
}
}
void SampleProfileMatcher::runOnModule() {
+ ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+ FunctionSamples::ProfileIsCS);
for (auto &F : M) {
if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
continue;
- FunctionSamples *FS = nullptr;
- if (FlattenProfileForMatching)
- FS = getFlattenedSamplesFor(F);
- else
- FS = Reader.getSamplesFor(F);
- if (!FS)
- continue;
- runOnFunction(F, *FS);
+ runOnFunction(F);
}
if (SalvageStaleProfile)
distributeIRToProfileLocationMap();
@@ -2424,7 +2520,7 @@ void SampleProfileMatcher::runOnModule() {
void SampleProfileMatcher::distributeIRToProfileLocationMap(
FunctionSamples &FS) {
- const auto ProfileMappings = FuncMappings.find(FS.getName());
+ const auto ProfileMappings = FuncMappings.find(FS.getFuncName());
if (ProfileMappings != FuncMappings.end()) {
FS.setIRToProfileLocationMap(&(ProfileMappings->second));
}
@@ -2466,10 +2562,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
Function *F = dyn_cast<Function>(N_F.getValue());
if (F == nullptr || OrigName.empty())
continue;
- SymbolMap[OrigName] = F;
+ SymbolMap[FunctionId(OrigName)] = F;
StringRef NewName = FunctionSamples::getCanonicalFnName(*F);
if (OrigName != NewName && !NewName.empty()) {
- auto r = SymbolMap.insert(std::make_pair(NewName, F));
+ auto r = SymbolMap.emplace(FunctionId(NewName), F);
// Failiing to insert means there is already an entry in SymbolMap,
// thus there are multiple functions that are mapped to the same
// stripped name. In this case of name conflicting, set the value
@@ -2482,11 +2578,11 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
if (Remapper) {
if (auto MapName = Remapper->lookUpNameInProfile(OrigName)) {
if (*MapName != OrigName && !MapName->empty())
- SymbolMap.insert(std::make_pair(*MapName, F));
+ SymbolMap.emplace(FunctionId(*MapName), F);
}
}
}
- assert(SymbolMap.count(StringRef()) == 0 &&
+ assert(SymbolMap.count(FunctionId()) == 0 &&
"No empty StringRef should be added in SymbolMap");
if (ReportProfileStaleness || PersistProfileStaleness ||
@@ -2550,7 +2646,9 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM)
// but not cold accumulatively...), so the outline function showing up as
// cold in sampled binary will actually not be cold after current build.
StringRef CanonName = FunctionSamples::getCanonicalFnName(F);
- if (NamesInProfile.count(CanonName))
+ if ((FunctionSamples::UseMD5 &&
+ GUIDsInProfile.count(Function::getGUID(CanonName))) ||
+ (!FunctionSamples::UseMD5 && NamesInProfile.count(CanonName)))
initialEntryCount = -1;
}
@@ -2571,8 +2669,24 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM)
if (FunctionSamples::ProfileIsCS)
Samples = ContextTracker->getBaseSamplesFor(F);
- else
+ else {
Samples = Reader->getSamplesFor(F);
+ // Try search in previously inlined functions that were split or duplicated
+ // into base.
+ if (!Samples) {
+ StringRef CanonName = FunctionSamples::getCanonicalFnName(F);
+ auto It = OutlineFunctionSamples.find(FunctionId(CanonName));
+ if (It != OutlineFunctionSamples.end()) {
+ Samples = &It->second;
+ } else if (auto Remapper = Reader->getRemapper()) {
+ if (auto RemppedName = Remapper->lookUpNameInProfile(CanonName)) {
+ It = OutlineFunctionSamples.find(FunctionId(*RemppedName));
+ if (It != OutlineFunctionSamples.end())
+ Samples = &It->second;
+ }
+ }
+ }
+ }
if (Samples && !Samples->empty())
return emitAnnotations(F);
diff --git a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
index 0a42de7224b4..8f0b12d0cfed 100644
--- a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp
@@ -18,6 +18,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -95,13 +96,13 @@ void PseudoProbeVerifier::runAfterPass(StringRef PassID, Any IR) {
std::string Banner =
"\n*** Pseudo Probe Verification After " + PassID.str() + " ***\n";
dbgs() << Banner;
- if (const auto **M = any_cast<const Module *>(&IR))
+ if (const auto **M = llvm::any_cast<const Module *>(&IR))
runAfterPass(*M);
- else if (const auto **F = any_cast<const Function *>(&IR))
+ else if (const auto **F = llvm::any_cast<const Function *>(&IR))
runAfterPass(*F);
- else if (const auto **C = any_cast<const LazyCallGraph::SCC *>(&IR))
+ else if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
runAfterPass(*C);
- else if (const auto **L = any_cast<const Loop *>(&IR))
+ else if (const auto **L = llvm::any_cast<const Loop *>(&IR))
runAfterPass(*L);
else
llvm_unreachable("Unknown IR unit");
@@ -221,12 +222,26 @@ void SampleProfileProber::computeProbeIdForBlocks() {
}
void SampleProfileProber::computeProbeIdForCallsites() {
+ LLVMContext &Ctx = F->getContext();
+ Module *M = F->getParent();
+
for (auto &BB : *F) {
for (auto &I : BB) {
if (!isa<CallBase>(I))
continue;
if (isa<IntrinsicInst>(&I))
continue;
+
+ // The current implementation uses the lower 16 bits of the discriminator
+ // so anything larger than 0xFFFF will be ignored.
+ if (LastProbeId >= 0xFFFF) {
+ std::string Msg = "Pseudo instrumentation incomplete for " +
+ std::string(F->getName()) + " because it's too large";
+ Ctx.diagnose(
+ DiagnosticInfoSampleProfile(M->getName().data(), Msg, DS_Warning));
+ return;
+ }
+
CallProbeIds[&I] = ++LastProbeId;
}
}
diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp
index 147513452789..28d7d4ba6b01 100644
--- a/llvm/lib/Transforms/IPO/StripSymbols.cpp
+++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp
@@ -30,12 +30,18 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/TypeFinder.h"
#include "llvm/IR/ValueSymbolTable.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/StripSymbols.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
+static cl::opt<bool>
+ StripGlobalConstants("strip-global-constants", cl::init(false), cl::Hidden,
+ cl::desc("Removes debug compile units which reference "
+ "to non-existing global constants"));
+
/// OnlyUsedBy - Return true if V is only used by Usr.
static bool OnlyUsedBy(Value *V, Value *Usr) {
for (User *U : V->users())
@@ -73,7 +79,7 @@ static void StripSymtab(ValueSymbolTable &ST, bool PreserveDbgInfo) {
Value *V = VI->getValue();
++VI;
if (!isa<GlobalValue>(V) || cast<GlobalValue>(V)->hasLocalLinkage()) {
- if (!PreserveDbgInfo || !V->getName().startswith("llvm.dbg"))
+ if (!PreserveDbgInfo || !V->getName().starts_with("llvm.dbg"))
// Set name to "", removing from symbol table!
V->setName("");
}
@@ -88,7 +94,7 @@ static void StripTypeNames(Module &M, bool PreserveDbgInfo) {
for (StructType *STy : StructTypes) {
if (STy->isLiteral() || STy->getName().empty()) continue;
- if (PreserveDbgInfo && STy->getName().startswith("llvm.dbg"))
+ if (PreserveDbgInfo && STy->getName().starts_with("llvm.dbg"))
continue;
STy->setName("");
@@ -118,13 +124,13 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) {
for (GlobalVariable &GV : M.globals()) {
if (GV.hasLocalLinkage() && !llvmUsedValues.contains(&GV))
- if (!PreserveDbgInfo || !GV.getName().startswith("llvm.dbg"))
+ if (!PreserveDbgInfo || !GV.getName().starts_with("llvm.dbg"))
GV.setName(""); // Internal symbols can't participate in linkage
}
for (Function &I : M) {
if (I.hasLocalLinkage() && !llvmUsedValues.contains(&I))
- if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg"))
+ if (!PreserveDbgInfo || !I.getName().starts_with("llvm.dbg"))
I.setName(""); // Internal symbols can't participate in linkage
if (auto *Symtab = I.getValueSymbolTable())
StripSymtab(*Symtab, PreserveDbgInfo);
@@ -216,7 +222,8 @@ static bool stripDeadDebugInfoImpl(Module &M) {
// Create our live global variable list.
bool GlobalVariableChange = false;
for (auto *DIG : DIC->getGlobalVariables()) {
- if (DIG->getExpression() && DIG->getExpression()->isConstant())
+ if (DIG->getExpression() && DIG->getExpression()->isConstant() &&
+ !StripGlobalConstants)
LiveGVs.insert(DIG);
// Make sure we only visit each global variable only once.
diff --git a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp
index d46f9a6c6757..f6f895676084 100644
--- a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp
+++ b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp
@@ -111,7 +111,7 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M,
// Now compute the callsite count from relative frequency and
// entry count:
BasicBlock *CSBB = CB.getParent();
- Scaled64 EntryFreq(BFI.getEntryFreq(), 0);
+ Scaled64 EntryFreq(BFI.getEntryFreq().getFrequency(), 0);
Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0);
BBCount /= EntryFreq;
BBCount *= Counts[Caller];
diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
index fc1e70b1b3d3..e5f9fa1dda88 100644
--- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
+++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp
@@ -186,7 +186,7 @@ void simplifyExternals(Module &M) {
if (!F.isDeclaration() || F.getFunctionType() == EmptyFT ||
// Changing the type of an intrinsic may invalidate the IR.
- F.getName().startswith("llvm."))
+ F.getName().starts_with("llvm."))
continue;
Function *NewF =
@@ -198,7 +198,7 @@ void simplifyExternals(Module &M) {
AttributeList::FunctionIndex,
F.getAttributes().getFnAttrs()));
NewF->takeName(&F);
- F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType()));
+ F.replaceAllUsesWith(NewF);
F.eraseFromParent();
}
@@ -329,7 +329,7 @@ void splitAndWriteThinLTOBitcode(
// comdat in MergedM to keep the comdat together.
DenseSet<const Comdat *> MergedMComdats;
for (GlobalVariable &GV : M.globals())
- if (HasTypeMetadata(&GV)) {
+ if (!GV.isDeclaration() && HasTypeMetadata(&GV)) {
if (const auto *C = GV.getComdat())
MergedMComdats.insert(C);
forEachVirtualFunction(GV.getInitializer(), [&](Function *F) {
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index d33258642365..85afc020dbf8 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -58,7 +58,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
@@ -369,8 +368,6 @@ template <> struct DenseMapInfo<VTableSlotSummary> {
} // end namespace llvm
-namespace {
-
// Returns true if the function must be unreachable based on ValueInfo.
//
// In particular, identifies a function as unreachable in the following
@@ -378,7 +375,7 @@ namespace {
// 1) All summaries are live.
// 2) All function summaries indicate it's unreachable
// 3) There is no non-function with the same GUID (which is rare)
-bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
+static bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
// Returns false if ValueInfo is absent, or the summary list is empty
// (e.g., function declarations).
@@ -403,6 +400,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
return true;
}
+namespace {
// A virtual call site. VTable is the loaded virtual table pointer, and CS is
// the indirect virtual call.
struct VirtualCallSite {
@@ -590,7 +588,7 @@ struct DevirtModule {
: M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
ExportSummary(ExportSummary), ImportSummary(ImportSummary),
Int8Ty(Type::getInt8Ty(M.getContext())),
- Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
+ Int8PtrTy(PointerType::getUnqual(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())),
Int64Ty(Type::getInt64Ty(M.getContext())),
IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
@@ -776,20 +774,59 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
return PreservedAnalyses::none();
}
-namespace llvm {
// Enable whole program visibility if enabled by client (e.g. linker) or
// internal option, and not force disabled.
-bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
+bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
!DisableWholeProgramVisibility;
}
+static bool
+typeIDVisibleToRegularObj(StringRef TypeID,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ // TypeID for member function pointer type is an internal construct
+ // and won't exist in IsVisibleToRegularObj. The full TypeID
+ // will be present and participate in invalidation.
+ if (TypeID.ends_with(".virtual"))
+ return false;
+
+ // TypeID that doesn't start with Itanium mangling (_ZTS) will be
+ // non-externally visible types which cannot interact with
+ // external native files. See CodeGenModule::CreateMetadataIdentifierImpl.
+ if (!TypeID.consume_front("_ZTS"))
+ return false;
+
+ // TypeID is keyed off the type name symbol (_ZTS). However, the native
+ // object may not contain this symbol if it does not contain a key
+ // function for the base type and thus only contains a reference to the
+ // type info (_ZTI). To catch this case we query using the type info
+ // symbol corresponding to the TypeID.
+ std::string typeInfo = ("_ZTI" + TypeID).str();
+ return IsVisibleToRegularObj(typeInfo);
+}
+
+static bool
+skipUpdateDueToValidation(GlobalVariable &GV,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ SmallVector<MDNode *, 2> Types;
+ GV.getMetadata(LLVMContext::MD_type, Types);
+
+ for (auto Type : Types)
+ if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get()))
+ return typeIDVisibleToRegularObj(TypeID->getString(),
+ IsVisibleToRegularObj);
+
+ return false;
+}
+
/// If whole program visibility asserted, then upgrade all public vcall
/// visibility metadata on vtable definitions to linkage unit visibility in
/// Module IR (for regular or hybrid LTO).
-void updateVCallVisibilityInModule(
+void llvm::updateVCallVisibilityInModule(
Module &M, bool WholeProgramVisibilityEnabledInLTO,
- const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
+ const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
+ bool ValidateAllVtablesHaveTypeInfos,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
return;
for (GlobalVariable &GV : M.globals()) {
@@ -800,13 +837,19 @@ void updateVCallVisibilityInModule(
GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic &&
// Don't upgrade the visibility for symbols exported to the dynamic
// linker, as we have no information on their eventual use.
- !DynamicExportSymbols.count(GV.getGUID()))
+ !DynamicExportSymbols.count(GV.getGUID()) &&
+ // With validation enabled, we want to exclude symbols visible to
+ // regular objects. Local symbols will be in this group due to the
+ // current implementation but those with VCallVisibilityTranslationUnit
+ // will have already been marked in clang so are unaffected.
+ !(ValidateAllVtablesHaveTypeInfos &&
+ skipUpdateDueToValidation(GV, IsVisibleToRegularObj)))
GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
}
}
-void updatePublicTypeTestCalls(Module &M,
- bool WholeProgramVisibilityEnabledInLTO) {
+void llvm::updatePublicTypeTestCalls(Module &M,
+ bool WholeProgramVisibilityEnabledInLTO) {
Function *PublicTypeTestFunc =
M.getFunction(Intrinsic::getName(Intrinsic::public_type_test));
if (!PublicTypeTestFunc)
@@ -832,12 +875,26 @@ void updatePublicTypeTestCalls(Module &M,
}
}
+/// Based on typeID string, get all associated vtable GUIDS that are
+/// visible to regular objects.
+void llvm::getVisibleToRegularObjVtableGUIDs(
+ ModuleSummaryIndex &Index,
+ DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols,
+ function_ref<bool(StringRef)> IsVisibleToRegularObj) {
+ for (const auto &typeID : Index.typeIdCompatibleVtableMap()) {
+ if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj))
+ for (const TypeIdOffsetVtableInfo &P : typeID.second)
+ VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID());
+ }
+}
+
/// If whole program visibility asserted, then upgrade all public vcall
/// visibility metadata on vtable definition summaries to linkage unit
/// visibility in Module summary index (for ThinLTO).
-void updateVCallVisibilityInIndex(
+void llvm::updateVCallVisibilityInIndex(
ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
- const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) {
+ const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
+ const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) {
if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
return;
for (auto &P : Index) {
@@ -850,18 +907,24 @@ void updateVCallVisibilityInIndex(
if (!GVar ||
GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
continue;
+ // With validation enabled, we want to exclude symbols visible to regular
+ // objects. Local symbols will be in this group due to the current
+ // implementation but those with VCallVisibilityTranslationUnit will have
+ // already been marked in clang so are unaffected.
+ if (VisibleToRegularObjSymbols.count(P.first))
+ continue;
GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
}
}
}
-void runWholeProgramDevirtOnIndex(
+void llvm::runWholeProgramDevirtOnIndex(
ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
}
-void updateIndexWPDForExports(
+void llvm::updateIndexWPDForExports(
ModuleSummaryIndex &Summary,
function_ref<bool(StringRef, ValueInfo)> isExported,
std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
@@ -887,8 +950,6 @@ void updateIndexWPDForExports(
}
}
-} // end namespace llvm
-
static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
// Check that summary index contains regular LTO module when performing
// export to prevent occasional use of index from pure ThinLTO compilation
@@ -942,7 +1003,7 @@ bool DevirtModule::runForTesting(
ExitOnError ExitOnErr(
"-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
std::error_code EC;
- if (StringRef(ClWriteSummary).endswith(".bc")) {
+ if (StringRef(ClWriteSummary).ends_with(".bc")) {
raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
ExitOnErr(errorCodeToError(EC));
writeIndexToFile(*Summary, OS);
@@ -1045,8 +1106,8 @@ bool DevirtModule::tryFindVirtualCallTargets(
}
bool DevirtIndex::tryFindVirtualCallTargets(
- std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo,
- uint64_t ByteOffset) {
+ std::vector<ValueInfo> &TargetsForSlot,
+ const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) {
for (const TypeIdOffsetVtableInfo &P : TIdInfo) {
// Find a representative copy of the vtable initializer.
// We can have multiple available_externally, linkonce_odr and weak_odr
@@ -1203,7 +1264,8 @@ static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
// to better ensure we have the opportunity to inline them.
bool IsExported = false;
auto &S = Callee.getSummaryList()[0];
- CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0);
+ CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false,
+ /* RelBF = */ 0);
auto AddCalls = [&](CallSiteInfo &CSInfo) {
for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
FS->addCall({Callee, CI});
@@ -1437,7 +1499,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
IRBuilder<> IRB(&CB);
std::vector<Value *> Args;
- Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
+ Args.push_back(VCallSite.VTable);
llvm::append_range(Args, CB.args());
CallBase *NewCS = nullptr;
@@ -1471,10 +1533,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// llvm.type.test and therefore require an llvm.type.test resolution for the
// type identifier.
- std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) {
- CBs.first->replaceAllUsesWith(CBs.second);
- CBs.first->eraseFromParent();
- });
+ for (auto &[Old, New] : CallBases) {
+ Old->replaceAllUsesWith(New);
+ Old->eraseFromParent();
+ }
};
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
@@ -1648,8 +1710,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
}
Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
- Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
- return ConstantExpr::getGetElementPtr(Int8Ty, C,
+ return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV,
ConstantInt::get(Int64Ty, M->Offset));
}
@@ -1708,8 +1769,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
continue;
auto *RetType = cast<IntegerType>(Call.CB.getType());
IRBuilder<> B(&Call.CB);
- Value *Addr =
- B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
+ Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
if (RetType->getBitWidth() == 1) {
Value *Bits = B.CreateLoad(Int8Ty, Addr);
Value *BitsAndBit = B.CreateAnd(Bits, Bit);
@@ -2007,17 +2067,14 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (TypeCheckedLoadFunc->getIntrinsicID() ==
Intrinsic::type_checked_load_relative) {
Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty));
- LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr);
+ LoadedValue = LoadB.CreateLoad(Int32Ty, GEP);
LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy);
GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy);
LoadedValue = LoadB.CreateAdd(GEP, LoadedValue);
LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy);
} else {
Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
- Value *GEPPtr =
- LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
- LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
+ LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP);
}
for (Instruction *LoadedPtr : LoadedPtrs) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 91ca44e0f11e..719a2678fc18 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -830,15 +830,15 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
// (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C)
Constant *NarrowC;
if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) {
- Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty);
- Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
+ Value *WideC = Builder.CreateSExt(NarrowC, Ty);
+ Value *NewC = Builder.CreateAdd(WideC, Op1C);
Value *WideX = Builder.CreateSExt(X, Ty);
return BinaryOperator::CreateAdd(WideX, NewC);
}
// (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C)
if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) {
- Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty);
- Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
+ Value *WideC = Builder.CreateZExt(NarrowC, Ty);
+ Value *NewC = Builder.CreateAdd(WideC, Op1C);
Value *WideX = Builder.CreateZExt(X, Ty);
return BinaryOperator::CreateAdd(WideX, NewC);
}
@@ -903,8 +903,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
// (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add`
Constant *Op01C;
- if (match(Op0, m_Or(m_Value(X), m_ImmConstant(Op01C))) &&
- haveNoCommonBitsSet(X, Op01C, DL, &AC, &Add, &DT))
+ if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C))))
return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C));
// (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C)
@@ -995,6 +994,69 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
return nullptr;
}
+// match variations of a^2 + 2*a*b + b^2
+//
+// to reuse the code between the FP and Int versions, the instruction OpCodes
+// and constant types have been turned into template parameters.
+//
+// Mul2Rhs: The constant to perform the multiplicative equivalent of X*2 with;
+// should be `m_SpecificFP(2.0)` for FP and `m_SpecificInt(1)` for Int
+// (we're matching `X<<1` instead of `X*2` for Int)
+template <bool FP, typename Mul2Rhs>
+static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
+ Value *&B) {
+ constexpr unsigned MulOp = FP ? Instruction::FMul : Instruction::Mul;
+ constexpr unsigned AddOp = FP ? Instruction::FAdd : Instruction::Add;
+ constexpr unsigned Mul2Op = FP ? Instruction::FMul : Instruction::Shl;
+
+ // (a * a) + (((a * 2) + b) * b)
+ if (match(&I, m_c_BinOp(
+ AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))),
+ m_OneUse(m_BinOp(
+ MulOp,
+ m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs),
+ m_Value(B)),
+ m_Deferred(B))))))
+ return true;
+
+ // ((a * b) * 2) or ((a * 2) * b)
+ // +
+ // (a * a + b * b) or (b * b + a * a)
+ return match(
+ &I,
+ m_c_BinOp(AddOp,
+ m_CombineOr(
+ m_OneUse(m_BinOp(
+ Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
+ m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
+ m_Value(B)))),
+ m_OneUse(m_c_BinOp(
+ AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
+ m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
+}
+
+// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2
+Instruction *InstCombinerImpl::foldSquareSumInt(BinaryOperator &I) {
+ Value *A, *B;
+ if (matchesSquareSum</*FP*/ false>(I, m_SpecificInt(1), A, B)) {
+ Value *AB = Builder.CreateAdd(A, B);
+ return BinaryOperator::CreateMul(AB, AB);
+ }
+ return nullptr;
+}
+
+// Fold floating point variations of a^2 + 2*a*b + b^2 -> (a + b)^2
+// Requires `nsz` and `reassoc`.
+Instruction *InstCombinerImpl::foldSquareSumFP(BinaryOperator &I) {
+ assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch");
+ Value *A, *B;
+ if (matchesSquareSum</*FP*/ true>(I, m_SpecificFP(2.0), A, B)) {
+ Value *AB = Builder.CreateFAddFMF(A, B, &I);
+ return BinaryOperator::CreateFMulFMF(AB, AB, &I);
+ }
+ return nullptr;
+}
+
// Matches multiplication expression Op * C where C is a constant. Returns the
// constant value in C and the other operand in Op. Returns true if such a
// match is found.
@@ -1146,6 +1208,21 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+// Transform:
+// (add A, (shl (neg B), Y))
+// -> (sub A, (shl B, Y))
+static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder,
+ const BinaryOperator &I) {
+ Value *A, *B, *Cnt;
+ if (match(&I,
+ m_c_Add(m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(B))), m_Value(Cnt))),
+ m_Value(A)))) {
+ Value *NewShl = Builder.CreateShl(B, Cnt);
+ return BinaryOperator::CreateSub(A, NewShl);
+ }
+ return nullptr;
+}
+
/// Try to reduce signed division by power-of-2 to an arithmetic shift right.
static Instruction *foldAddToAshr(BinaryOperator &Add) {
// Division must be by power-of-2, but not the minimum signed value.
@@ -1156,18 +1233,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
return nullptr;
// Rounding is done by adding -1 if the dividend (X) is negative and has any
- // low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
- // sext (icmp ugt (X & (DivC - 1)), SMIN)
- const APInt *MaskC;
+ // low bits set. It recognizes two canonical patterns:
+ // 1. For an 'ugt' cmp with the signed minimum value (SMIN), the
+ // pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN).
+ // 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1).
+ // Note that, by the time we end up here, if possible, ugt has been
+ // canonicalized into eq.
+ const APInt *MaskC, *MaskCCmp;
ICmpInst::Predicate Pred;
if (!match(Add.getOperand(1),
m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
- m_SignMask()))) ||
- Pred != ICmpInst::ICMP_UGT)
+ m_APInt(MaskCCmp)))))
+ return nullptr;
+
+ if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) &&
+ (Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC))
return nullptr;
APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
- if (*MaskC != (SMin | (*DivC - 1)))
+ bool IsMaskValid = Pred == ICmpInst::ICMP_UGT
+ ? (*MaskC == (SMin | (*DivC - 1)))
+ : (*DivC == 2 && *MaskC == SMin + 1);
+ if (!IsMaskValid)
return nullptr;
// (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
@@ -1327,8 +1414,10 @@ static Instruction *foldBoxMultiply(BinaryOperator &I) {
// ResLo = (CrossSum << HalfBits) + (YLo * XLo)
Value *XLo, *YLo;
Value *CrossSum;
+ // Require one-use on the multiply to avoid increasing the number of
+ // multiplications.
if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)),
- m_Mul(m_Value(YLo), m_Value(XLo)))))
+ m_OneUse(m_Mul(m_Value(YLo), m_Value(XLo))))))
return nullptr;
// XLo = X & HalfMask
@@ -1386,6 +1475,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *R = foldBinOpShiftWithShift(I))
return R;
+ if (Instruction *R = combineAddSubWithShlAddSub(Builder, I))
+ return R;
+
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
@@ -1406,7 +1498,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B));
// -A + B --> B - A
- return BinaryOperator::CreateSub(RHS, A);
+ auto *Sub = BinaryOperator::CreateSub(RHS, A);
+ auto *OB0 = cast<OverflowingBinaryOperator>(LHS);
+ Sub->setHasNoSignedWrap(I.hasNoSignedWrap() && OB0->hasNoSignedWrap());
+
+ return Sub;
}
// A + -B --> A - B
@@ -1485,8 +1581,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
// A+B --> A|B iff A and B have no bits set in common.
- if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
- return BinaryOperator::CreateOr(LHS, RHS);
+ WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
+ if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
+ return BinaryOperator::CreateDisjointOr(LHS, RHS);
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1576,15 +1673,33 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
m_c_UMin(m_Deferred(A), m_Deferred(B))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
+ // (~X) + (~Y) --> -2 - (X + Y)
+ {
+ // To ensure we can save instructions we need to ensure that we consume both
+ // LHS/RHS (i.e they have a `not`).
+ bool ConsumesLHS, ConsumesRHS;
+ if (isFreeToInvert(LHS, LHS->hasOneUse(), ConsumesLHS) && ConsumesLHS &&
+ isFreeToInvert(RHS, RHS->hasOneUse(), ConsumesRHS) && ConsumesRHS) {
+ Value *NotLHS = getFreelyInverted(LHS, LHS->hasOneUse(), &Builder);
+ Value *NotRHS = getFreelyInverted(RHS, RHS->hasOneUse(), &Builder);
+ assert(NotLHS != nullptr && NotRHS != nullptr &&
+ "isFreeToInvert desynced with getFreelyInverted");
+ Value *LHSPlusRHS = Builder.CreateAdd(NotLHS, NotRHS);
+ return BinaryOperator::CreateSub(ConstantInt::get(RHS->getType(), -2),
+ LHSPlusRHS);
+ }
+ }
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
bool Changed = false;
- if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
+ if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoSignedWrap(true);
}
- if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
+ if (!I.hasNoUnsignedWrap() &&
+ willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoUnsignedWrap(true);
}
@@ -1610,11 +1725,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// ctpop(A) + ctpop(B) => ctpop(A | B) if A and B have no bits set in common.
if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(A)))) &&
match(RHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(B)))) &&
- haveNoCommonBitsSet(A, B, DL, &AC, &I, &DT))
+ haveNoCommonBitsSet(A, B, SQ.getWithInstruction(&I)))
return replaceInstUsesWith(
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateOr(A, B)}));
+ if (Instruction *Res = foldSquareSumInt(I))
+ return Res;
+
if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
return Res;
@@ -1755,10 +1873,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
- Constant *CI =
- ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
+ Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
+ LHSIntVal->getType(), DL);
if (LHSConv->hasOneUse() &&
- ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
+ ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) ==
+ CFP &&
willNotOverflowSignedAdd(LHSIntVal, CI, I)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv");
@@ -1794,6 +1913,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
if (Instruction *F = factorizeFAddFSub(I, Builder))
return F;
+ if (Instruction *F = foldSquareSumFP(I))
+ return F;
+
// Try to fold fadd into start value of reduction intrinsic.
if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>(
m_AnyZeroFP(), m_Value(X))),
@@ -2017,14 +2139,16 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// C-(X+C2) --> (C-C2)-X
if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) {
- // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW
- // => (C-C2)-X can have NSW
+ // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW/NUW
+ // => (C-C2)-X can have NSW/NUW
bool WillNotSOV = willNotOverflowSignedSub(C, C2, I);
BinaryOperator *Res =
BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() &&
WillNotSOV);
+ Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap() &&
+ OBO1->hasNoUnsignedWrap());
return Res;
}
}
@@ -2058,7 +2182,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) ||
match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1)));
})) {
- if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this))
+ if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
+ I.hasNoSignedWrap(),
+ Op1, *this))
return BinaryOperator::CreateAdd(NegOp1, Op0);
}
if (IsNegation)
@@ -2093,19 +2219,50 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// ((X - Y) - Op1) --> X - (Y + Op1)
if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) {
- Value *Add = Builder.CreateAdd(Y, Op1);
- return BinaryOperator::CreateSub(X, Add);
+ OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0);
+ bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap();
+ bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap();
+ Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW,
+ /* HasNSW */ HasNSW);
+ BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add);
+ Sub->setHasNoUnsignedWrap(HasNUW);
+ Sub->setHasNoSignedWrap(HasNSW);
+ return Sub;
+ }
+
+ {
+ // (X + Z) - (Y + Z) --> (X - Y)
+ // This is done in other passes, but we want to be able to consume this
+ // pattern in InstCombine so we can generate it without creating infinite
+ // loops.
+ if (match(Op0, m_Add(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Add(m_Value(Y), m_Specific(Z))))
+ return BinaryOperator::CreateSub(X, Y);
+
+ // (X + C0) - (Y + C1) --> (X - Y) + (C0 - C1)
+ Constant *CX, *CY;
+ if (match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(CX)))) &&
+ match(Op1, m_OneUse(m_Add(m_Value(Y), m_ImmConstant(CY))))) {
+ Value *OpsSub = Builder.CreateSub(X, Y);
+ Constant *ConstsSub = ConstantExpr::getSub(CX, CY);
+ return BinaryOperator::CreateAdd(OpsSub, ConstsSub);
+ }
}
// (~X) - (~Y) --> Y - X
- // This is placed after the other reassociations and explicitly excludes a
- // sub-of-sub pattern to avoid infinite looping.
- if (isFreeToInvert(Op0, Op0->hasOneUse()) &&
- isFreeToInvert(Op1, Op1->hasOneUse()) &&
- !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) {
- Value *NotOp0 = Builder.CreateNot(Op0);
- Value *NotOp1 = Builder.CreateNot(Op1);
- return BinaryOperator::CreateSub(NotOp1, NotOp0);
+ {
+ // Need to ensure we can consume at least one of the `not` instructions,
+ // otherwise this can inf loop.
+ bool ConsumesOp0, ConsumesOp1;
+ if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) &&
+ isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) &&
+ (ConsumesOp0 || ConsumesOp1)) {
+ Value *NotOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder);
+ Value *NotOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder);
+ assert(NotOp0 != nullptr && NotOp1 != nullptr &&
+ "isFreeToInvert desynced with getFreelyInverted");
+ return BinaryOperator::CreateSub(NotOp1, NotOp0);
+ }
}
auto m_AddRdx = [](Value *&Vec) {
@@ -2520,18 +2677,33 @@ static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) {
return nullptr;
}
-static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
- InstCombiner::BuilderTy &Builder) {
- Value *FNeg;
- if (!match(&I, m_FNeg(m_Value(FNeg))))
- return nullptr;
-
+Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp,
+ Instruction &FMFSource) {
Value *X, *Y;
- if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))
- return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+ if (match(FNegOp, m_FMul(m_Value(X), m_Value(Y)))) {
+ return cast<Instruction>(Builder.CreateFMulFMF(
+ Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource));
+ }
+
+ if (match(FNegOp, m_FDiv(m_Value(X), m_Value(Y)))) {
+ return cast<Instruction>(Builder.CreateFDivFMF(
+ Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource));
+ }
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(FNegOp)) {
+ // Make sure to preserve flags and metadata on the call.
+ if (II->getIntrinsicID() == Intrinsic::ldexp) {
+ FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags();
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(FMF);
- if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))))
- return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+ CallInst *New = Builder.CreateCall(
+ II->getCalledFunction(),
+ {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)});
+ New->copyMetadata(*II);
+ return New;
+ }
+ }
return nullptr;
}
@@ -2553,13 +2725,13 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y)))))
return BinaryOperator::CreateFSubFMF(Y, X, &I);
- if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
- return R;
-
Value *OneUse;
if (!match(Op, m_OneUse(m_Value(OneUse))))
return nullptr;
+ if (Instruction *R = hoistFNegAboveFMulFDiv(OneUse, I))
+ return replaceInstUsesWith(I, R);
+
// Try to eliminate fneg if at least 1 arm of the select is negated.
Value *Cond;
if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) {
@@ -2569,8 +2741,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) {
S->copyFastMathFlags(&I);
if (auto *OldSel = dyn_cast<SelectInst>(Op)) {
- FastMathFlags FMF = I.getFastMathFlags();
- FMF |= OldSel->getFastMathFlags();
+ FastMathFlags FMF = I.getFastMathFlags() | OldSel->getFastMathFlags();
S->setFastMathFlags(FMF);
if (!OldSel->hasNoSignedZeros() && !CommonOperand &&
!isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition()))
@@ -2638,9 +2809,6 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
- if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
- return R;
-
Value *X, *Y;
Constant *C;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8a1fb6b7f17e..6002f599ca71 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1099,39 +1099,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
return Builder.CreateICmpUGE(Builder.CreateNeg(B), A);
}
- Value *Base, *Offset;
- if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset))))
- return nullptr;
-
- if (!match(UnsignedICmp,
- m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) ||
- !ICmpInst::isUnsigned(UnsignedPred))
- return nullptr;
-
- // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset
- // (no overflow and not null)
- if ((UnsignedPred == ICmpInst::ICMP_UGE ||
- UnsignedPred == ICmpInst::ICMP_UGT) &&
- EqPred == ICmpInst::ICMP_NE && IsAnd)
- return Builder.CreateICmpUGT(Base, Offset);
-
- // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset
- // (overflow or null)
- if ((UnsignedPred == ICmpInst::ICMP_ULE ||
- UnsignedPred == ICmpInst::ICMP_ULT) &&
- EqPred == ICmpInst::ICMP_EQ && !IsAnd)
- return Builder.CreateICmpULE(Base, Offset);
-
- // Base <= Offset && (Base - Offset) != 0 --> Base < Offset
- if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE &&
- IsAnd)
- return Builder.CreateICmpULT(Base, Offset);
-
- // Base > Offset || (Base - Offset) == 0 --> Base >= Offset
- if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ &&
- !IsAnd)
- return Builder.CreateICmpUGE(Base, Offset);
-
return nullptr;
}
@@ -1179,13 +1146,40 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
return nullptr;
CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
- if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
- return nullptr;
+ auto GetMatchPart = [&](ICmpInst *Cmp,
+ unsigned OpNo) -> std::optional<IntPart> {
+ if (Pred == Cmp->getPredicate())
+ return matchIntPart(Cmp->getOperand(OpNo));
+
+ const APInt *C;
+ // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ult (xor x, y), 1 << C) so also look for that.
+ if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) {
+ if (!match(Cmp->getOperand(1), m_Power2(C)) ||
+ !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+ return std::nullopt;
+ }
- std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
- std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
- std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
- std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+ // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
+ else if (Pred == CmpInst::ICMP_NE &&
+ Cmp->getPredicate() == CmpInst::ICMP_UGT) {
+ if (!match(Cmp->getOperand(1), m_LowBitMask(C)) ||
+ !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+ return std::nullopt;
+ } else {
+ return std::nullopt;
+ }
+
+ unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero();
+ Instruction *I = cast<Instruction>(Cmp->getOperand(0));
+ return {{I->getOperand(OpNo), From, C->getBitWidth() - From}};
+ };
+
+ std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0);
+ std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1);
+ std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0);
+ std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1);
if (!L0 || !R0 || !L1 || !R1)
return nullptr;
@@ -1616,7 +1610,7 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
/// (~A & ~B) == (~(A | B))
/// (~A | ~B) == (~(A & B))
static Instruction *matchDeMorgansLaws(BinaryOperator &I,
- InstCombiner::BuilderTy &Builder) {
+ InstCombiner &IC) {
const Instruction::BinaryOps Opcode = I.getOpcode();
assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
"Trying to match De Morgan's Laws with something other than and/or");
@@ -1629,10 +1623,10 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
Value *A, *B;
if (match(Op0, m_OneUse(m_Not(m_Value(A)))) &&
match(Op1, m_OneUse(m_Not(m_Value(B)))) &&
- !InstCombiner::isFreeToInvert(A, A->hasOneUse()) &&
- !InstCombiner::isFreeToInvert(B, B->hasOneUse())) {
+ !IC.isFreeToInvert(A, A->hasOneUse()) &&
+ !IC.isFreeToInvert(B, B->hasOneUse())) {
Value *AndOr =
- Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan");
+ IC.Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan");
return BinaryOperator::CreateNot(AndOr);
}
@@ -1644,8 +1638,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
Value *C;
if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) &&
match(Op1, m_Not(m_Value(C)))) {
- Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C);
- return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO));
+ Value *FlippedBO = IC.Builder.CreateBinOp(FlippedOpcode, B, C);
+ return BinaryOperator::Create(Opcode, A, IC.Builder.CreateNot(FlippedBO));
}
return nullptr;
@@ -1669,7 +1663,7 @@ bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) {
/// Fold {and,or,xor} (cast X), C.
static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
- InstCombiner::BuilderTy &Builder) {
+ InstCombinerImpl &IC) {
Constant *C = dyn_cast<Constant>(Logic.getOperand(1));
if (!C)
return nullptr;
@@ -1684,21 +1678,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
// instruction may be cheaper (particularly in the case of vectors).
Value *X;
if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
- Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy);
- Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy);
- if (ZextTruncC == C) {
+ if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
// LogicOpc (zext X), C --> zext (LogicOpc X, C)
- Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC);
+ Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new ZExtInst(NewOp, DestTy);
}
}
if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) {
- Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy);
- Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy);
- if (SextTruncC == C) {
+ if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
// LogicOpc (sext X), C --> sext (LogicOpc X, C)
- Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC);
+ Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new SExtInst(NewOp, DestTy);
}
}
@@ -1756,7 +1746,7 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
if (!SrcTy->isIntOrIntVectorTy())
return nullptr;
- if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder))
+ if (Instruction *Ret = foldLogicCastConstant(I, Cast0, *this))
return Ret;
CastInst *Cast1 = dyn_cast<CastInst>(Op1);
@@ -1802,29 +1792,6 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
return CastInst::Create(CastOpcode, NewOp, DestTy);
}
- // For now, only 'and'/'or' have optimizations after this.
- if (LogicOpc == Instruction::Xor)
- return nullptr;
-
- // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the
- // cast is otherwise not optimizable. This happens for vector sexts.
- ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src);
- ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src);
- if (ICmp0 && ICmp1) {
- if (Value *Res =
- foldAndOrOfICmps(ICmp0, ICmp1, I, LogicOpc == Instruction::And))
- return CastInst::Create(CastOpcode, Res, DestTy);
- return nullptr;
- }
-
- // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the
- // cast is otherwise not optimizable. This happens for vector sexts.
- FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src);
- FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src);
- if (FCmp0 && FCmp1)
- if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And))
- return CastInst::Create(CastOpcode, R, DestTy);
-
return nullptr;
}
@@ -2160,10 +2127,10 @@ Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) {
Constant *ShiftedC1, *ShiftedC2, *AddC;
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
- if (!match(&I,
- m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)),
- m_Shift(m_ImmConstant(ShiftedC2),
- m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC))))))
+ if (!match(&I, m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)),
+ m_Shift(m_ImmConstant(ShiftedC2),
+ m_AddLike(m_Deferred(ShAmt),
+ m_ImmConstant(AddC))))))
return nullptr;
// Make sure the add constant is a valid shift amount.
@@ -2254,6 +2221,14 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y);
}
+ // Canonicalize:
+ // (X +/- Y) & Y --> ~X & Y when Y is a power of 2.
+ if (match(&I, m_c_And(m_Value(Y), m_OneUse(m_CombineOr(
+ m_c_Add(m_Value(X), m_Deferred(Y)),
+ m_Sub(m_Value(X), m_Deferred(Y)))))) &&
+ isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, /*Depth*/ 0, &I))
+ return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
+
const APInt *C;
if (match(Op1, m_APInt(C))) {
const APInt *XorC;
@@ -2300,13 +2275,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
const APInt *AddC;
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) {
- // If we add zeros to every bit below a mask, the add has no effect:
- // (X + AddC) & LowMaskC --> X & LowMaskC
- unsigned Ctlz = C->countl_zero();
- APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz));
- if ((*AddC & LowMask).isZero())
- return BinaryOperator::CreateAnd(X, Op1);
-
// If we are masking the result of the add down to exactly one bit and
// the constant we are adding has no bits set below that bit, then the
// add is flipping a single bit. Example:
@@ -2455,6 +2423,28 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
}
}
+ // If we are clearing the sign bit of a floating-point value, convert this to
+ // fabs, then cast back to integer.
+ //
+ // This is a generous interpretation for noimplicitfloat, this is not a true
+ // floating-point operation.
+ //
+ // Assumes any IEEE-represented type has the sign bit in the high bit.
+ // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
+ Value *CastOp;
+ if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+ match(Op1, m_MaxSignedValue()) &&
+ !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
+ Attribute::NoImplicitFloat)) {
+ Type *EltTy = CastOp->getType()->getScalarType();
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
+ EltTy->getPrimitiveSizeInBits() ==
+ I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
+ return new BitCastInst(FAbs, I.getType());
+ }
+ }
+
if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))),
m_SignMask())) &&
match(Y, m_SpecificInt_ICMP(
@@ -2479,21 +2469,21 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (I.getType()->isIntOrIntVectorTy(1)) {
if (auto *SI0 = dyn_cast<SelectInst>(Op0)) {
- if (auto *I =
+ if (auto *R =
foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ true))
- return I;
+ return R;
}
if (auto *SI1 = dyn_cast<SelectInst>(Op1)) {
- if (auto *I =
+ if (auto *R =
foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ true))
- return I;
+ return R;
}
}
if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I))
return FoldedLogic;
- if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder))
+ if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this))
return DeMorgan;
{
@@ -2513,16 +2503,24 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(Op1, B);
// (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C
- if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
- if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))
- if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse()))
- return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C));
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) {
+ Value *NotC = Op1->hasOneUse()
+ ? Builder.CreateNot(C)
+ : getFreelyInverted(C, C->hasOneUse(), &Builder);
+ if (NotC != nullptr)
+ return BinaryOperator::CreateAnd(Op0, NotC);
+ }
// ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C
- if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))))
- if (match(Op1, m_Xor(m_Specific(B), m_Specific(A))))
- if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse()))
- return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C));
+ if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))) &&
+ match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) {
+ Value *NotC = Op0->hasOneUse()
+ ? Builder.CreateNot(C)
+ : getFreelyInverted(C, C->hasOneUse(), &Builder);
+ if (NotC != nullptr)
+ return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C));
+ }
// (A | B) & (~A ^ B) -> A & B
// (A | B) & (B ^ ~A) -> A & B
@@ -2621,23 +2619,34 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
// with binop identity constant. But creating a select with non-constant
// arm may not be reversible due to poison semantics. Is that a good
// canonicalization?
- Value *A;
- if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) &&
- A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Op1, Constant::getNullValue(Ty));
- if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) &&
+ Value *A, *B;
+ if (match(&I, m_c_And(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Op0, Constant::getNullValue(Ty));
+ return SelectInst::Create(A, B, Constant::getNullValue(Ty));
// Similarly, a 'not' of the bool translates to a swap of the select arms:
- // ~sext(A) & Op1 --> A ? 0 : Op1
- // Op0 & ~sext(A) --> A ? 0 : Op0
- if (match(Op0, m_Not(m_SExt(m_Value(A)))) &&
+ // ~sext(A) & B / B & ~sext(A) --> A ? 0 : B
+ if (match(&I, m_c_And(m_Not(m_SExt(m_Value(A))), m_Value(B))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Constant::getNullValue(Ty), Op1);
- if (match(Op1, m_Not(m_SExt(m_Value(A)))) &&
+ return SelectInst::Create(A, Constant::getNullValue(Ty), B);
+
+ // and(zext(A), B) -> A ? (B & 1) : 0
+ if (match(&I, m_c_And(m_OneUse(m_ZExt(m_Value(A))), m_Value(B))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Constant::getNullValue(Ty), Op0);
+ return SelectInst::Create(A, Builder.CreateAnd(B, ConstantInt::get(Ty, 1)),
+ Constant::getNullValue(Ty));
+
+ // (-1 + A) & B --> A ? 0 : B where A is 0/1.
+ if (match(&I, m_c_And(m_OneUse(m_Add(m_ZExtOrSelf(m_Value(A)), m_AllOnes())),
+ m_Value(B)))) {
+ if (A->getType()->isIntOrIntVectorTy(1))
+ return SelectInst::Create(A, Constant::getNullValue(Ty), B);
+ if (computeKnownBits(A, /* Depth */ 0, &I).countMaxActiveBits() <= 1) {
+ return SelectInst::Create(
+ Builder.CreateICmpEQ(A, Constant::getNullValue(A->getType())), B,
+ Constant::getNullValue(Ty));
+ }
+ }
// (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext
if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf(
@@ -2698,105 +2707,178 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I,
}
/// Match UB-safe variants of the funnel shift intrinsic.
-static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
+static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC,
+ const DominatorTree &DT) {
// TODO: Can we reduce the code duplication between this and the related
// rotate matching code under visitSelect and visitTrunc?
unsigned Width = Or.getType()->getScalarSizeInBits();
+ Instruction *Or0, *Or1;
+ if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
+ !match(Or.getOperand(1), m_Instruction(Or1)))
+ return nullptr;
+
+ bool IsFshl = true; // Sub on LSHR.
+ SmallVector<Value *, 3> FShiftArgs;
+
// First, find an or'd pair of opposite shifts:
// or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
- BinaryOperator *Or0, *Or1;
- if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
- !match(Or.getOperand(1), m_BinOp(Or1)))
- return nullptr;
+ if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
+ Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+ if (!match(Or0,
+ m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+ !match(Or1,
+ m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+ Or0->getOpcode() == Or1->getOpcode())
+ return nullptr;
- Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
- Or0->getOpcode() == Or1->getOpcode())
- return nullptr;
+ // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(ShVal0, ShVal1);
+ std::swap(ShAmt0, ShAmt1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
- // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
- if (Or0->getOpcode() == BinaryOperator::LShr) {
- std::swap(Or0, Or1);
- std::swap(ShVal0, ShVal1);
- std::swap(ShAmt0, ShAmt1);
- }
- assert(Or0->getOpcode() == BinaryOperator::Shl &&
- Or1->getOpcode() == BinaryOperator::LShr &&
- "Illegal or(shift,shift) pair");
+ // Match the shift amount operands for a funnel shift pattern. This always
+ // matches a subtraction on the R operand.
+ auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+ // Check for constant shift amounts that sum to the bitwidth.
+ const APInt *LI, *RI;
+ if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+ if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+ return ConstantInt::get(L->getType(), *LI);
+
+ Constant *LC, *RC;
+ if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+ match(L,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(R,
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+ return ConstantExpr::mergeUndefsWith(LC, RC);
+
+ // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+ // We limit this to X < Width in case the backend re-expands the
+ // intrinsic, and has to reintroduce a shift modulo operation (InstCombine
+ // might remove it after this fold). This still doesn't guarantee that the
+ // final codegen will match this original pattern.
+ if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+ KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+ return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+ }
- // Match the shift amount operands for a funnel shift pattern. This always
- // matches a subtraction on the R operand.
- auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
- // Check for constant shift amounts that sum to the bitwidth.
- const APInt *LI, *RI;
- if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
- if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
- return ConstantInt::get(L->getType(), *LI);
+ // For non-constant cases, the following patterns currently only work for
+ // rotation patterns.
+ // TODO: Add general funnel-shift compatible patterns.
+ if (ShVal0 != ShVal1)
+ return nullptr;
- Constant *LC, *RC;
- if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
- match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
- match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
- match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
- return ConstantExpr::mergeUndefsWith(LC, RC);
+ // For non-constant cases we don't support non-pow2 shift masks.
+ // TODO: Is it worth matching urem as well?
+ if (!isPowerOf2_32(Width))
+ return nullptr;
- // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
- // We limit this to X < Width in case the backend re-expands the intrinsic,
- // and has to reintroduce a shift modulo operation (InstCombine might remove
- // it after this fold). This still doesn't guarantee that the final codegen
- // will match this original pattern.
- if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
- KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
- return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+ // The shift amount may be masked with negation:
+ // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+ Value *X;
+ unsigned Mask = Width - 1;
+ if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
+ match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
+ return X;
+
+ // Similar to above, but the shift amount may be extended after masking,
+ // so return the extended value as the parameter for the intrinsic.
+ if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+ match(R,
+ m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
+ m_SpecificInt(Mask))))
+ return L;
+
+ if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+ match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+ return L;
+
+ return nullptr;
+ };
+
+ Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
+ if (!ShAmt) {
+ ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
+ IsFshl = false; // Sub on SHL.
}
+ if (!ShAmt)
+ return nullptr;
+
+ FShiftArgs = {ShVal0, ShVal1, ShAmt};
+ } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) {
+ // If there are two 'or' instructions concat variables in opposite order:
+ //
+ // Slot1 and Slot2 are all zero bits.
+ // | Slot1 | Low | Slot2 | High |
+ // LowHigh = or (shl (zext Low), ZextLowShlAmt), (zext High)
+ // | Slot2 | High | Slot1 | Low |
+ // HighLow = or (shl (zext High), ZextHighShlAmt), (zext Low)
+ //
+ // the latter 'or' can be safely convert to
+ // -> HighLow = fshl LowHigh, LowHigh, ZextHighShlAmt
+ // if ZextLowShlAmt + ZextHighShlAmt == Width.
+ if (!isa<ZExtInst>(Or1))
+ std::swap(Or0, Or1);
- // For non-constant cases, the following patterns currently only work for
- // rotation patterns.
- // TODO: Add general funnel-shift compatible patterns.
- if (ShVal0 != ShVal1)
+ Value *High, *ZextHigh, *Low;
+ const APInt *ZextHighShlAmt;
+ if (!match(Or0,
+ m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt)))))
return nullptr;
- // For non-constant cases we don't support non-pow2 shift masks.
- // TODO: Is it worth matching urem as well?
- if (!isPowerOf2_32(Width))
+ if (!match(Or1, m_ZExt(m_Value(Low))) ||
+ !match(ZextHigh, m_ZExt(m_Value(High))))
return nullptr;
- // The shift amount may be masked with negation:
- // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
- Value *X;
- unsigned Mask = Width - 1;
- if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
- match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
- return X;
+ unsigned HighSize = High->getType()->getScalarSizeInBits();
+ unsigned LowSize = Low->getType()->getScalarSizeInBits();
+ // Make sure High does not overlap with Low and most significant bits of
+ // High aren't shifted out.
+ if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize))
+ return nullptr;
- // Similar to above, but the shift amount may be extended after masking,
- // so return the extended value as the parameter for the intrinsic.
- if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
- match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
- m_SpecificInt(Mask))))
- return L;
+ for (User *U : ZextHigh->users()) {
+ Value *X, *Y;
+ if (!match(U, m_Or(m_Value(X), m_Value(Y))))
+ continue;
- if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
- match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
- return L;
+ if (!isa<ZExtInst>(Y))
+ std::swap(X, Y);
- return nullptr;
- };
+ const APInt *ZextLowShlAmt;
+ if (!match(X, m_Shl(m_Specific(Or1), m_APInt(ZextLowShlAmt))) ||
+ !match(Y, m_Specific(ZextHigh)) || !DT.dominates(U, &Or))
+ continue;
- Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
- bool IsFshl = true; // Sub on LSHR.
- if (!ShAmt) {
- ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
- IsFshl = false; // Sub on SHL.
+ // HighLow is good concat. If sum of two shifts amount equals to Width,
+ // LowHigh must also be a good concat.
+ if (*ZextLowShlAmt + *ZextHighShlAmt != Width)
+ continue;
+
+ // Low must not overlap with High and most significant bits of Low must
+ // not be shifted out.
+ assert(ZextLowShlAmt->uge(HighSize) &&
+ ZextLowShlAmt->ule(Width - LowSize) && "Invalid concat");
+
+ FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)};
+ break;
+ }
}
- if (!ShAmt)
+
+ if (FShiftArgs.empty())
return nullptr;
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
- return CallInst::Create(F, {ShVal0, ShVal1, ShAmt});
+ return CallInst::Create(F, FShiftArgs);
}
/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
@@ -3272,14 +3354,14 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1)) {
if (auto *SI0 = dyn_cast<SelectInst>(Op0)) {
- if (auto *I =
+ if (auto *R =
foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ false))
- return I;
+ return R;
}
if (auto *SI1 = dyn_cast<SelectInst>(Op1)) {
- if (auto *I =
+ if (auto *R =
foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ false))
- return I;
+ return R;
}
}
@@ -3290,7 +3372,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*MatchBitReversals*/ true))
return BitOp;
- if (Instruction *Funnel = matchFunnelShift(I, *this))
+ if (Instruction *Funnel = matchFunnelShift(I, *this, DT))
return Funnel;
if (Instruction *Concat = matchOrConcat(I, Builder))
@@ -3311,9 +3393,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// If the operands have no common bits set:
// or (mul X, Y), X --> add (mul X, Y), X --> mul X, (Y + 1)
- if (match(&I,
- m_c_Or(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), m_Deferred(X))) &&
- haveNoCommonBitsSet(Op0, Op1, DL)) {
+ if (match(&I, m_c_DisjointOr(m_OneUse(m_Mul(m_Value(X), m_Value(Y))),
+ m_Deferred(X)))) {
Value *IncrementY = Builder.CreateAdd(Y, ConstantInt::get(Ty, 1));
return BinaryOperator::CreateMul(X, IncrementY);
}
@@ -3435,7 +3516,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A))))
return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C));
- if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder))
+ if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this))
return DeMorgan;
// Canonicalize xor to the RHS.
@@ -3581,12 +3662,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// with binop identity constant. But creating a select with non-constant
// arm may not be reversible due to poison semantics. Is that a good
// canonicalization?
- if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) &&
+ if (match(&I, m_c_Or(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op1);
- if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) &&
- A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op0);
+ return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), B);
// Note: If we've gotten to the point of visiting the outer OR, then the
// inner one couldn't be simplified. If it was a constant, then it won't
@@ -3628,6 +3706,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
}
}
+ {
+ // ((A & B) ^ A) | ((A & B) ^ B) -> A ^ B
+ // (A ^ (A & B)) | (B ^ (A & B)) -> A ^ B
+ // ((A & B) ^ B) | ((A & B) ^ A) -> A ^ B
+ // (B ^ (A & B)) | (A ^ (A & B)) -> A ^ B
+ const auto TryXorOpt = [&](Value *Lhs, Value *Rhs) -> Instruction * {
+ if (match(Lhs, m_c_Xor(m_And(m_Value(A), m_Value(B)), m_Deferred(A))) &&
+ match(Rhs,
+ m_c_Xor(m_And(m_Specific(A), m_Specific(B)), m_Deferred(B)))) {
+ return BinaryOperator::CreateXor(A, B);
+ }
+ return nullptr;
+ };
+
+ if (Instruction *Result = TryXorOpt(Op0, Op1))
+ return Result;
+ if (Instruction *Result = TryXorOpt(Op1, Op0))
+ return Result;
+ }
+
if (Instruction *V =
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
return V;
@@ -3720,6 +3818,31 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
return Res;
+ // If we are setting the sign bit of a floating-point value, convert
+ // this to fneg(fabs), then cast back to integer.
+ //
+ // If the result isn't immediately cast back to a float, this will increase
+ // the number of instructions. This is still probably a better canonical form
+ // as it enables FP value tracking.
+ //
+ // Assumes any IEEE-represented type has the sign bit in the high bit.
+ //
+ // This is generous interpretation of noimplicitfloat, this is not a true
+ // floating-point operation.
+ Value *CastOp;
+ if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
+ Attribute::NoImplicitFloat)) {
+ Type *EltTy = CastOp->getType()->getScalarType();
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
+ EltTy->getPrimitiveSizeInBits() ==
+ I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
+ Value *FNegFAbs = Builder.CreateFNeg(FAbs);
+ return new BitCastInst(FNegFAbs, I.getType());
+ }
+ }
+
return nullptr;
}
@@ -3931,26 +4054,6 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
return nullptr;
}
-// Transform
-// ~(x ^ y)
-// into:
-// (~x) ^ y
-// or into
-// x ^ (~y)
-static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y,
- InstCombiner::BuilderTy &Builder) {
- // We only want to do the transform if it is free to do.
- if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) {
- // Ok, good.
- } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) {
- std::swap(X, Y);
- } else
- return nullptr;
-
- Value *NotX = Builder.CreateNot(X, X->getName() + ".not");
- return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan");
-}
-
static Instruction *foldNotXor(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
Value *X, *Y;
@@ -3959,9 +4062,6 @@ static Instruction *foldNotXor(BinaryOperator &I,
if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y))))))
return nullptr;
- if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder))
- return NewXor;
-
auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) {
return A == C || A == D || B == C || B == D;
};
@@ -4023,13 +4123,13 @@ static bool canFreelyInvert(InstCombiner &IC, Value *Op,
Instruction *IgnoredUser) {
auto *I = dyn_cast<Instruction>(Op);
return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) &&
- InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser);
+ IC.canFreelyInvertAllUsersOf(I, IgnoredUser);
}
static Value *freelyInvert(InstCombinerImpl &IC, Value *Op,
Instruction *IgnoredUser) {
auto *I = cast<Instruction>(Op);
- IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef());
+ IC.Builder.SetInsertPoint(*I->getInsertionPointAfterDef());
Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not");
Op->replaceUsesWithIf(NotOp,
[NotOp](Use &U) { return U.getUser() != NotOp; });
@@ -4067,7 +4167,7 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) {
Op0 = freelyInvert(*this, Op0, &I);
Op1 = freelyInvert(*this, Op1, &I);
- Builder.SetInsertPoint(I.getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*I.getInsertionPointAfterDef());
Value *NewLogicOp;
if (IsBinaryOp)
NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not");
@@ -4115,7 +4215,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) {
*OpToInvert = freelyInvert(*this, *OpToInvert, &I);
- Builder.SetInsertPoint(&*I.getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*I.getInsertionPointAfterDef());
Value *NewBinOp;
if (IsBinaryOp)
NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not");
@@ -4259,15 +4359,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
// ~max(~X, Y) --> min(X, ~Y)
auto *II = dyn_cast<IntrinsicInst>(NotOp);
if (II && II->hasOneUse()) {
- if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) &&
- isFreeToInvert(X, X->hasOneUse()) &&
- isFreeToInvert(Y, Y->hasOneUse())) {
- Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
- Value *NotX = Builder.CreateNot(X);
- Value *NotY = Builder.CreateNot(Y);
- Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY);
- return replaceInstUsesWith(I, InvMaxMin);
- }
if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) {
Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
Value *NotY = Builder.CreateNot(Y);
@@ -4317,6 +4408,11 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
if (Instruction *NewXor = foldNotXor(I, Builder))
return NewXor;
+ // TODO: Could handle multi-use better by checking if all uses of NotOp (other
+ // than I) can be inverted.
+ if (Value *R = getFreelyInverted(NotOp, NotOp->hasOneUse(), &Builder))
+ return replaceInstUsesWith(I, R);
+
return nullptr;
}
@@ -4366,7 +4462,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
Value *M;
if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()),
m_c_And(m_Deferred(M), m_Value()))))
- return BinaryOperator::CreateOr(Op0, Op1);
+ return BinaryOperator::CreateDisjointOr(Op0, Op1);
if (Instruction *Xor = visitMaskedMerge(I, Builder))
return Xor;
@@ -4466,6 +4562,27 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
// a 'not' op and moving it before the shift. Doing that requires
// preventing the inverse fold in canShiftBinOpWithConstantRHS().
}
+
+ // If we are XORing the sign bit of a floating-point value, convert
+ // this to fneg, then cast back to integer.
+ //
+ // This is generous interpretation of noimplicitfloat, this is not a true
+ // floating-point operation.
+ //
+ // Assumes any IEEE-represented type has the sign bit in the high bit.
+ // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
+ Value *CastOp;
+ if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
+ Attribute::NoImplicitFloat)) {
+ Type *EltTy = CastOp->getType()->getScalarType();
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
+ EltTy->getPrimitiveSizeInBits() ==
+ I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ Value *FNeg = Builder.CreateFNeg(CastOp);
+ return new BitCastInst(FNeg, I.getType());
+ }
+ }
}
// FIXME: This should not be limited to scalar (pull into APInt match above).
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d3ec6a7aa667..255ce6973a16 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -89,12 +89,6 @@ static cl::opt<unsigned> GuardWideningWindow(
cl::desc("How wide an instruction window to bypass looking for "
"another guard"));
-namespace llvm {
-/// enable preservation of attributes in assume like:
-/// call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ]
-extern cl::opt<bool> EnableKnowledgeRetention;
-} // namespace llvm
-
/// Return the specified type promoted as it would be to pass though a va_arg
/// area.
static Type *getPromotedType(Type *Ty) {
@@ -174,14 +168,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
return nullptr;
// Use an integer load+store unless we can find something better.
- unsigned SrcAddrSp =
- cast<PointerType>(MI->getArgOperand(1)->getType())->getAddressSpace();
- unsigned DstAddrSp =
- cast<PointerType>(MI->getArgOperand(0)->getType())->getAddressSpace();
-
IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3);
- Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp);
- Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp);
// If the memcpy has metadata describing the members, see if we can get the
// TBAA tag describing our copy.
@@ -200,8 +187,8 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
CopyMD = cast<MDNode>(M->getOperand(2));
}
- Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy);
- Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy);
+ Value *Src = MI->getArgOperand(1);
+ Value *Dest = MI->getArgOperand(0);
LoadInst *L = Builder.CreateLoad(IntType, Src);
// Alignment from the mem intrinsic will be better, so use it.
L->setAlignment(*CopySrcAlign);
@@ -291,9 +278,6 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8.
Value *Dest = MI->getDest();
- unsigned DstAddrSp = cast<PointerType>(Dest->getType())->getAddressSpace();
- Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp);
- Dest = Builder.CreateBitCast(Dest, NewDstPtrTy);
// Extract the fill value and store.
const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL;
@@ -301,7 +285,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile());
S->copyMetadata(*MI, LLVMContext::MD_DIAssignID);
for (auto *DAI : at::getAssignmentMarkers(S)) {
- if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; }))
+ if (llvm::is_contained(DAI->location_ops(), FillC))
DAI->replaceVariableLocationOp(FillC, FillVal);
}
@@ -500,8 +484,6 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II,
if (Result->getType()->getPointerAddressSpace() !=
II.getType()->getPointerAddressSpace())
Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType());
- if (Result->getType() != II.getType())
- Result = IC.Builder.CreateBitCast(Result, II.getType());
return cast<Instruction>(Result);
}
@@ -532,6 +514,8 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType()));
}
+ Constant *C;
+
if (IsTZ) {
// cttz(-x) -> cttz(x)
if (match(Op0, m_Neg(m_Value(X))))
@@ -567,6 +551,38 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X))))
return IC.replaceOperand(II, 0, X);
+
+ // cttz(shl(%const, %val), 1) --> add(cttz(%const, 1), %val)
+ if (match(Op0, m_Shl(m_ImmConstant(C), m_Value(X))) &&
+ match(Op1, m_One())) {
+ Value *ConstCttz =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1);
+ return BinaryOperator::CreateAdd(ConstCttz, X);
+ }
+
+ // cttz(lshr exact (%const, %val), 1) --> sub(cttz(%const, 1), %val)
+ if (match(Op0, m_Exact(m_LShr(m_ImmConstant(C), m_Value(X)))) &&
+ match(Op1, m_One())) {
+ Value *ConstCttz =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1);
+ return BinaryOperator::CreateSub(ConstCttz, X);
+ }
+ } else {
+ // ctlz(lshr(%const, %val), 1) --> add(ctlz(%const, 1), %val)
+ if (match(Op0, m_LShr(m_ImmConstant(C), m_Value(X))) &&
+ match(Op1, m_One())) {
+ Value *ConstCtlz =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1);
+ return BinaryOperator::CreateAdd(ConstCtlz, X);
+ }
+
+ // ctlz(shl nuw (%const, %val), 1) --> sub(ctlz(%const, 1), %val)
+ if (match(Op0, m_NUWShl(m_ImmConstant(C), m_Value(X))) &&
+ match(Op1, m_One())) {
+ Value *ConstCtlz =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1);
+ return BinaryOperator::CreateSub(ConstCtlz, X);
+ }
}
KnownBits Known = IC.computeKnownBits(Op0, 0, &II);
@@ -911,11 +927,27 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) {
Value *FAbsSrc;
if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) {
- II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask)));
+ II.setArgOperand(1, ConstantInt::get(Src1->getType(), inverse_fabs(Mask)));
return replaceOperand(II, 0, FAbsSrc);
}
- // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf
+ if ((OrderedMask == fcInf || OrderedInvertedMask == fcInf) &&
+ (IsOrdered || IsUnordered) && !IsStrict) {
+ // is.fpclass(x, fcInf) -> fcmp oeq fabs(x), +inf
+ // is.fpclass(x, ~fcInf) -> fcmp one fabs(x), +inf
+ // is.fpclass(x, fcInf|fcNan) -> fcmp ueq fabs(x), +inf
+ // is.fpclass(x, ~(fcInf|fcNan)) -> fcmp une fabs(x), +inf
+ Constant *Inf = ConstantFP::getInfinity(Src0->getType());
+ FCmpInst::Predicate Pred =
+ IsUnordered ? FCmpInst::FCMP_UEQ : FCmpInst::FCMP_OEQ;
+ if (OrderedInvertedMask == fcInf)
+ Pred = IsUnordered ? FCmpInst::FCMP_UNE : FCmpInst::FCMP_ONE;
+
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Src0);
+ Value *CmpInf = Builder.CreateFCmp(Pred, Fabs, Inf);
+ CmpInf->takeName(&II);
+ return replaceInstUsesWith(II, CmpInf);
+ }
if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) &&
(IsOrdered || IsUnordered) && !IsStrict) {
@@ -992,8 +1024,7 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) {
return replaceInstUsesWith(II, FCmp);
}
- KnownFPClass Known = computeKnownFPClass(
- Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT);
+ KnownFPClass Known = computeKnownFPClass(Src0, Mask, &II);
// Clear test bits we know must be false from the source value.
// fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other
@@ -1030,6 +1061,20 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL);
}
+static std::optional<bool> getKnownSignOrZero(Value *Op, Instruction *CxtI,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ DominatorTree *DT) {
+ if (std::optional<bool> Sign = getKnownSign(Op, CxtI, DL, AC, DT))
+ return Sign;
+
+ Value *X, *Y;
+ if (match(Op, m_NSWSub(m_Value(X), m_Value(Y))))
+ return isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, CxtI, DL);
+
+ return std::nullopt;
+}
+
/// Return true if two values \p Op0 and \p Op1 are known to have the same sign.
static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI,
const DataLayout &DL, AssumptionCache *AC,
@@ -1530,12 +1575,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
return replaceOperand(*II, 0, X);
- if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) {
- // abs(x) -> x if x >= 0
- if (!*Sign)
+ if (std::optional<bool> Known =
+ getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) {
+ // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y)
+ // abs(x) -> x if x > 0 (include abs(x-y) --> x - y where x > y)
+ if (!*Known)
return replaceInstUsesWith(*II, IIOperand);
// abs(x) -> -x if x < 0
+ // abs(x) -> -x if x < = 0 (include abs(x-y) --> y - x where x <= y)
if (IntMinIsPoison)
return BinaryOperator::CreateNSWNeg(IIOperand);
return BinaryOperator::CreateNeg(IIOperand);
@@ -1580,8 +1628,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType());
- if (ConstantExpr::getZExt(NarrowC, II->getType()) == C) {
+ if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
}
@@ -1603,13 +1650,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType());
- if (ConstantExpr::getSExt(NarrowC, II->getType()) == C) {
+ if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
}
}
+ // umin(i1 X, i1 Y) -> and i1 X, Y
+ // smax(i1 X, i1 Y) -> and i1 X, Y
+ if ((IID == Intrinsic::umin || IID == Intrinsic::smax) &&
+ II->getType()->isIntOrIntVectorTy(1)) {
+ return BinaryOperator::CreateAnd(I0, I1);
+ }
+
+ // umax(i1 X, i1 Y) -> or i1 X, Y
+ // smin(i1 X, i1 Y) -> or i1 X, Y
+ if ((IID == Intrinsic::umax || IID == Intrinsic::smin) &&
+ II->getType()->isIntOrIntVectorTy(1)) {
+ return BinaryOperator::CreateOr(I0, I1);
+ }
+
if (IID == Intrinsic::smax || IID == Intrinsic::smin) {
// smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y)
// smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y)
@@ -1672,12 +1732,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * {
Value *A;
if (match(X, m_OneUse(m_Not(m_Value(A)))) &&
- !isFreeToInvert(A, A->hasOneUse()) &&
- isFreeToInvert(Y, Y->hasOneUse())) {
- Value *NotY = Builder.CreateNot(Y);
- Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID);
- Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY);
- return BinaryOperator::CreateNot(InvMaxMin);
+ !isFreeToInvert(A, A->hasOneUse())) {
+ if (Value *NotY = getFreelyInverted(Y, Y->hasOneUse(), &Builder)) {
+ Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID);
+ Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY);
+ return BinaryOperator::CreateNot(InvMaxMin);
+ }
}
return nullptr;
};
@@ -1929,6 +1989,52 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return &CI;
break;
}
+ case Intrinsic::ptrmask: {
+ unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType());
+ KnownBits Known(BitWidth);
+ if (SimplifyDemandedInstructionBits(*II, Known))
+ return II;
+
+ Value *InnerPtr, *InnerMask;
+ bool Changed = false;
+ // Combine:
+ // (ptrmask (ptrmask p, A), B)
+ // -> (ptrmask p, (and A, B))
+ if (match(II->getArgOperand(0),
+ m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr),
+ m_Value(InnerMask))))) {
+ assert(II->getArgOperand(1)->getType() == InnerMask->getType() &&
+ "Mask types must match");
+ // TODO: If InnerMask == Op1, we could copy attributes from inner
+ // callsite -> outer callsite.
+ Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask);
+ replaceOperand(CI, 0, InnerPtr);
+ replaceOperand(CI, 1, NewMask);
+ Changed = true;
+ }
+
+ // See if we can deduce non-null.
+ if (!CI.hasRetAttr(Attribute::NonNull) &&
+ (Known.isNonZero() ||
+ isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) {
+ CI.addRetAttr(Attribute::NonNull);
+ Changed = true;
+ }
+
+ unsigned NewAlignmentLog =
+ std::min(Value::MaxAlignmentExponent,
+ std::min(BitWidth - 1, Known.countMinTrailingZeros()));
+ // Known bits will capture if we had alignment information associated with
+ // the pointer argument.
+ if (NewAlignmentLog > Log2(CI.getRetAlign().valueOrOne())) {
+ CI.addRetAttr(Attribute::getWithAlignment(
+ CI.getContext(), Align(uint64_t(1) << NewAlignmentLog)));
+ Changed = true;
+ }
+ if (Changed)
+ return &CI;
+ break;
+ }
case Intrinsic::uadd_with_overflow:
case Intrinsic::sadd_with_overflow: {
if (Instruction *I = foldIntrinsicWithOverflowCommon(II))
@@ -2493,10 +2599,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
VectorType *NewVT = cast<VectorType>(II->getType());
if (Constant *CV0 = dyn_cast<Constant>(Arg0)) {
if (Constant *CV1 = dyn_cast<Constant>(Arg1)) {
- CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext);
- CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext);
-
- return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1));
+ Value *V0 = Builder.CreateIntCast(CV0, NewVT, /*isSigned=*/!Zext);
+ Value *V1 = Builder.CreateIntCast(CV1, NewVT, /*isSigned=*/!Zext);
+ return replaceInstUsesWith(CI, Builder.CreateMul(V0, V1));
}
// Couldn't simplify - canonicalize constant to the RHS.
@@ -2950,24 +3055,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return replaceOperand(CI, 0, InsertTuple);
}
- auto *DstTy = dyn_cast<FixedVectorType>(ReturnType);
- auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
+ auto *DstTy = dyn_cast<VectorType>(ReturnType);
+ auto *VecTy = dyn_cast<VectorType>(Vec->getType());
- // Only canonicalize if the the destination vector and Vec are fixed
- // vectors.
if (DstTy && VecTy) {
- unsigned DstNumElts = DstTy->getNumElements();
- unsigned VecNumElts = VecTy->getNumElements();
+ auto DstEltCnt = DstTy->getElementCount();
+ auto VecEltCnt = VecTy->getElementCount();
unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue();
// Extracting the entirety of Vec is a nop.
- if (VecNumElts == DstNumElts) {
+ if (DstEltCnt == VecTy->getElementCount()) {
replaceInstUsesWith(CI, Vec);
return eraseInstFromFunction(CI);
}
+ // Only canonicalize to shufflevector if the destination vector and
+ // Vec are fixed vectors.
+ if (VecEltCnt.isScalable() || DstEltCnt.isScalable())
+ break;
+
SmallVector<int, 8> Mask;
- for (unsigned i = 0; i != DstNumElts; ++i)
+ for (unsigned i = 0; i != DstEltCnt.getKnownMinValue(); ++i)
Mask.push_back(IdxN + i);
Value *Shuffle = Builder.CreateShuffleVector(Vec, Mask);
@@ -3943,9 +4051,9 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy);
NC->setDebugLoc(Caller->getDebugLoc());
- Instruction *InsertPt = NewCall->getInsertionPointAfterDef();
- assert(InsertPt && "No place to insert cast");
- InsertNewInstBefore(NC, *InsertPt);
+ auto OptInsertPt = NewCall->getInsertionPointAfterDef();
+ assert(OptInsertPt && "No place to insert cast");
+ InsertNewInstBefore(NC, *OptInsertPt);
Worklist.pushUsersToWorkList(*Caller);
} else {
NV = PoisonValue::get(Caller->getType());
@@ -3972,8 +4080,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
Instruction *
InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp) {
- Value *Callee = Call.getCalledOperand();
- Type *CalleeTy = Callee->getType();
FunctionType *FTy = Call.getFunctionType();
AttributeList Attrs = Call.getAttributes();
@@ -4070,12 +4176,8 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
// Replace the trampoline call with a direct call. Let the generic
// code sort out any function type mismatches.
- FunctionType *NewFTy = FunctionType::get(FTy->getReturnType(), NewTypes,
- FTy->isVarArg());
- Constant *NewCallee =
- NestF->getType() == PointerType::getUnqual(NewFTy) ?
- NestF : ConstantExpr::getBitCast(NestF,
- PointerType::getUnqual(NewFTy));
+ FunctionType *NewFTy =
+ FunctionType::get(FTy->getReturnType(), NewTypes, FTy->isVarArg());
AttributeList NewPAL =
AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), NewArgAttrs);
@@ -4085,19 +4187,18 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
Instruction *NewCaller;
if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) {
- NewCaller = InvokeInst::Create(NewFTy, NewCallee,
- II->getNormalDest(), II->getUnwindDest(),
- NewArgs, OpBundles);
+ NewCaller = InvokeInst::Create(NewFTy, NestF, II->getNormalDest(),
+ II->getUnwindDest(), NewArgs, OpBundles);
cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv());
cast<InvokeInst>(NewCaller)->setAttributes(NewPAL);
} else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&Call)) {
NewCaller =
- CallBrInst::Create(NewFTy, NewCallee, CBI->getDefaultDest(),
+ CallBrInst::Create(NewFTy, NestF, CBI->getDefaultDest(),
CBI->getIndirectDests(), NewArgs, OpBundles);
cast<CallBrInst>(NewCaller)->setCallingConv(CBI->getCallingConv());
cast<CallBrInst>(NewCaller)->setAttributes(NewPAL);
} else {
- NewCaller = CallInst::Create(NewFTy, NewCallee, NewArgs, OpBundles);
+ NewCaller = CallInst::Create(NewFTy, NestF, NewArgs, OpBundles);
cast<CallInst>(NewCaller)->setTailCallKind(
cast<CallInst>(Call).getTailCallKind());
cast<CallInst>(NewCaller)->setCallingConv(
@@ -4113,7 +4214,6 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
// Replace the trampoline call with a direct call. Since there is no 'nest'
// parameter, there is no need to adjust the argument list. Let the generic
// code sort out any function type mismatches.
- Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy);
- Call.setCalledFunction(FTy, NewCallee);
+ Call.setCalledFunction(FTy, NestF);
return &Call;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 5c84f666616d..6629ca840a67 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -29,11 +29,8 @@ using namespace PatternMatch;
/// true for, actually insert the code to evaluate the expression.
Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
bool isSigned) {
- if (Constant *C = dyn_cast<Constant>(V)) {
- C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/);
- // If we got a constantexpr back, try to simplify it with DL info.
- return ConstantFoldConstant(C, DL, &TLI);
- }
+ if (Constant *C = dyn_cast<Constant>(V))
+ return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
// Otherwise, it must be an instruction.
Instruction *I = cast<Instruction>(V);
@@ -112,7 +109,7 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
Res->takeName(I);
- return InsertNewInstWith(Res, *I);
+ return InsertNewInstWith(Res, I->getIterator());
}
Instruction::CastOps
@@ -217,7 +214,8 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
/// free to be evaluated in that type. This is a helper for canEvaluate*.
static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
if (isa<Constant>(V))
- return true;
+ return match(V, m_ImmConstant());
+
Value *X;
if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) &&
X->getType() == Ty)
@@ -229,7 +227,6 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
/// Filter out values that we can not evaluate in the destination type for free.
/// This is a helper for canEvaluate*.
static bool canNotEvaluateInType(Value *V, Type *Ty) {
- assert(!isa<Constant>(V) && "Constant should already be handled.");
if (!isa<Instruction>(V))
return true;
// We don't extend or shrink something that has multiple uses -- doing so
@@ -505,11 +502,13 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
return nullptr;
- // We have an unnecessarily wide rotate!
- // trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt))
- // Narrow the inputs and convert to funnel shift intrinsic:
- // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt))
- Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
+ // Adjust the width of ShAmt for narrowed funnel shift operation:
+ // - Zero-extend if ShAmt is narrower than the destination type.
+ // - Truncate if ShAmt is wider, discarding non-significant high-order bits.
+ // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal),
+ // zext/trunc(ShAmt)).
+ Value *NarrowShAmt = Builder.CreateZExtOrTrunc(ShAmt, DestTy);
+
Value *X, *Y;
X = Y = Builder.CreateTrunc(ShVal0, DestTy);
if (ShVal0 != ShVal1)
@@ -582,13 +581,15 @@ Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) {
APInt(SrcWidth, MaxShiftAmt)))) {
auto *OldShift = cast<Instruction>(Trunc.getOperand(0));
bool IsExact = OldShift->isExact();
- auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
- ShAmt = Constant::mergeUndefsWith(ShAmt, C);
- Value *Shift =
- OldShift->getOpcode() == Instruction::AShr
- ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)
- : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact);
- return CastInst::CreateTruncOrBitCast(Shift, DestTy);
+ if (Constant *ShAmt = ConstantFoldIntegerCast(C, A->getType(),
+ /*IsSigned*/ true, DL)) {
+ ShAmt = Constant::mergeUndefsWith(ShAmt, C);
+ Value *Shift =
+ OldShift->getOpcode() == Instruction::AShr
+ ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)
+ : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact);
+ return CastInst::CreateTruncOrBitCast(Shift, DestTy);
+ }
}
}
break;
@@ -904,19 +905,18 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp,
// zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
// zext (X != 0) to i32 --> X iff X has only the low bit set.
// zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set.
- if (Op1CV->isZero() && Cmp->isEquality() &&
- (Cmp->getOperand(0)->getType() == Zext.getType() ||
- Cmp->getPredicate() == ICmpInst::ICMP_NE)) {
- // If Op1C some other power of two, convert:
- KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
+ if (Op1CV->isZero() && Cmp->isEquality()) {
// Exactly 1 possible 1? But not the high-bit because that is
// canonicalized to this form.
+ KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
APInt KnownZeroMask(~Known.Zero);
- if (KnownZeroMask.isPowerOf2() &&
- (Zext.getType()->getScalarSizeInBits() !=
- KnownZeroMask.logBase2() + 1)) {
- uint32_t ShAmt = KnownZeroMask.logBase2();
+ uint32_t ShAmt = KnownZeroMask.logBase2();
+ bool IsExpectShAmt = KnownZeroMask.isPowerOf2() &&
+ (Zext.getType()->getScalarSizeInBits() != ShAmt + 1);
+ if (IsExpectShAmt &&
+ (Cmp->getOperand(0)->getType() == Zext.getType() ||
+ Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) {
Value *In = Cmp->getOperand(0);
if (ShAmt) {
// Perform a logical shr by shiftamt.
@@ -1184,14 +1184,14 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
Value *X;
if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) &&
X->getType() == DestTy)
- return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy));
+ return BinaryOperator::CreateAnd(X, Builder.CreateZExt(C, DestTy));
// zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)).
Value *And;
if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) &&
match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) &&
X->getType() == DestTy) {
- Constant *ZC = ConstantExpr::getZExt(C, DestTy);
+ Value *ZC = Builder.CreateZExt(C, DestTy);
return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC);
}
@@ -1202,7 +1202,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
// zext (and (trunc X), C) --> and X, (zext C)
if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) &&
X->getType() == DestTy) {
- Constant *ZextC = ConstantExpr::getZExt(C, DestTy);
+ Value *ZextC = Builder.CreateZExt(C, DestTy);
return BinaryOperator::CreateAnd(X, ZextC);
}
@@ -1221,6 +1221,22 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
}
}
+ if (!Zext.hasNonNeg()) {
+ // If this zero extend is only used by a shift, add nneg flag.
+ if (Zext.hasOneUse() &&
+ SrcTy->getScalarSizeInBits() >
+ Log2_64_Ceil(DestTy->getScalarSizeInBits()) &&
+ match(Zext.user_back(), m_Shift(m_Value(), m_Specific(&Zext)))) {
+ Zext.setNonNeg();
+ return &Zext;
+ }
+
+ if (isKnownNonNegative(Src, SQ.getWithInstruction(&Zext))) {
+ Zext.setNonNeg();
+ return &Zext;
+ }
+ }
+
return nullptr;
}
@@ -1373,8 +1389,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
unsigned DestBitSize = DestTy->getScalarSizeInBits();
// If the value being extended is zero or positive, use a zext instead.
- if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT))
- return CastInst::Create(Instruction::ZExt, Src, DestTy);
+ if (isKnownNonNegative(Src, SQ.getWithInstruction(&Sext))) {
+ auto CI = CastInst::Create(Instruction::ZExt, Src, DestTy);
+ CI->setNonNeg(true);
+ return CI;
+ }
// Try to extend the entire expression tree to the wide destination type.
if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) {
@@ -1445,9 +1464,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
// TODO: Eventually this could be subsumed by EvaluateInDifferentType.
Constant *BA = nullptr, *CA = nullptr;
if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)),
- m_Constant(CA))) &&
+ m_ImmConstant(CA))) &&
BA->isElementWiseEqual(CA) && A->getType() == DestTy) {
- Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy);
+ Constant *WideCurrShAmt =
+ ConstantFoldCastOperand(Instruction::SExt, CA, DestTy, DL);
+ assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail");
Constant *NumLowbitsLeft = ConstantExpr::getSub(
ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt);
Constant *NewShAmt = ConstantExpr::getSub(
@@ -1915,29 +1936,6 @@ Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) {
return nullptr;
}
-/// Implement the transforms for cast of pointer (bitcast/ptrtoint)
-Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) {
- Value *Src = CI.getOperand(0);
-
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) {
- // If casting the result of a getelementptr instruction with no offset, turn
- // this into a cast of the original pointer!
- if (GEP->hasAllZeroIndices() &&
- // If CI is an addrspacecast and GEP changes the poiner type, merging
- // GEP into CI would undo canonicalizing addrspacecast with different
- // pointer types, causing infinite loops.
- (!isa<AddrSpaceCastInst>(CI) ||
- GEP->getType() == GEP->getPointerOperandType())) {
- // Changing the cast operand is usually not a good idea but it is safe
- // here because the pointer operand is being replaced with another
- // pointer operand so the opcode doesn't need to change.
- return replaceOperand(CI, 0, GEP->getOperand(0));
- }
- }
-
- return commonCastTransforms(CI);
-}
-
Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
// If the destination integer type is not the intptr_t type for this target,
// do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast
@@ -1955,6 +1953,15 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
}
+ // (ptrtoint (ptrmask P, M))
+ // -> (and (ptrtoint P), M)
+ // This is generally beneficial as `and` is better supported than `ptrmask`.
+ Value *Ptr, *Mask;
+ if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr),
+ m_Value(Mask)))) &&
+ Mask->getType() == Ty)
+ return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask);
+
if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) {
// Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use.
// While this can increase the number of instructions it doesn't actually
@@ -1979,7 +1986,7 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
return InsertElementInst::Create(Vec, NewCast, Index);
}
- return commonPointerCastTransforms(CI);
+ return commonCastTransforms(CI);
}
/// This input value (which is known to have vector type) is being zero extended
@@ -2136,9 +2143,12 @@ static bool collectInsertionElements(Value *V, unsigned Shift,
Type *ElementIntTy = IntegerType::get(C->getContext(), ElementSize);
for (unsigned i = 0; i != NumElts; ++i) {
- unsigned ShiftI = Shift+i*ElementSize;
- Constant *Piece = ConstantExpr::getLShr(C, ConstantInt::get(C->getType(),
- ShiftI));
+ unsigned ShiftI = Shift + i * ElementSize;
+ Constant *Piece = ConstantFoldBinaryInstruction(
+ Instruction::LShr, C, ConstantInt::get(C->getType(), ShiftI));
+ if (!Piece)
+ return false;
+
Piece = ConstantExpr::getTrunc(Piece, ElementIntTy);
if (!collectInsertionElements(Piece, ShiftI, Elements, VecEltTy,
isBigEndian))
@@ -2701,11 +2711,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
if (Instruction *I = foldBitCastSelect(CI, Builder))
return I;
- if (SrcTy->isPointerTy())
- return commonPointerCastTransforms(CI);
return commonCastTransforms(CI);
}
Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
- return commonPointerCastTransforms(CI);
+ return commonCastTransforms(CI);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 656f04370e17..e42e011bd436 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -12,12 +12,14 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Utils/Local.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
@@ -26,6 +28,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <bitset>
using namespace llvm;
using namespace PatternMatch;
@@ -412,7 +415,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
/// Returns true if we can rewrite Start as a GEP with pointer Base
/// and some integer offset. The nodes that need to be re-written
/// for this transformation will be added to Explored.
-static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
+static bool canRewriteGEPAsOffset(Value *Start, Value *Base,
const DataLayout &DL,
SetVector<Value *> &Explored) {
SmallVector<Value *, 16> WorkList(1, Start);
@@ -440,27 +443,15 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
continue;
}
- if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) &&
- !isa<GetElementPtrInst>(V) && !isa<PHINode>(V))
+ if (!isa<GetElementPtrInst>(V) && !isa<PHINode>(V))
// We've found some value that we can't explore which is different from
// the base. Therefore we can't do this transformation.
return false;
- if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) {
- auto *CI = cast<CastInst>(V);
- if (!CI->isNoopCast(DL))
- return false;
-
- if (!Explored.contains(CI->getOperand(0)))
- WorkList.push_back(CI->getOperand(0));
- }
-
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
- // We're limiting the GEP to having one index. This will preserve
- // the original pointer type. We could handle more cases in the
- // future.
- if (GEP->getNumIndices() != 1 || !GEP->isInBounds() ||
- GEP->getSourceElementType() != ElemTy)
+ // Only allow inbounds GEPs with at most one variable offset.
+ auto IsNonConst = [](Value *V) { return !isa<ConstantInt>(V); };
+ if (!GEP->isInBounds() || count_if(GEP->indices(), IsNonConst) > 1)
return false;
if (!Explored.contains(GEP->getOperand(0)))
@@ -514,7 +505,8 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
bool Before = true) {
if (auto *PHI = dyn_cast<PHINode>(V)) {
- Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt());
+ BasicBlock *Parent = PHI->getParent();
+ Builder.SetInsertPoint(Parent, Parent->getFirstInsertionPt());
return;
}
if (auto *I = dyn_cast<Instruction>(V)) {
@@ -526,7 +518,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
if (auto *A = dyn_cast<Argument>(V)) {
// Set the insertion point in the entry block.
BasicBlock &Entry = A->getParent()->getEntryBlock();
- Builder.SetInsertPoint(&*Entry.getFirstInsertionPt());
+ Builder.SetInsertPoint(&Entry, Entry.getFirstInsertionPt());
return;
}
// Otherwise, this is a constant and we don't need to set a new
@@ -536,7 +528,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
/// Returns a re-written value of Start as an indexed GEP using Base as a
/// pointer.
-static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
+static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
const DataLayout &DL,
SetVector<Value *> &Explored,
InstCombiner &IC) {
@@ -567,36 +559,18 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
// Create all the other instructions.
for (Value *Val : Explored) {
-
if (NewInsts.contains(Val))
continue;
- if (auto *CI = dyn_cast<CastInst>(Val)) {
- // Don't get rid of the intermediate variable here; the store can grow
- // the map which will invalidate the reference to the input value.
- Value *V = NewInsts[CI->getOperand(0)];
- NewInsts[CI] = V;
- continue;
- }
if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
- Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)]
- : GEP->getOperand(1);
setInsertionPoint(Builder, GEP);
- // Indices might need to be sign extended. GEPs will magically do
- // this, but we need to do it ourselves here.
- if (Index->getType()->getScalarSizeInBits() !=
- NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) {
- Index = Builder.CreateSExtOrTrunc(
- Index, NewInsts[GEP->getOperand(0)]->getType(),
- GEP->getOperand(0)->getName() + ".sext");
- }
-
- auto *Op = NewInsts[GEP->getOperand(0)];
+ Value *Op = NewInsts[GEP->getOperand(0)];
+ Value *OffsetV = emitGEPOffset(&Builder, DL, GEP);
if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
- NewInsts[GEP] = Index;
+ NewInsts[GEP] = OffsetV;
else
NewInsts[GEP] = Builder.CreateNSWAdd(
- Op, Index, GEP->getOperand(0)->getName() + ".add");
+ Op, OffsetV, GEP->getOperand(0)->getName() + ".add");
continue;
}
if (isa<PHINode>(Val))
@@ -624,23 +598,14 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
}
}
- PointerType *PtrTy =
- ElemTy->getPointerTo(Start->getType()->getPointerAddressSpace());
for (Value *Val : Explored) {
if (Val == Base)
continue;
- // Depending on the type, for external users we have to emit
- // a GEP or a GEP + ptrtoint.
setInsertionPoint(Builder, Val, false);
-
- // Cast base to the expected type.
- Value *NewVal = Builder.CreateBitOrPointerCast(
- Base, PtrTy, Start->getName() + "to.ptr");
- NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]),
- Val->getName() + ".ptr");
- NewVal = Builder.CreateBitOrPointerCast(
- NewVal, Val->getType(), Val->getName() + ".conv");
+ // Create GEP for external users.
+ Value *NewVal = Builder.CreateInBoundsGEP(
+ Builder.getInt8Ty(), Base, NewInsts[Val], Val->getName() + ".ptr");
IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal);
// Add old instruction to worklist for DCE. We don't directly remove it
// here because the original compare is one of the users.
@@ -650,48 +615,6 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
return NewInsts[Start];
}
-/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express
-/// the input Value as a constant indexed GEP. Returns a pair containing
-/// the GEPs Pointer and Index.
-static std::pair<Value *, Value *>
-getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) {
- Type *IndexType = IntegerType::get(V->getContext(),
- DL.getIndexTypeSizeInBits(V->getType()));
-
- Constant *Index = ConstantInt::getNullValue(IndexType);
- while (true) {
- if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
- // We accept only inbouds GEPs here to exclude the possibility of
- // overflow.
- if (!GEP->isInBounds())
- break;
- if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 &&
- GEP->getSourceElementType() == ElemTy) {
- V = GEP->getOperand(0);
- Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1));
- Index = ConstantExpr::getAdd(
- Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType));
- continue;
- }
- break;
- }
- if (auto *CI = dyn_cast<IntToPtrInst>(V)) {
- if (!CI->isNoopCast(DL))
- break;
- V = CI->getOperand(0);
- continue;
- }
- if (auto *CI = dyn_cast<PtrToIntInst>(V)) {
- if (!CI->isNoopCast(DL))
- break;
- V = CI->getOperand(0);
- continue;
- }
- break;
- }
- return {V, Index};
-}
-
/// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant.
/// We can look through PHIs, GEPs and casts in order to determine a common base
/// between GEPLHS and RHS.
@@ -706,14 +629,19 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
if (!GEPLHS->hasAllConstantIndices())
return nullptr;
- Type *ElemTy = GEPLHS->getSourceElementType();
- Value *PtrBase, *Index;
- std::tie(PtrBase, Index) = getAsConstantIndexedAddress(ElemTy, GEPLHS, DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(GEPLHS->getType()), 0);
+ Value *PtrBase =
+ GEPLHS->stripAndAccumulateConstantOffsets(DL, Offset,
+ /*AllowNonInbounds*/ false);
+
+ // Bail if we looked through addrspacecast.
+ if (PtrBase->getType() != GEPLHS->getType())
+ return nullptr;
// The set of nodes that will take part in this transformation.
SetVector<Value *> Nodes;
- if (!canRewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes))
+ if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes))
return nullptr;
// We know we can re-write this as
@@ -722,13 +650,14 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
// can't have overflow on either side. We can therefore re-write
// this as:
// OFFSET1 cmp OFFSET2
- Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC);
+ Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes, IC);
// RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written
// GEP having PtrBase as the pointer base, and has returned in NewRHS the
// offset. Since Index is the offset of LHS to the base pointer, we will now
// compare the offsets instead of comparing the pointers.
- return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS);
+ return new ICmpInst(ICmpInst::getSignedPredicate(Cond),
+ IC.Builder.getInt(Offset), NewRHS);
}
/// Fold comparisons between a GEP instruction and something else. At this point
@@ -844,17 +773,6 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this);
}
- // If one of the GEPs has all zero indices, recurse.
- // FIXME: Handle vector of pointers.
- if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices())
- return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
- ICmpInst::getSwappedPredicate(Cond), I);
-
- // If the other GEP has all zero indices, recurse.
- // FIXME: Handle vector of pointers.
- if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices())
- return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
-
bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() &&
GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) {
@@ -894,8 +812,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// Only lower this if the icmp is the only user of the GEP or if we expect
// the result to fold to a constant!
if ((GEPsInBounds || CmpInst::isEquality(Cond)) &&
- (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) &&
- (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) {
+ (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) &&
+ (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse())) {
// ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2)
Value *L = EmitGEPOffset(GEPLHS);
Value *R = EmitGEPOffset(GEPRHS);
@@ -1285,9 +1203,9 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
if (Pred == ICmpInst::ICMP_SGT) {
Value *A, *B;
if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) {
- if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT))
+ if (isKnownPositive(A, SQ.getWithInstruction(&Cmp)))
return new ICmpInst(Pred, B, Cmp.getOperand(1));
- if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT))
+ if (isKnownPositive(B, SQ.getWithInstruction(&Cmp)))
return new ICmpInst(Pred, A, Cmp.getOperand(1));
}
}
@@ -1554,6 +1472,61 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
return nullptr;
}
+/// Fold icmp (trunc X), (trunc Y).
+/// Fold icmp (trunc X), (zext Y).
+Instruction *
+InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
+ const SimplifyQuery &Q) {
+ if (Cmp.isSigned())
+ return nullptr;
+
+ Value *X, *Y;
+ ICmpInst::Predicate Pred;
+ bool YIsZext = false;
+ // Try to match icmp (trunc X), (trunc Y)
+ if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) {
+ if (X->getType() != Y->getType() &&
+ (!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse()))
+ return nullptr;
+ if (!isDesirableIntType(X->getType()->getScalarSizeInBits()) &&
+ isDesirableIntType(Y->getType()->getScalarSizeInBits())) {
+ std::swap(X, Y);
+ Pred = Cmp.getSwappedPredicate(Pred);
+ }
+ }
+ // Try to match icmp (trunc X), (zext Y)
+ else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)),
+ m_OneUse(m_ZExt(m_Value(Y))))))
+
+ YIsZext = true;
+ else
+ return nullptr;
+
+ Type *TruncTy = Cmp.getOperand(0)->getType();
+ unsigned TruncBits = TruncTy->getScalarSizeInBits();
+
+ // If this transform will end up changing from desirable types -> undesirable
+ // types skip it.
+ if (isDesirableIntType(TruncBits) &&
+ !isDesirableIntType(X->getType()->getScalarSizeInBits()))
+ return nullptr;
+
+ // Check if the trunc is unneeded.
+ KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q);
+ if (KnownX.countMaxActiveBits() > TruncBits)
+ return nullptr;
+
+ if (!YIsZext) {
+ // If Y is also a trunc, make sure it is unneeded.
+ KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q);
+ if (KnownY.countMaxActiveBits() > TruncBits)
+ return nullptr;
+ }
+
+ Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType());
+ return new ICmpInst(Pred, X, NewY);
+}
+
/// Fold icmp (xor X, Y), C.
Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp,
BinaryOperator *Xor,
@@ -1944,19 +1917,18 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
return nullptr;
}
-/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0.
-static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or,
- InstCombiner::BuilderTy &Builder) {
- // Are we using xors to bitwise check for a pair or pairs of (in)equalities?
- // Convert to a shorter form that has more potential to be folded even
- // further.
- // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4)
- // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4)
- // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 -->
+/// Fold icmp eq/ne (or (xor/sub (X1, X2), xor/sub (X3, X4))), 0.
+static Value *foldICmpOrXorSubChain(ICmpInst &Cmp, BinaryOperator *Or,
+ InstCombiner::BuilderTy &Builder) {
+ // Are we using xors or subs to bitwise check for a pair or pairs of
+ // (in)equalities? Convert to a shorter form that has more potential to be
+ // folded even further.
+ // ((X1 ^/- X2) || (X3 ^/- X4)) == 0 --> (X1 == X2) && (X3 == X4)
+ // ((X1 ^/- X2) || (X3 ^/- X4)) != 0 --> (X1 != X2) || (X3 != X4)
+ // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) == 0 -->
// (X1 == X2) && (X3 == X4) && (X5 == X6)
- // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 -->
+ // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) != 0 -->
// (X1 != X2) || (X3 != X4) || (X5 != X6)
- // TODO: Implement for sub
SmallVector<std::pair<Value *, Value *>, 2> CmpValues;
SmallVector<Value *, 16> WorkList(1, Or);
@@ -1967,9 +1939,16 @@ static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or,
if (match(OrOperatorArgument,
m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) {
CmpValues.emplace_back(Lhs, Rhs);
- } else {
- WorkList.push_back(OrOperatorArgument);
+ return;
}
+
+ if (match(OrOperatorArgument,
+ m_OneUse(m_Sub(m_Value(Lhs), m_Value(Rhs))))) {
+ CmpValues.emplace_back(Lhs, Rhs);
+ return;
+ }
+
+ WorkList.push_back(OrOperatorArgument);
};
Value *CurrentValue = WorkList.pop_back_val();
@@ -2082,7 +2061,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
return BinaryOperator::Create(BOpc, CmpP, CmpQ);
}
- if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder))
+ if (Value *V = foldICmpOrXorSubChain(Cmp, Or, Builder))
return replaceInstUsesWith(Cmp, V);
return nullptr;
@@ -2443,7 +2422,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
// constant-value-based preconditions in the folds below, then we could assert
// those conditions rather than checking them. This is difficult because of
// undef/poison (PR34838).
- if (IsAShr) {
+ if (IsAShr && Shr->hasOneUse()) {
if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) {
// When ShAmtC can be shifted losslessly:
// icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC)
@@ -2483,7 +2462,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
ConstantInt::getAllOnesValue(ShrTy));
}
}
- } else {
+ } else if (!IsAShr) {
if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) {
// icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC)
// icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC)
@@ -2888,19 +2867,97 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C));
}
+static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0,
+ Value *Op1, IRBuilderBase &Builder,
+ bool HasOneUse) {
+ auto FoldConstant = [&](bool Val) {
+ Constant *Res = Val ? Builder.getTrue() : Builder.getFalse();
+ if (Op0->getType()->isVectorTy())
+ Res = ConstantVector::getSplat(
+ cast<VectorType>(Op0->getType())->getElementCount(), Res);
+ return Res;
+ };
+
+ switch (Table.to_ulong()) {
+ case 0: // 0 0 0 0
+ return FoldConstant(false);
+ case 1: // 0 0 0 1
+ return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr;
+ case 2: // 0 0 1 0
+ return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr;
+ case 3: // 0 0 1 1
+ return Builder.CreateNot(Op0);
+ case 4: // 0 1 0 0
+ return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr;
+ case 5: // 0 1 0 1
+ return Builder.CreateNot(Op1);
+ case 6: // 0 1 1 0
+ return Builder.CreateXor(Op0, Op1);
+ case 7: // 0 1 1 1
+ return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr;
+ case 8: // 1 0 0 0
+ return Builder.CreateAnd(Op0, Op1);
+ case 9: // 1 0 0 1
+ return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr;
+ case 10: // 1 0 1 0
+ return Op1;
+ case 11: // 1 0 1 1
+ return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr;
+ case 12: // 1 1 0 0
+ return Op0;
+ case 13: // 1 1 0 1
+ return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr;
+ case 14: // 1 1 1 0
+ return Builder.CreateOr(Op0, Op1);
+ case 15: // 1 1 1 1
+ return FoldConstant(true);
+ default:
+ llvm_unreachable("Invalid Operation");
+ }
+ return nullptr;
+}
+
/// Fold icmp (add X, Y), C.
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
const APInt &C) {
Value *Y = Add->getOperand(1);
+ Value *X = Add->getOperand(0);
+
+ Value *Op0, *Op1;
+ Instruction *Ext0, *Ext1;
+ const CmpInst::Predicate Pred = Cmp.getPredicate();
+ if (match(Add,
+ m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))),
+ m_CombineAnd(m_Instruction(Ext1),
+ m_ZExtOrSExt(m_Value(Op1))))) &&
+ Op0->getType()->isIntOrIntVectorTy(1) &&
+ Op1->getType()->isIntOrIntVectorTy(1)) {
+ unsigned BW = C.getBitWidth();
+ std::bitset<4> Table;
+ auto ComputeTable = [&](bool Op0Val, bool Op1Val) {
+ int Res = 0;
+ if (Op0Val)
+ Res += isa<ZExtInst>(Ext0) ? 1 : -1;
+ if (Op1Val)
+ Res += isa<ZExtInst>(Ext1) ? 1 : -1;
+ return ICmpInst::compare(APInt(BW, Res, true), C, Pred);
+ };
+
+ Table[0] = ComputeTable(false, false);
+ Table[1] = ComputeTable(false, true);
+ Table[2] = ComputeTable(true, false);
+ Table[3] = ComputeTable(true, true);
+ if (auto *Cond =
+ createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse()))
+ return replaceInstUsesWith(Cmp, Cond);
+ }
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
return nullptr;
// Fold icmp pred (add X, C2), C.
- Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
- const CmpInst::Predicate Pred = Cmp.getPredicate();
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
@@ -3172,18 +3229,6 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
}
}
- // Test to see if the operands of the icmp are casted versions of other
- // values. If the ptr->ptr cast can be stripped off both arguments, do so.
- if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) {
- // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast
- // so eliminate it as well.
- if (auto *BC2 = dyn_cast<BitCastInst>(Op1))
- Op1 = BC2->getOperand(0);
-
- Op1 = Builder.CreateBitCast(Op1, SrcType);
- return new ICmpInst(Pred, BCSrcOp, Op1);
- }
-
const APInt *C;
if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() ||
!SrcType->isIntOrIntVectorTy())
@@ -3196,10 +3241,12 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
// icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0
// Example: are all elements equal? --> are zero elements not equal?
// TODO: Try harder to reduce compare of 2 freely invertible operands?
- if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() &&
- isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) {
- Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType);
- return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType));
+ if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse()) {
+ if (Value *NotBCSrcOp =
+ getFreelyInverted(BCSrcOp, BCSrcOp->hasOneUse(), &Builder)) {
+ Value *Cast = Builder.CreateBitCast(NotBCSrcOp, DstType);
+ return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType));
+ }
}
// If this is checking if all elements of an extended vector are clear or not,
@@ -3878,21 +3925,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
return nullptr;
switch (LHSI->getOpcode()) {
- case Instruction::GetElementPtr:
- // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null
- if (RHSC->isNullValue() &&
- cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices())
- return new ICmpInst(
- I.getPredicate(), LHSI->getOperand(0),
- Constant::getNullValue(LHSI->getOperand(0)->getType()));
- break;
case Instruction::PHI:
- // Only fold icmp into the PHI if the phi and icmp are in the same
- // block. If in the same block, we're encouraging jump threading. If
- // not, we are just pessimizing the code by making an i1 phi.
- if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
- return NV;
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
+ return NV;
break;
case Instruction::IntToPtr:
// icmp pred inttoptr(X), null -> icmp pred X, 0
@@ -4243,7 +4278,12 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
/*isNUW=*/false, SQ.getWithInstruction(&I)));
if (!NewShAmt)
return nullptr;
- NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy);
+ if (NewShAmt->getType() != WidestTy) {
+ NewShAmt =
+ ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, WidestTy, SQ.DL);
+ if (!NewShAmt)
+ return nullptr;
+ }
unsigned WidestBitWidth = WidestTy->getScalarSizeInBits();
// Is the new shift amount smaller than the bit width?
@@ -4424,6 +4464,65 @@ static Instruction *foldICmpXNegX(ICmpInst &I,
return nullptr;
}
+static Instruction *foldICmpAndXX(ICmpInst &I, const SimplifyQuery &Q,
+ InstCombinerImpl &IC) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A;
+ // Normalize and operand as operand 0.
+ CmpInst::Predicate Pred = I.getPredicate();
+ if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) {
+ std::swap(Op0, Op1);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ if (!match(Op0, m_c_And(m_Specific(Op1), m_Value(A))))
+ return nullptr;
+
+ // (icmp (X & Y) u< X --> (X & Y) != X
+ if (Pred == ICmpInst::ICMP_ULT)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+ // (icmp (X & Y) u>= X --> (X & Y) == X
+ if (Pred == ICmpInst::ICMP_UGE)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
+
+ return nullptr;
+}
+
+static Instruction *foldICmpOrXX(ICmpInst &I, const SimplifyQuery &Q,
+ InstCombinerImpl &IC) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A;
+
+ // Normalize or operand as operand 0.
+ CmpInst::Predicate Pred = I.getPredicate();
+ if (match(Op1, m_c_Or(m_Specific(Op0), m_Value(A)))) {
+ std::swap(Op0, Op1);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ } else if (!match(Op0, m_c_Or(m_Specific(Op1), m_Value(A)))) {
+ return nullptr;
+ }
+
+ // icmp (X | Y) u<= X --> (X | Y) == X
+ if (Pred == ICmpInst::ICMP_ULE)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
+
+ // icmp (X | Y) u> X --> (X | Y) != X
+ if (Pred == ICmpInst::ICMP_UGT)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+ if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) {
+ // icmp (X | Y) eq/ne Y --> (X & ~Y) eq/ne 0 if Y is freely invertible
+ if (Value *NotOp1 =
+ IC.getFreelyInverted(Op1, Op1->hasOneUse(), &IC.Builder))
+ return new ICmpInst(Pred, IC.Builder.CreateAnd(A, NotOp1),
+ Constant::getNullValue(Op1->getType()));
+ // icmp (X | Y) eq/ne Y --> (~X | Y) eq/ne -1 if X is freely invertible.
+ if (Value *NotA = IC.getFreelyInverted(A, A->hasOneUse(), &IC.Builder))
+ return new ICmpInst(Pred, IC.Builder.CreateOr(Op1, NotA),
+ Constant::getAllOnesValue(Op1->getType()));
+ }
+ return nullptr;
+}
+
static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q,
InstCombinerImpl &IC) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A;
@@ -4746,6 +4845,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (Instruction * R = foldICmpXorXX(I, Q, *this))
return R;
+ if (Instruction *R = foldICmpOrXX(I, Q, *this))
+ return R;
{
// Try to remove shared multiplier from comparison:
@@ -4915,6 +5016,9 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))
return replaceInstUsesWith(I, V);
+ if (Instruction *R = foldICmpAndXX(I, Q, *this))
+ return R;
+
if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder))
return replaceInstUsesWith(I, V);
@@ -4924,88 +5028,153 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
return nullptr;
}
-/// Fold icmp Pred min|max(X, Y), X.
-static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) {
- ICmpInst::Predicate Pred = Cmp.getPredicate();
- Value *Op0 = Cmp.getOperand(0);
- Value *X = Cmp.getOperand(1);
-
- // Canonicalize minimum or maximum operand to LHS of the icmp.
- if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) ||
- match(X, m_c_SMax(m_Specific(Op0), m_Value())) ||
- match(X, m_c_UMin(m_Specific(Op0), m_Value())) ||
- match(X, m_c_UMax(m_Specific(Op0), m_Value()))) {
- std::swap(Op0, X);
- Pred = Cmp.getSwappedPredicate();
- }
-
- Value *Y;
- if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) {
- // smin(X, Y) == X --> X s<= Y
- // smin(X, Y) s>= X --> X s<= Y
- if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE)
- return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);
-
- // smin(X, Y) != X --> X s> Y
- // smin(X, Y) s< X --> X s> Y
- if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT)
- return new ICmpInst(ICmpInst::ICMP_SGT, X, Y);
-
- // These cases should be handled in InstSimplify:
- // smin(X, Y) s<= X --> true
- // smin(X, Y) s> X --> false
+/// Fold icmp Pred min|max(X, Y), Z.
+Instruction *
+InstCombinerImpl::foldICmpWithMinMaxImpl(Instruction &I,
+ MinMaxIntrinsic *MinMax, Value *Z,
+ ICmpInst::Predicate Pred) {
+ Value *X = MinMax->getLHS();
+ Value *Y = MinMax->getRHS();
+ if (ICmpInst::isSigned(Pred) && !MinMax->isSigned())
return nullptr;
- }
-
- if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) {
- // smax(X, Y) == X --> X s>= Y
- // smax(X, Y) s<= X --> X s>= Y
- if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE)
- return new ICmpInst(ICmpInst::ICMP_SGE, X, Y);
-
- // smax(X, Y) != X --> X s< Y
- // smax(X, Y) s> X --> X s< Y
- if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT)
- return new ICmpInst(ICmpInst::ICMP_SLT, X, Y);
-
- // These cases should be handled in InstSimplify:
- // smax(X, Y) s>= X --> true
- // smax(X, Y) s< X --> false
+ if (ICmpInst::isUnsigned(Pred) && MinMax->isSigned())
return nullptr;
+ SimplifyQuery Q = SQ.getWithInstruction(&I);
+ auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
+ if (!Val)
+ return std::nullopt;
+ if (match(Val, m_One()))
+ return true;
+ if (match(Val, m_Zero()))
+ return false;
+ return std::nullopt;
+ };
+ auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(Pred, X, Z, Q));
+ auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(Pred, Y, Z, Q));
+ if (!CmpXZ.has_value() && !CmpYZ.has_value())
+ return nullptr;
+ if (!CmpXZ.has_value()) {
+ std::swap(X, Y);
+ std::swap(CmpXZ, CmpYZ);
}
- if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) {
- // umin(X, Y) == X --> X u<= Y
- // umin(X, Y) u>= X --> X u<= Y
- if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE)
- return new ICmpInst(ICmpInst::ICMP_ULE, X, Y);
-
- // umin(X, Y) != X --> X u> Y
- // umin(X, Y) u< X --> X u> Y
- if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT)
- return new ICmpInst(ICmpInst::ICMP_UGT, X, Y);
+ auto FoldIntoCmpYZ = [&]() -> Instruction * {
+ if (CmpYZ.has_value())
+ return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *CmpYZ));
+ return ICmpInst::Create(Instruction::ICmp, Pred, Y, Z);
+ };
- // These cases should be handled in InstSimplify:
- // umin(X, Y) u<= X --> true
- // umin(X, Y) u> X --> false
- return nullptr;
+ switch (Pred) {
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE: {
+ // If X == Z:
+ // Expr Result
+ // min(X, Y) == Z X <= Y
+ // max(X, Y) == Z X >= Y
+ // min(X, Y) != Z X > Y
+ // max(X, Y) != Z X < Y
+ if ((Pred == ICmpInst::ICMP_EQ) == *CmpXZ) {
+ ICmpInst::Predicate NewPred =
+ ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
+ if (Pred == ICmpInst::ICMP_NE)
+ NewPred = ICmpInst::getInversePredicate(NewPred);
+ return ICmpInst::Create(Instruction::ICmp, NewPred, X, Y);
+ }
+ // Otherwise (X != Z):
+ ICmpInst::Predicate NewPred = MinMax->getPredicate();
+ auto MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q));
+ if (!MinMaxCmpXZ.has_value()) {
+ std::swap(X, Y);
+ std::swap(CmpXZ, CmpYZ);
+ // Re-check pre-condition X != Z
+ if (!CmpXZ.has_value() || (Pred == ICmpInst::ICMP_EQ) == *CmpXZ)
+ break;
+ MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q));
+ }
+ if (!MinMaxCmpXZ.has_value())
+ break;
+ if (*MinMaxCmpXZ) {
+ // Expr Fact Result
+ // min(X, Y) == Z X < Z false
+ // max(X, Y) == Z X > Z false
+ // min(X, Y) != Z X < Z true
+ // max(X, Y) != Z X > Z true
+ return replaceInstUsesWith(
+ I, ConstantInt::getBool(I.getType(), Pred == ICmpInst::ICMP_NE));
+ } else {
+ // Expr Fact Result
+ // min(X, Y) == Z X > Z Y == Z
+ // max(X, Y) == Z X < Z Y == Z
+ // min(X, Y) != Z X > Z Y != Z
+ // max(X, Y) != Z X < Z Y != Z
+ return FoldIntoCmpYZ();
+ }
+ break;
+ }
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_SLE:
+ case ICmpInst::ICMP_ULE:
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_UGE: {
+ bool IsSame = MinMax->getPredicate() == ICmpInst::getStrictPredicate(Pred);
+ if (*CmpXZ) {
+ if (IsSame) {
+ // Expr Fact Result
+ // min(X, Y) < Z X < Z true
+ // min(X, Y) <= Z X <= Z true
+ // max(X, Y) > Z X > Z true
+ // max(X, Y) >= Z X >= Z true
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ } else {
+ // Expr Fact Result
+ // max(X, Y) < Z X < Z Y < Z
+ // max(X, Y) <= Z X <= Z Y <= Z
+ // min(X, Y) > Z X > Z Y > Z
+ // min(X, Y) >= Z X >= Z Y >= Z
+ return FoldIntoCmpYZ();
+ }
+ } else {
+ if (IsSame) {
+ // Expr Fact Result
+ // min(X, Y) < Z X >= Z Y < Z
+ // min(X, Y) <= Z X > Z Y <= Z
+ // max(X, Y) > Z X <= Z Y > Z
+ // max(X, Y) >= Z X < Z Y >= Z
+ return FoldIntoCmpYZ();
+ } else {
+ // Expr Fact Result
+ // max(X, Y) < Z X >= Z false
+ // max(X, Y) <= Z X > Z false
+ // min(X, Y) > Z X <= Z false
+ // min(X, Y) >= Z X < Z false
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ }
+ }
+ break;
+ }
+ default:
+ break;
}
- if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) {
- // umax(X, Y) == X --> X u>= Y
- // umax(X, Y) u<= X --> X u>= Y
- if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE)
- return new ICmpInst(ICmpInst::ICMP_UGE, X, Y);
+ return nullptr;
+}
+Instruction *InstCombinerImpl::foldICmpWithMinMax(ICmpInst &Cmp) {
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
+ Value *Lhs = Cmp.getOperand(0);
+ Value *Rhs = Cmp.getOperand(1);
- // umax(X, Y) != X --> X u< Y
- // umax(X, Y) u> X --> X u< Y
- if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT)
- return new ICmpInst(ICmpInst::ICMP_ULT, X, Y);
+ if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Lhs)) {
+ if (Instruction *Res = foldICmpWithMinMaxImpl(Cmp, MinMax, Rhs, Pred))
+ return Res;
+ }
- // These cases should be handled in InstSimplify:
- // umax(X, Y) u>= X --> true
- // umax(X, Y) u< X --> false
- return nullptr;
+ if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Rhs)) {
+ if (Instruction *Res = foldICmpWithMinMaxImpl(
+ Cmp, MinMax, Lhs, ICmpInst::getSwappedPredicate(Pred)))
+ return Res;
}
return nullptr;
@@ -5173,35 +5342,6 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
}
- // Test if 2 values have different or same signbits:
- // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
- // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
- // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
- // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
- Instruction *ExtI;
- if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) &&
- (Op0->hasOneUse() || Op1->hasOneUse())) {
- unsigned OpWidth = Op0->getType()->getScalarSizeInBits();
- Instruction *ShiftI;
- Value *X, *Y;
- ICmpInst::Predicate Pred2;
- if (match(Op0, m_CombineAnd(m_Instruction(ShiftI),
- m_Shr(m_Value(X),
- m_SpecificIntAllowUndef(OpWidth - 1)))) &&
- match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
- Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
- unsigned ExtOpc = ExtI->getOpcode();
- unsigned ShiftOpc = ShiftI->getOpcode();
- if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
- (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
- Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
- Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
- : Builder.CreateIsNotNeg(Xor);
- return replaceInstUsesWith(I, R);
- }
- }
- }
-
// (A >> C) == (B >> C) --> (A^B) u< (1 << C)
// For lshr and ashr pairs.
const APInt *AP1, *AP2;
@@ -5307,6 +5447,40 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
Pred, A,
Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B}));
+ // Canonicalize:
+ // icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst
+ Constant *Cst;
+ if (match(&I, m_c_ICmp(PredUnused,
+ m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))),
+ m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant())))))
+ return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst);
+
+ {
+ // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2)
+ auto m_Matcher =
+ m_CombineOr(m_CombineOr(m_c_Add(m_Value(B), m_Deferred(A)),
+ m_c_Xor(m_Value(B), m_Deferred(A))),
+ m_Sub(m_Value(B), m_Deferred(A)));
+ std::optional<bool> IsZero = std::nullopt;
+ if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
+ m_Deferred(A))))
+ IsZero = false;
+ // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0)
+ else if (match(&I,
+ m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
+ m_Zero())))
+ IsZero = true;
+
+ if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I))
+ // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2)
+ // -> (icmp eq/ne (and X, P2), 0)
+ // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0)
+ // -> (icmp eq/ne (and X, P2), P2)
+ return new ICmpInst(Pred, Builder.CreateAnd(B, A),
+ *IsZero ? A
+ : ConstantInt::getNullValue(A->getType()));
+ }
+
return nullptr;
}
@@ -5383,8 +5557,8 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
// icmp Pred (ext X), (ext Y)
Value *Y;
if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) {
- bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0));
- bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1));
+ bool IsZext0 = isa<ZExtInst>(ICmp.getOperand(0));
+ bool IsZext1 = isa<ZExtInst>(ICmp.getOperand(1));
if (IsZext0 != IsZext1) {
// If X and Y and both i1
@@ -5396,11 +5570,16 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y),
Constant::getNullValue(X->getType()));
- // If we have mismatched casts, treat the zext of a non-negative source as
- // a sext to simulate matching casts. Otherwise, we are done.
- // TODO: Can we handle some predicates (equality) without non-negative?
- if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) ||
- (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT)))
+ // If we have mismatched casts and zext has the nneg flag, we can
+ // treat the "zext nneg" as "sext". Otherwise, we cannot fold and quit.
+
+ auto *NonNegInst0 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(0));
+ auto *NonNegInst1 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(1));
+
+ bool IsNonNeg0 = NonNegInst0 && NonNegInst0->hasNonNeg();
+ bool IsNonNeg1 = NonNegInst1 && NonNegInst1->hasNonNeg();
+
+ if ((IsZext0 && IsNonNeg0) || (IsZext1 && IsNonNeg1))
IsSignedExt = true;
else
return nullptr;
@@ -5442,25 +5621,20 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
if (!C)
return nullptr;
- // Compute the constant that would happen if we truncated to SrcTy then
- // re-extended to DestTy.
+ // If a lossless truncate is possible...
Type *SrcTy = CastOp0->getSrcTy();
- Type *DestTy = CastOp0->getDestTy();
- Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
- Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
-
- // If the re-extended constant didn't change...
- if (Res2 == C) {
+ Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode());
+ if (Res) {
if (ICmp.isEquality())
- return new ICmpInst(ICmp.getPredicate(), X, Res1);
+ return new ICmpInst(ICmp.getPredicate(), X, Res);
// A signed comparison of sign extended values simplifies into a
// signed comparison.
if (IsSignedExt && IsSignedCmp)
- return new ICmpInst(ICmp.getPredicate(), X, Res1);
+ return new ICmpInst(ICmp.getPredicate(), X, Res);
// The other three cases all fold into an unsigned comparison.
- return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
+ return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res);
}
// The re-extended constant changed, partly changed (in the case of a vector),
@@ -5518,13 +5692,8 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
Value *NewOp1 = nullptr;
if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
Value *PtrSrc = PtrToIntOp1->getOperand(0);
- if (PtrSrc->getType()->getPointerAddressSpace() ==
- Op0Src->getType()->getPointerAddressSpace()) {
+ if (PtrSrc->getType() == Op0Src->getType())
NewOp1 = PtrToIntOp1->getOperand(0);
- // If the pointer types don't match, insert a bitcast.
- if (Op0Src->getType() != NewOp1->getType())
- NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType());
- }
} else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy);
}
@@ -5641,22 +5810,20 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
/// \returns Instruction which must replace the compare instruction, NULL if no
/// replacement required.
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
- Value *OtherVal,
+ const APInt *OtherVal,
InstCombinerImpl &IC) {
// Don't bother doing this transformation for pointers, don't do it for
// vectors.
if (!isa<IntegerType>(MulVal->getType()))
return nullptr;
- assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
- assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
auto *MulInstr = dyn_cast<Instruction>(MulVal);
if (!MulInstr)
return nullptr;
assert(MulInstr->getOpcode() == Instruction::Mul);
- auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
- *RHS = cast<ZExtOperator>(MulInstr->getOperand(1));
+ auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)),
+ *RHS = cast<ZExtInst>(MulInstr->getOperand(1));
assert(LHS->getOpcode() == Instruction::ZExt);
assert(RHS->getOpcode() == Instruction::ZExt);
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
@@ -5709,70 +5876,26 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
// Recognize patterns
switch (I.getPredicate()) {
- case ICmpInst::ICMP_EQ:
- case ICmpInst::ICMP_NE:
- // Recognize pattern:
- // mulval = mul(zext A, zext B)
- // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits.
- ConstantInt *CI;
- Value *ValToMask;
- if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) {
- if (ValToMask != MulVal)
- return nullptr;
- const APInt &CVal = CI->getValue() + 1;
- if (CVal.isPowerOf2()) {
- unsigned MaskWidth = CVal.logBase2();
- if (MaskWidth == MulWidth)
- break; // Recognized
- }
- }
- return nullptr;
-
- case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGT: {
// Recognize pattern:
// mulval = mul(zext A, zext B)
// cmp ugt mulval, max
- if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
- APInt MaxVal = APInt::getMaxValue(MulWidth);
- MaxVal = MaxVal.zext(CI->getBitWidth());
- if (MaxVal.eq(CI->getValue()))
- break; // Recognized
- }
- return nullptr;
-
- case ICmpInst::ICMP_UGE:
- // Recognize pattern:
- // mulval = mul(zext A, zext B)
- // cmp uge mulval, max+1
- if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
- APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
- if (MaxVal.eq(CI->getValue()))
- break; // Recognized
- }
- return nullptr;
-
- case ICmpInst::ICMP_ULE:
- // Recognize pattern:
- // mulval = mul(zext A, zext B)
- // cmp ule mulval, max
- if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
- APInt MaxVal = APInt::getMaxValue(MulWidth);
- MaxVal = MaxVal.zext(CI->getBitWidth());
- if (MaxVal.eq(CI->getValue()))
- break; // Recognized
- }
+ APInt MaxVal = APInt::getMaxValue(MulWidth);
+ MaxVal = MaxVal.zext(OtherVal->getBitWidth());
+ if (MaxVal.eq(*OtherVal))
+ break; // Recognized
return nullptr;
+ }
- case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULT: {
// Recognize pattern:
// mulval = mul(zext A, zext B)
// cmp ule mulval, max + 1
- if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
- APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
- if (MaxVal.eq(CI->getValue()))
- break; // Recognized
- }
+ APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
+ if (MaxVal.eq(*OtherVal))
+ break; // Recognized
return nullptr;
+ }
default:
return nullptr;
@@ -5798,7 +5921,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
if (MulVal->hasNUsesOrMore(2)) {
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
for (User *U : make_early_inc_range(MulVal->users())) {
- if (U == &I || U == OtherVal)
+ if (U == &I)
continue;
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
@@ -5819,34 +5942,10 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
IC.addToWorklist(cast<Instruction>(U));
}
}
- if (isa<Instruction>(OtherVal))
- IC.addToWorklist(cast<Instruction>(OtherVal));
// The original icmp gets replaced with the overflow value, maybe inverted
// depending on predicate.
- bool Inverse = false;
- switch (I.getPredicate()) {
- case ICmpInst::ICMP_NE:
- break;
- case ICmpInst::ICMP_EQ:
- Inverse = true;
- break;
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE:
- if (I.getOperand(0) == MulVal)
- break;
- Inverse = true;
- break;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE:
- if (I.getOperand(1) == MulVal)
- break;
- Inverse = true;
- break;
- default:
- llvm_unreachable("Unexpected predicate");
- }
- if (Inverse) {
+ if (I.getPredicate() == ICmpInst::ICMP_ULT) {
Value *Res = Builder.CreateExtractValue(Call, 1);
return BinaryOperator::CreateNot(Res);
}
@@ -6015,13 +6114,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
KnownBits Op0Known(BitWidth);
KnownBits Op1Known(BitWidth);
- if (SimplifyDemandedBits(&I, 0,
- getDemandedBitsLHSMask(I, BitWidth),
- Op0Known, 0))
- return &I;
+ {
+ // Don't use dominating conditions when folding icmp using known bits. This
+ // may convert signed into unsigned predicates in ways that other passes
+ // (especially IndVarSimplify) may not be able to reliably undo.
+ SQ.DC = nullptr;
+ auto _ = make_scope_exit([&]() { SQ.DC = &DC; });
+ if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth),
+ Op0Known, 0))
+ return &I;
- if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0))
- return &I;
+ if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0))
+ return &I;
+ }
// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
@@ -6269,57 +6374,70 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
+ // icmp eq/ne X, (zext/sext (icmp eq/ne X, C))
+ ICmpInst::Predicate Pred1, Pred2;
const APInt *C;
- if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
- match(I.getOperand(1), m_APInt(C)) &&
- X->getType()->isIntOrIntVectorTy(1) &&
- Y->getType()->isIntOrIntVectorTy(1)) {
- unsigned BitWidth = C->getBitWidth();
- Pred = I.getPredicate();
- APInt Zero = APInt::getZero(BitWidth);
- APInt MinusOne = APInt::getAllOnes(BitWidth);
- APInt One(BitWidth, 1);
- if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) ||
- (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT))
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) ||
- (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT))
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
- if (I.getOperand(0)->hasOneUse()) {
- APInt NewC = *C;
- // canonicalize predicate to eq/ne
- if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) ||
- (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) {
- // x s< 0 in [-1, 1] --> x == -1
- // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1
- NewC = MinusOne;
- Pred = ICmpInst::ICMP_EQ;
- } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) ||
- (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) {
- // x s> -1 in [-1, 1] --> x != -1
- // x u< -1 in [-1, 1] --> x != -1
- Pred = ICmpInst::ICMP_NE;
- } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) {
- // x s> 0 in [-1, 1] --> x == 1
- NewC = One;
- Pred = ICmpInst::ICMP_EQ;
- } else if (*C == One && Pred == ICmpInst::ICMP_SLT) {
- // x s< 1 in [-1, 1] --> x != 1
- Pred = ICmpInst::ICMP_NE;
+ Instruction *ExtI;
+ if (match(&I, m_c_ICmp(Pred1, m_Value(X),
+ m_CombineAnd(m_Instruction(ExtI),
+ m_ZExtOrSExt(m_ICmp(Pred2, m_Deferred(X),
+ m_APInt(C)))))) &&
+ ICmpInst::isEquality(Pred1) && ICmpInst::isEquality(Pred2)) {
+ bool IsSExt = ExtI->getOpcode() == Instruction::SExt;
+ bool HasOneUse = ExtI->hasOneUse() && ExtI->getOperand(0)->hasOneUse();
+ auto CreateRangeCheck = [&] {
+ Value *CmpV1 =
+ Builder.CreateICmp(Pred1, X, Constant::getNullValue(X->getType()));
+ Value *CmpV2 = Builder.CreateICmp(
+ Pred1, X, ConstantInt::getSigned(X->getType(), IsSExt ? -1 : 1));
+ return BinaryOperator::Create(
+ Pred1 == ICmpInst::ICMP_EQ ? Instruction::Or : Instruction::And,
+ CmpV1, CmpV2);
+ };
+ if (C->isZero()) {
+ if (Pred2 == ICmpInst::ICMP_EQ) {
+ // icmp eq X, (zext/sext (icmp eq X, 0)) --> false
+ // icmp ne X, (zext/sext (icmp eq X, 0)) --> true
+ return replaceInstUsesWith(
+ I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE));
+ } else if (!IsSExt || HasOneUse) {
+ // icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1
+ // icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1
+ // icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1
+ // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X == -1
+ return CreateRangeCheck();
}
-
- if (NewC == MinusOne) {
- if (Pred == ICmpInst::ICMP_EQ)
- return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
- if (Pred == ICmpInst::ICMP_NE)
- return BinaryOperator::CreateOr(X, Builder.CreateNot(Y));
- } else if (NewC == One) {
- if (Pred == ICmpInst::ICMP_EQ)
- return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y));
- if (Pred == ICmpInst::ICMP_NE)
- return BinaryOperator::CreateOr(Builder.CreateNot(X), Y);
+ } else if (IsSExt ? C->isAllOnes() : C->isOne()) {
+ if (Pred2 == ICmpInst::ICMP_NE) {
+ // icmp eq X, (zext (icmp ne X, 1)) --> false
+ // icmp ne X, (zext (icmp ne X, 1)) --> true
+ // icmp eq X, (sext (icmp ne X, -1)) --> false
+ // icmp ne X, (sext (icmp ne X, -1)) --> true
+ return replaceInstUsesWith(
+ I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE));
+ } else if (!IsSExt || HasOneUse) {
+ // icmp eq X, (zext (icmp eq X, 1)) --> X == 0 || X == 1
+ // icmp ne X, (zext (icmp eq X, 1)) --> X != 0 && X != 1
+ // icmp eq X, (sext (icmp eq X, -1)) --> X == 0 || X == -1
+ // icmp ne X, (sext (icmp eq X, -1)) --> X != 0 && X == -1
+ return CreateRangeCheck();
}
+ } else {
+ // when C != 0 && C != 1:
+ // icmp eq X, (zext (icmp eq X, C)) --> icmp eq X, 0
+ // icmp eq X, (zext (icmp ne X, C)) --> icmp eq X, 1
+ // icmp ne X, (zext (icmp eq X, C)) --> icmp ne X, 0
+ // icmp ne X, (zext (icmp ne X, C)) --> icmp ne X, 1
+ // when C != 0 && C != -1:
+ // icmp eq X, (sext (icmp eq X, C)) --> icmp eq X, 0
+ // icmp eq X, (sext (icmp ne X, C)) --> icmp eq X, -1
+ // icmp ne X, (sext (icmp eq X, C)) --> icmp ne X, 0
+ // icmp ne X, (sext (icmp ne X, C)) --> icmp ne X, -1
+ return ICmpInst::Create(
+ Instruction::ICmp, Pred1, X,
+ ConstantInt::getSigned(X->getType(), Pred2 == ICmpInst::ICMP_NE
+ ? (IsSExt ? -1 : 1)
+ : 0));
}
}
@@ -6783,6 +6901,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpUsingKnownBits(I))
return Res;
+ if (Instruction *Res = foldICmpTruncWithTruncOrExt(I, Q))
+ return Res;
+
// Test if the ICmpInst instruction is used exclusively by a select as
// part of a minimum or maximum operation. If so, refrain from doing
// any other folding. This helps out other analyses which understand
@@ -6913,38 +7034,40 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
return Res;
{
- Value *A, *B;
- // Transform (A & ~B) == 0 --> (A & B) != 0
- // and (A & ~B) != 0 --> (A & B) == 0
+ Value *X, *Y;
+ // Transform (X & ~Y) == 0 --> (X & Y) != 0
+ // and (X & ~Y) != 0 --> (X & Y) == 0
// if A is a power of 2.
- if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1, m_Zero()) &&
- isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality())
- return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B),
+ if (match(Op0, m_And(m_Value(X), m_Not(m_Value(Y)))) &&
+ match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(X, false, 0, &I) &&
+ I.isEquality())
+ return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(X, Y),
Op1);
- // ~X < ~Y --> Y < X
- // ~X < C --> X > ~C
- if (match(Op0, m_Not(m_Value(A)))) {
- if (match(Op1, m_Not(m_Value(B))))
- return new ICmpInst(I.getPredicate(), B, A);
-
- const APInt *C;
- if (match(Op1, m_APInt(C)))
- return new ICmpInst(I.getSwappedPredicate(), A,
- ConstantInt::get(Op1->getType(), ~(*C)));
+ // Op0 pred Op1 -> ~Op1 pred ~Op0, if this allows us to drop an instruction.
+ if (Op0->getType()->isIntOrIntVectorTy()) {
+ bool ConsumesOp0, ConsumesOp1;
+ if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) &&
+ isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) &&
+ (ConsumesOp0 || ConsumesOp1)) {
+ Value *InvOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder);
+ Value *InvOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder);
+ assert(InvOp0 && InvOp1 &&
+ "Mismatch between isFreeToInvert and getFreelyInverted");
+ return new ICmpInst(I.getSwappedPredicate(), InvOp0, InvOp1);
+ }
}
Instruction *AddI = nullptr;
- if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B),
+ if (match(&I, m_UAddWithOverflow(m_Value(X), m_Value(Y),
m_Instruction(AddI))) &&
- isa<IntegerType>(A->getType())) {
+ isa<IntegerType>(X->getType())) {
Value *Result;
Constant *Overflow;
// m_UAddWithOverflow can match patterns that do not include an explicit
// "add" instruction, so check the opcode of the matched op.
if (AddI->getOpcode() == Instruction::Add &&
- OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI,
+ OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, X, Y, *AddI,
Result, Overflow)) {
replaceInstUsesWith(*AddI, Result);
eraseInstFromFunction(*AddI);
@@ -6952,14 +7075,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
}
}
- // (zext a) * (zext b) --> llvm.umul.with.overflow.
- if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
- if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this))
+ // (zext X) * (zext Y) --> llvm.umul.with.overflow.
+ if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
+ match(Op1, m_APInt(C))) {
+ if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
return R;
}
- if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
- if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this))
- return R;
+
+ // Signbit test folds
+ // Fold (X u>> BitWidth - 1 Pred ZExt(i1)) --> X s< 0 Pred i1
+ // Fold (X s>> BitWidth - 1 Pred SExt(i1)) --> X s< 0 Pred i1
+ Instruction *ExtI;
+ if ((I.isUnsigned() || I.isEquality()) &&
+ match(Op1,
+ m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(Y)))) &&
+ Y->getType()->getScalarSizeInBits() == 1 &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ unsigned OpWidth = Op0->getType()->getScalarSizeInBits();
+ Instruction *ShiftI;
+ if (match(Op0, m_CombineAnd(m_Instruction(ShiftI),
+ m_Shr(m_Value(X), m_SpecificIntAllowUndef(
+ OpWidth - 1))))) {
+ unsigned ExtOpc = ExtI->getOpcode();
+ unsigned ShiftOpc = ShiftI->getOpcode();
+ if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
+ (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
+ Value *SLTZero =
+ Builder.CreateICmpSLT(X, Constant::getNullValue(X->getType()));
+ Value *Cmp = Builder.CreateICmp(Pred, SLTZero, Y, I.getName());
+ return replaceInstUsesWith(I, Cmp);
+ }
+ }
}
}
@@ -7177,17 +7323,14 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
}
// Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
- // [0, UMAX], but it may still be fractional. See if it is fractional by
- // casting the FP value to the integer value and back, checking for equality.
+ // [0, UMAX], but it may still be fractional. Check whether this is the case
+ // using the IsExact flag.
// Don't do this for zero, because -0.0 is not fractional.
- Constant *RHSInt = LHSUnsigned
- ? ConstantExpr::getFPToUI(RHSC, IntTy)
- : ConstantExpr::getFPToSI(RHSC, IntTy);
+ APSInt RHSInt(IntWidth, LHSUnsigned);
+ bool IsExact;
+ RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact);
if (!RHS.isZero()) {
- bool Equal = LHSUnsigned
- ? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC
- : ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC;
- if (!Equal) {
+ if (!IsExact) {
// If we had a comparison against a fractional value, we have to adjust
// the compare predicate and sometimes the value. RHSC is rounded towards
// zero at this point.
@@ -7253,7 +7396,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
// Lower this FP comparison into an appropriate integer version of the
// comparison.
- return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt);
+ return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt));
}
/// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary.
@@ -7532,12 +7675,8 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) {
switch (LHSI->getOpcode()) {
case Instruction::PHI:
- // Only fold fcmp into the PHI if the phi and fcmp are in the same
- // block. If in the same block, we're encouraging jump threading. If
- // not, we are just pessimizing the code by making an i1 phi.
- if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
- return NV;
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
+ return NV;
break;
case Instruction::SIToFP:
case Instruction::UIToFP:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 701579e1de48..bb620ad8d41c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -16,6 +16,7 @@
#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -73,6 +74,10 @@ public:
virtual ~InstCombinerImpl() = default;
+ /// Perform early cleanup and prepare the InstCombine worklist.
+ bool prepareWorklist(Function &F,
+ ReversePostOrderTraversal<BasicBlock *> &RPOT);
+
/// Run the combiner over the entire worklist until it is empty.
///
/// \returns true if the IR is changed.
@@ -93,6 +98,7 @@ public:
Instruction *visitSub(BinaryOperator &I);
Instruction *visitFSub(BinaryOperator &I);
Instruction *visitMul(BinaryOperator &I);
+ Instruction *foldFMulReassoc(BinaryOperator &I);
Instruction *visitFMul(BinaryOperator &I);
Instruction *visitURem(BinaryOperator &I);
Instruction *visitSRem(BinaryOperator &I);
@@ -126,7 +132,6 @@ public:
Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I);
Instruction *commonCastTransforms(CastInst &CI);
- Instruction *commonPointerCastTransforms(CastInst &CI);
Instruction *visitTrunc(TruncInst &CI);
Instruction *visitZExt(ZExtInst &Zext);
Instruction *visitSExt(SExtInst &Sext);
@@ -193,6 +198,44 @@ public:
LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy,
const Twine &Suffix = "");
+ KnownFPClass computeKnownFPClass(Value *Val, FastMathFlags FMF,
+ FPClassTest Interested = fcAllFlags,
+ const Instruction *CtxI = nullptr,
+ unsigned Depth = 0) const {
+ return llvm::computeKnownFPClass(Val, FMF, DL, Interested, Depth, &TLI, &AC,
+ CtxI, &DT);
+ }
+
+ KnownFPClass computeKnownFPClass(Value *Val,
+ FPClassTest Interested = fcAllFlags,
+ const Instruction *CtxI = nullptr,
+ unsigned Depth = 0) const {
+ return llvm::computeKnownFPClass(Val, DL, Interested, Depth, &TLI, &AC,
+ CtxI, &DT);
+ }
+
+ /// Check if fmul \p MulVal, +0.0 will yield +0.0 (or signed zero is
+ /// ignorable).
+ bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
+ const Instruction *CtxI) const;
+
+ Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) {
+ Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
+ Constant *ExtTruncC =
+ ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL);
+ if (ExtTruncC && ExtTruncC == C)
+ return TruncC;
+ return nullptr;
+ }
+
+ Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
+ return getLosslessTrunc(C, TruncTy, Instruction::ZExt);
+ }
+
+ Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
+ return getLosslessTrunc(C, TruncTy, Instruction::SExt);
+ }
+
private:
bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
bool isDesirableIntType(unsigned BitWidth) const;
@@ -252,13 +295,15 @@ private:
Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext);
- bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS,
+ bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS,
+ const WithCache<const Value *> &RHS,
const Instruction &CxtI) const {
return computeOverflowForSignedAdd(LHS, RHS, &CxtI) ==
OverflowResult::NeverOverflows;
}
- bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS,
+ bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS,
+ const WithCache<const Value *> &RHS,
const Instruction &CxtI) const {
return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) ==
OverflowResult::NeverOverflows;
@@ -387,15 +432,17 @@ private:
Instruction *foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI,
bool IsAnd);
+ Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
+
public:
/// Create and insert the idiom we use to indicate a block is unreachable
/// without having to rewrite the CFG from within InstCombine.
void CreateNonTerminatorUnreachable(Instruction *InsertAt) {
auto &Ctx = InsertAt->getContext();
auto *SI = new StoreInst(ConstantInt::getTrue(Ctx),
- PoisonValue::get(Type::getInt1PtrTy(Ctx)),
+ PoisonValue::get(PointerType::getUnqual(Ctx)),
/*isVolatile*/ false, Align(1));
- InsertNewInstBefore(SI, *InsertAt);
+ InsertNewInstBefore(SI, InsertAt->getIterator());
}
/// Combiner aware instruction erasure.
@@ -412,6 +459,7 @@ public:
// use counts.
SmallVector<Value *> Ops(I.operands());
Worklist.remove(&I);
+ DC.removeValue(&I);
I.eraseFromParent();
for (Value *Op : Ops)
Worklist.handleUseCountDecrement(Op);
@@ -498,6 +546,7 @@ public:
/// Tries to simplify operands to an integer instruction based on its
/// demanded bits.
bool SimplifyDemandedInstructionBits(Instruction &Inst);
+ bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known);
Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
APInt &UndefElts, unsigned Depth = 0,
@@ -535,6 +584,9 @@ public:
Instruction *foldAddWithConstant(BinaryOperator &Add);
+ Instruction *foldSquareSumInt(BinaryOperator &I);
+ Instruction *foldSquareSumFP(BinaryOperator &I);
+
/// Try to rotate an operation below a PHI node, using PHI nodes for
/// its operands.
Instruction *foldPHIArgOpIntoPHI(PHINode &PN);
@@ -580,6 +632,9 @@ public:
Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
const APInt &C);
Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
+ Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
+ Value *Z, ICmpInst::Predicate Pred);
+ Instruction *foldICmpWithMinMax(ICmpInst &Cmp);
Instruction *foldICmpEquality(ICmpInst &Cmp);
Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
Instruction *foldSignBitTest(ICmpInst &I);
@@ -593,6 +648,8 @@ public:
ConstantInt *C);
Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc,
const APInt &C);
+ Instruction *foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
+ const SimplifyQuery &Q);
Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And,
const APInt &C);
Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor,
@@ -667,8 +724,12 @@ public:
bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock);
bool removeInstructionsBeforeUnreachable(Instruction &I);
- bool handleUnreachableFrom(Instruction *I);
- bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
+ void addDeadEdge(BasicBlock *From, BasicBlock *To,
+ SmallVectorImpl<BasicBlock *> &Worklist);
+ void handleUnreachableFrom(Instruction *I,
+ SmallVectorImpl<BasicBlock *> &Worklist);
+ void handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist);
+ void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
};
@@ -679,16 +740,11 @@ class Negator final {
using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>;
BuilderTy Builder;
- const DataLayout &DL;
- AssumptionCache &AC;
- const DominatorTree &DT;
-
const bool IsTrulyNegation;
SmallDenseMap<Value *, Value *> NegationsCache;
- Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC,
- const DominatorTree &DT, bool IsTrulyNegation);
+ Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation);
#if LLVM_ENABLE_STATS
unsigned NumValuesVisitedInThisNegator = 0;
@@ -700,13 +756,13 @@ class Negator final {
std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I);
- [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth);
+ [[nodiscard]] Value *visitImpl(Value *V, bool IsNSW, unsigned Depth);
- [[nodiscard]] Value *negate(Value *V, unsigned Depth);
+ [[nodiscard]] Value *negate(Value *V, bool IsNSW, unsigned Depth);
/// Recurse depth-first and attempt to sink the negation.
/// FIXME: use worklist?
- [[nodiscard]] std::optional<Result> run(Value *Root);
+ [[nodiscard]] std::optional<Result> run(Value *Root, bool IsNSW);
Negator(const Negator &) = delete;
Negator(Negator &&) = delete;
@@ -716,7 +772,7 @@ class Negator final {
public:
/// Attempt to negate \p Root. Retuns nullptr if negation can't be performed,
/// otherwise returns negated value.
- [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root,
+ [[nodiscard]] static Value *Negate(bool LHSIsZero, bool IsNSW, Value *Root,
InstCombinerImpl &IC);
};
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 6aa20ee26b9a..b72b68c68d98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -36,6 +36,13 @@ static cl::opt<unsigned> MaxCopiedFromConstantUsers(
cl::desc("Maximum users to visit in copy from constant transform"),
cl::Hidden);
+namespace llvm {
+cl::opt<bool> EnableInferAlignmentPass(
+ "enable-infer-alignment-pass", cl::init(true), cl::Hidden, cl::ZeroOrMore,
+ cl::desc("Enable the InferAlignment pass, disabling alignment inference in "
+ "InstCombine"));
+}
+
/// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived)
/// pointer to an alloca. Ignore any reads of the pointer, return false if we
/// see any stores or other unknown uses. If we see pointer arithmetic, keep
@@ -224,7 +231,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
Value *Idx[2] = {NullIdx, NullIdx};
Instruction *GEP = GetElementPtrInst::CreateInBounds(
NewTy, New, Idx, New->getName() + ".sub");
- IC.InsertNewInstBefore(GEP, *It);
+ IC.InsertNewInstBefore(GEP, It);
// Now make everything use the getelementptr instead of the original
// allocation.
@@ -380,7 +387,7 @@ void PointerReplacer::replace(Instruction *I) {
NewI->takeName(LT);
copyMetadataForLoad(*NewI, *LT);
- IC.InsertNewInstWith(NewI, *LT);
+ IC.InsertNewInstWith(NewI, LT->getIterator());
IC.replaceInstUsesWith(*LT, NewI);
WorkMap[LT] = NewI;
} else if (auto *PHI = dyn_cast<PHINode>(I)) {
@@ -398,7 +405,7 @@ void PointerReplacer::replace(Instruction *I) {
Indices.append(GEP->idx_begin(), GEP->idx_end());
auto *NewI =
GetElementPtrInst::Create(GEP->getSourceElementType(), V, Indices);
- IC.InsertNewInstWith(NewI, *GEP);
+ IC.InsertNewInstWith(NewI, GEP->getIterator());
NewI->takeName(GEP);
WorkMap[GEP] = NewI;
} else if (auto *BC = dyn_cast<BitCastInst>(I)) {
@@ -407,14 +414,14 @@ void PointerReplacer::replace(Instruction *I) {
auto *NewT = PointerType::get(BC->getType()->getContext(),
V->getType()->getPointerAddressSpace());
auto *NewI = new BitCastInst(V, NewT);
- IC.InsertNewInstWith(NewI, *BC);
+ IC.InsertNewInstWith(NewI, BC->getIterator());
NewI->takeName(BC);
WorkMap[BC] = NewI;
} else if (auto *SI = dyn_cast<SelectInst>(I)) {
auto *NewSI = SelectInst::Create(
SI->getCondition(), getReplacement(SI->getTrueValue()),
getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI);
- IC.InsertNewInstWith(NewSI, *SI);
+ IC.InsertNewInstWith(NewSI, SI->getIterator());
NewSI->takeName(SI);
WorkMap[SI] = NewSI;
} else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) {
@@ -449,7 +456,7 @@ void PointerReplacer::replace(Instruction *I) {
ASC->getType()->getPointerAddressSpace()) {
auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), "");
NewI->takeName(ASC);
- IC.InsertNewInstWith(NewI, *ASC);
+ IC.InsertNewInstWith(NewI, ASC->getIterator());
NewV = NewI;
}
IC.replaceInstUsesWith(*ASC, NewV);
@@ -507,8 +514,6 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
// types.
const Align MaxAlign = std::max(EntryAI->getAlign(), AI.getAlign());
EntryAI->setAlignment(MaxAlign);
- if (AI.getType() != EntryAI->getType())
- return new BitCastInst(EntryAI, AI.getType());
return replaceInstUsesWith(AI, EntryAI);
}
}
@@ -534,13 +539,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n');
LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n');
unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace();
- auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace);
if (AI.getAddressSpace() == SrcAddrSpace) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
- Value *Cast = Builder.CreateBitCast(TheSrc, DestTy);
- Instruction *NewI = replaceInstUsesWith(AI, Cast);
+ Instruction *NewI = replaceInstUsesWith(AI, TheSrc);
eraseInstFromFunction(*Copy);
++NumGlobalCopies;
return NewI;
@@ -551,8 +554,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
- Value *Cast = Builder.CreateBitCast(TheSrc, DestTy);
- PtrReplacer.replacePointer(Cast);
+ PtrReplacer.replacePointer(TheSrc);
++NumGlobalCopies;
}
}
@@ -582,16 +584,9 @@ LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy,
assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) &&
"can't fold an atomic load to requested type");
- Value *Ptr = LI.getPointerOperand();
- unsigned AS = LI.getPointerAddressSpace();
- Type *NewPtrTy = NewTy->getPointerTo(AS);
- Value *NewPtr = nullptr;
- if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) &&
- NewPtr->getType() == NewPtrTy))
- NewPtr = Builder.CreateBitCast(Ptr, NewPtrTy);
-
- LoadInst *NewLoad = Builder.CreateAlignedLoad(
- NewTy, NewPtr, LI.getAlign(), LI.isVolatile(), LI.getName() + Suffix);
+ LoadInst *NewLoad =
+ Builder.CreateAlignedLoad(NewTy, LI.getPointerOperand(), LI.getAlign(),
+ LI.isVolatile(), LI.getName() + Suffix);
NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
copyMetadataForLoad(*NewLoad, LI);
return NewLoad;
@@ -606,13 +601,11 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI,
"can't fold an atomic store of requested type");
Value *Ptr = SI.getPointerOperand();
- unsigned AS = SI.getPointerAddressSpace();
SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
SI.getAllMetadata(MD);
- StoreInst *NewStore = IC.Builder.CreateAlignedStore(
- V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)),
- SI.getAlign(), SI.isVolatile());
+ StoreInst *NewStore =
+ IC.Builder.CreateAlignedStore(V, Ptr, SI.getAlign(), SI.isVolatile());
NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID());
for (const auto &MDPair : MD) {
unsigned ID = MDPair.first;
@@ -655,29 +648,6 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI,
return NewStore;
}
-/// Returns true if instruction represent minmax pattern like:
-/// select ((cmp load V1, load V2), V1, V2).
-static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) {
- assert(V->getType()->isPointerTy() && "Expected pointer type.");
- // Ignore possible ty* to ixx* bitcast.
- V = InstCombiner::peekThroughBitcast(V);
- // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax
- // pattern.
- CmpInst::Predicate Pred;
- Instruction *L1;
- Instruction *L2;
- Value *LHS;
- Value *RHS;
- if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)),
- m_Value(LHS), m_Value(RHS))))
- return false;
- LoadTy = L1->getType();
- return (match(L1, m_Load(m_Specific(LHS))) &&
- match(L2, m_Load(m_Specific(RHS)))) ||
- (match(L1, m_Load(m_Specific(RHS))) &&
- match(L2, m_Load(m_Specific(LHS))));
-}
-
/// Combine loads to match the type of their uses' value after looking
/// through intervening bitcasts.
///
@@ -818,7 +788,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
return nullptr;
const DataLayout &DL = IC.getDataLayout();
- auto EltSize = DL.getTypeAllocSize(ET);
+ TypeSize EltSize = DL.getTypeAllocSize(ET);
const auto Align = LI.getAlign();
auto *Addr = LI.getPointerOperand();
@@ -826,7 +796,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
auto *Zero = ConstantInt::get(IdxType, 0);
Value *V = PoisonValue::get(T);
- uint64_t Offset = 0;
+ TypeSize Offset = TypeSize::get(0, ET->isScalableTy());
for (uint64_t i = 0; i < NumElements; i++) {
Value *Indices[2] = {
Zero,
@@ -834,9 +804,9 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
};
auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices),
Name + ".elt");
+ auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue());
auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr,
- commonAlignment(Align, Offset),
- Name + ".unpack");
+ EltAlign, Name + ".unpack");
L->setAAMetadata(LI.getAAMetadata());
V = IC.Builder.CreateInsertValue(V, L, i);
Offset += EltSize;
@@ -971,7 +941,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC,
Type *SourceElementType = GEPI->getSourceElementType();
// Size information about scalable vectors is not available, so we cannot
// deduce whether indexing at n is undefined behaviour or not. Bail out.
- if (isa<ScalableVectorType>(SourceElementType))
+ if (SourceElementType->isScalableTy())
return false;
Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops);
@@ -1020,7 +990,7 @@ static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr,
Instruction *NewGEPI = GEPI->clone();
NewGEPI->setOperand(Idx,
ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0));
- IC.InsertNewInstBefore(NewGEPI, *GEPI);
+ IC.InsertNewInstBefore(NewGEPI, GEPI->getIterator());
return NewGEPI;
}
}
@@ -1062,11 +1032,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
if (Instruction *Res = combineLoadToOperationType(*this, LI))
return Res;
- // Attempt to improve the alignment.
- Align KnownAlign = getOrEnforceKnownAlignment(
- Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT);
- if (KnownAlign > LI.getAlign())
- LI.setAlignment(KnownAlign);
+ if (!EnableInferAlignmentPass) {
+ // Attempt to improve the alignment.
+ Align KnownAlign = getOrEnforceKnownAlignment(
+ Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT);
+ if (KnownAlign > LI.getAlign())
+ LI.setAlignment(KnownAlign);
+ }
// Replace GEP indices if possible.
if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI))
@@ -1337,7 +1309,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
return false;
const DataLayout &DL = IC.getDataLayout();
- auto EltSize = DL.getTypeAllocSize(AT->getElementType());
+ TypeSize EltSize = DL.getTypeAllocSize(AT->getElementType());
const auto Align = SI.getAlign();
SmallString<16> EltName = V->getName();
@@ -1349,7 +1321,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
auto *IdxType = Type::getInt64Ty(T->getContext());
auto *Zero = ConstantInt::get(IdxType, 0);
- uint64_t Offset = 0;
+ TypeSize Offset = TypeSize::get(0, AT->getElementType()->isScalableTy());
for (uint64_t i = 0; i < NumElements; i++) {
Value *Indices[2] = {
Zero,
@@ -1358,7 +1330,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
auto *Ptr =
IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName);
auto *Val = IC.Builder.CreateExtractValue(V, i, EltName);
- auto EltAlign = commonAlignment(Align, Offset);
+ auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue());
Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign);
NS->setAAMetadata(SI.getAAMetadata());
Offset += EltSize;
@@ -1399,58 +1371,6 @@ static bool equivalentAddressValues(Value *A, Value *B) {
return false;
}
-/// Converts store (bitcast (load (bitcast (select ...)))) to
-/// store (load (select ...)), where select is minmax:
-/// select ((cmp load V1, load V2), V1, V2).
-static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC,
- StoreInst &SI) {
- // bitcast?
- if (!match(SI.getPointerOperand(), m_BitCast(m_Value())))
- return false;
- // load? integer?
- Value *LoadAddr;
- if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr)))))
- return false;
- auto *LI = cast<LoadInst>(SI.getValueOperand());
- if (!LI->getType()->isIntegerTy())
- return false;
- Type *CmpLoadTy;
- if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy))
- return false;
-
- // Make sure the type would actually change.
- // This condition can be hit with chains of bitcasts.
- if (LI->getType() == CmpLoadTy)
- return false;
-
- // Make sure we're not changing the size of the load/store.
- const auto &DL = IC.getDataLayout();
- if (DL.getTypeStoreSizeInBits(LI->getType()) !=
- DL.getTypeStoreSizeInBits(CmpLoadTy))
- return false;
-
- if (!all_of(LI->users(), [LI, LoadAddr](User *U) {
- auto *SI = dyn_cast<StoreInst>(U);
- return SI && SI->getPointerOperand() != LI &&
- InstCombiner::peekThroughBitcast(SI->getPointerOperand()) !=
- LoadAddr &&
- !SI->getPointerOperand()->isSwiftError();
- }))
- return false;
-
- IC.Builder.SetInsertPoint(LI);
- LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy);
- // Replace all the stores with stores of the newly loaded value.
- for (auto *UI : LI->users()) {
- auto *USI = cast<StoreInst>(UI);
- IC.Builder.SetInsertPoint(USI);
- combineStoreToNewValue(IC, *USI, NewLI);
- }
- IC.replaceInstUsesWith(*LI, PoisonValue::get(LI->getType()));
- IC.eraseInstFromFunction(*LI);
- return true;
-}
-
Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
Value *Val = SI.getOperand(0);
Value *Ptr = SI.getOperand(1);
@@ -1459,19 +1379,18 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
if (combineStoreToValueType(*this, SI))
return eraseInstFromFunction(SI);
- // Attempt to improve the alignment.
- const Align KnownAlign = getOrEnforceKnownAlignment(
- Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT);
- if (KnownAlign > SI.getAlign())
- SI.setAlignment(KnownAlign);
+ if (!EnableInferAlignmentPass) {
+ // Attempt to improve the alignment.
+ const Align KnownAlign = getOrEnforceKnownAlignment(
+ Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT);
+ if (KnownAlign > SI.getAlign())
+ SI.setAlignment(KnownAlign);
+ }
// Try to canonicalize the stored type.
if (unpackStoreToAggregate(*this, SI))
return eraseInstFromFunction(SI);
- if (removeBitcastsFromLoadStoreOnMinMax(*this, SI))
- return eraseInstFromFunction(SI);
-
// Replace GEP indices if possible.
if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI))
return replaceOperand(SI, 1, NewGEPI);
@@ -1508,8 +1427,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
--BBI;
// Don't count debug info directives, lest they affect codegen,
// and we skip pointer-to-pointer bitcasts, which are NOPs.
- if (BBI->isDebugOrPseudoInst() ||
- (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) {
+ if (BBI->isDebugOrPseudoInst()) {
ScanInsts++;
continue;
}
@@ -1560,11 +1478,15 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
// This is a non-terminator unreachable marker. Don't remove it.
if (isa<UndefValue>(Ptr)) {
- // Remove all instructions after the marker and guaranteed-to-transfer
- // instructions before the marker.
- if (handleUnreachableFrom(SI.getNextNode()) ||
- removeInstructionsBeforeUnreachable(SI))
+ // Remove guaranteed-to-transfer instructions before the marker.
+ if (removeInstructionsBeforeUnreachable(SI))
return &SI;
+
+ // Remove all instructions after the marker and handle dead blocks this
+ // implies.
+ SmallVector<BasicBlock *> Worklist;
+ handleUnreachableFrom(SI.getNextNode(), Worklist);
+ handlePotentiallyDeadBlocks(Worklist);
return nullptr;
}
@@ -1626,8 +1548,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
if (OtherBr->isUnconditional()) {
--BBI;
// Skip over debugging info and pseudo probes.
- while (BBI->isDebugOrPseudoInst() ||
- (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) {
+ while (BBI->isDebugOrPseudoInst()) {
if (BBI==OtherBB->begin())
return false;
--BBI;
@@ -1681,7 +1602,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
Builder.SetInsertPoint(OtherStore);
PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()),
OtherBB);
- MergedVal = InsertNewInstBefore(PN, DestBB->front());
+ MergedVal = InsertNewInstBefore(PN, DestBB->begin());
PN->setDebugLoc(MergedLoc);
}
@@ -1690,7 +1611,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
StoreInst *NewSI =
new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), SI.getAlign(),
SI.getOrdering(), SI.getSyncScopeID());
- InsertNewInstBefore(NewSI, *BBI);
+ InsertNewInstBefore(NewSI, BBI);
NewSI->setDebugLoc(MergedLoc);
NewSI->mergeDIAssignID({&SI, OtherStore});
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 50458e2773e6..8d5866e98a8e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -258,9 +258,14 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) {
// Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation.
// The "* (1<<C)" thus becomes a potential shifting opportunity.
- if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this))
- return BinaryOperator::CreateMul(
- NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName());
+ if (Value *NegOp0 =
+ Negator::Negate(/*IsNegation*/ true, HasNSW, Op0, *this)) {
+ auto *Op1C = cast<Constant>(Op1);
+ return replaceInstUsesWith(
+ I, Builder.CreateMul(NegOp0, ConstantExpr::getNeg(Op1C), "",
+ /* HasNUW */ false,
+ HasNSW && Op1C->isNotMinSignedValue()));
+ }
// Try to convert multiply of extended operand to narrow negate and shift
// for better analysis.
@@ -295,9 +300,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC.
Value *X;
Constant *C1;
- if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) ||
- (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) &&
- haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) {
+ if (match(Op0, m_OneUse(m_AddLike(m_Value(X), m_ImmConstant(C1))))) {
// C1*MulC simplifies to a tidier constant.
Value *NewC = Builder.CreateMul(C1, MulC);
auto *BOp0 = cast<BinaryOperator>(Op0);
@@ -555,6 +558,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
return nullptr;
}
+Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ Value *X, *Y;
+ Constant *C;
+
+ // Reassociate constant RHS with another constant to form constant
+ // expression.
+ if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
+ Constant *C1;
+ if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
+ // (C1 / X) * C --> (C * C1) / X
+ Constant *CC1 =
+ ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
+ if (CC1 && CC1->isNormalFP())
+ return BinaryOperator::CreateFDivFMF(CC1, X, &I);
+ }
+ if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
+ // (X / C1) * C --> X * (C / C1)
+ Constant *CDivC1 =
+ ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
+ if (CDivC1 && CDivC1->isNormalFP())
+ return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
+
+ // If the constant was a denormal, try reassociating differently.
+ // (X / C1) * C --> X / (C1 / C)
+ Constant *C1DivC =
+ ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
+ if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
+ return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
+ }
+
+ // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
+ // canonicalized to 'fadd X, C'. Distributing the multiply may allow
+ // further folds and (X * C) + C2 is 'fma'.
+ if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
+ // (X + C1) * C --> (X * C) + (C * C1)
+ if (Constant *CC1 =
+ ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+ Value *XC = Builder.CreateFMulFMF(X, C, &I);
+ return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
+ }
+ }
+ if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
+ // (C1 - X) * C --> (C * C1) - (X * C)
+ if (Constant *CC1 =
+ ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+ Value *XC = Builder.CreateFMulFMF(X, C, &I);
+ return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
+ }
+ }
+ }
+
+ Value *Z;
+ if (match(&I,
+ m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) {
+ // Sink division: (X / Y) * Z --> (X * Z) / Y
+ Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
+ return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
+ }
+
+ // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+ // nnan disallows the possibility of returning a number if both operands are
+ // negative (in that case, we should return NaN).
+ if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
+ match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
+ Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+ Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
+ return replaceInstUsesWith(I, Sqrt);
+ }
+
+ // The following transforms are done irrespective of the number of uses
+ // for the expression "1.0/sqrt(X)".
+ // 1) 1.0/sqrt(X) * X -> X/sqrt(X)
+ // 2) X * 1.0/sqrt(X) -> X/sqrt(X)
+ // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
+ // has the necessary (reassoc) fast-math-flags.
+ if (I.hasNoSignedZeros() &&
+ match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+ match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
+ return BinaryOperator::CreateFDivFMF(X, Y, &I);
+ if (I.hasNoSignedZeros() &&
+ match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+ match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
+ return BinaryOperator::CreateFDivFMF(X, Y, &I);
+
+ // Like the similar transform in instsimplify, this requires 'nsz' because
+ // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
+ if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
+ // Peek through fdiv to find squaring of square root:
+ // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
+ if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
+ Value *XX = Builder.CreateFMulFMF(X, X, &I);
+ return BinaryOperator::CreateFDivFMF(XX, Y, &I);
+ }
+ // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
+ if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
+ Value *XX = Builder.CreateFMulFMF(X, X, &I);
+ return BinaryOperator::CreateFDivFMF(Y, XX, &I);
+ }
+ }
+
+ // pow(X, Y) * X --> pow(X, Y+1)
+ // X * pow(X, Y) --> pow(X, Y+1)
+ if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
+ m_Value(Y))),
+ m_Deferred(X)))) {
+ Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
+ Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
+ return replaceInstUsesWith(I, Pow);
+ }
+
+ if (I.isOnlyUserOfAnyOperand()) {
+ // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
+ if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+ match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
+ auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
+ auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
+ return replaceInstUsesWith(I, NewPow);
+ }
+ // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+ match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
+ auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
+ auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
+ return replaceInstUsesWith(I, NewPow);
+ }
+
+ // powi(x, y) * powi(x, z) -> powi(x, y + z)
+ if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
+ match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
+ Y->getType() == Z->getType()) {
+ auto *YZ = Builder.CreateAdd(Y, Z);
+ auto *NewPow = Builder.CreateIntrinsic(
+ Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
+ return replaceInstUsesWith(I, NewPow);
+ }
+
+ // exp(X) * exp(Y) -> exp(X + Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
+ match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
+ Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+ Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
+ return replaceInstUsesWith(I, Exp);
+ }
+
+ // exp2(X) * exp2(Y) -> exp2(X + Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
+ match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
+ Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+ Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
+ return replaceInstUsesWith(I, Exp2);
+ }
+ }
+
+ // (X*Y) * X => (X*X) * Y where Y != X
+ // The purpose is two-fold:
+ // 1) to form a power expression (of X).
+ // 2) potentially shorten the critical path: After transformation, the
+ // latency of the instruction Y is amortized by the expression of X*X,
+ // and therefore Y is in a "less critical" position compared to what it
+ // was before the transformation.
+ if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) {
+ Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
+ return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+ }
+ if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) {
+ Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
+ return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
@@ -602,176 +779,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);
- if (I.hasAllowReassoc()) {
- // Reassociate constant RHS with another constant to form constant
- // expression.
- if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
- Constant *C1;
- if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
- // (C1 / X) * C --> (C * C1) / X
- Constant *CC1 =
- ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
- if (CC1 && CC1->isNormalFP())
- return BinaryOperator::CreateFDivFMF(CC1, X, &I);
- }
- if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
- // (X / C1) * C --> X * (C / C1)
- Constant *CDivC1 =
- ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
- if (CDivC1 && CDivC1->isNormalFP())
- return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
-
- // If the constant was a denormal, try reassociating differently.
- // (X / C1) * C --> X / (C1 / C)
- Constant *C1DivC =
- ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
- if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
- return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
- }
-
- // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
- // canonicalized to 'fadd X, C'. Distributing the multiply may allow
- // further folds and (X * C) + C2 is 'fma'.
- if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
- // (X + C1) * C --> (X * C) + (C * C1)
- if (Constant *CC1 = ConstantFoldBinaryOpOperands(
- Instruction::FMul, C, C1, DL)) {
- Value *XC = Builder.CreateFMulFMF(X, C, &I);
- return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
- }
- }
- if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
- // (C1 - X) * C --> (C * C1) - (X * C)
- if (Constant *CC1 = ConstantFoldBinaryOpOperands(
- Instruction::FMul, C, C1, DL)) {
- Value *XC = Builder.CreateFMulFMF(X, C, &I);
- return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
- }
- }
- }
-
- Value *Z;
- if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))),
- m_Value(Z)))) {
- // Sink division: (X / Y) * Z --> (X * Z) / Y
- Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
- return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
- }
-
- // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
- // nnan disallows the possibility of returning a number if both operands are
- // negative (in that case, we should return NaN).
- if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
- match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
- Value *XY = Builder.CreateFMulFMF(X, Y, &I);
- Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
- return replaceInstUsesWith(I, Sqrt);
- }
-
- // The following transforms are done irrespective of the number of uses
- // for the expression "1.0/sqrt(X)".
- // 1) 1.0/sqrt(X) * X -> X/sqrt(X)
- // 2) X * 1.0/sqrt(X) -> X/sqrt(X)
- // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
- // has the necessary (reassoc) fast-math-flags.
- if (I.hasNoSignedZeros() &&
- match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
- match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
- return BinaryOperator::CreateFDivFMF(X, Y, &I);
- if (I.hasNoSignedZeros() &&
- match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
- match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
- return BinaryOperator::CreateFDivFMF(X, Y, &I);
-
- // Like the similar transform in instsimplify, this requires 'nsz' because
- // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
- if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
- Op0->hasNUses(2)) {
- // Peek through fdiv to find squaring of square root:
- // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
- if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
- Value *XX = Builder.CreateFMulFMF(X, X, &I);
- return BinaryOperator::CreateFDivFMF(XX, Y, &I);
- }
- // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
- if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
- Value *XX = Builder.CreateFMulFMF(X, X, &I);
- return BinaryOperator::CreateFDivFMF(Y, XX, &I);
- }
- }
-
- // pow(X, Y) * X --> pow(X, Y+1)
- // X * pow(X, Y) --> pow(X, Y+1)
- if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
- m_Value(Y))),
- m_Deferred(X)))) {
- Value *Y1 =
- Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
- Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
- return replaceInstUsesWith(I, Pow);
- }
-
- if (I.isOnlyUserOfAnyOperand()) {
- // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
- if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
- match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
- auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
- auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
- return replaceInstUsesWith(I, NewPow);
- }
- // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
- if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
- match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
- auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
- auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
- return replaceInstUsesWith(I, NewPow);
- }
-
- // powi(x, y) * powi(x, z) -> powi(x, y + z)
- if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
- match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
- Y->getType() == Z->getType()) {
- auto *YZ = Builder.CreateAdd(Y, Z);
- auto *NewPow = Builder.CreateIntrinsic(
- Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
- return replaceInstUsesWith(I, NewPow);
- }
-
- // exp(X) * exp(Y) -> exp(X + Y)
- if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
- match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
- Value *XY = Builder.CreateFAddFMF(X, Y, &I);
- Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
- return replaceInstUsesWith(I, Exp);
- }
-
- // exp2(X) * exp2(Y) -> exp2(X + Y)
- if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
- match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
- Value *XY = Builder.CreateFAddFMF(X, Y, &I);
- Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
- return replaceInstUsesWith(I, Exp2);
- }
- }
-
- // (X*Y) * X => (X*X) * Y where Y != X
- // The purpose is two-fold:
- // 1) to form a power expression (of X).
- // 2) potentially shorten the critical path: After transformation, the
- // latency of the instruction Y is amortized by the expression of X*X,
- // and therefore Y is in a "less critical" position compared to what it
- // was before the transformation.
- if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) &&
- Op1 != Y) {
- Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
- return BinaryOperator::CreateFMulFMF(XX, Y, &I);
- }
- if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) &&
- Op0 != Y) {
- Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
- return BinaryOperator::CreateFMulFMF(XX, Y, &I);
- }
- }
+ if (I.hasAllowReassoc())
+ if (Instruction *FoldedMul = foldFMulReassoc(I))
+ return FoldedMul;
// log2(X * 0.5) * Y = log2(X) * Y - Y
if (I.isFast()) {
@@ -802,7 +812,7 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
I.hasNoSignedZeros() && match(Start, m_Zero()))
return replaceInstUsesWith(I, Start);
- // minimun(X, Y) * maximum(X, Y) => X * Y.
+ // minimum(X, Y) * maximum(X, Y) => X * Y.
if (match(&I,
m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
@@ -918,8 +928,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient,
return Remainder.isMinValue();
}
-static Instruction *foldIDivShl(BinaryOperator &I,
- InstCombiner::BuilderTy &Builder) {
+static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) {
assert((I.getOpcode() == Instruction::SDiv ||
I.getOpcode() == Instruction::UDiv) &&
"Expected integer divide");
@@ -928,7 +937,6 @@ static Instruction *foldIDivShl(BinaryOperator &I,
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
- Instruction *Ret = nullptr;
Value *X, *Y, *Z;
// With appropriate no-wrap constraints, remove a common factor in the
@@ -943,12 +951,12 @@ static Instruction *foldIDivShl(BinaryOperator &I,
// (X * Y) u/ (X << Z) --> Y u>> Z
if (!IsSigned && HasNUW)
- Ret = BinaryOperator::CreateLShr(Y, Z);
+ return Builder.CreateLShr(Y, Z, "", I.isExact());
// (X * Y) s/ (X << Z) --> Y s/ (1 << Z)
if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) {
Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
- Ret = BinaryOperator::CreateSDiv(Y, Shl);
+ return Builder.CreateSDiv(Y, Shl, "", I.isExact());
}
}
@@ -966,20 +974,38 @@ static Instruction *foldIDivShl(BinaryOperator &I,
((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) ||
(Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() &&
Shl1->hasNoSignedWrap())))
- Ret = BinaryOperator::CreateUDiv(X, Y);
+ return Builder.CreateUDiv(X, Y, "", I.isExact());
// For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor.
// (X << Z) / (Y << Z) --> X / Y
if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() &&
Shl1->hasNoUnsignedWrap())
- Ret = BinaryOperator::CreateSDiv(X, Y);
+ return Builder.CreateSDiv(X, Y, "", I.isExact());
}
- if (!Ret)
- return nullptr;
+ // If X << Y and X << Z does not overflow, then:
+ // (X << Y) / (X << Z) -> (1 << Y) / (1 << Z) -> 1 << Y >> Z
+ if (match(Op0, m_Shl(m_Value(X), m_Value(Y))) &&
+ match(Op1, m_Shl(m_Specific(X), m_Value(Z)))) {
+ auto *Shl0 = cast<OverflowingBinaryOperator>(Op0);
+ auto *Shl1 = cast<OverflowingBinaryOperator>(Op1);
- Ret->setIsExact(I.isExact());
- return Ret;
+ if (IsSigned ? (Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap())
+ : (Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap())) {
+ Constant *One = ConstantInt::get(X->getType(), 1);
+ // Only preserve the nsw flag if dividend has nsw
+ // or divisor has nsw and operator is sdiv.
+ Value *Dividend = Builder.CreateShl(
+ One, Y, "shl.dividend",
+ /*HasNUW*/ true,
+ /*HasNSW*/
+ IsSigned ? (Shl0->hasNoUnsignedWrap() || Shl1->hasNoUnsignedWrap())
+ : Shl0->hasNoSignedWrap());
+ return Builder.CreateLShr(Dividend, Z, "", I.isExact());
+ }
+ }
+
+ return nullptr;
}
/// This function implements the transforms common to both integer division
@@ -1156,8 +1182,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
return NewDiv;
}
- if (Instruction *R = foldIDivShl(I, Builder))
- return R;
+ if (Value *R = foldIDivShl(I, Builder))
+ return replaceInstUsesWith(I, R);
// With the appropriate no-wrap constraint, remove a multiply by the divisor
// after peeking through another divide:
@@ -1263,7 +1289,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
/// If we have zero-extended operands of an unsigned div or rem, we may be able
/// to narrow the operation (sink the zext below the math).
static Instruction *narrowUDivURem(BinaryOperator &I,
- InstCombiner::BuilderTy &Builder) {
+ InstCombinerImpl &IC) {
Instruction::BinaryOps Opcode = I.getOpcode();
Value *N = I.getOperand(0);
Value *D = I.getOperand(1);
@@ -1273,7 +1299,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) {
// udiv (zext X), (zext Y) --> zext (udiv X, Y)
// urem (zext X), (zext Y) --> zext (urem X, Y)
- Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y);
+ Value *NarrowOp = IC.Builder.CreateBinOp(Opcode, X, Y);
return new ZExtInst(NarrowOp, Ty);
}
@@ -1281,24 +1307,24 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
match(D, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
- if (ConstantExpr::getZExt(TruncC, Ty) != C)
+ Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ if (!TruncC)
return nullptr;
// udiv (zext X), C --> zext (udiv X, C')
// urem (zext X), C --> zext (urem X, C')
- return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty);
+ return new ZExtInst(IC.Builder.CreateBinOp(Opcode, X, TruncC), Ty);
}
if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
match(N, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
- if (ConstantExpr::getZExt(TruncC, Ty) != C)
+ Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ if (!TruncC)
return nullptr;
// udiv C, (zext X) --> zext (udiv C', X)
// urem C, (zext X) --> zext (urem C', X)
- return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty);
+ return new ZExtInst(IC.Builder.CreateBinOp(Opcode, TruncC, X), Ty);
}
return nullptr;
@@ -1346,7 +1372,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
return CastInst::CreateZExtOrBitCast(Cmp, Ty);
}
- if (Instruction *NarrowDiv = narrowUDivURem(I, Builder))
+ if (Instruction *NarrowDiv = narrowUDivURem(I, *this))
return NarrowDiv;
// If the udiv operands are non-overflowing multiplies with a common operand,
@@ -1405,7 +1431,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
// sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined)
if (match(Op1, m_AllOnes()) ||
(match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
- return BinaryOperator::CreateNeg(Op0);
+ return BinaryOperator::CreateNSWNeg(Op0);
// X / INT_MIN --> X == INT_MIN
if (match(Op1, m_SignMask()))
@@ -1428,7 +1454,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1));
Constant *C = ConstantExpr::getExactLogBase2(NegPow2C);
Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true);
- return BinaryOperator::CreateNeg(Ashr);
+ return BinaryOperator::CreateNSWNeg(Ashr);
}
}
@@ -1490,7 +1516,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
if (KnownDividend.isNonNegative()) {
// If both operands are unsigned, turn this into a udiv.
- if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) {
+ if (isKnownNonNegative(Op1, SQ.getWithInstruction(&I))) {
auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
BO->setIsExact(I.isExact());
return BO;
@@ -1516,6 +1542,13 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
}
}
+ // -X / X --> X == INT_MIN ? 1 : -1
+ if (isKnownNegation(Op0, Op1)) {
+ APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
+ Value *Cond = Builder.CreateICmpEQ(Op0, ConstantInt::get(Ty, MinVal));
+ return SelectInst::Create(Cond, ConstantInt::get(Ty, 1),
+ ConstantInt::getAllOnesValue(Ty));
+ }
return nullptr;
}
@@ -1759,6 +1792,21 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return replaceInstUsesWith(I, Pow);
}
+ // powi(X, Y) / X --> powi(X, Y-1)
+ // This is legal when (Y - 1) can't wraparound, in which case reassoc and nnan
+ // are required.
+ // TODO: Multi-use may be also better off creating Powi(x,y-1)
+ if (I.hasAllowReassoc() && I.hasNoNaNs() &&
+ match(Op0, m_OneUse(m_Intrinsic<Intrinsic::powi>(m_Specific(Op1),
+ m_Value(Y)))) &&
+ willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) {
+ Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType());
+ Value *Y1 = Builder.CreateAdd(Y, NegOne);
+ Type *Types[] = {Op1->getType(), Y1->getType()};
+ Value *Pow = Builder.CreateIntrinsic(Intrinsic::powi, Types, {Op1, Y1}, &I);
+ return replaceInstUsesWith(I, Pow);
+ }
+
return nullptr;
}
@@ -1936,7 +1984,7 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
if (Instruction *common = commonIRemTransforms(I))
return common;
- if (Instruction *NarrowRem = narrowUDivURem(I, Builder))
+ if (Instruction *NarrowRem = narrowUDivURem(I, *this))
return NarrowRem;
// X urem Y -> X and Y-1, where Y is a power of 2,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index e24abc48424d..513b185c83a4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -20,7 +20,6 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constant.h"
@@ -98,14 +97,13 @@ static cl::opt<unsigned>
cl::desc("What is the maximal lookup depth when trying to "
"check for viability of negation sinking."));
-Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_,
- const DominatorTree &DT_, bool IsTrulyNegation_)
- : Builder(C, TargetFolder(DL_),
+Negator::Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation_)
+ : Builder(C, TargetFolder(DL),
IRBuilderCallbackInserter([&](Instruction *I) {
++NegatorNumInstructionsCreatedTotal;
NewInstructions.push_back(I);
})),
- DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {}
+ IsTrulyNegation(IsTrulyNegation_) {}
#if LLVM_ENABLE_STATS
Negator::~Negator() {
@@ -128,7 +126,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
// FIXME: can this be reworked into a worklist-based algorithm while preserving
// the depth-first, early bailout traversal?
-[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) {
+[[nodiscard]] Value *Negator::visitImpl(Value *V, bool IsNSW, unsigned Depth) {
// -(undef) -> undef.
if (match(V, m_Undef()))
return V;
@@ -237,7 +235,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
// However, only do this either if the old `sub` doesn't stick around, or
// it was subtracting from a constant. Otherwise, this isn't profitable.
return Builder.CreateSub(I->getOperand(1), I->getOperand(0),
- I->getName() + ".neg");
+ I->getName() + ".neg", /* HasNUW */ false,
+ IsNSW && I->hasNoSignedWrap());
}
// Some other cases, while still don't require recursion,
@@ -302,7 +301,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
switch (I->getOpcode()) {
case Instruction::Freeze: {
// `freeze` is negatible if its operand is negatible.
- Value *NegOp = negate(I->getOperand(0), Depth + 1);
+ Value *NegOp = negate(I->getOperand(0), IsNSW, Depth + 1);
if (!NegOp) // Early return.
return nullptr;
return Builder.CreateFreeze(NegOp, I->getName() + ".neg");
@@ -313,7 +312,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands());
for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) {
if (!(std::get<1>(I) =
- negate(std::get<0>(I), Depth + 1))) // Early return.
+ negate(std::get<0>(I), IsNSW, Depth + 1))) // Early return.
return nullptr;
}
// All incoming values are indeed negatible. Create negated PHI node.
@@ -336,10 +335,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
return NewSelect;
}
// `select` is negatible if both hands of `select` are negatible.
- Value *NegOp1 = negate(I->getOperand(1), Depth + 1);
+ Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1);
if (!NegOp1) // Early return.
return nullptr;
- Value *NegOp2 = negate(I->getOperand(2), Depth + 1);
+ Value *NegOp2 = negate(I->getOperand(2), IsNSW, Depth + 1);
if (!NegOp2)
return nullptr;
// Do preserve the metadata!
@@ -349,10 +348,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
case Instruction::ShuffleVector: {
// `shufflevector` is negatible if both operands are negatible.
auto *Shuf = cast<ShuffleVectorInst>(I);
- Value *NegOp0 = negate(I->getOperand(0), Depth + 1);
+ Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1);
if (!NegOp0) // Early return.
return nullptr;
- Value *NegOp1 = negate(I->getOperand(1), Depth + 1);
+ Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1);
if (!NegOp1)
return nullptr;
return Builder.CreateShuffleVector(NegOp0, NegOp1, Shuf->getShuffleMask(),
@@ -361,7 +360,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
case Instruction::ExtractElement: {
// `extractelement` is negatible if source operand is negatible.
auto *EEI = cast<ExtractElementInst>(I);
- Value *NegVector = negate(EEI->getVectorOperand(), Depth + 1);
+ Value *NegVector = negate(EEI->getVectorOperand(), IsNSW, Depth + 1);
if (!NegVector) // Early return.
return nullptr;
return Builder.CreateExtractElement(NegVector, EEI->getIndexOperand(),
@@ -371,10 +370,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
// `insertelement` is negatible if both the source vector and
// element-to-be-inserted are negatible.
auto *IEI = cast<InsertElementInst>(I);
- Value *NegVector = negate(IEI->getOperand(0), Depth + 1);
+ Value *NegVector = negate(IEI->getOperand(0), IsNSW, Depth + 1);
if (!NegVector) // Early return.
return nullptr;
- Value *NegNewElt = negate(IEI->getOperand(1), Depth + 1);
+ Value *NegNewElt = negate(IEI->getOperand(1), IsNSW, Depth + 1);
if (!NegNewElt) // Early return.
return nullptr;
return Builder.CreateInsertElement(NegVector, NegNewElt, IEI->getOperand(2),
@@ -382,15 +381,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
}
case Instruction::Trunc: {
// `trunc` is negatible if its operand is negatible.
- Value *NegOp = negate(I->getOperand(0), Depth + 1);
+ Value *NegOp = negate(I->getOperand(0), /* IsNSW */ false, Depth + 1);
if (!NegOp) // Early return.
return nullptr;
return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg");
}
case Instruction::Shl: {
// `shl` is negatible if the first operand is negatible.
- if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1))
- return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg");
+ IsNSW &= I->hasNoSignedWrap();
+ if (Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1))
+ return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg",
+ /* HasNUW */ false, IsNSW);
// Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`.
auto *Op1C = dyn_cast<Constant>(I->getOperand(1));
if (!Op1C || !IsTrulyNegation)
@@ -398,11 +399,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
return Builder.CreateMul(
I->getOperand(0),
ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C),
- I->getName() + ".neg");
+ I->getName() + ".neg", /* HasNUW */ false, IsNSW);
}
case Instruction::Or: {
- if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I,
- &DT))
+ if (!cast<PossiblyDisjointInst>(I)->isDisjoint())
return nullptr; // Don't know how to handle `or` in general.
std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I);
// `or`/`add` are interchangeable when operands have no common bits set.
@@ -417,7 +417,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
SmallVector<Value *, 2> NegatedOps, NonNegatedOps;
for (Value *Op : I->operands()) {
// Can we sink the negation into this operand?
- if (Value *NegOp = negate(Op, Depth + 1)) {
+ if (Value *NegOp = negate(Op, /* IsNSW */ false, Depth + 1)) {
NegatedOps.emplace_back(NegOp); // Successfully negated operand!
continue;
}
@@ -446,9 +446,11 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
// `xor` is negatible if one of its operands is invertible.
// FIXME: InstCombineInverter? But how to connect Inverter and Negator?
if (auto *C = dyn_cast<Constant>(Ops[1])) {
- Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C));
- return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1),
- I->getName() + ".neg");
+ if (IsTrulyNegation) {
+ Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C));
+ return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1),
+ I->getName() + ".neg");
+ }
}
return nullptr;
}
@@ -458,16 +460,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
Value *NegatedOp, *OtherOp;
// First try the second operand, in case it's a constant it will be best to
// just invert it instead of sinking the `neg` deeper.
- if (Value *NegOp1 = negate(Ops[1], Depth + 1)) {
+ if (Value *NegOp1 = negate(Ops[1], /* IsNSW */ false, Depth + 1)) {
NegatedOp = NegOp1;
OtherOp = Ops[0];
- } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) {
+ } else if (Value *NegOp0 = negate(Ops[0], /* IsNSW */ false, Depth + 1)) {
NegatedOp = NegOp0;
OtherOp = Ops[1];
} else
// Can't negate either of them.
return nullptr;
- return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg");
+ return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg",
+ /* HasNUW */ false, IsNSW && I->hasNoSignedWrap());
}
default:
return nullptr; // Don't know, likely not negatible for free.
@@ -476,7 +479,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
llvm_unreachable("Can't get here. We always return from switch.");
}
-[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) {
+[[nodiscard]] Value *Negator::negate(Value *V, bool IsNSW, unsigned Depth) {
NegatorMaxDepthVisited.updateMax(Depth);
++NegatorNumValuesVisited;
@@ -506,15 +509,16 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
#endif
// No luck. Try negating it for real.
- Value *NegatedV = visitImpl(V, Depth);
+ Value *NegatedV = visitImpl(V, IsNSW, Depth);
// And cache the (real) result for the future.
NegationsCache[V] = NegatedV;
return NegatedV;
}
-[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) {
- Value *Negated = negate(Root, /*Depth=*/0);
+[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root,
+ bool IsNSW) {
+ Value *Negated = negate(Root, IsNSW, /*Depth=*/0);
if (!Negated) {
// We must cleanup newly-inserted instructions, to avoid any potential
// endless combine looping.
@@ -525,7 +529,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated);
}
-[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root,
+[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, bool IsNSW, Value *Root,
InstCombinerImpl &IC) {
++NegatorTotalNegationsAttempted;
LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root
@@ -534,9 +538,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter))
return nullptr;
- Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(),
- IC.getDominatorTree(), LHSIsZero);
- std::optional<Result> Res = N.run(Root);
+ Negator N(Root->getContext(), IC.getDataLayout(), LHSIsZero);
+ std::optional<Result> Res = N.run(Root, IsNSW);
if (!Res) { // Negation failed.
LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root
<< "\n");
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 2f6aa85062a5..20b34c1379d5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -248,7 +248,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
PHINode *NewPtrPHI = PHINode::Create(
IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr");
- InsertNewInstBefore(NewPtrPHI, PN);
+ InsertNewInstBefore(NewPtrPHI, PN.getIterator());
SmallDenseMap<Value *, Instruction *> Casts;
for (auto Incoming : zip(PN.blocks(), AvailablePtrVals)) {
auto *IncomingBB = std::get<0>(Incoming);
@@ -285,10 +285,10 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
if (isa<PHINode>(IncomingI))
InsertPos = BB->getFirstInsertionPt();
assert(InsertPos != BB->end() && "should have checked above");
- InsertNewInstBefore(CI, *InsertPos);
+ InsertNewInstBefore(CI, InsertPos);
} else {
auto *InsertBB = &IncomingBB->getParent()->getEntryBlock();
- InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt());
+ InsertNewInstBefore(CI, InsertBB->getFirstInsertionPt());
}
}
NewPtrPHI->addIncoming(CI, IncomingBB);
@@ -353,7 +353,7 @@ InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) {
NewOperand->addIncoming(
cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx),
std::get<0>(Incoming));
- InsertNewInstBefore(NewOperand, PN);
+ InsertNewInstBefore(NewOperand, PN.getIterator());
}
// And finally, create `insertvalue` over the newly-formed PHI nodes.
@@ -391,7 +391,7 @@ InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) {
NewAggregateOperand->addIncoming(
cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(),
std::get<0>(Incoming));
- InsertNewInstBefore(NewAggregateOperand, PN);
+ InsertNewInstBefore(NewAggregateOperand, PN.getIterator());
// And finally, create `extractvalue` over the newly-formed PHI nodes.
auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand,
@@ -450,7 +450,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) {
NewLHS = PHINode::Create(LHSType, PN.getNumIncomingValues(),
FirstInst->getOperand(0)->getName() + ".pn");
NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0));
- InsertNewInstBefore(NewLHS, PN);
+ InsertNewInstBefore(NewLHS, PN.getIterator());
LHSVal = NewLHS;
}
@@ -458,7 +458,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) {
NewRHS = PHINode::Create(RHSType, PN.getNumIncomingValues(),
FirstInst->getOperand(1)->getName() + ".pn");
NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0));
- InsertNewInstBefore(NewRHS, PN);
+ InsertNewInstBefore(NewRHS, PN.getIterator());
RHSVal = NewRHS;
}
@@ -581,7 +581,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
Value *FirstOp = FirstInst->getOperand(I);
PHINode *NewPN =
PHINode::Create(FirstOp->getType(), E, FirstOp->getName() + ".pn");
- InsertNewInstBefore(NewPN, PN);
+ InsertNewInstBefore(NewPN, PN.getIterator());
NewPN->addIncoming(FirstOp, PN.getIncomingBlock(0));
OperandPhis[I] = NewPN;
@@ -769,7 +769,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) {
NewLI->setOperand(0, InVal);
delete NewPN;
} else {
- InsertNewInstBefore(NewPN, PN);
+ InsertNewInstBefore(NewPN, PN.getIterator());
}
// If this was a volatile load that we are merging, make sure to loop through
@@ -825,8 +825,8 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
NumZexts++;
} else if (auto *C = dyn_cast<Constant>(V)) {
// Make sure that constants can fit in the new type.
- Constant *Trunc = ConstantExpr::getTrunc(C, NarrowType);
- if (ConstantExpr::getZExt(Trunc, C->getType()) != C)
+ Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType);
+ if (!Trunc)
return nullptr;
NewIncoming.push_back(Trunc);
NumConsts++;
@@ -853,7 +853,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
for (unsigned I = 0; I != NumIncomingValues; ++I)
NewPhi->addIncoming(NewIncoming[I], Phi.getIncomingBlock(I));
- InsertNewInstBefore(NewPhi, Phi);
+ InsertNewInstBefore(NewPhi, Phi.getIterator());
return CastInst::CreateZExtOrBitCast(NewPhi, Phi.getType());
}
@@ -943,7 +943,7 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
PhiVal = InVal;
delete NewPN;
} else {
- InsertNewInstBefore(NewPN, PN);
+ InsertNewInstBefore(NewPN, PN.getIterator());
PhiVal = NewPN;
}
@@ -996,8 +996,8 @@ static bool isDeadPHICycle(PHINode *PN,
/// Return true if this phi node is always equal to NonPhiInVal.
/// This happens with mutually cyclic phi nodes like:
/// z = some value; x = phi (y, z); y = phi (x, z)
-static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal,
- SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) {
+static bool PHIsEqualValue(PHINode *PN, Value *&NonPhiInVal,
+ SmallPtrSetImpl<PHINode *> &ValueEqualPHIs) {
// See if we already saw this PHI node.
if (!ValueEqualPHIs.insert(PN).second)
return true;
@@ -1010,8 +1010,11 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal,
// the value.
for (Value *Op : PN->incoming_values()) {
if (PHINode *OpPN = dyn_cast<PHINode>(Op)) {
- if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs))
- return false;
+ if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) {
+ if (NonPhiInVal)
+ return false;
+ NonPhiInVal = OpPN;
+ }
} else if (Op != NonPhiInVal)
return false;
}
@@ -1368,7 +1371,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// sinking.
auto InsertPt = BB->getFirstInsertionPt();
if (InsertPt != BB->end()) {
- Self.Builder.SetInsertPoint(&*InsertPt);
+ Self.Builder.SetInsertPoint(&*BB, InsertPt);
return Self.Builder.CreateNot(Cond);
}
@@ -1437,22 +1440,45 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// are induction variable analysis (sometimes) and ADCE, which is only run
// late.
if (PHIUser->hasOneUse() &&
- (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) &&
+ (isa<BinaryOperator>(PHIUser) || isa<UnaryOperator>(PHIUser) ||
+ isa<GetElementPtrInst>(PHIUser)) &&
PHIUser->user_back() == &PN) {
return replaceInstUsesWith(PN, PoisonValue::get(PN.getType()));
}
- // When a PHI is used only to be compared with zero, it is safe to replace
- // an incoming value proved as known nonzero with any non-zero constant.
- // For example, in the code below, the incoming value %v can be replaced
- // with any non-zero constant based on the fact that the PHI is only used to
- // be compared with zero and %v is a known non-zero value:
- // %v = select %cond, 1, 2
- // %p = phi [%v, BB] ...
- // icmp eq, %p, 0
- auto *CmpInst = dyn_cast<ICmpInst>(PHIUser);
- // FIXME: To be simple, handle only integer type for now.
- if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() &&
- match(CmpInst->getOperand(1), m_Zero())) {
+ }
+
+ // When a PHI is used only to be compared with zero, it is safe to replace
+ // an incoming value proved as known nonzero with any non-zero constant.
+ // For example, in the code below, the incoming value %v can be replaced
+ // with any non-zero constant based on the fact that the PHI is only used to
+ // be compared with zero and %v is a known non-zero value:
+ // %v = select %cond, 1, 2
+ // %p = phi [%v, BB] ...
+ // icmp eq, %p, 0
+ // FIXME: To be simple, handle only integer type for now.
+ // This handles a small number of uses to keep the complexity down, and an
+ // icmp(or(phi)) can equally be replaced with any non-zero constant as the
+ // "or" will only add bits.
+ if (!PN.hasNUsesOrMore(3)) {
+ SmallVector<Instruction *> DropPoisonFlags;
+ bool AllUsesOfPhiEndsInCmp = all_of(PN.users(), [&](User *U) {
+ auto *CmpInst = dyn_cast<ICmpInst>(U);
+ if (!CmpInst) {
+ // This is always correct as OR only add bits and we are checking
+ // against 0.
+ if (U->hasOneUse() && match(U, m_c_Or(m_Specific(&PN), m_Value()))) {
+ DropPoisonFlags.push_back(cast<Instruction>(U));
+ CmpInst = dyn_cast<ICmpInst>(U->user_back());
+ }
+ }
+ if (!CmpInst || !isa<IntegerType>(PN.getType()) ||
+ !CmpInst->isEquality() || !match(CmpInst->getOperand(1), m_Zero())) {
+ return false;
+ }
+ return true;
+ });
+ // All uses of PHI results in a compare with zero.
+ if (AllUsesOfPhiEndsInCmp) {
ConstantInt *NonZeroConst = nullptr;
bool MadeChange = false;
for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
@@ -1461,9 +1487,11 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) {
if (!NonZeroConst)
NonZeroConst = getAnyNonZeroConstInt(PN);
-
if (NonZeroConst != VA) {
replaceOperand(PN, I, NonZeroConst);
+ // The "disjoint" flag may no longer hold after the transform.
+ for (Instruction *I : DropPoisonFlags)
+ I->dropPoisonGeneratingFlags();
MadeChange = true;
}
}
@@ -1478,7 +1506,9 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// z = some value; x = phi (y, z); y = phi (x, z)
// where the phi nodes don't necessarily need to be in the same block. Do a
// quick check to see if the PHI node only contains a single non-phi value, if
- // so, scan to see if the phi cycle is actually equal to that value.
+ // so, scan to see if the phi cycle is actually equal to that value. If the
+ // phi has no non-phi values then allow the "NonPhiInVal" to be set later if
+ // one of the phis itself does not have a single input.
{
unsigned InValNo = 0, NumIncomingVals = PN.getNumIncomingValues();
// Scan for the first non-phi operand.
@@ -1486,25 +1516,25 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
isa<PHINode>(PN.getIncomingValue(InValNo)))
++InValNo;
- if (InValNo != NumIncomingVals) {
- Value *NonPhiInVal = PN.getIncomingValue(InValNo);
+ Value *NonPhiInVal =
+ InValNo != NumIncomingVals ? PN.getIncomingValue(InValNo) : nullptr;
- // Scan the rest of the operands to see if there are any conflicts, if so
- // there is no need to recursively scan other phis.
+ // Scan the rest of the operands to see if there are any conflicts, if so
+ // there is no need to recursively scan other phis.
+ if (NonPhiInVal)
for (++InValNo; InValNo != NumIncomingVals; ++InValNo) {
Value *OpVal = PN.getIncomingValue(InValNo);
if (OpVal != NonPhiInVal && !isa<PHINode>(OpVal))
break;
}
- // If we scanned over all operands, then we have one unique value plus
- // phi values. Scan PHI nodes to see if they all merge in each other or
- // the value.
- if (InValNo == NumIncomingVals) {
- SmallPtrSet<PHINode*, 16> ValueEqualPHIs;
- if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs))
- return replaceInstUsesWith(PN, NonPhiInVal);
- }
+ // If we scanned over all operands, then we have one unique value plus
+ // phi values. Scan PHI nodes to see if they all merge in each other or
+ // the value.
+ if (InValNo == NumIncomingVals) {
+ SmallPtrSet<PHINode *, 16> ValueEqualPHIs;
+ if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs))
+ return replaceInstUsesWith(PN, NonPhiInVal);
}
}
@@ -1512,11 +1542,12 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// the blocks in the same order. This will help identical PHIs be eliminated
// by other passes. Other passes shouldn't depend on this for correctness
// however.
- PHINode *FirstPN = cast<PHINode>(PN.getParent()->begin());
- if (&PN != FirstPN)
- for (unsigned I = 0, E = FirstPN->getNumIncomingValues(); I != E; ++I) {
+ auto Res = PredOrder.try_emplace(PN.getParent());
+ if (!Res.second) {
+ const auto &Preds = Res.first->second;
+ for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
BasicBlock *BBA = PN.getIncomingBlock(I);
- BasicBlock *BBB = FirstPN->getIncomingBlock(I);
+ BasicBlock *BBB = Preds[I];
if (BBA != BBB) {
Value *VA = PN.getIncomingValue(I);
unsigned J = PN.getBasicBlockIndex(BBB);
@@ -1531,6 +1562,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// this in this case.
}
}
+ } else {
+ // Remember the block order of the first encountered phi node.
+ append_range(Res.first->second, PN.blocks());
+ }
// Is there an identical PHI node in this basic block?
for (PHINode &IdenticalPN : PN.getParent()->phis()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 661c50062223..2dda46986f0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -689,34 +689,40 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
}
/// We want to turn:
-/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
+/// (select (icmp eq (and X, C1), 0), Y, (BinOp Y, C2))
/// into:
-/// (or (shl (and X, C1), C3), Y)
+/// IF C2 u>= C1
+/// (BinOp Y, (shl (and X, C1), C3))
+/// ELSE
+/// (BinOp Y, (lshr (and X, C1), C3))
/// iff:
+/// 0 on the RHS is the identity value (i.e add, xor, shl, etc...)
/// C1 and C2 are both powers of 2
/// where:
-/// C3 = Log(C2) - Log(C1)
+/// IF C2 u>= C1
+/// C3 = Log(C2) - Log(C1)
+/// ELSE
+/// C3 = Log(C1) - Log(C2)
///
/// This transform handles cases where:
/// 1. The icmp predicate is inverted
/// 2. The select operands are reversed
/// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
+static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
Value *FalseVal,
InstCombiner::BuilderTy &Builder) {
// Only handle integer compares. Also, if this is a vector select, we need a
// vector compare.
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
- TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
+ TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
return nullptr;
Value *CmpLHS = IC->getOperand(0);
Value *CmpRHS = IC->getOperand(1);
- Value *V;
unsigned C1Log;
- bool IsEqualZero;
bool NeedAnd = false;
+ CmpInst::Predicate Pred = IC->getPredicate();
if (IC->isEquality()) {
if (!match(CmpRHS, m_Zero()))
return nullptr;
@@ -725,49 +731,49 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
return nullptr;
- V = CmpLHS;
C1Log = C1->logBase2();
- IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ;
- } else if (IC->getPredicate() == ICmpInst::ICMP_SLT ||
- IC->getPredicate() == ICmpInst::ICMP_SGT) {
- // We also need to recognize (icmp slt (trunc (X)), 0) and
- // (icmp sgt (trunc (X)), -1).
- IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT;
- if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) ||
- (!IsEqualZero && !match(CmpRHS, m_Zero())))
- return nullptr;
-
- if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V)))))
+ } else {
+ APInt C1;
+ if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) ||
+ !C1.isPowerOf2())
return nullptr;
- C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1;
+ C1Log = C1.logBase2();
NeedAnd = true;
- } else {
- return nullptr;
}
+ Value *Y, *V = CmpLHS;
+ BinaryOperator *BinOp;
const APInt *C2;
- bool OrOnTrueVal = false;
- bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)));
- if (!OrOnFalseVal)
- OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)));
-
- if (!OrOnFalseVal && !OrOnTrueVal)
+ bool NeedXor;
+ if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
+ Y = TrueVal;
+ BinOp = cast<BinaryOperator>(FalseVal);
+ NeedXor = Pred == ICmpInst::ICMP_NE;
+ } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
+ Y = FalseVal;
+ BinOp = cast<BinaryOperator>(TrueVal);
+ NeedXor = Pred == ICmpInst::ICMP_EQ;
+ } else {
return nullptr;
+ }
- Value *Y = OrOnFalseVal ? TrueVal : FalseVal;
+ // Check that 0 on RHS is identity value for this binop.
+ auto *IdentityC =
+ ConstantExpr::getBinOpIdentity(BinOp->getOpcode(), BinOp->getType(),
+ /*AllowRHSConstant*/ true);
+ if (IdentityC == nullptr || !IdentityC->isNullValue())
+ return nullptr;
unsigned C2Log = C2->logBase2();
- bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);
bool NeedShift = C1Log != C2Log;
bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() !=
V->getType()->getScalarSizeInBits();
// Make sure we don't create more instructions than we save.
- Value *Or = OrOnFalseVal ? FalseVal : TrueVal;
- if ((NeedShift + NeedXor + NeedZExtTrunc) >
- (IC->hasOneUse() + Or->hasOneUse()))
+ if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
+ (IC->hasOneUse() + BinOp->hasOneUse()))
return nullptr;
if (NeedAnd) {
@@ -788,7 +794,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
if (NeedXor)
V = Builder.CreateXor(V, *C2);
- return Builder.CreateOr(V, Y);
+ return Builder.CreateBinOp(BinOp->getOpcode(), Y, V);
}
/// Canonicalize a set or clear of a masked set of constant bits to
@@ -870,7 +876,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
auto *FalseValI = cast<Instruction>(FalseVal);
auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"),
- *FalseValI);
+ FalseValI->getIterator());
IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY);
return IC.replaceInstUsesWith(SI, FalseValI);
}
@@ -1303,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
return nullptr;
// InstSimplify already performed this fold if it was possible subject to
- // current poison-generating flags. Try the transform again with
- // poison-generating flags temporarily dropped.
- bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
- if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
- WasNUW = OBO->hasNoUnsignedWrap();
- WasNSW = OBO->hasNoSignedWrap();
- FalseInst->setHasNoUnsignedWrap(false);
- FalseInst->setHasNoSignedWrap(false);
- }
- if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
- WasExact = PEO->isExact();
- FalseInst->setIsExact(false);
- }
- if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
- WasInBounds = GEP->isInBounds();
- GEP->setIsInBounds(false);
- }
+ // current poison-generating flags. Check whether dropping poison-generating
+ // flags enables the transform.
// Try each equivalence substitution possibility.
// We have an 'EQ' comparison, so the select's false value will propagate.
// Example:
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
+ SmallVector<Instruction *> DropFlags;
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
- /* AllowRefinement */ false) == TrueVal ||
+ /* AllowRefinement */ false,
+ &DropFlags) == TrueVal ||
simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
- /* AllowRefinement */ false) == TrueVal) {
+ /* AllowRefinement */ false,
+ &DropFlags) == TrueVal) {
+ for (Instruction *I : DropFlags) {
+ I->dropPoisonGeneratingFlagsAndMetadata();
+ Worklist.add(I);
+ }
+
return replaceInstUsesWith(Sel, FalseVal);
}
- // Restore poison-generating flags if the transform did not apply.
- if (WasNUW)
- FalseInst->setHasNoUnsignedWrap();
- if (WasNSW)
- FalseInst->setHasNoSignedWrap();
- if (WasExact)
- FalseInst->setIsExact();
- if (WasInBounds)
- cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
-
return nullptr;
}
@@ -1506,8 +1495,13 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
if (!match(ReplacementLow, m_ImmConstant(LowC)) ||
!match(ReplacementHigh, m_ImmConstant(HighC)))
return nullptr;
- ReplacementLow = ConstantExpr::getSExt(LowC, X->getType());
- ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType());
+ const DataLayout &DL = Sel0.getModule()->getDataLayout();
+ ReplacementLow =
+ ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL);
+ ReplacementHigh =
+ ConstantFoldCastOperand(Instruction::SExt, HighC, X->getType(), DL);
+ assert(ReplacementLow && ReplacementHigh &&
+ "Constant folding of ImmConstant cannot fail");
}
// All good, finally emit the new pattern.
@@ -1797,7 +1791,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
return V;
- if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
+ if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
@@ -2094,9 +2088,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
// If the constant is the same after truncation to the smaller type and
// extension to the original type, we can narrow the select.
Type *SelType = Sel.getType();
- Constant *TruncC = ConstantExpr::getTrunc(C, SmallType);
- Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType);
- if (ExtC == C && ExtInst->hasOneUse()) {
+ Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
+ if (TruncC && ExtInst->hasOneUse()) {
Value *TruncCVal = cast<Value>(TruncC);
if (ExtInst == Sel.getFalseValue())
std::swap(X, TruncCVal);
@@ -2107,23 +2100,6 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType);
}
- // If one arm of the select is the extend of the condition, replace that arm
- // with the extension of the appropriate known bool value.
- if (Cond == X) {
- if (ExtInst == Sel.getTrueValue()) {
- // select X, (sext X), C --> select X, -1, C
- // select X, (zext X), C --> select X, 1, C
- Constant *One = ConstantInt::getTrue(SmallType);
- Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType);
- return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel);
- } else {
- // select X, C, (sext X) --> select X, C, 0
- // select X, C, (zext X) --> select X, C, 0
- Constant *Zero = ConstantInt::getNullValue(SelType);
- return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel);
- }
- }
-
return nullptr;
}
@@ -2561,7 +2537,7 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB,
return nullptr;
}
- Builder.SetInsertPoint(&*BB->begin());
+ Builder.SetInsertPoint(BB, BB->begin());
auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size());
for (auto *Pred : predecessors(BB))
PN->addIncoming(Inputs[Pred], Pred);
@@ -2584,6 +2560,61 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
return nullptr;
}
+/// Tries to reduce a pattern that arises when calculating the remainder of the
+/// Euclidean division. When the divisor is a power of two and is guaranteed not
+/// to be negative, a signed remainder can be folded with a bitwise and.
+///
+/// (x % n) < 0 ? (x % n) + n : (x % n)
+/// -> x & (n - 1)
+static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
+ IRBuilderBase &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
+ ICmpInst::Predicate Pred;
+ Value *Op, *RemRes, *Remainder;
+ const APInt *C;
+ bool TrueIfSigned = false;
+
+ if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) &&
+ IC.isSignBitCheck(Pred, *C, TrueIfSigned)))
+ return nullptr;
+
+ // If the sign bit is not set, we have a SGE/SGT comparison, and the operands
+ // of the select are inverted.
+ if (!TrueIfSigned)
+ std::swap(TrueVal, FalseVal);
+
+ auto FoldToBitwiseAnd = [&](Value *Remainder) -> Instruction * {
+ Value *Add = Builder.CreateAdd(
+ Remainder, Constant::getAllOnesValue(RemRes->getType()));
+ return BinaryOperator::CreateAnd(Op, Add);
+ };
+
+ // Match the general case:
+ // %rem = srem i32 %x, %n
+ // %cnd = icmp slt i32 %rem, 0
+ // %add = add i32 %rem, %n
+ // %sel = select i1 %cnd, i32 %add, i32 %rem
+ if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) &&
+ match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) &&
+ IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) &&
+ FalseVal == RemRes)
+ return FoldToBitwiseAnd(Remainder);
+
+ // Match the case where the one arm has been replaced by constant 1:
+ // %rem = srem i32 %n, 2
+ // %cnd = icmp slt i32 %rem, 0
+ // %sel = select i1 %cnd, i32 1, i32 %rem
+ if (match(TrueVal, m_One()) &&
+ match(RemRes, m_SRem(m_Value(Op), m_SpecificInt(2))) &&
+ FalseVal == RemRes)
+ return FoldToBitwiseAnd(ConstantInt::get(RemRes->getType(), 2));
+
+ return nullptr;
+}
+
static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
if (!FI)
@@ -2860,8 +2891,15 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
Value *AltCond = nullptr;
- auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) {
- return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond)));
+ auto matchOuterCond = [OuterSel, IsAndVariant, &AltCond](auto m_InnerCond) {
+ // An unsimplified select condition can match both LogicalAnd and LogicalOr
+ // (select true, true, false). Since below we assume that LogicalAnd implies
+ // InnerSel match the FVal and vice versa for LogicalOr, we can't match the
+ // alternative pattern here.
+ return IsAndVariant ? match(OuterSel.Cond,
+ m_c_LogicalAnd(m_InnerCond, m_Value(AltCond)))
+ : match(OuterSel.Cond,
+ m_c_LogicalOr(m_InnerCond, m_Value(AltCond)));
};
// Finally, match the condition that was driving the outermost `select`,
@@ -3024,31 +3062,37 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
return replaceOperand(SI, 0, A);
- // select a, (select ~a, true, b), false -> select a, b, false
- if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) &&
- match(FalseVal, m_Zero()))
- return replaceOperand(SI, 1, B);
- // select a, true, (select ~a, b, false) -> select a, true, b
- if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) &&
- match(TrueVal, m_One()))
- return replaceOperand(SI, 2, B);
// ~(A & B) & (A | B) --> A ^ B
if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
m_c_LogicalOr(m_Deferred(A), m_Deferred(B)))))
return BinaryOperator::CreateXor(A, B);
- // select (~a | c), a, b -> and a, (or c, freeze(b))
- if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) &&
- CondVal->hasOneUse()) {
- FalseVal = Builder.CreateFreeze(FalseVal);
- return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal));
+ // select (~a | c), a, b -> select a, (select c, true, b), false
+ if (match(CondVal,
+ m_OneUse(m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))))) {
+ Value *OrV = Builder.CreateSelect(C, One, FalseVal);
+ return SelectInst::Create(TrueVal, OrV, Zero);
+ }
+ // select (c & b), a, b -> select b, (select ~c, true, a), false
+ if (match(CondVal, m_OneUse(m_c_And(m_Value(C), m_Specific(FalseVal))))) {
+ if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) {
+ Value *OrV = Builder.CreateSelect(NotC, One, TrueVal);
+ return SelectInst::Create(FalseVal, OrV, Zero);
+ }
+ }
+ // select (a | c), a, b -> select a, true, (select ~c, b, false)
+ if (match(CondVal, m_OneUse(m_c_Or(m_Specific(TrueVal), m_Value(C))))) {
+ if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) {
+ Value *AndV = Builder.CreateSelect(NotC, FalseVal, Zero);
+ return SelectInst::Create(TrueVal, One, AndV);
+ }
}
- // select (~c & b), a, b -> and b, (or freeze(a), c)
- if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) &&
- CondVal->hasOneUse()) {
- TrueVal = Builder.CreateFreeze(TrueVal);
- return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
+ // select (c & ~b), a, b -> select b, true, (select c, a, false)
+ if (match(CondVal,
+ m_OneUse(m_c_And(m_Value(C), m_Not(m_Specific(FalseVal)))))) {
+ Value *AndV = Builder.CreateSelect(C, TrueVal, Zero);
+ return SelectInst::Create(FalseVal, One, AndV);
}
if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
@@ -3057,7 +3101,7 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
Value *Op1 = IsAnd ? TrueVal : FalseVal;
if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) {
auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr");
- InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser()));
+ InsertNewInstBefore(FI, cast<Instruction>(Y->getUser())->getIterator());
replaceUse(*Y, FI);
return replaceInstUsesWith(SI, Op1);
}
@@ -3272,6 +3316,31 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Masked);
}
+bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
+ const Instruction *CtxI) const {
+ KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
+
+ return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity() &&
+ (FMF.noSignedZeros() || Known.signBitIsZeroOrNaN());
+}
+
+static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0,
+ Value *Cmp1, Value *TrueVal,
+ Value *FalseVal, Instruction &CtxI,
+ bool SelectIsNSZ) {
+ Value *MulRHS;
+ if (match(Cmp1, m_PosZeroFP()) &&
+ match(TrueVal, m_c_FMul(m_Specific(Cmp0), m_Value(MulRHS)))) {
+ FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags();
+ // nsz must be on the select, it must be ignored on the multiply. We
+ // need nnan and ninf on the multiply for the other value.
+ FMF.setNoSignedZeros(SelectIsNSZ);
+ return IC.fmulByZeroIsZero(MulRHS, FMF, &CtxI);
+ }
+
+ return false;
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3303,28 +3372,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
ConstantInt::getFalse(CondType), SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 2, S);
-
- // Handle patterns involving sext/zext + not explicitly,
- // as simplifyWithOpReplaced() only looks past one instruction.
- Value *NotCond;
-
- // select a, sext(!a), b -> select !a, b, 0
- // select a, zext(!a), b -> select !a, b, 0
- if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, FalseVal,
- Constant::getNullValue(SelType));
-
- // select a, b, zext(!a) -> select !a, 1, b
- if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal);
-
- // select a, b, sext(!a) -> select !a, -1, b
- if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond),
- m_Not(m_Specific(CondVal))))))
- return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType),
- TrueVal);
}
if (Instruction *R = foldSelectOfBools(SI))
@@ -3362,7 +3409,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ auto *SIFPOp = dyn_cast<FPMathOperator>(&SI);
+
if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) {
+ FCmpInst::Predicate Pred = FCmp->getPredicate();
Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1);
// Are we selecting a value based on a comparison of the two values?
if ((Cmp0 == TrueVal && Cmp1 == FalseVal) ||
@@ -3372,7 +3422,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
//
// e.g.
// (X ugt Y) ? X : Y -> (X ole Y) ? Y : X
- if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) {
+ if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) {
FCmpInst::Predicate InvPred = FCmp->getInversePredicate();
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
// FIXME: The FMF should propagate from the select, not the fcmp.
@@ -3383,14 +3433,47 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceInstUsesWith(SI, NewSel);
}
}
+
+ if (SIFPOp) {
+ // Fold out scale-if-equals-zero pattern.
+ //
+ // This pattern appears in code with denormal range checks after it's
+ // assumed denormals are treated as zero. This drops a canonicalization.
+
+ // TODO: Could relax the signed zero logic. We just need to know the sign
+ // of the result matches (fmul x, y has the same sign as x).
+ //
+ // TODO: Handle always-canonicalizing variant that selects some value or 1
+ // scaling factor in the fmul visitor.
+
+ // TODO: Handle ldexp too
+
+ Value *MatchCmp0 = nullptr;
+ Value *MatchCmp1 = nullptr;
+
+ // (select (fcmp [ou]eq x, 0.0), (fmul x, K), x => x
+ // (select (fcmp [ou]ne x, 0.0), x, (fmul x, K) => x
+ if (Pred == CmpInst::FCMP_OEQ || Pred == CmpInst::FCMP_UEQ) {
+ MatchCmp0 = FalseVal;
+ MatchCmp1 = TrueVal;
+ } else if (Pred == CmpInst::FCMP_ONE || Pred == CmpInst::FCMP_UNE) {
+ MatchCmp0 = TrueVal;
+ MatchCmp1 = FalseVal;
+ }
+
+ if (Cmp0 == MatchCmp0 &&
+ matchFMulByZeroIfResultEqZero(*this, Cmp0, Cmp1, MatchCmp1, MatchCmp0,
+ SI, SIFPOp->hasNoSignedZeros()))
+ return replaceInstUsesWith(SI, Cmp0);
+ }
}
- if (isa<FPMathOperator>(SI)) {
+ if (SIFPOp) {
// TODO: Try to forward-propagate FMF from select arms to the select.
// Canonicalize select of FP values where NaN and -0.0 are not valid as
// minnum/maxnum intrinsics.
- if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
+ if (SIFPOp->hasNoNaNs() && SIFPOp->hasNoSignedZeros()) {
Value *X, *Y;
if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
return replaceInstUsesWith(
@@ -3430,6 +3513,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldSelectExtConst(SI))
return I;
+ if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
+ return I;
+
// Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0))
// Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx))
auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 89dad455f015..b7958978c450 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -136,9 +136,14 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
assert(IdenticalShOpcodes && "Should not get here with different shifts.");
- // All good, we can do this fold.
- NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
+ if (NewShAmt->getType() != X->getType()) {
+ NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt,
+ X->getType(), SQ.DL);
+ if (!NewShAmt)
+ return nullptr;
+ }
+ // All good, we can do this fold.
BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
// The flags can only be propagated if there wasn't a trunc.
@@ -245,7 +250,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
SumOfShAmts = Constant::replaceUndefsWith(
SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
ExtendedTy->getScalarSizeInBits()));
- auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+ auto *ExtendedSumOfShAmts = ConstantFoldCastOperand(
+ Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL);
+ if (!ExtendedSumOfShAmts)
+ return nullptr;
+
// And compute the mask as usual: ~(-1 << (SumOfShAmts))
auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
auto *ExtendedInvertedMask =
@@ -278,16 +287,22 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
ShAmtsDiff = Constant::replaceUndefsWith(
ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
-WidestTyBitWidth));
- auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+ auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand(
+ Instruction::ZExt,
ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
WidestTyBitWidth,
/*isSigned=*/false),
ShAmtsDiff),
- ExtendedTy);
+ ExtendedTy, Q.DL);
+ if (!ExtendedNumHighBitsToClear)
+ return nullptr;
+
// And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
- NewMask =
- ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+ NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes,
+ ExtendedNumHighBitsToClear, Q.DL);
+ if (!NewMask)
+ return nullptr;
} else
return nullptr; // Don't know anything about this pattern.
@@ -545,8 +560,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
/// this succeeds, getShiftedValue() will be called to produce the value.
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombinerImpl &IC, Instruction *CxtI) {
- // We can always evaluate constants shifted.
- if (isa<Constant>(V))
+ // We can always evaluate immediate constants.
+ if (match(V, m_ImmConstant()))
return true;
Instruction *I = dyn_cast<Instruction>(V);
@@ -709,13 +724,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Mul: {
assert(!isLeftShift && "Unexpected shift direction!");
auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0));
- IC.InsertNewInstWith(Neg, *I);
+ IC.InsertNewInstWith(Neg, I->getIterator());
unsigned TypeWidth = I->getType()->getScalarSizeInBits();
APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits);
auto *And = BinaryOperator::CreateAnd(Neg,
ConstantInt::get(I->getType(), Mask));
And->takeName(I);
- return IC.InsertNewInstWith(And, *I);
+ return IC.InsertNewInstWith(And, I->getIterator());
}
}
}
@@ -745,7 +760,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
// (C2 >> X) >> C1 --> (C2 >> C1) >> X
Constant *C2;
Value *X;
- if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X))))
+ if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X))))
return BinaryOperator::Create(
I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
@@ -928,6 +943,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
+// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
+static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
+ assert(I.isShift() && "Expected a shift as input");
+ // We already have all the flags.
+ if (I.getOpcode() == Instruction::Shl) {
+ if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap())
+ return false;
+ } else {
+ if (I.isExact())
+ return false;
+
+ // shr (shl X, Y), Y
+ if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) {
+ I.setIsExact();
+ return true;
+ }
+ }
+
+ // Compute what we know about shift count.
+ KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
+ unsigned BitWidth = KnownCnt.getBitWidth();
+ // Since shift produces a poison value if RHS is equal to or larger than the
+ // bit width, we can safely assume that RHS is less than the bit width.
+ uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
+
+ KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
+ bool Changed = false;
+
+ if (I.getOpcode() == Instruction::Shl) {
+ // If we have as many leading zeros than maximum shift cnt we have nuw.
+ if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) {
+ I.setHasNoUnsignedWrap();
+ Changed = true;
+ }
+ // If we have more sign bits than maximum shift cnt we have nsw.
+ if (!I.hasNoSignedWrap()) {
+ if (MaxCnt < KnownAmt.countMinSignBits() ||
+ MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC,
+ Q.CxtI, Q.DT)) {
+ I.setHasNoSignedWrap();
+ Changed = true;
+ }
+ }
+ return Changed;
+ }
+
+ // If we have at least as many trailing zeros as maximum count then we have
+ // exact.
+ Changed = MaxCnt <= KnownAmt.countMinTrailingZeros();
+ I.setIsExact(Changed);
+
+ return Changed;
+}
+
Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -976,7 +1045,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
// If C1 < C: (X >>?,exact C1) << C --> X << (C - C1)
Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoUnsignedWrap(
+ I.hasNoUnsignedWrap() ||
+ (ShrAmt &&
+ cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
+ I.hasNoSignedWrap()));
NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
return NewShl;
}
@@ -997,7 +1070,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
// If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C)
Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoUnsignedWrap(
+ I.hasNoUnsignedWrap() ||
+ (ShrAmt &&
+ cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
+ I.hasNoSignedWrap()));
NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
Builder.Insert(NewShl);
APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
@@ -1108,22 +1185,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
Value *NewShift = Builder.CreateShl(X, Op1);
return BinaryOperator::CreateSub(NewLHS, NewShift);
}
-
- // If the shifted-out value is known-zero, then this is a NUW shift.
- if (!I.hasNoUnsignedWrap() &&
- MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0,
- &I)) {
- I.setHasNoUnsignedWrap();
- return &I;
- }
-
- // If the shifted-out value is all signbits, then this is a NSW shift.
- if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) {
- I.setHasNoSignedWrap();
- return &I;
- }
}
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Transform (x >> y) << y to x & (-1 << y)
// Valid for any type of right-shift.
Value *X;
@@ -1161,15 +1227,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
Value *NegX = Builder.CreateNeg(X, "neg");
return BinaryOperator::CreateAnd(NegX, X);
}
-
- // The only way to shift out the 1 is with an over-shift, so that would
- // be poison with or without "nuw". Undef is excluded because (undef << X)
- // is not undef (it is zero).
- Constant *ConstantOne = cast<Constant>(Op0);
- if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) {
- I.setHasNoUnsignedWrap();
- return &I;
- }
}
return nullptr;
@@ -1235,9 +1292,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
unsigned ShlAmtC = C1->getZExtValue();
Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC);
if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
- // (X <<nuw C1) >>u C --> X <<nuw (C1 - C)
+ // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C)
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
NewShl->setHasNoUnsignedWrap(true);
+ NewShl->setHasNoSignedWrap(ShAmtC > 0);
return NewShl;
}
if (Op0->hasOneUse()) {
@@ -1370,12 +1428,13 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
if (Op0->hasOneUse()) {
APInt NewMulC = MulC->lshr(ShAmtC);
// if c is divisible by (1 << ShAmtC):
- // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC)
+ // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
if (MulC->eq(NewMulC.shl(ShAmtC))) {
auto *NewMul =
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
- BinaryOperator *OrigMul = cast<BinaryOperator>(Op0);
- NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap());
+ assert(ShAmtC != 0 &&
+ "lshr X, 0 should be handled by simplifyLShrInst.");
+ NewMul->setHasNoSignedWrap(true);
return NewMul;
}
}
@@ -1414,15 +1473,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
Value *And = Builder.CreateAnd(BoolX, BoolY);
return new ZExtInst(And, Ty);
}
-
- // If the shifted-out value is known-zero, then this is an exact shift.
- if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
- I.setIsExact();
- return &I;
- }
}
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Transform (x << y) >> y to x & (-1 >> y)
if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
@@ -1581,15 +1637,12 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
}
-
- // If the shifted-out value is known-zero, then this is an exact shift.
- if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
- I.setIsExact();
- return &I;
- }
}
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)`
// as the pattern to splat the lowest bit.
// FIXME: iff X is already masked, we don't need the one-use check.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 00eece9534b0..046ce9d1207e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -24,6 +24,12 @@ using namespace llvm::PatternMatch;
#define DEBUG_TYPE "instcombine"
+static cl::opt<bool>
+ VerifyKnownBits("instcombine-verify-known-bits",
+ cl::desc("Verify that computeKnownBits() and "
+ "SimplifyDemandedBits() are consistent"),
+ cl::Hidden, cl::init(false));
+
/// Check to see if the specified operand of the specified instruction is a
/// constant integer. If so, check to see if there are any bits set in the
/// constant that are not demanded. If so, shrink the constant and return true.
@@ -48,15 +54,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}
+/// Returns the bitwidth of the given scalar or pointer type. For vector types,
+/// returns the element type's bitwidth.
+static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+ if (unsigned BitWidth = Ty->getScalarSizeInBits())
+ return BitWidth;
+ return DL.getPointerTypeSizeInBits(Ty);
+}
/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
/// the instruction has any properties that allow us to simplify its operands.
-bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
- unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
- KnownBits Known(BitWidth);
- APInt DemandedMask(APInt::getAllOnes(BitWidth));
-
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
+ KnownBits &Known) {
+ APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
0, &Inst);
if (!V) return false;
@@ -65,6 +76,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
return true;
}
+/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
+/// the instruction has any properties that allow us to simplify its operands.
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
+ KnownBits Known(getBitWidth(Inst.getType(), DL));
+ return SimplifyDemandedInstructionBits(Inst, Known);
+}
+
/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
@@ -95,8 +113,8 @@ bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
/// expression.
/// Known.One and Known.Zero always follow the invariant that:
/// Known.One & Known.Zero == 0.
-/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and
-/// Known.Zero may only be accurate for those bits set in DemandedMask. Note
+/// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero
+/// are accurate even for bits not in DemandedMask. Note
/// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all
/// be the same.
///
@@ -143,7 +161,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
-
// If this is the root being simplified, allow it to have multiple uses,
// just set the DemandedMask to all bits so that we can try to simplify the
// operands. This allows visitTruncInst (for example) to simplify the
@@ -196,7 +213,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
- Depth, DL, &AC, CxtI, &DT);
+ Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -220,13 +237,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If either the LHS or the RHS are One, the result is One.
if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
- Depth + 1))
+ Depth + 1)) {
+ // Disjoint flag may not longer hold.
+ I->dropPoisonGeneratingFlags();
return I;
+ }
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
- Depth, DL, &AC, CxtI, &DT);
+ Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -244,6 +264,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
+ // Infer disjoint flag if no common bits are set.
+ if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) {
+ WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown),
+ RHSCache(I->getOperand(1), RHSKnown);
+ if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(I))) {
+ cast<PossiblyDisjointInst>(I)->setIsDisjoint(true);
+ return I;
+ }
+ }
+
break;
}
case Instruction::Xor: {
@@ -265,7 +295,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
- Depth, DL, &AC, CxtI, &DT);
+ Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -284,9 +314,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) {
Instruction *Or =
- BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
- I->getName());
- return InsertNewInstWith(Or, *I);
+ BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1));
+ if (DemandedMask.isAllOnes())
+ cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true);
+ Or->takeName(I);
+ return InsertNewInstWith(Or, I->getIterator());
}
// If all of the demanded bits on one side are known, and all of the set
@@ -298,7 +330,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Constant *AndC = Constant::getIntegerValue(VTy,
~RHSKnown.One & DemandedMask);
Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
- return InsertNewInstWith(And, *I);
+ return InsertNewInstWith(And, I->getIterator());
}
// If the RHS is a constant, see if we can change it. Don't alter a -1
@@ -330,11 +362,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue());
Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
- InsertNewInstWith(NewAnd, *I);
+ InsertNewInstWith(NewAnd, I->getIterator());
Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue());
Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
- return InsertNewInstWith(NewXor, *I);
+ return InsertNewInstWith(NewXor, I->getIterator());
}
}
break;
@@ -411,36 +443,21 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
KnownBits InputKnown(SrcBitWidth);
- if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) {
+ // For zext nneg, we may have dropped the instruction which made the
+ // input non-negative.
+ I->dropPoisonGeneratingFlags();
return I;
+ }
assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
+ if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() &&
+ !InputKnown.isNegative())
+ InputKnown.makeNonNegative();
Known = InputKnown.zextOrTrunc(BitWidth);
- assert(!Known.hasConflict() && "Bits known to be one AND zero?");
- break;
- }
- case Instruction::BitCast:
- if (!I->getOperand(0)->getType()->isIntOrIntVectorTy())
- return nullptr; // vector->int or fp->int?
-
- if (auto *DstVTy = dyn_cast<VectorType>(VTy)) {
- if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) {
- if (isa<ScalableVectorType>(DstVTy) ||
- isa<ScalableVectorType>(SrcVTy) ||
- cast<FixedVectorType>(DstVTy)->getNumElements() !=
- cast<FixedVectorType>(SrcVTy)->getNumElements())
- // Don't touch a bitcast between vectors of different element counts.
- return nullptr;
- } else
- // Don't touch a scalar-to-vector bitcast.
- return nullptr;
- } else if (I->getOperand(0)->getType()->isVectorTy())
- // Don't touch a vector-to-scalar bitcast.
- return nullptr;
- if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1))
- return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
+ }
case Instruction::SExt: {
// Compute the bits in the result that are not present in the input.
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
@@ -461,8 +478,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (InputKnown.isNonNegative() ||
DemandedMask.getActiveBits() <= SrcBitWidth) {
// Convert to ZExt cast.
- CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName());
- return InsertNewInstWith(NewCast, *I);
+ CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy);
+ NewCast->takeName(I);
+ return InsertNewInstWith(NewCast, I->getIterator());
}
// If the sign bit of the input is known set or clear, then we know the
@@ -586,7 +604,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) {
Constant *ShiftC = ConstantInt::get(VTy, CTZ);
Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
- return InsertNewInstWith(Shl, *I);
+ return InsertNewInstWith(Shl, I->getIterator());
}
}
// For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
@@ -595,7 +613,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) {
Constant *One = ConstantInt::get(VTy, 1);
Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One);
- return InsertNewInstWith(And1, *I);
+ return InsertNewInstWith(And1, I->getIterator());
}
computeKnownBits(I, Known, Depth, CxtI);
@@ -624,10 +642,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (DemandedMask.countr_zero() >= ShiftAmt &&
match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
- Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC);
- if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) {
+ Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
+ LeftShiftAmtC, DL);
+ if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC,
+ DL) == C) {
Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
- return InsertNewInstWith(Lshr, *I);
+ return InsertNewInstWith(Lshr, I->getIterator());
}
}
@@ -688,24 +708,23 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Constant *C;
if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) {
Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
- Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC);
- if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) {
+ Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C,
+ RightShiftAmtC, DL);
+ if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC,
+ RightShiftAmtC, DL) == C) {
Instruction *Shl = BinaryOperator::CreateShl(NewC, X);
- return InsertNewInstWith(Shl, *I);
+ return InsertNewInstWith(Shl, I->getIterator());
}
}
}
// Unsigned shift right.
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
-
- // If the shift is exact, then it does demand the low bits (and knows that
- // they are zero).
- if (cast<LShrOperator>(I)->isExact())
- DemandedMaskIn.setLowBits(ShiftAmt);
-
- if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) {
+ // exact flag may not longer hold.
+ I->dropPoisonGeneratingFlags();
return I;
+ }
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
Known.Zero.lshrInPlace(ShiftAmt);
Known.One.lshrInPlace(ShiftAmt);
@@ -733,7 +752,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Perform the logical shift right.
Instruction *NewVal = BinaryOperator::CreateLShr(
I->getOperand(0), I->getOperand(1), I->getName());
- return InsertNewInstWith(NewVal, *I);
+ return InsertNewInstWith(NewVal, I->getIterator());
}
const APInt *SA;
@@ -747,13 +766,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (DemandedMask.countl_zero() <= ShiftAmt)
DemandedMaskIn.setSignBit();
- // If the shift is exact, then it does demand the low bits (and knows that
- // they are zero).
- if (cast<AShrOperator>(I)->isExact())
- DemandedMaskIn.setLowBits(ShiftAmt);
-
- if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) {
+ // exact flag may not longer hold.
+ I->dropPoisonGeneratingFlags();
return I;
+ }
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
// Compute the new bits that are at the top now plus sign bits.
@@ -770,7 +787,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0),
I->getOperand(1));
LShr->setIsExact(cast<BinaryOperator>(I)->isExact());
- return InsertNewInstWith(LShr, *I);
+ LShr->takeName(I);
+ return InsertNewInstWith(LShr, I->getIterator());
} else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one.
Known.One |= HighBits;
}
@@ -867,7 +885,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
match(II->getArgOperand(0), m_Not(m_Value(X)))) {
Function *Ctpop = Intrinsic::getDeclaration(
II->getModule(), Intrinsic::ctpop, VTy);
- return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I);
+ return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator());
}
break;
}
@@ -894,10 +912,52 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
NewVal = BinaryOperator::CreateShl(
II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ));
NewVal->takeName(I);
- return InsertNewInstWith(NewVal, *I);
+ return InsertNewInstWith(NewVal, I->getIterator());
}
break;
}
+ case Intrinsic::ptrmask: {
+ unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
+ RHSKnown = KnownBits(MaskWidth);
+ // If either the LHS or the RHS are Zero, the result is zero.
+ if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
+ SimplifyDemandedBits(
+ I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
+ RHSKnown, Depth + 1))
+ return I;
+
+ // TODO: Should be 1-extend
+ RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
+ assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
+ assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
+
+ Known = LHSKnown & RHSKnown;
+ KnownBitsComputed = true;
+
+ // If the client is only demanding bits we know to be zero, return
+ // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
+ // provenance, but making the mask zero will be easily optimizable in
+ // the backend.
+ if (DemandedMask.isSubsetOf(Known.Zero) &&
+ !match(I->getOperand(1), m_Zero()))
+ return replaceOperand(
+ *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
+
+ // Mask in demanded space does nothing.
+ // NOTE: We may have attributes associated with the return value of the
+ // llvm.ptrmask intrinsic that will be lost when we just return the
+ // operand. We should try to preserve them.
+ if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
+ return I->getOperand(0);
+
+ // If the RHS is a constant, see if we can simplify it.
+ if (ShrinkDemandedConstant(
+ I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
+ return I;
+
+ break;
+ }
+
case Intrinsic::fshr:
case Intrinsic::fshl: {
const APInt *SA;
@@ -918,7 +978,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
return I;
} else { // fshl is a rotate
- // Avoid converting rotate into funnel shift.
+ // Avoid converting rotate into funnel shift.
// Only simplify if one operand is constant.
LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I);
if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) &&
@@ -982,10 +1042,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
}
+ if (V->getType()->isPointerTy()) {
+ Align Alignment = V->getPointerAlignment(DL);
+ Known.Zero.setLowBits(Log2(Alignment));
+ }
+
// If the client is only demanding bits that we know, return the known
- // constant.
- if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
+ // constant. We can't directly simplify pointers as a constant because of
+ // pointer provenance.
+ // TODO: We could return `(inttoptr const)` for pointers.
+ if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
+
+ if (VerifyKnownBits) {
+ KnownBits ReferenceKnown = computeKnownBits(V, Depth, CxtI);
+ if (Known != ReferenceKnown) {
+ errs() << "Mismatched known bits for " << *V << " in "
+ << I->getFunction()->getName() << "\n";
+ errs() << "computeKnownBits(): " << ReferenceKnown << "\n";
+ errs() << "SimplifyDemandedBits(): " << Known << "\n";
+ std::abort();
+ }
+ }
+
return nullptr;
}
@@ -1009,8 +1088,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
case Instruction::And: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
- Known = LHSKnown & RHSKnown;
- computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, SQ.getWithInstruction(CxtI));
+ computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1029,8 +1109,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
case Instruction::Or: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
- Known = LHSKnown | RHSKnown;
- computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, SQ.getWithInstruction(CxtI));
+ computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1051,8 +1132,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
case Instruction::Xor: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
- Known = LHSKnown ^ RHSKnown;
- computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
+ Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
+ Depth, SQ.getWithInstruction(CxtI));
+ computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
// If the client is only demanding bits that we know, return the known
// constant.
@@ -1085,7 +1167,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown);
- computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
+ computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::Sub: {
@@ -1101,7 +1183,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown);
- computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI));
+ computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
break;
}
case Instruction::AShr: {
@@ -1219,7 +1301,7 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits(
New->setIsExact(true);
}
- return InsertNewInstWith(New, *Shl);
+ return InsertNewInstWith(New, Shl->getIterator());
}
return nullptr;
@@ -1549,7 +1631,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
Instruction *New = InsertElementInst::Create(
Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx),
Shuffle->getName());
- InsertNewInstWith(New, *Shuffle);
+ InsertNewInstWith(New, Shuffle->getIterator());
return New;
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 4a5ffef2b08e..c8b58c51d4e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -132,7 +132,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
// Create a scalar PHI node that will replace the vector PHI node
// just before the current PHI node.
PHINode *scalarPHI = cast<PHINode>(InsertNewInstWith(
- PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), *PN));
+ PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), PN->getIterator()));
// Scalarize each PHI operand.
for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
Value *PHIInVal = PN->getIncomingValue(i);
@@ -148,10 +148,10 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
Value *Op = InsertNewInstWith(
ExtractElementInst::Create(B0->getOperand(opId), Elt,
B0->getOperand(opId)->getName() + ".Elt"),
- *B0);
+ B0->getIterator());
Value *newPHIUser = InsertNewInstWith(
BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(),
- scalarPHI, Op, B0), *B0);
+ scalarPHI, Op, B0), B0->getIterator());
scalarPHI->addIncoming(newPHIUser, inBB);
} else {
// Scalarize PHI input:
@@ -165,7 +165,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
InsertPos = inBB->getFirstInsertionPt();
}
- InsertNewInstWith(newEI, *InsertPos);
+ InsertNewInstWith(newEI, InsertPos);
scalarPHI->addIncoming(newEI, inBB);
}
@@ -441,7 +441,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
if (IndexC->getValue().getActiveBits() <= BitWidth)
Idx = ConstantInt::get(Ty, IndexC->getValue().zextOrTrunc(BitWidth));
else
- Idx = UndefValue::get(Ty);
+ Idx = PoisonValue::get(Ty);
return replaceInstUsesWith(EI, Idx);
}
}
@@ -742,7 +742,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt,
if (ExtVecOpInst && !isa<PHINode>(ExtVecOpInst))
WideVec->insertAfter(ExtVecOpInst);
else
- IC.InsertNewInstWith(WideVec, *ExtElt->getParent()->getFirstInsertionPt());
+ IC.InsertNewInstWith(WideVec, ExtElt->getParent()->getFirstInsertionPt());
// Replace extracts from the original narrow vector with extracts from the new
// wide vector.
@@ -751,7 +751,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt,
if (!OldExt || OldExt->getParent() != WideVec->getParent())
continue;
auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1));
- IC.InsertNewInstWith(NewExt, *OldExt);
+ IC.InsertNewInstWith(NewExt, OldExt->getIterator());
IC.replaceInstUsesWith(*OldExt, NewExt);
// Add the old extracts to the worklist for DCE. We can't remove the
// extracts directly, because they may still be used by the calling code.
@@ -1121,7 +1121,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// Note that the same block can be a predecessor more than once,
// and we need to preserve that invariant for the PHI node.
BuilderTy::InsertPointGuard Guard(Builder);
- Builder.SetInsertPoint(UseBB->getFirstNonPHI());
+ Builder.SetInsertPoint(UseBB, UseBB->getFirstNonPHIIt());
auto *PHI =
Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged");
for (BasicBlock *Pred : Preds)
@@ -2122,8 +2122,8 @@ static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) {
NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i];
// A select mask with undef elements might look like an identity mask.
- assert((ShuffleVectorInst::isSelectMask(NewMask) ||
- ShuffleVectorInst::isIdentityMask(NewMask)) &&
+ assert((ShuffleVectorInst::isSelectMask(NewMask, NumElts) ||
+ ShuffleVectorInst::isIdentityMask(NewMask, NumElts)) &&
"Unexpected shuffle mask");
return new ShuffleVectorInst(X, Y, NewMask);
}
@@ -2197,9 +2197,9 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf,
!match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0)
return nullptr;
- // Insert into element 0 of an undef vector.
- UndefValue *UndefVec = UndefValue::get(Shuf.getType());
- Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0);
+ // Insert into element 0 of a poison vector.
+ PoisonValue *PoisonVec = PoisonValue::get(Shuf.getType());
+ Value *NewIns = Builder.CreateInsertElement(PoisonVec, X, (uint64_t)0);
// Splat from element 0. Any mask element that is undefined remains undefined.
// For example:
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index afd6e034f46d..f072f5cec309 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -130,13 +130,6 @@ STATISTIC(NumReassoc , "Number of reassociations");
DEBUG_COUNTER(VisitCounter, "instcombine-visit",
"Controls which instructions are visited");
-// FIXME: these limits eventually should be as low as 2.
-#ifndef NDEBUG
-static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100;
-#else
-static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000;
-#endif
-
static cl::opt<bool>
EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"),
cl::init(true));
@@ -145,12 +138,6 @@ static cl::opt<unsigned> MaxSinkNumUsers(
"instcombine-max-sink-users", cl::init(32),
cl::desc("Maximum number of undroppable users for instruction sinking"));
-static cl::opt<unsigned> InfiniteLoopDetectionThreshold(
- "instcombine-infinite-loop-threshold",
- cl::desc("Number of instruction combining iterations considered an "
- "infinite loop"),
- cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden);
-
static cl::opt<unsigned>
MaxArraySize("instcombine-maxarray-size", cl::init(1024),
cl::desc("Maximum array size considered when doing a combine"));
@@ -358,15 +345,19 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
// Fold the constants together in the destination type:
// (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC)
+ const DataLayout &DL = IC.getDataLayout();
Type *DestTy = C1->getType();
- Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy);
- Constant *FoldedC =
- ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout());
+ Constant *CastC2 = ConstantFoldCastOperand(CastOpcode, C2, DestTy, DL);
+ if (!CastC2)
+ return false;
+ Constant *FoldedC = ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, DL);
if (!FoldedC)
return false;
IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0));
IC.replaceOperand(*BinOp1, 1, FoldedC);
+ BinOp1->dropPoisonGeneratingFlags();
+ Cast->dropPoisonGeneratingFlags();
return true;
}
@@ -542,12 +533,12 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
BinaryOperator::Create(Opcode, A, B);
if (isa<FPMathOperator>(NewBO)) {
- FastMathFlags Flags = I.getFastMathFlags();
- Flags &= Op0->getFastMathFlags();
- Flags &= Op1->getFastMathFlags();
- NewBO->setFastMathFlags(Flags);
+ FastMathFlags Flags = I.getFastMathFlags() &
+ Op0->getFastMathFlags() &
+ Op1->getFastMathFlags();
+ NewBO->setFastMathFlags(Flags);
}
- InsertNewInstWith(NewBO, I);
+ InsertNewInstWith(NewBO, I.getIterator());
NewBO->takeName(Op1);
replaceOperand(I, 0, NewBO);
replaceOperand(I, 1, CRes);
@@ -749,7 +740,16 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
// 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`).
//
// -> (BinOp (logic_shift (BinOp X, Y)), Mask)
+//
+// (Binop1 (Binop2 (arithmetic_shift X, Amt), Mask), (arithmetic_shift Y, Amt))
+// IFF
+// 1) Binop1 is bitwise logical operator `and`, `or` or `xor`
+// 2) Binop2 is `not`
+//
+// -> (arithmetic_shift Binop1((not X), Y), Amt)
+
Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
+ const DataLayout &DL = I.getModule()->getDataLayout();
auto IsValidBinOpc = [](unsigned Opc) {
switch (Opc) {
default:
@@ -768,11 +768,13 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
// constraints.
auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2,
unsigned ShOpc) {
+ assert(ShOpc != Instruction::AShr);
return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) ||
ShOpc == Instruction::Shl;
};
auto GetInvShift = [](unsigned ShOpc) {
+ assert(ShOpc != Instruction::AShr);
return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr;
};
@@ -796,23 +798,23 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
// Otherwise, need mask that meets the below requirement.
// (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask
- return ConstantExpr::get(
- ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift),
- CShift) == CMask;
+ Constant *MaskInvShift =
+ ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL);
+ return ConstantFoldBinaryOpOperands(ShOpc, MaskInvShift, CShift, DL) ==
+ CMask;
};
auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * {
Constant *CMask, *CShift;
Value *X, *Y, *ShiftedX, *Mask, *Shift;
if (!match(I.getOperand(ShOpnum),
- m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift)))))
+ m_OneUse(m_Shift(m_Value(Y), m_Value(Shift)))))
return nullptr;
if (!match(I.getOperand(1 - ShOpnum),
m_BinOp(m_Value(ShiftedX), m_Value(Mask))))
return nullptr;
- if (!match(ShiftedX,
- m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift)))))
+ if (!match(ShiftedX, m_OneUse(m_Shift(m_Value(X), m_Specific(Shift)))))
return nullptr;
// Make sure we are matching instruction shifts and not ConstantExpr
@@ -836,6 +838,18 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc))
return nullptr;
+ if (ShOpc == Instruction::AShr) {
+ if (Instruction::isBitwiseLogicOp(I.getOpcode()) &&
+ BinOpc == Instruction::Xor && match(Mask, m_AllOnes())) {
+ Value *NotX = Builder.CreateNot(X);
+ Value *NewBinOp = Builder.CreateBinOp(I.getOpcode(), Y, NotX);
+ return BinaryOperator::Create(
+ static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp, Shift);
+ }
+
+ return nullptr;
+ }
+
// If BinOp1 == BinOp2 and it's bitwise or shl with add, then just
// distribute to drop the shift irrelevant of constants.
if (BinOpc == I.getOpcode() &&
@@ -857,7 +871,8 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift))
return nullptr;
- Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift);
+ Constant *NewCMask =
+ ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL);
Value *NewBinOp2 = Builder.CreateBinOp(
static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask);
Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2);
@@ -924,13 +939,17 @@ InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) {
// If the value used in the zext/sext is the select condition, or the negated
// of the select condition, the binop can be simplified.
- if (CondVal == A)
- return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal),
+ if (CondVal == A) {
+ Value *NewTrueVal = NewFoldedConst(false, TrueVal);
+ return SelectInst::Create(CondVal, NewTrueVal,
NewFoldedConst(true, FalseVal));
+ }
- if (match(A, m_Not(m_Specific(CondVal))))
- return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal),
+ if (match(A, m_Not(m_Specific(CondVal)))) {
+ Value *NewTrueVal = NewFoldedConst(true, TrueVal);
+ return SelectInst::Create(CondVal, NewTrueVal,
NewFoldedConst(false, FalseVal));
+ }
return nullptr;
}
@@ -1167,6 +1186,8 @@ void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) {
break;
case Instruction::Xor:
replaceInstUsesWith(cast<Instruction>(*U), I);
+ // Add to worklist for DCE.
+ addToWorklist(cast<Instruction>(U));
break;
default:
llvm_unreachable("Got unexpected user - out of sync with "
@@ -1268,7 +1289,7 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
Value *NewOp, InstCombiner &IC) {
Instruction *Clone = I.clone();
Clone->replaceUsesOfWith(SI, NewOp);
- IC.InsertNewInstBefore(Clone, *SI);
+ IC.InsertNewInstBefore(Clone, SI->getIterator());
return Clone;
}
@@ -1302,6 +1323,21 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
return nullptr;
}
+ // Test if a FCmpInst instruction is used exclusively by a select as
+ // part of a minimum or maximum operation. If so, refrain from doing
+ // any other folding. This helps out other analyses which understand
+ // non-obfuscated minimum and maximum idioms. And in this case, at
+ // least one of the comparison operands has at least one user besides
+ // the compare (the select), which would often largely negate the
+ // benefit of folding anyway.
+ if (auto *CI = dyn_cast<FCmpInst>(SI->getCondition())) {
+ if (CI->hasOneUse()) {
+ Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1);
+ if ((TV == Op0 && FV == Op1) || (FV == Op0 && TV == Op1))
+ return nullptr;
+ }
+ }
+
// Make sure that one of the select arms constant folds successfully.
Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true);
Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false);
@@ -1316,6 +1352,47 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
}
+static Value *simplifyInstructionWithPHI(Instruction &I, PHINode *PN,
+ Value *InValue, BasicBlock *InBB,
+ const DataLayout &DL,
+ const SimplifyQuery SQ) {
+ // NB: It is a precondition of this transform that the operands be
+ // phi translatable! This is usually trivially satisfied by limiting it
+ // to constant ops, and for selects we do a more sophisticated check.
+ SmallVector<Value *> Ops;
+ for (Value *Op : I.operands()) {
+ if (Op == PN)
+ Ops.push_back(InValue);
+ else
+ Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB));
+ }
+
+ // Don't consider the simplification successful if we get back a constant
+ // expression. That's just an instruction in hiding.
+ // Also reject the case where we simplify back to the phi node. We wouldn't
+ // be able to remove it in that case.
+ Value *NewVal = simplifyInstructionWithOperands(
+ &I, Ops, SQ.getWithInstruction(InBB->getTerminator()));
+ if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr()))
+ return NewVal;
+
+ // Check if incoming PHI value can be replaced with constant
+ // based on implied condition.
+ BranchInst *TerminatorBI = dyn_cast<BranchInst>(InBB->getTerminator());
+ const ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
+ if (TerminatorBI && TerminatorBI->isConditional() &&
+ TerminatorBI->getSuccessor(0) != TerminatorBI->getSuccessor(1) && ICmp) {
+ bool LHSIsTrue = TerminatorBI->getSuccessor(0) == PN->getParent();
+ std::optional<bool> ImpliedCond =
+ isImpliedCondition(TerminatorBI->getCondition(), ICmp->getPredicate(),
+ Ops[0], Ops[1], DL, LHSIsTrue);
+ if (ImpliedCond)
+ return ConstantInt::getBool(I.getType(), ImpliedCond.value());
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
unsigned NumPHIValues = PN->getNumIncomingValues();
if (NumPHIValues == 0)
@@ -1344,29 +1421,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
Value *InVal = PN->getIncomingValue(i);
BasicBlock *InBB = PN->getIncomingBlock(i);
- // NB: It is a precondition of this transform that the operands be
- // phi translatable! This is usually trivially satisfied by limiting it
- // to constant ops, and for selects we do a more sophisticated check.
- SmallVector<Value *> Ops;
- for (Value *Op : I.operands()) {
- if (Op == PN)
- Ops.push_back(InVal);
- else
- Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB));
- }
-
- // Don't consider the simplification successful if we get back a constant
- // expression. That's just an instruction in hiding.
- // Also reject the case where we simplify back to the phi node. We wouldn't
- // be able to remove it in that case.
- Value *NewVal = simplifyInstructionWithOperands(
- &I, Ops, SQ.getWithInstruction(InBB->getTerminator()));
- if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) {
+ if (auto *NewVal = simplifyInstructionWithPHI(I, PN, InVal, InBB, DL, SQ)) {
NewPhiValues.push_back(NewVal);
continue;
}
- if (isa<PHINode>(InVal)) return nullptr; // Itself a phi.
if (NonSimplifiedBB) return nullptr; // More than one non-simplified value.
NonSimplifiedBB = InBB;
@@ -1402,7 +1461,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
// Okay, we can do the transformation: create the new PHI node.
PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues());
- InsertNewInstBefore(NewPN, *PN);
+ InsertNewInstBefore(NewPN, PN->getIterator());
NewPN->takeName(PN);
NewPN->setDebugLoc(PN->getDebugLoc());
@@ -1417,7 +1476,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
else
U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB);
}
- InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator());
+ InsertNewInstBefore(Clone, NonSimplifiedBB->getTerminator()->getIterator());
}
for (unsigned i = 0; i != NumPHIValues; ++i) {
@@ -1848,8 +1907,8 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
Constant *WideC;
if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
return nullptr;
- Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType());
- if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC)
+ Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc);
+ if (!NarrowC)
return nullptr;
Y = NarrowC;
}
@@ -1940,7 +1999,7 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0);
if (NumVarIndices != Src->getNumIndices()) {
// FIXME: getIndexedOffsetInType() does not handled scalable vectors.
- if (isa<ScalableVectorType>(BaseType))
+ if (BaseType->isScalableTy())
return nullptr;
SmallVector<Value *> ConstantIndices;
@@ -2048,12 +2107,126 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
return nullptr;
}
+Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses,
+ BuilderTy *Builder,
+ bool &DoesConsume, unsigned Depth) {
+ static Value *const NonNull = reinterpret_cast<Value *>(uintptr_t(1));
+ // ~(~(X)) -> X.
+ Value *A, *B;
+ if (match(V, m_Not(m_Value(A)))) {
+ DoesConsume = true;
+ return A;
+ }
+
+ Constant *C;
+ // Constants can be considered to be not'ed values.
+ if (match(V, m_ImmConstant(C)))
+ return ConstantExpr::getNot(C);
+
+ if (Depth++ >= MaxAnalysisRecursionDepth)
+ return nullptr;
+
+ // The rest of the cases require that we invert all uses so don't bother
+ // doing the analysis if we know we can't use the result.
+ if (!WillInvertAllUses)
+ return nullptr;
+
+ // Compares can be inverted if all of their uses are being modified to use
+ // the ~V.
+ if (auto *I = dyn_cast<CmpInst>(V)) {
+ if (Builder != nullptr)
+ return Builder->CreateCmp(I->getInversePredicate(), I->getOperand(0),
+ I->getOperand(1));
+ return NonNull;
+ }
+
+ // If `V` is of the form `A + B` then `-1 - V` can be folded into
+ // `(-1 - B) - A` if we are willing to invert all of the uses.
+ if (match(V, m_Add(m_Value(A), m_Value(B)))) {
+ if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateSub(BV, A) : NonNull;
+ if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateSub(AV, B) : NonNull;
+ return nullptr;
+ }
+
+ // If `V` is of the form `A ^ ~B` then `~(A ^ ~B)` can be folded
+ // into `A ^ B` if we are willing to invert all of the uses.
+ if (match(V, m_Xor(m_Value(A), m_Value(B)))) {
+ if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateXor(A, BV) : NonNull;
+ if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateXor(AV, B) : NonNull;
+ return nullptr;
+ }
+
+ // If `V` is of the form `B - A` then `-1 - V` can be folded into
+ // `A + (-1 - B)` if we are willing to invert all of the uses.
+ if (match(V, m_Sub(m_Value(A), m_Value(B)))) {
+ if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateAdd(AV, B) : NonNull;
+ return nullptr;
+ }
+
+ // If `V` is of the form `(~A) s>> B` then `~((~A) s>> B)` can be folded
+ // into `A s>> B` if we are willing to invert all of the uses.
+ if (match(V, m_AShr(m_Value(A), m_Value(B)))) {
+ if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateAShr(AV, B) : NonNull;
+ return nullptr;
+ }
+
+ // Treat lshr with non-negative operand as ashr.
+ if (match(V, m_LShr(m_Value(A), m_Value(B))) &&
+ isKnownNonNegative(A, SQ.getWithInstruction(cast<Instruction>(V)),
+ Depth)) {
+ if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth))
+ return Builder ? Builder->CreateAShr(AV, B) : NonNull;
+ return nullptr;
+ }
+
+ Value *Cond;
+ // LogicOps are special in that we canonicalize them at the cost of an
+ // instruction.
+ bool IsSelect = match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) &&
+ !shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(V));
+ // Selects/min/max with invertible operands are freely invertible
+ if (IsSelect || match(V, m_MaxOrMin(m_Value(A), m_Value(B)))) {
+ if (!getFreelyInvertedImpl(B, B->hasOneUse(), /*Builder*/ nullptr,
+ DoesConsume, Depth))
+ return nullptr;
+ if (Value *NotA = getFreelyInvertedImpl(A, A->hasOneUse(), Builder,
+ DoesConsume, Depth)) {
+ if (Builder != nullptr) {
+ Value *NotB = getFreelyInvertedImpl(B, B->hasOneUse(), Builder,
+ DoesConsume, Depth);
+ assert(NotB != nullptr &&
+ "Unable to build inverted value for known freely invertable op");
+ if (auto *II = dyn_cast<IntrinsicInst>(V))
+ return Builder->CreateBinaryIntrinsic(
+ getInverseMinMaxIntrinsic(II->getIntrinsicID()), NotA, NotB);
+ return Builder->CreateSelect(Cond, NotA, NotB);
+ }
+ return NonNull;
+ }
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
Value *PtrOp = GEP.getOperand(0);
SmallVector<Value *, 8> Indices(GEP.indices());
Type *GEPType = GEP.getType();
Type *GEPEltType = GEP.getSourceElementType();
- bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType);
+ bool IsGEPSrcEleScalable = GEPEltType->isScalableTy();
if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(),
SQ.getWithInstruction(&GEP)))
return replaceInstUsesWith(GEP, V);
@@ -2221,7 +2394,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
NewGEP->setOperand(DI, NewPN);
}
- NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt());
+ NewGEP->insertBefore(*GEP.getParent(), GEP.getParent()->getFirstInsertionPt());
return replaceOperand(GEP, 0, NewGEP);
}
@@ -2264,11 +2437,43 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType);
}
}
-
// We do not handle pointer-vector geps here.
if (GEPType->isVectorTy())
return nullptr;
+ if (GEP.getNumIndices() == 1) {
+ // Try to replace ADD + GEP with GEP + GEP.
+ Value *Idx1, *Idx2;
+ if (match(GEP.getOperand(1),
+ m_OneUse(m_Add(m_Value(Idx1), m_Value(Idx2))))) {
+ // %idx = add i64 %idx1, %idx2
+ // %gep = getelementptr i32, ptr %ptr, i64 %idx
+ // as:
+ // %newptr = getelementptr i32, ptr %ptr, i64 %idx1
+ // %newgep = getelementptr i32, ptr %newptr, i64 %idx2
+ auto *NewPtr = Builder.CreateGEP(GEP.getResultElementType(),
+ GEP.getPointerOperand(), Idx1);
+ return GetElementPtrInst::Create(GEP.getResultElementType(), NewPtr,
+ Idx2);
+ }
+ ConstantInt *C;
+ if (match(GEP.getOperand(1), m_OneUse(m_SExt(m_OneUse(m_NSWAdd(
+ m_Value(Idx1), m_ConstantInt(C))))))) {
+ // %add = add nsw i32 %idx1, idx2
+ // %sidx = sext i32 %add to i64
+ // %gep = getelementptr i32, ptr %ptr, i64 %sidx
+ // as:
+ // %newptr = getelementptr i32, ptr %ptr, i32 %idx1
+ // %newgep = getelementptr i32, ptr %newptr, i32 idx2
+ auto *NewPtr = Builder.CreateGEP(
+ GEP.getResultElementType(), GEP.getPointerOperand(),
+ Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType()));
+ return GetElementPtrInst::Create(
+ GEP.getResultElementType(), NewPtr,
+ Builder.CreateSExt(C, GEP.getOperand(1)->getType()));
+ }
+ }
+
if (!GEP.isInBounds()) {
unsigned IdxWidth =
DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace());
@@ -2362,6 +2567,26 @@ static bool isAllocSiteRemovable(Instruction *AI,
unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0;
if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI))
return false;
+
+ // Do not fold compares to aligned_alloc calls, as they may have to
+ // return null in case the required alignment cannot be satisfied,
+ // unless we can prove that both alignment and size are valid.
+ auto AlignmentAndSizeKnownValid = [](CallBase *CB) {
+ // Check if alignment and size of a call to aligned_alloc is valid,
+ // that is alignment is a power-of-2 and the size is a multiple of the
+ // alignment.
+ const APInt *Alignment;
+ const APInt *Size;
+ return match(CB->getArgOperand(0), m_APInt(Alignment)) &&
+ match(CB->getArgOperand(1), m_APInt(Size)) &&
+ Alignment->isPowerOf2() && Size->urem(*Alignment).isZero();
+ };
+ auto *CB = dyn_cast<CallBase>(AI);
+ LibFunc TheLibFunc;
+ if (CB && TLI.getLibFunc(*CB->getCalledFunction(), TheLibFunc) &&
+ TLI.has(TheLibFunc) && TheLibFunc == LibFunc_aligned_alloc &&
+ !AlignmentAndSizeKnownValid(CB))
+ return false;
Users.emplace_back(I);
continue;
}
@@ -2451,9 +2676,10 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
// If we are removing an alloca with a dbg.declare, insert dbg.value calls
// before each store.
SmallVector<DbgVariableIntrinsic *, 8> DVIs;
+ SmallVector<DPValue *, 8> DPVs;
std::unique_ptr<DIBuilder> DIB;
if (isa<AllocaInst>(MI)) {
- findDbgUsers(DVIs, &MI);
+ findDbgUsers(DVIs, &MI, &DPVs);
DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false));
}
@@ -2493,6 +2719,9 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
for (auto *DVI : DVIs)
if (DVI->isAddressOfVariable())
ConvertDebugDeclareToDebugValue(DVI, SI, *DIB);
+ for (auto *DPV : DPVs)
+ if (DPV->isAddressOfVariable())
+ ConvertDebugDeclareToDebugValue(DPV, SI, *DIB);
} else {
// Casts, GEP, or anything else: we're about to delete this instruction,
// so it can not have any valid uses.
@@ -2531,9 +2760,15 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
// If there is a dead store to `%a` in @trivially_inlinable_no_op, the
// "arg0" dbg.value may be stale after the call. However, failing to remove
// the DW_OP_deref dbg.value causes large gaps in location coverage.
+ //
+ // FIXME: the Assignment Tracking project has now likely made this
+ // redundant (and it's sometimes harmful).
for (auto *DVI : DVIs)
if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref())
DVI->eraseFromParent();
+ for (auto *DPV : DPVs)
+ if (DPV->isAddressOfVariable() || DPV->getExpression()->startsWithDeref())
+ DPV->eraseFromParent();
return eraseInstFromFunction(MI);
}
@@ -2612,7 +2847,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI,
for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) {
if (&Instr == FreeInstrBBTerminator)
break;
- Instr.moveBefore(TI);
+ Instr.moveBeforePreserving(TI);
}
assert(FreeInstrBB->size() == 1 &&
"Only the branch instruction should remain");
@@ -2746,55 +2981,77 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) {
return nullptr;
}
+void InstCombinerImpl::addDeadEdge(BasicBlock *From, BasicBlock *To,
+ SmallVectorImpl<BasicBlock *> &Worklist) {
+ if (!DeadEdges.insert({From, To}).second)
+ return;
+
+ // Replace phi node operands in successor with poison.
+ for (PHINode &PN : To->phis())
+ for (Use &U : PN.incoming_values())
+ if (PN.getIncomingBlock(U) == From && !isa<PoisonValue>(U)) {
+ replaceUse(U, PoisonValue::get(PN.getType()));
+ addToWorklist(&PN);
+ MadeIRChange = true;
+ }
+
+ Worklist.push_back(To);
+}
+
// Under the assumption that I is unreachable, remove it and following
-// instructions.
-bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) {
- bool Changed = false;
+// instructions. Changes are reported directly to MadeIRChange.
+void InstCombinerImpl::handleUnreachableFrom(
+ Instruction *I, SmallVectorImpl<BasicBlock *> &Worklist) {
BasicBlock *BB = I->getParent();
for (Instruction &Inst : make_early_inc_range(
make_range(std::next(BB->getTerminator()->getReverseIterator()),
std::next(I->getReverseIterator())))) {
if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) {
replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType()));
- Changed = true;
+ MadeIRChange = true;
}
if (Inst.isEHPad() || Inst.getType()->isTokenTy())
continue;
+ // RemoveDIs: erase debug-info on this instruction manually.
+ Inst.dropDbgValues();
eraseInstFromFunction(Inst);
- Changed = true;
+ MadeIRChange = true;
}
- // Replace phi node operands in successor blocks with poison.
+ // RemoveDIs: to match behaviour in dbg.value mode, drop debug-info on
+ // terminator too.
+ BB->getTerminator()->dropDbgValues();
+
+ // Handle potentially dead successors.
for (BasicBlock *Succ : successors(BB))
- for (PHINode &PN : Succ->phis())
- for (Use &U : PN.incoming_values())
- if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) {
- replaceUse(U, PoisonValue::get(PN.getType()));
- addToWorklist(&PN);
- Changed = true;
- }
+ addDeadEdge(BB, Succ, Worklist);
+}
- // TODO: Successor blocks may also be dead.
- return Changed;
+void InstCombinerImpl::handlePotentiallyDeadBlocks(
+ SmallVectorImpl<BasicBlock *> &Worklist) {
+ while (!Worklist.empty()) {
+ BasicBlock *BB = Worklist.pop_back_val();
+ if (!all_of(predecessors(BB), [&](BasicBlock *Pred) {
+ return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred);
+ }))
+ continue;
+
+ handleUnreachableFrom(&BB->front(), Worklist);
+ }
}
-bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
+void InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
BasicBlock *LiveSucc) {
- bool Changed = false;
+ SmallVector<BasicBlock *> Worklist;
for (BasicBlock *Succ : successors(BB)) {
// The live successor isn't dead.
if (Succ == LiveSucc)
continue;
- if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) {
- return DT.dominates(BasicBlockEdge(BB, Succ),
- BasicBlockEdge(Pred, Succ));
- }))
- continue;
-
- Changed |= handleUnreachableFrom(&Succ->front());
+ addDeadEdge(BB, Succ, Worklist);
}
- return Changed;
+
+ handlePotentiallyDeadBlocks(Worklist);
}
Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
@@ -2840,14 +3097,17 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return &BI;
}
- if (isa<UndefValue>(Cond) &&
- handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr))
- return &BI;
- if (auto *CI = dyn_cast<ConstantInt>(Cond))
- if (handlePotentiallyDeadSuccessors(BI.getParent(),
- BI.getSuccessor(!CI->getZExtValue())))
- return &BI;
+ if (isa<UndefValue>(Cond)) {
+ handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr);
+ return nullptr;
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
+ handlePotentiallyDeadSuccessors(BI.getParent(),
+ BI.getSuccessor(!CI->getZExtValue()));
+ return nullptr;
+ }
+ DC.registerBranch(&BI);
return nullptr;
}
@@ -2866,14 +3126,6 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
return replaceOperand(SI, 0, Op0);
}
- if (isa<UndefValue>(Cond) &&
- handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr))
- return &SI;
- if (auto *CI = dyn_cast<ConstantInt>(Cond))
- if (handlePotentiallyDeadSuccessors(
- SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor()))
- return &SI;
-
KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
@@ -2906,6 +3158,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
return replaceOperand(SI, 0, NewCond);
}
+ if (isa<UndefValue>(Cond)) {
+ handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr);
+ return nullptr;
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
+ handlePotentiallyDeadSuccessors(SI.getParent(),
+ SI.findCaseValue(CI)->getCaseSuccessor());
+ return nullptr;
+ }
+
return nullptr;
}
@@ -3532,7 +3794,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI,
Value *StartV = StartU->get();
BasicBlock *StartBB = PN->getIncomingBlock(*StartU);
bool StartNeedsFreeze = !isGuaranteedNotToBeUndefOrPoison(StartV);
- // We can't insert freeze if the the start value is the result of the
+ // We can't insert freeze if the start value is the result of the
// terminator (e.g. an invoke).
if (StartNeedsFreeze && StartBB->getTerminator() == StartV)
return nullptr;
@@ -3583,19 +3845,27 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) {
// *all* uses if the operand is an invoke/callbr and the use is in a phi on
// the normal/default destination. This is why the domination check in the
// replacement below is still necessary.
- Instruction *MoveBefore;
+ BasicBlock::iterator MoveBefore;
if (isa<Argument>(Op)) {
MoveBefore =
- &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca();
+ FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca();
} else {
- MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef();
- if (!MoveBefore)
+ auto MoveBeforeOpt = cast<Instruction>(Op)->getInsertionPointAfterDef();
+ if (!MoveBeforeOpt)
return false;
+ MoveBefore = *MoveBeforeOpt;
}
+ // Don't move to the position of a debug intrinsic.
+ if (isa<DbgInfoIntrinsic>(MoveBefore))
+ MoveBefore = MoveBefore->getNextNonDebugInstruction()->getIterator();
+ // Re-point iterator to come after any debug-info records, if we're
+ // running in "RemoveDIs" mode
+ MoveBefore.setHeadBit(false);
+
bool Changed = false;
- if (&FI != MoveBefore) {
- FI.moveBefore(MoveBefore);
+ if (&FI != &*MoveBefore) {
+ FI.moveBefore(*MoveBefore->getParent(), MoveBefore);
Changed = true;
}
@@ -3798,7 +4068,7 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I,
/// the new position.
BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt();
- I->moveBefore(&*InsertPos);
+ I->moveBefore(*DestBlock, InsertPos);
++NumSunkInst;
// Also sink all related debug uses from the source basic block. Otherwise we
@@ -3808,10 +4078,19 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I,
// here, but that computation has been sunk.
SmallVector<DbgVariableIntrinsic *, 2> DbgUsers;
findDbgUsers(DbgUsers, I);
- // Process the sinking DbgUsers in reverse order, as we only want to clone the
- // last appearing debug intrinsic for each given variable.
+
+ // For all debug values in the destination block, the sunk instruction
+ // will still be available, so they do not need to be dropped.
+ SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSalvage;
+ SmallVector<DPValue *, 2> DPValuesToSalvage;
+ for (auto &DbgUser : DbgUsers)
+ if (DbgUser->getParent() != DestBlock)
+ DbgUsersToSalvage.push_back(DbgUser);
+
+ // Process the sinking DbgUsersToSalvage in reverse order, as we only want
+ // to clone the last appearing debug intrinsic for each given variable.
SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSink;
- for (DbgVariableIntrinsic *DVI : DbgUsers)
+ for (DbgVariableIntrinsic *DVI : DbgUsersToSalvage)
if (DVI->getParent() == SrcBlock)
DbgUsersToSink.push_back(DVI);
llvm::sort(DbgUsersToSink,
@@ -3847,7 +4126,10 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I,
// Perform salvaging without the clones, then sink the clones.
if (!DIIClones.empty()) {
- salvageDebugInfoForDbgValues(*I, DbgUsers);
+ // RemoveDIs: pass in empty vector of DPValues until we get to instrumenting
+ // this pass.
+ SmallVector<DPValue *, 1> DummyDPValues;
+ salvageDebugInfoForDbgValues(*I, DbgUsersToSalvage, DummyDPValues);
// The clones are in reverse order of original appearance, reverse again to
// maintain the original order.
for (auto &DIIClone : llvm::reverse(DIIClones)) {
@@ -4093,43 +4375,52 @@ public:
}
};
-/// Populate the IC worklist from a function, by walking it in depth-first
-/// order and adding all reachable code to the worklist.
+/// Populate the IC worklist from a function, by walking it in reverse
+/// post-order and adding all reachable code to the worklist.
///
/// This has a couple of tricks to make the code faster and more powerful. In
/// particular, we constant fold and DCE instructions as we go, to avoid adding
/// them to the worklist (this significantly speeds up instcombine on code where
/// many instructions are dead or constant). Additionally, if we find a branch
/// whose condition is a known constant, we only visit the reachable successors.
-static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
- const TargetLibraryInfo *TLI,
- InstructionWorklist &ICWorklist) {
+bool InstCombinerImpl::prepareWorklist(
+ Function &F, ReversePostOrderTraversal<BasicBlock *> &RPOT) {
bool MadeIRChange = false;
- SmallPtrSet<BasicBlock *, 32> Visited;
- SmallVector<BasicBlock*, 256> Worklist;
- Worklist.push_back(&F.front());
-
+ SmallPtrSet<BasicBlock *, 32> LiveBlocks;
SmallVector<Instruction *, 128> InstrsForInstructionWorklist;
DenseMap<Constant *, Constant *> FoldedConstants;
AliasScopeTracker SeenAliasScopes;
- do {
- BasicBlock *BB = Worklist.pop_back_val();
+ auto HandleOnlyLiveSuccessor = [&](BasicBlock *BB, BasicBlock *LiveSucc) {
+ for (BasicBlock *Succ : successors(BB))
+ if (Succ != LiveSucc && DeadEdges.insert({BB, Succ}).second)
+ for (PHINode &PN : Succ->phis())
+ for (Use &U : PN.incoming_values())
+ if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) {
+ U.set(PoisonValue::get(PN.getType()));
+ MadeIRChange = true;
+ }
+ };
- // We have now visited this block! If we've already been here, ignore it.
- if (!Visited.insert(BB).second)
+ for (BasicBlock *BB : RPOT) {
+ if (!BB->isEntryBlock() && all_of(predecessors(BB), [&](BasicBlock *Pred) {
+ return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred);
+ })) {
+ HandleOnlyLiveSuccessor(BB, nullptr);
continue;
+ }
+ LiveBlocks.insert(BB);
for (Instruction &Inst : llvm::make_early_inc_range(*BB)) {
// ConstantProp instruction if trivially constant.
if (!Inst.use_empty() &&
(Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0))))
- if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) {
+ if (Constant *C = ConstantFoldInstruction(&Inst, DL, &TLI)) {
LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst
<< '\n');
Inst.replaceAllUsesWith(C);
++NumConstProp;
- if (isInstructionTriviallyDead(&Inst, TLI))
+ if (isInstructionTriviallyDead(&Inst, &TLI))
Inst.eraseFromParent();
MadeIRChange = true;
continue;
@@ -4143,7 +4434,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
auto *C = cast<Constant>(U);
Constant *&FoldRes = FoldedConstants[C];
if (!FoldRes)
- FoldRes = ConstantFoldConstant(C, DL, TLI);
+ FoldRes = ConstantFoldConstant(C, DL, &TLI);
if (FoldRes != C) {
LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst
@@ -4163,37 +4454,39 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
}
}
- // Recursively visit successors. If this is a branch or switch on a
- // constant, only visit the reachable successor.
+ // If this is a branch or switch on a constant, mark only the single
+ // live successor. Otherwise assume all successors are live.
Instruction *TI = BB->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) {
- if (isa<UndefValue>(BI->getCondition()))
+ if (isa<UndefValue>(BI->getCondition())) {
// Branch on undef is UB.
+ HandleOnlyLiveSuccessor(BB, nullptr);
continue;
+ }
if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) {
bool CondVal = Cond->getZExtValue();
- BasicBlock *ReachableBB = BI->getSuccessor(!CondVal);
- Worklist.push_back(ReachableBB);
+ HandleOnlyLiveSuccessor(BB, BI->getSuccessor(!CondVal));
continue;
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
- if (isa<UndefValue>(SI->getCondition()))
+ if (isa<UndefValue>(SI->getCondition())) {
// Switch on undef is UB.
+ HandleOnlyLiveSuccessor(BB, nullptr);
continue;
+ }
if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) {
- Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor());
+ HandleOnlyLiveSuccessor(BB,
+ SI->findCaseValue(Cond)->getCaseSuccessor());
continue;
}
}
-
- append_range(Worklist, successors(TI));
- } while (!Worklist.empty());
+ }
// Remove instructions inside unreachable blocks. This prevents the
// instcombine code from having to deal with some bad special cases, and
// reduces use counts of instructions.
for (BasicBlock &BB : F) {
- if (Visited.count(&BB))
+ if (LiveBlocks.count(&BB))
continue;
unsigned NumDeadInstInBB;
@@ -4210,11 +4503,11 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
// of the function down. This jives well with the way that it adds all uses
// of instructions to the worklist after doing a transformation, thus avoiding
// some N^2 behavior in pathological cases.
- ICWorklist.reserve(InstrsForInstructionWorklist.size());
+ Worklist.reserve(InstrsForInstructionWorklist.size());
for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) {
// DCE instruction if trivially dead. As we iterate in reverse program
// order here, we will clean up whole chains of dead instructions.
- if (isInstructionTriviallyDead(Inst, TLI) ||
+ if (isInstructionTriviallyDead(Inst, &TLI) ||
SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) {
++NumDeadInst;
LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n');
@@ -4224,7 +4517,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
continue;
}
- ICWorklist.push(Inst);
+ Worklist.push(Inst);
}
return MadeIRChange;
@@ -4234,7 +4527,7 @@ static bool combineInstructionsOverFunction(
Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA,
AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
- ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) {
+ ProfileSummaryInfo *PSI, LoopInfo *LI, const InstCombineOptions &Opts) {
auto &DL = F.getParent()->getDataLayout();
/// Builder - This is an IRBuilder that automatically inserts new
@@ -4247,6 +4540,8 @@ static bool combineInstructionsOverFunction(
AC.registerAssumption(Assume);
}));
+ ReversePostOrderTraversal<BasicBlock *> RPOT(&F.front());
+
// Lower dbg.declare intrinsics otherwise their value may be clobbered
// by instcombiner.
bool MadeIRChange = false;
@@ -4256,35 +4551,33 @@ static bool combineInstructionsOverFunction(
// Iterate while there is work to do.
unsigned Iteration = 0;
while (true) {
- ++NumWorklistIterations;
++Iteration;
- if (Iteration > InfiniteLoopDetectionThreshold) {
- report_fatal_error(
- "Instruction Combining seems stuck in an infinite loop after " +
- Twine(InfiniteLoopDetectionThreshold) + " iterations.");
- }
-
- if (Iteration > MaxIterations) {
- LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations
+ if (Iteration > Opts.MaxIterations && !Opts.VerifyFixpoint) {
+ LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << Opts.MaxIterations
<< " on " << F.getName()
- << " reached; stopping before reaching a fixpoint\n");
+ << " reached; stopping without verifying fixpoint\n");
break;
}
+ ++NumWorklistIterations;
LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on "
<< F.getName() << "\n");
- MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist);
-
InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT,
ORE, BFI, PSI, DL, LI);
IC.MaxArraySizeForCombine = MaxArraySize;
-
- if (!IC.run())
+ bool MadeChangeInThisIteration = IC.prepareWorklist(F, RPOT);
+ MadeChangeInThisIteration |= IC.run();
+ if (!MadeChangeInThisIteration)
break;
MadeIRChange = true;
+ if (Iteration > Opts.MaxIterations) {
+ report_fatal_error(
+ "Instruction Combining did not reach a fixpoint after " +
+ Twine(Opts.MaxIterations) + " iterations");
+ }
}
if (Iteration == 1)
@@ -4307,7 +4600,8 @@ void InstCombinePass::printPipeline(
OS, MapClassName2PassName);
OS << '<';
OS << "max-iterations=" << Options.MaxIterations << ";";
- OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info";
+ OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info;";
+ OS << (Options.VerifyFixpoint ? "" : "no-") << "verify-fixpoint";
OS << '>';
}
@@ -4333,7 +4627,7 @@ PreservedAnalyses InstCombinePass::run(Function &F,
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
- BFI, PSI, Options.MaxIterations, LI))
+ BFI, PSI, LI, Options))
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
@@ -4382,8 +4676,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
nullptr;
return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
- BFI, PSI,
- InstCombineDefaultMaxIterations, LI);
+ BFI, PSI, LI, InstCombineOptions());
}
char InstructionCombiningPass::ID = 0;
diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index bde5fba20f3b..b175e6f93f3e 100644
--- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -201,8 +201,8 @@ static cl::opt<bool> ClRecover(
static cl::opt<bool> ClInsertVersionCheck(
"asan-guard-against-version-mismatch",
- cl::desc("Guard against compiler/runtime version mismatch."),
- cl::Hidden, cl::init(true));
+ cl::desc("Guard against compiler/runtime version mismatch."), cl::Hidden,
+ cl::init(true));
// This flag may need to be replaced with -f[no-]asan-reads.
static cl::opt<bool> ClInstrumentReads("asan-instrument-reads",
@@ -323,10 +323,9 @@ static cl::opt<unsigned> ClRealignStack(
static cl::opt<int> ClInstrumentationWithCallsThreshold(
"asan-instrumentation-with-call-threshold",
- cl::desc(
- "If the function being instrumented contains more than "
- "this number of memory accesses, use callbacks instead of "
- "inline checks (-1 means never use callbacks)."),
+ cl::desc("If the function being instrumented contains more than "
+ "this number of memory accesses, use callbacks instead of "
+ "inline checks (-1 means never use callbacks)."),
cl::Hidden, cl::init(7000));
static cl::opt<std::string> ClMemoryAccessCallbackPrefix(
@@ -491,7 +490,8 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize,
bool IsMIPS32 = TargetTriple.isMIPS32();
bool IsMIPS64 = TargetTriple.isMIPS64();
bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb();
- bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64;
+ bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64 ||
+ TargetTriple.getArch() == Triple::aarch64_be;
bool IsLoongArch64 = TargetTriple.isLoongArch64();
bool IsRISCV64 = TargetTriple.getArch() == Triple::riscv64;
bool IsWindows = TargetTriple.isOSWindows();
@@ -644,8 +644,9 @@ namespace {
/// AddressSanitizer: instrument the code in module to find memory bugs.
struct AddressSanitizer {
AddressSanitizer(Module &M, const StackSafetyGlobalInfo *SSGI,
- bool CompileKernel = false, bool Recover = false,
- bool UseAfterScope = false,
+ int InstrumentationWithCallsThreshold,
+ uint32_t MaxInlinePoisoningSize, bool CompileKernel = false,
+ bool Recover = false, bool UseAfterScope = false,
AsanDetectStackUseAfterReturnMode UseAfterReturn =
AsanDetectStackUseAfterReturnMode::Runtime)
: CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan
@@ -654,12 +655,19 @@ struct AddressSanitizer {
UseAfterScope(UseAfterScope || ClUseAfterScope),
UseAfterReturn(ClUseAfterReturn.getNumOccurrences() ? ClUseAfterReturn
: UseAfterReturn),
- SSGI(SSGI) {
+ SSGI(SSGI),
+ InstrumentationWithCallsThreshold(
+ ClInstrumentationWithCallsThreshold.getNumOccurrences() > 0
+ ? ClInstrumentationWithCallsThreshold
+ : InstrumentationWithCallsThreshold),
+ MaxInlinePoisoningSize(ClMaxInlinePoisoningSize.getNumOccurrences() > 0
+ ? ClMaxInlinePoisoningSize
+ : MaxInlinePoisoningSize) {
C = &(M.getContext());
DL = &M.getDataLayout();
LongSize = M.getDataLayout().getPointerSizeInBits();
IntptrTy = Type::getIntNTy(*C, LongSize);
- Int8PtrTy = Type::getInt8PtrTy(*C);
+ PtrTy = PointerType::getUnqual(*C);
Int32Ty = Type::getInt32Ty(*C);
TargetTriple = Triple(M.getTargetTriple());
@@ -751,8 +759,8 @@ private:
bool UseAfterScope;
AsanDetectStackUseAfterReturnMode UseAfterReturn;
Type *IntptrTy;
- Type *Int8PtrTy;
Type *Int32Ty;
+ PointerType *PtrTy;
ShadowMapping Mapping;
FunctionCallee AsanHandleNoReturnFunc;
FunctionCallee AsanPtrCmpFunction, AsanPtrSubFunction;
@@ -773,17 +781,22 @@ private:
FunctionCallee AMDGPUAddressShared;
FunctionCallee AMDGPUAddressPrivate;
+ int InstrumentationWithCallsThreshold;
+ uint32_t MaxInlinePoisoningSize;
};
class ModuleAddressSanitizer {
public:
- ModuleAddressSanitizer(Module &M, bool CompileKernel = false,
- bool Recover = false, bool UseGlobalsGC = true,
- bool UseOdrIndicator = true,
+ ModuleAddressSanitizer(Module &M, bool InsertVersionCheck,
+ bool CompileKernel = false, bool Recover = false,
+ bool UseGlobalsGC = true, bool UseOdrIndicator = true,
AsanDtorKind DestructorKind = AsanDtorKind::Global,
AsanCtorKind ConstructorKind = AsanCtorKind::Global)
: CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan
: CompileKernel),
+ InsertVersionCheck(ClInsertVersionCheck.getNumOccurrences() > 0
+ ? ClInsertVersionCheck
+ : InsertVersionCheck),
Recover(ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover),
UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC && !this->CompileKernel),
// Enable aliases as they should have no downside with ODR indicators.
@@ -802,10 +815,13 @@ public:
// do globals-gc.
UseCtorComdat(UseGlobalsGC && ClWithComdat && !this->CompileKernel),
DestructorKind(DestructorKind),
- ConstructorKind(ConstructorKind) {
+ ConstructorKind(ClConstructorKind.getNumOccurrences() > 0
+ ? ClConstructorKind
+ : ConstructorKind) {
C = &(M.getContext());
int LongSize = M.getDataLayout().getPointerSizeInBits();
IntptrTy = Type::getIntNTy(*C, LongSize);
+ PtrTy = PointerType::getUnqual(*C);
TargetTriple = Triple(M.getTargetTriple());
Mapping = getShadowMapping(TargetTriple, LongSize, this->CompileKernel);
@@ -819,11 +835,11 @@ public:
private:
void initializeCallbacks(Module &M);
- bool InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat);
+ void instrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat);
void InstrumentGlobalsCOFF(IRBuilder<> &IRB, Module &M,
ArrayRef<GlobalVariable *> ExtendedGlobals,
ArrayRef<Constant *> MetadataInitializers);
- void InstrumentGlobalsELF(IRBuilder<> &IRB, Module &M,
+ void instrumentGlobalsELF(IRBuilder<> &IRB, Module &M,
ArrayRef<GlobalVariable *> ExtendedGlobals,
ArrayRef<Constant *> MetadataInitializers,
const std::string &UniqueModuleId);
@@ -854,6 +870,7 @@ private:
int GetAsanVersion(const Module &M) const;
bool CompileKernel;
+ bool InsertVersionCheck;
bool Recover;
bool UseGlobalsGC;
bool UsePrivateAlias;
@@ -862,6 +879,7 @@ private:
AsanDtorKind DestructorKind;
AsanCtorKind ConstructorKind;
Type *IntptrTy;
+ PointerType *PtrTy;
LLVMContext *C;
Triple TargetTriple;
ShadowMapping Mapping;
@@ -1148,22 +1166,22 @@ AddressSanitizerPass::AddressSanitizerPass(
AsanCtorKind ConstructorKind)
: Options(Options), UseGlobalGC(UseGlobalGC),
UseOdrIndicator(UseOdrIndicator), DestructorKind(DestructorKind),
- ConstructorKind(ClConstructorKind) {}
+ ConstructorKind(ConstructorKind) {}
PreservedAnalyses AddressSanitizerPass::run(Module &M,
ModuleAnalysisManager &MAM) {
- ModuleAddressSanitizer ModuleSanitizer(M, Options.CompileKernel,
- Options.Recover, UseGlobalGC,
- UseOdrIndicator, DestructorKind,
- ConstructorKind);
+ ModuleAddressSanitizer ModuleSanitizer(
+ M, Options.InsertVersionCheck, Options.CompileKernel, Options.Recover,
+ UseGlobalGC, UseOdrIndicator, DestructorKind, ConstructorKind);
bool Modified = false;
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
const StackSafetyGlobalInfo *const SSGI =
ClUseStackSafety ? &MAM.getResult<StackSafetyGlobalAnalysis>(M) : nullptr;
for (Function &F : M) {
- AddressSanitizer FunctionSanitizer(M, SSGI, Options.CompileKernel,
- Options.Recover, Options.UseAfterScope,
- Options.UseAfterReturn);
+ AddressSanitizer FunctionSanitizer(
+ M, SSGI, Options.InstrumentationWithCallsThreshold,
+ Options.MaxInlinePoisoningSize, Options.CompileKernel, Options.Recover,
+ Options.UseAfterScope, Options.UseAfterReturn);
const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
Modified |= FunctionSanitizer.instrumentFunction(F, &TLI);
}
@@ -1188,17 +1206,17 @@ static size_t TypeStoreSizeToSizeIndex(uint32_t TypeSize) {
/// Check if \p G has been created by a trusted compiler pass.
static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) {
// Do not instrument @llvm.global_ctors, @llvm.used, etc.
- if (G->getName().startswith("llvm.") ||
+ if (G->getName().starts_with("llvm.") ||
// Do not instrument gcov counter arrays.
- G->getName().startswith("__llvm_gcov_ctr") ||
+ G->getName().starts_with("__llvm_gcov_ctr") ||
// Do not instrument rtti proxy symbols for function sanitizer.
- G->getName().startswith("__llvm_rtti_proxy"))
+ G->getName().starts_with("__llvm_rtti_proxy"))
return true;
// Do not instrument asan globals.
- if (G->getName().startswith(kAsanGenPrefix) ||
- G->getName().startswith(kSanCovGenPrefix) ||
- G->getName().startswith(kODRGenPrefix))
+ if (G->getName().starts_with(kAsanGenPrefix) ||
+ G->getName().starts_with(kSanCovGenPrefix) ||
+ G->getName().starts_with(kODRGenPrefix))
return true;
return false;
@@ -1232,15 +1250,13 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) {
void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
InstrumentationIRBuilder IRB(MI);
if (isa<MemTransferInst>(MI)) {
- IRB.CreateCall(
- isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ IRB.CreateCall(isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy,
+ {MI->getOperand(0), MI->getOperand(1),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
} else if (isa<MemSetInst>(MI)) {
IRB.CreateCall(
AsanMemset,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ {MI->getOperand(0),
IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
}
@@ -1570,7 +1586,7 @@ void AddressSanitizer::instrumentMaskedLoadOrStore(
InstrumentedAddress = IRB.CreateExtractElement(Addr, Index);
} else if (Stride) {
Index = IRB.CreateMul(Index, Stride);
- Addr = IRB.CreateBitCast(Addr, Type::getInt8PtrTy(*C));
+ Addr = IRB.CreateBitCast(Addr, PointerType::getUnqual(*C));
InstrumentedAddress = IRB.CreateGEP(Type::getInt8Ty(*C), Addr, {Index});
} else {
InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index});
@@ -1695,9 +1711,8 @@ Instruction *AddressSanitizer::instrumentAMDGPUAddress(
return InsertBefore;
// Instrument generic addresses in supported addressspaces.
IRBuilder<> IRB(InsertBefore);
- Value *AddrLong = IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy());
- Value *IsShared = IRB.CreateCall(AMDGPUAddressShared, {AddrLong});
- Value *IsPrivate = IRB.CreateCall(AMDGPUAddressPrivate, {AddrLong});
+ Value *IsShared = IRB.CreateCall(AMDGPUAddressShared, {Addr});
+ Value *IsPrivate = IRB.CreateCall(AMDGPUAddressPrivate, {Addr});
Value *IsSharedOrPrivate = IRB.CreateOr(IsShared, IsPrivate);
Value *Cmp = IRB.CreateNot(IsSharedOrPrivate);
Value *AddrSpaceZeroLanding =
@@ -1728,7 +1743,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns,
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
IRB.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::asan_check_memaccess),
- {IRB.CreatePointerCast(Addr, Int8PtrTy),
+ {IRB.CreatePointerCast(Addr, PtrTy),
ConstantInt::get(Int32Ty, AccessInfo.Packed)});
return;
}
@@ -1869,7 +1884,7 @@ ModuleAddressSanitizer::getExcludedAliasedGlobal(const GlobalAlias &GA) const {
// When compiling the kernel, globals that are aliased by symbols prefixed
// by "__" are special and cannot be padded with a redzone.
- if (GA.getName().startswith("__"))
+ if (GA.getName().starts_with("__"))
return dyn_cast<GlobalVariable>(C->stripPointerCastsAndAliases());
return nullptr;
@@ -1939,9 +1954,9 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const {
// Do not instrument function pointers to initialization and termination
// routines: dynamic linker will not properly handle redzones.
- if (Section.startswith(".preinit_array") ||
- Section.startswith(".init_array") ||
- Section.startswith(".fini_array")) {
+ if (Section.starts_with(".preinit_array") ||
+ Section.starts_with(".init_array") ||
+ Section.starts_with(".fini_array")) {
return false;
}
@@ -1978,7 +1993,7 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const {
// those conform to /usr/lib/objc/runtime.h, so we can't add redzones to
// them.
if (ParsedSegment == "__OBJC" ||
- (ParsedSegment == "__DATA" && ParsedSection.startswith("__objc_"))) {
+ (ParsedSegment == "__DATA" && ParsedSection.starts_with("__objc_"))) {
LLVM_DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n");
return false;
}
@@ -2006,7 +2021,7 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const {
if (CompileKernel) {
// Globals that prefixed by "__" are special and cannot be padded with a
// redzone.
- if (G->getName().startswith("__"))
+ if (G->getName().starts_with("__"))
return false;
}
@@ -2129,6 +2144,9 @@ ModuleAddressSanitizer::CreateMetadataGlobal(Module &M, Constant *Initializer,
M, Initializer->getType(), false, Linkage, Initializer,
Twine("__asan_global_") + GlobalValue::dropLLVMManglingEscape(OriginalName));
Metadata->setSection(getGlobalMetadataSection());
+ // Place metadata in a large section for x86-64 ELF binaries to mitigate
+ // relocation pressure.
+ setGlobalVariableLargeSection(TargetTriple, *Metadata);
return Metadata;
}
@@ -2177,7 +2195,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsCOFF(
appendToCompilerUsed(M, MetadataGlobals);
}
-void ModuleAddressSanitizer::InstrumentGlobalsELF(
+void ModuleAddressSanitizer::instrumentGlobalsELF(
IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals,
ArrayRef<Constant *> MetadataInitializers,
const std::string &UniqueModuleId) {
@@ -2187,7 +2205,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsELF(
// false negative odr violations at link time. If odr indicators are used, we
// keep the comdat sections, as link time odr violations will be dectected on
// the odr indicator symbols.
- bool UseComdatForGlobalsGC = UseOdrIndicator;
+ bool UseComdatForGlobalsGC = UseOdrIndicator && !UniqueModuleId.empty();
SmallVector<GlobalValue *, 16> MetadataGlobals(ExtendedGlobals.size());
for (size_t i = 0; i < ExtendedGlobals.size(); i++) {
@@ -2237,7 +2255,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsELF(
// We also need to unregister globals at the end, e.g., when a shared library
// gets closed.
- if (DestructorKind != AsanDtorKind::None) {
+ if (DestructorKind != AsanDtorKind::None && !MetadataGlobals.empty()) {
IRBuilder<> IrbDtor(CreateAsanModuleDtor(M));
IrbDtor.CreateCall(AsanUnregisterElfGlobals,
{IRB.CreatePointerCast(RegisteredFlag, IntptrTy),
@@ -2343,10 +2361,8 @@ void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray(
// redzones and inserts this function into llvm.global_ctors.
// Sets *CtorComdat to true if the global registration code emitted into the
// asan constructor is comdat-compatible.
-bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
+void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, Module &M,
bool *CtorComdat) {
- *CtorComdat = false;
-
// Build set of globals that are aliased by some GA, where
// getExcludedAliasedGlobal(GA) returns the relevant GlobalVariable.
SmallPtrSet<const GlobalVariable *, 16> AliasedGlobalExclusions;
@@ -2364,11 +2380,6 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
}
size_t n = GlobalsToChange.size();
- if (n == 0) {
- *CtorComdat = true;
- return false;
- }
-
auto &DL = M.getDataLayout();
// A global is described by a structure
@@ -2391,8 +2402,11 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
// We shouldn't merge same module names, as this string serves as unique
// module ID in runtime.
- GlobalVariable *ModuleName = createPrivateGlobalForString(
- M, M.getModuleIdentifier(), /*AllowMerging*/ false, kAsanGenPrefix);
+ GlobalVariable *ModuleName =
+ n != 0
+ ? createPrivateGlobalForString(M, M.getModuleIdentifier(),
+ /*AllowMerging*/ false, kAsanGenPrefix)
+ : nullptr;
for (size_t i = 0; i < n; i++) {
GlobalVariable *G = GlobalsToChange[i];
@@ -2455,7 +2469,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
G->eraseFromParent();
NewGlobals[i] = NewGlobal;
- Constant *ODRIndicator = ConstantExpr::getNullValue(IRB.getInt8PtrTy());
+ Constant *ODRIndicator = ConstantPointerNull::get(PtrTy);
GlobalValue *InstrumentedGlobal = NewGlobal;
bool CanUsePrivateAliases =
@@ -2470,8 +2484,8 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
// ODR should not happen for local linkage.
if (NewGlobal->hasLocalLinkage()) {
- ODRIndicator = ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1),
- IRB.getInt8PtrTy());
+ ODRIndicator =
+ ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), PtrTy);
} else if (UseOdrIndicator) {
// With local aliases, we need to provide another externally visible
// symbol __odr_asan_XXX to detect ODR violation.
@@ -2517,19 +2531,27 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
}
appendToCompilerUsed(M, ArrayRef<GlobalValue *>(GlobalsToAddToUsedList));
- std::string ELFUniqueModuleId =
- (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) ? getUniqueModuleId(&M)
- : "";
-
- if (!ELFUniqueModuleId.empty()) {
- InstrumentGlobalsELF(IRB, M, NewGlobals, Initializers, ELFUniqueModuleId);
+ if (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) {
+ // Use COMDAT and register globals even if n == 0 to ensure that (a) the
+ // linkage unit will only have one module constructor, and (b) the register
+ // function will be called. The module destructor is not created when n ==
+ // 0.
*CtorComdat = true;
- } else if (UseGlobalsGC && TargetTriple.isOSBinFormatCOFF()) {
- InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers);
- } else if (UseGlobalsGC && ShouldUseMachOGlobalsSection()) {
- InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers);
+ instrumentGlobalsELF(IRB, M, NewGlobals, Initializers,
+ getUniqueModuleId(&M));
+ } else if (n == 0) {
+ // When UseGlobalsGC is false, COMDAT can still be used if n == 0, because
+ // all compile units will have identical module constructor/destructor.
+ *CtorComdat = TargetTriple.isOSBinFormatELF();
} else {
- InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers);
+ *CtorComdat = false;
+ if (UseGlobalsGC && TargetTriple.isOSBinFormatCOFF()) {
+ InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers);
+ } else if (UseGlobalsGC && ShouldUseMachOGlobalsSection()) {
+ InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers);
+ } else {
+ InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers);
+ }
}
// Create calls for poisoning before initializers run and unpoisoning after.
@@ -2537,7 +2559,6 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M,
createInitializerPoisonCalls(M, ModuleName);
LLVM_DEBUG(dbgs() << M);
- return true;
}
uint64_t
@@ -2588,7 +2609,7 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) {
} else {
std::string AsanVersion = std::to_string(GetAsanVersion(M));
std::string VersionCheckName =
- ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : "";
+ InsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : "";
std::tie(AsanCtorFunction, std::ignore) =
createSanitizerCtorAndInitFunctions(M, kAsanModuleCtorName,
kAsanInitName, /*InitArgTypes=*/{},
@@ -2601,10 +2622,10 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) {
assert(AsanCtorFunction || ConstructorKind == AsanCtorKind::None);
if (AsanCtorFunction) {
IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator());
- InstrumentGlobals(IRB, M, &CtorComdat);
+ instrumentGlobals(IRB, M, &CtorComdat);
} else {
IRBuilder<> IRB(*C);
- InstrumentGlobals(IRB, M, &CtorComdat);
+ instrumentGlobals(IRB, M, &CtorComdat);
}
}
@@ -2684,15 +2705,12 @@ void AddressSanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo *T
? std::string("")
: ClMemoryAccessCallbackPrefix;
AsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy);
- AsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy);
+ PtrTy, PtrTy, PtrTy, IntptrTy);
+ AsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", PtrTy,
+ PtrTy, PtrTy, IntptrTy);
AsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset",
TLI->getAttrList(C, {1}, /*Signed=*/false),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt32Ty(), IntptrTy);
+ PtrTy, PtrTy, IRB.getInt32Ty(), IntptrTy);
AsanHandleNoReturnFunc =
M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy());
@@ -2705,10 +2723,10 @@ void AddressSanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo *T
AsanShadowGlobal = M.getOrInsertGlobal("__asan_shadow",
ArrayType::get(IRB.getInt8Ty(), 0));
- AMDGPUAddressShared = M.getOrInsertFunction(
- kAMDGPUAddressSharedName, IRB.getInt1Ty(), IRB.getInt8PtrTy());
- AMDGPUAddressPrivate = M.getOrInsertFunction(
- kAMDGPUAddressPrivateName, IRB.getInt1Ty(), IRB.getInt8PtrTy());
+ AMDGPUAddressShared =
+ M.getOrInsertFunction(kAMDGPUAddressSharedName, IRB.getInt1Ty(), PtrTy);
+ AMDGPUAddressPrivate =
+ M.getOrInsertFunction(kAMDGPUAddressPrivateName, IRB.getInt1Ty(), PtrTy);
}
bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) {
@@ -2799,7 +2817,7 @@ bool AddressSanitizer::instrumentFunction(Function &F,
return false;
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false;
if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false;
- if (F.getName().startswith("__asan_")) return false;
+ if (F.getName().starts_with("__asan_")) return false;
bool FunctionModified = false;
@@ -2890,9 +2908,9 @@ bool AddressSanitizer::instrumentFunction(Function &F,
}
}
- bool UseCalls = (ClInstrumentationWithCallsThreshold >= 0 &&
+ bool UseCalls = (InstrumentationWithCallsThreshold >= 0 &&
OperandsToInstrument.size() + IntrinToInstrument.size() >
- (unsigned)ClInstrumentationWithCallsThreshold);
+ (unsigned)InstrumentationWithCallsThreshold);
const DataLayout &DL = F.getParent()->getDataLayout();
ObjectSizeOpts ObjSizeOpts;
ObjSizeOpts.RoundToAlign = true;
@@ -3034,7 +3052,7 @@ void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask,
Value *Ptr = IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i));
Value *Poison = IRB.getIntN(StoreSizeInBytes * 8, Val);
IRB.CreateAlignedStore(
- Poison, IRB.CreateIntToPtr(Ptr, Poison->getType()->getPointerTo()),
+ Poison, IRB.CreateIntToPtr(Ptr, PointerType::getUnqual(Poison->getContext())),
Align(1));
i += StoreSizeInBytes;
@@ -3066,7 +3084,7 @@ void FunctionStackPoisoner::copyToShadow(ArrayRef<uint8_t> ShadowMask,
for (; j < End && ShadowMask[j] && Val == ShadowBytes[j]; ++j) {
}
- if (j - i >= ClMaxInlinePoisoningSize) {
+ if (j - i >= ASan.MaxInlinePoisoningSize) {
copyToShadowInline(ShadowMask, ShadowBytes, Done, i, IRB, ShadowBase);
IRB.CreateCall(AsanSetShadowFunc[Val],
{IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)),
@@ -3500,7 +3518,7 @@ void FunctionStackPoisoner::processStaticAllocas() {
IntptrTy, IRBPoison.CreateIntToPtr(SavedFlagPtrPtr, IntptrPtrTy));
IRBPoison.CreateStore(
Constant::getNullValue(IRBPoison.getInt8Ty()),
- IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getInt8PtrTy()));
+ IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getPtrTy()));
} else {
// For larger frames call __asan_stack_free_*.
IRBPoison.CreateCall(
diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
index 709095184af5..ee5b81960417 100644
--- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
+++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
@@ -37,6 +37,9 @@ using namespace llvm;
static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap",
cl::desc("Use one trap block per function"));
+static cl::opt<bool> DebugTrapBB("bounds-checking-unique-traps",
+ cl::desc("Always use one trap per check"));
+
STATISTIC(ChecksAdded, "Bounds checks added");
STATISTIC(ChecksSkipped, "Bounds checks skipped");
STATISTIC(ChecksUnable, "Bounds checks unable to add");
@@ -180,19 +183,27 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI,
// will create a fresh block every time it is called.
BasicBlock *TrapBB = nullptr;
auto GetTrapBB = [&TrapBB](BuilderTy &IRB) {
- if (TrapBB && SingleTrapBB)
- return TrapBB;
-
Function *Fn = IRB.GetInsertBlock()->getParent();
- // FIXME: This debug location doesn't make a lot of sense in the
- // `SingleTrapBB` case.
auto DebugLoc = IRB.getCurrentDebugLocation();
IRBuilder<>::InsertPointGuard Guard(IRB);
+
+ if (TrapBB && SingleTrapBB && !DebugTrapBB)
+ return TrapBB;
+
TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
IRB.SetInsertPoint(TrapBB);
- auto *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
- CallInst *TrapCall = IRB.CreateCall(F, {});
+ Intrinsic::ID IntrID = DebugTrapBB ? Intrinsic::ubsantrap : Intrinsic::trap;
+ auto *F = Intrinsic::getDeclaration(Fn->getParent(), IntrID);
+
+ CallInst *TrapCall;
+ if (DebugTrapBB) {
+ TrapCall =
+ IRB.CreateCall(F, ConstantInt::get(IRB.getInt8Ty(), Fn->size()));
+ } else {
+ TrapCall = IRB.CreateCall(F, {});
+ }
+
TrapCall->setDoesNotReturn();
TrapCall->setDoesNotThrow();
TrapCall->setDebugLoc(DebugLoc);
diff --git a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
index d53e12ad1ff5..e2e5f21b376b 100644
--- a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
+++ b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp
@@ -66,7 +66,7 @@ static bool runCGProfilePass(
if (F.isDeclaration() || !F.getEntryCount())
continue;
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
- if (BFI.getEntryFreq() == 0)
+ if (BFI.getEntryFreq() == BlockFrequency(0))
continue;
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
for (auto &BB : F) {
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 3e3be536defc..0a3d8d6000cf 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1593,8 +1593,8 @@ static void insertTrivialPHIs(CHRScope *Scope,
// Insert a trivial phi for I (phi [&I, P0], [&I, P1], ...) at
// ExitBlock. Replace I with the new phi in UI unless UI is another
// phi at ExitBlock.
- PHINode *PN = PHINode::Create(I.getType(), pred_size(ExitBlock), "",
- &ExitBlock->front());
+ PHINode *PN = PHINode::Create(I.getType(), pred_size(ExitBlock), "");
+ PN->insertBefore(ExitBlock->begin());
for (BasicBlock *Pred : predecessors(ExitBlock)) {
PN->addIncoming(&I, Pred);
}
@@ -1777,6 +1777,13 @@ void CHR::cloneScopeBlocks(CHRScope *Scope,
BasicBlock *NewBB = CloneBasicBlock(BB, VMap, ".nonchr", &F);
NewBlocks.push_back(NewBB);
VMap[BB] = NewBB;
+
+ // Unreachable predecessors will not be cloned and will not have an edge
+ // to the cloned block. As such, also remove them from any phi nodes.
+ for (PHINode &PN : make_early_inc_range(NewBB->phis()))
+ PN.removeIncomingValueIf([&](unsigned Idx) {
+ return !DT.isReachableFromEntry(PN.getIncomingBlock(Idx));
+ });
}
// Place the cloned blocks right after the original blocks (right before the
@@ -1871,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
- MDBuilder MDB(F.getContext());
- MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*MergedBR, Weights);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}
diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index 8caee5bed8ed..2ba127bba6f6 100644
--- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -564,7 +564,7 @@ class DataFlowSanitizer {
/// getShadowTy([n x T]) = [n x getShadowTy(T)]
/// getShadowTy(other type) = i16
Type *getShadowTy(Type *OrigTy);
- /// Returns the shadow type of of V's type.
+ /// Returns the shadow type of V's type.
Type *getShadowTy(Value *V);
const uint64_t NumOfElementsInArgOrgTLS = ArgTLSSize / OriginWidthBytes;
@@ -1145,7 +1145,7 @@ bool DataFlowSanitizer::initializeModule(Module &M) {
Mod = &M;
Ctx = &M.getContext();
- Int8Ptr = Type::getInt8PtrTy(*Ctx);
+ Int8Ptr = PointerType::getUnqual(*Ctx);
OriginTy = IntegerType::get(*Ctx, OriginWidthBits);
OriginPtrTy = PointerType::getUnqual(OriginTy);
PrimitiveShadowTy = IntegerType::get(*Ctx, ShadowWidthBits);
@@ -1162,19 +1162,19 @@ bool DataFlowSanitizer::initializeModule(Module &M) {
FunctionType::get(IntegerType::get(*Ctx, 64), DFSanLoadLabelAndOriginArgs,
/*isVarArg=*/false);
DFSanUnimplementedFnTy = FunctionType::get(
- Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false);
+ Type::getVoidTy(*Ctx), PointerType::getUnqual(*Ctx), /*isVarArg=*/false);
Type *DFSanWrapperExternWeakNullArgs[2] = {Int8Ptr, Int8Ptr};
DFSanWrapperExternWeakNullFnTy =
FunctionType::get(Type::getVoidTy(*Ctx), DFSanWrapperExternWeakNullArgs,
/*isVarArg=*/false);
Type *DFSanSetLabelArgs[4] = {PrimitiveShadowTy, OriginTy,
- Type::getInt8PtrTy(*Ctx), IntptrTy};
+ PointerType::getUnqual(*Ctx), IntptrTy};
DFSanSetLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx),
DFSanSetLabelArgs, /*isVarArg=*/false);
DFSanNonzeroLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), std::nullopt,
/*isVarArg=*/false);
DFSanVarargWrapperFnTy = FunctionType::get(
- Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false);
+ Type::getVoidTy(*Ctx), PointerType::getUnqual(*Ctx), /*isVarArg=*/false);
DFSanConditionalCallbackFnTy =
FunctionType::get(Type::getVoidTy(*Ctx), PrimitiveShadowTy,
/*isVarArg=*/false);
@@ -1288,7 +1288,7 @@ void DataFlowSanitizer::buildExternWeakCheckIfNeeded(IRBuilder<> &IRB,
// for a extern weak function, add a check here to help identify the issue.
if (GlobalValue::isExternalWeakLinkage(F->getLinkage())) {
std::vector<Value *> Args;
- Args.push_back(IRB.CreatePointerCast(F, IRB.getInt8PtrTy()));
+ Args.push_back(F);
Args.push_back(IRB.CreateGlobalStringPtr(F->getName()));
IRB.CreateCall(DFSanWrapperExternWeakNullFn, Args);
}
@@ -1553,7 +1553,7 @@ bool DataFlowSanitizer::runImpl(
assert(isa<Function>(C) && "Personality routine is not a function!");
Function *F = cast<Function>(C);
if (!isInstrumented(F))
- llvm::erase_value(FnsToInstrument, F);
+ llvm::erase(FnsToInstrument, F);
}
}
@@ -1575,7 +1575,7 @@ bool DataFlowSanitizer::runImpl(
// below will take care of instrumenting it.
Function *NewF =
buildWrapperFunction(F, "", GA.getLinkage(), F->getFunctionType());
- GA.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, GA.getType()));
+ GA.replaceAllUsesWith(NewF);
NewF->takeName(&GA);
GA.eraseFromParent();
FnsToInstrument.push_back(NewF);
@@ -1622,9 +1622,6 @@ bool DataFlowSanitizer::runImpl(
WrapperLinkage, FT);
NewF->removeFnAttrs(ReadOnlyNoneAttrs);
- Value *WrappedFnCst =
- ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT));
-
// Extern weak functions can sometimes be null at execution time.
// Code will sometimes check if an extern weak function is null.
// This could look something like:
@@ -1657,9 +1654,9 @@ bool DataFlowSanitizer::runImpl(
}
return true;
};
- F.replaceUsesWithIf(WrappedFnCst, IsNotCmpUse);
+ F.replaceUsesWithIf(NewF, IsNotCmpUse);
- UnwrappedFnMap[WrappedFnCst] = &F;
+ UnwrappedFnMap[NewF] = &F;
*FI = NewF;
if (!F.isDeclaration()) {
@@ -2273,8 +2270,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking(
IRBuilder<> IRB(Pos);
CallInst *Call =
IRB.CreateCall(DFS.DFSanLoadLabelAndOriginFn,
- {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
- ConstantInt::get(DFS.IntptrTy, Size)});
+ {Addr, ConstantInt::get(DFS.IntptrTy, Size)});
Call->addRetAttr(Attribute::ZExt);
return {IRB.CreateTrunc(IRB.CreateLShr(Call, DFS.OriginWidthBits),
DFS.PrimitiveShadowTy),
@@ -2436,9 +2432,9 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) {
if (ClEventCallbacks) {
IRBuilder<> IRB(Pos);
- Value *Addr8 = IRB.CreateBitCast(LI.getPointerOperand(), DFSF.DFS.Int8Ptr);
+ Value *Addr = LI.getPointerOperand();
CallInst *CI =
- IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr8});
+ IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr});
CI->addParamAttr(0, Attribute::ZExt);
}
@@ -2530,10 +2526,9 @@ void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size,
}
if (shouldInstrumentWithCall()) {
- IRB.CreateCall(DFS.DFSanMaybeStoreOriginFn,
- {CollapsedShadow,
- IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
- ConstantInt::get(DFS.IntptrTy, Size), Origin});
+ IRB.CreateCall(
+ DFS.DFSanMaybeStoreOriginFn,
+ {CollapsedShadow, Addr, ConstantInt::get(DFS.IntptrTy, Size), Origin});
} else {
Value *Cmp = convertToBool(CollapsedShadow, IRB, "_dfscmp");
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
@@ -2554,9 +2549,7 @@ void DFSanFunction::storeZeroPrimitiveShadow(Value *Addr, uint64_t Size,
IntegerType::get(*DFS.Ctx, Size * DFS.ShadowWidthBits);
Value *ExtZeroShadow = ConstantInt::get(ShadowTy, 0);
Value *ShadowAddr = DFS.getShadowAddress(Addr, Pos);
- Value *ExtShadowAddr =
- IRB.CreateBitCast(ShadowAddr, PointerType::getUnqual(ShadowTy));
- IRB.CreateAlignedStore(ExtZeroShadow, ExtShadowAddr, ShadowAlign);
+ IRB.CreateAlignedStore(ExtZeroShadow, ShadowAddr, ShadowAlign);
// Do not write origins for 0 shadows because we do not trace origins for
// untainted sinks.
}
@@ -2611,11 +2604,9 @@ void DFSanFunction::storePrimitiveShadowOrigin(Value *Addr, uint64_t Size,
ShadowVec, PrimitiveShadow,
ConstantInt::get(Type::getInt32Ty(*DFS.Ctx), I));
}
- Value *ShadowVecAddr =
- IRB.CreateBitCast(ShadowAddr, PointerType::getUnqual(ShadowVecTy));
do {
Value *CurShadowVecAddr =
- IRB.CreateConstGEP1_32(ShadowVecTy, ShadowVecAddr, Offset);
+ IRB.CreateConstGEP1_32(ShadowVecTy, ShadowAddr, Offset);
IRB.CreateAlignedStore(ShadowVec, CurShadowVecAddr, ShadowAlign);
LeftSize -= ShadowVecSize;
++Offset;
@@ -2699,9 +2690,9 @@ void DFSanVisitor::visitStoreInst(StoreInst &SI) {
PrimitiveShadow, Origin, &SI);
if (ClEventCallbacks) {
IRBuilder<> IRB(&SI);
- Value *Addr8 = IRB.CreateBitCast(SI.getPointerOperand(), DFSF.DFS.Int8Ptr);
+ Value *Addr = SI.getPointerOperand();
CallInst *CI =
- IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr8});
+ IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr});
CI->addParamAttr(0, Attribute::ZExt);
}
}
@@ -2918,11 +2909,9 @@ void DFSanVisitor::visitMemSetInst(MemSetInst &I) {
Value *ValOrigin = DFSF.DFS.shouldTrackOrigins()
? DFSF.getOrigin(I.getValue())
: DFSF.DFS.ZeroOrigin;
- IRB.CreateCall(
- DFSF.DFS.DFSanSetLabelFn,
- {ValShadow, ValOrigin,
- IRB.CreateBitCast(I.getDest(), Type::getInt8PtrTy(*DFSF.DFS.Ctx)),
- IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)});
+ IRB.CreateCall(DFSF.DFS.DFSanSetLabelFn,
+ {ValShadow, ValOrigin, I.getDest(),
+ IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)});
}
void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) {
@@ -2933,28 +2922,24 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) {
if (DFSF.DFS.shouldTrackOrigins()) {
IRB.CreateCall(
DFSF.DFS.DFSanMemOriginTransferFn,
- {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()),
+ {I.getArgOperand(0), I.getArgOperand(1),
IRB.CreateIntCast(I.getArgOperand(2), DFSF.DFS.IntptrTy, false)});
}
- Value *RawDestShadow = DFSF.DFS.getShadowAddress(I.getDest(), &I);
+ Value *DestShadow = DFSF.DFS.getShadowAddress(I.getDest(), &I);
Value *SrcShadow = DFSF.DFS.getShadowAddress(I.getSource(), &I);
Value *LenShadow =
IRB.CreateMul(I.getLength(), ConstantInt::get(I.getLength()->getType(),
DFSF.DFS.ShadowWidthBytes));
- Type *Int8Ptr = Type::getInt8PtrTy(*DFSF.DFS.Ctx);
- Value *DestShadow = IRB.CreateBitCast(RawDestShadow, Int8Ptr);
- SrcShadow = IRB.CreateBitCast(SrcShadow, Int8Ptr);
auto *MTI = cast<MemTransferInst>(
IRB.CreateCall(I.getFunctionType(), I.getCalledOperand(),
{DestShadow, SrcShadow, LenShadow, I.getVolatileCst()}));
MTI->setDestAlignment(DFSF.getShadowAlign(I.getDestAlign().valueOrOne()));
MTI->setSourceAlignment(DFSF.getShadowAlign(I.getSourceAlign().valueOrOne()));
if (ClEventCallbacks) {
- IRB.CreateCall(DFSF.DFS.DFSanMemTransferCallbackFn,
- {RawDestShadow,
- IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)});
+ IRB.CreateCall(
+ DFSF.DFS.DFSanMemTransferCallbackFn,
+ {DestShadow, IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)});
}
}
@@ -3225,10 +3210,9 @@ void DFSanVisitor::visitLibAtomicLoad(CallBase &CB) {
// TODO: Support ClCombinePointerLabelsOnLoad
// TODO: Support ClEventCallbacks
- NextIRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn,
- {NextIRB.CreatePointerCast(DstPtr, NextIRB.getInt8PtrTy()),
- NextIRB.CreatePointerCast(SrcPtr, NextIRB.getInt8PtrTy()),
- NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
+ NextIRB.CreateCall(
+ DFSF.DFS.DFSanMemShadowOriginTransferFn,
+ {DstPtr, SrcPtr, NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
}
Value *DFSanVisitor::makeAddReleaseOrderingTable(IRBuilder<> &IRB) {
@@ -3264,10 +3248,9 @@ void DFSanVisitor::visitLibAtomicStore(CallBase &CB) {
// TODO: Support ClCombinePointerLabelsOnStore
// TODO: Support ClEventCallbacks
- IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn,
- {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()),
- IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
+ IRB.CreateCall(
+ DFSF.DFS.DFSanMemShadowOriginTransferFn,
+ {DstPtr, SrcPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
}
void DFSanVisitor::visitLibAtomicExchange(CallBase &CB) {
@@ -3285,16 +3268,14 @@ void DFSanVisitor::visitLibAtomicExchange(CallBase &CB) {
// the additional complexity to address this is not warrented.
// Current Target to Dest
- IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn,
- {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()),
- IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
+ IRB.CreateCall(
+ DFSF.DFS.DFSanMemShadowOriginTransferFn,
+ {DstPtr, TargetPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
// Current Src to Target (overriding)
- IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn,
- {IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()),
- IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
+ IRB.CreateCall(
+ DFSF.DFS.DFSanMemShadowOriginTransferFn,
+ {TargetPtr, SrcPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
}
void DFSanVisitor::visitLibAtomicCompareExchange(CallBase &CB) {
@@ -3317,13 +3298,10 @@ void DFSanVisitor::visitLibAtomicCompareExchange(CallBase &CB) {
// If original call returned true, copy Desired to Target.
// If original call returned false, copy Target to Expected.
- NextIRB.CreateCall(
- DFSF.DFS.DFSanMemShadowOriginConditionalExchangeFn,
- {NextIRB.CreateIntCast(&CB, NextIRB.getInt8Ty(), false),
- NextIRB.CreatePointerCast(TargetPtr, NextIRB.getInt8PtrTy()),
- NextIRB.CreatePointerCast(ExpectedPtr, NextIRB.getInt8PtrTy()),
- NextIRB.CreatePointerCast(DesiredPtr, NextIRB.getInt8PtrTy()),
- NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
+ NextIRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginConditionalExchangeFn,
+ {NextIRB.CreateIntCast(&CB, NextIRB.getInt8Ty(), false),
+ TargetPtr, ExpectedPtr, DesiredPtr,
+ NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)});
}
void DFSanVisitor::visitCallBase(CallBase &CB) {
diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
index 21f0b1a92293..1ff0a34bae24 100644
--- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
@@ -148,7 +148,7 @@ private:
std::string mangleName(const DICompileUnit *CU, GCovFileType FileType);
GCOVOptions Options;
- support::endianness Endian;
+ llvm::endianness Endian;
raw_ostream *os;
// Checksum, produced by hash of EdgeDestinations
@@ -750,7 +750,7 @@ static BasicBlock *getInstrBB(CFGMST<Edge, BBInfo> &MST, Edge &E,
#ifndef NDEBUG
static void dumpEdges(CFGMST<Edge, BBInfo> &MST, GCOVFunction &GF) {
size_t ID = 0;
- for (auto &E : make_pointee_range(MST.AllEdges)) {
+ for (const auto &E : make_pointee_range(MST.allEdges())) {
GCOVBlock &Src = E.SrcBB ? GF.getBlock(E.SrcBB) : GF.getEntryBlock();
GCOVBlock &Dst = E.DestBB ? GF.getBlock(E.DestBB) : GF.getReturnBlock();
dbgs() << " Edge " << ID++ << ": " << Src.Number << "->" << Dst.Number
@@ -788,8 +788,8 @@ bool GCOVProfiler::emitProfileNotes(
std::vector<uint8_t> EdgeDestinations;
SmallVector<std::pair<GlobalVariable *, MDNode *>, 8> CountersBySP;
- Endian = M->getDataLayout().isLittleEndian() ? support::endianness::little
- : support::endianness::big;
+ Endian = M->getDataLayout().isLittleEndian() ? llvm::endianness::little
+ : llvm::endianness::big;
unsigned FunctionIdent = 0;
for (auto &F : M->functions()) {
DISubprogram *SP = F.getSubprogram();
@@ -820,8 +820,8 @@ bool GCOVProfiler::emitProfileNotes(
CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry_=*/false, BPI, BFI);
// getInstrBB can split basic blocks and push elements to AllEdges.
- for (size_t I : llvm::seq<size_t>(0, MST.AllEdges.size())) {
- auto &E = *MST.AllEdges[I];
+ for (size_t I : llvm::seq<size_t>(0, MST.numEdges())) {
+ auto &E = *MST.allEdges()[I];
// For now, disable spanning tree optimization when fork or exec* is
// used.
if (HasExecOrFork)
@@ -836,16 +836,16 @@ bool GCOVProfiler::emitProfileNotes(
// Some non-tree edges are IndirectBr which cannot be split. Ignore them
// as well.
- llvm::erase_if(MST.AllEdges, [](std::unique_ptr<Edge> &E) {
+ llvm::erase_if(MST.allEdges(), [](std::unique_ptr<Edge> &E) {
return E->Removed || (!E->InMST && !E->Place);
});
const size_t Measured =
std::stable_partition(
- MST.AllEdges.begin(), MST.AllEdges.end(),
+ MST.allEdges().begin(), MST.allEdges().end(),
[](std::unique_ptr<Edge> &E) { return E->Place; }) -
- MST.AllEdges.begin();
+ MST.allEdges().begin();
for (size_t I : llvm::seq<size_t>(0, Measured)) {
- Edge &E = *MST.AllEdges[I];
+ Edge &E = *MST.allEdges()[I];
GCOVBlock &Src =
E.SrcBB ? Func.getBlock(E.SrcBB) : Func.getEntryBlock();
GCOVBlock &Dst =
@@ -854,13 +854,13 @@ bool GCOVProfiler::emitProfileNotes(
E.DstNumber = Dst.Number;
}
std::stable_sort(
- MST.AllEdges.begin(), MST.AllEdges.begin() + Measured,
+ MST.allEdges().begin(), MST.allEdges().begin() + Measured,
[](const std::unique_ptr<Edge> &L, const std::unique_ptr<Edge> &R) {
return L->SrcNumber != R->SrcNumber ? L->SrcNumber < R->SrcNumber
: L->DstNumber < R->DstNumber;
});
- for (const Edge &E : make_pointee_range(MST.AllEdges)) {
+ for (const Edge &E : make_pointee_range(MST.allEdges())) {
GCOVBlock &Src =
E.SrcBB ? Func.getBlock(E.SrcBB) : Func.getEntryBlock();
GCOVBlock &Dst =
@@ -898,7 +898,9 @@ bool GCOVProfiler::emitProfileNotes(
if (Line == Loc.getLine()) continue;
Line = Loc.getLine();
- if (SP != getDISubprogram(Loc.getScope()))
+ MDNode *Scope = Loc.getScope();
+ // TODO: Handle blocks from another file due to #line, #include, etc.
+ if (isa<DILexicalBlockFile>(Scope) || SP != getDISubprogram(Scope))
continue;
GCOVLines &Lines = Block.getFile(Filename);
@@ -915,7 +917,7 @@ bool GCOVProfiler::emitProfileNotes(
CountersBySP.emplace_back(Counters, SP);
for (size_t I : llvm::seq<size_t>(0, Measured)) {
- const Edge &E = *MST.AllEdges[I];
+ const Edge &E = *MST.allEdges()[I];
IRBuilder<> Builder(E.Place, E.Place->getFirstInsertionPt());
Value *V = Builder.CreateConstInBoundsGEP2_64(
Counters->getValueType(), Counters, 0, I);
@@ -955,7 +957,7 @@ bool GCOVProfiler::emitProfileNotes(
continue;
}
os = &out;
- if (Endian == support::endianness::big) {
+ if (Endian == llvm::endianness::big) {
out.write("gcno", 4);
out.write(Options.Version, 4);
} else {
@@ -1029,9 +1031,9 @@ void GCOVProfiler::emitGlobalConstructor(
FunctionCallee GCOVProfiler::getStartFileFunc(const TargetLibraryInfo *TLI) {
Type *Args[] = {
- Type::getInt8PtrTy(*Ctx), // const char *orig_filename
- Type::getInt32Ty(*Ctx), // uint32_t version
- Type::getInt32Ty(*Ctx), // uint32_t checksum
+ PointerType::getUnqual(*Ctx), // const char *orig_filename
+ Type::getInt32Ty(*Ctx), // uint32_t version
+ Type::getInt32Ty(*Ctx), // uint32_t checksum
};
FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false);
return M->getOrInsertFunction("llvm_gcda_start_file", FTy,
@@ -1051,8 +1053,8 @@ FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) {
FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) {
Type *Args[] = {
- Type::getInt32Ty(*Ctx), // uint32_t num_counters
- Type::getInt64PtrTy(*Ctx), // uint64_t *counters
+ Type::getInt32Ty(*Ctx), // uint32_t num_counters
+ PointerType::getUnqual(*Ctx), // uint64_t *counters
};
FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false);
return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy,
@@ -1098,19 +1100,16 @@ Function *GCOVProfiler::insertCounterWriteout(
// Collect the relevant data into a large constant data structure that we can
// walk to write out everything.
StructType *StartFileCallArgsTy = StructType::create(
- {Builder.getInt8PtrTy(), Builder.getInt32Ty(), Builder.getInt32Ty()},
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getInt32Ty()},
"start_file_args_ty");
StructType *EmitFunctionCallArgsTy = StructType::create(
{Builder.getInt32Ty(), Builder.getInt32Ty(), Builder.getInt32Ty()},
"emit_function_args_ty");
- StructType *EmitArcsCallArgsTy = StructType::create(
- {Builder.getInt32Ty(), Builder.getInt64Ty()->getPointerTo()},
- "emit_arcs_args_ty");
- StructType *FileInfoTy =
- StructType::create({StartFileCallArgsTy, Builder.getInt32Ty(),
- EmitFunctionCallArgsTy->getPointerTo(),
- EmitArcsCallArgsTy->getPointerTo()},
- "file_info");
+ auto *PtrTy = Builder.getPtrTy();
+ StructType *EmitArcsCallArgsTy =
+ StructType::create({Builder.getInt32Ty(), PtrTy}, "emit_arcs_args_ty");
+ StructType *FileInfoTy = StructType::create(
+ {StartFileCallArgsTy, Builder.getInt32Ty(), PtrTy, PtrTy}, "file_info");
Constant *Zero32 = Builder.getInt32(0);
// Build an explicit array of two zeros for use in ConstantExpr GEP building.
diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
index 28db47a19092..f7f8fed643e9 100644
--- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
@@ -17,9 +17,11 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/StackSafetyAnalysis.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/BinaryFormat/ELF.h"
@@ -42,7 +44,6 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
-#include "llvm/IR/NoFolder.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Casting.h"
@@ -52,6 +53,7 @@
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/MemoryTaggingSupport.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -134,7 +136,7 @@ static cl::opt<size_t> ClMaxLifetimes(
static cl::opt<bool>
ClUseAfterScope("hwasan-use-after-scope",
cl::desc("detect use after scope within function"),
- cl::Hidden, cl::init(false));
+ cl::Hidden, cl::init(true));
static cl::opt<bool> ClGenerateTagsWithCalls(
"hwasan-generate-tags-with-calls",
@@ -223,6 +225,10 @@ static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks",
cl::desc("inline all checks"),
cl::Hidden, cl::init(false));
+static cl::opt<bool> ClInlineFastPathChecks("hwasan-inline-fast-path-checks",
+ cl::desc("inline all checks"),
+ cl::Hidden, cl::init(false));
+
// Enabled from clang by "-fsanitize-hwaddress-experimental-aliasing".
static cl::opt<bool> ClUsePageAliases("hwasan-experimental-use-page-aliases",
cl::desc("Use page aliasing in HWASan"),
@@ -274,9 +280,18 @@ public:
initializeModule();
}
+ void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM);
+
+private:
+ struct ShadowTagCheckInfo {
+ Instruction *TagMismatchTerm = nullptr;
+ Value *PtrLong = nullptr;
+ Value *AddrLong = nullptr;
+ Value *PtrTag = nullptr;
+ Value *MemTag = nullptr;
+ };
void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; }
- void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM);
void initializeModule();
void createHwasanCtorComdat();
@@ -291,18 +306,24 @@ public:
Value *memToShadow(Value *Shadow, IRBuilder<> &IRB);
int64_t getAccessInfo(bool IsWrite, unsigned AccessSizeIndex);
+ ShadowTagCheckInfo insertShadowTagCheck(Value *Ptr, Instruction *InsertBefore,
+ DomTreeUpdater &DTU, LoopInfo *LI);
void instrumentMemAccessOutline(Value *Ptr, bool IsWrite,
unsigned AccessSizeIndex,
- Instruction *InsertBefore);
+ Instruction *InsertBefore,
+ DomTreeUpdater &DTU, LoopInfo *LI);
void instrumentMemAccessInline(Value *Ptr, bool IsWrite,
unsigned AccessSizeIndex,
- Instruction *InsertBefore);
+ Instruction *InsertBefore, DomTreeUpdater &DTU,
+ LoopInfo *LI);
bool ignoreMemIntrinsic(MemIntrinsic *MI);
void instrumentMemIntrinsic(MemIntrinsic *MI);
- bool instrumentMemAccess(InterestingMemoryOperand &O);
+ bool instrumentMemAccess(InterestingMemoryOperand &O, DomTreeUpdater &DTU,
+ LoopInfo *LI);
bool ignoreAccess(Instruction *Inst, Value *Ptr);
void getInterestingMemoryOperands(
- Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting);
+ Instruction *I, const TargetLibraryInfo &TLI,
+ SmallVectorImpl<InterestingMemoryOperand> &Interesting);
void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size);
Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag);
@@ -332,7 +353,6 @@ public:
void instrumentPersonalityFunctions();
-private:
LLVMContext *C;
Module &M;
const StackSafetyGlobalInfo *SSI;
@@ -364,7 +384,7 @@ private:
Type *VoidTy = Type::getVoidTy(M.getContext());
Type *IntptrTy;
- Type *Int8PtrTy;
+ PointerType *PtrTy;
Type *Int8Ty;
Type *Int32Ty;
Type *Int64Ty = Type::getInt64Ty(M.getContext());
@@ -372,6 +392,7 @@ private:
bool CompileKernel;
bool Recover;
bool OutlinedChecks;
+ bool InlineFastPath;
bool UseShortGranules;
bool InstrumentLandingPads;
bool InstrumentWithCalls;
@@ -420,6 +441,12 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M,
HWASan.sanitizeFunction(F, FAM);
PreservedAnalyses PA = PreservedAnalyses::none();
+ // DominatorTreeAnalysis, PostDominatorTreeAnalysis, and LoopAnalysis
+ // are incrementally updated throughout this pass whenever
+ // SplitBlockAndInsertIfThen is called.
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<PostDominatorTreeAnalysis>();
+ PA.preserve<LoopAnalysis>();
// GlobalsAA is considered stateless and does not get invalidated unless
// explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers
// make changes that require GlobalsAA to be invalidated.
@@ -560,7 +587,7 @@ void HWAddressSanitizer::initializeModule() {
C = &(M.getContext());
IRBuilder<> IRB(*C);
IntptrTy = IRB.getIntPtrTy(DL);
- Int8PtrTy = IRB.getInt8PtrTy();
+ PtrTy = IRB.getPtrTy();
Int8Ty = IRB.getInt8Ty();
Int32Ty = IRB.getInt32Ty();
@@ -579,6 +606,13 @@ void HWAddressSanitizer::initializeModule() {
TargetTriple.isOSBinFormatELF() &&
(ClInlineAllChecks.getNumOccurrences() ? !ClInlineAllChecks : !Recover);
+ InlineFastPath =
+ (ClInlineFastPathChecks.getNumOccurrences()
+ ? ClInlineFastPathChecks
+ : !(TargetTriple.isAndroid() ||
+ TargetTriple.isOSFuchsia())); // These platforms may prefer less
+ // inlining to reduce binary size.
+
if (ClMatchAllTag.getNumOccurrences()) {
if (ClMatchAllTag != -1) {
MatchAllTag = ClMatchAllTag & 0xFF;
@@ -633,19 +667,19 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) {
FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false);
HwasanMemoryAccessCallbackFnTy =
FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false);
- HwasanMemTransferFnTy = FunctionType::get(
- Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false);
- HwasanMemsetFnTy = FunctionType::get(
- Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false);
+ HwasanMemTransferFnTy =
+ FunctionType::get(PtrTy, {PtrTy, PtrTy, IntptrTy, Int8Ty}, false);
+ HwasanMemsetFnTy =
+ FunctionType::get(PtrTy, {PtrTy, Int32Ty, IntptrTy, Int8Ty}, false);
} else {
HwasanMemoryAccessCallbackSizedFnTy =
FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false);
HwasanMemoryAccessCallbackFnTy =
FunctionType::get(VoidTy, {IntptrTy}, false);
HwasanMemTransferFnTy =
- FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false);
+ FunctionType::get(PtrTy, {PtrTy, PtrTy, IntptrTy}, false);
HwasanMemsetFnTy =
- FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false);
+ FunctionType::get(PtrTy, {PtrTy, Int32Ty, IntptrTy}, false);
}
for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) {
@@ -679,7 +713,7 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) {
MemIntrinCallbackPrefix + "memset" + MatchAllStr, HwasanMemsetFnTy);
HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy,
- Int8PtrTy, Int8Ty, IntptrTy);
+ PtrTy, Int8Ty, IntptrTy);
HwasanGenerateTagFunc =
M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty);
@@ -699,7 +733,7 @@ Value *HWAddressSanitizer::getOpaqueNoopCast(IRBuilder<> &IRB, Value *Val) {
// This prevents code bloat as a result of rematerializing trivial definitions
// such as constants or global addresses at every load and store.
InlineAsm *Asm =
- InlineAsm::get(FunctionType::get(Int8PtrTy, {Val->getType()}, false),
+ InlineAsm::get(FunctionType::get(PtrTy, {Val->getType()}, false),
StringRef(""), StringRef("=r,0"),
/*hasSideEffects=*/false);
return IRB.CreateCall(Asm, {Val}, ".hwasan.shadow");
@@ -713,15 +747,15 @@ Value *HWAddressSanitizer::getShadowNonTls(IRBuilder<> &IRB) {
if (Mapping.Offset != kDynamicShadowSentinel)
return getOpaqueNoopCast(
IRB, ConstantExpr::getIntToPtr(
- ConstantInt::get(IntptrTy, Mapping.Offset), Int8PtrTy));
+ ConstantInt::get(IntptrTy, Mapping.Offset), PtrTy));
if (Mapping.InGlobal)
return getDynamicShadowIfunc(IRB);
Value *GlobalDynamicAddress =
IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal(
- kHwasanShadowMemoryDynamicAddress, Int8PtrTy);
- return IRB.CreateLoad(Int8PtrTy, GlobalDynamicAddress);
+ kHwasanShadowMemoryDynamicAddress, PtrTy);
+ return IRB.CreateLoad(PtrTy, GlobalDynamicAddress);
}
bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) {
@@ -748,7 +782,8 @@ bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) {
}
void HWAddressSanitizer::getInterestingMemoryOperands(
- Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting) {
+ Instruction *I, const TargetLibraryInfo &TLI,
+ SmallVectorImpl<InterestingMemoryOperand> &Interesting) {
// Skip memory accesses inserted by another instrumentation.
if (I->hasMetadata(LLVMContext::MD_nosanitize))
return;
@@ -786,6 +821,7 @@ void HWAddressSanitizer::getInterestingMemoryOperands(
Type *Ty = CI->getParamByValType(ArgNo);
Interesting.emplace_back(I, ArgNo, false, Ty, Align(1));
}
+ maybeMarkSanitizerLibraryCallNoBuiltin(CI, &TLI);
}
}
@@ -824,7 +860,7 @@ Value *HWAddressSanitizer::memToShadow(Value *Mem, IRBuilder<> &IRB) {
// Mem >> Scale
Value *Shadow = IRB.CreateLShr(Mem, Mapping.Scale);
if (Mapping.Offset == 0)
- return IRB.CreateIntToPtr(Shadow, Int8PtrTy);
+ return IRB.CreateIntToPtr(Shadow, PtrTy);
// (Mem >> Scale) + Offset
return IRB.CreateGEP(Int8Ty, ShadowBase, Shadow);
}
@@ -839,14 +875,48 @@ int64_t HWAddressSanitizer::getAccessInfo(bool IsWrite,
(AccessSizeIndex << HWASanAccessInfo::AccessSizeShift);
}
+HWAddressSanitizer::ShadowTagCheckInfo
+HWAddressSanitizer::insertShadowTagCheck(Value *Ptr, Instruction *InsertBefore,
+ DomTreeUpdater &DTU, LoopInfo *LI) {
+ ShadowTagCheckInfo R;
+
+ IRBuilder<> IRB(InsertBefore);
+
+ R.PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy);
+ R.PtrTag =
+ IRB.CreateTrunc(IRB.CreateLShr(R.PtrLong, PointerTagShift), Int8Ty);
+ R.AddrLong = untagPointer(IRB, R.PtrLong);
+ Value *Shadow = memToShadow(R.AddrLong, IRB);
+ R.MemTag = IRB.CreateLoad(Int8Ty, Shadow);
+ Value *TagMismatch = IRB.CreateICmpNE(R.PtrTag, R.MemTag);
+
+ if (MatchAllTag.has_value()) {
+ Value *TagNotIgnored = IRB.CreateICmpNE(
+ R.PtrTag, ConstantInt::get(R.PtrTag->getType(), *MatchAllTag));
+ TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored);
+ }
+
+ R.TagMismatchTerm = SplitBlockAndInsertIfThen(
+ TagMismatch, InsertBefore, false,
+ MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI);
+
+ return R;
+}
+
void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite,
unsigned AccessSizeIndex,
- Instruction *InsertBefore) {
+ Instruction *InsertBefore,
+ DomTreeUpdater &DTU,
+ LoopInfo *LI) {
assert(!UsePageAliases);
const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex);
+
+ if (InlineFastPath)
+ InsertBefore =
+ insertShadowTagCheck(Ptr, InsertBefore, DTU, LI).TagMismatchTerm;
+
IRBuilder<> IRB(InsertBefore);
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
- Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy);
IRB.CreateCall(Intrinsic::getDeclaration(
M, UseShortGranules
? Intrinsic::hwasan_check_memaccess_shortgranules
@@ -856,55 +926,38 @@ void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite,
void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
unsigned AccessSizeIndex,
- Instruction *InsertBefore) {
+ Instruction *InsertBefore,
+ DomTreeUpdater &DTU,
+ LoopInfo *LI) {
assert(!UsePageAliases);
const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex);
- IRBuilder<> IRB(InsertBefore);
-
- Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy);
- Value *PtrTag =
- IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), Int8Ty);
- Value *AddrLong = untagPointer(IRB, PtrLong);
- Value *Shadow = memToShadow(AddrLong, IRB);
- Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow);
- Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag);
-
- if (MatchAllTag.has_value()) {
- Value *TagNotIgnored = IRB.CreateICmpNE(
- PtrTag, ConstantInt::get(PtrTag->getType(), *MatchAllTag));
- TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored);
- }
- Instruction *CheckTerm =
- SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false,
- MDBuilder(*C).createBranchWeights(1, 100000));
+ ShadowTagCheckInfo TCI = insertShadowTagCheck(Ptr, InsertBefore, DTU, LI);
- IRB.SetInsertPoint(CheckTerm);
+ IRBuilder<> IRB(TCI.TagMismatchTerm);
Value *OutOfShortGranuleTagRange =
- IRB.CreateICmpUGT(MemTag, ConstantInt::get(Int8Ty, 15));
- Instruction *CheckFailTerm =
- SplitBlockAndInsertIfThen(OutOfShortGranuleTagRange, CheckTerm, !Recover,
- MDBuilder(*C).createBranchWeights(1, 100000));
+ IRB.CreateICmpUGT(TCI.MemTag, ConstantInt::get(Int8Ty, 15));
+ Instruction *CheckFailTerm = SplitBlockAndInsertIfThen(
+ OutOfShortGranuleTagRange, TCI.TagMismatchTerm, !Recover,
+ MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI);
- IRB.SetInsertPoint(CheckTerm);
- Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(PtrLong, 15), Int8Ty);
+ IRB.SetInsertPoint(TCI.TagMismatchTerm);
+ Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(TCI.PtrLong, 15), Int8Ty);
PtrLowBits = IRB.CreateAdd(
PtrLowBits, ConstantInt::get(Int8Ty, (1 << AccessSizeIndex) - 1));
- Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, MemTag);
- SplitBlockAndInsertIfThen(PtrLowBitsOOB, CheckTerm, false,
- MDBuilder(*C).createBranchWeights(1, 100000),
- (DomTreeUpdater *)nullptr, nullptr,
- CheckFailTerm->getParent());
+ Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, TCI.MemTag);
+ SplitBlockAndInsertIfThen(PtrLowBitsOOB, TCI.TagMismatchTerm, false,
+ MDBuilder(*C).createBranchWeights(1, 100000), &DTU,
+ LI, CheckFailTerm->getParent());
- IRB.SetInsertPoint(CheckTerm);
- Value *InlineTagAddr = IRB.CreateOr(AddrLong, 15);
- InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, Int8PtrTy);
+ IRB.SetInsertPoint(TCI.TagMismatchTerm);
+ Value *InlineTagAddr = IRB.CreateOr(TCI.AddrLong, 15);
+ InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, PtrTy);
Value *InlineTag = IRB.CreateLoad(Int8Ty, InlineTagAddr);
- Value *InlineTagMismatch = IRB.CreateICmpNE(PtrTag, InlineTag);
- SplitBlockAndInsertIfThen(InlineTagMismatch, CheckTerm, false,
- MDBuilder(*C).createBranchWeights(1, 100000),
- (DomTreeUpdater *)nullptr, nullptr,
- CheckFailTerm->getParent());
+ Value *InlineTagMismatch = IRB.CreateICmpNE(TCI.PtrTag, InlineTag);
+ SplitBlockAndInsertIfThen(InlineTagMismatch, TCI.TagMismatchTerm, false,
+ MDBuilder(*C).createBranchWeights(1, 100000), &DTU,
+ LI, CheckFailTerm->getParent());
IRB.SetInsertPoint(CheckFailTerm);
InlineAsm *Asm;
@@ -912,7 +965,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::x86_64:
// The signal handler will find the data address in rdi.
Asm = InlineAsm::get(
- FunctionType::get(VoidTy, {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false),
"int3\nnopl " +
itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)) +
"(%rax)",
@@ -923,7 +976,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::aarch64_be:
// The signal handler will find the data address in x0.
Asm = InlineAsm::get(
- FunctionType::get(VoidTy, {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false),
"brk #" + itostr(0x900 + (AccessInfo & HWASanAccessInfo::RuntimeMask)),
"{x0}",
/*hasSideEffects=*/true);
@@ -931,7 +984,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
case Triple::riscv64:
// The signal handler will find the data address in x10.
Asm = InlineAsm::get(
- FunctionType::get(VoidTy, {PtrLong->getType()}, false),
+ FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false),
"ebreak\naddiw x0, x11, " +
itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)),
"{x10}",
@@ -940,9 +993,10 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite,
default:
report_fatal_error("unsupported architecture");
}
- IRB.CreateCall(Asm, PtrLong);
+ IRB.CreateCall(Asm, TCI.PtrLong);
if (Recover)
- cast<BranchInst>(CheckFailTerm)->setSuccessor(0, CheckTerm->getParent());
+ cast<BranchInst>(CheckFailTerm)
+ ->setSuccessor(0, TCI.TagMismatchTerm->getParent());
}
bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) {
@@ -958,40 +1012,28 @@ bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) {
void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
IRBuilder<> IRB(MI);
if (isa<MemTransferInst>(MI)) {
- if (UseMatchAllCallback) {
- IRB.CreateCall(
- isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
- ConstantInt::get(Int8Ty, *MatchAllTag)});
- } else {
- IRB.CreateCall(
- isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
- }
+ SmallVector<Value *, 4> Args{
+ MI->getOperand(0), MI->getOperand(1),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)};
+
+ if (UseMatchAllCallback)
+ Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag));
+ IRB.CreateCall(isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, Args);
} else if (isa<MemSetInst>(MI)) {
- if (UseMatchAllCallback) {
- IRB.CreateCall(
- HwasanMemset,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
- ConstantInt::get(Int8Ty, *MatchAllTag)});
- } else {
- IRB.CreateCall(
- HwasanMemset,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
- }
+ SmallVector<Value *, 4> Args{
+ MI->getOperand(0),
+ IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)};
+ if (UseMatchAllCallback)
+ Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag));
+ IRB.CreateCall(HwasanMemset, Args);
}
MI->eraseFromParent();
}
-bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) {
+bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O,
+ DomTreeUpdater &DTU,
+ LoopInfo *LI) {
Value *Addr = O.getPtr();
LLVM_DEBUG(dbgs() << "Instrumenting: " << O.getInsn() << "\n");
@@ -1006,34 +1048,26 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) {
*O.Alignment >= O.TypeStoreSize / 8)) {
size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeStoreSize);
if (InstrumentWithCalls) {
- if (UseMatchAllCallback) {
- IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
- {IRB.CreatePointerCast(Addr, IntptrTy),
- ConstantInt::get(Int8Ty, *MatchAllTag)});
- } else {
- IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
- IRB.CreatePointerCast(Addr, IntptrTy));
- }
+ SmallVector<Value *, 2> Args{IRB.CreatePointerCast(Addr, IntptrTy)};
+ if (UseMatchAllCallback)
+ Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag));
+ IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex],
+ Args);
} else if (OutlinedChecks) {
- instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn());
+ instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn(),
+ DTU, LI);
} else {
- instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn());
+ instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn(),
+ DTU, LI);
}
} else {
- if (UseMatchAllCallback) {
- IRB.CreateCall(
- HwasanMemoryAccessCallbackSized[O.IsWrite],
- {IRB.CreatePointerCast(Addr, IntptrTy),
- IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize),
- ConstantInt::get(IntptrTy, 8)),
- ConstantInt::get(Int8Ty, *MatchAllTag)});
- } else {
- IRB.CreateCall(
- HwasanMemoryAccessCallbackSized[O.IsWrite],
- {IRB.CreatePointerCast(Addr, IntptrTy),
- IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize),
- ConstantInt::get(IntptrTy, 8))});
- }
+ SmallVector<Value *, 3> Args{
+ IRB.CreatePointerCast(Addr, IntptrTy),
+ IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize),
+ ConstantInt::get(IntptrTy, 8))};
+ if (UseMatchAllCallback)
+ Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag));
+ IRB.CreateCall(HwasanMemoryAccessCallbackSized[O.IsWrite], Args);
}
untagPointerOperand(O.getInsn(), Addr);
@@ -1049,7 +1083,7 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag,
Tag = IRB.CreateTrunc(Tag, Int8Ty);
if (InstrumentWithCalls) {
IRB.CreateCall(HwasanTagMemoryFunc,
- {IRB.CreatePointerCast(AI, Int8PtrTy), Tag,
+ {IRB.CreatePointerCast(AI, PtrTy), Tag,
ConstantInt::get(IntptrTy, AlignedSize)});
} else {
size_t ShadowSize = Size >> Mapping.Scale;
@@ -1067,9 +1101,9 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag,
const uint8_t SizeRemainder = Size % Mapping.getObjectAlignment().value();
IRB.CreateStore(ConstantInt::get(Int8Ty, SizeRemainder),
IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize));
- IRB.CreateStore(Tag, IRB.CreateConstGEP1_32(
- Int8Ty, IRB.CreatePointerCast(AI, Int8PtrTy),
- AlignedSize - 1));
+ IRB.CreateStore(
+ Tag, IRB.CreateConstGEP1_32(Int8Ty, IRB.CreatePointerCast(AI, PtrTy),
+ AlignedSize - 1));
}
}
}
@@ -1183,10 +1217,8 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) {
// in Bionic's libc/private/bionic_tls.h.
Function *ThreadPointerFunc =
Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
- Value *SlotPtr = IRB.CreatePointerCast(
- IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), 0x30),
- Ty->getPointerTo(0));
- return SlotPtr;
+ return IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc),
+ 0x30);
}
if (ThreadPtrGlobal)
return ThreadPtrGlobal;
@@ -1208,7 +1240,7 @@ Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) {
Module *M = F->getParent();
auto *GetStackPointerFn = Intrinsic::getDeclaration(
M, Intrinsic::frameaddress,
- IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
+ IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace()));
CachedSP = IRB.CreatePtrToInt(
IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(Int32Ty)}),
IntptrTy);
@@ -1271,8 +1303,8 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) {
// Store data to ring buffer.
Value *FrameRecordInfo = getFrameRecordInfo(IRB);
- Value *RecordPtr = IRB.CreateIntToPtr(ThreadLongMaybeUntagged,
- IntptrTy->getPointerTo(0));
+ Value *RecordPtr =
+ IRB.CreateIntToPtr(ThreadLongMaybeUntagged, IRB.getPtrTy(0));
IRB.CreateStore(FrameRecordInfo, RecordPtr);
// Update the ring buffer. Top byte of ThreadLong defines the size of the
@@ -1309,7 +1341,7 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) {
ThreadLongMaybeUntagged,
ConstantInt::get(IntptrTy, (1ULL << kShadowBaseAlignment) - 1)),
ConstantInt::get(IntptrTy, 1), "hwasan.shadow");
- ShadowBase = IRB.CreateIntToPtr(ShadowBase, Int8PtrTy);
+ ShadowBase = IRB.CreateIntToPtr(ShadowBase, PtrTy);
}
}
@@ -1369,7 +1401,7 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
size_t Size = memtag::getAllocaSizeInBytes(*AI);
size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment());
- Value *AICast = IRB.CreatePointerCast(AI, Int8PtrTy);
+ Value *AICast = IRB.CreatePointerCast(AI, PtrTy);
auto HandleLifetime = [&](IntrinsicInst *II) {
// Set the lifetime intrinsic to cover the whole alloca. This reduces the
@@ -1462,6 +1494,7 @@ void HWAddressSanitizer::sanitizeFunction(Function &F,
SmallVector<InterestingMemoryOperand, 16> OperandsToInstrument;
SmallVector<MemIntrinsic *, 16> IntrinToInstrument;
SmallVector<Instruction *, 8> LandingPadVec;
+ const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
memtag::StackInfoBuilder SIB(SSI);
for (auto &Inst : instructions(F)) {
@@ -1472,7 +1505,7 @@ void HWAddressSanitizer::sanitizeFunction(Function &F,
if (InstrumentLandingPads && isa<LandingPadInst>(Inst))
LandingPadVec.push_back(&Inst);
- getInterestingMemoryOperands(&Inst, OperandsToInstrument);
+ getInterestingMemoryOperands(&Inst, TLI, OperandsToInstrument);
if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst))
if (!ignoreMemIntrinsic(MI))
@@ -1528,8 +1561,13 @@ void HWAddressSanitizer::sanitizeFunction(Function &F,
}
}
+ DominatorTree *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F);
+ PostDominatorTree *PDT = FAM.getCachedResult<PostDominatorTreeAnalysis>(F);
+ LoopInfo *LI = FAM.getCachedResult<LoopAnalysis>(F);
+ DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
for (auto &Operand : OperandsToInstrument)
- instrumentMemAccess(Operand);
+ instrumentMemAccess(Operand, DTU, LI);
+ DTU.flush();
if (ClInstrumentMemIntrinsics && !IntrinToInstrument.empty()) {
for (auto *Inst : IntrinToInstrument)
@@ -1624,7 +1662,7 @@ void HWAddressSanitizer::instrumentGlobals() {
if (GV.hasSanitizerMetadata() && GV.getSanitizerMetadata().NoHWAddress)
continue;
- if (GV.isDeclarationForLinker() || GV.getName().startswith("llvm.") ||
+ if (GV.isDeclarationForLinker() || GV.getName().starts_with("llvm.") ||
GV.isThreadLocal())
continue;
@@ -1682,8 +1720,8 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() {
return;
FunctionCallee HwasanPersonalityWrapper = M.getOrInsertFunction(
- "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty,
- Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy);
+ "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty, PtrTy,
+ PtrTy, PtrTy, PtrTy, PtrTy);
FunctionCallee UnwindGetGR = M.getOrInsertFunction("_Unwind_GetGR", VoidTy);
FunctionCallee UnwindGetCFA = M.getOrInsertFunction("_Unwind_GetCFA", VoidTy);
@@ -1692,7 +1730,7 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() {
if (P.first)
ThunkName += ("." + P.first->getName()).str();
FunctionType *ThunkFnTy = FunctionType::get(
- Int32Ty, {Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int8PtrTy}, false);
+ Int32Ty, {Int32Ty, Int32Ty, Int64Ty, PtrTy, PtrTy}, false);
bool IsLocal = P.first && (!isa<GlobalValue>(P.first) ||
cast<GlobalValue>(P.first)->hasLocalLinkage());
auto *ThunkFn = Function::Create(ThunkFnTy,
@@ -1710,10 +1748,8 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() {
HwasanPersonalityWrapper,
{ThunkFn->getArg(0), ThunkFn->getArg(1), ThunkFn->getArg(2),
ThunkFn->getArg(3), ThunkFn->getArg(4),
- P.first ? IRB.CreateBitCast(P.first, Int8PtrTy)
- : Constant::getNullValue(Int8PtrTy),
- IRB.CreateBitCast(UnwindGetGR.getCallee(), Int8PtrTy),
- IRB.CreateBitCast(UnwindGetCFA.getCallee(), Int8PtrTy)});
+ P.first ? P.first : Constant::getNullValue(PtrTy),
+ UnwindGetGR.getCallee(), UnwindGetCFA.getCallee()});
WrapperCall->setTailCall();
IRB.CreateRet(WrapperCall);
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017..7344fea17517 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
if (AttachProfToDirectCall) {
- MDBuilder MDB(NewInst.getContext());
- NewInst.setMetadata(
- LLVMContext::MD_prof,
- MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+ setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
}
using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index a7b1953ce81c..d3282779d9f5 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This pass lowers instrprof_* intrinsics emitted by a frontend for profiling.
+// This pass lowers instrprof_* intrinsics emitted by an instrumentor.
// It also builds the data structures and initialization code needed for
// updating execution counts and emitting the profile at runtime.
//
@@ -14,6 +14,7 @@
#include "llvm/Transforms/Instrumentation/InstrProfiling.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
@@ -23,6 +24,7 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DIBuilder.h"
@@ -47,6 +49,9 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/TargetParser/Triple.h"
+#include "llvm/Transforms/Instrumentation.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
@@ -190,7 +195,8 @@ public:
auto *OrigBiasInst = dyn_cast<BinaryOperator>(AddrInst->getOperand(0));
assert(OrigBiasInst->getOpcode() == Instruction::BinaryOps::Add);
Value *BiasInst = Builder.Insert(OrigBiasInst->clone());
- Addr = Builder.CreateIntToPtr(BiasInst, Ty->getPointerTo());
+ Addr = Builder.CreateIntToPtr(BiasInst,
+ PointerType::getUnqual(Ty->getContext()));
}
if (AtomicCounterUpdatePromoted)
// automic update currently can only be promoted across the current
@@ -241,7 +247,10 @@ public:
return;
for (BasicBlock *ExitBlock : LoopExitBlocks) {
- if (BlockSet.insert(ExitBlock).second) {
+ if (BlockSet.insert(ExitBlock).second &&
+ llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) {
+ return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock);
+ })) {
ExitBlocks.push_back(ExitBlock);
InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
}
@@ -430,6 +439,15 @@ bool InstrProfiling::lowerIntrinsics(Function *F) {
} else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(&Instr)) {
lowerValueProfileInst(IPVP);
MadeChange = true;
+ } else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(&Instr)) {
+ IPMP->eraseFromParent();
+ MadeChange = true;
+ } else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(&Instr)) {
+ lowerMCDCTestVectorBitmapUpdate(IPBU);
+ MadeChange = true;
+ } else if (auto *IPTU = dyn_cast<InstrProfMCDCCondBitmapUpdate>(&Instr)) {
+ lowerMCDCCondBitmapUpdate(IPTU);
+ MadeChange = true;
}
}
}
@@ -544,19 +562,27 @@ bool InstrProfiling::run(
// the instrumented function. This is counting the number of instrumented
// target value sites to enter it as field in the profile data variable.
for (Function &F : M) {
- InstrProfInstBase *FirstProfInst = nullptr;
- for (BasicBlock &BB : F)
- for (auto I = BB.begin(), E = BB.end(); I != E; I++)
+ InstrProfCntrInstBase *FirstProfInst = nullptr;
+ for (BasicBlock &BB : F) {
+ for (auto I = BB.begin(), E = BB.end(); I != E; I++) {
if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I))
computeNumValueSiteCounts(Ind);
- else if (FirstProfInst == nullptr &&
- (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I)))
- FirstProfInst = dyn_cast<InstrProfInstBase>(I);
+ else {
+ if (FirstProfInst == nullptr &&
+ (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I)))
+ FirstProfInst = dyn_cast<InstrProfCntrInstBase>(I);
+ // If the MCDCBitmapParameters intrinsic seen, create the bitmaps.
+ if (const auto &Params = dyn_cast<InstrProfMCDCBitmapParameters>(I))
+ static_cast<void>(getOrCreateRegionBitmaps(Params));
+ }
+ }
+ }
- // Value profiling intrinsic lowering requires per-function profile data
- // variable to be created first.
- if (FirstProfInst != nullptr)
+ // Use a profile intrinsic to create the region counters and data variable.
+ // Also create the data variable based on the MCDCParams.
+ if (FirstProfInst != nullptr) {
static_cast<void>(getOrCreateRegionCounters(FirstProfInst));
+ }
}
for (Function &F : M)
@@ -651,15 +677,11 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {
SmallVector<OperandBundleDef, 1> OpBundles;
Ind->getOperandBundlesAsDefs(OpBundles);
if (!IsMemOpSize) {
- Value *Args[3] = {Ind->getTargetValue(),
- Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
- Builder.getInt32(Index)};
+ Value *Args[3] = {Ind->getTargetValue(), DataVar, Builder.getInt32(Index)};
Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), Args,
OpBundles);
} else {
- Value *Args[3] = {Ind->getTargetValue(),
- Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()),
- Builder.getInt32(Index)};
+ Value *Args[3] = {Ind->getTargetValue(), DataVar, Builder.getInt32(Index)};
Call = Builder.CreateCall(
getOrInsertValueProfilingCall(*M, *TLI, ValueProfilingCallType::MemOp),
Args, OpBundles);
@@ -670,7 +692,7 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {
Ind->eraseFromParent();
}
-Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) {
+Value *InstrProfiling::getCounterAddress(InstrProfCntrInstBase *I) {
auto *Counters = getOrCreateRegionCounters(I);
IRBuilder<> Builder(I);
@@ -710,6 +732,25 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) {
return Builder.CreateIntToPtr(Add, Addr->getType());
}
+Value *InstrProfiling::getBitmapAddress(InstrProfMCDCTVBitmapUpdate *I) {
+ auto *Bitmaps = getOrCreateRegionBitmaps(I);
+ IRBuilder<> Builder(I);
+
+ auto *Addr = Builder.CreateConstInBoundsGEP2_32(
+ Bitmaps->getValueType(), Bitmaps, 0, I->getBitmapIndex()->getZExtValue());
+
+ if (isRuntimeCounterRelocationEnabled()) {
+ LLVMContext &Ctx = M->getContext();
+ Ctx.diagnose(DiagnosticInfoPGOProfile(
+ M->getName().data(),
+ Twine("Runtime counter relocation is presently not supported for MC/DC "
+ "bitmaps."),
+ DS_Warning));
+ }
+
+ return Addr;
+}
+
void InstrProfiling::lowerCover(InstrProfCoverInst *CoverInstruction) {
auto *Addr = getCounterAddress(CoverInstruction);
IRBuilder<> Builder(CoverInstruction);
@@ -769,6 +810,86 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) {
CoverageNamesVar->eraseFromParent();
}
+void InstrProfiling::lowerMCDCTestVectorBitmapUpdate(
+ InstrProfMCDCTVBitmapUpdate *Update) {
+ IRBuilder<> Builder(Update);
+ auto *Int8Ty = Type::getInt8Ty(M->getContext());
+ auto *Int8PtrTy = PointerType::getUnqual(M->getContext());
+ auto *Int32Ty = Type::getInt32Ty(M->getContext());
+ auto *Int64Ty = Type::getInt64Ty(M->getContext());
+ auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr();
+ auto *BitmapAddr = getBitmapAddress(Update);
+
+ // Load Temp Val.
+ // %mcdc.temp = load i32, ptr %mcdc.addr, align 4
+ auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp");
+
+ // Calculate byte offset using div8.
+ // %1 = lshr i32 %mcdc.temp, 3
+ auto *BitmapByteOffset = Builder.CreateLShr(Temp, 0x3);
+
+ // Add byte offset to section base byte address.
+ // %2 = zext i32 %1 to i64
+ // %3 = add i64 ptrtoint (ptr @__profbm_test to i64), %2
+ auto *BitmapByteAddr =
+ Builder.CreateAdd(Builder.CreatePtrToInt(BitmapAddr, Int64Ty),
+ Builder.CreateZExtOrBitCast(BitmapByteOffset, Int64Ty));
+
+ // Convert to a pointer.
+ // %4 = inttoptr i32 %3 to ptr
+ BitmapByteAddr = Builder.CreateIntToPtr(BitmapByteAddr, Int8PtrTy);
+
+ // Calculate bit offset into bitmap byte by using div8 remainder (AND ~8)
+ // %5 = and i32 %mcdc.temp, 7
+ // %6 = trunc i32 %5 to i8
+ auto *BitToSet = Builder.CreateTrunc(Builder.CreateAnd(Temp, 0x7), Int8Ty);
+
+ // Shift bit offset left to form a bitmap.
+ // %7 = shl i8 1, %6
+ auto *ShiftedVal = Builder.CreateShl(Builder.getInt8(0x1), BitToSet);
+
+ // Load profile bitmap byte.
+ // %mcdc.bits = load i8, ptr %4, align 1
+ auto *Bitmap = Builder.CreateLoad(Int8Ty, BitmapByteAddr, "mcdc.bits");
+
+ // Perform logical OR of profile bitmap byte and shifted bit offset.
+ // %8 = or i8 %mcdc.bits, %7
+ auto *Result = Builder.CreateOr(Bitmap, ShiftedVal);
+
+ // Store the updated profile bitmap byte.
+ // store i8 %8, ptr %3, align 1
+ Builder.CreateStore(Result, BitmapByteAddr);
+ Update->eraseFromParent();
+}
+
+void InstrProfiling::lowerMCDCCondBitmapUpdate(
+ InstrProfMCDCCondBitmapUpdate *Update) {
+ IRBuilder<> Builder(Update);
+ auto *Int32Ty = Type::getInt32Ty(M->getContext());
+ auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr();
+
+ // Load the MCDC temporary value from the stack.
+ // %mcdc.temp = load i32, ptr %mcdc.addr, align 4
+ auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp");
+
+ // Zero-extend the evaluated condition boolean value (0 or 1) by 32bits.
+ // %1 = zext i1 %tobool to i32
+ auto *CondV_32 = Builder.CreateZExt(Update->getCondBool(), Int32Ty);
+
+ // Shift the boolean value left (by the condition's ID) to form a bitmap.
+ // %2 = shl i32 %1, <Update->getCondID()>
+ auto *ShiftedVal = Builder.CreateShl(CondV_32, Update->getCondID());
+
+ // Perform logical OR of the bitmap against the loaded MCDC temporary value.
+ // %3 = or i32 %mcdc.temp, %2
+ auto *Result = Builder.CreateOr(Temp, ShiftedVal);
+
+ // Store the updated temporary value back to the stack.
+ // store i32 %3, ptr %mcdc.addr, align 4
+ Builder.CreateStore(Result, MCDCCondBitmapAddr);
+ Update->eraseFromParent();
+}
+
/// Get the name of a profiling variable for a particular function.
static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix,
bool &Renamed) {
@@ -784,7 +905,7 @@ static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix,
Renamed = true;
uint64_t FuncHash = Inc->getHash()->getZExtValue();
SmallVector<char, 24> HashPostfix;
- if (Name.endswith((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix)))
+ if (Name.ends_with((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix)))
return (Prefix + Name).str();
return (Prefix + Name + "." + Twine(FuncHash)).str();
}
@@ -878,7 +999,7 @@ static inline bool shouldUsePublicSymbol(Function *Fn) {
}
static inline Constant *getFuncAddrForProfData(Function *Fn) {
- auto *Int8PtrTy = Type::getInt8PtrTy(Fn->getContext());
+ auto *Int8PtrTy = PointerType::getUnqual(Fn->getContext());
// Store a nullptr in __llvm_profd, if we shouldn't use a real address
if (!shouldRecordFunctionAddr(Fn))
return ConstantPointerNull::get(Int8PtrTy);
@@ -886,7 +1007,7 @@ static inline Constant *getFuncAddrForProfData(Function *Fn) {
// If we can't use an alias, we must use the public symbol, even though this
// may require a symbolic relocation.
if (shouldUsePublicSymbol(Fn))
- return ConstantExpr::getBitCast(Fn, Int8PtrTy);
+ return Fn;
// When possible use a private alias to avoid symbolic relocations.
auto *GA = GlobalAlias::create(GlobalValue::LinkageTypes::PrivateLinkage,
@@ -909,7 +1030,7 @@ static inline Constant *getFuncAddrForProfData(Function *Fn) {
// appendToCompilerUsed(*Fn->getParent(), {GA});
- return ConstantExpr::getBitCast(GA, Int8PtrTy);
+ return GA;
}
static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) {
@@ -924,37 +1045,31 @@ static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) {
return true;
}
-GlobalVariable *
-InstrProfiling::createRegionCounters(InstrProfInstBase *Inc, StringRef Name,
- GlobalValue::LinkageTypes Linkage) {
- uint64_t NumCounters = Inc->getNumCounters()->getZExtValue();
- auto &Ctx = M->getContext();
- GlobalVariable *GV;
- if (isa<InstrProfCoverInst>(Inc)) {
- auto *CounterTy = Type::getInt8Ty(Ctx);
- auto *CounterArrTy = ArrayType::get(CounterTy, NumCounters);
- // TODO: `Constant::getAllOnesValue()` does not yet accept an array type.
- std::vector<Constant *> InitialValues(NumCounters,
- Constant::getAllOnesValue(CounterTy));
- GV = new GlobalVariable(*M, CounterArrTy, false, Linkage,
- ConstantArray::get(CounterArrTy, InitialValues),
- Name);
- GV->setAlignment(Align(1));
- } else {
- auto *CounterTy = ArrayType::get(Type::getInt64Ty(Ctx), NumCounters);
- GV = new GlobalVariable(*M, CounterTy, false, Linkage,
- Constant::getNullValue(CounterTy), Name);
- GV->setAlignment(Align(8));
- }
- return GV;
+void InstrProfiling::maybeSetComdat(GlobalVariable *GV, Function *Fn,
+ StringRef VarName) {
+ bool DataReferencedByCode = profDataReferencedByCode(*M);
+ bool NeedComdat = needsComdatForCounter(*Fn, *M);
+ bool UseComdat = (NeedComdat || TT.isOSBinFormatELF());
+
+ if (!UseComdat)
+ return;
+
+ StringRef GroupName =
+ TT.isOSBinFormatCOFF() && DataReferencedByCode ? GV->getName() : VarName;
+ Comdat *C = M->getOrInsertComdat(GroupName);
+ if (!NeedComdat)
+ C->setSelectionKind(Comdat::NoDeduplicate);
+ GV->setComdat(C);
+ // COFF doesn't allow the comdat group leader to have private linkage, so
+ // upgrade private linkage to internal linkage to produce a symbol table
+ // entry.
+ if (TT.isOSBinFormatCOFF() && GV->hasPrivateLinkage())
+ GV->setLinkage(GlobalValue::InternalLinkage);
}
-GlobalVariable *
-InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
+GlobalVariable *InstrProfiling::setupProfileSection(InstrProfInstBase *Inc,
+ InstrProfSectKind IPSK) {
GlobalVariable *NamePtr = Inc->getName();
- auto &PD = ProfileDataMap[NamePtr];
- if (PD.RegionCounters)
- return PD.RegionCounters;
// Match the linkage and visibility of the name global.
Function *Fn = Inc->getParent()->getParent();
@@ -993,42 +1108,101 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
// nodeduplicate COMDAT which is lowered to a zero-flag section group. This
// allows -z start-stop-gc to discard the entire group when the function is
// discarded.
- bool DataReferencedByCode = profDataReferencedByCode(*M);
- bool NeedComdat = needsComdatForCounter(*Fn, *M);
bool Renamed;
- std::string CntsVarName =
- getVarName(Inc, getInstrProfCountersVarPrefix(), Renamed);
- std::string DataVarName =
- getVarName(Inc, getInstrProfDataVarPrefix(), Renamed);
- auto MaybeSetComdat = [&](GlobalVariable *GV) {
- bool UseComdat = (NeedComdat || TT.isOSBinFormatELF());
- if (UseComdat) {
- StringRef GroupName = TT.isOSBinFormatCOFF() && DataReferencedByCode
- ? GV->getName()
- : CntsVarName;
- Comdat *C = M->getOrInsertComdat(GroupName);
- if (!NeedComdat)
- C->setSelectionKind(Comdat::NoDeduplicate);
- GV->setComdat(C);
- // COFF doesn't allow the comdat group leader to have private linkage, so
- // upgrade private linkage to internal linkage to produce a symbol table
- // entry.
- if (TT.isOSBinFormatCOFF() && GV->hasPrivateLinkage())
- GV->setLinkage(GlobalValue::InternalLinkage);
- }
- };
+ GlobalVariable *Ptr;
+ StringRef VarPrefix;
+ std::string VarName;
+ if (IPSK == IPSK_cnts) {
+ VarPrefix = getInstrProfCountersVarPrefix();
+ VarName = getVarName(Inc, VarPrefix, Renamed);
+ InstrProfCntrInstBase *CntrIncrement = dyn_cast<InstrProfCntrInstBase>(Inc);
+ Ptr = createRegionCounters(CntrIncrement, VarName, Linkage);
+ } else if (IPSK == IPSK_bitmap) {
+ VarPrefix = getInstrProfBitmapVarPrefix();
+ VarName = getVarName(Inc, VarPrefix, Renamed);
+ InstrProfMCDCBitmapInstBase *BitmapUpdate =
+ dyn_cast<InstrProfMCDCBitmapInstBase>(Inc);
+ Ptr = createRegionBitmaps(BitmapUpdate, VarName, Linkage);
+ } else {
+ llvm_unreachable("Profile Section must be for Counters or Bitmaps");
+ }
+
+ Ptr->setVisibility(Visibility);
+ // Put the counters and bitmaps in their own sections so linkers can
+ // remove unneeded sections.
+ Ptr->setSection(getInstrProfSectionName(IPSK, TT.getObjectFormat()));
+ Ptr->setLinkage(Linkage);
+ maybeSetComdat(Ptr, Fn, VarName);
+ return Ptr;
+}
+
+GlobalVariable *
+InstrProfiling::createRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc,
+ StringRef Name,
+ GlobalValue::LinkageTypes Linkage) {
+ uint64_t NumBytes = Inc->getNumBitmapBytes()->getZExtValue();
+ auto *BitmapTy = ArrayType::get(Type::getInt8Ty(M->getContext()), NumBytes);
+ auto GV = new GlobalVariable(*M, BitmapTy, false, Linkage,
+ Constant::getNullValue(BitmapTy), Name);
+ GV->setAlignment(Align(1));
+ return GV;
+}
+
+GlobalVariable *
+InstrProfiling::getOrCreateRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc) {
+ GlobalVariable *NamePtr = Inc->getName();
+ auto &PD = ProfileDataMap[NamePtr];
+ if (PD.RegionBitmaps)
+ return PD.RegionBitmaps;
+
+ // If RegionBitmaps doesn't already exist, create it by first setting up
+ // the corresponding profile section.
+ auto *BitmapPtr = setupProfileSection(Inc, IPSK_bitmap);
+ PD.RegionBitmaps = BitmapPtr;
+ PD.NumBitmapBytes = Inc->getNumBitmapBytes()->getZExtValue();
+ return PD.RegionBitmaps;
+}
+GlobalVariable *
+InstrProfiling::createRegionCounters(InstrProfCntrInstBase *Inc, StringRef Name,
+ GlobalValue::LinkageTypes Linkage) {
uint64_t NumCounters = Inc->getNumCounters()->getZExtValue();
- LLVMContext &Ctx = M->getContext();
+ auto &Ctx = M->getContext();
+ GlobalVariable *GV;
+ if (isa<InstrProfCoverInst>(Inc)) {
+ auto *CounterTy = Type::getInt8Ty(Ctx);
+ auto *CounterArrTy = ArrayType::get(CounterTy, NumCounters);
+ // TODO: `Constant::getAllOnesValue()` does not yet accept an array type.
+ std::vector<Constant *> InitialValues(NumCounters,
+ Constant::getAllOnesValue(CounterTy));
+ GV = new GlobalVariable(*M, CounterArrTy, false, Linkage,
+ ConstantArray::get(CounterArrTy, InitialValues),
+ Name);
+ GV->setAlignment(Align(1));
+ } else {
+ auto *CounterTy = ArrayType::get(Type::getInt64Ty(Ctx), NumCounters);
+ GV = new GlobalVariable(*M, CounterTy, false, Linkage,
+ Constant::getNullValue(CounterTy), Name);
+ GV->setAlignment(Align(8));
+ }
+ return GV;
+}
+
+GlobalVariable *
+InstrProfiling::getOrCreateRegionCounters(InstrProfCntrInstBase *Inc) {
+ GlobalVariable *NamePtr = Inc->getName();
+ auto &PD = ProfileDataMap[NamePtr];
+ if (PD.RegionCounters)
+ return PD.RegionCounters;
- auto *CounterPtr = createRegionCounters(Inc, CntsVarName, Linkage);
- CounterPtr->setVisibility(Visibility);
- CounterPtr->setSection(
- getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat()));
- CounterPtr->setLinkage(Linkage);
- MaybeSetComdat(CounterPtr);
+ // If RegionCounters doesn't already exist, create it by first setting up
+ // the corresponding profile section.
+ auto *CounterPtr = setupProfileSection(Inc, IPSK_cnts);
PD.RegionCounters = CounterPtr;
+
if (DebugInfoCorrelate) {
+ LLVMContext &Ctx = M->getContext();
+ Function *Fn = Inc->getParent()->getParent();
if (auto *SP = Fn->getSubprogram()) {
DIBuilder DB(*M, true, SP->getUnit());
Metadata *FunctionNameAnnotation[] = {
@@ -1056,16 +1230,58 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
Annotations);
CounterPtr->addDebugInfo(DICounter);
DB.finalize();
- } else {
- std::string Msg = ("Missing debug info for function " + Fn->getName() +
- "; required for profile correlation.")
- .str();
- Ctx.diagnose(
- DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
}
+
+ // Mark the counter variable as used so that it isn't optimized out.
+ CompilerUsedVars.push_back(PD.RegionCounters);
}
- auto *Int8PtrTy = Type::getInt8PtrTy(Ctx);
+ // Create the data variable (if it doesn't already exist).
+ createDataVariable(Inc);
+
+ return PD.RegionCounters;
+}
+
+void InstrProfiling::createDataVariable(InstrProfCntrInstBase *Inc) {
+ // When debug information is correlated to profile data, a data variable
+ // is not needed.
+ if (DebugInfoCorrelate)
+ return;
+
+ GlobalVariable *NamePtr = Inc->getName();
+ auto &PD = ProfileDataMap[NamePtr];
+
+ // Return if data variable was already created.
+ if (PD.DataVar)
+ return;
+
+ LLVMContext &Ctx = M->getContext();
+
+ Function *Fn = Inc->getParent()->getParent();
+ GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage();
+ GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility();
+
+ // Due to the limitation of binder as of 2021/09/28, the duplicate weak
+ // symbols in the same csect won't be discarded. When there are duplicate weak
+ // symbols, we can NOT guarantee that the relocations get resolved to the
+ // intended weak symbol, so we can not ensure the correctness of the relative
+ // CounterPtr, so we have to use private linkage for counter and data symbols.
+ if (TT.isOSBinFormatXCOFF()) {
+ Linkage = GlobalValue::PrivateLinkage;
+ Visibility = GlobalValue::DefaultVisibility;
+ }
+
+ bool DataReferencedByCode = profDataReferencedByCode(*M);
+ bool NeedComdat = needsComdatForCounter(*Fn, *M);
+ bool Renamed;
+
+ // The Data Variable section is anchored to profile counters.
+ std::string CntsVarName =
+ getVarName(Inc, getInstrProfCountersVarPrefix(), Renamed);
+ std::string DataVarName =
+ getVarName(Inc, getInstrProfDataVarPrefix(), Renamed);
+
+ auto *Int8PtrTy = PointerType::getUnqual(Ctx);
// Allocate statically the array of pointers to value profile nodes for
// the current function.
Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy);
@@ -1079,19 +1295,18 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
*M, ValuesTy, false, Linkage, Constant::getNullValue(ValuesTy),
getVarName(Inc, getInstrProfValuesVarPrefix(), Renamed));
ValuesVar->setVisibility(Visibility);
+ setGlobalVariableLargeSection(TT, *ValuesVar);
ValuesVar->setSection(
getInstrProfSectionName(IPSK_vals, TT.getObjectFormat()));
ValuesVar->setAlignment(Align(8));
- MaybeSetComdat(ValuesVar);
- ValuesPtrExpr =
- ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx));
+ maybeSetComdat(ValuesVar, Fn, CntsVarName);
+ ValuesPtrExpr = ValuesVar;
}
- if (DebugInfoCorrelate) {
- // Mark the counter variable as used so that it isn't optimized out.
- CompilerUsedVars.push_back(PD.RegionCounters);
- return PD.RegionCounters;
- }
+ uint64_t NumCounters = Inc->getNumCounters()->getZExtValue();
+ auto *CounterPtr = PD.RegionCounters;
+
+ uint64_t NumBitmapBytes = PD.NumBitmapBytes;
// Create data variable.
auto *IntPtrTy = M->getDataLayout().getIntPtrType(M->getContext());
@@ -1134,6 +1349,16 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
ConstantExpr::getSub(ConstantExpr::getPtrToInt(CounterPtr, IntPtrTy),
ConstantExpr::getPtrToInt(Data, IntPtrTy));
+ // Bitmaps are relative to the same data variable as profile counters.
+ GlobalVariable *BitmapPtr = PD.RegionBitmaps;
+ Constant *RelativeBitmapPtr = ConstantInt::get(IntPtrTy, 0);
+
+ if (BitmapPtr != nullptr) {
+ RelativeBitmapPtr =
+ ConstantExpr::getSub(ConstantExpr::getPtrToInt(BitmapPtr, IntPtrTy),
+ ConstantExpr::getPtrToInt(Data, IntPtrTy));
+ }
+
Constant *DataVals[] = {
#define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init,
#include "llvm/ProfileData/InstrProfData.inc"
@@ -1143,7 +1368,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
Data->setVisibility(Visibility);
Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat()));
Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT));
- MaybeSetComdat(Data);
+ maybeSetComdat(Data, Fn, CntsVarName);
PD.DataVar = Data;
@@ -1155,8 +1380,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) {
NamePtr->setLinkage(GlobalValue::PrivateLinkage);
// Collect the referenced names to be used by emitNameData.
ReferencedNames.push_back(NamePtr);
-
- return PD.RegionCounters;
}
void InstrProfiling::emitVNodes() {
@@ -1201,6 +1424,7 @@ void InstrProfiling::emitVNodes() {
auto *VNodesVar = new GlobalVariable(
*M, VNodesTy, false, GlobalValue::PrivateLinkage,
Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName());
+ setGlobalVariableLargeSection(TT, *VNodesVar);
VNodesVar->setSection(
getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat()));
VNodesVar->setAlignment(M->getDataLayout().getABITypeAlign(VNodesTy));
@@ -1228,6 +1452,7 @@ void InstrProfiling::emitNameData() {
GlobalValue::PrivateLinkage, NamesVal,
getInstrProfNamesVarName());
NamesSize = CompressedNameStr.size();
+ setGlobalVariableLargeSection(TT, *NamesVar);
NamesVar->setSection(
getInstrProfSectionName(IPSK_name, TT.getObjectFormat()));
// On COFF, it's important to reduce the alignment down to 1 to prevent the
@@ -1248,7 +1473,7 @@ void InstrProfiling::emitRegistration() {
// Construct the function.
auto *VoidTy = Type::getVoidTy(M->getContext());
- auto *VoidPtrTy = Type::getInt8PtrTy(M->getContext());
+ auto *VoidPtrTy = PointerType::getUnqual(M->getContext());
auto *Int64Ty = Type::getInt64Ty(M->getContext());
auto *RegisterFTy = FunctionType::get(VoidTy, false);
auto *RegisterF = Function::Create(RegisterFTy, GlobalValue::InternalLinkage,
@@ -1265,10 +1490,10 @@ void InstrProfiling::emitRegistration() {
IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", RegisterF));
for (Value *Data : CompilerUsedVars)
if (!isa<Function>(Data))
- IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy));
+ IRB.CreateCall(RuntimeRegisterF, Data);
for (Value *Data : UsedVars)
if (Data != NamesVar && !isa<Function>(Data))
- IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy));
+ IRB.CreateCall(RuntimeRegisterF, Data);
if (NamesVar) {
Type *ParamTypes[] = {VoidPtrTy, Int64Ty};
@@ -1277,8 +1502,7 @@ void InstrProfiling::emitRegistration() {
auto *NamesRegisterF =
Function::Create(NamesRegisterTy, GlobalVariable::ExternalLinkage,
getInstrProfNamesRegFuncName(), M);
- IRB.CreateCall(NamesRegisterF, {IRB.CreateBitCast(NamesVar, VoidPtrTy),
- IRB.getInt64(NamesSize)});
+ IRB.CreateCall(NamesRegisterF, {NamesVar, IRB.getInt64(NamesSize)});
}
IRB.CreateRetVoid();
diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
index 806afc8fcdf7..199afbe966dd 100644
--- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
@@ -85,3 +85,10 @@ Comdat *llvm::getOrCreateFunctionComdat(Function &F, Triple &T) {
return C;
}
+void llvm::setGlobalVariableLargeSection(Triple &TargetTriple,
+ GlobalVariable &GV) {
+ if (TargetTriple.getArch() == Triple::x86_64 &&
+ TargetTriple.getObjectFormat() == Triple::ELF) {
+ GV.setCodeModel(CodeModel::Large);
+ }
+}
diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index 789ed005d03d..539b7441d24b 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -182,6 +182,7 @@ public:
C = &(M.getContext());
LongSize = M.getDataLayout().getPointerSizeInBits();
IntptrTy = Type::getIntNTy(*C, LongSize);
+ PtrTy = PointerType::getUnqual(*C);
}
/// If it is an interesting memory access, populate information
@@ -209,6 +210,7 @@ private:
LLVMContext *C;
int LongSize;
Type *IntptrTy;
+ PointerType *PtrTy;
ShadowMapping Mapping;
// These arrays is indexed by AccessIsWrite
@@ -267,15 +269,13 @@ Value *MemProfiler::memToShadow(Value *Shadow, IRBuilder<> &IRB) {
void MemProfiler::instrumentMemIntrinsic(MemIntrinsic *MI) {
IRBuilder<> IRB(MI);
if (isa<MemTransferInst>(MI)) {
- IRB.CreateCall(
- isa<MemMoveInst>(MI) ? MemProfMemmove : MemProfMemcpy,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+ IRB.CreateCall(isa<MemMoveInst>(MI) ? MemProfMemmove : MemProfMemcpy,
+ {MI->getOperand(0), MI->getOperand(1),
+ IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
} else if (isa<MemSetInst>(MI)) {
IRB.CreateCall(
MemProfMemset,
- {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+ {MI->getOperand(0),
IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
}
@@ -364,13 +364,13 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const {
StringRef SectionName = GV->getSection();
// Check if the global is in the PGO counters section.
auto OF = Triple(I->getModule()->getTargetTriple()).getObjectFormat();
- if (SectionName.endswith(
+ if (SectionName.ends_with(
getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
return std::nullopt;
}
// Do not instrument accesses to LLVM internal variables.
- if (GV->getName().startswith("__llvm"))
+ if (GV->getName().starts_with("__llvm"))
return std::nullopt;
}
@@ -519,14 +519,12 @@ void MemProfiler::initializeCallbacks(Module &M) {
FunctionType::get(IRB.getVoidTy(), Args1, false));
}
MemProfMemmove = M.getOrInsertFunction(
- ClMemoryAccessCallbackPrefix + "memmove", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy);
+ ClMemoryAccessCallbackPrefix + "memmove", PtrTy, PtrTy, PtrTy, IntptrTy);
MemProfMemcpy = M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memcpy",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IntptrTy);
- MemProfMemset = M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memset",
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(),
- IRB.getInt32Ty(), IntptrTy);
+ PtrTy, PtrTy, PtrTy, IntptrTy);
+ MemProfMemset =
+ M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memset", PtrTy,
+ PtrTy, IRB.getInt32Ty(), IntptrTy);
}
bool MemProfiler::maybeInsertMemProfInitAtFunctionEntry(Function &F) {
@@ -562,7 +560,7 @@ bool MemProfiler::instrumentFunction(Function &F) {
return false;
if (ClDebugFunc == F.getName())
return false;
- if (F.getName().startswith("__memprof_"))
+ if (F.getName().starts_with("__memprof_"))
return false;
bool FunctionModified = false;
@@ -628,7 +626,7 @@ static void addCallsiteMetadata(Instruction &I,
static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset,
uint32_t Column) {
- llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little>
+ llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::endianness::little>
HashBuilder;
HashBuilder.add(Function, LineOffset, Column);
llvm::BLAKE3Result<8> Hash = HashBuilder.final();
@@ -678,13 +676,19 @@ static void readMemprof(Module &M, Function &F,
IndexedInstrProfReader *MemProfReader,
const TargetLibraryInfo &TLI) {
auto &Ctx = M.getContext();
-
- auto FuncName = getPGOFuncName(F);
+ // Previously we used getIRPGOFuncName() here. If F is local linkage,
+ // getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But
+ // llvm-profdata uses FuncName in dwarf to create GUID which doesn't
+ // contain FileName's prefix. It caused local linkage function can't
+ // find MemProfRecord. So we use getName() now.
+ // 'unique-internal-linkage-names' can make MemProf work better for local
+ // linkage function.
+ auto FuncName = F.getName();
auto FuncGUID = Function::getGUID(FuncName);
- Expected<memprof::MemProfRecord> MemProfResult =
- MemProfReader->getMemProfRecord(FuncGUID);
- if (Error E = MemProfResult.takeError()) {
- handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
+ std::optional<memprof::MemProfRecord> MemProfRec;
+ auto Err = MemProfReader->getMemProfRecord(FuncGUID).moveInto(MemProfRec);
+ if (Err) {
+ handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) {
auto Err = IPE.get();
bool SkipWarning = false;
LLVM_DEBUG(dbgs() << "Error in reading profile for Func " << FuncName
@@ -715,6 +719,12 @@ static void readMemprof(Module &M, Function &F,
return;
}
+ // Detect if there are non-zero column numbers in the profile. If not,
+ // treat all column numbers as 0 when matching (i.e. ignore any non-zero
+ // columns in the IR). The profiled binary might have been built with
+ // column numbers disabled, for example.
+ bool ProfileHasColumns = false;
+
// Build maps of the location hash to all profile data with that leaf location
// (allocation info and the callsites).
std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo;
@@ -722,21 +732,22 @@ static void readMemprof(Module &M, Function &F,
// the frame array (see comments below where the map entries are added).
std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>>
LocHashToCallSites;
- const auto MemProfRec = std::move(MemProfResult.get());
- for (auto &AI : MemProfRec.AllocSites) {
+ for (auto &AI : MemProfRec->AllocSites) {
// Associate the allocation info with the leaf frame. The later matching
// code will match any inlined call sequences in the IR with a longer prefix
// of call stack frames.
uint64_t StackId = computeStackId(AI.CallStack[0]);
LocHashToAllocInfo[StackId].insert(&AI);
+ ProfileHasColumns |= AI.CallStack[0].Column;
}
- for (auto &CS : MemProfRec.CallSites) {
+ for (auto &CS : MemProfRec->CallSites) {
// Need to record all frames from leaf up to and including this function,
// as any of these may or may not have been inlined at this point.
unsigned Idx = 0;
for (auto &StackFrame : CS) {
uint64_t StackId = computeStackId(StackFrame);
LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++));
+ ProfileHasColumns |= StackFrame.Column;
// Once we find this function, we can stop recording.
if (StackFrame.Function == FuncGUID)
break;
@@ -785,21 +796,21 @@ static void readMemprof(Module &M, Function &F,
if (Name.empty())
Name = DIL->getScope()->getSubprogram()->getName();
auto CalleeGUID = Function::getGUID(Name);
- auto StackId =
- computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn());
- // LeafFound will only be false on the first iteration, since we either
- // set it true or break out of the loop below.
+ auto StackId = computeStackId(CalleeGUID, GetOffset(DIL),
+ ProfileHasColumns ? DIL->getColumn() : 0);
+ // Check if we have found the profile's leaf frame. If yes, collect
+ // the rest of the call's inlined context starting here. If not, see if
+ // we find a match further up the inlined context (in case the profile
+ // was missing debug frames at the leaf).
if (!LeafFound) {
AllocInfoIter = LocHashToAllocInfo.find(StackId);
CallSitesIter = LocHashToCallSites.find(StackId);
- // Check if the leaf is in one of the maps. If not, no need to look
- // further at this call.
- if (AllocInfoIter == LocHashToAllocInfo.end() &&
- CallSitesIter == LocHashToCallSites.end())
- break;
- LeafFound = true;
+ if (AllocInfoIter != LocHashToAllocInfo.end() ||
+ CallSitesIter != LocHashToCallSites.end())
+ LeafFound = true;
}
- InlinedCallStack.push_back(StackId);
+ if (LeafFound)
+ InlinedCallStack.push_back(StackId);
}
// If leaf not in either of the maps, skip inst.
if (!LeafFound)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 83d90049abc3..94af63da38c8 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -152,7 +152,6 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -550,6 +549,7 @@ public:
private:
friend struct MemorySanitizerVisitor;
+ friend struct VarArgHelperBase;
friend struct VarArgAMD64Helper;
friend struct VarArgMIPS64Helper;
friend struct VarArgAArch64Helper;
@@ -574,8 +574,9 @@ private:
Triple TargetTriple;
LLVMContext *C;
- Type *IntptrTy;
+ Type *IntptrTy; ///< Integer type with the size of a ptr in default AS.
Type *OriginTy;
+ PointerType *PtrTy; ///< Integer type with the size of a ptr in default AS.
// XxxTLS variables represent the per-thread state in MSan and per-task state
// in KMSAN.
@@ -595,16 +596,13 @@ private:
/// Thread-local origin storage for function return value.
Value *RetvalOriginTLS;
- /// Thread-local shadow storage for in-register va_arg function
- /// parameters (x86_64-specific).
+ /// Thread-local shadow storage for in-register va_arg function.
Value *VAArgTLS;
- /// Thread-local shadow storage for in-register va_arg function
- /// parameters (x86_64-specific).
+ /// Thread-local shadow storage for in-register va_arg function.
Value *VAArgOriginTLS;
- /// Thread-local shadow storage for va_arg overflow area
- /// (x86_64-specific).
+ /// Thread-local shadow storage for va_arg overflow area.
Value *VAArgOverflowSizeTLS;
/// Are the instrumentation callbacks set up?
@@ -823,11 +821,10 @@ void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) {
PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty());
// Functions for poisoning and unpoisoning memory.
- MsanPoisonAllocaFn =
- M.getOrInsertFunction("__msan_poison_alloca", IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy());
+ MsanPoisonAllocaFn = M.getOrInsertFunction(
+ "__msan_poison_alloca", IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy);
MsanUnpoisonAllocaFn = M.getOrInsertFunction(
- "__msan_unpoison_alloca", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy);
+ "__msan_unpoison_alloca", IRB.getVoidTy(), PtrTy, IntptrTy);
}
static Constant *getOrInsertGlobal(Module &M, StringRef Name, Type *Ty) {
@@ -894,18 +891,18 @@ void MemorySanitizer::createUserspaceApi(Module &M, const TargetLibraryInfo &TLI
FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize);
MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction(
FunctionName, TLI.getAttrList(C, {0, 2}, /*Signed=*/false),
- IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), IRB.getInt8PtrTy(),
+ IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), PtrTy,
IRB.getInt32Ty());
}
- MsanSetAllocaOriginWithDescriptionFn = M.getOrInsertFunction(
- "__msan_set_alloca_origin_with_descr", IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy(), IRB.getInt8PtrTy());
- MsanSetAllocaOriginNoDescriptionFn = M.getOrInsertFunction(
- "__msan_set_alloca_origin_no_descr", IRB.getVoidTy(), IRB.getInt8PtrTy(),
- IntptrTy, IRB.getInt8PtrTy());
- MsanPoisonStackFn = M.getOrInsertFunction(
- "__msan_poison_stack", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy);
+ MsanSetAllocaOriginWithDescriptionFn =
+ M.getOrInsertFunction("__msan_set_alloca_origin_with_descr",
+ IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy, PtrTy);
+ MsanSetAllocaOriginNoDescriptionFn =
+ M.getOrInsertFunction("__msan_set_alloca_origin_no_descr",
+ IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy);
+ MsanPoisonStackFn = M.getOrInsertFunction("__msan_poison_stack",
+ IRB.getVoidTy(), PtrTy, IntptrTy);
}
/// Insert extern declaration of runtime-provided functions and globals.
@@ -923,16 +920,14 @@ void MemorySanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo &TL
IRB.getInt32Ty());
MsanSetOriginFn = M.getOrInsertFunction(
"__msan_set_origin", TLI.getAttrList(C, {2}, /*Signed=*/false),
- IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, IRB.getInt32Ty());
+ IRB.getVoidTy(), PtrTy, IntptrTy, IRB.getInt32Ty());
MemmoveFn =
- M.getOrInsertFunction("__msan_memmove", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy);
+ M.getOrInsertFunction("__msan_memmove", PtrTy, PtrTy, PtrTy, IntptrTy);
MemcpyFn =
- M.getOrInsertFunction("__msan_memcpy", IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy);
- MemsetFn = M.getOrInsertFunction(
- "__msan_memset", TLI.getAttrList(C, {1}, /*Signed=*/true),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy);
+ M.getOrInsertFunction("__msan_memcpy", PtrTy, PtrTy, PtrTy, IntptrTy);
+ MemsetFn = M.getOrInsertFunction("__msan_memset",
+ TLI.getAttrList(C, {1}, /*Signed=*/true),
+ PtrTy, PtrTy, IRB.getInt32Ty(), IntptrTy);
MsanInstrumentAsmStoreFn =
M.getOrInsertFunction("__msan_instrument_asm_store", IRB.getVoidTy(),
@@ -1046,6 +1041,7 @@ void MemorySanitizer::initializeModule(Module &M) {
IRBuilder<> IRB(*C);
IntptrTy = IRB.getIntPtrTy(DL);
OriginTy = IRB.getInt32Ty();
+ PtrTy = IRB.getPtrTy();
ColdCallWeights = MDBuilder(*C).createBranchWeights(1, 1000);
OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000);
@@ -1304,9 +1300,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
FunctionCallee Fn = MS.MaybeStoreOriginFn[SizeIndex];
Value *ConvertedShadow2 =
IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex)));
- CallBase *CB = IRB.CreateCall(
- Fn, {ConvertedShadow2,
- IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), Origin});
+ CallBase *CB = IRB.CreateCall(Fn, {ConvertedShadow2, Addr, Origin});
CB->addParamAttr(0, Attribute::ZExt);
CB->addParamAttr(2, Attribute::ZExt);
} else {
@@ -1676,7 +1670,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
VectTy->getElementCount());
}
assert(IntPtrTy == MS.IntptrTy);
- return ShadowTy->getPointerTo();
+ return PointerType::get(*MS.C, 0);
}
Constant *constToIntPtr(Type *IntPtrTy, uint64_t C) const {
@@ -1718,6 +1712,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
std::pair<Value *, Value *>
getShadowOriginPtrUserspace(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy,
MaybeAlign Alignment) {
+ VectorType *VectTy = dyn_cast<VectorType>(Addr->getType());
+ if (!VectTy) {
+ assert(Addr->getType()->isPointerTy());
+ } else {
+ assert(VectTy->getElementType()->isPointerTy());
+ }
Type *IntptrTy = ptrToIntPtrType(Addr->getType());
Value *ShadowOffset = getShadowPtrOffset(Addr, IRB);
Value *ShadowLong = ShadowOffset;
@@ -1800,11 +1800,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// TODO: Support callbacs with vectors of addresses.
unsigned NumElements = cast<FixedVectorType>(VectTy)->getNumElements();
Value *ShadowPtrs = ConstantInt::getNullValue(
- FixedVectorType::get(ShadowTy->getPointerTo(), NumElements));
+ FixedVectorType::get(IRB.getPtrTy(), NumElements));
Value *OriginPtrs = nullptr;
if (MS.TrackOrigins)
OriginPtrs = ConstantInt::getNullValue(
- FixedVectorType::get(MS.OriginTy->getPointerTo(), NumElements));
+ FixedVectorType::get(IRB.getPtrTy(), NumElements));
for (unsigned i = 0; i < NumElements; ++i) {
Value *OneAddr =
IRB.CreateExtractElement(Addr, ConstantInt::get(IRB.getInt32Ty(), i));
@@ -1832,33 +1832,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
/// Compute the shadow address for a given function argument.
///
/// Shadow = ParamTLS+ArgOffset.
- Value *getShadowPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) {
+ Value *getShadowPtrForArgument(IRBuilder<> &IRB, int ArgOffset) {
Value *Base = IRB.CreatePointerCast(MS.ParamTLS, MS.IntptrTy);
if (ArgOffset)
Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(getShadowTy(A), 0),
- "_msarg");
+ return IRB.CreateIntToPtr(Base, IRB.getPtrTy(0), "_msarg");
}
/// Compute the origin address for a given function argument.
- Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) {
+ Value *getOriginPtrForArgument(IRBuilder<> &IRB, int ArgOffset) {
if (!MS.TrackOrigins)
return nullptr;
Value *Base = IRB.CreatePointerCast(MS.ParamOriginTLS, MS.IntptrTy);
if (ArgOffset)
Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0),
- "_msarg_o");
+ return IRB.CreateIntToPtr(Base, IRB.getPtrTy(0), "_msarg_o");
}
/// Compute the shadow address for a retval.
- Value *getShadowPtrForRetval(Value *A, IRBuilder<> &IRB) {
- return IRB.CreatePointerCast(MS.RetvalTLS,
- PointerType::get(getShadowTy(A), 0), "_msret");
+ Value *getShadowPtrForRetval(IRBuilder<> &IRB) {
+ return IRB.CreatePointerCast(MS.RetvalTLS, IRB.getPtrTy(0), "_msret");
}
/// Compute the origin address for a retval.
- Value *getOriginPtrForRetval(IRBuilder<> &IRB) {
+ Value *getOriginPtrForRetval() {
// We keep a single origin for the entire retval. Might be too optimistic.
return MS.RetvalOriginTLS;
}
@@ -1982,7 +1979,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
CpShadowPtr, Constant::getNullValue(EntryIRB.getInt8Ty()),
Size, ArgAlign);
} else {
- Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset);
+ Value *Base = getShadowPtrForArgument(EntryIRB, ArgOffset);
const Align CopyAlign = std::min(ArgAlign, kShadowTLSAlignment);
Value *Cpy = EntryIRB.CreateMemCpy(CpShadowPtr, CopyAlign, Base,
CopyAlign, Size);
@@ -1991,7 +1988,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (MS.TrackOrigins) {
Value *OriginPtr =
- getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset);
+ getOriginPtrForArgument(EntryIRB, ArgOffset);
// FIXME: OriginSize should be:
// alignTo(V % kMinOriginAlignment + Size, kMinOriginAlignment)
unsigned OriginSize = alignTo(Size, kMinOriginAlignment);
@@ -2010,12 +2007,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOrigin(A, getCleanOrigin());
} else {
// Shadow over TLS
- Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset);
+ Value *Base = getShadowPtrForArgument(EntryIRB, ArgOffset);
ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base,
kShadowTLSAlignment);
if (MS.TrackOrigins) {
Value *OriginPtr =
- getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset);
+ getOriginPtrForArgument(EntryIRB, ArgOffset);
setOrigin(A, EntryIRB.CreateLoad(MS.OriginTy, OriginPtr));
}
}
@@ -2838,11 +2835,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
void visitMemMoveInst(MemMoveInst &I) {
getShadow(I.getArgOperand(1)); // Ensure shadow initialized
IRBuilder<> IRB(&I);
- IRB.CreateCall(
- MS.MemmoveFn,
- {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)});
+ IRB.CreateCall(MS.MemmoveFn,
+ {I.getArgOperand(0), I.getArgOperand(1),
+ IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)});
I.eraseFromParent();
}
@@ -2863,11 +2858,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
void visitMemCpyInst(MemCpyInst &I) {
getShadow(I.getArgOperand(1)); // Ensure shadow initialized
IRBuilder<> IRB(&I);
- IRB.CreateCall(
- MS.MemcpyFn,
- {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)});
+ IRB.CreateCall(MS.MemcpyFn,
+ {I.getArgOperand(0), I.getArgOperand(1),
+ IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)});
I.eraseFromParent();
}
@@ -2876,7 +2869,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
IRBuilder<> IRB(&I);
IRB.CreateCall(
MS.MemsetFn,
- {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()),
+ {I.getArgOperand(0),
IRB.CreateIntCast(I.getArgOperand(1), IRB.getInt32Ty(), false),
IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)});
I.eraseFromParent();
@@ -3385,8 +3378,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Value *ShadowPtr =
getShadowOriginPtr(Addr, IRB, Ty, Align(1), /*isStore*/ true).first;
- IRB.CreateStore(getCleanShadow(Ty),
- IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo()));
+ IRB.CreateStore(getCleanShadow(Ty), ShadowPtr);
if (ClCheckAccessAddress)
insertShadowCheck(Addr, &I);
@@ -4162,7 +4154,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (Function *Func = CB.getCalledFunction()) {
// __sanitizer_unaligned_{load,store} functions may be called by users
// and always expects shadows in the TLS. So don't check them.
- MayCheckCall &= !Func->getName().startswith("__sanitizer_unaligned_");
+ MayCheckCall &= !Func->getName().starts_with("__sanitizer_unaligned_");
}
unsigned ArgOffset = 0;
@@ -4188,7 +4180,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// in that case getShadow() will copy the actual arg shadow to
// __msan_param_tls.
Value *ArgShadow = getShadow(A);
- Value *ArgShadowBase = getShadowPtrForArgument(A, IRB, ArgOffset);
+ Value *ArgShadowBase = getShadowPtrForArgument(IRB, ArgOffset);
LLVM_DEBUG(dbgs() << " Arg#" << i << ": " << *A
<< " Shadow: " << *ArgShadow << "\n");
if (ByVal) {
@@ -4215,7 +4207,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr,
Alignment, Size);
if (MS.TrackOrigins) {
- Value *ArgOriginBase = getOriginPtrForArgument(A, IRB, ArgOffset);
+ Value *ArgOriginBase = getOriginPtrForArgument(IRB, ArgOffset);
// FIXME: OriginSize should be:
// alignTo(A % kMinOriginAlignment + Size, kMinOriginAlignment)
unsigned OriginSize = alignTo(Size, kMinOriginAlignment);
@@ -4237,7 +4229,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Constant *Cst = dyn_cast<Constant>(ArgShadow);
if (MS.TrackOrigins && !(Cst && Cst->isNullValue())) {
IRB.CreateStore(getOrigin(A),
- getOriginPtrForArgument(A, IRB, ArgOffset));
+ getOriginPtrForArgument(IRB, ArgOffset));
}
}
(void)Store;
@@ -4269,7 +4261,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
IRBuilder<> IRBBefore(&CB);
// Until we have full dynamic coverage, make sure the retval shadow is 0.
- Value *Base = getShadowPtrForRetval(&CB, IRBBefore);
+ Value *Base = getShadowPtrForRetval(IRBBefore);
IRBBefore.CreateAlignedStore(getCleanShadow(&CB), Base,
kShadowTLSAlignment);
BasicBlock::iterator NextInsn;
@@ -4294,12 +4286,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
IRBuilder<> IRBAfter(&*NextInsn);
Value *RetvalShadow = IRBAfter.CreateAlignedLoad(
- getShadowTy(&CB), getShadowPtrForRetval(&CB, IRBAfter),
+ getShadowTy(&CB), getShadowPtrForRetval(IRBAfter),
kShadowTLSAlignment, "_msret");
setShadow(&CB, RetvalShadow);
if (MS.TrackOrigins)
setOrigin(&CB, IRBAfter.CreateLoad(MS.OriginTy,
- getOriginPtrForRetval(IRBAfter)));
+ getOriginPtrForRetval()));
}
bool isAMustTailRetVal(Value *RetVal) {
@@ -4320,7 +4312,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// Don't emit the epilogue for musttail call returns.
if (isAMustTailRetVal(RetVal))
return;
- Value *ShadowPtr = getShadowPtrForRetval(RetVal, IRB);
+ Value *ShadowPtr = getShadowPtrForRetval(IRB);
bool HasNoUndef = F.hasRetAttribute(Attribute::NoUndef);
bool StoreShadow = !(MS.EagerChecks && HasNoUndef);
// FIXME: Consider using SpecialCaseList to specify a list of functions that
@@ -4340,7 +4332,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (StoreShadow) {
IRB.CreateAlignedStore(Shadow, ShadowPtr, kShadowTLSAlignment);
if (MS.TrackOrigins && StoreOrigin)
- IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval(IRB));
+ IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval());
}
}
@@ -4374,8 +4366,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
void poisonAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) {
if (PoisonStack && ClPoisonStackWithCall) {
- IRB.CreateCall(MS.MsanPoisonStackFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len});
+ IRB.CreateCall(MS.MsanPoisonStackFn, {&I, Len});
} else {
Value *ShadowBase, *OriginBase;
std::tie(ShadowBase, OriginBase) = getShadowOriginPtr(
@@ -4390,13 +4381,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
if (ClPrintStackNames) {
Value *Descr = getLocalVarDescription(I);
IRB.CreateCall(MS.MsanSetAllocaOriginWithDescriptionFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len,
- IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())});
+ {&I, Len, Idptr, Descr});
} else {
- IRB.CreateCall(MS.MsanSetAllocaOriginNoDescriptionFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len,
- IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy())});
+ IRB.CreateCall(MS.MsanSetAllocaOriginNoDescriptionFn, {&I, Len, Idptr});
}
}
}
@@ -4404,12 +4391,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
void poisonAllocaKmsan(AllocaInst &I, IRBuilder<> &IRB, Value *Len) {
Value *Descr = getLocalVarDescription(I);
if (PoisonStack) {
- IRB.CreateCall(MS.MsanPoisonAllocaFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len,
- IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())});
+ IRB.CreateCall(MS.MsanPoisonAllocaFn, {&I, Len, Descr});
} else {
- IRB.CreateCall(MS.MsanUnpoisonAllocaFn,
- {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len});
+ IRB.CreateCall(MS.MsanUnpoisonAllocaFn, {&I, Len});
}
}
@@ -4571,10 +4555,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
if (!ElemTy->isSized())
return;
- Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy());
Value *SizeVal =
IRB.CreateTypeSize(MS.IntptrTy, DL.getTypeStoreSize(ElemTy));
- IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal});
+ IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Operand, SizeVal});
}
/// Get the number of output arguments returned by pointers.
@@ -4668,8 +4651,91 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
};
+struct VarArgHelperBase : public VarArgHelper {
+ Function &F;
+ MemorySanitizer &MS;
+ MemorySanitizerVisitor &MSV;
+ SmallVector<CallInst *, 16> VAStartInstrumentationList;
+ const unsigned VAListTagSize;
+
+ VarArgHelperBase(Function &F, MemorySanitizer &MS,
+ MemorySanitizerVisitor &MSV, unsigned VAListTagSize)
+ : F(F), MS(MS), MSV(MSV), VAListTagSize(VAListTagSize) {}
+
+ Value *getShadowAddrForVAArgument(IRBuilder<> &IRB, unsigned ArgOffset) {
+ Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
+ return IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
+ }
+
+ /// Compute the shadow address for a given va_arg.
+ Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
+ unsigned ArgOffset) {
+ Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
+ Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
+ return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0),
+ "_msarg_va_s");
+ }
+
+ /// Compute the shadow address for a given va_arg.
+ Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
+ unsigned ArgOffset, unsigned ArgSize) {
+ // Make sure we don't overflow __msan_va_arg_tls.
+ if (ArgOffset + ArgSize > kParamTLSSize)
+ return nullptr;
+ return getShadowPtrForVAArgument(Ty, IRB, ArgOffset);
+ }
+
+ /// Compute the origin address for a given va_arg.
+ Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) {
+ Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy);
+ // getOriginPtrForVAArgument() is always called after
+ // getShadowPtrForVAArgument(), so __msan_va_arg_origin_tls can never
+ // overflow.
+ Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
+ return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0),
+ "_msarg_va_o");
+ }
+
+ void CleanUnusedTLS(IRBuilder<> &IRB, Value *ShadowBase,
+ unsigned BaseOffset) {
+ // The tails of __msan_va_arg_tls is not large enough to fit full
+ // value shadow, but it will be copied to backup anyway. Make it
+ // clean.
+ if (BaseOffset >= kParamTLSSize)
+ return;
+ Value *TailSize =
+ ConstantInt::getSigned(IRB.getInt32Ty(), kParamTLSSize - BaseOffset);
+ IRB.CreateMemSet(ShadowBase, ConstantInt::getNullValue(IRB.getInt8Ty()),
+ TailSize, Align(8));
+ }
+
+ void unpoisonVAListTagForInst(IntrinsicInst &I) {
+ IRBuilder<> IRB(&I);
+ Value *VAListTag = I.getArgOperand(0);
+ const Align Alignment = Align(8);
+ auto [ShadowPtr, OriginPtr] = MSV.getShadowOriginPtr(
+ VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
+ // Unpoison the whole __va_list_tag.
+ IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
+ VAListTagSize, Alignment, false);
+ }
+
+ void visitVAStartInst(VAStartInst &I) override {
+ if (F.getCallingConv() == CallingConv::Win64)
+ return;
+ VAStartInstrumentationList.push_back(&I);
+ unpoisonVAListTagForInst(I);
+ }
+
+ void visitVACopyInst(VACopyInst &I) override {
+ if (F.getCallingConv() == CallingConv::Win64)
+ return;
+ unpoisonVAListTagForInst(I);
+ }
+};
+
/// AMD64-specific implementation of VarArgHelper.
-struct VarArgAMD64Helper : public VarArgHelper {
+struct VarArgAMD64Helper : public VarArgHelperBase {
// An unfortunate workaround for asymmetric lowering of va_arg stuff.
// See a comment in visitCallBase for more details.
static const unsigned AMD64GpEndOffset = 48; // AMD64 ABI Draft 0.99.6 p3.5.7
@@ -4678,20 +4744,15 @@ struct VarArgAMD64Helper : public VarArgHelper {
static const unsigned AMD64FpEndOffsetNoSSE = AMD64GpEndOffset;
unsigned AMD64FpEndOffset;
- Function &F;
- MemorySanitizer &MS;
- MemorySanitizerVisitor &MSV;
AllocaInst *VAArgTLSCopy = nullptr;
AllocaInst *VAArgTLSOriginCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
- SmallVector<CallInst *, 16> VAStartInstrumentationList;
-
enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory };
VarArgAMD64Helper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV) {
+ : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/24) {
AMD64FpEndOffset = AMD64FpEndOffsetSSE;
for (const auto &Attr : F.getAttributes().getFnAttrs()) {
if (Attr.isStringAttribute() &&
@@ -4706,6 +4767,8 @@ struct VarArgAMD64Helper : public VarArgHelper {
ArgKind classifyArgument(Value *arg) {
// A very rough approximation of X86_64 argument classification rules.
Type *T = arg->getType();
+ if (T->isX86_FP80Ty())
+ return AK_Memory;
if (T->isFPOrFPVectorTy() || T->isX86_MMXTy())
return AK_FloatingPoint;
if (T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64)
@@ -4728,6 +4791,7 @@ struct VarArgAMD64Helper : public VarArgHelper {
unsigned FpOffset = AMD64GpEndOffset;
unsigned OverflowOffset = AMD64FpEndOffset;
const DataLayout &DL = F.getParent()->getDataLayout();
+
for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) {
bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams();
bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal);
@@ -4740,19 +4804,24 @@ struct VarArgAMD64Helper : public VarArgHelper {
assert(A->getType()->isPointerTy());
Type *RealTy = CB.getParamByValType(ArgNo);
uint64_t ArgSize = DL.getTypeAllocSize(RealTy);
- Value *ShadowBase = getShadowPtrForVAArgument(
- RealTy, IRB, OverflowOffset, alignTo(ArgSize, 8));
+ uint64_t AlignedSize = alignTo(ArgSize, 8);
+ unsigned BaseOffset = OverflowOffset;
+ Value *ShadowBase =
+ getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset);
Value *OriginBase = nullptr;
if (MS.TrackOrigins)
- OriginBase = getOriginPtrForVAArgument(RealTy, IRB, OverflowOffset);
- OverflowOffset += alignTo(ArgSize, 8);
- if (!ShadowBase)
- continue;
+ OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset);
+ OverflowOffset += AlignedSize;
+
+ if (OverflowOffset > kParamTLSSize) {
+ CleanUnusedTLS(IRB, ShadowBase, BaseOffset);
+ continue; // We have no space to copy shadow there.
+ }
+
Value *ShadowPtr, *OriginPtr;
std::tie(ShadowPtr, OriginPtr) =
MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment,
/*isStore*/ false);
-
IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr,
kShadowTLSAlignment, ArgSize);
if (MS.TrackOrigins)
@@ -4767,37 +4836,42 @@ struct VarArgAMD64Helper : public VarArgHelper {
Value *ShadowBase, *OriginBase = nullptr;
switch (AK) {
case AK_GeneralPurpose:
- ShadowBase =
- getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8);
+ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset);
if (MS.TrackOrigins)
- OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, GpOffset);
+ OriginBase = getOriginPtrForVAArgument(IRB, GpOffset);
GpOffset += 8;
+ assert(GpOffset <= kParamTLSSize);
break;
case AK_FloatingPoint:
- ShadowBase =
- getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16);
+ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset);
if (MS.TrackOrigins)
- OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, FpOffset);
+ OriginBase = getOriginPtrForVAArgument(IRB, FpOffset);
FpOffset += 16;
+ assert(FpOffset <= kParamTLSSize);
break;
case AK_Memory:
if (IsFixed)
continue;
uint64_t ArgSize = DL.getTypeAllocSize(A->getType());
+ uint64_t AlignedSize = alignTo(ArgSize, 8);
+ unsigned BaseOffset = OverflowOffset;
ShadowBase =
- getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8);
- if (MS.TrackOrigins)
- OriginBase =
- getOriginPtrForVAArgument(A->getType(), IRB, OverflowOffset);
- OverflowOffset += alignTo(ArgSize, 8);
+ getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset);
+ if (MS.TrackOrigins) {
+ OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset);
+ }
+ OverflowOffset += AlignedSize;
+ if (OverflowOffset > kParamTLSSize) {
+ // We have no space to copy shadow there.
+ CleanUnusedTLS(IRB, ShadowBase, BaseOffset);
+ continue;
+ }
}
// Take fixed arguments into account for GpOffset and FpOffset,
// but don't actually store shadows for them.
// TODO(glider): don't call get*PtrForVAArgument() for them.
if (IsFixed)
continue;
- if (!ShadowBase)
- continue;
Value *Shadow = MSV.getShadow(A);
IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment);
if (MS.TrackOrigins) {
@@ -4813,59 +4887,6 @@ struct VarArgAMD64Helper : public VarArgHelper {
IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS);
}
- /// Compute the shadow address for a given va_arg.
- Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
- unsigned ArgOffset, unsigned ArgSize) {
- // Make sure we don't overflow __msan_va_arg_tls.
- if (ArgOffset + ArgSize > kParamTLSSize)
- return nullptr;
- Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0),
- "_msarg_va_s");
- }
-
- /// Compute the origin address for a given va_arg.
- Value *getOriginPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) {
- Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy);
- // getOriginPtrForVAArgument() is always called after
- // getShadowPtrForVAArgument(), so __msan_va_arg_origin_tls can never
- // overflow.
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0),
- "_msarg_va_o");
- }
-
- void unpoisonVAListTagForInst(IntrinsicInst &I) {
- IRBuilder<> IRB(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) =
- MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment,
- /*isStore*/ true);
-
- // Unpoison the whole __va_list_tag.
- // FIXME: magic ABI constants.
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 24, Alignment, false);
- // We shouldn't need to zero out the origins, as they're only checked for
- // nonzero shadow.
- }
-
- void visitVAStartInst(VAStartInst &I) override {
- if (F.getCallingConv() == CallingConv::Win64)
- return;
- VAStartInstrumentationList.push_back(&I);
- unpoisonVAListTagForInst(I);
- }
-
- void visitVACopyInst(VACopyInst &I) override {
- if (F.getCallingConv() == CallingConv::Win64)
- return;
- unpoisonVAListTagForInst(I);
- }
-
void finalizeInstrumentation() override {
assert(!VAArgOverflowSize && !VAArgTLSCopy &&
"finalizeInstrumentation called twice");
@@ -4902,7 +4923,7 @@ struct VarArgAMD64Helper : public VarArgHelper {
NextNodeIRBuilder IRB(OrigInst);
Value *VAListTag = OrigInst->getArgOperand(0);
- Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
ConstantInt::get(MS.IntptrTy, 16)),
@@ -4919,7 +4940,7 @@ struct VarArgAMD64Helper : public VarArgHelper {
if (MS.TrackOrigins)
IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy,
Alignment, AMD64FpEndOffset);
- Type *OverflowArgAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *OverflowArgAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
ConstantInt::get(MS.IntptrTy, 8)),
@@ -4945,18 +4966,14 @@ struct VarArgAMD64Helper : public VarArgHelper {
};
/// MIPS64-specific implementation of VarArgHelper.
-struct VarArgMIPS64Helper : public VarArgHelper {
- Function &F;
- MemorySanitizer &MS;
- MemorySanitizerVisitor &MSV;
+/// NOTE: This is also used for LoongArch64.
+struct VarArgMIPS64Helper : public VarArgHelperBase {
AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgSize = nullptr;
- SmallVector<CallInst *, 16> VAStartInstrumentationList;
-
VarArgMIPS64Helper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV) {}
+ : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/8) {}
void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override {
unsigned VAArgOffset = 0;
@@ -4986,42 +5003,6 @@ struct VarArgMIPS64Helper : public VarArgHelper {
IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS);
}
- /// Compute the shadow address for a given va_arg.
- Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
- unsigned ArgOffset, unsigned ArgSize) {
- // Make sure we don't overflow __msan_va_arg_tls.
- if (ArgOffset + ArgSize > kParamTLSSize)
- return nullptr;
- Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0),
- "_msarg");
- }
-
- void visitVAStartInst(VAStartInst &I) override {
- IRBuilder<> IRB(&I);
- VAStartInstrumentationList.push_back(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 8, Alignment, false);
- }
-
- void visitVACopyInst(VACopyInst &I) override {
- IRBuilder<> IRB(&I);
- VAStartInstrumentationList.push_back(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 8, Alignment, false);
- }
-
void finalizeInstrumentation() override {
assert(!VAArgSize && !VAArgTLSCopy &&
"finalizeInstrumentation called twice");
@@ -5051,7 +5032,7 @@ struct VarArgMIPS64Helper : public VarArgHelper {
CallInst *OrigInst = VAStartInstrumentationList[i];
NextNodeIRBuilder IRB(OrigInst);
Value *VAListTag = OrigInst->getArgOperand(0);
- Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *RegSaveAreaPtrPtr =
IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
PointerType::get(RegSaveAreaPtrTy, 0));
@@ -5069,7 +5050,7 @@ struct VarArgMIPS64Helper : public VarArgHelper {
};
/// AArch64-specific implementation of VarArgHelper.
-struct VarArgAArch64Helper : public VarArgHelper {
+struct VarArgAArch64Helper : public VarArgHelperBase {
static const unsigned kAArch64GrArgSize = 64;
static const unsigned kAArch64VrArgSize = 128;
@@ -5081,28 +5062,36 @@ struct VarArgAArch64Helper : public VarArgHelper {
AArch64VrBegOffset + kAArch64VrArgSize;
static const unsigned AArch64VAEndOffset = AArch64VrEndOffset;
- Function &F;
- MemorySanitizer &MS;
- MemorySanitizerVisitor &MSV;
AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
- SmallVector<CallInst *, 16> VAStartInstrumentationList;
-
enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory };
VarArgAArch64Helper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV) {}
+ : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/32) {}
- ArgKind classifyArgument(Value *arg) {
- Type *T = arg->getType();
- if (T->isFPOrFPVectorTy())
- return AK_FloatingPoint;
- if ((T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64) ||
- (T->isPointerTy()))
- return AK_GeneralPurpose;
- return AK_Memory;
+ // A very rough approximation of aarch64 argument classification rules.
+ std::pair<ArgKind, uint64_t> classifyArgument(Type *T) {
+ if (T->isIntOrPtrTy() && T->getPrimitiveSizeInBits() <= 64)
+ return {AK_GeneralPurpose, 1};
+ if (T->isFloatingPointTy() && T->getPrimitiveSizeInBits() <= 128)
+ return {AK_FloatingPoint, 1};
+
+ if (T->isArrayTy()) {
+ auto R = classifyArgument(T->getArrayElementType());
+ R.second *= T->getScalarType()->getArrayNumElements();
+ return R;
+ }
+
+ if (const FixedVectorType *FV = dyn_cast<FixedVectorType>(T)) {
+ auto R = classifyArgument(FV->getScalarType());
+ R.second *= FV->getNumElements();
+ return R;
+ }
+
+ LLVM_DEBUG(errs() << "Unknown vararg type: " << *T << "\n");
+ return {AK_Memory, 0};
}
// The instrumentation stores the argument shadow in a non ABI-specific
@@ -5110,7 +5099,7 @@ struct VarArgAArch64Helper : public VarArgHelper {
// like x86_64 case, lowers the va_args in the frontend and this pass only
// sees the low level code that deals with va_list internals).
// The first seven GR registers are saved in the first 56 bytes of the
- // va_arg tls arra, followers by the first 8 FP/SIMD registers, and then
+ // va_arg tls arra, followed by the first 8 FP/SIMD registers, and then
// the remaining arguments.
// Using constant offset within the va_arg TLS array allows fast copy
// in the finalize instrumentation.
@@ -5122,20 +5111,22 @@ struct VarArgAArch64Helper : public VarArgHelper {
const DataLayout &DL = F.getParent()->getDataLayout();
for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) {
bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams();
- ArgKind AK = classifyArgument(A);
- if (AK == AK_GeneralPurpose && GrOffset >= AArch64GrEndOffset)
+ auto [AK, RegNum] = classifyArgument(A->getType());
+ if (AK == AK_GeneralPurpose &&
+ (GrOffset + RegNum * 8) > AArch64GrEndOffset)
AK = AK_Memory;
- if (AK == AK_FloatingPoint && VrOffset >= AArch64VrEndOffset)
+ if (AK == AK_FloatingPoint &&
+ (VrOffset + RegNum * 16) > AArch64VrEndOffset)
AK = AK_Memory;
Value *Base;
switch (AK) {
case AK_GeneralPurpose:
- Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset, 8);
- GrOffset += 8;
+ Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset);
+ GrOffset += 8 * RegNum;
break;
case AK_FloatingPoint:
- Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset, 8);
- VrOffset += 16;
+ Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset);
+ VrOffset += 16 * RegNum;
break;
case AK_Memory:
// Don't count fixed arguments in the overflow area - va_start will
@@ -5143,17 +5134,21 @@ struct VarArgAArch64Helper : public VarArgHelper {
if (IsFixed)
continue;
uint64_t ArgSize = DL.getTypeAllocSize(A->getType());
- Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset,
- alignTo(ArgSize, 8));
- OverflowOffset += alignTo(ArgSize, 8);
+ uint64_t AlignedSize = alignTo(ArgSize, 8);
+ unsigned BaseOffset = OverflowOffset;
+ Base = getShadowPtrForVAArgument(A->getType(), IRB, BaseOffset);
+ OverflowOffset += AlignedSize;
+ if (OverflowOffset > kParamTLSSize) {
+ // We have no space to copy shadow there.
+ CleanUnusedTLS(IRB, Base, BaseOffset);
+ continue;
+ }
break;
}
// Count Gp/Vr fixed arguments to their respective offsets, but don't
// bother to actually store a shadow.
if (IsFixed)
continue;
- if (!Base)
- continue;
IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment);
}
Constant *OverflowSize =
@@ -5161,48 +5156,12 @@ struct VarArgAArch64Helper : public VarArgHelper {
IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS);
}
- /// Compute the shadow address for a given va_arg.
- Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
- unsigned ArgOffset, unsigned ArgSize) {
- // Make sure we don't overflow __msan_va_arg_tls.
- if (ArgOffset + ArgSize > kParamTLSSize)
- return nullptr;
- Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0),
- "_msarg");
- }
-
- void visitVAStartInst(VAStartInst &I) override {
- IRBuilder<> IRB(&I);
- VAStartInstrumentationList.push_back(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 32, Alignment, false);
- }
-
- void visitVACopyInst(VACopyInst &I) override {
- IRBuilder<> IRB(&I);
- VAStartInstrumentationList.push_back(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 32, Alignment, false);
- }
-
// Retrieve a va_list field of 'void*' size.
Value *getVAField64(IRBuilder<> &IRB, Value *VAListTag, int offset) {
Value *SaveAreaPtrPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
ConstantInt::get(MS.IntptrTy, offset)),
- Type::getInt64PtrTy(*MS.C));
+ PointerType::get(*MS.C, 0));
return IRB.CreateLoad(Type::getInt64Ty(*MS.C), SaveAreaPtrPtr);
}
@@ -5211,7 +5170,7 @@ struct VarArgAArch64Helper : public VarArgHelper {
Value *SaveAreaPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
ConstantInt::get(MS.IntptrTy, offset)),
- Type::getInt32PtrTy(*MS.C));
+ PointerType::get(*MS.C, 0));
Value *SaveArea32 = IRB.CreateLoad(IRB.getInt32Ty(), SaveAreaPtr);
return IRB.CreateSExt(SaveArea32, MS.IntptrTy);
}
@@ -5262,21 +5221,25 @@ struct VarArgAArch64Helper : public VarArgHelper {
// we need to adjust the offset for both GR and VR fields based on
// the __{gr,vr}_offs value (since they are stores based on incoming
// named arguments).
+ Type *RegSaveAreaPtrTy = IRB.getPtrTy();
// Read the stack pointer from the va_list.
- Value *StackSaveAreaPtr = getVAField64(IRB, VAListTag, 0);
+ Value *StackSaveAreaPtr =
+ IRB.CreateIntToPtr(getVAField64(IRB, VAListTag, 0), RegSaveAreaPtrTy);
// Read both the __gr_top and __gr_off and add them up.
Value *GrTopSaveAreaPtr = getVAField64(IRB, VAListTag, 8);
Value *GrOffSaveArea = getVAField32(IRB, VAListTag, 24);
- Value *GrRegSaveAreaPtr = IRB.CreateAdd(GrTopSaveAreaPtr, GrOffSaveArea);
+ Value *GrRegSaveAreaPtr = IRB.CreateIntToPtr(
+ IRB.CreateAdd(GrTopSaveAreaPtr, GrOffSaveArea), RegSaveAreaPtrTy);
// Read both the __vr_top and __vr_off and add them up.
Value *VrTopSaveAreaPtr = getVAField64(IRB, VAListTag, 16);
Value *VrOffSaveArea = getVAField32(IRB, VAListTag, 28);
- Value *VrRegSaveAreaPtr = IRB.CreateAdd(VrTopSaveAreaPtr, VrOffSaveArea);
+ Value *VrRegSaveAreaPtr = IRB.CreateIntToPtr(
+ IRB.CreateAdd(VrTopSaveAreaPtr, VrOffSaveArea), RegSaveAreaPtrTy);
// It does not know how many named arguments is being used and, on the
// callsite all the arguments were saved. Since __gr_off is defined as
@@ -5332,18 +5295,13 @@ struct VarArgAArch64Helper : public VarArgHelper {
};
/// PowerPC64-specific implementation of VarArgHelper.
-struct VarArgPowerPC64Helper : public VarArgHelper {
- Function &F;
- MemorySanitizer &MS;
- MemorySanitizerVisitor &MSV;
+struct VarArgPowerPC64Helper : public VarArgHelperBase {
AllocaInst *VAArgTLSCopy = nullptr;
Value *VAArgSize = nullptr;
- SmallVector<CallInst *, 16> VAStartInstrumentationList;
-
VarArgPowerPC64Helper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV) {}
+ : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/8) {}
void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override {
// For PowerPC, we need to deal with alignment of stack arguments -
@@ -5431,43 +5389,6 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS);
}
- /// Compute the shadow address for a given va_arg.
- Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB,
- unsigned ArgOffset, unsigned ArgSize) {
- // Make sure we don't overflow __msan_va_arg_tls.
- if (ArgOffset + ArgSize > kParamTLSSize)
- return nullptr;
- Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0),
- "_msarg");
- }
-
- void visitVAStartInst(VAStartInst &I) override {
- IRBuilder<> IRB(&I);
- VAStartInstrumentationList.push_back(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 8, Alignment, false);
- }
-
- void visitVACopyInst(VACopyInst &I) override {
- IRBuilder<> IRB(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(
- VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true);
- // Unpoison the whole __va_list_tag.
- // FIXME: magic ABI constants.
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- /* size */ 8, Alignment, false);
- }
-
void finalizeInstrumentation() override {
assert(!VAArgSize && !VAArgTLSCopy &&
"finalizeInstrumentation called twice");
@@ -5498,7 +5419,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
CallInst *OrigInst = VAStartInstrumentationList[i];
NextNodeIRBuilder IRB(OrigInst);
Value *VAListTag = OrigInst->getArgOperand(0);
- Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *RegSaveAreaPtrPtr =
IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
PointerType::get(RegSaveAreaPtrTy, 0));
@@ -5516,7 +5437,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper {
};
/// SystemZ-specific implementation of VarArgHelper.
-struct VarArgSystemZHelper : public VarArgHelper {
+struct VarArgSystemZHelper : public VarArgHelperBase {
static const unsigned SystemZGpOffset = 16;
static const unsigned SystemZGpEndOffset = 56;
static const unsigned SystemZFpOffset = 128;
@@ -5528,16 +5449,11 @@ struct VarArgSystemZHelper : public VarArgHelper {
static const unsigned SystemZOverflowArgAreaPtrOffset = 16;
static const unsigned SystemZRegSaveAreaPtrOffset = 24;
- Function &F;
- MemorySanitizer &MS;
- MemorySanitizerVisitor &MSV;
bool IsSoftFloatABI;
AllocaInst *VAArgTLSCopy = nullptr;
AllocaInst *VAArgTLSOriginCopy = nullptr;
Value *VAArgOverflowSize = nullptr;
- SmallVector<CallInst *, 16> VAStartInstrumentationList;
-
enum class ArgKind {
GeneralPurpose,
FloatingPoint,
@@ -5550,7 +5466,7 @@ struct VarArgSystemZHelper : public VarArgHelper {
VarArgSystemZHelper(Function &F, MemorySanitizer &MS,
MemorySanitizerVisitor &MSV)
- : F(F), MS(MS), MSV(MSV),
+ : VarArgHelperBase(F, MS, MSV, SystemZVAListTagSize),
IsSoftFloatABI(F.getFnAttribute("use-soft-float").getValueAsBool()) {}
ArgKind classifyArgument(Type *T) {
@@ -5711,39 +5627,8 @@ struct VarArgSystemZHelper : public VarArgHelper {
IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS);
}
- Value *getShadowAddrForVAArgument(IRBuilder<> &IRB, unsigned ArgOffset) {
- Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy);
- return IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- }
-
- Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) {
- Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy);
- Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset));
- return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0),
- "_msarg_va_o");
- }
-
- void unpoisonVAListTagForInst(IntrinsicInst &I) {
- IRBuilder<> IRB(&I);
- Value *VAListTag = I.getArgOperand(0);
- Value *ShadowPtr, *OriginPtr;
- const Align Alignment = Align(8);
- std::tie(ShadowPtr, OriginPtr) =
- MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment,
- /*isStore*/ true);
- IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()),
- SystemZVAListTagSize, Alignment, false);
- }
-
- void visitVAStartInst(VAStartInst &I) override {
- VAStartInstrumentationList.push_back(&I);
- unpoisonVAListTagForInst(I);
- }
-
- void visitVACopyInst(VACopyInst &I) override { unpoisonVAListTagForInst(I); }
-
void copyRegSaveArea(IRBuilder<> &IRB, Value *VAListTag) {
- Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(
IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
@@ -5767,8 +5652,10 @@ struct VarArgSystemZHelper : public VarArgHelper {
Alignment, RegSaveAreaSize);
}
+ // FIXME: This implementation limits OverflowOffset to kParamTLSSize, so we
+ // don't know real overflow size and can't clear shadow beyond kParamTLSSize.
void copyOverflowArea(IRBuilder<> &IRB, Value *VAListTag) {
- Type *OverflowArgAreaPtrTy = Type::getInt64PtrTy(*MS.C);
+ Type *OverflowArgAreaPtrTy = PointerType::getUnqual(*MS.C); // i64*
Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(
IRB.CreatePtrToInt(VAListTag, MS.IntptrTy),
@@ -5836,6 +5723,10 @@ struct VarArgSystemZHelper : public VarArgHelper {
}
};
+// Loongarch64 is not a MIPS, but the current vargs calling convention matches
+// the MIPS.
+using VarArgLoongArch64Helper = VarArgMIPS64Helper;
+
/// A no-op implementation of VarArgHelper.
struct VarArgNoOpHelper : public VarArgHelper {
VarArgNoOpHelper(Function &F, MemorySanitizer &MS,
@@ -5868,6 +5759,8 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan,
return new VarArgPowerPC64Helper(Func, Msan, Visitor);
else if (TargetTriple.getArch() == Triple::systemz)
return new VarArgSystemZHelper(Func, Msan, Visitor);
+ else if (TargetTriple.isLoongArch64())
+ return new VarArgLoongArch64Helper(Func, Msan, Visitor);
else
return new VarArgNoOpHelper(Func, Msan, Visitor);
}
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 3c8f25d73c62..4a5a0b25bebb 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -327,7 +327,6 @@ extern cl::opt<PGOViewCountsType> PGOViewCounts;
// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
extern cl::opt<std::string> ViewBlockFreqFuncName;
-extern cl::opt<bool> DebugInfoCorrelate;
} // namespace llvm
static cl::opt<bool>
@@ -525,6 +524,7 @@ public:
std::vector<std::vector<VPCandidateInfo>> ValueSites;
SelectInstVisitor SIVisitor;
std::string FuncName;
+ std::string DeprecatedFuncName;
GlobalVariable *FuncNameVar;
// CFG hash value for this function.
@@ -582,21 +582,22 @@ public:
if (!IsCS) {
NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
NumOfPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size();
- NumOfPGOBB += MST.BBInfos.size();
+ NumOfPGOBB += MST.bbInfoSize();
ValueSites[IPVK_IndirectCallTarget] = VPC.get(IPVK_IndirectCallTarget);
} else {
NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
NumOfCSPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size();
- NumOfCSPGOBB += MST.BBInfos.size();
+ NumOfCSPGOBB += MST.bbInfoSize();
}
- FuncName = getPGOFuncName(F);
+ FuncName = getIRPGOFuncName(F);
+ DeprecatedFuncName = getPGOFuncName(F);
computeCFGHash();
if (!ComdatMembers.empty())
renameComdatFunction();
LLVM_DEBUG(dumpInfo("after CFGMST"));
- for (auto &E : MST.AllEdges) {
+ for (const auto &E : MST.allEdges()) {
if (E->Removed)
continue;
IsCS ? NumOfCSPGOEdge++ : NumOfPGOEdge++;
@@ -639,7 +640,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 |
(uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 |
//(uint64_t)ValueSites[IPVK_MemOPSize].size() << 40 |
- (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
+ (uint64_t)MST.numEdges() << 32 | JC.getCRC();
} else {
// The higher 32 bits.
auto updateJCH = [&JCH](uint64_t Num) {
@@ -653,7 +654,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
if (BCI) {
updateJCH(BCI->getInstrumentedBlocksHash());
} else {
- updateJCH((uint64_t)MST.AllEdges.size());
+ updateJCH((uint64_t)MST.numEdges());
}
// Hash format for context sensitive profile. Reserve 4 bits for other
@@ -668,7 +669,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
LLVM_DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n"
<< " CRC = " << JC.getCRC()
<< ", Selects = " << SIVisitor.getNumOfSelectInsts()
- << ", Edges = " << MST.AllEdges.size() << ", ICSites = "
+ << ", Edges = " << MST.numEdges() << ", ICSites = "
<< ValueSites[IPVK_IndirectCallTarget].size());
if (!PGOOldCFGHashing) {
LLVM_DEBUG(dbgs() << ", Memops = " << ValueSites[IPVK_MemOPSize].size()
@@ -756,8 +757,8 @@ void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs(
// Use a worklist as we will update the vector during the iteration.
std::vector<Edge *> EdgeList;
- EdgeList.reserve(MST.AllEdges.size());
- for (auto &E : MST.AllEdges)
+ EdgeList.reserve(MST.numEdges());
+ for (const auto &E : MST.allEdges())
EdgeList.push_back(E.get());
for (auto &E : EdgeList) {
@@ -874,8 +875,7 @@ static void instrumentOneFunc(
F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry,
PGOBlockCoverage);
- Type *I8PtrTy = Type::getInt8PtrTy(M->getContext());
- auto Name = ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy);
+ auto Name = FuncInfo.FuncNameVar;
auto CFGHash = ConstantInt::get(Type::getInt64Ty(M->getContext()),
FuncInfo.FunctionHash);
if (PGOFunctionEntryCoverage) {
@@ -964,9 +964,8 @@ static void instrumentOneFunc(
populateEHOperandBundle(Cand, BlockColors, OpBundles);
Builder.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
- {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
- Builder.getInt64(FuncInfo.FunctionHash), ToProfile,
- Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)},
+ {FuncInfo.FuncNameVar, Builder.getInt64(FuncInfo.FunctionHash),
+ ToProfile, Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)},
OpBundles);
}
} // IPVK_First <= Kind <= IPVK_Last
@@ -1164,12 +1163,12 @@ private:
} // end anonymous namespace
/// Set up InEdges/OutEdges for all BBs in the MST.
-static void
-setupBBInfoEdges(FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) {
+static void setupBBInfoEdges(
+ const FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) {
// This is not required when there is block coverage inference.
if (FuncInfo.BCI)
return;
- for (auto &E : FuncInfo.MST.AllEdges) {
+ for (const auto &E : FuncInfo.MST.allEdges()) {
if (E->Removed)
continue;
const BasicBlock *SrcBB = E->SrcBB;
@@ -1225,7 +1224,7 @@ bool PGOUseFunc::setInstrumentedCounts(
// Set the profile count the Instrumented edges. There are BBs that not in
// MST but not instrumented. Need to set the edge count value so that we can
// populate the profile counts later.
- for (auto &E : FuncInfo.MST.AllEdges) {
+ for (const auto &E : FuncInfo.MST.allEdges()) {
if (E->Removed || E->InMST)
continue;
const BasicBlock *SrcBB = E->SrcBB;
@@ -1336,7 +1335,8 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
auto &Ctx = M->getContext();
uint64_t MismatchedFuncSum = 0;
Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(
- FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum);
+ FuncInfo.FuncName, FuncInfo.FunctionHash, FuncInfo.DeprecatedFuncName,
+ &MismatchedFuncSum);
if (Error E = Result.takeError()) {
handleInstrProfError(std::move(E), MismatchedFuncSum);
return false;
@@ -1381,7 +1381,8 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros,
void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
uint64_t MismatchedFuncSum = 0;
Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(
- FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum);
+ FuncInfo.FuncName, FuncInfo.FunctionHash, FuncInfo.DeprecatedFuncName,
+ &MismatchedFuncSum);
if (auto Err = Result.takeError()) {
handleInstrProfError(std::move(Err), MismatchedFuncSum);
return;
@@ -1436,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
// If A is uncovered, set weight=1.
// This setup will allow BFI to give nonzero profile counts to only covered
// blocks.
- SmallVector<unsigned, 4> Weights;
+ SmallVector<uint32_t, 4> Weights;
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
- BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Weights));
+ llvm::setBranchWeights(*BB.getTerminator(), Weights);
}
unsigned NumCorruptCoverage = 0;
@@ -1647,12 +1647,10 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) {
Module *M = F.getParent();
IRBuilder<> Builder(&SI);
Type *Int64Ty = Builder.getInt64Ty();
- Type *I8PtrTy = Builder.getInt8PtrTy();
auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty);
Builder.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step),
- {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
- Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs),
+ {FuncNameVar, Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs),
Builder.getInt32(*CurCtrIdx), Step});
++(*CurCtrIdx);
}
@@ -1757,17 +1755,10 @@ static void collectComdatMembers(
ComdatMembers.insert(std::make_pair(C, &GA));
}
-// Don't perform PGO instrumeatnion / profile-use.
-static bool skipPGO(const Function &F) {
+// Return true if we should not find instrumentation data for this function
+static bool skipPGOUse(const Function &F) {
if (F.isDeclaration())
return true;
- if (F.hasFnAttribute(llvm::Attribute::NoProfile))
- return true;
- if (F.hasFnAttribute(llvm::Attribute::SkipProfile))
- return true;
- if (F.getInstructionCount() < PGOFunctionSizeThreshold)
- return true;
-
// If there are too many critical edges, PGO might cause
// compiler time problem. Skip PGO if the number of
// critical edges execeed the threshold.
@@ -1785,7 +1776,19 @@ static bool skipPGO(const Function &F) {
<< " exceed the threshold. Skip PGO.\n");
return true;
}
+ return false;
+}
+// Return true if we should not instrument this function
+static bool skipPGOGen(const Function &F) {
+ if (skipPGOUse(F))
+ return true;
+ if (F.hasFnAttribute(llvm::Attribute::NoProfile))
+ return true;
+ if (F.hasFnAttribute(llvm::Attribute::SkipProfile))
+ return true;
+ if (F.getInstructionCount() < PGOFunctionSizeThreshold)
+ return true;
return false;
}
@@ -1801,7 +1804,7 @@ static bool InstrumentAllFunctions(
collectComdatMembers(M, ComdatMembers);
for (auto &F : M) {
- if (skipPGO(F))
+ if (skipPGOGen(F))
continue;
auto &TLI = LookupTLI(F);
auto *BPI = LookupBPI(F);
@@ -2028,7 +2031,7 @@ static bool annotateAllFunctions(
InstrumentFuncEntry = PGOInstrumentEntry;
bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
for (auto &F : M) {
- if (skipPGO(F))
+ if (skipPGOUse(F))
continue;
auto &TLI = LookupTLI(F);
auto *BPI = LookupBPI(F);
@@ -2201,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Module *M, Instruction *TI,
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
- MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
@@ -2215,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
index 2906fe190984..fd0f69eca96e 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp
@@ -378,7 +378,7 @@ bool MemOPSizeOpt::perform(MemOp MO) {
assert(It != DefaultBB->end());
BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It), DT);
MergeBB->setName("MemOP.Merge");
- BFI.setBlockFreq(MergeBB, OrigBBFreq.getFrequency());
+ BFI.setBlockFreq(MergeBB, OrigBBFreq);
DefaultBB->setName("MemOP.Default");
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
index d83a3a991c89..230bb8b0a5dc 100644
--- a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
+++ b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp
@@ -198,17 +198,16 @@ bool SanitizerBinaryMetadata::run() {
// metadata features.
//
- auto *Int8PtrTy = IRB.getInt8PtrTy();
- auto *Int8PtrPtrTy = PointerType::getUnqual(Int8PtrTy);
+ auto *PtrTy = IRB.getPtrTy();
auto *Int32Ty = IRB.getInt32Ty();
- const std::array<Type *, 3> InitTypes = {Int32Ty, Int8PtrPtrTy, Int8PtrPtrTy};
+ const std::array<Type *, 3> InitTypes = {Int32Ty, PtrTy, PtrTy};
auto *Version = ConstantInt::get(Int32Ty, getVersion());
for (const MetadataInfo *MI : MIS) {
const std::array<Value *, InitTypes.size()> InitArgs = {
Version,
- getSectionMarker(getSectionStart(MI->SectionSuffix), Int8PtrTy),
- getSectionMarker(getSectionEnd(MI->SectionSuffix), Int8PtrTy),
+ getSectionMarker(getSectionStart(MI->SectionSuffix), PtrTy),
+ getSectionMarker(getSectionEnd(MI->SectionSuffix), PtrTy),
};
// We declare the _add and _del functions as weak, and only call them if
// there is a valid symbol linked. This allows building binaries with
@@ -306,11 +305,11 @@ bool isUARSafeCall(CallInst *CI) {
// It's safe to both pass pointers to local variables to them
// and to tail-call them.
return F && (F->isIntrinsic() || F->doesNotReturn() ||
- F->getName().startswith("__asan_") ||
- F->getName().startswith("__hwsan_") ||
- F->getName().startswith("__ubsan_") ||
- F->getName().startswith("__msan_") ||
- F->getName().startswith("__tsan_"));
+ F->getName().starts_with("__asan_") ||
+ F->getName().starts_with("__hwsan_") ||
+ F->getName().starts_with("__ubsan_") ||
+ F->getName().starts_with("__msan_") ||
+ F->getName().starts_with("__tsan_"));
}
bool hasUseAfterReturnUnsafeUses(Value &V) {
@@ -368,11 +367,11 @@ bool SanitizerBinaryMetadata::pretendAtomicAccess(const Value *Addr) {
const auto OF = Triple(Mod.getTargetTriple()).getObjectFormat();
const auto ProfSec =
getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false);
- if (GV->getSection().endswith(ProfSec))
+ if (GV->getSection().ends_with(ProfSec))
return true;
}
- if (GV->getName().startswith("__llvm_gcov") ||
- GV->getName().startswith("__llvm_gcda"))
+ if (GV->getName().starts_with("__llvm_gcov") ||
+ GV->getName().starts_with("__llvm_gcda"))
return true;
return false;
diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
index f22918141f6e..906687663519 100644
--- a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
+++ b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp
@@ -261,9 +261,7 @@ private:
FunctionCallee SanCovTraceGepFunction;
FunctionCallee SanCovTraceSwitchFunction;
GlobalVariable *SanCovLowestStack;
- Type *Int128PtrTy, *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty,
- *Int32PtrTy, *Int16PtrTy, *Int16Ty, *Int8Ty, *Int8PtrTy, *Int1Ty,
- *Int1PtrTy;
+ Type *PtrTy, *IntptrTy, *Int64Ty, *Int32Ty, *Int16Ty, *Int8Ty, *Int1Ty;
Module *CurModule;
std::string CurModuleUniqueId;
Triple TargetTriple;
@@ -331,11 +329,10 @@ ModuleSanitizerCoverage::CreateSecStartEnd(Module &M, const char *Section,
// Account for the fact that on windows-msvc __start_* symbols actually
// point to a uint64_t before the start of the array.
- auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, Int8PtrTy);
+ auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, PtrTy);
auto GEP = IRB.CreateGEP(Int8Ty, SecStartI8Ptr,
ConstantInt::get(IntptrTy, sizeof(uint64_t)));
- return std::make_pair(IRB.CreatePointerCast(GEP, PointerType::getUnqual(Ty)),
- SecEnd);
+ return std::make_pair(GEP, SecEnd);
}
Function *ModuleSanitizerCoverage::CreateInitCallsForSections(
@@ -345,7 +342,6 @@ Function *ModuleSanitizerCoverage::CreateInitCallsForSections(
auto SecStart = SecStartEnd.first;
auto SecEnd = SecStartEnd.second;
Function *CtorFunc;
- Type *PtrTy = PointerType::getUnqual(Ty);
std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions(
M, CtorName, InitFunctionName, {PtrTy, PtrTy}, {SecStart, SecEnd});
assert(CtorFunc->getName() == CtorName);
@@ -391,15 +387,9 @@ bool ModuleSanitizerCoverage::instrumentModule(
FunctionPCsArray = nullptr;
FunctionCFsArray = nullptr;
IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits());
- IntptrPtrTy = PointerType::getUnqual(IntptrTy);
+ PtrTy = PointerType::getUnqual(*C);
Type *VoidTy = Type::getVoidTy(*C);
IRBuilder<> IRB(*C);
- Int128PtrTy = PointerType::getUnqual(IRB.getInt128Ty());
- Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty());
- Int16PtrTy = PointerType::getUnqual(IRB.getInt16Ty());
- Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty());
- Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty());
- Int1PtrTy = PointerType::getUnqual(IRB.getInt1Ty());
Int64Ty = IRB.getInt64Ty();
Int32Ty = IRB.getInt32Ty();
Int16Ty = IRB.getInt16Ty();
@@ -438,26 +428,26 @@ bool ModuleSanitizerCoverage::instrumentModule(
M.getOrInsertFunction(SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty);
// Loads.
- SanCovLoadFunction[0] = M.getOrInsertFunction(SanCovLoad1, VoidTy, Int8PtrTy);
+ SanCovLoadFunction[0] = M.getOrInsertFunction(SanCovLoad1, VoidTy, PtrTy);
SanCovLoadFunction[1] =
- M.getOrInsertFunction(SanCovLoad2, VoidTy, Int16PtrTy);
+ M.getOrInsertFunction(SanCovLoad2, VoidTy, PtrTy);
SanCovLoadFunction[2] =
- M.getOrInsertFunction(SanCovLoad4, VoidTy, Int32PtrTy);
+ M.getOrInsertFunction(SanCovLoad4, VoidTy, PtrTy);
SanCovLoadFunction[3] =
- M.getOrInsertFunction(SanCovLoad8, VoidTy, Int64PtrTy);
+ M.getOrInsertFunction(SanCovLoad8, VoidTy, PtrTy);
SanCovLoadFunction[4] =
- M.getOrInsertFunction(SanCovLoad16, VoidTy, Int128PtrTy);
+ M.getOrInsertFunction(SanCovLoad16, VoidTy, PtrTy);
// Stores.
SanCovStoreFunction[0] =
- M.getOrInsertFunction(SanCovStore1, VoidTy, Int8PtrTy);
+ M.getOrInsertFunction(SanCovStore1, VoidTy, PtrTy);
SanCovStoreFunction[1] =
- M.getOrInsertFunction(SanCovStore2, VoidTy, Int16PtrTy);
+ M.getOrInsertFunction(SanCovStore2, VoidTy, PtrTy);
SanCovStoreFunction[2] =
- M.getOrInsertFunction(SanCovStore4, VoidTy, Int32PtrTy);
+ M.getOrInsertFunction(SanCovStore4, VoidTy, PtrTy);
SanCovStoreFunction[3] =
- M.getOrInsertFunction(SanCovStore8, VoidTy, Int64PtrTy);
+ M.getOrInsertFunction(SanCovStore8, VoidTy, PtrTy);
SanCovStoreFunction[4] =
- M.getOrInsertFunction(SanCovStore16, VoidTy, Int128PtrTy);
+ M.getOrInsertFunction(SanCovStore16, VoidTy, PtrTy);
{
AttributeList AL;
@@ -470,7 +460,7 @@ bool ModuleSanitizerCoverage::instrumentModule(
SanCovTraceGepFunction =
M.getOrInsertFunction(SanCovTraceGep, VoidTy, IntptrTy);
SanCovTraceSwitchFunction =
- M.getOrInsertFunction(SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy);
+ M.getOrInsertFunction(SanCovTraceSwitchName, VoidTy, Int64Ty, PtrTy);
Constant *SanCovLowestStackConstant =
M.getOrInsertGlobal(SanCovLowestStackName, IntptrTy);
@@ -487,7 +477,7 @@ bool ModuleSanitizerCoverage::instrumentModule(
SanCovTracePC = M.getOrInsertFunction(SanCovTracePCName, VoidTy);
SanCovTracePCGuard =
- M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy);
+ M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, PtrTy);
for (auto &F : M)
instrumentFunction(F, DTCallback, PDTCallback);
@@ -510,7 +500,7 @@ bool ModuleSanitizerCoverage::instrumentModule(
if (Ctor && Options.PCTable) {
auto SecStartEnd = CreateSecStartEnd(M, SanCovPCsSectionName, IntptrTy);
FunctionCallee InitFunction = declareSanitizerInitFunction(
- M, SanCovPCsInitName, {IntptrPtrTy, IntptrPtrTy});
+ M, SanCovPCsInitName, {PtrTy, PtrTy});
IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator());
IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second});
}
@@ -518,7 +508,7 @@ bool ModuleSanitizerCoverage::instrumentModule(
if (Ctor && Options.CollectControlFlow) {
auto SecStartEnd = CreateSecStartEnd(M, SanCovCFsSectionName, IntptrTy);
FunctionCallee InitFunction = declareSanitizerInitFunction(
- M, SanCovCFsInitName, {IntptrPtrTy, IntptrPtrTy});
+ M, SanCovCFsInitName, {PtrTy, PtrTy});
IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator());
IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second});
}
@@ -616,7 +606,7 @@ void ModuleSanitizerCoverage::instrumentFunction(
return;
if (F.getName().find(".module_ctor") != std::string::npos)
return; // Should not instrument sanitizer init functions.
- if (F.getName().startswith("__sanitizer_"))
+ if (F.getName().starts_with("__sanitizer_"))
return; // Don't instrument __sanitizer_* callbacks.
// Don't touch available_externally functions, their actual body is elewhere.
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
@@ -744,19 +734,19 @@ ModuleSanitizerCoverage::CreatePCArray(Function &F,
IRBuilder<> IRB(&*F.getEntryBlock().getFirstInsertionPt());
for (size_t i = 0; i < N; i++) {
if (&F.getEntryBlock() == AllBlocks[i]) {
- PCs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy));
+ PCs.push_back((Constant *)IRB.CreatePointerCast(&F, PtrTy));
PCs.push_back((Constant *)IRB.CreateIntToPtr(
- ConstantInt::get(IntptrTy, 1), IntptrPtrTy));
+ ConstantInt::get(IntptrTy, 1), PtrTy));
} else {
PCs.push_back((Constant *)IRB.CreatePointerCast(
- BlockAddress::get(AllBlocks[i]), IntptrPtrTy));
- PCs.push_back(Constant::getNullValue(IntptrPtrTy));
+ BlockAddress::get(AllBlocks[i]), PtrTy));
+ PCs.push_back(Constant::getNullValue(PtrTy));
}
}
- auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, IntptrPtrTy,
+ auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, PtrTy,
SanCovPCsSectionName);
PCArray->setInitializer(
- ConstantArray::get(ArrayType::get(IntptrPtrTy, N * 2), PCs));
+ ConstantArray::get(ArrayType::get(PtrTy, N * 2), PCs));
PCArray->setConstant(true);
return PCArray;
@@ -833,10 +823,9 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch(
Int64Ty->getScalarSizeInBits())
Cond = IRB.CreateIntCast(Cond, Int64Ty, false);
for (auto It : SI->cases()) {
- Constant *C = It.getCaseValue();
- if (C->getType()->getScalarSizeInBits() <
- Int64Ty->getScalarSizeInBits())
- C = ConstantExpr::getCast(CastInst::ZExt, It.getCaseValue(), Int64Ty);
+ ConstantInt *C = It.getCaseValue();
+ if (C->getType()->getScalarSizeInBits() < 64)
+ C = ConstantInt::get(C->getContext(), C->getValue().zext(64));
Initializers.push_back(C);
}
llvm::sort(drop_begin(Initializers, 2),
@@ -850,7 +839,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch(
ConstantArray::get(ArrayOfInt64Ty, Initializers),
"__sancov_gen_cov_switch_values");
IRB.CreateCall(SanCovTraceSwitchFunction,
- {Cond, IRB.CreatePointerCast(GV, Int64PtrTy)});
+ {Cond, IRB.CreatePointerCast(GV, PtrTy)});
}
}
}
@@ -895,16 +884,13 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores(
: TypeSize == 128 ? 4
: -1;
};
- Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy,
- Int128PtrTy};
for (auto *LI : Loads) {
InstrumentationIRBuilder IRB(LI);
auto Ptr = LI->getPointerOperand();
int Idx = CallbackIdx(LI->getType());
if (Idx < 0)
continue;
- IRB.CreateCall(SanCovLoadFunction[Idx],
- IRB.CreatePointerCast(Ptr, PointerType[Idx]));
+ IRB.CreateCall(SanCovLoadFunction[Idx], Ptr);
}
for (auto *SI : Stores) {
InstrumentationIRBuilder IRB(SI);
@@ -912,8 +898,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores(
int Idx = CallbackIdx(SI->getValueOperand()->getType());
if (Idx < 0)
continue;
- IRB.CreateCall(SanCovStoreFunction[Idx],
- IRB.CreatePointerCast(Ptr, PointerType[Idx]));
+ IRB.CreateCall(SanCovStoreFunction[Idx], Ptr);
}
}
@@ -978,7 +963,7 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
auto GuardPtr = IRB.CreateIntToPtr(
IRB.CreateAdd(IRB.CreatePointerCast(FunctionGuardArray, IntptrTy),
ConstantInt::get(IntptrTy, Idx * 4)),
- Int32PtrTy);
+ PtrTy);
IRB.CreateCall(SanCovTracePCGuard, GuardPtr)->setCannotMerge();
}
if (Options.Inline8bitCounters) {
@@ -1008,7 +993,7 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB,
Module *M = F.getParent();
Function *GetFrameAddr = Intrinsic::getDeclaration(
M, Intrinsic::frameaddress,
- IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace()));
+ IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace()));
auto FrameAddrPtr =
IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)});
auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy);
@@ -1059,40 +1044,40 @@ void ModuleSanitizerCoverage::createFunctionControlFlow(Function &F) {
for (auto &BB : F) {
// blockaddress can not be used on function's entry block.
if (&BB == &F.getEntryBlock())
- CFs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy));
+ CFs.push_back((Constant *)IRB.CreatePointerCast(&F, PtrTy));
else
CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(&BB),
- IntptrPtrTy));
+ PtrTy));
for (auto SuccBB : successors(&BB)) {
assert(SuccBB != &F.getEntryBlock());
CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(SuccBB),
- IntptrPtrTy));
+ PtrTy));
}
- CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy));
+ CFs.push_back((Constant *)Constant::getNullValue(PtrTy));
for (auto &Inst : BB) {
if (CallBase *CB = dyn_cast<CallBase>(&Inst)) {
if (CB->isIndirectCall()) {
// TODO(navidem): handle indirect calls, for now mark its existence.
CFs.push_back((Constant *)IRB.CreateIntToPtr(
- ConstantInt::get(IntptrTy, -1), IntptrPtrTy));
+ ConstantInt::get(IntptrTy, -1), PtrTy));
} else {
auto CalledF = CB->getCalledFunction();
if (CalledF && !CalledF->isIntrinsic())
CFs.push_back(
- (Constant *)IRB.CreatePointerCast(CalledF, IntptrPtrTy));
+ (Constant *)IRB.CreatePointerCast(CalledF, PtrTy));
}
}
}
- CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy));
+ CFs.push_back((Constant *)Constant::getNullValue(PtrTy));
}
FunctionCFsArray = CreateFunctionLocalArrayInSection(
- CFs.size(), F, IntptrPtrTy, SanCovCFsSectionName);
+ CFs.size(), F, PtrTy, SanCovCFsSectionName);
FunctionCFsArray->setInitializer(
- ConstantArray::get(ArrayType::get(IntptrPtrTy, CFs.size()), CFs));
+ ConstantArray::get(ArrayType::get(PtrTy, CFs.size()), CFs));
FunctionCFsArray->setConstant(true);
}
diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
index ce35eefb63fa..8ee0bca7e354 100644
--- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp
@@ -205,7 +205,7 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) {
Attr = Attr.addFnAttribute(Ctx, Attribute::NoUnwind);
// Initialize the callbacks.
TsanFuncEntry = M.getOrInsertFunction("__tsan_func_entry", Attr,
- IRB.getVoidTy(), IRB.getInt8PtrTy());
+ IRB.getVoidTy(), IRB.getPtrTy());
TsanFuncExit =
M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy());
TsanIgnoreBegin = M.getOrInsertFunction("__tsan_ignore_thread_begin", Attr,
@@ -220,49 +220,49 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) {
std::string BitSizeStr = utostr(BitSize);
SmallString<32> ReadName("__tsan_read" + ByteSizeStr);
TsanRead[i] = M.getOrInsertFunction(ReadName, Attr, IRB.getVoidTy(),
- IRB.getInt8PtrTy());
+ IRB.getPtrTy());
SmallString<32> WriteName("__tsan_write" + ByteSizeStr);
TsanWrite[i] = M.getOrInsertFunction(WriteName, Attr, IRB.getVoidTy(),
- IRB.getInt8PtrTy());
+ IRB.getPtrTy());
SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr);
TsanUnalignedRead[i] = M.getOrInsertFunction(
- UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr);
TsanUnalignedWrite[i] = M.getOrInsertFunction(
- UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> VolatileReadName("__tsan_volatile_read" + ByteSizeStr);
TsanVolatileRead[i] = M.getOrInsertFunction(
- VolatileReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ VolatileReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> VolatileWriteName("__tsan_volatile_write" + ByteSizeStr);
TsanVolatileWrite[i] = M.getOrInsertFunction(
- VolatileWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ VolatileWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> UnalignedVolatileReadName("__tsan_unaligned_volatile_read" +
ByteSizeStr);
TsanUnalignedVolatileRead[i] = M.getOrInsertFunction(
- UnalignedVolatileReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ UnalignedVolatileReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> UnalignedVolatileWriteName(
"__tsan_unaligned_volatile_write" + ByteSizeStr);
TsanUnalignedVolatileWrite[i] = M.getOrInsertFunction(
- UnalignedVolatileWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ UnalignedVolatileWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> CompoundRWName("__tsan_read_write" + ByteSizeStr);
TsanCompoundRW[i] = M.getOrInsertFunction(
- CompoundRWName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ CompoundRWName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
SmallString<64> UnalignedCompoundRWName("__tsan_unaligned_read_write" +
ByteSizeStr);
TsanUnalignedCompoundRW[i] = M.getOrInsertFunction(
- UnalignedCompoundRWName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy());
+ UnalignedCompoundRWName, Attr, IRB.getVoidTy(), IRB.getPtrTy());
Type *Ty = Type::getIntNTy(Ctx, BitSize);
- Type *PtrTy = Ty->getPointerTo();
+ Type *PtrTy = PointerType::get(Ctx, 0);
SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load");
TsanAtomicLoad[i] =
M.getOrInsertFunction(AtomicLoadName,
@@ -318,9 +318,9 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) {
}
TsanVptrUpdate =
M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy());
+ IRB.getPtrTy(), IRB.getPtrTy());
TsanVptrLoad = M.getOrInsertFunction("__tsan_vptr_read", Attr,
- IRB.getVoidTy(), IRB.getInt8PtrTy());
+ IRB.getVoidTy(), IRB.getPtrTy());
TsanAtomicThreadFence = M.getOrInsertFunction(
"__tsan_atomic_thread_fence",
TLI.getAttrList(&Ctx, {0}, /*Signed=*/true, /*Ret=*/false, Attr),
@@ -332,15 +332,15 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) {
IRB.getVoidTy(), OrdTy);
MemmoveFn =
- M.getOrInsertFunction("__tsan_memmove", Attr, IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy);
+ M.getOrInsertFunction("__tsan_memmove", Attr, IRB.getPtrTy(),
+ IRB.getPtrTy(), IRB.getPtrTy(), IntptrTy);
MemcpyFn =
- M.getOrInsertFunction("__tsan_memcpy", Attr, IRB.getInt8PtrTy(),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy);
+ M.getOrInsertFunction("__tsan_memcpy", Attr, IRB.getPtrTy(),
+ IRB.getPtrTy(), IRB.getPtrTy(), IntptrTy);
MemsetFn = M.getOrInsertFunction(
"__tsan_memset",
TLI.getAttrList(&Ctx, {1}, /*Signed=*/true, /*Ret=*/false, Attr),
- IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy);
+ IRB.getPtrTy(), IRB.getPtrTy(), IRB.getInt32Ty(), IntptrTy);
}
static bool isVtableAccess(Instruction *I) {
@@ -360,15 +360,10 @@ static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
StringRef SectionName = GV->getSection();
// Check if the global is in the PGO counters section.
auto OF = Triple(M->getTargetTriple()).getObjectFormat();
- if (SectionName.endswith(
+ if (SectionName.ends_with(
getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
return false;
}
-
- // Check if the global is private gcov data.
- if (GV->getName().startswith("__llvm_gcov") ||
- GV->getName().startswith("__llvm_gcda"))
- return false;
}
// Do not instrument accesses from different address spaces; we cannot deal
@@ -522,6 +517,9 @@ bool ThreadSanitizer::sanitizeFunction(Function &F,
// Traverse all instructions, collect loads/stores/returns, check for calls.
for (auto &BB : F) {
for (auto &Inst : BB) {
+ // Skip instructions inserted by another instrumentation.
+ if (Inst.hasMetadata(LLVMContext::MD_nosanitize))
+ continue;
if (isTsanAtomic(&Inst))
AtomicAccesses.push_back(&Inst);
else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst))
@@ -613,17 +611,14 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II,
StoredValue = IRB.CreateExtractElement(
StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
if (StoredValue->getType()->isIntegerTy())
- StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
+ StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getPtrTy());
// Call TsanVptrUpdate.
- IRB.CreateCall(TsanVptrUpdate,
- {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
+ IRB.CreateCall(TsanVptrUpdate, {Addr, StoredValue});
NumInstrumentedVtableWrites++;
return true;
}
if (!IsWrite && isVtableAccess(II.Inst)) {
- IRB.CreateCall(TsanVptrLoad,
- IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
+ IRB.CreateCall(TsanVptrLoad, Addr);
NumInstrumentedVtableReads++;
return true;
}
@@ -655,7 +650,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II,
else
OnAccessFunc = IsWrite ? TsanUnalignedWrite[Idx] : TsanUnalignedRead[Idx];
}
- IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
+ IRB.CreateCall(OnAccessFunc, Addr);
if (IsCompoundRW || IsWrite)
NumInstrumentedWrites++;
if (IsCompoundRW || !IsWrite)
@@ -691,17 +686,19 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) {
bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) {
InstrumentationIRBuilder IRB(I);
if (MemSetInst *M = dyn_cast<MemSetInst>(I)) {
+ Value *Cast1 = IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false);
+ Value *Cast2 = IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false);
IRB.CreateCall(
MemsetFn,
- {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
- IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false),
- IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)});
+ {M->getArgOperand(0),
+ Cast1,
+ Cast2});
I->eraseFromParent();
} else if (MemTransferInst *M = dyn_cast<MemTransferInst>(I)) {
IRB.CreateCall(
isa<MemCpyInst>(M) ? MemcpyFn : MemmoveFn,
- {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
- IRB.CreatePointerCast(M->getArgOperand(1), IRB.getInt8PtrTy()),
+ {M->getArgOperand(0),
+ M->getArgOperand(1),
IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)});
I->eraseFromParent();
}
@@ -724,11 +721,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) {
int Idx = getMemoryAccessFuncIndex(OrigTy, Addr, DL);
if (Idx < 0)
return false;
- const unsigned ByteSize = 1U << Idx;
- const unsigned BitSize = ByteSize * 8;
- Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
- Type *PtrTy = Ty->getPointerTo();
- Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
+ Value *Args[] = {Addr,
createOrdering(&IRB, LI->getOrdering())};
Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args);
Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy);
@@ -742,8 +735,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) {
const unsigned ByteSize = 1U << Idx;
const unsigned BitSize = ByteSize * 8;
Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
- Type *PtrTy = Ty->getPointerTo();
- Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
+ Value *Args[] = {Addr,
IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty),
createOrdering(&IRB, SI->getOrdering())};
CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args);
@@ -760,8 +752,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) {
const unsigned ByteSize = 1U << Idx;
const unsigned BitSize = ByteSize * 8;
Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
- Type *PtrTy = Ty->getPointerTo();
- Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
+ Value *Args[] = {Addr,
IRB.CreateIntCast(RMWI->getValOperand(), Ty, false),
createOrdering(&IRB, RMWI->getOrdering())};
CallInst *C = CallInst::Create(F, Args);
@@ -775,12 +766,11 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) {
const unsigned ByteSize = 1U << Idx;
const unsigned BitSize = ByteSize * 8;
Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
- Type *PtrTy = Ty->getPointerTo();
Value *CmpOperand =
IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
Value *NewOperand =
IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
- Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
+ Value *Args[] = {Addr,
CmpOperand,
NewOperand,
createOrdering(&IRB, CASI->getSuccessOrdering()),
diff --git a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h
index dd6a1c3f9795..7732eeb4b9c8 100644
--- a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h
+++ b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h
@@ -22,7 +22,6 @@
#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H
#define LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/ObjCARCInstKind.h"
namespace llvm {
diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
index adf86526ebf1..b51e4d46bffe 100644
--- a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
+++ b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp
@@ -933,7 +933,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst,
if (IsNullOrUndef(CI->getArgOperand(0))) {
Changed = true;
new StoreInst(ConstantInt::getTrue(CI->getContext()),
- PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
+ PoisonValue::get(PointerType::getUnqual(CI->getContext())),
+ CI);
Value *NewValue = PoisonValue::get(CI->getType());
LLVM_DEBUG(
dbgs() << "A null pointer-to-weak-pointer is undefined behavior."
@@ -952,7 +953,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst,
IsNullOrUndef(CI->getArgOperand(1))) {
Changed = true;
new StoreInst(ConstantInt::getTrue(CI->getContext()),
- PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI);
+ PoisonValue::get(PointerType::getUnqual(CI->getContext())),
+ CI);
Value *NewValue = PoisonValue::get(CI->getType());
LLVM_DEBUG(
diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp
index 9f15772f2fa1..e563ecfb1622 100644
--- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp
+++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp
@@ -19,7 +19,7 @@ using namespace llvm::objcarc;
static StringRef getName(Value *V) {
StringRef Name = V->getName();
- if (Name.startswith("\1"))
+ if (Name.starts_with("\1"))
return Name.substr(1);
return Name;
}
diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp
index 24354211341f..9af275a9f4e2 100644
--- a/llvm/lib/Transforms/Scalar/ADCE.cpp
+++ b/llvm/lib/Transforms/Scalar/ADCE.cpp
@@ -544,6 +544,16 @@ ADCEChanged AggressiveDeadCodeElimination::removeDeadInstructions() {
// value of the function, and may therefore be deleted safely.
// NOTE: We reuse the Worklist vector here for memory efficiency.
for (Instruction &I : llvm::reverse(instructions(F))) {
+ // With "RemoveDIs" debug-info stored in DPValue objects, debug-info
+ // attached to this instruction, and drop any for scopes that aren't alive,
+ // like the rest of this loop does. Extending support to assignment tracking
+ // is future work.
+ for (DPValue &DPV : make_early_inc_range(I.getDbgValueRange())) {
+ if (AliveScopes.count(DPV.getDebugLoc()->getScope()))
+ continue;
+ I.dropOneDbgValue(&DPV);
+ }
+
// Check if the instruction is alive.
if (isLive(&I))
continue;
diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index b259c76fc3a5..f3422a705dca 100644
--- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -83,11 +83,7 @@ static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
const SCEV *OffSCEV, Value *Ptr,
ScalarEvolution *SE) {
const SCEV *PtrSCEV = SE->getSCEV(Ptr);
- // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
- // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
- // may disagree. Trunc/extend so they agree.
- PtrSCEV = SE->getTruncateOrZeroExtend(
- PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
+
const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
if (isa<SCEVCouldNotCompute>(DiffSCEV))
return Align(1);
@@ -179,6 +175,9 @@ bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
// Added to suppress a crash because consumer doesn't expect non-constant
// alignments in the assume bundle. TODO: Consider generalizing caller.
return false;
+ if (!cast<SCEVConstant>(AlignSCEV)->getAPInt().isPowerOf2())
+ // Only power of two alignments are supported.
+ return false;
if (AlignOB.Inputs.size() == 3)
OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());
else
@@ -264,11 +263,17 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
// Now that we've updated that use of the pointer, look for other uses of
// the pointer to update.
Visited.insert(J);
- for (User *UJ : J->users()) {
- Instruction *K = cast<Instruction>(UJ);
- if (!Visited.count(K))
- WorkList.push_back(K);
- }
+ if (isa<GetElementPtrInst>(J) || isa<PHINode>(J))
+ for (auto &U : J->uses()) {
+ if (U->getType()->isPointerTy()) {
+ Instruction *K = cast<Instruction>(U.getUser());
+ StoreInst *SI = dyn_cast<StoreInst>(K);
+ if (SI && SI->getPointerOperandIndex() != U.getOperandNo())
+ continue;
+ if (!Visited.count(K))
+ WorkList.push_back(K);
+ }
+ }
}
return true;
diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index aeb7c5d461f0..47f663fa0cf0 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -62,10 +62,8 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -374,10 +372,10 @@ static void splitCallSite(CallBase &CB,
return;
}
- auto *OriginalBegin = &*TailBB->begin();
+ BasicBlock::iterator OriginalBegin = TailBB->begin();
// Replace users of the original call with a PHI mering call-sites split.
if (CallPN) {
- CallPN->insertBefore(OriginalBegin);
+ CallPN->insertBefore(*TailBB, OriginalBegin);
CB.replaceAllUsesWith(CallPN);
}
@@ -389,6 +387,7 @@ static void splitCallSite(CallBase &CB,
// do not introduce unnecessary PHI nodes for def-use chains from the call
// instruction to the beginning of the block.
auto I = CB.getReverseIterator();
+ Instruction *OriginalBeginInst = &*OriginalBegin;
while (I != TailBB->rend()) {
Instruction *CurrentI = &*I++;
if (!CurrentI->use_empty()) {
@@ -401,12 +400,13 @@ static void splitCallSite(CallBase &CB,
for (auto &Mapping : ValueToValueMaps)
NewPN->addIncoming(Mapping[CurrentI],
cast<Instruction>(Mapping[CurrentI])->getParent());
- NewPN->insertBefore(&*TailBB->begin());
+ NewPN->insertBefore(*TailBB, TailBB->begin());
CurrentI->replaceAllUsesWith(NewPN);
}
+ CurrentI->dropDbgValues();
CurrentI->eraseFromParent();
// We are done once we handled the first original instruction in TailBB.
- if (CurrentI == OriginalBegin)
+ if (CurrentI == OriginalBeginInst)
break;
}
}
diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 611e64bd0976..3e5d979f11cc 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -761,11 +761,9 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
if (Adj->Offset) {
if (Adj->Ty) {
// Constant being rebased is a ConstantExpr.
- PointerType *Int8PtrTy = Type::getInt8PtrTy(
- *Ctx, cast<PointerType>(Adj->Ty)->getAddressSpace());
- Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", Adj->MatInsertPt);
Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset,
"mat_gep", Adj->MatInsertPt);
+ // Hide it behind a bitcast.
Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt);
} else
// Constant being rebased is a ConstantInt.
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 15628d32280d..a6fbddca5cba 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -18,13 +18,17 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstraintSystem.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Verifier.h"
@@ -83,32 +87,69 @@ static Instruction *getContextInstForUse(Use &U) {
}
namespace {
+/// Struct to express a condition of the form %Op0 Pred %Op1.
+struct ConditionTy {
+ CmpInst::Predicate Pred;
+ Value *Op0;
+ Value *Op1;
+
+ ConditionTy()
+ : Pred(CmpInst::BAD_ICMP_PREDICATE), Op0(nullptr), Op1(nullptr) {}
+ ConditionTy(CmpInst::Predicate Pred, Value *Op0, Value *Op1)
+ : Pred(Pred), Op0(Op0), Op1(Op1) {}
+};
+
/// Represents either
-/// * a condition that holds on entry to a block (=conditional fact)
+/// * a condition that holds on entry to a block (=condition fact)
/// * an assume (=assume fact)
/// * a use of a compare instruction to simplify.
/// It also tracks the Dominator DFS in and out numbers for each entry.
struct FactOrCheck {
+ enum class EntryTy {
+ ConditionFact, /// A condition that holds on entry to a block.
+ InstFact, /// A fact that holds after Inst executed (e.g. an assume or
+ /// min/mix intrinsic.
+ InstCheck, /// An instruction to simplify (e.g. an overflow math
+ /// intrinsics).
+ UseCheck /// An use of a compare instruction to simplify.
+ };
+
union {
Instruction *Inst;
Use *U;
+ ConditionTy Cond;
};
+
+ /// A pre-condition that must hold for the current fact to be added to the
+ /// system.
+ ConditionTy DoesHold;
+
unsigned NumIn;
unsigned NumOut;
- bool HasInst;
- bool Not;
+ EntryTy Ty;
- FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool Not)
+ FactOrCheck(EntryTy Ty, DomTreeNode *DTN, Instruction *Inst)
: Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
- HasInst(true), Not(Not) {}
+ Ty(Ty) {}
FactOrCheck(DomTreeNode *DTN, Use *U)
- : U(U), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
- HasInst(false), Not(false) {}
+ : U(U), DoesHold(CmpInst::BAD_ICMP_PREDICATE, nullptr, nullptr),
+ NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()),
+ Ty(EntryTy::UseCheck) {}
+
+ FactOrCheck(DomTreeNode *DTN, CmpInst::Predicate Pred, Value *Op0, Value *Op1,
+ ConditionTy Precond = ConditionTy())
+ : Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(DTN->getDFSNumIn()),
+ NumOut(DTN->getDFSNumOut()), Ty(EntryTy::ConditionFact) {}
- static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst,
- bool Not = false) {
- return FactOrCheck(DTN, Inst, Not);
+ static FactOrCheck getConditionFact(DomTreeNode *DTN, CmpInst::Predicate Pred,
+ Value *Op0, Value *Op1,
+ ConditionTy Precond = ConditionTy()) {
+ return FactOrCheck(DTN, Pred, Op0, Op1, Precond);
+ }
+
+ static FactOrCheck getInstFact(DomTreeNode *DTN, Instruction *Inst) {
+ return FactOrCheck(EntryTy::InstFact, DTN, Inst);
}
static FactOrCheck getCheck(DomTreeNode *DTN, Use *U) {
@@ -116,39 +157,47 @@ struct FactOrCheck {
}
static FactOrCheck getCheck(DomTreeNode *DTN, CallInst *CI) {
- return FactOrCheck(DTN, CI, false);
+ return FactOrCheck(EntryTy::InstCheck, DTN, CI);
}
bool isCheck() const {
- return !HasInst ||
- match(Inst, m_Intrinsic<Intrinsic::ssub_with_overflow>());
+ return Ty == EntryTy::InstCheck || Ty == EntryTy::UseCheck;
}
Instruction *getContextInst() const {
- if (HasInst)
- return Inst;
- return getContextInstForUse(*U);
+ if (Ty == EntryTy::UseCheck)
+ return getContextInstForUse(*U);
+ return Inst;
}
+
Instruction *getInstructionToSimplify() const {
assert(isCheck());
- if (HasInst)
+ if (Ty == EntryTy::InstCheck)
return Inst;
// The use may have been simplified to a constant already.
return dyn_cast<Instruction>(*U);
}
- bool isConditionFact() const { return !isCheck() && isa<CmpInst>(Inst); }
+
+ bool isConditionFact() const { return Ty == EntryTy::ConditionFact; }
};
/// Keep state required to build worklist.
struct State {
DominatorTree &DT;
+ LoopInfo &LI;
+ ScalarEvolution &SE;
SmallVector<FactOrCheck, 64> WorkList;
- State(DominatorTree &DT) : DT(DT) {}
+ State(DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE)
+ : DT(DT), LI(LI), SE(SE) {}
/// Process block \p BB and add known facts to work-list.
void addInfoFor(BasicBlock &BB);
+ /// Try to add facts for loop inductions (AddRecs) in EQ/NE compares
+ /// controlling the loop header.
+ void addInfoForInductions(BasicBlock &BB);
+
/// Returns true if we can add a known condition from BB to its successor
/// block Succ.
bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const {
@@ -172,19 +221,9 @@ struct StackEntry {
ValuesToRelease(ValuesToRelease) {}
};
-/// Struct to express a pre-condition of the form %Op0 Pred %Op1.
-struct PreconditionTy {
- CmpInst::Predicate Pred;
- Value *Op0;
- Value *Op1;
-
- PreconditionTy(CmpInst::Predicate Pred, Value *Op0, Value *Op1)
- : Pred(Pred), Op0(Op0), Op1(Op1) {}
-};
-
struct ConstraintTy {
SmallVector<int64_t, 8> Coefficients;
- SmallVector<PreconditionTy, 2> Preconditions;
+ SmallVector<ConditionTy, 2> Preconditions;
SmallVector<SmallVector<int64_t, 8>> ExtraInfo;
@@ -327,10 +366,57 @@ struct Decomposition {
}
};
+// Variable and constant offsets for a chain of GEPs, with base pointer BasePtr.
+struct OffsetResult {
+ Value *BasePtr;
+ APInt ConstantOffset;
+ MapVector<Value *, APInt> VariableOffsets;
+ bool AllInbounds;
+
+ OffsetResult() : BasePtr(nullptr), ConstantOffset(0, uint64_t(0)) {}
+
+ OffsetResult(GEPOperator &GEP, const DataLayout &DL)
+ : BasePtr(GEP.getPointerOperand()), AllInbounds(GEP.isInBounds()) {
+ ConstantOffset = APInt(DL.getIndexTypeSizeInBits(BasePtr->getType()), 0);
+ }
+};
} // namespace
+// Try to collect variable and constant offsets for \p GEP, partly traversing
+// nested GEPs. Returns an OffsetResult with nullptr as BasePtr of collecting
+// the offset fails.
+static OffsetResult collectOffsets(GEPOperator &GEP, const DataLayout &DL) {
+ OffsetResult Result(GEP, DL);
+ unsigned BitWidth = Result.ConstantOffset.getBitWidth();
+ if (!GEP.collectOffset(DL, BitWidth, Result.VariableOffsets,
+ Result.ConstantOffset))
+ return {};
+
+ // If we have a nested GEP, check if we can combine the constant offset of the
+ // inner GEP with the outer GEP.
+ if (auto *InnerGEP = dyn_cast<GetElementPtrInst>(Result.BasePtr)) {
+ MapVector<Value *, APInt> VariableOffsets2;
+ APInt ConstantOffset2(BitWidth, 0);
+ bool CanCollectInner = InnerGEP->collectOffset(
+ DL, BitWidth, VariableOffsets2, ConstantOffset2);
+ // TODO: Support cases with more than 1 variable offset.
+ if (!CanCollectInner || Result.VariableOffsets.size() > 1 ||
+ VariableOffsets2.size() > 1 ||
+ (Result.VariableOffsets.size() >= 1 && VariableOffsets2.size() >= 1)) {
+ // More than 1 variable index, use outer result.
+ return Result;
+ }
+ Result.BasePtr = InnerGEP->getPointerOperand();
+ Result.ConstantOffset += ConstantOffset2;
+ if (Result.VariableOffsets.size() == 0 && VariableOffsets2.size() == 1)
+ Result.VariableOffsets = VariableOffsets2;
+ Result.AllInbounds &= InnerGEP->isInBounds();
+ }
+ return Result;
+}
+
static Decomposition decompose(Value *V,
- SmallVectorImpl<PreconditionTy> &Preconditions,
+ SmallVectorImpl<ConditionTy> &Preconditions,
bool IsSigned, const DataLayout &DL);
static bool canUseSExt(ConstantInt *CI) {
@@ -338,51 +424,22 @@ static bool canUseSExt(ConstantInt *CI) {
return Val.sgt(MinSignedConstraintValue) && Val.slt(MaxConstraintValue);
}
-static Decomposition
-decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions,
- bool IsSigned, const DataLayout &DL) {
+static Decomposition decomposeGEP(GEPOperator &GEP,
+ SmallVectorImpl<ConditionTy> &Preconditions,
+ bool IsSigned, const DataLayout &DL) {
// Do not reason about pointers where the index size is larger than 64 bits,
// as the coefficients used to encode constraints are 64 bit integers.
if (DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()) > 64)
return &GEP;
- if (!GEP.isInBounds())
- return &GEP;
-
assert(!IsSigned && "The logic below only supports decomposition for "
"unsinged predicates at the moment.");
- Type *PtrTy = GEP.getType()->getScalarType();
- unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy);
- MapVector<Value *, APInt> VariableOffsets;
- APInt ConstantOffset(BitWidth, 0);
- if (!GEP.collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
+ const auto &[BasePtr, ConstantOffset, VariableOffsets, AllInbounds] =
+ collectOffsets(GEP, DL);
+ if (!BasePtr || !AllInbounds)
return &GEP;
- // Handle the (gep (gep ....), C) case by incrementing the constant
- // coefficient of the inner GEP, if C is a constant.
- auto *InnerGEP = dyn_cast<GEPOperator>(GEP.getPointerOperand());
- if (VariableOffsets.empty() && InnerGEP && InnerGEP->getNumOperands() == 2) {
- auto Result = decompose(InnerGEP, Preconditions, IsSigned, DL);
- Result.add(ConstantOffset.getSExtValue());
-
- if (ConstantOffset.isNegative()) {
- unsigned Scale = DL.getTypeAllocSize(InnerGEP->getResultElementType());
- int64_t ConstantOffsetI = ConstantOffset.getSExtValue();
- if (ConstantOffsetI % Scale != 0)
- return &GEP;
- // Add pre-condition ensuring the GEP is increasing monotonically and
- // can be de-composed.
- // Both sides are normalized by being divided by Scale.
- Preconditions.emplace_back(
- CmpInst::ICMP_SGE, InnerGEP->getOperand(1),
- ConstantInt::get(InnerGEP->getOperand(1)->getType(),
- -1 * (ConstantOffsetI / Scale)));
- }
- return Result;
- }
-
- Decomposition Result(ConstantOffset.getSExtValue(),
- DecompEntry(1, GEP.getPointerOperand()));
+ Decomposition Result(ConstantOffset.getSExtValue(), DecompEntry(1, BasePtr));
for (auto [Index, Scale] : VariableOffsets) {
auto IdxResult = decompose(Index, Preconditions, IsSigned, DL);
IdxResult.mul(Scale.getSExtValue());
@@ -401,7 +458,7 @@ decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions,
// Variable } where Coefficient * Variable. The sum of the constant offset and
// pairs equals \p V.
static Decomposition decompose(Value *V,
- SmallVectorImpl<PreconditionTy> &Preconditions,
+ SmallVectorImpl<ConditionTy> &Preconditions,
bool IsSigned, const DataLayout &DL) {
auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B,
@@ -412,6 +469,22 @@ static Decomposition decompose(Value *V,
return ResA;
};
+ Type *Ty = V->getType()->getScalarType();
+ if (Ty->isPointerTy() && !IsSigned) {
+ if (auto *GEP = dyn_cast<GEPOperator>(V))
+ return decomposeGEP(*GEP, Preconditions, IsSigned, DL);
+ if (isa<ConstantPointerNull>(V))
+ return int64_t(0);
+
+ return V;
+ }
+
+ // Don't handle integers > 64 bit. Our coefficients are 64-bit large, so
+ // coefficient add/mul may wrap, while the operation in the full bit width
+ // would not.
+ if (!Ty->isIntegerTy() || Ty->getIntegerBitWidth() > 64)
+ return V;
+
// Decompose \p V used with a signed predicate.
if (IsSigned) {
if (auto *CI = dyn_cast<ConstantInt>(V)) {
@@ -424,7 +497,7 @@ static Decomposition decompose(Value *V,
return MergeResults(Op0, Op1, IsSigned);
ConstantInt *CI;
- if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI)))) {
+ if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) {
auto Result = decompose(Op0, Preconditions, IsSigned, DL);
Result.mul(CI->getSExtValue());
return Result;
@@ -439,9 +512,6 @@ static Decomposition decompose(Value *V,
return int64_t(CI->getZExtValue());
}
- if (auto *GEP = dyn_cast<GEPOperator>(V))
- return decomposeGEP(*GEP, Preconditions, IsSigned, DL);
-
Value *Op0;
bool IsKnownNonNegative = false;
if (match(V, m_ZExt(m_Value(Op0)))) {
@@ -474,10 +544,8 @@ static Decomposition decompose(Value *V,
}
// Decompose or as an add if there are no common bits between the operands.
- if (match(V, m_Or(m_Value(Op0), m_ConstantInt(CI))) &&
- haveNoCommonBitsSet(Op0, CI, DL)) {
+ if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI))))
return MergeResults(Op0, CI, IsSigned);
- }
if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) {
if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64)
@@ -544,7 +612,7 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
Pred != CmpInst::ICMP_SLE && Pred != CmpInst::ICMP_SLT)
return {};
- SmallVector<PreconditionTy, 4> Preconditions;
+ SmallVector<ConditionTy, 4> Preconditions;
bool IsSigned = CmpInst::isSigned(Pred);
auto &Value2Index = getValue2Index(IsSigned);
auto ADec = decompose(Op0->stripPointerCastsSameRepresentation(),
@@ -637,6 +705,17 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred,
Value *Op0,
Value *Op1) const {
+ Constant *NullC = Constant::getNullValue(Op0->getType());
+ // Handle trivially true compares directly to avoid adding V UGE 0 constraints
+ // for all variables in the unsigned system.
+ if ((Pred == CmpInst::ICMP_ULE && Op0 == NullC) ||
+ (Pred == CmpInst::ICMP_UGE && Op1 == NullC)) {
+ auto &Value2Index = getValue2Index(false);
+ // Return constraint that's trivially true.
+ return ConstraintTy(SmallVector<int64_t, 8>(Value2Index.size(), 0), false,
+ false, false);
+ }
+
// If both operands are known to be non-negative, change signed predicates to
// unsigned ones. This increases the reasoning effectiveness in combination
// with the signed <-> unsigned transfer logic.
@@ -654,7 +733,7 @@ ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred,
bool ConstraintTy::isValid(const ConstraintInfo &Info) const {
return Coefficients.size() > 0 &&
- all_of(Preconditions, [&Info](const PreconditionTy &C) {
+ all_of(Preconditions, [&Info](const ConditionTy &C) {
return Info.doesHold(C.Pred, C.Op0, C.Op1);
});
}
@@ -713,6 +792,10 @@ bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A,
void ConstraintInfo::transferToOtherSystem(
CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn,
unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) {
+ auto IsKnownNonNegative = [this](Value *V) {
+ return doesHold(CmpInst::ICMP_SGE, V, ConstantInt::get(V->getType(), 0)) ||
+ isKnownNonNegative(V, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1);
+ };
// Check if we can combine facts from the signed and unsigned systems to
// derive additional facts.
if (!A->getType()->isIntegerTy())
@@ -724,30 +807,41 @@ void ConstraintInfo::transferToOtherSystem(
default:
break;
case CmpInst::ICMP_ULT:
- // If B is a signed positive constant, A >=s 0 and A <s B.
- if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) {
+ case CmpInst::ICMP_ULE:
+ // If B is a signed positive constant, then A >=s 0 and A <s (or <=s) B.
+ if (IsKnownNonNegative(B)) {
addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), NumIn,
NumOut, DFSInStack);
- addFact(CmpInst::ICMP_SLT, A, B, NumIn, NumOut, DFSInStack);
+ addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut,
+ DFSInStack);
+ }
+ break;
+ case CmpInst::ICMP_UGE:
+ case CmpInst::ICMP_UGT:
+ // If A is a signed positive constant, then B >=s 0 and A >s (or >=s) B.
+ if (IsKnownNonNegative(A)) {
+ addFact(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0), NumIn,
+ NumOut, DFSInStack);
+ addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut,
+ DFSInStack);
}
break;
case CmpInst::ICMP_SLT:
- if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0)))
+ if (IsKnownNonNegative(A))
addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack);
break;
case CmpInst::ICMP_SGT: {
if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1)))
addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn,
NumOut, DFSInStack);
- if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0)))
+ if (IsKnownNonNegative(B))
addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack);
break;
}
case CmpInst::ICMP_SGE:
- if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) {
+ if (IsKnownNonNegative(B))
addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack);
- }
break;
}
}
@@ -762,7 +856,138 @@ static void dumpConstraint(ArrayRef<int64_t> C,
}
#endif
+void State::addInfoForInductions(BasicBlock &BB) {
+ auto *L = LI.getLoopFor(&BB);
+ if (!L || L->getHeader() != &BB)
+ return;
+
+ Value *A;
+ Value *B;
+ CmpInst::Predicate Pred;
+
+ if (!match(BB.getTerminator(),
+ m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())))
+ return;
+ PHINode *PN = dyn_cast<PHINode>(A);
+ if (!PN) {
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ std::swap(A, B);
+ PN = dyn_cast<PHINode>(A);
+ }
+
+ if (!PN || PN->getParent() != &BB || PN->getNumIncomingValues() != 2 ||
+ !SE.isSCEVable(PN->getType()))
+ return;
+
+ BasicBlock *InLoopSucc = nullptr;
+ if (Pred == CmpInst::ICMP_NE)
+ InLoopSucc = cast<BranchInst>(BB.getTerminator())->getSuccessor(0);
+ else if (Pred == CmpInst::ICMP_EQ)
+ InLoopSucc = cast<BranchInst>(BB.getTerminator())->getSuccessor(1);
+ else
+ return;
+
+ if (!L->contains(InLoopSucc) || !L->isLoopExiting(&BB) || InLoopSucc == &BB)
+ return;
+
+ auto *AR = dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(PN));
+ BasicBlock *LoopPred = L->getLoopPredecessor();
+ if (!AR || AR->getLoop() != L || !LoopPred)
+ return;
+
+ const SCEV *StartSCEV = AR->getStart();
+ Value *StartValue = nullptr;
+ if (auto *C = dyn_cast<SCEVConstant>(StartSCEV)) {
+ StartValue = C->getValue();
+ } else {
+ StartValue = PN->getIncomingValueForBlock(LoopPred);
+ assert(SE.getSCEV(StartValue) == StartSCEV && "inconsistent start value");
+ }
+
+ DomTreeNode *DTN = DT.getNode(InLoopSucc);
+ auto Inc = SE.getMonotonicPredicateType(AR, CmpInst::ICMP_UGT);
+ bool MonotonicallyIncreasing =
+ Inc && *Inc == ScalarEvolution::MonotonicallyIncreasing;
+ if (MonotonicallyIncreasing) {
+ // SCEV guarantees that AR does not wrap, so PN >= StartValue can be added
+ // unconditionally.
+ WorkList.push_back(
+ FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_UGE, PN, StartValue));
+ }
+
+ APInt StepOffset;
+ if (auto *C = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+ StepOffset = C->getAPInt();
+ else
+ return;
+
+ // Make sure the bound B is loop-invariant.
+ if (!L->isLoopInvariant(B))
+ return;
+
+ // Handle negative steps.
+ if (StepOffset.isNegative()) {
+ // TODO: Extend to allow steps > -1.
+ if (!(-StepOffset).isOne())
+ return;
+
+ // AR may wrap.
+ // Add StartValue >= PN conditional on B <= StartValue which guarantees that
+ // the loop exits before wrapping with a step of -1.
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN, CmpInst::ICMP_UGE, StartValue, PN,
+ ConditionTy(CmpInst::ICMP_ULE, B, StartValue)));
+ // Add PN > B conditional on B <= StartValue which guarantees that the loop
+ // exits when reaching B with a step of -1.
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN, CmpInst::ICMP_UGT, PN, B,
+ ConditionTy(CmpInst::ICMP_ULE, B, StartValue)));
+ return;
+ }
+
+ // Make sure AR either steps by 1 or that the value we compare against is a
+ // GEP based on the same start value and all offsets are a multiple of the
+ // step size, to guarantee that the induction will reach the value.
+ if (StepOffset.isZero() || StepOffset.isNegative())
+ return;
+
+ if (!StepOffset.isOne()) {
+ auto *UpperGEP = dyn_cast<GetElementPtrInst>(B);
+ if (!UpperGEP || UpperGEP->getPointerOperand() != StartValue ||
+ !UpperGEP->isInBounds())
+ return;
+
+ MapVector<Value *, APInt> UpperVariableOffsets;
+ APInt UpperConstantOffset(StepOffset.getBitWidth(), 0);
+ const DataLayout &DL = BB.getModule()->getDataLayout();
+ if (!UpperGEP->collectOffset(DL, StepOffset.getBitWidth(),
+ UpperVariableOffsets, UpperConstantOffset))
+ return;
+ // All variable offsets and the constant offset have to be a multiple of the
+ // step.
+ if (!UpperConstantOffset.urem(StepOffset).isZero() ||
+ any_of(UpperVariableOffsets, [&StepOffset](const auto &P) {
+ return !P.second.urem(StepOffset).isZero();
+ }))
+ return;
+ }
+
+ // AR may wrap. Add PN >= StartValue conditional on StartValue <= B which
+ // guarantees that the loop exits before wrapping in combination with the
+ // restrictions on B and the step above.
+ if (!MonotonicallyIncreasing) {
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN, CmpInst::ICMP_UGE, PN, StartValue,
+ ConditionTy(CmpInst::ICMP_ULE, StartValue, B)));
+ }
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN, CmpInst::ICMP_ULT, PN, B,
+ ConditionTy(CmpInst::ICMP_ULE, StartValue, B)));
+}
+
void State::addInfoFor(BasicBlock &BB) {
+ addInfoForInductions(BB);
+
// True as long as long as the current instruction is guaranteed to execute.
bool GuaranteedToExecute = true;
// Queue conditions and assumes.
@@ -785,27 +1010,40 @@ void State::addInfoFor(BasicBlock &BB) {
}
if (isa<MinMaxIntrinsic>(&I)) {
- WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I));
+ WorkList.push_back(FactOrCheck::getInstFact(DT.getNode(&BB), &I));
continue;
}
- Value *Cond;
+ Value *A, *B;
+ CmpInst::Predicate Pred;
// For now, just handle assumes with a single compare as condition.
- if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) &&
- isa<ICmpInst>(Cond)) {
+ if (match(&I, m_Intrinsic<Intrinsic::assume>(
+ m_ICmp(Pred, m_Value(A), m_Value(B))))) {
if (GuaranteedToExecute) {
// The assume is guaranteed to execute when BB is entered, hence Cond
// holds on entry to BB.
- WorkList.emplace_back(FactOrCheck::getFact(DT.getNode(I.getParent()),
- cast<Instruction>(Cond)));
+ WorkList.emplace_back(FactOrCheck::getConditionFact(
+ DT.getNode(I.getParent()), Pred, A, B));
} else {
WorkList.emplace_back(
- FactOrCheck::getFact(DT.getNode(I.getParent()), &I));
+ FactOrCheck::getInstFact(DT.getNode(I.getParent()), &I));
}
}
GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I);
}
+ if (auto *Switch = dyn_cast<SwitchInst>(BB.getTerminator())) {
+ for (auto &Case : Switch->cases()) {
+ BasicBlock *Succ = Case.getCaseSuccessor();
+ Value *V = Case.getCaseValue();
+ if (!canAddSuccessor(BB, Succ))
+ continue;
+ WorkList.emplace_back(FactOrCheck::getConditionFact(
+ DT.getNode(Succ), CmpInst::ICMP_EQ, Switch->getCondition(), V));
+ }
+ return;
+ }
+
auto *Br = dyn_cast<BranchInst>(BB.getTerminator());
if (!Br || !Br->isConditional())
return;
@@ -837,8 +1075,11 @@ void State::addInfoFor(BasicBlock &BB) {
while (!CondWorkList.empty()) {
Value *Cur = CondWorkList.pop_back_val();
if (auto *Cmp = dyn_cast<ICmpInst>(Cur)) {
- WorkList.emplace_back(
- FactOrCheck::getFact(DT.getNode(Successor), Cmp, IsOr));
+ WorkList.emplace_back(FactOrCheck::getConditionFact(
+ DT.getNode(Successor),
+ IsOr ? CmpInst::getInversePredicate(Cmp->getPredicate())
+ : Cmp->getPredicate(),
+ Cmp->getOperand(0), Cmp->getOperand(1)));
continue;
}
if (IsOr && match(Cur, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
@@ -860,11 +1101,14 @@ void State::addInfoFor(BasicBlock &BB) {
if (!CmpI)
return;
if (canAddSuccessor(BB, Br->getSuccessor(0)))
- WorkList.emplace_back(
- FactOrCheck::getFact(DT.getNode(Br->getSuccessor(0)), CmpI));
+ WorkList.emplace_back(FactOrCheck::getConditionFact(
+ DT.getNode(Br->getSuccessor(0)), CmpI->getPredicate(),
+ CmpI->getOperand(0), CmpI->getOperand(1)));
if (canAddSuccessor(BB, Br->getSuccessor(1)))
- WorkList.emplace_back(
- FactOrCheck::getFact(DT.getNode(Br->getSuccessor(1)), CmpI, true));
+ WorkList.emplace_back(FactOrCheck::getConditionFact(
+ DT.getNode(Br->getSuccessor(1)),
+ CmpInst::getInversePredicate(CmpI->getPredicate()), CmpI->getOperand(0),
+ CmpI->getOperand(1)));
}
namespace {
@@ -1069,7 +1313,8 @@ static std::optional<bool> checkCondition(CmpInst *Cmp, ConstraintInfo &Info,
static bool checkAndReplaceCondition(
CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut,
Instruction *ContextInst, Module *ReproducerModule,
- ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT) {
+ ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT,
+ SmallVectorImpl<Instruction *> &ToRemove) {
auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) {
generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT);
Constant *ConstantC = ConstantInt::getBool(
@@ -1090,6 +1335,8 @@ static bool checkAndReplaceCondition(
return !II || II->getIntrinsicID() != Intrinsic::assume;
});
NumCondsRemoved++;
+ if (Cmp->use_empty())
+ ToRemove.push_back(Cmp);
return true;
};
@@ -1120,6 +1367,7 @@ static bool checkAndSecondOpImpliedByFirst(
FactOrCheck &CB, ConstraintInfo &Info, Module *ReproducerModule,
SmallVectorImpl<ReproducerEntry> &ReproducerCondStack,
SmallVectorImpl<StackEntry> &DFSInStack) {
+
CmpInst::Predicate Pred;
Value *A, *B;
Instruction *And = CB.getContextInst();
@@ -1263,7 +1511,8 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info,
return Changed;
}
-static bool eliminateConstraints(Function &F, DominatorTree &DT,
+static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
+ ScalarEvolution &SE,
OptimizationRemarkEmitter &ORE) {
bool Changed = false;
DT.updateDFSNumbers();
@@ -1271,7 +1520,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
for (Value &Arg : F.args())
FunctionArgs.push_back(&Arg);
ConstraintInfo Info(F.getParent()->getDataLayout(), FunctionArgs);
- State S(DT);
+ State S(DT, LI, SE);
std::unique_ptr<Module> ReproducerModule(
DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr);
@@ -1293,8 +1542,9 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
// transfer logic.
stable_sort(S.WorkList, [](const FactOrCheck &A, const FactOrCheck &B) {
auto HasNoConstOp = [](const FactOrCheck &B) {
- return !isa<ConstantInt>(B.Inst->getOperand(0)) &&
- !isa<ConstantInt>(B.Inst->getOperand(1));
+ Value *V0 = B.isConditionFact() ? B.Cond.Op0 : B.Inst->getOperand(0);
+ Value *V1 = B.isConditionFact() ? B.Cond.Op1 : B.Inst->getOperand(1);
+ return !isa<ConstantInt>(V0) && !isa<ConstantInt>(V1);
};
// If both entries have the same In numbers, conditional facts come first.
// Otherwise use the relative order in the basic block.
@@ -1355,7 +1605,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
} else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) {
bool Simplified = checkAndReplaceCondition(
Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(),
- ReproducerModule.get(), ReproducerCondStack, S.DT);
+ ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove);
if (!Simplified && match(CB.getContextInst(),
m_LogicalAnd(m_Value(), m_Specific(Inst)))) {
Simplified =
@@ -1367,8 +1617,11 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
continue;
}
- LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n");
auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
+ LLVM_DEBUG(dbgs() << "fact to add to the system: "
+ << CmpInst::getPredicateName(Pred) << " ";
+ A->printAsOperand(dbgs()); dbgs() << ", ";
+ B->printAsOperand(dbgs(), false); dbgs() << "\n");
if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
LLVM_DEBUG(
dbgs()
@@ -1394,23 +1647,30 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
};
ICmpInst::Predicate Pred;
- if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
- Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
- AddFact(Pred, MinMax, MinMax->getLHS());
- AddFact(Pred, MinMax, MinMax->getRHS());
- continue;
+ if (!CB.isConditionFact()) {
+ if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
+ Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
+ AddFact(Pred, MinMax, MinMax->getLHS());
+ AddFact(Pred, MinMax, MinMax->getRHS());
+ continue;
+ }
}
- Value *A, *B;
- Value *Cmp = CB.Inst;
- match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
- if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
- // Use the inverse predicate if required.
- if (CB.Not)
- Pred = CmpInst::getInversePredicate(Pred);
-
- AddFact(Pred, A, B);
+ Value *A = nullptr, *B = nullptr;
+ if (CB.isConditionFact()) {
+ Pred = CB.Cond.Pred;
+ A = CB.Cond.Op0;
+ B = CB.Cond.Op1;
+ if (CB.DoesHold.Pred != CmpInst::BAD_ICMP_PREDICATE &&
+ !Info.doesHold(CB.DoesHold.Pred, CB.DoesHold.Op0, CB.DoesHold.Op1))
+ continue;
+ } else {
+ bool Matched = match(CB.Inst, m_Intrinsic<Intrinsic::assume>(
+ m_ICmp(Pred, m_Value(A), m_Value(B))));
+ (void)Matched;
+ assert(Matched && "Must have an assume intrinsic with a icmp operand");
}
+ AddFact(Pred, A, B);
}
if (ReproducerModule && !ReproducerModule->functions().empty()) {
@@ -1440,12 +1700,16 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
PreservedAnalyses ConstraintEliminationPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &LI = AM.getResult<LoopAnalysis>(F);
+ auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
- if (!eliminateConstraints(F, DT, ORE))
+ if (!eliminateConstraints(F, DT, LI, SE, ORE))
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<LoopAnalysis>();
+ PA.preserve<ScalarEvolutionAnalysis>();
PA.preserveSet<CFGAnalyses>();
return PA;
}
diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 48b27a1ea0a2..a5cf875ef354 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -55,7 +55,6 @@ static cl::opt<bool> CanonicalizeICmpPredicatesToUnsigned(
STATISTIC(NumPhis, "Number of phis propagated");
STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value");
STATISTIC(NumSelects, "Number of selects propagated");
-STATISTIC(NumMemAccess, "Number of memory access targets propagated");
STATISTIC(NumCmps, "Number of comparisons propagated");
STATISTIC(NumReturns, "Number of return values propagated");
STATISTIC(NumDeadCases, "Number of switch cases removed");
@@ -93,6 +92,7 @@ STATISTIC(NumNonNull, "Number of function pointer arguments marked non-null");
STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed");
STATISTIC(NumUDivURemsNarrowedExpanded,
"Number of bound udiv's/urem's expanded");
+STATISTIC(NumZExt, "Number of non-negative deductions");
static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
if (S->getType()->isVectorTy() || isa<Constant>(S->getCondition()))
@@ -263,23 +263,6 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT,
return Changed;
}
-static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) {
- Value *Pointer = nullptr;
- if (LoadInst *L = dyn_cast<LoadInst>(I))
- Pointer = L->getPointerOperand();
- else
- Pointer = cast<StoreInst>(I)->getPointerOperand();
-
- if (isa<Constant>(Pointer)) return false;
-
- Constant *C = LVI->getConstant(Pointer, I);
- if (!C) return false;
-
- ++NumMemAccess;
- I->replaceUsesOfWith(Pointer, C);
- return true;
-}
-
static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) {
if (!CanonicalizeICmpPredicatesToUnsigned)
return false;
@@ -294,8 +277,9 @@ static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) {
ICmpInst::Predicate UnsignedPred =
ConstantRange::getEquivalentPredWithFlippedSignedness(
- Cmp->getPredicate(), LVI->getConstantRange(Cmp->getOperand(0), Cmp),
- LVI->getConstantRange(Cmp->getOperand(1), Cmp));
+ Cmp->getPredicate(),
+ LVI->getConstantRangeAtUse(Cmp->getOperandUse(0)),
+ LVI->getConstantRangeAtUse(Cmp->getOperandUse(1)));
if (UnsignedPred == ICmpInst::Predicate::BAD_ICMP_PREDICATE)
return false;
@@ -470,17 +454,17 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
// because it is negation-invariant.
static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
Value *X = II->getArgOperand(0);
- bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne();
-
Type *Ty = X->getType();
- Constant *IntMin =
- ConstantInt::get(Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits()));
- LazyValueInfo::Tristate Result;
+ if (!Ty->isIntegerTy())
+ return false;
+
+ bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne();
+ APInt IntMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
+ ConstantRange Range = LVI->getConstantRangeAtUse(
+ II->getOperandUse(0), /*UndefAllowed*/ IsIntMinPoison);
// Is X in [0, IntMin]? NOTE: INT_MIN is fine!
- Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_ULE, X, IntMin, II,
- /*UseBlockValue=*/true);
- if (Result == LazyValueInfo::True) {
+ if (Range.icmp(CmpInst::ICMP_ULE, IntMin)) {
++NumAbs;
II->replaceAllUsesWith(X);
II->eraseFromParent();
@@ -488,40 +472,30 @@ static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
}
// Is X in [IntMin, 0]? NOTE: INT_MIN is fine!
- Constant *Zero = ConstantInt::getNullValue(Ty);
- Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_SLE, X, Zero, II,
- /*UseBlockValue=*/true);
- assert(Result != LazyValueInfo::False && "Should have been handled already.");
-
- if (Result == LazyValueInfo::Unknown) {
- // Argument's range crosses zero.
- bool Changed = false;
- if (!IsIntMinPoison) {
- // Can we at least tell that the argument is never INT_MIN?
- Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_NE, X, IntMin, II,
- /*UseBlockValue=*/true);
- if (Result == LazyValueInfo::True) {
- ++NumNSW;
- ++NumSubNSW;
- II->setArgOperand(1, ConstantInt::getTrue(II->getContext()));
- Changed = true;
- }
- }
- return Changed;
- }
+ if (Range.getSignedMax().isNonPositive()) {
+ IRBuilder<> B(II);
+ Value *NegX = B.CreateNeg(X, II->getName(), /*HasNUW=*/false,
+ /*HasNSW=*/IsIntMinPoison);
+ ++NumAbs;
+ II->replaceAllUsesWith(NegX);
+ II->eraseFromParent();
- IRBuilder<> B(II);
- Value *NegX = B.CreateNeg(X, II->getName(), /*HasNUW=*/false,
- /*HasNSW=*/IsIntMinPoison);
- ++NumAbs;
- II->replaceAllUsesWith(NegX);
- II->eraseFromParent();
+ // See if we can infer some no-wrap flags.
+ if (auto *BO = dyn_cast<BinaryOperator>(NegX))
+ processBinOp(BO, LVI);
- // See if we can infer some no-wrap flags.
- if (auto *BO = dyn_cast<BinaryOperator>(NegX))
- processBinOp(BO, LVI);
+ return true;
+ }
- return true;
+ // Argument's range crosses zero.
+ // Can we at least tell that the argument is never INT_MIN?
+ if (!IsIntMinPoison && !Range.contains(IntMin)) {
+ ++NumNSW;
+ ++NumSubNSW;
+ II->setArgOperand(1, ConstantInt::getTrue(II->getContext()));
+ return true;
+ }
+ return false;
}
// See if this min/max intrinsic always picks it's one specific operand.
@@ -783,7 +757,7 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
// NOTE: this transformation introduces two uses of X,
// but it may be undef so we must freeze it first.
Value *FrozenX = X;
- if (!isGuaranteedNotToBeUndefOrPoison(X))
+ if (!isGuaranteedNotToBeUndef(X))
FrozenX = B.CreateFreeze(X, X->getName() + ".frozen");
auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem");
auto *Cmp =
@@ -919,6 +893,14 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR,
assert(SDI->getOpcode() == Instruction::SDiv);
assert(!SDI->getType()->isVectorTy());
+ // Check whether the division folds to a constant.
+ ConstantRange DivCR = LCR.sdiv(RCR);
+ if (const APInt *Elem = DivCR.getSingleElement()) {
+ SDI->replaceAllUsesWith(ConstantInt::get(SDI->getType(), *Elem));
+ SDI->eraseFromParent();
+ return true;
+ }
+
struct Operand {
Value *V;
Domain D;
@@ -1026,12 +1008,31 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
auto *ZExt = CastInst::CreateZExtOrBitCast(Base, SDI->getType(), "", SDI);
ZExt->takeName(SDI);
ZExt->setDebugLoc(SDI->getDebugLoc());
+ ZExt->setNonNeg();
SDI->replaceAllUsesWith(ZExt);
SDI->eraseFromParent();
return true;
}
+static bool processZExt(ZExtInst *ZExt, LazyValueInfo *LVI) {
+ if (ZExt->getType()->isVectorTy())
+ return false;
+
+ if (ZExt->hasNonNeg())
+ return false;
+
+ const Use &Base = ZExt->getOperandUse(0);
+ if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
+ .isAllNonNegative())
+ return false;
+
+ ++NumZExt;
+ ZExt->setNonNeg();
+
+ return true;
+}
+
static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
using OBO = OverflowingBinaryOperator;
@@ -1140,10 +1141,6 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
case Instruction::FCmp:
BBChanged |= processCmp(cast<CmpInst>(&II), LVI);
break;
- case Instruction::Load:
- case Instruction::Store:
- BBChanged |= processMemAccess(&II, LVI);
- break;
case Instruction::Call:
case Instruction::Invoke:
BBChanged |= processCallSite(cast<CallBase>(II), LVI);
@@ -1162,6 +1159,9 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
case Instruction::SExt:
BBChanged |= processSExt(cast<SExtInst>(&II), LVI);
break;
+ case Instruction::ZExt:
+ BBChanged |= processZExt(cast<ZExtInst>(&II), LVI);
+ break;
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
diff --git a/llvm/lib/Transforms/Scalar/DCE.cpp b/llvm/lib/Transforms/Scalar/DCE.cpp
index d309799d95f0..2ad46130dc94 100644
--- a/llvm/lib/Transforms/Scalar/DCE.cpp
+++ b/llvm/lib/Transforms/Scalar/DCE.cpp
@@ -36,39 +36,6 @@ STATISTIC(DCEEliminated, "Number of insts removed");
DEBUG_COUNTER(DCECounter, "dce-transform",
"Controls which instructions are eliminated");
-//===--------------------------------------------------------------------===//
-// RedundantDbgInstElimination pass implementation
-//
-
-namespace {
-struct RedundantDbgInstElimination : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- RedundantDbgInstElimination() : FunctionPass(ID) {
- initializeRedundantDbgInstEliminationPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- bool Changed = false;
- for (auto &BB : F)
- Changed |= RemoveRedundantDbgInstrs(&BB);
- return Changed;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- }
-};
-}
-
-char RedundantDbgInstElimination::ID = 0;
-INITIALIZE_PASS(RedundantDbgInstElimination, "redundant-dbg-inst-elim",
- "Redundant Dbg Instruction Elimination", false, false)
-
-Pass *llvm::createRedundantDbgInstEliminationPass() {
- return new RedundantDbgInstElimination();
-}
-
PreservedAnalyses
RedundantDbgInstEliminationPass::run(Function &F, FunctionAnalysisManager &AM) {
bool Changed = false;
diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index f2efe60bdf88..edfeb36f3422 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -100,10 +100,10 @@ static cl::opt<unsigned> MaxPathLength(
cl::desc("Max number of blocks searched to find a threading path"),
cl::Hidden, cl::init(20));
-static cl::opt<unsigned> MaxNumPaths(
- "dfa-max-num-paths",
- cl::desc("Max number of paths enumerated around a switch"),
- cl::Hidden, cl::init(200));
+static cl::opt<unsigned>
+ MaxNumPaths("dfa-max-num-paths",
+ cl::desc("Max number of paths enumerated around a switch"),
+ cl::Hidden, cl::init(200));
static cl::opt<unsigned>
CostThreshold("dfa-cost-threshold",
@@ -249,16 +249,20 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
FT = FalseBlock;
// Update the phi node of SI.
- SIUse->removeIncomingValue(StartBlock, /* DeletePHIIfEmpty = */ false);
SIUse->addIncoming(SI->getTrueValue(), TrueBlock);
SIUse->addIncoming(SI->getFalseValue(), FalseBlock);
// Update any other PHI nodes in EndBlock.
for (PHINode &Phi : EndBlock->phis()) {
if (&Phi != SIUse) {
- Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock);
- Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock);
+ Value *OrigValue = Phi.getIncomingValueForBlock(StartBlock);
+ Phi.addIncoming(OrigValue, TrueBlock);
+ Phi.addIncoming(OrigValue, FalseBlock);
}
+
+ // Remove incoming place of original StartBlock, which comes in a indirect
+ // way (through TrueBlock and FalseBlock) now.
+ Phi.removeIncomingValue(StartBlock, /* DeletePHIIfEmpty = */ false);
}
} else {
BasicBlock *NewBlock = nullptr;
@@ -297,6 +301,7 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
{DominatorTree::Insert, StartBlock, FT}});
// The select is now dead.
+ assert(SI->use_empty() && "Select must be dead now");
SI->eraseFromParent();
}
@@ -466,8 +471,9 @@ private:
if (!SITerm || !SITerm->isUnconditional())
return false;
- if (isa<PHINode>(SIUse) &&
- SIBB->getSingleSuccessor() != cast<Instruction>(SIUse)->getParent())
+ // Only fold the select coming from directly where it is defined.
+ PHINode *PHIUser = dyn_cast<PHINode>(SIUse);
+ if (PHIUser && PHIUser->getIncomingBlock(*SI->use_begin()) != SIBB)
return false;
// If select will not be sunk during unfolding, and it is in the same basic
@@ -728,6 +734,10 @@ private:
CodeMetrics Metrics;
SwitchInst *Switch = SwitchPaths->getSwitchInst();
+ // Don't thread switch without multiple successors.
+ if (Switch->getNumSuccessors() <= 1)
+ return false;
+
// Note that DuplicateBlockMap is not being used as intended here. It is
// just being used to ensure (BB, State) pairs are only counted once.
DuplicateBlockMap DuplicateMap;
@@ -805,6 +815,8 @@ private:
// using binary search, hence the LogBase2().
unsigned CondBranches =
APInt(32, Switch->getNumSuccessors()).ceilLogBase2();
+ assert(CondBranches > 0 &&
+ "The threaded switch must have multiple branches");
DuplicationCost = Metrics.NumInsts / CondBranches;
} else {
// Compared with jump tables, the DFA optimizer removes an indirect branch
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index d3fbe49439a8..dd0a290252da 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -38,9 +38,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CaptureTracking.h"
-#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
@@ -205,16 +203,17 @@ static bool isShortenableAtTheBeginning(Instruction *I) {
return isa<AnyMemSetInst>(I);
}
-static uint64_t getPointerSize(const Value *V, const DataLayout &DL,
- const TargetLibraryInfo &TLI,
- const Function *F) {
+static std::optional<TypeSize> getPointerSize(const Value *V,
+ const DataLayout &DL,
+ const TargetLibraryInfo &TLI,
+ const Function *F) {
uint64_t Size;
ObjectSizeOpts Opts;
Opts.NullIsUnknownSize = NullPointerIsDefined(F);
if (getObjectSize(V, Size, DL, &TLI, Opts))
- return Size;
- return MemoryLocation::UnknownSize;
+ return TypeSize::getFixed(Size);
+ return std::nullopt;
}
namespace {
@@ -629,20 +628,11 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
Value *OrigDest = DeadIntrinsic->getRawDest();
if (!IsOverwriteEnd) {
- Type *Int8PtrTy =
- Type::getInt8PtrTy(DeadIntrinsic->getContext(),
- OrigDest->getType()->getPointerAddressSpace());
- Value *Dest = OrigDest;
- if (OrigDest->getType() != Int8PtrTy)
- Dest = CastInst::CreatePointerCast(OrigDest, Int8PtrTy, "", DeadI);
Value *Indices[1] = {
ConstantInt::get(DeadWriteLength->getType(), ToRemoveSize)};
Instruction *NewDestGEP = GetElementPtrInst::CreateInBounds(
- Type::getInt8Ty(DeadIntrinsic->getContext()), Dest, Indices, "", DeadI);
+ Type::getInt8Ty(DeadIntrinsic->getContext()), OrigDest, Indices, "", DeadI);
NewDestGEP->setDebugLoc(DeadIntrinsic->getDebugLoc());
- if (NewDestGEP->getType() != OrigDest->getType())
- NewDestGEP = CastInst::CreatePointerCast(NewDestGEP, OrigDest->getType(),
- "", DeadI);
DeadIntrinsic->setDest(NewDestGEP);
}
@@ -850,9 +840,6 @@ struct DSEState {
// Post-order numbers for each basic block. Used to figure out if memory
// accesses are executed before another access.
DenseMap<BasicBlock *, unsigned> PostOrderNumbers;
- // Values that are only used with assumes. Used to refine pointer escape
- // analysis.
- SmallPtrSet<const Value *, 32> EphValues;
/// Keep track of instructions (partly) overlapping with killing MemoryDefs per
/// basic block.
@@ -872,10 +859,10 @@ struct DSEState {
DSEState &operator=(const DSEState &) = delete;
DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT,
- PostDominatorTree &PDT, AssumptionCache &AC,
- const TargetLibraryInfo &TLI, const LoopInfo &LI)
- : F(F), AA(AA), EI(DT, LI, EphValues), BatchAA(AA, &EI), MSSA(MSSA),
- DT(DT), PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) {
+ PostDominatorTree &PDT, const TargetLibraryInfo &TLI,
+ const LoopInfo &LI)
+ : F(F), AA(AA), EI(DT, &LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT),
+ PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) {
// Collect blocks with throwing instructions not modeled in MemorySSA and
// alloc-like objects.
unsigned PO = 0;
@@ -905,8 +892,6 @@ struct DSEState {
AnyUnreachableExit = any_of(PDT.roots(), [](const BasicBlock *E) {
return isa<UnreachableInst>(E->getTerminator());
});
-
- CodeMetrics::collectEphemeralValues(&F, &AC, EphValues);
}
LocationSize strengthenLocationSize(const Instruction *I,
@@ -958,10 +943,11 @@ struct DSEState {
// Check whether the killing store overwrites the whole object, in which
// case the size/offset of the dead store does not matter.
- if (DeadUndObj == KillingUndObj && KillingLocSize.isPrecise()) {
- uint64_t KillingUndObjSize = getPointerSize(KillingUndObj, DL, TLI, &F);
- if (KillingUndObjSize != MemoryLocation::UnknownSize &&
- KillingUndObjSize == KillingLocSize.getValue())
+ if (DeadUndObj == KillingUndObj && KillingLocSize.isPrecise() &&
+ isIdentifiedObject(KillingUndObj)) {
+ std::optional<TypeSize> KillingUndObjSize =
+ getPointerSize(KillingUndObj, DL, TLI, &F);
+ if (KillingUndObjSize && *KillingUndObjSize == KillingLocSize.getValue())
return OW_Complete;
}
@@ -984,9 +970,15 @@ struct DSEState {
return isMaskedStoreOverwrite(KillingI, DeadI, BatchAA);
}
- const uint64_t KillingSize = KillingLocSize.getValue();
- const uint64_t DeadSize = DeadLoc.Size.getValue();
+ const TypeSize KillingSize = KillingLocSize.getValue();
+ const TypeSize DeadSize = DeadLoc.Size.getValue();
+ // Bail on doing Size comparison which depends on AA for now
+ // TODO: Remove AnyScalable once Alias Analysis deal with scalable vectors
+ const bool AnyScalable =
+ DeadSize.isScalable() || KillingLocSize.isScalable();
+ if (AnyScalable)
+ return OW_Unknown;
// Query the alias information
AliasResult AAR = BatchAA.alias(KillingLoc, DeadLoc);
@@ -1076,7 +1068,7 @@ struct DSEState {
if (!isInvisibleToCallerOnUnwind(V)) {
I.first->second = false;
} else if (isNoAliasCall(V)) {
- I.first->second = !PointerMayBeCaptured(V, true, false, EphValues);
+ I.first->second = !PointerMayBeCaptured(V, true, false);
}
}
return I.first->second;
@@ -1095,7 +1087,7 @@ struct DSEState {
// with the killing MemoryDef. But we refrain from doing so for now to
// limit compile-time and this does not cause any changes to the number
// of stores removed on a large test set in practice.
- I.first->second = PointerMayBeCaptured(V, false, true, EphValues);
+ I.first->second = PointerMayBeCaptured(V, false, true);
return !I.first->second;
}
@@ -1861,6 +1853,10 @@ struct DSEState {
if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) ||
Func != LibFunc_malloc)
return false;
+ // Gracefully handle malloc with unexpected memory attributes.
+ auto *MallocDef = dyn_cast_or_null<MemoryDef>(MSSA.getMemoryAccess(Malloc));
+ if (!MallocDef)
+ return false;
auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) {
// Check for br(icmp ptr, null), truebb, falsebb) pattern at the end
@@ -1894,11 +1890,9 @@ struct DSEState {
if (!Calloc)
return false;
MemorySSAUpdater Updater(&MSSA);
- auto *LastDef =
- cast<MemoryDef>(Updater.getMemorySSA()->getMemoryAccess(Malloc));
auto *NewAccess =
- Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), LastDef,
- LastDef);
+ Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), nullptr,
+ MallocDef);
auto *NewAccessMD = cast<MemoryDef>(NewAccess);
Updater.insertDef(NewAccessMD, /*RenameUses=*/true);
Updater.removeMemoryAccess(Malloc);
@@ -2064,12 +2058,11 @@ struct DSEState {
static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
DominatorTree &DT, PostDominatorTree &PDT,
- AssumptionCache &AC,
const TargetLibraryInfo &TLI,
const LoopInfo &LI) {
bool MadeChange = false;
- DSEState State(F, AA, MSSA, DT, PDT, AC, TLI, LI);
+ DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
// For each store:
for (unsigned I = 0; I < State.MemDefs.size(); I++) {
MemoryDef *KillingDef = State.MemDefs[I];
@@ -2250,10 +2243,9 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
- AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
- bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI);
+ bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI);
#ifdef LLVM_ENABLE_STATS
if (AreStatisticsEnabled())
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index 67e8e82e408f..f736d429cb63 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -67,6 +67,7 @@ STATISTIC(NumCSE, "Number of instructions CSE'd");
STATISTIC(NumCSECVP, "Number of compare instructions CVP'd");
STATISTIC(NumCSELoad, "Number of load instructions CSE'd");
STATISTIC(NumCSECall, "Number of call instructions CSE'd");
+STATISTIC(NumCSEGEP, "Number of GEP instructions CSE'd");
STATISTIC(NumDSE, "Number of trivial dead stores removed");
DEBUG_COUNTER(CSECounter, "early-cse",
@@ -143,11 +144,11 @@ struct SimpleValue {
!CI->getFunction()->isPresplitCoroutine();
}
return isa<CastInst>(Inst) || isa<UnaryOperator>(Inst) ||
- isa<BinaryOperator>(Inst) || isa<GetElementPtrInst>(Inst) ||
- isa<CmpInst>(Inst) || isa<SelectInst>(Inst) ||
- isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
- isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) ||
- isa<InsertValueInst>(Inst) || isa<FreezeInst>(Inst);
+ isa<BinaryOperator>(Inst) || isa<CmpInst>(Inst) ||
+ isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) ||
+ isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) ||
+ isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst) ||
+ isa<FreezeInst>(Inst);
}
};
@@ -307,21 +308,20 @@ static unsigned getHashValueImpl(SimpleValue Val) {
IVI->getOperand(1),
hash_combine_range(IVI->idx_begin(), IVI->idx_end()));
- assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) ||
- isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
- isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst) ||
- isa<FreezeInst>(Inst)) &&
+ assert((isa<CallInst>(Inst) || isa<ExtractElementInst>(Inst) ||
+ isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) ||
+ isa<UnaryOperator>(Inst) || isa<FreezeInst>(Inst)) &&
"Invalid/unknown instruction");
// Handle intrinsics with commutative operands.
- // TODO: Extend this to handle intrinsics with >2 operands where the 1st
- // 2 operands are commutative.
auto *II = dyn_cast<IntrinsicInst>(Inst);
- if (II && II->isCommutative() && II->arg_size() == 2) {
+ if (II && II->isCommutative() && II->arg_size() >= 2) {
Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
if (LHS > RHS)
std::swap(LHS, RHS);
- return hash_combine(II->getOpcode(), LHS, RHS);
+ return hash_combine(
+ II->getOpcode(), LHS, RHS,
+ hash_combine_range(II->value_op_begin() + 2, II->value_op_end()));
}
// gc.relocate is 'special' call: its second and third operands are
@@ -396,13 +396,14 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
}
- // TODO: Extend this for >2 args by matching the trailing N-2 args.
auto *LII = dyn_cast<IntrinsicInst>(LHSI);
auto *RII = dyn_cast<IntrinsicInst>(RHSI);
if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() &&
- LII->isCommutative() && LII->arg_size() == 2) {
+ LII->isCommutative() && LII->arg_size() >= 2) {
return LII->getArgOperand(0) == RII->getArgOperand(1) &&
- LII->getArgOperand(1) == RII->getArgOperand(0);
+ LII->getArgOperand(1) == RII->getArgOperand(0) &&
+ std::equal(LII->arg_begin() + 2, LII->arg_end(),
+ RII->arg_begin() + 2, RII->arg_end());
}
// See comment above in `getHashValue()`.
@@ -548,12 +549,82 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {
// currently executing, so conservatively return false if they are in
// different basic blocks.
if (LHSI->isConvergent() && LHSI->getParent() != RHSI->getParent())
- return false;
+ return false;
return LHSI->isIdenticalTo(RHSI);
}
//===----------------------------------------------------------------------===//
+// GEPValue
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct GEPValue {
+ Instruction *Inst;
+ std::optional<int64_t> ConstantOffset;
+
+ GEPValue(Instruction *I) : Inst(I) {
+ assert((isSentinel() || canHandle(I)) && "Inst can't be handled!");
+ }
+
+ GEPValue(Instruction *I, std::optional<int64_t> ConstantOffset)
+ : Inst(I), ConstantOffset(ConstantOffset) {
+ assert((isSentinel() || canHandle(I)) && "Inst can't be handled!");
+ }
+
+ bool isSentinel() const {
+ return Inst == DenseMapInfo<Instruction *>::getEmptyKey() ||
+ Inst == DenseMapInfo<Instruction *>::getTombstoneKey();
+ }
+
+ static bool canHandle(Instruction *Inst) {
+ return isa<GetElementPtrInst>(Inst);
+ }
+};
+
+} // namespace
+
+namespace llvm {
+
+template <> struct DenseMapInfo<GEPValue> {
+ static inline GEPValue getEmptyKey() {
+ return DenseMapInfo<Instruction *>::getEmptyKey();
+ }
+
+ static inline GEPValue getTombstoneKey() {
+ return DenseMapInfo<Instruction *>::getTombstoneKey();
+ }
+
+ static unsigned getHashValue(const GEPValue &Val);
+ static bool isEqual(const GEPValue &LHS, const GEPValue &RHS);
+};
+
+} // end namespace llvm
+
+unsigned DenseMapInfo<GEPValue>::getHashValue(const GEPValue &Val) {
+ auto *GEP = cast<GetElementPtrInst>(Val.Inst);
+ if (Val.ConstantOffset.has_value())
+ return hash_combine(GEP->getOpcode(), GEP->getPointerOperand(),
+ Val.ConstantOffset.value());
+ return hash_combine(
+ GEP->getOpcode(),
+ hash_combine_range(GEP->value_op_begin(), GEP->value_op_end()));
+}
+
+bool DenseMapInfo<GEPValue>::isEqual(const GEPValue &LHS, const GEPValue &RHS) {
+ if (LHS.isSentinel() || RHS.isSentinel())
+ return LHS.Inst == RHS.Inst;
+ auto *LGEP = cast<GetElementPtrInst>(LHS.Inst);
+ auto *RGEP = cast<GetElementPtrInst>(RHS.Inst);
+ if (LGEP->getPointerOperand() != RGEP->getPointerOperand())
+ return false;
+ if (LHS.ConstantOffset.has_value() && RHS.ConstantOffset.has_value())
+ return LHS.ConstantOffset.value() == RHS.ConstantOffset.value();
+ return LGEP->isIdenticalToWhenDefined(RGEP);
+}
+
+//===----------------------------------------------------------------------===//
// EarlyCSE implementation
//===----------------------------------------------------------------------===//
@@ -647,6 +718,13 @@ public:
ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>>;
CallHTType AvailableCalls;
+ using GEPMapAllocatorTy =
+ RecyclingAllocator<BumpPtrAllocator,
+ ScopedHashTableVal<GEPValue, Value *>>;
+ using GEPHTType = ScopedHashTable<GEPValue, Value *, DenseMapInfo<GEPValue>,
+ GEPMapAllocatorTy>;
+ GEPHTType AvailableGEPs;
+
/// This is the current generation of the memory value.
unsigned CurrentGeneration = 0;
@@ -667,9 +745,11 @@ private:
class NodeScope {
public:
NodeScope(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,
- InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls)
- : Scope(AvailableValues), LoadScope(AvailableLoads),
- InvariantScope(AvailableInvariants), CallScope(AvailableCalls) {}
+ InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls,
+ GEPHTType &AvailableGEPs)
+ : Scope(AvailableValues), LoadScope(AvailableLoads),
+ InvariantScope(AvailableInvariants), CallScope(AvailableCalls),
+ GEPScope(AvailableGEPs) {}
NodeScope(const NodeScope &) = delete;
NodeScope &operator=(const NodeScope &) = delete;
@@ -678,6 +758,7 @@ private:
LoadHTType::ScopeTy LoadScope;
InvariantHTType::ScopeTy InvariantScope;
CallHTType::ScopeTy CallScope;
+ GEPHTType::ScopeTy GEPScope;
};
// Contains all the needed information to create a stack for doing a depth
@@ -688,13 +769,13 @@ private:
public:
StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,
InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls,
- unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child,
+ GEPHTType &AvailableGEPs, unsigned cg, DomTreeNode *n,
+ DomTreeNode::const_iterator child,
DomTreeNode::const_iterator end)
: CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child),
EndIter(end),
Scopes(AvailableValues, AvailableLoads, AvailableInvariants,
- AvailableCalls)
- {}
+ AvailableCalls, AvailableGEPs) {}
StackNode(const StackNode &) = delete;
StackNode &operator=(const StackNode &) = delete;
@@ -1214,6 +1295,20 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst,
return Result;
}
+static void combineIRFlags(Instruction &From, Value *To) {
+ if (auto *I = dyn_cast<Instruction>(To)) {
+ // If I being poison triggers UB, there is no need to drop those
+ // flags. Otherwise, only retain flags present on both I and Inst.
+ // TODO: Currently some fast-math flags are not treated as
+ // poison-generating even though they should. Until this is fixed,
+ // always retain flags present on both I and Inst for floating point
+ // instructions.
+ if (isa<FPMathOperator>(I) ||
+ (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I)))
+ I->andIRFlags(&From);
+ }
+}
+
bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier,
const ParseMemoryInst &Later) {
// Can we remove Earlier store because of Later store?
@@ -1424,7 +1519,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
// If this is a simple instruction that we can value number, process it.
if (SimpleValue::canHandle(&Inst)) {
- if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&Inst)) {
+ if ([[maybe_unused]] auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&Inst)) {
assert(CI->getExceptionBehavior() != fp::ebStrict &&
"Unexpected ebStrict from SimpleValue::canHandle()");
assert((!CI->getRoundingMode() ||
@@ -1439,16 +1534,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
- if (auto *I = dyn_cast<Instruction>(V)) {
- // If I being poison triggers UB, there is no need to drop those
- // flags. Otherwise, only retain flags present on both I and Inst.
- // TODO: Currently some fast-math flags are not treated as
- // poison-generating even though they should. Until this is fixed,
- // always retain flags present on both I and Inst for floating point
- // instructions.
- if (isa<FPMathOperator>(I) || (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I)))
- I->andIRFlags(&Inst);
- }
+ combineIRFlags(Inst, V);
Inst.replaceAllUsesWith(V);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
@@ -1561,6 +1647,31 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
continue;
}
+ // Compare GEP instructions based on offset.
+ if (GEPValue::canHandle(&Inst)) {
+ auto *GEP = cast<GetElementPtrInst>(&Inst);
+ APInt Offset = APInt(SQ.DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+ GEPValue GEPVal(GEP, GEP->accumulateConstantOffset(SQ.DL, Offset)
+ ? Offset.trySExtValue()
+ : std::nullopt);
+ if (Value *V = AvailableGEPs.lookup(GEPVal)) {
+ LLVM_DEBUG(dbgs() << "EarlyCSE CSE GEP: " << Inst << " to: " << *V
+ << '\n');
+ combineIRFlags(Inst, V);
+ Inst.replaceAllUsesWith(V);
+ salvageKnowledge(&Inst, &AC);
+ removeMSSA(Inst);
+ Inst.eraseFromParent();
+ Changed = true;
+ ++NumCSEGEP;
+ continue;
+ }
+
+ // Otherwise, just remember that we have this GEP.
+ AvailableGEPs.insert(GEPVal, &Inst);
+ continue;
+ }
+
// A release fence requires that all stores complete before it, but does
// not prevent the reordering of following loads 'before' the fence. As a
// result, we don't need to consider it as writing to memory and don't need
@@ -1675,7 +1786,7 @@ bool EarlyCSE::run() {
// Process the root node.
nodesToProcess.push_back(new StackNode(
AvailableValues, AvailableLoads, AvailableInvariants, AvailableCalls,
- CurrentGeneration, DT.getRootNode(),
+ AvailableGEPs, CurrentGeneration, DT.getRootNode(),
DT.getRootNode()->begin(), DT.getRootNode()->end()));
assert(!CurrentGeneration && "Create a new EarlyCSE instance to rerun it.");
@@ -1698,10 +1809,10 @@ bool EarlyCSE::run() {
} else if (NodeToProcess->childIter() != NodeToProcess->end()) {
// Push the next child onto the stack.
DomTreeNode *child = NodeToProcess->nextChild();
- nodesToProcess.push_back(
- new StackNode(AvailableValues, AvailableLoads, AvailableInvariants,
- AvailableCalls, NodeToProcess->childGeneration(),
- child, child->begin(), child->end()));
+ nodesToProcess.push_back(new StackNode(
+ AvailableValues, AvailableLoads, AvailableInvariants, AvailableCalls,
+ AvailableGEPs, NodeToProcess->childGeneration(), child,
+ child->begin(), child->end()));
} else {
// It has been processed, and there are no more children to process,
// so delete it and pop it off the stack.
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 03e8a2507b45..5e58af0edc15 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -760,7 +760,7 @@ PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) {
auto &AA = AM.getResult<AAManager>(F);
auto *MemDep =
isMemDepEnabled() ? &AM.getResult<MemoryDependenceAnalysis>(F) : nullptr;
- auto *LI = AM.getCachedResult<LoopAnalysis>(F);
+ auto &LI = AM.getResult<LoopAnalysis>(F);
auto *MSSA = AM.getCachedResult<MemorySSAAnalysis>(F);
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
bool Changed = runImpl(F, AC, DT, TLI, AA, MemDep, LI, &ORE,
@@ -772,8 +772,7 @@ PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<TargetLibraryAnalysis>();
if (MSSA)
PA.preserve<MemorySSAAnalysis>();
- if (LI)
- PA.preserve<LoopAnalysis>();
+ PA.preserve<LoopAnalysis>();
return PA;
}
@@ -946,9 +945,14 @@ static void replaceValuesPerBlockEntry(
SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, Value *OldValue,
Value *NewValue) {
for (AvailableValueInBlock &V : ValuesPerBlock) {
- if ((V.AV.isSimpleValue() && V.AV.getSimpleValue() == OldValue) ||
- (V.AV.isCoercedLoadValue() && V.AV.getCoercedLoadValue() == OldValue))
- V = AvailableValueInBlock::get(V.BB, NewValue);
+ if (V.AV.Val == OldValue)
+ V.AV.Val = NewValue;
+ if (V.AV.isSelectValue()) {
+ if (V.AV.V1 == OldValue)
+ V.AV.V1 = NewValue;
+ if (V.AV.V2 == OldValue)
+ V.AV.V2 = NewValue;
+ }
}
}
@@ -1147,13 +1151,11 @@ static Value *findDominatingValue(const MemoryLocation &Loc, Type *LoadTy,
BasicBlock *FromBB = From->getParent();
BatchAAResults BatchAA(*AA);
for (BasicBlock *BB = FromBB; BB; BB = BB->getSinglePredecessor())
- for (auto I = BB == FromBB ? From->getReverseIterator() : BB->rbegin(),
- E = BB->rend();
- I != E; ++I) {
+ for (auto *Inst = BB == FromBB ? From : BB->getTerminator();
+ Inst != nullptr; Inst = Inst->getPrevNonDebugInstruction()) {
// Stop the search if limit is reached.
if (++NumVisitedInsts > MaxNumVisitedInsts)
return nullptr;
- Instruction *Inst = &*I;
if (isModSet(BatchAA.getModRefInfo(Inst, Loc)))
return nullptr;
if (auto *LI = dyn_cast<LoadInst>(Inst))
@@ -1368,7 +1370,7 @@ LoadInst *GVNPass::findLoadToHoistIntoPred(BasicBlock *Pred, BasicBlock *LoadBB,
LoadInst *Load) {
// For simplicity we handle a Pred has 2 successors only.
auto *Term = Pred->getTerminator();
- if (Term->getNumSuccessors() != 2 || Term->isExceptionalTerminator())
+ if (Term->getNumSuccessors() != 2 || Term->isSpecialTerminator())
return nullptr;
auto *SuccBB = Term->getSuccessor(0);
if (SuccBB == LoadBB)
@@ -1416,16 +1418,8 @@ void GVNPass::eliminatePartiallyRedundantLoad(
Load->getSyncScopeID(), UnavailableBlock->getTerminator());
NewLoad->setDebugLoc(Load->getDebugLoc());
if (MSSAU) {
- auto *MSSA = MSSAU->getMemorySSA();
- // Get the defining access of the original load or use the load if it is a
- // MemoryDef (e.g. because it is volatile). The inserted loads are
- // guaranteed to load from the same definition.
- auto *LoadAcc = MSSA->getMemoryAccess(Load);
- auto *DefiningAcc =
- isa<MemoryDef>(LoadAcc) ? LoadAcc : LoadAcc->getDefiningAccess();
auto *NewAccess = MSSAU->createMemoryAccessInBB(
- NewLoad, DefiningAcc, NewLoad->getParent(),
- MemorySSA::BeforeTerminator);
+ NewLoad, nullptr, NewLoad->getParent(), MemorySSA::BeforeTerminator);
if (auto *NewDef = dyn_cast<MemoryDef>(NewAccess))
MSSAU->insertDef(NewDef, /*RenameUses=*/true);
else
@@ -1444,8 +1438,7 @@ void GVNPass::eliminatePartiallyRedundantLoad(
if (auto *RangeMD = Load->getMetadata(LLVMContext::MD_range))
NewLoad->setMetadata(LLVMContext::MD_range, RangeMD);
if (auto *AccessMD = Load->getMetadata(LLVMContext::MD_access_group))
- if (LI &&
- LI->getLoopFor(Load->getParent()) == LI->getLoopFor(UnavailableBlock))
+ if (LI->getLoopFor(Load->getParent()) == LI->getLoopFor(UnavailableBlock))
NewLoad->setMetadata(LLVMContext::MD_access_group, AccessMD);
// We do not propagate the old load's debug location, because the new
@@ -1482,6 +1475,7 @@ void GVNPass::eliminatePartiallyRedundantLoad(
// Perform PHI construction.
Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this);
// ConstructSSAForLoadSet is responsible for combining metadata.
+ ICF->removeUsersOf(Load);
Load->replaceAllUsesWith(V);
if (isa<PHINode>(V))
V->takeName(Load);
@@ -1752,9 +1746,6 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock,
bool GVNPass::performLoopLoadPRE(LoadInst *Load,
AvailValInBlkVect &ValuesPerBlock,
UnavailBlkVect &UnavailableBlocks) {
- if (!LI)
- return false;
-
const Loop *L = LI->getLoopFor(Load->getParent());
// TODO: Generalize to other loop blocks that dominate the latch.
if (!L || L->getHeader() != Load->getParent())
@@ -1901,6 +1892,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
// Perform PHI construction.
Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this);
// ConstructSSAForLoadSet is responsible for combining metadata.
+ ICF->removeUsersOf(Load);
Load->replaceAllUsesWith(V);
if (isa<PHINode>(V))
@@ -1922,7 +1914,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
// Step 4: Eliminate partial redundancy.
if (!isPREEnabled() || !isLoadPREEnabled())
return Changed;
- if (!isLoadInLoopPREEnabled() && LI && LI->getLoopFor(Load->getParent()))
+ if (!isLoadInLoopPREEnabled() && LI->getLoopFor(Load->getParent()))
return Changed;
if (performLoopLoadPRE(Load, ValuesPerBlock, UnavailableBlocks) ||
@@ -1998,12 +1990,12 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
if (Cond->isZero()) {
Type *Int8Ty = Type::getInt8Ty(V->getContext());
+ Type *PtrTy = PointerType::get(V->getContext(), 0);
// Insert a new store to null instruction before the load to indicate that
// this code is not reachable. FIXME: We could insert unreachable
// instruction directly because we can modify the CFG.
auto *NewS = new StoreInst(PoisonValue::get(Int8Ty),
- Constant::getNullValue(Int8Ty->getPointerTo()),
- IntrinsicI);
+ Constant::getNullValue(PtrTy), IntrinsicI);
if (MSSAU) {
const MemoryUseOrDef *FirstNonDom = nullptr;
const auto *AL =
@@ -2023,14 +2015,12 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
}
}
- // This added store is to null, so it will never executed and we can
- // just use the LiveOnEntry def as defining access.
auto *NewDef =
FirstNonDom ? MSSAU->createMemoryAccessBefore(
- NewS, MSSAU->getMemorySSA()->getLiveOnEntryDef(),
+ NewS, nullptr,
const_cast<MemoryUseOrDef *>(FirstNonDom))
: MSSAU->createMemoryAccessInBB(
- NewS, MSSAU->getMemorySSA()->getLiveOnEntryDef(),
+ NewS, nullptr,
NewS->getParent(), MemorySSA::BeforeTerminator);
MSSAU->insertDef(cast<MemoryDef>(NewDef), /*RenameUses=*/false);
@@ -2177,6 +2167,7 @@ bool GVNPass::processLoad(LoadInst *L) {
Value *AvailableValue = AV->MaterializeAdjustedValue(L, L, *this);
// MaterializeAdjustedValue is responsible for combining metadata.
+ ICF->removeUsersOf(L);
L->replaceAllUsesWith(AvailableValue);
markInstructionForDeletion(L);
if (MSSAU)
@@ -2695,7 +2686,7 @@ bool GVNPass::processInstruction(Instruction *I) {
/// runOnFunction - This is the main transformation entry point for a function.
bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
const TargetLibraryInfo &RunTLI, AAResults &RunAA,
- MemoryDependenceResults *RunMD, LoopInfo *LI,
+ MemoryDependenceResults *RunMD, LoopInfo &LI,
OptimizationRemarkEmitter *RunORE, MemorySSA *MSSA) {
AC = &RunAC;
DT = &RunDT;
@@ -2705,7 +2696,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
MD = RunMD;
ImplicitControlFlowTracking ImplicitCFT;
ICF = &ImplicitCFT;
- this->LI = LI;
+ this->LI = &LI;
VN.setMemDep(MD);
ORE = RunORE;
InvalidBlockRPONumbers = true;
@@ -2719,7 +2710,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
// Merge unconditional branches, allowing PRE to catch more
// optimization opportunities.
for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
- bool removedBlock = MergeBlockIntoPredecessor(&BB, &DTU, LI, MSSAU, MD);
+ bool removedBlock = MergeBlockIntoPredecessor(&BB, &DTU, &LI, MSSAU, MD);
if (removedBlock)
++NumGVNBlocks;
@@ -2778,7 +2769,12 @@ bool GVNPass::processBlock(BasicBlock *BB) {
// use our normal hash approach for phis. Instead, simply look for
// obvious duplicates. The first pass of GVN will tend to create
// identical phis, and the second or later passes can eliminate them.
- ChangedFunction |= EliminateDuplicatePHINodes(BB);
+ SmallPtrSet<PHINode *, 8> PHINodesToRemove;
+ ChangedFunction |= EliminateDuplicatePHINodes(BB, PHINodesToRemove);
+ for (PHINode *PN : PHINodesToRemove) {
+ VN.erase(PN);
+ removeInstruction(PN);
+ }
for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
BI != BE;) {
@@ -2997,9 +2993,9 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {
++NumGVNPRE;
// Create a PHI to make the value available in this block.
- PHINode *Phi =
- PHINode::Create(CurInst->getType(), predMap.size(),
- CurInst->getName() + ".pre-phi", &CurrentBlock->front());
+ PHINode *Phi = PHINode::Create(CurInst->getType(), predMap.size(),
+ CurInst->getName() + ".pre-phi");
+ Phi->insertBefore(CurrentBlock->begin());
for (unsigned i = 0, e = predMap.size(); i != e; ++i) {
if (Value *V = predMap[i].first) {
// If we use an existing value in this phi, we have to patch the original
@@ -3290,8 +3286,6 @@ public:
if (skipFunction(F))
return false;
- auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
-
auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
return Impl.runImpl(
F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
@@ -3301,7 +3295,7 @@ public:
Impl.isMemDepEnabled()
? &getAnalysis<MemoryDependenceWrapperPass>().getMemDep()
: nullptr,
- LIWP ? &LIWP->getLoopInfo() : nullptr,
+ getAnalysis<LoopInfoWrapperPass>().getLoopInfo(),
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(),
MSSAWP ? &MSSAWP->getMSSA() : nullptr);
}
diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp
index 26a6978656e6..2b38831139a5 100644
--- a/llvm/lib/Transforms/Scalar/GVNSink.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp
@@ -850,8 +850,9 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
// Create a new PHI in the successor block and populate it.
auto *Op = I0->getOperand(O);
assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
- auto *PN = PHINode::Create(Op->getType(), Insts.size(),
- Op->getName() + ".sink", &BBEnd->front());
+ auto *PN =
+ PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink");
+ PN->insertBefore(BBEnd->begin());
for (auto *I : Insts)
PN->addIncoming(I->getOperand(O), I->getParent());
NewOperands.push_back(PN);
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
index 62b40a23e38c..3bbf6642a90c 100644
--- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
@@ -45,16 +45,14 @@
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
@@ -123,12 +121,12 @@ static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) {
/// condition should stay invariant. Otherwise there can be a miscompile, like
/// the one described at https://github.com/llvm/llvm-project/issues/60234. The
/// safest way to do it is to expand the new condition at WC's block.
-static Instruction *findInsertionPointForWideCondition(Instruction *Guard) {
- Value *Condition, *WC;
- BasicBlock *IfTrue, *IfFalse;
- if (parseWidenableBranch(Guard, Condition, WC, IfTrue, IfFalse))
+static Instruction *findInsertionPointForWideCondition(Instruction *WCOrGuard) {
+ if (isGuard(WCOrGuard))
+ return WCOrGuard;
+ if (auto WC = extractWidenableCondition(WCOrGuard))
return cast<Instruction>(WC);
- return Guard;
+ return nullptr;
}
class GuardWideningImpl {
@@ -157,8 +155,8 @@ class GuardWideningImpl {
/// maps BasicBlocks to the set of guards seen in that block.
bool eliminateInstrViaWidening(
Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI,
- const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> &
- GuardsPerBlock, bool InvertCondition = false);
+ const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>>
+ &GuardsPerBlock);
/// Used to keep track of which widening potential is more effective.
enum WideningScore {
@@ -181,11 +179,12 @@ class GuardWideningImpl {
static StringRef scoreTypeToString(WideningScore WS);
/// Compute the score for widening the condition in \p DominatedInstr
- /// into \p DominatingGuard. If \p InvertCond is set, then we widen the
- /// inverted condition of the dominating guard.
+ /// into \p WideningPoint.
WideningScore computeWideningScore(Instruction *DominatedInstr,
- Instruction *DominatingGuard,
- bool InvertCond);
+ Instruction *ToWiden,
+ Instruction *WideningPoint,
+ SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden);
/// Helper to check if \p V can be hoisted to \p InsertPos.
bool canBeHoistedTo(const Value *V, const Instruction *InsertPos) const {
@@ -196,19 +195,36 @@ class GuardWideningImpl {
bool canBeHoistedTo(const Value *V, const Instruction *InsertPos,
SmallPtrSetImpl<const Instruction *> &Visited) const;
+ bool canBeHoistedTo(const SmallVectorImpl<Value *> &Checks,
+ const Instruction *InsertPos) const {
+ return all_of(Checks,
+ [&](const Value *V) { return canBeHoistedTo(V, InsertPos); });
+ }
/// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c
/// canBeHoistedTo returned true.
void makeAvailableAt(Value *V, Instruction *InsertPos) const;
+ void makeAvailableAt(const SmallVectorImpl<Value *> &Checks,
+ Instruction *InsertPos) const {
+ for (Value *V : Checks)
+ makeAvailableAt(V, InsertPos);
+ }
+
/// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try
- /// to generate an expression computing the logical AND of \p Cond0 and (\p
- /// Cond1 XOR \p InvertCondition).
- /// Return true if the expression computing the AND is only as
- /// expensive as computing one of the two. If \p InsertPt is true then
- /// actually generate the resulting expression, make it available at \p
- /// InsertPt and return it in \p Result (else no change to the IR is made).
- bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt,
- Value *&Result, bool InvertCondition);
+ /// to generate an expression computing the logical AND of \p ChecksToHoist
+ /// and \p ChecksToWiden. Return true if the expression computing the AND is
+ /// only as expensive as computing one of the set of expressions. If \p
+ /// InsertPt is true then actually generate the resulting expression, make it
+ /// available at \p InsertPt and return it in \p Result (else no change to the
+ /// IR is made).
+ std::optional<Value *> mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden,
+ Instruction *InsertPt);
+
+ /// Generate the logical AND of \p ChecksToHoist and \p OldCondition and make
+ /// it available at InsertPt
+ Value *hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist,
+ Value *OldCondition, Instruction *InsertPt);
/// Adds freeze to Orig and push it as far as possible very aggressively.
/// Also replaces all uses of frozen instruction with frozen version.
@@ -253,16 +269,19 @@ class GuardWideningImpl {
}
};
- /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and
+ /// Parse \p ToParse into a conjunction (logical-and) of range checks; and
/// append them to \p Checks. Returns true on success, may clobber \c Checks
/// on failure.
- bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) {
- SmallPtrSet<const Value *, 8> Visited;
- return parseRangeChecks(CheckCond, Checks, Visited);
+ bool parseRangeChecks(SmallVectorImpl<Value *> &ToParse,
+ SmallVectorImpl<RangeCheck> &Checks) {
+ for (auto CheckCond : ToParse) {
+ if (!parseRangeChecks(CheckCond, Checks))
+ return false;
+ }
+ return true;
}
- bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks,
- SmallPtrSetImpl<const Value *> &Visited);
+ bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks);
/// Combine the checks in \p Checks into a smaller set of checks and append
/// them into \p CombinedChecks. Return true on success (i.e. all of checks
@@ -271,23 +290,24 @@ class GuardWideningImpl {
bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks,
SmallVectorImpl<RangeCheck> &CombinedChecks) const;
- /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of
- /// computing only one of the two expressions?
- bool isWideningCondProfitable(Value *Cond0, Value *Cond1, bool InvertCond) {
- Value *ResultUnused;
- return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused,
- InvertCond);
+ /// Can we compute the logical AND of \p ChecksToHoist and \p ChecksToWiden
+ /// for the price of computing only one of the set of expressions?
+ bool isWideningCondProfitable(SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden) {
+ return mergeChecks(ChecksToHoist, ChecksToWiden, /*InsertPt=*/nullptr)
+ .has_value();
}
- /// If \p InvertCondition is false, Widen \p ToWiden to fail if
- /// \p NewCondition is false, otherwise make it fail if \p NewCondition is
- /// true (in addition to whatever it is already checking).
- void widenGuard(Instruction *ToWiden, Value *NewCondition,
- bool InvertCondition) {
- Value *Result;
+ /// Widen \p ChecksToWiden to fail if any of \p ChecksToHoist is false
+ void widenGuard(SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden,
+ Instruction *ToWiden) {
Instruction *InsertPt = findInsertionPointForWideCondition(ToWiden);
- widenCondCommon(getCondition(ToWiden), NewCondition, InsertPt, Result,
- InvertCondition);
+ auto MergedCheck = mergeChecks(ChecksToHoist, ChecksToWiden, InsertPt);
+ Value *Result = MergedCheck ? *MergedCheck
+ : hoistChecks(ChecksToHoist,
+ getCondition(ToWiden), InsertPt);
+
if (isGuardAsWidenableBranch(ToWiden)) {
setWidenableBranchCond(cast<BranchInst>(ToWiden), Result);
return;
@@ -353,12 +373,15 @@ bool GuardWideningImpl::run() {
bool GuardWideningImpl::eliminateInstrViaWidening(
Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI,
- const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> &
- GuardsInBlock, bool InvertCondition) {
+ const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>>
+ &GuardsInBlock) {
+ SmallVector<Value *> ChecksToHoist;
+ parseWidenableGuard(Instr, ChecksToHoist);
// Ignore trivial true or false conditions. These instructions will be
// trivially eliminated by any cleanup pass. Do not erase them because other
// guards can possibly be widened into them.
- if (isa<ConstantInt>(getCondition(Instr)))
+ if (ChecksToHoist.empty() ||
+ (ChecksToHoist.size() == 1 && isa<ConstantInt>(ChecksToHoist.front())))
return false;
Instruction *BestSoFar = nullptr;
@@ -394,10 +417,15 @@ bool GuardWideningImpl::eliminateInstrViaWidening(
assert((i == (e - 1)) == (Instr->getParent() == CurBB) && "Bad DFS?");
for (auto *Candidate : make_range(I, E)) {
- auto Score = computeWideningScore(Instr, Candidate, InvertCondition);
- LLVM_DEBUG(dbgs() << "Score between " << *getCondition(Instr)
- << " and " << *getCondition(Candidate) << " is "
- << scoreTypeToString(Score) << "\n");
+ auto *WideningPoint = findInsertionPointForWideCondition(Candidate);
+ if (!WideningPoint)
+ continue;
+ SmallVector<Value *> CandidateChecks;
+ parseWidenableGuard(Candidate, CandidateChecks);
+ auto Score = computeWideningScore(Instr, Candidate, WideningPoint,
+ ChecksToHoist, CandidateChecks);
+ LLVM_DEBUG(dbgs() << "Score between " << *Instr << " and " << *Candidate
+ << " is " << scoreTypeToString(Score) << "\n");
if (Score > BestScoreSoFar) {
BestScoreSoFar = Score;
BestSoFar = Candidate;
@@ -416,22 +444,22 @@ bool GuardWideningImpl::eliminateInstrViaWidening(
LLVM_DEBUG(dbgs() << "Widening " << *Instr << " into " << *BestSoFar
<< " with score " << scoreTypeToString(BestScoreSoFar)
<< "\n");
- widenGuard(BestSoFar, getCondition(Instr), InvertCondition);
- auto NewGuardCondition = InvertCondition
- ? ConstantInt::getFalse(Instr->getContext())
- : ConstantInt::getTrue(Instr->getContext());
+ SmallVector<Value *> ChecksToWiden;
+ parseWidenableGuard(BestSoFar, ChecksToWiden);
+ widenGuard(ChecksToHoist, ChecksToWiden, BestSoFar);
+ auto NewGuardCondition = ConstantInt::getTrue(Instr->getContext());
setCondition(Instr, NewGuardCondition);
EliminatedGuardsAndBranches.push_back(Instr);
WidenedGuards.insert(BestSoFar);
return true;
}
-GuardWideningImpl::WideningScore
-GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
- Instruction *DominatingGuard,
- bool InvertCond) {
+GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore(
+ Instruction *DominatedInstr, Instruction *ToWiden,
+ Instruction *WideningPoint, SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden) {
Loop *DominatedInstrLoop = LI.getLoopFor(DominatedInstr->getParent());
- Loop *DominatingGuardLoop = LI.getLoopFor(DominatingGuard->getParent());
+ Loop *DominatingGuardLoop = LI.getLoopFor(WideningPoint->getParent());
bool HoistingOutOfLoop = false;
if (DominatingGuardLoop != DominatedInstrLoop) {
@@ -444,10 +472,12 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
HoistingOutOfLoop = true;
}
- auto *WideningPoint = findInsertionPointForWideCondition(DominatingGuard);
- if (!canBeHoistedTo(getCondition(DominatedInstr), WideningPoint))
+ if (!canBeHoistedTo(ChecksToHoist, WideningPoint))
return WS_IllegalOrNegative;
- if (!canBeHoistedTo(getCondition(DominatingGuard), WideningPoint))
+ // Further in the GuardWideningImpl::hoistChecks the entire condition might be
+ // widened, not the parsed list of checks. So we need to check the possibility
+ // of that condition hoisting.
+ if (!canBeHoistedTo(getCondition(ToWiden), WideningPoint))
return WS_IllegalOrNegative;
// If the guard was conditional executed, it may never be reached
@@ -458,8 +488,7 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
// here. TODO: evaluate cost model for spurious deopt
// NOTE: As written, this also lets us hoist right over another guard which
// is essentially just another spelling for control flow.
- if (isWideningCondProfitable(getCondition(DominatedInstr),
- getCondition(DominatingGuard), InvertCond))
+ if (isWideningCondProfitable(ChecksToHoist, ChecksToWiden))
return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive;
if (HoistingOutOfLoop)
@@ -495,7 +524,7 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
// control flow (guards, calls which throw, etc...). That choice appears
// arbitrary (we assume that implicit control flow exits are all rare).
auto MaybeHoistingToHotterBlock = [&]() {
- const auto *DominatingBlock = DominatingGuard->getParent();
+ const auto *DominatingBlock = WideningPoint->getParent();
const auto *DominatedBlock = DominatedInstr->getParent();
// Descend as low as we can, always taking the likely successor.
@@ -521,7 +550,8 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr,
if (!DT.dominates(DominatingBlock, DominatedBlock))
return true;
// TODO: diamond, triangle cases
- if (!PDT) return true;
+ if (!PDT)
+ return true;
return !PDT->dominates(DominatedBlock, DominatingBlock);
};
@@ -566,35 +596,47 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const {
}
// Return Instruction before which we can insert freeze for the value V as close
-// to def as possible. If there is no place to add freeze, return nullptr.
-static Instruction *getFreezeInsertPt(Value *V, const DominatorTree &DT) {
+// to def as possible. If there is no place to add freeze, return empty.
+static std::optional<BasicBlock::iterator>
+getFreezeInsertPt(Value *V, const DominatorTree &DT) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
- return &*DT.getRoot()->getFirstNonPHIOrDbgOrAlloca();
+ return DT.getRoot()->getFirstNonPHIOrDbgOrAlloca()->getIterator();
- auto *Res = I->getInsertionPointAfterDef();
+ std::optional<BasicBlock::iterator> Res = I->getInsertionPointAfterDef();
// If there is no place to add freeze - return nullptr.
- if (!Res || !DT.dominates(I, Res))
- return nullptr;
+ if (!Res || !DT.dominates(I, &**Res))
+ return std::nullopt;
+
+ Instruction *ResInst = &**Res;
// If there is a User dominated by original I, then it should be dominated
// by Freeze instruction as well.
if (any_of(I->users(), [&](User *U) {
Instruction *User = cast<Instruction>(U);
- return Res != User && DT.dominates(I, User) && !DT.dominates(Res, User);
+ return ResInst != User && DT.dominates(I, User) &&
+ !DT.dominates(ResInst, User);
}))
- return nullptr;
+ return std::nullopt;
return Res;
}
Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) {
if (isGuaranteedNotToBePoison(Orig, nullptr, InsertPt, &DT))
return Orig;
- Instruction *InsertPtAtDef = getFreezeInsertPt(Orig, DT);
- if (!InsertPtAtDef)
- return new FreezeInst(Orig, "gw.freeze", InsertPt);
- if (isa<Constant>(Orig) || isa<GlobalValue>(Orig))
- return new FreezeInst(Orig, "gw.freeze", InsertPtAtDef);
+ std::optional<BasicBlock::iterator> InsertPtAtDef =
+ getFreezeInsertPt(Orig, DT);
+ if (!InsertPtAtDef) {
+ FreezeInst *FI = new FreezeInst(Orig, "gw.freeze");
+ FI->insertBefore(InsertPt);
+ return FI;
+ }
+ if (isa<Constant>(Orig) || isa<GlobalValue>(Orig)) {
+ BasicBlock::iterator InsertPt = *InsertPtAtDef;
+ FreezeInst *FI = new FreezeInst(Orig, "gw.freeze");
+ FI->insertBefore(*InsertPt->getParent(), InsertPt);
+ return FI;
+ }
SmallSet<Value *, 16> Visited;
SmallVector<Value *, 16> Worklist;
@@ -613,8 +655,10 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) {
if (Visited.insert(Def).second) {
if (isGuaranteedNotToBePoison(Def, nullptr, InsertPt, &DT))
return true;
- CacheOfFreezes[Def] = new FreezeInst(Def, Def->getName() + ".gw.fr",
- getFreezeInsertPt(Def, DT));
+ BasicBlock::iterator InsertPt = *getFreezeInsertPt(Def, DT);
+ FreezeInst *FI = new FreezeInst(Def, Def->getName() + ".gw.fr");
+ FI->insertBefore(*InsertPt->getParent(), InsertPt);
+ CacheOfFreezes[Def] = FI;
}
if (CacheOfFreezes.count(Def))
@@ -655,8 +699,9 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) {
Value *Result = Orig;
for (Value *V : NeedFreeze) {
- auto *FreezeInsertPt = getFreezeInsertPt(V, DT);
- FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr", FreezeInsertPt);
+ BasicBlock::iterator FreezeInsertPt = *getFreezeInsertPt(V, DT);
+ FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr");
+ FI->insertBefore(*FreezeInsertPt->getParent(), FreezeInsertPt);
++FreezeAdded;
if (V == Orig)
Result = FI;
@@ -667,20 +712,25 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) {
return Result;
}
-bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
- Instruction *InsertPt, Value *&Result,
- bool InvertCondition) {
+std::optional<Value *>
+GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist,
+ SmallVectorImpl<Value *> &ChecksToWiden,
+ Instruction *InsertPt) {
using namespace llvm::PatternMatch;
+ Value *Result = nullptr;
{
// L >u C0 && L >u C1 -> L >u max(C0, C1)
ConstantInt *RHS0, *RHS1;
Value *LHS;
ICmpInst::Predicate Pred0, Pred1;
- if (match(Cond0, m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) &&
- match(Cond1, m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) {
- if (InvertCondition)
- Pred1 = ICmpInst::getInversePredicate(Pred1);
+ // TODO: Support searching for pairs to merge from both whole lists of
+ // ChecksToHoist and ChecksToWiden.
+ if (ChecksToWiden.size() == 1 && ChecksToHoist.size() == 1 &&
+ match(ChecksToWiden.front(),
+ m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) &&
+ match(ChecksToHoist.front(),
+ m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) {
ConstantRange CR0 =
ConstantRange::makeExactICmpRegion(Pred0, RHS0->getValue());
@@ -697,12 +747,12 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) {
if (InsertPt) {
ConstantInt *NewRHS =
- ConstantInt::get(Cond0->getContext(), NewRHSAP);
+ ConstantInt::get(InsertPt->getContext(), NewRHSAP);
assert(canBeHoistedTo(LHS, InsertPt) && "must be");
makeAvailableAt(LHS, InsertPt);
Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk");
}
- return true;
+ return Result;
}
}
}
@@ -710,12 +760,10 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
{
SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks;
- // TODO: Support InvertCondition case?
- if (!InvertCondition &&
- parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) &&
+ if (parseRangeChecks(ChecksToWiden, Checks) &&
+ parseRangeChecks(ChecksToHoist, Checks) &&
combineRangeChecks(Checks, CombinedChecks)) {
if (InsertPt) {
- Result = nullptr;
for (auto &RC : CombinedChecks) {
makeAvailableAt(RC.getCheckInst(), InsertPt);
if (Result)
@@ -728,40 +776,32 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1,
Result->setName("wide.chk");
Result = freezeAndPush(Result, InsertPt);
}
- return true;
+ return Result;
}
}
+ // We were not able to compute ChecksToHoist AND ChecksToWiden for the price
+ // of one.
+ return std::nullopt;
+}
- // Base case -- just logical-and the two conditions together.
-
- if (InsertPt) {
- makeAvailableAt(Cond0, InsertPt);
- makeAvailableAt(Cond1, InsertPt);
- if (InvertCondition)
- Cond1 = BinaryOperator::CreateNot(Cond1, "inverted", InsertPt);
- Cond1 = freezeAndPush(Cond1, InsertPt);
- Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt);
- }
-
- // We were not able to compute Cond0 AND Cond1 for the price of one.
- return false;
+Value *GuardWideningImpl::hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist,
+ Value *OldCondition,
+ Instruction *InsertPt) {
+ assert(!ChecksToHoist.empty());
+ IRBuilder<> Builder(InsertPt);
+ makeAvailableAt(ChecksToHoist, InsertPt);
+ makeAvailableAt(OldCondition, InsertPt);
+ Value *Result = Builder.CreateAnd(ChecksToHoist);
+ Result = freezeAndPush(Result, InsertPt);
+ Result = Builder.CreateAnd(OldCondition, Result);
+ Result->setName("wide.chk");
+ return Result;
}
bool GuardWideningImpl::parseRangeChecks(
- Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks,
- SmallPtrSetImpl<const Value *> &Visited) {
- if (!Visited.insert(CheckCond).second)
- return true;
-
+ Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks) {
using namespace llvm::PatternMatch;
- {
- Value *AndLHS, *AndRHS;
- if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS))))
- return parseRangeChecks(AndLHS, Checks) &&
- parseRangeChecks(AndRHS, Checks);
- }
-
auto *IC = dyn_cast<ICmpInst>(CheckCond);
if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() ||
(IC->getPredicate() != ICmpInst::ICMP_ULT &&
@@ -934,6 +974,15 @@ StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) {
PreservedAnalyses GuardWideningPass::run(Function &F,
FunctionAnalysisManager &AM) {
+ // Avoid requesting analyses if there are no guards or widenable conditions.
+ auto *GuardDecl = F.getParent()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_guard));
+ bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
+ auto *WCDecl = F.getParent()->getFunction(
+ Intrinsic::getName(Intrinsic::experimental_widenable_condition));
+ bool HasWidenableConditions = WCDecl && !WCDecl->use_empty();
+ if (!HasIntrinsicGuards && !HasWidenableConditions)
+ return PreservedAnalyses::all();
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &LI = AM.getResult<LoopAnalysis>(F);
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
@@ -976,109 +1025,3 @@ PreservedAnalyses GuardWideningPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-struct GuardWideningLegacyPass : public FunctionPass {
- static char ID;
-
- GuardWideningLegacyPass() : FunctionPass(ID) {
- initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
- auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- std::unique_ptr<MemorySSAUpdater> MSSAU;
- if (MSSAWP)
- MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
- return GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr,
- DT.getRootNode(),
- [](BasicBlock *) { return true; })
- .run();
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<PostDominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-
-/// Same as above, but restricted to a single loop at a time. Can be
-/// scheduled with other loop passes w/o breaking out of LPM
-struct LoopGuardWideningLegacyPass : public LoopPass {
- static char ID;
-
- LoopGuardWideningLegacyPass() : LoopPass(ID) {
- initializeLoopGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
- *L->getHeader()->getParent());
- auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
- auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
- auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- std::unique_ptr<MemorySSAUpdater> MSSAU;
- if (MSSAWP)
- MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
-
- BasicBlock *RootBB = L->getLoopPredecessor();
- if (!RootBB)
- RootBB = L->getHeader();
- auto BlockFilter = [&](BasicBlock *BB) {
- return BB == RootBB || L->contains(BB);
- };
- return GuardWideningImpl(DT, PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr,
- DT.getNode(RootBB), BlockFilter)
- .run();
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- getLoopAnalysisUsage(AU);
- AU.addPreserved<PostDominatorTreeWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-}
-
-char GuardWideningLegacyPass::ID = 0;
-char LoopGuardWideningLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards",
- false, false)
-
-INITIALIZE_PASS_BEGIN(LoopGuardWideningLegacyPass, "loop-guard-widening",
- "Widen guards (within a single loop, as a loop pass)",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(LoopGuardWideningLegacyPass, "loop-guard-widening",
- "Widen guards (within a single loop, as a loop pass)",
- false, false)
-
-FunctionPass *llvm::createGuardWideningPass() {
- return new GuardWideningLegacyPass();
-}
-
-Pass *llvm::createLoopGuardWideningPass() {
- return new LoopGuardWideningLegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 40475d9563b2..41c4d6236173 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1997,20 +1997,12 @@ bool IndVarSimplify::run(Loop *L) {
TTI, PreHeader->getTerminator()))
continue;
- // Check preconditions for proper SCEVExpander operation. SCEV does not
- // express SCEVExpander's dependencies, such as LoopSimplify. Instead
- // any pass that uses the SCEVExpander must do it. This does not work
- // well for loop passes because SCEVExpander makes assumptions about
- // all loops, while LoopPassManager only forces the current loop to be
- // simplified.
- //
- // FIXME: SCEV expansion has no way to bail out, so the caller must
- // explicitly check any assumptions made by SCEV. Brittle.
- const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ExitCount);
- if (!AR || AR->getLoop()->getLoopPreheader())
- Changed |= linearFunctionTestReplace(L, ExitingBB,
- ExitCount, IndVar,
- Rewriter);
+ if (!Rewriter.isSafeToExpand(ExitCount))
+ continue;
+
+ Changed |= linearFunctionTestReplace(L, ExitingBB,
+ ExitCount, IndVar,
+ Rewriter);
}
}
// Clear the rewriter cache, because values that are in the rewriter's cache
diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
index b52589baeee7..5f82af1ca46d 100644
--- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
@@ -81,6 +81,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopConstrainer.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
@@ -91,7 +92,6 @@
#include <limits>
#include <optional>
#include <utility>
-#include <vector>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -129,8 +129,6 @@ static cl::opt<bool>
PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks",
cl::Hidden, cl::init(false));
-static const char *ClonedLoopTag = "irce.loop.clone";
-
#define DEBUG_TYPE "irce"
namespace {
@@ -241,8 +239,6 @@ public:
SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed);
};
-struct LoopStructure;
-
class InductiveRangeCheckElimination {
ScalarEvolution &SE;
BranchProbabilityInfo *BPI;
@@ -554,649 +550,6 @@ void InductiveRangeCheck::extractRangeChecksFromBranch(
Checks, Visited);
}
-// Add metadata to the loop L to disable loop optimizations. Callers need to
-// confirm that optimizing loop L is not beneficial.
-static void DisableAllLoopOptsOnLoop(Loop &L) {
- // We do not care about any existing loopID related metadata for L, since we
- // are setting all loop metadata to false.
- LLVMContext &Context = L.getHeader()->getContext();
- // Reserve first location for self reference to the LoopID metadata node.
- MDNode *Dummy = MDNode::get(Context, {});
- MDNode *DisableUnroll = MDNode::get(
- Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
- Metadata *FalseVal =
- ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
- MDNode *DisableVectorize = MDNode::get(
- Context,
- {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
- MDNode *DisableLICMVersioning = MDNode::get(
- Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
- MDNode *DisableDistribution= MDNode::get(
- Context,
- {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
- MDNode *NewLoopID =
- MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
- DisableLICMVersioning, DisableDistribution});
- // Set operand 0 to refer to the loop id itself.
- NewLoopID->replaceOperandWith(0, NewLoopID);
- L.setLoopID(NewLoopID);
-}
-
-namespace {
-
-// Keeps track of the structure of a loop. This is similar to llvm::Loop,
-// except that it is more lightweight and can track the state of a loop through
-// changing and potentially invalid IR. This structure also formalizes the
-// kinds of loops we can deal with -- ones that have a single latch that is also
-// an exiting block *and* have a canonical induction variable.
-struct LoopStructure {
- const char *Tag = "";
-
- BasicBlock *Header = nullptr;
- BasicBlock *Latch = nullptr;
-
- // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
- // successor is `LatchExit', the exit block of the loop.
- BranchInst *LatchBr = nullptr;
- BasicBlock *LatchExit = nullptr;
- unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max();
-
- // The loop represented by this instance of LoopStructure is semantically
- // equivalent to:
- //
- // intN_ty inc = IndVarIncreasing ? 1 : -1;
- // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT;
- //
- // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase)
- // ... body ...
-
- Value *IndVarBase = nullptr;
- Value *IndVarStart = nullptr;
- Value *IndVarStep = nullptr;
- Value *LoopExitAt = nullptr;
- bool IndVarIncreasing = false;
- bool IsSignedPredicate = true;
-
- LoopStructure() = default;
-
- template <typename M> LoopStructure map(M Map) const {
- LoopStructure Result;
- Result.Tag = Tag;
- Result.Header = cast<BasicBlock>(Map(Header));
- Result.Latch = cast<BasicBlock>(Map(Latch));
- Result.LatchBr = cast<BranchInst>(Map(LatchBr));
- Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
- Result.LatchBrExitIdx = LatchBrExitIdx;
- Result.IndVarBase = Map(IndVarBase);
- Result.IndVarStart = Map(IndVarStart);
- Result.IndVarStep = Map(IndVarStep);
- Result.LoopExitAt = Map(LoopExitAt);
- Result.IndVarIncreasing = IndVarIncreasing;
- Result.IsSignedPredicate = IsSignedPredicate;
- return Result;
- }
-
- static std::optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
- Loop &, const char *&);
-};
-
-/// This class is used to constrain loops to run within a given iteration space.
-/// The algorithm this class implements is given a Loop and a range [Begin,
-/// End). The algorithm then tries to break out a "main loop" out of the loop
-/// it is given in a way that the "main loop" runs with the induction variable
-/// in a subset of [Begin, End). The algorithm emits appropriate pre and post
-/// loops to run any remaining iterations. The pre loop runs any iterations in
-/// which the induction variable is < Begin, and the post loop runs any
-/// iterations in which the induction variable is >= End.
-class LoopConstrainer {
- // The representation of a clone of the original loop we started out with.
- struct ClonedLoop {
- // The cloned blocks
- std::vector<BasicBlock *> Blocks;
-
- // `Map` maps values in the clonee into values in the cloned version
- ValueToValueMapTy Map;
-
- // An instance of `LoopStructure` for the cloned loop
- LoopStructure Structure;
- };
-
- // Result of rewriting the range of a loop. See changeIterationSpaceEnd for
- // more details on what these fields mean.
- struct RewrittenRangeInfo {
- BasicBlock *PseudoExit = nullptr;
- BasicBlock *ExitSelector = nullptr;
- std::vector<PHINode *> PHIValuesAtPseudoExit;
- PHINode *IndVarEnd = nullptr;
-
- RewrittenRangeInfo() = default;
- };
-
- // Calculated subranges we restrict the iteration space of the main loop to.
- // See the implementation of `calculateSubRanges' for more details on how
- // these fields are computed. `LowLimit` is std::nullopt if there is no
- // restriction on low end of the restricted iteration space of the main loop.
- // `HighLimit` is std::nullopt if there is no restriction on high end of the
- // restricted iteration space of the main loop.
-
- struct SubRanges {
- std::optional<const SCEV *> LowLimit;
- std::optional<const SCEV *> HighLimit;
- };
-
- // Compute a safe set of limits for the main loop to run in -- effectively the
- // intersection of `Range' and the iteration space of the original loop.
- // Return std::nullopt if unable to compute the set of subranges.
- std::optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const;
-
- // Clone `OriginalLoop' and return the result in CLResult. The IR after
- // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
- // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
- // but there is no such edge.
- void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
-
- // Create the appropriate loop structure needed to describe a cloned copy of
- // `Original`. The clone is described by `VM`.
- Loop *createClonedLoopStructure(Loop *Original, Loop *Parent,
- ValueToValueMapTy &VM, bool IsSubloop);
-
- // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
- // iteration space of the rewritten loop ends at ExitLoopAt. The start of the
- // iteration space is not changed. `ExitLoopAt' is assumed to be slt
- // `OriginalHeaderCount'.
- //
- // If there are iterations left to execute, control is made to jump to
- // `ContinuationBlock', otherwise they take the normal loop exit. The
- // returned `RewrittenRangeInfo' object is populated as follows:
- //
- // .PseudoExit is a basic block that unconditionally branches to
- // `ContinuationBlock'.
- //
- // .ExitSelector is a basic block that decides, on exit from the loop,
- // whether to branch to the "true" exit or to `PseudoExit'.
- //
- // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
- // for each PHINode in the loop header on taking the pseudo exit.
- //
- // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
- // preheader because it is made to branch to the loop header only
- // conditionally.
- RewrittenRangeInfo
- changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
- Value *ExitLoopAt,
- BasicBlock *ContinuationBlock) const;
-
- // The loop denoted by `LS' has `OldPreheader' as its preheader. This
- // function creates a new preheader for `LS' and returns it.
- BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
- const char *Tag) const;
-
- // `ContinuationBlockAndPreheader' was the continuation block for some call to
- // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
- // This function rewrites the PHI nodes in `LS.Header' to start with the
- // correct value.
- void rewriteIncomingValuesForPHIs(
- LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
- const LoopConstrainer::RewrittenRangeInfo &RRI) const;
-
- // Even though we do not preserve any passes at this time, we at least need to
- // keep the parent loop structure consistent. The `LPPassManager' seems to
- // verify this after running a loop pass. This function adds the list of
- // blocks denoted by BBs to this loops parent loop if required.
- void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
-
- // Some global state.
- Function &F;
- LLVMContext &Ctx;
- ScalarEvolution &SE;
- DominatorTree &DT;
- LoopInfo &LI;
- function_ref<void(Loop *, bool)> LPMAddNewLoop;
-
- // Information about the original loop we started out with.
- Loop &OriginalLoop;
-
- const IntegerType *ExitCountTy = nullptr;
- BasicBlock *OriginalPreheader = nullptr;
-
- // The preheader of the main loop. This may or may not be different from
- // `OriginalPreheader'.
- BasicBlock *MainLoopPreheader = nullptr;
-
- // The range we need to run the main loop in.
- InductiveRangeCheck::Range Range;
-
- // The structure of the main loop (see comment at the beginning of this class
- // for a definition)
- LoopStructure MainLoopStructure;
-
-public:
- LoopConstrainer(Loop &L, LoopInfo &LI,
- function_ref<void(Loop *, bool)> LPMAddNewLoop,
- const LoopStructure &LS, ScalarEvolution &SE,
- DominatorTree &DT, InductiveRangeCheck::Range R)
- : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
- SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L),
- Range(R), MainLoopStructure(LS) {}
-
- // Entry point for the algorithm. Returns true on success.
- bool run();
-};
-
-} // end anonymous namespace
-
-/// Given a loop with an deccreasing induction variable, is it possible to
-/// safely calculate the bounds of a new loop using the given Predicate.
-static bool isSafeDecreasingBound(const SCEV *Start,
- const SCEV *BoundSCEV, const SCEV *Step,
- ICmpInst::Predicate Pred,
- unsigned LatchBrExitIdx,
- Loop *L, ScalarEvolution &SE) {
- if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
- Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
- return false;
-
- if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
- return false;
-
- assert(SE.isKnownNegative(Step) && "expecting negative step");
-
- LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n");
- LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
- LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
- LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
- LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
- LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
-
- bool IsSigned = ICmpInst::isSigned(Pred);
- // The predicate that we need to check that the induction variable lies
- // within bounds.
- ICmpInst::Predicate BoundPred =
- IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
-
- if (LatchBrExitIdx == 1)
- return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
-
- assert(LatchBrExitIdx == 0 &&
- "LatchBrExitIdx should be either 0 or 1");
-
- const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
- unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
- APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) :
- APInt::getMinValue(BitWidth);
- const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
-
- const SCEV *MinusOne =
- SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
-
- return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
- SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
-
-}
-
-/// Given a loop with an increasing induction variable, is it possible to
-/// safely calculate the bounds of a new loop using the given Predicate.
-static bool isSafeIncreasingBound(const SCEV *Start,
- const SCEV *BoundSCEV, const SCEV *Step,
- ICmpInst::Predicate Pred,
- unsigned LatchBrExitIdx,
- Loop *L, ScalarEvolution &SE) {
- if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
- Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
- return false;
-
- if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
- return false;
-
- LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n");
- LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
- LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
- LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
- LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
- LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
-
- bool IsSigned = ICmpInst::isSigned(Pred);
- // The predicate that we need to check that the induction variable lies
- // within bounds.
- ICmpInst::Predicate BoundPred =
- IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
-
- if (LatchBrExitIdx == 1)
- return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
-
- assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
-
- const SCEV *StepMinusOne =
- SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
- unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
- APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) :
- APInt::getMaxValue(BitWidth);
- const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
-
- return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
- SE.getAddExpr(BoundSCEV, Step)) &&
- SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
-}
-
-/// Returns estimate for max latch taken count of the loop of the narrowest
-/// available type. If the latch block has such estimate, it is returned.
-/// Otherwise, we use max exit count of whole loop (that is potentially of wider
-/// type than latch check itself), which is still better than no estimate.
-static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
- const Loop &L) {
- const SCEV *FromBlock =
- SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
- if (isa<SCEVCouldNotCompute>(FromBlock))
- return SE.getSymbolicMaxBackedgeTakenCount(&L);
- return FromBlock;
-}
-
-std::optional<LoopStructure>
-LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
- const char *&FailureReason) {
- if (!L.isLoopSimplifyForm()) {
- FailureReason = "loop not in LoopSimplify form";
- return std::nullopt;
- }
-
- BasicBlock *Latch = L.getLoopLatch();
- assert(Latch && "Simplified loops only have one latch!");
-
- if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
- FailureReason = "loop has already been cloned";
- return std::nullopt;
- }
-
- if (!L.isLoopExiting(Latch)) {
- FailureReason = "no loop latch";
- return std::nullopt;
- }
-
- BasicBlock *Header = L.getHeader();
- BasicBlock *Preheader = L.getLoopPreheader();
- if (!Preheader) {
- FailureReason = "no preheader";
- return std::nullopt;
- }
-
- BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
- if (!LatchBr || LatchBr->isUnconditional()) {
- FailureReason = "latch terminator not conditional branch";
- return std::nullopt;
- }
-
- unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
-
- ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
- if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
- FailureReason = "latch terminator branch not conditional on integral icmp";
- return std::nullopt;
- }
-
- const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
- if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
- FailureReason = "could not compute latch count";
- return std::nullopt;
- }
- assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
- ScalarEvolution::LoopInvariant &&
- "loop variant exit count doesn't make sense!");
-
- ICmpInst::Predicate Pred = ICI->getPredicate();
- Value *LeftValue = ICI->getOperand(0);
- const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
- IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
-
- Value *RightValue = ICI->getOperand(1);
- const SCEV *RightSCEV = SE.getSCEV(RightValue);
-
- // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
- if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
- if (isa<SCEVAddRecExpr>(RightSCEV)) {
- std::swap(LeftSCEV, RightSCEV);
- std::swap(LeftValue, RightValue);
- Pred = ICmpInst::getSwappedPredicate(Pred);
- } else {
- FailureReason = "no add recurrences in the icmp";
- return std::nullopt;
- }
- }
-
- auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
- if (AR->getNoWrapFlags(SCEV::FlagNSW))
- return true;
-
- IntegerType *Ty = cast<IntegerType>(AR->getType());
- IntegerType *WideTy =
- IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
-
- const SCEVAddRecExpr *ExtendAfterOp =
- dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
- if (ExtendAfterOp) {
- const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
- const SCEV *ExtendedStep =
- SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
-
- bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
- ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
-
- if (NoSignedWrap)
- return true;
- }
-
- // We may have proved this when computing the sign extension above.
- return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
- };
-
- // `ICI` is interpreted as taking the backedge if the *next* value of the
- // induction variable satisfies some constraint.
-
- const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
- if (IndVarBase->getLoop() != &L) {
- FailureReason = "LHS in cmp is not an AddRec for this loop";
- return std::nullopt;
- }
- if (!IndVarBase->isAffine()) {
- FailureReason = "LHS in icmp not induction variable";
- return std::nullopt;
- }
- const SCEV* StepRec = IndVarBase->getStepRecurrence(SE);
- if (!isa<SCEVConstant>(StepRec)) {
- FailureReason = "LHS in icmp not induction variable";
- return std::nullopt;
- }
- ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
-
- if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
- FailureReason = "LHS in icmp needs nsw for equality predicates";
- return std::nullopt;
- }
-
- assert(!StepCI->isZero() && "Zero step?");
- bool IsIncreasing = !StepCI->isNegative();
- bool IsSignedPredicate;
- const SCEV *StartNext = IndVarBase->getStart();
- const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
- const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
- const SCEV *Step = SE.getSCEV(StepCI);
-
- const SCEV *FixedRightSCEV = nullptr;
-
- // If RightValue resides within loop (but still being loop invariant),
- // regenerate it as preheader.
- if (auto *I = dyn_cast<Instruction>(RightValue))
- if (L.contains(I->getParent()))
- FixedRightSCEV = RightSCEV;
-
- if (IsIncreasing) {
- bool DecreasedRightValueByOne = false;
- if (StepCI->isOne()) {
- // Try to turn eq/ne predicates to those we can work with.
- if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
- // while (++i != len) { while (++i < len) {
- // ... ---> ...
- // } }
- // If both parts are known non-negative, it is profitable to use
- // unsigned comparison in increasing loop. This allows us to make the
- // comparison check against "RightSCEV + 1" more optimistic.
- if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
- isKnownNonNegativeInLoop(RightSCEV, &L, SE))
- Pred = ICmpInst::ICMP_ULT;
- else
- Pred = ICmpInst::ICMP_SLT;
- else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
- // while (true) { while (true) {
- // if (++i == len) ---> if (++i > len - 1)
- // break; break;
- // ... ...
- // } }
- if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
- cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) {
- Pred = ICmpInst::ICMP_UGT;
- RightSCEV = SE.getMinusSCEV(RightSCEV,
- SE.getOne(RightSCEV->getType()));
- DecreasedRightValueByOne = true;
- } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) {
- Pred = ICmpInst::ICMP_SGT;
- RightSCEV = SE.getMinusSCEV(RightSCEV,
- SE.getOne(RightSCEV->getType()));
- DecreasedRightValueByOne = true;
- }
- }
- }
-
- bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
- bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
- bool FoundExpectedPred =
- (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
-
- if (!FoundExpectedPred) {
- FailureReason = "expected icmp slt semantically, found something else";
- return std::nullopt;
- }
-
- IsSignedPredicate = ICmpInst::isSigned(Pred);
- if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
- FailureReason = "unsigned latch conditions are explicitly prohibited";
- return std::nullopt;
- }
-
- if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
- LatchBrExitIdx, &L, SE)) {
- FailureReason = "Unsafe loop bounds";
- return std::nullopt;
- }
- if (LatchBrExitIdx == 0) {
- // We need to increase the right value unless we have already decreased
- // it virtually when we replaced EQ with SGT.
- if (!DecreasedRightValueByOne)
- FixedRightSCEV =
- SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
- } else {
- assert(!DecreasedRightValueByOne &&
- "Right value can be decreased only for LatchBrExitIdx == 0!");
- }
- } else {
- bool IncreasedRightValueByOne = false;
- if (StepCI->isMinusOne()) {
- // Try to turn eq/ne predicates to those we can work with.
- if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
- // while (--i != len) { while (--i > len) {
- // ... ---> ...
- // } }
- // We intentionally don't turn the predicate into UGT even if we know
- // that both operands are non-negative, because it will only pessimize
- // our check against "RightSCEV - 1".
- Pred = ICmpInst::ICMP_SGT;
- else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
- // while (true) { while (true) {
- // if (--i == len) ---> if (--i < len + 1)
- // break; break;
- // ... ...
- // } }
- if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
- cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
- Pred = ICmpInst::ICMP_ULT;
- RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
- IncreasedRightValueByOne = true;
- } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
- Pred = ICmpInst::ICMP_SLT;
- RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
- IncreasedRightValueByOne = true;
- }
- }
- }
-
- bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
- bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
-
- bool FoundExpectedPred =
- (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
-
- if (!FoundExpectedPred) {
- FailureReason = "expected icmp sgt semantically, found something else";
- return std::nullopt;
- }
-
- IsSignedPredicate =
- Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
-
- if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
- FailureReason = "unsigned latch conditions are explicitly prohibited";
- return std::nullopt;
- }
-
- if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
- LatchBrExitIdx, &L, SE)) {
- FailureReason = "Unsafe bounds";
- return std::nullopt;
- }
-
- if (LatchBrExitIdx == 0) {
- // We need to decrease the right value unless we have already increased
- // it virtually when we replaced EQ with SLT.
- if (!IncreasedRightValueByOne)
- FixedRightSCEV =
- SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
- } else {
- assert(!IncreasedRightValueByOne &&
- "Right value can be increased only for LatchBrExitIdx == 0!");
- }
- }
- BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
-
- assert(!L.contains(LatchExit) && "expected an exit block!");
- const DataLayout &DL = Preheader->getModule()->getDataLayout();
- SCEVExpander Expander(SE, DL, "irce");
- Instruction *Ins = Preheader->getTerminator();
-
- if (FixedRightSCEV)
- RightValue =
- Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
-
- Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
- IndVarStartV->setName("indvar.start");
-
- LoopStructure Result;
-
- Result.Tag = "main";
- Result.Header = Header;
- Result.Latch = Latch;
- Result.LatchBr = LatchBr;
- Result.LatchExit = LatchExit;
- Result.LatchBrExitIdx = LatchBrExitIdx;
- Result.IndVarStart = IndVarStartV;
- Result.IndVarStep = StepCI;
- Result.IndVarBase = LeftValue;
- Result.IndVarIncreasing = IsIncreasing;
- Result.LoopExitAt = RightValue;
- Result.IsSignedPredicate = IsSignedPredicate;
-
- FailureReason = nullptr;
-
- return Result;
-}
-
/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return
/// signed or unsigned extension of \p S to type \p Ty.
static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
@@ -1204,17 +557,23 @@ static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty);
}
-std::optional<LoopConstrainer::SubRanges>
-LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
+// Compute a safe set of limits for the main loop to run in -- effectively the
+// intersection of `Range' and the iteration space of the original loop.
+// Return std::nullopt if unable to compute the set of subranges.
+static std::optional<LoopConstrainer::SubRanges>
+calculateSubRanges(ScalarEvolution &SE, const Loop &L,
+ InductiveRangeCheck::Range &Range,
+ const LoopStructure &MainLoopStructure) {
auto *RTy = cast<IntegerType>(Range.getType());
// We only support wide range checks and narrow latches.
- if (!AllowNarrowLatchCondition && RTy != ExitCountTy)
+ if (!AllowNarrowLatchCondition && RTy != MainLoopStructure.ExitCountTy)
return std::nullopt;
- if (RTy->getBitWidth() < ExitCountTy->getBitWidth())
+ if (RTy->getBitWidth() < MainLoopStructure.ExitCountTy->getBitWidth())
return std::nullopt;
LoopConstrainer::SubRanges Result;
+ bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
// I think we can be more aggressive here and make this nuw / nsw if the
// addition that feeds into the icmp for the latch's terminating branch is nuw
// / nsw. In any case, a wrapping 2's complement addition is safe.
@@ -1245,7 +604,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
// `End`, decrementing by one every time.
//
// * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
- // induction variable is decreasing we know that that the smallest value
+ // induction variable is decreasing we know that the smallest value
// the loop body is actually executed with is `INT_SMIN` == `Smallest`.
//
// * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In
@@ -1258,7 +617,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
GreatestSeen = Start;
}
- auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) {
+ auto Clamp = [&SE, Smallest, Greatest, IsSignedPredicate](const SCEV *S) {
return IsSignedPredicate
? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S))
: SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S));
@@ -1283,464 +642,6 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
return Result;
}
-void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
- const char *Tag) const {
- for (BasicBlock *BB : OriginalLoop.getBlocks()) {
- BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
- Result.Blocks.push_back(Clone);
- Result.Map[BB] = Clone;
- }
-
- auto GetClonedValue = [&Result](Value *V) {
- assert(V && "null values not in domain!");
- auto It = Result.Map.find(V);
- if (It == Result.Map.end())
- return V;
- return static_cast<Value *>(It->second);
- };
-
- auto *ClonedLatch =
- cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
- ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
- MDNode::get(Ctx, {}));
-
- Result.Structure = MainLoopStructure.map(GetClonedValue);
- Result.Structure.Tag = Tag;
-
- for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
- BasicBlock *ClonedBB = Result.Blocks[i];
- BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
-
- assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
-
- for (Instruction &I : *ClonedBB)
- RemapInstruction(&I, Result.Map,
- RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
-
- // Exit blocks will now have one more predecessor and their PHI nodes need
- // to be edited to reflect that. No phi nodes need to be introduced because
- // the loop is in LCSSA.
-
- for (auto *SBB : successors(OriginalBB)) {
- if (OriginalLoop.contains(SBB))
- continue; // not an exit block
-
- for (PHINode &PN : SBB->phis()) {
- Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
- PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
- SE.forgetValue(&PN);
- }
- }
- }
-}
-
-LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
- const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
- BasicBlock *ContinuationBlock) const {
- // We start with a loop with a single latch:
- //
- // +--------------------+
- // | |
- // | preheader |
- // | |
- // +--------+-----------+
- // | ----------------\
- // | / |
- // +--------v----v------+ |
- // | | |
- // | header | |
- // | | |
- // +--------------------+ |
- // |
- // ..... |
- // |
- // +--------------------+ |
- // | | |
- // | latch >----------/
- // | |
- // +-------v------------+
- // |
- // |
- // | +--------------------+
- // | | |
- // +---> original exit |
- // | |
- // +--------------------+
- //
- // We change the control flow to look like
- //
- //
- // +--------------------+
- // | |
- // | preheader >-------------------------+
- // | | |
- // +--------v-----------+ |
- // | /-------------+ |
- // | / | |
- // +--------v--v--------+ | |
- // | | | |
- // | header | | +--------+ |
- // | | | | | |
- // +--------------------+ | | +-----v-----v-----------+
- // | | | |
- // | | | .pseudo.exit |
- // | | | |
- // | | +-----------v-----------+
- // | | |
- // ..... | | |
- // | | +--------v-------------+
- // +--------------------+ | | | |
- // | | | | | ContinuationBlock |
- // | latch >------+ | | |
- // | | | +----------------------+
- // +---------v----------+ |
- // | |
- // | |
- // | +---------------^-----+
- // | | |
- // +-----> .exit.selector |
- // | |
- // +----------v----------+
- // |
- // +--------------------+ |
- // | | |
- // | original exit <----+
- // | |
- // +--------------------+
-
- RewrittenRangeInfo RRI;
-
- BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
- RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
- &F, BBInsertLocation);
- RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
- BBInsertLocation);
-
- BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
- bool Increasing = LS.IndVarIncreasing;
- bool IsSignedPredicate = LS.IsSignedPredicate;
-
- IRBuilder<> B(PreheaderJump);
- auto *RangeTy = Range.getBegin()->getType();
- auto NoopOrExt = [&](Value *V) {
- if (V->getType() == RangeTy)
- return V;
- return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
- : B.CreateZExt(V, RangeTy, "wide." + V->getName());
- };
-
- // EnterLoopCond - is it okay to start executing this `LS'?
- Value *EnterLoopCond = nullptr;
- auto Pred =
- Increasing
- ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
- : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
- Value *IndVarStart = NoopOrExt(LS.IndVarStart);
- EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
-
- B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
- PreheaderJump->eraseFromParent();
-
- LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
- B.SetInsertPoint(LS.LatchBr);
- Value *IndVarBase = NoopOrExt(LS.IndVarBase);
- Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
-
- Value *CondForBranch = LS.LatchBrExitIdx == 1
- ? TakeBackedgeLoopCond
- : B.CreateNot(TakeBackedgeLoopCond);
-
- LS.LatchBr->setCondition(CondForBranch);
-
- B.SetInsertPoint(RRI.ExitSelector);
-
- // IterationsLeft - are there any more iterations left, given the original
- // upper bound on the induction variable? If not, we branch to the "real"
- // exit.
- Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
- Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
- B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
-
- BranchInst *BranchToContinuation =
- BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
-
- // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
- // each of the PHI nodes in the loop header. This feeds into the initial
- // value of the same PHI nodes if/when we continue execution.
- for (PHINode &PN : LS.Header->phis()) {
- PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
- BranchToContinuation);
-
- NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
- NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
- RRI.ExitSelector);
- RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
- }
-
- RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
- BranchToContinuation);
- RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
- RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
-
- // The latch exit now has a branch from `RRI.ExitSelector' instead of
- // `LS.Latch'. The PHI nodes need to be updated to reflect that.
- LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
-
- return RRI;
-}
-
-void LoopConstrainer::rewriteIncomingValuesForPHIs(
- LoopStructure &LS, BasicBlock *ContinuationBlock,
- const LoopConstrainer::RewrittenRangeInfo &RRI) const {
- unsigned PHIIndex = 0;
- for (PHINode &PN : LS.Header->phis())
- PN.setIncomingValueForBlock(ContinuationBlock,
- RRI.PHIValuesAtPseudoExit[PHIIndex++]);
-
- LS.IndVarStart = RRI.IndVarEnd;
-}
-
-BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
- BasicBlock *OldPreheader,
- const char *Tag) const {
- BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
- BranchInst::Create(LS.Header, Preheader);
-
- LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
-
- return Preheader;
-}
-
-void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
- Loop *ParentLoop = OriginalLoop.getParentLoop();
- if (!ParentLoop)
- return;
-
- for (BasicBlock *BB : BBs)
- ParentLoop->addBasicBlockToLoop(BB, LI);
-}
-
-Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
- ValueToValueMapTy &VM,
- bool IsSubloop) {
- Loop &New = *LI.AllocateLoop();
- if (Parent)
- Parent->addChildLoop(&New);
- else
- LI.addTopLevelLoop(&New);
- LPMAddNewLoop(&New, IsSubloop);
-
- // Add all of the blocks in Original to the new loop.
- for (auto *BB : Original->blocks())
- if (LI.getLoopFor(BB) == Original)
- New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
-
- // Add all of the subloops to the new loop.
- for (Loop *SubLoop : *Original)
- createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
-
- return &New;
-}
-
-bool LoopConstrainer::run() {
- BasicBlock *Preheader = nullptr;
- const SCEV *MaxBETakenCount =
- getNarrowestLatchMaxTakenCountEstimate(SE, OriginalLoop);
- Preheader = OriginalLoop.getLoopPreheader();
- assert(!isa<SCEVCouldNotCompute>(MaxBETakenCount) && Preheader != nullptr &&
- "preconditions!");
- ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
-
- OriginalPreheader = Preheader;
- MainLoopPreheader = Preheader;
-
- bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
- std::optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate);
- if (!MaybeSR) {
- LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n");
- return false;
- }
-
- SubRanges SR = *MaybeSR;
- bool Increasing = MainLoopStructure.IndVarIncreasing;
- IntegerType *IVTy =
- cast<IntegerType>(Range.getBegin()->getType());
-
- SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
- Instruction *InsertPt = OriginalPreheader->getTerminator();
-
- // It would have been better to make `PreLoop' and `PostLoop'
- // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
- // constructor.
- ClonedLoop PreLoop, PostLoop;
- bool NeedsPreLoop =
- Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
- bool NeedsPostLoop =
- Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
-
- Value *ExitPreLoopAt = nullptr;
- Value *ExitMainLoopAt = nullptr;
- const SCEVConstant *MinusOneS =
- cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
-
- if (NeedsPreLoop) {
- const SCEV *ExitPreLoopAtSCEV = nullptr;
-
- if (Increasing)
- ExitPreLoopAtSCEV = *SR.LowLimit;
- else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
- IsSignedPredicate))
- ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
- else {
- LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
- << "preloop exit limit. HighLimit = "
- << *(*SR.HighLimit) << "\n");
- return false;
- }
-
- if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
- LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
- << " preloop exit limit " << *ExitPreLoopAtSCEV
- << " at block " << InsertPt->getParent()->getName()
- << "\n");
- return false;
- }
-
- ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
- ExitPreLoopAt->setName("exit.preloop.at");
- }
-
- if (NeedsPostLoop) {
- const SCEV *ExitMainLoopAtSCEV = nullptr;
-
- if (Increasing)
- ExitMainLoopAtSCEV = *SR.HighLimit;
- else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
- IsSignedPredicate))
- ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
- else {
- LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
- << "mainloop exit limit. LowLimit = "
- << *(*SR.LowLimit) << "\n");
- return false;
- }
-
- if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
- LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
- << " main loop exit limit " << *ExitMainLoopAtSCEV
- << " at block " << InsertPt->getParent()->getName()
- << "\n");
- return false;
- }
-
- ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
- ExitMainLoopAt->setName("exit.mainloop.at");
- }
-
- // We clone these ahead of time so that we don't have to deal with changing
- // and temporarily invalid IR as we transform the loops.
- if (NeedsPreLoop)
- cloneLoop(PreLoop, "preloop");
- if (NeedsPostLoop)
- cloneLoop(PostLoop, "postloop");
-
- RewrittenRangeInfo PreLoopRRI;
-
- if (NeedsPreLoop) {
- Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
- PreLoop.Structure.Header);
-
- MainLoopPreheader =
- createPreheader(MainLoopStructure, Preheader, "mainloop");
- PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
- ExitPreLoopAt, MainLoopPreheader);
- rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
- PreLoopRRI);
- }
-
- BasicBlock *PostLoopPreheader = nullptr;
- RewrittenRangeInfo PostLoopRRI;
-
- if (NeedsPostLoop) {
- PostLoopPreheader =
- createPreheader(PostLoop.Structure, Preheader, "postloop");
- PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
- ExitMainLoopAt, PostLoopPreheader);
- rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
- PostLoopRRI);
- }
-
- BasicBlock *NewMainLoopPreheader =
- MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
- BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
- PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
- PostLoopRRI.ExitSelector, NewMainLoopPreheader};
-
- // Some of the above may be nullptr, filter them out before passing to
- // addToParentLoopIfNeeded.
- auto NewBlocksEnd =
- std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
-
- addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
-
- DT.recalculate(F);
-
- // We need to first add all the pre and post loop blocks into the loop
- // structures (as part of createClonedLoopStructure), and then update the
- // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
- // LI when LoopSimplifyForm is generated.
- Loop *PreL = nullptr, *PostL = nullptr;
- if (!PreLoop.Blocks.empty()) {
- PreL = createClonedLoopStructure(&OriginalLoop,
- OriginalLoop.getParentLoop(), PreLoop.Map,
- /* IsSubLoop */ false);
- }
-
- if (!PostLoop.Blocks.empty()) {
- PostL =
- createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
- PostLoop.Map, /* IsSubLoop */ false);
- }
-
- // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
- auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) {
- formLCSSARecursively(*L, DT, &LI, &SE);
- simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
- // Pre/post loops are slow paths, we do not need to perform any loop
- // optimizations on them.
- if (!IsOriginalLoop)
- DisableAllLoopOptsOnLoop(*L);
- };
- if (PreL)
- CanonicalizeLoop(PreL, false);
- if (PostL)
- CanonicalizeLoop(PostL, false);
- CanonicalizeLoop(&OriginalLoop, true);
-
- /// At this point:
- /// - We've broken a "main loop" out of the loop in a way that the "main loop"
- /// runs with the induction variable in a subset of [Begin, End).
- /// - There is no overflow when computing "main loop" exit limit.
- /// - Max latch taken count of the loop is limited.
- /// It guarantees that induction variable will not overflow iterating in the
- /// "main loop".
- if (auto BO = dyn_cast<BinaryOperator>(MainLoopStructure.IndVarBase))
- if (IsSignedPredicate)
- BO->setHasNoSignedWrap(true);
- /// TODO: support unsigned predicate.
- /// To add NUW flag we need to prove that both operands of BO are
- /// non-negative. E.g:
- /// ...
- /// %iv.next = add nsw i32 %iv, -1
- /// %cmp = icmp ult i32 %iv.next, %n
- /// br i1 %cmp, label %loopexit, label %loop
- ///
- /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
- /// overflow, therefore NUW flag is not legal here.
-
- return true;
-}
-
/// Computes and returns a range of values for the induction variable (IndVar)
/// in which the range check can be safely elided. If it cannot compute such a
/// range, returns std::nullopt.
@@ -2108,7 +1009,8 @@ bool InductiveRangeCheckElimination::run(
const char *FailureReason = nullptr;
std::optional<LoopStructure> MaybeLoopStructure =
- LoopStructure::parseLoopStructure(SE, *L, FailureReason);
+ LoopStructure::parseLoopStructure(SE, *L, AllowUnsignedLatchCondition,
+ FailureReason);
if (!MaybeLoopStructure) {
LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: "
<< FailureReason << "\n";);
@@ -2147,7 +1049,15 @@ bool InductiveRangeCheckElimination::run(
if (!SafeIterRange)
return Changed;
- LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange);
+ std::optional<LoopConstrainer::SubRanges> MaybeSR =
+ calculateSubRanges(SE, *L, *SafeIterRange, LS);
+ if (!MaybeSR) {
+ LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n");
+ return false;
+ }
+
+ LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT,
+ SafeIterRange->getBegin()->getType(), *MaybeSR);
if (LC.run()) {
Changed = true;
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index c2b5a12fd63f..1bf50d79e533 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -164,9 +164,13 @@ class InferAddressSpaces : public FunctionPass {
public:
static char ID;
- InferAddressSpaces() :
- FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) {}
- InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) {}
+ InferAddressSpaces()
+ : FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) {
+ initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry());
+ }
+ InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) {
+ initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry());
+ }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
@@ -221,8 +225,8 @@ class InferAddressSpacesImpl {
Value *V, PostorderStackTy &PostorderStack,
DenseSet<Value *> &Visited) const;
- bool rewriteIntrinsicOperands(IntrinsicInst *II,
- Value *OldV, Value *NewV) const;
+ bool rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV,
+ Value *NewV) const;
void collectRewritableIntrinsicOperands(IntrinsicInst *II,
PostorderStackTy &PostorderStack,
DenseSet<Value *> &Visited) const;
@@ -473,7 +477,7 @@ void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack(
}
// Returns all flat address expressions in function F. The elements are ordered
-// ordered in postorder.
+// in postorder.
std::vector<WeakTrackingVH>
InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
// This function implements a non-recursive postorder traversal of a partial
@@ -483,8 +487,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
DenseSet<Value *> Visited;
auto PushPtrOperand = [&](Value *Ptr) {
- appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack,
- Visited);
+ appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, Visited);
};
// Look at operations that may be interesting accelerate by moving to a known
@@ -519,8 +522,11 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
PushPtrOperand(ASC->getPointerOperand());
} else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
- PushPtrOperand(
- cast<Operator>(I2P->getOperand(0))->getOperand(0));
+ PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0));
+ } else if (auto *RI = dyn_cast<ReturnInst>(&I)) {
+ if (auto *RV = RI->getReturnValue();
+ RV && RV->getType()->isPtrOrPtrVectorTy())
+ PushPtrOperand(RV);
}
}
@@ -923,12 +929,14 @@ bool InferAddressSpacesImpl::updateAddressSpace(
Value *Src1 = Op.getOperand(2);
auto I = InferredAddrSpace.find(Src0);
- unsigned Src0AS = (I != InferredAddrSpace.end()) ?
- I->second : Src0->getType()->getPointerAddressSpace();
+ unsigned Src0AS = (I != InferredAddrSpace.end())
+ ? I->second
+ : Src0->getType()->getPointerAddressSpace();
auto J = InferredAddrSpace.find(Src1);
- unsigned Src1AS = (J != InferredAddrSpace.end()) ?
- J->second : Src1->getType()->getPointerAddressSpace();
+ unsigned Src1AS = (J != InferredAddrSpace.end())
+ ? J->second
+ : Src1->getType()->getPointerAddressSpace();
auto *C0 = dyn_cast<Constant>(Src0);
auto *C1 = dyn_cast<Constant>(Src1);
@@ -1097,7 +1105,8 @@ bool InferAddressSpacesImpl::isSafeToCastConstAddrSpace(Constant *C,
// If we already have a constant addrspacecast, it should be safe to cast it
// off.
if (Op->getOpcode() == Instruction::AddrSpaceCast)
- return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS);
+ return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)),
+ NewAS);
if (Op->getOpcode() == Instruction::IntToPtr &&
Op->getType()->getPointerAddressSpace() == FlatAddrSpace)
@@ -1128,7 +1137,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
// construction.
ValueToValueMapTy ValueWithNewAddrSpace;
SmallVector<const Use *, 32> PoisonUsesToFix;
- for (Value* V : Postorder) {
+ for (Value *V : Postorder) {
unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
// In some degenerate cases (e.g. invalid IR in unreachable code), we may
@@ -1161,6 +1170,8 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
}
SmallVector<Instruction *, 16> DeadInstructions;
+ ValueToValueMapTy VMap;
+ ValueMapper VMapper(VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
// Replaces the uses of the old address expressions with the new ones.
for (const WeakTrackingVH &WVH : Postorder) {
@@ -1174,18 +1185,41 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
<< *NewV << '\n');
if (Constant *C = dyn_cast<Constant>(V)) {
- Constant *Replace = ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
- C->getType());
+ Constant *Replace =
+ ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), C->getType());
if (C != Replace) {
LLVM_DEBUG(dbgs() << "Inserting replacement const cast: " << Replace
<< ": " << *Replace << '\n');
- C->replaceAllUsesWith(Replace);
+ SmallVector<User *, 16> WorkList;
+ for (User *U : make_early_inc_range(C->users())) {
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ if (I->getFunction() == F)
+ I->replaceUsesOfWith(C, Replace);
+ } else {
+ WorkList.append(U->user_begin(), U->user_end());
+ }
+ }
+ if (!WorkList.empty()) {
+ VMap[C] = Replace;
+ DenseSet<User *> Visited{WorkList.begin(), WorkList.end()};
+ while (!WorkList.empty()) {
+ User *U = WorkList.pop_back_val();
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ if (I->getFunction() == F)
+ VMapper.remapInstruction(*I);
+ continue;
+ }
+ for (User *U2 : U->users())
+ if (Visited.insert(U2).second)
+ WorkList.push_back(U2);
+ }
+ }
V = Replace;
}
}
Value::use_iterator I, E, Next;
- for (I = V->use_begin(), E = V->use_end(); I != E; ) {
+ for (I = V->use_begin(), E = V->use_end(); I != E;) {
Use &U = *I;
// Some users may see the same pointer operand in multiple operands. Skip
@@ -1205,6 +1239,11 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
// Skip if the current user is the new value itself.
if (CurUser == NewV)
continue;
+
+ if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
+ CurUserI && CurUserI->getFunction() != F)
+ continue;
+
// Handle more complex cases like intrinsic that need to be remangled.
if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
@@ -1241,8 +1280,8 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
Cmp->setOperand(SrcIdx, NewV);
- Cmp->setOperand(OtherIdx,
- ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType()));
+ Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
+ KOtherSrc, NewV->getType()));
continue;
}
}
diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
new file mode 100644
index 000000000000..b75b8d486fbb
--- /dev/null
+++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
@@ -0,0 +1,91 @@
+//===- InferAlignment.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Infer alignment for load, stores and other memory operations based on
+// trailing zero known bits information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/InferAlignment.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
+
+using namespace llvm;
+
+static bool tryToImproveAlign(
+ const DataLayout &DL, Instruction *I,
+ function_ref<Align(Value *PtrOp, Align OldAlign, Align PrefAlign)> Fn) {
+ if (auto *LI = dyn_cast<LoadInst>(I)) {
+ Value *PtrOp = LI->getPointerOperand();
+ Align OldAlign = LI->getAlign();
+ Align NewAlign = Fn(PtrOp, OldAlign, DL.getPrefTypeAlign(LI->getType()));
+ if (NewAlign > OldAlign) {
+ LI->setAlignment(NewAlign);
+ return true;
+ }
+ } else if (auto *SI = dyn_cast<StoreInst>(I)) {
+ Value *PtrOp = SI->getPointerOperand();
+ Value *ValOp = SI->getValueOperand();
+ Align OldAlign = SI->getAlign();
+ Align NewAlign = Fn(PtrOp, OldAlign, DL.getPrefTypeAlign(ValOp->getType()));
+ if (NewAlign > OldAlign) {
+ SI->setAlignment(NewAlign);
+ return true;
+ }
+ }
+ // TODO: Also handle memory intrinsics.
+ return false;
+}
+
+bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ bool Changed = false;
+
+ // Enforce preferred type alignment if possible. We do this as a separate
+ // pass first, because it may improve the alignments we infer below.
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ Changed |= tryToImproveAlign(
+ DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
+ if (PrefAlign > OldAlign)
+ return std::max(OldAlign,
+ tryEnforceAlignment(PtrOp, PrefAlign, DL));
+ return OldAlign;
+ });
+ }
+ }
+
+ // Compute alignment from known bits.
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ Changed |= tryToImproveAlign(
+ DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
+ KnownBits Known = computeKnownBits(PtrOp, DL, 0, &AC, &I, &DT);
+ unsigned TrailZ = std::min(Known.countMinTrailingZeros(),
+ +Value::MaxAlignmentExponent);
+ return Align(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
+ });
+ }
+ }
+
+ return Changed;
+}
+
+PreservedAnalyses InferAlignmentPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
+ DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ inferAlignment(F, AC, DT);
+ // Changes to alignment shouldn't invalidated analyses.
+ return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 24390f1b54f6..8603c5cf9c02 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -102,11 +102,6 @@ static cl::opt<unsigned> PhiDuplicateThreshold(
cl::desc("Max PHIs in BB to duplicate for jump threading"), cl::init(76),
cl::Hidden);
-static cl::opt<bool> PrintLVIAfterJumpThreading(
- "print-lvi-after-jump-threading",
- cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false),
- cl::Hidden);
-
static cl::opt<bool> ThreadAcrossLoopHeaders(
"jump-threading-across-loop-headers",
cl::desc("Allow JumpThreading to thread across loop headers, for testing"),
@@ -228,17 +223,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (BP >= BranchProbability(50, 100))
continue;
- SmallVector<uint32_t, 2> Weights;
+ uint32_t Weights[2];
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
- Weights.push_back(BP.getNumerator());
- Weights.push_back(BP.getCompl().getNumerator());
+ Weights[0] = BP.getNumerator();
+ Weights[1] = BP.getCompl().getNumerator();
} else {
- Weights.push_back(BP.getCompl().getNumerator());
- Weights.push_back(BP.getNumerator());
+ Weights[0] = BP.getCompl().getNumerator();
+ Weights[1] = BP.getNumerator();
}
- PredBr->setMetadata(LLVMContext::MD_prof,
- MDBuilder(PredBr->getParent()->getContext())
- .createBranchWeights(Weights));
+ setBranchWeights(*PredBr, Weights);
}
}
@@ -259,11 +252,6 @@ PreservedAnalyses JumpThreadingPass::run(Function &F,
&DT, nullptr, DomTreeUpdater::UpdateStrategy::Lazy),
std::nullopt, std::nullopt);
- if (PrintLVIAfterJumpThreading) {
- dbgs() << "LVI for function '" << F.getName() << "':\n";
- LVI.printLVI(F, getDomTreeUpdater()->getDomTree(), dbgs());
- }
-
if (!Changed)
return PreservedAnalyses::all();
@@ -412,6 +400,10 @@ static bool replaceFoldableUses(Instruction *Cond, Value *ToVal,
if (Cond->getParent() == KnownAtEndOfBB)
Changed |= replaceNonLocalUsesWith(Cond, ToVal);
for (Instruction &I : reverse(*KnownAtEndOfBB)) {
+ // Replace any debug-info record users of Cond with ToVal.
+ for (DPValue &DPV : I.getDbgValueRange())
+ DPV.replaceVariableLocationOp(Cond, ToVal, true);
+
// Reached the Cond whose uses we are trying to replace, so there are no
// more uses.
if (&I == Cond)
@@ -568,6 +560,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
Value *V, BasicBlock *BB, PredValueInfo &Result,
ConstantPreference Preference, DenseSet<Value *> &RecursionSet,
Instruction *CxtI) {
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+
// This method walks up use-def chains recursively. Because of this, we could
// get into an infinite loop going around loops in the use-def chain. To
// prevent this, keep track of what (value, block) pairs we've already visited
@@ -635,16 +629,19 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
// Handle Cast instructions.
if (CastInst *CI = dyn_cast<CastInst>(I)) {
Value *Source = CI->getOperand(0);
- computeValueKnownInPredecessorsImpl(Source, BB, Result, Preference,
+ PredValueInfoTy Vals;
+ computeValueKnownInPredecessorsImpl(Source, BB, Vals, Preference,
RecursionSet, CxtI);
- if (Result.empty())
+ if (Vals.empty())
return false;
// Convert the known values.
- for (auto &R : Result)
- R.first = ConstantExpr::getCast(CI->getOpcode(), R.first, CI->getType());
+ for (auto &Val : Vals)
+ if (Constant *Folded = ConstantFoldCastOperand(CI->getOpcode(), Val.first,
+ CI->getType(), DL))
+ Result.emplace_back(Folded, Val.second);
- return true;
+ return !Result.empty();
}
if (FreezeInst *FI = dyn_cast<FreezeInst>(I)) {
@@ -726,7 +723,6 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
if (Preference != WantInteger)
return false;
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
- const DataLayout &DL = BO->getModule()->getDataLayout();
PredValueInfoTy LHSVals;
computeValueKnownInPredecessorsImpl(BO->getOperand(0), BB, LHSVals,
WantInteger, RecursionSet, CxtI);
@@ -757,7 +753,10 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
PHINode *PN = dyn_cast<PHINode>(CmpLHS);
if (!PN)
PN = dyn_cast<PHINode>(CmpRHS);
- if (PN && PN->getParent() == BB) {
+ // Do not perform phi translation across a loop header phi, because this
+ // may result in comparison of values from two different loop iterations.
+ // FIXME: This check is broken if LoopHeaders is not populated.
+ if (PN && PN->getParent() == BB && !LoopHeaders.contains(BB)) {
const DataLayout &DL = PN->getModule()->getDataLayout();
// We can do this simplification if any comparisons fold to true or false.
// See if any do.
@@ -1269,6 +1268,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) {
if (IsLoadCSE) {
LoadInst *NLoadI = cast<LoadInst>(AvailableVal);
combineMetadataForCSE(NLoadI, LoadI, false);
+ LVI->forgetValue(NLoadI);
};
// If the returned value is the load itself, replace with poison. This can
@@ -1432,8 +1432,8 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) {
// Create a PHI node at the start of the block for the PRE'd load value.
pred_iterator PB = pred_begin(LoadBB), PE = pred_end(LoadBB);
- PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), "",
- &LoadBB->front());
+ PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), "");
+ PN->insertBefore(LoadBB->begin());
PN->takeName(LoadI);
PN->setDebugLoc(LoadI->getDebugLoc());
@@ -1461,6 +1461,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) {
for (LoadInst *PredLoadI : CSELoads) {
combineMetadataForCSE(PredLoadI, LoadI, true);
+ LVI->forgetValue(PredLoadI);
}
LoadI->replaceAllUsesWith(PN);
@@ -1899,7 +1900,7 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) {
return false;
const Instruction *TI = SinglePred->getTerminator();
- if (TI->isExceptionalTerminator() || TI->getNumSuccessors() != 1 ||
+ if (TI->isSpecialTerminator() || TI->getNumSuccessors() != 1 ||
SinglePred == BB || hasAddressTakenAndUsed(BB))
return false;
@@ -1954,6 +1955,7 @@ void JumpThreadingPass::updateSSA(
SSAUpdater SSAUpdate;
SmallVector<Use *, 16> UsesToRename;
SmallVector<DbgValueInst *, 4> DbgValues;
+ SmallVector<DPValue *, 4> DPValues;
for (Instruction &I : *BB) {
// Scan all uses of this instruction to see if it is used outside of its
@@ -1970,15 +1972,16 @@ void JumpThreadingPass::updateSSA(
}
// Find debug values outside of the block
- findDbgValues(DbgValues, &I);
- DbgValues.erase(remove_if(DbgValues,
- [&](const DbgValueInst *DbgVal) {
- return DbgVal->getParent() == BB;
- }),
- DbgValues.end());
+ findDbgValues(DbgValues, &I, &DPValues);
+ llvm::erase_if(DbgValues, [&](const DbgValueInst *DbgVal) {
+ return DbgVal->getParent() == BB;
+ });
+ llvm::erase_if(DPValues, [&](const DPValue *DPVal) {
+ return DPVal->getParent() == BB;
+ });
// If there are no uses outside the block, we're done with this instruction.
- if (UsesToRename.empty() && DbgValues.empty())
+ if (UsesToRename.empty() && DbgValues.empty() && DPValues.empty())
continue;
LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n");
@@ -1991,9 +1994,11 @@ void JumpThreadingPass::updateSSA(
while (!UsesToRename.empty())
SSAUpdate.RewriteUse(*UsesToRename.pop_back_val());
- if (!DbgValues.empty()) {
+ if (!DbgValues.empty() || !DPValues.empty()) {
SSAUpdate.UpdateDebugValues(&I, DbgValues);
+ SSAUpdate.UpdateDebugValues(&I, DPValues);
DbgValues.clear();
+ DPValues.clear();
}
LLVM_DEBUG(dbgs() << "\n");
@@ -2036,6 +2041,26 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI,
return true;
};
+ // Duplicate implementation of the above dbg.value code, using DPValues
+ // instead.
+ auto RetargetDPValueIfPossible = [&](DPValue *DPV) {
+ SmallSet<std::pair<Value *, Value *>, 16> OperandsToRemap;
+ for (auto *Op : DPV->location_ops()) {
+ Instruction *OpInst = dyn_cast<Instruction>(Op);
+ if (!OpInst)
+ continue;
+
+ auto I = ValueMapping.find(OpInst);
+ if (I != ValueMapping.end())
+ OperandsToRemap.insert({OpInst, I->second});
+ }
+
+ for (auto &[OldOp, MappedOp] : OperandsToRemap)
+ DPV->replaceVariableLocationOp(OldOp, MappedOp);
+ };
+
+ BasicBlock *RangeBB = BI->getParent();
+
// Clone the phi nodes of the source basic block into NewBB. The resulting
// phi nodes are trivial since NewBB only has one predecessor, but SSAUpdater
// might need to rewrite the operand of the cloned phi.
@@ -2054,6 +2079,12 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI,
identifyNoAliasScopesToClone(BI, BE, NoAliasScopes);
cloneNoAliasScopes(NoAliasScopes, ClonedScopes, "thread", Context);
+ auto CloneAndRemapDbgInfo = [&](Instruction *NewInst, Instruction *From) {
+ auto DPVRange = NewInst->cloneDebugInfoFrom(From);
+ for (DPValue &DPV : DPVRange)
+ RetargetDPValueIfPossible(&DPV);
+ };
+
// Clone the non-phi instructions of the source basic block into NewBB,
// keeping track of the mapping and using it to remap operands in the cloned
// instructions.
@@ -2064,6 +2095,8 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI,
ValueMapping[&*BI] = New;
adaptNoAliasScopes(New, ClonedScopes, Context);
+ CloneAndRemapDbgInfo(New, &*BI);
+
if (RetargetDbgValueIfPossible(New))
continue;
@@ -2076,6 +2109,17 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI,
}
}
+ // There may be DPValues on the terminator, clone directly from marker
+ // to marker as there isn't an instruction there.
+ if (BE != RangeBB->end() && BE->hasDbgValues()) {
+ // Dump them at the end.
+ DPMarker *Marker = RangeBB->getMarker(BE);
+ DPMarker *EndMarker = NewBB->createMarker(NewBB->end());
+ auto DPVRange = EndMarker->cloneDebugInfoFrom(Marker, std::nullopt);
+ for (DPValue &DPV : DPVRange)
+ RetargetDPValueIfPossible(&DPV);
+ }
+
return ValueMapping;
}
@@ -2245,7 +2289,7 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB,
assert(BPI && "It's expected BPI to exist along with BFI");
auto NewBBFreq = BFI->getBlockFreq(PredPredBB) *
BPI->getEdgeProbability(PredPredBB, PredBB);
- BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+ BFI->setBlockFreq(NewBB, NewBBFreq);
}
// We are going to have to map operands from the original BB block to the new
@@ -2371,7 +2415,7 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB,
assert(BPI && "It's expected BPI to exist along with BFI");
auto NewBBFreq =
BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB);
- BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+ BFI->setBlockFreq(NewBB, NewBBFreq);
}
// Copy all the instructions from BB to NewBB except the terminator.
@@ -2456,7 +2500,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
NewBBFreq += FreqMap.lookup(Pred);
}
if (BFI) // Apply the summed frequency to NewBB.
- BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+ BFI->setBlockFreq(NewBB, NewBBFreq);
}
DTU->applyUpdatesPermissive(Updates);
@@ -2496,7 +2540,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
auto NewBBFreq = BFI->getBlockFreq(NewBB);
auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB);
auto BBNewFreq = BBOrigFreq - NewBBFreq;
- BFI->setBlockFreq(BB, BBNewFreq.getFrequency());
+ BFI->setBlockFreq(BB, BBNewFreq);
// Collect updated outgoing edges' frequencies from BB and use them to update
// edge probabilities.
@@ -2567,9 +2611,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
Weights.push_back(Prob.getNumerator());
auto TI = BB->getTerminator();
- TI->setMetadata(
- LLVMContext::MD_prof,
- MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
}
}
@@ -2663,6 +2705,9 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
if (!New->mayHaveSideEffects()) {
New->eraseFromParent();
New = nullptr;
+ // Clone debug-info on the elided instruction to the destination
+ // position.
+ OldPredBranch->cloneDebugInfoFrom(&*BI, std::nullopt, true);
}
} else {
ValueMapping[&*BI] = New;
@@ -2670,6 +2715,8 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
if (New) {
// Otherwise, insert the new instruction into the block.
New->setName(BI->getName());
+ // Clone across any debug-info attached to the old instruction.
+ New->cloneDebugInfoFrom(&*BI);
// Update Dominance from simplified New instruction operands.
for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i)
if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i)))
@@ -2754,7 +2801,7 @@ void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
BranchProbability PredToNewBBProb = BranchProbability::getBranchProbability(
TrueWeight, TrueWeight + FalseWeight);
auto NewBBFreq = BFI->getBlockFreq(Pred) * PredToNewBBProb;
- BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+ BFI->setBlockFreq(NewBB, NewBBFreq);
}
// The select is now dead.
@@ -2924,7 +2971,9 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) {
Value *Cond = SI->getCondition();
if (!isGuaranteedNotToBeUndefOrPoison(Cond, nullptr, SI))
Cond = new FreezeInst(Cond, "cond.fr", SI);
- Instruction *Term = SplitBlockAndInsertIfThen(Cond, SI, false);
+ MDNode *BranchWeights = getBranchWeightMDNode(*SI);
+ Instruction *Term =
+ SplitBlockAndInsertIfThen(Cond, SI, false, BranchWeights);
BasicBlock *SplitBB = SI->getParent();
BasicBlock *NewBB = Term->getParent();
PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI);
@@ -3059,8 +3108,8 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard,
if (!isa<PHINode>(&*BI))
ToRemove.push_back(&*BI);
- Instruction *InsertionPoint = &*BB->getFirstInsertionPt();
- assert(InsertionPoint && "Empty block?");
+ BasicBlock::iterator InsertionPoint = BB->getFirstInsertionPt();
+ assert(InsertionPoint != BB->end() && "Empty block?");
// Substitute with Phis & remove.
for (auto *Inst : reverse(ToRemove)) {
if (!Inst->use_empty()) {
@@ -3070,6 +3119,7 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard,
NewPN->insertBefore(InsertionPoint);
Inst->replaceAllUsesWith(NewPN);
}
+ Inst->dropDbgValues();
Inst->eraseFromParent();
}
return true;
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index f8fab03f151d..d0afe09ce41d 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -108,6 +108,8 @@ STATISTIC(NumGEPsHoisted,
"Number of geps reassociated and hoisted out of the loop");
STATISTIC(NumAddSubHoisted, "Number of add/subtract expressions reassociated "
"and hoisted out of the loop");
+STATISTIC(NumFPAssociationsHoisted, "Number of invariant FP expressions "
+ "reassociated and hoisted out of the loop");
/// Memory promotion is enabled by default.
static cl::opt<bool>
@@ -127,6 +129,12 @@ static cl::opt<uint32_t> MaxNumUsesTraversed(
cl::desc("Max num uses visited for identifying load "
"invariance in loop using invariant start (default = 8)"));
+static cl::opt<unsigned> FPAssociationUpperLimit(
+ "licm-max-num-fp-reassociations", cl::init(5U), cl::Hidden,
+ cl::desc(
+ "Set upper limit for the number of transformations performed "
+ "during a single round of hoisting the reassociated expressions."));
+
// Experimental option to allow imprecision in LICM in pathological cases, in
// exchange for faster compile. This is to be removed if MemorySSA starts to
// address the same issue. LICM calls MemorySSAWalker's
@@ -473,12 +481,12 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
});
if (!HasCatchSwitch) {
- SmallVector<Instruction *, 8> InsertPts;
+ SmallVector<BasicBlock::iterator, 8> InsertPts;
SmallVector<MemoryAccess *, 8> MSSAInsertPts;
InsertPts.reserve(ExitBlocks.size());
MSSAInsertPts.reserve(ExitBlocks.size());
for (BasicBlock *ExitBlock : ExitBlocks) {
- InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
+ InsertPts.push_back(ExitBlock->getFirstInsertionPt());
MSSAInsertPts.push_back(nullptr);
}
@@ -985,7 +993,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
// loop invariant). If so make them unconditional by moving them to their
// immediate dominator. We iterate through the instructions in reverse order
// which ensures that when we rehoist an instruction we rehoist its operands,
- // and also keep track of where in the block we are rehoisting to to make sure
+ // and also keep track of where in the block we are rehoisting to make sure
// that we rehoist instructions before the instructions that use them.
Instruction *HoistPoint = nullptr;
if (ControlFlowHoisting) {
@@ -1031,7 +1039,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,
// invariant.start has no uses.
static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT,
Loop *CurLoop) {
- Value *Addr = LI->getOperand(0);
+ Value *Addr = LI->getPointerOperand();
const DataLayout &DL = LI->getModule()->getDataLayout();
const TypeSize LocSizeInBits = DL.getTypeSizeInBits(LI->getType());
@@ -1047,20 +1055,6 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT,
if (LocSizeInBits.isScalable())
return false;
- // if the type is i8 addrspace(x)*, we know this is the type of
- // llvm.invariant.start operand
- auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()),
- LI->getPointerAddressSpace());
- unsigned BitcastsVisited = 0;
- // Look through bitcasts until we reach the i8* type (this is invariant.start
- // operand type).
- while (Addr->getType() != PtrInt8Ty) {
- auto *BC = dyn_cast<BitCastInst>(Addr);
- // Avoid traversing high number of bitcast uses.
- if (++BitcastsVisited > MaxNumUsesTraversed || !BC)
- return false;
- Addr = BC->getOperand(0);
- }
// If we've ended up at a global/constant, bail. We shouldn't be looking at
// uselists for non-local Values in a loop pass.
if (isa<Constant>(Addr))
@@ -1480,8 +1474,9 @@ static Instruction *cloneInstructionInExitBlock(
if (LI->wouldBeOutOfLoopUseRequiringLCSSA(Op.get(), PN.getParent())) {
auto *OInst = cast<Instruction>(Op.get());
PHINode *OpPN =
- PHINode::Create(OInst->getType(), PN.getNumIncomingValues(),
- OInst->getName() + ".lcssa", &ExitBlock.front());
+ PHINode::Create(OInst->getType(), PN.getNumIncomingValues(),
+ OInst->getName() + ".lcssa");
+ OpPN->insertBefore(ExitBlock.begin());
for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i)
OpPN->addIncoming(OInst, PN.getIncomingBlock(i));
Op = OpPN;
@@ -1799,7 +1794,7 @@ namespace {
class LoopPromoter : public LoadAndStorePromoter {
Value *SomePtr; // Designated pointer to store to.
SmallVectorImpl<BasicBlock *> &LoopExitBlocks;
- SmallVectorImpl<Instruction *> &LoopInsertPts;
+ SmallVectorImpl<BasicBlock::iterator> &LoopInsertPts;
SmallVectorImpl<MemoryAccess *> &MSSAInsertPts;
PredIteratorCache &PredCache;
MemorySSAUpdater &MSSAU;
@@ -1823,7 +1818,8 @@ class LoopPromoter : public LoadAndStorePromoter {
// We need to create an LCSSA PHI node for the incoming value and
// store that.
PHINode *PN = PHINode::Create(I->getType(), PredCache.size(BB),
- I->getName() + ".lcssa", &BB->front());
+ I->getName() + ".lcssa");
+ PN->insertBefore(BB->begin());
for (BasicBlock *Pred : PredCache.get(BB))
PN->addIncoming(I, Pred);
return PN;
@@ -1832,7 +1828,7 @@ class LoopPromoter : public LoadAndStorePromoter {
public:
LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S,
SmallVectorImpl<BasicBlock *> &LEB,
- SmallVectorImpl<Instruction *> &LIP,
+ SmallVectorImpl<BasicBlock::iterator> &LIP,
SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC,
MemorySSAUpdater &MSSAU, LoopInfo &li, DebugLoc dl,
Align Alignment, bool UnorderedAtomic, const AAMDNodes &AATags,
@@ -1855,7 +1851,7 @@ public:
Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock);
LiveInValue = maybeInsertLCSSAPHI(LiveInValue, ExitBlock);
Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock);
- Instruction *InsertPos = LoopInsertPts[i];
+ BasicBlock::iterator InsertPos = LoopInsertPts[i];
StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos);
if (UnorderedAtomic)
NewSI->setOrdering(AtomicOrdering::Unordered);
@@ -1934,23 +1930,6 @@ bool isNotVisibleOnUnwindInLoop(const Value *Object, const Loop *L,
isNotCapturedBeforeOrInLoop(Object, L, DT);
}
-// We don't consider globals as writable: While the physical memory is writable,
-// we may not have provenance to perform the write.
-bool isWritableObject(const Value *Object) {
- // TODO: Alloca might not be writable after its lifetime ends.
- // See https://github.com/llvm/llvm-project/issues/51838.
- if (isa<AllocaInst>(Object))
- return true;
-
- // TODO: Also handle sret.
- if (auto *A = dyn_cast<Argument>(Object))
- return A->hasByValAttr();
-
- // TODO: Noalias has nothing to do with writability, this should check for
- // an allocator function.
- return isNoAliasCall(Object);
-}
-
bool isThreadLocalObject(const Value *Object, const Loop *L, DominatorTree *DT,
TargetTransformInfo *TTI) {
// The object must be function-local to start with, and then not captured
@@ -1970,7 +1949,7 @@ bool isThreadLocalObject(const Value *Object, const Loop *L, DominatorTree *DT,
bool llvm::promoteLoopAccessesToScalars(
const SmallSetVector<Value *, 8> &PointerMustAliases,
SmallVectorImpl<BasicBlock *> &ExitBlocks,
- SmallVectorImpl<Instruction *> &InsertPts,
+ SmallVectorImpl<BasicBlock::iterator> &InsertPts,
SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC,
LoopInfo *LI, DominatorTree *DT, AssumptionCache *AC,
const TargetLibraryInfo *TLI, TargetTransformInfo *TTI, Loop *CurLoop,
@@ -2192,7 +2171,10 @@ bool llvm::promoteLoopAccessesToScalars(
// violating the memory model.
if (StoreSafety == StoreSafetyUnknown) {
Value *Object = getUnderlyingObject(SomePtr);
- if (isWritableObject(Object) &&
+ bool ExplicitlyDereferenceableOnly;
+ if (isWritableObject(Object, ExplicitlyDereferenceableOnly) &&
+ (!ExplicitlyDereferenceableOnly ||
+ isDereferenceablePointer(SomePtr, AccessTy, MDL)) &&
isThreadLocalObject(Object, CurLoop, DT, TTI))
StoreSafety = StoreSafe;
}
@@ -2511,7 +2493,7 @@ static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
// handle both offsets being non-negative.
const DataLayout &DL = GEP->getModule()->getDataLayout();
auto NonNegative = [&](Value *V) {
- return isKnownNonNegative(V, DL, 0, AC, GEP, DT);
+ return isKnownNonNegative(V, SimplifyQuery(DL, DT, AC, GEP));
};
bool IsInBounds = Src->isInBounds() && GEP->isInBounds() &&
all_of(Src->indices(), NonNegative) &&
@@ -2561,8 +2543,9 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS,
// we want to avoid this.
auto &DL = L.getHeader()->getModule()->getDataLayout();
bool ProvedNoOverflowAfterReassociate =
- computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp,
- DT) == llvm::OverflowResult::NeverOverflows;
+ computeOverflowForSignedSub(InvariantRHS, InvariantOp,
+ SimplifyQuery(DL, DT, AC, &ICmp)) ==
+ llvm::OverflowResult::NeverOverflows;
if (!ProvedNoOverflowAfterReassociate)
return false;
auto *Preheader = L.getLoopPreheader();
@@ -2612,15 +2595,16 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS,
// we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that
// "C1 - C2" does not overflow.
auto &DL = L.getHeader()->getModule()->getDataLayout();
+ SimplifyQuery SQ(DL, DT, AC, &ICmp);
if (VariantSubtracted) {
// C1 - LV < C2 --> LV > C1 - C2
- if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp,
- DT) != llvm::OverflowResult::NeverOverflows)
+ if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, SQ) !=
+ llvm::OverflowResult::NeverOverflows)
return false;
} else {
// LV - C1 < C2 --> LV < C1 + C2
- if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp,
- DT) != llvm::OverflowResult::NeverOverflows)
+ if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, SQ) !=
+ llvm::OverflowResult::NeverOverflows)
return false;
}
auto *Preheader = L.getLoopPreheader();
@@ -2674,6 +2658,72 @@ static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
return false;
}
+/// Try to reassociate expressions like ((A1 * B1) + (A2 * B2) + ...) * C where
+/// A1, A2, ... and C are loop invariants into expressions like
+/// ((A1 * C * B1) + (A2 * C * B2) + ...) and hoist the (A1 * C), (A2 * C), ...
+/// invariant expressions. This functions returns true only if any hoisting has
+/// actually occured.
+static bool hoistFPAssociation(Instruction &I, Loop &L,
+ ICFLoopSafetyInfo &SafetyInfo,
+ MemorySSAUpdater &MSSAU, AssumptionCache *AC,
+ DominatorTree *DT) {
+ using namespace PatternMatch;
+ Value *VariantOp = nullptr, *InvariantOp = nullptr;
+
+ if (!match(&I, m_FMul(m_Value(VariantOp), m_Value(InvariantOp))) ||
+ !I.hasAllowReassoc() || !I.hasNoSignedZeros())
+ return false;
+ if (L.isLoopInvariant(VariantOp))
+ std::swap(VariantOp, InvariantOp);
+ if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp))
+ return false;
+ Value *Factor = InvariantOp;
+
+ // First, we need to make sure we should do the transformation.
+ SmallVector<Use *> Changes;
+ SmallVector<BinaryOperator *> Worklist;
+ if (BinaryOperator *VariantBinOp = dyn_cast<BinaryOperator>(VariantOp))
+ Worklist.push_back(VariantBinOp);
+ while (!Worklist.empty()) {
+ BinaryOperator *BO = Worklist.pop_back_val();
+ if (!BO->hasOneUse() || !BO->hasAllowReassoc() || !BO->hasNoSignedZeros())
+ return false;
+ BinaryOperator *Op0, *Op1;
+ if (match(BO, m_FAdd(m_BinOp(Op0), m_BinOp(Op1)))) {
+ Worklist.push_back(Op0);
+ Worklist.push_back(Op1);
+ continue;
+ }
+ if (BO->getOpcode() != Instruction::FMul || L.isLoopInvariant(BO))
+ return false;
+ Use &U0 = BO->getOperandUse(0);
+ Use &U1 = BO->getOperandUse(1);
+ if (L.isLoopInvariant(U0))
+ Changes.push_back(&U0);
+ else if (L.isLoopInvariant(U1))
+ Changes.push_back(&U1);
+ else
+ return false;
+ if (Changes.size() > FPAssociationUpperLimit)
+ return false;
+ }
+ if (Changes.empty())
+ return false;
+
+ // We know we should do it so let's do the transformation.
+ auto *Preheader = L.getLoopPreheader();
+ assert(Preheader && "Loop is not in simplify form?");
+ IRBuilder<> Builder(Preheader->getTerminator());
+ for (auto *U : Changes) {
+ assert(L.isLoopInvariant(U->get()));
+ Instruction *Ins = cast<Instruction>(U->getUser());
+ U->set(Builder.CreateFMulFMF(U->get(), Factor, Ins, "factor.op.fmul"));
+ }
+ I.replaceAllUsesWith(VariantOp);
+ eraseInstruction(I, SafetyInfo, MSSAU);
+ return true;
+}
+
static bool hoistArithmetics(Instruction &I, Loop &L,
ICFLoopSafetyInfo &SafetyInfo,
MemorySSAUpdater &MSSAU, AssumptionCache *AC,
@@ -2701,6 +2751,12 @@ static bool hoistArithmetics(Instruction &I, Loop &L,
return true;
}
+ if (hoistFPAssociation(I, L, SafetyInfo, MSSAU, AC, DT)) {
+ ++NumHoisted;
+ ++NumFPAssociationsHoisted;
+ return true;
+ }
+
return false;
}
diff --git a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp
index 9ae55b9018da..3d3f22d686e3 100644
--- a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp
@@ -20,7 +20,8 @@ PreservedAnalyses LoopAccessInfoPrinterPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &LAIs = AM.getResult<LoopAccessAnalysis>(F);
auto &LI = AM.getResult<LoopAnalysis>(F);
- OS << "Loop access info in function '" << F.getName() << "':\n";
+ OS << "Printing analysis 'Loop Access Analysis' for function '" << F.getName()
+ << "':\n";
SmallPriorityWorklist<Loop *, 4> Worklist;
appendLoopsToWorklist(LI, Worklist);
diff --git a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
index 2b9800f11912..9a27a08c86eb 100644
--- a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
@@ -430,7 +430,7 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
// Update phi node in exit block of post-loop.
- Builder.SetInsertPoint(&PostLoopPreHeader->front());
+ Builder.SetInsertPoint(PostLoopPreHeader, PostLoopPreHeader->begin());
for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
for (auto i : seq<int>(0, PN.getNumOperands())) {
// Check incoming block is pre-loop's exiting block.
diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
index 7c2770979a90..cc1f56014eee 100644
--- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
@@ -399,7 +399,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {
continue;
unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace();
- Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
+ Type *I8Ptr = PointerType::get(BB->getContext(), PtrAddrSpace);
Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
IRBuilder<> Builder(P.InsertPt);
diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
index 27196e46ca56..626888c74bad 100644
--- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -104,9 +104,9 @@ static cl::opt<unsigned> DistributeSCEVCheckThreshold(
static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold(
"loop-distribute-scev-check-threshold-with-pragma", cl::init(128),
cl::Hidden,
- cl::desc(
- "The maximum number of SCEV checks allowed for Loop "
- "Distribution for loop marked with #pragma loop distribute(enable)"));
+ cl::desc("The maximum number of SCEV checks allowed for Loop "
+ "Distribution for loop marked with #pragma clang loop "
+ "distribute(enable)"));
static cl::opt<bool> EnableLoopDistribute(
"enable-loop-distribute", cl::Hidden,
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index edc8a4956dd1..b1add3c42976 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -641,8 +641,9 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
// Check if the multiply could not overflow due to known ranges of the
// input values.
OverflowResult OR = computeOverflowForUnsignedMul(
- FI.InnerTripCount, FI.OuterTripCount, DL, AC,
- FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
+ FI.InnerTripCount, FI.OuterTripCount,
+ SimplifyQuery(DL, DT, AC,
+ FI.OuterLoop->getLoopPreheader()->getTerminator()));
if (OR != OverflowResult::MayOverflow)
return OR;
diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
index d35b562be0aa..e0b224d5ef73 100644
--- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp
@@ -1411,7 +1411,7 @@ private:
}
// Walk through all uses in FC1. For each use, find the reaching def. If the
- // def is located in FC0 then it is is not safe to fuse.
+ // def is located in FC0 then it is not safe to fuse.
for (BasicBlock *BB : FC1.L->blocks())
for (Instruction &I : *BB)
for (auto &Op : I.operands())
@@ -1473,12 +1473,13 @@ private:
for (Instruction *I : HoistInsts) {
assert(I->getParent() == FC1.Preheader);
- I->moveBefore(FC0.Preheader->getTerminator());
+ I->moveBefore(*FC0.Preheader,
+ FC0.Preheader->getTerminator()->getIterator());
}
// insert instructions in reverse order to maintain dominance relationship
for (Instruction *I : reverse(SinkInsts)) {
assert(I->getParent() == FC1.Preheader);
- I->moveBefore(&*FC1.ExitBlock->getFirstInsertionPt());
+ I->moveBefore(*FC1.ExitBlock, FC1.ExitBlock->getFirstInsertionPt());
}
}
@@ -1491,7 +1492,7 @@ private:
/// 2. The successors of the guard have the same flow into/around the loop.
/// If the compare instructions are identical, then the first successor of the
/// guard must go to the same place (either the preheader of the loop or the
- /// NonLoopBlock). In other words, the the first successor of both loops must
+ /// NonLoopBlock). In other words, the first successor of both loops must
/// both go into the loop (i.e., the preheader) or go around the loop (i.e.,
/// the NonLoopBlock). The same must be true for the second successor.
bool haveIdenticalGuards(const FusionCandidate &FC0,
@@ -1624,7 +1625,7 @@ private:
// first, or undef otherwise. This is sound as exiting the first implies the
// second will exit too, __without__ taking the back-edge. [Their
// trip-counts are equal after all.
- // KB: Would this sequence be simpler to just just make FC0.ExitingBlock go
+ // KB: Would this sequence be simpler to just make FC0.ExitingBlock go
// to FC1.Header? I think this is basically what the three sequences are
// trying to accomplish; however, doing this directly in the CFG may mean
// the DT/PDT becomes invalid
@@ -1671,7 +1672,7 @@ private:
// exiting the first and jumping to the header of the second does not break
// the SSA property of the phis originally in the first loop. See also the
// comment above.
- Instruction *L1HeaderIP = &FC1.Header->front();
+ BasicBlock::iterator L1HeaderIP = FC1.Header->begin();
for (PHINode *LCPHI : OriginalFC0PHIs) {
int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch);
assert(L1LatchBBIdx >= 0 &&
@@ -1679,8 +1680,9 @@ private:
Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx);
- PHINode *L1HeaderPHI = PHINode::Create(
- LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP);
+ PHINode *L1HeaderPHI =
+ PHINode::Create(LCV->getType(), 2, LCPHI->getName() + ".afterFC0");
+ L1HeaderPHI->insertBefore(L1HeaderIP);
L1HeaderPHI->addIncoming(LCV, FC0.Latch);
L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()),
FC0.ExitingBlock);
@@ -1953,7 +1955,7 @@ private:
// exiting the first and jumping to the header of the second does not break
// the SSA property of the phis originally in the first loop. See also the
// comment above.
- Instruction *L1HeaderIP = &FC1.Header->front();
+ BasicBlock::iterator L1HeaderIP = FC1.Header->begin();
for (PHINode *LCPHI : OriginalFC0PHIs) {
int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch);
assert(L1LatchBBIdx >= 0 &&
@@ -1961,8 +1963,9 @@ private:
Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx);
- PHINode *L1HeaderPHI = PHINode::Create(
- LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP);
+ PHINode *L1HeaderPHI =
+ PHINode::Create(LCV->getType(), 2, LCPHI->getName() + ".afterFC0");
+ L1HeaderPHI->insertBefore(L1HeaderIP);
L1HeaderPHI->addIncoming(LCV, FC0.Latch);
L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()),
FC0.ExitingBlock);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 8572a442e784..3721564890dd 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -24,12 +24,6 @@
// memcmp, strlen, etc.
// Future floating point idioms to recognize in -ffast-math mode:
// fpowi
-// Future integer operation idioms to recognize:
-// ctpop
-//
-// Beware that isel's default lowering for ctpop is highly inefficient for
-// i64 and larger types when i64 is legal and the value has few bits set. It
-// would be good to enhance isel to emit a loop for ctpop in this case.
//
// This could recognize common matrix multiplies and dot product idioms and
// replace them with calls to BLAS (if linked in??).
@@ -948,9 +942,13 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
- if (BECst && ConstSize)
- AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
- ConstSize->getValue()->getZExtValue());
+ if (BECst && ConstSize) {
+ std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
+ std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
+ // FIXME: Should this check for overflow?
+ if (BEInt && SizeInt)
+ AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
+ }
// TODO: For this to be really effective, we have to dive into the pointer
// operand in the store. Store to &A[i] of 100 will always return may alias
@@ -1023,7 +1021,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
SCEVExpander Expander(*SE, *DL, "loop-idiom");
SCEVExpanderCleaner ExpCleaner(Expander);
- Type *DestInt8PtrTy = Builder.getInt8PtrTy(DestAS);
+ Type *DestInt8PtrTy = Builder.getPtrTy(DestAS);
Type *IntIdxTy = DL->getIndexType(DestPtr->getType());
bool Changed = false;
@@ -1107,7 +1105,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
PatternValue, ".memset_pattern");
GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
GV->setAlignment(Align(16));
- Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy);
+ Value *PatternPtr = GV;
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
// Set the TBAA info if present.
@@ -1284,7 +1282,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
// feeds the stores. Check for an alias by generating the base address and
// checking everything.
Value *StoreBasePtr = Expander.expandCodeFor(
- StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator());
+ StrStart, Builder.getPtrTy(StrAS), Preheader->getTerminator());
// From here on out, conservatively report to the pass manager that we've
// changed the IR, even if we later clean up these added instructions. There
@@ -1336,8 +1334,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
// For a memcpy, we have to make sure that the input array is not being
// mutated by the loop.
- Value *LoadBasePtr = Expander.expandCodeFor(
- LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator());
+ Value *LoadBasePtr = Expander.expandCodeFor(LdStart, Builder.getPtrTy(LdAS),
+ Preheader->getTerminator());
// If the store is a memcpy instruction, we must check if it will write to
// the load memory locations. So remove it from the ignored stores.
@@ -2026,7 +2024,8 @@ void LoopIdiomRecognize::transformLoopToCountable(
auto *LbBr = cast<BranchInst>(Body->getTerminator());
ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition());
- PHINode *TcPhi = PHINode::Create(CountTy, 2, "tcphi", &Body->front());
+ PHINode *TcPhi = PHINode::Create(CountTy, 2, "tcphi");
+ TcPhi->insertBefore(Body->begin());
Builder.SetInsertPoint(LbCond);
Instruction *TcDec = cast<Instruction>(Builder.CreateSub(
@@ -2132,7 +2131,8 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB,
ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition());
Type *Ty = TripCnt->getType();
- PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi", &Body->front());
+ PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi");
+ TcPhi->insertBefore(Body->begin());
Builder.SetInsertPoint(LbCond);
Instruction *TcDec = cast<Instruction>(
@@ -2411,7 +2411,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
// it's use count.
Instruction *InsertPt = nullptr;
if (auto *BitPosI = dyn_cast<Instruction>(BitPos))
- InsertPt = BitPosI->getInsertionPointAfterDef();
+ InsertPt = &**BitPosI->getInsertionPointAfterDef();
else
InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca();
if (!InsertPt)
@@ -2493,7 +2493,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
// Step 4: Rewrite the loop into a countable form, with canonical IV.
// The new canonical induction variable.
- Builder.SetInsertPoint(&LoopHeaderBB->front());
+ Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin());
auto *IV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv");
// The induction itself.
@@ -2817,11 +2817,11 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
// Step 3: Rewrite the loop into a countable form, with canonical IV.
// The new canonical induction variable.
- Builder.SetInsertPoint(&LoopHeaderBB->front());
+ Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin());
auto *CIV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv");
// The induction itself.
- Builder.SetInsertPoint(LoopHeaderBB->getFirstNonPHI());
+ Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->getFirstNonPHIIt());
auto *CIVNext =
Builder.CreateAdd(CIV, ConstantInt::get(Ty, 1), CIV->getName() + ".next",
/*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
index c9798a80978d..cfe069d00bce 100644
--- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp
@@ -29,8 +29,6 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -172,46 +170,6 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI,
return Changed;
}
-namespace {
-
-class LoopInstSimplifyLegacyPass : public LoopPass {
-public:
- static char ID; // Pass ID, replacement for typeid
-
- LoopInstSimplifyLegacyPass() : LoopPass(ID) {
- initializeLoopInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
- DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- AssumptionCache &AC =
- getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
- *L->getHeader()->getParent());
- const TargetLibraryInfo &TLI =
- getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
- *L->getHeader()->getParent());
- MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
- MemorySSAUpdater MSSAU(MSSA);
-
- return simplifyLoopInst(*L, DT, LI, AC, TLI, &MSSAU);
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.setPreservesCFG();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &) {
@@ -231,18 +189,3 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-char LoopInstSimplifyLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify",
- "Simplify instructions in loops", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(LoopInstSimplifyLegacyPass, "loop-instsimplify",
- "Simplify instructions in loops", false, false)
-
-Pass *llvm::createLoopInstSimplifyPass() {
- return new LoopInstSimplifyLegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index 91286ebcea33..277f530ee25f 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -1374,7 +1374,7 @@ bool LoopInterchangeTransform::transform() {
for (Instruction &I :
make_early_inc_range(make_range(InnerLoopPreHeader->begin(),
std::prev(InnerLoopPreHeader->end()))))
- I.moveBefore(OuterLoopHeader->getTerminator());
+ I.moveBeforePreserving(OuterLoopHeader->getTerminator());
}
Transformed |= adjustLoopLinks();
diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 179ccde8d035..5ec387300aac 100644
--- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -195,7 +195,8 @@ public:
Instruction *Source = Dep.getSource(LAI);
Instruction *Destination = Dep.getDestination(LAI);
- if (Dep.Type == MemoryDepChecker::Dependence::Unknown) {
+ if (Dep.Type == MemoryDepChecker::Dependence::Unknown ||
+ Dep.Type == MemoryDepChecker::Dependence::IndirectUnsafe) {
if (isa<LoadInst>(Source))
LoadsWithUnknownDepedence.insert(Source);
if (isa<LoadInst>(Destination))
@@ -443,8 +444,8 @@ public:
Cand.Load->getType(), InitialPtr, "load_initial",
/* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator());
- PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded",
- &L->getHeader()->front());
+ PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded");
+ PHI->insertBefore(L->getHeader()->begin());
PHI->addIncoming(Initial, PH);
Type *LoadType = Initial->getType();
diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
index 2c8a3351281b..a4f2dbf9a582 100644
--- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -269,11 +269,12 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
PI.pushBeforeNonSkippedPassCallback([&LAR, &LI](StringRef PassID, Any IR) {
if (isSpecialPass(PassID, {"PassManager"}))
return;
- assert(any_cast<const Loop *>(&IR) || any_cast<const LoopNest *>(&IR));
- const Loop **LPtr = any_cast<const Loop *>(&IR);
+ assert(llvm::any_cast<const Loop *>(&IR) ||
+ llvm::any_cast<const LoopNest *>(&IR));
+ const Loop **LPtr = llvm::any_cast<const Loop *>(&IR);
const Loop *L = LPtr ? *LPtr : nullptr;
if (!L)
- L = &any_cast<const LoopNest *>(IR)->getOutermostLoop();
+ L = &llvm::any_cast<const LoopNest *>(IR)->getOutermostLoop();
assert(L && "Loop should be valid for printing");
// Verify the loop structure and LCSSA form before visiting the loop.
@@ -312,7 +313,8 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
if (LAR.MSSA && !PassPA.getChecker<MemorySSAAnalysis>().preserved())
report_fatal_error("Loop pass manager using MemorySSA contains a pass "
- "that does not preserve MemorySSA");
+ "that does not preserve MemorySSA",
+ /*gen_crash_diag*/ false);
#ifndef NDEBUG
// LoopAnalysisResults should always be valid.
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 12852ae5c460..027dbb9c0f71 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -282,7 +282,7 @@ class LoopPredication {
Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
/// Same as above, *except* that this uses the SCEV definition of invariant
/// which is that an expression *can be made* invariant via SCEVExpander.
- /// Thus, this version is only suitable for finding an insert point to be be
+ /// Thus, this version is only suitable for finding an insert point to be
/// passed to SCEVExpander!
Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
ArrayRef<const SCEV *> Ops);
@@ -307,8 +307,9 @@ class LoopPredication {
widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
SCEVExpander &Expander,
Instruction *Guard);
- unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
- SCEVExpander &Expander, Instruction *Guard);
+ void widenChecks(SmallVectorImpl<Value *> &Checks,
+ SmallVectorImpl<Value *> &WidenedChecks,
+ SCEVExpander &Expander, Instruction *Guard);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
// If the loop always exits through another block in the loop, we should not
@@ -326,49 +327,8 @@ public:
bool runOnLoop(Loop *L);
};
-class LoopPredicationLegacyPass : public LoopPass {
-public:
- static char ID;
- LoopPredicationLegacyPass() : LoopPass(ID) {
- initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<BranchProbabilityInfoWrapperPass>();
- getLoopAnalysisUsage(AU);
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
- auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- std::unique_ptr<MemorySSAUpdater> MSSAU;
- if (MSSAWP)
- MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
- auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
- return LP.runOnLoop(L);
- }
-};
-
-char LoopPredicationLegacyPass::ID = 0;
} // end namespace
-INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
- "Loop predication", false, false)
-INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
- "Loop predication", false, false)
-
-Pass *llvm::createLoopPredicationPass() {
- return new LoopPredicationLegacyPass();
-}
-
PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
@@ -754,58 +714,15 @@ LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
}
}
-unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
- Value *Condition,
- SCEVExpander &Expander,
- Instruction *Guard) {
- unsigned NumWidened = 0;
- // The guard condition is expected to be in form of:
- // cond1 && cond2 && cond3 ...
- // Iterate over subconditions looking for icmp conditions which can be
- // widened across loop iterations. Widening these conditions remember the
- // resulting list of subconditions in Checks vector.
- SmallVector<Value *, 4> Worklist(1, Condition);
- SmallPtrSet<Value *, 4> Visited;
- Visited.insert(Condition);
- Value *WideableCond = nullptr;
- do {
- Value *Condition = Worklist.pop_back_val();
- Value *LHS, *RHS;
- using namespace llvm::PatternMatch;
- if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
- if (Visited.insert(LHS).second)
- Worklist.push_back(LHS);
- if (Visited.insert(RHS).second)
- Worklist.push_back(RHS);
- continue;
- }
-
- if (match(Condition,
- m_Intrinsic<Intrinsic::experimental_widenable_condition>())) {
- // Pick any, we don't care which
- WideableCond = Condition;
- continue;
- }
-
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
- if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
- Guard)) {
- Checks.push_back(*NewRangeCheck);
- NumWidened++;
- continue;
+void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks,
+ SmallVectorImpl<Value *> &WidenedChecks,
+ SCEVExpander &Expander, Instruction *Guard) {
+ for (auto &Check : Checks)
+ if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check))
+ if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) {
+ WidenedChecks.push_back(Check);
+ Check = *NewRangeCheck;
}
- }
-
- // Save the condition as is if we can't widen it
- Checks.push_back(Condition);
- } while (!Worklist.empty());
- // At the moment, our matching logic for wideable conditions implicitly
- // assumes we preserve the form: (br (and Cond, WC())). FIXME
- // Note that if there were multiple calls to wideable condition in the
- // traversal, we only need to keep one, and which one is arbitrary.
- if (WideableCond)
- Checks.push_back(WideableCond);
- return NumWidened;
}
bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
@@ -815,12 +732,13 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
TotalConsidered++;
SmallVector<Value *, 4> Checks;
- unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
- Guard);
- if (NumWidened == 0)
+ SmallVector<Value *> WidenedChecks;
+ parseWidenableGuard(Guard, Checks);
+ widenChecks(Checks, WidenedChecks, Expander, Guard);
+ if (WidenedChecks.empty())
return false;
- TotalWidened += NumWidened;
+ TotalWidened += WidenedChecks.size();
// Emit the new guard condition
IRBuilder<> Builder(findInsertPt(Guard, Checks));
@@ -833,7 +751,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
}
RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
- LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
+ LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
return true;
}
@@ -843,20 +761,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
LLVM_DEBUG(dbgs() << "Processing guard:\n");
LLVM_DEBUG(BI->dump());
- Value *Cond, *WC;
- BasicBlock *IfTrueBB, *IfFalseBB;
- bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB);
- assert(Parsed && "Must be able to parse widenable branch");
- (void)Parsed;
-
TotalConsidered++;
SmallVector<Value *, 4> Checks;
- unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
- Expander, BI);
- if (NumWidened == 0)
+ SmallVector<Value *> WidenedChecks;
+ parseWidenableGuard(BI, Checks);
+ // At the moment, our matching logic for wideable conditions implicitly
+ // assumes we preserve the form: (br (and Cond, WC())). FIXME
+ auto WC = extractWidenableCondition(BI);
+ Checks.push_back(WC);
+ widenChecks(Checks, WidenedChecks, Expander, BI);
+ if (WidenedChecks.empty())
return false;
- TotalWidened += NumWidened;
+ TotalWidened += WidenedChecks.size();
// Emit the new guard condition
IRBuilder<> Builder(findInsertPt(BI, Checks));
@@ -864,17 +781,18 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
auto *OldCond = BI->getCondition();
BI->setCondition(AllChecks);
if (InsertAssumesOfPredicatedGuardsConditions) {
+ BasicBlock *IfTrueBB = BI->getSuccessor(0);
Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
// If this block has other predecessors, we might not be able to use Cond.
// In this case, create a Phi where every other input is `true` and input
// from guard block is Cond.
- Value *AssumeCond = Cond;
+ Value *AssumeCond = Builder.CreateAnd(WidenedChecks);
if (!IfTrueBB->getUniquePredecessor()) {
auto *GuardBB = BI->getParent();
- auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB),
+ auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB),
"assume.cond");
for (auto *Pred : predecessors(IfTrueBB))
- PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred);
+ PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred);
AssumeCond = PN;
}
Builder.CreateAssumption(AssumeCond);
@@ -883,7 +801,7 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
assert(isGuardAsWidenableBranch(BI) &&
"Stopped being a guard after transform?");
- LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
+ LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
return true;
}
@@ -1008,6 +926,9 @@ bool LoopPredication::isLoopProfitableToPredicate() {
Numerator += Weight;
Denominator += Weight;
}
+ // If all weights are zero act as if there was no profile data
+ if (Denominator == 0)
+ return BranchProbability::getBranchProbability(1, NumSucc);
return BranchProbability::getBranchProbability(Numerator, Denominator);
} else {
assert(LatchBlock != ExitingBlock &&
@@ -1070,13 +991,9 @@ static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) {
} while (true);
if (BasicBlock *Pred = BB->getSinglePredecessor()) {
- auto *Term = Pred->getTerminator();
-
- Value *Cond, *WC;
- BasicBlock *IfTrueBB, *IfFalseBB;
- if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) &&
- IfTrueBB == BB)
- return cast<BranchInst>(Term);
+ if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()))
+ if (BI->getSuccessor(0) == BB && isWidenableBranch(BI))
+ return BI;
}
return nullptr;
}
@@ -1164,13 +1081,13 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
if (!BI)
continue;
- Use *Cond, *WC;
- BasicBlock *IfTrueBB, *IfFalseBB;
- if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
- L->contains(IfTrueBB)) {
- WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
- ChangedLoop = true;
- }
+ if (auto WC = extractWidenableCondition(BI))
+ if (L->contains(BI->getSuccessor(0))) {
+ assert(WC->hasOneUse() && "Not appropriate widenable branch!");
+ WC->user_back()->replaceUsesOfWith(
+ WC, ConstantInt::getTrue(BI->getContext()));
+ ChangedLoop = true;
+ }
}
if (ChangedLoop)
SE->forgetLoop(L);
diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
index 8d59fdff9236..028a487ecdbc 100644
--- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp
@@ -20,13 +20,11 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopIterator.h"
-#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
@@ -734,52 +732,3 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-class LoopSimplifyCFGLegacyPass : public LoopPass {
-public:
- static char ID; // Pass ID, replacement for typeid
- LoopSimplifyCFGLegacyPass() : LoopPass(ID) {
- initializeLoopSimplifyCFGLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
-
- DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
- std::optional<MemorySSAUpdater> MSSAU;
- if (MSSAA)
- MSSAU = MemorySSAUpdater(&MSSAA->getMSSA());
- if (MSSAA && VerifyMemorySSA)
- MSSAU->getMemorySSA()->verifyMemorySSA();
- bool DeleteCurrentLoop = false;
- bool Changed = simplifyLoopCFG(*L, DT, LI, SE, MSSAU ? &*MSSAU : nullptr,
- DeleteCurrentLoop);
- if (DeleteCurrentLoop)
- LPM.markLoopAsDeleted(*L);
- return Changed;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addPreserved<MemorySSAWrapperPass>();
- AU.addPreserved<DependenceAnalysisWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-} // end namespace
-
-char LoopSimplifyCFGLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg",
- "Simplify loop CFG", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_END(LoopSimplifyCFGLegacyPass, "loop-simplifycfg",
- "Simplify loop CFG", false, false)
-
-Pass *llvm::createLoopSimplifyCFGPass() {
- return new LoopSimplifyCFGLegacyPass();
-}
diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp
index 597c159682c5..6eedf95e7575 100644
--- a/llvm/lib/Transforms/Scalar/LoopSink.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp
@@ -36,13 +36,11 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Scalar.h"
@@ -79,7 +77,7 @@ static cl::opt<unsigned> MaxNumberOfUseBBsForSinking(
/// AdjustedFreq(BBs) = 99 / SinkFrequencyPercentThreshold%
static BlockFrequency adjustedSumFreq(SmallPtrSetImpl<BasicBlock *> &BBs,
BlockFrequencyInfo &BFI) {
- BlockFrequency T = 0;
+ BlockFrequency T(0);
for (BasicBlock *B : BBs)
T += BFI.getBlockFreq(B);
if (BBs.size() > 1)
@@ -222,9 +220,11 @@ static bool sinkInstruction(
// order. No need to stable sort as the block numbers are a total ordering.
SmallVector<BasicBlock *, 2> SortedBBsToSinkInto;
llvm::append_range(SortedBBsToSinkInto, BBsToSinkInto);
- llvm::sort(SortedBBsToSinkInto, [&](BasicBlock *A, BasicBlock *B) {
- return LoopBlockNumber.find(A)->second < LoopBlockNumber.find(B)->second;
- });
+ if (SortedBBsToSinkInto.size() > 1) {
+ llvm::sort(SortedBBsToSinkInto, [&](BasicBlock *A, BasicBlock *B) {
+ return LoopBlockNumber.find(A)->second < LoopBlockNumber.find(B)->second;
+ });
+ }
BasicBlock *MoveBB = *SortedBBsToSinkInto.begin();
// FIXME: Optimize the efficiency for cloned value replacement. The current
@@ -388,58 +388,3 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) {
return PA;
}
-
-namespace {
-struct LegacyLoopSinkPass : public LoopPass {
- static char ID;
- LegacyLoopSinkPass() : LoopPass(ID) {
- initializeLegacyLoopSinkPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override {
- if (skipLoop(L))
- return false;
-
- BasicBlock *Preheader = L->getLoopPreheader();
- if (!Preheader)
- return false;
-
- // Enable LoopSink only when runtime profile is available.
- // With static profile, the sinking decision may be sub-optimal.
- if (!Preheader->getParent()->hasProfileData())
- return false;
-
- AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
- auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
- bool Changed = sinkLoopInvariantInstructions(
- *L, AA, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(),
- getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(),
- MSSA, SE ? &SE->getSE() : nullptr);
-
- if (VerifyMemorySSA)
- MSSA.verifyMemorySSA();
-
- return Changed;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<BlockFrequencyInfoWrapperPass>();
- getLoopAnalysisUsage(AU);
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-}
-
-char LegacyLoopSinkPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false,
- false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_END(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, false)
-
-Pass *llvm::createLoopSinkPass() { return new LegacyLoopSinkPass(); }
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index a4369b83e732..39607464dd00 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -67,6 +67,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/IVUsers.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
@@ -188,8 +189,8 @@ static cl::opt<unsigned> SetupCostDepthLimit(
"lsr-setupcost-depth-limit", cl::Hidden, cl::init(7),
cl::desc("The limit on recursion depth for LSRs setup cost"));
-static cl::opt<bool> AllowTerminatingConditionFoldingAfterLSR(
- "lsr-term-fold", cl::Hidden, cl::init(false),
+static cl::opt<cl::boolOrDefault> AllowTerminatingConditionFoldingAfterLSR(
+ "lsr-term-fold", cl::Hidden,
cl::desc("Attempt to replace primary IV with other IV."));
static cl::opt<bool> AllowDropSolutionIfLessProfitable(
@@ -943,12 +944,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
}
}
- // All pointers have the same requirements, so canonicalize them to an
- // arbitrary pointer type to minimize variation.
- if (PointerType *PTy = dyn_cast<PointerType>(AccessTy.MemTy))
- AccessTy.MemTy = PointerType::get(IntegerType::get(PTy->getContext(), 1),
- PTy->getAddressSpace());
-
return AccessTy;
}
@@ -2794,18 +2789,6 @@ static Value *getWideOperand(Value *Oper) {
return Oper;
}
-/// Return true if we allow an IV chain to include both types.
-static bool isCompatibleIVType(Value *LVal, Value *RVal) {
- Type *LType = LVal->getType();
- Type *RType = RVal->getType();
- return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() &&
- // Different address spaces means (possibly)
- // different types of the pointer implementation,
- // e.g. i16 vs i32 so disallow that.
- (LType->getPointerAddressSpace() ==
- RType->getPointerAddressSpace()));
-}
-
/// Return an approximation of this SCEV expression's "base", or NULL for any
/// constant. Returning the expression itself is conservative. Returning a
/// deeper subexpression is more precise and valid as long as it isn't less
@@ -2985,7 +2968,7 @@ void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper,
continue;
Value *PrevIV = getWideOperand(Chain.Incs.back().IVOperand);
- if (!isCompatibleIVType(PrevIV, NextIV))
+ if (PrevIV->getType() != NextIV->getType())
continue;
// A phi node terminates a chain.
@@ -3279,7 +3262,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// do this if we also found a wide value for the head of the chain.
if (isa<PHINode>(Chain.tailUserInst())) {
for (PHINode &Phi : L->getHeader()->phis()) {
- if (!isCompatibleIVType(&Phi, IVSrc))
+ if (Phi.getType() != IVSrc->getType())
continue;
Instruction *PostIncV = dyn_cast<Instruction>(
Phi.getIncomingValueForBlock(L->getLoopLatch()));
@@ -3488,6 +3471,11 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
SmallVector<const SCEV *, 8> Worklist(RegUses.begin(), RegUses.end());
SmallPtrSet<const SCEV *, 32> Visited;
+ // Don't collect outside uses if we are favoring postinc - the instructions in
+ // the loop are more important than the ones outside of it.
+ if (AMK == TTI::AMK_PostIndexed)
+ return;
+
while (!Worklist.empty()) {
const SCEV *S = Worklist.pop_back_val();
@@ -5559,10 +5547,12 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF,
"a scale at the same time!");
Constant *C = ConstantInt::getSigned(SE.getEffectiveSCEVType(OpTy),
-(uint64_t)Offset);
- if (C->getType() != OpTy)
- C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
- OpTy, false),
- C, OpTy);
+ if (C->getType() != OpTy) {
+ C = ConstantFoldCastOperand(
+ CastInst::getCastOpcode(C, false, OpTy, false), C, OpTy,
+ CI->getModule()->getDataLayout());
+ assert(C && "Cast of ConstantInt should have folded");
+ }
CI->setOperand(1, C);
}
@@ -5610,7 +5600,8 @@ void LSRInstance::RewriteForPHI(
.setKeepOneInputPHIs());
} else {
SmallVector<BasicBlock*, 2> NewBBs;
- SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DT, &LI);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+ SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DTU, &LI);
NewBB = NewBBs[0];
}
// If NewBB==NULL, then SplitCriticalEdge refused to split because all
@@ -6949,7 +6940,19 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
}
}
- if (AllowTerminatingConditionFoldingAfterLSR) {
+ const bool EnableFormTerm = [&] {
+ switch (AllowTerminatingConditionFoldingAfterLSR) {
+ case cl::BOU_TRUE:
+ return true;
+ case cl::BOU_FALSE:
+ return false;
+ case cl::BOU_UNSET:
+ return TTI.shouldFoldTerminatingConditionAfterLSR();
+ }
+ llvm_unreachable("Unhandled cl::boolOrDefault enum");
+ }();
+
+ if (EnableFormTerm) {
if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) {
auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
index 9c6e4ebf62a9..7b4c54370e48 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
@@ -111,7 +111,7 @@ static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
if (!S)
continue;
- if (S->getString().startswith(Prefix))
+ if (S->getString().starts_with(Prefix))
return true;
}
}
@@ -153,9 +153,11 @@ static bool computeUnrollAndJamCount(
LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE,
const SmallPtrSetImpl<const Value *> &EphValues,
OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
- unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount,
- unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP,
+ unsigned OuterTripMultiple, const UnrollCostEstimator &OuterUCE,
+ unsigned InnerTripCount, unsigned InnerLoopSize,
+ TargetTransformInfo::UnrollingPreferences &UP,
TargetTransformInfo::PeelingPreferences &PP) {
+ unsigned OuterLoopSize = OuterUCE.getRolledLoopSize();
// First up use computeUnrollCount from the loop unroller to get a count
// for unrolling the outer loop, plus any loops requiring explicit
// unrolling we leave to the unroller. This uses UP.Threshold /
@@ -165,7 +167,7 @@ static bool computeUnrollAndJamCount(
bool UseUpperBound = false;
bool ExplicitUnroll = computeUnrollCount(
L, TTI, DT, LI, AC, SE, EphValues, ORE, OuterTripCount, MaxTripCount,
- /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP,
+ /*MaxOrZero*/ false, OuterTripMultiple, OuterUCE, UP, PP,
UseUpperBound);
if (ExplicitUnroll || UseUpperBound) {
// If the user explicitly set the loop as unrolled, dont UnJ it. Leave it
@@ -318,39 +320,28 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
}
// Approximate the loop size and collect useful info
- unsigned NumInlineCandidates;
- bool NotDuplicatable;
- bool Convergent;
SmallPtrSet<const Value *, 32> EphValues;
CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
Loop *SubLoop = L->getSubLoops()[0];
- InstructionCost InnerLoopSizeIC =
- ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable,
- Convergent, TTI, EphValues, UP.BEInsns);
- InstructionCost OuterLoopSizeIC =
- ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
- TTI, EphValues, UP.BEInsns);
- LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterLoopSizeIC << "\n");
- LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSizeIC << "\n");
+ UnrollCostEstimator InnerUCE(SubLoop, TTI, EphValues, UP.BEInsns);
+ UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);
- if (!InnerLoopSizeIC.isValid() || !OuterLoopSizeIC.isValid()) {
+ if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
- << " with invalid cost.\n");
+ << " which cannot be duplicated or have invalid cost.\n");
return LoopUnrollResult::Unmodified;
}
- unsigned InnerLoopSize = *InnerLoopSizeIC.getValue();
- unsigned OuterLoopSize = *OuterLoopSizeIC.getValue();
- if (NotDuplicatable) {
- LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable "
- "instructions.\n");
- return LoopUnrollResult::Unmodified;
- }
- if (NumInlineCandidates != 0) {
+ unsigned InnerLoopSize = InnerUCE.getRolledLoopSize();
+ LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterUCE.getRolledLoopSize()
+ << "\n");
+ LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n");
+
+ if (InnerUCE.NumInlineCandidates != 0 || OuterUCE.NumInlineCandidates != 0) {
LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
return LoopUnrollResult::Unmodified;
}
- if (Convergent) {
+ if (InnerUCE.Convergent || OuterUCE.Convergent) {
LLVM_DEBUG(
dbgs() << " Not unrolling loop with convergent instructions.\n");
return LoopUnrollResult::Unmodified;
@@ -379,7 +370,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
// Decide if, and by how much, to unroll
bool IsCountSetExplicitly = computeUnrollAndJamCount(
L, SubLoop, TTI, DT, LI, &AC, SE, EphValues, &ORE, OuterTripCount,
- OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP);
+ OuterTripMultiple, OuterUCE, InnerTripCount, InnerLoopSize, UP, PP);
if (UP.Count <= 1)
return LoopUnrollResult::Unmodified;
// Unroll factor (Count) must be less or equal to TripCount.
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
index 335b489d3cb2..f14541a1a037 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -662,19 +662,16 @@ static std::optional<EstimatedUnrollCost> analyzeLoopUnrollCost(
unsigned(*RolledDynamicCost.getValue())}};
}
-/// ApproximateLoopSize - Approximate the size of the loop.
-InstructionCost llvm::ApproximateLoopSize(
- const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, bool &Convergent,
- const TargetTransformInfo &TTI,
+UnrollCostEstimator::UnrollCostEstimator(
+ const Loop *L, const TargetTransformInfo &TTI,
const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) {
CodeMetrics Metrics;
for (BasicBlock *BB : L->blocks())
Metrics.analyzeBasicBlock(BB, TTI, EphValues);
- NumCalls = Metrics.NumInlineCandidates;
+ NumInlineCandidates = Metrics.NumInlineCandidates;
NotDuplicatable = Metrics.notDuplicatable;
Convergent = Metrics.convergent;
-
- InstructionCost LoopSize = Metrics.NumInsts;
+ LoopSize = Metrics.NumInsts;
// Don't allow an estimate of size zero. This would allows unrolling of loops
// with huge iteration counts, which is a compile time problem even if it's
@@ -685,8 +682,17 @@ InstructionCost llvm::ApproximateLoopSize(
if (LoopSize.isValid() && LoopSize < BEInsns + 1)
// This is an open coded max() on InstructionCost
LoopSize = BEInsns + 1;
+}
- return LoopSize;
+uint64_t UnrollCostEstimator::getUnrolledLoopSize(
+ const TargetTransformInfo::UnrollingPreferences &UP,
+ unsigned CountOverwrite) const {
+ unsigned LS = *LoopSize.getValue();
+ assert(LS >= UP.BEInsns && "LoopSize should not be less than BEInsns!");
+ if (CountOverwrite)
+ return static_cast<uint64_t>(LS - UP.BEInsns) * CountOverwrite + UP.BEInsns;
+ else
+ return static_cast<uint64_t>(LS - UP.BEInsns) * UP.Count + UP.BEInsns;
}
// Returns the loop hint metadata node with the given name (for example,
@@ -746,36 +752,10 @@ static unsigned getFullUnrollBoostingFactor(const EstimatedUnrollCost &Cost,
return MaxPercentThresholdBoost;
}
-// Produce an estimate of the unrolled cost of the specified loop. This
-// is used to a) produce a cost estimate for partial unrolling and b) to
-// cheaply estimate cost for full unrolling when we don't want to symbolically
-// evaluate all iterations.
-class UnrollCostEstimator {
- const unsigned LoopSize;
-
-public:
- UnrollCostEstimator(Loop &L, unsigned LoopSize) : LoopSize(LoopSize) {}
-
- // Returns loop size estimation for unrolled loop, given the unrolling
- // configuration specified by UP.
- uint64_t
- getUnrolledLoopSize(const TargetTransformInfo::UnrollingPreferences &UP,
- const unsigned CountOverwrite = 0) const {
- assert(LoopSize >= UP.BEInsns &&
- "LoopSize should not be less than BEInsns!");
- if (CountOverwrite)
- return static_cast<uint64_t>(LoopSize - UP.BEInsns) * CountOverwrite +
- UP.BEInsns;
- else
- return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count +
- UP.BEInsns;
- }
-};
-
static std::optional<unsigned>
shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo,
const unsigned TripMultiple, const unsigned TripCount,
- const UnrollCostEstimator UCE,
+ unsigned MaxTripCount, const UnrollCostEstimator UCE,
const TargetTransformInfo::UnrollingPreferences &UP) {
// Using unroll pragma
@@ -796,6 +776,10 @@ shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo,
if (PInfo.PragmaFullUnroll && TripCount != 0)
return TripCount;
+ if (PInfo.PragmaEnableUnroll && !TripCount && MaxTripCount &&
+ MaxTripCount <= UnrollMaxUpperBound)
+ return MaxTripCount;
+
// if didn't return until here, should continue to other priorties
return std::nullopt;
}
@@ -888,14 +872,14 @@ shouldPartialUnroll(const unsigned LoopSize, const unsigned TripCount,
// refactored into it own function.
bool llvm::computeUnrollCount(
Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI,
- AssumptionCache *AC,
- ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues,
+ AssumptionCache *AC, ScalarEvolution &SE,
+ const SmallPtrSetImpl<const Value *> &EphValues,
OptimizationRemarkEmitter *ORE, unsigned TripCount, unsigned MaxTripCount,
- bool MaxOrZero, unsigned TripMultiple, unsigned LoopSize,
+ bool MaxOrZero, unsigned TripMultiple, const UnrollCostEstimator &UCE,
TargetTransformInfo::UnrollingPreferences &UP,
TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) {
- UnrollCostEstimator UCE(*L, LoopSize);
+ unsigned LoopSize = UCE.getRolledLoopSize();
const bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0;
const bool PragmaFullUnroll = hasUnrollFullPragma(L);
@@ -922,7 +906,7 @@ bool llvm::computeUnrollCount(
// 1st priority is unroll count set by "unroll-count" option.
// 2nd priority is unroll count set by pragma.
if (auto UnrollFactor = shouldPragmaUnroll(L, PInfo, TripMultiple, TripCount,
- UCE, UP)) {
+ MaxTripCount, UCE, UP)) {
UP.Count = *UnrollFactor;
if (UserUnrollCount || (PragmaCount > 0)) {
@@ -1177,9 +1161,6 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
return LoopUnrollResult::Unmodified;
bool OptForSize = L->getHeader()->getParent()->hasOptSize();
- unsigned NumInlineCandidates;
- bool NotDuplicatable;
- bool Convergent;
TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
L, SE, TTI, BFI, PSI, ORE, OptLevel, ProvidedThreshold, ProvidedCount,
ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound,
@@ -1196,30 +1177,22 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
SmallPtrSet<const Value *, 32> EphValues;
CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
- InstructionCost LoopSizeIC =
- ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
- TTI, EphValues, UP.BEInsns);
- LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSizeIC << "\n");
-
- if (!LoopSizeIC.isValid()) {
+ UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
+ if (!UCE.canUnroll()) {
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
- << " with invalid cost.\n");
+ << " which cannot be duplicated or have invalid cost.\n");
return LoopUnrollResult::Unmodified;
}
- unsigned LoopSize = *LoopSizeIC.getValue();
- if (NotDuplicatable) {
- LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable"
- << " instructions.\n");
- return LoopUnrollResult::Unmodified;
- }
+ unsigned LoopSize = UCE.getRolledLoopSize();
+ LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n");
// When optimizing for size, use LoopSize + 1 as threshold (we use < Threshold
// later), to (fully) unroll loops, if it does not increase code size.
if (OptForSize)
UP.Threshold = std::max(UP.Threshold, LoopSize + 1);
- if (NumInlineCandidates != 0) {
+ if (UCE.NumInlineCandidates != 0) {
LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
return LoopUnrollResult::Unmodified;
}
@@ -1261,7 +1234,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
// Assuming n is the same on all threads, any kind of unrolling is
// safe. But currently llvm's notion of convergence isn't powerful
// enough to express this.
- if (Convergent)
+ if (UCE.Convergent)
UP.AllowRemainder = false;
// Try to find the trip count upper bound if we cannot find the exact trip
@@ -1277,8 +1250,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
// fully unroll the loop.
bool UseUpperBound = false;
bool IsCountSetExplicitly = computeUnrollCount(
- L, TTI, DT, LI, &AC, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero,
- TripMultiple, LoopSize, UP, PP, UseUpperBound);
+ L, TTI, DT, LI, &AC, SE, EphValues, &ORE, TripCount, MaxTripCount,
+ MaxOrZero, TripMultiple, UCE, UP, PP, UseUpperBound);
if (!UP.Count)
return LoopUnrollResult::Unmodified;
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 454aa56be531..6f87e4d91d2c 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -13,7 +13,6 @@
#include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
@@ -21,10 +20,8 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/MisExpect.h"
#include <cmath>
@@ -105,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
SI.setCondition(ArgValue);
-
- SI.setMetadata(LLVMContext::MD_prof,
- MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+ setBranchWeights(SI, Weights);
return true;
}
@@ -416,29 +410,3 @@ PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F,
return PreservedAnalyses::all();
}
-
-namespace {
-/// Legacy pass for lowering expect intrinsics out of the IR.
-///
-/// When this pass is run over a function it uses expect intrinsics which feed
-/// branches and switches to provide branch weight metadata for those
-/// terminators. It then removes the expect intrinsics from the IR so the rest
-/// of the optimizer can ignore them.
-class LowerExpectIntrinsic : public FunctionPass {
-public:
- static char ID;
- LowerExpectIntrinsic() : FunctionPass(ID) {
- initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); }
-};
-} // namespace
-
-char LowerExpectIntrinsic::ID = 0;
-INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect",
- "Lower 'expect' Intrinsics", false, false)
-
-FunctionPass *llvm::createLowerExpectIntrinsicPass() {
- return new LowerExpectIntrinsic();
-}
diff --git a/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp
index 8dc037b10cc8..a59ecdda1746 100644
--- a/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp
@@ -20,25 +20,10 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/GuardUtils.h"
using namespace llvm;
-namespace {
-struct LowerGuardIntrinsicLegacyPass : public FunctionPass {
- static char ID;
- LowerGuardIntrinsicLegacyPass() : FunctionPass(ID) {
- initializeLowerGuardIntrinsicLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-};
-}
-
static bool lowerGuardIntrinsic(Function &F) {
// Check if we can cheaply rule out the possibility of not having any work to
// do.
@@ -71,19 +56,6 @@ static bool lowerGuardIntrinsic(Function &F) {
return true;
}
-bool LowerGuardIntrinsicLegacyPass::runOnFunction(Function &F) {
- return lowerGuardIntrinsic(F);
-}
-
-char LowerGuardIntrinsicLegacyPass::ID = 0;
-INITIALIZE_PASS(LowerGuardIntrinsicLegacyPass, "lower-guard-intrinsic",
- "Lower the guard intrinsic to normal control flow", false,
- false)
-
-Pass *llvm::createLowerGuardIntrinsicPass() {
- return new LowerGuardIntrinsicLegacyPass();
-}
-
PreservedAnalyses LowerGuardIntrinsicPass::run(Function &F,
FunctionAnalysisManager &AM) {
if (lowerGuardIntrinsic(F))
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f46ea6a20afa..72b9db1e73d7 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,6 +19,7 @@
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
@@ -36,12 +37,9 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"
@@ -180,7 +178,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
assert((!isa<ConstantInt>(Stride) ||
cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
"Stride must be >= the number of elements in the result vector.");
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
// Compute the start of the vector with index VecIdx as VecIdx * Stride.
Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
@@ -192,11 +189,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
else
VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
- // Cast elementwise vector start pointer to a pointer to a vector
- // (EltType x NumElements)*.
- auto *VecType = FixedVectorType::get(EltType, NumElements);
- Type *VecPtrType = PointerType::get(VecType, AS);
- return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
+ return VecStart;
}
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -1063,13 +1056,6 @@ public:
return Changed;
}
- /// Turns \p BasePtr into an elementwise pointer to \p EltType.
- Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- Type *EltPtrType = PointerType::get(EltType, AS);
- return Builder.CreatePointerCast(BasePtr, EltPtrType);
- }
-
/// Replace intrinsic calls
bool VisitCallInst(CallInst *Inst) {
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
@@ -1121,7 +1107,7 @@ public:
auto *VType = cast<VectorType>(Ty);
Type *EltTy = VType->getElementType();
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
- Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
+ Value *EltPtr = Ptr;
MatrixTy Result;
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
Value *GEP = computeVectorAddr(
@@ -1147,17 +1133,11 @@ public:
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
ResultShape.NumColumns);
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- return loadMatrix(TileTy, TilePtr, Align,
+ return loadMatrix(TileTy, TileStart, Align,
Builder.getInt64(MatrixShape.getStride()), IsVolatile,
ResultShape, Builder);
}
@@ -1193,17 +1173,11 @@ public:
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
StoreVal.getNumColumns());
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
+ storeMatrix(TileTy, StoreVal, TileStart, MAlign,
Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
}
@@ -1213,7 +1187,7 @@ public:
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
IRBuilder<> &Builder) {
auto VType = cast<VectorType>(Ty);
- Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
+ Value *EltPtr = Ptr;
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
EltPtr,
@@ -2180,7 +2154,7 @@ public:
/// Returns true if \p V is a matrix value in the given subprogram.
bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
- /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
+ /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
/// \p SS.
void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
auto M = Inst2Matrix.find(V);
@@ -2201,7 +2175,7 @@ public:
write("<no called fn>");
else {
StringRef Name = CI->getCalledFunction()->getName();
- if (!Name.startswith("llvm.matrix")) {
+ if (!Name.starts_with("llvm.matrix")) {
write(Name);
return;
}
diff --git a/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp b/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp
index e2de322933bc..3c977b816a05 100644
--- a/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp
@@ -19,24 +19,10 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
using namespace llvm;
-namespace {
-struct LowerWidenableConditionLegacyPass : public FunctionPass {
- static char ID;
- LowerWidenableConditionLegacyPass() : FunctionPass(ID) {
- initializeLowerWidenableConditionLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-};
-}
-
static bool lowerWidenableCondition(Function &F) {
// Check if we can cheaply rule out the possibility of not having any work to
// do.
@@ -65,19 +51,6 @@ static bool lowerWidenableCondition(Function &F) {
return true;
}
-bool LowerWidenableConditionLegacyPass::runOnFunction(Function &F) {
- return lowerWidenableCondition(F);
-}
-
-char LowerWidenableConditionLegacyPass::ID = 0;
-INITIALIZE_PASS(LowerWidenableConditionLegacyPass, "lower-widenable-condition",
- "Lower the widenable condition to default true value", false,
- false)
-
-Pass *llvm::createLowerWidenableConditionPass() {
- return new LowerWidenableConditionLegacyPass();
-}
-
PreservedAnalyses LowerWidenableConditionPass::run(Function &F,
FunctionAnalysisManager &AM) {
if (lowerWidenableCondition(F))
diff --git a/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp
index a3f09a5a33c3..78e474f925b5 100644
--- a/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp
+++ b/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp
@@ -42,17 +42,6 @@
using namespace llvm;
-namespace {
-struct MakeGuardsExplicitLegacyPass : public FunctionPass {
- static char ID;
- MakeGuardsExplicitLegacyPass() : FunctionPass(ID) {
- initializeMakeGuardsExplicitLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-};
-}
-
static void turnToExplicitForm(CallInst *Guard, Function *DeoptIntrinsic) {
// Replace the guard with an explicit branch (just like in GuardWidening).
BasicBlock *OriginalBB = Guard->getParent();
@@ -89,15 +78,6 @@ static bool explicifyGuards(Function &F) {
return true;
}
-bool MakeGuardsExplicitLegacyPass::runOnFunction(Function &F) {
- return explicifyGuards(F);
-}
-
-char MakeGuardsExplicitLegacyPass::ID = 0;
-INITIALIZE_PASS(MakeGuardsExplicitLegacyPass, "make-guards-explicit",
- "Lower the guard intrinsic to explicit control flow form",
- false, false)
-
PreservedAnalyses MakeGuardsExplicitPass::run(Function &F,
FunctionAnalysisManager &) {
if (explicifyGuards(F))
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 68642a01b37c..0e55249d63a8 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -19,12 +19,15 @@
#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/BasicBlock.h"
@@ -66,9 +69,9 @@ static cl::opt<bool> EnableMemCpyOptWithoutLibcalls(
STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted");
STATISTIC(NumMemSetInfer, "Number of memsets inferred");
-STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy");
-STATISTIC(NumCpyToSet, "Number of memcpys converted to memset");
-STATISTIC(NumCallSlot, "Number of call slot optimizations performed");
+STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy");
+STATISTIC(NumCpyToSet, "Number of memcpys converted to memset");
+STATISTIC(NumCallSlot, "Number of call slot optimizations performed");
STATISTIC(NumStackMove, "Number of stack-move optimizations performed");
namespace {
@@ -333,6 +336,17 @@ static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA,
return !MSSA->dominates(Clobber, Start);
}
+// Update AA metadata
+static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
+ // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be
+ // handled here, but combineMetadata doesn't support them yet
+ unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
+ LLVMContext::MD_noalias,
+ LLVMContext::MD_invariant_group,
+ LLVMContext::MD_access_group};
+ combineMetadata(ReplInst, I, KnownIDs, true);
+}
+
/// When scanning forward over instructions, we look for some other patterns to
/// fold away. In particular, this looks for stores to neighboring locations of
/// memory. If it sees enough consecutive ones, it attempts to merge them
@@ -357,21 +371,13 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
// Keeps track of the last memory use or def before the insertion point for
// the new memset. The new MemoryDef for the inserted memsets will be inserted
- // after MemInsertPoint. It points to either LastMemDef or to the last user
- // before the insertion point of the memset, if there are any such users.
+ // after MemInsertPoint.
MemoryUseOrDef *MemInsertPoint = nullptr;
- // Keeps track of the last MemoryDef between StartInst and the insertion point
- // for the new memset. This will become the defining access of the inserted
- // memsets.
- MemoryDef *LastMemDef = nullptr;
for (++BI; !BI->isTerminator(); ++BI) {
auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
- if (CurrentAcc) {
+ if (CurrentAcc)
MemInsertPoint = CurrentAcc;
- if (auto *CurrentDef = dyn_cast<MemoryDef>(CurrentAcc))
- LastMemDef = CurrentDef;
- }
// Calls that only access inaccessible memory do not block merging
// accessible stores.
@@ -475,16 +481,13 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
if (!Range.TheStores.empty())
AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
- assert(LastMemDef && MemInsertPoint &&
- "Both LastMemDef and MemInsertPoint need to be set");
auto *NewDef =
cast<MemoryDef>(MemInsertPoint->getMemoryInst() == &*BI
? MSSAU->createMemoryAccessBefore(
- AMemSet, LastMemDef, MemInsertPoint)
+ AMemSet, nullptr, MemInsertPoint)
: MSSAU->createMemoryAccessAfter(
- AMemSet, LastMemDef, MemInsertPoint));
+ AMemSet, nullptr, MemInsertPoint));
MSSAU->insertDef(NewDef, /*RenameUses=*/true);
- LastMemDef = NewDef;
MemInsertPoint = NewDef;
// Zap all the stores.
@@ -693,7 +696,7 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI,
auto *LastDef =
cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI));
- auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef);
+ auto *NewAccess = MSSAU->createMemoryAccessAfter(M, nullptr, LastDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true);
eraseInstruction(SI);
@@ -814,7 +817,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
// store, so we do not need to rename uses.
auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI));
auto *NewAccess = MSSAU->createMemoryAccessBefore(
- M, StoreDef->getDefiningAccess(), StoreDef);
+ M, nullptr, StoreDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false);
eraseInstruction(SI);
@@ -922,10 +925,12 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
return false;
}
- // Check that accessing the first srcSize bytes of dest will not cause a
- // trap. Otherwise the transform is invalid since it might cause a trap
- // to occur earlier than it otherwise would.
- if (!isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize),
+ // Check that storing to the first srcSize bytes of dest will not cause a
+ // trap or data race.
+ bool ExplicitlyDereferenceableOnly;
+ if (!isWritableObject(getUnderlyingObject(cpyDest),
+ ExplicitlyDereferenceableOnly) ||
+ !isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize),
DL, C, AC, DT)) {
LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer not dereferenceable\n");
return false;
@@ -1040,12 +1045,13 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
// Since we're changing the parameter to the callsite, we need to make sure
// that what would be the new parameter dominates the callsite.
+ bool NeedMoveGEP = false;
if (!DT->dominates(cpyDest, C)) {
// Support moving a constant index GEP before the call.
auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest);
if (GEP && GEP->hasAllConstantIndices() &&
DT->dominates(GEP->getPointerOperand(), C))
- GEP->moveBefore(C);
+ NeedMoveGEP = true;
else
return false;
}
@@ -1064,29 +1070,19 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
// We can't create address space casts here because we don't know if they're
// safe for the target.
- if (cpySrc->getType()->getPointerAddressSpace() !=
- cpyDest->getType()->getPointerAddressSpace())
+ if (cpySrc->getType() != cpyDest->getType())
return false;
for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI)
if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc &&
- cpySrc->getType()->getPointerAddressSpace() !=
- C->getArgOperand(ArgI)->getType()->getPointerAddressSpace())
+ cpySrc->getType() != C->getArgOperand(ArgI)->getType())
return false;
// All the checks have passed, so do the transformation.
bool changedArgument = false;
for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI)
if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc) {
- Value *Dest = cpySrc->getType() == cpyDest->getType() ? cpyDest
- : CastInst::CreatePointerCast(cpyDest, cpySrc->getType(),
- cpyDest->getName(), C);
changedArgument = true;
- if (C->getArgOperand(ArgI)->getType() == Dest->getType())
- C->setArgOperand(ArgI, Dest);
- else
- C->setArgOperand(ArgI, CastInst::CreatePointerCast(
- Dest, C->getArgOperand(ArgI)->getType(),
- Dest->getName(), C));
+ C->setArgOperand(ArgI, cpyDest);
}
if (!changedArgument)
@@ -1098,22 +1094,20 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
cast<AllocaInst>(cpyDest)->setAlignment(srcAlign);
}
+ if (NeedMoveGEP) {
+ auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest);
+ GEP->moveBefore(C);
+ }
+
if (SkippedLifetimeStart) {
SkippedLifetimeStart->moveBefore(C);
MSSAU->moveBefore(MSSA->getMemoryAccess(SkippedLifetimeStart),
MSSA->getMemoryAccess(C));
}
- // Update AA metadata
- // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be
- // handled here, but combineMetadata doesn't support them yet
- unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias,
- LLVMContext::MD_invariant_group,
- LLVMContext::MD_access_group};
- combineMetadata(C, cpyLoad, KnownIDs, true);
+ combineAAMetadata(C, cpyLoad);
if (cpyLoad != cpyStore)
- combineMetadata(C, cpyStore, KnownIDs, true);
+ combineAAMetadata(C, cpyStore);
++NumCallSlot;
return true;
@@ -1203,7 +1197,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)));
auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M));
- auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef);
+ auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true);
// Remove the instruction we're replacing.
@@ -1300,12 +1294,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize);
Value *MemsetLen = Builder.CreateSelect(
Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff);
- unsigned DestAS = Dest->getType()->getPointerAddressSpace();
Instruction *NewMemSet = Builder.CreateMemSet(
- Builder.CreateGEP(
- Builder.getInt8Ty(),
- Builder.CreatePointerCast(Dest, Builder.getInt8PtrTy(DestAS)),
- SrcSize),
+ Builder.CreateGEP(Builder.getInt8Ty(), Dest, SrcSize),
MemSet->getOperand(1), MemsetLen, Alignment);
assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) &&
@@ -1315,7 +1305,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
auto *LastDef =
cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy));
auto *NewAccess = MSSAU->createMemoryAccessBefore(
- NewMemSet, LastDef->getDefiningAccess(), LastDef);
+ NewMemSet, nullptr, LastDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true);
eraseInstruction(MemSet);
@@ -1420,7 +1410,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
CopySize, MemCpy->getDestAlign());
auto *LastDef =
cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy));
- auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef);
+ auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true);
return true;
@@ -1440,7 +1430,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
// allocas that aren't captured.
bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
AllocaInst *DestAlloca,
- AllocaInst *SrcAlloca, uint64_t Size,
+ AllocaInst *SrcAlloca, TypeSize Size,
BatchAAResults &BAA) {
LLVM_DEBUG(dbgs() << "Stack Move: Attempting to optimize:\n"
<< *Store << "\n");
@@ -1451,35 +1441,30 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
return false;
}
- // 1. Check that copy is full. Calculate the static size of the allocas to be
- // merged, bail out if we can't.
+ // Check that copy is full with static size.
const DataLayout &DL = DestAlloca->getModule()->getDataLayout();
std::optional<TypeSize> SrcSize = SrcAlloca->getAllocationSize(DL);
- if (!SrcSize || SrcSize->isScalable() || Size != SrcSize->getFixedValue()) {
+ if (!SrcSize || Size != *SrcSize) {
LLVM_DEBUG(dbgs() << "Stack Move: Source alloca size mismatch\n");
return false;
}
std::optional<TypeSize> DestSize = DestAlloca->getAllocationSize(DL);
- if (!DestSize || DestSize->isScalable() ||
- Size != DestSize->getFixedValue()) {
+ if (!DestSize || Size != *DestSize) {
LLVM_DEBUG(dbgs() << "Stack Move: Destination alloca size mismatch\n");
return false;
}
- // 2-1. Check that src and dest are static allocas, which are not affected by
- // stacksave/stackrestore.
- if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca() ||
- SrcAlloca->getParent() != Load->getParent() ||
- SrcAlloca->getParent() != Store->getParent())
+ if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca())
return false;
- // 2-2. Check that src and dest are never captured, unescaped allocas. Also
- // collect lifetime markers first/last users in order to shrink wrap the
- // lifetimes, and instructions with noalias metadata to remove them.
+ // Check that src and dest are never captured, unescaped allocas. Also
+ // find the nearest common dominator and postdominator for all users in
+ // order to shrink wrap the lifetimes, and instructions with noalias metadata
+ // to remove them.
SmallVector<Instruction *, 4> LifetimeMarkers;
- Instruction *FirstUser = nullptr, *LastUser = nullptr;
SmallSet<Instruction *, 4> NoAliasInstrs;
+ bool SrcNotDom = false;
// Recursively track the user and check whether modified alias exist.
auto IsDereferenceableOrNull = [](Value *V, const DataLayout &DL) -> bool {
@@ -1499,6 +1484,12 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
Instruction *I = Worklist.back();
Worklist.pop_back();
for (const Use &U : I->uses()) {
+ auto *UI = cast<Instruction>(U.getUser());
+ // If any use that isn't dominated by SrcAlloca exists, we move src
+ // alloca to the entry before the transformation.
+ if (!DT->dominates(SrcAlloca, UI))
+ SrcNotDom = true;
+
if (Visited.size() >= MaxUsesToExplore) {
LLVM_DEBUG(
dbgs()
@@ -1512,22 +1503,15 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
return false;
case UseCaptureKind::PASSTHROUGH:
// Instructions cannot have non-instruction users.
- Worklist.push_back(cast<Instruction>(U.getUser()));
+ Worklist.push_back(UI);
continue;
case UseCaptureKind::NO_CAPTURE: {
- auto *UI = cast<Instruction>(U.getUser());
- if (DestAlloca->getParent() != UI->getParent())
- return false;
- if (!FirstUser || UI->comesBefore(FirstUser))
- FirstUser = UI;
- if (!LastUser || LastUser->comesBefore(UI))
- LastUser = UI;
if (UI->isLifetimeStartOrEnd()) {
// We note the locations of these intrinsic calls so that we can
// delete them later if the optimization succeeds, this is safe
// since both llvm.lifetime.start and llvm.lifetime.end intrinsics
- // conceptually fill all the bytes of the alloca with an undefined
- // value.
+ // practically fill all the bytes of the alloca with an undefined
+ // value, although conceptually marked as alive/dead.
int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue();
if (Size < 0 || Size == DestSize) {
LifetimeMarkers.push_back(UI);
@@ -1545,37 +1529,64 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
return true;
};
- // 3. Check that dest has no Mod/Ref, except full size lifetime intrinsics,
- // from the alloca to the Store.
+ // Check that dest has no Mod/Ref, from the alloca to the Store, except full
+ // size lifetime intrinsics. And collect modref inst for the reachability
+ // check.
ModRefInfo DestModRef = ModRefInfo::NoModRef;
MemoryLocation DestLoc(DestAlloca, LocationSize::precise(Size));
+ SmallVector<BasicBlock *, 8> ReachabilityWorklist;
auto DestModRefCallback = [&](Instruction *UI) -> bool {
// We don't care about the store itself.
if (UI == Store)
return true;
ModRefInfo Res = BAA.getModRefInfo(UI, DestLoc);
- // FIXME: For multi-BB cases, we need to see reachability from it to
- // store.
- // Bailout if Dest may have any ModRef before Store.
- if (UI->comesBefore(Store) && isModOrRefSet(Res))
- return false;
- DestModRef |= BAA.getModRefInfo(UI, DestLoc);
+ DestModRef |= Res;
+ if (isModOrRefSet(Res)) {
+ // Instructions reachability checks.
+ // FIXME: adding the Instruction version isPotentiallyReachableFromMany on
+ // lib/Analysis/CFG.cpp (currently only for BasicBlocks) might be helpful.
+ if (UI->getParent() == Store->getParent()) {
+ // The same block case is special because it's the only time we're
+ // looking within a single block to see which instruction comes first.
+ // Once we start looking at multiple blocks, the first instruction of
+ // the block is reachable, so we only need to determine reachability
+ // between whole blocks.
+ BasicBlock *BB = UI->getParent();
+ // If A comes before B, then B is definitively reachable from A.
+ if (UI->comesBefore(Store))
+ return false;
+
+ // If the user's parent block is entry, no predecessor exists.
+ if (BB->isEntryBlock())
+ return true;
+
+ // Otherwise, continue doing the normal per-BB CFG walk.
+ ReachabilityWorklist.append(succ_begin(BB), succ_end(BB));
+ } else {
+ ReachabilityWorklist.push_back(UI->getParent());
+ }
+ }
return true;
};
if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback))
return false;
+ // Bailout if Dest may have any ModRef before Store.
+ if (!ReachabilityWorklist.empty() &&
+ isPotentiallyReachableFromMany(ReachabilityWorklist, Store->getParent(),
+ nullptr, DT, nullptr))
+ return false;
- // 3. Check that, from after the Load to the end of the BB,
- // 3-1. if the dest has any Mod, src has no Ref, and
- // 3-2. if the dest has any Ref, src has no Mod except full-sized lifetimes.
+ // Check that, from after the Load to the end of the BB,
+ // - if the dest has any Mod, src has no Ref, and
+ // - if the dest has any Ref, src has no Mod except full-sized lifetimes.
MemoryLocation SrcLoc(SrcAlloca, LocationSize::precise(Size));
auto SrcModRefCallback = [&](Instruction *UI) -> bool {
- // Any ModRef before Load doesn't matter, also Load and Store can be
- // ignored.
- if (UI->comesBefore(Load) || UI == Load || UI == Store)
+ // Any ModRef post-dominated by Load doesn't matter, also Load and Store
+ // themselves can be ignored.
+ if (PDT->dominates(Load, UI) || UI == Load || UI == Store)
return true;
ModRefInfo Res = BAA.getModRefInfo(UI, SrcLoc);
if ((isModSet(DestModRef) && isRefSet(Res)) ||
@@ -1588,7 +1599,12 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback))
return false;
- // We can do the transformation. First, align the allocas appropriately.
+ // We can do the transformation. First, move the SrcAlloca to the start of the
+ // BB.
+ if (SrcNotDom)
+ SrcAlloca->moveBefore(*SrcAlloca->getParent(),
+ SrcAlloca->getParent()->getFirstInsertionPt());
+ // Align the allocas appropriately.
SrcAlloca->setAlignment(
std::max(SrcAlloca->getAlign(), DestAlloca->getAlign()));
@@ -1599,28 +1615,10 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
// Drop metadata on the source alloca.
SrcAlloca->dropUnknownNonDebugMetadata();
- // Do "shrink wrap" the lifetimes, if the original lifetime intrinsics exists.
+ // TODO: Reconstruct merged lifetime markers.
+ // Remove all other lifetime markers. if the original lifetime intrinsics
+ // exists.
if (!LifetimeMarkers.empty()) {
- LLVMContext &C = SrcAlloca->getContext();
- IRBuilder<> Builder(C);
-
- ConstantInt *AllocaSize = ConstantInt::get(Type::getInt64Ty(C), Size);
- // Create a new lifetime start marker before the first user of src or alloca
- // users.
- Builder.SetInsertPoint(FirstUser->getParent(), FirstUser->getIterator());
- Builder.CreateLifetimeStart(SrcAlloca, AllocaSize);
-
- // Create a new lifetime end marker after the last user of src or alloca
- // users.
- // FIXME: If the last user is the terminator for the bb, we can insert
- // lifetime.end marker to the immidiate post-dominator, but currently do
- // nothing.
- if (!LastUser->isTerminator()) {
- Builder.SetInsertPoint(LastUser->getParent(), ++LastUser->getIterator());
- Builder.CreateLifetimeEnd(SrcAlloca, AllocaSize);
- }
-
- // Remove all other lifetime markers.
for (Instruction *I : LifetimeMarkers)
eraseInstruction(I);
}
@@ -1637,6 +1635,16 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
return true;
}
+static bool isZeroSize(Value *Size) {
+ if (auto *I = dyn_cast<Instruction>(Size))
+ if (auto *Res = simplifyInstruction(I, I->getModule()->getDataLayout()))
+ Size = Res;
+ // Treat undef/poison size like zero.
+ if (auto *C = dyn_cast<Constant>(Size))
+ return isa<UndefValue>(C) || C->isNullValue();
+ return false;
+}
+
/// Perform simplification of memcpy's. If we have memcpy A
/// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite
/// B to be a memcpy from X to Z (or potentially a memmove, depending on
@@ -1653,6 +1661,19 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
return true;
}
+ // If the size is zero, remove the memcpy. This also prevents infinite loops
+ // in processMemSetMemCpyDependence, which is a no-op for zero-length memcpys.
+ if (isZeroSize(M->getLength())) {
+ ++BBI;
+ eraseInstruction(M);
+ return true;
+ }
+
+ MemoryUseOrDef *MA = MSSA->getMemoryAccess(M);
+ if (!MA)
+ // Degenerate case: memcpy marked as not accessing memory.
+ return false;
+
// If copying from a constant, try to turn the memcpy into a memset.
if (auto *GV = dyn_cast<GlobalVariable>(M->getSource()))
if (GV->isConstant() && GV->hasDefinitiveInitializer())
@@ -1661,10 +1682,9 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
IRBuilder<> Builder(M);
Instruction *NewM = Builder.CreateMemSet(
M->getRawDest(), ByteVal, M->getLength(), M->getDestAlign(), false);
- auto *LastDef =
- cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M));
+ auto *LastDef = cast<MemoryDef>(MA);
auto *NewAccess =
- MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef);
+ MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef);
MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true);
eraseInstruction(M);
@@ -1673,7 +1693,6 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
}
BatchAAResults BAA(*AA);
- MemoryUseOrDef *MA = MSSA->getMemoryAccess(M);
// FIXME: Not using getClobberingMemoryAccess() here due to PR54682.
MemoryAccess *AnyClobber = MA->getDefiningAccess();
MemoryLocation DestLoc = MemoryLocation::getForDest(M);
@@ -1751,8 +1770,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
ConstantInt *Len = dyn_cast<ConstantInt>(M->getLength());
if (Len == nullptr)
return false;
- if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca, Len->getZExtValue(),
- BAA)) {
+ if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca,
+ TypeSize::getFixed(Len->getZExtValue()), BAA)) {
// Avoid invalidating the iterator.
BBI = M->getNextNonDebugInstruction()->getIterator();
eraseInstruction(M);
@@ -1831,9 +1850,8 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
DT) < *ByValAlign)
return false;
- // The address space of the memcpy source must match the byval argument
- if (MDep->getSource()->getType()->getPointerAddressSpace() !=
- ByValArg->getType()->getPointerAddressSpace())
+ // The type of the memcpy source must match the byval argument
+ if (MDep->getSource()->getType() != ByValArg->getType())
return false;
// Verify that the copied-from memory doesn't change in between the memcpy and
@@ -1851,6 +1869,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
<< " " << CB << "\n");
// Otherwise we're good! Update the byval argument.
+ combineAAMetadata(&CB, MDep);
CB.setArgOperand(ArgNo, MDep->getSource());
++NumMemCpyInstr;
return true;
@@ -1907,9 +1926,8 @@ bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) {
if (!MDep || MDep->isVolatile() || AI != MDep->getDest())
return false;
- // The address space of the memcpy source must match the immut argument
- if (MDep->getSource()->getType()->getPointerAddressSpace() !=
- ImmutArg->getType()->getPointerAddressSpace())
+ // The type of the memcpy source must match the immut argument
+ if (MDep->getSource()->getType() != ImmutArg->getType())
return false;
// 2-1. The length of the memcpy must be equal to the size of the alloca.
@@ -1946,6 +1964,7 @@ bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) {
<< " " << CB << "\n");
// Otherwise we're good! Update the immut argument.
+ combineAAMetadata(&CB, MDep);
CB.setArgOperand(ArgNo, MDep->getSource());
++NumMemCpyInstr;
return true;
@@ -2004,9 +2023,10 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) {
auto *AA = &AM.getResult<AAManager>(F);
auto *AC = &AM.getResult<AssumptionAnalysis>(F);
auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
+ auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F);
auto *MSSA = &AM.getResult<MemorySSAAnalysis>(F);
- bool MadeChange = runImpl(F, &TLI, AA, AC, DT, &MSSA->getMSSA());
+ bool MadeChange = runImpl(F, &TLI, AA, AC, DT, PDT, &MSSA->getMSSA());
if (!MadeChange)
return PreservedAnalyses::all();
@@ -2018,12 +2038,14 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) {
bool MemCpyOptPass::runImpl(Function &F, TargetLibraryInfo *TLI_,
AliasAnalysis *AA_, AssumptionCache *AC_,
- DominatorTree *DT_, MemorySSA *MSSA_) {
+ DominatorTree *DT_, PostDominatorTree *PDT_,
+ MemorySSA *MSSA_) {
bool MadeChange = false;
TLI = TLI_;
AA = AA_;
AC = AC_;
DT = DT_;
+ PDT = PDT_;
MSSA = MSSA_;
MemorySSAUpdater MSSAU_(MSSA_);
MSSAU = &MSSAU_;
diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index 311a6435ba7c..1e0906717549 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -275,7 +275,7 @@ void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const {
// Do the actual spliting.
for (Instruction *Inst : reverse(OtherInsts))
- Inst->moveBefore(*NewParent, NewParent->begin());
+ Inst->moveBeforePreserving(*NewParent, NewParent->begin());
}
bool BCECmpBlock::canSplit(AliasAnalysis &AA) const {
diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
index 6c5453831ade..d65054a6ff9d 100644
--- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
+++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp
@@ -80,7 +80,6 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
@@ -217,8 +216,8 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0,
if (Opd1 == Opd2)
return nullptr;
- auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink",
- &BB->front());
+ auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink");
+ NewPN->insertBefore(BB->begin());
NewPN->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc());
NewPN->addIncoming(Opd1, S0->getParent());
NewPN->addIncoming(Opd2, S1->getParent());
@@ -269,7 +268,7 @@ void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0,
// Create the new store to be inserted at the join point.
StoreInst *SNew = cast<StoreInst>(S0->clone());
- SNew->insertBefore(&*InsertPt);
+ SNew->insertBefore(InsertPt);
// New PHI operand? Use it.
if (PHINode *NewPN = getPHIOperand(BB, S0, S1))
SNew->setOperand(0, NewPN);
@@ -378,52 +377,6 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) {
return Changed;
}
-namespace {
-class MergedLoadStoreMotionLegacyPass : public FunctionPass {
- const bool SplitFooterBB;
-public:
- static char ID; // Pass identification, replacement for typeid
- MergedLoadStoreMotionLegacyPass(bool SplitFooterBB = false)
- : FunctionPass(ID), SplitFooterBB(SplitFooterBB) {
- initializeMergedLoadStoreMotionLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- ///
- /// Run the transformation for each function
- ///
- bool runOnFunction(Function &F) override {
- if (skipFunction(F))
- return false;
- MergedLoadStoreMotion Impl(SplitFooterBB);
- return Impl.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults());
- }
-
-private:
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- if (!SplitFooterBB)
- AU.setPreservesCFG();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- }
-};
-
-char MergedLoadStoreMotionLegacyPass::ID = 0;
-} // anonymous namespace
-
-///
-/// createMergedLoadStoreMotionPass - The public interface to this file.
-///
-FunctionPass *llvm::createMergedLoadStoreMotionPass(bool SplitFooterBB) {
- return new MergedLoadStoreMotionLegacyPass(SplitFooterBB);
-}
-
-INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion",
- "MergedLoadStoreMotion", false, false)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion",
- "MergedLoadStoreMotion", false, false)
-
PreservedAnalyses
MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) {
MergedLoadStoreMotion Impl(Options.SplitFooterBB);
diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
index 9c3e9a2fd018..7fe1a222021e 100644
--- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -359,12 +359,13 @@ bool NaryReassociatePass::requiresSignExtension(Value *Index,
GetElementPtrInst *
NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
unsigned I, Type *IndexedType) {
+ SimplifyQuery SQ(*DL, DT, AC, GEP);
Value *IndexToSplit = GEP->getOperand(I + 1);
if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) {
IndexToSplit = SExt->getOperand(0);
} else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) {
// zext can be treated as sext if the source is non-negative.
- if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT))
+ if (isKnownNonNegative(ZExt->getOperand(0), SQ))
IndexToSplit = ZExt->getOperand(0);
}
@@ -373,8 +374,7 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
// nsw, we cannot split the add because
// sext(LHS + RHS) != sext(LHS) + sext(RHS).
if (requiresSignExtension(IndexToSplit, GEP) &&
- computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
- OverflowResult::NeverOverflows)
+ computeOverflowForSignedAdd(AO, SQ) != OverflowResult::NeverOverflows)
return nullptr;
Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
@@ -402,7 +402,7 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
IndexExprs.push_back(SE->getSCEV(Index));
// Replace the I-th index with LHS.
IndexExprs[I] = SE->getSCEV(LHS);
- if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) &&
+ if (isKnownNonNegative(LHS, SimplifyQuery(*DL, DT, AC, GEP)) &&
DL->getTypeSizeInBits(LHS->getType()).getFixedValue() <
DL->getTypeSizeInBits(GEP->getOperand(I)->getType())
.getFixedValue()) {
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 1af40e2c4e62..19ac9526b5f8 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -774,7 +774,7 @@ private:
// Symbolic evaluation.
ExprResult checkExprResults(Expression *, Instruction *, Value *) const;
- ExprResult performSymbolicEvaluation(Value *,
+ ExprResult performSymbolicEvaluation(Instruction *,
SmallPtrSetImpl<Value *> &) const;
const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *,
Instruction *,
@@ -1904,7 +1904,7 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
LastPredInfo = PI;
// In phi of ops cases, we may have predicate info that we are evaluating
// in a different context.
- if (!DT->dominates(PBranch->To, getBlockForValue(I)))
+ if (!DT->dominates(PBranch->To, I->getParent()))
continue;
// TODO: Along the false edge, we may know more things too, like
// icmp of
@@ -1961,95 +1961,88 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
return createExpression(I);
}
-// Substitute and symbolize the value before value numbering.
+// Substitute and symbolize the instruction before value numbering.
NewGVN::ExprResult
-NewGVN::performSymbolicEvaluation(Value *V,
+NewGVN::performSymbolicEvaluation(Instruction *I,
SmallPtrSetImpl<Value *> &Visited) const {
const Expression *E = nullptr;
- if (auto *C = dyn_cast<Constant>(V))
- E = createConstantExpression(C);
- else if (isa<Argument>(V) || isa<GlobalVariable>(V)) {
- E = createVariableExpression(V);
- } else {
- // TODO: memory intrinsics.
- // TODO: Some day, we should do the forward propagation and reassociation
- // parts of the algorithm.
- auto *I = cast<Instruction>(V);
- switch (I->getOpcode()) {
- case Instruction::ExtractValue:
- case Instruction::InsertValue:
- E = performSymbolicAggrValueEvaluation(I);
- break;
- case Instruction::PHI: {
- SmallVector<ValPair, 3> Ops;
- auto *PN = cast<PHINode>(I);
- for (unsigned i = 0; i < PN->getNumOperands(); ++i)
- Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)});
- // Sort to ensure the invariant createPHIExpression requires is met.
- sortPHIOps(Ops);
- E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I));
- } break;
- case Instruction::Call:
- return performSymbolicCallEvaluation(I);
- break;
- case Instruction::Store:
- E = performSymbolicStoreEvaluation(I);
- break;
- case Instruction::Load:
- E = performSymbolicLoadEvaluation(I);
- break;
- case Instruction::BitCast:
- case Instruction::AddrSpaceCast:
- case Instruction::Freeze:
- return createExpression(I);
- break;
- case Instruction::ICmp:
- case Instruction::FCmp:
- return performSymbolicCmpEvaluation(I);
- break;
- case Instruction::FNeg:
- case Instruction::Add:
- case Instruction::FAdd:
- case Instruction::Sub:
- case Instruction::FSub:
- case Instruction::Mul:
- case Instruction::FMul:
- case Instruction::UDiv:
- case Instruction::SDiv:
- case Instruction::FDiv:
- case Instruction::URem:
- case Instruction::SRem:
- case Instruction::FRem:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor:
- case Instruction::Trunc:
- case Instruction::ZExt:
- case Instruction::SExt:
- case Instruction::FPToUI:
- case Instruction::FPToSI:
- case Instruction::UIToFP:
- case Instruction::SIToFP:
- case Instruction::FPTrunc:
- case Instruction::FPExt:
- case Instruction::PtrToInt:
- case Instruction::IntToPtr:
- case Instruction::Select:
- case Instruction::ExtractElement:
- case Instruction::InsertElement:
- case Instruction::GetElementPtr:
- return createExpression(I);
- break;
- case Instruction::ShuffleVector:
- // FIXME: Add support for shufflevector to createExpression.
- return ExprResult::none();
- default:
- return ExprResult::none();
- }
+ // TODO: memory intrinsics.
+ // TODO: Some day, we should do the forward propagation and reassociation
+ // parts of the algorithm.
+ switch (I->getOpcode()) {
+ case Instruction::ExtractValue:
+ case Instruction::InsertValue:
+ E = performSymbolicAggrValueEvaluation(I);
+ break;
+ case Instruction::PHI: {
+ SmallVector<ValPair, 3> Ops;
+ auto *PN = cast<PHINode>(I);
+ for (unsigned i = 0; i < PN->getNumOperands(); ++i)
+ Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)});
+ // Sort to ensure the invariant createPHIExpression requires is met.
+ sortPHIOps(Ops);
+ E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I));
+ } break;
+ case Instruction::Call:
+ return performSymbolicCallEvaluation(I);
+ break;
+ case Instruction::Store:
+ E = performSymbolicStoreEvaluation(I);
+ break;
+ case Instruction::Load:
+ E = performSymbolicLoadEvaluation(I);
+ break;
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
+ case Instruction::Freeze:
+ return createExpression(I);
+ break;
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return performSymbolicCmpEvaluation(I);
+ break;
+ case Instruction::FNeg:
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::FDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::Trunc:
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::UIToFP:
+ case Instruction::SIToFP:
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ case Instruction::Select:
+ case Instruction::ExtractElement:
+ case Instruction::InsertElement:
+ case Instruction::GetElementPtr:
+ return createExpression(I);
+ break;
+ case Instruction::ShuffleVector:
+ // FIXME: Add support for shufflevector to createExpression.
+ return ExprResult::none();
+ default:
+ return ExprResult::none();
}
return ExprResult::some(E);
}
@@ -2772,6 +2765,9 @@ NewGVN::makePossiblePHIOfOps(Instruction *I,
// Clone the instruction, create an expression from it that is
// translated back into the predecessor, and see if we have a leader.
Instruction *ValueOp = I->clone();
+ // Emit the temporal instruction in the predecessor basic block where the
+ // corresponding value is defined.
+ ValueOp->insertBefore(PredBB->getTerminator());
if (MemAccess)
TempToMemory.insert({ValueOp, MemAccess});
bool SafeForPHIOfOps = true;
@@ -2801,7 +2797,7 @@ NewGVN::makePossiblePHIOfOps(Instruction *I,
FoundVal = !SafeForPHIOfOps ? nullptr
: findLeaderForInst(ValueOp, Visited,
MemAccess, I, PredBB);
- ValueOp->deleteValue();
+ ValueOp->eraseFromParent();
if (!FoundVal) {
// We failed to find a leader for the current ValueOp, but this might
// change in case of the translated operands change.
@@ -3542,7 +3538,7 @@ struct NewGVN::ValueDFS {
// the second. We only want it to be less than if the DFS orders are equal.
//
// Each LLVM instruction only produces one value, and thus the lowest-level
- // differentiator that really matters for the stack (and what we use as as a
+ // differentiator that really matters for the stack (and what we use as a
// replacement) is the local dfs number.
// Everything else in the structure is instruction level, and only affects
// the order in which we will replace operands of a given instruction.
@@ -4034,9 +4030,18 @@ bool NewGVN::eliminateInstructions(Function &F) {
// because stores are put in terms of the stored value, we skip
// stored values here. If the stored value is really dead, it will
// still be marked for deletion when we process it in its own class.
- if (!EliminationStack.empty() && Def != EliminationStack.back() &&
- isa<Instruction>(Def) && !FromStore)
- markInstructionForDeletion(cast<Instruction>(Def));
+ auto *DefI = dyn_cast<Instruction>(Def);
+ if (!EliminationStack.empty() && DefI && !FromStore) {
+ Value *DominatingLeader = EliminationStack.back();
+ if (DominatingLeader != Def) {
+ // Even if the instruction is removed, we still need to update
+ // flags/metadata due to downstreams users of the leader.
+ if (!match(DefI, m_Intrinsic<Intrinsic::ssa_copy>()))
+ patchReplacementInstruction(DefI, DominatingLeader);
+
+ markInstructionForDeletion(DefI);
+ }
+ }
continue;
}
// At this point, we know it is a Use we are trying to possibly
@@ -4095,9 +4100,12 @@ bool NewGVN::eliminateInstructions(Function &F) {
// For copy instructions, we use their operand as a leader,
// which means we remove a user of the copy and it may become dead.
if (isSSACopy) {
- unsigned &IIUseCount = UseCounts[II];
- if (--IIUseCount == 0)
- ProbablyDead.insert(II);
+ auto It = UseCounts.find(II);
+ if (It != UseCounts.end()) {
+ unsigned &IIUseCount = It->second;
+ if (--IIUseCount == 0)
+ ProbablyDead.insert(II);
+ }
}
++LeaderUseCount;
AnythingReplaced = true;
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 40c84e249523..818c7b40d489 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>;
/// type and thus make the expression bigger.
static bool LinearizeExprTree(Instruction *I,
SmallVectorImpl<RepeatedValue> &Ops,
- ReassociatePass::OrderedSet &ToRedo) {
+ ReassociatePass::OrderedSet &ToRedo,
+ bool &HasNUW) {
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
"Expected a UnaryOperator or BinaryOperator!");
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I,
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.
+ if (isa<OverflowingBinaryOperator>(I))
+ HasNUW &= I->hasNoUnsignedWrap();
+
for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
APInt Weight = P.second; // Number of paths to this operand.
@@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I,
/// Now that the operands for this expression tree are
/// linearized and optimized, emit them in-order.
void ReassociatePass::RewriteExprTree(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ bool HasNUW) {
assert(Ops.size() > 1 && "Single values should be used directly!");
// Since our optimizations should never increase the number of operations, the
@@ -814,14 +819,20 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
if (ExpressionChangedStart) {
bool ClearFlags = true;
do {
- // Preserve FastMathFlags.
+ // Preserve flags.
if (ClearFlags) {
if (isa<FPMathOperator>(I)) {
FastMathFlags Flags = I->getFastMathFlags();
ExpressionChangedStart->clearSubclassOptionalData();
ExpressionChangedStart->setFastMathFlags(Flags);
- } else
+ } else {
ExpressionChangedStart->clearSubclassOptionalData();
+ // Note that it doesn't hold for mul if one of the operands is zero.
+ // TODO: We can preserve NUW flag if we prove that all mul operands
+ // are non-zero.
+ if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
+ ExpressionChangedStart->setHasNoUnsignedWrap();
+ }
}
if (ExpressionChangedStart == ExpressionChangedEnd)
@@ -921,16 +932,20 @@ static Value *NegateValue(Value *V, Instruction *BI,
TheNeg->getParent()->getParent() != BI->getParent()->getParent())
continue;
- Instruction *InsertPt;
+ BasicBlock::iterator InsertPt;
if (Instruction *InstInput = dyn_cast<Instruction>(V)) {
- InsertPt = InstInput->getInsertionPointAfterDef();
- if (!InsertPt)
+ auto InsertPtOpt = InstInput->getInsertionPointAfterDef();
+ if (!InsertPtOpt)
continue;
+ InsertPt = *InsertPtOpt;
} else {
- InsertPt = &*TheNeg->getFunction()->getEntryBlock().begin();
+ InsertPt = TheNeg->getFunction()
+ ->getEntryBlock()
+ .getFirstNonPHIOrDbg()
+ ->getIterator();
}
- TheNeg->moveBefore(InsertPt);
+ TheNeg->moveBefore(*InsertPt->getParent(), InsertPt);
if (TheNeg->getOpcode() == Instruction::Sub) {
TheNeg->setHasNoUnsignedWrap(false);
TheNeg->setHasNoSignedWrap(false);
@@ -1171,7 +1186,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
return nullptr;
SmallVector<RepeatedValue, 8> Tree;
- MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts);
+ bool HasNUW = true;
+ MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
SmallVector<ValueEntry, 8> Factors;
Factors.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1213,7 +1229,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
if (!FoundFactor) {
// Make sure to restore the operands to the expression tree.
- RewriteExprTree(BO, Factors);
+ RewriteExprTree(BO, Factors, HasNUW);
return nullptr;
}
@@ -1225,7 +1241,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
RedoInsts.insert(BO);
V = Factors[0].Op;
} else {
- RewriteExprTree(BO, Factors);
+ RewriteExprTree(BO, Factors, HasNUW);
V = BO;
}
@@ -2252,9 +2268,10 @@ void ReassociatePass::OptimizeInst(Instruction *I) {
// with no common bits set, convert it to X+Y.
if (I->getOpcode() == Instruction::Or &&
shouldConvertOrWithNoCommonBitsToAdd(I) && !isLoadCombineCandidate(I) &&
- haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1),
- I->getModule()->getDataLayout(), /*AC=*/nullptr, I,
- /*DT=*/nullptr)) {
+ (cast<PossiblyDisjointInst>(I)->isDisjoint() ||
+ haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1),
+ SimplifyQuery(I->getModule()->getDataLayout(),
+ /*DT=*/nullptr, /*AC=*/nullptr, I)))) {
Instruction *NI = convertOrWithNoCommonBitsToAdd(I);
RedoInsts.insert(I);
MadeChange = true;
@@ -2349,7 +2366,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
SmallVector<RepeatedValue, 8> Tree;
- MadeChange |= LinearizeExprTree(I, Tree, RedoInsts);
+ bool HasNUW = true;
+ MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
SmallVector<ValueEntry, 8> Ops;
Ops.reserve(Tree.size());
for (const RepeatedValue &E : Tree)
@@ -2542,7 +2560,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
dbgs() << '\n');
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
- RewriteExprTree(I, Ops);
+ RewriteExprTree(I, Ops, HasNUW);
}
void
@@ -2550,7 +2568,7 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
// Make a "pairmap" of how often each operand pair occurs.
for (BasicBlock *BI : RPOT) {
for (Instruction &I : *BI) {
- if (!I.isAssociative())
+ if (!I.isAssociative() || !I.isBinaryOp())
continue;
// Ignore nodes that aren't at the root of trees.
diff --git a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp
index db7a1f24660c..6c2b3e9bd4a7 100644
--- a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp
+++ b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp
@@ -25,8 +25,6 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -107,36 +105,3 @@ PreservedAnalyses RegToMemPass::run(Function &F, FunctionAnalysisManager &AM) {
PA.preserve<LoopAnalysis>();
return PA;
}
-
-namespace {
-struct RegToMemLegacy : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- RegToMemLegacy() : FunctionPass(ID) {
- initializeRegToMemLegacyPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequiredID(BreakCriticalEdgesID);
- AU.addPreservedID(BreakCriticalEdgesID);
- }
-
- bool runOnFunction(Function &F) override {
- if (F.isDeclaration() || skipFunction(F))
- return false;
- return runPass(F);
- }
-};
-} // namespace
-
-char RegToMemLegacy::ID = 0;
-INITIALIZE_PASS_BEGIN(RegToMemLegacy, "reg2mem",
- "Demote all values to stack slots", false, false)
-INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
-INITIALIZE_PASS_END(RegToMemLegacy, "reg2mem",
- "Demote all values to stack slots", false, false)
-
-// createDemoteRegisterToMemory - Provide an entry point to create this pass.
-char &llvm::DemoteRegisterToMemoryID = RegToMemLegacy::ID;
-FunctionPass *llvm::createDemoteRegisterToMemoryPass() {
- return new RegToMemLegacy();
-}
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index 908bda5709a0..40b4ea92e1ff 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -54,15 +55,12 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -995,7 +993,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
NewState.meet(OpState);
});
- BDVState OldState = States[BDV];
+ BDVState OldState = Pair.second;
if (OldState != NewState) {
Progress = true;
States[BDV] = NewState;
@@ -1014,8 +1012,44 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
}
#endif
- // Handle all instructions that have a vector BDV, but the instruction itself
- // is of scalar type.
+ // Even though we have identified a concrete base (or a conflict) for all live
+ // pointers at this point, there are cases where the base is of an
+ // incompatible type compared to the original instruction. We conservatively
+ // mark those as conflicts to ensure that corresponding BDVs will be generated
+ // in the next steps.
+
+ // this is a rather explicit check for all cases where we should mark the
+ // state as a conflict to force the latter stages of the algorithm to emit
+ // the BDVs.
+ // TODO: in many cases the instructions emited for the conflicting states
+ // will be identical to the I itself (if the I's operate on their BDVs
+ // themselves). We should expoit this, but can't do it here since it would
+ // break the invariant about the BDVs not being known to be a base.
+ // TODO: the code also does not handle constants at all - the algorithm relies
+ // on all constants having the same BDV and therefore constant-only insns
+ // will never be in conflict, but this check is ignored here. If the
+ // constant conflicts will be to BDVs themselves, they will be identical
+ // instructions and will get optimized away (as in the above TODO)
+ auto MarkConflict = [&](Instruction *I, Value *BaseValue) {
+ // II and EE mixes vector & scalar so is always a conflict
+ if (isa<InsertElementInst>(I) || isa<ExtractElementInst>(I))
+ return true;
+ // Shuffle vector is always a conflict as it creates new vector from
+ // existing ones.
+ if (isa<ShuffleVectorInst>(I))
+ return true;
+ // Any instructions where the computed base type differs from the
+ // instruction type. An example is where an extract instruction is used by a
+ // select. Here the select's BDV is a vector (because of extract's BDV),
+ // while the select itself is a scalar type. Note that the IE and EE
+ // instruction check is not fully subsumed by the vector<->scalar check at
+ // the end, this is due to the BDV algorithm being ignorant of BDV types at
+ // this junction.
+ if (!areBothVectorOrScalar(BaseValue, I))
+ return true;
+ return false;
+ };
+
for (auto Pair : States) {
Instruction *I = cast<Instruction>(Pair.first);
BDVState State = Pair.second;
@@ -1028,30 +1062,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
"why did it get added?");
assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
- if (!State.isBase() || !isa<VectorType>(BaseValue->getType()))
+ // since we only mark vec-scalar insns as conflicts in the pass, our work is
+ // done if the instruction already conflicts
+ if (State.isConflict())
continue;
- // extractelement instructions are a bit special in that we may need to
- // insert an extract even when we know an exact base for the instruction.
- // The problem is that we need to convert from a vector base to a scalar
- // base for the particular indice we're interested in.
- if (isa<ExtractElementInst>(I)) {
- auto *EE = cast<ExtractElementInst>(I);
- // TODO: In many cases, the new instruction is just EE itself. We should
- // exploit this, but can't do it here since it would break the invariant
- // about the BDV not being known to be a base.
- auto *BaseInst = ExtractElementInst::Create(
- State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);
- BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
- States[I] = BDVState(I, BDVState::Base, BaseInst);
- setKnownBase(BaseInst, /* IsKnownBase */true, KnownBases);
- } else if (!isa<VectorType>(I->getType())) {
- // We need to handle cases that have a vector base but the instruction is
- // a scalar type (these could be phis or selects or any instruction that
- // are of scalar type, but the base can be a vector type). We
- // conservatively set this as conflict. Setting the base value for these
- // conflicts is handled in the next loop which traverses States.
+
+ if (MarkConflict(I, BaseValue))
States[I] = BDVState(I, BDVState::Conflict);
- }
}
#ifndef NDEBUG
@@ -1234,6 +1251,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
VerifyStates();
#endif
+ // get the data layout to compare the sizes of base/derived pointer values
+ [[maybe_unused]] auto &DL =
+ cast<llvm::Instruction>(Def)->getModule()->getDataLayout();
// Cache all of our results so we can cheaply reuse them
// NOTE: This is actually two caches: one of the base defining value
// relation and one of the base pointer relation! FIXME
@@ -1241,6 +1261,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache,
auto *BDV = Pair.first;
Value *Base = Pair.second.getBaseValue();
assert(BDV && Base);
+ // Whenever we have a derived ptr(s), their base
+ // ptr(s) must be of the same size, not necessarily the same type
+ assert(DL.getTypeAllocSize(BDV->getType()) ==
+ DL.getTypeAllocSize(Base->getType()) &&
+ "Derived and base values should have same size");
// Only values that do not have known bases or those that have differing
// type (scalar versus vector) from a possible known base should be in the
// lattice.
@@ -1425,14 +1450,15 @@ static constexpr Attribute::AttrKind FnAttrsToStrip[] =
{Attribute::Memory, Attribute::NoSync, Attribute::NoFree};
// Create new attribute set containing only attributes which can be transferred
-// from original call to the safepoint.
-static AttributeList legalizeCallAttributes(LLVMContext &Ctx,
- AttributeList OrigAL,
+// from the original call to the safepoint.
+static AttributeList legalizeCallAttributes(CallBase *Call, bool IsMemIntrinsic,
AttributeList StatepointAL) {
+ AttributeList OrigAL = Call->getAttributes();
if (OrigAL.isEmpty())
return StatepointAL;
// Remove the readonly, readnone, and statepoint function attributes.
+ LLVMContext &Ctx = Call->getContext();
AttrBuilder FnAttrs(Ctx, OrigAL.getFnAttrs());
for (auto Attr : FnAttrsToStrip)
FnAttrs.removeAttribute(Attr);
@@ -1442,8 +1468,24 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx,
FnAttrs.removeAttribute(A);
}
- // Just skip parameter and return attributes for now
- return StatepointAL.addFnAttributes(Ctx, FnAttrs);
+ StatepointAL = StatepointAL.addFnAttributes(Ctx, FnAttrs);
+
+ // The memory intrinsics do not have a 1:1 correspondence of the original
+ // call arguments to the produced statepoint. Do not transfer the argument
+ // attributes to avoid putting them on incorrect arguments.
+ if (IsMemIntrinsic)
+ return StatepointAL;
+
+ // Attach the argument attributes from the original call at the corresponding
+ // arguments in the statepoint. Note that any argument attributes that are
+ // invalid after lowering are stripped in stripNonValidDataFromBody.
+ for (unsigned I : llvm::seq(Call->arg_size()))
+ StatepointAL = StatepointAL.addParamAttributes(
+ Ctx, GCStatepointInst::CallArgsBeginPos + I,
+ AttrBuilder(Ctx, OrigAL.getParamAttrs(I)));
+
+ // Return attributes are later attached to the gc.result intrinsic.
+ return StatepointAL;
}
/// Helper function to place all gc relocates necessary for the given
@@ -1480,7 +1522,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
auto getGCRelocateDecl = [&](Type *Ty) {
assert(isHandledGCPointerType(Ty, GC));
auto AS = Ty->getScalarType()->getPointerAddressSpace();
- Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS);
+ Type *NewTy = PointerType::get(M->getContext(), AS);
if (auto *VT = dyn_cast<VectorType>(Ty))
NewTy = FixedVectorType::get(NewTy,
cast<FixedVectorType>(VT)->getNumElements());
@@ -1633,6 +1675,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
// with a return value, we lower then as never returning calls to
// __llvm_deoptimize that are followed by unreachable to get better codegen.
bool IsDeoptimize = false;
+ bool IsMemIntrinsic = false;
StatepointDirectives SD =
parseStatepointDirectivesFromAttrs(Call->getAttributes());
@@ -1673,6 +1716,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
IsDeoptimize = true;
} else if (IID == Intrinsic::memcpy_element_unordered_atomic ||
IID == Intrinsic::memmove_element_unordered_atomic) {
+ IsMemIntrinsic = true;
+
// Unordered atomic memcpy and memmove intrinsics which are not explicitly
// marked as "gc-leaf-function" should be lowered in a GC parseable way.
// Specifically, these calls should be lowered to the
@@ -1788,12 +1833,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
SPCall->setTailCallKind(CI->getTailCallKind());
SPCall->setCallingConv(CI->getCallingConv());
- // Currently we will fail on parameter attributes and on certain
- // function attributes. In case if we can handle this set of attributes -
- // set up function attrs directly on statepoint and return attrs later for
+ // Set up function attrs directly on statepoint and return attrs later for
// gc_result intrinsic.
- SPCall->setAttributes(legalizeCallAttributes(
- CI->getContext(), CI->getAttributes(), SPCall->getAttributes()));
+ SPCall->setAttributes(
+ legalizeCallAttributes(CI, IsMemIntrinsic, SPCall->getAttributes()));
Token = cast<GCStatepointInst>(SPCall);
@@ -1815,12 +1858,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
SPInvoke->setCallingConv(II->getCallingConv());
- // Currently we will fail on parameter attributes and on certain
- // function attributes. In case if we can handle this set of attributes -
- // set up function attrs directly on statepoint and return attrs later for
+ // Set up function attrs directly on statepoint and return attrs later for
// gc_result intrinsic.
- SPInvoke->setAttributes(legalizeCallAttributes(
- II->getContext(), II->getAttributes(), SPInvoke->getAttributes()));
+ SPInvoke->setAttributes(
+ legalizeCallAttributes(II, IsMemIntrinsic, SPInvoke->getAttributes()));
Token = cast<GCStatepointInst>(SPInvoke);
@@ -1830,7 +1871,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
UnwindBlock->getUniquePredecessor() &&
"can't safely insert in this block!");
- Builder.SetInsertPoint(&*UnwindBlock->getFirstInsertionPt());
+ Builder.SetInsertPoint(UnwindBlock, UnwindBlock->getFirstInsertionPt());
Builder.SetCurrentDebugLocation(II->getDebugLoc());
// Attach exceptional gc relocates to the landingpad.
@@ -1845,7 +1886,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
NormalDest->getUniquePredecessor() &&
"can't safely insert in this block!");
- Builder.SetInsertPoint(&*NormalDest->getFirstInsertionPt());
+ Builder.SetInsertPoint(NormalDest, NormalDest->getFirstInsertionPt());
// gc relocates will be generated later as if it were regular call
// statepoint
diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp
index fcdc503c54a4..69679b608f8d 100644
--- a/llvm/lib/Transforms/Scalar/SCCP.cpp
+++ b/llvm/lib/Transforms/Scalar/SCCP.cpp
@@ -17,10 +17,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/SCCP.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
@@ -51,7 +48,6 @@
#include "llvm/Transforms/Utils/SCCPSolver.h"
#include <cassert>
#include <utility>
-#include <vector>
using namespace llvm;
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 983a75e1d708..f578762d2b49 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
@@ -70,6 +71,7 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
+#include "llvm/IR/ValueHandle.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
@@ -91,10 +93,10 @@
#include <string>
#include <tuple>
#include <utility>
+#include <variant>
#include <vector>
using namespace llvm;
-using namespace llvm::sroa;
#define DEBUG_TYPE "sroa"
@@ -123,6 +125,138 @@ static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false),
cl::Hidden);
namespace {
+class AllocaSliceRewriter;
+class AllocaSlices;
+class Partition;
+
+class SelectHandSpeculativity {
+ unsigned char Storage = 0; // None are speculatable by default.
+ using TrueVal = Bitfield::Element<bool, 0, 1>; // Low 0'th bit.
+ using FalseVal = Bitfield::Element<bool, 1, 1>; // Low 1'th bit.
+public:
+ SelectHandSpeculativity() = default;
+ SelectHandSpeculativity &setAsSpeculatable(bool isTrueVal);
+ bool isSpeculatable(bool isTrueVal) const;
+ bool areAllSpeculatable() const;
+ bool areAnySpeculatable() const;
+ bool areNoneSpeculatable() const;
+ // For interop as int half of PointerIntPair.
+ explicit operator intptr_t() const { return static_cast<intptr_t>(Storage); }
+ explicit SelectHandSpeculativity(intptr_t Storage_) : Storage(Storage_) {}
+};
+static_assert(sizeof(SelectHandSpeculativity) == sizeof(unsigned char));
+
+using PossiblySpeculatableLoad =
+ PointerIntPair<LoadInst *, 2, SelectHandSpeculativity>;
+using UnspeculatableStore = StoreInst *;
+using RewriteableMemOp =
+ std::variant<PossiblySpeculatableLoad, UnspeculatableStore>;
+using RewriteableMemOps = SmallVector<RewriteableMemOp, 2>;
+
+/// An optimization pass providing Scalar Replacement of Aggregates.
+///
+/// This pass takes allocations which can be completely analyzed (that is, they
+/// don't escape) and tries to turn them into scalar SSA values. There are
+/// a few steps to this process.
+///
+/// 1) It takes allocations of aggregates and analyzes the ways in which they
+/// are used to try to split them into smaller allocations, ideally of
+/// a single scalar data type. It will split up memcpy and memset accesses
+/// as necessary and try to isolate individual scalar accesses.
+/// 2) It will transform accesses into forms which are suitable for SSA value
+/// promotion. This can be replacing a memset with a scalar store of an
+/// integer value, or it can involve speculating operations on a PHI or
+/// select to be a PHI or select of the results.
+/// 3) Finally, this will try to detect a pattern of accesses which map cleanly
+/// onto insert and extract operations on a vector value, and convert them to
+/// this form. By doing so, it will enable promotion of vector aggregates to
+/// SSA vector values.
+class SROA {
+ LLVMContext *const C;
+ DomTreeUpdater *const DTU;
+ AssumptionCache *const AC;
+ const bool PreserveCFG;
+
+ /// Worklist of alloca instructions to simplify.
+ ///
+ /// Each alloca in the function is added to this. Each new alloca formed gets
+ /// added to it as well to recursively simplify unless that alloca can be
+ /// directly promoted. Finally, each time we rewrite a use of an alloca other
+ /// the one being actively rewritten, we add it back onto the list if not
+ /// already present to ensure it is re-visited.
+ SmallSetVector<AllocaInst *, 16> Worklist;
+
+ /// A collection of instructions to delete.
+ /// We try to batch deletions to simplify code and make things a bit more
+ /// efficient. We also make sure there is no dangling pointers.
+ SmallVector<WeakVH, 8> DeadInsts;
+
+ /// Post-promotion worklist.
+ ///
+ /// Sometimes we discover an alloca which has a high probability of becoming
+ /// viable for SROA after a round of promotion takes place. In those cases,
+ /// the alloca is enqueued here for re-processing.
+ ///
+ /// Note that we have to be very careful to clear allocas out of this list in
+ /// the event they are deleted.
+ SmallSetVector<AllocaInst *, 16> PostPromotionWorklist;
+
+ /// A collection of alloca instructions we can directly promote.
+ std::vector<AllocaInst *> PromotableAllocas;
+
+ /// A worklist of PHIs to speculate prior to promoting allocas.
+ ///
+ /// All of these PHIs have been checked for the safety of speculation and by
+ /// being speculated will allow promoting allocas currently in the promotable
+ /// queue.
+ SmallSetVector<PHINode *, 8> SpeculatablePHIs;
+
+ /// A worklist of select instructions to rewrite prior to promoting
+ /// allocas.
+ SmallMapVector<SelectInst *, RewriteableMemOps, 8> SelectsToRewrite;
+
+ /// Select instructions that use an alloca and are subsequently loaded can be
+ /// rewritten to load both input pointers and then select between the result,
+ /// allowing the load of the alloca to be promoted.
+ /// From this:
+ /// %P2 = select i1 %cond, ptr %Alloca, ptr %Other
+ /// %V = load <type>, ptr %P2
+ /// to:
+ /// %V1 = load <type>, ptr %Alloca -> will be mem2reg'd
+ /// %V2 = load <type>, ptr %Other
+ /// %V = select i1 %cond, <type> %V1, <type> %V2
+ ///
+ /// We can do this to a select if its only uses are loads
+ /// and if either the operand to the select can be loaded unconditionally,
+ /// or if we are allowed to perform CFG modifications.
+ /// If found an intervening bitcast with a single use of the load,
+ /// allow the promotion.
+ static std::optional<RewriteableMemOps>
+ isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG);
+
+public:
+ SROA(LLVMContext *C, DomTreeUpdater *DTU, AssumptionCache *AC,
+ SROAOptions PreserveCFG_)
+ : C(C), DTU(DTU), AC(AC),
+ PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {}
+
+ /// Main run method used by both the SROAPass and by the legacy pass.
+ std::pair<bool /*Changed*/, bool /*CFGChanged*/> runSROA(Function &F);
+
+private:
+ friend class AllocaSliceRewriter;
+
+ bool presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS);
+ AllocaInst *rewritePartition(AllocaInst &AI, AllocaSlices &AS, Partition &P);
+ bool splitAlloca(AllocaInst &AI, AllocaSlices &AS);
+ std::pair<bool /*Changed*/, bool /*CFGChanged*/> runOnAlloca(AllocaInst &AI);
+ void clobberUse(Use &U);
+ bool deleteDeadInstructions(SmallPtrSetImpl<AllocaInst *> &DeletedAllocas);
+ bool promoteAllocas(Function &F);
+};
+
+} // end anonymous namespace
+
/// Calculate the fragment of a variable to use when slicing a store
/// based on the slice dimensions, existing fragment, and base storage
/// fragment.
@@ -131,7 +265,9 @@ namespace {
/// UseNoFrag - The new slice already covers the whole variable.
/// Skip - The new alloca slice doesn't include this variable.
/// FIXME: Can we use calculateFragmentIntersect instead?
+namespace {
enum FragCalcResult { UseFrag, UseNoFrag, Skip };
+}
static FragCalcResult
calculateFragment(DILocalVariable *Variable,
uint64_t NewStorageSliceOffsetInBits,
@@ -330,6 +466,8 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit,
}
}
+namespace {
+
/// A custom IRBuilder inserter which prefixes all names, but only in
/// Assert builds.
class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter {
@@ -422,8 +560,6 @@ public:
bool operator!=(const Slice &RHS) const { return !operator==(RHS); }
};
-} // end anonymous namespace
-
/// Representation of the alloca slices.
///
/// This class represents the slices of an alloca which are formed by its
@@ -431,7 +567,7 @@ public:
/// for the slices used and we reflect that in this structure. The uses are
/// stored, sorted by increasing beginning offset and with unsplittable slices
/// starting at a particular offset before splittable slices.
-class llvm::sroa::AllocaSlices {
+class AllocaSlices {
public:
/// Construct the slices of a particular alloca.
AllocaSlices(const DataLayout &DL, AllocaInst &AI);
@@ -563,7 +699,7 @@ private:
///
/// Objects of this type are produced by traversing the alloca's slices, but
/// are only ephemeral and not persistent.
-class llvm::sroa::Partition {
+class Partition {
private:
friend class AllocaSlices;
friend class AllocaSlices::partition_iterator;
@@ -628,6 +764,8 @@ public:
ArrayRef<Slice *> splitSliceTails() const { return SplitTails; }
};
+} // end anonymous namespace
+
/// An iterator over partitions of the alloca's slices.
///
/// This iterator implements the core algorithm for partitioning the alloca's
@@ -1144,6 +1282,7 @@ private:
}
if (II.isLaunderOrStripInvariantGroup()) {
+ insertUse(II, Offset, AllocSize, true);
enqueueUsers(II);
return;
}
@@ -1169,16 +1308,24 @@ private:
std::tie(UsedI, I) = Uses.pop_back_val();
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
- Size =
- std::max(Size, DL.getTypeStoreSize(LI->getType()).getFixedValue());
+ TypeSize LoadSize = DL.getTypeStoreSize(LI->getType());
+ if (LoadSize.isScalable()) {
+ PI.setAborted(LI);
+ return nullptr;
+ }
+ Size = std::max(Size, LoadSize.getFixedValue());
continue;
}
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
Value *Op = SI->getOperand(0);
if (Op == UsedI)
return SI;
- Size =
- std::max(Size, DL.getTypeStoreSize(Op->getType()).getFixedValue());
+ TypeSize StoreSize = DL.getTypeStoreSize(Op->getType());
+ if (StoreSize.isScalable()) {
+ PI.setAborted(SI);
+ return nullptr;
+ }
+ Size = std::max(Size, StoreSize.getFixedValue());
continue;
}
@@ -1525,38 +1672,37 @@ static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) {
PN.eraseFromParent();
}
-sroa::SelectHandSpeculativity &
-sroa::SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) {
+SelectHandSpeculativity &
+SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) {
if (isTrueVal)
- Bitfield::set<sroa::SelectHandSpeculativity::TrueVal>(Storage, true);
+ Bitfield::set<SelectHandSpeculativity::TrueVal>(Storage, true);
else
- Bitfield::set<sroa::SelectHandSpeculativity::FalseVal>(Storage, true);
+ Bitfield::set<SelectHandSpeculativity::FalseVal>(Storage, true);
return *this;
}
-bool sroa::SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const {
- return isTrueVal
- ? Bitfield::get<sroa::SelectHandSpeculativity::TrueVal>(Storage)
- : Bitfield::get<sroa::SelectHandSpeculativity::FalseVal>(Storage);
+bool SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const {
+ return isTrueVal ? Bitfield::get<SelectHandSpeculativity::TrueVal>(Storage)
+ : Bitfield::get<SelectHandSpeculativity::FalseVal>(Storage);
}
-bool sroa::SelectHandSpeculativity::areAllSpeculatable() const {
+bool SelectHandSpeculativity::areAllSpeculatable() const {
return isSpeculatable(/*isTrueVal=*/true) &&
isSpeculatable(/*isTrueVal=*/false);
}
-bool sroa::SelectHandSpeculativity::areAnySpeculatable() const {
+bool SelectHandSpeculativity::areAnySpeculatable() const {
return isSpeculatable(/*isTrueVal=*/true) ||
isSpeculatable(/*isTrueVal=*/false);
}
-bool sroa::SelectHandSpeculativity::areNoneSpeculatable() const {
+bool SelectHandSpeculativity::areNoneSpeculatable() const {
return !areAnySpeculatable();
}
-static sroa::SelectHandSpeculativity
+static SelectHandSpeculativity
isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) {
assert(LI.isSimple() && "Only for simple loads");
- sroa::SelectHandSpeculativity Spec;
+ SelectHandSpeculativity Spec;
const DataLayout &DL = SI.getModule()->getDataLayout();
for (Value *Value : {SI.getTrueValue(), SI.getFalseValue()})
@@ -1569,8 +1715,8 @@ isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) {
return Spec;
}
-std::optional<sroa::RewriteableMemOps>
-SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) {
+std::optional<RewriteableMemOps>
+SROA::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) {
RewriteableMemOps Ops;
for (User *U : SI.users()) {
@@ -1604,7 +1750,7 @@ SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) {
continue;
}
- sroa::SelectHandSpeculativity Spec =
+ SelectHandSpeculativity Spec =
isSafeLoadOfSelectToSpeculate(*LI, SI, PreserveCFG);
if (PreserveCFG && !Spec.areAllSpeculatable())
return {}; // Give up on this `select`.
@@ -1655,7 +1801,7 @@ static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI,
template <typename T>
static void rewriteMemOpOfSelect(SelectInst &SI, T &I,
- sroa::SelectHandSpeculativity Spec,
+ SelectHandSpeculativity Spec,
DomTreeUpdater &DTU) {
assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Only for load and store!");
LLVM_DEBUG(dbgs() << " original mem op: " << I << "\n");
@@ -1711,7 +1857,7 @@ static void rewriteMemOpOfSelect(SelectInst &SI, T &I,
}
static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I,
- sroa::SelectHandSpeculativity Spec,
+ SelectHandSpeculativity Spec,
DomTreeUpdater &DTU) {
if (auto *LI = dyn_cast<LoadInst>(&I))
rewriteMemOpOfSelect(SelInst, *LI, Spec, DTU);
@@ -1722,13 +1868,13 @@ static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I,
}
static bool rewriteSelectInstMemOps(SelectInst &SI,
- const sroa::RewriteableMemOps &Ops,
+ const RewriteableMemOps &Ops,
IRBuilderTy &IRB, DomTreeUpdater *DTU) {
bool CFGChanged = false;
LLVM_DEBUG(dbgs() << " original select: " << SI << "\n");
for (const RewriteableMemOp &Op : Ops) {
- sroa::SelectHandSpeculativity Spec;
+ SelectHandSpeculativity Spec;
Instruction *I;
if (auto *const *US = std::get_if<UnspeculatableStore>(&Op)) {
I = *US;
@@ -2421,14 +2567,15 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
return V;
}
+namespace {
+
/// Visitor to rewrite instructions using p particular slice of an alloca
/// to use a new alloca.
///
/// Also implements the rewriting to vector-based accesses when the partition
/// passes the isVectorPromotionViable predicate. Most of the rewriting logic
/// lives here.
-class llvm::sroa::AllocaSliceRewriter
- : public InstVisitor<AllocaSliceRewriter, bool> {
+class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
// Befriend the base class so it can delegate to private visit methods.
friend class InstVisitor<AllocaSliceRewriter, bool>;
@@ -2436,7 +2583,7 @@ class llvm::sroa::AllocaSliceRewriter
const DataLayout &DL;
AllocaSlices &AS;
- SROAPass &Pass;
+ SROA &Pass;
AllocaInst &OldAI, &NewAI;
const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset;
Type *NewAllocaTy;
@@ -2489,12 +2636,12 @@ class llvm::sroa::AllocaSliceRewriter
if (!IsVolatile || AddrSpace == NewAI.getType()->getPointerAddressSpace())
return &NewAI;
- Type *AccessTy = NewAI.getAllocatedType()->getPointerTo(AddrSpace);
+ Type *AccessTy = IRB.getPtrTy(AddrSpace);
return IRB.CreateAddrSpaceCast(&NewAI, AccessTy);
}
public:
- AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROAPass &Pass,
+ AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass,
AllocaInst &OldAI, AllocaInst &NewAI,
uint64_t NewAllocaBeginOffset,
uint64_t NewAllocaEndOffset, bool IsIntegerPromotable,
@@ -2697,7 +2844,7 @@ private:
NewEndOffset == NewAllocaEndOffset &&
(canConvertValue(DL, NewAllocaTy, TargetTy) ||
(IsLoadPastEnd && NewAllocaTy->isIntegerTy() &&
- TargetTy->isIntegerTy()))) {
+ TargetTy->isIntegerTy() && !LI.isVolatile()))) {
Value *NewPtr =
getPtrToNewAI(LI.getPointerAddressSpace(), LI.isVolatile());
LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), NewPtr,
@@ -2732,7 +2879,7 @@ private:
"endian_shift");
}
} else {
- Type *LTy = TargetTy->getPointerTo(AS);
+ Type *LTy = IRB.getPtrTy(AS);
LoadInst *NewLI =
IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy),
getSliceAlign(), LI.isVolatile(), LI.getName());
@@ -2762,9 +2909,9 @@ private:
// basis for the new value. This allows us to replace the uses of LI with
// the computed value, and then replace the placeholder with LI, leaving
// LI only used for this computation.
- Value *Placeholder = new LoadInst(
- LI.getType(), PoisonValue::get(LI.getType()->getPointerTo(AS)), "",
- false, Align(1));
+ Value *Placeholder =
+ new LoadInst(LI.getType(), PoisonValue::get(IRB.getPtrTy(AS)), "",
+ false, Align(1));
V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,
"insert");
LI.replaceAllUsesWith(V);
@@ -2875,26 +3022,10 @@ private:
if (IntTy && V->getType()->isIntegerTy())
return rewriteIntegerStore(V, SI, AATags);
- const bool IsStorePastEnd =
- DL.getTypeStoreSize(V->getType()).getFixedValue() > SliceSize;
StoreInst *NewSI;
if (NewBeginOffset == NewAllocaBeginOffset &&
NewEndOffset == NewAllocaEndOffset &&
- (canConvertValue(DL, V->getType(), NewAllocaTy) ||
- (IsStorePastEnd && NewAllocaTy->isIntegerTy() &&
- V->getType()->isIntegerTy()))) {
- // If this is an integer store past the end of slice (and thus the bytes
- // past that point are irrelevant or this is unreachable), truncate the
- // value prior to storing.
- if (auto *VITy = dyn_cast<IntegerType>(V->getType()))
- if (auto *AITy = dyn_cast<IntegerType>(NewAllocaTy))
- if (VITy->getBitWidth() > AITy->getBitWidth()) {
- if (DL.isBigEndian())
- V = IRB.CreateLShr(V, VITy->getBitWidth() - AITy->getBitWidth(),
- "endian_shift");
- V = IRB.CreateTrunc(V, AITy, "load.trunc");
- }
-
+ canConvertValue(DL, V->getType(), NewAllocaTy)) {
V = convertValue(DL, IRB, V, NewAllocaTy);
Value *NewPtr =
getPtrToNewAI(SI.getPointerAddressSpace(), SI.isVolatile());
@@ -2903,7 +3034,7 @@ private:
IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), SI.isVolatile());
} else {
unsigned AS = SI.getPointerAddressSpace();
- Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS));
+ Value *NewPtr = getNewAllocaSlicePtr(IRB, IRB.getPtrTy(AS));
NewSI =
IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile());
}
@@ -3126,8 +3257,7 @@ private:
if (IsDest) {
// Update the address component of linked dbg.assigns.
for (auto *DAI : at::getAssignmentMarkers(&II)) {
- if (any_of(DAI->location_ops(),
- [&](Value *V) { return V == II.getDest(); }) ||
+ if (llvm::is_contained(DAI->location_ops(), II.getDest()) ||
DAI->getAddress() == II.getDest())
DAI->replaceVariableLocationOp(II.getDest(), AdjustedPtr);
}
@@ -3259,7 +3389,6 @@ private:
} else {
OtherTy = NewAllocaTy;
}
- OtherPtrTy = OtherTy->getPointerTo(OtherAS);
Value *AdjPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy,
OtherPtr->getName() + ".");
@@ -3337,7 +3466,8 @@ private:
}
bool visitIntrinsicInst(IntrinsicInst &II) {
- assert((II.isLifetimeStartOrEnd() || II.isDroppable()) &&
+ assert((II.isLifetimeStartOrEnd() || II.isLaunderOrStripInvariantGroup() ||
+ II.isDroppable()) &&
"Unexpected intrinsic!");
LLVM_DEBUG(dbgs() << " original: " << II << "\n");
@@ -3351,6 +3481,9 @@ private:
return true;
}
+ if (II.isLaunderOrStripInvariantGroup())
+ return true;
+
assert(II.getArgOperand(1) == OldPtr);
// Lifetime intrinsics are only promotable if they cover the whole alloca.
// Therefore, we drop lifetime intrinsics which don't cover the whole
@@ -3368,7 +3501,7 @@ private:
NewEndOffset - NewBeginOffset);
// Lifetime intrinsics always expect an i8* so directly get such a pointer
// for the new alloca slice.
- Type *PointerTy = IRB.getInt8PtrTy(OldPtr->getType()->getPointerAddressSpace());
+ Type *PointerTy = IRB.getPtrTy(OldPtr->getType()->getPointerAddressSpace());
Value *Ptr = getNewAllocaSlicePtr(IRB, PointerTy);
Value *New;
if (II.getIntrinsicID() == Intrinsic::lifetime_start)
@@ -3422,7 +3555,8 @@ private:
// dominate the PHI.
IRBuilderBase::InsertPointGuard Guard(IRB);
if (isa<PHINode>(OldPtr))
- IRB.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt());
+ IRB.SetInsertPoint(OldPtr->getParent(),
+ OldPtr->getParent()->getFirstInsertionPt());
else
IRB.SetInsertPoint(OldPtr);
IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc());
@@ -3472,8 +3606,6 @@ private:
}
};
-namespace {
-
/// Visitor to rewrite aggregate loads and stores as scalar.
///
/// This pass aggressively rewrites all aggregate loads and stores on
@@ -3811,7 +3943,7 @@ private:
SmallVector<Value *, 4> Index(GEPI.indices());
bool IsInBounds = GEPI.isInBounds();
- IRB.SetInsertPoint(GEPI.getParent()->getFirstNonPHI());
+ IRB.SetInsertPoint(GEPI.getParent(), GEPI.getParent()->getFirstNonPHIIt());
PHINode *NewPN = IRB.CreatePHI(GEPI.getType(), PHI->getNumIncomingValues(),
PHI->getName() + ".sroa.phi");
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
@@ -4046,7 +4178,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
/// there all along.
///
/// \returns true if any changes are made.
-bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
+bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
LLVM_DEBUG(dbgs() << "Pre-splitting loads and stores\n");
// Track the loads and stores which are candidates for pre-splitting here, in
@@ -4268,7 +4400,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
for (;;) {
auto *PartTy = Type::getIntNTy(LI->getContext(), PartSize * 8);
auto AS = LI->getPointerAddressSpace();
- auto *PartPtrTy = PartTy->getPointerTo(AS);
+ auto *PartPtrTy = LI->getPointerOperandType();
LoadInst *PLoad = IRB.CreateAlignedLoad(
PartTy,
getAdjustedPtr(IRB, DL, BasePtr,
@@ -4323,8 +4455,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) {
LoadInst *PLoad = SplitLoads[Idx];
uint64_t PartOffset = Idx == 0 ? 0 : Offsets.Splits[Idx - 1];
- auto *PartPtrTy =
- PLoad->getType()->getPointerTo(SI->getPointerAddressSpace());
+ auto *PartPtrTy = SI->getPointerOperandType();
auto AS = SI->getPointerAddressSpace();
StoreInst *PStore = IRB.CreateAlignedStore(
@@ -4404,8 +4535,8 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
int Idx = 0, Size = Offsets.Splits.size();
for (;;) {
auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8);
- auto *LoadPartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace());
- auto *StorePartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace());
+ auto *LoadPartPtrTy = LI->getPointerOperandType();
+ auto *StorePartPtrTy = SI->getPointerOperandType();
// Either lookup a split load or create one.
LoadInst *PLoad;
@@ -4526,8 +4657,8 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
/// appropriate new offsets. It also evaluates how successful the rewrite was
/// at enabling promotion and if it was successful queues the alloca to be
/// promoted.
-AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
- Partition &P) {
+AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
+ Partition &P) {
// Try to compute a friendly type for this partition of the alloca. This
// won't always succeed, in which case we fall back to a legal integer type
// or an i8 array of an appropriate size.
@@ -4709,7 +4840,7 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
/// Walks the slices of an alloca and form partitions based on them,
/// rewriting each of their uses.
-bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
+bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
if (AS.begin() == AS.end())
return false;
@@ -4900,7 +5031,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
}
/// Clobber a use with poison, deleting the used value if it becomes dead.
-void SROAPass::clobberUse(Use &U) {
+void SROA::clobberUse(Use &U) {
Value *OldV = U;
// Replace the use with an poison value.
U = PoisonValue::get(OldV->getType());
@@ -4920,7 +5051,7 @@ void SROAPass::clobberUse(Use &U) {
/// the slices of the alloca, and then hands it off to be split and
/// rewritten as needed.
std::pair<bool /*Changed*/, bool /*CFGChanged*/>
-SROAPass::runOnAlloca(AllocaInst &AI) {
+SROA::runOnAlloca(AllocaInst &AI) {
bool Changed = false;
bool CFGChanged = false;
@@ -5002,7 +5133,7 @@ SROAPass::runOnAlloca(AllocaInst &AI) {
///
/// We also record the alloca instructions deleted here so that they aren't
/// subsequently handed to mem2reg to promote.
-bool SROAPass::deleteDeadInstructions(
+bool SROA::deleteDeadInstructions(
SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) {
bool Changed = false;
while (!DeadInsts.empty()) {
@@ -5043,7 +5174,7 @@ bool SROAPass::deleteDeadInstructions(
/// This attempts to promote whatever allocas have been identified as viable in
/// the PromotableAllocas list. If that list is empty, there is nothing to do.
/// This function returns whether any promotion occurred.
-bool SROAPass::promoteAllocas(Function &F) {
+bool SROA::promoteAllocas(Function &F) {
if (PromotableAllocas.empty())
return false;
@@ -5060,12 +5191,8 @@ bool SROAPass::promoteAllocas(Function &F) {
return true;
}
-PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU,
- AssumptionCache &RunAC) {
+std::pair<bool /*Changed*/, bool /*CFGChanged*/> SROA::runSROA(Function &F) {
LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n");
- C = &F.getContext();
- DTU = &RunDTU;
- AC = &RunAC;
const DataLayout &DL = F.getParent()->getDataLayout();
BasicBlock &EntryBB = F.getEntryBlock();
@@ -5116,56 +5243,50 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU,
assert((!CFGChanged || !PreserveCFG) &&
"Should not have modified the CFG when told to preserve it.");
- if (!Changed)
- return PreservedAnalyses::all();
-
- if (isAssignmentTrackingEnabled(*F.getParent())) {
+ if (Changed && isAssignmentTrackingEnabled(*F.getParent())) {
for (auto &BB : F)
RemoveRedundantDbgInstrs(&BB);
}
- PreservedAnalyses PA;
- if (!CFGChanged)
- PA.preserveSet<CFGAnalyses>();
- PA.preserve<DominatorTreeAnalysis>();
- return PA;
-}
-
-PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT,
- AssumptionCache &RunAC) {
- DomTreeUpdater DTU(RunDT, DomTreeUpdater::UpdateStrategy::Lazy);
- return runImpl(F, DTU, RunAC);
+ return {Changed, CFGChanged};
}
PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) {
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
- return runImpl(F, DT, AC);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ auto [Changed, CFGChanged] =
+ SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F);
+ if (!Changed)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ if (!CFGChanged)
+ PA.preserveSet<CFGAnalyses>();
+ PA.preserve<DominatorTreeAnalysis>();
+ return PA;
}
void SROAPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<SROAPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << (PreserveCFG ? "<preserve-cfg>" : "<modify-cfg>");
+ OS << (PreserveCFG == SROAOptions::PreserveCFG ? "<preserve-cfg>"
+ : "<modify-cfg>");
}
-SROAPass::SROAPass(SROAOptions PreserveCFG_)
- : PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {}
+SROAPass::SROAPass(SROAOptions PreserveCFG) : PreserveCFG(PreserveCFG) {}
+
+namespace {
/// A legacy pass for the legacy pass manager that wraps the \c SROA pass.
-///
-/// This is in the llvm namespace purely to allow it to be a friend of the \c
-/// SROA pass.
-class llvm::sroa::SROALegacyPass : public FunctionPass {
- /// The SROA implementation.
- SROAPass Impl;
+class SROALegacyPass : public FunctionPass {
+ SROAOptions PreserveCFG;
public:
static char ID;
SROALegacyPass(SROAOptions PreserveCFG = SROAOptions::PreserveCFG)
- : FunctionPass(ID), Impl(PreserveCFG) {
+ : FunctionPass(ID), PreserveCFG(PreserveCFG) {
initializeSROALegacyPassPass(*PassRegistry::getPassRegistry());
}
@@ -5173,10 +5294,13 @@ public:
if (skipFunction(F))
return false;
- auto PA = Impl.runImpl(
- F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
- getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F));
- return !PA.areAllPreserved();
+ DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ AssumptionCache &AC =
+ getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ auto [Changed, _] =
+ SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F);
+ return Changed;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -5189,6 +5313,8 @@ public:
StringRef getPassName() const override { return "SROA"; }
};
+} // end anonymous namespace
+
char SROALegacyPass::ID = 0;
FunctionPass *llvm::createSROAPass(bool PreserveCFG) {
diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp
index 37b032e4d7c7..4ce6ce93be33 100644
--- a/llvm/lib/Transforms/Scalar/Scalar.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalar.cpp
@@ -21,41 +21,27 @@ using namespace llvm;
void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeConstantHoistingLegacyPassPass(Registry);
initializeDCELegacyPassPass(Registry);
- initializeScalarizerLegacyPassPass(Registry);
- initializeGuardWideningLegacyPassPass(Registry);
- initializeLoopGuardWideningLegacyPassPass(Registry);
initializeGVNLegacyPassPass(Registry);
initializeEarlyCSELegacyPassPass(Registry);
initializeEarlyCSEMemSSALegacyPassPass(Registry);
- initializeMakeGuardsExplicitLegacyPassPass(Registry);
initializeFlattenCFGLegacyPassPass(Registry);
initializeInferAddressSpacesPass(Registry);
initializeInstSimplifyLegacyPassPass(Registry);
initializeLegacyLICMPassPass(Registry);
- initializeLegacyLoopSinkPassPass(Registry);
initializeLoopDataPrefetchLegacyPassPass(Registry);
- initializeLoopInstSimplifyLegacyPassPass(Registry);
- initializeLoopPredicationLegacyPassPass(Registry);
initializeLoopRotateLegacyPassPass(Registry);
initializeLoopStrengthReducePass(Registry);
initializeLoopUnrollPass(Registry);
initializeLowerAtomicLegacyPassPass(Registry);
initializeLowerConstantIntrinsicsPass(Registry);
- initializeLowerExpectIntrinsicPass(Registry);
- initializeLowerGuardIntrinsicLegacyPassPass(Registry);
- initializeLowerWidenableConditionLegacyPassPass(Registry);
initializeMergeICmpsLegacyPassPass(Registry);
- initializeMergedLoadStoreMotionLegacyPassPass(Registry);
initializeNaryReassociateLegacyPassPass(Registry);
initializePartiallyInlineLibCallsLegacyPassPass(Registry);
initializeReassociateLegacyPassPass(Registry);
- initializeRedundantDbgInstEliminationPass(Registry);
- initializeRegToMemLegacyPass(Registry);
initializeScalarizeMaskedMemIntrinLegacyPassPass(Registry);
initializeSROALegacyPassPass(Registry);
initializeCFGSimplifyPassPass(Registry);
initializeStructurizeCFGLegacyPassPass(Registry);
- initializeSimpleLoopUnswitchLegacyPassPass(Registry);
initializeSinkingLegacyPassPass(Registry);
initializeTailCallElimPass(Registry);
initializeTLSVariableHoistLegacyPassPass(Registry);
@@ -63,5 +49,4 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
initializeSpeculativeExecutionLegacyPassPass(Registry);
initializeStraightLineStrengthReduceLegacyPassPass(Registry);
initializePlaceBackedgeSafepointsLegacyPassPass(Registry);
- initializeLoopSimplifyCFGLegacyPassPass(Registry);
}
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 86b55dfd304a..3eca9ac7c267 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -36,8 +36,6 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -282,12 +280,10 @@ T getWithDefaultOverride(const cl::opt<T> &ClOption,
class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
public:
- ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT,
- ScalarizerPassOptions Options)
- : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT),
- ScalarizeVariableInsertExtract(
- getWithDefaultOverride(ClScalarizeVariableInsertExtract,
- Options.ScalarizeVariableInsertExtract)),
+ ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options)
+ : DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride(
+ ClScalarizeVariableInsertExtract,
+ Options.ScalarizeVariableInsertExtract)),
ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
Options.ScalarizeLoadStore)),
ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
@@ -337,8 +333,6 @@ private:
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
- unsigned ParallelLoopAccessMDKind;
-
DominatorTree *DT;
const bool ScalarizeVariableInsertExtract;
@@ -346,31 +340,8 @@ private:
const unsigned ScalarizeMinBits;
};
-class ScalarizerLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- ScalarizerLegacyPass() : FunctionPass(ID) {
- initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage& AU) const override {
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- }
-};
-
} // end anonymous namespace
-char ScalarizerLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
- "Scalarize vector operations", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
- "Scalarize vector operations", false, false)
-
Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
const VectorSplit &VS, ValueVector *cachePtr)
: BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
@@ -443,22 +414,6 @@ Value *Scatterer::operator[](unsigned Frag) {
return CV[Frag];
}
-bool ScalarizerLegacyPass::runOnFunction(Function &F) {
- if (skipFunction(F))
- return false;
-
- Module &M = *F.getParent();
- unsigned ParallelLoopAccessMDKind =
- M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
- DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions());
- return Impl.visit(F);
-}
-
-FunctionPass *llvm::createScalarizerPass() {
- return new ScalarizerLegacyPass();
-}
-
bool ScalarizerVisitor::visit(Function &F) {
assert(Gathered.empty() && Scattered.empty());
@@ -558,7 +513,7 @@ bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
|| Tag == LLVMContext::MD_invariant_load
|| Tag == LLVMContext::MD_alias_scope
|| Tag == LLVMContext::MD_noalias
- || Tag == ParallelLoopAccessMDKind
+ || Tag == LLVMContext::MD_mem_parallel_loop_access
|| Tag == LLVMContext::MD_access_group);
}
@@ -730,7 +685,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
// vector type, which is true for all current intrinsics.
for (unsigned I = 0; I != NumArgs; ++I) {
Value *OpI = CI.getOperand(I);
- if (auto *OpVecTy = dyn_cast<FixedVectorType>(OpI->getType())) {
+ if ([[maybe_unused]] auto *OpVecTy =
+ dyn_cast<FixedVectorType>(OpI->getType())) {
assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
@@ -1253,11 +1209,8 @@ bool ScalarizerVisitor::finish() {
}
PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
- Module &M = *F.getParent();
- unsigned ParallelLoopAccessMDKind =
- M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
- ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options);
+ ScalarizerVisitor Impl(DT, Options);
bool Changed = Impl.visit(F);
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 89d0b7c33e0d..b8c9d9d100f1 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -524,7 +524,7 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
// FIXME: this does not appear to be covered by any tests
// (with x86/aarch64 backends at least)
if (BO->getOpcode() == Instruction::Or &&
- !haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT))
+ !haveNoCommonBitsSet(LHS, RHS, SimplifyQuery(DL, DT, /*AC*/ nullptr, BO)))
return false;
// FIXME: We don't currently support constants from the RHS of subs,
@@ -661,15 +661,16 @@ Value *ConstantOffsetExtractor::applyExts(Value *V) {
// in the reversed order.
for (CastInst *I : llvm::reverse(ExtInsts)) {
if (Constant *C = dyn_cast<Constant>(Current)) {
- // If Current is a constant, apply s/zext using ConstantExpr::getCast.
- // ConstantExpr::getCast emits a ConstantInt if C is a ConstantInt.
- Current = ConstantExpr::getCast(I->getOpcode(), C, I->getType());
- } else {
- Instruction *Ext = I->clone();
- Ext->setOperand(0, Current);
- Ext->insertBefore(IP);
- Current = Ext;
+ // Try to constant fold the cast.
+ Current = ConstantFoldCastOperand(I->getOpcode(), C, I->getType(), DL);
+ if (Current)
+ continue;
}
+
+ Instruction *Ext = I->clone();
+ Ext->setOperand(0, Current);
+ Ext->insertBefore(IP);
+ Current = Ext;
}
return Current;
}
@@ -830,7 +831,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
if (GTI.isSequential()) {
// Constant offsets of scalable types are not really constant.
- if (isa<ScalableVectorType>(GTI.getIndexedType()))
+ if (GTI.getIndexedType()->isScalableTy())
continue;
// Tries to extract a constant offset from this GEP index.
@@ -1019,7 +1020,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
if (GTI.isSequential()) {
// Constant offsets of scalable types are not really constant.
- if (isa<ScalableVectorType>(GTI.getIndexedType()))
+ if (GTI.getIndexedType()->isScalableTy())
continue;
// Splits this GEP index into a variadic part and a constant offset, and
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index ad7d34b61470..7eb0ba1c2c17 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -24,7 +24,6 @@
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopIterator.h"
-#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/MustExecute.h"
@@ -46,8 +45,6 @@
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -368,10 +365,11 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB,
bool FullUnswitch) {
assert(&ExitBB != &UnswitchedBB &&
"Must have different loop exit and unswitched blocks!");
- Instruction *InsertPt = &*UnswitchedBB.begin();
+ BasicBlock::iterator InsertPt = UnswitchedBB.begin();
for (PHINode &PN : ExitBB.phis()) {
auto *NewPN = PHINode::Create(PN.getType(), /*NumReservedValues*/ 2,
- PN.getName() + ".split", InsertPt);
+ PN.getName() + ".split");
+ NewPN->insertBefore(InsertPt);
// Walk backwards over the old PHI node's inputs to minimize the cost of
// removing each one. We have to do this weird loop manually so that we
@@ -609,7 +607,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
UnswitchedBB = LoopExitBB;
} else {
UnswitchedBB =
- SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI, MSSAU);
+ SplitBlock(LoopExitBB, LoopExitBB->begin(), &DT, &LI, MSSAU, "", false);
}
if (MSSAU && VerifyMemorySSA)
@@ -623,7 +621,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
// If fully unswitching, we can use the existing branch instruction.
// Splice it into the old PH to gate reaching the new preheader and re-point
// its successors.
- OldPH->splice(OldPH->end(), BI.getParent(), BI.getIterator());
+ BI.moveBefore(*OldPH, OldPH->end());
BI.setCondition(Cond);
if (MSSAU) {
// Temporarily clone the terminator, to make MSSA update cheaper by
@@ -882,7 +880,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH);
} else {
auto *SplitBB =
- SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI, MSSAU);
+ SplitBlock(DefaultExitBB, DefaultExitBB->begin(), &DT, &LI, MSSAU);
rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB,
*ParentBB, *OldPH,
/*FullUnswitch*/ true);
@@ -909,7 +907,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB];
if (!SplitExitBB) {
// If this is the first time we see this, do the split and remember it.
- SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU);
+ SplitExitBB = SplitBlock(ExitBB, ExitBB->begin(), &DT, &LI, MSSAU);
rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB,
*ParentBB, *OldPH,
/*FullUnswitch*/ true);
@@ -1210,7 +1208,7 @@ static BasicBlock *buildClonedLoopBlocks(
// place to merge the CFG, so split the exit first. This is always safe to
// do because there cannot be any non-loop predecessors of a loop exit in
// loop simplified form.
- auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU);
+ auto *MergeBB = SplitBlock(ExitBB, ExitBB->begin(), &DT, &LI, MSSAU);
// Rearrange the names to make it easier to write test cases by having the
// exit block carry the suffix rather than the merge block carrying the
@@ -1246,8 +1244,8 @@ static BasicBlock *buildClonedLoopBlocks(
SE->forgetValue(&I);
auto *MergePN =
- PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi",
- &*MergeBB->getFirstInsertionPt());
+ PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi");
+ MergePN->insertBefore(MergeBB->getFirstInsertionPt());
I.replaceAllUsesWith(MergePN);
MergePN->addIncoming(&I, ExitBB);
MergePN->addIncoming(&ClonedI, ClonedExitBB);
@@ -1259,8 +1257,11 @@ static BasicBlock *buildClonedLoopBlocks(
// everything available. Also, we have inserted new instructions which may
// include assume intrinsics, so we update the assumption cache while
// processing this.
+ Module *M = ClonedPH->getParent()->getParent();
for (auto *ClonedBB : NewBlocks)
for (Instruction &I : *ClonedBB) {
+ RemapDPValueRange(M, I.getDbgValueRange(), VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
RemapInstruction(&I, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
if (auto *II = dyn_cast<AssumeInst>(&I))
@@ -1684,13 +1685,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
BB->eraseFromParent();
}
-static void
-deleteDeadBlocksFromLoop(Loop &L,
- SmallVectorImpl<BasicBlock *> &ExitBlocks,
- DominatorTree &DT, LoopInfo &LI,
- MemorySSAUpdater *MSSAU,
- ScalarEvolution *SE,
- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static void deleteDeadBlocksFromLoop(Loop &L,
+ SmallVectorImpl<BasicBlock *> &ExitBlocks,
+ DominatorTree &DT, LoopInfo &LI,
+ MemorySSAUpdater *MSSAU,
+ ScalarEvolution *SE,
+ LPMUpdater &LoopUpdater) {
// Find all the dead blocks tied to this loop, and remove them from their
// successors.
SmallSetVector<BasicBlock *, 8> DeadBlockSet;
@@ -1740,7 +1740,7 @@ deleteDeadBlocksFromLoop(Loop &L,
}) &&
"If the child loop header is dead all blocks in the child loop must "
"be dead as well!");
- DestroyLoopCB(*ChildL, ChildL->getName());
+ LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName());
if (SE)
SE->forgetBlockAndLoopDispositions();
LI.destroy(ChildL);
@@ -2084,8 +2084,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
ParentL->removeChildLoop(llvm::find(*ParentL, &L));
else
LI.removeLoop(llvm::find(LI, &L));
- // markLoopAsDeleted for L should be triggered by the caller (it is typically
- // done by using the UnswitchCB callback).
+ // markLoopAsDeleted for L should be triggered by the caller (it is
+ // typically done within postUnswitch).
if (SE)
SE->forgetBlockAndLoopDispositions();
LI.destroy(&L);
@@ -2122,17 +2122,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
} while (!DomWorklist.empty());
}
+void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
+ bool CurrentLoopValid, bool PartiallyInvariant,
+ bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
+ // If we did a non-trivial unswitch, we have added new (cloned) loops.
+ if (!NewLoops.empty())
+ U.addSiblingLoops(NewLoops);
+
+ // If the current loop remains valid, we should revisit it to catch any
+ // other unswitch opportunities. Otherwise, we need to mark it as deleted.
+ if (CurrentLoopValid) {
+ if (PartiallyInvariant) {
+ // Mark the new loop as partially unswitched, to avoid unswitching on
+ // the same condition again.
+ auto &Context = L.getHeader()->getContext();
+ MDNode *DisableUnswitchMD = MDNode::get(
+ Context,
+ MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
+ MDNode *NewLoopID = makePostTransformationMetadata(
+ Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
+ {DisableUnswitchMD});
+ L.setLoopID(NewLoopID);
+ } else if (InjectedCondition) {
+ // Do the same for injection of invariant conditions.
+ auto &Context = L.getHeader()->getContext();
+ MDNode *DisableUnswitchMD = MDNode::get(
+ Context,
+ MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
+ MDNode *NewLoopID = makePostTransformationMetadata(
+ Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
+ {DisableUnswitchMD});
+ L.setLoopID(NewLoopID);
+ } else
+ U.revisitCurrentLoop();
+ } else
+ U.markLoopAsDeleted(L, LoopName);
+}
+
static void unswitchNontrivialInvariants(
Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
- AssumptionCache &AC,
- function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
- function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze) {
+ AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
+ LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) {
auto *ParentBB = TI.getParent();
BranchInst *BI = dyn_cast<BranchInst>(&TI);
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
+ // Save the current loop name in a variable so that we can report it even
+ // after it has been deleted.
+ std::string LoopName(L.getName());
+
// We can only unswitch switches, conditional branches with an invariant
// condition, or combining invariant conditions with an instruction or
// partially invariant instructions.
@@ -2295,7 +2334,7 @@ static void unswitchNontrivialInvariants(
if (FullUnswitch) {
// Splice the terminator from the original loop and rewrite its
// successors.
- SplitBB->splice(SplitBB->end(), ParentBB, TI.getIterator());
+ TI.moveBefore(*SplitBB, SplitBB->end());
// Keep a clone of the terminator for MSSA updates.
Instruction *NewTI = TI.clone();
@@ -2445,7 +2484,7 @@ static void unswitchNontrivialInvariants(
// Now that our cloned loops have been built, we can update the original loop.
// First we delete the dead blocks from it and then we rebuild the loop
// structure taking these deletions into account.
- deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB);
+ deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater);
if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -2581,7 +2620,8 @@ static void unswitchNontrivialInvariants(
for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
if (UpdatedL->getParentLoop() == ParentL)
SibLoops.push_back(UpdatedL);
- UnswitchCB(IsStillLoop, PartiallyInvariant, SibLoops);
+ postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant,
+ InjectedCondition, SibLoops);
if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -2979,13 +3019,6 @@ static bool shouldTryInjectInvariantCondition(
/// the metadata.
bool shouldTryInjectBasingOnMetadata(const BranchInst *BI,
const BasicBlock *TakenSucc) {
- // Skip branches that have already been unswithed this way. After successful
- // unswitching of injected condition, we will still have a copy of this loop
- // which looks exactly the same as original one. To prevent the 2nd attempt
- // of unswitching it in the same pass, mark this branch as "nothing to do
- // here".
- if (BI->hasMetadata("llvm.invariant.condition.injection.disabled"))
- return false;
SmallVector<uint32_t> Weights;
if (!extractBranchWeights(*BI, Weights))
return false;
@@ -3060,7 +3093,6 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L,
auto *InjectedCond =
ICmpInst::Create(Instruction::ICmp, Pred, LHS, RHS, "injected.cond",
Preheader->getTerminator());
- auto *OldCond = TI->getCondition();
BasicBlock *CheckBlock = BasicBlock::Create(Ctx, BB->getName() + ".check",
BB->getParent(), InLoopSucc);
@@ -3069,12 +3101,9 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L,
Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock);
Builder.SetInsertPoint(CheckBlock);
- auto *NewTerm = Builder.CreateCondBr(OldCond, InLoopSucc, OutOfLoopSucc);
-
+ Builder.CreateCondBr(TI->getCondition(), TI->getSuccessor(0),
+ TI->getSuccessor(1));
TI->eraseFromParent();
- // Prevent infinite unswitching.
- NewTerm->setMetadata("llvm.invariant.condition.injection.disabled",
- MDNode::get(BB->getContext(), {}));
// Fixup phis.
for (auto &I : *InLoopSucc) {
@@ -3439,12 +3468,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
}
-static bool unswitchBestCondition(
- Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
- AAResults &AA, TargetTransformInfo &TTI,
- function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
+ AssumptionCache &AC, AAResults &AA,
+ TargetTransformInfo &TTI, ScalarEvolution *SE,
+ MemorySSAUpdater *MSSAU,
+ LPMUpdater &LoopUpdater) {
// Collect all invariant conditions within this loop (as opposed to an inner
// loop which would be handled when visiting that inner loop).
SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
@@ -3452,9 +3480,10 @@ static bool unswitchBestCondition(
Instruction *PartialIVCondBranch = nullptr;
collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
PartialIVCondBranch, L, LI, AA, MSSAU);
- collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
- PartialIVCondBranch, L, DT, LI, AA,
- MSSAU);
+ if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.injection.disable"))
+ collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
+ PartialIVCondBranch, L, DT, LI, AA,
+ MSSAU);
// If we didn't find any candidates, we're done.
if (UnswitchCandidates.empty())
return false;
@@ -3475,8 +3504,11 @@ static bool unswitchBestCondition(
return false;
}
- if (Best.hasPendingInjection())
+ bool InjectedCondition = false;
+ if (Best.hasPendingInjection()) {
Best = injectPendingInvariantConditions(Best, L, DT, LI, AC, MSSAU);
+ InjectedCondition = true;
+ }
assert(!Best.hasPendingInjection() &&
"All injections should have been done by now!");
@@ -3503,8 +3535,8 @@ static bool unswitchBestCondition(
LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost
<< ") terminator: " << *Best.TI << "\n");
unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
- LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
- InsertFreeze);
+ LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze,
+ InjectedCondition);
return true;
}
@@ -3523,20 +3555,18 @@ static bool unswitchBestCondition(
/// true, we will attempt to do non-trivial unswitching as well as trivial
/// unswitching.
///
-/// The `UnswitchCB` callback provided will be run after unswitching is
-/// complete, with the first parameter set to `true` if the provided loop
-/// remains a loop, and a list of new sibling loops created.
+/// The `postUnswitch` function will be run after unswitching is complete
+/// with information on whether or not the provided loop remains a loop and
+/// a list of new sibling loops created.
///
/// If `SE` is non-null, we will update that analysis based on the unswitching
/// done.
-static bool
-unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
- AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
- bool NonTrivial,
- function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
- ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
+ AssumptionCache &AC, AAResults &AA,
+ TargetTransformInfo &TTI, bool Trivial,
+ bool NonTrivial, ScalarEvolution *SE,
+ MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI,
+ BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) {
assert(L.isRecursivelyLCSSAForm(DT, LI) &&
"Loops must be in LCSSA form before unswitching.");
@@ -3548,7 +3578,9 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
// If we unswitched successfully we will want to clean up the loop before
// processing it further so just mark it as unswitched and return.
- UnswitchCB(/*CurrentLoopValid*/ true, false, {});
+ postUnswitch(L, LoopUpdater, L.getName(),
+ /*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
+ /*InjectedCondition*/ false, {});
return true;
}
@@ -3617,8 +3649,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
// Try to unswitch the best invariant condition. We prefer this full unswitch to
// a partial unswitch when possible below the threshold.
- if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU,
- DestroyLoopCB))
+ if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater))
return true;
// No other opportunities to unswitch.
@@ -3638,41 +3669,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L
<< "\n");
- // Save the current loop name in a variable so that we can report it even
- // after it has been deleted.
- std::string LoopName = std::string(L.getName());
-
- auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
- bool PartiallyInvariant,
- ArrayRef<Loop *> NewLoops) {
- // If we did a non-trivial unswitch, we have added new (cloned) loops.
- if (!NewLoops.empty())
- U.addSiblingLoops(NewLoops);
-
- // If the current loop remains valid, we should revisit it to catch any
- // other unswitch opportunities. Otherwise, we need to mark it as deleted.
- if (CurrentLoopValid) {
- if (PartiallyInvariant) {
- // Mark the new loop as partially unswitched, to avoid unswitching on
- // the same condition again.
- auto &Context = L.getHeader()->getContext();
- MDNode *DisableUnswitchMD = MDNode::get(
- Context,
- MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
- MDNode *NewLoopID = makePostTransformationMetadata(
- Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
- {DisableUnswitchMD});
- L.setLoopID(NewLoopID);
- } else
- U.revisitCurrentLoop();
- } else
- U.markLoopAsDeleted(L, LoopName);
- };
-
- auto DestroyLoopCB = [&U](Loop &L, StringRef Name) {
- U.markLoopAsDeleted(L, Name);
- };
-
std::optional<MemorySSAUpdater> MSSAU;
if (AR.MSSA) {
MSSAU = MemorySSAUpdater(AR.MSSA);
@@ -3680,8 +3676,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
AR.MSSA->verifyMemorySSA();
}
if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial,
- UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI,
- DestroyLoopCB))
+ &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U))
return PreservedAnalyses::all();
if (AR.MSSA && VerifyMemorySSA)
@@ -3707,104 +3702,3 @@ void SimpleLoopUnswitchPass::printPipeline(
OS << (Trivial ? "" : "no-") << "trivial";
OS << '>';
}
-
-namespace {
-
-class SimpleLoopUnswitchLegacyPass : public LoopPass {
- bool NonTrivial;
-
-public:
- static char ID; // Pass ID, replacement for typeid
-
- explicit SimpleLoopUnswitchLegacyPass(bool NonTrivial = false)
- : LoopPass(ID), NonTrivial(NonTrivial) {
- initializeSimpleLoopUnswitchLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnLoop(Loop *L, LPPassManager &LPM) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<MemorySSAWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- getLoopAnalysisUsage(AU);
- }
-};
-
-} // end anonymous namespace
-
-bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
- if (skipLoop(L))
- return false;
-
- Function &F = *L->getHeader()->getParent();
-
- LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L
- << "\n");
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();
- MemorySSAUpdater MSSAU(MSSA);
-
- auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
- auto *SE = SEWP ? &SEWP->getSE() : nullptr;
-
- auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, bool PartiallyInvariant,
- ArrayRef<Loop *> NewLoops) {
- // If we did a non-trivial unswitch, we have added new (cloned) loops.
- for (auto *NewL : NewLoops)
- LPM.addLoop(*NewL);
-
- // If the current loop remains valid, re-add it to the queue. This is
- // a little wasteful as we'll finish processing the current loop as well,
- // but it is the best we can do in the old PM.
- if (CurrentLoopValid) {
- // If the current loop has been unswitched using a partially invariant
- // condition, we should not re-add the current loop to avoid unswitching
- // on the same condition again.
- if (!PartiallyInvariant)
- LPM.addLoop(*L);
- } else
- LPM.markLoopAsDeleted(*L);
- };
-
- auto DestroyLoopCB = [&LPM](Loop &L, StringRef /* Name */) {
- LPM.markLoopAsDeleted(L);
- };
-
- if (VerifyMemorySSA)
- MSSA->verifyMemorySSA();
- bool Changed =
- unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, UnswitchCB, SE,
- &MSSAU, nullptr, nullptr, DestroyLoopCB);
-
- if (VerifyMemorySSA)
- MSSA->verifyMemorySSA();
-
- // Historically this pass has had issues with the dominator tree so verify it
- // in asserts builds.
- assert(DT.verify(DominatorTree::VerificationLevel::Fast));
-
- return Changed;
-}
-
-char SimpleLoopUnswitchLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch",
- "Simple unswitch loops", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
-INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch",
- "Simple unswitch loops", false, false)
-
-Pass *llvm::createSimpleLoopUnswitchLegacyPass(bool NonTrivial) {
- return new SimpleLoopUnswitchLegacyPass(NonTrivial);
-}
diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp
index 8b99f73b850b..46bcfd6b41ce 100644
--- a/llvm/lib/Transforms/Scalar/Sink.cpp
+++ b/llvm/lib/Transforms/Scalar/Sink.cpp
@@ -67,9 +67,8 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo,
assert(Inst && "Instruction to be sunk is null");
assert(SuccToSinkTo && "Candidate sink target is null");
- // It's never legal to sink an instruction into a block which terminates in an
- // EH-pad.
- if (SuccToSinkTo->getTerminator()->isExceptionalTerminator())
+ // It's never legal to sink an instruction into an EH-pad block.
+ if (SuccToSinkTo->isEHPad())
return false;
// If the block has multiple predecessors, this would introduce computation
@@ -131,15 +130,16 @@ static bool SinkInstruction(Instruction *Inst,
for (Use &U : Inst->uses()) {
Instruction *UseInst = cast<Instruction>(U.getUser());
BasicBlock *UseBlock = UseInst->getParent();
- // Don't worry about dead users.
- if (!DT.isReachableFromEntry(UseBlock))
- continue;
if (PHINode *PN = dyn_cast<PHINode>(UseInst)) {
// PHI nodes use the operand in the predecessor block, not the block with
// the PHI.
unsigned Num = PHINode::getIncomingValueNumForOperand(U.getOperandNo());
UseBlock = PN->getIncomingBlock(Num);
}
+ // Don't worry about dead users.
+ if (!DT.isReachableFromEntry(UseBlock))
+ continue;
+
if (SuccToSinkTo)
SuccToSinkTo = DT.findNearestCommonDominator(SuccToSinkTo, UseBlock);
else
diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
index e866fe681127..7a5318d4404c 100644
--- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
+++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp
@@ -316,7 +316,7 @@ bool SpeculativeExecutionPass::considerHoistingFromTo(
auto Current = I;
++I;
if (!NotHoisted.count(&*Current)) {
- Current->moveBefore(ToBlock.getTerminator());
+ Current->moveBeforePreserving(ToBlock.getTerminator());
}
}
return true;
@@ -346,4 +346,14 @@ PreservedAnalyses SpeculativeExecutionPass::run(Function &F,
PA.preserveSet<CFGAnalyses>();
return PA;
}
+
+void SpeculativeExecutionPass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<SpeculativeExecutionPass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ OS << '<';
+ if (OnlyIfDivergentTarget)
+ OS << "only-if-divergent-target";
+ OS << '>';
+}
} // namespace llvm
diff --git a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
index fdb41cb415df..543469d62fe7 100644
--- a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
@@ -680,7 +680,7 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis(
if (BumpWithUglyGEP) {
// C = (char *)Basis + Bump
unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
- Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS);
+ Type *CharTy = PointerType::get(Basis.Ins->getContext(), AS);
Reduced = Builder.CreateBitCast(Basis.Ins, CharTy);
Reduced =
Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds);
diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
index fac5695c7bea..7d96a3478858 100644
--- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
+++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp
@@ -42,6 +42,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
@@ -353,7 +354,6 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
if (SkipUniformRegions)
AU.addRequired<UniformityInfoWrapperPass>();
- AU.addRequiredID(LowerSwitchID);
AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
@@ -368,7 +368,6 @@ char StructurizeCFGLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(StructurizeCFGLegacyPass, "structurizecfg",
"Structurize the CFG", false, false)
INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(RegionInfoPass)
INITIALIZE_PASS_END(StructurizeCFGLegacyPass, "structurizecfg",
@@ -1173,6 +1172,8 @@ bool StructurizeCFG::run(Region *R, DominatorTree *DT) {
this->DT = DT;
Func = R->getEntry()->getParent();
+ assert(hasOnlySimpleTerminator(*Func) && "Unsupported block terminator.");
+
ParentRegion = R;
orderNodes();
diff --git a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp
index 4ec7181ad859..58ea5b68d548 100644
--- a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp
+++ b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp
@@ -32,7 +32,6 @@
#include <cassert>
#include <cstdint>
#include <iterator>
-#include <tuple>
#include <utility>
using namespace llvm;
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 4f1350e4ebb9..c6e8505d5ab4 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -369,8 +369,14 @@ static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
if (!I->isAssociative() || !I->isCommutative())
return false;
- assert(I->getNumOperands() == 2 &&
- "Associative/commutative operations should have 2 args!");
+ assert(I->getNumOperands() >= 2 &&
+ "Associative/commutative operations should have at least 2 args!");
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ // Accumulators must have an identity.
+ if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType()))
+ return false;
+ }
// Exactly one operand should be the result of the call instruction.
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
@@ -518,10 +524,10 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
// block, insert a PHI node for each argument of the function.
// For now, we initialize each PHI to only have the real arguments
// which are passed in.
- Instruction *InsertPos = &HeaderBB->front();
+ BasicBlock::iterator InsertPos = HeaderBB->begin();
for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) {
- PHINode *PN =
- PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos);
+ PHINode *PN = PHINode::Create(I->getType(), 2, I->getName() + ".tr");
+ PN->insertBefore(InsertPos);
I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
PN->addIncoming(&*I, NewEntry);
ArgumentPHIs.push_back(PN);
@@ -534,8 +540,10 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
Type *RetType = F.getReturnType();
if (!RetType->isVoidTy()) {
Type *BoolType = Type::getInt1Ty(F.getContext());
- RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos);
- RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos);
+ RetPN = PHINode::Create(RetType, 2, "ret.tr");
+ RetPN->insertBefore(InsertPos);
+ RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr");
+ RetKnownPN->insertBefore(InsertPos);
RetPN->addIncoming(PoisonValue::get(RetType), NewEntry);
RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry);
@@ -555,7 +563,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
// Start by inserting a new PHI node for the accumulator.
pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB);
AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1,
- "accumulator.tr", &HeaderBB->front());
+ "accumulator.tr");
+ AccPN->insertBefore(HeaderBB->begin());
// Loop over all of the predecessors of the tail recursion block. For the
// real entry into the function we seed the PHI with the identity constant for
@@ -566,8 +575,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
for (pred_iterator PI = PB; PI != PE; ++PI) {
BasicBlock *P = *PI;
if (P == &F.getEntryBlock()) {
- Constant *Identity = ConstantExpr::getBinOpIdentity(
- AccRecInstr->getOpcode(), AccRecInstr->getType());
+ Constant *Identity =
+ ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType());
AccPN->addIncoming(Identity, P);
} else {
AccPN->addIncoming(AccPN, P);
@@ -675,6 +684,12 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
for (unsigned I = 0, E = CI->arg_size(); I != E; ++I) {
if (CI->isByValArgument(I)) {
copyLocalTempOfByValueOperandIntoArguments(CI, I);
+ // When eliminating a tail call, we modify the values of the arguments.
+ // Therefore, if the byval parameter has a readonly attribute, we have to
+ // remove it. It is safe because, from the perspective of a caller, the
+ // byval parameter is always treated as "readonly," even if the readonly
+ // attribute is removed.
+ F.removeParamAttr(I, Attribute::ReadOnly);
ArgumentPHIs[I]->addIncoming(F.getArg(I), BB);
} else
ArgumentPHIs[I]->addIncoming(CI->getArgOperand(I), BB);
diff --git a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
index 2195406c144c..6ca737df49b9 100644
--- a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
+++ b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
@@ -153,19 +153,17 @@ static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
Value *Length, bool isLast) {
auto Int64Ty = Builder.getInt64Ty();
- auto CharPtrTy = Builder.getInt8PtrTy();
+ auto PtrTy = Builder.getPtrTy();
auto Int32Ty = Builder.getInt32Ty();
auto M = Builder.GetInsertBlock()->getModule();
auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
- Int64Ty, CharPtrTy, Int64Ty, Int32Ty);
+ Int64Ty, PtrTy, Int64Ty, Int32Ty);
auto IsLastInt32 = Builder.getInt32(isLast);
return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
}
static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
bool IsLast) {
- Arg = Builder.CreateBitCast(
- Arg, Builder.getInt8PtrTy(Arg->getType()->getPointerAddressSpace()));
auto Length = getStrlenWithNull(Builder, Arg);
return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
}
@@ -299,9 +297,9 @@ static Value *callBufferedPrintfStart(
Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);
Type *Tys_alloc[1] = {Builder.getInt32Ty()};
- Type *I8Ptr =
- Builder.getInt8PtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
- FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false);
+ Type *PtrTy =
+ Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
+ FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false);
auto PrintfAllocFn =
M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);
diff --git a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
index 7d127400651e..f95d5e23c9c8 100644
--- a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
+++ b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp
@@ -63,13 +63,10 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
#include <utility>
diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
index 45cf98e65a5a..efa8e874b955 100644
--- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
+++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
@@ -19,7 +19,6 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/DebugCounter.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -587,37 +586,3 @@ PreservedAnalyses AssumeBuilderPass::run(Function &F,
PA.preserveSet<CFGAnalyses>();
return PA;
}
-
-namespace {
-class AssumeBuilderPassLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- AssumeBuilderPassLegacyPass() : FunctionPass(ID) {
- initializeAssumeBuilderPassLegacyPassPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- AssumptionCache &AC =
- getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- DominatorTreeWrapperPass *DTWP =
- getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- for (Instruction &I : instructions(F))
- salvageKnowledge(&I, &AC, DTWP ? &DTWP->getDomTree() : nullptr);
- return true;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
-
- AU.setPreservesAll();
- }
-};
-} // namespace
-
-char AssumeBuilderPassLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(AssumeBuilderPassLegacyPass, "assume-builder",
- "Assume Builder", false, false)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(AssumeBuilderPassLegacyPass, "assume-builder",
- "Assume Builder", false, false)
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index f06ea89cc61d..b700edf8ea6c 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -194,7 +194,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
// Don't break unwinding instructions or terminators with other side-effects.
Instruction *PTI = PredBB->getTerminator();
- if (PTI->isExceptionalTerminator() || PTI->mayHaveSideEffects())
+ if (PTI->isSpecialTerminator() || PTI->mayHaveSideEffects())
return false;
// Can't merge if there are multiple distinct successors.
@@ -300,7 +300,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
PredBB->back().eraseFromParent();
// Move terminator instruction.
- PredBB->splice(PredBB->end(), BB);
+ BB->back().moveBeforePreserving(*PredBB, PredBB->end());
// Terminator may be a memory accessing instruction too.
if (MSSAU)
@@ -382,7 +382,39 @@ bool llvm::MergeBlockSuccessorsIntoGivenBlocks(
/// - Check fully overlapping fragments and not only identical fragments.
/// - Support dbg.declare. dbg.label, and possibly other meta instructions being
/// part of the sequence of consecutive instructions.
+static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) {
+ SmallVector<DPValue *, 8> ToBeRemoved;
+ SmallDenseSet<DebugVariable> VariableSet;
+ for (auto &I : reverse(*BB)) {
+ for (DPValue &DPV : reverse(I.getDbgValueRange())) {
+ DebugVariable Key(DPV.getVariable(), DPV.getExpression(),
+ DPV.getDebugLoc()->getInlinedAt());
+ auto R = VariableSet.insert(Key);
+ // If the same variable fragment is described more than once it is enough
+ // to keep the last one (i.e. the first found since we for reverse
+ // iteration).
+ // FIXME: add assignment tracking support (see parallel implementation
+ // below).
+ if (!R.second)
+ ToBeRemoved.push_back(&DPV);
+ continue;
+ }
+ // Sequence with consecutive dbg.value instrs ended. Clear the map to
+ // restart identifying redundant instructions if case we find another
+ // dbg.value sequence.
+ VariableSet.clear();
+ }
+
+ for (auto &DPV : ToBeRemoved)
+ DPV->eraseFromParent();
+
+ return !ToBeRemoved.empty();
+}
+
static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) {
+ if (BB->IsNewDbgInfoFormat)
+ return DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BB);
+
SmallVector<DbgValueInst *, 8> ToBeRemoved;
SmallDenseSet<DebugVariable> VariableSet;
for (auto &I : reverse(*BB)) {
@@ -440,7 +472,38 @@ static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) {
///
/// Possible improvements:
/// - Keep track of non-overlapping fragments.
+static bool DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) {
+ SmallVector<DPValue *, 8> ToBeRemoved;
+ DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>>
+ VariableMap;
+ for (auto &I : *BB) {
+ for (DPValue &DPV : I.getDbgValueRange()) {
+ DebugVariable Key(DPV.getVariable(), std::nullopt,
+ DPV.getDebugLoc()->getInlinedAt());
+ auto VMI = VariableMap.find(Key);
+ // Update the map if we found a new value/expression describing the
+ // variable, or if the variable wasn't mapped already.
+ SmallVector<Value *, 4> Values(DPV.location_ops());
+ if (VMI == VariableMap.end() || VMI->second.first != Values ||
+ VMI->second.second != DPV.getExpression()) {
+ VariableMap[Key] = {Values, DPV.getExpression()};
+ continue;
+ }
+ // Found an identical mapping. Remember the instruction for later removal.
+ ToBeRemoved.push_back(&DPV);
+ }
+ }
+
+ for (auto *DPV : ToBeRemoved)
+ DPV->eraseFromParent();
+
+ return !ToBeRemoved.empty();
+}
+
static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) {
+ if (BB->IsNewDbgInfoFormat)
+ return DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BB);
+
SmallVector<DbgValueInst *, 8> ToBeRemoved;
DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>>
VariableMap;
@@ -852,9 +915,11 @@ void llvm::createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds,
continue;
// Otherwise a new PHI is needed. Create one and populate it.
- PHINode *NewPN = PHINode::Create(
- PN.getType(), Preds.size(), "split",
- SplitBB->isLandingPad() ? &SplitBB->front() : SplitBB->getTerminator());
+ PHINode *NewPN = PHINode::Create(PN.getType(), Preds.size(), "split");
+ BasicBlock::iterator InsertPos =
+ SplitBB->isLandingPad() ? SplitBB->begin()
+ : SplitBB->getTerminator()->getIterator();
+ NewPN->insertBefore(InsertPos);
for (BasicBlock *BB : Preds)
NewPN->addIncoming(V, BB);
@@ -877,7 +942,7 @@ llvm::SplitAllCriticalEdges(Function &F,
return NumBroken;
}
-static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt,
+static BasicBlock *SplitBlockImpl(BasicBlock *Old, BasicBlock::iterator SplitPt,
DomTreeUpdater *DTU, DominatorTree *DT,
LoopInfo *LI, MemorySSAUpdater *MSSAU,
const Twine &BBName, bool Before) {
@@ -887,7 +952,7 @@ static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt,
DTU ? DTU : (DT ? &LocalDTU : nullptr), LI, MSSAU,
BBName);
}
- BasicBlock::iterator SplitIt = SplitPt->getIterator();
+ BasicBlock::iterator SplitIt = SplitPt;
while (isa<PHINode>(SplitIt) || SplitIt->isEHPad()) {
++SplitIt;
assert(SplitIt != SplitPt->getParent()->end());
@@ -933,14 +998,14 @@ static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt,
return New;
}
-BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt,
+BasicBlock *llvm::SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt,
DominatorTree *DT, LoopInfo *LI,
MemorySSAUpdater *MSSAU, const Twine &BBName,
bool Before) {
return SplitBlockImpl(Old, SplitPt, /*DTU=*/nullptr, DT, LI, MSSAU, BBName,
Before);
}
-BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt,
+BasicBlock *llvm::SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt,
DomTreeUpdater *DTU, LoopInfo *LI,
MemorySSAUpdater *MSSAU, const Twine &BBName,
bool Before) {
@@ -948,12 +1013,12 @@ BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt,
Before);
}
-BasicBlock *llvm::splitBlockBefore(BasicBlock *Old, Instruction *SplitPt,
+BasicBlock *llvm::splitBlockBefore(BasicBlock *Old, BasicBlock::iterator SplitPt,
DomTreeUpdater *DTU, LoopInfo *LI,
MemorySSAUpdater *MSSAU,
const Twine &BBName) {
- BasicBlock::iterator SplitIt = SplitPt->getIterator();
+ BasicBlock::iterator SplitIt = SplitPt;
while (isa<PHINode>(SplitIt) || SplitIt->isEHPad())
++SplitIt;
std::string Name = BBName.str();
@@ -1137,14 +1202,11 @@ static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB,
// If all incoming values for the new PHI would be the same, just don't
// make a new PHI. Instead, just remove the incoming values from the old
// PHI.
-
- // NOTE! This loop walks backwards for a reason! First off, this minimizes
- // the cost of removal if we end up removing a large number of values, and
- // second off, this ensures that the indices for the incoming values
- // aren't invalidated when we remove one.
- for (int64_t i = PN->getNumIncomingValues() - 1; i >= 0; --i)
- if (PredSet.count(PN->getIncomingBlock(i)))
- PN->removeIncomingValue(i, false);
+ PN->removeIncomingValueIf(
+ [&](unsigned Idx) {
+ return PredSet.contains(PN->getIncomingBlock(Idx));
+ },
+ /* DeletePHIIfEmpty */ false);
// Add an incoming value to the PHI node in the loop for the preheader
// edge.
@@ -1394,17 +1456,6 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB,
ArrayRef<BasicBlock *> Preds,
const char *Suffix1, const char *Suffix2,
SmallVectorImpl<BasicBlock *> &NewBBs,
- DominatorTree *DT, LoopInfo *LI,
- MemorySSAUpdater *MSSAU,
- bool PreserveLCSSA) {
- return SplitLandingPadPredecessorsImpl(
- OrigBB, Preds, Suffix1, Suffix2, NewBBs,
- /*DTU=*/nullptr, DT, LI, MSSAU, PreserveLCSSA);
-}
-void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB,
- ArrayRef<BasicBlock *> Preds,
- const char *Suffix1, const char *Suffix2,
- SmallVectorImpl<BasicBlock *> &NewBBs,
DomTreeUpdater *DTU, LoopInfo *LI,
MemorySSAUpdater *MSSAU,
bool PreserveLCSSA) {
@@ -1472,7 +1523,7 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB,
}
Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond,
- Instruction *SplitBefore,
+ BasicBlock::iterator SplitBefore,
bool Unreachable,
MDNode *BranchWeights,
DomTreeUpdater *DTU, LoopInfo *LI,
@@ -1485,7 +1536,7 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond,
}
Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond,
- Instruction *SplitBefore,
+ BasicBlock::iterator SplitBefore,
bool Unreachable,
MDNode *BranchWeights,
DomTreeUpdater *DTU, LoopInfo *LI,
@@ -1497,7 +1548,7 @@ Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond,
return ElseBlock->getTerminator();
}
-void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore,
+void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, BasicBlock::iterator SplitBefore,
Instruction **ThenTerm,
Instruction **ElseTerm,
MDNode *BranchWeights,
@@ -1513,7 +1564,7 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore,
}
void llvm::SplitBlockAndInsertIfThenElse(
- Value *Cond, Instruction *SplitBefore, BasicBlock **ThenBlock,
+ Value *Cond, BasicBlock::iterator SplitBefore, BasicBlock **ThenBlock,
BasicBlock **ElseBlock, bool UnreachableThen, bool UnreachableElse,
MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI) {
assert((ThenBlock || ElseBlock) &&
@@ -1530,7 +1581,7 @@ void llvm::SplitBlockAndInsertIfThenElse(
}
LLVMContext &C = Head->getContext();
- BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator());
+ BasicBlock *Tail = Head->splitBasicBlock(SplitBefore);
BasicBlock *TrueBlock = Tail;
BasicBlock *FalseBlock = Tail;
bool ThenToTailEdge = false;
@@ -2077,3 +2128,25 @@ void llvm::InvertBranch(BranchInst *PBI, IRBuilderBase &Builder) {
PBI->setCondition(NewCond);
PBI->swapSuccessors();
}
+
+bool llvm::hasOnlySimpleTerminator(const Function &F) {
+ for (auto &BB : F) {
+ auto *Term = BB.getTerminator();
+ if (!(isa<ReturnInst>(Term) || isa<UnreachableInst>(Term) ||
+ isa<BranchInst>(Term)))
+ return false;
+ }
+ return true;
+}
+
+bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
+ const BasicBlock &Dest) {
+ assert(Src.getParent() == Dest.getParent());
+ if (!Src.getParent()->isPresplitCoroutine())
+ return false;
+ if (auto *SW = dyn_cast<SwitchInst>(Src.getTerminator()))
+ if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
+ return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
+ SW->getDefaultDest() == &Dest;
+ return false;
+}
diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
index ddb35756030f..5fb796cc3db6 100644
--- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
+++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
@@ -387,7 +387,7 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F,
if (ShouldUpdateAnalysis) {
// Copy the BFI/BPI from Target to BodyBlock.
BPI->setEdgeProbability(BodyBlock, EdgeProbabilities);
- BFI->setBlockFreq(BodyBlock, BFI->getBlockFreq(Target).getFrequency());
+ BFI->setBlockFreq(BodyBlock, BFI->getBlockFreq(Target));
}
// It's possible Target was its own successor through an indirectbr.
// In this case, the indirectbr now comes from BodyBlock.
@@ -411,10 +411,10 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F,
BPI->getEdgeProbability(Src, DirectSucc);
}
if (ShouldUpdateAnalysis) {
- BFI->setBlockFreq(DirectSucc, BlockFreqForDirectSucc.getFrequency());
+ BFI->setBlockFreq(DirectSucc, BlockFreqForDirectSucc);
BlockFrequency NewBlockFreqForTarget =
BFI->getBlockFreq(Target) - BlockFreqForDirectSucc;
- BFI->setBlockFreq(Target, NewBlockFreqForTarget.getFrequency());
+ BFI->setBlockFreq(Target, NewBlockFreqForTarget);
}
// Ok, now fix up the PHIs. We know the two blocks only have PHIs, and that
@@ -449,8 +449,8 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F,
// Create a PHI in the body block, to merge the direct and indirect
// predecessors.
- PHINode *MergePHI =
- PHINode::Create(IndPHI->getType(), 2, "merge", &*MergeInsert);
+ PHINode *MergePHI = PHINode::Create(IndPHI->getType(), 2, "merge");
+ MergePHI->insertBefore(MergeInsert);
MergePHI->addIncoming(NewIndPHI, Target);
MergePHI->addIncoming(DirPHI, DirectSucc);
diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
index 5de8ff84de77..12741dc5af5a 100644
--- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -1425,11 +1425,6 @@ StringRef llvm::getFloatFn(const Module *M, const TargetLibraryInfo *TLI,
//- Emit LibCalls ------------------------------------------------------------//
-Value *llvm::castToCStr(Value *V, IRBuilderBase &B) {
- unsigned AS = V->getType()->getPointerAddressSpace();
- return B.CreateBitCast(V, B.getInt8PtrTy(AS), "cstr");
-}
-
static IntegerType *getIntTy(IRBuilderBase &B, const TargetLibraryInfo *TLI) {
return B.getIntNTy(TLI->getIntSize());
}
@@ -1461,63 +1456,64 @@ static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType,
Value *llvm::emitStrLen(Value *Ptr, IRBuilderBase &B, const DataLayout &DL,
const TargetLibraryInfo *TLI) {
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_strlen, SizeTTy,
- B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI);
+ return emitLibCall(LibFunc_strlen, SizeTTy, CharPtrTy, Ptr, B, TLI);
}
Value *llvm::emitStrDup(Value *Ptr, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- return emitLibCall(LibFunc_strdup, B.getInt8PtrTy(), B.getInt8PtrTy(),
- castToCStr(Ptr, B), B, TLI);
+ Type *CharPtrTy = B.getPtrTy();
+ return emitLibCall(LibFunc_strdup, CharPtrTy, CharPtrTy, Ptr, B, TLI);
}
Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
- return emitLibCall(LibFunc_strchr, I8Ptr, {I8Ptr, IntTy},
- {castToCStr(Ptr, B), ConstantInt::get(IntTy, C)}, B, TLI);
+ return emitLibCall(LibFunc_strchr, CharPtrTy, {CharPtrTy, IntTy},
+ {Ptr, ConstantInt::get(IntTy, C)}, B, TLI);
}
Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(
LibFunc_strncmp, IntTy,
- {B.getInt8PtrTy(), B.getInt8PtrTy(), SizeTTy},
- {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI);
+ {CharPtrTy, CharPtrTy, SizeTTy},
+ {Ptr1, Ptr2, Len}, B, TLI);
}
Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = Dst->getType();
- return emitLibCall(LibFunc_strcpy, I8Ptr, {I8Ptr, I8Ptr},
- {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI);
+ Type *CharPtrTy = Dst->getType();
+ return emitLibCall(LibFunc_strcpy, CharPtrTy, {CharPtrTy, CharPtrTy},
+ {Dst, Src}, B, TLI);
}
Value *llvm::emitStpCpy(Value *Dst, Value *Src, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
- return emitLibCall(LibFunc_stpcpy, I8Ptr, {I8Ptr, I8Ptr},
- {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI);
+ Type *CharPtrTy = B.getPtrTy();
+ return emitLibCall(LibFunc_stpcpy, CharPtrTy, {CharPtrTy, CharPtrTy},
+ {Dst, Src}, B, TLI);
}
Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_strncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI);
+ return emitLibCall(LibFunc_strncpy, CharPtrTy, {CharPtrTy, CharPtrTy, SizeTTy},
+ {Dst, Src, Len}, B, TLI);
}
Value *llvm::emitStpNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_stpncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI);
+ return emitLibCall(LibFunc_stpncpy, CharPtrTy, {CharPtrTy, CharPtrTy, SizeTTy},
+ {Dst, Src, Len}, B, TLI);
}
Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
@@ -1530,13 +1526,11 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
AttributeList AS;
AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex,
Attribute::NoUnwind);
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
FunctionCallee MemCpy = getOrInsertLibFunc(M, *TLI, LibFunc_memcpy_chk,
- AttributeList::get(M->getContext(), AS), I8Ptr,
- I8Ptr, I8Ptr, SizeTTy, SizeTTy);
- Dst = castToCStr(Dst, B);
- Src = castToCStr(Src, B);
+ AttributeList::get(M->getContext(), AS), VoidPtrTy,
+ VoidPtrTy, VoidPtrTy, SizeTTy, SizeTTy);
CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize});
if (const Function *F =
dyn_cast<Function>(MemCpy.getCallee()->stripPointerCasts()))
@@ -1546,140 +1540,141 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
Value *llvm::emitMemPCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_mempcpy, I8Ptr,
- {I8Ptr, I8Ptr, SizeTTy},
+ return emitLibCall(LibFunc_mempcpy, VoidPtrTy,
+ {VoidPtrTy, VoidPtrTy, SizeTTy},
{Dst, Src, Len}, B, TLI);
}
Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_memchr, I8Ptr,
- {I8Ptr, IntTy, SizeTTy},
- {castToCStr(Ptr, B), Val, Len}, B, TLI);
+ return emitLibCall(LibFunc_memchr, VoidPtrTy,
+ {VoidPtrTy, IntTy, SizeTTy},
+ {Ptr, Val, Len}, B, TLI);
}
Value *llvm::emitMemRChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_memrchr, I8Ptr,
- {I8Ptr, IntTy, SizeTTy},
- {castToCStr(Ptr, B), Val, Len}, B, TLI);
+ return emitLibCall(LibFunc_memrchr, VoidPtrTy,
+ {VoidPtrTy, IntTy, SizeTTy},
+ {Ptr, Val, Len}, B, TLI);
}
Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(LibFunc_memcmp, IntTy,
- {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI);
+ {VoidPtrTy, VoidPtrTy, SizeTTy},
+ {Ptr1, Ptr2, Len}, B, TLI);
}
Value *llvm::emitBCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B,
const DataLayout &DL, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(LibFunc_bcmp, IntTy,
- {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI);
+ {VoidPtrTy, VoidPtrTy, SizeTTy},
+ {Ptr1, Ptr2, Len}, B, TLI);
}
Value *llvm::emitMemCCpy(Value *Ptr1, Value *Ptr2, Value *Val, Value *Len,
IRBuilderBase &B, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *VoidPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_memccpy, I8Ptr,
- {I8Ptr, I8Ptr, IntTy, SizeTTy},
+ return emitLibCall(LibFunc_memccpy, VoidPtrTy,
+ {VoidPtrTy, VoidPtrTy, IntTy, SizeTTy},
{Ptr1, Ptr2, Val, Len}, B, TLI);
}
Value *llvm::emitSNPrintf(Value *Dest, Value *Size, Value *Fmt,
ArrayRef<Value *> VariadicArgs, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
- SmallVector<Value *, 8> Args{castToCStr(Dest, B), Size, castToCStr(Fmt, B)};
+ SmallVector<Value *, 8> Args{Dest, Size, Fmt};
llvm::append_range(Args, VariadicArgs);
return emitLibCall(LibFunc_snprintf, IntTy,
- {I8Ptr, SizeTTy, I8Ptr},
+ {CharPtrTy, SizeTTy, CharPtrTy},
Args, B, TLI, /*IsVaArgs=*/true);
}
Value *llvm::emitSPrintf(Value *Dest, Value *Fmt,
ArrayRef<Value *> VariadicArgs, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
- SmallVector<Value *, 8> Args{castToCStr(Dest, B), castToCStr(Fmt, B)};
+ SmallVector<Value *, 8> Args{Dest, Fmt};
llvm::append_range(Args, VariadicArgs);
return emitLibCall(LibFunc_sprintf, IntTy,
- {I8Ptr, I8Ptr}, Args, B, TLI,
+ {CharPtrTy, CharPtrTy}, Args, B, TLI,
/*IsVaArgs=*/true);
}
Value *llvm::emitStrCat(Value *Dest, Value *Src, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- return emitLibCall(LibFunc_strcat, B.getInt8PtrTy(),
- {B.getInt8PtrTy(), B.getInt8PtrTy()},
- {castToCStr(Dest, B), castToCStr(Src, B)}, B, TLI);
+ Type *CharPtrTy = B.getPtrTy();
+ return emitLibCall(LibFunc_strcat, CharPtrTy,
+ {CharPtrTy, CharPtrTy},
+ {Dest, Src}, B, TLI);
}
Value *llvm::emitStrLCpy(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(LibFunc_strlcpy, SizeTTy,
- {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI);
+ {CharPtrTy, CharPtrTy, SizeTTy},
+ {Dest, Src, Size}, B, TLI);
}
Value *llvm::emitStrLCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(LibFunc_strlcat, SizeTTy,
- {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI);
+ {CharPtrTy, CharPtrTy, SizeTTy},
+ {Dest, Src, Size}, B, TLI);
}
Value *llvm::emitStrNCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B,
const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *SizeTTy = getSizeTTy(B, TLI);
- return emitLibCall(LibFunc_strncat, I8Ptr,
- {I8Ptr, I8Ptr, SizeTTy},
- {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI);
+ return emitLibCall(LibFunc_strncat, CharPtrTy,
+ {CharPtrTy, CharPtrTy, SizeTTy},
+ {Dest, Src, Size}, B, TLI);
}
Value *llvm::emitVSNPrintf(Value *Dest, Value *Size, Value *Fmt, Value *VAList,
IRBuilderBase &B, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
Type *SizeTTy = getSizeTTy(B, TLI);
return emitLibCall(
LibFunc_vsnprintf, IntTy,
- {I8Ptr, SizeTTy, I8Ptr, VAList->getType()},
- {castToCStr(Dest, B), Size, castToCStr(Fmt, B), VAList}, B, TLI);
+ {CharPtrTy, SizeTTy, CharPtrTy, VAList->getType()},
+ {Dest, Size, Fmt, VAList}, B, TLI);
}
Value *llvm::emitVSPrintf(Value *Dest, Value *Fmt, Value *VAList,
IRBuilderBase &B, const TargetLibraryInfo *TLI) {
- Type *I8Ptr = B.getInt8PtrTy();
+ Type *CharPtrTy = B.getPtrTy();
Type *IntTy = getIntTy(B, TLI);
return emitLibCall(LibFunc_vsprintf, IntTy,
- {I8Ptr, I8Ptr, VAList->getType()},
- {castToCStr(Dest, B), castToCStr(Fmt, B), VAList}, B, TLI);
+ {CharPtrTy, CharPtrTy, VAList->getType()},
+ {Dest, Fmt, VAList}, B, TLI);
}
/// Append a suffix to the function name according to the type of 'Op'.
@@ -1829,9 +1824,9 @@ Value *llvm::emitPutS(Value *Str, IRBuilderBase &B,
Type *IntTy = getIntTy(B, TLI);
StringRef PutsName = TLI->getName(LibFunc_puts);
FunctionCallee PutS = getOrInsertLibFunc(M, *TLI, LibFunc_puts, IntTy,
- B.getInt8PtrTy());
+ B.getPtrTy());
inferNonMandatoryLibFuncAttrs(M, PutsName, *TLI);
- CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName);
+ CallInst *CI = B.CreateCall(PutS, Str, PutsName);
if (const Function *F =
dyn_cast<Function>(PutS.getCallee()->stripPointerCasts()))
CI->setCallingConv(F->getCallingConv());
@@ -1867,10 +1862,10 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilderBase &B,
Type *IntTy = getIntTy(B, TLI);
StringRef FPutsName = TLI->getName(LibFunc_fputs);
FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputs, IntTy,
- B.getInt8PtrTy(), File->getType());
+ B.getPtrTy(), File->getType());
if (File->getType()->isPointerTy())
inferNonMandatoryLibFuncAttrs(M, FPutsName, *TLI);
- CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsName);
+ CallInst *CI = B.CreateCall(F, {Str, File}, FPutsName);
if (const Function *Fn =
dyn_cast<Function>(F.getCallee()->stripPointerCasts()))
@@ -1887,13 +1882,13 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilderBase &B,
Type *SizeTTy = getSizeTTy(B, TLI);
StringRef FWriteName = TLI->getName(LibFunc_fwrite);
FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fwrite,
- SizeTTy, B.getInt8PtrTy(), SizeTTy,
+ SizeTTy, B.getPtrTy(), SizeTTy,
SizeTTy, File->getType());
if (File->getType()->isPointerTy())
inferNonMandatoryLibFuncAttrs(M, FWriteName, *TLI);
CallInst *CI =
- B.CreateCall(F, {castToCStr(Ptr, B), Size,
+ B.CreateCall(F, {Ptr, Size,
ConstantInt::get(SizeTTy, 1), File});
if (const Function *Fn =
@@ -1911,7 +1906,7 @@ Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL,
StringRef MallocName = TLI->getName(LibFunc_malloc);
Type *SizeTTy = getSizeTTy(B, TLI);
FunctionCallee Malloc = getOrInsertLibFunc(M, *TLI, LibFunc_malloc,
- B.getInt8PtrTy(), SizeTTy);
+ B.getPtrTy(), SizeTTy);
inferNonMandatoryLibFuncAttrs(M, MallocName, *TLI);
CallInst *CI = B.CreateCall(Malloc, Num, MallocName);
@@ -1931,7 +1926,7 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B,
StringRef CallocName = TLI.getName(LibFunc_calloc);
Type *SizeTTy = getSizeTTy(B, &TLI);
FunctionCallee Calloc = getOrInsertLibFunc(M, TLI, LibFunc_calloc,
- B.getInt8PtrTy(), SizeTTy, SizeTTy);
+ B.getPtrTy(), SizeTTy, SizeTTy);
inferNonMandatoryLibFuncAttrs(M, CallocName, TLI);
CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName);
@@ -1950,7 +1945,7 @@ Value *llvm::emitHotColdNew(Value *Num, IRBuilderBase &B,
return nullptr;
StringRef Name = TLI->getName(NewFunc);
- FunctionCallee Func = M->getOrInsertFunction(Name, B.getInt8PtrTy(),
+ FunctionCallee Func = M->getOrInsertFunction(Name, B.getPtrTy(),
Num->getType(), B.getInt8Ty());
inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
CallInst *CI = B.CreateCall(Func, {Num, B.getInt8(HotCold)}, Name);
@@ -1971,7 +1966,7 @@ Value *llvm::emitHotColdNewNoThrow(Value *Num, Value *NoThrow, IRBuilderBase &B,
StringRef Name = TLI->getName(NewFunc);
FunctionCallee Func =
- M->getOrInsertFunction(Name, B.getInt8PtrTy(), Num->getType(),
+ M->getOrInsertFunction(Name, B.getPtrTy(), Num->getType(),
NoThrow->getType(), B.getInt8Ty());
inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
CallInst *CI = B.CreateCall(Func, {Num, NoThrow, B.getInt8(HotCold)}, Name);
@@ -1992,7 +1987,7 @@ Value *llvm::emitHotColdNewAligned(Value *Num, Value *Align, IRBuilderBase &B,
StringRef Name = TLI->getName(NewFunc);
FunctionCallee Func = M->getOrInsertFunction(
- Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), B.getInt8Ty());
+ Name, B.getPtrTy(), Num->getType(), Align->getType(), B.getInt8Ty());
inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
CallInst *CI = B.CreateCall(Func, {Num, Align, B.getInt8(HotCold)}, Name);
@@ -2013,7 +2008,7 @@ Value *llvm::emitHotColdNewAlignedNoThrow(Value *Num, Value *Align,
StringRef Name = TLI->getName(NewFunc);
FunctionCallee Func = M->getOrInsertFunction(
- Name, B.getInt8PtrTy(), Num->getType(), Align->getType(),
+ Name, B.getPtrTy(), Num->getType(), Align->getType(),
NoThrow->getType(), B.getInt8Ty());
inferNonMandatoryLibFuncAttrs(M, Name, *TLI);
CallInst *CI =
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index b488e3bb0cbd..e42cdab64446 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -111,7 +111,7 @@ static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst,
if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty())
return;
- Builder.SetInsertPoint(&MergeBlock->front());
+ Builder.SetInsertPoint(MergeBlock, MergeBlock->begin());
PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0);
SmallVector<User *, 16> UsersToUpdate(OrigInst->users());
for (User *U : UsersToUpdate)
diff --git a/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp b/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp
index a1ee3df907ec..fb4d82885377 100644
--- a/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp
+++ b/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp
@@ -30,6 +30,7 @@
#include "llvm/Transforms/Utils/CanonicalizeFreezeInLoops.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp
index d55208602b71..c0f333364fa5 100644
--- a/llvm/lib/Transforms/Utils/CloneFunction.cpp
+++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp
@@ -44,6 +44,7 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap,
ClonedCodeInfo *CodeInfo,
DebugInfoFinder *DIFinder) {
BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "", F);
+ NewBB->IsNewDbgInfoFormat = BB->IsNewDbgInfoFormat;
if (BB->hasName())
NewBB->setName(BB->getName() + NameSuffix);
@@ -58,7 +59,10 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap,
Instruction *NewInst = I.clone();
if (I.hasName())
NewInst->setName(I.getName() + NameSuffix);
- NewInst->insertInto(NewBB, NewBB->end());
+
+ NewInst->insertBefore(*NewBB, NewBB->end());
+ NewInst->cloneDebugInfoFrom(&I);
+
VMap[&I] = NewInst; // Add instruction map to value.
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst()) {
@@ -90,6 +94,7 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,
const char *NameSuffix, ClonedCodeInfo *CodeInfo,
ValueMapTypeRemapper *TypeMapper,
ValueMaterializer *Materializer) {
+ NewFunc->setIsNewDbgInfoFormat(OldFunc->IsNewDbgInfoFormat);
assert(NameSuffix && "NameSuffix cannot be null!");
#ifndef NDEBUG
@@ -267,9 +272,13 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc,
BB = cast<BasicBlock>(VMap[&OldFunc->front()])->getIterator(),
BE = NewFunc->end();
BB != BE; ++BB)
- // Loop over all instructions, fixing each one as we find it...
- for (Instruction &II : *BB)
+ // Loop over all instructions, fixing each one as we find it, and any
+ // attached debug-info records.
+ for (Instruction &II : *BB) {
RemapInstruction(&II, VMap, RemapFlag, TypeMapper, Materializer);
+ RemapDPValueRange(II.getModule(), II.getDbgValueRange(), VMap, RemapFlag,
+ TypeMapper, Materializer);
+ }
// Only update !llvm.dbg.cu for DifferentModule (not CloneModule). In the
// same module, the compile unit will already be listed (or not). When
@@ -327,6 +336,7 @@ Function *llvm::CloneFunction(Function *F, ValueToValueMapTy &VMap,
// Create the new function...
Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
F->getName(), F->getParent());
+ NewF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
// Loop over the arguments, copying the names of the mapped arguments over...
Function::arg_iterator DestI = NewF->arg_begin();
@@ -472,6 +482,7 @@ void PruningFunctionCloner::CloneBlock(
BasicBlock *NewBB;
Twine NewName(BB->hasName() ? Twine(BB->getName()) + NameSuffix : "");
BBEntry = NewBB = BasicBlock::Create(BB->getContext(), NewName, NewFunc);
+ NewBB->IsNewDbgInfoFormat = BB->IsNewDbgInfoFormat;
// It is only legal to clone a function if a block address within that
// function is never referenced outside of the function. Given that, we
@@ -491,6 +502,22 @@ void PruningFunctionCloner::CloneBlock(
bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false;
bool hasMemProfMetadata = false;
+ // Keep a cursor pointing at the last place we cloned debug-info records from.
+ BasicBlock::const_iterator DbgCursor = StartingInst;
+ auto CloneDbgRecordsToHere =
+ [NewBB, &DbgCursor](Instruction *NewInst, BasicBlock::const_iterator II) {
+ if (!NewBB->IsNewDbgInfoFormat)
+ return;
+
+ // Clone debug-info records onto this instruction. Iterate through any
+ // source-instructions we've cloned and then subsequently optimised
+ // away, so that their debug-info doesn't go missing.
+ for (; DbgCursor != II; ++DbgCursor)
+ NewInst->cloneDebugInfoFrom(&*DbgCursor, std::nullopt, false);
+ NewInst->cloneDebugInfoFrom(&*II);
+ DbgCursor = std::next(II);
+ };
+
// Loop over all instructions, and copy them over, DCE'ing as we go. This
// loop doesn't include the terminator.
for (BasicBlock::const_iterator II = StartingInst, IE = --BB->end(); II != IE;
@@ -540,6 +567,8 @@ void PruningFunctionCloner::CloneBlock(
hasMemProfMetadata |= II->hasMetadata(LLVMContext::MD_memprof);
}
+ CloneDbgRecordsToHere(NewInst, II);
+
if (CodeInfo) {
CodeInfo->OrigVMap[&*II] = NewInst;
if (auto *CB = dyn_cast<CallBase>(&*II))
@@ -597,6 +626,9 @@ void PruningFunctionCloner::CloneBlock(
if (OldTI->hasName())
NewInst->setName(OldTI->getName() + NameSuffix);
NewInst->insertInto(NewBB, NewBB->end());
+
+ CloneDbgRecordsToHere(NewInst, OldTI->getIterator());
+
VMap[OldTI] = NewInst; // Add instruction map to value.
if (CodeInfo) {
@@ -608,6 +640,13 @@ void PruningFunctionCloner::CloneBlock(
// Recursively clone any reachable successor blocks.
append_range(ToClone, successors(BB->getTerminator()));
+ } else {
+ // If we didn't create a new terminator, clone DPValues from the old
+ // terminator onto the new terminator.
+ Instruction *NewInst = NewBB->getTerminator();
+ assert(NewInst);
+
+ CloneDbgRecordsToHere(NewInst, OldTI->getIterator());
}
if (CodeInfo) {
@@ -845,12 +884,22 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc,
TypeMapper, Materializer);
}
+ // Do the same for DPValues, touching all the instructions in the cloned
+ // range of blocks.
+ Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator();
+ for (BasicBlock &BB : make_range(Begin, NewFunc->end())) {
+ for (Instruction &I : BB) {
+ RemapDPValueRange(I.getModule(), I.getDbgValueRange(), VMap,
+ ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges,
+ TypeMapper, Materializer);
+ }
+ }
+
// Simplify conditional branches and switches with a constant operand. We try
// to prune these out when cloning, but if the simplification required
// looking through PHI nodes, those are only available after forming the full
// basic block. That may leave some here, and we still want to prune the dead
// code as early as possible.
- Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator();
for (BasicBlock &BB : make_range(Begin, NewFunc->end()))
ConstantFoldTerminator(&BB);
@@ -939,10 +988,14 @@ void llvm::CloneAndPruneFunctionInto(
void llvm::remapInstructionsInBlocks(ArrayRef<BasicBlock *> Blocks,
ValueToValueMapTy &VMap) {
// Rewrite the code to refer to itself.
- for (auto *BB : Blocks)
- for (auto &Inst : *BB)
+ for (auto *BB : Blocks) {
+ for (auto &Inst : *BB) {
+ RemapDPValueRange(Inst.getModule(), Inst.getDbgValueRange(), VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
RemapInstruction(&Inst, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ }
+ }
}
/// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p
@@ -1066,6 +1119,7 @@ BasicBlock *llvm::DuplicateInstructionsInSplitBetween(
Instruction *New = BI->clone();
New->setName(BI->getName());
New->insertBefore(NewTerm);
+ New->cloneDebugInfoFrom(&*BI);
ValueMapping[&*BI] = New;
// Remap operands to patch up intra-block references.
diff --git a/llvm/lib/Transforms/Utils/CloneModule.cpp b/llvm/lib/Transforms/Utils/CloneModule.cpp
index 55e051298a9a..00e40fe73d90 100644
--- a/llvm/lib/Transforms/Utils/CloneModule.cpp
+++ b/llvm/lib/Transforms/Utils/CloneModule.cpp
@@ -34,6 +34,8 @@ static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) {
/// copies of global variables and functions, and making their (initializers and
/// references, respectively) refer to the right globals.
///
+/// Cloning un-materialized modules is not currently supported, so any
+/// modules initialized via lazy loading should be materialized before cloning
std::unique_ptr<Module> llvm::CloneModule(const Module &M) {
// Create the value map that maps things from the old module over to the new
// module.
@@ -49,6 +51,9 @@ std::unique_ptr<Module> llvm::CloneModule(const Module &M,
std::unique_ptr<Module> llvm::CloneModule(
const Module &M, ValueToValueMapTy &VMap,
function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) {
+
+ assert(M.isMaterialized() && "Module must be materialized before cloning!");
+
// First off, we need to create the new module.
std::unique_ptr<Module> New =
std::make_unique<Module>(M.getModuleIdentifier(), M.getContext());
@@ -56,6 +61,7 @@ std::unique_ptr<Module> llvm::CloneModule(
New->setDataLayout(M.getDataLayout());
New->setTargetTriple(M.getTargetTriple());
New->setModuleInlineAsm(M.getModuleInlineAsm());
+ New->IsNewDbgInfoFormat = M.IsNewDbgInfoFormat;
// Loop over all of the global variables, making corresponding globals in the
// new module. Here we add them to the VMap and to the new Module. We
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index c390af351a69..9c1186232e02 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
BranchProbabilityInfo *BPI, AssumptionCache *AC,
bool AllowVarArgs, bool AllowAlloca,
- BasicBlock *AllocationBlock, std::string Suffix)
+ BasicBlock *AllocationBlock, std::string Suffix,
+ bool ArgsInZeroAddressSpace)
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
AllowVarArgs(AllowVarArgs),
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
- Suffix(Suffix) {}
+ Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
BlockFrequencyInfo *BFI,
@@ -567,7 +568,7 @@ void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
for (Instruction *I : LifetimeBitcastUsers) {
Module *M = AIFunc->getParent();
LLVMContext &Ctx = M->getContext();
- auto *Int8PtrTy = Type::getInt8PtrTy(Ctx);
+ auto *Int8PtrTy = PointerType::getUnqual(Ctx);
CastInst *CastI =
CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I);
I->replaceUsesOfWith(I->getOperand(1), CastI);
@@ -721,7 +722,8 @@ void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
// Create a new PHI node in the new region, which has an incoming value
// from OldPred of PN.
PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
- PN->getName() + ".ce", &NewBB->front());
+ PN->getName() + ".ce");
+ NewPN->insertBefore(NewBB->begin());
PN->replaceAllUsesWith(NewPN);
NewPN->addIncoming(PN, OldPred);
@@ -766,6 +768,7 @@ void CodeExtractor::severSplitPHINodesOfExits(
NewBB = BasicBlock::Create(ExitBB->getContext(),
ExitBB->getName() + ".split",
ExitBB->getParent(), ExitBB);
+ NewBB->IsNewDbgInfoFormat = ExitBB->IsNewDbgInfoFormat;
SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBB));
for (BasicBlock *PredBB : Preds)
if (Blocks.count(PredBB))
@@ -775,9 +778,9 @@ void CodeExtractor::severSplitPHINodesOfExits(
}
// Split this PHI.
- PHINode *NewPN =
- PHINode::Create(PN.getType(), IncomingVals.size(),
- PN.getName() + ".ce", NewBB->getFirstNonPHI());
+ PHINode *NewPN = PHINode::Create(PN.getType(), IncomingVals.size(),
+ PN.getName() + ".ce");
+ NewPN->insertBefore(NewBB->getFirstNonPHIIt());
for (unsigned i : IncomingVals)
NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
for (unsigned i : reverse(IncomingVals))
@@ -865,7 +868,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
StructType *StructTy = nullptr;
if (AggregateArgs && !AggParamTy.empty()) {
StructTy = StructType::get(M->getContext(), AggParamTy);
- ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
+ ParamTy.push_back(PointerType::get(
+ StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
}
LLVM_DEBUG({
@@ -886,6 +890,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
Function *newFunction = Function::Create(
funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
oldFunction->getName() + "." + SuffixToUse, M);
+ newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
// Inherit all of the target dependent attributes and white-listed
// target independent attributes.
@@ -919,6 +924,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::PresplitCoroutine:
case Attribute::Memory:
case Attribute::NoFPClass:
+ case Attribute::CoroDestroyOnlyWhenComplete:
continue;
// Those attributes should be safe to propagate to the extracted function.
case Attribute::AlwaysInline:
@@ -940,6 +946,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::NoSanitizeBounds:
case Attribute::NoSanitizeCoverage:
case Attribute::NullPointerIsValid:
+ case Attribute::OptimizeForDebugging:
case Attribute::OptForFuzzing:
case Attribute::OptimizeNone:
case Attribute::OptimizeForSize:
@@ -990,6 +997,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::ImmArg:
case Attribute::ByRef:
case Attribute::WriteOnly:
+ case Attribute::Writable:
// These are not really attributes.
case Attribute::None:
case Attribute::EndAttrKinds:
@@ -1185,8 +1193,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
: &codeReplacer->getParent()->front().front());
- params.push_back(Struct);
+ if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+ auto *StructSpaceCast = new AddrSpaceCastInst(
+ Struct, PointerType ::get(Context, 0), "structArg.ascast");
+ StructSpaceCast->insertAfter(Struct);
+ params.push_back(StructSpaceCast);
+ } else {
+ params.push_back(Struct);
+ }
// Store aggregated inputs in the struct.
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
if (inputs.contains(StructValues[i])) {
@@ -1492,10 +1507,14 @@ void CodeExtractor::calculateNewCallTerminatorWeights(
static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) {
for (Instruction &I : instructions(F)) {
SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
- findDbgUsers(DbgUsers, &I);
+ SmallVector<DPValue *, 4> DPValues;
+ findDbgUsers(DbgUsers, &I, &DPValues);
for (DbgVariableIntrinsic *DVI : DbgUsers)
if (DVI->getFunction() != &F)
DVI->eraseFromParent();
+ for (DPValue *DPV : DPValues)
+ if (DPV->getFunction() != &F)
+ DPV->eraseFromParent();
}
}
@@ -1531,6 +1550,16 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
/*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags);
NewFunc.setSubprogram(NewSP);
+ auto IsInvalidLocation = [&NewFunc](Value *Location) {
+ // Location is invalid if it isn't a constant or an instruction, or is an
+ // instruction but isn't in the new function.
+ if (!Location ||
+ (!isa<Constant>(Location) && !isa<Instruction>(Location)))
+ return true;
+ Instruction *LocationInst = dyn_cast<Instruction>(Location);
+ return LocationInst && LocationInst->getFunction() != &NewFunc;
+ };
+
// Debug intrinsics in the new function need to be updated in one of two
// ways:
// 1) They need to be deleted, because they describe a value in the old
@@ -1539,8 +1568,41 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
// point to a variable in the wrong scope.
SmallDenseMap<DINode *, DINode *> RemappedMetadata;
SmallVector<Instruction *, 4> DebugIntrinsicsToDelete;
+ SmallVector<DPValue *, 4> DPVsToDelete;
DenseMap<const MDNode *, MDNode *> Cache;
+
+ auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar) {
+ DINode *&NewVar = RemappedMetadata[OldVar];
+ if (!NewVar) {
+ DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
+ *OldVar->getScope(), *NewSP, Ctx, Cache);
+ NewVar = DIB.createAutoVariable(
+ NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(),
+ OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero,
+ OldVar->getAlignInBits());
+ }
+ return cast<DILocalVariable>(NewVar);
+ };
+
+ auto UpdateDPValuesOnInst = [&](Instruction &I) -> void {
+ for (auto &DPV : I.getDbgValueRange()) {
+ // Apply the two updates that dbg.values get: invalid operands, and
+ // variable metadata fixup.
+ // FIXME: support dbg.assign form of DPValues.
+ if (any_of(DPV.location_ops(), IsInvalidLocation)) {
+ DPVsToDelete.push_back(&DPV);
+ continue;
+ }
+ if (!DPV.getDebugLoc().getInlinedAt())
+ DPV.setVariable(GetUpdatedDIVariable(DPV.getVariable()));
+ DPV.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DPV.getDebugLoc(),
+ *NewSP, Ctx, Cache));
+ }
+ };
+
for (Instruction &I : instructions(NewFunc)) {
+ UpdateDPValuesOnInst(I);
+
auto *DII = dyn_cast<DbgInfoIntrinsic>(&I);
if (!DII)
continue;
@@ -1562,41 +1624,28 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
continue;
}
- auto IsInvalidLocation = [&NewFunc](Value *Location) {
- // Location is invalid if it isn't a constant or an instruction, or is an
- // instruction but isn't in the new function.
- if (!Location ||
- (!isa<Constant>(Location) && !isa<Instruction>(Location)))
- return true;
- Instruction *LocationInst = dyn_cast<Instruction>(Location);
- return LocationInst && LocationInst->getFunction() != &NewFunc;
- };
-
auto *DVI = cast<DbgVariableIntrinsic>(DII);
// If any of the used locations are invalid, delete the intrinsic.
if (any_of(DVI->location_ops(), IsInvalidLocation)) {
DebugIntrinsicsToDelete.push_back(DVI);
continue;
}
+ // DbgAssign intrinsics have an extra Value argument:
+ if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI);
+ DAI && IsInvalidLocation(DAI->getAddress())) {
+ DebugIntrinsicsToDelete.push_back(DVI);
+ continue;
+ }
// If the variable was in the scope of the old function, i.e. it was not
// inlined, point the intrinsic to a fresh variable within the new function.
- if (!DVI->getDebugLoc().getInlinedAt()) {
- DILocalVariable *OldVar = DVI->getVariable();
- DINode *&NewVar = RemappedMetadata[OldVar];
- if (!NewVar) {
- DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
- *OldVar->getScope(), *NewSP, Ctx, Cache);
- NewVar = DIB.createAutoVariable(
- NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(),
- OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero,
- OldVar->getAlignInBits());
- }
- DVI->setVariable(cast<DILocalVariable>(NewVar));
- }
+ if (!DVI->getDebugLoc().getInlinedAt())
+ DVI->setVariable(GetUpdatedDIVariable(DVI->getVariable()));
}
for (auto *DII : DebugIntrinsicsToDelete)
DII->eraseFromParent();
+ for (auto *DPV : DPVsToDelete)
+ DPV->getMarker()->MarkedInstr->dropOneDbgValue(DPV);
DIB.finalizeSubprogram(NewSP);
// Fix up the scope information attached to the line locations in the new
@@ -1702,11 +1751,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
"codeRepl", oldFunction,
header);
+ codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
// The new function needs a root node because other nodes can branch to the
// head of the region, but the entry node of a function cannot have preds.
BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
"newFuncRoot");
+ newFuncRoot->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+
auto *BranchI = BranchInst::Create(header);
// If the original function has debug info, we have to add a debug location
// to the new branch instruction from the artificial entry block.
@@ -1772,11 +1824,11 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
// Update the entry count of the function.
if (BFI) {
- auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+ auto Count = BFI->getProfileCountFromFreq(EntryFreq);
if (Count)
newFunction->setEntryCount(
ProfileCount(*Count, Function::PCT_Real)); // FIXME
- BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+ BFI->setBlockFreq(codeReplacer, EntryFreq);
}
CallInst *TheCall =
diff --git a/llvm/lib/Transforms/Utils/CodeLayout.cpp b/llvm/lib/Transforms/Utils/CodeLayout.cpp
index ac74a1c116cc..95edd27c675d 100644
--- a/llvm/lib/Transforms/Utils/CodeLayout.cpp
+++ b/llvm/lib/Transforms/Utils/CodeLayout.cpp
@@ -45,8 +45,11 @@
#include "llvm/Support/Debug.h"
#include <cmath>
+#include <set>
using namespace llvm;
+using namespace llvm::codelayout;
+
#define DEBUG_TYPE "code-layout"
namespace llvm {
@@ -61,8 +64,8 @@ cl::opt<bool> ApplyExtTspWithoutProfile(
cl::init(true), cl::Hidden);
} // namespace llvm
-// Algorithm-specific params. The values are tuned for the best performance
-// of large-scale front-end bound binaries.
+// Algorithm-specific params for Ext-TSP. The values are tuned for the best
+// performance of large-scale front-end bound binaries.
static cl::opt<double> ForwardWeightCond(
"ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1),
cl::desc("The weight of conditional forward jumps for ExtTSP value"));
@@ -96,10 +99,10 @@ static cl::opt<unsigned> BackwardDistance(
cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP"));
// The maximum size of a chain created by the algorithm. The size is bounded
-// so that the algorithm can efficiently process extremely large instance.
+// so that the algorithm can efficiently process extremely large instances.
static cl::opt<unsigned>
- MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(4096),
- cl::desc("The maximum size of a chain to create."));
+ MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(512),
+ cl::desc("The maximum size of a chain to create"));
// The maximum size of a chain for splitting. Larger values of the threshold
// may yield better quality at the cost of worsen run-time.
@@ -107,11 +110,29 @@ static cl::opt<unsigned> ChainSplitThreshold(
"ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128),
cl::desc("The maximum size of a chain to apply splitting"));
-// The option enables splitting (large) chains along in-coming and out-going
-// jumps. This typically results in a better quality.
-static cl::opt<bool> EnableChainSplitAlongJumps(
- "ext-tsp-enable-chain-split-along-jumps", cl::ReallyHidden, cl::init(true),
- cl::desc("The maximum size of a chain to apply splitting"));
+// The maximum ratio between densities of two chains for merging.
+static cl::opt<double> MaxMergeDensityRatio(
+ "ext-tsp-max-merge-density-ratio", cl::ReallyHidden, cl::init(100),
+ cl::desc("The maximum ratio between densities of two chains for merging"));
+
+// Algorithm-specific options for CDSort.
+static cl::opt<unsigned> CacheEntries("cdsort-cache-entries", cl::ReallyHidden,
+ cl::desc("The size of the cache"));
+
+static cl::opt<unsigned> CacheSize("cdsort-cache-size", cl::ReallyHidden,
+ cl::desc("The size of a line in the cache"));
+
+static cl::opt<unsigned>
+ CDMaxChainSize("cdsort-max-chain-size", cl::ReallyHidden,
+ cl::desc("The maximum size of a chain to create"));
+
+static cl::opt<double> DistancePower(
+ "cdsort-distance-power", cl::ReallyHidden,
+ cl::desc("The power exponent for the distance-based locality"));
+
+static cl::opt<double> FrequencyScale(
+ "cdsort-frequency-scale", cl::ReallyHidden,
+ cl::desc("The scale factor for the frequency-based locality"));
namespace {
@@ -199,11 +220,14 @@ struct NodeT {
NodeT &operator=(const NodeT &) = delete;
NodeT &operator=(NodeT &&) = default;
- explicit NodeT(size_t Index, uint64_t Size, uint64_t EC)
- : Index(Index), Size(Size), ExecutionCount(EC) {}
+ explicit NodeT(size_t Index, uint64_t Size, uint64_t Count)
+ : Index(Index), Size(Size), ExecutionCount(Count) {}
bool isEntry() const { return Index == 0; }
+ // Check if Other is a successor of the node.
+ bool isSuccessor(const NodeT *Other) const;
+
// The total execution count of outgoing jumps.
uint64_t outCount() const;
@@ -267,7 +291,7 @@ struct ChainT {
size_t numBlocks() const { return Nodes.size(); }
- double density() const { return static_cast<double>(ExecutionCount) / Size; }
+ double density() const { return ExecutionCount / Size; }
bool isEntry() const { return Nodes[0]->Index == 0; }
@@ -280,9 +304,9 @@ struct ChainT {
}
ChainEdge *getEdge(ChainT *Other) const {
- for (auto It : Edges) {
- if (It.first == Other)
- return It.second;
+ for (const auto &[Chain, ChainEdge] : Edges) {
+ if (Chain == Other)
+ return ChainEdge;
}
return nullptr;
}
@@ -302,13 +326,13 @@ struct ChainT {
Edges.push_back(std::make_pair(Other, Edge));
}
- void merge(ChainT *Other, const std::vector<NodeT *> &MergedBlocks) {
- Nodes = MergedBlocks;
- // Update the chain's data
+ void merge(ChainT *Other, std::vector<NodeT *> MergedBlocks) {
+ Nodes = std::move(MergedBlocks);
+ // Update the chain's data.
ExecutionCount += Other->ExecutionCount;
Size += Other->Size;
Id = Nodes[0]->Index;
- // Update the node's data
+ // Update the node's data.
for (size_t Idx = 0; Idx < Nodes.size(); Idx++) {
Nodes[Idx]->CurChain = this;
Nodes[Idx]->CurIndex = Idx;
@@ -328,8 +352,9 @@ struct ChainT {
uint64_t Id;
// Cached ext-tsp score for the chain.
double Score{0};
- // The total execution count of the chain.
- uint64_t ExecutionCount{0};
+ // The total execution count of the chain. Since the execution count of
+ // a basic block is uint64_t, using doubles here to avoid overflow.
+ double ExecutionCount{0};
// The total size of the chain.
uint64_t Size{0};
// Nodes of the chain.
@@ -340,7 +365,7 @@ struct ChainT {
/// An edge in the graph representing jumps between two chains.
/// When nodes are merged into chains, the edges are combined too so that
-/// there is always at most one edge between a pair of chains
+/// there is always at most one edge between a pair of chains.
struct ChainEdge {
ChainEdge(const ChainEdge &) = delete;
ChainEdge(ChainEdge &&) = default;
@@ -424,53 +449,57 @@ private:
bool CacheValidBackward{false};
};
+bool NodeT::isSuccessor(const NodeT *Other) const {
+ for (JumpT *Jump : OutJumps)
+ if (Jump->Target == Other)
+ return true;
+ return false;
+}
+
uint64_t NodeT::outCount() const {
uint64_t Count = 0;
- for (JumpT *Jump : OutJumps) {
+ for (JumpT *Jump : OutJumps)
Count += Jump->ExecutionCount;
- }
return Count;
}
uint64_t NodeT::inCount() const {
uint64_t Count = 0;
- for (JumpT *Jump : InJumps) {
+ for (JumpT *Jump : InJumps)
Count += Jump->ExecutionCount;
- }
return Count;
}
void ChainT::mergeEdges(ChainT *Other) {
- // Update edges adjacent to chain Other
- for (auto EdgeIt : Other->Edges) {
- ChainT *DstChain = EdgeIt.first;
- ChainEdge *DstEdge = EdgeIt.second;
+ // Update edges adjacent to chain Other.
+ for (const auto &[DstChain, DstEdge] : Other->Edges) {
ChainT *TargetChain = DstChain == Other ? this : DstChain;
ChainEdge *CurEdge = getEdge(TargetChain);
if (CurEdge == nullptr) {
DstEdge->changeEndpoint(Other, this);
this->addEdge(TargetChain, DstEdge);
- if (DstChain != this && DstChain != Other) {
+ if (DstChain != this && DstChain != Other)
DstChain->addEdge(this, DstEdge);
- }
} else {
CurEdge->moveJumps(DstEdge);
}
- // Cleanup leftover edge
- if (DstChain != Other) {
+ // Cleanup leftover edge.
+ if (DstChain != Other)
DstChain->removeEdge(Other);
- }
}
}
using NodeIter = std::vector<NodeT *>::const_iterator;
+static std::vector<NodeT *> EmptyList;
-/// A wrapper around three chains of nodes; it is used to avoid extra
-/// instantiation of the vectors.
-struct MergedChain {
- MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
- NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
- NodeIter End3 = NodeIter())
+/// A wrapper around three concatenated vectors (chains) of nodes; it is used
+/// to avoid extra instantiation of the vectors.
+struct MergedNodesT {
+ MergedNodesT(NodeIter Begin1, NodeIter End1,
+ NodeIter Begin2 = EmptyList.begin(),
+ NodeIter End2 = EmptyList.end(),
+ NodeIter Begin3 = EmptyList.begin(),
+ NodeIter End3 = EmptyList.end())
: Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
End3(End3) {}
@@ -504,15 +533,35 @@ private:
NodeIter End3;
};
+/// A wrapper around two concatenated vectors (chains) of jumps.
+struct MergedJumpsT {
+ MergedJumpsT(const std::vector<JumpT *> *Jumps1,
+ const std::vector<JumpT *> *Jumps2 = nullptr) {
+ assert(!Jumps1->empty() && "cannot merge empty jump list");
+ JumpArray[0] = Jumps1;
+ JumpArray[1] = Jumps2;
+ }
+
+ template <typename F> void forEach(const F &Func) const {
+ for (auto Jumps : JumpArray)
+ if (Jumps != nullptr)
+ for (JumpT *Jump : *Jumps)
+ Func(Jump);
+ }
+
+private:
+ std::array<const std::vector<JumpT *> *, 2> JumpArray{nullptr, nullptr};
+};
+
/// Merge two chains of nodes respecting a given 'type' and 'offset'.
///
/// If MergeType == 0, then the result is a concatenation of two chains.
/// Otherwise, the first chain is cut into two sub-chains at the offset,
/// and merged using all possible ways of concatenating three chains.
-MergedChain mergeNodes(const std::vector<NodeT *> &X,
- const std::vector<NodeT *> &Y, size_t MergeOffset,
- MergeTypeT MergeType) {
- // Split the first chain, X, into X1 and X2
+MergedNodesT mergeNodes(const std::vector<NodeT *> &X,
+ const std::vector<NodeT *> &Y, size_t MergeOffset,
+ MergeTypeT MergeType) {
+ // Split the first chain, X, into X1 and X2.
NodeIter BeginX1 = X.begin();
NodeIter EndX1 = X.begin() + MergeOffset;
NodeIter BeginX2 = X.begin() + MergeOffset;
@@ -520,18 +569,18 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X,
NodeIter BeginY = Y.begin();
NodeIter EndY = Y.end();
- // Construct a new chain from the three existing ones
+ // Construct a new chain from the three existing ones.
switch (MergeType) {
case MergeTypeT::X_Y:
- return MergedChain(BeginX1, EndX2, BeginY, EndY);
+ return MergedNodesT(BeginX1, EndX2, BeginY, EndY);
case MergeTypeT::Y_X:
- return MergedChain(BeginY, EndY, BeginX1, EndX2);
+ return MergedNodesT(BeginY, EndY, BeginX1, EndX2);
case MergeTypeT::X1_Y_X2:
- return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
+ return MergedNodesT(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
case MergeTypeT::Y_X2_X1:
- return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
+ return MergedNodesT(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
case MergeTypeT::X2_X1_Y:
- return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
+ return MergedNodesT(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
}
llvm_unreachable("unexpected chain merge type");
}
@@ -539,15 +588,14 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X,
/// The implementation of the ExtTSP algorithm.
class ExtTSPImpl {
public:
- ExtTSPImpl(const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<EdgeCountT> &EdgeCounts)
+ ExtTSPImpl(ArrayRef<uint64_t> NodeSizes, ArrayRef<uint64_t> NodeCounts,
+ ArrayRef<EdgeCount> EdgeCounts)
: NumNodes(NodeSizes.size()) {
initialize(NodeSizes, NodeCounts, EdgeCounts);
}
/// Run the algorithm and return an optimized ordering of nodes.
- void run(std::vector<uint64_t> &Result) {
+ std::vector<uint64_t> run() {
// Pass 1: Merge nodes with their mutually forced successors
mergeForcedPairs();
@@ -558,78 +606,80 @@ public:
mergeColdChains();
// Collect nodes from all chains
- concatChains(Result);
+ return concatChains();
}
private:
/// Initialize the algorithm's data structures.
- void initialize(const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<EdgeCountT> &EdgeCounts) {
- // Initialize nodes
+ void initialize(const ArrayRef<uint64_t> &NodeSizes,
+ const ArrayRef<uint64_t> &NodeCounts,
+ const ArrayRef<EdgeCount> &EdgeCounts) {
+ // Initialize nodes.
AllNodes.reserve(NumNodes);
for (uint64_t Idx = 0; Idx < NumNodes; Idx++) {
uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL);
uint64_t ExecutionCount = NodeCounts[Idx];
- // The execution count of the entry node is set to at least one
+ // The execution count of the entry node is set to at least one.
if (Idx == 0 && ExecutionCount == 0)
ExecutionCount = 1;
AllNodes.emplace_back(Idx, Size, ExecutionCount);
}
- // Initialize jumps between nodes
+ // Initialize jumps between the nodes.
SuccNodes.resize(NumNodes);
PredNodes.resize(NumNodes);
std::vector<uint64_t> OutDegree(NumNodes, 0);
AllJumps.reserve(EdgeCounts.size());
- for (auto It : EdgeCounts) {
- uint64_t Pred = It.first.first;
- uint64_t Succ = It.first.second;
- OutDegree[Pred]++;
- // Ignore self-edges
- if (Pred == Succ)
+ for (auto Edge : EdgeCounts) {
+ ++OutDegree[Edge.src];
+ // Ignore self-edges.
+ if (Edge.src == Edge.dst)
continue;
- SuccNodes[Pred].push_back(Succ);
- PredNodes[Succ].push_back(Pred);
- uint64_t ExecutionCount = It.second;
- if (ExecutionCount > 0) {
- NodeT &PredNode = AllNodes[Pred];
- NodeT &SuccNode = AllNodes[Succ];
- AllJumps.emplace_back(&PredNode, &SuccNode, ExecutionCount);
+ SuccNodes[Edge.src].push_back(Edge.dst);
+ PredNodes[Edge.dst].push_back(Edge.src);
+ if (Edge.count > 0) {
+ NodeT &PredNode = AllNodes[Edge.src];
+ NodeT &SuccNode = AllNodes[Edge.dst];
+ AllJumps.emplace_back(&PredNode, &SuccNode, Edge.count);
SuccNode.InJumps.push_back(&AllJumps.back());
PredNode.OutJumps.push_back(&AllJumps.back());
+ // Adjust execution counts.
+ PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Edge.count);
+ SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Edge.count);
}
}
for (JumpT &Jump : AllJumps) {
- assert(OutDegree[Jump.Source->Index] > 0);
+ assert(OutDegree[Jump.Source->Index] > 0 &&
+ "incorrectly computed out-degree of the block");
Jump.IsConditional = OutDegree[Jump.Source->Index] > 1;
}
- // Initialize chains
+ // Initialize chains.
AllChains.reserve(NumNodes);
HotChains.reserve(NumNodes);
for (NodeT &Node : AllNodes) {
+ // Create a chain.
AllChains.emplace_back(Node.Index, &Node);
Node.CurChain = &AllChains.back();
- if (Node.ExecutionCount > 0) {
+ if (Node.ExecutionCount > 0)
HotChains.push_back(&AllChains.back());
- }
}
- // Initialize chain edges
+ // Initialize chain edges.
AllEdges.reserve(AllJumps.size());
for (NodeT &PredNode : AllNodes) {
for (JumpT *Jump : PredNode.OutJumps) {
+ assert(Jump->ExecutionCount > 0 && "incorrectly initialized jump");
NodeT *SuccNode = Jump->Target;
ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
- // this edge is already present in the graph
+ // This edge is already present in the graph.
if (CurEdge != nullptr) {
assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
CurEdge->appendJump(Jump);
continue;
}
- // this is a new edge
+ // This is a new edge.
AllEdges.emplace_back(Jump);
PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
@@ -642,7 +692,7 @@ private:
/// to B are from A. Such nodes should be adjacent in the optimal ordering;
/// the method finds and merges such pairs of nodes.
void mergeForcedPairs() {
- // Find fallthroughs based on edge weights
+ // Find forced pairs of blocks.
for (NodeT &Node : AllNodes) {
if (SuccNodes[Node.Index].size() == 1 &&
PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
@@ -669,12 +719,12 @@ private:
}
if (SuccNode == nullptr)
continue;
- // Break the cycle
+ // Break the cycle.
AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr;
Node.ForcedPred = nullptr;
}
- // Merge nodes with their fallthrough successors
+ // Merge nodes with their fallthrough successors.
for (NodeT &Node : AllNodes) {
if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) {
const NodeT *CurBlock = &Node;
@@ -689,33 +739,42 @@ private:
/// Merge pairs of chains while improving the ExtTSP objective.
void mergeChainPairs() {
- /// Deterministically compare pairs of chains
+ /// Deterministically compare pairs of chains.
auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
const ChainT *A2, const ChainT *B2) {
- if (A1 != A2)
- return A1->Id < A2->Id;
- return B1->Id < B2->Id;
+ return std::make_tuple(A1->Id, B1->Id) < std::make_tuple(A2->Id, B2->Id);
};
while (HotChains.size() > 1) {
ChainT *BestChainPred = nullptr;
ChainT *BestChainSucc = nullptr;
MergeGainT BestGain;
- // Iterate over all pairs of chains
+ // Iterate over all pairs of chains.
for (ChainT *ChainPred : HotChains) {
- // Get candidates for merging with the current chain
- for (auto EdgeIt : ChainPred->Edges) {
- ChainT *ChainSucc = EdgeIt.first;
- ChainEdge *Edge = EdgeIt.second;
- // Ignore loop edges
- if (ChainPred == ChainSucc)
+ // Get candidates for merging with the current chain.
+ for (const auto &[ChainSucc, Edge] : ChainPred->Edges) {
+ // Ignore loop edges.
+ if (Edge->isSelfEdge())
continue;
-
- // Stop early if the combined chain violates the maximum allowed size
+ // Skip the merge if the combined chain violates the maximum specified
+ // size.
if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize)
continue;
+ // Don't merge the chains if they have vastly different densities.
+ // Skip the merge if the ratio between the densities exceeds
+ // MaxMergeDensityRatio. Smaller values of the option result in fewer
+ // merges, and hence, more chains.
+ const double ChainPredDensity = ChainPred->density();
+ const double ChainSuccDensity = ChainSucc->density();
+ assert(ChainPredDensity > 0.0 && ChainSuccDensity > 0.0 &&
+ "incorrectly computed chain densities");
+ auto [MinDensity, MaxDensity] =
+ std::minmax(ChainPredDensity, ChainSuccDensity);
+ const double Ratio = MaxDensity / MinDensity;
+ if (Ratio > MaxMergeDensityRatio)
+ continue;
- // Compute the gain of merging the two chains
+ // Compute the gain of merging the two chains.
MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge);
if (CurGain.score() <= EPS)
continue;
@@ -731,11 +790,11 @@ private:
}
}
- // Stop merging when there is no improvement
+ // Stop merging when there is no improvement.
if (BestGain.score() <= EPS)
break;
- // Merge the best pair of chains
+ // Merge the best pair of chains.
mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(),
BestGain.mergeType());
}
@@ -743,7 +802,7 @@ private:
/// Merge remaining nodes into chains w/o taking jump counts into
/// consideration. This allows to maintain the original node order in the
- /// absence of profile data
+ /// absence of profile data.
void mergeColdChains() {
for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) {
// Iterating in reverse order to make sure original fallthrough jumps are
@@ -764,24 +823,22 @@ private:
}
/// Compute the Ext-TSP score for a given node order and a list of jumps.
- double extTSPScore(const MergedChain &MergedBlocks,
- const std::vector<JumpT *> &Jumps) const {
- if (Jumps.empty())
- return 0.0;
+ double extTSPScore(const MergedNodesT &Nodes,
+ const MergedJumpsT &Jumps) const {
uint64_t CurAddr = 0;
- MergedBlocks.forEach([&](const NodeT *Node) {
+ Nodes.forEach([&](const NodeT *Node) {
Node->EstimatedAddr = CurAddr;
CurAddr += Node->Size;
});
double Score = 0;
- for (JumpT *Jump : Jumps) {
+ Jumps.forEach([&](const JumpT *Jump) {
const NodeT *SrcBlock = Jump->Source;
const NodeT *DstBlock = Jump->Target;
Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
DstBlock->EstimatedAddr, Jump->ExecutionCount,
Jump->IsConditional);
- }
+ });
return Score;
}
@@ -793,74 +850,76 @@ private:
/// element being the corresponding merging type.
MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
ChainEdge *Edge) const {
- if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) {
+ if (Edge->hasCachedMergeGain(ChainPred, ChainSucc))
return Edge->getCachedMergeGain(ChainPred, ChainSucc);
- }
- // Precompute jumps between ChainPred and ChainSucc
- auto Jumps = Edge->jumps();
+ assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
+ // Precompute jumps between ChainPred and ChainSucc.
ChainEdge *EdgePP = ChainPred->getEdge(ChainPred);
- if (EdgePP != nullptr) {
- Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end());
- }
- assert(!Jumps.empty() && "trying to merge chains w/o jumps");
+ MergedJumpsT Jumps(&Edge->jumps(), EdgePP ? &EdgePP->jumps() : nullptr);
- // The object holds the best currently chosen gain of merging the two chains
+ // This object holds the best chosen gain of merging two chains.
MergeGainT Gain = MergeGainT();
/// Given a merge offset and a list of merge types, try to merge two chains
- /// and update Gain with a better alternative
+ /// and update Gain with a better alternative.
auto tryChainMerging = [&](size_t Offset,
const std::vector<MergeTypeT> &MergeTypes) {
- // Skip merging corresponding to concatenation w/o splitting
+ // Skip merging corresponding to concatenation w/o splitting.
if (Offset == 0 || Offset == ChainPred->Nodes.size())
return;
- // Skip merging if it breaks Forced successors
+ // Skip merging if it breaks Forced successors.
NodeT *Node = ChainPred->Nodes[Offset - 1];
if (Node->ForcedSucc != nullptr)
return;
// Apply the merge, compute the corresponding gain, and update the best
- // value, if the merge is beneficial
+ // value, if the merge is beneficial.
for (const MergeTypeT &MergeType : MergeTypes) {
Gain.updateIfLessThan(
computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType));
}
};
- // Try to concatenate two chains w/o splitting
+ // Try to concatenate two chains w/o splitting.
Gain.updateIfLessThan(
computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y));
- if (EnableChainSplitAlongJumps) {
- // Attach (a part of) ChainPred before the first node of ChainSucc
- for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
- const NodeT *SrcBlock = Jump->Source;
- if (SrcBlock->CurChain != ChainPred)
- continue;
- size_t Offset = SrcBlock->CurIndex + 1;
- tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
- }
+ // Attach (a part of) ChainPred before the first node of ChainSucc.
+ for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
+ const NodeT *SrcBlock = Jump->Source;
+ if (SrcBlock->CurChain != ChainPred)
+ continue;
+ size_t Offset = SrcBlock->CurIndex + 1;
+ tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
+ }
- // Attach (a part of) ChainPred after the last node of ChainSucc
- for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
- const NodeT *DstBlock = Jump->Source;
- if (DstBlock->CurChain != ChainPred)
- continue;
- size_t Offset = DstBlock->CurIndex;
- tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
- }
+ // Attach (a part of) ChainPred after the last node of ChainSucc.
+ for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
+ const NodeT *DstBlock = Jump->Target;
+ if (DstBlock->CurChain != ChainPred)
+ continue;
+ size_t Offset = DstBlock->CurIndex;
+ tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
}
- // Try to break ChainPred in various ways and concatenate with ChainSucc
+ // Try to break ChainPred in various ways and concatenate with ChainSucc.
if (ChainPred->Nodes.size() <= ChainSplitThreshold) {
for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) {
- // Try to split the chain in different ways. In practice, applying
- // X2_Y_X1 merging is almost never provides benefits; thus, we exclude
- // it from consideration to reduce the search space
+ // Do not split the chain along a fall-through jump. One of the two
+ // loops above may still "break" such a jump whenever it results in a
+ // new fall-through.
+ const NodeT *BB = ChainPred->Nodes[Offset - 1];
+ const NodeT *BB2 = ChainPred->Nodes[Offset];
+ if (BB->isSuccessor(BB2))
+ continue;
+
+ // In practice, applying X2_Y_X1 merging almost never provides benefits;
+ // thus, we exclude it from consideration to reduce the search space.
tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1,
MergeTypeT::X2_X1_Y});
}
}
+
Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
return Gain;
}
@@ -870,19 +929,20 @@ private:
///
/// The two chains are not modified in the method.
MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
- const std::vector<JumpT *> &Jumps,
- size_t MergeOffset, MergeTypeT MergeType) const {
- auto MergedBlocks =
+ const MergedJumpsT &Jumps, size_t MergeOffset,
+ MergeTypeT MergeType) const {
+ MergedNodesT MergedNodes =
mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
- // Do not allow a merge that does not preserve the original entry point
+ // Do not allow a merge that does not preserve the original entry point.
if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
- !MergedBlocks.getFirstNode()->isEntry())
+ !MergedNodes.getFirstNode()->isEntry())
return MergeGainT();
- // The gain for the new chain
- auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score;
- return MergeGainT(NewGainScore, MergeOffset, MergeType);
+ // The gain for the new chain.
+ double NewScore = extTSPScore(MergedNodes, Jumps);
+ double CurScore = ChainPred->Score;
+ return MergeGainT(NewScore - CurScore, MergeOffset, MergeType);
}
/// Merge chain From into chain Into, update the list of active chains,
@@ -891,39 +951,398 @@ private:
MergeTypeT MergeType) {
assert(Into != From && "a chain cannot be merged with itself");
- // Merge the nodes
- MergedChain MergedNodes =
+ // Merge the nodes.
+ MergedNodesT MergedNodes =
mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
Into->merge(From, MergedNodes.getNodes());
- // Merge the edges
+ // Merge the edges.
Into->mergeEdges(From);
From->clear();
- // Update cached ext-tsp score for the new chain
+ // Update cached ext-tsp score for the new chain.
ChainEdge *SelfEdge = Into->getEdge(Into);
if (SelfEdge != nullptr) {
- MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end());
- Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps());
+ MergedNodes = MergedNodesT(Into->Nodes.begin(), Into->Nodes.end());
+ MergedJumpsT MergedJumps(&SelfEdge->jumps());
+ Into->Score = extTSPScore(MergedNodes, MergedJumps);
}
- // Remove the chain from the list of active chains
- llvm::erase_value(HotChains, From);
+ // Remove the chain from the list of active chains.
+ llvm::erase(HotChains, From);
- // Invalidate caches
+ // Invalidate caches.
for (auto EdgeIt : Into->Edges)
EdgeIt.second->invalidateCache();
}
/// Concatenate all chains into the final order.
- void concatChains(std::vector<uint64_t> &Order) {
- // Collect chains and calculate density stats for their sorting
+ std::vector<uint64_t> concatChains() {
+ // Collect non-empty chains.
+ std::vector<const ChainT *> SortedChains;
+ for (ChainT &Chain : AllChains) {
+ if (!Chain.Nodes.empty())
+ SortedChains.push_back(&Chain);
+ }
+
+ // Sorting chains by density in the decreasing order.
+ std::sort(SortedChains.begin(), SortedChains.end(),
+ [&](const ChainT *L, const ChainT *R) {
+ // Place the entry point at the beginning of the order.
+ if (L->isEntry() != R->isEntry())
+ return L->isEntry();
+
+ // Compare by density and break ties by chain identifiers.
+ return std::make_tuple(-L->density(), L->Id) <
+ std::make_tuple(-R->density(), R->Id);
+ });
+
+ // Collect the nodes in the order specified by their chains.
+ std::vector<uint64_t> Order;
+ Order.reserve(NumNodes);
+ for (const ChainT *Chain : SortedChains)
+ for (NodeT *Node : Chain->Nodes)
+ Order.push_back(Node->Index);
+ return Order;
+ }
+
+private:
+ /// The number of nodes in the graph.
+ const size_t NumNodes;
+
+ /// Successors of each node.
+ std::vector<std::vector<uint64_t>> SuccNodes;
+
+ /// Predecessors of each node.
+ std::vector<std::vector<uint64_t>> PredNodes;
+
+ /// All nodes (basic blocks) in the graph.
+ std::vector<NodeT> AllNodes;
+
+ /// All jumps between the nodes.
+ std::vector<JumpT> AllJumps;
+
+ /// All chains of nodes.
+ std::vector<ChainT> AllChains;
+
+ /// All edges between the chains.
+ std::vector<ChainEdge> AllEdges;
+
+ /// Active chains. The vector gets updated at runtime when chains are merged.
+ std::vector<ChainT *> HotChains;
+};
+
+/// The implementation of the Cache-Directed Sort (CDSort) algorithm for
+/// ordering functions represented by a call graph.
+class CDSortImpl {
+public:
+ CDSortImpl(const CDSortConfig &Config, ArrayRef<uint64_t> NodeSizes,
+ ArrayRef<uint64_t> NodeCounts, ArrayRef<EdgeCount> EdgeCounts,
+ ArrayRef<uint64_t> EdgeOffsets)
+ : Config(Config), NumNodes(NodeSizes.size()) {
+ initialize(NodeSizes, NodeCounts, EdgeCounts, EdgeOffsets);
+ }
+
+ /// Run the algorithm and return an ordered set of function clusters.
+ std::vector<uint64_t> run() {
+ // Merge pairs of chains while improving the objective.
+ mergeChainPairs();
+
+ // Collect nodes from all the chains.
+ return concatChains();
+ }
+
+private:
+ /// Initialize the algorithm's data structures.
+ void initialize(const ArrayRef<uint64_t> &NodeSizes,
+ const ArrayRef<uint64_t> &NodeCounts,
+ const ArrayRef<EdgeCount> &EdgeCounts,
+ const ArrayRef<uint64_t> &EdgeOffsets) {
+ // Initialize nodes.
+ AllNodes.reserve(NumNodes);
+ for (uint64_t Node = 0; Node < NumNodes; Node++) {
+ uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL);
+ uint64_t ExecutionCount = NodeCounts[Node];
+ AllNodes.emplace_back(Node, Size, ExecutionCount);
+ TotalSamples += ExecutionCount;
+ if (ExecutionCount > 0)
+ TotalSize += Size;
+ }
+
+ // Initialize jumps between the nodes.
+ SuccNodes.resize(NumNodes);
+ PredNodes.resize(NumNodes);
+ AllJumps.reserve(EdgeCounts.size());
+ for (size_t I = 0; I < EdgeCounts.size(); I++) {
+ auto [Pred, Succ, Count] = EdgeCounts[I];
+ // Ignore recursive calls.
+ if (Pred == Succ)
+ continue;
+
+ SuccNodes[Pred].push_back(Succ);
+ PredNodes[Succ].push_back(Pred);
+ if (Count > 0) {
+ NodeT &PredNode = AllNodes[Pred];
+ NodeT &SuccNode = AllNodes[Succ];
+ AllJumps.emplace_back(&PredNode, &SuccNode, Count);
+ AllJumps.back().Offset = EdgeOffsets[I];
+ SuccNode.InJumps.push_back(&AllJumps.back());
+ PredNode.OutJumps.push_back(&AllJumps.back());
+ // Adjust execution counts.
+ PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Count);
+ SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Count);
+ }
+ }
+
+ // Initialize chains.
+ AllChains.reserve(NumNodes);
+ for (NodeT &Node : AllNodes) {
+ // Adjust execution counts.
+ Node.ExecutionCount = std::max(Node.ExecutionCount, Node.inCount());
+ Node.ExecutionCount = std::max(Node.ExecutionCount, Node.outCount());
+ // Create chain.
+ AllChains.emplace_back(Node.Index, &Node);
+ Node.CurChain = &AllChains.back();
+ }
+
+ // Initialize chain edges.
+ AllEdges.reserve(AllJumps.size());
+ for (NodeT &PredNode : AllNodes) {
+ for (JumpT *Jump : PredNode.OutJumps) {
+ NodeT *SuccNode = Jump->Target;
+ ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
+ // This edge is already present in the graph.
+ if (CurEdge != nullptr) {
+ assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
+ CurEdge->appendJump(Jump);
+ continue;
+ }
+ // This is a new edge.
+ AllEdges.emplace_back(Jump);
+ PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
+ SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
+ }
+ }
+ }
+
+ /// Merge pairs of chains while there is an improvement in the objective.
+ void mergeChainPairs() {
+ // Create a priority queue containing all edges ordered by the merge gain.
+ auto GainComparator = [](ChainEdge *L, ChainEdge *R) {
+ return std::make_tuple(-L->gain(), L->srcChain()->Id, L->dstChain()->Id) <
+ std::make_tuple(-R->gain(), R->srcChain()->Id, R->dstChain()->Id);
+ };
+ std::set<ChainEdge *, decltype(GainComparator)> Queue(GainComparator);
+
+ // Insert the edges into the queue.
+ [[maybe_unused]] size_t NumActiveChains = 0;
+ for (NodeT &Node : AllNodes) {
+ if (Node.ExecutionCount == 0)
+ continue;
+ ++NumActiveChains;
+ for (const auto &[_, Edge] : Node.CurChain->Edges) {
+ // Ignore self-edges.
+ if (Edge->isSelfEdge())
+ continue;
+ // Ignore already processed edges.
+ if (Edge->gain() != -1.0)
+ continue;
+
+ // Compute the gain of merging the two chains.
+ MergeGainT Gain = getBestMergeGain(Edge);
+ Edge->setMergeGain(Gain);
+
+ if (Edge->gain() > EPS)
+ Queue.insert(Edge);
+ }
+ }
+
+ // Merge the chains while the gain of merging is positive.
+ while (!Queue.empty()) {
+ // Extract the best (top) edge for merging.
+ ChainEdge *BestEdge = *Queue.begin();
+ Queue.erase(Queue.begin());
+ ChainT *BestSrcChain = BestEdge->srcChain();
+ ChainT *BestDstChain = BestEdge->dstChain();
+
+ // Remove outdated edges from the queue.
+ for (const auto &[_, ChainEdge] : BestSrcChain->Edges)
+ Queue.erase(ChainEdge);
+ for (const auto &[_, ChainEdge] : BestDstChain->Edges)
+ Queue.erase(ChainEdge);
+
+ // Merge the best pair of chains.
+ MergeGainT BestGain = BestEdge->getMergeGain();
+ mergeChains(BestSrcChain, BestDstChain, BestGain.mergeOffset(),
+ BestGain.mergeType());
+ --NumActiveChains;
+
+ // Insert newly created edges into the queue.
+ for (const auto &[_, Edge] : BestSrcChain->Edges) {
+ // Ignore loop edges.
+ if (Edge->isSelfEdge())
+ continue;
+ if (Edge->srcChain()->numBlocks() + Edge->dstChain()->numBlocks() >
+ Config.MaxChainSize)
+ continue;
+
+ // Compute the gain of merging the two chains.
+ MergeGainT Gain = getBestMergeGain(Edge);
+ Edge->setMergeGain(Gain);
+
+ if (Edge->gain() > EPS)
+ Queue.insert(Edge);
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << "Cache-directed function sorting reduced the number"
+ << " of chains from " << NumNodes << " to "
+ << NumActiveChains << "\n");
+ }
+
+ /// Compute the gain of merging two chains.
+ ///
+ /// The function considers all possible ways of merging two chains and
+ /// computes the one having the largest increase in ExtTSP objective. The
+ /// result is a pair with the first element being the gain and the second
+ /// element being the corresponding merging type.
+ MergeGainT getBestMergeGain(ChainEdge *Edge) const {
+ assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
+ // Precompute jumps between ChainPred and ChainSucc.
+ MergedJumpsT Jumps(&Edge->jumps());
+ ChainT *SrcChain = Edge->srcChain();
+ ChainT *DstChain = Edge->dstChain();
+
+ // This object holds the best currently chosen gain of merging two chains.
+ MergeGainT Gain = MergeGainT();
+
+ /// Given a list of merge types, try to merge two chains and update Gain
+ /// with a better alternative.
+ auto tryChainMerging = [&](const std::vector<MergeTypeT> &MergeTypes) {
+ // Apply the merge, compute the corresponding gain, and update the best
+ // value, if the merge is beneficial.
+ for (const MergeTypeT &MergeType : MergeTypes) {
+ MergeGainT NewGain =
+ computeMergeGain(SrcChain, DstChain, Jumps, MergeType);
+
+ // When forward and backward gains are the same, prioritize merging that
+ // preserves the original order of the functions in the binary.
+ if (std::abs(Gain.score() - NewGain.score()) < EPS) {
+ if ((MergeType == MergeTypeT::X_Y && SrcChain->Id < DstChain->Id) ||
+ (MergeType == MergeTypeT::Y_X && SrcChain->Id > DstChain->Id)) {
+ Gain = NewGain;
+ }
+ } else if (NewGain.score() > Gain.score() + EPS) {
+ Gain = NewGain;
+ }
+ }
+ };
+
+ // Try to concatenate two chains w/o splitting.
+ tryChainMerging({MergeTypeT::X_Y, MergeTypeT::Y_X});
+
+ return Gain;
+ }
+
+ /// Compute the score gain of merging two chains, respecting a given type.
+ ///
+ /// The two chains are not modified in the method.
+ MergeGainT computeMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
+ const MergedJumpsT &Jumps,
+ MergeTypeT MergeType) const {
+ // This doesn't depend on the ordering of the nodes
+ double FreqGain = freqBasedLocalityGain(ChainPred, ChainSucc);
+
+ // Merge offset is always 0, as the chains are not split.
+ size_t MergeOffset = 0;
+ auto MergedBlocks =
+ mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
+ double DistGain = distBasedLocalityGain(MergedBlocks, Jumps);
+
+ double GainScore = DistGain + Config.FrequencyScale * FreqGain;
+ // Scale the result to increase the importance of merging short chains.
+ if (GainScore >= 0.0)
+ GainScore /= std::min(ChainPred->Size, ChainSucc->Size);
+
+ return MergeGainT(GainScore, MergeOffset, MergeType);
+ }
+
+ /// Compute the change of the frequency locality after merging the chains.
+ double freqBasedLocalityGain(ChainT *ChainPred, ChainT *ChainSucc) const {
+ auto missProbability = [&](double ChainDensity) {
+ double PageSamples = ChainDensity * Config.CacheSize;
+ if (PageSamples >= TotalSamples)
+ return 0.0;
+ double P = PageSamples / TotalSamples;
+ return pow(1.0 - P, static_cast<double>(Config.CacheEntries));
+ };
+
+ // Cache misses on the chains before merging.
+ double CurScore =
+ ChainPred->ExecutionCount * missProbability(ChainPred->density()) +
+ ChainSucc->ExecutionCount * missProbability(ChainSucc->density());
+
+ // Cache misses on the merged chain
+ double MergedCounts = ChainPred->ExecutionCount + ChainSucc->ExecutionCount;
+ double MergedSize = ChainPred->Size + ChainSucc->Size;
+ double MergedDensity = static_cast<double>(MergedCounts) / MergedSize;
+ double NewScore = MergedCounts * missProbability(MergedDensity);
+
+ return CurScore - NewScore;
+ }
+
+ /// Compute the distance locality for a jump / call.
+ double distScore(uint64_t SrcAddr, uint64_t DstAddr, uint64_t Count) const {
+ uint64_t Dist = SrcAddr <= DstAddr ? DstAddr - SrcAddr : SrcAddr - DstAddr;
+ double D = Dist == 0 ? 0.1 : static_cast<double>(Dist);
+ return static_cast<double>(Count) * std::pow(D, -Config.DistancePower);
+ }
+
+ /// Compute the change of the distance locality after merging the chains.
+ double distBasedLocalityGain(const MergedNodesT &Nodes,
+ const MergedJumpsT &Jumps) const {
+ uint64_t CurAddr = 0;
+ Nodes.forEach([&](const NodeT *Node) {
+ Node->EstimatedAddr = CurAddr;
+ CurAddr += Node->Size;
+ });
+
+ double CurScore = 0;
+ double NewScore = 0;
+ Jumps.forEach([&](const JumpT *Jump) {
+ uint64_t SrcAddr = Jump->Source->EstimatedAddr + Jump->Offset;
+ uint64_t DstAddr = Jump->Target->EstimatedAddr;
+ NewScore += distScore(SrcAddr, DstAddr, Jump->ExecutionCount);
+ CurScore += distScore(0, TotalSize, Jump->ExecutionCount);
+ });
+ return NewScore - CurScore;
+ }
+
+ /// Merge chain From into chain Into, update the list of active chains,
+ /// adjacency information, and the corresponding cached values.
+ void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
+ MergeTypeT MergeType) {
+ assert(Into != From && "a chain cannot be merged with itself");
+
+ // Merge the nodes.
+ MergedNodesT MergedNodes =
+ mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
+ Into->merge(From, MergedNodes.getNodes());
+
+ // Merge the edges.
+ Into->mergeEdges(From);
+ From->clear();
+ }
+
+ /// Concatenate all chains into the final order.
+ std::vector<uint64_t> concatChains() {
+ // Collect chains and calculate density stats for their sorting.
std::vector<const ChainT *> SortedChains;
DenseMap<const ChainT *, double> ChainDensity;
for (ChainT &Chain : AllChains) {
if (!Chain.Nodes.empty()) {
SortedChains.push_back(&Chain);
- // Using doubles to avoid overflow of ExecutionCounts
+ // Using doubles to avoid overflow of ExecutionCounts.
double Size = 0;
double ExecutionCount = 0;
for (NodeT *Node : Chain.Nodes) {
@@ -935,30 +1354,29 @@ private:
}
}
- // Sorting chains by density in the decreasing order
- std::stable_sort(SortedChains.begin(), SortedChains.end(),
- [&](const ChainT *L, const ChainT *R) {
- // Make sure the original entry point is at the
- // beginning of the order
- if (L->isEntry() != R->isEntry())
- return L->isEntry();
-
- const double DL = ChainDensity[L];
- const double DR = ChainDensity[R];
- // Compare by density and break ties by chain identifiers
- return (DL != DR) ? (DL > DR) : (L->Id < R->Id);
- });
+ // Sort chains by density in the decreasing order.
+ std::sort(SortedChains.begin(), SortedChains.end(),
+ [&](const ChainT *L, const ChainT *R) {
+ const double DL = ChainDensity[L];
+ const double DR = ChainDensity[R];
+ // Compare by density and break ties by chain identifiers.
+ return std::make_tuple(-DL, L->Id) <
+ std::make_tuple(-DR, R->Id);
+ });
- // Collect the nodes in the order specified by their chains
+ // Collect the nodes in the order specified by their chains.
+ std::vector<uint64_t> Order;
Order.reserve(NumNodes);
- for (const ChainT *Chain : SortedChains) {
- for (NodeT *Node : Chain->Nodes) {
+ for (const ChainT *Chain : SortedChains)
+ for (NodeT *Node : Chain->Nodes)
Order.push_back(Node->Index);
- }
- }
+ return Order;
}
private:
+ /// Config for the algorithm.
+ const CDSortConfig Config;
+
/// The number of nodes in the graph.
const size_t NumNodes;
@@ -968,10 +1386,10 @@ private:
/// Predecessors of each node.
std::vector<std::vector<uint64_t>> PredNodes;
- /// All nodes (basic blocks) in the graph.
+ /// All nodes (functions) in the graph.
std::vector<NodeT> AllNodes;
- /// All jumps between the nodes.
+ /// All jumps (function calls) between the nodes.
std::vector<JumpT> AllJumps;
/// All chains of nodes.
@@ -980,65 +1398,95 @@ private:
/// All edges between the chains.
std::vector<ChainEdge> AllEdges;
- /// Active chains. The vector gets updated at runtime when chains are merged.
- std::vector<ChainT *> HotChains;
+ /// The total number of samples in the graph.
+ uint64_t TotalSamples{0};
+
+ /// The total size of the nodes in the graph.
+ uint64_t TotalSize{0};
};
} // end of anonymous namespace
std::vector<uint64_t>
-llvm::applyExtTspLayout(const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<EdgeCountT> &EdgeCounts) {
- // Verify correctness of the input data
+codelayout::computeExtTspLayout(ArrayRef<uint64_t> NodeSizes,
+ ArrayRef<uint64_t> NodeCounts,
+ ArrayRef<EdgeCount> EdgeCounts) {
+ // Verify correctness of the input data.
assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input");
assert(NodeSizes.size() > 2 && "Incorrect input");
- // Apply the reordering algorithm
+ // Apply the reordering algorithm.
ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts);
- std::vector<uint64_t> Result;
- Alg.run(Result);
+ std::vector<uint64_t> Result = Alg.run();
- // Verify correctness of the output
+ // Verify correctness of the output.
assert(Result.front() == 0 && "Original entry point is not preserved");
assert(Result.size() == NodeSizes.size() && "Incorrect size of layout");
return Result;
}
-double llvm::calcExtTspScore(const std::vector<uint64_t> &Order,
- const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<EdgeCountT> &EdgeCounts) {
- // Estimate addresses of the blocks in memory
+double codelayout::calcExtTspScore(ArrayRef<uint64_t> Order,
+ ArrayRef<uint64_t> NodeSizes,
+ ArrayRef<uint64_t> NodeCounts,
+ ArrayRef<EdgeCount> EdgeCounts) {
+ // Estimate addresses of the blocks in memory.
std::vector<uint64_t> Addr(NodeSizes.size(), 0);
for (size_t Idx = 1; Idx < Order.size(); Idx++) {
Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]];
}
std::vector<uint64_t> OutDegree(NodeSizes.size(), 0);
- for (auto It : EdgeCounts) {
- uint64_t Pred = It.first.first;
- OutDegree[Pred]++;
- }
+ for (auto Edge : EdgeCounts)
+ ++OutDegree[Edge.src];
- // Increase the score for each jump
+ // Increase the score for each jump.
double Score = 0;
- for (auto It : EdgeCounts) {
- uint64_t Pred = It.first.first;
- uint64_t Succ = It.first.second;
- uint64_t Count = It.second;
- bool IsConditional = OutDegree[Pred] > 1;
- Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count,
- IsConditional);
+ for (auto Edge : EdgeCounts) {
+ bool IsConditional = OutDegree[Edge.src] > 1;
+ Score += ::extTSPScore(Addr[Edge.src], NodeSizes[Edge.src], Addr[Edge.dst],
+ Edge.count, IsConditional);
}
return Score;
}
-double llvm::calcExtTspScore(const std::vector<uint64_t> &NodeSizes,
- const std::vector<uint64_t> &NodeCounts,
- const std::vector<EdgeCountT> &EdgeCounts) {
+double codelayout::calcExtTspScore(ArrayRef<uint64_t> NodeSizes,
+ ArrayRef<uint64_t> NodeCounts,
+ ArrayRef<EdgeCount> EdgeCounts) {
std::vector<uint64_t> Order(NodeSizes.size());
for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) {
Order[Idx] = Idx;
}
return calcExtTspScore(Order, NodeSizes, NodeCounts, EdgeCounts);
}
+
+std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
+ const CDSortConfig &Config, ArrayRef<uint64_t> FuncSizes,
+ ArrayRef<uint64_t> FuncCounts, ArrayRef<EdgeCount> CallCounts,
+ ArrayRef<uint64_t> CallOffsets) {
+ // Verify correctness of the input data.
+ assert(FuncCounts.size() == FuncSizes.size() && "Incorrect input");
+
+ // Apply the reordering algorithm.
+ CDSortImpl Alg(Config, FuncSizes, FuncCounts, CallCounts, CallOffsets);
+ std::vector<uint64_t> Result = Alg.run();
+ assert(Result.size() == FuncSizes.size() && "Incorrect size of layout");
+ return Result;
+}
+
+std::vector<uint64_t> codelayout::computeCacheDirectedLayout(
+ ArrayRef<uint64_t> FuncSizes, ArrayRef<uint64_t> FuncCounts,
+ ArrayRef<EdgeCount> CallCounts, ArrayRef<uint64_t> CallOffsets) {
+ CDSortConfig Config;
+ // Populate the config from the command-line options.
+ if (CacheEntries.getNumOccurrences() > 0)
+ Config.CacheEntries = CacheEntries;
+ if (CacheSize.getNumOccurrences() > 0)
+ Config.CacheSize = CacheSize;
+ if (CDMaxChainSize.getNumOccurrences() > 0)
+ Config.MaxChainSize = CDMaxChainSize;
+ if (DistancePower.getNumOccurrences() > 0)
+ Config.DistancePower = DistancePower;
+ if (FrequencyScale.getNumOccurrences() > 0)
+ Config.FrequencyScale = FrequencyScale;
+ return computeCacheDirectedLayout(Config, FuncSizes, FuncCounts, CallCounts,
+ CallOffsets);
+}
diff --git a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp
index 4a6719741719..6a2dae5bab68 100644
--- a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp
@@ -417,7 +417,7 @@ void llvm::moveInstructionsToTheBeginning(BasicBlock &FromBB, BasicBlock &ToBB,
Instruction *MovePos = ToBB.getFirstNonPHIOrDbg();
if (isSafeToMoveBefore(I, *MovePos, DT, &PDT, &DI))
- I.moveBefore(MovePos);
+ I.moveBeforePreserving(MovePos);
}
}
@@ -429,7 +429,7 @@ void llvm::moveInstructionsToTheEnd(BasicBlock &FromBB, BasicBlock &ToBB,
while (FromBB.size() > 1) {
Instruction &I = FromBB.front();
if (isSafeToMoveBefore(I, *MovePos, DT, &PDT, &DI))
- I.moveBefore(MovePos);
+ I.moveBeforePreserving(MovePos);
}
}
diff --git a/llvm/lib/Transforms/Utils/CtorUtils.cpp b/llvm/lib/Transforms/Utils/CtorUtils.cpp
index e07c92df2265..507729bc5ebc 100644
--- a/llvm/lib/Transforms/Utils/CtorUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CtorUtils.cpp
@@ -52,12 +52,9 @@ static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemov
NGV->takeName(GCL);
// Nuke the old list, replacing any uses with the new one.
- if (!GCL->use_empty()) {
- Constant *V = NGV;
- if (V->getType() != GCL->getType())
- V = ConstantExpr::getBitCast(V, GCL->getType());
- GCL->replaceAllUsesWith(V);
- }
+ if (!GCL->use_empty())
+ GCL->replaceAllUsesWith(NGV);
+
GCL->eraseFromParent();
}
diff --git a/llvm/lib/Transforms/Utils/DXILUpgrade.cpp b/llvm/lib/Transforms/Utils/DXILUpgrade.cpp
new file mode 100644
index 000000000000..735686ddce38
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/DXILUpgrade.cpp
@@ -0,0 +1,36 @@
+//===- DXILUpgrade.cpp - Upgrade DXIL metadata to LLVM constructs ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/DXILUpgrade.h"
+
+using namespace llvm;
+
+static bool handleValVerMetadata(Module &M) {
+ NamedMDNode *ValVer = M.getNamedMetadata("dx.valver");
+ if (!ValVer)
+ return false;
+
+ // We don't need the validation version internally, so we drop it.
+ ValVer->dropAllReferences();
+ ValVer->eraseFromParent();
+ return true;
+}
+
+PreservedAnalyses DXILUpgradePass::run(Module &M, ModuleAnalysisManager &AM) {
+ PreservedAnalyses PA;
+ // We never add, remove, or change functions here.
+ PA.preserve<FunctionAnalysisManagerModuleProxy>();
+ PA.preserveSet<AllAnalysesOn<Function>>();
+
+ bool Changed = false;
+ Changed |= handleValVerMetadata(M);
+
+ if (!Changed)
+ return PreservedAnalyses::all();
+ return PA;
+}
diff --git a/llvm/lib/Transforms/Utils/Debugify.cpp b/llvm/lib/Transforms/Utils/Debugify.cpp
index 93cad0888a56..d0cc603426d2 100644
--- a/llvm/lib/Transforms/Utils/Debugify.cpp
+++ b/llvm/lib/Transforms/Utils/Debugify.cpp
@@ -801,7 +801,15 @@ bool checkDebugifyMetadata(Module &M,
/// legacy module pass manager.
struct DebugifyModulePass : public ModulePass {
bool runOnModule(Module &M) override {
- return applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass);
+ bool NewDebugMode = M.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ M.convertFromNewDbgValues();
+
+ bool Result = applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass);
+
+ if (NewDebugMode)
+ M.convertToNewDbgValues();
+ return Result;
}
DebugifyModulePass(enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo,
@@ -826,7 +834,15 @@ private:
/// single function, used with the legacy module pass manager.
struct DebugifyFunctionPass : public FunctionPass {
bool runOnFunction(Function &F) override {
- return applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass);
+ bool NewDebugMode = F.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ F.convertFromNewDbgValues();
+
+ bool Result = applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass);
+
+ if (NewDebugMode)
+ F.convertToNewDbgValues();
+ return Result;
}
DebugifyFunctionPass(
@@ -852,13 +868,24 @@ private:
/// legacy module pass manager.
struct CheckDebugifyModulePass : public ModulePass {
bool runOnModule(Module &M) override {
+ bool NewDebugMode = M.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ M.convertFromNewDbgValues();
+
+ bool Result;
if (Mode == DebugifyMode::SyntheticDebugInfo)
- return checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass,
+ Result = checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass,
"CheckModuleDebugify", Strip, StatsMap);
- return checkDebugInfoMetadata(
+ else
+ Result = checkDebugInfoMetadata(
M, M.functions(), *DebugInfoBeforePass,
"CheckModuleDebugify (original debuginfo)", NameOfWrappedPass,
OrigDIVerifyBugsReportFilePath);
+
+ if (NewDebugMode)
+ M.convertToNewDbgValues();
+
+ return Result;
}
CheckDebugifyModulePass(
@@ -891,16 +918,26 @@ private:
/// with the legacy module pass manager.
struct CheckDebugifyFunctionPass : public FunctionPass {
bool runOnFunction(Function &F) override {
+ bool NewDebugMode = F.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ F.convertFromNewDbgValues();
+
Module &M = *F.getParent();
auto FuncIt = F.getIterator();
+ bool Result;
if (Mode == DebugifyMode::SyntheticDebugInfo)
- return checkDebugifyMetadata(M, make_range(FuncIt, std::next(FuncIt)),
+ Result = checkDebugifyMetadata(M, make_range(FuncIt, std::next(FuncIt)),
NameOfWrappedPass, "CheckFunctionDebugify",
Strip, StatsMap);
- return checkDebugInfoMetadata(
+ else
+ Result = checkDebugInfoMetadata(
M, make_range(FuncIt, std::next(FuncIt)), *DebugInfoBeforePass,
"CheckFunctionDebugify (original debuginfo)", NameOfWrappedPass,
OrigDIVerifyBugsReportFilePath);
+
+ if (NewDebugMode)
+ F.convertToNewDbgValues();
+ return Result;
}
CheckDebugifyFunctionPass(
@@ -972,6 +1009,10 @@ createDebugifyFunctionPass(enum DebugifyMode Mode,
}
PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) {
+ bool NewDebugMode = M.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ M.convertFromNewDbgValues();
+
if (Mode == DebugifyMode::SyntheticDebugInfo)
applyDebugifyMetadata(M, M.functions(),
"ModuleDebugify: ", /*ApplyToMF*/ nullptr);
@@ -979,6 +1020,10 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) {
collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass,
"ModuleDebugify (original debuginfo)",
NameOfWrappedPass);
+
+ if (NewDebugMode)
+ M.convertToNewDbgValues();
+
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return PA;
@@ -1010,6 +1055,10 @@ FunctionPass *createCheckDebugifyFunctionPass(
PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M,
ModuleAnalysisManager &) {
+ bool NewDebugMode = M.IsNewDbgInfoFormat;
+ if (NewDebugMode)
+ M.convertFromNewDbgValues();
+
if (Mode == DebugifyMode::SyntheticDebugInfo)
checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass,
"CheckModuleDebugify", Strip, StatsMap);
@@ -1018,6 +1067,10 @@ PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M,
M, M.functions(), *DebugInfoBeforePass,
"CheckModuleDebugify (original debuginfo)", NameOfWrappedPass,
OrigDIVerifyBugsReportFilePath);
+
+ if (NewDebugMode)
+ M.convertToNewDbgValues();
+
return PreservedAnalyses::all();
}
@@ -1035,13 +1088,13 @@ void DebugifyEachInstrumentation::registerCallbacks(
return;
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
- if (const auto **CF = any_cast<const Function *>(&IR)) {
+ if (const auto **CF = llvm::any_cast<const Function *>(&IR)) {
Function &F = *const_cast<Function *>(*CF);
applyDebugify(F, Mode, DebugInfoBeforePass, P);
MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent())
.getManager()
.invalidate(F, PA);
- } else if (const auto **CM = any_cast<const Module *>(&IR)) {
+ } else if (const auto **CM = llvm::any_cast<const Module *>(&IR)) {
Module &M = *const_cast<Module *>(*CM);
applyDebugify(M, Mode, DebugInfoBeforePass, P);
MAM.invalidate(M, PA);
@@ -1053,7 +1106,7 @@ void DebugifyEachInstrumentation::registerCallbacks(
return;
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
- if (const auto **CF = any_cast<const Function *>(&IR)) {
+ if (const auto **CF = llvm::any_cast<const Function *>(&IR)) {
auto &F = *const_cast<Function *>(*CF);
Module &M = *F.getParent();
auto It = F.getIterator();
@@ -1069,7 +1122,7 @@ void DebugifyEachInstrumentation::registerCallbacks(
MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent())
.getManager()
.invalidate(F, PA);
- } else if (const auto **CM = any_cast<const Module *>(&IR)) {
+ } else if (const auto **CM = llvm::any_cast<const Module *>(&IR)) {
Module &M = *const_cast<Module *>(*CM);
if (Mode == DebugifyMode::SyntheticDebugInfo)
checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify",
diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
index d424ebbef99d..092f1799755d 100644
--- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
+++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp
@@ -35,7 +35,7 @@ static void insertCall(Function &CurFn, StringRef Func,
Triple TargetTriple(M.getTargetTriple());
if (TargetTriple.isOSAIX() && Func == "__mcount") {
Type *SizeTy = M.getDataLayout().getIntPtrType(C);
- Type *SizePtrTy = SizeTy->getPointerTo();
+ Type *SizePtrTy = PointerType::getUnqual(C);
GlobalVariable *GV = new GlobalVariable(M, SizeTy, /*isConstant=*/false,
GlobalValue::InternalLinkage,
ConstantInt::get(SizeTy, 0));
@@ -54,7 +54,7 @@ static void insertCall(Function &CurFn, StringRef Func,
}
if (Func == "__cyg_profile_func_enter" || Func == "__cyg_profile_func_exit") {
- Type *ArgTypes[] = {Type::getInt8PtrTy(C), Type::getInt8PtrTy(C)};
+ Type *ArgTypes[] = {PointerType::getUnqual(C), PointerType::getUnqual(C)};
FunctionCallee Fn = M.getOrInsertFunction(
Func, FunctionType::get(Type::getVoidTy(C), ArgTypes, false));
@@ -65,9 +65,7 @@ static void insertCall(Function &CurFn, StringRef Func,
InsertionPt);
RetAddr->setDebugLoc(DL);
- Value *Args[] = {ConstantExpr::getBitCast(&CurFn, Type::getInt8PtrTy(C)),
- RetAddr};
-
+ Value *Args[] = {&CurFn, RetAddr};
CallInst *Call =
CallInst::Create(Fn, ArrayRef<Value *>(Args), "", InsertionPt);
Call->setDebugLoc(DL);
diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
index 88c838685bca..cc00106fcbfe 100644
--- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
+++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
@@ -70,7 +70,7 @@ IRBuilder<> *EscapeEnumerator::Next() {
// Create a cleanup block.
LLVMContext &C = F.getContext();
BasicBlock *CleanupBB = BasicBlock::Create(C, CleanupBBName, &F);
- Type *ExnTy = StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C));
+ Type *ExnTy = StructType::get(PointerType::getUnqual(C), Type::getInt32Ty(C));
if (!F.hasPersonalityFn()) {
FunctionCallee PersFn = getDefaultPersonalityFn(F.getParent());
F.setPersonalityFn(cast<Constant>(PersFn.getCallee()));
diff --git a/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/llvm/lib/Transforms/Utils/FixIrreducible.cpp
index dda236167363..11e24d0585be 100644
--- a/llvm/lib/Transforms/Utils/FixIrreducible.cpp
+++ b/llvm/lib/Transforms/Utils/FixIrreducible.cpp
@@ -87,10 +87,8 @@ struct FixIrreducible : public FunctionPass {
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequiredID(LowerSwitchID);
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreservedID(LowerSwitchID);
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<LoopInfoWrapperPass>();
}
@@ -106,7 +104,6 @@ FunctionPass *llvm::createFixIrreduciblePass() { return new FixIrreducible(); }
INITIALIZE_PASS_BEGIN(FixIrreducible, "fix-irreducible",
"Convert irreducible control-flow into natural loops",
false /* Only looks at CFG */, false /* Analysis Pass */)
-INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(FixIrreducible, "fix-irreducible",
@@ -317,6 +314,8 @@ static bool FixIrreducibleImpl(Function &F, LoopInfo &LI, DominatorTree &DT) {
LLVM_DEBUG(dbgs() << "===== Fix irreducible control-flow in function: "
<< F.getName() << "\n");
+ assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator.");
+
bool Changed = false;
SmallVector<Loop *, 8> WorkList;
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index 8daeb92130ba..79ca99d1566c 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -160,10 +160,23 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
int FunctionComparator::cmpMetadata(const Metadata *L,
const Metadata *R) const {
// TODO: the following routine coerce the metadata contents into constants
- // before comparison.
+ // or MDStrings before comparison.
// It ignores any other cases, so that the metadata nodes are considered
// equal even though this is not correct.
// We should structurally compare the metadata nodes to be perfect here.
+
+ auto *MDStringL = dyn_cast<MDString>(L);
+ auto *MDStringR = dyn_cast<MDString>(R);
+ if (MDStringL && MDStringR) {
+ if (MDStringL == MDStringR)
+ return 0;
+ return MDStringL->getString().compare(MDStringR->getString());
+ }
+ if (MDStringR)
+ return -1;
+ if (MDStringL)
+ return 1;
+
auto *CL = dyn_cast<ConstantAsMetadata>(L);
auto *CR = dyn_cast<ConstantAsMetadata>(R);
if (CL == CR)
@@ -820,6 +833,21 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const {
if (ConstR)
return -1;
+ const MetadataAsValue *MetadataValueL = dyn_cast<MetadataAsValue>(L);
+ const MetadataAsValue *MetadataValueR = dyn_cast<MetadataAsValue>(R);
+ if (MetadataValueL && MetadataValueR) {
+ if (MetadataValueL == MetadataValueR)
+ return 0;
+
+ return cmpMetadata(MetadataValueL->getMetadata(),
+ MetadataValueR->getMetadata());
+ }
+
+ if (MetadataValueL)
+ return 1;
+ if (MetadataValueR)
+ return -1;
+
const InlineAsm *InlineAsmL = dyn_cast<InlineAsm>(L);
const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R);
@@ -958,67 +986,3 @@ int FunctionComparator::compare() {
}
return 0;
}
-
-namespace {
-
-// Accumulate the hash of a sequence of 64-bit integers. This is similar to a
-// hash of a sequence of 64bit ints, but the entire input does not need to be
-// available at once. This interface is necessary for functionHash because it
-// needs to accumulate the hash as the structure of the function is traversed
-// without saving these values to an intermediate buffer. This form of hashing
-// is not often needed, as usually the object to hash is just read from a
-// buffer.
-class HashAccumulator64 {
- uint64_t Hash;
-
-public:
- // Initialize to random constant, so the state isn't zero.
- HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; }
-
- void add(uint64_t V) { Hash = hashing::detail::hash_16_bytes(Hash, V); }
-
- // No finishing is required, because the entire hash value is used.
- uint64_t getHash() { return Hash; }
-};
-
-} // end anonymous namespace
-
-// A function hash is calculated by considering only the number of arguments and
-// whether a function is varargs, the order of basic blocks (given by the
-// successors of each basic block in depth first order), and the order of
-// opcodes of each instruction within each of these basic blocks. This mirrors
-// the strategy compare() uses to compare functions by walking the BBs in depth
-// first order and comparing each instruction in sequence. Because this hash
-// does not look at the operands, it is insensitive to things such as the
-// target of calls and the constants used in the function, which makes it useful
-// when possibly merging functions which are the same modulo constants and call
-// targets.
-FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) {
- HashAccumulator64 H;
- H.add(F.isVarArg());
- H.add(F.arg_size());
-
- SmallVector<const BasicBlock *, 8> BBs;
- SmallPtrSet<const BasicBlock *, 16> VisitedBBs;
-
- // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(),
- // accumulating the hash of the function "structure." (BB and opcode sequence)
- BBs.push_back(&F.getEntryBlock());
- VisitedBBs.insert(BBs[0]);
- while (!BBs.empty()) {
- const BasicBlock *BB = BBs.pop_back_val();
- // This random value acts as a block header, as otherwise the partition of
- // opcodes into BBs wouldn't affect the hash, only the order of the opcodes
- H.add(45798);
- for (const auto &Inst : *BB) {
- H.add(Inst.getOpcode());
- }
- const Instruction *Term = BB->getTerminator();
- for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) {
- if (!VisitedBBs.insert(Term->getSuccessor(i)).second)
- continue;
- BBs.push_back(Term->getSuccessor(i));
- }
- }
- return H.getHash();
-}
diff --git a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
index dab0be3a9fde..0990c750af55 100644
--- a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
+++ b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp
@@ -91,18 +91,16 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) {
Mappings.end());
auto AddVariantDecl = [&](const ElementCount &VF, bool Predicate) {
- const std::string TLIName =
- std::string(TLI.getVectorizedFunction(ScalarName, VF, Predicate));
- if (!TLIName.empty()) {
- std::string MangledName = VFABI::mangleTLIVectorName(
- TLIName, ScalarName, CI.arg_size(), VF, Predicate);
+ const VecDesc *VD = TLI.getVectorMappingInfo(ScalarName, VF, Predicate);
+ if (VD && !VD->getVectorFnName().empty()) {
+ std::string MangledName = VD->getVectorFunctionABIVariantString();
if (!OriginalSetOfMappings.count(MangledName)) {
Mappings.push_back(MangledName);
++NumCallInjected;
}
- Function *VariantF = M->getFunction(TLIName);
+ Function *VariantF = M->getFunction(VD->getVectorFnName());
if (!VariantF)
- addVariantDeclaration(CI, VF, Predicate, TLIName);
+ addVariantDeclaration(CI, VF, Predicate, VD->getVectorFnName());
}
};
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index f7b93fc8fd06..39d5f6e53c1d 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -30,6 +30,7 @@
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
@@ -189,20 +190,21 @@ BasicBlock *LandingPadInliningInfo::getInnerResumeDest() {
const unsigned PHICapacity = 2;
// Create corresponding new PHIs for all the PHIs in the outer landing pad.
- Instruction *InsertPoint = &InnerResumeDest->front();
+ BasicBlock::iterator InsertPoint = InnerResumeDest->begin();
BasicBlock::iterator I = OuterResumeDest->begin();
for (unsigned i = 0, e = UnwindDestPHIValues.size(); i != e; ++i, ++I) {
PHINode *OuterPHI = cast<PHINode>(I);
PHINode *InnerPHI = PHINode::Create(OuterPHI->getType(), PHICapacity,
- OuterPHI->getName() + ".lpad-body",
- InsertPoint);
+ OuterPHI->getName() + ".lpad-body");
+ InnerPHI->insertBefore(InsertPoint);
OuterPHI->replaceAllUsesWith(InnerPHI);
InnerPHI->addIncoming(OuterPHI, OuterResumeDest);
}
// Create a PHI for the exception values.
- InnerEHValuesPHI = PHINode::Create(CallerLPad->getType(), PHICapacity,
- "eh.lpad-body", InsertPoint);
+ InnerEHValuesPHI =
+ PHINode::Create(CallerLPad->getType(), PHICapacity, "eh.lpad-body");
+ InnerEHValuesPHI->insertBefore(InsertPoint);
CallerLPad->replaceAllUsesWith(InnerEHValuesPHI);
InnerEHValuesPHI->addIncoming(CallerLPad, OuterResumeDest);
@@ -1331,38 +1333,51 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap,
}
}
-static bool MayContainThrowingOrExitingCall(Instruction *Begin,
- Instruction *End) {
+static bool MayContainThrowingOrExitingCallAfterCB(CallBase *Begin,
+ ReturnInst *End) {
assert(Begin->getParent() == End->getParent() &&
"Expected to be in same basic block!");
+ auto BeginIt = Begin->getIterator();
+ assert(BeginIt != End->getIterator() && "Non-empty BB has empty iterator");
return !llvm::isGuaranteedToTransferExecutionToSuccessor(
- Begin->getIterator(), End->getIterator(), InlinerAttributeWindow + 1);
+ ++BeginIt, End->getIterator(), InlinerAttributeWindow + 1);
}
-static AttrBuilder IdentifyValidAttributes(CallBase &CB) {
+// Only allow these white listed attributes to be propagated back to the
+// callee. This is because other attributes may only be valid on the call
+// itself, i.e. attributes such as signext and zeroext.
- AttrBuilder AB(CB.getContext(), CB.getAttributes().getRetAttrs());
- if (!AB.hasAttributes())
- return AB;
+// Attributes that are always okay to propagate as if they are violated its
+// immediate UB.
+static AttrBuilder IdentifyValidUBGeneratingAttributes(CallBase &CB) {
AttrBuilder Valid(CB.getContext());
- // Only allow these white listed attributes to be propagated back to the
- // callee. This is because other attributes may only be valid on the call
- // itself, i.e. attributes such as signext and zeroext.
- if (auto DerefBytes = AB.getDereferenceableBytes())
+ if (auto DerefBytes = CB.getRetDereferenceableBytes())
Valid.addDereferenceableAttr(DerefBytes);
- if (auto DerefOrNullBytes = AB.getDereferenceableOrNullBytes())
+ if (auto DerefOrNullBytes = CB.getRetDereferenceableOrNullBytes())
Valid.addDereferenceableOrNullAttr(DerefOrNullBytes);
- if (AB.contains(Attribute::NoAlias))
+ if (CB.hasRetAttr(Attribute::NoAlias))
Valid.addAttribute(Attribute::NoAlias);
- if (AB.contains(Attribute::NonNull))
+ if (CB.hasRetAttr(Attribute::NoUndef))
+ Valid.addAttribute(Attribute::NoUndef);
+ return Valid;
+}
+
+// Attributes that need additional checks as propagating them may change
+// behavior or cause new UB.
+static AttrBuilder IdentifyValidPoisonGeneratingAttributes(CallBase &CB) {
+ AttrBuilder Valid(CB.getContext());
+ if (CB.hasRetAttr(Attribute::NonNull))
Valid.addAttribute(Attribute::NonNull);
+ if (CB.hasRetAttr(Attribute::Alignment))
+ Valid.addAlignmentAttr(CB.getRetAlign());
return Valid;
}
static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) {
- AttrBuilder Valid = IdentifyValidAttributes(CB);
- if (!Valid.hasAttributes())
+ AttrBuilder ValidUB = IdentifyValidUBGeneratingAttributes(CB);
+ AttrBuilder ValidPG = IdentifyValidPoisonGeneratingAttributes(CB);
+ if (!ValidUB.hasAttributes() && !ValidPG.hasAttributes())
return;
auto *CalledFunction = CB.getCalledFunction();
auto &Context = CalledFunction->getContext();
@@ -1397,7 +1412,7 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) {
// limit the check to both RetVal and RI are in the same basic block and
// there are no throwing/exiting instructions between these instructions.
if (RI->getParent() != RetVal->getParent() ||
- MayContainThrowingOrExitingCall(RetVal, RI))
+ MayContainThrowingOrExitingCallAfterCB(RetVal, RI))
continue;
// Add to the existing attributes of NewRetVal, i.e. the cloned call
// instruction.
@@ -1406,7 +1421,62 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) {
// existing attribute value (i.e. attributes such as dereferenceable,
// dereferenceable_or_null etc). See AttrBuilder::merge for more details.
AttributeList AL = NewRetVal->getAttributes();
- AttributeList NewAL = AL.addRetAttributes(Context, Valid);
+ if (ValidUB.getDereferenceableBytes() < AL.getRetDereferenceableBytes())
+ ValidUB.removeAttribute(Attribute::Dereferenceable);
+ if (ValidUB.getDereferenceableOrNullBytes() <
+ AL.getRetDereferenceableOrNullBytes())
+ ValidUB.removeAttribute(Attribute::DereferenceableOrNull);
+ AttributeList NewAL = AL.addRetAttributes(Context, ValidUB);
+ // Attributes that may generate poison returns are a bit tricky. If we
+ // propagate them, other uses of the callsite might have their behavior
+ // change or cause UB (if they have noundef) b.c of the new potential
+ // poison.
+ // Take the following three cases:
+ //
+ // 1)
+ // define nonnull ptr @foo() {
+ // %p = call ptr @bar()
+ // call void @use(ptr %p) willreturn nounwind
+ // ret ptr %p
+ // }
+ //
+ // 2)
+ // define noundef nonnull ptr @foo() {
+ // %p = call ptr @bar()
+ // call void @use(ptr %p) willreturn nounwind
+ // ret ptr %p
+ // }
+ //
+ // 3)
+ // define nonnull ptr @foo() {
+ // %p = call noundef ptr @bar()
+ // ret ptr %p
+ // }
+ //
+ // In case 1, we can't propagate nonnull because poison value in @use may
+ // change behavior or trigger UB.
+ // In case 2, we don't need to be concerned about propagating nonnull, as
+ // any new poison at @use will trigger UB anyways.
+ // In case 3, we can never propagate nonnull because it may create UB due to
+ // the noundef on @bar.
+ if (ValidPG.getAlignment().valueOrOne() < AL.getRetAlignment().valueOrOne())
+ ValidPG.removeAttribute(Attribute::Alignment);
+ if (ValidPG.hasAttributes()) {
+ // Three checks.
+ // If the callsite has `noundef`, then a poison due to violating the
+ // return attribute will create UB anyways so we can always propagate.
+ // Otherwise, if the return value (callee to be inlined) has `noundef`, we
+ // can't propagate as a new poison return will cause UB.
+ // Finally, check if the return value has no uses whose behavior may
+ // change/may cause UB if we potentially return poison. At the moment this
+ // is implemented overly conservatively with a single-use check.
+ // TODO: Update the single-use check to iterate through uses and only bail
+ // if we have a potentially dangerous use.
+
+ if (CB.hasRetAttr(Attribute::NoUndef) ||
+ (RetVal->hasOneUse() && !RetVal->hasRetAttr(Attribute::NoUndef)))
+ NewAL = NewAL.addRetAttributes(Context, ValidPG);
+ }
NewRetVal->setAttributes(NewAL);
}
}
@@ -1515,10 +1585,10 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg,
if (ByValAlignment)
Alignment = std::max(Alignment, *ByValAlignment);
- Value *NewAlloca =
- new AllocaInst(ByValType, DL.getAllocaAddrSpace(), nullptr, Alignment,
- Arg->getName(), &*Caller->begin()->begin());
- IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca));
+ AllocaInst *NewAlloca = new AllocaInst(ByValType, DL.getAllocaAddrSpace(),
+ nullptr, Alignment, Arg->getName());
+ NewAlloca->insertBefore(Caller->begin()->begin());
+ IFI.StaticAllocas.push_back(NewAlloca);
// Uses of the argument in the function should use our new alloca
// instead.
@@ -1538,8 +1608,8 @@ static bool isUsedByLifetimeMarker(Value *V) {
// lifetime.start or lifetime.end intrinsics.
static bool hasLifetimeMarkers(AllocaInst *AI) {
Type *Ty = AI->getType();
- Type *Int8PtrTy = Type::getInt8PtrTy(Ty->getContext(),
- Ty->getPointerAddressSpace());
+ Type *Int8PtrTy =
+ PointerType::get(Ty->getContext(), Ty->getPointerAddressSpace());
if (Ty == Int8PtrTy)
return isUsedByLifetimeMarker(AI);
@@ -1596,48 +1666,71 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,
// the call site location instead.
bool NoInlineLineTables = Fn->hasFnAttribute("no-inline-line-tables");
- for (; FI != Fn->end(); ++FI) {
- for (BasicBlock::iterator BI = FI->begin(), BE = FI->end();
- BI != BE; ++BI) {
- // Loop metadata needs to be updated so that the start and end locs
- // reference inlined-at locations.
- auto updateLoopInfoLoc = [&Ctx, &InlinedAtNode,
- &IANodes](Metadata *MD) -> Metadata * {
- if (auto *Loc = dyn_cast_or_null<DILocation>(MD))
- return inlineDebugLoc(Loc, InlinedAtNode, Ctx, IANodes).get();
- return MD;
- };
- updateLoopMetadataDebugLocations(*BI, updateLoopInfoLoc);
+ // Helper-util for updating the metadata attached to an instruction.
+ auto UpdateInst = [&](Instruction &I) {
+ // Loop metadata needs to be updated so that the start and end locs
+ // reference inlined-at locations.
+ auto updateLoopInfoLoc = [&Ctx, &InlinedAtNode,
+ &IANodes](Metadata *MD) -> Metadata * {
+ if (auto *Loc = dyn_cast_or_null<DILocation>(MD))
+ return inlineDebugLoc(Loc, InlinedAtNode, Ctx, IANodes).get();
+ return MD;
+ };
+ updateLoopMetadataDebugLocations(I, updateLoopInfoLoc);
- if (!NoInlineLineTables)
- if (DebugLoc DL = BI->getDebugLoc()) {
- DebugLoc IDL =
- inlineDebugLoc(DL, InlinedAtNode, BI->getContext(), IANodes);
- BI->setDebugLoc(IDL);
- continue;
- }
+ if (!NoInlineLineTables)
+ if (DebugLoc DL = I.getDebugLoc()) {
+ DebugLoc IDL =
+ inlineDebugLoc(DL, InlinedAtNode, I.getContext(), IANodes);
+ I.setDebugLoc(IDL);
+ return;
+ }
- if (CalleeHasDebugInfo && !NoInlineLineTables)
- continue;
+ if (CalleeHasDebugInfo && !NoInlineLineTables)
+ return;
- // If the inlined instruction has no line number, or if inline info
- // is not being generated, make it look as if it originates from the call
- // location. This is important for ((__always_inline, __nodebug__))
- // functions which must use caller location for all instructions in their
- // function body.
+ // If the inlined instruction has no line number, or if inline info
+ // is not being generated, make it look as if it originates from the call
+ // location. This is important for ((__always_inline, __nodebug__))
+ // functions which must use caller location for all instructions in their
+ // function body.
- // Don't update static allocas, as they may get moved later.
- if (auto *AI = dyn_cast<AllocaInst>(BI))
- if (allocaWouldBeStaticInEntry(AI))
- continue;
+ // Don't update static allocas, as they may get moved later.
+ if (auto *AI = dyn_cast<AllocaInst>(&I))
+ if (allocaWouldBeStaticInEntry(AI))
+ return;
- // Do not force a debug loc for pseudo probes, since they do not need to
- // be debuggable, and also they are expected to have a zero/null dwarf
- // discriminator at this point which could be violated otherwise.
- if (isa<PseudoProbeInst>(BI))
- continue;
+ // Do not force a debug loc for pseudo probes, since they do not need to
+ // be debuggable, and also they are expected to have a zero/null dwarf
+ // discriminator at this point which could be violated otherwise.
+ if (isa<PseudoProbeInst>(I))
+ return;
- BI->setDebugLoc(TheCallDL);
+ I.setDebugLoc(TheCallDL);
+ };
+
+ // Helper-util for updating debug-info records attached to instructions.
+ auto UpdateDPV = [&](DPValue *DPV) {
+ assert(DPV->getDebugLoc() && "Debug Value must have debug loc");
+ if (NoInlineLineTables) {
+ DPV->setDebugLoc(TheCallDL);
+ return;
+ }
+ DebugLoc DL = DPV->getDebugLoc();
+ DebugLoc IDL =
+ inlineDebugLoc(DL, InlinedAtNode,
+ DPV->getMarker()->getParent()->getContext(), IANodes);
+ DPV->setDebugLoc(IDL);
+ };
+
+ // Iterate over all instructions, updating metadata and debug-info records.
+ for (; FI != Fn->end(); ++FI) {
+ for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE;
+ ++BI) {
+ UpdateInst(*BI);
+ for (DPValue &DPV : BI->getDbgValueRange()) {
+ UpdateDPV(&DPV);
+ }
}
// Remove debug info intrinsics if we're not keeping inline info.
@@ -1647,11 +1740,12 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI,
if (isa<DbgInfoIntrinsic>(BI)) {
BI = BI->eraseFromParent();
continue;
+ } else {
+ BI->dropDbgValues();
}
++BI;
}
}
-
}
}
@@ -1760,12 +1854,12 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock,
continue;
auto *OrigBB = cast<BasicBlock>(Entry.first);
auto *ClonedBB = cast<BasicBlock>(Entry.second);
- uint64_t Freq = CalleeBFI->getBlockFreq(OrigBB).getFrequency();
+ BlockFrequency Freq = CalleeBFI->getBlockFreq(OrigBB);
if (!ClonedBBs.insert(ClonedBB).second) {
// Multiple blocks in the callee might get mapped to one cloned block in
// the caller since we prune the callee as we clone it. When that happens,
// we want to use the maximum among the original blocks' frequencies.
- uint64_t NewFreq = CallerBFI->getBlockFreq(ClonedBB).getFrequency();
+ BlockFrequency NewFreq = CallerBFI->getBlockFreq(ClonedBB);
if (NewFreq > Freq)
Freq = NewFreq;
}
@@ -1773,8 +1867,7 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock,
}
BasicBlock *EntryClone = cast<BasicBlock>(VMap.lookup(&CalleeEntryBlock));
CallerBFI->setBlockFreqAndScale(
- EntryClone, CallerBFI->getBlockFreq(CallSiteBlock).getFrequency(),
- ClonedBBs);
+ EntryClone, CallerBFI->getBlockFreq(CallSiteBlock), ClonedBBs);
}
/// Update the branch metadata for cloned call instructions.
@@ -1882,8 +1975,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind,
Builder.SetInsertPoint(II);
Function *IFn =
Intrinsic::getDeclaration(Mod, Intrinsic::objc_release);
- Value *BC = Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType());
- Builder.CreateCall(IFn, BC, "");
+ Builder.CreateCall(IFn, RetOpnd, "");
}
II->eraseFromParent();
InsertRetainCall = false;
@@ -1918,8 +2010,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind,
// to objc_retain.
Builder.SetInsertPoint(RI);
Function *IFn = Intrinsic::getDeclaration(Mod, Intrinsic::objc_retain);
- Value *BC = Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType());
- Builder.CreateCall(IFn, BC, "");
+ Builder.CreateCall(IFn, RetOpnd, "");
}
}
}
@@ -1953,9 +2044,11 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// The inliner does not know how to inline through calls with operand bundles
// in general ...
+ Value *ConvergenceControlToken = nullptr;
if (CB.hasOperandBundles()) {
for (int i = 0, e = CB.getNumOperandBundles(); i != e; ++i) {
- uint32_t Tag = CB.getOperandBundleAt(i).getTagID();
+ auto OBUse = CB.getOperandBundleAt(i);
+ uint32_t Tag = OBUse.getTagID();
// ... but it knows how to inline through "deopt" operand bundles ...
if (Tag == LLVMContext::OB_deopt)
continue;
@@ -1966,11 +2059,37 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
continue;
if (Tag == LLVMContext::OB_kcfi)
continue;
+ if (Tag == LLVMContext::OB_convergencectrl) {
+ ConvergenceControlToken = OBUse.Inputs[0].get();
+ continue;
+ }
return InlineResult::failure("unsupported operand bundle");
}
}
+ // FIXME: The check below is redundant and incomplete. According to spec, if a
+ // convergent call is missing a token, then the caller is using uncontrolled
+ // convergence. If the callee has an entry intrinsic, then the callee is using
+ // controlled convergence, and the call cannot be inlined. A proper
+ // implemenation of this check requires a whole new analysis that identifies
+ // convergence in every function. For now, we skip that and just do this one
+ // cursory check. The underlying assumption is that in a compiler flow that
+ // fully implements convergence control tokens, there is no mixing of
+ // controlled and uncontrolled convergent operations in the whole program.
+ if (CB.isConvergent()) {
+ auto *I = CalledFunc->getEntryBlock().getFirstNonPHI();
+ if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
+ if (IntrinsicCall->getIntrinsicID() ==
+ Intrinsic::experimental_convergence_entry) {
+ if (!ConvergenceControlToken) {
+ return InlineResult::failure(
+ "convergent call needs convergencectrl operand");
+ }
+ }
+ }
+ }
+
// If the call to the callee cannot throw, set the 'nounwind' flag on any
// calls that we inline.
bool MarkNoUnwind = CB.doesNotThrow();
@@ -2260,6 +2379,17 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
IFI.GetAssumptionCache(*Caller).registerAssumption(II);
}
+ if (ConvergenceControlToken) {
+ auto *I = FirstNewBlock->getFirstNonPHI();
+ if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
+ if (IntrinsicCall->getIntrinsicID() ==
+ Intrinsic::experimental_convergence_entry) {
+ IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken);
+ IntrinsicCall->eraseFromParent();
+ }
+ }
+ }
+
// If there are any alloca instructions in the block that used to be the entry
// block for the callee, move them to the entry block of the caller. First
// calculate which instruction they should be inserted before. We insert the
@@ -2296,6 +2426,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// Transfer all of the allocas over in a block. Using splice means
// that the instructions aren't removed from the symbol table, then
// reinserted.
+ I.setTailBit(true);
Caller->getEntryBlock().splice(InsertPoint, &*FirstNewBlock,
AI->getIterator(), I);
}
@@ -2400,7 +2531,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// `Caller->isPresplitCoroutine()` would affect AlwaysInliner at O0 only.
if ((InsertLifetime || Caller->isPresplitCoroutine()) &&
!IFI.StaticAllocas.empty()) {
- IRBuilder<> builder(&FirstNewBlock->front());
+ IRBuilder<> builder(&*FirstNewBlock, FirstNewBlock->begin());
for (unsigned ai = 0, ae = IFI.StaticAllocas.size(); ai != ae; ++ai) {
AllocaInst *AI = IFI.StaticAllocas[ai];
// Don't mark swifterror allocas. They can't have bitcast uses.
@@ -2454,14 +2585,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// If the inlined code contained dynamic alloca instructions, wrap the inlined
// code with llvm.stacksave/llvm.stackrestore intrinsics.
if (InlinedFunctionInfo.ContainsDynamicAllocas) {
- Module *M = Caller->getParent();
- // Get the two intrinsics we care about.
- Function *StackSave = Intrinsic::getDeclaration(M, Intrinsic::stacksave);
- Function *StackRestore=Intrinsic::getDeclaration(M,Intrinsic::stackrestore);
-
// Insert the llvm.stacksave.
CallInst *SavedPtr = IRBuilder<>(&*FirstNewBlock, FirstNewBlock->begin())
- .CreateCall(StackSave, {}, "savedstack");
+ .CreateStackSave("savedstack");
// Insert a call to llvm.stackrestore before any return instructions in the
// inlined function.
@@ -2472,7 +2598,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
continue;
if (InlinedDeoptimizeCalls && RI->getParent()->getTerminatingDeoptimizeCall())
continue;
- IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr);
+ IRBuilder<>(RI).CreateStackRestore(SavedPtr);
}
}
@@ -2574,6 +2700,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
Builder.CreateRetVoid();
else
Builder.CreateRet(NewDeoptCall);
+ // Since the ret type is changed, remove the incompatible attributes.
+ NewDeoptCall->removeRetAttrs(
+ AttributeFuncs::typeIncompatible(NewDeoptCall->getType()));
}
// Leave behind the normal returns so we can merge control flow.
@@ -2704,8 +2833,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
if (IFI.CallerBFI) {
// Copy original BB's block frequency to AfterCallBB
- IFI.CallerBFI->setBlockFreq(
- AfterCallBB, IFI.CallerBFI->getBlockFreq(OrigBB).getFrequency());
+ IFI.CallerBFI->setBlockFreq(AfterCallBB,
+ IFI.CallerBFI->getBlockFreq(OrigBB));
}
// Change the branch that used to go to AfterCallBB to branch to the first
@@ -2731,8 +2860,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
// The PHI node should go at the front of the new basic block to merge all
// possible incoming values.
if (!CB.use_empty()) {
- PHI = PHINode::Create(RTy, Returns.size(), CB.getName(),
- &AfterCallBB->front());
+ PHI = PHINode::Create(RTy, Returns.size(), CB.getName());
+ PHI->insertBefore(AfterCallBB->begin());
// Anything that used the result of the function call should now use the
// PHI node as their operand.
CB.replaceAllUsesWith(PHI);
diff --git a/llvm/lib/Transforms/Utils/LCSSA.cpp b/llvm/lib/Transforms/Utils/LCSSA.cpp
index c36b0533580b..5e0c312fe149 100644
--- a/llvm/lib/Transforms/Utils/LCSSA.cpp
+++ b/llvm/lib/Transforms/Utils/LCSSA.cpp
@@ -160,7 +160,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
if (SSAUpdate.HasValueForBlock(ExitBB))
continue;
PHINode *PN = PHINode::Create(I->getType(), PredCache.size(ExitBB),
- I->getName() + ".lcssa", &ExitBB->front());
+ I->getName() + ".lcssa");
+ PN->insertBefore(ExitBB->begin());
if (InsertedPHIs)
InsertedPHIs->push_back(PN);
// Get the debug location from the original instruction.
@@ -241,7 +242,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
}
SmallVector<DbgValueInst *, 4> DbgValues;
- llvm::findDbgValues(DbgValues, I);
+ SmallVector<DPValue *, 4> DPValues;
+ llvm::findDbgValues(DbgValues, I, &DPValues);
// Update pre-existing debug value uses that reside outside the loop.
for (auto *DVI : DbgValues) {
@@ -257,6 +259,21 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist,
DVI->replaceVariableLocationOp(I, V);
}
+ // RemoveDIs: copy-paste of block above, using non-instruction debug-info
+ // records.
+ for (DPValue *DPV : DPValues) {
+ BasicBlock *UserBB = DPV->getMarker()->getParent();
+ if (InstBB == UserBB || L->contains(UserBB))
+ continue;
+ // We currently only handle debug values residing in blocks that were
+ // traversed while rewriting the uses. If we inserted just a single PHI,
+ // we will handle all relevant debug values.
+ Value *V = AddedPHIs.size() == 1 ? AddedPHIs[0]
+ : SSAUpdate.FindValueForBlock(UserBB);
+ if (V)
+ DPV->replaceVariableLocationOp(I, V);
+ }
+
// SSAUpdater might have inserted phi-nodes inside other loops. We'll need
// to post-process them to keep LCSSA form.
for (PHINode *InsertedPN : LocalInsertedPHIs) {
diff --git a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
index cdcfb5050bff..6220f8509309 100644
--- a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
+++ b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp
@@ -101,7 +101,7 @@ private:
float Val) {
Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val));
if (!Arg->getType()->isFloatTy())
- V = ConstantExpr::getFPExtend(V, Arg->getType());
+ V = ConstantFoldCastInstruction(Instruction::FPExt, V, Arg->getType());
if (BBBuilder.GetInsertBlock()->getParent()->hasFnAttribute(Attribute::StrictFP))
BBBuilder.setIsFPConstrained(true);
return BBBuilder.CreateFCmp(Cmp, Arg, V);
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index f153ace5d3fc..51f39e0ba0cc 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -69,6 +69,7 @@
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
@@ -86,6 +87,8 @@
using namespace llvm;
using namespace llvm::PatternMatch;
+extern cl::opt<bool> UseNewDbgInfoFormat;
+
#define DEBUG_TYPE "local"
STATISTIC(NumRemoved, "Number of unreachable basic blocks removed");
@@ -227,9 +230,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove weight for this case.
std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
- SI->setMetadata(LLVMContext::MD_prof,
- MDBuilder(BB->getContext()).
- createBranchWeights(Weights));
+ setBranchWeights(*SI, Weights);
}
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();
@@ -414,7 +415,7 @@ bool llvm::wouldInstructionBeTriviallyDeadOnUnusedPaths(
return wouldInstructionBeTriviallyDead(I, TLI);
}
-bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
+bool llvm::wouldInstructionBeTriviallyDead(const Instruction *I,
const TargetLibraryInfo *TLI) {
if (I->isTerminator())
return false;
@@ -428,7 +429,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
if (isa<DbgVariableIntrinsic>(I))
return false;
- if (DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) {
+ if (const DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) {
if (DLI->getLabel())
return false;
return true;
@@ -443,9 +444,16 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
if (!II)
return false;
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::experimental_guard: {
+ // Guards on true are operationally no-ops. In the future we can
+ // consider more sophisticated tradeoffs for guards considering potential
+ // for check widening, but for now we keep things simple.
+ auto *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0));
+ return Cond && Cond->isOne();
+ }
// TODO: These intrinsics are not safe to remove, because this may remove
// a well-defined trap.
- switch (II->getIntrinsicID()) {
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
case Intrinsic::ptrauth_auth:
@@ -461,7 +469,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
// Special case intrinsics that "may have side effects" but can be deleted
// when dead.
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
// Safe to delete llvm.stacksave and launder.invariant.group if dead.
if (II->getIntrinsicID() == Intrinsic::stacksave ||
II->getIntrinsicID() == Intrinsic::launder_invariant_group)
@@ -484,13 +492,9 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I,
return false;
}
- // Assumptions are dead if their condition is trivially true. Guards on
- // true are operationally no-ops. In the future we can consider more
- // sophisticated tradeoffs for guards considering potential for check
- // widening, but for now we keep things simple.
- if ((II->getIntrinsicID() == Intrinsic::assume &&
- isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) ||
- II->getIntrinsicID() == Intrinsic::experimental_guard) {
+ // Assumptions are dead if their condition is trivially true.
+ if (II->getIntrinsicID() == Intrinsic::assume &&
+ isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) {
if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0)))
return !Cond->isZero();
@@ -605,10 +609,13 @@ void llvm::RecursivelyDeleteTriviallyDeadInstructions(
bool llvm::replaceDbgUsesWithUndef(Instruction *I) {
SmallVector<DbgVariableIntrinsic *, 1> DbgUsers;
- findDbgUsers(DbgUsers, I);
+ SmallVector<DPValue *, 1> DPUsers;
+ findDbgUsers(DbgUsers, I, &DPUsers);
for (auto *DII : DbgUsers)
DII->setKillLocation();
- return !DbgUsers.empty();
+ for (auto *DPV : DPUsers)
+ DPV->setKillLocation();
+ return !DbgUsers.empty() || !DPUsers.empty();
}
/// areAllUsesEqual - Check whether the uses of a value are all the same.
@@ -847,17 +854,17 @@ static bool CanMergeValues(Value *First, Value *Second) {
/// branch to Succ, into Succ.
///
/// Assumption: Succ is the single successor for BB.
-static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) {
+static bool
+CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ,
+ const SmallPtrSetImpl<BasicBlock *> &BBPreds) {
assert(*succ_begin(BB) == Succ && "Succ is not successor of BB!");
LLVM_DEBUG(dbgs() << "Looking to fold " << BB->getName() << " into "
<< Succ->getName() << "\n");
// Shortcut, if there is only a single predecessor it must be BB and merging
// is always safe
- if (Succ->getSinglePredecessor()) return true;
-
- // Make a list of the predecessors of BB
- SmallPtrSet<BasicBlock*, 16> BBPreds(pred_begin(BB), pred_end(BB));
+ if (Succ->getSinglePredecessor())
+ return true;
// Look at all the phi nodes in Succ, to see if they present a conflict when
// merging these blocks
@@ -997,6 +1004,35 @@ static void replaceUndefValuesInPhi(PHINode *PN,
}
}
+// Only when they shares a single common predecessor, return true.
+// Only handles cases when BB can't be merged while its predecessors can be
+// redirected.
+static bool
+CanRedirectPredsOfEmptyBBToSucc(BasicBlock *BB, BasicBlock *Succ,
+ const SmallPtrSetImpl<BasicBlock *> &BBPreds,
+ const SmallPtrSetImpl<BasicBlock *> &SuccPreds,
+ BasicBlock *&CommonPred) {
+
+ // There must be phis in BB, otherwise BB will be merged into Succ directly
+ if (BB->phis().empty() || Succ->phis().empty())
+ return false;
+
+ // BB must have predecessors not shared that can be redirected to Succ
+ if (!BB->hasNPredecessorsOrMore(2))
+ return false;
+
+ // Get single common predecessors of both BB and Succ
+ for (BasicBlock *SuccPred : SuccPreds) {
+ if (BBPreds.count(SuccPred)) {
+ if (CommonPred)
+ return false;
+ CommonPred = SuccPred;
+ }
+ }
+
+ return true;
+}
+
/// Replace a value flowing from a block to a phi with
/// potentially multiple instances of that value flowing from the
/// block's predecessors to the phi.
@@ -1004,9 +1040,11 @@ static void replaceUndefValuesInPhi(PHINode *PN,
/// \param BB The block with the value flowing into the phi.
/// \param BBPreds The predecessors of BB.
/// \param PN The phi that we are updating.
+/// \param CommonPred The common predecessor of BB and PN's BasicBlock
static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB,
const PredBlockVector &BBPreds,
- PHINode *PN) {
+ PHINode *PN,
+ BasicBlock *CommonPred) {
Value *OldVal = PN->removeIncomingValue(BB, false);
assert(OldVal && "No entry in PHI for Pred BB!");
@@ -1034,26 +1072,39 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB,
// will trigger asserts if we try to clean it up now, without also
// simplifying the corresponding conditional branch).
BasicBlock *PredBB = OldValPN->getIncomingBlock(i);
+
+ if (PredBB == CommonPred)
+ continue;
+
Value *PredVal = OldValPN->getIncomingValue(i);
- Value *Selected = selectIncomingValueForBlock(PredVal, PredBB,
- IncomingValues);
+ Value *Selected =
+ selectIncomingValueForBlock(PredVal, PredBB, IncomingValues);
// And add a new incoming value for this predecessor for the
// newly retargeted branch.
PN->addIncoming(Selected, PredBB);
}
+ if (CommonPred)
+ PN->addIncoming(OldValPN->getIncomingValueForBlock(CommonPred), BB);
+
} else {
for (unsigned i = 0, e = BBPreds.size(); i != e; ++i) {
// Update existing incoming values in PN for this
// predecessor of BB.
BasicBlock *PredBB = BBPreds[i];
- Value *Selected = selectIncomingValueForBlock(OldVal, PredBB,
- IncomingValues);
+
+ if (PredBB == CommonPred)
+ continue;
+
+ Value *Selected =
+ selectIncomingValueForBlock(OldVal, PredBB, IncomingValues);
// And add a new incoming value for this predecessor for the
// newly retargeted branch.
PN->addIncoming(Selected, PredBB);
}
+ if (CommonPred)
+ PN->addIncoming(OldVal, BB);
}
replaceUndefValuesInPhi(PN, IncomingValues);
@@ -1064,13 +1115,30 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
assert(BB != &BB->getParent()->getEntryBlock() &&
"TryToSimplifyUncondBranchFromEmptyBlock called on entry block!");
- // We can't eliminate infinite loops.
+ // We can't simplify infinite loops.
BasicBlock *Succ = cast<BranchInst>(BB->getTerminator())->getSuccessor(0);
- if (BB == Succ) return false;
+ if (BB == Succ)
+ return false;
+
+ SmallPtrSet<BasicBlock *, 16> BBPreds(pred_begin(BB), pred_end(BB));
+ SmallPtrSet<BasicBlock *, 16> SuccPreds(pred_begin(Succ), pred_end(Succ));
- // Check to see if merging these blocks would cause conflicts for any of the
- // phi nodes in BB or Succ. If not, we can safely merge.
- if (!CanPropagatePredecessorsForPHIs(BB, Succ)) return false;
+ // The single common predecessor of BB and Succ when BB cannot be killed
+ BasicBlock *CommonPred = nullptr;
+
+ bool BBKillable = CanPropagatePredecessorsForPHIs(BB, Succ, BBPreds);
+
+ // Even if we can not fold bB into Succ, we may be able to redirect the
+ // predecessors of BB to Succ.
+ bool BBPhisMergeable =
+ BBKillable ||
+ CanRedirectPredsOfEmptyBBToSucc(BB, Succ, BBPreds, SuccPreds, CommonPred);
+
+ if (!BBKillable && !BBPhisMergeable)
+ return false;
+
+ // Check to see if merging these blocks/phis would cause conflicts for any of
+ // the phi nodes in BB or Succ. If not, we can safely merge.
// Check for cases where Succ has multiple predecessors and a PHI node in BB
// has uses which will not disappear when the PHI nodes are merged. It is
@@ -1099,6 +1167,11 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
}
}
+ if (BBPhisMergeable && CommonPred)
+ LLVM_DEBUG(dbgs() << "Found Common Predecessor between: " << BB->getName()
+ << " and " << Succ->getName() << " : "
+ << CommonPred->getName() << "\n");
+
// 'BB' and 'BB->Pred' are loop latches, bail out to presrve inner loop
// metadata.
//
@@ -1171,25 +1244,37 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
if (PredTI->hasMetadata(LLVMContext::MD_loop))
return false;
- LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB);
+ if (BBKillable)
+ LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB);
+ else if (BBPhisMergeable)
+ LLVM_DEBUG(dbgs() << "Merge Phis in Trivial BB: \n" << *BB);
SmallVector<DominatorTree::UpdateType, 32> Updates;
+
if (DTU) {
// To avoid processing the same predecessor more than once.
SmallPtrSet<BasicBlock *, 8> SeenPreds;
- // All predecessors of BB will be moved to Succ.
- SmallPtrSet<BasicBlock *, 8> PredsOfSucc(pred_begin(Succ), pred_end(Succ));
+ // All predecessors of BB (except the common predecessor) will be moved to
+ // Succ.
Updates.reserve(Updates.size() + 2 * pred_size(BB) + 1);
- for (auto *PredOfBB : predecessors(BB))
- // This predecessor of BB may already have Succ as a successor.
- if (!PredsOfSucc.contains(PredOfBB))
+
+ for (auto *PredOfBB : predecessors(BB)) {
+ // Do not modify those common predecessors of BB and Succ
+ if (!SuccPreds.contains(PredOfBB))
if (SeenPreds.insert(PredOfBB).second)
Updates.push_back({DominatorTree::Insert, PredOfBB, Succ});
+ }
+
SeenPreds.clear();
+
for (auto *PredOfBB : predecessors(BB))
- if (SeenPreds.insert(PredOfBB).second)
+ // When BB cannot be killed, do not remove the edge between BB and
+ // CommonPred.
+ if (SeenPreds.insert(PredOfBB).second && PredOfBB != CommonPred)
Updates.push_back({DominatorTree::Delete, PredOfBB, BB});
- Updates.push_back({DominatorTree::Delete, BB, Succ});
+
+ if (BBKillable)
+ Updates.push_back({DominatorTree::Delete, BB, Succ});
}
if (isa<PHINode>(Succ->begin())) {
@@ -1201,21 +1286,19 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
// Loop over all of the PHI nodes in the successor of BB.
for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
PHINode *PN = cast<PHINode>(I);
-
- redirectValuesFromPredecessorsToPhi(BB, BBPreds, PN);
+ redirectValuesFromPredecessorsToPhi(BB, BBPreds, PN, CommonPred);
}
}
if (Succ->getSinglePredecessor()) {
// BB is the only predecessor of Succ, so Succ will end up with exactly
// the same predecessors BB had.
-
// Copy over any phi, debug or lifetime instruction.
BB->getTerminator()->eraseFromParent();
- Succ->splice(Succ->getFirstNonPHI()->getIterator(), BB);
+ Succ->splice(Succ->getFirstNonPHIIt(), BB);
} else {
while (PHINode *PN = dyn_cast<PHINode>(&BB->front())) {
- // We explicitly check for such uses in CanPropagatePredecessorsForPHIs.
+ // We explicitly check for such uses for merging phis.
assert(PN->use_empty() && "There shouldn't be any uses here!");
PN->eraseFromParent();
}
@@ -1228,26 +1311,42 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
for (BasicBlock *Pred : predecessors(BB))
Pred->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD);
- // Everything that jumped to BB now goes to Succ.
- BB->replaceAllUsesWith(Succ);
- if (!Succ->hasName()) Succ->takeName(BB);
+ if (BBKillable) {
+ // Everything that jumped to BB now goes to Succ.
+ BB->replaceAllUsesWith(Succ);
- // Clear the successor list of BB to match updates applying to DTU later.
- if (BB->getTerminator())
- BB->back().eraseFromParent();
- new UnreachableInst(BB->getContext(), BB);
- assert(succ_empty(BB) && "The successor list of BB isn't empty before "
- "applying corresponding DTU updates.");
+ if (!Succ->hasName())
+ Succ->takeName(BB);
+
+ // Clear the successor list of BB to match updates applying to DTU later.
+ if (BB->getTerminator())
+ BB->back().eraseFromParent();
+
+ new UnreachableInst(BB->getContext(), BB);
+ assert(succ_empty(BB) && "The successor list of BB isn't empty before "
+ "applying corresponding DTU updates.");
+ } else if (BBPhisMergeable) {
+ // Everything except CommonPred that jumped to BB now goes to Succ.
+ BB->replaceUsesWithIf(Succ, [BBPreds, CommonPred](Use &U) -> bool {
+ if (Instruction *UseInst = dyn_cast<Instruction>(U.getUser()))
+ return UseInst->getParent() != CommonPred &&
+ BBPreds.contains(UseInst->getParent());
+ return false;
+ });
+ }
if (DTU)
DTU->applyUpdates(Updates);
- DeleteDeadBlock(BB, DTU);
+ if (BBKillable)
+ DeleteDeadBlock(BB, DTU);
return true;
}
-static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) {
+static bool
+EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB,
+ SmallPtrSetImpl<PHINode *> &ToRemove) {
// This implementation doesn't currently consider undef operands
// specially. Theoretically, two phis which are identical except for
// one having an undef where the other doesn't could be collapsed.
@@ -1263,12 +1362,14 @@ static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) {
// Note that we only look in the upper square's triangle,
// we already checked that the lower triangle PHI's aren't identical.
for (auto J = I; PHINode *DuplicatePN = dyn_cast<PHINode>(J); ++J) {
+ if (ToRemove.contains(DuplicatePN))
+ continue;
if (!DuplicatePN->isIdenticalToWhenDefined(PN))
continue;
// A duplicate. Replace this PHI with the base PHI.
++NumPHICSEs;
DuplicatePN->replaceAllUsesWith(PN);
- DuplicatePN->eraseFromParent();
+ ToRemove.insert(DuplicatePN);
Changed = true;
// The RAUW can change PHIs that we already visited.
@@ -1279,7 +1380,9 @@ static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) {
return Changed;
}
-static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) {
+static bool
+EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB,
+ SmallPtrSetImpl<PHINode *> &ToRemove) {
// This implementation doesn't currently consider undef operands
// specially. Theoretically, two phis which are identical except for
// one having an undef where the other doesn't could be collapsed.
@@ -1343,12 +1446,14 @@ static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) {
// Examine each PHI.
bool Changed = false;
for (auto I = BB->begin(); PHINode *PN = dyn_cast<PHINode>(I++);) {
+ if (ToRemove.contains(PN))
+ continue;
auto Inserted = PHISet.insert(PN);
if (!Inserted.second) {
// A duplicate. Replace this PHI with its duplicate.
++NumPHICSEs;
PN->replaceAllUsesWith(*Inserted.first);
- PN->eraseFromParent();
+ ToRemove.insert(PN);
Changed = true;
// The RAUW can change PHIs that we already visited. Start over from the
@@ -1361,25 +1466,27 @@ static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) {
return Changed;
}
-bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) {
+bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB,
+ SmallPtrSetImpl<PHINode *> &ToRemove) {
if (
#ifndef NDEBUG
!PHICSEDebugHash &&
#endif
hasNItemsOrLess(BB->phis(), PHICSENumPHISmallSize))
- return EliminateDuplicatePHINodesNaiveImpl(BB);
- return EliminateDuplicatePHINodesSetBasedImpl(BB);
+ return EliminateDuplicatePHINodesNaiveImpl(BB, ToRemove);
+ return EliminateDuplicatePHINodesSetBasedImpl(BB, ToRemove);
}
-/// If the specified pointer points to an object that we control, try to modify
-/// the object's alignment to PrefAlign. Returns a minimum known alignment of
-/// the value after the operation, which may be lower than PrefAlign.
-///
-/// Increating value alignment isn't often possible though. If alignment is
-/// important, a more reliable approach is to simply align all global variables
-/// and allocation instructions to their preferred alignment from the beginning.
-static Align tryEnforceAlignment(Value *V, Align PrefAlign,
- const DataLayout &DL) {
+bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) {
+ SmallPtrSet<PHINode *, 8> ToRemove;
+ bool Changed = EliminateDuplicatePHINodes(BB, ToRemove);
+ for (PHINode *PN : ToRemove)
+ PN->eraseFromParent();
+ return Changed;
+}
+
+Align llvm::tryEnforceAlignment(Value *V, Align PrefAlign,
+ const DataLayout &DL) {
V = V->stripPointerCasts();
if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
@@ -1463,12 +1570,18 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar,
// is removed by LowerDbgDeclare(), we need to make sure that we are
// not inserting the same dbg.value intrinsic over and over.
SmallVector<DbgValueInst *, 1> DbgValues;
- findDbgValues(DbgValues, APN);
+ SmallVector<DPValue *, 1> DPValues;
+ findDbgValues(DbgValues, APN, &DPValues);
for (auto *DVI : DbgValues) {
assert(is_contained(DVI->getValues(), APN));
if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr))
return true;
}
+ for (auto *DPV : DPValues) {
+ assert(is_contained(DPV->location_ops(), APN));
+ if ((DPV->getVariable() == DIVar) && (DPV->getExpression() == DIExpr))
+ return true;
+ }
return false;
}
@@ -1504,6 +1617,67 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) {
// Could not determine size of variable. Conservatively return false.
return false;
}
+// RemoveDIs: duplicate implementation of the above, using DPValues, the
+// replacement for dbg.values.
+static bool valueCoversEntireFragment(Type *ValTy, DPValue *DPV) {
+ const DataLayout &DL = DPV->getModule()->getDataLayout();
+ TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy);
+ if (std::optional<uint64_t> FragmentSize = DPV->getFragmentSizeInBits())
+ return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize));
+
+ // We can't always calculate the size of the DI variable (e.g. if it is a
+ // VLA). Try to use the size of the alloca that the dbg intrinsic describes
+ // intead.
+ if (DPV->isAddressOfVariable()) {
+ // DPV should have exactly 1 location when it is an address.
+ assert(DPV->getNumVariableLocationOps() == 1 &&
+ "address of variable must have exactly 1 location operand.");
+ if (auto *AI =
+ dyn_cast_or_null<AllocaInst>(DPV->getVariableLocationOp(0))) {
+ if (std::optional<TypeSize> FragmentSize = AI->getAllocationSizeInBits(DL)) {
+ return TypeSize::isKnownGE(ValueSize, *FragmentSize);
+ }
+ }
+ }
+ // Could not determine size of variable. Conservatively return false.
+ return false;
+}
+
+static void insertDbgValueOrDPValue(DIBuilder &Builder, Value *DV,
+ DILocalVariable *DIVar,
+ DIExpression *DIExpr,
+ const DebugLoc &NewLoc,
+ BasicBlock::iterator Instr) {
+ if (!UseNewDbgInfoFormat) {
+ auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc,
+ (Instruction *)nullptr);
+ DbgVal->insertBefore(Instr);
+ } else {
+ // RemoveDIs: if we're using the new debug-info format, allocate a
+ // DPValue directly instead of a dbg.value intrinsic.
+ ValueAsMetadata *DVAM = ValueAsMetadata::get(DV);
+ DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get());
+ Instr->getParent()->insertDPValueBefore(DV, Instr);
+ }
+}
+
+static void insertDbgValueOrDPValueAfter(DIBuilder &Builder, Value *DV,
+ DILocalVariable *DIVar,
+ DIExpression *DIExpr,
+ const DebugLoc &NewLoc,
+ BasicBlock::iterator Instr) {
+ if (!UseNewDbgInfoFormat) {
+ auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc,
+ (Instruction *)nullptr);
+ DbgVal->insertAfter(&*Instr);
+ } else {
+ // RemoveDIs: if we're using the new debug-info format, allocate a
+ // DPValue directly instead of a dbg.value intrinsic.
+ ValueAsMetadata *DVAM = ValueAsMetadata::get(DV);
+ DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get());
+ Instr->getParent()->insertDPValueAfter(DV, &*Instr);
+ }
+}
/// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value
/// that has an associated llvm.dbg.declare intrinsic.
@@ -1533,7 +1707,8 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
DIExpr->isDeref() || (!DIExpr->startsWithDeref() &&
valueCoversEntireFragment(DV->getType(), DII));
if (CanConvert) {
- Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI);
+ insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc,
+ SI->getIterator());
return;
}
@@ -1545,7 +1720,19 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
// know which part) we insert an dbg.value intrinsic to indicate that we
// know nothing about the variable's content.
DV = UndefValue::get(DV->getType());
- Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI);
+ insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc,
+ SI->getIterator());
+}
+
+// RemoveDIs: duplicate the getDebugValueLoc method using DPValues instead of
+// dbg.value intrinsics.
+static DebugLoc getDebugValueLocDPV(DPValue *DPV) {
+ // Original dbg.declare must have a location.
+ const DebugLoc &DeclareLoc = DPV->getDebugLoc();
+ MDNode *Scope = DeclareLoc.getScope();
+ DILocation *InlinedAt = DeclareLoc.getInlinedAt();
+ // Produce an unknown location with the correct scope / inlinedAt fields.
+ return DILocation::get(DPV->getContext(), 0, 0, Scope, InlinedAt);
}
/// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value
@@ -1571,9 +1758,40 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
// future if multi-location support is added to the IR, it might be
// preferable to keep tracking both the loaded value and the original
// address in case the alloca can not be elided.
- Instruction *DbgValue = Builder.insertDbgValueIntrinsic(
- LI, DIVar, DIExpr, NewLoc, (Instruction *)nullptr);
- DbgValue->insertAfter(LI);
+ insertDbgValueOrDPValueAfter(Builder, LI, DIVar, DIExpr, NewLoc,
+ LI->getIterator());
+}
+
+void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, StoreInst *SI,
+ DIBuilder &Builder) {
+ assert(DPV->isAddressOfVariable());
+ auto *DIVar = DPV->getVariable();
+ assert(DIVar && "Missing variable");
+ auto *DIExpr = DPV->getExpression();
+ Value *DV = SI->getValueOperand();
+
+ DebugLoc NewLoc = getDebugValueLocDPV(DPV);
+
+ if (!valueCoversEntireFragment(DV->getType(), DPV)) {
+ // FIXME: If storing to a part of the variable described by the dbg.declare,
+ // then we want to insert a DPValue.value for the corresponding fragment.
+ LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV
+ << '\n');
+ // For now, when there is a store to parts of the variable (but we do not
+ // know which part) we insert an DPValue record to indicate that we know
+ // nothing about the variable's content.
+ DV = UndefValue::get(DV->getType());
+ ValueAsMetadata *DVAM = ValueAsMetadata::get(DV);
+ DPValue *NewDPV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get());
+ SI->getParent()->insertDPValueBefore(NewDPV, SI->getIterator());
+ return;
+ }
+
+ assert(UseNewDbgInfoFormat);
+ // Create a DPValue directly and insert.
+ ValueAsMetadata *DVAM = ValueAsMetadata::get(DV);
+ DPValue *NewDPV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get());
+ SI->getParent()->insertDPValueBefore(NewDPV, SI->getIterator());
}
/// Inserts a llvm.dbg.value intrinsic after a phi that has an associated
@@ -1604,8 +1822,38 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII,
// The block may be a catchswitch block, which does not have a valid
// insertion point.
// FIXME: Insert dbg.value markers in the successors when appropriate.
- if (InsertionPt != BB->end())
- Builder.insertDbgValueIntrinsic(APN, DIVar, DIExpr, NewLoc, &*InsertionPt);
+ if (InsertionPt != BB->end()) {
+ insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt);
+ }
+}
+
+void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, LoadInst *LI,
+ DIBuilder &Builder) {
+ auto *DIVar = DPV->getVariable();
+ auto *DIExpr = DPV->getExpression();
+ assert(DIVar && "Missing variable");
+
+ if (!valueCoversEntireFragment(LI->getType(), DPV)) {
+ // FIXME: If only referring to a part of the variable described by the
+ // dbg.declare, then we want to insert a DPValue for the corresponding
+ // fragment.
+ LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV
+ << '\n');
+ return;
+ }
+
+ DebugLoc NewLoc = getDebugValueLocDPV(DPV);
+
+ // We are now tracking the loaded value instead of the address. In the
+ // future if multi-location support is added to the IR, it might be
+ // preferable to keep tracking both the loaded value and the original
+ // address in case the alloca can not be elided.
+ assert(UseNewDbgInfoFormat);
+
+ // Create a DPValue directly and insert.
+ ValueAsMetadata *LIVAM = ValueAsMetadata::get(LI);
+ DPValue *DV = new DPValue(LIVAM, DIVar, DIExpr, NewLoc.get());
+ LI->getParent()->insertDPValueAfter(DV, LI);
}
/// Determine whether this alloca is either a VLA or an array.
@@ -1618,6 +1866,36 @@ static bool isArray(AllocaInst *AI) {
static bool isStructure(AllocaInst *AI) {
return AI->getAllocatedType() && AI->getAllocatedType()->isStructTy();
}
+void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, PHINode *APN,
+ DIBuilder &Builder) {
+ auto *DIVar = DPV->getVariable();
+ auto *DIExpr = DPV->getExpression();
+ assert(DIVar && "Missing variable");
+
+ if (PhiHasDebugValue(DIVar, DIExpr, APN))
+ return;
+
+ if (!valueCoversEntireFragment(APN->getType(), DPV)) {
+ // FIXME: If only referring to a part of the variable described by the
+ // dbg.declare, then we want to insert a DPValue for the corresponding
+ // fragment.
+ LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV
+ << '\n');
+ return;
+ }
+
+ BasicBlock *BB = APN->getParent();
+ auto InsertionPt = BB->getFirstInsertionPt();
+
+ DebugLoc NewLoc = getDebugValueLocDPV(DPV);
+
+ // The block may be a catchswitch block, which does not have a valid
+ // insertion point.
+ // FIXME: Insert DPValue markers in the successors when appropriate.
+ if (InsertionPt != BB->end()) {
+ insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt);
+ }
+}
/// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set
/// of llvm.dbg.value intrinsics.
@@ -1674,8 +1952,8 @@ bool llvm::LowerDbgDeclare(Function &F) {
DebugLoc NewLoc = getDebugValueLoc(DDI);
auto *DerefExpr =
DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref);
- DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr,
- NewLoc, CI);
+ insertDbgValueOrDPValue(DIB, AI, DDI->getVariable(), DerefExpr,
+ NewLoc, CI->getIterator());
}
} else if (BitCastInst *BI = dyn_cast<BitCastInst>(U)) {
if (BI->getType()->isPointerTy())
@@ -1694,6 +1972,69 @@ bool llvm::LowerDbgDeclare(Function &F) {
return Changed;
}
+// RemoveDIs: re-implementation of insertDebugValuesForPHIs, but which pulls the
+// debug-info out of the block's DPValues rather than dbg.value intrinsics.
+static void insertDPValuesForPHIs(BasicBlock *BB,
+ SmallVectorImpl<PHINode *> &InsertedPHIs) {
+ assert(BB && "No BasicBlock to clone DPValue(s) from.");
+ if (InsertedPHIs.size() == 0)
+ return;
+
+ // Map existing PHI nodes to their DPValues.
+ DenseMap<Value *, DPValue *> DbgValueMap;
+ for (auto &I : *BB) {
+ for (auto &DPV : I.getDbgValueRange()) {
+ for (Value *V : DPV.location_ops())
+ if (auto *Loc = dyn_cast_or_null<PHINode>(V))
+ DbgValueMap.insert({Loc, &DPV});
+ }
+ }
+ if (DbgValueMap.size() == 0)
+ return;
+
+ // Map a pair of the destination BB and old DPValue to the new DPValue,
+ // so that if a DPValue is being rewritten to use more than one of the
+ // inserted PHIs in the same destination BB, we can update the same DPValue
+ // with all the new PHIs instead of creating one copy for each.
+ MapVector<std::pair<BasicBlock *, DPValue *>, DPValue *> NewDbgValueMap;
+ // Then iterate through the new PHIs and look to see if they use one of the
+ // previously mapped PHIs. If so, create a new DPValue that will propagate
+ // the info through the new PHI. If we use more than one new PHI in a single
+ // destination BB with the same old dbg.value, merge the updates so that we
+ // get a single new DPValue with all the new PHIs.
+ for (auto PHI : InsertedPHIs) {
+ BasicBlock *Parent = PHI->getParent();
+ // Avoid inserting a debug-info record into an EH block.
+ if (Parent->getFirstNonPHI()->isEHPad())
+ continue;
+ for (auto VI : PHI->operand_values()) {
+ auto V = DbgValueMap.find(VI);
+ if (V != DbgValueMap.end()) {
+ DPValue *DbgII = cast<DPValue>(V->second);
+ auto NewDI = NewDbgValueMap.find({Parent, DbgII});
+ if (NewDI == NewDbgValueMap.end()) {
+ DPValue *NewDbgII = DbgII->clone();
+ NewDI = NewDbgValueMap.insert({{Parent, DbgII}, NewDbgII}).first;
+ }
+ DPValue *NewDbgII = NewDI->second;
+ // If PHI contains VI as an operand more than once, we may
+ // replaced it in NewDbgII; confirm that it is present.
+ if (is_contained(NewDbgII->location_ops(), VI))
+ NewDbgII->replaceVariableLocationOp(VI, PHI);
+ }
+ }
+ }
+ // Insert the new DPValues into their destination blocks.
+ for (auto DI : NewDbgValueMap) {
+ BasicBlock *Parent = DI.first.first;
+ DPValue *NewDbgII = DI.second;
+ auto InsertionPt = Parent->getFirstInsertionPt();
+ assert(InsertionPt != Parent->end() && "Ill-formed basic block");
+
+ InsertionPt->DbgMarker->insertDPValue(NewDbgII, true);
+ }
+}
+
/// Propagate dbg.value intrinsics through the newly inserted PHIs.
void llvm::insertDebugValuesForPHIs(BasicBlock *BB,
SmallVectorImpl<PHINode *> &InsertedPHIs) {
@@ -1701,6 +2042,8 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB,
if (InsertedPHIs.size() == 0)
return;
+ insertDPValuesForPHIs(BB, InsertedPHIs);
+
// Map existing PHI nodes to their dbg.values.
ValueToValueMapTy DbgValueMap;
for (auto &I : *BB) {
@@ -1775,44 +2118,60 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress,
return !DbgDeclares.empty();
}
-static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress,
- DIBuilder &Builder, int Offset) {
- const DebugLoc &Loc = DVI->getDebugLoc();
- auto *DIVar = DVI->getVariable();
- auto *DIExpr = DVI->getExpression();
+static void updateOneDbgValueForAlloca(const DebugLoc &Loc,
+ DILocalVariable *DIVar,
+ DIExpression *DIExpr, Value *NewAddress,
+ DbgValueInst *DVI, DPValue *DPV,
+ DIBuilder &Builder, int Offset) {
assert(DIVar && "Missing variable");
- // This is an alloca-based llvm.dbg.value. The first thing it should do with
- // the alloca pointer is dereference it. Otherwise we don't know how to handle
- // it and give up.
+ // This is an alloca-based dbg.value/DPValue. The first thing it should do
+ // with the alloca pointer is dereference it. Otherwise we don't know how to
+ // handle it and give up.
if (!DIExpr || DIExpr->getNumElements() < 1 ||
DIExpr->getElement(0) != dwarf::DW_OP_deref)
return;
// Insert the offset before the first deref.
- // We could just change the offset argument of dbg.value, but it's unsigned...
if (Offset)
DIExpr = DIExpression::prepend(DIExpr, 0, Offset);
- Builder.insertDbgValueIntrinsic(NewAddress, DIVar, DIExpr, Loc, DVI);
- DVI->eraseFromParent();
+ if (DVI) {
+ DVI->setExpression(DIExpr);
+ DVI->replaceVariableLocationOp(0u, NewAddress);
+ } else {
+ assert(DPV);
+ DPV->setExpression(DIExpr);
+ DPV->replaceVariableLocationOp(0u, NewAddress);
+ }
}
void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress,
DIBuilder &Builder, int Offset) {
- if (auto *L = LocalAsMetadata::getIfExists(AI))
- if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L))
- for (Use &U : llvm::make_early_inc_range(MDV->uses()))
- if (auto *DVI = dyn_cast<DbgValueInst>(U.getUser()))
- replaceOneDbgValueForAlloca(DVI, NewAllocaAddress, Builder, Offset);
+ SmallVector<DbgValueInst *, 1> DbgUsers;
+ SmallVector<DPValue *, 1> DPUsers;
+ findDbgValues(DbgUsers, AI, &DPUsers);
+
+ // Attempt to replace dbg.values that use this alloca.
+ for (auto *DVI : DbgUsers)
+ updateOneDbgValueForAlloca(DVI->getDebugLoc(), DVI->getVariable(),
+ DVI->getExpression(), NewAllocaAddress, DVI,
+ nullptr, Builder, Offset);
+
+ // Replace any DPValues that use this alloca.
+ for (DPValue *DPV : DPUsers)
+ updateOneDbgValueForAlloca(DPV->getDebugLoc(), DPV->getVariable(),
+ DPV->getExpression(), NewAllocaAddress, nullptr,
+ DPV, Builder, Offset);
}
/// Where possible to salvage debug information for \p I do so.
/// If not possible mark undef.
void llvm::salvageDebugInfo(Instruction &I) {
SmallVector<DbgVariableIntrinsic *, 1> DbgUsers;
- findDbgUsers(DbgUsers, &I);
- salvageDebugInfoForDbgValues(I, DbgUsers);
+ SmallVector<DPValue *, 1> DPUsers;
+ findDbgUsers(DbgUsers, &I, &DPUsers);
+ salvageDebugInfoForDbgValues(I, DbgUsers, DPUsers);
}
/// Salvage the address component of \p DAI.
@@ -1850,7 +2209,8 @@ static void salvageDbgAssignAddress(DbgAssignIntrinsic *DAI) {
}
void llvm::salvageDebugInfoForDbgValues(
- Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers) {
+ Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers,
+ ArrayRef<DPValue *> DPUsers) {
// These are arbitrary chosen limits on the maximum number of values and the
// maximum size of a debug expression we can salvage up to, used for
// performance reasons.
@@ -1916,12 +2276,70 @@ void llvm::salvageDebugInfoForDbgValues(
LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n');
Salvaged = true;
}
+ // Duplicate of above block for DPValues.
+ for (auto *DPV : DPUsers) {
+ // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they
+ // are implicitly pointing out the value as a DWARF memory location
+ // description.
+ bool StackValue = DPV->getType() == DPValue::LocationType::Value;
+ auto DPVLocation = DPV->location_ops();
+ assert(
+ is_contained(DPVLocation, &I) &&
+ "DbgVariableIntrinsic must use salvaged instruction as its location");
+ SmallVector<Value *, 4> AdditionalValues;
+ // 'I' may appear more than once in DPV's location ops, and each use of 'I'
+ // must be updated in the DIExpression and potentially have additional
+ // values added; thus we call salvageDebugInfoImpl for each 'I' instance in
+ // DPVLocation.
+ Value *Op0 = nullptr;
+ DIExpression *SalvagedExpr = DPV->getExpression();
+ auto LocItr = find(DPVLocation, &I);
+ while (SalvagedExpr && LocItr != DPVLocation.end()) {
+ SmallVector<uint64_t, 16> Ops;
+ unsigned LocNo = std::distance(DPVLocation.begin(), LocItr);
+ uint64_t CurrentLocOps = SalvagedExpr->getNumLocationOperands();
+ Op0 = salvageDebugInfoImpl(I, CurrentLocOps, Ops, AdditionalValues);
+ if (!Op0)
+ break;
+ SalvagedExpr =
+ DIExpression::appendOpsToArg(SalvagedExpr, Ops, LocNo, StackValue);
+ LocItr = std::find(++LocItr, DPVLocation.end(), &I);
+ }
+ // salvageDebugInfoImpl should fail on examining the first element of
+ // DbgUsers, or none of them.
+ if (!Op0)
+ break;
+
+ DPV->replaceVariableLocationOp(&I, Op0);
+ bool IsValidSalvageExpr =
+ SalvagedExpr->getNumElements() <= MaxExpressionSize;
+ if (AdditionalValues.empty() && IsValidSalvageExpr) {
+ DPV->setExpression(SalvagedExpr);
+ } else if (DPV->getType() == DPValue::LocationType::Value &&
+ IsValidSalvageExpr &&
+ DPV->getNumVariableLocationOps() + AdditionalValues.size() <=
+ MaxDebugArgs) {
+ DPV->addVariableLocationOps(AdditionalValues, SalvagedExpr);
+ } else {
+ // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is
+ // currently only valid for stack value expressions.
+ // Also do not salvage if the resulting DIArgList would contain an
+ // unreasonably large number of values.
+ Value *Undef = UndefValue::get(I.getOperand(0)->getType());
+ DPV->replaceVariableLocationOp(I.getOperand(0), Undef);
+ }
+ LLVM_DEBUG(dbgs() << "SALVAGE: " << DPV << '\n');
+ Salvaged = true;
+ }
if (Salvaged)
return;
for (auto *DII : DbgUsers)
DII->setKillLocation();
+
+ for (auto *DPV : DPUsers)
+ DPV->setKillLocation();
}
Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL,
@@ -2136,16 +2554,20 @@ using DbgValReplacement = std::optional<DIExpression *>;
/// changes are made.
static bool rewriteDebugUsers(
Instruction &From, Value &To, Instruction &DomPoint, DominatorTree &DT,
- function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr) {
+ function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr,
+ function_ref<DbgValReplacement(DPValue &DPV)> RewriteDPVExpr) {
// Find debug users of From.
SmallVector<DbgVariableIntrinsic *, 1> Users;
- findDbgUsers(Users, &From);
- if (Users.empty())
+ SmallVector<DPValue *, 1> DPUsers;
+ findDbgUsers(Users, &From, &DPUsers);
+ if (Users.empty() && DPUsers.empty())
return false;
// Prevent use-before-def of To.
bool Changed = false;
+
SmallPtrSet<DbgVariableIntrinsic *, 1> UndefOrSalvage;
+ SmallPtrSet<DPValue *, 1> UndefOrSalvageDPV;
if (isa<Instruction>(&To)) {
bool DomPointAfterFrom = From.getNextNonDebugInstruction() == &DomPoint;
@@ -2163,6 +2585,25 @@ static bool rewriteDebugUsers(
UndefOrSalvage.insert(DII);
}
}
+
+ // DPValue implementation of the above.
+ for (auto *DPV : DPUsers) {
+ Instruction *MarkedInstr = DPV->getMarker()->MarkedInstr;
+ Instruction *NextNonDebug = MarkedInstr;
+ // The next instruction might still be a dbg.declare, skip over it.
+ if (isa<DbgVariableIntrinsic>(NextNonDebug))
+ NextNonDebug = NextNonDebug->getNextNonDebugInstruction();
+
+ if (DomPointAfterFrom && NextNonDebug == &DomPoint) {
+ LLVM_DEBUG(dbgs() << "MOVE: " << *DPV << '\n');
+ DPV->removeFromParent();
+ // Ensure there's a marker.
+ DomPoint.getParent()->insertDPValueAfter(DPV, &DomPoint);
+ Changed = true;
+ } else if (!DT.dominates(&DomPoint, MarkedInstr)) {
+ UndefOrSalvageDPV.insert(DPV);
+ }
+ }
}
// Update debug users without use-before-def risk.
@@ -2179,8 +2620,21 @@ static bool rewriteDebugUsers(
LLVM_DEBUG(dbgs() << "REWRITE: " << *DII << '\n');
Changed = true;
}
+ for (auto *DPV : DPUsers) {
+ if (UndefOrSalvageDPV.count(DPV))
+ continue;
- if (!UndefOrSalvage.empty()) {
+ DbgValReplacement DVR = RewriteDPVExpr(*DPV);
+ if (!DVR)
+ continue;
+
+ DPV->replaceVariableLocationOp(&From, &To);
+ DPV->setExpression(*DVR);
+ LLVM_DEBUG(dbgs() << "REWRITE: " << DPV << '\n');
+ Changed = true;
+ }
+
+ if (!UndefOrSalvage.empty() || !UndefOrSalvageDPV.empty()) {
// Try to salvage the remaining debug users.
salvageDebugInfo(From);
Changed = true;
@@ -2228,12 +2682,15 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To,
auto Identity = [&](DbgVariableIntrinsic &DII) -> DbgValReplacement {
return DII.getExpression();
};
+ auto IdentityDPV = [&](DPValue &DPV) -> DbgValReplacement {
+ return DPV.getExpression();
+ };
// Handle no-op conversions.
Module &M = *From.getModule();
const DataLayout &DL = M.getDataLayout();
if (isBitCastSemanticsPreserving(DL, FromTy, ToTy))
- return rewriteDebugUsers(From, To, DomPoint, DT, Identity);
+ return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV);
// Handle integer-to-integer widening and narrowing.
// FIXME: Use DW_OP_convert when it's available everywhere.
@@ -2245,7 +2702,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To,
// When the width of the result grows, assume that a debugger will only
// access the low `FromBits` bits when inspecting the source variable.
if (FromBits < ToBits)
- return rewriteDebugUsers(From, To, DomPoint, DT, Identity);
+ return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV);
// The width of the result has shrunk. Use sign/zero extension to describe
// the source variable's high bits.
@@ -2261,7 +2718,22 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To,
return DIExpression::appendExt(DII.getExpression(), ToBits, FromBits,
Signed);
};
- return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt);
+ // RemoveDIs: duplicate implementation working on DPValues rather than on
+ // dbg.value intrinsics.
+ auto SignOrZeroExtDPV = [&](DPValue &DPV) -> DbgValReplacement {
+ DILocalVariable *Var = DPV.getVariable();
+
+ // Without knowing signedness, sign/zero extension isn't possible.
+ auto Signedness = Var->getSignedness();
+ if (!Signedness)
+ return std::nullopt;
+
+ bool Signed = *Signedness == DIBasicType::Signedness::Signed;
+ return DIExpression::appendExt(DPV.getExpression(), ToBits, FromBits,
+ Signed);
+ };
+ return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt,
+ SignOrZeroExtDPV);
}
// TODO: Floating-point conversions, vectors.
@@ -2275,12 +2747,17 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) {
// Delete the instructions backwards, as it has a reduced likelihood of
// having to update as many def-use and use-def chains.
Instruction *EndInst = BB->getTerminator(); // Last not to be deleted.
+ // RemoveDIs: erasing debug-info must be done manually.
+ EndInst->dropDbgValues();
while (EndInst != &BB->front()) {
// Delete the next to last instruction.
Instruction *Inst = &*--EndInst->getIterator();
if (!Inst->use_empty() && !Inst->getType()->isTokenTy())
Inst->replaceAllUsesWith(PoisonValue::get(Inst->getType()));
if (Inst->isEHPad() || Inst->getType()->isTokenTy()) {
+ // EHPads can't have DPValues attached to them, but it might be possible
+ // for things with token type.
+ Inst->dropDbgValues();
EndInst = Inst;
continue;
}
@@ -2288,6 +2765,8 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) {
++NumDeadDbgInst;
else
++NumDeadInst;
+ // RemoveDIs: erasing debug-info must be done manually.
+ Inst->dropDbgValues();
Inst->eraseFromParent();
}
return {NumDeadInst, NumDeadDbgInst};
@@ -2329,6 +2808,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA,
Updates.push_back({DominatorTree::Delete, BB, UniqueSuccessor});
DTU->applyUpdates(Updates);
}
+ BB->flushTerminatorDbgValues();
return NumInstrsRemoved;
}
@@ -2482,9 +2962,9 @@ static bool markAliveBlocks(Function &F,
// If we found a call to a no-return function, insert an unreachable
// instruction after it. Make sure there isn't *already* one there
// though.
- if (!isa<UnreachableInst>(CI->getNextNode())) {
+ if (!isa<UnreachableInst>(CI->getNextNonDebugInstruction())) {
// Don't insert a call to llvm.trap right before the unreachable.
- changeToUnreachable(CI->getNextNode(), false, DTU);
+ changeToUnreachable(CI->getNextNonDebugInstruction(), false, DTU);
Changed = true;
}
break;
@@ -2896,9 +3376,10 @@ static unsigned replaceDominatedUsesWith(Value *From, Value *To,
for (Use &U : llvm::make_early_inc_range(From->uses())) {
if (!Dominates(Root, U))
continue;
+ LLVM_DEBUG(dbgs() << "Replace dominated use of '";
+ From->printAsOperand(dbgs());
+ dbgs() << "' with " << *To << " in " << *U.getUser() << "\n");
U.set(To);
- LLVM_DEBUG(dbgs() << "Replace dominated use of '" << From->getName()
- << "' as " << *To << " in " << *U << "\n");
++Count;
}
return Count;
@@ -3017,9 +3498,12 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI,
void llvm::dropDebugUsers(Instruction &I) {
SmallVector<DbgVariableIntrinsic *, 1> DbgUsers;
- findDbgUsers(DbgUsers, &I);
+ SmallVector<DPValue *, 1> DPUsers;
+ findDbgUsers(DbgUsers, &I, &DPUsers);
for (auto *DII : DbgUsers)
DII->eraseFromParent();
+ for (auto *DPV : DPUsers)
+ DPV->eraseFromParent();
}
void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt,
@@ -3051,6 +3535,8 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt,
I->dropUBImplyingAttrsAndMetadata();
if (I->isUsedByMetadata())
dropDebugUsers(*I);
+ // RemoveDIs: drop debug-info too as the following code does.
+ I->dropDbgValues();
if (I->isDebugOrPseudoInst()) {
// Remove DbgInfo and pseudo probe Intrinsics.
II = I->eraseFromParent();
@@ -3063,6 +3549,41 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt,
BB->getTerminator()->getIterator());
}
+DIExpression *llvm::getExpressionForConstant(DIBuilder &DIB, const Constant &C,
+ Type &Ty) {
+ // Create integer constant expression.
+ auto createIntegerExpression = [&DIB](const Constant &CV) -> DIExpression * {
+ const APInt &API = cast<ConstantInt>(&CV)->getValue();
+ std::optional<int64_t> InitIntOpt = API.trySExtValue();
+ return InitIntOpt ? DIB.createConstantValueExpression(
+ static_cast<uint64_t>(*InitIntOpt))
+ : nullptr;
+ };
+
+ if (isa<ConstantInt>(C))
+ return createIntegerExpression(C);
+
+ if (Ty.isFloatTy() || Ty.isDoubleTy()) {
+ const APFloat &APF = cast<ConstantFP>(&C)->getValueAPF();
+ return DIB.createConstantValueExpression(
+ APF.bitcastToAPInt().getZExtValue());
+ }
+
+ if (!Ty.isPointerTy())
+ return nullptr;
+
+ if (isa<ConstantPointerNull>(C))
+ return DIB.createConstantValueExpression(0);
+
+ if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(&C))
+ if (CE->getOpcode() == Instruction::IntToPtr) {
+ const Value *V = CE->getOperand(0);
+ if (auto CI = dyn_cast_or_null<ConstantInt>(V))
+ return createIntegerExpression(*CI);
+ }
+ return nullptr;
+}
+
namespace {
/// A potential constituent of a bitreverse or bswap expression. See
diff --git a/llvm/lib/Transforms/Utils/LoopConstrainer.cpp b/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
new file mode 100644
index 000000000000..ea6d952cfa7d
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
@@ -0,0 +1,904 @@
+#include "llvm/Transforms/Utils/LoopConstrainer.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+
+using namespace llvm;
+
+static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
+
+#define DEBUG_TYPE "loop-constrainer"
+
+/// Given a loop with an deccreasing induction variable, is it possible to
+/// safely calculate the bounds of a new loop using the given Predicate.
+static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
+ const SCEV *Step, ICmpInst::Predicate Pred,
+ unsigned LatchBrExitIdx, Loop *L,
+ ScalarEvolution &SE) {
+ if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
+ Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
+ return false;
+
+ if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
+ return false;
+
+ assert(SE.isKnownNegative(Step) && "expecting negative step");
+
+ LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
+ LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
+ LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
+ LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
+ LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
+ LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
+
+ bool IsSigned = ICmpInst::isSigned(Pred);
+ // The predicate that we need to check that the induction variable lies
+ // within bounds.
+ ICmpInst::Predicate BoundPred =
+ IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
+
+ if (LatchBrExitIdx == 1)
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
+
+ assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
+
+ const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
+ unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
+ APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
+ : APInt::getMinValue(BitWidth);
+ const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
+
+ const SCEV *MinusOne =
+ SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
+
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
+ SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
+}
+
+/// Given a loop with an increasing induction variable, is it possible to
+/// safely calculate the bounds of a new loop using the given Predicate.
+static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
+ const SCEV *Step, ICmpInst::Predicate Pred,
+ unsigned LatchBrExitIdx, Loop *L,
+ ScalarEvolution &SE) {
+ if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
+ Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
+ return false;
+
+ if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
+ LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
+ LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
+ LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
+ LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
+ LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
+
+ bool IsSigned = ICmpInst::isSigned(Pred);
+ // The predicate that we need to check that the induction variable lies
+ // within bounds.
+ ICmpInst::Predicate BoundPred =
+ IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
+
+ if (LatchBrExitIdx == 1)
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
+
+ assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
+
+ const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
+ unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
+ APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
+ : APInt::getMaxValue(BitWidth);
+ const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
+
+ return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
+ SE.getAddExpr(BoundSCEV, Step)) &&
+ SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
+}
+
+/// Returns estimate for max latch taken count of the loop of the narrowest
+/// available type. If the latch block has such estimate, it is returned.
+/// Otherwise, we use max exit count of whole loop (that is potentially of wider
+/// type than latch check itself), which is still better than no estimate.
+static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
+ const Loop &L) {
+ const SCEV *FromBlock =
+ SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
+ if (isa<SCEVCouldNotCompute>(FromBlock))
+ return SE.getSymbolicMaxBackedgeTakenCount(&L);
+ return FromBlock;
+}
+
+std::optional<LoopStructure>
+LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
+ bool AllowUnsignedLatchCond,
+ const char *&FailureReason) {
+ if (!L.isLoopSimplifyForm()) {
+ FailureReason = "loop not in LoopSimplify form";
+ return std::nullopt;
+ }
+
+ BasicBlock *Latch = L.getLoopLatch();
+ assert(Latch && "Simplified loops only have one latch!");
+
+ if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
+ FailureReason = "loop has already been cloned";
+ return std::nullopt;
+ }
+
+ if (!L.isLoopExiting(Latch)) {
+ FailureReason = "no loop latch";
+ return std::nullopt;
+ }
+
+ BasicBlock *Header = L.getHeader();
+ BasicBlock *Preheader = L.getLoopPreheader();
+ if (!Preheader) {
+ FailureReason = "no preheader";
+ return std::nullopt;
+ }
+
+ BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
+ if (!LatchBr || LatchBr->isUnconditional()) {
+ FailureReason = "latch terminator not conditional branch";
+ return std::nullopt;
+ }
+
+ unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
+
+ ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
+ if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
+ FailureReason = "latch terminator branch not conditional on integral icmp";
+ return std::nullopt;
+ }
+
+ const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
+ if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
+ FailureReason = "could not compute latch count";
+ return std::nullopt;
+ }
+ assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
+ ScalarEvolution::LoopInvariant &&
+ "loop variant exit count doesn't make sense!");
+
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *LeftValue = ICI->getOperand(0);
+ const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
+ IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
+
+ Value *RightValue = ICI->getOperand(1);
+ const SCEV *RightSCEV = SE.getSCEV(RightValue);
+
+ // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
+ if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
+ if (isa<SCEVAddRecExpr>(RightSCEV)) {
+ std::swap(LeftSCEV, RightSCEV);
+ std::swap(LeftValue, RightValue);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ } else {
+ FailureReason = "no add recurrences in the icmp";
+ return std::nullopt;
+ }
+ }
+
+ auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
+ if (AR->getNoWrapFlags(SCEV::FlagNSW))
+ return true;
+
+ IntegerType *Ty = cast<IntegerType>(AR->getType());
+ IntegerType *WideTy =
+ IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
+
+ const SCEVAddRecExpr *ExtendAfterOp =
+ dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
+ if (ExtendAfterOp) {
+ const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
+ const SCEV *ExtendedStep =
+ SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
+
+ bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
+ ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
+
+ if (NoSignedWrap)
+ return true;
+ }
+
+ // We may have proved this when computing the sign extension above.
+ return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
+ };
+
+ // `ICI` is interpreted as taking the backedge if the *next* value of the
+ // induction variable satisfies some constraint.
+
+ const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
+ if (IndVarBase->getLoop() != &L) {
+ FailureReason = "LHS in cmp is not an AddRec for this loop";
+ return std::nullopt;
+ }
+ if (!IndVarBase->isAffine()) {
+ FailureReason = "LHS in icmp not induction variable";
+ return std::nullopt;
+ }
+ const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
+ if (!isa<SCEVConstant>(StepRec)) {
+ FailureReason = "LHS in icmp not induction variable";
+ return std::nullopt;
+ }
+ ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
+
+ if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
+ FailureReason = "LHS in icmp needs nsw for equality predicates";
+ return std::nullopt;
+ }
+
+ assert(!StepCI->isZero() && "Zero step?");
+ bool IsIncreasing = !StepCI->isNegative();
+ bool IsSignedPredicate;
+ const SCEV *StartNext = IndVarBase->getStart();
+ const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
+ const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
+ const SCEV *Step = SE.getSCEV(StepCI);
+
+ const SCEV *FixedRightSCEV = nullptr;
+
+ // If RightValue resides within loop (but still being loop invariant),
+ // regenerate it as preheader.
+ if (auto *I = dyn_cast<Instruction>(RightValue))
+ if (L.contains(I->getParent()))
+ FixedRightSCEV = RightSCEV;
+
+ if (IsIncreasing) {
+ bool DecreasedRightValueByOne = false;
+ if (StepCI->isOne()) {
+ // Try to turn eq/ne predicates to those we can work with.
+ if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
+ // while (++i != len) { while (++i < len) {
+ // ... ---> ...
+ // } }
+ // If both parts are known non-negative, it is profitable to use
+ // unsigned comparison in increasing loop. This allows us to make the
+ // comparison check against "RightSCEV + 1" more optimistic.
+ if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
+ isKnownNonNegativeInLoop(RightSCEV, &L, SE))
+ Pred = ICmpInst::ICMP_ULT;
+ else
+ Pred = ICmpInst::ICMP_SLT;
+ else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
+ // while (true) { while (true) {
+ // if (++i == len) ---> if (++i > len - 1)
+ // break; break;
+ // ... ...
+ // } }
+ if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
+ cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
+ Pred = ICmpInst::ICMP_UGT;
+ RightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ DecreasedRightValueByOne = true;
+ } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
+ Pred = ICmpInst::ICMP_SGT;
+ RightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ DecreasedRightValueByOne = true;
+ }
+ }
+ }
+
+ bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
+ bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
+ bool FoundExpectedPred =
+ (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
+
+ if (!FoundExpectedPred) {
+ FailureReason = "expected icmp slt semantically, found something else";
+ return std::nullopt;
+ }
+
+ IsSignedPredicate = ICmpInst::isSigned(Pred);
+ if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
+ FailureReason = "unsigned latch conditions are explicitly prohibited";
+ return std::nullopt;
+ }
+
+ if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
+ LatchBrExitIdx, &L, SE)) {
+ FailureReason = "Unsafe loop bounds";
+ return std::nullopt;
+ }
+ if (LatchBrExitIdx == 0) {
+ // We need to increase the right value unless we have already decreased
+ // it virtually when we replaced EQ with SGT.
+ if (!DecreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ } else {
+ assert(!DecreasedRightValueByOne &&
+ "Right value can be decreased only for LatchBrExitIdx == 0!");
+ }
+ } else {
+ bool IncreasedRightValueByOne = false;
+ if (StepCI->isMinusOne()) {
+ // Try to turn eq/ne predicates to those we can work with.
+ if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
+ // while (--i != len) { while (--i > len) {
+ // ... ---> ...
+ // } }
+ // We intentionally don't turn the predicate into UGT even if we know
+ // that both operands are non-negative, because it will only pessimize
+ // our check against "RightSCEV - 1".
+ Pred = ICmpInst::ICMP_SGT;
+ else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
+ // while (true) { while (true) {
+ // if (--i == len) ---> if (--i < len + 1)
+ // break; break;
+ // ... ...
+ // } }
+ if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
+ cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
+ Pred = ICmpInst::ICMP_ULT;
+ RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ IncreasedRightValueByOne = true;
+ } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
+ Pred = ICmpInst::ICMP_SLT;
+ RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ IncreasedRightValueByOne = true;
+ }
+ }
+ }
+
+ bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
+ bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
+
+ bool FoundExpectedPred =
+ (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
+
+ if (!FoundExpectedPred) {
+ FailureReason = "expected icmp sgt semantically, found something else";
+ return std::nullopt;
+ }
+
+ IsSignedPredicate =
+ Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
+
+ if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
+ FailureReason = "unsigned latch conditions are explicitly prohibited";
+ return std::nullopt;
+ }
+
+ if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
+ LatchBrExitIdx, &L, SE)) {
+ FailureReason = "Unsafe bounds";
+ return std::nullopt;
+ }
+
+ if (LatchBrExitIdx == 0) {
+ // We need to decrease the right value unless we have already increased
+ // it virtually when we replaced EQ with SLT.
+ if (!IncreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ } else {
+ assert(!IncreasedRightValueByOne &&
+ "Right value can be increased only for LatchBrExitIdx == 0!");
+ }
+ }
+ BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
+
+ assert(!L.contains(LatchExit) && "expected an exit block!");
+ const DataLayout &DL = Preheader->getModule()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "loop-constrainer");
+ Instruction *Ins = Preheader->getTerminator();
+
+ if (FixedRightSCEV)
+ RightValue =
+ Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
+
+ Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
+ IndVarStartV->setName("indvar.start");
+
+ LoopStructure Result;
+
+ Result.Tag = "main";
+ Result.Header = Header;
+ Result.Latch = Latch;
+ Result.LatchBr = LatchBr;
+ Result.LatchExit = LatchExit;
+ Result.LatchBrExitIdx = LatchBrExitIdx;
+ Result.IndVarStart = IndVarStartV;
+ Result.IndVarStep = StepCI;
+ Result.IndVarBase = LeftValue;
+ Result.IndVarIncreasing = IsIncreasing;
+ Result.LoopExitAt = RightValue;
+ Result.IsSignedPredicate = IsSignedPredicate;
+ Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
+
+ FailureReason = nullptr;
+
+ return Result;
+}
+
+// Add metadata to the loop L to disable loop optimizations. Callers need to
+// confirm that optimizing loop L is not beneficial.
+static void DisableAllLoopOptsOnLoop(Loop &L) {
+ // We do not care about any existing loopID related metadata for L, since we
+ // are setting all loop metadata to false.
+ LLVMContext &Context = L.getHeader()->getContext();
+ // Reserve first location for self reference to the LoopID metadata node.
+ MDNode *Dummy = MDNode::get(Context, {});
+ MDNode *DisableUnroll = MDNode::get(
+ Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
+ Metadata *FalseVal =
+ ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
+ MDNode *DisableVectorize = MDNode::get(
+ Context,
+ {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
+ MDNode *DisableLICMVersioning = MDNode::get(
+ Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
+ MDNode *DisableDistribution = MDNode::get(
+ Context,
+ {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
+ MDNode *NewLoopID =
+ MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
+ DisableLICMVersioning, DisableDistribution});
+ // Set operand 0 to refer to the loop id itself.
+ NewLoopID->replaceOperandWith(0, NewLoopID);
+ L.setLoopID(NewLoopID);
+}
+
+LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
+ function_ref<void(Loop *, bool)> LPMAddNewLoop,
+ const LoopStructure &LS, ScalarEvolution &SE,
+ DominatorTree &DT, Type *T, SubRanges SR)
+ : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
+ DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
+ MainLoopStructure(LS), SR(SR) {}
+
+void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
+ const char *Tag) const {
+ for (BasicBlock *BB : OriginalLoop.getBlocks()) {
+ BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
+ Result.Blocks.push_back(Clone);
+ Result.Map[BB] = Clone;
+ }
+
+ auto GetClonedValue = [&Result](Value *V) {
+ assert(V && "null values not in domain!");
+ auto It = Result.Map.find(V);
+ if (It == Result.Map.end())
+ return V;
+ return static_cast<Value *>(It->second);
+ };
+
+ auto *ClonedLatch =
+ cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
+ ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
+ MDNode::get(Ctx, {}));
+
+ Result.Structure = MainLoopStructure.map(GetClonedValue);
+ Result.Structure.Tag = Tag;
+
+ for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
+ BasicBlock *ClonedBB = Result.Blocks[i];
+ BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
+
+ assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
+
+ for (Instruction &I : *ClonedBB)
+ RemapInstruction(&I, Result.Map,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+
+ // Exit blocks will now have one more predecessor and their PHI nodes need
+ // to be edited to reflect that. No phi nodes need to be introduced because
+ // the loop is in LCSSA.
+
+ for (auto *SBB : successors(OriginalBB)) {
+ if (OriginalLoop.contains(SBB))
+ continue; // not an exit block
+
+ for (PHINode &PN : SBB->phis()) {
+ Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
+ PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
+ SE.forgetValue(&PN);
+ }
+ }
+ }
+}
+
+LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
+ const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
+ BasicBlock *ContinuationBlock) const {
+ // We start with a loop with a single latch:
+ //
+ // +--------------------+
+ // | |
+ // | preheader |
+ // | |
+ // +--------+-----------+
+ // | ----------------\
+ // | / |
+ // +--------v----v------+ |
+ // | | |
+ // | header | |
+ // | | |
+ // +--------------------+ |
+ // |
+ // ..... |
+ // |
+ // +--------------------+ |
+ // | | |
+ // | latch >----------/
+ // | |
+ // +-------v------------+
+ // |
+ // |
+ // | +--------------------+
+ // | | |
+ // +---> original exit |
+ // | |
+ // +--------------------+
+ //
+ // We change the control flow to look like
+ //
+ //
+ // +--------------------+
+ // | |
+ // | preheader >-------------------------+
+ // | | |
+ // +--------v-----------+ |
+ // | /-------------+ |
+ // | / | |
+ // +--------v--v--------+ | |
+ // | | | |
+ // | header | | +--------+ |
+ // | | | | | |
+ // +--------------------+ | | +-----v-----v-----------+
+ // | | | |
+ // | | | .pseudo.exit |
+ // | | | |
+ // | | +-----------v-----------+
+ // | | |
+ // ..... | | |
+ // | | +--------v-------------+
+ // +--------------------+ | | | |
+ // | | | | | ContinuationBlock |
+ // | latch >------+ | | |
+ // | | | +----------------------+
+ // +---------v----------+ |
+ // | |
+ // | |
+ // | +---------------^-----+
+ // | | |
+ // +-----> .exit.selector |
+ // | |
+ // +----------v----------+
+ // |
+ // +--------------------+ |
+ // | | |
+ // | original exit <----+
+ // | |
+ // +--------------------+
+
+ RewrittenRangeInfo RRI;
+
+ BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
+ RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
+ &F, BBInsertLocation);
+ RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
+ BBInsertLocation);
+
+ BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
+ bool Increasing = LS.IndVarIncreasing;
+ bool IsSignedPredicate = LS.IsSignedPredicate;
+
+ IRBuilder<> B(PreheaderJump);
+ auto NoopOrExt = [&](Value *V) {
+ if (V->getType() == RangeTy)
+ return V;
+ return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
+ : B.CreateZExt(V, RangeTy, "wide." + V->getName());
+ };
+
+ // EnterLoopCond - is it okay to start executing this `LS'?
+ Value *EnterLoopCond = nullptr;
+ auto Pred =
+ Increasing
+ ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
+ : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
+ Value *IndVarStart = NoopOrExt(LS.IndVarStart);
+ EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
+
+ B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
+ PreheaderJump->eraseFromParent();
+
+ LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
+ B.SetInsertPoint(LS.LatchBr);
+ Value *IndVarBase = NoopOrExt(LS.IndVarBase);
+ Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
+
+ Value *CondForBranch = LS.LatchBrExitIdx == 1
+ ? TakeBackedgeLoopCond
+ : B.CreateNot(TakeBackedgeLoopCond);
+
+ LS.LatchBr->setCondition(CondForBranch);
+
+ B.SetInsertPoint(RRI.ExitSelector);
+
+ // IterationsLeft - are there any more iterations left, given the original
+ // upper bound on the induction variable? If not, we branch to the "real"
+ // exit.
+ Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
+ Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
+ B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
+
+ BranchInst *BranchToContinuation =
+ BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
+
+ // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
+ // each of the PHI nodes in the loop header. This feeds into the initial
+ // value of the same PHI nodes if/when we continue execution.
+ for (PHINode &PN : LS.Header->phis()) {
+ PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
+ BranchToContinuation);
+
+ NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
+ NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
+ RRI.ExitSelector);
+ RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
+ }
+
+ RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
+ BranchToContinuation);
+ RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
+ RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
+
+ // The latch exit now has a branch from `RRI.ExitSelector' instead of
+ // `LS.Latch'. The PHI nodes need to be updated to reflect that.
+ LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
+
+ return RRI;
+}
+
+void LoopConstrainer::rewriteIncomingValuesForPHIs(
+ LoopStructure &LS, BasicBlock *ContinuationBlock,
+ const LoopConstrainer::RewrittenRangeInfo &RRI) const {
+ unsigned PHIIndex = 0;
+ for (PHINode &PN : LS.Header->phis())
+ PN.setIncomingValueForBlock(ContinuationBlock,
+ RRI.PHIValuesAtPseudoExit[PHIIndex++]);
+
+ LS.IndVarStart = RRI.IndVarEnd;
+}
+
+BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
+ BasicBlock *OldPreheader,
+ const char *Tag) const {
+ BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
+ BranchInst::Create(LS.Header, Preheader);
+
+ LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
+
+ return Preheader;
+}
+
+void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
+ Loop *ParentLoop = OriginalLoop.getParentLoop();
+ if (!ParentLoop)
+ return;
+
+ for (BasicBlock *BB : BBs)
+ ParentLoop->addBasicBlockToLoop(BB, LI);
+}
+
+Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
+ ValueToValueMapTy &VM,
+ bool IsSubloop) {
+ Loop &New = *LI.AllocateLoop();
+ if (Parent)
+ Parent->addChildLoop(&New);
+ else
+ LI.addTopLevelLoop(&New);
+ LPMAddNewLoop(&New, IsSubloop);
+
+ // Add all of the blocks in Original to the new loop.
+ for (auto *BB : Original->blocks())
+ if (LI.getLoopFor(BB) == Original)
+ New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
+
+ // Add all of the subloops to the new loop.
+ for (Loop *SubLoop : *Original)
+ createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
+
+ return &New;
+}
+
+bool LoopConstrainer::run() {
+ BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
+ assert(Preheader != nullptr && "precondition!");
+
+ OriginalPreheader = Preheader;
+ MainLoopPreheader = Preheader;
+ bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
+ bool Increasing = MainLoopStructure.IndVarIncreasing;
+ IntegerType *IVTy = cast<IntegerType>(RangeTy);
+
+ SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer");
+ Instruction *InsertPt = OriginalPreheader->getTerminator();
+
+ // It would have been better to make `PreLoop' and `PostLoop'
+ // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
+ // constructor.
+ ClonedLoop PreLoop, PostLoop;
+ bool NeedsPreLoop =
+ Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
+ bool NeedsPostLoop =
+ Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
+
+ Value *ExitPreLoopAt = nullptr;
+ Value *ExitMainLoopAt = nullptr;
+ const SCEVConstant *MinusOneS =
+ cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
+
+ if (NeedsPreLoop) {
+ const SCEV *ExitPreLoopAtSCEV = nullptr;
+
+ if (Increasing)
+ ExitPreLoopAtSCEV = *SR.LowLimit;
+ else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
+ IsSignedPredicate))
+ ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
+ else {
+ LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
+ << "preloop exit limit. HighLimit = "
+ << *(*SR.HighLimit) << "\n");
+ return false;
+ }
+
+ if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
+ LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
+ << " preloop exit limit " << *ExitPreLoopAtSCEV
+ << " at block " << InsertPt->getParent()->getName()
+ << "\n");
+ return false;
+ }
+
+ ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
+ ExitPreLoopAt->setName("exit.preloop.at");
+ }
+
+ if (NeedsPostLoop) {
+ const SCEV *ExitMainLoopAtSCEV = nullptr;
+
+ if (Increasing)
+ ExitMainLoopAtSCEV = *SR.HighLimit;
+ else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
+ IsSignedPredicate))
+ ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
+ else {
+ LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
+ << "mainloop exit limit. LowLimit = "
+ << *(*SR.LowLimit) << "\n");
+ return false;
+ }
+
+ if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
+ LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
+ << " main loop exit limit " << *ExitMainLoopAtSCEV
+ << " at block " << InsertPt->getParent()->getName()
+ << "\n");
+ return false;
+ }
+
+ ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
+ ExitMainLoopAt->setName("exit.mainloop.at");
+ }
+
+ // We clone these ahead of time so that we don't have to deal with changing
+ // and temporarily invalid IR as we transform the loops.
+ if (NeedsPreLoop)
+ cloneLoop(PreLoop, "preloop");
+ if (NeedsPostLoop)
+ cloneLoop(PostLoop, "postloop");
+
+ RewrittenRangeInfo PreLoopRRI;
+
+ if (NeedsPreLoop) {
+ Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
+ PreLoop.Structure.Header);
+
+ MainLoopPreheader =
+ createPreheader(MainLoopStructure, Preheader, "mainloop");
+ PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
+ ExitPreLoopAt, MainLoopPreheader);
+ rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
+ PreLoopRRI);
+ }
+
+ BasicBlock *PostLoopPreheader = nullptr;
+ RewrittenRangeInfo PostLoopRRI;
+
+ if (NeedsPostLoop) {
+ PostLoopPreheader =
+ createPreheader(PostLoop.Structure, Preheader, "postloop");
+ PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
+ ExitMainLoopAt, PostLoopPreheader);
+ rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
+ PostLoopRRI);
+ }
+
+ BasicBlock *NewMainLoopPreheader =
+ MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
+ BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
+ PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
+ PostLoopRRI.ExitSelector, NewMainLoopPreheader};
+
+ // Some of the above may be nullptr, filter them out before passing to
+ // addToParentLoopIfNeeded.
+ auto NewBlocksEnd =
+ std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
+
+ addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
+
+ DT.recalculate(F);
+
+ // We need to first add all the pre and post loop blocks into the loop
+ // structures (as part of createClonedLoopStructure), and then update the
+ // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
+ // LI when LoopSimplifyForm is generated.
+ Loop *PreL = nullptr, *PostL = nullptr;
+ if (!PreLoop.Blocks.empty()) {
+ PreL = createClonedLoopStructure(&OriginalLoop,
+ OriginalLoop.getParentLoop(), PreLoop.Map,
+ /* IsSubLoop */ false);
+ }
+
+ if (!PostLoop.Blocks.empty()) {
+ PostL =
+ createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
+ PostLoop.Map, /* IsSubLoop */ false);
+ }
+
+ // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
+ auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
+ formLCSSARecursively(*L, DT, &LI, &SE);
+ simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
+ // Pre/post loops are slow paths, we do not need to perform any loop
+ // optimizations on them.
+ if (!IsOriginalLoop)
+ DisableAllLoopOptsOnLoop(*L);
+ };
+ if (PreL)
+ CanonicalizeLoop(PreL, false);
+ if (PostL)
+ CanonicalizeLoop(PostL, false);
+ CanonicalizeLoop(&OriginalLoop, true);
+
+ /// At this point:
+ /// - We've broken a "main loop" out of the loop in a way that the "main loop"
+ /// runs with the induction variable in a subset of [Begin, End).
+ /// - There is no overflow when computing "main loop" exit limit.
+ /// - Max latch taken count of the loop is limited.
+ /// It guarantees that induction variable will not overflow iterating in the
+ /// "main loop".
+ if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
+ if (IsSignedPredicate)
+ cast<BinaryOperator>(MainLoopStructure.IndVarBase)
+ ->setHasNoSignedWrap(true);
+ /// TODO: support unsigned predicate.
+ /// To add NUW flag we need to prove that both operands of BO are
+ /// non-negative. E.g:
+ /// ...
+ /// %iv.next = add nsw i32 %iv, -1
+ /// %cmp = icmp ult i32 %iv.next, %n
+ /// br i1 %cmp, label %loopexit, label %loop
+ ///
+ /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
+ /// overflow, therefore NUW flag is not legal here.
+
+ return true;
+}
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index d701cf110154..f76fa3bb6c61 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -351,11 +351,20 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
MaxPeelCount =
std::min((unsigned)SC->getAPInt().getLimitedValue() - 1, MaxPeelCount);
- auto ComputePeelCount = [&](Value *Condition) -> void {
- if (!Condition->getType()->isIntegerTy())
+ const unsigned MaxDepth = 4;
+ std::function<void(Value *, unsigned)> ComputePeelCount =
+ [&](Value *Condition, unsigned Depth) -> void {
+ if (!Condition->getType()->isIntegerTy() || Depth >= MaxDepth)
return;
Value *LeftVal, *RightVal;
+ if (match(Condition, m_And(m_Value(LeftVal), m_Value(RightVal))) ||
+ match(Condition, m_Or(m_Value(LeftVal), m_Value(RightVal)))) {
+ ComputePeelCount(LeftVal, Depth + 1);
+ ComputePeelCount(RightVal, Depth + 1);
+ return;
+ }
+
CmpInst::Predicate Pred;
if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal))))
return;
@@ -443,7 +452,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
for (BasicBlock *BB : L.blocks()) {
for (Instruction &I : *BB) {
if (SelectInst *SI = dyn_cast<SelectInst>(&I))
- ComputePeelCount(SI->getCondition());
+ ComputePeelCount(SI->getCondition(), 0);
}
auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
@@ -454,7 +463,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
if (L.getLoopLatch() == BB)
continue;
- ComputePeelCount(BI->getCondition());
+ ComputePeelCount(BI->getCondition(), 0);
}
return DesiredPeelCount;
@@ -624,21 +633,24 @@ struct WeightInfo {
/// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to
/// go to exit.
/// Then, Estimated ExitCount = F / E.
-/// For I-th (counting from 0) peeled off iteration we set the the weights for
+/// For I-th (counting from 0) peeled off iteration we set the weights for
/// the peeled exit as (EC - I, 1). It gives us reasonable distribution,
/// The probability to go to exit 1/(EC-I) increases. At the same time
/// the estimated exit count in the remainder loop reduces by I.
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
+ setBranchWeights(*Term, Info.Weights);
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
if (SubWeight != 0)
- Info.Weights[Idx] = Info.Weights[Idx] > SubWeight
- ? Info.Weights[Idx] - SubWeight
- : 1;
+ // Don't set the probability of taking the edge from latch to loop header
+ // to less than 1:1 ratio (meaning Weight should not be lower than
+ // SubWeight), as this could significantly reduce the loop's hotness,
+ // which would be incorrect in the case of underestimating the trip count.
+ Info.Weights[Idx] =
+ Info.Weights[Idx] > SubWeight
+ ? std::max(Info.Weights[Idx] - SubWeight, SubWeight)
+ : SubWeight;
}
/// Initialize the weights for all exiting blocks.
@@ -685,14 +697,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
}
}
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
-}
-
/// Clones the body of the loop L, putting it between \p InsertTop and \p
/// InsertBot.
/// \param IterNumber The serial number of the iteration currently being
@@ -1028,8 +1032,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
- for (const auto &[Term, Info] : Weights)
- fixupBranchWeights(Term, Info);
+ for (const auto &[Term, Info] : Weights) {
+ setBranchWeights(*Term, Info.Weights);
+ }
// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index d81db5647c60..76280ed492b3 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -25,6 +25,8 @@
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -50,6 +52,9 @@ static cl::opt<bool>
cl::desc("Allow loop rotation multiple times in order to reach "
"a better latch exit"));
+// Probability that a rotated loop has zero trip count / is never entered.
+static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
+
namespace {
/// A simple loop rotation transformation.
class LoopRotate {
@@ -154,7 +159,8 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,
// Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug
// intrinsics.
SmallVector<DbgValueInst *, 1> DbgValues;
- llvm::findDbgValues(DbgValues, OrigHeaderVal);
+ SmallVector<DPValue *, 1> DPValues;
+ llvm::findDbgValues(DbgValues, OrigHeaderVal, &DPValues);
for (auto &DbgValue : DbgValues) {
// The original users in the OrigHeader are already using the original
// definitions.
@@ -175,6 +181,29 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader,
NewVal = UndefValue::get(OrigHeaderVal->getType());
DbgValue->replaceVariableLocationOp(OrigHeaderVal, NewVal);
}
+
+ // RemoveDIs: duplicate implementation for non-instruction debug-info
+ // storage in DPValues.
+ for (DPValue *DPV : DPValues) {
+ // The original users in the OrigHeader are already using the original
+ // definitions.
+ BasicBlock *UserBB = DPV->getMarker()->getParent();
+ if (UserBB == OrigHeader)
+ continue;
+
+ // Users in the OrigPreHeader need to use the value to which the
+ // original definitions are mapped and anything else can be handled by
+ // the SSAUpdater. To avoid adding PHINodes, check if the value is
+ // available in UserBB, if not substitute undef.
+ Value *NewVal;
+ if (UserBB == OrigPreheader)
+ NewVal = OrigPreHeaderVal;
+ else if (SSA.HasValueForBlock(UserBB))
+ NewVal = SSA.GetValueInMiddleOfBlock(UserBB);
+ else
+ NewVal = UndefValue::get(OrigHeaderVal->getType());
+ DPV->replaceVariableLocationOp(OrigHeaderVal, NewVal);
+ }
}
}
@@ -244,6 +273,123 @@ static bool canRotateDeoptimizingLatchExit(Loop *L) {
return false;
}
+static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+ bool HasConditionalPreHeader,
+ bool SuccsSwapped) {
+ MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
+ if (WeightMD == nullptr)
+ return;
+
+ // LoopBI should currently be a clone of PreHeaderBI with the same
+ // metadata. But we double check to make sure we don't have a degenerate case
+ // where instsimplify changed the instructions.
+ if (WeightMD != getBranchWeightMDNode(LoopBI))
+ return;
+
+ SmallVector<uint32_t, 2> Weights;
+ extractFromBranchWeightMD(WeightMD, Weights);
+ if (Weights.size() != 2)
+ return;
+ uint32_t OrigLoopExitWeight = Weights[0];
+ uint32_t OrigLoopBackedgeWeight = Weights[1];
+
+ if (SuccsSwapped)
+ std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+
+ // Update branch weights. Consider the following edge-counts:
+ //
+ // | |-------- |
+ // V V | V
+ // Br i1 ... | Br i1 ...
+ // | | | | |
+ // x| y| | becomes: | y0| |-----
+ // V V | | V V |
+ // Exit Loop | | Loop |
+ // | | | Br i1 ... |
+ // ----- | | | |
+ // x0| x1| y1 | |
+ // V V ----
+ // Exit
+ //
+ // The following must hold:
+ // - x == x0 + x1 # counts to "exit" must stay the same.
+ // - y0 == x - x0 == x1 # how often loop was entered at all.
+ // - y1 == y - y0 # How often loop was repeated (after first iter.).
+ //
+ // We cannot generally deduce how often we had a zero-trip count loop so we
+ // have to make a guess for how to distribute x among the new x0 and x1.
+
+ uint32_t ExitWeight0; // aka x0
+ uint32_t ExitWeight1; // aka x1
+ uint32_t EnterWeight; // aka y0
+ uint32_t LoopBackWeight; // aka y1
+ if (OrigLoopExitWeight > 0 && OrigLoopBackedgeWeight > 0) {
+ ExitWeight0 = 0;
+ if (HasConditionalPreHeader) {
+ // Here we cannot know how many 0-trip count loops we have, so we guess:
+ if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
+ // If the loop count is bigger than the exit count then we set
+ // probabilities as if 0-trip count nearly never happens.
+ ExitWeight0 = ZeroTripCountWeights[0];
+ // Scale up counts if necessary so we can match `ZeroTripCountWeights`
+ // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
+ while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
+ // ... but don't overflow.
+ uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
+ if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
+ (OrigLoopExitWeight & HighBit) != 0)
+ break;
+ OrigLoopBackedgeWeight <<= 1;
+ OrigLoopExitWeight <<= 1;
+ }
+ } else {
+ // If there's a higher exit-count than backedge-count then we set
+ // probabilities as if there are only 0-trip and 1-trip cases.
+ ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
+ }
+ }
+ ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
+ EnterWeight = ExitWeight1;
+ LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
+ } else if (OrigLoopExitWeight == 0) {
+ if (OrigLoopBackedgeWeight == 0) {
+ // degenerate case... keep everything zero...
+ ExitWeight0 = 0;
+ ExitWeight1 = 0;
+ EnterWeight = 0;
+ LoopBackWeight = 0;
+ } else {
+ // Special case "LoopExitWeight == 0" weights which behaves like an
+ // endless where we don't want loop-enttry (y0) to be the same as
+ // loop-exit (x1).
+ ExitWeight0 = 0;
+ ExitWeight1 = 0;
+ EnterWeight = 1;
+ LoopBackWeight = OrigLoopBackedgeWeight;
+ }
+ } else {
+ // loop is never entered.
+ assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
+ ExitWeight0 = 1;
+ ExitWeight1 = 1;
+ EnterWeight = 0;
+ LoopBackWeight = 0;
+ }
+
+ const uint32_t LoopBIWeights[] = {
+ SuccsSwapped ? LoopBackWeight : ExitWeight1,
+ SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+ };
+ setBranchWeights(LoopBI, LoopBIWeights);
+ if (HasConditionalPreHeader) {
+ const uint32_t PreHeaderBIWeights[] = {
+ SuccsSwapped ? EnterWeight : ExitWeight0,
+ SuccsSwapped ? ExitWeight0 : EnterWeight,
+ };
+ setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
+ }
+}
+
/// Rotate loop LP. Return true if the loop is rotated.
///
/// \param SimplifiedLatch is true if the latch was just folded into the final
@@ -363,7 +509,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// loop. Otherwise loop is not suitable for rotation.
BasicBlock *Exit = BI->getSuccessor(0);
BasicBlock *NewHeader = BI->getSuccessor(1);
- if (L->contains(Exit))
+ bool BISuccsSwapped = L->contains(Exit);
+ if (BISuccsSwapped)
std::swap(Exit, NewHeader);
assert(NewHeader && "Unable to determine new loop header");
assert(L->contains(NewHeader) && !L->contains(Exit) &&
@@ -394,20 +541,32 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// duplication.
using DbgIntrinsicHash =
std::pair<std::pair<hash_code, DILocalVariable *>, DIExpression *>;
- auto makeHash = [](DbgVariableIntrinsic *D) -> DbgIntrinsicHash {
+ auto makeHash = [](auto *D) -> DbgIntrinsicHash {
auto VarLocOps = D->location_ops();
return {{hash_combine_range(VarLocOps.begin(), VarLocOps.end()),
D->getVariable()},
D->getExpression()};
};
+
SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics;
for (Instruction &I : llvm::drop_begin(llvm::reverse(*OrigPreheader))) {
- if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I))
+ if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) {
DbgIntrinsics.insert(makeHash(DII));
- else
+ // Until RemoveDIs supports dbg.declares in DPValue format, we'll need
+ // to collect DPValues attached to any other debug intrinsics.
+ for (const DPValue &DPV : DII->getDbgValueRange())
+ DbgIntrinsics.insert(makeHash(&DPV));
+ } else {
break;
+ }
}
+ // Build DPValue hashes for DPValues attached to the terminator, which isn't
+ // considered in the loop above.
+ for (const DPValue &DPV :
+ OrigPreheader->getTerminator()->getDbgValueRange())
+ DbgIntrinsics.insert(makeHash(&DPV));
+
// Remember the local noalias scope declarations in the header. After the
// rotation, they must be duplicated and the scope must be cloned. This
// avoids unwanted interaction across iterations.
@@ -416,6 +575,29 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
if (auto *Decl = dyn_cast<NoAliasScopeDeclInst>(&I))
NoAliasDeclInstructions.push_back(Decl);
+ Module *M = OrigHeader->getModule();
+
+ // Track the next DPValue to clone. If we have a sequence where an
+ // instruction is hoisted instead of being cloned:
+ // DPValue blah
+ // %foo = add i32 0, 0
+ // DPValue xyzzy
+ // %bar = call i32 @foobar()
+ // where %foo is hoisted, then the DPValue "blah" will be seen twice, once
+ // attached to %foo, then when %foo his hoisted it will "fall down" onto the
+ // function call:
+ // DPValue blah
+ // DPValue xyzzy
+ // %bar = call i32 @foobar()
+ // causing it to appear attached to the call too.
+ //
+ // To avoid this, cloneDebugInfoFrom takes an optional "start cloning from
+ // here" position to account for this behaviour. We point it at any DPValues
+ // on the next instruction, here labelled xyzzy, before we hoist %foo.
+ // Later, we only only clone DPValues from that position (xyzzy) onwards,
+ // which avoids cloning DPValue "blah" multiple times.
+ std::optional<DPValue::self_iterator> NextDbgInst = std::nullopt;
+
while (I != E) {
Instruction *Inst = &*I++;
@@ -428,7 +610,21 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() &&
!Inst->mayWriteToMemory() && !Inst->isTerminator() &&
!isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) {
+
+ if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) {
+ auto DbgValueRange =
+ LoopEntryBranch->cloneDebugInfoFrom(Inst, NextDbgInst);
+ RemapDPValueRange(M, DbgValueRange, ValueMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ // Erase anything we've seen before.
+ for (DPValue &DPV : make_early_inc_range(DbgValueRange))
+ if (DbgIntrinsics.count(makeHash(&DPV)))
+ DPV.eraseFromParent();
+ }
+
+ NextDbgInst = I->getDbgValueRange().begin();
Inst->moveBefore(LoopEntryBranch);
+
++NumInstrsHoisted;
continue;
}
@@ -439,6 +635,17 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
++NumInstrsDuplicated;
+ if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) {
+ auto Range = C->cloneDebugInfoFrom(Inst, NextDbgInst);
+ RemapDPValueRange(M, Range, ValueMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ NextDbgInst = std::nullopt;
+ // Erase anything we've seen before.
+ for (DPValue &DPV : make_early_inc_range(Range))
+ if (DbgIntrinsics.count(makeHash(&DPV)))
+ DPV.eraseFromParent();
+ }
+
// Eagerly remap the operands of the instruction.
RemapInstruction(C, ValueMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
@@ -553,6 +760,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// OrigPreHeader's old terminator (the original branch into the loop), and
// remove the corresponding incoming values from the PHI nodes in OrigHeader.
LoopEntryBranch->eraseFromParent();
+ OrigPreheader->flushTerminatorDbgValues();
// Update MemorySSA before the rewrite call below changes the 1:1
// instruction:cloned_instruction_or_value mapping.
@@ -605,9 +813,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// to split as many edges.
BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator());
assert(PHBI->isConditional() && "Should be clone of BI condbr!");
- if (!isa<ConstantInt>(PHBI->getCondition()) ||
- PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) !=
- NewHeader) {
+ const Value *Cond = PHBI->getCondition();
+ const bool HasConditionalPreHeader =
+ !isa<ConstantInt>(Cond) ||
+ PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
+
+ updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+
+ if (HasConditionalPreHeader) {
// The conditional branch can't be folded, handle the general case.
// Split edges as necessary to preserve LoopSimplify form.
diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
index 3e604fdf2e11..07e622b1577f 100644
--- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp
+++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
@@ -429,8 +429,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader,
PN->setIncomingBlock(0, PN->getIncomingBlock(PreheaderIdx));
}
// Nuke all entries except the zero'th.
- for (unsigned i = 0, e = PN->getNumIncomingValues()-1; i != e; ++i)
- PN->removeIncomingValue(e-i, false);
+ PN->removeIncomingValueIf([](unsigned Idx) { return Idx != 0; },
+ /* DeletePHIIfEmpty */ false);
// Finally, add the newly constructed PHI node as the entry for the BEBlock.
PN->addIncoming(NewPN, BEBlock);
diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
index 511dd61308f9..ee6f7b35750a 100644
--- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
@@ -24,7 +24,6 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/ilist_iterator.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -838,7 +837,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
DTUToUse ? nullptr : DT)) {
// Dest has been folded into Fold. Update our worklists accordingly.
std::replace(Latches.begin(), Latches.end(), Dest, Fold);
- llvm::erase_value(UnrolledLoopBlocks, Dest);
+ llvm::erase(UnrolledLoopBlocks, Dest);
}
}
}
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
index 31b8cd34eb24..3c06a6e47a30 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp
@@ -19,7 +19,6 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/ADT/iterator_range.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/DependenceAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 1e22eca30d2d..612f69970881 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -56,6 +56,17 @@ static cl::opt<bool> UnrollRuntimeOtherExitPredictable(
"unroll-runtime-other-exit-predictable", cl::init(false), cl::Hidden,
cl::desc("Assume the non latch exit block to be predictable"));
+// Probability that the loop trip count is so small that after the prolog
+// we do not enter the unrolled loop at all.
+// It is unlikely that the loop trip count is smaller than the unroll factor;
+// other than that, the choice of constant is not tuned yet.
+static const uint32_t UnrolledLoopHeaderWeights[] = {1, 127};
+// Probability that the loop trip count is so small that we skip the unrolled
+// loop completely and immediately enter the epilogue loop.
+// It is unlikely that the loop trip count is smaller than the unroll factor;
+// other than that, the choice of constant is not tuned yet.
+static const uint32_t EpilogHeaderWeights[] = {1, 127};
+
/// Connect the unrolling prolog code to the original loop.
/// The unrolling prolog code contains code to execute the
/// 'extra' iterations if the run-time trip count modulo the
@@ -105,8 +116,8 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
// PrologLatch. When supporting multiple-exiting block loops, we can have
// two or more blocks that have the LatchExit as the target in the
// original loop.
- PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr",
- PrologExit->getFirstNonPHI());
+ PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr");
+ NewPN->insertBefore(PrologExit->getFirstNonPHIIt());
// Adding a value to the new PHI node from the original loop preheader.
// This is the value that skips all the prolog code.
if (L->contains(&PN)) {
@@ -169,7 +180,14 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
SplitBlockPredecessors(OriginalLoopLatchExit, Preds, ".unr-lcssa", DT, LI,
nullptr, PreserveLCSSA);
// Add the branch to the exit block (around the unrolled loop)
- B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume loop is nearly always entered.
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(UnrolledLoopHeaderWeights);
+ }
+ B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader,
+ BranchWeights);
InsertPt->eraseFromParent();
if (DT) {
auto *NewDom = DT->findNearestCommonDominator(OriginalLoopLatchExit,
@@ -194,8 +212,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
BasicBlock *Exit, BasicBlock *PreHeader,
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
ValueToValueMapTy &VMap, DominatorTree *DT,
- LoopInfo *LI, bool PreserveLCSSA,
- ScalarEvolution &SE) {
+ LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
+ unsigned Count) {
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Loop must have a latch");
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -269,8 +287,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
for (PHINode &PN : Succ->phis()) {
// Add new PHI nodes to the loop exit block and update epilog
// PHIs with the new PHI values.
- PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr",
- NewExit->getFirstNonPHI());
+ PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr");
+ NewPN->insertBefore(NewExit->getFirstNonPHIIt());
// Adding a value to the new PHI node from the unrolling loop preheader.
NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader);
// Adding a value to the new PHI node from the unrolling loop latch.
@@ -292,7 +310,13 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr,
PreserveLCSSA);
// Add the branch to the exit block (around the unrolling loop)
- B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume equal distribution in interval [0, Count).
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(1, Count - 1);
+ }
+ B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
InsertPt->eraseFromParent();
if (DT) {
auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit);
@@ -316,8 +340,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
const bool UnrollRemainder,
BasicBlock *InsertTop,
BasicBlock *InsertBot, BasicBlock *Preheader,
- std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks,
- ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI) {
+ std::vector<BasicBlock *> &NewBlocks,
+ LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
+ DominatorTree *DT, LoopInfo *LI, unsigned Count) {
StringRef suffix = UseEpilogRemainder ? "epil" : "prol";
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
@@ -363,14 +388,34 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]);
BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator());
IRBuilder<> Builder(LatchBR);
- PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2,
- suffix + ".iter",
- FirstLoopBB->getFirstNonPHI());
+ PHINode *NewIdx =
+ PHINode::Create(NewIter->getType(), 2, suffix + ".iter");
+ NewIdx->insertBefore(FirstLoopBB->getFirstNonPHIIt());
auto *Zero = ConstantInt::get(NewIdx->getType(), 0);
auto *One = ConstantInt::get(NewIdx->getType(), 1);
- Value *IdxNext = Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");
+ Value *IdxNext =
+ Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");
Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp");
- Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*LatchBR)) {
+ uint32_t ExitWeight;
+ uint32_t BackEdgeWeight;
+ if (Count >= 3) {
+ // Note: We do not enter this loop for zero-remainders. The check
+ // is at the end of the loop. We assume equal distribution between
+ // possible remainders in [1, Count).
+ ExitWeight = 1;
+ BackEdgeWeight = (Count - 2) / 2;
+ } else {
+ // Unnecessary backedge, should never be taken. The conditional
+ // jump should be optimized away later.
+ ExitWeight = 1;
+ BackEdgeWeight = 0;
+ }
+ MDBuilder MDB(Builder.getContext());
+ BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
+ }
+ Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
NewIdx->addIncoming(Zero, InsertTop);
NewIdx->addIncoming(IdxNext, NewBB);
LatchBR->eraseFromParent();
@@ -464,32 +509,6 @@ static bool canProfitablyUnrollMultiExitLoop(
// know of kinds of multiexit loops that would benefit from unrolling.
}
-// Assign the maximum possible trip count as the back edge weight for the
-// remainder loop if the original loop comes with a branch weight.
-static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
- Loop *RemainderLoop,
- uint64_t UnrollFactor) {
- uint64_t TrueWeight, FalseWeight;
- BranchInst *LatchBR =
- cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
- if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
- return;
- uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
- ? FalseWeight
- : TrueWeight;
- assert(UnrollFactor > 1);
- uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight;
- BasicBlock *Header = RemainderLoop->getHeader();
- BasicBlock *Latch = RemainderLoop->getLoopLatch();
- auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator());
- unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1);
- MDBuilder MDB(RemainderLatchBR->getContext());
- MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight)
- : MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
- RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
-}
-
/// Calculate ModVal = (BECount + 1) % Count on the abstract integer domain
/// accounting for the possibility of unsigned overflow in the 2s complement
/// domain. Preconditions:
@@ -775,7 +794,13 @@ bool llvm::UnrollRuntimeLoopRemainder(
BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader;
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
// Branch to either remainder (extra iterations) loop or unrolling loop.
- B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume loop is nearly always entered.
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights);
+ }
+ B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
PreHeaderBR->eraseFromParent();
if (DT) {
if (UseEpilogRemainder)
@@ -804,12 +829,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
Loop *remainderLoop = CloneLoopBlocks(
L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
- NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);
-
- // Assign the maximum possible trip count as the back edge weight for the
- // remainder loop if the original loop comes with a branch weight.
- if (remainderLoop && !UnrollRemainder)
- updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count);
+ NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
// Insert the cloned blocks into the function.
F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end());
@@ -893,9 +913,12 @@ bool llvm::UnrollRuntimeLoopRemainder(
// Rewrite the cloned instruction operands to use the values created when the
// clone is created.
for (BasicBlock *BB : NewBlocks) {
+ Module *M = BB->getModule();
for (Instruction &I : *BB) {
RemapInstruction(&I, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ RemapDPValueRange(M, I.getDbgValueRange(), VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
}
}
@@ -903,7 +926,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
// Connect the epilog code to the original loop and update the
// PHI functions.
ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE);
+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count);
// Update counter in loop for unrolling.
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
@@ -912,8 +935,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
IRBuilder<> B2(NewPreHeader->getTerminator());
Value *TestVal = B2.CreateSub(TripCount, ModVal, "unroll_iter");
BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator());
- PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter",
- Header->getFirstNonPHI());
+ PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter");
+ NewIdx->insertBefore(Header->getFirstNonPHIIt());
B2.SetInsertPoint(LatchBR);
auto *Zero = ConstantInt::get(NewIdx->getType(), 0);
auto *One = ConstantInt::get(NewIdx->getType(), 1);
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 7d6662c44f07..59485126b280 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -296,7 +296,7 @@ std::optional<MDNode *> llvm::makeFollowupLoopID(
StringRef AttrName = cast<MDString>(NameMD)->getString();
// Do not inherit excluded attributes.
- return !AttrName.startswith(InheritOptionsExceptPrefix);
+ return !AttrName.starts_with(InheritOptionsExceptPrefix);
};
if (InheritThisAttribute(Op))
@@ -556,12 +556,8 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// Removes all incoming values from all other exiting blocks (including
// duplicate values from an exiting block).
// Nuke all entries except the zero'th entry which is the preheader entry.
- // NOTE! We need to remove Incoming Values in the reverse order as done
- // below, to keep the indices valid for deletion (removeIncomingValues
- // updates getNumIncomingValues and shifts all values down into the
- // operand being deleted).
- for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i)
- P.removeIncomingValue(e - i, false);
+ P.removeIncomingValueIf([](unsigned Idx) { return Idx != 0; },
+ /* DeletePHIIfEmpty */ false);
assert((P.getNumIncomingValues() == 1 &&
P.getIncomingBlock(PredIndex) == Preheader) &&
@@ -608,6 +604,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// Use a map to unique and a vector to guarantee deterministic ordering.
llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst;
+ llvm::SmallVector<DPValue *, 4> DeadDPValues;
if (ExitBlock) {
// Given LCSSA form is satisfied, we should not have users of instructions
@@ -632,6 +629,24 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
"Unexpected user in reachable block");
U.set(Poison);
}
+
+ // RemoveDIs: do the same as below for DPValues.
+ if (Block->IsNewDbgInfoFormat) {
+ for (DPValue &DPV :
+ llvm::make_early_inc_range(I.getDbgValueRange())) {
+ DebugVariable Key(DPV.getVariable(), DPV.getExpression(),
+ DPV.getDebugLoc().get());
+ if (!DeadDebugSet.insert(Key).second)
+ continue;
+ // Unlinks the DPV from it's container, for later insertion.
+ DPV.removeFromParent();
+ DeadDPValues.push_back(&DPV);
+ }
+ }
+
+ // For one of each variable encountered, preserve a debug intrinsic (set
+ // to Poison) and transfer it to the loop exit. This terminates any
+ // variable locations that were set during the loop.
auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I);
if (!DVI)
continue;
@@ -646,12 +661,22 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// be be replaced with undef. Loop invariant values will still be available.
// Move dbg.values out the loop so that earlier location ranges are still
// terminated and loop invariant assignments are preserved.
- Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI();
- assert(InsertDbgValueBefore &&
+ DIBuilder DIB(*ExitBlock->getModule());
+ BasicBlock::iterator InsertDbgValueBefore =
+ ExitBlock->getFirstInsertionPt();
+ assert(InsertDbgValueBefore != ExitBlock->end() &&
"There should be a non-PHI instruction in exit block, else these "
"instructions will have no parent.");
+
for (auto *DVI : DeadDebugInst)
- DVI->moveBefore(InsertDbgValueBefore);
+ DVI->moveBefore(*ExitBlock, InsertDbgValueBefore);
+
+ // Due to the "head" bit in BasicBlock::iterator, we're going to insert
+ // each DPValue right at the start of the block, wheras dbg.values would be
+ // repeatedly inserted before the first instruction. To replicate this
+ // behaviour, do it backwards.
+ for (DPValue *DPV : llvm::reverse(DeadDPValues))
+ ExitBlock->insertDPValueBefore(DPV, InsertDbgValueBefore);
}
// Remove the block from the reference counting scheme, so that we can
@@ -937,8 +962,8 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
}
}
-Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
- RecurKind RK, Value *Left, Value *Right) {
+Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
+ RecurKind RK, Value *Left, Value *Right) {
if (auto VTy = dyn_cast<VectorType>(Left->getType()))
StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal);
Value *Cmp =
@@ -1028,14 +1053,12 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
}
-Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
- const TargetTransformInfo *TTI,
- Value *Src,
- const RecurrenceDescriptor &Desc,
- PHINode *OrigPhi) {
- assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind(
- Desc.getRecurrenceKind()) &&
- "Unexpected reduction kind");
+Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
+ const RecurrenceDescriptor &Desc,
+ PHINode *OrigPhi) {
+ assert(
+ RecurrenceDescriptor::isAnyOfRecurrenceKind(Desc.getRecurrenceKind()) &&
+ "Unexpected reduction kind");
Value *InitVal = Desc.getRecurrenceStartValue();
Value *NewVal = nullptr;
@@ -1068,9 +1091,8 @@ Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
}
-Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
- const TargetTransformInfo *TTI,
- Value *Src, RecurKind RdxKind) {
+Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
+ RecurKind RdxKind) {
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
switch (RdxKind) {
case RecurKind::Add:
@@ -1111,7 +1133,6 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
}
Value *llvm::createTargetReduction(IRBuilderBase &B,
- const TargetTransformInfo *TTI,
const RecurrenceDescriptor &Desc, Value *Src,
PHINode *OrigPhi) {
// TODO: Support in-order reductions based on the recurrence descriptor.
@@ -1121,10 +1142,10 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
B.setFastMathFlags(Desc.getFastMathFlags());
RecurKind RK = Desc.getRecurrenceKind();
- if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
- return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi);
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
- return createSimpleTargetReduction(B, TTI, Src, RK);
+ return createSimpleTargetReduction(B, Src, RK);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,
@@ -1453,7 +1474,7 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
// Note that we must not perform expansions until after
// we query *all* the costs, because if we perform temporary expansion
// inbetween, one that we might not intend to keep, said expansion
- // *may* affect cost calculation of the the next SCEV's we'll query,
+ // *may* affect cost calculation of the next SCEV's we'll query,
// and next SCEV may errneously get smaller cost.
// Collect all the candidate PHINodes to be rewritten.
@@ -1632,42 +1653,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
struct PointerBounds {
TrackingVH<Value> Start;
TrackingVH<Value> End;
+ Value *StrideToCheck;
};
/// Expand code for the lower and upper bound of the pointer group \p CG
/// in \p TheLoop. \return the values for the bounds.
static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
Loop *TheLoop, Instruction *Loc,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
LLVMContext &Ctx = Loc->getContext();
- Type *PtrArithTy = Type::getInt8PtrTy(Ctx, CG->AddressSpace);
+ Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace);
Value *Start = nullptr, *End = nullptr;
LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
- Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
- End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
+
+ // If the Low and High values are themselves loop-variant, then we may want
+ // to expand the range to include those covered by the outer loop as well.
+ // There is a trade-off here with the advantage being that creating checks
+ // using the expanded range permits the runtime memory checks to be hoisted
+ // out of the outer loop. This reduces the cost of entering the inner loop,
+ // which can be significant for low trip counts. The disadvantage is that
+ // there is a chance we may now never enter the vectorized inner loop,
+ // whereas using a restricted range check could have allowed us to enter at
+ // least once. This is why the behaviour is not currently the default and is
+ // controlled by the parameter 'HoistRuntimeChecks'.
+ if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
+ isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) {
+ auto *HighAR = cast<SCEVAddRecExpr>(High);
+ auto *LowAR = cast<SCEVAddRecExpr>(Low);
+ const Loop *OuterLoop = TheLoop->getParentLoop();
+ const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE());
+ if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) &&
+ HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
+ BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
+ const SCEV *OuterExitCount =
+ Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch);
+ if (!isa<SCEVCouldNotCompute>(OuterExitCount) &&
+ OuterExitCount->getType()->isIntegerTy()) {
+ const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration(
+ OuterExitCount, *Exp.getSE());
+ if (!isa<SCEVCouldNotCompute>(NewHigh)) {
+ LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
+ "outer loop in order to permit hoisting\n");
+ High = NewHigh;
+ Low = cast<SCEVAddRecExpr>(Low)->getStart();
+ // If there is a possibility that the stride is negative then we have
+ // to generate extra checks to ensure the stride is positive.
+ if (!Exp.getSE()->isKnownNonNegative(Recur)) {
+ Stride = Recur;
+ LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
+ "positive: "
+ << *Stride << '\n');
+ }
+ }
+ }
+ }
+ }
+
+ Start = Exp.expandCodeFor(Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(High, PtrArithTy, Loc);
if (CG->NeedsFreeze) {
IRBuilder<> Builder(Loc);
Start = Builder.CreateFreeze(Start, Start->getName() + ".fr");
End = Builder.CreateFreeze(End, End->getName() + ".fr");
}
- LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n");
- return {Start, End};
+ Value *StrideVal =
+ Stride ? Exp.expandCodeFor(Stride, Stride->getType(), Loc) : nullptr;
+ LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
+ return {Start, End, StrideVal};
}
/// Turns a collection of checks into a collection of expanded upper and
/// lower bounds for both pointers in the check.
static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
- Instruction *Loc, SCEVExpander &Exp) {
+ Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
// Here we're relying on the SCEV Expander's cache to only emit code for the
// same bounds once.
transform(PointerChecks, std::back_inserter(ChecksWithBounds),
[&](const RuntimePointerCheck &Check) {
- PointerBounds First = expandBounds(Check.first, L, Loc, Exp),
- Second = expandBounds(Check.second, L, Loc, Exp);
+ PointerBounds First = expandBounds(Check.first, L, Loc, Exp,
+ HoistRuntimeChecks),
+ Second = expandBounds(Check.second, L, Loc, Exp,
+ HoistRuntimeChecks);
return std::make_pair(First, Second);
});
@@ -1677,10 +1748,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
Value *llvm::addRuntimeChecks(
Instruction *Loc, Loop *TheLoop,
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
// TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
// TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
- auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp);
+ auto ExpandedChecks =
+ expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks);
LLVMContext &Ctx = Loc->getContext();
IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx,
@@ -1693,21 +1765,13 @@ Value *llvm::addRuntimeChecks(
const PointerBounds &A = Check.first, &B = Check.second;
// Check if two pointers (A and B) conflict where conflict is computed as:
// start(A) <= end(B) && start(B) <= end(A)
- unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
- unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
- assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
- (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+ assert((A.Start->getType()->getPointerAddressSpace() ==
+ B.End->getType()->getPointerAddressSpace()) &&
+ (B.Start->getType()->getPointerAddressSpace() ==
+ A.End->getType()->getPointerAddressSpace()) &&
"Trying to bounds check pointers with different address spaces");
- Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
- Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
-
- Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
- Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
- Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
- Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
-
// [A|B].Start points to the first accessed byte under base [A|B].
// [A|B].End points to the last accessed byte, plus one.
// There is no conflict when the intervals are disjoint:
@@ -1716,9 +1780,21 @@ Value *llvm::addRuntimeChecks(
// bound0 = (B.Start < A.End)
// bound1 = (A.Start < B.End)
// IsConflict = bound0 & bound1
- Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
- Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
+ Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0");
+ Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1");
Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ if (A.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
+ if (B.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
@@ -1740,23 +1816,31 @@ Value *llvm::addDiffRuntimeChecks(
// Our instructions might fold to a constant.
Value *MemoryRuntimeCheck = nullptr;
+ auto &SE = *Expander.getSE();
+ // Map to keep track of created compares, The key is the pair of operands for
+ // the compare, to allow detecting and re-using redundant compares.
+ DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
for (const auto &C : Checks) {
Type *Ty = C.SinkStart->getType();
// Compute VF * IC * AccessSize.
auto *VFTimesUFTimesSize =
ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
ConstantInt::get(Ty, IC * C.AccessSize));
- Value *Sink = Expander.expandCodeFor(C.SinkStart, Ty, Loc);
- Value *Src = Expander.expandCodeFor(C.SrcStart, Ty, Loc);
- if (C.NeedsFreeze) {
- IRBuilder<> Builder(Loc);
- Sink = Builder.CreateFreeze(Sink, Sink->getName() + ".fr");
- Src = Builder.CreateFreeze(Src, Src->getName() + ".fr");
- }
- Value *Diff = ChkBuilder.CreateSub(Sink, Src);
- Value *IsConflict =
- ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check");
+ Value *Diff = Expander.expandCodeFor(
+ SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc);
+
+ // Check if the same compare has already been created earlier. In that case,
+ // there is no need to check it again.
+ Value *IsConflict = SeenCompares.lookup({Diff, VFTimesUFTimesSize});
+ if (IsConflict)
+ continue;
+ IsConflict =
+ ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check");
+ SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict});
+ if (C.NeedsFreeze)
+ IsConflict =
+ ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr");
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index 78ebe75c121b..548b0f3c55f0 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -145,8 +145,8 @@ void LoopVersioning::addPHINodes(
}
// If not create it.
if (!PN) {
- PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
- &PHIBlock->front());
+ PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver");
+ PN->insertBefore(PHIBlock->begin());
SmallVector<User*, 8> UsersToUpdate;
for (User *U : Inst->users())
if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
diff --git a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp
index 195c274ff18e..4908535cba54 100644
--- a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp
+++ b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp
@@ -128,7 +128,7 @@ static bool runImpl(Module &M) {
// extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
LLVMContext &C = M.getContext();
- PointerType *VoidStar = Type::getInt8PtrTy(C);
+ PointerType *VoidStar = PointerType::getUnqual(C);
Type *AtExitFuncArgs[] = {VoidStar};
FunctionType *AtExitFuncTy =
FunctionType::get(Type::getVoidTy(C), AtExitFuncArgs,
@@ -140,6 +140,17 @@ static bool runImpl(Module &M) {
{PointerType::get(AtExitFuncTy, 0), VoidStar, VoidStar},
/*isVarArg=*/false));
+ // If __cxa_atexit is defined (e.g. in the case of LTO) and arg0 is not
+ // actually used (i.e. it's dummy/stub function as used in emscripten when
+ // the program never exits) we can simply return early and clear out
+ // @llvm.global_dtors.
+ if (auto F = dyn_cast<Function>(AtExit.getCallee())) {
+ if (F && F->hasExactDefinition() && F->getArg(0)->getNumUses() == 0) {
+ GV->eraseFromParent();
+ return true;
+ }
+ }
+
// Declare __dso_local.
Type *DsoHandleTy = Type::getInt8Ty(C);
Constant *DsoHandle = M.getOrInsertGlobal("__dso_handle", DsoHandleTy, [&] {
diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
index 906eb71fc2d9..c75de8687879 100644
--- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
@@ -64,17 +64,6 @@ void llvm::createMemCpyLoopKnownSize(
IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
- // Cast the Src and Dst pointers to pointers to the loop operand type (if
- // needed).
- PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS);
- PointerType *DstOpType = PointerType::get(LoopOpType, DstAS);
- if (SrcAddr->getType() != SrcOpType) {
- SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType);
- }
- if (DstAddr->getType() != DstOpType) {
- DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType);
- }
-
Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize));
@@ -137,13 +126,9 @@ void llvm::createMemCpyLoopKnownSize(
uint64_t GepIndex = BytesCopied / OperandSize;
assert(GepIndex * OperandSize == BytesCopied &&
"Division should have no Remainder!");
- // Cast source to operand type and load
- PointerType *SrcPtrType = PointerType::get(OpTy, SrcAS);
- Value *CastedSrc = SrcAddr->getType() == SrcPtrType
- ? SrcAddr
- : RBuilder.CreateBitCast(SrcAddr, SrcPtrType);
+
Value *SrcGEP = RBuilder.CreateInBoundsGEP(
- OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex));
+ OpTy, SrcAddr, ConstantInt::get(TypeOfCopyLen, GepIndex));
LoadInst *Load =
RBuilder.CreateAlignedLoad(OpTy, SrcGEP, PartSrcAlign, SrcIsVolatile);
if (!CanOverlap) {
@@ -151,13 +136,8 @@ void llvm::createMemCpyLoopKnownSize(
Load->setMetadata(LLVMContext::MD_alias_scope,
MDNode::get(Ctx, NewScope));
}
- // Cast destination to operand type and store.
- PointerType *DstPtrType = PointerType::get(OpTy, DstAS);
- Value *CastedDst = DstAddr->getType() == DstPtrType
- ? DstAddr
- : RBuilder.CreateBitCast(DstAddr, DstPtrType);
Value *DstGEP = RBuilder.CreateInBoundsGEP(
- OpTy, CastedDst, ConstantInt::get(TypeOfCopyLen, GepIndex));
+ OpTy, DstAddr, ConstantInt::get(TypeOfCopyLen, GepIndex));
StoreInst *Store = RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign,
DstIsVolatile);
if (!CanOverlap) {
@@ -206,15 +186,6 @@ void llvm::createMemCpyLoopUnknownSize(
IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
- PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS);
- PointerType *DstOpType = PointerType::get(LoopOpType, DstAS);
- if (SrcAddr->getType() != SrcOpType) {
- SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType);
- }
- if (DstAddr->getType() != DstOpType) {
- DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType);
- }
-
// Calculate the loop trip count, and remaining bytes to copy after the loop.
Type *CopyLenType = CopyLen->getType();
IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType);
@@ -305,13 +276,9 @@ void llvm::createMemCpyLoopUnknownSize(
ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index");
ResidualIndex->addIncoming(Zero, ResHeaderBB);
- Value *SrcAsResLoopOpType = ResBuilder.CreateBitCast(
- SrcAddr, PointerType::get(ResLoopOpType, SrcAS));
- Value *DstAsResLoopOpType = ResBuilder.CreateBitCast(
- DstAddr, PointerType::get(ResLoopOpType, DstAS));
Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex);
- Value *SrcGEP = ResBuilder.CreateInBoundsGEP(
- ResLoopOpType, SrcAsResLoopOpType, FullOffset);
+ Value *SrcGEP =
+ ResBuilder.CreateInBoundsGEP(ResLoopOpType, SrcAddr, FullOffset);
LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP,
PartSrcAlign, SrcIsVolatile);
if (!CanOverlap) {
@@ -319,8 +286,8 @@ void llvm::createMemCpyLoopUnknownSize(
Load->setMetadata(LLVMContext::MD_alias_scope,
MDNode::get(Ctx, NewScope));
}
- Value *DstGEP = ResBuilder.CreateInBoundsGEP(
- ResLoopOpType, DstAsResLoopOpType, FullOffset);
+ Value *DstGEP =
+ ResBuilder.CreateInBoundsGEP(ResLoopOpType, DstAddr, FullOffset);
StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign,
DstIsVolatile);
if (!CanOverlap) {
@@ -479,11 +446,6 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr,
IRBuilder<> Builder(OrigBB->getTerminator());
- // Cast pointer to the type of value getting stored
- unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
- DstAddr = Builder.CreateBitCast(DstAddr,
- PointerType::get(SetValue->getType(), dstAS));
-
Builder.CreateCondBr(
Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB,
LoopBB);
diff --git a/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/llvm/lib/Transforms/Utils/MetaRenamer.cpp
index 44ac65f265f0..fd0112ae529c 100644
--- a/llvm/lib/Transforms/Utils/MetaRenamer.cpp
+++ b/llvm/lib/Transforms/Utils/MetaRenamer.cpp
@@ -151,7 +151,7 @@ void MetaRename(Module &M,
auto IsNameExcluded = [](StringRef &Name,
SmallVectorImpl<StringRef> &ExcludedPrefixes) {
return any_of(ExcludedPrefixes,
- [&Name](auto &Prefix) { return Name.startswith(Prefix); });
+ [&Name](auto &Prefix) { return Name.starts_with(Prefix); });
};
// Leave library functions alone because their presence or absence could
@@ -159,7 +159,7 @@ void MetaRename(Module &M,
auto ExcludeLibFuncs = [&](Function &F) {
LibFunc Tmp;
StringRef Name = F.getName();
- return Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
+ return Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) ||
GetTLI(F).getLibFunc(F, Tmp) ||
IsNameExcluded(Name, ExcludedFuncPrefixes);
};
@@ -177,7 +177,7 @@ void MetaRename(Module &M,
// Rename all aliases
for (GlobalAlias &GA : M.aliases()) {
StringRef Name = GA.getName();
- if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
+ if (Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) ||
IsNameExcluded(Name, ExcludedAliasesPrefixes))
continue;
@@ -187,7 +187,7 @@ void MetaRename(Module &M,
// Rename all global variables
for (GlobalVariable &GV : M.globals()) {
StringRef Name = GV.getName();
- if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) ||
+ if (Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) ||
IsNameExcluded(Name, ExcludedGlobalsPrefixes))
continue;
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index 1e243ef74df7..7de0959ca57e 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -44,17 +44,17 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F,
}
GVCtor->eraseFromParent();
} else {
- EltTy = StructType::get(
- IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()),
- IRB.getInt8PtrTy());
+ EltTy = StructType::get(IRB.getInt32Ty(),
+ PointerType::get(FnTy, F->getAddressSpace()),
+ IRB.getPtrTy());
}
// Build a 3 field global_ctor entry. We don't take a comdat key.
Constant *CSVals[3];
CSVals[0] = IRB.getInt32(Priority);
CSVals[1] = F;
- CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy())
- : Constant::getNullValue(IRB.getInt8PtrTy());
+ CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getPtrTy())
+ : Constant::getNullValue(IRB.getPtrTy());
Constant *RuntimeCtorInit =
ConstantStruct::get(EltTy, ArrayRef(CSVals, EltTy->getNumElements()));
@@ -96,7 +96,7 @@ static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *>
if (GV)
GV->eraseFromParent();
- Type *ArrayEltTy = llvm::Type::getInt8PtrTy(M.getContext());
+ Type *ArrayEltTy = llvm::PointerType::getUnqual(M.getContext());
for (auto *V : Values)
Init.insert(ConstantExpr::getPointerBitCastOrAddrSpaceCast(V, ArrayEltTy));
@@ -301,7 +301,7 @@ std::string llvm::getUniqueModuleId(Module *M) {
MD5 Md5;
bool ExportsSymbols = false;
auto AddGlobal = [&](GlobalValue &GV) {
- if (GV.isDeclaration() || GV.getName().startswith("llvm.") ||
+ if (GV.isDeclaration() || GV.getName().starts_with("llvm.") ||
!GV.hasExternalLinkage() || GV.hasComdat())
return;
ExportsSymbols = true;
@@ -346,7 +346,8 @@ void VFABI::setVectorVariantNames(CallInst *CI,
#ifndef NDEBUG
for (const std::string &VariantMapping : VariantMappings) {
LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n");
- std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M);
+ std::optional<VFInfo> VI =
+ VFABI::tryDemangleForVFABI(VariantMapping, CI->getFunctionType());
assert(VI && "Cannot add an invalid VFABI name.");
assert(M->getNamedValue(VI->VectorName) &&
"Cannot add variant to attribute: "
diff --git a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp
index b0ca0b15c08e..a977ad87b79f 100644
--- a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp
+++ b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp
@@ -14,7 +14,6 @@
#include "llvm/Transforms/Utils/MoveAutoInit.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/StringSet.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -50,7 +49,7 @@ static std::optional<MemoryLocation> writeToAlloca(const Instruction &I) {
else if (auto *SI = dyn_cast<StoreInst>(&I))
ML = MemoryLocation::get(SI);
else
- assert(false && "memory location set");
+ return std::nullopt;
if (isa<AllocaInst>(getUnderlyingObject(ML.Ptr)))
return ML;
@@ -202,7 +201,7 @@ static bool runMoveAutoInit(Function &F, DominatorTree &DT, MemorySSA &MSSA) {
// if two instructions are moved from the same BB to the same BB, we insert
// the second one in the front, then the first on top of it.
for (auto &Job : reverse(JobList)) {
- Job.first->moveBefore(&*Job.second->getFirstInsertionPt());
+ Job.first->moveBefore(*Job.second, Job.second->getFirstInsertionPt());
MSSAU.moveToPlace(MSSA.getMemoryAccess(Job.first), Job.first->getParent(),
MemorySSA::InsertionPlace::Beginning);
}
diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index 1f16ba78bdb0..902977b08d15 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -23,7 +23,6 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugCounter.h"
@@ -33,12 +32,6 @@
using namespace llvm;
using namespace PatternMatch;
-INITIALIZE_PASS_BEGIN(PredicateInfoPrinterLegacyPass, "print-predicateinfo",
- "PredicateInfo Printer", false, false)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo",
- "PredicateInfo Printer", false, false)
static cl::opt<bool> VerifyPredicateInfo(
"verify-predicateinfo", cl::init(false), cl::Hidden,
cl::desc("Verify PredicateInfo in legacy printer pass."));
@@ -835,20 +828,6 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
void PredicateInfo::verifyPredicateInfo() const {}
-char PredicateInfoPrinterLegacyPass::ID = 0;
-
-PredicateInfoPrinterLegacyPass::PredicateInfoPrinterLegacyPass()
- : FunctionPass(ID) {
- initializePredicateInfoPrinterLegacyPassPass(
- *PassRegistry::getPassRegistry());
-}
-
-void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.setPreservesAll();
- AU.addRequiredTransitive<DominatorTreeWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
-}
-
// Replace ssa_copy calls created by PredicateInfo with their operand.
static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) {
for (Instruction &Inst : llvm::make_early_inc_range(instructions(F))) {
@@ -862,18 +841,6 @@ static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) {
}
}
-bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) {
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC);
- PredInfo->print(dbgs());
- if (VerifyPredicateInfo)
- PredInfo->verifyPredicateInfo();
-
- replaceCreatedSSACopys(*PredInfo, F);
- return false;
-}
-
PreservedAnalyses PredicateInfoPrinterPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
index 2e5f40d39912..717b6d301c8c 100644
--- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
+++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
@@ -31,6 +31,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/DebugProgramInstruction.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
@@ -172,6 +173,7 @@ public:
struct AllocaInfo {
using DbgUserVec = SmallVector<DbgVariableIntrinsic *, 1>;
+ using DPUserVec = SmallVector<DPValue *, 1>;
SmallVector<BasicBlock *, 32> DefiningBlocks;
SmallVector<BasicBlock *, 32> UsingBlocks;
@@ -182,6 +184,7 @@ struct AllocaInfo {
/// Debug users of the alloca - does not include dbg.assign intrinsics.
DbgUserVec DbgUsers;
+ DPUserVec DPUsers;
/// Helper to update assignment tracking debug info.
AssignmentTrackingInfo AssignmentTracking;
@@ -192,6 +195,7 @@ struct AllocaInfo {
OnlyBlock = nullptr;
OnlyUsedInOneBlock = true;
DbgUsers.clear();
+ DPUsers.clear();
AssignmentTracking.clear();
}
@@ -225,7 +229,7 @@ struct AllocaInfo {
}
}
DbgUserVec AllDbgUsers;
- findDbgUsers(AllDbgUsers, AI);
+ findDbgUsers(AllDbgUsers, AI, &DPUsers);
std::copy_if(AllDbgUsers.begin(), AllDbgUsers.end(),
std::back_inserter(DbgUsers), [](DbgVariableIntrinsic *DII) {
return !isa<DbgAssignIntrinsic>(DII);
@@ -329,6 +333,7 @@ struct PromoteMem2Reg {
/// describes it, if any, so that we can convert it to a dbg.value
/// intrinsic if the alloca gets promoted.
SmallVector<AllocaInfo::DbgUserVec, 8> AllocaDbgUsers;
+ SmallVector<AllocaInfo::DPUserVec, 8> AllocaDPUsers;
/// For each alloca, keep an instance of a helper class that gives us an easy
/// way to update assignment tracking debug info if the alloca is promoted.
@@ -525,14 +530,18 @@ static bool rewriteSingleStoreAlloca(
// Record debuginfo for the store and remove the declaration's
// debuginfo.
- for (DbgVariableIntrinsic *DII : Info.DbgUsers) {
- if (DII->isAddressOfVariable()) {
- ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB);
- DII->eraseFromParent();
- } else if (DII->getExpression()->startsWithDeref()) {
- DII->eraseFromParent();
+ auto ConvertDebugInfoForStore = [&](auto &Container) {
+ for (auto *DbgItem : Container) {
+ if (DbgItem->isAddressOfVariable()) {
+ ConvertDebugDeclareToDebugValue(DbgItem, Info.OnlyStore, DIB);
+ DbgItem->eraseFromParent();
+ } else if (DbgItem->getExpression()->startsWithDeref()) {
+ DbgItem->eraseFromParent();
+ }
}
- }
+ };
+ ConvertDebugInfoForStore(Info.DbgUsers);
+ ConvertDebugInfoForStore(Info.DPUsers);
// Remove dbg.assigns linked to the alloca as these are now redundant.
at::deleteAssignmentMarkers(AI);
@@ -629,12 +638,18 @@ static bool promoteSingleBlockAlloca(
StoreInst *SI = cast<StoreInst>(AI->user_back());
// Update assignment tracking info for the store we're going to delete.
Info.AssignmentTracking.updateForDeletedStore(SI, DIB, DbgAssignsToDelete);
+
// Record debuginfo for the store before removing it.
- for (DbgVariableIntrinsic *DII : Info.DbgUsers) {
- if (DII->isAddressOfVariable()) {
- ConvertDebugDeclareToDebugValue(DII, SI, DIB);
+ auto DbgUpdateForStore = [&](auto &Container) {
+ for (auto *DbgItem : Container) {
+ if (DbgItem->isAddressOfVariable()) {
+ ConvertDebugDeclareToDebugValue(DbgItem, SI, DIB);
+ }
}
- }
+ };
+ DbgUpdateForStore(Info.DbgUsers);
+ DbgUpdateForStore(Info.DPUsers);
+
SI->eraseFromParent();
LBI.deleteValue(SI);
}
@@ -644,9 +659,14 @@ static bool promoteSingleBlockAlloca(
AI->eraseFromParent();
// The alloca's debuginfo can be removed as well.
- for (DbgVariableIntrinsic *DII : Info.DbgUsers)
- if (DII->isAddressOfVariable() || DII->getExpression()->startsWithDeref())
- DII->eraseFromParent();
+ auto DbgUpdateForAlloca = [&](auto &Container) {
+ for (auto *DbgItem : Container)
+ if (DbgItem->isAddressOfVariable() ||
+ DbgItem->getExpression()->startsWithDeref())
+ DbgItem->eraseFromParent();
+ };
+ DbgUpdateForAlloca(Info.DbgUsers);
+ DbgUpdateForAlloca(Info.DPUsers);
++NumLocalPromoted;
return true;
@@ -657,6 +677,7 @@ void PromoteMem2Reg::run() {
AllocaDbgUsers.resize(Allocas.size());
AllocaATInfo.resize(Allocas.size());
+ AllocaDPUsers.resize(Allocas.size());
AllocaInfo Info;
LargeBlockInfo LBI;
@@ -720,6 +741,8 @@ void PromoteMem2Reg::run() {
AllocaDbgUsers[AllocaNum] = Info.DbgUsers;
if (!Info.AssignmentTracking.empty())
AllocaATInfo[AllocaNum] = Info.AssignmentTracking;
+ if (!Info.DPUsers.empty())
+ AllocaDPUsers[AllocaNum] = Info.DPUsers;
// Keep the reverse mapping of the 'Allocas' array for the rename pass.
AllocaLookup[Allocas[AllocaNum]] = AllocaNum;
@@ -795,11 +818,16 @@ void PromoteMem2Reg::run() {
}
// Remove alloca's dbg.declare intrinsics from the function.
- for (auto &DbgUsers : AllocaDbgUsers) {
- for (auto *DII : DbgUsers)
- if (DII->isAddressOfVariable() || DII->getExpression()->startsWithDeref())
- DII->eraseFromParent();
- }
+ auto RemoveDbgDeclares = [&](auto &Container) {
+ for (auto &DbgUsers : Container) {
+ for (auto *DbgItem : DbgUsers)
+ if (DbgItem->isAddressOfVariable() ||
+ DbgItem->getExpression()->startsWithDeref())
+ DbgItem->eraseFromParent();
+ }
+ };
+ RemoveDbgDeclares(AllocaDbgUsers);
+ RemoveDbgDeclares(AllocaDPUsers);
// Loop over all of the PHI nodes and see if there are any that we can get
// rid of because they merge all of the same incoming values. This can
@@ -981,8 +1009,8 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo,
// Create a PhiNode using the dereferenced type... and add the phi-node to the
// BasicBlock.
PN = PHINode::Create(Allocas[AllocaNo]->getAllocatedType(), getNumPreds(BB),
- Allocas[AllocaNo]->getName() + "." + Twine(Version++),
- &BB->front());
+ Allocas[AllocaNo]->getName() + "." + Twine(Version++));
+ PN->insertBefore(BB->begin());
++NumPHIInsert;
PhiToAllocaMap[PN] = AllocaNo;
return true;
@@ -1041,9 +1069,13 @@ NextIteration:
// The currently active variable for this block is now the PHI.
IncomingVals[AllocaNo] = APN;
AllocaATInfo[AllocaNo].updateForNewPhi(APN, DIB);
- for (DbgVariableIntrinsic *DII : AllocaDbgUsers[AllocaNo])
- if (DII->isAddressOfVariable())
- ConvertDebugDeclareToDebugValue(DII, APN, DIB);
+ auto ConvertDbgDeclares = [&](auto &Container) {
+ for (auto *DbgItem : Container)
+ if (DbgItem->isAddressOfVariable())
+ ConvertDebugDeclareToDebugValue(DbgItem, APN, DIB);
+ };
+ ConvertDbgDeclares(AllocaDbgUsers[AllocaNo]);
+ ConvertDbgDeclares(AllocaDPUsers[AllocaNo]);
// Get the next phi node.
++PNI;
@@ -1098,9 +1130,13 @@ NextIteration:
IncomingLocs[AllocaNo] = SI->getDebugLoc();
AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB,
&DbgAssignsToDelete);
- for (DbgVariableIntrinsic *DII : AllocaDbgUsers[ai->second])
- if (DII->isAddressOfVariable())
- ConvertDebugDeclareToDebugValue(DII, SI, DIB);
+ auto ConvertDbgDeclares = [&](auto &Container) {
+ for (auto *DbgItem : Container)
+ if (DbgItem->isAddressOfVariable())
+ ConvertDebugDeclareToDebugValue(DbgItem, SI, DIB);
+ };
+ ConvertDbgDeclares(AllocaDbgUsers[ai->second]);
+ ConvertDbgDeclares(AllocaDPUsers[ai->second]);
SI->eraseFromParent();
}
}
diff --git a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp
index c9ff94dc9744..ea628d7c3d7d 100644
--- a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp
+++ b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp
@@ -153,17 +153,12 @@ static void convertToRelLookupTable(GlobalVariable &LookupTable) {
Builder.SetInsertPoint(Load);
Function *LoadRelIntrinsic = llvm::Intrinsic::getDeclaration(
&M, Intrinsic::load_relative, {Index->getType()});
- Value *Base = Builder.CreateBitCast(RelLookupTable, Builder.getInt8PtrTy());
// Create a call to load.relative intrinsic that computes the target address
// by adding base address (lookup table address) and relative offset.
- Value *Result = Builder.CreateCall(LoadRelIntrinsic, {Base, Offset},
+ Value *Result = Builder.CreateCall(LoadRelIntrinsic, {RelLookupTable, Offset},
"reltable.intrinsic");
- // Create a bitcast instruction if necessary.
- if (Load->getType() != Builder.getInt8PtrTy())
- Result = Builder.CreateBitCast(Result, Load->getType(), "reltable.bitcast");
-
// Replace load instruction with the new generated instruction sequence.
Load->replaceAllUsesWith(Result);
// Remove Load and GEP instructions.
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index de3626a24212..ab95698abc43 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -107,9 +107,7 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
static bool refineInstruction(SCCPSolver &Solver,
const SmallPtrSetImpl<Value *> &InsertedValues,
Instruction &Inst) {
- if (!isa<OverflowingBinaryOperator>(Inst))
- return false;
-
+ bool Changed = false;
auto GetRange = [&Solver, &InsertedValues](Value *Op) {
if (auto *Const = dyn_cast<ConstantInt>(Op))
return ConstantRange(Const->getValue());
@@ -120,23 +118,32 @@ static bool refineInstruction(SCCPSolver &Solver,
return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(),
/*UndefAllowed=*/false);
};
- auto RangeA = GetRange(Inst.getOperand(0));
- auto RangeB = GetRange(Inst.getOperand(1));
- bool Changed = false;
- if (!Inst.hasNoUnsignedWrap()) {
- auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
- Instruction::BinaryOps(Inst.getOpcode()), RangeB,
- OverflowingBinaryOperator::NoUnsignedWrap);
- if (NUWRange.contains(RangeA)) {
- Inst.setHasNoUnsignedWrap();
- Changed = true;
+
+ if (isa<OverflowingBinaryOperator>(Inst)) {
+ auto RangeA = GetRange(Inst.getOperand(0));
+ auto RangeB = GetRange(Inst.getOperand(1));
+ if (!Inst.hasNoUnsignedWrap()) {
+ auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::BinaryOps(Inst.getOpcode()), RangeB,
+ OverflowingBinaryOperator::NoUnsignedWrap);
+ if (NUWRange.contains(RangeA)) {
+ Inst.setHasNoUnsignedWrap();
+ Changed = true;
+ }
}
- }
- if (!Inst.hasNoSignedWrap()) {
- auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
- Instruction::BinaryOps(Inst.getOpcode()), RangeB, OverflowingBinaryOperator::NoSignedWrap);
- if (NSWRange.contains(RangeA)) {
- Inst.setHasNoSignedWrap();
+ if (!Inst.hasNoSignedWrap()) {
+ auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
+ Instruction::BinaryOps(Inst.getOpcode()), RangeB,
+ OverflowingBinaryOperator::NoSignedWrap);
+ if (NSWRange.contains(RangeA)) {
+ Inst.setHasNoSignedWrap();
+ Changed = true;
+ }
+ }
+ } else if (isa<ZExtInst>(Inst) && !Inst.hasNonNeg()) {
+ auto Range = GetRange(Inst.getOperand(0));
+ if (Range.isAllNonNegative()) {
+ Inst.setNonNeg();
Changed = true;
}
}
@@ -171,6 +178,7 @@ static bool replaceSignedInst(SCCPSolver &Solver,
if (InsertedValues.count(Op0) || !isNonNegative(Op0))
return false;
NewInst = new ZExtInst(Op0, Inst.getType(), "", &Inst);
+ NewInst->setNonNeg();
break;
}
case Instruction::AShr: {
@@ -179,6 +187,7 @@ static bool replaceSignedInst(SCCPSolver &Solver,
if (InsertedValues.count(Op0) || !isNonNegative(Op0))
return false;
NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", &Inst);
+ NewInst->setIsExact(Inst.isExact());
break;
}
case Instruction::SDiv:
@@ -191,6 +200,8 @@ static bool replaceSignedInst(SCCPSolver &Solver,
auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv
: Instruction::URem;
NewInst = BinaryOperator::Create(NewOpcode, Op0, Op1, "", &Inst);
+ if (Inst.getOpcode() == Instruction::SDiv)
+ NewInst->setIsExact(Inst.isExact());
break;
}
default:
@@ -1029,8 +1040,9 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
return;
}
- // Unwinding instructions successors are always executable.
- if (TI.isExceptionalTerminator()) {
+ // We cannot analyze special terminators, so consider all successors
+ // executable.
+ if (TI.isSpecialTerminator()) {
Succs.assign(TI.getNumSuccessors(), true);
return;
}
@@ -1098,13 +1110,6 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
return;
}
- // In case of callbr, we pessimistically assume that all successors are
- // feasible.
- if (isa<CallBrInst>(&TI)) {
- Succs.assign(TI.getNumSuccessors(), true);
- return;
- }
-
LLVM_DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n');
llvm_unreachable("SCCP: Don't know how to handle this terminator!");
}
@@ -1231,10 +1236,12 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) {
// Fold the constant as we build.
- Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);
- markConstant(&I, C);
- } else if (I.getDestTy()->isIntegerTy() &&
- I.getSrcTy()->isIntOrIntVectorTy()) {
+ if (Constant *C =
+ ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL))
+ return (void)markConstant(&I, C);
+ }
+
+ if (I.getDestTy()->isIntegerTy() && I.getSrcTy()->isIntOrIntVectorTy()) {
auto &LV = getValueState(&I);
ConstantRange OpRange = getConstantRange(OpSt, I.getSrcTy());
@@ -1539,11 +1546,8 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
return (void)markOverdefined(&I);
}
- Constant *Ptr = Operands[0];
- auto Indices = ArrayRef(Operands.begin() + 1, Operands.end());
- Constant *C =
- ConstantExpr::getGetElementPtr(I.getSourceElementType(), Ptr, Indices);
- markConstant(&I, C);
+ if (Constant *C = ConstantFoldInstOperands(&I, Operands, DL))
+ markConstant(&I, C);
}
void SCCPInstVisitor::visitStoreInst(StoreInst &SI) {
diff --git a/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/llvm/lib/Transforms/Utils/SSAUpdater.cpp
index ebe9cb27f5ab..fc21fb552137 100644
--- a/llvm/lib/Transforms/Utils/SSAUpdater.cpp
+++ b/llvm/lib/Transforms/Utils/SSAUpdater.cpp
@@ -156,8 +156,9 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) {
}
// Ok, we have no way out, insert a new one now.
- PHINode *InsertedPHI = PHINode::Create(ProtoType, PredValues.size(),
- ProtoName, &BB->front());
+ PHINode *InsertedPHI =
+ PHINode::Create(ProtoType, PredValues.size(), ProtoName);
+ InsertedPHI->insertBefore(BB->begin());
// Fill in all the predecessors of the PHI.
for (const auto &PredValue : PredValues)
@@ -198,12 +199,18 @@ void SSAUpdater::RewriteUse(Use &U) {
void SSAUpdater::UpdateDebugValues(Instruction *I) {
SmallVector<DbgValueInst *, 4> DbgValues;
- llvm::findDbgValues(DbgValues, I);
+ SmallVector<DPValue *, 4> DPValues;
+ llvm::findDbgValues(DbgValues, I, &DPValues);
for (auto &DbgValue : DbgValues) {
if (DbgValue->getParent() == I->getParent())
continue;
UpdateDebugValue(I, DbgValue);
}
+ for (auto &DPV : DPValues) {
+ if (DPV->getParent() == I->getParent())
+ continue;
+ UpdateDebugValue(I, DPV);
+ }
}
void SSAUpdater::UpdateDebugValues(Instruction *I,
@@ -213,16 +220,31 @@ void SSAUpdater::UpdateDebugValues(Instruction *I,
}
}
+void SSAUpdater::UpdateDebugValues(Instruction *I,
+ SmallVectorImpl<DPValue *> &DPValues) {
+ for (auto &DPV : DPValues) {
+ UpdateDebugValue(I, DPV);
+ }
+}
+
void SSAUpdater::UpdateDebugValue(Instruction *I, DbgValueInst *DbgValue) {
BasicBlock *UserBB = DbgValue->getParent();
if (HasValueForBlock(UserBB)) {
Value *NewVal = GetValueAtEndOfBlock(UserBB);
DbgValue->replaceVariableLocationOp(I, NewVal);
- }
- else
+ } else
DbgValue->setKillLocation();
}
+void SSAUpdater::UpdateDebugValue(Instruction *I, DPValue *DPV) {
+ BasicBlock *UserBB = DPV->getParent();
+ if (HasValueForBlock(UserBB)) {
+ Value *NewVal = GetValueAtEndOfBlock(UserBB);
+ DPV->replaceVariableLocationOp(I, NewVal);
+ } else
+ DPV->setKillLocation();
+}
+
void SSAUpdater::RewriteUseAfterInsertions(Use &U) {
Instruction *User = cast<Instruction>(U.getUser());
@@ -295,8 +317,9 @@ public:
/// Reserve space for the operands but do not fill them in yet.
static Value *CreateEmptyPHI(BasicBlock *BB, unsigned NumPreds,
SSAUpdater *Updater) {
- PHINode *PHI = PHINode::Create(Updater->ProtoType, NumPreds,
- Updater->ProtoName, &BB->front());
+ PHINode *PHI =
+ PHINode::Create(Updater->ProtoType, NumPreds, Updater->ProtoName);
+ PHI->insertBefore(BB->begin());
return PHI;
}
diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
index 31d62fbf0618..101b70d8def4 100644
--- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
+++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp
@@ -159,7 +159,7 @@ public:
/// Get the total flow from a given source node.
/// Returns a list of pairs (target node, amount of flow to the target).
- const std::vector<std::pair<uint64_t, int64_t>> getFlow(uint64_t Src) const {
+ std::vector<std::pair<uint64_t, int64_t>> getFlow(uint64_t Src) const {
std::vector<std::pair<uint64_t, int64_t>> Flow;
for (const auto &Edge : Edges[Src]) {
if (Edge.Flow > 0)
diff --git a/llvm/lib/Transforms/Utils/SanitizerStats.cpp b/llvm/lib/Transforms/Utils/SanitizerStats.cpp
index fd21ee4cc408..b80c5a6f9d68 100644
--- a/llvm/lib/Transforms/Utils/SanitizerStats.cpp
+++ b/llvm/lib/Transforms/Utils/SanitizerStats.cpp
@@ -21,7 +21,7 @@
using namespace llvm;
SanitizerStatReport::SanitizerStatReport(Module *M) : M(M) {
- StatTy = ArrayType::get(Type::getInt8PtrTy(M->getContext()), 2);
+ StatTy = ArrayType::get(PointerType::getUnqual(M->getContext()), 2);
EmptyModuleStatsTy = makeModuleStatsTy();
ModuleStatsGV = new GlobalVariable(*M, EmptyModuleStatsTy, false,
@@ -33,28 +33,28 @@ ArrayType *SanitizerStatReport::makeModuleStatsArrayTy() {
}
StructType *SanitizerStatReport::makeModuleStatsTy() {
- return StructType::get(M->getContext(), {Type::getInt8PtrTy(M->getContext()),
- Type::getInt32Ty(M->getContext()),
- makeModuleStatsArrayTy()});
+ return StructType::get(M->getContext(),
+ {PointerType::getUnqual(M->getContext()),
+ Type::getInt32Ty(M->getContext()),
+ makeModuleStatsArrayTy()});
}
void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) {
Function *F = B.GetInsertBlock()->getParent();
Module *M = F->getParent();
- PointerType *Int8PtrTy = B.getInt8PtrTy();
+ PointerType *PtrTy = B.getPtrTy();
IntegerType *IntPtrTy = B.getIntPtrTy(M->getDataLayout());
- ArrayType *StatTy = ArrayType::get(Int8PtrTy, 2);
+ ArrayType *StatTy = ArrayType::get(PtrTy, 2);
Inits.push_back(ConstantArray::get(
StatTy,
- {Constant::getNullValue(Int8PtrTy),
+ {Constant::getNullValue(PtrTy),
ConstantExpr::getIntToPtr(
ConstantInt::get(IntPtrTy, uint64_t(SK) << (IntPtrTy->getBitWidth() -
kSanitizerStatKindBits)),
- Int8PtrTy)}));
+ PtrTy)}));
- FunctionType *StatReportTy =
- FunctionType::get(B.getVoidTy(), Int8PtrTy, false);
+ FunctionType *StatReportTy = FunctionType::get(B.getVoidTy(), PtrTy, false);
FunctionCallee StatReport =
M->getOrInsertFunction("__sanitizer_stat_report", StatReportTy);
@@ -64,7 +64,7 @@ void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) {
ConstantInt::get(IntPtrTy, 0), ConstantInt::get(B.getInt32Ty(), 2),
ConstantInt::get(IntPtrTy, Inits.size() - 1),
});
- B.CreateCall(StatReport, ConstantExpr::getBitCast(InitAddr, Int8PtrTy));
+ B.CreateCall(StatReport, InitAddr);
}
void SanitizerStatReport::finish() {
@@ -73,7 +73,7 @@ void SanitizerStatReport::finish() {
return;
}
- PointerType *Int8PtrTy = Type::getInt8PtrTy(M->getContext());
+ PointerType *Int8PtrTy = PointerType::getUnqual(M->getContext());
IntegerType *Int32Ty = Type::getInt32Ty(M->getContext());
Type *VoidTy = Type::getVoidTy(M->getContext());
@@ -85,8 +85,7 @@ void SanitizerStatReport::finish() {
{Constant::getNullValue(Int8PtrTy),
ConstantInt::get(Int32Ty, Inits.size()),
ConstantArray::get(makeModuleStatsArrayTy(), Inits)}));
- ModuleStatsGV->replaceAllUsesWith(
- ConstantExpr::getBitCast(NewModuleStatsGV, ModuleStatsGV->getType()));
+ ModuleStatsGV->replaceAllUsesWith(NewModuleStatsGV);
ModuleStatsGV->eraseFromParent();
// Create a global constructor to register NewModuleStatsGV.
@@ -99,7 +98,7 @@ void SanitizerStatReport::finish() {
FunctionCallee StatInit =
M->getOrInsertFunction("__sanitizer_stat_init", StatInitTy);
- B.CreateCall(StatInit, ConstantExpr::getBitCast(NewModuleStatsGV, Int8PtrTy));
+ B.CreateCall(StatInit, NewModuleStatsGV);
B.CreateRetVoid();
appendToGlobalCtors(*M, F, 0);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 20844271b943..cd3ac317cd23 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -170,11 +170,10 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) {
if (Op == Instruction::IntToPtr) {
auto *PtrTy = cast<PointerType>(Ty);
if (DL.isNonIntegralPointerType(PtrTy)) {
- auto *Int8PtrTy = Builder.getInt8PtrTy(PtrTy->getAddressSpace());
assert(DL.getTypeAllocSize(Builder.getInt8Ty()) == 1 &&
"alloc size of i8 must by 1 byte for the GEP to be correct");
return Builder.CreateGEP(
- Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "scevgep");
+ Builder.getInt8Ty(), Constant::getNullValue(PtrTy), V, "scevgep");
}
}
// Short-circuit unnecessary bitcasts.
@@ -313,11 +312,11 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
/// loop-invariant portions of expressions, after considering what
/// can be folded using target addressing modes.
///
-Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) {
+Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Value *V) {
assert(!isa<Instruction>(V) ||
SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
- Value *Idx = expandCodeForImpl(Offset, Ty);
+ Value *Idx = expand(Offset);
// Fold a GEP with constant operands.
if (Constant *CLHS = dyn_cast<Constant>(V))
@@ -339,7 +338,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) {
if (IP->getOpcode() == Instruction::GetElementPtr &&
IP->getOperand(0) == V && IP->getOperand(1) == Idx &&
cast<GEPOperator>(&*IP)->getSourceElementType() ==
- Type::getInt8Ty(Ty->getContext()))
+ Builder.getInt8Ty())
return &*IP;
if (IP == BlockBegin) break;
}
@@ -457,8 +456,6 @@ public:
}
Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
-
// Collect all the add operands in a loop, along with their associated loops.
// Iterate in reverse so that constants are emitted last, all else equal, and
// so that pointer operands are inserted first, which the code below relies on
@@ -498,20 +495,19 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
X = SE.getSCEV(U->getValue());
NewOps.push_back(X);
}
- Sum = expandAddToGEP(SE.getAddExpr(NewOps), Ty, Sum);
+ Sum = expandAddToGEP(SE.getAddExpr(NewOps), Sum);
} else if (Op->isNonConstantNegative()) {
// Instead of doing a negate and add, just do a subtract.
- Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty);
- Sum = InsertNoopCastOfTo(Sum, Ty);
+ Value *W = expand(SE.getNegativeSCEV(Op));
Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ true);
++I;
} else {
// A simple add.
- Value *W = expandCodeForImpl(Op, Ty);
- Sum = InsertNoopCastOfTo(Sum, Ty);
+ Value *W = expand(Op);
// Canonicalize a constant to the RHS.
- if (isa<Constant>(Sum)) std::swap(Sum, W);
+ if (isa<Constant>(Sum))
+ std::swap(Sum, W);
Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(),
/*IsSafeToHoist*/ true);
++I;
@@ -522,7 +518,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
}
Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
+ Type *Ty = S->getType();
// Collect all the mul operands in a loop, along with their associated loops.
// Iterate in reverse so that constants are emitted last, all else equal.
@@ -541,7 +537,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
// Expand the calculation of X pow N in the following manner:
// Let N = P1 + P2 + ... + PK, where all P are powers of 2. Then:
// X pow N = (X pow P1) * (X pow P2) * ... * (X pow PK).
- const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops, &Ty]() {
+ const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops]() {
auto E = I;
// Calculate how many times the same operand from the same loop is included
// into this power.
@@ -559,7 +555,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
// Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them
// that are needed into the result.
- Value *P = expandCodeForImpl(I->second, Ty);
+ Value *P = expand(I->second);
Value *Result = nullptr;
if (Exponent & 1)
Result = P;
@@ -584,14 +580,12 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
Prod = ExpandOpBinPowN();
} else if (I->second->isAllOnesValue()) {
// Instead of doing a multiply by negative one, just do a negate.
- Prod = InsertNoopCastOfTo(Prod, Ty);
Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod,
SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
++I;
} else {
// A simple mul.
Value *W = ExpandOpBinPowN();
- Prod = InsertNoopCastOfTo(Prod, Ty);
// Canonicalize a constant to the RHS.
if (isa<Constant>(Prod)) std::swap(Prod, W);
const APInt *RHS;
@@ -616,18 +610,16 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
}
Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
-
- Value *LHS = expandCodeForImpl(S->getLHS(), Ty);
+ Value *LHS = expand(S->getLHS());
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
const APInt &RHS = SC->getAPInt();
if (RHS.isPowerOf2())
return InsertBinop(Instruction::LShr, LHS,
- ConstantInt::get(Ty, RHS.logBase2()),
+ ConstantInt::get(SC->getType(), RHS.logBase2()),
SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
}
- Value *RHS = expandCodeForImpl(S->getRHS(), Ty);
+ Value *RHS = expand(S->getRHS());
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
@@ -803,12 +795,11 @@ bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV,
/// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may
/// need to materialize IV increments elsewhere to handle difficult situations.
Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L,
- Type *ExpandTy, Type *IntTy,
bool useSubtract) {
Value *IncV;
// If the PHI is a pointer, use a GEP, otherwise use an add or sub.
- if (ExpandTy->isPointerTy()) {
- IncV = expandAddToGEP(SE.getSCEV(StepV), IntTy, PN);
+ if (PN->getType()->isPointerTy()) {
+ IncV = expandAddToGEP(SE.getSCEV(StepV), PN);
} else {
IncV = useSubtract ?
Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") :
@@ -824,12 +815,11 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE,
const SCEVAddRecExpr *Requested,
bool &InvertStep) {
// We can't transform to match a pointer PHI.
- if (Phi->getType()->isPointerTy())
+ Type *PhiTy = Phi->getType();
+ Type *RequestedTy = Requested->getType();
+ if (PhiTy->isPointerTy() || RequestedTy->isPointerTy())
return false;
- Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType());
- Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType());
-
if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth())
return false;
@@ -886,12 +876,10 @@ static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) {
/// values, and return the PHI.
PHINode *
SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
- const Loop *L,
- Type *ExpandTy,
- Type *IntTy,
- Type *&TruncTy,
+ const Loop *L, Type *&TruncTy,
bool &InvertStep) {
- assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position");
+ assert((!IVIncInsertLoop || IVIncInsertPos) &&
+ "Uninitialized insert position");
// Reuse a previously-inserted PHI, if present.
BasicBlock *LatchBlock = L->getLoopLatch();
@@ -962,7 +950,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
// later.
AddRecPhiMatch = &PN;
IncV = TempIncV;
- TruncTy = SE.getEffectiveSCEVType(Normalized->getType());
+ TruncTy = Normalized->getType();
}
}
@@ -996,8 +984,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
assert(L->getLoopPreheader() &&
"Can't expand add recurrences without a loop preheader!");
Value *StartV =
- expandCodeForImpl(Normalized->getStart(), ExpandTy,
- L->getLoopPreheader()->getTerminator());
+ expand(Normalized->getStart(), L->getLoopPreheader()->getTerminator());
// StartV must have been be inserted into L's preheader to dominate the new
// phi.
@@ -1008,6 +995,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
// Expand code for the step value. Do this before creating the PHI so that PHI
// reuse code doesn't see an incomplete PHI.
const SCEV *Step = Normalized->getStepRecurrence(SE);
+ Type *ExpandTy = Normalized->getType();
// If the stride is negative, insert a sub instead of an add for the increment
// (unless it's a constant, because subtracts of constants are canonicalized
// to adds).
@@ -1015,8 +1003,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
if (useSubtract)
Step = SE.getNegativeSCEV(Step);
// Expand the step somewhere that dominates the loop header.
- Value *StepV = expandCodeForImpl(
- Step, IntTy, &*L->getHeader()->getFirstInsertionPt());
+ Value *StepV = expand(Step, L->getHeader()->getFirstInsertionPt());
// The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if
// we actually do emit an addition. It does not apply if we emit a
@@ -1047,7 +1034,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
Instruction *InsertPos = L == IVIncInsertLoop ?
IVIncInsertPos : Pred->getTerminator();
Builder.SetInsertPoint(InsertPos);
- Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
+ Value *IncV = expandIVInc(PN, StepV, L, useSubtract);
if (isa<OverflowingBinaryOperator>(IncV)) {
if (IncrementIsNUW)
@@ -1070,8 +1057,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
}
Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
- Type *STy = S->getType();
- Type *IntTy = SE.getEffectiveSCEVType(STy);
const Loop *L = S->getLoop();
// Determine a normalized form of this expression, which is the expression
@@ -1084,51 +1069,17 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
normalizeForPostIncUse(S, Loops, SE, /*CheckInvertible=*/false));
}
- // Strip off any non-loop-dominating component from the addrec start.
- const SCEV *Start = Normalized->getStart();
- const SCEV *PostLoopOffset = nullptr;
- if (!SE.properlyDominates(Start, L->getHeader())) {
- PostLoopOffset = Start;
- Start = SE.getConstant(Normalized->getType(), 0);
- Normalized = cast<SCEVAddRecExpr>(
- SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE),
- Normalized->getLoop(),
- Normalized->getNoWrapFlags(SCEV::FlagNW)));
- }
-
- // Strip off any non-loop-dominating component from the addrec step.
+ [[maybe_unused]] const SCEV *Start = Normalized->getStart();
const SCEV *Step = Normalized->getStepRecurrence(SE);
- const SCEV *PostLoopScale = nullptr;
- if (!SE.dominates(Step, L->getHeader())) {
- PostLoopScale = Step;
- Step = SE.getConstant(Normalized->getType(), 1);
- if (!Start->isZero()) {
- // The normalization below assumes that Start is constant zero, so if
- // it isn't re-associate Start to PostLoopOffset.
- assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?");
- PostLoopOffset = Start;
- Start = SE.getConstant(Normalized->getType(), 0);
- }
- Normalized =
- cast<SCEVAddRecExpr>(SE.getAddRecExpr(
- Start, Step, Normalized->getLoop(),
- Normalized->getNoWrapFlags(SCEV::FlagNW)));
- }
-
- // Expand the core addrec. If we need post-loop scaling, force it to
- // expand to an integer type to avoid the need for additional casting.
- Type *ExpandTy = PostLoopScale ? IntTy : STy;
- // We can't use a pointer type for the addrec if the pointer type is
- // non-integral.
- Type *AddRecPHIExpandTy =
- DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy;
+ assert(SE.properlyDominates(Start, L->getHeader()) &&
+ "Start does not properly dominate loop header");
+ assert(SE.dominates(Step, L->getHeader()) && "Step not dominate loop header");
// In some cases, we decide to reuse an existing phi node but need to truncate
// it and/or invert the step.
Type *TruncTy = nullptr;
bool InvertStep = false;
- PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy,
- IntTy, TruncTy, InvertStep);
+ PHINode *PN = getAddRecExprPHILiterally(Normalized, L, TruncTy, InvertStep);
// Accommodate post-inc mode, if necessary.
Value *Result;
@@ -1167,59 +1118,29 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
// inserting an extra IV increment. StepV might fold into PostLoopOffset,
// but hopefully expandCodeFor handles that.
bool useSubtract =
- !ExpandTy->isPointerTy() && Step->isNonConstantNegative();
+ !S->getType()->isPointerTy() && Step->isNonConstantNegative();
if (useSubtract)
Step = SE.getNegativeSCEV(Step);
Value *StepV;
{
// Expand the step somewhere that dominates the loop header.
SCEVInsertPointGuard Guard(Builder, this);
- StepV = expandCodeForImpl(
- Step, IntTy, &*L->getHeader()->getFirstInsertionPt());
+ StepV = expand(Step, L->getHeader()->getFirstInsertionPt());
}
- Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
+ Result = expandIVInc(PN, StepV, L, useSubtract);
}
}
// We have decided to reuse an induction variable of a dominating loop. Apply
// truncation and/or inversion of the step.
if (TruncTy) {
- Type *ResTy = Result->getType();
- // Normalize the result type.
- if (ResTy != SE.getEffectiveSCEVType(ResTy))
- Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy));
// Truncate the result.
if (TruncTy != Result->getType())
Result = Builder.CreateTrunc(Result, TruncTy);
// Invert the result.
if (InvertStep)
- Result = Builder.CreateSub(
- expandCodeForImpl(Normalized->getStart(), TruncTy), Result);
- }
-
- // Re-apply any non-loop-dominating scale.
- if (PostLoopScale) {
- assert(S->isAffine() && "Can't linearly scale non-affine recurrences.");
- Result = InsertNoopCastOfTo(Result, IntTy);
- Result = Builder.CreateMul(Result,
- expandCodeForImpl(PostLoopScale, IntTy));
- }
-
- // Re-apply any non-loop-dominating offset.
- if (PostLoopOffset) {
- if (isa<PointerType>(ExpandTy)) {
- if (Result->getType()->isIntegerTy()) {
- Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy);
- Result = expandAddToGEP(SE.getUnknown(Result), IntTy, Base);
- } else {
- Result = expandAddToGEP(PostLoopOffset, IntTy, Result);
- }
- } else {
- Result = InsertNoopCastOfTo(Result, IntTy);
- Result = Builder.CreateAdd(
- Result, expandCodeForImpl(PostLoopOffset, IntTy));
- }
+ Result = Builder.CreateSub(expand(Normalized->getStart()), Result);
}
return Result;
@@ -1260,8 +1181,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
S->getNoWrapFlags(SCEV::FlagNW)));
BasicBlock::iterator NewInsertPt =
findInsertPointAfter(cast<Instruction>(V), &*Builder.GetInsertPoint());
- V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
- &*NewInsertPt);
+ V = expand(SE.getTruncateExpr(SE.getUnknown(V), Ty), NewInsertPt);
return V;
}
@@ -1269,7 +1189,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
if (!S->getStart()->isZero()) {
if (isa<PointerType>(S->getType())) {
Value *StartV = expand(SE.getPointerBase(S));
- return expandAddToGEP(SE.removePointerBase(S), Ty, StartV);
+ return expandAddToGEP(SE.removePointerBase(S), StartV);
}
SmallVector<const SCEV *, 4> NewOps(S->operands());
@@ -1292,8 +1212,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
// specified loop.
BasicBlock *Header = L->getHeader();
pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header);
- CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar",
- &Header->front());
+ CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar");
+ CanonicalIV->insertBefore(Header->begin());
rememberInstruction(CanonicalIV);
SmallSet<BasicBlock *, 4> PredSeen;
@@ -1361,34 +1281,25 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
}
Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) {
- Value *V =
- expandCodeForImpl(S->getOperand(), S->getOperand()->getType());
+ Value *V = expand(S->getOperand());
return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
GetOptimalInsertionPointForCastOf(V));
}
Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
- Value *V = expandCodeForImpl(
- S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
- );
- return Builder.CreateTrunc(V, Ty);
+ Value *V = expand(S->getOperand());
+ return Builder.CreateTrunc(V, S->getType());
}
Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
- Value *V = expandCodeForImpl(
- S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
- );
- return Builder.CreateZExt(V, Ty);
+ Value *V = expand(S->getOperand());
+ return Builder.CreateZExt(V, S->getType(), "",
+ SE.isKnownNonNegative(S->getOperand()));
}
Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
- Type *Ty = SE.getEffectiveSCEVType(S->getType());
- Value *V = expandCodeForImpl(
- S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
- );
- return Builder.CreateSExt(V, Ty);
+ Value *V = expand(S->getOperand());
+ return Builder.CreateSExt(V, S->getType());
}
Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
@@ -1399,7 +1310,7 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
if (IsSequential)
LHS = Builder.CreateFreeze(LHS);
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
- Value *RHS = expandCodeForImpl(S->getOperand(i), Ty);
+ Value *RHS = expand(S->getOperand(i));
if (IsSequential && i != 0)
RHS = Builder.CreateFreeze(RHS);
Value *Sel;
@@ -1440,14 +1351,14 @@ Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
}
-Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
- Instruction *IP) {
+Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
+ BasicBlock::iterator IP) {
setInsertPoint(IP);
- Value *V = expandCodeForImpl(SH, Ty);
+ Value *V = expandCodeFor(SH, Ty);
return V;
}
-Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) {
+Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
// Expand the code for this SCEV.
Value *V = expand(SH);
@@ -1459,8 +1370,64 @@ Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) {
return V;
}
-Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S,
- const Instruction *InsertPt) {
+static bool
+canReuseInstruction(ScalarEvolution &SE, const SCEV *S, Instruction *I,
+ SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
+ // If the instruction cannot be poison, it's always safe to reuse.
+ if (programUndefinedIfPoison(I))
+ return true;
+
+ // Otherwise, it is possible that I is more poisonous that S. Collect the
+ // poison-contributors of S, and then check whether I has any additional
+ // poison-contributors. Poison that is contributed through poison-generating
+ // flags is handled by dropping those flags instead.
+ SmallPtrSet<const Value *, 8> PoisonVals;
+ SE.getPoisonGeneratingValues(PoisonVals, S);
+
+ SmallVector<Value *> Worklist;
+ SmallPtrSet<Value *, 8> Visited;
+ Worklist.push_back(I);
+ while (!Worklist.empty()) {
+ Value *V = Worklist.pop_back_val();
+ if (!Visited.insert(V).second)
+ continue;
+
+ // Avoid walking large instruction graphs.
+ if (Visited.size() > 16)
+ return false;
+
+ // Either the value can't be poison, or the S would also be poison if it
+ // is.
+ if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
+ continue;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return false;
+
+ // FIXME: Ignore vscale, even though it technically could be poison. Do this
+ // because SCEV currently assumes it can't be poison. Remove this special
+ // case once we proper model when vscale can be poison.
+ if (auto *II = dyn_cast<IntrinsicInst>(I);
+ II && II->getIntrinsicID() == Intrinsic::vscale)
+ continue;
+
+ if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
+ return false;
+
+ // If the instruction can't create poison, we can recurse to its operands.
+ if (I->hasPoisonGeneratingFlagsOrMetadata())
+ DropPoisonGeneratingInsts.push_back(I);
+
+ for (Value *Op : I->operands())
+ Worklist.push_back(Op);
+ }
+ return true;
+}
+
+Value *SCEVExpander::FindValueInExprValueMap(
+ const SCEV *S, const Instruction *InsertPt,
+ SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
// If the expansion is not in CanonicalMode, and the SCEV contains any
// sub scAddRecExpr type SCEV, it is required to expand the SCEV literally.
if (!CanonicalMode && SE.containsAddRecurrence(S))
@@ -1470,20 +1437,24 @@ Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S,
if (isa<SCEVConstant>(S))
return nullptr;
- // Choose a Value from the set which dominates the InsertPt.
- // InsertPt should be inside the Value's parent loop so as not to break
- // the LCSSA form.
for (Value *V : SE.getSCEVValues(S)) {
Instruction *EntInst = dyn_cast<Instruction>(V);
if (!EntInst)
continue;
+ // Choose a Value from the set which dominates the InsertPt.
+ // InsertPt should be inside the Value's parent loop so as not to break
+ // the LCSSA form.
assert(EntInst->getFunction() == InsertPt->getFunction());
- if (S->getType() == V->getType() &&
- SE.DT.dominates(EntInst, InsertPt) &&
- (SE.LI.getLoopFor(EntInst->getParent()) == nullptr ||
- SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)))
+ if (S->getType() != V->getType() || !SE.DT.dominates(EntInst, InsertPt) ||
+ !(SE.LI.getLoopFor(EntInst->getParent()) == nullptr ||
+ SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)))
+ continue;
+
+ // Make sure reusing the instruction is poison-safe.
+ if (canReuseInstruction(SE, S, EntInst, DropPoisonGeneratingInsts))
return V;
+ DropPoisonGeneratingInsts.clear();
}
return nullptr;
}
@@ -1497,7 +1468,7 @@ Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S,
Value *SCEVExpander::expand(const SCEV *S) {
// Compute an insertion point for this SCEV object. Hoist the instructions
// as far out in the loop nest as possible.
- Instruction *InsertPt = &*Builder.GetInsertPoint();
+ BasicBlock::iterator InsertPt = Builder.GetInsertPoint();
// We can move insertion point only if there is no div or rem operations
// otherwise we are risky to move it over the check for zero denominator.
@@ -1521,24 +1492,25 @@ Value *SCEVExpander::expand(const SCEV *S) {
L = L->getParentLoop()) {
if (SE.isLoopInvariant(S, L)) {
if (!L) break;
- if (BasicBlock *Preheader = L->getLoopPreheader())
- InsertPt = Preheader->getTerminator();
- else
+ if (BasicBlock *Preheader = L->getLoopPreheader()) {
+ InsertPt = Preheader->getTerminator()->getIterator();
+ } else {
// LSR sets the insertion point for AddRec start/step values to the
// block start to simplify value reuse, even though it's an invalid
// position. SCEVExpander must correct for this in all cases.
- InsertPt = &*L->getHeader()->getFirstInsertionPt();
+ InsertPt = L->getHeader()->getFirstInsertionPt();
+ }
} else {
// If the SCEV is computable at this level, insert it into the header
// after the PHIs (and after any other instructions that we've inserted
// there) so that it is guaranteed to dominate any user inside the loop.
if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L))
- InsertPt = &*L->getHeader()->getFirstInsertionPt();
+ InsertPt = L->getHeader()->getFirstInsertionPt();
- while (InsertPt->getIterator() != Builder.GetInsertPoint() &&
- (isInsertedInstruction(InsertPt) ||
- isa<DbgInfoIntrinsic>(InsertPt))) {
- InsertPt = &*std::next(InsertPt->getIterator());
+ while (InsertPt != Builder.GetInsertPoint() &&
+ (isInsertedInstruction(&*InsertPt) ||
+ isa<DbgInfoIntrinsic>(&*InsertPt))) {
+ InsertPt = std::next(InsertPt);
}
break;
}
@@ -1546,26 +1518,40 @@ Value *SCEVExpander::expand(const SCEV *S) {
}
// Check to see if we already expanded this here.
- auto I = InsertedExpressions.find(std::make_pair(S, InsertPt));
+ auto I = InsertedExpressions.find(std::make_pair(S, &*InsertPt));
if (I != InsertedExpressions.end())
return I->second;
SCEVInsertPointGuard Guard(Builder, this);
- Builder.SetInsertPoint(InsertPt);
+ Builder.SetInsertPoint(InsertPt->getParent(), InsertPt);
// Expand the expression into instructions.
- Value *V = FindValueInExprValueMap(S, InsertPt);
+ SmallVector<Instruction *> DropPoisonGeneratingInsts;
+ Value *V = FindValueInExprValueMap(S, &*InsertPt, DropPoisonGeneratingInsts);
if (!V) {
V = visit(S);
V = fixupLCSSAFormFor(V);
} else {
- // If we're reusing an existing instruction, we are effectively CSEing two
- // copies of the instruction (with potentially different flags). As such,
- // we need to drop any poison generating flags unless we can prove that
- // said flags must be valid for all new users.
- if (auto *I = dyn_cast<Instruction>(V))
- if (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I))
- I->dropPoisonGeneratingFlags();
+ for (Instruction *I : DropPoisonGeneratingInsts) {
+ I->dropPoisonGeneratingFlagsAndMetadata();
+ // See if we can re-infer from first principles any of the flags we just
+ // dropped.
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I))
+ if (auto Flags = SE.getStrengthenedNoWrapFlagsFromBinOp(OBO)) {
+ auto *BO = cast<BinaryOperator>(I);
+ BO->setHasNoUnsignedWrap(
+ ScalarEvolution::maskFlags(*Flags, SCEV::FlagNUW) == SCEV::FlagNUW);
+ BO->setHasNoSignedWrap(
+ ScalarEvolution::maskFlags(*Flags, SCEV::FlagNSW) == SCEV::FlagNSW);
+ }
+ if (auto *NNI = dyn_cast<PossiblyNonNegInst>(I)) {
+ auto *Src = NNI->getOperand(0);
+ if (isImpliedByDomCondition(ICmpInst::ICMP_SGE, Src,
+ Constant::getNullValue(Src->getType()), I,
+ DL).value_or(false))
+ NNI->setNonNeg(true);
+ }
+ }
}
// Remember the expanded value for this SCEV at this location.
//
@@ -1573,7 +1559,7 @@ Value *SCEVExpander::expand(const SCEV *S) {
// the expression at this insertion point. If the mapped value happened to be
// a postinc expansion, it could be reused by a non-postinc user, but only if
// its insertion point was already at the head of the loop.
- InsertedExpressions[std::make_pair(S, InsertPt)] = V;
+ InsertedExpressions[std::make_pair(S, &*InsertPt)] = V;
return V;
}
@@ -1710,13 +1696,13 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
<< *IsomorphicInc << '\n');
Value *NewInc = OrigInc;
if (OrigInc->getType() != IsomorphicInc->getType()) {
- Instruction *IP = nullptr;
+ BasicBlock::iterator IP;
if (PHINode *PN = dyn_cast<PHINode>(OrigInc))
- IP = &*PN->getParent()->getFirstInsertionPt();
+ IP = PN->getParent()->getFirstInsertionPt();
else
- IP = OrigInc->getNextNode();
+ IP = OrigInc->getNextNonDebugInstruction()->getIterator();
- IRBuilder<> Builder(IP);
+ IRBuilder<> Builder(IP->getParent(), IP);
Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc());
NewInc = Builder.CreateTruncOrBitCast(
OrigInc, IsomorphicInc->getType(), IVName);
@@ -1734,7 +1720,8 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
++NumElim;
Value *NewIV = OrigPhiRef;
if (OrigPhiRef->getType() != Phi->getType()) {
- IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt());
+ IRBuilder<> Builder(L->getHeader(),
+ L->getHeader()->getFirstInsertionPt());
Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName);
}
@@ -1744,9 +1731,9 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
return NumElim;
}
-Value *SCEVExpander::getRelatedExistingExpansion(const SCEV *S,
- const Instruction *At,
- Loop *L) {
+bool SCEVExpander::hasRelatedExistingExpansion(const SCEV *S,
+ const Instruction *At,
+ Loop *L) {
using namespace llvm::PatternMatch;
SmallVector<BasicBlock *, 4> ExitingBlocks;
@@ -1763,17 +1750,18 @@ Value *SCEVExpander::getRelatedExistingExpansion(const SCEV *S,
continue;
if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At))
- return LHS;
+ return true;
if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At))
- return RHS;
+ return true;
}
// Use expand's logic which is used for reusing a previous Value in
// ExprValueMap. Note that we don't currently model the cost of
// needing to drop poison generating flags on the instruction if we
// want to reuse it. We effectively assume that has zero cost.
- return FindValueInExprValueMap(S, At);
+ SmallVector<Instruction *> DropPoisonGeneratingInsts;
+ return FindValueInExprValueMap(S, At, DropPoisonGeneratingInsts) != nullptr;
}
template<typename T> static InstructionCost costAndCollectOperands(
@@ -1951,7 +1939,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
// If we can find an existing value for this scev available at the point "At"
// then consider the expression cheap.
- if (getRelatedExistingExpansion(S, &At, L))
+ if (hasRelatedExistingExpansion(S, &At, L))
return false; // Consider the expression to be free.
TargetTransformInfo::TargetCostKind CostKind =
@@ -1993,7 +1981,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
// At the beginning of this function we already tried to find existing
// value for plain 'S'. Now try to lookup 'S + 1' since it is common
// pattern involving division. This is just a simple search heuristic.
- if (getRelatedExistingExpansion(
+ if (hasRelatedExistingExpansion(
SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), &At, L))
return false; // Consider it to be free.
@@ -2045,10 +2033,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred,
Instruction *IP) {
- Value *Expr0 =
- expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP);
- Value *Expr1 =
- expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+ Value *Expr0 = expand(Pred->getLHS(), IP);
+ Value *Expr1 = expand(Pred->getRHS(), IP);
Builder.SetInsertPoint(IP);
auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate());
@@ -2080,17 +2066,15 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
// Step >= 0, Start + |Step| * Backedge > Start
// and |Step| * Backedge doesn't unsigned overflow.
- IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
Builder.SetInsertPoint(Loc);
- Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc);
+ Value *TripCountVal = expand(ExitCount, Loc);
IntegerType *Ty =
IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));
- Value *StepValue = expandCodeForImpl(Step, Ty, Loc);
- Value *NegStepValue =
- expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc);
- Value *StartValue = expandCodeForImpl(Start, ARTy, Loc);
+ Value *StepValue = expand(Step, Loc);
+ Value *NegStepValue = expand(SE.getNegativeSCEV(Step), Loc);
+ Value *StartValue = expand(Start, Loc);
ConstantInt *Zero =
ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits));
@@ -2136,9 +2120,7 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
bool NeedPosCheck = !SE.isKnownNegative(Step);
bool NeedNegCheck = !SE.isKnownPositive(Step);
- if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARTy)) {
- StartValue = InsertNoopCastOfTo(
- StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace()));
+ if (isa<PointerType>(ARTy)) {
Value *NegMulV = Builder.CreateNeg(MulV);
if (NeedPosCheck)
Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV);
@@ -2171,7 +2153,7 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
// If the backedge taken count type is larger than the AR type,
// check that we don't drop any bits by truncating it. If we are
// dropping bits, then we have overflow (unless the step is zero).
- if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) {
+ if (SrcBits > DstBits) {
auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits);
auto *BackedgeCheck =
Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal,
@@ -2244,7 +2226,7 @@ Value *SCEVExpander::fixupLCSSAFormFor(Value *V) {
// instruction.
Type *ToTy;
if (DefI->getType()->isIntegerTy())
- ToTy = DefI->getType()->getPointerTo();
+ ToTy = PointerType::get(DefI->getContext(), 0);
else
ToTy = Type::getInt32Ty(DefI->getContext());
Instruction *User =
@@ -2306,12 +2288,6 @@ struct SCEVFindUnsafe {
}
}
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
- const SCEV *Step = AR->getStepRecurrence(SE);
- if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) {
- IsUnsafe = true;
- return false;
- }
-
// For non-affine addrecs or in non-canonical mode we need a preheader
// to insert into.
if (!AR->getLoop()->getLoopPreheader() &&
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index d3a9a41aef15..c09cf9c2325c 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -271,7 +271,10 @@ class SimplifyCFGOpt {
bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI,
IRBuilder<> &Builder);
- bool HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly);
+ bool hoistCommonCodeFromSuccessors(BasicBlock *BB, bool EqTermsOnly);
+ bool hoistSuccIdenticalTerminatorToSwitchOrIf(
+ Instruction *TI, Instruction *I1,
+ SmallVectorImpl<Instruction *> &OtherSuccTIs);
bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB);
bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond,
BasicBlock *TrueBB, BasicBlock *FalseBB,
@@ -499,7 +502,7 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout &DL) {
return CI;
else
return cast<ConstantInt>(
- ConstantExpr::getIntegerCast(CI, PtrTy, /*isSigned=*/false));
+ ConstantFoldIntegerCast(CI, PtrTy, /*isSigned=*/false, DL));
}
return nullptr;
}
@@ -819,7 +822,7 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases(
static void
EliminateBlockCases(BasicBlock *BB,
std::vector<ValueEqualityComparisonCase> &Cases) {
- llvm::erase_value(Cases, BB);
+ llvm::erase(Cases, BB);
}
/// Return true if there are any keys in C1 that exist in C2 as well.
@@ -1098,12 +1101,13 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
// Note that there may be multiple predecessor blocks, so we cannot move
// bonus instructions to a predecessor block.
for (Instruction &BonusInst : *BB) {
- if (isa<DbgInfoIntrinsic>(BonusInst) || BonusInst.isTerminator())
+ if (BonusInst.isTerminator())
continue;
Instruction *NewBonusInst = BonusInst.clone();
- if (PTI->getDebugLoc() != NewBonusInst->getDebugLoc()) {
+ if (!isa<DbgInfoIntrinsic>(BonusInst) &&
+ PTI->getDebugLoc() != NewBonusInst->getDebugLoc()) {
// Unless the instruction has the same !dbg location as the original
// branch, drop it. When we fold the bonus instructions we want to make
// sure we reset their debug locations in order to avoid stepping on
@@ -1113,7 +1117,6 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
RemapInstruction(NewBonusInst, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
- VMap[&BonusInst] = NewBonusInst;
// If we speculated an instruction, we need to drop any metadata that may
// result in undefined behavior, as the metadata might have been valid
@@ -1123,8 +1126,16 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
NewBonusInst->dropUBImplyingAttrsAndMetadata();
NewBonusInst->insertInto(PredBlock, PTI->getIterator());
+ auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
+ RemapDPValueRange(NewBonusInst->getModule(), Range, VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+
+ if (isa<DbgInfoIntrinsic>(BonusInst))
+ continue;
+
NewBonusInst->takeName(&BonusInst);
BonusInst.setName(NewBonusInst->getName() + ".old");
+ VMap[&BonusInst] = NewBonusInst;
// Update (liveout) uses of bonus instructions,
// now that the bonus instruction has been cloned into predecessor.
@@ -1303,7 +1314,7 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
}
for (const std::pair<BasicBlock *, int /*Num*/> &NewSuccessor :
NewSuccessors) {
- for (auto I : seq(0, NewSuccessor.second)) {
+ for (auto I : seq(NewSuccessor.second)) {
(void)I;
AddPredecessorToBlock(NewSuccessor.first, Pred, BB);
}
@@ -1408,8 +1419,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(Instruction *TI,
}
// If we would need to insert a select that uses the value of this invoke
-// (comments in HoistThenElseCodeToIf explain why we would need to do this), we
-// can't hoist the invoke, as there is nowhere to put the select in this case.
+// (comments in hoistSuccIdenticalTerminatorToSwitchOrIf explain why we would
+// need to do this), we can't hoist the invoke, as there is nowhere to put the
+// select in this case.
static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2,
Instruction *I1, Instruction *I2) {
for (BasicBlock *Succ : successors(BB1)) {
@@ -1424,9 +1436,9 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2,
return true;
}
-// Get interesting characteristics of instructions that `HoistThenElseCodeToIf`
-// didn't hoist. They restrict what kind of instructions can be reordered
-// across.
+// Get interesting characteristics of instructions that
+// `hoistCommonCodeFromSuccessors` didn't hoist. They restrict what kind of
+// instructions can be reordered across.
enum SkipFlags {
SkipReadMem = 1,
SkipSideEffect = 2,
@@ -1484,7 +1496,7 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) {
static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValueMayBeModified = false);
-/// Helper function for HoistThenElseCodeToIf. Return true if identical
+/// Helper function for hoistCommonCodeFromSuccessors. Return true if identical
/// instructions \p I1 and \p I2 can and should be hoisted.
static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2,
const TargetTransformInfo &TTI) {
@@ -1515,62 +1527,51 @@ static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2,
return true;
}
-/// Given a conditional branch that goes to BB1 and BB2, hoist any common code
-/// in the two blocks up into the branch block. The caller of this function
-/// guarantees that BI's block dominates BB1 and BB2. If EqTermsOnly is given,
-/// only perform hoisting in case both blocks only contain a terminator. In that
-/// case, only the original BI will be replaced and selects for PHIs are added.
-bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) {
+/// Hoist any common code in the successor blocks up into the block. This
+/// function guarantees that BB dominates all successors. If EqTermsOnly is
+/// given, only perform hoisting in case both blocks only contain a terminator.
+/// In that case, only the original BI will be replaced and selects for PHIs are
+/// added.
+bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB,
+ bool EqTermsOnly) {
// This does very trivial matching, with limited scanning, to find identical
- // instructions in the two blocks. In particular, we don't want to get into
- // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As
+ // instructions in the two blocks. In particular, we don't want to get into
+ // O(N1*N2*...) situations here where Ni are the sizes of these successors. As
// such, we currently just scan for obviously identical instructions in an
// identical order, possibly separated by the same number of non-identical
// instructions.
- BasicBlock *BB1 = BI->getSuccessor(0); // The true destination.
- BasicBlock *BB2 = BI->getSuccessor(1); // The false destination
+ unsigned int SuccSize = succ_size(BB);
+ if (SuccSize < 2)
+ return false;
// If either of the blocks has it's address taken, then we can't do this fold,
// because the code we'd hoist would no longer run when we jump into the block
// by it's address.
- if (BB1->hasAddressTaken() || BB2->hasAddressTaken())
- return false;
+ for (auto *Succ : successors(BB))
+ if (Succ->hasAddressTaken() || !Succ->getSinglePredecessor())
+ return false;
- BasicBlock::iterator BB1_Itr = BB1->begin();
- BasicBlock::iterator BB2_Itr = BB2->begin();
+ auto *TI = BB->getTerminator();
- Instruction *I1 = &*BB1_Itr++, *I2 = &*BB2_Itr++;
- // Skip debug info if it is not identical.
- DbgInfoIntrinsic *DBI1 = dyn_cast<DbgInfoIntrinsic>(I1);
- DbgInfoIntrinsic *DBI2 = dyn_cast<DbgInfoIntrinsic>(I2);
- if (!DBI1 || !DBI2 || !DBI1->isIdenticalToWhenDefined(DBI2)) {
- while (isa<DbgInfoIntrinsic>(I1))
- I1 = &*BB1_Itr++;
- while (isa<DbgInfoIntrinsic>(I2))
- I2 = &*BB2_Itr++;
+ // The second of pair is a SkipFlags bitmask.
+ using SuccIterPair = std::pair<BasicBlock::iterator, unsigned>;
+ SmallVector<SuccIterPair, 8> SuccIterPairs;
+ for (auto *Succ : successors(BB)) {
+ BasicBlock::iterator SuccItr = Succ->begin();
+ if (isa<PHINode>(*SuccItr))
+ return false;
+ SuccIterPairs.push_back(SuccIterPair(SuccItr, 0));
}
- if (isa<PHINode>(I1))
- return false;
-
- BasicBlock *BIParent = BI->getParent();
-
- bool Changed = false;
-
- auto _ = make_scope_exit([&]() {
- if (Changed)
- ++NumHoistCommonCode;
- });
// Check if only hoisting terminators is allowed. This does not add new
// instructions to the hoist location.
if (EqTermsOnly) {
// Skip any debug intrinsics, as they are free to hoist.
- auto *I1NonDbg = &*skipDebugIntrinsics(I1->getIterator());
- auto *I2NonDbg = &*skipDebugIntrinsics(I2->getIterator());
- if (!I1NonDbg->isIdenticalToWhenDefined(I2NonDbg))
- return false;
- if (!I1NonDbg->isTerminator())
- return false;
+ for (auto &SuccIter : make_first_range(SuccIterPairs)) {
+ auto *INonDbg = &*skipDebugIntrinsics(SuccIter);
+ if (!INonDbg->isTerminator())
+ return false;
+ }
// Now we know that we only need to hoist debug intrinsics and the
// terminator. Let the loop below handle those 2 cases.
}
@@ -1579,153 +1580,235 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) {
// many instructions we skip, serving as a compilation time control as well as
// preventing excessive increase of life ranges.
unsigned NumSkipped = 0;
+ // If we find an unreachable instruction at the beginning of a basic block, we
+ // can still hoist instructions from the rest of the basic blocks.
+ if (SuccIterPairs.size() > 2) {
+ erase_if(SuccIterPairs,
+ [](const auto &Pair) { return isa<UnreachableInst>(Pair.first); });
+ if (SuccIterPairs.size() < 2)
+ return false;
+ }
- // Record any skipped instuctions that may read memory, write memory or have
- // side effects, or have implicit control flow.
- unsigned SkipFlagsBB1 = 0;
- unsigned SkipFlagsBB2 = 0;
+ bool Changed = false;
for (;;) {
+ auto *SuccIterPairBegin = SuccIterPairs.begin();
+ auto &BB1ItrPair = *SuccIterPairBegin++;
+ auto OtherSuccIterPairRange =
+ iterator_range(SuccIterPairBegin, SuccIterPairs.end());
+ auto OtherSuccIterRange = make_first_range(OtherSuccIterPairRange);
+
+ Instruction *I1 = &*BB1ItrPair.first;
+ auto *BB1 = I1->getParent();
+
+ // Skip debug info if it is not identical.
+ bool AllDbgInstsAreIdentical = all_of(OtherSuccIterRange, [I1](auto &Iter) {
+ Instruction *I2 = &*Iter;
+ return I1->isIdenticalToWhenDefined(I2);
+ });
+ if (!AllDbgInstsAreIdentical) {
+ while (isa<DbgInfoIntrinsic>(I1))
+ I1 = &*++BB1ItrPair.first;
+ for (auto &SuccIter : OtherSuccIterRange) {
+ Instruction *I2 = &*SuccIter;
+ while (isa<DbgInfoIntrinsic>(I2))
+ I2 = &*++SuccIter;
+ }
+ }
+
+ bool AllInstsAreIdentical = true;
+ bool HasTerminator = I1->isTerminator();
+ for (auto &SuccIter : OtherSuccIterRange) {
+ Instruction *I2 = &*SuccIter;
+ HasTerminator |= I2->isTerminator();
+ if (AllInstsAreIdentical && !I1->isIdenticalToWhenDefined(I2))
+ AllInstsAreIdentical = false;
+ }
+
// If we are hoisting the terminator instruction, don't move one (making a
// broken BB), instead clone it, and remove BI.
- if (I1->isTerminator() || I2->isTerminator()) {
+ if (HasTerminator) {
+ // Even if BB, which contains only one unreachable instruction, is ignored
+ // at the beginning of the loop, we can hoist the terminator instruction.
// If any instructions remain in the block, we cannot hoist terminators.
- if (NumSkipped || !I1->isIdenticalToWhenDefined(I2))
+ if (NumSkipped || !AllInstsAreIdentical)
return Changed;
- goto HoistTerminator;
+ SmallVector<Instruction *, 8> Insts;
+ for (auto &SuccIter : OtherSuccIterRange)
+ Insts.push_back(&*SuccIter);
+ return hoistSuccIdenticalTerminatorToSwitchOrIf(TI, I1, Insts) || Changed;
}
- if (I1->isIdenticalToWhenDefined(I2) &&
- // Even if the instructions are identical, it may not be safe to hoist
- // them if we have skipped over instructions with side effects or their
- // operands weren't hoisted.
- isSafeToHoistInstr(I1, SkipFlagsBB1) &&
- isSafeToHoistInstr(I2, SkipFlagsBB2) &&
- shouldHoistCommonInstructions(I1, I2, TTI)) {
- if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) {
- assert(isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2));
+ if (AllInstsAreIdentical) {
+ unsigned SkipFlagsBB1 = BB1ItrPair.second;
+ AllInstsAreIdentical =
+ isSafeToHoistInstr(I1, SkipFlagsBB1) &&
+ all_of(OtherSuccIterPairRange, [=](const auto &Pair) {
+ Instruction *I2 = &*Pair.first;
+ unsigned SkipFlagsBB2 = Pair.second;
+ // Even if the instructions are identical, it may not
+ // be safe to hoist them if we have skipped over
+ // instructions with side effects or their operands
+ // weren't hoisted.
+ return isSafeToHoistInstr(I2, SkipFlagsBB2) &&
+ shouldHoistCommonInstructions(I1, I2, TTI);
+ });
+ }
+
+ if (AllInstsAreIdentical) {
+ BB1ItrPair.first++;
+ if (isa<DbgInfoIntrinsic>(I1)) {
// The debug location is an integral part of a debug info intrinsic
// and can't be separated from it or replaced. Instead of attempting
// to merge locations, simply hoist both copies of the intrinsic.
- BIParent->splice(BI->getIterator(), BB1, I1->getIterator());
- BIParent->splice(BI->getIterator(), BB2, I2->getIterator());
+ I1->moveBeforePreserving(TI);
+ for (auto &SuccIter : OtherSuccIterRange) {
+ auto *I2 = &*SuccIter++;
+ assert(isa<DbgInfoIntrinsic>(I2));
+ I2->moveBeforePreserving(TI);
+ }
} else {
// For a normal instruction, we just move one to right before the
// branch, then replace all uses of the other with the first. Finally,
// we remove the now redundant second instruction.
- BIParent->splice(BI->getIterator(), BB1, I1->getIterator());
- if (!I2->use_empty())
- I2->replaceAllUsesWith(I1);
- I1->andIRFlags(I2);
- combineMetadataForCSE(I1, I2, true);
-
- // I1 and I2 are being combined into a single instruction. Its debug
- // location is the merged locations of the original instructions.
- I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc());
-
- I2->eraseFromParent();
+ I1->moveBeforePreserving(TI);
+ BB->splice(TI->getIterator(), BB1, I1->getIterator());
+ for (auto &SuccIter : OtherSuccIterRange) {
+ Instruction *I2 = &*SuccIter++;
+ assert(I2 != I1);
+ if (!I2->use_empty())
+ I2->replaceAllUsesWith(I1);
+ I1->andIRFlags(I2);
+ combineMetadataForCSE(I1, I2, true);
+ // I1 and I2 are being combined into a single instruction. Its debug
+ // location is the merged locations of the original instructions.
+ I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc());
+ I2->eraseFromParent();
+ }
}
+ if (!Changed)
+ NumHoistCommonCode += SuccIterPairs.size();
Changed = true;
- ++NumHoistCommonInstrs;
+ NumHoistCommonInstrs += SuccIterPairs.size();
} else {
if (NumSkipped >= HoistCommonSkipLimit)
return Changed;
// We are about to skip over a pair of non-identical instructions. Record
// if any have characteristics that would prevent reordering instructions
// across them.
- SkipFlagsBB1 |= skippedInstrFlags(I1);
- SkipFlagsBB2 |= skippedInstrFlags(I2);
+ for (auto &SuccIterPair : SuccIterPairs) {
+ Instruction *I = &*SuccIterPair.first++;
+ SuccIterPair.second |= skippedInstrFlags(I);
+ }
++NumSkipped;
}
-
- I1 = &*BB1_Itr++;
- I2 = &*BB2_Itr++;
- // Skip debug info if it is not identical.
- DbgInfoIntrinsic *DBI1 = dyn_cast<DbgInfoIntrinsic>(I1);
- DbgInfoIntrinsic *DBI2 = dyn_cast<DbgInfoIntrinsic>(I2);
- if (!DBI1 || !DBI2 || !DBI1->isIdenticalToWhenDefined(DBI2)) {
- while (isa<DbgInfoIntrinsic>(I1))
- I1 = &*BB1_Itr++;
- while (isa<DbgInfoIntrinsic>(I2))
- I2 = &*BB2_Itr++;
- }
}
+}
- return Changed;
+bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf(
+ Instruction *TI, Instruction *I1,
+ SmallVectorImpl<Instruction *> &OtherSuccTIs) {
-HoistTerminator:
- // It may not be possible to hoist an invoke.
+ auto *BI = dyn_cast<BranchInst>(TI);
+
+ bool Changed = false;
+ BasicBlock *TIParent = TI->getParent();
+ BasicBlock *BB1 = I1->getParent();
+
+ // Use only for an if statement.
+ auto *I2 = *OtherSuccTIs.begin();
+ auto *BB2 = I2->getParent();
+ if (BI) {
+ assert(OtherSuccTIs.size() == 1);
+ assert(BI->getSuccessor(0) == I1->getParent());
+ assert(BI->getSuccessor(1) == I2->getParent());
+ }
+
+ // In the case of an if statement, we try to hoist an invoke.
// FIXME: Can we define a safety predicate for CallBr?
- if (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2))
- return Changed;
+ // FIXME: Test case llvm/test/Transforms/SimplifyCFG/2009-06-15-InvokeCrash.ll
+ // removed in 4c923b3b3fd0ac1edebf0603265ca3ba51724937 commit?
+ if (isa<InvokeInst>(I1) && (!BI || !isSafeToHoistInvoke(BB1, BB2, I1, I2)))
+ return false;
// TODO: callbr hoisting currently disabled pending further study.
if (isa<CallBrInst>(I1))
- return Changed;
+ return false;
for (BasicBlock *Succ : successors(BB1)) {
for (PHINode &PN : Succ->phis()) {
Value *BB1V = PN.getIncomingValueForBlock(BB1);
- Value *BB2V = PN.getIncomingValueForBlock(BB2);
- if (BB1V == BB2V)
- continue;
+ for (Instruction *OtherSuccTI : OtherSuccTIs) {
+ Value *BB2V = PN.getIncomingValueForBlock(OtherSuccTI->getParent());
+ if (BB1V == BB2V)
+ continue;
- // Check for passingValueIsAlwaysUndefined here because we would rather
- // eliminate undefined control flow then converting it to a select.
- if (passingValueIsAlwaysUndefined(BB1V, &PN) ||
- passingValueIsAlwaysUndefined(BB2V, &PN))
- return Changed;
+ // In the case of an if statement, check for
+ // passingValueIsAlwaysUndefined here because we would rather eliminate
+ // undefined control flow then converting it to a select.
+ if (!BI || passingValueIsAlwaysUndefined(BB1V, &PN) ||
+ passingValueIsAlwaysUndefined(BB2V, &PN))
+ return false;
+ }
}
}
// Okay, it is safe to hoist the terminator.
Instruction *NT = I1->clone();
- NT->insertInto(BIParent, BI->getIterator());
+ NT->insertInto(TIParent, TI->getIterator());
if (!NT->getType()->isVoidTy()) {
I1->replaceAllUsesWith(NT);
- I2->replaceAllUsesWith(NT);
+ for (Instruction *OtherSuccTI : OtherSuccTIs)
+ OtherSuccTI->replaceAllUsesWith(NT);
NT->takeName(I1);
}
Changed = true;
- ++NumHoistCommonInstrs;
+ NumHoistCommonInstrs += OtherSuccTIs.size() + 1;
// Ensure terminator gets a debug location, even an unknown one, in case
// it involves inlinable calls.
- NT->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc());
+ SmallVector<DILocation *, 4> Locs;
+ Locs.push_back(I1->getDebugLoc());
+ for (auto *OtherSuccTI : OtherSuccTIs)
+ Locs.push_back(OtherSuccTI->getDebugLoc());
+ NT->setDebugLoc(DILocation::getMergedLocations(Locs));
// PHIs created below will adopt NT's merged DebugLoc.
IRBuilder<NoFolder> Builder(NT);
- // Hoisting one of the terminators from our successor is a great thing.
- // Unfortunately, the successors of the if/else blocks may have PHI nodes in
- // them. If they do, all PHI entries for BB1/BB2 must agree for all PHI
- // nodes, so we insert select instruction to compute the final result.
- std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects;
- for (BasicBlock *Succ : successors(BB1)) {
- for (PHINode &PN : Succ->phis()) {
- Value *BB1V = PN.getIncomingValueForBlock(BB1);
- Value *BB2V = PN.getIncomingValueForBlock(BB2);
- if (BB1V == BB2V)
- continue;
+ // In the case of an if statement, hoisting one of the terminators from our
+ // successor is a great thing. Unfortunately, the successors of the if/else
+ // blocks may have PHI nodes in them. If they do, all PHI entries for BB1/BB2
+ // must agree for all PHI nodes, so we insert select instruction to compute
+ // the final result.
+ if (BI) {
+ std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects;
+ for (BasicBlock *Succ : successors(BB1)) {
+ for (PHINode &PN : Succ->phis()) {
+ Value *BB1V = PN.getIncomingValueForBlock(BB1);
+ Value *BB2V = PN.getIncomingValueForBlock(BB2);
+ if (BB1V == BB2V)
+ continue;
- // These values do not agree. Insert a select instruction before NT
- // that determines the right value.
- SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)];
- if (!SI) {
- // Propagate fast-math-flags from phi node to its replacement select.
- IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
- if (isa<FPMathOperator>(PN))
- Builder.setFastMathFlags(PN.getFastMathFlags());
+ // These values do not agree. Insert a select instruction before NT
+ // that determines the right value.
+ SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)];
+ if (!SI) {
+ // Propagate fast-math-flags from phi node to its replacement select.
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ if (isa<FPMathOperator>(PN))
+ Builder.setFastMathFlags(PN.getFastMathFlags());
- SI = cast<SelectInst>(
- Builder.CreateSelect(BI->getCondition(), BB1V, BB2V,
- BB1V->getName() + "." + BB2V->getName(), BI));
- }
+ SI = cast<SelectInst>(Builder.CreateSelect(
+ BI->getCondition(), BB1V, BB2V,
+ BB1V->getName() + "." + BB2V->getName(), BI));
+ }
- // Make the PHI node use the select for all incoming values for BB1/BB2
- for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i)
- if (PN.getIncomingBlock(i) == BB1 || PN.getIncomingBlock(i) == BB2)
- PN.setIncomingValue(i, SI);
+ // Make the PHI node use the select for all incoming values for BB1/BB2
+ for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i)
+ if (PN.getIncomingBlock(i) == BB1 || PN.getIncomingBlock(i) == BB2)
+ PN.setIncomingValue(i, SI);
+ }
}
}
@@ -1733,16 +1816,16 @@ HoistTerminator:
// Update any PHI nodes in our new successors.
for (BasicBlock *Succ : successors(BB1)) {
- AddPredecessorToBlock(Succ, BIParent, BB1);
+ AddPredecessorToBlock(Succ, TIParent, BB1);
if (DTU)
- Updates.push_back({DominatorTree::Insert, BIParent, Succ});
+ Updates.push_back({DominatorTree::Insert, TIParent, Succ});
}
if (DTU)
- for (BasicBlock *Succ : successors(BI))
- Updates.push_back({DominatorTree::Delete, BIParent, Succ});
+ for (BasicBlock *Succ : successors(TI))
+ Updates.push_back({DominatorTree::Delete, TIParent, Succ});
- EraseTerminatorAndDCECond(BI);
+ EraseTerminatorAndDCECond(TI);
if (DTU)
DTU->applyUpdates(Updates);
return Changed;
@@ -1808,10 +1891,19 @@ static bool canSinkInstructions(
}
const Instruction *I0 = Insts.front();
- for (auto *I : Insts)
+ for (auto *I : Insts) {
if (!I->isSameOperationAs(I0))
return false;
+ // swifterror pointers can only be used by a load or store; sinking a load
+ // or store would require introducing a select for the pointer operand,
+ // which isn't allowed for swifterror pointers.
+ if (isa<StoreInst>(I) && I->getOperand(1)->isSwiftError())
+ return false;
+ if (isa<LoadInst>(I) && I->getOperand(0)->isSwiftError())
+ return false;
+ }
+
// All instructions in Insts are known to be the same opcode. If they have a
// use, check that the only user is a PHI or in the same block as the
// instruction, because if a user is in the same block as an instruction we're
@@ -1952,8 +2044,9 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) {
// Create a new PHI in the successor block and populate it.
auto *Op = I0->getOperand(O);
assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
- auto *PN = PHINode::Create(Op->getType(), Insts.size(),
- Op->getName() + ".sink", &BBEnd->front());
+ auto *PN =
+ PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink");
+ PN->insertBefore(BBEnd->begin());
for (auto *I : Insts)
PN->addIncoming(I->getOperand(O), I->getParent());
NewOperands.push_back(PN);
@@ -1963,7 +2056,8 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) {
// and move it to the start of the successor block.
for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
I0->getOperandUse(O).set(NewOperands[O]);
- I0->moveBefore(&*BBEnd->getFirstInsertionPt());
+
+ I0->moveBefore(*BBEnd, BBEnd->getFirstInsertionPt());
// Update metadata and IR flags, and merge debug locations.
for (auto *I : Insts)
@@ -2765,8 +2859,8 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
Value *OrigV = PN.getIncomingValueForBlock(BB);
Value *ThenV = PN.getIncomingValueForBlock(ThenBB);
- // FIXME: Try to remove some of the duplication with HoistThenElseCodeToIf.
- // Skip PHIs which are trivial.
+ // FIXME: Try to remove some of the duplication with
+ // hoistCommonCodeFromSuccessors. Skip PHIs which are trivial.
if (ThenV == OrigV)
continue;
@@ -3009,7 +3103,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI,
// store %merge, %x.dest, !DIAssignID !2
// dbg.assign %merge, "x", ..., !2
for (auto *DAI : at::getAssignmentMarkers(SpeculatedStore)) {
- if (any_of(DAI->location_ops(), [&](Value *V) { return V == OrigV; }))
+ if (llvm::is_contained(DAI->location_ops(), OrigV))
DAI->replaceVariableLocationOp(OrigV, S);
}
}
@@ -3036,6 +3130,11 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI,
}
// Hoist the instructions.
+ // In "RemoveDIs" non-instr debug-info mode, drop DPValues attached to these
+ // instructions, in the same way that dbg.value intrinsics are dropped at the
+ // end of this block.
+ for (auto &It : make_range(ThenBB->begin(), ThenBB->end()))
+ It.dropDbgValues();
BB->splice(BI->getIterator(), ThenBB, ThenBB->begin(),
std::prev(ThenBB->end()));
@@ -3207,6 +3306,10 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
BasicBlock::iterator InsertPt = EdgeBB->getFirstInsertionPt();
DenseMap<Value *, Value *> TranslateMap; // Track translated values.
TranslateMap[Cond] = CB;
+
+ // RemoveDIs: track instructions that we optimise away while folding, so
+ // that we can copy DPValues from them later.
+ BasicBlock::iterator SrcDbgCursor = BB->begin();
for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) {
if (PHINode *PN = dyn_cast<PHINode>(BBI)) {
TranslateMap[PN] = PN->getIncomingValueForBlock(EdgeBB);
@@ -3241,6 +3344,15 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
TranslateMap[&*BBI] = N;
}
if (N) {
+ // Copy all debug-info attached to instructions from the last we
+ // successfully clone, up to this instruction (they might have been
+ // folded away).
+ for (; SrcDbgCursor != BBI; ++SrcDbgCursor)
+ N->cloneDebugInfoFrom(&*SrcDbgCursor);
+ SrcDbgCursor = std::next(BBI);
+ // Clone debug-info on this instruction too.
+ N->cloneDebugInfoFrom(&*BBI);
+
// Register the new instruction with the assumption cache if necessary.
if (auto *Assume = dyn_cast<AssumeInst>(N))
if (AC)
@@ -3248,6 +3360,10 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
}
}
+ for (; &*SrcDbgCursor != BI; ++SrcDbgCursor)
+ InsertPt->cloneDebugInfoFrom(&*SrcDbgCursor);
+ InsertPt->cloneDebugInfoFrom(BI);
+
BB->removePredecessor(EdgeBB);
BranchInst *EdgeBI = cast<BranchInst>(EdgeBB->getTerminator());
EdgeBI->setSuccessor(0, RealDest);
@@ -3652,22 +3768,22 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
ValueToValueMapTy VMap; // maps original values to cloned values
CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(BB, PredBlock, VMap);
+ Module *M = BB->getModule();
+
+ if (PredBlock->IsNewDbgInfoFormat) {
+ PredBlock->getTerminator()->cloneDebugInfoFrom(BB->getTerminator());
+ for (DPValue &DPV : PredBlock->getTerminator()->getDbgValueRange()) {
+ RemapDPValue(M, &DPV, VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ }
+ }
+
// Now that the Cond was cloned into the predecessor basic block,
// or/and the two conditions together.
Value *BICond = VMap[BI->getCondition()];
PBI->setCondition(
createLogicalOp(Builder, Opc, PBI->getCondition(), BICond, "or.cond"));
- // Copy any debug value intrinsics into the end of PredBlock.
- for (Instruction &I : *BB) {
- if (isa<DbgInfoIntrinsic>(I)) {
- Instruction *NewI = I.clone();
- RemapInstruction(NewI, VMap,
- RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
- NewI->insertBefore(PBI);
- }
- }
-
++NumFoldBranchToCommonDest;
return true;
}
@@ -3867,7 +3983,8 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB,
(!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != BB))
return V;
- PHI = PHINode::Create(V->getType(), 2, "simplifycfg.merge", &Succ->front());
+ PHI = PHINode::Create(V->getType(), 2, "simplifycfg.merge");
+ PHI->insertBefore(Succ->begin());
PHI->addIncoming(V, BB);
for (BasicBlock *PredBB : predecessors(Succ))
if (PredBB != BB)
@@ -3991,7 +4108,9 @@ static bool mergeConditionalStoreToAddress(
Value *QPHI = ensureValueAvailableInSuccessor(QStore->getValueOperand(),
QStore->getParent(), PPHI);
- IRBuilder<> QB(&*PostBB->getFirstInsertionPt());
+ BasicBlock::iterator PostBBFirst = PostBB->getFirstInsertionPt();
+ IRBuilder<> QB(PostBB, PostBBFirst);
+ QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());
Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
@@ -4002,9 +4121,11 @@ static bool mergeConditionalStoreToAddress(
QPred = QB.CreateNot(QPred);
Value *CombinedPred = QB.CreateOr(PPred, QPred);
- auto *T = SplitBlockAndInsertIfThen(CombinedPred, &*QB.GetInsertPoint(),
+ BasicBlock::iterator InsertPt = QB.GetInsertPoint();
+ auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
+
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
SI->setAAMetadata(PStore->getAAMetadata().merge(QStore->getAAMetadata()));
@@ -4140,10 +4261,10 @@ static bool tryWidenCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
// 2) We can sink side effecting instructions into BI's fallthrough
// successor provided they doesn't contribute to computation of
// BI's condition.
- Value *CondWB, *WC;
- BasicBlock *IfTrueBB, *IfFalseBB;
- if (!parseWidenableBranch(PBI, CondWB, WC, IfTrueBB, IfFalseBB) ||
- IfTrueBB != BI->getParent() || !BI->getParent()->getSinglePredecessor())
+ BasicBlock *IfTrueBB = PBI->getSuccessor(0);
+ BasicBlock *IfFalseBB = PBI->getSuccessor(1);
+ if (!isWidenableBranch(PBI) || IfTrueBB != BI->getParent() ||
+ !BI->getParent()->getSinglePredecessor())
return false;
if (!IfFalseBB->phis().empty())
return false; // TODO
@@ -4256,6 +4377,21 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
if (PBI->getSuccessor(PBIOp) == BB)
return false;
+ // If predecessor's branch probability to BB is too low don't merge branches.
+ SmallVector<uint32_t, 2> PredWeights;
+ if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
+ extractBranchWeights(*PBI, PredWeights) &&
+ (static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]) != 0) {
+
+ BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
+ PredWeights[PBIOp],
+ static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]);
+
+ BranchProbability Likely = TTI.getPredictableBranchThreshold();
+ if (CommonDestProb >= Likely)
+ return false;
+ }
+
// Do not perform this transformation if it would require
// insertion of a large number of select instructions. For targets
// without predication/cmovs, this is a big pessimization.
@@ -5088,6 +5224,15 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
bool Changed = false;
+ // Ensure that any debug-info records that used to occur after the Unreachable
+ // are moved to in front of it -- otherwise they'll "dangle" at the end of
+ // the block.
+ BB->flushTerminatorDbgValues();
+
+ // Debug-info records on the unreachable inst itself should be deleted, as
+ // below we delete everything past the final executable instruction.
+ UI->dropDbgValues();
+
// If there are any instructions immediately before the unreachable that can
// be removed, do so.
while (UI->getIterator() != BB->begin()) {
@@ -5104,6 +5249,10 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
// block will be the unwind edges of Invoke/CatchSwitch/CleanupReturn,
// and we can therefore guarantee this block will be erased.
+ // If we're deleting this, we're deleting any subsequent dbg.values, so
+ // delete DPValue records of variable information.
+ BBI->dropDbgValues();
+
// Delete this instruction (any uses are guaranteed to be dead)
BBI->replaceAllUsesWith(PoisonValue::get(BBI->getType()));
BBI->eraseFromParent();
@@ -5667,7 +5816,7 @@ getCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest,
for (Instruction &I : CaseDest->instructionsWithoutDebug(false)) {
if (I.isTerminator()) {
// If the terminator is a simple branch, continue to the next block.
- if (I.getNumSuccessors() != 1 || I.isExceptionalTerminator())
+ if (I.getNumSuccessors() != 1 || I.isSpecialTerminator())
return false;
Pred = CaseDest;
CaseDest = I.getSuccessor(0);
@@ -5890,8 +6039,8 @@ static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI,
// Remove the switch.
- while (PHI->getBasicBlockIndex(SelectBB) >= 0)
- PHI->removeIncomingValue(SelectBB);
+ PHI->removeIncomingValueIf(
+ [&](unsigned Idx) { return PHI->getIncomingBlock(Idx) == SelectBB; });
PHI->addIncoming(SelectValue, SelectBB);
SmallPtrSet<BasicBlock *, 4> RemovedSuccessors;
@@ -6051,8 +6200,9 @@ SwitchLookupTable::SwitchLookupTable(
bool LinearMappingPossible = true;
APInt PrevVal;
APInt DistToPrev;
- // When linear map is monotonic, we can attach nsw.
- bool Wrapped = false;
+ // When linear map is monotonic and signed overflow doesn't happen on
+ // maximum index, we can attach nsw on Add and Mul.
+ bool NonMonotonic = false;
assert(TableSize >= 2 && "Should be a SingleValue table.");
// Check if there is the same distance between two consecutive values.
for (uint64_t I = 0; I < TableSize; ++I) {
@@ -6072,7 +6222,7 @@ SwitchLookupTable::SwitchLookupTable(
LinearMappingPossible = false;
break;
}
- Wrapped |=
+ NonMonotonic |=
Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
}
PrevVal = Val;
@@ -6080,7 +6230,10 @@ SwitchLookupTable::SwitchLookupTable(
if (LinearMappingPossible) {
LinearOffset = cast<ConstantInt>(TableContents[0]);
LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev);
- LinearMapValWrapped = Wrapped;
+ bool MayWrap = false;
+ APInt M = LinearMultiplier->getValue();
+ (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
+ LinearMapValWrapped = NonMonotonic || MayWrap;
Kind = LinearMapKind;
++NumLinearMaps;
return;
@@ -6503,9 +6656,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// If the default destination is unreachable, or if the lookup table covers
// all values of the conditional variable, branch directly to the lookup table
// BB. Otherwise, check that the condition is within the case range.
- const bool DefaultIsReachable =
+ bool DefaultIsReachable =
!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
- const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize);
// Create the BB that does the lookups.
Module &Mod = *CommonDest->getParent()->getParent();
@@ -6536,6 +6688,28 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
BranchInst *RangeCheckBranch = nullptr;
+ // Grow the table to cover all possible index values to avoid the range check.
+ // It will use the default result to fill in the table hole later, so make
+ // sure it exist.
+ if (UseSwitchConditionAsTableIndex && HasDefaultResults) {
+ ConstantRange CR = computeConstantRange(TableIndex, /* ForSigned */ false);
+ // Grow the table shouldn't have any size impact by checking
+ // WouldFitInRegister.
+ // TODO: Consider growing the table also when it doesn't fit in a register
+ // if no optsize is specified.
+ const uint64_t UpperBound = CR.getUpper().getLimitedValue();
+ if (!CR.isUpperWrapped() && all_of(ResultTypes, [&](const auto &KV) {
+ return SwitchLookupTable::WouldFitInRegister(
+ DL, UpperBound, KV.second /* ResultType */);
+ })) {
+ // The default branch is unreachable after we enlarge the lookup table.
+ // Adjust DefaultIsReachable to reuse code path.
+ TableSize = UpperBound;
+ DefaultIsReachable = false;
+ }
+ }
+
+ const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize);
if (!DefaultIsReachable || GeneratingCoveredLookupTable) {
Builder.CreateBr(LookupBB);
if (DTU)
@@ -6697,9 +6871,6 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
// This transform can be done speculatively because it is so cheap - it
// results in a single rotate operation being inserted.
- // FIXME: It's possible that optimizing a switch on powers of two might also
- // be beneficial - flag values are often powers of two and we could use a CLZ
- // as the key function.
// countTrailingZeros(0) returns 64. As Values is guaranteed to have more than
// one element and LLVM disallows duplicate cases, Shift is guaranteed to be
@@ -6744,6 +6915,80 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Tries to transform switch of powers of two to reduce switch range.
+/// For example, switch like:
+/// switch (C) { case 1: case 2: case 64: case 128: }
+/// will be transformed to:
+/// switch (count_trailing_zeros(C)) { case 0: case 1: case 6: case 7: }
+///
+/// This transformation allows better lowering and could allow transforming into
+/// a lookup table.
+static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
+ const DataLayout &DL,
+ const TargetTransformInfo &TTI) {
+ Value *Condition = SI->getCondition();
+ LLVMContext &Context = SI->getContext();
+ auto *CondTy = cast<IntegerType>(Condition->getType());
+
+ if (CondTy->getIntegerBitWidth() > 64 ||
+ !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+ return false;
+
+ const auto CttzIntrinsicCost = TTI.getIntrinsicInstrCost(
+ IntrinsicCostAttributes(Intrinsic::cttz, CondTy,
+ {Condition, ConstantInt::getTrue(Context)}),
+ TTI::TCK_SizeAndLatency);
+
+ if (CttzIntrinsicCost > TTI::TCC_Basic)
+ // Inserting intrinsic is too expensive.
+ return false;
+
+ // Only bother with this optimization if there are more than 3 switch cases.
+ // SDAG will only bother creating jump tables for 4 or more cases.
+ if (SI->getNumCases() < 4)
+ return false;
+
+ // We perform this optimization only for switches with
+ // unreachable default case.
+ // This assumtion will save us from checking if `Condition` is a power of two.
+ if (!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()))
+ return false;
+
+ // Check that switch cases are powers of two.
+ SmallVector<uint64_t, 4> Values;
+ for (const auto &Case : SI->cases()) {
+ uint64_t CaseValue = Case.getCaseValue()->getValue().getZExtValue();
+ if (llvm::has_single_bit(CaseValue))
+ Values.push_back(CaseValue);
+ else
+ return false;
+ }
+
+ // isSwichDense requires case values to be sorted.
+ llvm::sort(Values);
+ if (!isSwitchDense(Values.size(), llvm::countr_zero(Values.back()) -
+ llvm::countr_zero(Values.front()) + 1))
+ // Transform is unable to generate dense switch.
+ return false;
+
+ Builder.SetInsertPoint(SI);
+
+ // Replace each case with its trailing zeros number.
+ for (auto &Case : SI->cases()) {
+ auto *OrigValue = Case.getCaseValue();
+ Case.setValue(ConstantInt::get(OrigValue->getType(),
+ OrigValue->getValue().countr_zero()));
+ }
+
+ // Replace condition with its trailing zeros number.
+ auto *ConditionTrailingZeros = Builder.CreateIntrinsic(
+ Intrinsic::cttz, {CondTy}, {Condition, ConstantInt::getTrue(Context)});
+
+ SI->setCondition(ConditionTrailingZeros);
+
+ return true;
+}
+
bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
BasicBlock *BB = SI->getParent();
@@ -6791,9 +7036,16 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
SwitchToLookupTable(SI, Builder, DTU, DL, TTI))
return requestResimplify();
+ if (simplifySwitchOfPowersOfTwo(SI, Builder, DL, TTI))
+ return requestResimplify();
+
if (ReduceSwitchRange(SI, Builder, DL, TTI))
return requestResimplify();
+ if (HoistCommon &&
+ hoistCommonCodeFromSuccessors(SI->getParent(), !Options.HoistCommonInsts))
+ return requestResimplify();
+
return false;
}
@@ -6978,7 +7230,8 @@ bool SimplifyCFGOpt::simplifyUncondBranch(BranchInst *BI,
// branches to us and our successor, fold the comparison into the
// predecessor and use logical operations to update the incoming value
// for PHI nodes in common successor.
- if (FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI,
+ if (Options.SpeculateBlocks &&
+ FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI,
Options.BonusInstThreshold))
return requestResimplify();
return false;
@@ -7048,7 +7301,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// If this basic block is ONLY a compare and a branch, and if a predecessor
// branches to us and one of our successors, fold the comparison into the
// predecessor and use logical operations to pick the right destination.
- if (FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI,
+ if (Options.SpeculateBlocks &&
+ FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI,
Options.BonusInstThreshold))
return requestResimplify();
@@ -7058,7 +7312,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// can hoist it up to the branching block.
if (BI->getSuccessor(0)->getSinglePredecessor()) {
if (BI->getSuccessor(1)->getSinglePredecessor()) {
- if (HoistCommon && HoistThenElseCodeToIf(BI, !Options.HoistCommonInsts))
+ if (HoistCommon && hoistCommonCodeFromSuccessors(
+ BI->getParent(), !Options.HoistCommonInsts))
return requestResimplify();
} else {
// If Successor #1 has multiple preds, we may be able to conditionally
diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index a28916bc9baf..722ed03db3de 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -539,7 +539,8 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) {
for (auto *ICI : ICmpUsers) {
bool IsSwapped = L->isLoopInvariant(ICI->getOperand(0));
auto *Op1 = IsSwapped ? ICI->getOperand(0) : ICI->getOperand(1);
- Instruction *Ext = nullptr;
+ IRBuilder<> Builder(ICI);
+ Value *Ext = nullptr;
// For signed/unsigned predicate, replace the old comparison with comparison
// of immediate IV against sext/zext of the invariant argument. If we can
// use either sext or zext (i.e. we are dealing with equality predicate),
@@ -550,18 +551,18 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) {
if (IsSwapped) Pred = ICmpInst::getSwappedPredicate(Pred);
if (CanUseZExt(ICI)) {
assert(DoesZExtCollapse && "Unprofitable zext?");
- Ext = new ZExtInst(Op1, IVTy, "zext", ICI);
+ Ext = Builder.CreateZExt(Op1, IVTy, "zext");
Pred = ICmpInst::getUnsignedPredicate(Pred);
} else {
assert(DoesSExtCollapse && "Unprofitable sext?");
- Ext = new SExtInst(Op1, IVTy, "sext", ICI);
+ Ext = Builder.CreateSExt(Op1, IVTy, "sext");
assert(Pred == ICmpInst::getSignedPredicate(Pred) && "Must be signed!");
}
bool Changed;
L->makeLoopInvariant(Ext, Changed);
(void)Changed;
- ICmpInst *NewICI = new ICmpInst(ICI, Pred, IV, Ext);
- ICI->replaceAllUsesWith(NewICI);
+ auto *NewCmp = Builder.CreateICmp(Pred, IV, Ext);
+ ICI->replaceAllUsesWith(NewCmp);
DeadInsts.emplace_back(ICI);
}
@@ -659,12 +660,12 @@ bool SimplifyIndvar::replaceFloatIVWithIntegerIV(Instruction *UseInst) {
Instruction *IVOperand = cast<Instruction>(UseInst->getOperand(0));
// Get the symbolic expression for this instruction.
const SCEV *IV = SE->getSCEV(IVOperand);
- unsigned MaskBits;
+ int MaskBits;
if (UseInst->getOpcode() == CastInst::SIToFP)
- MaskBits = SE->getSignedRange(IV).getMinSignedBits();
+ MaskBits = (int)SE->getSignedRange(IV).getMinSignedBits();
else
- MaskBits = SE->getUnsignedRange(IV).getActiveBits();
- unsigned DestNumSigBits = UseInst->getType()->getFPMantissaWidth();
+ MaskBits = (int)SE->getUnsignedRange(IV).getActiveBits();
+ int DestNumSigBits = UseInst->getType()->getFPMantissaWidth();
if (MaskBits <= DestNumSigBits) {
for (User *U : UseInst->users()) {
// Match for fptosi/fptoui of sitofp and with same type.
@@ -908,8 +909,9 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) {
if (replaceIVUserWithLoopInvariant(UseInst))
continue;
- // Go further for the bitcast ''prtoint ptr to i64'
- if (isa<PtrToIntInst>(UseInst))
+ // Go further for the bitcast 'prtoint ptr to i64' or if the cast is done
+ // by truncation
+ if ((isa<PtrToIntInst>(UseInst)) || (isa<TruncInst>(UseInst)))
for (Use &U : UseInst->uses()) {
Instruction *User = cast<Instruction>(U.getUser());
if (replaceIVUserWithLoopInvariant(User))
@@ -1373,16 +1375,32 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0;
assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU");
- const SCEV *ExtendOperExpr = nullptr;
const OverflowingBinaryOperator *OBO =
cast<OverflowingBinaryOperator>(DU.NarrowUse);
ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
- if (ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap())
- ExtendOperExpr = SE->getSignExtendExpr(
- SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType);
- else if (ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())
- ExtendOperExpr = SE->getZeroExtendExpr(
- SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType);
+ if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) &&
+ !(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) {
+ ExtKind = ExtendKind::Unknown;
+
+ // For a non-negative NarrowDef, we can choose either type of
+ // extension. We want to use the current extend kind if legal
+ // (see above), and we only hit this code if we need to check
+ // the opposite case.
+ if (DU.NeverNegative) {
+ if (OBO->hasNoSignedWrap()) {
+ ExtKind = ExtendKind::Sign;
+ } else if (OBO->hasNoUnsignedWrap()) {
+ ExtKind = ExtendKind::Zero;
+ }
+ }
+ }
+
+ const SCEV *ExtendOperExpr =
+ SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx));
+ if (ExtKind == ExtendKind::Sign)
+ ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType);
+ else if (ExtKind == ExtendKind::Zero)
+ ExtendOperExpr = SE->getZeroExtendExpr(ExtendOperExpr, WideType);
else
return {nullptr, ExtendKind::Unknown};
@@ -1493,10 +1511,6 @@ bool WidenIV::widenLoopCompare(WidenIV::NarrowIVDefUse DU) {
assert(CastWidth <= IVWidth && "Unexpected width while widening compare.");
// Widen the compare instruction.
- auto *InsertPt = getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI);
- if (!InsertPt)
- return false;
- IRBuilder<> Builder(InsertPt);
DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef);
// Widen the other operand of the compare, if necessary.
@@ -1673,7 +1687,8 @@ bool WidenIV::widenWithVariantUse(WidenIV::NarrowIVDefUse DU) {
assert(LoopExitingBlock && L->contains(LoopExitingBlock) &&
"Not a LCSSA Phi?");
WidePN->addIncoming(WideBO, LoopExitingBlock);
- Builder.SetInsertPoint(&*User->getParent()->getFirstInsertionPt());
+ Builder.SetInsertPoint(User->getParent(),
+ User->getParent()->getFirstInsertionPt());
auto *TruncPN = Builder.CreateTrunc(WidePN, User->getType());
User->replaceAllUsesWith(TruncPN);
DeadInsts.emplace_back(User);
@@ -1726,7 +1741,8 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri
PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide",
UsePhi);
WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0));
- IRBuilder<> Builder(&*WidePhi->getParent()->getFirstInsertionPt());
+ BasicBlock *WidePhiBB = WidePhi->getParent();
+ IRBuilder<> Builder(WidePhiBB, WidePhiBB->getFirstInsertionPt());
Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType());
UsePhi->replaceAllUsesWith(Trunc);
DeadInsts.emplace_back(UsePhi);
@@ -1786,65 +1802,70 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri
return nullptr;
}
- // Does this user itself evaluate to a recurrence after widening?
- WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU);
- if (!WideAddRec.first)
- WideAddRec = getWideRecurrence(DU);
-
- assert((WideAddRec.first == nullptr) ==
- (WideAddRec.second == ExtendKind::Unknown));
- if (!WideAddRec.first) {
- // If use is a loop condition, try to promote the condition instead of
- // truncating the IV first.
- if (widenLoopCompare(DU))
+ auto tryAddRecExpansion = [&]() -> Instruction* {
+ // Does this user itself evaluate to a recurrence after widening?
+ WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU);
+ if (!WideAddRec.first)
+ WideAddRec = getWideRecurrence(DU);
+ assert((WideAddRec.first == nullptr) ==
+ (WideAddRec.second == ExtendKind::Unknown));
+ if (!WideAddRec.first)
return nullptr;
- // We are here about to generate a truncate instruction that may hurt
- // performance because the scalar evolution expression computed earlier
- // in WideAddRec.first does not indicate a polynomial induction expression.
- // In that case, look at the operands of the use instruction to determine
- // if we can still widen the use instead of truncating its operand.
- if (widenWithVariantUse(DU))
+ // Reuse the IV increment that SCEVExpander created as long as it dominates
+ // NarrowUse.
+ Instruction *WideUse = nullptr;
+ if (WideAddRec.first == WideIncExpr &&
+ Rewriter.hoistIVInc(WideInc, DU.NarrowUse))
+ WideUse = WideInc;
+ else {
+ WideUse = cloneIVUser(DU, WideAddRec.first);
+ if (!WideUse)
+ return nullptr;
+ }
+ // Evaluation of WideAddRec ensured that the narrow expression could be
+ // extended outside the loop without overflow. This suggests that the wide use
+ // evaluates to the same expression as the extended narrow use, but doesn't
+ // absolutely guarantee it. Hence the following failsafe check. In rare cases
+ // where it fails, we simply throw away the newly created wide use.
+ if (WideAddRec.first != SE->getSCEV(WideUse)) {
+ LLVM_DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": "
+ << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first
+ << "\n");
+ DeadInsts.emplace_back(WideUse);
return nullptr;
+ };
- // This user does not evaluate to a recurrence after widening, so don't
- // follow it. Instead insert a Trunc to kill off the original use,
- // eventually isolating the original narrow IV so it can be removed.
- truncateIVUse(DU, DT, LI);
- return nullptr;
- }
+ // if we reached this point then we are going to replace
+ // DU.NarrowUse with WideUse. Reattach DbgValue then.
+ replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT);
- // Reuse the IV increment that SCEVExpander created as long as it dominates
- // NarrowUse.
- Instruction *WideUse = nullptr;
- if (WideAddRec.first == WideIncExpr &&
- Rewriter.hoistIVInc(WideInc, DU.NarrowUse))
- WideUse = WideInc;
- else {
- WideUse = cloneIVUser(DU, WideAddRec.first);
- if (!WideUse)
- return nullptr;
- }
- // Evaluation of WideAddRec ensured that the narrow expression could be
- // extended outside the loop without overflow. This suggests that the wide use
- // evaluates to the same expression as the extended narrow use, but doesn't
- // absolutely guarantee it. Hence the following failsafe check. In rare cases
- // where it fails, we simply throw away the newly created wide use.
- if (WideAddRec.first != SE->getSCEV(WideUse)) {
- LLVM_DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": "
- << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first
- << "\n");
- DeadInsts.emplace_back(WideUse);
+ ExtendKindMap[DU.NarrowUse] = WideAddRec.second;
+ // Returning WideUse pushes it on the worklist.
+ return WideUse;
+ };
+
+ if (auto *I = tryAddRecExpansion())
+ return I;
+
+ // If use is a loop condition, try to promote the condition instead of
+ // truncating the IV first.
+ if (widenLoopCompare(DU))
return nullptr;
- }
- // if we reached this point then we are going to replace
- // DU.NarrowUse with WideUse. Reattach DbgValue then.
- replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT);
+ // We are here about to generate a truncate instruction that may hurt
+ // performance because the scalar evolution expression computed earlier
+ // in WideAddRec.first does not indicate a polynomial induction expression.
+ // In that case, look at the operands of the use instruction to determine
+ // if we can still widen the use instead of truncating its operand.
+ if (widenWithVariantUse(DU))
+ return nullptr;
- ExtendKindMap[DU.NarrowUse] = WideAddRec.second;
- // Returning WideUse pushes it on the worklist.
- return WideUse;
+ // This user does not evaluate to a recurrence after widening, so don't
+ // follow it. Instead insert a Trunc to kill off the original use,
+ // eventually isolating the original narrow IV so it can be removed.
+ truncateIVUse(DU, DT, LI);
+ return nullptr;
}
/// Add eligible users of NarrowDef to NarrowIVUsers.
@@ -1944,13 +1965,15 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) {
// SCEVExpander. Henceforth, we produce 1-to-1 narrow to wide uses.
if (BasicBlock *LatchBlock = L->getLoopLatch()) {
WideInc =
- cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock));
- WideIncExpr = SE->getSCEV(WideInc);
- // Propagate the debug location associated with the original loop increment
- // to the new (widened) increment.
- auto *OrigInc =
- cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock));
- WideInc->setDebugLoc(OrigInc->getDebugLoc());
+ dyn_cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock));
+ if (WideInc) {
+ WideIncExpr = SE->getSCEV(WideInc);
+ // Propagate the debug location associated with the original loop
+ // increment to the new (widened) increment.
+ auto *OrigInc =
+ cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock));
+ WideInc->setDebugLoc(OrigInc->getDebugLoc());
+ }
}
LLVM_DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n");
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 5b0951252c07..760a626c8b6f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -227,9 +227,21 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr,
return ConstantInt::get(RetTy, Result);
}
+static bool isOnlyUsedInComparisonWithZero(Value *V) {
+ for (User *U : V->users()) {
+ if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
+ if (Constant *C = dyn_cast<Constant>(IC->getOperand(1)))
+ if (C->isNullValue())
+ continue;
+ // Unknown instruction.
+ return false;
+ }
+ return true;
+}
+
static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len,
const DataLayout &DL) {
- if (!isOnlyUsedInZeroComparison(CI))
+ if (!isOnlyUsedInComparisonWithZero(CI))
return false;
if (!isDereferenceableAndAlignedPointer(Str, Align(1), APInt(64, Len), DL))
@@ -1136,7 +1148,7 @@ Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilderBase &B) {
Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) {
// fold strstr(x, x) -> x.
if (CI->getArgOperand(0) == CI->getArgOperand(1))
- return B.CreateBitCast(CI->getArgOperand(0), CI->getType());
+ return CI->getArgOperand(0);
// fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0
if (isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) {
@@ -1164,7 +1176,7 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) {
// fold strstr(x, "") -> x.
if (HasStr2 && ToFindStr.empty())
- return B.CreateBitCast(CI->getArgOperand(0), CI->getType());
+ return CI->getArgOperand(0);
// If both strings are known, constant fold it.
if (HasStr1 && HasStr2) {
@@ -1174,16 +1186,13 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) {
return Constant::getNullValue(CI->getType());
// strstr("abcd", "bc") -> gep((char*)"abcd", 1)
- Value *Result = castToCStr(CI->getArgOperand(0), B);
- Result =
- B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), Result, Offset, "strstr");
- return B.CreateBitCast(Result, CI->getType());
+ return B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), CI->getArgOperand(0),
+ Offset, "strstr");
}
// fold strstr(x, "y") -> strchr(x, 'y').
if (HasStr2 && ToFindStr.size() == 1) {
- Value *StrChr = emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI);
- return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr;
+ return emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI);
}
annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
@@ -1380,7 +1389,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
if (isOnlyUsedInEqualityComparison(CI, SrcStr))
// S is dereferenceable so it's safe to load from it and fold
// memchr(S, C, N) == S to N && *S == C for any C and N.
- // TODO: This is safe even even for nonconstant S.
+ // TODO: This is safe even for nonconstant S.
return memChrToCharCompare(CI, Size, B, DL);
// From now on we need a constant length and constant array.
@@ -1522,12 +1531,10 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS,
// memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS
if (Len == 1) {
- Value *LHSV =
- B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(LHS, B), "lhsc"),
- CI->getType(), "lhsv");
- Value *RHSV =
- B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(RHS, B), "rhsc"),
- CI->getType(), "rhsv");
+ Value *LHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), LHS, "lhsc"),
+ CI->getType(), "lhsv");
+ Value *RHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), RHS, "rhsc"),
+ CI->getType(), "rhsv");
return B.CreateSub(LHSV, RHSV, "chardiff");
}
@@ -1833,7 +1840,7 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B,
StringRef CallerName = CI->getFunction()->getName();
if (!CallerName.empty() && CallerName.back() == 'f' &&
CallerName.size() == (CalleeName.size() + 1) &&
- CallerName.startswith(CalleeName))
+ CallerName.starts_with(CalleeName))
return nullptr;
}
@@ -2368,8 +2375,8 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) {
FMF.setNoSignedZeros();
B.setFastMathFlags(FMF);
- Intrinsic::ID IID = Callee->getName().startswith("fmin") ? Intrinsic::minnum
- : Intrinsic::maxnum;
+ Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum
+ : Intrinsic::maxnum;
Function *F = Intrinsic::getDeclaration(CI->getModule(), IID, CI->getType());
return copyFlags(
*CI, B.CreateCall(F, {CI->getArgOperand(0), CI->getArgOperand(1)}));
@@ -3066,7 +3073,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
if (!CI->getArgOperand(2)->getType()->isIntegerTy())
return nullptr;
Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char");
- Value *Ptr = castToCStr(Dest, B);
+ Value *Ptr = Dest;
B.CreateStore(V, Ptr);
Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
B.CreateStore(B.getInt8(0), Ptr);
@@ -3093,9 +3100,6 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
return ConstantInt::get(CI->getType(), SrcLen - 1);
} else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
// sprintf(dest, "%s", str) -> stpcpy(dest, str) - dest
- // Handle mismatched pointer types (goes away with typeless pointers?).
- V = B.CreatePointerCast(V, B.getInt8PtrTy());
- Dest = B.CreatePointerCast(Dest, B.getInt8PtrTy());
Value *PtrDiff = B.CreatePtrDiff(B.getInt8Ty(), V, Dest);
return B.CreateIntCast(PtrDiff, CI->getType(), false);
}
@@ -3261,7 +3265,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI,
if (!CI->getArgOperand(3)->getType()->isIntegerTy())
return nullptr;
Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char");
- Value *Ptr = castToCStr(DstArg, B);
+ Value *Ptr = DstArg;
B.CreateStore(V, Ptr);
Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
B.CreateStore(B.getInt8(0), Ptr);
@@ -3397,8 +3401,7 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilderBase &B) {
// If this is writing one byte, turn it into fputc.
// This optimisation is only valid, if the return value is unused.
if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F)
- Value *Char = B.CreateLoad(B.getInt8Ty(),
- castToCStr(CI->getArgOperand(0), B), "char");
+ Value *Char = B.CreateLoad(B.getInt8Ty(), CI->getArgOperand(0), "char");
Type *IntTy = B.getIntNTy(TLI->getIntSize());
Value *Cast = B.CreateIntCast(Char, IntTy, /*isSigned*/ true, "chari");
Value *NewCI = emitFPutC(Cast, CI->getArgOperand(3), B, TLI);
diff --git a/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/llvm/lib/Transforms/Utils/StripGCRelocates.cpp
index 0ff88e8b4612..6094f36a77f4 100644
--- a/llvm/lib/Transforms/Utils/StripGCRelocates.cpp
+++ b/llvm/lib/Transforms/Utils/StripGCRelocates.cpp
@@ -18,8 +18,6 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Statepoint.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
using namespace llvm;
@@ -66,21 +64,3 @@ PreservedAnalyses StripGCRelocates::run(Function &F,
PA.preserveSet<CFGAnalyses>();
return PA;
}
-
-namespace {
-struct StripGCRelocatesLegacy : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- StripGCRelocatesLegacy() : FunctionPass(ID) {
- initializeStripGCRelocatesLegacyPass(*PassRegistry::getPassRegistry());
- }
-
- void getAnalysisUsage(AnalysisUsage &Info) const override {}
-
- bool runOnFunction(Function &F) override { return ::stripGCRelocates(F); }
-};
-char StripGCRelocatesLegacy::ID = 0;
-} // namespace
-
-INITIALIZE_PASS(StripGCRelocatesLegacy, "strip-gc-relocates",
- "Strip gc.relocates inserted through RewriteStatepointsForGC",
- true, false)
diff --git a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
index c3ae43e567b0..8b4f34209e85 100644
--- a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
+++ b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp
@@ -68,8 +68,6 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
diff --git a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp
index 2b706858cbed..d5468909dd4e 100644
--- a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp
+++ b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp
@@ -16,33 +16,9 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
-#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils.h"
using namespace llvm;
-char UnifyFunctionExitNodesLegacyPass::ID = 0;
-
-UnifyFunctionExitNodesLegacyPass::UnifyFunctionExitNodesLegacyPass()
- : FunctionPass(ID) {
- initializeUnifyFunctionExitNodesLegacyPassPass(
- *PassRegistry::getPassRegistry());
-}
-
-INITIALIZE_PASS(UnifyFunctionExitNodesLegacyPass, "mergereturn",
- "Unify function exit nodes", false, false)
-
-Pass *llvm::createUnifyFunctionExitNodesPass() {
- return new UnifyFunctionExitNodesLegacyPass();
-}
-
-void UnifyFunctionExitNodesLegacyPass::getAnalysisUsage(
- AnalysisUsage &AU) const {
- // We preserve the non-critical-edgeness property
- AU.addPreservedID(BreakCriticalEdgesID);
- // This is a cluster of orthogonal Transforms
- AU.addPreservedID(LowerSwitchID);
-}
-
namespace {
bool unifyUnreachableBlocks(Function &F) {
@@ -110,16 +86,6 @@ bool unifyReturnBlocks(Function &F) {
}
} // namespace
-// Unify all exit nodes of the CFG by creating a new BasicBlock, and converting
-// all returns to unconditional branches to this new basic block. Also, unify
-// all unreachable blocks.
-bool UnifyFunctionExitNodesLegacyPass::runOnFunction(Function &F) {
- bool Changed = false;
- Changed |= unifyUnreachableBlocks(F);
- Changed |= unifyReturnBlocks(F);
- return Changed;
-}
-
PreservedAnalyses UnifyFunctionExitNodesPass::run(Function &F,
FunctionAnalysisManager &AM) {
bool Changed = false;
diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
index 8c781f59ff5a..2f37f7f972cb 100644
--- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
+++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp
@@ -44,10 +44,8 @@ struct UnifyLoopExitsLegacyPass : public FunctionPass {
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequiredID(LowerSwitchID);
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreservedID(LowerSwitchID);
AU.addPreserved<LoopInfoWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
@@ -65,7 +63,6 @@ FunctionPass *llvm::createUnifyLoopExitsPass() {
INITIALIZE_PASS_BEGIN(UnifyLoopExitsLegacyPass, "unify-loop-exits",
"Fixup each natural loop to have a single exit block",
false /* Only looks at CFG */, false /* Analysis Pass */)
-INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(UnifyLoopExitsLegacyPass, "unify-loop-exits",
@@ -234,6 +231,8 @@ bool UnifyLoopExitsLegacyPass::runOnFunction(Function &F) {
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator.");
+
return runImpl(LI, DT);
}
diff --git a/llvm/lib/Transforms/Utils/Utils.cpp b/llvm/lib/Transforms/Utils/Utils.cpp
index 91c743f17764..51e1e824dd26 100644
--- a/llvm/lib/Transforms/Utils/Utils.cpp
+++ b/llvm/lib/Transforms/Utils/Utils.cpp
@@ -21,7 +21,6 @@ using namespace llvm;
/// initializeTransformUtils - Initialize all passes in the TransformUtils
/// library.
void llvm::initializeTransformUtils(PassRegistry &Registry) {
- initializeAssumeBuilderPassLegacyPassPass(Registry);
initializeBreakCriticalEdgesPass(Registry);
initializeCanonicalizeFreezeInLoopsPass(Registry);
initializeLCSSAWrapperPassPass(Registry);
@@ -30,9 +29,6 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) {
initializeLowerInvokeLegacyPassPass(Registry);
initializeLowerSwitchLegacyPassPass(Registry);
initializePromoteLegacyPassPass(Registry);
- initializeUnifyFunctionExitNodesLegacyPassPass(Registry);
- initializeStripGCRelocatesLegacyPass(Registry);
- initializePredicateInfoPrinterLegacyPassPass(Registry);
initializeFixIrreduciblePass(Registry);
initializeUnifyLoopExitsLegacyPassPass(Registry);
}
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 3446e31cc2ef..71d0f09e4771 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -31,6 +31,7 @@
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
@@ -145,6 +146,7 @@ public:
Value *mapValue(const Value *V);
void remapInstruction(Instruction *I);
void remapFunction(Function &F);
+ void remapDPValue(DPValue &DPV);
Constant *mapConstant(const Constant *C) {
return cast_or_null<Constant>(mapValue(C));
@@ -535,6 +537,39 @@ Value *Mapper::mapValue(const Value *V) {
return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy));
}
+void Mapper::remapDPValue(DPValue &V) {
+ // Remap variables and DILocations.
+ auto *MappedVar = mapMetadata(V.getVariable());
+ auto *MappedDILoc = mapMetadata(V.getDebugLoc());
+ V.setVariable(cast<DILocalVariable>(MappedVar));
+ V.setDebugLoc(DebugLoc(cast<DILocation>(MappedDILoc)));
+
+ // Find Value operands and remap those.
+ SmallVector<Value *, 4> Vals, NewVals;
+ for (Value *Val : V.location_ops())
+ Vals.push_back(Val);
+ for (Value *Val : Vals)
+ NewVals.push_back(mapValue(Val));
+
+ // If there are no changes to the Value operands, finished.
+ if (Vals == NewVals)
+ return;
+
+ bool IgnoreMissingLocals = Flags & RF_IgnoreMissingLocals;
+
+ // Otherwise, do some replacement.
+ if (!IgnoreMissingLocals &&
+ llvm::any_of(NewVals, [&](Value *V) { return V == nullptr; })) {
+ V.setKillLocation();
+ } else {
+ // Either we have all non-empty NewVals, or we're permitted to ignore
+ // missing locals.
+ for (unsigned int I = 0; I < Vals.size(); ++I)
+ if (NewVals[I])
+ V.replaceVariableLocationOp(I, NewVals[I]);
+ }
+}
+
Value *Mapper::mapBlockAddress(const BlockAddress &BA) {
Function *F = cast<Function>(mapValue(BA.getFunction()));
@@ -1179,6 +1214,17 @@ void ValueMapper::remapInstruction(Instruction &I) {
FlushingMapper(pImpl)->remapInstruction(&I);
}
+void ValueMapper::remapDPValue(Module *M, DPValue &V) {
+ FlushingMapper(pImpl)->remapDPValue(V);
+}
+
+void ValueMapper::remapDPValueRange(
+ Module *M, iterator_range<DPValue::self_iterator> Range) {
+ for (DPValue &DPV : Range) {
+ remapDPValue(M, DPV);
+ }
+}
+
void ValueMapper::remapFunction(Function &F) {
FlushingMapper(pImpl)->remapFunction(F);
}
diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 260d7889906b..c0dbd52acbab 100644
--- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -103,7 +103,6 @@
#include "llvm/Support/ModRef.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
@@ -900,9 +899,9 @@ bool Vectorizer::vectorizeChain(Chain &C) {
// Chain is in offset order, so C[0] is the instr with the lowest offset,
// i.e. the root of the vector.
- Value *Bitcast = Builder.CreateBitCast(
- getLoadStorePointerOperand(C[0].Inst), VecTy->getPointerTo(AS));
- VecInst = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment);
+ VecInst = Builder.CreateAlignedLoad(VecTy,
+ getLoadStorePointerOperand(C[0].Inst),
+ Alignment);
unsigned VecIdx = 0;
for (const ChainElem &E : C) {
@@ -976,8 +975,7 @@ bool Vectorizer::vectorizeChain(Chain &C) {
// i.e. the root of the vector.
VecInst = Builder.CreateAlignedStore(
Vec,
- Builder.CreateBitCast(getLoadStorePointerOperand(C[0].Inst),
- VecTy->getPointerTo(AS)),
+ getLoadStorePointerOperand(C[0].Inst),
Alignment);
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index f923f0be6621..37a356c43e29 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -289,7 +289,7 @@ void LoopVectorizeHints::getHintsFromMetadata() {
}
void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) {
- if (!Name.startswith(Prefix()))
+ if (!Name.starts_with(Prefix()))
return;
Name = Name.substr(Prefix().size(), StringRef::npos);
@@ -943,6 +943,11 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
}
}
+ // If we found a vectorized variant of a function, note that so LV can
+ // make better decisions about maximum VF.
+ if (CI && !VFDatabase::getMappings(*CI).empty())
+ VecCallVariantsFound = true;
+
// Check that the instruction return type is vectorizable.
// Also, we can't vectorize extractelement instructions.
if ((!VectorType::isValidElementType(I.getType()) &&
@@ -1242,13 +1247,12 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
bool LoopVectorizationLegality::blockCanBePredicated(
BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs,
- SmallPtrSetImpl<const Instruction *> &MaskedOp,
- SmallPtrSetImpl<Instruction *> &ConditionalAssumes) const {
+ SmallPtrSetImpl<const Instruction *> &MaskedOp) const {
for (Instruction &I : *BB) {
// We can predicate blocks with calls to assume, as long as we drop them in
// case we flatten the CFG via predication.
if (match(&I, m_Intrinsic<Intrinsic::assume>())) {
- ConditionalAssumes.insert(&I);
+ MaskedOp.insert(&I);
continue;
}
@@ -1345,16 +1349,13 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
}
// We must be able to predicate all blocks that need to be predicated.
- if (blockNeedsPredication(BB)) {
- if (!blockCanBePredicated(BB, SafePointers, MaskedOp,
- ConditionalAssumes)) {
- reportVectorizationFailure(
- "Control flow cannot be substituted for a select",
- "control flow cannot be substituted for a select",
- "NoCFGForSelect", ORE, TheLoop,
- BB->getTerminator());
- return false;
- }
+ if (blockNeedsPredication(BB) &&
+ !blockCanBePredicated(BB, SafePointers, MaskedOp)) {
+ reportVectorizationFailure(
+ "Control flow cannot be substituted for a select",
+ "control flow cannot be substituted for a select", "NoCFGForSelect",
+ ORE, TheLoop, BB->getTerminator());
+ return false;
}
}
@@ -1554,14 +1555,14 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
// The list of pointers that we can safely read and write to remains empty.
SmallPtrSet<Value *, 8> SafePointers;
+ // Collect masked ops in temporary set first to avoid partially populating
+ // MaskedOp if a block cannot be predicated.
SmallPtrSet<const Instruction *, 8> TmpMaskedOp;
- SmallPtrSet<Instruction *, 8> TmpConditionalAssumes;
// Check and mark all blocks for predication, including those that ordinarily
// do not need predication such as the header block.
for (BasicBlock *BB : TheLoop->blocks()) {
- if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp,
- TmpConditionalAssumes)) {
+ if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp)) {
LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as requested.\n");
return false;
}
@@ -1570,9 +1571,6 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n");
MaskedOp.insert(TmpMaskedOp.begin(), TmpMaskedOp.end());
- ConditionalAssumes.insert(TmpConditionalAssumes.begin(),
- TmpConditionalAssumes.end());
-
return true;
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 13357cb06c55..577ce8000de2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -31,6 +31,7 @@
namespace llvm {
class LoopInfo;
+class DominatorTree;
class LoopVectorizationLegality;
class LoopVectorizationCostModel;
class PredicatedScalarEvolution;
@@ -45,13 +46,17 @@ class VPBuilder {
VPBasicBlock *BB = nullptr;
VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator();
+ /// Insert \p VPI in BB at InsertPt if BB is set.
+ VPInstruction *tryInsertInstruction(VPInstruction *VPI) {
+ if (BB)
+ BB->insert(VPI, InsertPt);
+ return VPI;
+ }
+
VPInstruction *createInstruction(unsigned Opcode,
ArrayRef<VPValue *> Operands, DebugLoc DL,
const Twine &Name = "") {
- VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name);
- if (BB)
- BB->insert(Instr, InsertPt);
- return Instr;
+ return tryInsertInstruction(new VPInstruction(Opcode, Operands, DL, Name));
}
VPInstruction *createInstruction(unsigned Opcode,
@@ -62,6 +67,7 @@ class VPBuilder {
public:
VPBuilder() = default;
+ VPBuilder(VPBasicBlock *InsertBB) { setInsertPoint(InsertBB); }
/// Clear the insertion point: created instructions will not be inserted into
/// a block.
@@ -116,10 +122,11 @@ public:
InsertPt = IP;
}
- /// Insert and return the specified instruction.
- VPInstruction *insert(VPInstruction *I) const {
- BB->insert(I, InsertPt);
- return I;
+ /// This specifies that created instructions should be inserted at the
+ /// specified point.
+ void setInsertPoint(VPRecipeBase *IP) {
+ BB = IP->getParent();
+ InsertPt = IP->getIterator();
}
/// Create an N-ary operation with \p Opcode, \p Operands and set \p Inst as
@@ -138,6 +145,13 @@ public:
return createInstruction(Opcode, Operands, DL, Name);
}
+ VPInstruction *createOverflowingOp(unsigned Opcode,
+ std::initializer_list<VPValue *> Operands,
+ VPRecipeWithIRFlags::WrapFlagsTy WrapFlags,
+ DebugLoc DL, const Twine &Name = "") {
+ return tryInsertInstruction(
+ new VPInstruction(Opcode, Operands, WrapFlags, DL, Name));
+ }
VPValue *createNot(VPValue *Operand, DebugLoc DL, const Twine &Name = "") {
return createInstruction(VPInstruction::Not, {Operand}, DL, Name);
}
@@ -158,6 +172,12 @@ public:
Name);
}
+ /// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A
+ /// and \p B.
+ /// TODO: add createFCmp when needed.
+ VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+ DebugLoc DL = {}, const Twine &Name = "");
+
//===--------------------------------------------------------------------===//
// RAII helpers.
//===--------------------------------------------------------------------===//
@@ -268,6 +288,9 @@ class LoopVectorizationPlanner {
/// Loop Info analysis.
LoopInfo *LI;
+ /// The dominator tree.
+ DominatorTree *DT;
+
/// Target Library Info.
const TargetLibraryInfo *TLI;
@@ -298,16 +321,14 @@ class LoopVectorizationPlanner {
VPBuilder Builder;
public:
- LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI,
- const TargetTransformInfo &TTI,
- LoopVectorizationLegality *Legal,
- LoopVectorizationCostModel &CM,
- InterleavedAccessInfo &IAI,
- PredicatedScalarEvolution &PSE,
- const LoopVectorizeHints &Hints,
- OptimizationRemarkEmitter *ORE)
- : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), IAI(IAI),
- PSE(PSE), Hints(Hints), ORE(ORE) {}
+ LoopVectorizationPlanner(
+ Loop *L, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI,
+ const TargetTransformInfo &TTI, LoopVectorizationLegality *Legal,
+ LoopVectorizationCostModel &CM, InterleavedAccessInfo &IAI,
+ PredicatedScalarEvolution &PSE, const LoopVectorizeHints &Hints,
+ OptimizationRemarkEmitter *ORE)
+ : OrigLoop(L), LI(LI), DT(DT), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM),
+ IAI(IAI), PSE(PSE), Hints(Hints), ORE(ORE) {}
/// Plan how to best vectorize, return the best VF and its cost, or
/// std::nullopt if vectorization and interleaving should be avoided up front.
@@ -333,7 +354,7 @@ public:
executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan,
InnerLoopVectorizer &LB, DominatorTree *DT,
bool IsEpilogueVectorization,
- DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr);
+ const DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr);
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void printPlans(raw_ostream &O);
@@ -377,8 +398,7 @@ private:
/// returned VPlan is valid for. If no VPlan can be built for the input range,
/// set the largest included VF to the maximum VF for which no plan could be
/// built.
- std::optional<VPlanPtr> tryToBuildVPlanWithVPRecipes(
- VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions);
+ VPlanPtr tryToBuildVPlanWithVPRecipes(VFRange &Range);
/// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive,
/// according to the information gathered by Legal when it checked if it is
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d7e40e8ef978..f82e161fb846 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -27,7 +27,7 @@
//
// There is a development effort going on to migrate loop vectorizer to the
// VPlan infrastructure and to introduce outer loop vectorization support (see
-// docs/Proposal/VectorizationPlan.rst and
+// docs/VectorizationPlan.rst and
// http://lists.llvm.org/pipermail/llvm-dev/2017-December/119523.html). For this
// purpose, we temporarily introduced the VPlan-native vectorization path: an
// alternative vectorization path that is natively implemented on top of the
@@ -57,6 +57,7 @@
#include "LoopVectorizationPlanner.h"
#include "VPRecipeBuilder.h"
#include "VPlan.h"
+#include "VPlanAnalysis.h"
#include "VPlanHCFGBuilder.h"
#include "VPlanTransforms.h"
#include "llvm/ADT/APInt.h"
@@ -111,10 +112,12 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
@@ -390,6 +393,21 @@ static cl::opt<cl::boolOrDefault> ForceSafeDivisor(
cl::desc(
"Override cost based safe divisor widening for div/rem instructions"));
+static cl::opt<bool> UseWiderVFIfCallVariantsPresent(
+ "vectorizer-maximize-bandwidth-for-vector-calls", cl::init(true),
+ cl::Hidden,
+ cl::desc("Try wider VFs if they enable the use of vector variants"));
+
+// Likelyhood of bypassing the vectorized loop because assumptions about SCEV
+// variables not overflowing do not hold. See `emitSCEVChecks`.
+static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because pointers overlap. See
+// `emitMemRuntimeChecks`.
+static constexpr uint32_t MemCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because there are zero trips left
+// after prolog. See `emitIterationCountCheck`.
+static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
+
/// A helper function that returns true if the given type is irregular. The
/// type is irregular if its allocated size doesn't equal the store size of an
/// element of the corresponding vector type.
@@ -408,13 +426,6 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL) {
/// we always assume predicated blocks have a 50% chance of executing.
static unsigned getReciprocalPredBlockProb() { return 2; }
-/// A helper function that returns an integer or floating-point constant with
-/// value C.
-static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
- return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C)
- : ConstantFP::get(Ty, C);
-}
-
/// Returns "best known" trip count for the specified loop \p L as defined by
/// the following procedure:
/// 1) Returns exact trip count if it is known.
@@ -556,10 +567,6 @@ public:
const VPIteration &Instance,
VPTransformState &State);
- /// Construct the vector value of a scalarized value \p V one lane at a time.
- void packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance,
- VPTransformState &State);
-
/// Try to vectorize interleaved access group \p Group with the base address
/// given in \p Addr, optionally masking the vector operations if \p
/// BlockInMask is non-null. Use \p State to translate given VPValues to IR
@@ -634,10 +641,6 @@ protected:
/// the block that was created for it.
void sinkScalarOperands(Instruction *PredInst);
- /// Shrinks vector element sizes to the smallest bitwidth they can be legally
- /// represented as.
- void truncateToMinimalBitwidths(VPTransformState &State);
-
/// Returns (and creates if needed) the trip count of the widened loop.
Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock);
@@ -943,21 +946,21 @@ protected:
/// Look for a meaningful debug location on the instruction or it's
/// operands.
-static Instruction *getDebugLocFromInstOrOperands(Instruction *I) {
+static DebugLoc getDebugLocFromInstOrOperands(Instruction *I) {
if (!I)
- return I;
+ return DebugLoc();
DebugLoc Empty;
if (I->getDebugLoc() != Empty)
- return I;
+ return I->getDebugLoc();
for (Use &Op : I->operands()) {
if (Instruction *OpInst = dyn_cast<Instruction>(Op))
if (OpInst->getDebugLoc() != Empty)
- return OpInst;
+ return OpInst->getDebugLoc();
}
- return I;
+ return I->getDebugLoc();
}
/// Write a \p DebugMsg about vectorization to the debug output stream. If \p I
@@ -1021,14 +1024,6 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
}
-static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy,
- ElementCount VF) {
- assert(FTy->isFloatingPointTy() && "Expected floating point type!");
- Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits());
- Value *RuntimeVF = getRuntimeVF(B, IntTy, VF);
- return B.CreateUIToFP(RuntimeVF, FTy);
-}
-
void reportVectorizationFailure(const StringRef DebugMsg,
const StringRef OREMsg, const StringRef ORETag,
OptimizationRemarkEmitter *ORE, Loop *TheLoop,
@@ -1050,6 +1045,23 @@ void reportVectorizationInfo(const StringRef Msg, const StringRef ORETag,
<< Msg);
}
+/// Report successful vectorization of the loop. In case an outer loop is
+/// vectorized, prepend "outer" to the vectorization remark.
+static void reportVectorization(OptimizationRemarkEmitter *ORE, Loop *TheLoop,
+ VectorizationFactor VF, unsigned IC) {
+ LLVM_DEBUG(debugVectorizationMessage(
+ "Vectorizing: ", TheLoop->isInnermost() ? "innermost loop" : "outer loop",
+ nullptr));
+ StringRef LoopType = TheLoop->isInnermost() ? "" : "outer ";
+ ORE->emit([&]() {
+ return OptimizationRemark(LV_NAME, "Vectorized", TheLoop->getStartLoc(),
+ TheLoop->getHeader())
+ << "vectorized " << LoopType << "loop (vectorization width: "
+ << ore::NV("VectorizationFactor", VF.Width)
+ << ", interleaved count: " << ore::NV("InterleaveCount", IC) << ")";
+ });
+}
+
} // end namespace llvm
#ifndef NDEBUG
@@ -1104,7 +1116,8 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes(
if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) {
RecWithFlags->dropPoisonGeneratingFlags();
} else {
- Instruction *Instr = CurRec->getUnderlyingInstr();
+ Instruction *Instr = dyn_cast_or_null<Instruction>(
+ CurRec->getVPSingleValue()->getUnderlyingValue());
(void)Instr;
assert((!Instr || !Instr->hasPoisonGeneratingFlags()) &&
"found instruction with poison generating flags not covered by "
@@ -1247,6 +1260,13 @@ public:
/// avoid redundant calculations.
void setCostBasedWideningDecision(ElementCount VF);
+ /// A call may be vectorized in different ways depending on whether we have
+ /// vectorized variants available and whether the target supports masking.
+ /// This function analyzes all calls in the function at the supplied VF,
+ /// makes a decision based on the costs of available options, and stores that
+ /// decision in a map for use in planning and plan execution.
+ void setVectorizedCallDecision(ElementCount VF);
+
/// A struct that represents some properties of the register usage
/// of a loop.
struct RegisterUsage {
@@ -1270,7 +1290,7 @@ public:
void collectElementTypesForWidening();
/// Split reductions into those that happen in the loop, and those that happen
- /// outside. In loop reductions are collected into InLoopReductionChains.
+ /// outside. In loop reductions are collected into InLoopReductions.
void collectInLoopReductions();
/// Returns true if we should use strict in-order reductions for the given
@@ -1358,7 +1378,9 @@ public:
CM_Widen_Reverse, // For consecutive accesses with stride -1.
CM_Interleave,
CM_GatherScatter,
- CM_Scalarize
+ CM_Scalarize,
+ CM_VectorCall,
+ CM_IntrinsicCall
};
/// Save vectorization decision \p W and \p Cost taken by the cost model for
@@ -1414,6 +1436,29 @@ public:
return WideningDecisions[InstOnVF].second;
}
+ struct CallWideningDecision {
+ InstWidening Kind;
+ Function *Variant;
+ Intrinsic::ID IID;
+ std::optional<unsigned> MaskPos;
+ InstructionCost Cost;
+ };
+
+ void setCallWideningDecision(CallInst *CI, ElementCount VF, InstWidening Kind,
+ Function *Variant, Intrinsic::ID IID,
+ std::optional<unsigned> MaskPos,
+ InstructionCost Cost) {
+ assert(!VF.isScalar() && "Expected vector VF");
+ CallWideningDecisions[std::make_pair(CI, VF)] = {Kind, Variant, IID,
+ MaskPos, Cost};
+ }
+
+ CallWideningDecision getCallWideningDecision(CallInst *CI,
+ ElementCount VF) const {
+ assert(!VF.isScalar() && "Expected vector VF");
+ return CallWideningDecisions.at(std::make_pair(CI, VF));
+ }
+
/// Return True if instruction \p I is an optimizable truncate whose operand
/// is an induction variable. Such a truncate will be removed by adding a new
/// induction variable with the destination type.
@@ -1447,11 +1492,15 @@ public:
/// Collect Uniform and Scalar values for the given \p VF.
/// The sets depend on CM decision for Load/Store instructions
/// that may be vectorized as interleave, gather-scatter or scalarized.
+ /// Also make a decision on what to do about call instructions in the loop
+ /// at that VF -- scalarize, call a known vector routine, or call a
+ /// vector intrinsic.
void collectUniformsAndScalars(ElementCount VF) {
// Do the analysis once.
if (VF.isScalar() || Uniforms.contains(VF))
return;
setCostBasedWideningDecision(VF);
+ setVectorizedCallDecision(VF);
collectLoopUniforms(VF);
collectLoopScalars(VF);
}
@@ -1606,20 +1655,9 @@ public:
return foldTailByMasking() || Legal->blockNeedsPredication(BB);
}
- /// A SmallMapVector to store the InLoop reduction op chains, mapping phi
- /// nodes to the chain of instructions representing the reductions. Uses a
- /// MapVector to ensure deterministic iteration order.
- using ReductionChainMap =
- SmallMapVector<PHINode *, SmallVector<Instruction *, 4>, 4>;
-
- /// Return the chain of instructions representing an inloop reduction.
- const ReductionChainMap &getInLoopReductionChains() const {
- return InLoopReductionChains;
- }
-
/// Returns true if the Phi is part of an inloop reduction.
bool isInLoopReduction(PHINode *Phi) const {
- return InLoopReductionChains.count(Phi);
+ return InLoopReductions.contains(Phi);
}
/// Estimate cost of an intrinsic call instruction CI if it were vectorized
@@ -1629,16 +1667,13 @@ public:
/// Estimate cost of a call instruction CI if it were vectorized with factor
/// VF. Return the cost of the instruction, including scalarization overhead
- /// if it's needed. The flag NeedToScalarize shows if the call needs to be
- /// scalarized -
- /// i.e. either vector version isn't available, or is too expensive.
- InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF,
- Function **Variant,
- bool *NeedsMask = nullptr) const;
+ /// if it's needed.
+ InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF) const;
/// Invalidates decisions already taken by the cost model.
void invalidateCostModelingDecisions() {
WideningDecisions.clear();
+ CallWideningDecisions.clear();
Uniforms.clear();
Scalars.clear();
}
@@ -1675,14 +1710,14 @@ private:
/// elements is a power-of-2 larger than zero. If scalable vectorization is
/// disabled or unsupported, then the scalable part will be equal to
/// ElementCount::getScalable(0).
- FixedScalableVFPair computeFeasibleMaxVF(unsigned ConstTripCount,
+ FixedScalableVFPair computeFeasibleMaxVF(unsigned MaxTripCount,
ElementCount UserVF,
bool FoldTailByMasking);
/// \return the maximized element count based on the targets vector
/// registers and the loop trip-count, but limited to a maximum safe VF.
/// This is a helper function of computeFeasibleMaxVF.
- ElementCount getMaximizedVFForTarget(unsigned ConstTripCount,
+ ElementCount getMaximizedVFForTarget(unsigned MaxTripCount,
unsigned SmallestType,
unsigned WidestType,
ElementCount MaxSafeVF,
@@ -1705,7 +1740,7 @@ private:
/// part of that pattern.
std::optional<InstructionCost>
getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
- TTI::TargetCostKind CostKind);
+ TTI::TargetCostKind CostKind) const;
/// Calculate vectorization cost of memory instruction \p I.
InstructionCost getMemoryInstructionCost(Instruction *I, ElementCount VF);
@@ -1783,15 +1818,12 @@ private:
/// scalarized.
DenseMap<ElementCount, SmallPtrSet<Instruction *, 4>> ForcedScalars;
- /// PHINodes of the reductions that should be expanded in-loop along with
- /// their associated chains of reduction operations, in program order from top
- /// (PHI) to bottom
- ReductionChainMap InLoopReductionChains;
+ /// PHINodes of the reductions that should be expanded in-loop.
+ SmallPtrSet<PHINode *, 4> InLoopReductions;
/// A Map of inloop reduction operations and their immediate chain operand.
/// FIXME: This can be removed once reductions can be costed correctly in
- /// vplan. This was added to allow quick lookup to the inloop operations,
- /// without having to loop through InLoopReductionChains.
+ /// VPlan. This was added to allow quick lookup of the inloop operations.
DenseMap<Instruction *, Instruction *> InLoopReductionImmediateChains;
/// Returns the expected difference in cost from scalarizing the expression
@@ -1830,6 +1862,11 @@ private:
DecisionList WideningDecisions;
+ using CallDecisionList =
+ DenseMap<std::pair<CallInst *, ElementCount>, CallWideningDecision>;
+
+ CallDecisionList CallWideningDecisions;
+
/// Returns true if \p V is expected to be vectorized and it needs to be
/// extracted.
bool needsExtract(Value *V, ElementCount VF) const {
@@ -1933,12 +1970,14 @@ class GeneratedRTChecks {
SCEVExpander MemCheckExp;
bool CostTooHigh = false;
+ const bool AddBranchWeights;
public:
GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
- TargetTransformInfo *TTI, const DataLayout &DL)
+ TargetTransformInfo *TTI, const DataLayout &DL,
+ bool AddBranchWeights)
: DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
- MemCheckExp(SE, DL, "scev.check") {}
+ MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
/// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
/// accurately estimate the cost of the runtime checks. The blocks are
@@ -1990,9 +2029,9 @@ public:
},
IC);
} else {
- MemRuntimeCheckCond =
- addRuntimeChecks(MemCheckBlock->getTerminator(), L,
- RtPtrChecking.getChecks(), MemCheckExp);
+ MemRuntimeCheckCond = addRuntimeChecks(
+ MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
+ MemCheckExp, VectorizerParams::HoistRuntimeChecks);
}
assert(MemRuntimeCheckCond &&
"no RT checks generated although RtPtrChecking "
@@ -2131,8 +2170,10 @@ public:
DT->addNewBlock(SCEVCheckBlock, Pred);
DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock);
- ReplaceInstWithInst(SCEVCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, Cond));
+ BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond);
+ if (AddBranchWeights)
+ setBranchWeights(BI, SCEVCheckBypassWeights);
+ ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI);
return SCEVCheckBlock;
}
@@ -2156,9 +2197,12 @@ public:
if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
PL->addBasicBlockToLoop(MemCheckBlock, *LI);
- ReplaceInstWithInst(
- MemCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
+ if (AddBranchWeights) {
+ setBranchWeights(BI, MemCheckBypassWeights);
+ }
+ ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI);
MemCheckBlock->getTerminator()->setDebugLoc(
Pred->getTerminator()->getDebugLoc());
@@ -2252,157 +2296,17 @@ static void collectSupportedLoops(Loop &L, LoopInfo *LI,
// LoopVectorizationCostModel and LoopVectorizationPlanner.
//===----------------------------------------------------------------------===//
-/// This function adds
-/// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...)
-/// to each vector element of Val. The sequence starts at StartIndex.
-/// \p Opcode is relevant for FP induction variable.
-static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step,
- Instruction::BinaryOps BinOp, ElementCount VF,
- IRBuilderBase &Builder) {
- assert(VF.isVector() && "only vector VFs are supported");
-
- // Create and check the types.
- auto *ValVTy = cast<VectorType>(Val->getType());
- ElementCount VLen = ValVTy->getElementCount();
-
- Type *STy = Val->getType()->getScalarType();
- assert((STy->isIntegerTy() || STy->isFloatingPointTy()) &&
- "Induction Step must be an integer or FP");
- assert(Step->getType() == STy && "Step has wrong type");
-
- SmallVector<Constant *, 8> Indices;
-
- // Create a vector of consecutive numbers from zero to VF.
- VectorType *InitVecValVTy = ValVTy;
- if (STy->isFloatingPointTy()) {
- Type *InitVecValSTy =
- IntegerType::get(STy->getContext(), STy->getScalarSizeInBits());
- InitVecValVTy = VectorType::get(InitVecValSTy, VLen);
- }
- Value *InitVec = Builder.CreateStepVector(InitVecValVTy);
-
- // Splat the StartIdx
- Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx);
-
- if (STy->isIntegerTy()) {
- InitVec = Builder.CreateAdd(InitVec, StartIdxSplat);
- Step = Builder.CreateVectorSplat(VLen, Step);
- assert(Step->getType() == Val->getType() && "Invalid step vec");
- // FIXME: The newly created binary instructions should contain nsw/nuw
- // flags, which can be found from the original scalar operations.
- Step = Builder.CreateMul(InitVec, Step);
- return Builder.CreateAdd(Val, Step, "induction");
- }
-
- // Floating point induction.
- assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) &&
- "Binary Opcode should be specified for FP induction");
- InitVec = Builder.CreateUIToFP(InitVec, ValVTy);
- InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat);
-
- Step = Builder.CreateVectorSplat(VLen, Step);
- Value *MulOp = Builder.CreateFMul(InitVec, Step);
- return Builder.CreateBinOp(BinOp, Val, MulOp, "induction");
-}
-
-/// Compute scalar induction steps. \p ScalarIV is the scalar induction
-/// variable on which to base the steps, \p Step is the size of the step.
-static void buildScalarSteps(Value *ScalarIV, Value *Step,
- const InductionDescriptor &ID, VPValue *Def,
- VPTransformState &State) {
- IRBuilderBase &Builder = State.Builder;
-
- // Ensure step has the same type as that of scalar IV.
- Type *ScalarIVTy = ScalarIV->getType()->getScalarType();
- if (ScalarIVTy != Step->getType()) {
- // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to
- // avoid separate truncate here.
- assert(Step->getType()->isIntegerTy() &&
- "Truncation requires an integer step");
- Step = State.Builder.CreateTrunc(Step, ScalarIVTy);
- }
-
- // We build scalar steps for both integer and floating-point induction
- // variables. Here, we determine the kind of arithmetic we will perform.
- Instruction::BinaryOps AddOp;
- Instruction::BinaryOps MulOp;
- if (ScalarIVTy->isIntegerTy()) {
- AddOp = Instruction::Add;
- MulOp = Instruction::Mul;
- } else {
- AddOp = ID.getInductionOpcode();
- MulOp = Instruction::FMul;
- }
-
- // Determine the number of scalars we need to generate for each unroll
- // iteration.
- bool FirstLaneOnly = vputils::onlyFirstLaneUsed(Def);
- // Compute the scalar steps and save the results in State.
- Type *IntStepTy = IntegerType::get(ScalarIVTy->getContext(),
- ScalarIVTy->getScalarSizeInBits());
- Type *VecIVTy = nullptr;
- Value *UnitStepVec = nullptr, *SplatStep = nullptr, *SplatIV = nullptr;
- if (!FirstLaneOnly && State.VF.isScalable()) {
- VecIVTy = VectorType::get(ScalarIVTy, State.VF);
- UnitStepVec =
- Builder.CreateStepVector(VectorType::get(IntStepTy, State.VF));
- SplatStep = Builder.CreateVectorSplat(State.VF, Step);
- SplatIV = Builder.CreateVectorSplat(State.VF, ScalarIV);
- }
-
- unsigned StartPart = 0;
- unsigned EndPart = State.UF;
- unsigned StartLane = 0;
- unsigned EndLane = FirstLaneOnly ? 1 : State.VF.getKnownMinValue();
- if (State.Instance) {
- StartPart = State.Instance->Part;
- EndPart = StartPart + 1;
- StartLane = State.Instance->Lane.getKnownLane();
- EndLane = StartLane + 1;
- }
- for (unsigned Part = StartPart; Part < EndPart; ++Part) {
- Value *StartIdx0 = createStepForVF(Builder, IntStepTy, State.VF, Part);
-
- if (!FirstLaneOnly && State.VF.isScalable()) {
- auto *SplatStartIdx = Builder.CreateVectorSplat(State.VF, StartIdx0);
- auto *InitVec = Builder.CreateAdd(SplatStartIdx, UnitStepVec);
- if (ScalarIVTy->isFloatingPointTy())
- InitVec = Builder.CreateSIToFP(InitVec, VecIVTy);
- auto *Mul = Builder.CreateBinOp(MulOp, InitVec, SplatStep);
- auto *Add = Builder.CreateBinOp(AddOp, SplatIV, Mul);
- State.set(Def, Add, Part);
- // It's useful to record the lane values too for the known minimum number
- // of elements so we do those below. This improves the code quality when
- // trying to extract the first element, for example.
- }
-
- if (ScalarIVTy->isFloatingPointTy())
- StartIdx0 = Builder.CreateSIToFP(StartIdx0, ScalarIVTy);
-
- for (unsigned Lane = StartLane; Lane < EndLane; ++Lane) {
- Value *StartIdx = Builder.CreateBinOp(
- AddOp, StartIdx0, getSignedIntOrFpConstant(ScalarIVTy, Lane));
- // The step returned by `createStepForVF` is a runtime-evaluated value
- // when VF is scalable. Otherwise, it should be folded into a Constant.
- assert((State.VF.isScalable() || isa<Constant>(StartIdx)) &&
- "Expected StartIdx to be folded to a constant when VF is not "
- "scalable");
- auto *Mul = Builder.CreateBinOp(MulOp, StartIdx, Step);
- auto *Add = Builder.CreateBinOp(AddOp, ScalarIV, Mul);
- State.set(Def, Add, VPIteration(Part, Lane));
- }
- }
-}
-
/// Compute the transformed value of Index at offset StartValue using step
/// StepValue.
/// For integer induction, returns StartValue + Index * StepValue.
/// For pointer induction, returns StartValue[Index * StepValue].
/// FIXME: The newly created binary instructions should contain nsw/nuw
/// flags, which can be found from the original scalar operations.
-static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
- Value *StartValue, Value *Step,
- const InductionDescriptor &ID) {
+static Value *
+emitTransformedIndex(IRBuilderBase &B, Value *Index, Value *StartValue,
+ Value *Step,
+ InductionDescriptor::InductionKind InductionKind,
+ const BinaryOperator *InductionBinOp) {
Type *StepTy = Step->getType();
Value *CastedIndex = StepTy->isIntegerTy()
? B.CreateSExtOrTrunc(Index, StepTy)
@@ -2446,7 +2350,7 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
return B.CreateMul(X, Y);
};
- switch (ID.getKind()) {
+ switch (InductionKind) {
case InductionDescriptor::IK_IntInduction: {
assert(!isa<VectorType>(Index->getType()) &&
"Vector indices not supported for integer inductions yet");
@@ -2464,7 +2368,6 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
assert(!isa<VectorType>(Index->getType()) &&
"Vector indices not supported for FP inductions yet");
assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value");
- auto InductionBinOp = ID.getInductionBinOp();
assert(InductionBinOp &&
(InductionBinOp->getOpcode() == Instruction::FAdd ||
InductionBinOp->getOpcode() == Instruction::FSub) &&
@@ -2524,17 +2427,6 @@ static bool isIndvarOverflowCheckKnownFalse(
return false;
}
-void InnerLoopVectorizer::packScalarIntoVectorValue(VPValue *Def,
- const VPIteration &Instance,
- VPTransformState &State) {
- Value *ScalarInst = State.get(Def, Instance);
- Value *VectorValue = State.get(Def, Instance.Part);
- VectorValue = Builder.CreateInsertElement(
- VectorValue, ScalarInst,
- Instance.Lane.getAsRuntimeExpr(State.Builder, VF));
- State.set(Def, VectorValue, Instance.Part);
-}
-
// Return whether we allow using masked interleave-groups (for dealing with
// strided loads/stores that reside in predicated blocks, or for dealing
// with gaps).
@@ -2612,7 +2504,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
for (unsigned Part = 0; Part < UF; Part++) {
Value *AddrPart = State.get(Addr, VPIteration(Part, 0));
- State.setDebugLocFromInst(AddrPart);
+ if (auto *I = dyn_cast<Instruction>(AddrPart))
+ State.setDebugLocFrom(I->getDebugLoc());
// Notice current instruction could be any index. Need to adjust the address
// to the member of index 0.
@@ -2630,14 +2523,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts()))
InBounds = gep->isInBounds();
AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds);
-
- // Cast to the vector pointer type.
- unsigned AddressSpace = AddrPart->getType()->getPointerAddressSpace();
- Type *PtrTy = VecTy->getPointerTo(AddressSpace);
- AddrParts.push_back(Builder.CreateBitCast(AddrPart, PtrTy));
+ AddrParts.push_back(AddrPart);
}
- State.setDebugLocFromInst(Instr);
+ State.setDebugLocFrom(Instr->getDebugLoc());
Value *PoisonVec = PoisonValue::get(VecTy);
auto CreateGroupMask = [this, &BlockInMask, &State, &InterleaveFactor](
@@ -2835,13 +2724,20 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr,
bool IsVoidRetTy = Instr->getType()->isVoidTy();
Instruction *Cloned = Instr->clone();
- if (!IsVoidRetTy)
+ if (!IsVoidRetTy) {
Cloned->setName(Instr->getName() + ".cloned");
+#if !defined(NDEBUG)
+ // Verify that VPlan type inference results agree with the type of the
+ // generated values.
+ assert(State.TypeAnalysis.inferScalarType(RepRecipe) == Cloned->getType() &&
+ "inferred type and type from generated instructions do not match");
+#endif
+ }
RepRecipe->setFlags(Cloned);
- if (Instr->getDebugLoc())
- State.setDebugLocFromInst(Instr);
+ if (auto DL = Instr->getDebugLoc())
+ State.setDebugLocFrom(DL);
// Replace the operands of the cloned instructions with their scalar
// equivalents in the new loop.
@@ -3019,9 +2915,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
// dominator of the exit blocks.
DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock);
- ReplaceInstWithInst(
- TCCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+ setBranchWeights(BI, MinItersBypassWeights);
+ ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
LoopBypassBlocks.push_back(TCCheckBlock);
}
@@ -3151,15 +3049,17 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue(
if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp()))
B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags());
- EndValue =
- emitTransformedIndex(B, VectorTripCount, II.getStartValue(), Step, II);
+ EndValue = emitTransformedIndex(B, VectorTripCount, II.getStartValue(),
+ Step, II.getKind(), II.getInductionBinOp());
EndValue->setName("ind.end");
// Compute the end value for the additional bypass (if applicable).
if (AdditionalBypass.first) {
- B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt()));
- EndValueFromAdditionalBypass = emitTransformedIndex(
- B, AdditionalBypass.second, II.getStartValue(), Step, II);
+ B.SetInsertPoint(AdditionalBypass.first,
+ AdditionalBypass.first->getFirstInsertionPt());
+ EndValueFromAdditionalBypass =
+ emitTransformedIndex(B, AdditionalBypass.second, II.getStartValue(),
+ Step, II.getKind(), II.getInductionBinOp());
EndValueFromAdditionalBypass->setName("ind.end");
}
}
@@ -3240,16 +3140,25 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
// 3) Otherwise, construct a runtime check.
if (!Cost->requiresScalarEpilogue(VF.isVector()) &&
!Cost->foldTailByMasking()) {
- Instruction *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ,
- Count, VectorTripCount, "cmp.n",
- LoopMiddleBlock->getTerminator());
-
// Here we use the same DebugLoc as the scalar loop latch terminator instead
// of the corresponding compare because they may have ended up with
// different line numbers and we want to avoid awkward line stepping while
// debugging. Eg. if the compare has got a line number inside the loop.
- CmpN->setDebugLoc(ScalarLatchTerm->getDebugLoc());
- cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN);
+ // TODO: At the moment, CreateICmpEQ will simplify conditions with constant
+ // operands. Perform simplification directly on VPlan once the branch is
+ // modeled there.
+ IRBuilder<> B(LoopMiddleBlock->getTerminator());
+ B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc());
+ Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n");
+ BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator());
+ BI.setCondition(CmpN);
+ if (hasBranchWeightMD(*ScalarLatchTerm)) {
+ // Assume that `Count % VectorTripCount` is equally distributed.
+ unsigned TripCount = UF * VF.getKnownMinValue();
+ assert(TripCount > 0 && "trip count should not be zero");
+ const uint32_t Weights[] = {1, TripCount - 1};
+ setBranchWeights(BI, Weights);
+ }
}
#ifdef EXPENSIVE_CHECKS
@@ -3373,7 +3282,8 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue()
: State.get(StepVPV, {0, 0});
Value *Escape =
- emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, II);
+ emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step,
+ II.getKind(), II.getInductionBinOp());
Escape->setName("ind.escape");
MissingVals[UI] = Escape;
}
@@ -3445,76 +3355,33 @@ static void cse(BasicBlock *BB) {
}
}
-InstructionCost LoopVectorizationCostModel::getVectorCallCost(
- CallInst *CI, ElementCount VF, Function **Variant, bool *NeedsMask) const {
- Function *F = CI->getCalledFunction();
- Type *ScalarRetTy = CI->getType();
- SmallVector<Type *, 4> Tys, ScalarTys;
- bool MaskRequired = Legal->isMaskRequired(CI);
- for (auto &ArgOp : CI->args())
- ScalarTys.push_back(ArgOp->getType());
+InstructionCost
+LoopVectorizationCostModel::getVectorCallCost(CallInst *CI,
+ ElementCount VF) const {
+ // We only need to calculate a cost if the VF is scalar; for actual vectors
+ // we should already have a pre-calculated cost at each VF.
+ if (!VF.isScalar())
+ return CallWideningDecisions.at(std::make_pair(CI, VF)).Cost;
- // Estimate cost of scalarized vector call. The source operands are assumed
- // to be vectors, so we need to extract individual elements from there,
- // execute VF scalar calls, and then gather the result into the vector return
- // value.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- InstructionCost ScalarCallCost =
- TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys, CostKind);
- if (VF.isScalar())
- return ScalarCallCost;
-
- // Compute corresponding vector type for return value and arguments.
- Type *RetTy = ToVectorTy(ScalarRetTy, VF);
- for (Type *ScalarTy : ScalarTys)
- Tys.push_back(ToVectorTy(ScalarTy, VF));
-
- // Compute costs of unpacking argument values for the scalar calls and
- // packing the return values to a vector.
- InstructionCost ScalarizationCost =
- getScalarizationOverhead(CI, VF, CostKind);
-
- InstructionCost Cost =
- ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost;
+ Type *RetTy = CI->getType();
+ if (RecurrenceDescriptor::isFMulAddIntrinsic(CI))
+ if (auto RedCost = getReductionPatternCost(CI, VF, RetTy, CostKind))
+ return *RedCost;
- // If we can't emit a vector call for this function, then the currently found
- // cost is the cost we need to return.
- InstructionCost MaskCost = 0;
- VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
- if (NeedsMask)
- *NeedsMask = MaskRequired;
- Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
- // If we want an unmasked vector function but can't find one matching the VF,
- // maybe we can find vector function that does use a mask and synthesize
- // an all-true mask.
- if (!VecFunc && !MaskRequired) {
- Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
- VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
- // If we found one, add in the cost of creating a mask
- if (VecFunc) {
- if (NeedsMask)
- *NeedsMask = true;
- MaskCost = TTI.getShuffleCost(
- TargetTransformInfo::SK_Broadcast,
- VectorType::get(
- IntegerType::getInt1Ty(VecFunc->getFunctionType()->getContext()),
- VF));
- }
- }
+ SmallVector<Type *, 4> Tys;
+ for (auto &ArgOp : CI->args())
+ Tys.push_back(ArgOp->getType());
- // We don't support masked function calls yet, but we can scalarize a
- // masked call with branches (unless VF is scalable).
- if (!TLI || CI->isNoBuiltin() || !VecFunc)
- return VF.isScalable() ? InstructionCost::getInvalid() : Cost;
+ InstructionCost ScalarCallCost =
+ TTI.getCallInstrCost(CI->getCalledFunction(), RetTy, Tys, CostKind);
- // If the corresponding vector cost is cheaper, return its cost.
- InstructionCost VectorCallCost =
- TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
- if (VectorCallCost < Cost) {
- *Variant = VecFunc;
- Cost = VectorCallCost;
+ // If this is an intrinsic we may have a lower cost for it.
+ if (getVectorIntrinsicIDForCall(CI, TLI)) {
+ InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF);
+ return std::min(ScalarCallCost, IntrinsicCost);
}
- return Cost;
+ return ScalarCallCost;
}
static Type *MaybeVectorizeType(Type *Elt, ElementCount VF) {
@@ -3558,146 +3425,8 @@ static Type *largestIntegerVectorType(Type *T1, Type *T2) {
return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2;
}
-void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) {
- // For every instruction `I` in MinBWs, truncate the operands, create a
- // truncated version of `I` and reextend its result. InstCombine runs
- // later and will remove any ext/trunc pairs.
- SmallPtrSet<Value *, 4> Erased;
- for (const auto &KV : Cost->getMinimalBitwidths()) {
- // If the value wasn't vectorized, we must maintain the original scalar
- // type. The absence of the value from State indicates that it
- // wasn't vectorized.
- // FIXME: Should not rely on getVPValue at this point.
- VPValue *Def = State.Plan->getVPValue(KV.first, true);
- if (!State.hasAnyVectorValue(Def))
- continue;
- for (unsigned Part = 0; Part < UF; ++Part) {
- Value *I = State.get(Def, Part);
- if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I))
- continue;
- Type *OriginalTy = I->getType();
- Type *ScalarTruncatedTy =
- IntegerType::get(OriginalTy->getContext(), KV.second);
- auto *TruncatedTy = VectorType::get(
- ScalarTruncatedTy, cast<VectorType>(OriginalTy)->getElementCount());
- if (TruncatedTy == OriginalTy)
- continue;
-
- IRBuilder<> B(cast<Instruction>(I));
- auto ShrinkOperand = [&](Value *V) -> Value * {
- if (auto *ZI = dyn_cast<ZExtInst>(V))
- if (ZI->getSrcTy() == TruncatedTy)
- return ZI->getOperand(0);
- return B.CreateZExtOrTrunc(V, TruncatedTy);
- };
-
- // The actual instruction modification depends on the instruction type,
- // unfortunately.
- Value *NewI = nullptr;
- if (auto *BO = dyn_cast<BinaryOperator>(I)) {
- NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)),
- ShrinkOperand(BO->getOperand(1)));
-
- // Any wrapping introduced by shrinking this operation shouldn't be
- // considered undefined behavior. So, we can't unconditionally copy
- // arithmetic wrapping flags to NewI.
- cast<BinaryOperator>(NewI)->copyIRFlags(I, /*IncludeWrapFlags=*/false);
- } else if (auto *CI = dyn_cast<ICmpInst>(I)) {
- NewI =
- B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)),
- ShrinkOperand(CI->getOperand(1)));
- } else if (auto *SI = dyn_cast<SelectInst>(I)) {
- NewI = B.CreateSelect(SI->getCondition(),
- ShrinkOperand(SI->getTrueValue()),
- ShrinkOperand(SI->getFalseValue()));
- } else if (auto *CI = dyn_cast<CastInst>(I)) {
- switch (CI->getOpcode()) {
- default:
- llvm_unreachable("Unhandled cast!");
- case Instruction::Trunc:
- NewI = ShrinkOperand(CI->getOperand(0));
- break;
- case Instruction::SExt:
- NewI = B.CreateSExtOrTrunc(
- CI->getOperand(0),
- smallestIntegerVectorType(OriginalTy, TruncatedTy));
- break;
- case Instruction::ZExt:
- NewI = B.CreateZExtOrTrunc(
- CI->getOperand(0),
- smallestIntegerVectorType(OriginalTy, TruncatedTy));
- break;
- }
- } else if (auto *SI = dyn_cast<ShuffleVectorInst>(I)) {
- auto Elements0 =
- cast<VectorType>(SI->getOperand(0)->getType())->getElementCount();
- auto *O0 = B.CreateZExtOrTrunc(
- SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0));
- auto Elements1 =
- cast<VectorType>(SI->getOperand(1)->getType())->getElementCount();
- auto *O1 = B.CreateZExtOrTrunc(
- SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1));
-
- NewI = B.CreateShuffleVector(O0, O1, SI->getShuffleMask());
- } else if (isa<LoadInst>(I) || isa<PHINode>(I)) {
- // Don't do anything with the operands, just extend the result.
- continue;
- } else if (auto *IE = dyn_cast<InsertElementInst>(I)) {
- auto Elements =
- cast<VectorType>(IE->getOperand(0)->getType())->getElementCount();
- auto *O0 = B.CreateZExtOrTrunc(
- IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements));
- auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy);
- NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2));
- } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) {
- auto Elements =
- cast<VectorType>(EE->getOperand(0)->getType())->getElementCount();
- auto *O0 = B.CreateZExtOrTrunc(
- EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements));
- NewI = B.CreateExtractElement(O0, EE->getOperand(2));
- } else {
- // If we don't know what to do, be conservative and don't do anything.
- continue;
- }
-
- // Lastly, extend the result.
- NewI->takeName(cast<Instruction>(I));
- Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy);
- I->replaceAllUsesWith(Res);
- cast<Instruction>(I)->eraseFromParent();
- Erased.insert(I);
- State.reset(Def, Res, Part);
- }
- }
-
- // We'll have created a bunch of ZExts that are now parentless. Clean up.
- for (const auto &KV : Cost->getMinimalBitwidths()) {
- // If the value wasn't vectorized, we must maintain the original scalar
- // type. The absence of the value from State indicates that it
- // wasn't vectorized.
- // FIXME: Should not rely on getVPValue at this point.
- VPValue *Def = State.Plan->getVPValue(KV.first, true);
- if (!State.hasAnyVectorValue(Def))
- continue;
- for (unsigned Part = 0; Part < UF; ++Part) {
- Value *I = State.get(Def, Part);
- ZExtInst *Inst = dyn_cast<ZExtInst>(I);
- if (Inst && Inst->use_empty()) {
- Value *NewI = Inst->getOperand(0);
- Inst->eraseFromParent();
- State.reset(Def, NewI, Part);
- }
- }
- }
-}
-
void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
VPlan &Plan) {
- // Insert truncates and extends for any truncated instructions as hints to
- // InstCombine.
- if (VF.isVector())
- truncateToMinimalBitwidths(State);
-
// Fix widened non-induction PHIs by setting up the PHI operands.
if (EnableVPlanNativePath)
fixNonInductionPHIs(Plan, State);
@@ -3710,6 +3439,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
// Forget the original basic block.
PSE.getSE()->forgetLoop(OrigLoop);
+ PSE.getSE()->forgetBlockAndLoopDispositions();
// After vectorization, the exit blocks of the original loop will have
// additional predecessors. Invalidate SCEVs for the exit phis in case SE
@@ -3718,7 +3448,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
OrigLoop->getExitBlocks(ExitBlocks);
for (BasicBlock *Exit : ExitBlocks)
for (PHINode &PN : Exit->phis())
- PSE.getSE()->forgetValue(&PN);
+ PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN);
VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock();
Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]);
@@ -3744,7 +3474,8 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
// Fix LCSSA phis not already fixed earlier. Extracts may need to be generated
// in the exit block, so update the builder.
- State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI());
+ State.Builder.SetInsertPoint(State.CFG.ExitBB,
+ State.CFG.ExitBB->getFirstNonPHIIt());
for (const auto &KV : Plan.getLiveOuts())
KV.second->fixPhi(Plan, State);
@@ -3781,10 +3512,14 @@ void InnerLoopVectorizer::fixCrossIterationPHIs(VPTransformState &State) {
// the incoming edges.
VPBasicBlock *Header =
State.Plan->getVectorLoopRegion()->getEntryBasicBlock();
+
for (VPRecipeBase &R : Header->phis()) {
if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R))
fixReduction(ReductionPhi, State);
- else if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R))
+ }
+
+ for (VPRecipeBase &R : Header->phis()) {
+ if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R))
fixFixedOrderRecurrence(FOR, State);
}
}
@@ -3895,7 +3630,7 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence(
}
// Fix the initial value of the original recurrence in the scalar loop.
- Builder.SetInsertPoint(&*LoopScalarPreHeader->begin());
+ Builder.SetInsertPoint(LoopScalarPreHeader, LoopScalarPreHeader->begin());
PHINode *Phi = cast<PHINode>(PhiR->getUnderlyingValue());
auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init");
auto *ScalarInit = PhiR->getStartValue()->getLiveInIRValue();
@@ -3919,90 +3654,56 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
RecurKind RK = RdxDesc.getRecurrenceKind();
TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue();
Instruction *LoopExitInst = RdxDesc.getLoopExitInstr();
- State.setDebugLocFromInst(ReductionStartValue);
+ if (auto *I = dyn_cast<Instruction>(&*ReductionStartValue))
+ State.setDebugLocFrom(I->getDebugLoc());
VPValue *LoopExitInstDef = PhiR->getBackedgeValue();
- // This is the vector-clone of the value that leaves the loop.
- Type *VecTy = State.get(LoopExitInstDef, 0)->getType();
// Before each round, move the insertion point right between
// the PHIs and the values we are going to write.
// This allows us to write both PHINodes and the extractelement
// instructions.
- Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
+ Builder.SetInsertPoint(LoopMiddleBlock,
+ LoopMiddleBlock->getFirstInsertionPt());
- State.setDebugLocFromInst(LoopExitInst);
+ State.setDebugLocFrom(LoopExitInst->getDebugLoc());
Type *PhiTy = OrigPhi->getType();
-
- VPBasicBlock *LatchVPBB =
- PhiR->getParent()->getEnclosingLoopRegion()->getExitingBasicBlock();
- BasicBlock *VectorLoopLatch = State.CFG.VPBB2IRBB[LatchVPBB];
// If tail is folded by masking, the vector value to leave the loop should be
// a Select choosing between the vectorized LoopExitInst and vectorized Phi,
// instead of the former. For an inloop reduction the reduction will already
// be predicated, and does not need to be handled here.
if (Cost->foldTailByMasking() && !PhiR->isInLoop()) {
- for (unsigned Part = 0; Part < UF; ++Part) {
- Value *VecLoopExitInst = State.get(LoopExitInstDef, Part);
- SelectInst *Sel = nullptr;
- for (User *U : VecLoopExitInst->users()) {
- if (isa<SelectInst>(U)) {
- assert(!Sel && "Reduction exit feeding two selects");
- Sel = cast<SelectInst>(U);
- } else
- assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select");
- }
- assert(Sel && "Reduction exit feeds no select");
- State.reset(LoopExitInstDef, Sel, Part);
-
- if (isa<FPMathOperator>(Sel))
- Sel->setFastMathFlags(RdxDesc.getFastMathFlags());
-
- // If the target can create a predicated operator for the reduction at no
- // extra cost in the loop (for example a predicated vadd), it can be
- // cheaper for the select to remain in the loop than be sunk out of it,
- // and so use the select value for the phi instead of the old
- // LoopExitValue.
- if (PreferPredicatedReductionSelect ||
- TTI->preferPredicatedReductionSelect(
- RdxDesc.getOpcode(), PhiTy,
- TargetTransformInfo::ReductionFlags())) {
- auto *VecRdxPhi =
- cast<PHINode>(State.get(PhiR, Part));
- VecRdxPhi->setIncomingValueForBlock(VectorLoopLatch, Sel);
+ VPValue *Def = nullptr;
+ for (VPUser *U : LoopExitInstDef->users()) {
+ auto *S = dyn_cast<VPInstruction>(U);
+ if (S && S->getOpcode() == Instruction::Select) {
+ Def = S;
+ break;
}
}
+ if (Def)
+ LoopExitInstDef = Def;
}
+ VectorParts RdxParts(UF);
+ for (unsigned Part = 0; Part < UF; ++Part)
+ RdxParts[Part] = State.get(LoopExitInstDef, Part);
+
// If the vector reduction can be performed in a smaller type, we truncate
// then extend the loop exit value to enable InstCombine to evaluate the
// entire expression in the smaller type.
if (VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
- assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
+ Builder.SetInsertPoint(LoopMiddleBlock,
+ LoopMiddleBlock->getFirstInsertionPt());
Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF);
- Builder.SetInsertPoint(VectorLoopLatch->getTerminator());
- VectorParts RdxParts(UF);
- for (unsigned Part = 0; Part < UF; ++Part) {
- RdxParts[Part] = State.get(LoopExitInstDef, Part);
- Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
- Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy)
- : Builder.CreateZExt(Trunc, VecTy);
- for (User *U : llvm::make_early_inc_range(RdxParts[Part]->users()))
- if (U != Trunc) {
- U->replaceUsesOfWith(RdxParts[Part], Extnd);
- RdxParts[Part] = Extnd;
- }
- }
- Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt());
for (unsigned Part = 0; Part < UF; ++Part) {
RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
- State.reset(LoopExitInstDef, RdxParts[Part], Part);
}
}
// Reduce all of the unrolled parts into a single vector.
- Value *ReducedPartRdx = State.get(LoopExitInstDef, 0);
+ Value *ReducedPartRdx = RdxParts[0];
unsigned Op = RecurrenceDescriptor::getOpcode(RK);
// The middle block terminator has already been assigned a DebugLoc here (the
@@ -4012,21 +3713,21 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
// conditional branch, and (c) other passes may add new predecessors which
// terminate on this line. This is the easiest way to ensure we don't
// accidentally cause an extra step back into the loop while debugging.
- State.setDebugLocFromInst(LoopMiddleBlock->getTerminator());
+ State.setDebugLocFrom(LoopMiddleBlock->getTerminator()->getDebugLoc());
if (PhiR->isOrdered())
- ReducedPartRdx = State.get(LoopExitInstDef, UF - 1);
+ ReducedPartRdx = RdxParts[UF - 1];
else {
// Floating-point operations should have some FMF to enable the reduction.
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
for (unsigned Part = 1; Part < UF; ++Part) {
- Value *RdxPart = State.get(LoopExitInstDef, Part);
- if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
+ Value *RdxPart = RdxParts[Part];
+ if (Op != Instruction::ICmp && Op != Instruction::FCmp)
ReducedPartRdx = Builder.CreateBinOp(
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
- } else if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
- ReducedPartRdx = createSelectCmpOp(Builder, ReductionStartValue, RK,
- ReducedPartRdx, RdxPart);
+ else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK,
+ ReducedPartRdx, RdxPart);
else
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
}
@@ -4036,7 +3737,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
// target reduction in the loop using a Reduction recipe.
if (VF.isVector() && !PhiR->isInLoop()) {
ReducedPartRdx =
- createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, OrigPhi);
+ createTargetReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
// If the reduction can be performed in a smaller type, we need to extend
// the reduction to the wider type before we branch to the original loop.
if (PhiTy != RdxDesc.getRecurrenceType())
@@ -4073,7 +3774,8 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
// inside the loop, create the final store here.
if (StoreInst *SI = RdxDesc.IntermediateStore) {
StoreInst *NewSI =
- Builder.CreateStore(ReducedPartRdx, SI->getPointerOperand());
+ Builder.CreateAlignedStore(ReducedPartRdx, SI->getPointerOperand(),
+ SI->getAlign());
propagateMetadata(NewSI, SI);
// If the reduction value is used in other places,
@@ -4402,7 +4104,10 @@ bool LoopVectorizationCostModel::isScalarWithPredication(
default:
return true;
case Instruction::Call:
- return !VFDatabase::hasMaskedVariant(*(cast<CallInst>(I)), VF);
+ if (VF.isScalar())
+ return true;
+ return CallWideningDecisions.at(std::make_pair(cast<CallInst>(I), VF))
+ .Kind == CM_Scalarize;
case Instruction::Load:
case Instruction::Store: {
auto *Ptr = getLoadStorePointerOperand(I);
@@ -4954,7 +4659,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
}
FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
- unsigned ConstTripCount, ElementCount UserVF, bool FoldTailByMasking) {
+ unsigned MaxTripCount, ElementCount UserVF, bool FoldTailByMasking) {
MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
unsigned SmallestType, WidestType;
std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
@@ -5042,12 +4747,12 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
FixedScalableVFPair Result(ElementCount::getFixed(1),
ElementCount::getScalable(0));
if (auto MaxVF =
- getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
+ getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
MaxSafeFixedVF, FoldTailByMasking))
Result.FixedVF = MaxVF;
if (auto MaxVF =
- getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
+ getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
MaxSafeScalableVF, FoldTailByMasking))
if (MaxVF.isScalable()) {
Result.ScalableVF = MaxVF;
@@ -5071,6 +4776,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
}
unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
+ unsigned MaxTC = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop);
LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
if (TC == 1) {
reportVectorizationFailure("Single iteration (non) loop",
@@ -5081,7 +4787,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
switch (ScalarEpilogueStatus) {
case CM_ScalarEpilogueAllowed:
- return computeFeasibleMaxVF(TC, UserVF, false);
+ return computeFeasibleMaxVF(MaxTC, UserVF, false);
case CM_ScalarEpilogueNotAllowedUsePredicate:
[[fallthrough]];
case CM_ScalarEpilogueNotNeededUsePredicate:
@@ -5119,7 +4825,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking: vectorize with a "
"scalar epilogue instead.\n");
ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
- return computeFeasibleMaxVF(TC, UserVF, false);
+ return computeFeasibleMaxVF(MaxTC, UserVF, false);
}
return FixedScalableVFPair::getNone();
}
@@ -5136,7 +4842,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
InterleaveInfo.invalidateGroupsRequiringScalarEpilogue();
}
- FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true);
+ FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(MaxTC, UserVF, true);
// Avoid tail folding if the trip count is known to be a multiple of any VF
// we choose.
@@ -5212,7 +4918,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
}
ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
- unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType,
+ unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType,
ElementCount MaxSafeVF, bool FoldTailByMasking) {
bool ComputeScalableMaxVF = MaxSafeVF.isScalable();
const TypeSize WidestRegister = TTI.getRegisterBitWidth(
@@ -5251,31 +4957,35 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
}
// When a scalar epilogue is required, at least one iteration of the scalar
- // loop has to execute. Adjust ConstTripCount accordingly to avoid picking a
+ // loop has to execute. Adjust MaxTripCount accordingly to avoid picking a
// max VF that results in a dead vector loop.
- if (ConstTripCount > 0 && requiresScalarEpilogue(true))
- ConstTripCount -= 1;
+ if (MaxTripCount > 0 && requiresScalarEpilogue(true))
+ MaxTripCount -= 1;
- if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC &&
- (!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) {
- // If loop trip count (TC) is known at compile time there is no point in
- // choosing VF greater than TC (as done in the loop below). Select maximum
- // power of two which doesn't exceed TC.
- // If MaxVectorElementCount is scalable, we only fall back on a fixed VF
- // when the TC is less than or equal to the known number of lanes.
- auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount);
+ if (MaxTripCount && MaxTripCount <= WidestRegisterMinEC &&
+ (!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) {
+ // If upper bound loop trip count (TC) is known at compile time there is no
+ // point in choosing VF greater than TC (as done in the loop below). Select
+ // maximum power of two which doesn't exceed TC. If MaxVectorElementCount is
+ // scalable, we only fall back on a fixed VF when the TC is less than or
+ // equal to the known number of lanes.
+ auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount);
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
"exceeding the constant trip count: "
- << ClampedConstTripCount << "\n");
- return ElementCount::getFixed(ClampedConstTripCount);
+ << ClampedUpperTripCount << "\n");
+ return ElementCount::get(
+ ClampedUpperTripCount,
+ FoldTailByMasking ? MaxVectorElementCount.isScalable() : false);
}
TargetTransformInfo::RegisterKind RegKind =
ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector
: TargetTransformInfo::RGK_FixedWidthVector;
ElementCount MaxVF = MaxVectorElementCount;
- if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 &&
- TTI.shouldMaximizeVectorBandwidth(RegKind))) {
+ if (MaximizeBandwidth ||
+ (MaximizeBandwidth.getNumOccurrences() == 0 &&
+ (TTI.shouldMaximizeVectorBandwidth(RegKind) ||
+ (UseWiderVFIfCallVariantsPresent && Legal->hasVectorCallVariants())))) {
auto MaxVectorElementCountMaxBW = ElementCount::get(
llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType),
ComputeScalableMaxVF);
@@ -5947,7 +5657,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
HasReductions &&
any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
const RecurrenceDescriptor &RdxDesc = Reduction.second;
- return RecurrenceDescriptor::isSelectCmpRecurrenceKind(
+ return RecurrenceDescriptor::isAnyOfRecurrenceKind(
RdxDesc.getRecurrenceKind());
});
if (HasSelectCmpReductions) {
@@ -6115,6 +5825,8 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
if (ValuesToIgnore.count(I))
continue;
+ collectInLoopReductions();
+
// For each VF find the maximum usage of registers.
for (unsigned j = 0, e = VFs.size(); j < e; ++j) {
// Count the number of registers used, per register class, given all open
@@ -6634,10 +6346,11 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
std::optional<InstructionCost>
LoopVectorizationCostModel::getReductionPatternCost(
- Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) {
+ Instruction *I, ElementCount VF, Type *Ty,
+ TTI::TargetCostKind CostKind) const {
using namespace llvm::PatternMatch;
// Early exit for no inloop reductions
- if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty))
+ if (InLoopReductions.empty() || VF.isScalar() || !isa<VectorType>(Ty))
return std::nullopt;
auto *VectorTy = cast<VectorType>(Ty);
@@ -6672,10 +6385,10 @@ LoopVectorizationCostModel::getReductionPatternCost(
// Find the reduction this chain is a part of and calculate the basic cost of
// the reduction on its own.
- Instruction *LastChain = InLoopReductionImmediateChains[RetI];
+ Instruction *LastChain = InLoopReductionImmediateChains.at(RetI);
Instruction *ReductionPhi = LastChain;
while (!isa<PHINode>(ReductionPhi))
- ReductionPhi = InLoopReductionImmediateChains[ReductionPhi];
+ ReductionPhi = InLoopReductionImmediateChains.at(ReductionPhi);
const RecurrenceDescriptor &RdxDesc =
Legal->getReductionVars().find(cast<PHINode>(ReductionPhi))->second;
@@ -7093,6 +6806,168 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) {
}
}
+void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
+ assert(!VF.isScalar() &&
+ "Trying to set a vectorization decision for a scalar VF");
+
+ for (BasicBlock *BB : TheLoop->blocks()) {
+ // For each instruction in the old loop.
+ for (Instruction &I : *BB) {
+ CallInst *CI = dyn_cast<CallInst>(&I);
+
+ if (!CI)
+ continue;
+
+ InstructionCost ScalarCost = InstructionCost::getInvalid();
+ InstructionCost VectorCost = InstructionCost::getInvalid();
+ InstructionCost IntrinsicCost = InstructionCost::getInvalid();
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ Function *ScalarFunc = CI->getCalledFunction();
+ Type *ScalarRetTy = CI->getType();
+ SmallVector<Type *, 4> Tys, ScalarTys;
+ bool MaskRequired = Legal->isMaskRequired(CI);
+ for (auto &ArgOp : CI->args())
+ ScalarTys.push_back(ArgOp->getType());
+
+ // Compute corresponding vector type for return value and arguments.
+ Type *RetTy = ToVectorTy(ScalarRetTy, VF);
+ for (Type *ScalarTy : ScalarTys)
+ Tys.push_back(ToVectorTy(ScalarTy, VF));
+
+ // An in-loop reduction using an fmuladd intrinsic is a special case;
+ // we don't want the normal cost for that intrinsic.
+ if (RecurrenceDescriptor::isFMulAddIntrinsic(CI))
+ if (auto RedCost = getReductionPatternCost(CI, VF, RetTy, CostKind)) {
+ setCallWideningDecision(CI, VF, CM_IntrinsicCall, nullptr,
+ getVectorIntrinsicIDForCall(CI, TLI),
+ std::nullopt, *RedCost);
+ continue;
+ }
+
+ // Estimate cost of scalarized vector call. The source operands are
+ // assumed to be vectors, so we need to extract individual elements from
+ // there, execute VF scalar calls, and then gather the result into the
+ // vector return value.
+ InstructionCost ScalarCallCost =
+ TTI.getCallInstrCost(ScalarFunc, ScalarRetTy, ScalarTys, CostKind);
+
+ // Compute costs of unpacking argument values for the scalar calls and
+ // packing the return values to a vector.
+ InstructionCost ScalarizationCost =
+ getScalarizationOverhead(CI, VF, CostKind);
+
+ ScalarCost = ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost;
+
+ // Find the cost of vectorizing the call, if we can find a suitable
+ // vector variant of the function.
+ bool UsesMask = false;
+ VFInfo FuncInfo;
+ Function *VecFunc = nullptr;
+ // Search through any available variants for one we can use at this VF.
+ for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
+ // Must match requested VF.
+ if (Info.Shape.VF != VF)
+ continue;
+
+ // Must take a mask argument if one is required
+ if (MaskRequired && !Info.isMasked())
+ continue;
+
+ // Check that all parameter kinds are supported
+ bool ParamsOk = true;
+ for (VFParameter Param : Info.Shape.Parameters) {
+ switch (Param.ParamKind) {
+ case VFParamKind::Vector:
+ break;
+ case VFParamKind::OMP_Uniform: {
+ Value *ScalarParam = CI->getArgOperand(Param.ParamPos);
+ // Make sure the scalar parameter in the loop is invariant.
+ if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(ScalarParam),
+ TheLoop))
+ ParamsOk = false;
+ break;
+ }
+ case VFParamKind::OMP_Linear: {
+ Value *ScalarParam = CI->getArgOperand(Param.ParamPos);
+ // Find the stride for the scalar parameter in this loop and see if
+ // it matches the stride for the variant.
+ // TODO: do we need to figure out the cost of an extract to get the
+ // first lane? Or do we hope that it will be folded away?
+ ScalarEvolution *SE = PSE.getSE();
+ const auto *SAR =
+ dyn_cast<SCEVAddRecExpr>(SE->getSCEV(ScalarParam));
+
+ if (!SAR || SAR->getLoop() != TheLoop) {
+ ParamsOk = false;
+ break;
+ }
+
+ const SCEVConstant *Step =
+ dyn_cast<SCEVConstant>(SAR->getStepRecurrence(*SE));
+
+ if (!Step ||
+ Step->getAPInt().getSExtValue() != Param.LinearStepOrPos)
+ ParamsOk = false;
+
+ break;
+ }
+ case VFParamKind::GlobalPredicate:
+ UsesMask = true;
+ break;
+ default:
+ ParamsOk = false;
+ break;
+ }
+ }
+
+ if (!ParamsOk)
+ continue;
+
+ // Found a suitable candidate, stop here.
+ VecFunc = CI->getModule()->getFunction(Info.VectorName);
+ FuncInfo = Info;
+ break;
+ }
+
+ // Add in the cost of synthesizing a mask if one wasn't required.
+ InstructionCost MaskCost = 0;
+ if (VecFunc && UsesMask && !MaskRequired)
+ MaskCost = TTI.getShuffleCost(
+ TargetTransformInfo::SK_Broadcast,
+ VectorType::get(IntegerType::getInt1Ty(
+ VecFunc->getFunctionType()->getContext()),
+ VF));
+
+ if (TLI && VecFunc && !CI->isNoBuiltin())
+ VectorCost =
+ TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
+
+ // Find the cost of an intrinsic; some targets may have instructions that
+ // perform the operation without needing an actual call.
+ Intrinsic::ID IID = getVectorIntrinsicIDForCall(CI, TLI);
+ if (IID != Intrinsic::not_intrinsic)
+ IntrinsicCost = getVectorIntrinsicCost(CI, VF);
+
+ InstructionCost Cost = ScalarCost;
+ InstWidening Decision = CM_Scalarize;
+
+ if (VectorCost <= Cost) {
+ Cost = VectorCost;
+ Decision = CM_VectorCall;
+ }
+
+ if (IntrinsicCost <= Cost) {
+ Cost = IntrinsicCost;
+ Decision = CM_IntrinsicCall;
+ }
+
+ setCallWideningDecision(CI, VF, Decision, VecFunc, IID,
+ FuncInfo.getParamIndexForOptionalMask(), Cost);
+ }
+ }
+}
+
InstructionCost
LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
Type *&VectorTy) {
@@ -7122,7 +6997,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
// With the exception of GEPs and PHIs, after scalarization there should
// only be one copy of the instruction generated in the loop. This is
// because the VF is either 1, or any instructions that need scalarizing
- // have already been dealt with by the the time we get here. As a result,
+ // have already been dealt with by the time we get here. As a result,
// it means we don't have to multiply the instruction cost by VF.
assert(I->getOpcode() == Instruction::GetElementPtr ||
I->getOpcode() == Instruction::PHI ||
@@ -7350,6 +7225,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
return TTI::CastContextHint::Reversed;
case LoopVectorizationCostModel::CM_Unknown:
llvm_unreachable("Instr did not go through cost modelling?");
+ case LoopVectorizationCostModel::CM_VectorCall:
+ case LoopVectorizationCostModel::CM_IntrinsicCall:
+ llvm_unreachable_internal("Instr has invalid widening decision");
}
llvm_unreachable("Unhandled case!");
@@ -7407,19 +7285,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I);
}
- case Instruction::Call: {
- if (RecurrenceDescriptor::isFMulAddIntrinsic(I))
- if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind))
- return *RedCost;
- Function *Variant;
- CallInst *CI = cast<CallInst>(I);
- InstructionCost CallCost = getVectorCallCost(CI, VF, &Variant);
- if (getVectorIntrinsicIDForCall(CI, TLI)) {
- InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF);
- return std::min(CallCost, IntrinsicCost);
- }
- return CallCost;
- }
+ case Instruction::Call:
+ return getVectorCallCost(cast<CallInst>(I), VF);
case Instruction::ExtractValue:
return TTI.getInstructionCost(I, TTI::TCK_RecipThroughput);
case Instruction::Alloca:
@@ -7487,8 +7354,9 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
SmallVector<Instruction *, 4> ReductionOperations =
RdxDesc.getReductionOpChain(Phi, TheLoop);
bool InLoop = !ReductionOperations.empty();
+
if (InLoop) {
- InLoopReductionChains[Phi] = ReductionOperations;
+ InLoopReductions.insert(Phi);
// Add the elements to InLoopReductionImmediateChains for cost modelling.
Instruction *LastChain = Phi;
for (auto *I : ReductionOperations) {
@@ -7501,21 +7369,38 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
}
}
+VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+ DebugLoc DL, const Twine &Name) {
+ assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
+ Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
+ return tryInsertInstruction(
+ new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+}
+
+// This function will select a scalable VF if the target supports scalable
+// vectors and a fixed one otherwise.
// TODO: we could return a pair of values that specify the max VF and
// min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of
// `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment
// doesn't have a cost model that can choose which plan to execute if
// more than one is generated.
-static unsigned determineVPlanVF(const unsigned WidestVectorRegBits,
- LoopVectorizationCostModel &CM) {
+static ElementCount determineVPlanVF(const TargetTransformInfo &TTI,
+ LoopVectorizationCostModel &CM) {
unsigned WidestType;
std::tie(std::ignore, WidestType) = CM.getSmallestAndWidestTypes();
- return WidestVectorRegBits / WidestType;
+
+ TargetTransformInfo::RegisterKind RegKind =
+ TTI.enableScalableVectorization()
+ ? TargetTransformInfo::RGK_ScalableVector
+ : TargetTransformInfo::RGK_FixedWidthVector;
+
+ TypeSize RegSize = TTI.getRegisterBitWidth(RegKind);
+ unsigned N = RegSize.getKnownMinValue() / WidestType;
+ return ElementCount::get(N, RegSize.isScalable());
}
VectorizationFactor
LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
- assert(!UserVF.isScalable() && "scalable vectors not yet supported");
ElementCount VF = UserVF;
// Outer loop handling: They may require CFG and instruction level
// transformations before even evaluating whether vectorization is profitable.
@@ -7525,10 +7410,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
// If the user doesn't provide a vectorization factor, determine a
// reasonable one.
if (UserVF.isZero()) {
- VF = ElementCount::getFixed(determineVPlanVF(
- TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedValue(),
- CM));
+ VF = determineVPlanVF(TTI, CM);
LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n");
// Make sure we have a VF > 1 for stress testing.
@@ -7537,6 +7419,17 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
<< "overriding computed VF.\n");
VF = ElementCount::getFixed(4);
}
+ } else if (UserVF.isScalable() && !TTI.supportsScalableVectors() &&
+ !ForceTargetSupportsScalableVectors) {
+ LLVM_DEBUG(dbgs() << "LV: Not vectorizing. Scalable VF requested, but "
+ << "not supported by the target.\n");
+ reportVectorizationFailure(
+ "Scalable vectorization requested but not supported by the target",
+ "the scalable user-specified vectorization width for outer-loop "
+ "vectorization cannot be used because the target does not support "
+ "scalable vectors.",
+ "ScalableVFUnfeasible", ORE, OrigLoop);
+ return VectorizationFactor::Disabled();
}
assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
assert(isPowerOf2_32(VF.getKnownMinValue()) &&
@@ -7590,9 +7483,9 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
"VF needs to be a power of two");
// Collect the instructions (and their associated costs) that will be more
// profitable to scalarize.
+ CM.collectInLoopReductions();
if (CM.selectUserVectorizationFactor(UserVF)) {
LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n");
- CM.collectInLoopReductions();
buildVPlansWithVPRecipes(UserVF, UserVF);
if (!hasPlanWithVF(UserVF)) {
LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << UserVF
@@ -7616,6 +7509,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
ElementCount::isKnownLE(VF, MaxFactors.ScalableVF); VF *= 2)
VFCandidates.insert(VF);
+ CM.collectInLoopReductions();
for (const auto &VF : VFCandidates) {
// Collect Uniform and Scalar instructions after vectorization with VF.
CM.collectUniformsAndScalars(VF);
@@ -7626,7 +7520,6 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
CM.collectInstsToScalarize(VF);
}
- CM.collectInLoopReductions();
buildVPlansWithVPRecipes(ElementCount::getFixed(1), MaxFactors.FixedVF);
buildVPlansWithVPRecipes(ElementCount::getScalable(1), MaxFactors.ScalableVF);
@@ -7671,7 +7564,7 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) {
if (MD) {
const auto *S = dyn_cast<MDString>(MD->getOperand(0));
IsUnrollMetadata =
- S && S->getString().startswith("llvm.loop.unroll.disable");
+ S && S->getString().starts_with("llvm.loop.unroll.disable");
}
MDs.push_back(LoopID->getOperand(i));
}
@@ -7695,7 +7588,7 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) {
SCEV2ValueTy LoopVectorizationPlanner::executePlan(
ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
InnerLoopVectorizer &ILV, DominatorTree *DT, bool IsEpilogueVectorization,
- DenseMap<const SCEV *, Value *> *ExpandedSCEVs) {
+ const DenseMap<const SCEV *, Value *> *ExpandedSCEVs) {
assert(BestVPlan.hasVF(BestVF) &&
"Trying to execute plan with unsupported VF");
assert(BestVPlan.hasUF(BestUF) &&
@@ -7711,7 +7604,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
// Perform the actual loop transformation.
- VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan};
+ VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+ OrigLoop->getHeader()->getContext());
// 0. Generate SCEV-dependent code into the preheader, including TripCount,
// before making any changes to the CFG.
@@ -7764,9 +7658,9 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
//===------------------------------------------------===//
// 2. Copy and widen instructions from the old loop into the new loop.
- BestVPlan.prepareToExecute(
- ILV.getTripCount(), ILV.getOrCreateVectorTripCount(nullptr),
- CanonicalIVStartValue, State, IsEpilogueVectorization);
+ BestVPlan.prepareToExecute(ILV.getTripCount(),
+ ILV.getOrCreateVectorTripCount(nullptr),
+ CanonicalIVStartValue, State);
BestVPlan.execute(&State);
@@ -7930,9 +7824,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
EPI.TripCount = Count;
}
- ReplaceInstWithInst(
- TCCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+ setBranchWeights(BI, MinItersBypassWeights);
+ ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
return TCCheckBlock;
}
@@ -8030,8 +7926,8 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton(
// Generate a resume induction for the vector epilogue and put it in the
// vector epilogue preheader
Type *IdxTy = Legal->getWidestInductionType();
- PHINode *EPResumeVal = PHINode::Create(IdxTy, 2, "vec.epilog.resume.val",
- LoopVectorPreHeader->getFirstNonPHI());
+ PHINode *EPResumeVal = PHINode::Create(IdxTy, 2, "vec.epilog.resume.val");
+ EPResumeVal->insertBefore(LoopVectorPreHeader->getFirstNonPHIIt());
EPResumeVal->addIncoming(EPI.VectorTripCount, VecEpilogueIterationCountCheck);
EPResumeVal->addIncoming(ConstantInt::get(IdxTy, 0),
EPI.MainLoopIterationCountCheck);
@@ -8076,9 +7972,22 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
EPI.EpilogueVF, EPI.EpilogueUF),
"min.epilog.iters.check");
- ReplaceInstWithInst(
- Insert->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
+ unsigned MainLoopStep = UF * VF.getKnownMinValue();
+ unsigned EpilogueLoopStep =
+ EPI.EpilogueUF * EPI.EpilogueVF.getKnownMinValue();
+ // We assume the remaining `Count` is equally distributed in
+ // [0, MainLoopStep)
+ // So the probability for `Count < EpilogueLoopStep` should be
+ // min(MainLoopStep, EpilogueLoopStep) / MainLoopStep
+ unsigned EstimatedSkipCount = std::min(MainLoopStep, EpilogueLoopStep);
+ const uint32_t Weights[] = {EstimatedSkipCount,
+ MainLoopStep - EstimatedSkipCount};
+ setBranchWeights(BI, Weights);
+ }
+ ReplaceInstWithInst(Insert->getTerminator(), &BI);
LoopBypassBlocks.push_back(Insert);
return Insert;
@@ -8172,6 +8081,33 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst,
return EdgeMaskCache[Edge] = EdgeMask;
}
+void VPRecipeBuilder::createHeaderMask(VPlan &Plan) {
+ BasicBlock *Header = OrigLoop->getHeader();
+
+ // When not folding the tail, use nullptr to model all-true mask.
+ if (!CM.foldTailByMasking()) {
+ BlockMaskCache[Header] = nullptr;
+ return;
+ }
+
+ // Introduce the early-exit compare IV <= BTC to form header block mask.
+ // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
+ // constructing the desired canonical IV in the header block as its first
+ // non-phi instructions.
+
+ VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
+ auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
+ auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
+ HeaderVPBB->insert(IV, NewInsertionPoint);
+
+ VPBuilder::InsertPointGuard Guard(Builder);
+ Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint);
+ VPValue *BlockMask = nullptr;
+ VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
+ BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
+ BlockMaskCache[Header] = BlockMask;
+}
+
VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) {
assert(OrigLoop->contains(BB) && "Block is not a part of a loop");
@@ -8180,45 +8116,12 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) {
if (BCEntryIt != BlockMaskCache.end())
return BCEntryIt->second;
+ assert(OrigLoop->getHeader() != BB &&
+ "Loop header must have cached block mask");
+
// All-one mask is modelled as no-mask following the convention for masked
// load/store/gather/scatter. Initialize BlockMask to no-mask.
VPValue *BlockMask = nullptr;
-
- if (OrigLoop->getHeader() == BB) {
- if (!CM.blockNeedsPredicationForAnyReason(BB))
- return BlockMaskCache[BB] = BlockMask; // Loop incoming mask is all-one.
-
- assert(CM.foldTailByMasking() && "must fold the tail");
-
- // If we're using the active lane mask for control flow, then we get the
- // mask from the active lane mask PHI that is cached in the VPlan.
- TailFoldingStyle TFStyle = CM.getTailFoldingStyle();
- if (useActiveLaneMaskForControlFlow(TFStyle))
- return BlockMaskCache[BB] = Plan.getActiveLaneMaskPhi();
-
- // Introduce the early-exit compare IV <= BTC to form header block mask.
- // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
- // constructing the desired canonical IV in the header block as its first
- // non-phi instructions.
-
- VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
- auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
- auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
- HeaderVPBB->insert(IV, HeaderVPBB->getFirstNonPhi());
-
- VPBuilder::InsertPointGuard Guard(Builder);
- Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint);
- if (useActiveLaneMask(TFStyle)) {
- VPValue *TC = Plan.getTripCount();
- BlockMask = Builder.createNaryOp(VPInstruction::ActiveLaneMask, {IV, TC},
- nullptr, "active.lane.mask");
- } else {
- VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
- BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC});
- }
- return BlockMaskCache[BB] = BlockMask;
- }
-
// This is the block mask. We OR all incoming edges.
for (auto *Predecessor : predecessors(BB)) {
VPValue *EdgeMask = createEdgeMask(Predecessor, BB, Plan);
@@ -8424,22 +8327,15 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
bool ShouldUseVectorIntrinsic =
ID && LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) -> bool {
- Function *Variant;
- // Is it beneficial to perform intrinsic call compared to lib
- // call?
- InstructionCost CallCost =
- CM.getVectorCallCost(CI, VF, &Variant);
- InstructionCost IntrinsicCost =
- CM.getVectorIntrinsicCost(CI, VF);
- return IntrinsicCost <= CallCost;
+ return CM.getCallWideningDecision(CI, VF).Kind ==
+ LoopVectorizationCostModel::CM_IntrinsicCall;
},
Range);
if (ShouldUseVectorIntrinsic)
return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID);
Function *Variant = nullptr;
- ElementCount VariantVF;
- bool NeedsMask = false;
+ std::optional<unsigned> MaskPos;
// Is better to call a vectorized version of the function than to to scalarize
// the call?
auto ShouldUseVectorCall = LoopVectorizationPlanner::getDecisionAndClampRange(
@@ -8458,16 +8354,19 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
// finds a valid variant.
if (Variant)
return false;
- CM.getVectorCallCost(CI, VF, &Variant, &NeedsMask);
- // If we found a valid vector variant at this VF, then store the VF
- // in case we need to generate a mask.
- if (Variant)
- VariantVF = VF;
- return Variant != nullptr;
+ LoopVectorizationCostModel::CallWideningDecision Decision =
+ CM.getCallWideningDecision(CI, VF);
+ if (Decision.Kind == LoopVectorizationCostModel::CM_VectorCall) {
+ Variant = Decision.Variant;
+ MaskPos = Decision.MaskPos;
+ return true;
+ }
+
+ return false;
},
Range);
if (ShouldUseVectorCall) {
- if (NeedsMask) {
+ if (MaskPos.has_value()) {
// We have 2 cases that would require a mask:
// 1) The block needs to be predicated, either due to a conditional
// in the scalar loop or use of an active lane mask with
@@ -8482,17 +8381,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
Mask = Plan->getVPValueOrAddLiveIn(ConstantInt::getTrue(
IntegerType::getInt1Ty(Variant->getFunctionType()->getContext())));
- VFShape Shape = VFShape::get(*CI, VariantVF, /*HasGlobalPred=*/true);
- unsigned MaskPos = 0;
-
- for (const VFInfo &Info : VFDatabase::getMappings(*CI))
- if (Info.Shape == Shape) {
- assert(Info.isMasked() && "Vector function info shape mismatch");
- MaskPos = Info.getParamIndexForOptionalMask().value();
- break;
- }
-
- Ops.insert(Ops.begin() + MaskPos, Mask);
+ Ops.insert(Ops.begin() + *MaskPos, Mask);
}
return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()),
@@ -8713,8 +8602,8 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
}
if (auto *CI = dyn_cast<CastInst>(Instr)) {
- return toVPRecipeResult(
- new VPWidenCastRecipe(CI->getOpcode(), Operands[0], CI->getType(), CI));
+ return toVPRecipeResult(new VPWidenCastRecipe(CI->getOpcode(), Operands[0],
+ CI->getType(), *CI));
}
return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan));
@@ -8724,27 +8613,26 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
ElementCount MaxVF) {
assert(OrigLoop->isInnermost() && "Inner loop expected.");
- // Add assume instructions we need to drop to DeadInstructions, to prevent
- // them from being added to the VPlan.
- // TODO: We only need to drop assumes in blocks that get flattend. If the
- // control flow is preserved, we should keep them.
- SmallPtrSet<Instruction *, 4> DeadInstructions;
- auto &ConditionalAssumes = Legal->getConditionalAssumes();
- DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end());
-
auto MaxVFTimes2 = MaxVF * 2;
for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) {
VFRange SubRange = {VF, MaxVFTimes2};
- if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange, DeadInstructions))
- VPlans.push_back(std::move(*Plan));
+ if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange)) {
+ // Now optimize the initial VPlan.
+ if (!Plan->hasVF(ElementCount::getFixed(1)))
+ VPlanTransforms::truncateToMinimalBitwidths(
+ *Plan, CM.getMinimalBitwidths(), PSE.getSE()->getContext());
+ VPlanTransforms::optimize(*Plan, *PSE.getSE());
+ assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid");
+ VPlans.push_back(std::move(Plan));
+ }
VF = SubRange.End;
}
}
// Add the necessary canonical IV and branch recipes required to control the
// loop.
-static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
- TailFoldingStyle Style) {
+static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
+ DebugLoc DL) {
Value *StartIdx = ConstantInt::get(IdxTy, 0);
auto *StartV = Plan.getVPValueOrAddLiveIn(StartIdx);
@@ -8756,102 +8644,24 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL,
// Add a CanonicalIVIncrement{NUW} VPInstruction to increment the scalar
// IV by VF * UF.
- bool HasNUW = Style == TailFoldingStyle::None;
auto *CanonicalIVIncrement =
- new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementNUW
- : VPInstruction::CanonicalIVIncrement,
- {CanonicalIVPHI}, DL, "index.next");
+ new VPInstruction(Instruction::Add, {CanonicalIVPHI, &Plan.getVFxUF()},
+ {HasNUW, false}, DL, "index.next");
CanonicalIVPHI->addOperand(CanonicalIVIncrement);
VPBasicBlock *EB = TopRegion->getExitingBasicBlock();
- if (useActiveLaneMaskForControlFlow(Style)) {
- // Create the active lane mask instruction in the vplan preheader.
- VPBasicBlock *VecPreheader =
- cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSinglePredecessor());
-
- // We can't use StartV directly in the ActiveLaneMask VPInstruction, since
- // we have to take unrolling into account. Each part needs to start at
- // Part * VF
- auto *CanonicalIVIncrementParts =
- new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW
- : VPInstruction::CanonicalIVIncrementForPart,
- {StartV}, DL, "index.part.next");
- VecPreheader->appendRecipe(CanonicalIVIncrementParts);
-
- // Create the ActiveLaneMask instruction using the correct start values.
- VPValue *TC = Plan.getTripCount();
-
- VPValue *TripCount, *IncrementValue;
- if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
- // When avoiding a runtime check, the active.lane.mask inside the loop
- // uses a modified trip count and the induction variable increment is
- // done after the active.lane.mask intrinsic is called.
- auto *TCMinusVF =
- new VPInstruction(VPInstruction::CalculateTripCountMinusVF, {TC}, DL);
- VecPreheader->appendRecipe(TCMinusVF);
- IncrementValue = CanonicalIVPHI;
- TripCount = TCMinusVF;
- } else {
- // When the loop is guarded by a runtime overflow check for the loop
- // induction variable increment by VF, we can increment the value before
- // the get.active.lane mask and use the unmodified tripcount.
- EB->appendRecipe(CanonicalIVIncrement);
- IncrementValue = CanonicalIVIncrement;
- TripCount = TC;
- }
-
- auto *EntryALM = new VPInstruction(VPInstruction::ActiveLaneMask,
- {CanonicalIVIncrementParts, TC}, DL,
- "active.lane.mask.entry");
- VecPreheader->appendRecipe(EntryALM);
-
- // Now create the ActiveLaneMaskPhi recipe in the main loop using the
- // preheader ActiveLaneMask instruction.
- auto *LaneMaskPhi = new VPActiveLaneMaskPHIRecipe(EntryALM, DebugLoc());
- Header->insert(LaneMaskPhi, Header->getFirstNonPhi());
-
- // Create the active lane mask for the next iteration of the loop.
- CanonicalIVIncrementParts =
- new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW
- : VPInstruction::CanonicalIVIncrementForPart,
- {IncrementValue}, DL);
- EB->appendRecipe(CanonicalIVIncrementParts);
-
- auto *ALM = new VPInstruction(VPInstruction::ActiveLaneMask,
- {CanonicalIVIncrementParts, TripCount}, DL,
- "active.lane.mask.next");
- EB->appendRecipe(ALM);
- LaneMaskPhi->addOperand(ALM);
-
- if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
- // Do the increment of the canonical IV after the active.lane.mask, because
- // that value is still based off %CanonicalIVPHI
- EB->appendRecipe(CanonicalIVIncrement);
- }
-
- // We have to invert the mask here because a true condition means jumping
- // to the exit block.
- auto *NotMask = new VPInstruction(VPInstruction::Not, ALM, DL);
- EB->appendRecipe(NotMask);
+ EB->appendRecipe(CanonicalIVIncrement);
- VPInstruction *BranchBack =
- new VPInstruction(VPInstruction::BranchOnCond, {NotMask}, DL);
- EB->appendRecipe(BranchBack);
- } else {
- EB->appendRecipe(CanonicalIVIncrement);
-
- // Add the BranchOnCount VPInstruction to the latch.
- VPInstruction *BranchBack = new VPInstruction(
- VPInstruction::BranchOnCount,
- {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL);
- EB->appendRecipe(BranchBack);
- }
+ // Add the BranchOnCount VPInstruction to the latch.
+ VPInstruction *BranchBack =
+ new VPInstruction(VPInstruction::BranchOnCount,
+ {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL);
+ EB->appendRecipe(BranchBack);
}
// Add exit values to \p Plan. VPLiveOuts are added for each LCSSA phi in the
// original exit block.
-static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB,
- VPBasicBlock *MiddleVPBB, Loop *OrigLoop,
+static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, Loop *OrigLoop,
VPlan &Plan) {
BasicBlock *ExitBB = OrigLoop->getUniqueExitBlock();
BasicBlock *ExitingBB = OrigLoop->getExitingBlock();
@@ -8868,8 +8678,8 @@ static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB,
}
}
-std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
- VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions) {
+VPlanPtr
+LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
SmallPtrSet<const InterleaveGroup<Instruction> *, 1> InterleaveGroups;
@@ -8880,24 +8690,6 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// process after constructing the initial VPlan.
// ---------------------------------------------------------------------------
- for (const auto &Reduction : CM.getInLoopReductionChains()) {
- PHINode *Phi = Reduction.first;
- RecurKind Kind =
- Legal->getReductionVars().find(Phi)->second.getRecurrenceKind();
- const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second;
-
- RecipeBuilder.recordRecipeOf(Phi);
- for (const auto &R : ReductionOperations) {
- RecipeBuilder.recordRecipeOf(R);
- // For min/max reductions, where we have a pair of icmp/select, we also
- // need to record the ICmp recipe, so it can be removed later.
- assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) &&
- "Only min/max recurrences allowed for inloop reductions");
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
- RecipeBuilder.recordRecipeOf(cast<Instruction>(R->getOperand(0)));
- }
- }
-
// For each interleave group which is relevant for this (possibly trimmed)
// Range, add it to the set of groups to be later applied to the VPlan and add
// placeholders for its members' Recipes which we'll be replacing with a
@@ -8938,23 +8730,27 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");
VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch");
VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB);
- auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop");
- VPBlockUtils::insertBlockAfter(TopRegion, Plan->getEntry());
- VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
- VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
+ Plan->getVectorLoopRegion()->setEntry(HeaderVPBB);
+ Plan->getVectorLoopRegion()->setExiting(LatchVPBB);
// Don't use getDecisionAndClampRange here, because we don't know the UF
// so this function is better to be conservative, rather than to split
// it up into different VPlans.
+ // TODO: Consider using getDecisionAndClampRange here to split up VPlans.
bool IVUpdateMayOverflow = false;
for (ElementCount VF : Range)
IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF);
- Instruction *DLInst =
- getDebugLocFromInstOrOperands(Legal->getPrimaryInduction());
- addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(),
- DLInst ? DLInst->getDebugLoc() : DebugLoc(),
- CM.getTailFoldingStyle(IVUpdateMayOverflow));
+ DebugLoc DL = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction());
+ TailFoldingStyle Style = CM.getTailFoldingStyle(IVUpdateMayOverflow);
+ // When not folding the tail, we know that the induction increment will not
+ // overflow.
+ bool HasNUW = Style == TailFoldingStyle::None;
+ addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
+
+ // Proactively create header mask. Masks for other blocks are created on
+ // demand.
+ RecipeBuilder.createHeaderMask(*Plan);
// Scan the body of the loop in a topological order to visit each basic block
// after having visited its predecessor basic blocks.
@@ -8971,14 +8767,8 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// Introduce each ingredient into VPlan.
// TODO: Model and preserve debug intrinsics in VPlan.
- for (Instruction &I : BB->instructionsWithoutDebug(false)) {
+ for (Instruction &I : drop_end(BB->instructionsWithoutDebug(false))) {
Instruction *Instr = &I;
-
- // First filter out irrelevant instructions, to ensure no recipes are
- // built for them.
- if (isa<BranchInst>(Instr) || DeadInstructions.count(Instr))
- continue;
-
SmallVector<VPValue *, 4> Operands;
auto *Phi = dyn_cast<PHINode>(Instr);
if (Phi && Phi->getParent() == OrigLoop->getHeader()) {
@@ -9018,11 +8808,18 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
}
RecipeBuilder.setRecipe(Instr, Recipe);
- if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) &&
- HeaderVPBB->getFirstNonPhi() != VPBB->end()) {
- // Move VPWidenIntOrFpInductionRecipes for optimized truncates to the
- // phi section of HeaderVPBB.
- assert(isa<TruncInst>(Instr));
+ if (isa<VPHeaderPHIRecipe>(Recipe)) {
+ // VPHeaderPHIRecipes must be kept in the phi section of HeaderVPBB. In
+ // the following cases, VPHeaderPHIRecipes may be created after non-phi
+ // recipes and need to be moved to the phi section of HeaderVPBB:
+ // * tail-folding (non-phi recipes computing the header mask are
+ // introduced earlier than regular header phi recipes, and should appear
+ // after them)
+ // * Optimizing truncates to VPWidenIntOrFpInductionRecipe.
+
+ assert((HeaderVPBB->getFirstNonPhi() == VPBB->end() ||
+ CM.foldTailByMasking() || isa<TruncInst>(Instr)) &&
+ "unexpected recipe needs moving");
Recipe->insertBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi());
} else
VPBB->appendRecipe(Recipe);
@@ -9040,7 +8837,7 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// and there is nothing to fix from vector loop; phis should have incoming
// from scalar loop only.
} else
- addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan);
+ addUsersInExitBlock(HeaderVPBB, OrigLoop, *Plan);
assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) &&
!Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() &&
@@ -9054,8 +8851,7 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// ---------------------------------------------------------------------------
// Adjust the recipes for any inloop reductions.
- adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExiting()), Plan,
- RecipeBuilder, Range.Start);
+ adjustRecipesForReductions(LatchVPBB, Plan, RecipeBuilder, Range.Start);
// Interleave memory: for each Interleave Group we marked earlier as relevant
// for this VPlan, replace the Recipes widening its memory instructions with a
@@ -9116,21 +8912,18 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// Sink users of fixed-order recurrence past the recipe defining the previous
// value and introduce FirstOrderRecurrenceSplice VPInstructions.
if (!VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder))
- return std::nullopt;
-
- VPlanTransforms::removeRedundantCanonicalIVs(*Plan);
- VPlanTransforms::removeRedundantInductionCasts(*Plan);
-
- VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE());
- VPlanTransforms::removeDeadRecipes(*Plan);
-
- VPlanTransforms::createAndOptimizeReplicateRegions(*Plan);
-
- VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan);
- VPlanTransforms::mergeBlocksIntoPredecessors(*Plan);
+ return nullptr;
- assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid");
- return std::make_optional(std::move(Plan));
+ if (useActiveLaneMask(Style)) {
+ // TODO: Move checks to VPlanTransforms::addActiveLaneMask once
+ // TailFoldingStyle is visible there.
+ bool ForControlFlow = useActiveLaneMaskForControlFlow(Style);
+ bool WithoutRuntimeCheck =
+ Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
+ VPlanTransforms::addActiveLaneMask(*Plan, ForControlFlow,
+ WithoutRuntimeCheck);
+ }
+ return Plan;
}
VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
@@ -9164,8 +8957,11 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
Plan->getVectorLoopRegion()->getExitingBasicBlock()->getTerminator();
Term->eraseFromParent();
- addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DebugLoc(),
- CM.getTailFoldingStyle());
+ // Tail folding is not supported for outer loops, so the induction increment
+ // is guaranteed to not wrap.
+ bool HasNUW = true;
+ addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW,
+ DebugLoc());
return Plan;
}
@@ -9177,105 +8973,211 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
void LoopVectorizationPlanner::adjustRecipesForReductions(
VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder,
ElementCount MinVF) {
- for (const auto &Reduction : CM.getInLoopReductionChains()) {
- PHINode *Phi = Reduction.first;
- const RecurrenceDescriptor &RdxDesc =
- Legal->getReductionVars().find(Phi)->second;
- const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second;
+ VPBasicBlock *Header = Plan->getVectorLoopRegion()->getEntryBasicBlock();
+ // Gather all VPReductionPHIRecipe and sort them so that Intermediate stores
+ // sank outside of the loop would keep the same order as they had in the
+ // original loop.
+ SmallVector<VPReductionPHIRecipe *> ReductionPHIList;
+ for (VPRecipeBase &R : Header->phis()) {
+ if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R))
+ ReductionPHIList.emplace_back(ReductionPhi);
+ }
+ bool HasIntermediateStore = false;
+ stable_sort(ReductionPHIList,
+ [this, &HasIntermediateStore](const VPReductionPHIRecipe *R1,
+ const VPReductionPHIRecipe *R2) {
+ auto *IS1 = R1->getRecurrenceDescriptor().IntermediateStore;
+ auto *IS2 = R2->getRecurrenceDescriptor().IntermediateStore;
+ HasIntermediateStore |= IS1 || IS2;
+
+ // If neither of the recipes has an intermediate store, keep the
+ // order the same.
+ if (!IS1 && !IS2)
+ return false;
+
+ // If only one of the recipes has an intermediate store, then
+ // move it towards the beginning of the list.
+ if (IS1 && !IS2)
+ return true;
+
+ if (!IS1 && IS2)
+ return false;
- if (MinVF.isScalar() && !CM.useOrderedReductions(RdxDesc))
+ // If both recipes have an intermediate store, then the recipe
+ // with the later store should be processed earlier. So it
+ // should go to the beginning of the list.
+ return DT->dominates(IS2, IS1);
+ });
+
+ if (HasIntermediateStore && ReductionPHIList.size() > 1)
+ for (VPRecipeBase *R : ReductionPHIList)
+ R->moveBefore(*Header, Header->getFirstNonPhi());
+
+ SmallVector<VPReductionPHIRecipe *> InLoopReductionPhis;
+ for (VPRecipeBase &R : Header->phis()) {
+ auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
+ if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
continue;
+ InLoopReductionPhis.push_back(PhiR);
+ }
- // ReductionOperations are orders top-down from the phi's use to the
- // LoopExitValue. We keep a track of the previous item (the Chain) to tell
- // which of the two operands will remain scalar and which will be reduced.
- // For minmax the chain will be the select instructions.
- Instruction *Chain = Phi;
- for (Instruction *R : ReductionOperations) {
- VPRecipeBase *WidenRecipe = RecipeBuilder.getRecipe(R);
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ for (VPReductionPHIRecipe *PhiR : InLoopReductionPhis) {
+ const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+ RecurKind Kind = RdxDesc.getRecurrenceKind();
+ assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
+ "AnyOf reductions are not allowed for in-loop reductions");
- VPValue *ChainOp = Plan->getVPValue(Chain);
- unsigned FirstOpId;
- assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) &&
- "Only min/max recurrences allowed for inloop reductions");
+ // Collect the chain of "link" recipes for the reduction starting at PhiR.
+ SetVector<VPRecipeBase *> Worklist;
+ Worklist.insert(PhiR);
+ for (unsigned I = 0; I != Worklist.size(); ++I) {
+ VPRecipeBase *Cur = Worklist[I];
+ for (VPUser *U : Cur->getVPSingleValue()->users()) {
+ auto *UserRecipe = dyn_cast<VPRecipeBase>(U);
+ if (!UserRecipe)
+ continue;
+ assert(UserRecipe->getNumDefinedValues() == 1 &&
+ "recipes must define exactly one result value");
+ Worklist.insert(UserRecipe);
+ }
+ }
+
+ // Visit operation "Links" along the reduction chain top-down starting from
+ // the phi until LoopExitValue. We keep track of the previous item
+ // (PreviousLink) to tell which of the two operands of a Link will remain
+ // scalar and which will be reduced. For minmax by select(cmp), Link will be
+ // the select instructions.
+ VPRecipeBase *PreviousLink = PhiR; // Aka Worklist[0].
+ for (VPRecipeBase *CurrentLink : Worklist.getArrayRef().drop_front()) {
+ VPValue *PreviousLinkV = PreviousLink->getVPSingleValue();
+
+ Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr();
+
+ // Index of the first operand which holds a non-mask vector operand.
+ unsigned IndexOfFirstOperand;
// Recognize a call to the llvm.fmuladd intrinsic.
bool IsFMulAdd = (Kind == RecurKind::FMulAdd);
- assert((!IsFMulAdd || RecurrenceDescriptor::isFMulAddIntrinsic(R)) &&
- "Expected instruction to be a call to the llvm.fmuladd intrinsic");
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
- assert(isa<VPWidenSelectRecipe>(WidenRecipe) &&
- "Expected to replace a VPWidenSelectSC");
- FirstOpId = 1;
+ VPValue *VecOp;
+ VPBasicBlock *LinkVPBB = CurrentLink->getParent();
+ if (IsFMulAdd) {
+ assert(
+ RecurrenceDescriptor::isFMulAddIntrinsic(CurrentLinkI) &&
+ "Expected instruction to be a call to the llvm.fmuladd intrinsic");
+ assert(((MinVF.isScalar() && isa<VPReplicateRecipe>(CurrentLink)) ||
+ isa<VPWidenCallRecipe>(CurrentLink)) &&
+ CurrentLink->getOperand(2) == PreviousLinkV &&
+ "expected a call where the previous link is the added operand");
+
+ // If the instruction is a call to the llvm.fmuladd intrinsic then we
+ // need to create an fmul recipe (multiplying the first two operands of
+ // the fmuladd together) to use as the vector operand for the fadd
+ // reduction.
+ VPInstruction *FMulRecipe = new VPInstruction(
+ Instruction::FMul,
+ {CurrentLink->getOperand(0), CurrentLink->getOperand(1)},
+ CurrentLinkI->getFastMathFlags());
+ LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
+ VecOp = FMulRecipe;
} else {
- assert((MinVF.isScalar() || isa<VPWidenRecipe>(WidenRecipe) ||
- (IsFMulAdd && isa<VPWidenCallRecipe>(WidenRecipe))) &&
- "Expected to replace a VPWidenSC");
- FirstOpId = 0;
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
+ if (isa<VPWidenRecipe>(CurrentLink)) {
+ assert(isa<CmpInst>(CurrentLinkI) &&
+ "need to have the compare of the select");
+ continue;
+ }
+ assert(isa<VPWidenSelectRecipe>(CurrentLink) &&
+ "must be a select recipe");
+ IndexOfFirstOperand = 1;
+ } else {
+ assert((MinVF.isScalar() || isa<VPWidenRecipe>(CurrentLink)) &&
+ "Expected to replace a VPWidenSC");
+ IndexOfFirstOperand = 0;
+ }
+ // Note that for non-commutable operands (cmp-selects), the semantics of
+ // the cmp-select are captured in the recurrence kind.
+ unsigned VecOpId =
+ CurrentLink->getOperand(IndexOfFirstOperand) == PreviousLinkV
+ ? IndexOfFirstOperand + 1
+ : IndexOfFirstOperand;
+ VecOp = CurrentLink->getOperand(VecOpId);
+ assert(VecOp != PreviousLinkV &&
+ CurrentLink->getOperand(CurrentLink->getNumOperands() - 1 -
+ (VecOpId - IndexOfFirstOperand)) ==
+ PreviousLinkV &&
+ "PreviousLinkV must be the operand other than VecOp");
}
- unsigned VecOpId =
- R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId;
- VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId));
+ BasicBlock *BB = CurrentLinkI->getParent();
VPValue *CondOp = nullptr;
- if (CM.blockNeedsPredicationForAnyReason(R->getParent())) {
+ if (CM.blockNeedsPredicationForAnyReason(BB)) {
VPBuilder::InsertPointGuard Guard(Builder);
- Builder.setInsertPoint(WidenRecipe->getParent(),
- WidenRecipe->getIterator());
- CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan);
+ Builder.setInsertPoint(CurrentLink);
+ CondOp = RecipeBuilder.createBlockInMask(BB, *Plan);
}
- if (IsFMulAdd) {
- // If the instruction is a call to the llvm.fmuladd intrinsic then we
- // need to create an fmul recipe to use as the vector operand for the
- // fadd reduction.
- VPInstruction *FMulRecipe = new VPInstruction(
- Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))});
- FMulRecipe->setFastMathFlags(R->getFastMathFlags());
- WidenRecipe->getParent()->insert(FMulRecipe,
- WidenRecipe->getIterator());
- VecOp = FMulRecipe;
- }
- VPReductionRecipe *RedRecipe =
- new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI);
- WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
- Plan->removeVPValueFor(R);
- Plan->addVPValue(R, RedRecipe);
+ VPReductionRecipe *RedRecipe = new VPReductionRecipe(
+ RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp);
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
- WidenRecipe->getParent()->appendRecipe(RedRecipe);
- WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
- WidenRecipe->eraseFromParent();
-
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
- VPRecipeBase *CompareRecipe =
- RecipeBuilder.getRecipe(cast<Instruction>(R->getOperand(0)));
- assert(isa<VPWidenRecipe>(CompareRecipe) &&
- "Expected to replace a VPWidenSC");
- assert(cast<VPWidenRecipe>(CompareRecipe)->getNumUsers() == 0 &&
- "Expected no remaining users");
- CompareRecipe->eraseFromParent();
- }
- Chain = R;
+ // Note that this transformation may leave over dead recipes (including
+ // CurrentLink), which will be cleaned by a later VPlan transform.
+ LinkVPBB->appendRecipe(RedRecipe);
+ CurrentLink->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
+ PreviousLink = RedRecipe;
}
}
-
- // If tail is folded by masking, introduce selects between the phi
- // and the live-out instruction of each reduction, at the beginning of the
- // dedicated latch block.
- if (CM.foldTailByMasking()) {
- Builder.setInsertPoint(LatchVPBB, LatchVPBB->begin());
+ Builder.setInsertPoint(&*LatchVPBB->begin());
for (VPRecipeBase &R :
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
- VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
- if (!PhiR || PhiR->isInLoop())
- continue;
+ VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
+ if (!PhiR || PhiR->isInLoop())
+ continue;
+
+ const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+ auto *Result = PhiR->getBackedgeValue()->getDefiningRecipe();
+ // If tail is folded by masking, introduce selects between the phi
+ // and the live-out instruction of each reduction, at the beginning of the
+ // dedicated latch block.
+ if (CM.foldTailByMasking()) {
VPValue *Cond =
RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan);
VPValue *Red = PhiR->getBackedgeValue();
assert(Red->getDefiningRecipe()->getParent() != LatchVPBB &&
"reduction recipe must be defined before latch");
- Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR});
+ FastMathFlags FMFs = RdxDesc.getFastMathFlags();
+ Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType();
+ Result =
+ PhiTy->isFloatingPointTy()
+ ? new VPInstruction(Instruction::Select, {Cond, Red, PhiR}, FMFs)
+ : new VPInstruction(Instruction::Select, {Cond, Red, PhiR});
+ Result->insertBefore(&*Builder.getInsertPoint());
+ Red->replaceUsesWithIf(
+ Result->getVPSingleValue(),
+ [](VPUser &U, unsigned) { return isa<VPLiveOut>(&U); });
+ if (PreferPredicatedReductionSelect ||
+ TTI.preferPredicatedReductionSelect(
+ PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy,
+ TargetTransformInfo::ReductionFlags()))
+ PhiR->setOperand(1, Result->getVPSingleValue());
+ }
+ // If the vector reduction can be performed in a smaller type, we truncate
+ // then extend the loop exit value to enable InstCombine to evaluate the
+ // entire expression in the smaller type.
+ Type *PhiTy = PhiR->getStartValue()->getLiveInIRValue()->getType();
+ if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
+ assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!");
+ Type *RdxTy = RdxDesc.getRecurrenceType();
+ auto *Trunc = new VPWidenCastRecipe(Instruction::Trunc,
+ Result->getVPSingleValue(), RdxTy);
+ auto *Extnd =
+ RdxDesc.isSigned()
+ ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy)
+ : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy);
+
+ Trunc->insertAfter(Result);
+ Extnd->insertAfter(Trunc);
+ Result->getVPSingleValue()->replaceAllUsesWith(Extnd);
+ Trunc->setOperand(0, Result->getVPSingleValue());
}
}
@@ -9313,107 +9215,6 @@ void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif
-void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) {
- assert(!State.Instance && "Int or FP induction being replicated.");
-
- Value *Start = getStartValue()->getLiveInIRValue();
- const InductionDescriptor &ID = getInductionDescriptor();
- TruncInst *Trunc = getTruncInst();
- IRBuilderBase &Builder = State.Builder;
- assert(IV->getType() == ID.getStartValue()->getType() && "Types must match");
- assert(State.VF.isVector() && "must have vector VF");
-
- // The value from the original loop to which we are mapping the new induction
- // variable.
- Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV;
-
- // Fast-math-flags propagate from the original induction instruction.
- IRBuilder<>::FastMathFlagGuard FMFG(Builder);
- if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp()))
- Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags());
-
- // Now do the actual transformations, and start with fetching the step value.
- Value *Step = State.get(getStepValue(), VPIteration(0, 0));
-
- assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) &&
- "Expected either an induction phi-node or a truncate of it!");
-
- // Construct the initial value of the vector IV in the vector loop preheader
- auto CurrIP = Builder.saveIP();
- BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this);
- Builder.SetInsertPoint(VectorPH->getTerminator());
- if (isa<TruncInst>(EntryVal)) {
- assert(Start->getType()->isIntegerTy() &&
- "Truncation requires an integer type");
- auto *TruncType = cast<IntegerType>(EntryVal->getType());
- Step = Builder.CreateTrunc(Step, TruncType);
- Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType);
- }
-
- Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0);
- Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start);
- Value *SteppedStart = getStepVector(
- SplatStart, Zero, Step, ID.getInductionOpcode(), State.VF, State.Builder);
-
- // We create vector phi nodes for both integer and floating-point induction
- // variables. Here, we determine the kind of arithmetic we will perform.
- Instruction::BinaryOps AddOp;
- Instruction::BinaryOps MulOp;
- if (Step->getType()->isIntegerTy()) {
- AddOp = Instruction::Add;
- MulOp = Instruction::Mul;
- } else {
- AddOp = ID.getInductionOpcode();
- MulOp = Instruction::FMul;
- }
-
- // Multiply the vectorization factor by the step using integer or
- // floating-point arithmetic as appropriate.
- Type *StepType = Step->getType();
- Value *RuntimeVF;
- if (Step->getType()->isFloatingPointTy())
- RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF);
- else
- RuntimeVF = getRuntimeVF(Builder, StepType, State.VF);
- Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF);
-
- // Create a vector splat to use in the induction update.
- //
- // FIXME: If the step is non-constant, we create the vector splat with
- // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't
- // handle a constant vector splat.
- Value *SplatVF = isa<Constant>(Mul)
- ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul))
- : Builder.CreateVectorSplat(State.VF, Mul);
- Builder.restoreIP(CurrIP);
-
- // We may need to add the step a number of times, depending on the unroll
- // factor. The last of those goes into the PHI.
- PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind",
- &*State.CFG.PrevBB->getFirstInsertionPt());
- VecInd->setDebugLoc(EntryVal->getDebugLoc());
- Instruction *LastInduction = VecInd;
- for (unsigned Part = 0; Part < State.UF; ++Part) {
- State.set(this, LastInduction, Part);
-
- if (isa<TruncInst>(EntryVal))
- State.addMetadata(LastInduction, EntryVal);
-
- LastInduction = cast<Instruction>(
- Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add"));
- LastInduction->setDebugLoc(EntryVal->getDebugLoc());
- }
-
- LastInduction->setName("vec.ind.next");
- VecInd->addIncoming(SteppedStart, VectorPH);
- // Add induction update using an incorrect block temporarily. The phi node
- // will be fixed after VPlan execution. Note that at this point the latch
- // block cannot be used, as it does not exist yet.
- // TODO: Model increment value in VPlan, by turning the recipe into a
- // multi-def and a subclass of VPHeaderPHIRecipe.
- VecInd->addIncoming(LastInduction, VectorPH);
-}
-
void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
assert(IndDesc.getKind() == InductionDescriptor::IK_PtrInduction &&
"Not a pointer induction according to InductionDescriptor!");
@@ -9446,7 +9247,8 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) {
Value *Step = State.get(getOperand(1), VPIteration(Part, Lane));
Value *SclrGep = emitTransformedIndex(
- State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc);
+ State.Builder, GlobalIdx, IndDesc.getStartValue(), Step,
+ IndDesc.getKind(), IndDesc.getInductionBinOp());
SclrGep->setName("next.gep");
State.set(this, SclrGep, VPIteration(Part, Lane));
}
@@ -9513,41 +9315,26 @@ void VPDerivedIVRecipe::execute(VPTransformState &State) {
// Fast-math-flags propagate from the original induction instruction.
IRBuilder<>::FastMathFlagGuard FMFG(State.Builder);
- if (IndDesc.getInductionBinOp() &&
- isa<FPMathOperator>(IndDesc.getInductionBinOp()))
- State.Builder.setFastMathFlags(
- IndDesc.getInductionBinOp()->getFastMathFlags());
+ if (FPBinOp)
+ State.Builder.setFastMathFlags(FPBinOp->getFastMathFlags());
Value *Step = State.get(getStepValue(), VPIteration(0, 0));
Value *CanonicalIV = State.get(getCanonicalIV(), VPIteration(0, 0));
- Value *DerivedIV =
- emitTransformedIndex(State.Builder, CanonicalIV,
- getStartValue()->getLiveInIRValue(), Step, IndDesc);
+ Value *DerivedIV = emitTransformedIndex(
+ State.Builder, CanonicalIV, getStartValue()->getLiveInIRValue(), Step,
+ Kind, cast_if_present<BinaryOperator>(FPBinOp));
DerivedIV->setName("offset.idx");
- if (ResultTy != DerivedIV->getType()) {
- assert(Step->getType()->isIntegerTy() &&
+ if (TruncResultTy) {
+ assert(TruncResultTy != DerivedIV->getType() &&
+ Step->getType()->isIntegerTy() &&
"Truncation requires an integer step");
- DerivedIV = State.Builder.CreateTrunc(DerivedIV, ResultTy);
+ DerivedIV = State.Builder.CreateTrunc(DerivedIV, TruncResultTy);
}
assert(DerivedIV != CanonicalIV && "IV didn't need transforming?");
State.set(this, DerivedIV, VPIteration(0, 0));
}
-void VPScalarIVStepsRecipe::execute(VPTransformState &State) {
- // Fast-math-flags propagate from the original induction instruction.
- IRBuilder<>::FastMathFlagGuard FMFG(State.Builder);
- if (IndDesc.getInductionBinOp() &&
- isa<FPMathOperator>(IndDesc.getInductionBinOp()))
- State.Builder.setFastMathFlags(
- IndDesc.getInductionBinOp()->getFastMathFlags());
-
- Value *BaseIV = State.get(getOperand(0), VPIteration(0, 0));
- Value *Step = State.get(getStepValue(), VPIteration(0, 0));
-
- buildScalarSteps(BaseIV, Step, IndDesc, this, State);
-}
-
void VPInterleaveRecipe::execute(VPTransformState &State) {
assert(!State.Instance && "Interleave group being replicated.");
State.ILV->vectorizeInterleaveGroup(IG, definedValues(), State, getAddr(),
@@ -9558,48 +9345,51 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
void VPReductionRecipe::execute(VPTransformState &State) {
assert(!State.Instance && "Reduction being replicated.");
Value *PrevInChain = State.get(getChainOp(), 0);
- RecurKind Kind = RdxDesc->getRecurrenceKind();
- bool IsOrdered = State.ILV->useOrderedReductions(*RdxDesc);
+ RecurKind Kind = RdxDesc.getRecurrenceKind();
+ bool IsOrdered = State.ILV->useOrderedReductions(RdxDesc);
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
- State.Builder.setFastMathFlags(RdxDesc->getFastMathFlags());
+ State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *NewVecOp = State.get(getVecOp(), Part);
if (VPValue *Cond = getCondOp()) {
- Value *NewCond = State.get(Cond, Part);
- VectorType *VecTy = cast<VectorType>(NewVecOp->getType());
- Value *Iden = RdxDesc->getRecurrenceIdentity(
- Kind, VecTy->getElementType(), RdxDesc->getFastMathFlags());
- Value *IdenVec =
- State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden);
- Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, IdenVec);
+ Value *NewCond = State.VF.isVector() ? State.get(Cond, Part)
+ : State.get(Cond, {Part, 0});
+ VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
+ Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
+ Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, ElementTy,
+ RdxDesc.getFastMathFlags());
+ if (State.VF.isVector()) {
+ Iden =
+ State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden);
+ }
+
+ Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Iden);
NewVecOp = Select;
}
Value *NewRed;
Value *NextInChain;
if (IsOrdered) {
if (State.VF.isVector())
- NewRed = createOrderedReduction(State.Builder, *RdxDesc, NewVecOp,
+ NewRed = createOrderedReduction(State.Builder, RdxDesc, NewVecOp,
PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), PrevInChain,
+ (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), PrevInChain,
NewVecOp);
PrevInChain = NewRed;
} else {
PrevInChain = State.get(getChainOp(), Part);
- NewRed = createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp);
+ NewRed = createTargetReduction(State.Builder, RdxDesc, NewVecOp);
}
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
- NextInChain =
- createMinMaxOp(State.Builder, RdxDesc->getRecurrenceKind(),
- NewRed, PrevInChain);
+ NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
+ NewRed, PrevInChain);
} else if (IsOrdered)
NextInChain = NewRed;
else
NextInChain = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), NewRed,
- PrevInChain);
+ (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, PrevInChain);
State.set(this, NextInChain, Part);
}
}
@@ -9618,7 +9408,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
VectorType::get(UI->getType(), State.VF));
State.set(this, Poison, State.Instance->Part);
}
- State.ILV->packScalarIntoVectorValue(this, *State.Instance, State);
+ State.packScalarIntoVectorValue(this, *State.Instance);
}
return;
}
@@ -9684,9 +9474,16 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
InnerLoopVectorizer::VectorParts BlockInMaskParts(State.UF);
bool isMaskRequired = getMask();
- if (isMaskRequired)
- for (unsigned Part = 0; Part < State.UF; ++Part)
- BlockInMaskParts[Part] = State.get(getMask(), Part);
+ if (isMaskRequired) {
+ // Mask reversal is only neede for non-all-one (null) masks, as reverse of a
+ // null all-one mask is a null mask.
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ Value *Mask = State.get(getMask(), Part);
+ if (isReverse())
+ Mask = Builder.CreateVectorReverse(Mask, "reverse");
+ BlockInMaskParts[Part] = Mask;
+ }
+ }
const auto CreateVecPtr = [&](unsigned Part, Value *Ptr) -> Value * {
// Calculate the pointer for the specific unroll-part.
@@ -9697,7 +9494,8 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
const DataLayout &DL =
Builder.GetInsertBlock()->getModule()->getDataLayout();
Type *IndexTy = State.VF.isScalable() && (isReverse() || Part > 0)
- ? DL.getIndexType(ScalarDataTy->getPointerTo())
+ ? DL.getIndexType(PointerType::getUnqual(
+ ScalarDataTy->getContext()))
: Builder.getInt32Ty();
bool InBounds = false;
if (auto *gep = dyn_cast<GetElementPtrInst>(Ptr->stripPointerCasts()))
@@ -9717,21 +9515,17 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, NumElt, "", InBounds);
PartPtr =
Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane, "", InBounds);
- if (isMaskRequired) // Reverse of a null all-one mask is a null mask.
- BlockInMaskParts[Part] =
- Builder.CreateVectorReverse(BlockInMaskParts[Part], "reverse");
} else {
Value *Increment = createStepForVF(Builder, IndexTy, State.VF, Part);
PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, Increment, "", InBounds);
}
- unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace();
- return Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace));
+ return PartPtr;
};
// Handle Stores:
if (SI) {
- State.setDebugLocFromInst(SI);
+ State.setDebugLocFrom(SI->getDebugLoc());
for (unsigned Part = 0; Part < State.UF; ++Part) {
Instruction *NewSI = nullptr;
@@ -9764,7 +9558,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) {
// Handle loads.
assert(LI && "Must have a load instruction");
- State.setDebugLocFromInst(LI);
+ State.setDebugLocFrom(LI->getDebugLoc());
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *NewLI;
if (CreateGatherScatter) {
@@ -9843,95 +9637,6 @@ static ScalarEpilogueLowering getScalarEpilogueLowering(
return CM_ScalarEpilogueAllowed;
}
-Value *VPTransformState::get(VPValue *Def, unsigned Part) {
- // If Values have been set for this Def return the one relevant for \p Part.
- if (hasVectorValue(Def, Part))
- return Data.PerPartOutput[Def][Part];
-
- auto GetBroadcastInstrs = [this, Def](Value *V) {
- bool SafeToHoist = Def->isDefinedOutsideVectorRegions();
- if (VF.isScalar())
- return V;
- // Place the code for broadcasting invariant variables in the new preheader.
- IRBuilder<>::InsertPointGuard Guard(Builder);
- if (SafeToHoist) {
- BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>(
- Plan->getVectorLoopRegion()->getSinglePredecessor())];
- if (LoopVectorPreHeader)
- Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());
- }
-
- // Place the code for broadcasting invariant variables in the new preheader.
- // Broadcast the scalar into all locations in the vector.
- Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast");
-
- return Shuf;
- };
-
- if (!hasScalarValue(Def, {Part, 0})) {
- Value *IRV = Def->getLiveInIRValue();
- Value *B = GetBroadcastInstrs(IRV);
- set(Def, B, Part);
- return B;
- }
-
- Value *ScalarValue = get(Def, {Part, 0});
- // If we aren't vectorizing, we can just copy the scalar map values over
- // to the vector map.
- if (VF.isScalar()) {
- set(Def, ScalarValue, Part);
- return ScalarValue;
- }
-
- bool IsUniform = vputils::isUniformAfterVectorization(Def);
-
- unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1;
- // Check if there is a scalar value for the selected lane.
- if (!hasScalarValue(Def, {Part, LastLane})) {
- // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and
- // VPExpandSCEVRecipes can also be uniform.
- assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) ||
- isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) ||
- isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) &&
- "unexpected recipe found to be invariant");
- IsUniform = true;
- LastLane = 0;
- }
-
- auto *LastInst = cast<Instruction>(get(Def, {Part, LastLane}));
- // Set the insert point after the last scalarized instruction or after the
- // last PHI, if LastInst is a PHI. This ensures the insertelement sequence
- // will directly follow the scalar definitions.
- auto OldIP = Builder.saveIP();
- auto NewIP =
- isa<PHINode>(LastInst)
- ? BasicBlock::iterator(LastInst->getParent()->getFirstNonPHI())
- : std::next(BasicBlock::iterator(LastInst));
- Builder.SetInsertPoint(&*NewIP);
-
- // However, if we are vectorizing, we need to construct the vector values.
- // If the value is known to be uniform after vectorization, we can just
- // broadcast the scalar value corresponding to lane zero for each unroll
- // iteration. Otherwise, we construct the vector values using
- // insertelement instructions. Since the resulting vectors are stored in
- // State, we will only generate the insertelements once.
- Value *VectorValue = nullptr;
- if (IsUniform) {
- VectorValue = GetBroadcastInstrs(ScalarValue);
- set(Def, VectorValue, Part);
- } else {
- // Initialize packing with insertelements to start from undef.
- assert(!VF.isScalable() && "VF is assumed to be non scalable.");
- Value *Undef = PoisonValue::get(VectorType::get(LastInst->getType(), VF));
- set(Def, Undef, Part);
- for (unsigned Lane = 0; Lane < VF.getKnownMinValue(); ++Lane)
- ILV->packScalarIntoVectorValue(Def, {Part, Lane}, *this);
- VectorValue = get(Def, Part);
- }
- Builder.restoreIP(OldIP);
- return VectorValue;
-}
-
// Process the loop in the VPlan-native vectorization path. This path builds
// VPlan upfront in the vectorization pipeline, which allows to apply
// VPlan-to-VPlan transformations from the very beginning without modifying the
@@ -9960,7 +9665,8 @@ static bool processLoopInVPlanNativePath(
// Use the planner for outer loop vectorization.
// TODO: CM is not used at this point inside the planner. Turn CM into an
// optional argument if we don't need it in the future.
- LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, LVL, CM, IAI, PSE, Hints, ORE);
+ LoopVectorizationPlanner LVP(L, LI, DT, TLI, *TTI, LVL, CM, IAI, PSE, Hints,
+ ORE);
// Get user vectorization factor.
ElementCount UserVF = Hints.getWidth();
@@ -9979,8 +9685,10 @@ static bool processLoopInVPlanNativePath(
VPlan &BestPlan = LVP.getBestPlanFor(VF.Width);
{
+ bool AddBranchWeights =
+ hasBranchWeightMD(*L->getLoopLatch()->getTerminator());
GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI,
- F->getParent()->getDataLayout());
+ F->getParent()->getDataLayout(), AddBranchWeights);
InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width,
VF.Width, 1, LVL, &CM, BFI, PSI, Checks);
LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \""
@@ -9988,6 +9696,8 @@ static bool processLoopInVPlanNativePath(
LVP.executePlan(VF.Width, 1, BestPlan, LB, DT, false);
}
+ reportVectorization(ORE, L, VF, 1);
+
// Mark the loop as already vectorized to avoid vectorizing again.
Hints.setAlreadyVectorized();
assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()));
@@ -10042,7 +9752,8 @@ static void checkMixedPrecision(Loop *L, OptimizationRemarkEmitter *ORE) {
static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks,
VectorizationFactor &VF,
std::optional<unsigned> VScale, Loop *L,
- ScalarEvolution &SE) {
+ ScalarEvolution &SE,
+ ScalarEpilogueLowering SEL) {
InstructionCost CheckCost = Checks.getCost();
if (!CheckCost.isValid())
return false;
@@ -10112,11 +9823,13 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks,
// RtC < ScalarC * TC * (1 / X) ==> RtC * X / ScalarC < TC
double MinTC2 = RtC * 10 / ScalarC;
- // Now pick the larger minimum. If it is not a multiple of VF, choose the
- // next closest multiple of VF. This should partly compensate for ignoring
- // the epilogue cost.
+ // Now pick the larger minimum. If it is not a multiple of VF and a scalar
+ // epilogue is allowed, choose the next closest multiple of VF. This should
+ // partly compensate for ignoring the epilogue cost.
uint64_t MinTC = std::ceil(std::max(MinTC1, MinTC2));
- VF.MinProfitableTripCount = ElementCount::getFixed(alignTo(MinTC, IntVF));
+ if (SEL == CM_ScalarEpilogueAllowed)
+ MinTC = alignTo(MinTC, IntVF);
+ VF.MinProfitableTripCount = ElementCount::getFixed(MinTC);
LLVM_DEBUG(
dbgs() << "LV: Minimum required TC for runtime checks to be profitable:"
@@ -10236,7 +9949,14 @@ bool LoopVectorizePass::processLoop(Loop *L) {
else {
if (*ExpectedTC > TTI->getMinTripCountTailFoldingThreshold()) {
LLVM_DEBUG(dbgs() << "\n");
- SEL = CM_ScalarEpilogueNotAllowedLowTripLoop;
+ // Predicate tail-folded loops are efficient even when the loop
+ // iteration count is low. However, setting the epilogue policy to
+ // `CM_ScalarEpilogueNotAllowedLowTripLoop` prevents vectorizing loops
+ // with runtime checks. It's more effective to let
+ // `areRuntimeChecksProfitable` determine if vectorization is beneficial
+ // for the loop.
+ if (SEL != CM_ScalarEpilogueNotNeededUsePredicate)
+ SEL = CM_ScalarEpilogueNotAllowedLowTripLoop;
} else {
LLVM_DEBUG(dbgs() << " But the target considers the trip count too "
"small to consider vectorizing.\n");
@@ -10300,7 +10020,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE,
F, &Hints, IAI);
// Use the planner for vectorization.
- LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, &LVL, CM, IAI, PSE, Hints,
+ LoopVectorizationPlanner LVP(L, LI, DT, TLI, *TTI, &LVL, CM, IAI, PSE, Hints,
ORE);
// Get user vectorization factor and interleave count.
@@ -10313,8 +10033,10 @@ bool LoopVectorizePass::processLoop(Loop *L) {
VectorizationFactor VF = VectorizationFactor::Disabled();
unsigned IC = 1;
+ bool AddBranchWeights =
+ hasBranchWeightMD(*L->getLoopLatch()->getTerminator());
GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI,
- F->getParent()->getDataLayout());
+ F->getParent()->getDataLayout(), AddBranchWeights);
if (MaybeVF) {
VF = *MaybeVF;
// Select the interleave count.
@@ -10331,7 +10053,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
Hints.getForce() == LoopVectorizeHints::FK_Enabled;
if (!ForceVectorization &&
!areRuntimeChecksProfitable(Checks, VF, getVScaleForTuning(L, *TTI), L,
- *PSE.getSE())) {
+ *PSE.getSE(), SEL)) {
ORE->emit([&]() {
return OptimizationRemarkAnalysisAliasing(
DEBUG_TYPE, "CantReorderMemOps", L->getStartLoc(),
@@ -10553,13 +10275,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
DisableRuntimeUnroll = true;
}
// Report the vectorization decision.
- ORE->emit([&]() {
- return OptimizationRemark(LV_NAME, "Vectorized", L->getStartLoc(),
- L->getHeader())
- << "vectorized loop (vectorization width: "
- << NV("VectorizationFactor", VF.Width)
- << ", interleaved count: " << NV("InterleaveCount", IC) << ")";
- });
+ reportVectorization(ORE, L, VF, IC);
}
if (ORE->allowExtraAnalysis(LV_NAME))
@@ -10642,8 +10358,14 @@ LoopVectorizeResult LoopVectorizePass::runImpl(
Changed |= CFGChanged |= processLoop(L);
- if (Changed)
+ if (Changed) {
LAIs->clear();
+
+#ifndef NDEBUG
+ if (VerifySCEV)
+ SE->verify();
+#endif
+ }
}
// Process each loop nest in the function.
@@ -10691,10 +10413,6 @@ PreservedAnalyses LoopVectorizePass::run(Function &F,
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
-
-#ifdef EXPENSIVE_CHECKS
- SE.verify();
-#endif
}
if (Result.MadeCFGChange) {
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 821a3fa22a85..fe2aac78e5ab 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19,7 +19,6 @@
#include "llvm/Transforms/Vectorize/SLPVectorizer.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/PriorityQueue.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
@@ -34,6 +33,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/IVDescriptors.h"
@@ -97,7 +97,6 @@
#include <string>
#include <tuple>
#include <utility>
-#include <vector>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -108,8 +107,9 @@ using namespace slpvectorizer;
STATISTIC(NumVectorInstructions, "Number of vector instructions generated");
-cl::opt<bool> RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden,
- cl::desc("Run the SLP vectorization passes"));
+static cl::opt<bool>
+ RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden,
+ cl::desc("Run the SLP vectorization passes"));
static cl::opt<int>
SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
@@ -140,10 +140,6 @@ static cl::opt<unsigned>
MaxVFOption("slp-max-vf", cl::init(0), cl::Hidden,
cl::desc("Maximum SLP vectorization factor (0=unlimited)"));
-static cl::opt<int>
-MaxStoreLookup("slp-max-store-lookup", cl::init(32), cl::Hidden,
- cl::desc("Maximum depth of the lookup for consecutive stores."));
-
/// Limits the size of scheduling regions in a block.
/// It avoid long compile times for _very_ large blocks where vector
/// instructions are spread over a wide range.
@@ -232,6 +228,17 @@ static bool isVectorLikeInstWithConstOps(Value *V) {
return isConstant(I->getOperand(2));
}
+#if !defined(NDEBUG)
+/// Print a short descriptor of the instruction bundle suitable for debug output.
+static std::string shortBundleName(ArrayRef<Value *> VL) {
+ std::string Result;
+ raw_string_ostream OS(Result);
+ OS << "n=" << VL.size() << " [" << *VL.front() << ", ..]";
+ OS.flush();
+ return Result;
+}
+#endif
+
/// \returns true if all of the instructions in \p VL are in the same block or
/// false otherwise.
static bool allSameBlock(ArrayRef<Value *> VL) {
@@ -429,26 +436,6 @@ static SmallBitVector isUndefVector(const Value *V,
/// i32 6>
/// %2 = mul <4 x i8> %1, %1
/// ret <4 x i8> %2
-/// We convert this initially to something like:
-/// %x0 = extractelement <4 x i8> %x, i32 0
-/// %x3 = extractelement <4 x i8> %x, i32 3
-/// %y1 = extractelement <4 x i8> %y, i32 1
-/// %y2 = extractelement <4 x i8> %y, i32 2
-/// %1 = insertelement <4 x i8> poison, i8 %x0, i32 0
-/// %2 = insertelement <4 x i8> %1, i8 %x3, i32 1
-/// %3 = insertelement <4 x i8> %2, i8 %y1, i32 2
-/// %4 = insertelement <4 x i8> %3, i8 %y2, i32 3
-/// %5 = mul <4 x i8> %4, %4
-/// %6 = extractelement <4 x i8> %5, i32 0
-/// %ins1 = insertelement <4 x i8> poison, i8 %6, i32 0
-/// %7 = extractelement <4 x i8> %5, i32 1
-/// %ins2 = insertelement <4 x i8> %ins1, i8 %7, i32 1
-/// %8 = extractelement <4 x i8> %5, i32 2
-/// %ins3 = insertelement <4 x i8> %ins2, i8 %8, i32 2
-/// %9 = extractelement <4 x i8> %5, i32 3
-/// %ins4 = insertelement <4 x i8> %ins3, i8 %9, i32 3
-/// ret <4 x i8> %ins4
-/// InstCombiner transforms this into a shuffle and vector mul
/// Mask will return the Shuffle Mask equivalent to the extracted elements.
/// TODO: Can we split off and reuse the shuffle mask detection from
/// ShuffleVectorInst/getShuffleCost?
@@ -539,117 +526,6 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
return *EI->idx_begin();
}
-/// Tries to find extractelement instructions with constant indices from fixed
-/// vector type and gather such instructions into a bunch, which highly likely
-/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was
-/// successful, the matched scalars are replaced by poison values in \p VL for
-/// future analysis.
-static std::optional<TTI::ShuffleKind>
-tryToGatherExtractElements(SmallVectorImpl<Value *> &VL,
- SmallVectorImpl<int> &Mask) {
- // Scan list of gathered scalars for extractelements that can be represented
- // as shuffles.
- MapVector<Value *, SmallVector<int>> VectorOpToIdx;
- SmallVector<int> UndefVectorExtracts;
- for (int I = 0, E = VL.size(); I < E; ++I) {
- auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
- if (!EI) {
- if (isa<UndefValue>(VL[I]))
- UndefVectorExtracts.push_back(I);
- continue;
- }
- auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
- if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand()))
- continue;
- std::optional<unsigned> Idx = getExtractIndex(EI);
- // Undefined index.
- if (!Idx) {
- UndefVectorExtracts.push_back(I);
- continue;
- }
- SmallBitVector ExtractMask(VecTy->getNumElements(), true);
- ExtractMask.reset(*Idx);
- if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) {
- UndefVectorExtracts.push_back(I);
- continue;
- }
- VectorOpToIdx[EI->getVectorOperand()].push_back(I);
- }
- // Sort the vector operands by the maximum number of uses in extractelements.
- MapVector<unsigned, SmallVector<Value *>> VFToVector;
- for (const auto &Data : VectorOpToIdx)
- VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
- .push_back(Data.first);
- for (auto &Data : VFToVector) {
- stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
- return VectorOpToIdx.find(V1)->second.size() >
- VectorOpToIdx.find(V2)->second.size();
- });
- }
- // Find the best pair of the vectors with the same number of elements or a
- // single vector.
- const int UndefSz = UndefVectorExtracts.size();
- unsigned SingleMax = 0;
- Value *SingleVec = nullptr;
- unsigned PairMax = 0;
- std::pair<Value *, Value *> PairVec(nullptr, nullptr);
- for (auto &Data : VFToVector) {
- Value *V1 = Data.second.front();
- if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
- SingleMax = VectorOpToIdx[V1].size() + UndefSz;
- SingleVec = V1;
- }
- Value *V2 = nullptr;
- if (Data.second.size() > 1)
- V2 = *std::next(Data.second.begin());
- if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
- UndefSz) {
- PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
- PairVec = std::make_pair(V1, V2);
- }
- }
- if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
- return std::nullopt;
- // Check if better to perform a shuffle of 2 vectors or just of a single
- // vector.
- SmallVector<Value *> SavedVL(VL.begin(), VL.end());
- SmallVector<Value *> GatheredExtracts(
- VL.size(), PoisonValue::get(VL.front()->getType()));
- if (SingleMax >= PairMax && SingleMax) {
- for (int Idx : VectorOpToIdx[SingleVec])
- std::swap(GatheredExtracts[Idx], VL[Idx]);
- } else {
- for (Value *V : {PairVec.first, PairVec.second})
- for (int Idx : VectorOpToIdx[V])
- std::swap(GatheredExtracts[Idx], VL[Idx]);
- }
- // Add extracts from undefs too.
- for (int Idx : UndefVectorExtracts)
- std::swap(GatheredExtracts[Idx], VL[Idx]);
- // Check that gather of extractelements can be represented as just a
- // shuffle of a single/two vectors the scalars are extracted from.
- std::optional<TTI::ShuffleKind> Res =
- isFixedVectorShuffle(GatheredExtracts, Mask);
- if (!Res) {
- // TODO: try to check other subsets if possible.
- // Restore the original VL if attempt was not successful.
- VL.swap(SavedVL);
- return std::nullopt;
- }
- // Restore unused scalars from mask, if some of the extractelements were not
- // selected for shuffle.
- for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) {
- auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
- if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) ||
- !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) ||
- is_contained(UndefVectorExtracts, I))
- continue;
- if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I]))
- std::swap(VL[I], GatheredExtracts[I]);
- }
- return Res;
-}
-
namespace {
/// Main data required for vectorization of instructions.
@@ -695,7 +571,7 @@ static Value *isOneOf(const InstructionsState &S, Value *Op) {
return S.OpValue;
}
-/// \returns true if \p Opcode is allowed as part of of the main/alternate
+/// \returns true if \p Opcode is allowed as part of the main/alternate
/// instruction for SLP vectorization.
///
/// Example of unsupported opcode is SDIV that can potentially cause UB if the
@@ -889,18 +765,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
/// \returns true if all of the values in \p VL have the same type or false
/// otherwise.
static bool allSameType(ArrayRef<Value *> VL) {
- Type *Ty = VL[0]->getType();
- for (int i = 1, e = VL.size(); i < e; i++)
- if (VL[i]->getType() != Ty)
- return false;
-
- return true;
+ Type *Ty = VL.front()->getType();
+ return all_of(VL.drop_front(), [&](Value *V) { return V->getType() == Ty; });
}
/// \returns True if in-tree use also needs extract. This refers to
/// possible scalar operand in vectorized instruction.
-static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
- TargetLibraryInfo *TLI) {
+static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
+ TargetLibraryInfo *TLI) {
unsigned Opcode = UserInst->getOpcode();
switch (Opcode) {
case Instruction::Load: {
@@ -914,11 +786,10 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
case Instruction::Call: {
CallInst *CI = cast<CallInst>(UserInst);
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
- for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) {
- if (isVectorIntrinsicWithScalarOpAtArg(ID, i))
- return (CI->getArgOperand(i) == Scalar);
- }
- [[fallthrough]];
+ return any_of(enumerate(CI->args()), [&](auto &&Arg) {
+ return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) &&
+ Arg.value().get() == Scalar;
+ });
}
default:
return false;
@@ -1181,6 +1052,7 @@ public:
void deleteTree() {
VectorizableTree.clear();
ScalarToTreeEntry.clear();
+ MultiNodeScalars.clear();
MustGather.clear();
EntryToLastInstruction.clear();
ExternalUses.clear();
@@ -1273,7 +1145,7 @@ public:
/// {{{i16, i16}, {i16, i16}}, {{i16, i16}, {i16, i16}}} and so on.
///
/// \returns number of elements in vector if isomorphism exists, 0 otherwise.
- unsigned canMapToVector(Type *T, const DataLayout &DL) const;
+ unsigned canMapToVector(Type *T) const;
/// \returns True if the VectorizableTree is both tiny and not fully
/// vectorizable. We do not vectorize such trees.
@@ -1324,6 +1196,9 @@ public:
}
LLVM_DUMP_METHOD void dump() const { dump(dbgs()); }
#endif
+ bool operator == (const EdgeInfo &Other) const {
+ return UserTE == Other.UserTE && EdgeIdx == Other.EdgeIdx;
+ }
};
/// A helper class used for scoring candidates for two consecutive lanes.
@@ -1764,7 +1639,7 @@ public:
auto *IdxLaneI = dyn_cast<Instruction>(IdxLaneV);
if (!IdxLaneI || !isa<Instruction>(OpIdxLaneV))
return 0;
- return R.areAllUsersVectorized(IdxLaneI, std::nullopt)
+ return R.areAllUsersVectorized(IdxLaneI)
? LookAheadHeuristics::ScoreAllUserVectorized
: 0;
}
@@ -1941,7 +1816,7 @@ public:
HashMap[NumFreeOpsHash.Hash] = std::make_pair(1, Lane);
} else if (NumFreeOpsHash.NumOfAPOs == Min &&
NumFreeOpsHash.NumOpsWithSameOpcodeParent == SameOpNumber) {
- auto It = HashMap.find(NumFreeOpsHash.Hash);
+ auto *It = HashMap.find(NumFreeOpsHash.Hash);
if (It == HashMap.end())
HashMap[NumFreeOpsHash.Hash] = std::make_pair(1, Lane);
else
@@ -2203,7 +2078,7 @@ public:
for (int Pass = 0; Pass != 2; ++Pass) {
// Check if no need to reorder operands since they're are perfect or
// shuffled diamond match.
- // Need to to do it to avoid extra external use cost counting for
+ // Need to do it to avoid extra external use cost counting for
// shuffled matches, which may cause regressions.
if (SkipReordering())
break;
@@ -2388,6 +2263,18 @@ public:
~BoUpSLP();
private:
+ /// Determine if a vectorized value \p V in can be demoted to
+ /// a smaller type with a truncation. We collect the values that will be
+ /// demoted in ToDemote and additional roots that require investigating in
+ /// Roots.
+ /// \param DemotedConsts list of Instruction/OperandIndex pairs that are
+ /// constant and to be demoted. Required to correctly identify constant nodes
+ /// to be demoted.
+ bool collectValuesToDemote(
+ Value *V, SmallVectorImpl<Value *> &ToDemote,
+ DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
+ SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const;
+
/// Check if the operands on the edges \p Edges of the \p UserTE allows
/// reordering (i.e. the operands can be reordered because they have only one
/// user and reordarable).
@@ -2410,12 +2297,25 @@ private:
TreeEntry *getVectorizedOperand(TreeEntry *UserTE, unsigned OpIdx) {
ArrayRef<Value *> VL = UserTE->getOperand(OpIdx);
TreeEntry *TE = nullptr;
- const auto *It = find_if(VL, [this, &TE](Value *V) {
+ const auto *It = find_if(VL, [&](Value *V) {
TE = getTreeEntry(V);
- return TE;
+ if (TE && is_contained(TE->UserTreeIndices, EdgeInfo(UserTE, OpIdx)))
+ return true;
+ auto It = MultiNodeScalars.find(V);
+ if (It != MultiNodeScalars.end()) {
+ for (TreeEntry *E : It->second) {
+ if (is_contained(E->UserTreeIndices, EdgeInfo(UserTE, OpIdx))) {
+ TE = E;
+ return true;
+ }
+ }
+ }
+ return false;
});
- if (It != VL.end() && TE->isSame(VL))
+ if (It != VL.end()) {
+ assert(TE->isSame(VL) && "Expected same scalars.");
return TE;
+ }
return nullptr;
}
@@ -2428,13 +2328,16 @@ private:
}
/// Checks if all users of \p I are the part of the vectorization tree.
- bool areAllUsersVectorized(Instruction *I,
- ArrayRef<Value *> VectorizedVals) const;
+ bool areAllUsersVectorized(
+ Instruction *I,
+ const SmallDenseSet<Value *> *VectorizedVals = nullptr) const;
/// Return information about the vector formed for the specified index
/// of a vector of (the same) instruction.
- TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> VL,
- unsigned OpIdx);
+ TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> Ops);
+
+ /// \ returns the graph entry for the \p Idx operand of the \p E entry.
+ const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const;
/// \returns the cost of the vectorizable entry.
InstructionCost getEntryCost(const TreeEntry *E,
@@ -2450,15 +2353,22 @@ private:
/// vector) and sets \p CurrentOrder to the identity permutation; otherwise
/// returns false, setting \p CurrentOrder to either an empty vector or a
/// non-identity permutation that allows to reuse extract instructions.
+ /// \param ResizeAllowed indicates whether it is allowed to handle subvector
+ /// extract order.
bool canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
- SmallVectorImpl<unsigned> &CurrentOrder) const;
+ SmallVectorImpl<unsigned> &CurrentOrder,
+ bool ResizeAllowed = false) const;
/// Vectorize a single entry in the tree.
- Value *vectorizeTree(TreeEntry *E);
+ /// \param PostponedPHIs true, if need to postpone emission of phi nodes to
+ /// avoid issues with def-use order.
+ Value *vectorizeTree(TreeEntry *E, bool PostponedPHIs);
/// Vectorize a single entry in the tree, the \p Idx-th operand of the entry
/// \p E.
- Value *vectorizeOperand(TreeEntry *E, unsigned NodeIdx);
+ /// \param PostponedPHIs true, if need to postpone emission of phi nodes to
+ /// avoid issues with def-use order.
+ Value *vectorizeOperand(TreeEntry *E, unsigned NodeIdx, bool PostponedPHIs);
/// Create a new vector from a list of scalar values. Produces a sequence
/// which exploits values reused across lanes, and arranges the inserts
@@ -2477,17 +2387,50 @@ private:
/// instruction in the list).
Instruction &getLastInstructionInBundle(const TreeEntry *E);
- /// Checks if the gathered \p VL can be represented as shuffle(s) of previous
- /// tree entries.
+ /// Tries to find extractelement instructions with constant indices from fixed
+ /// vector type and gather such instructions into a bunch, which highly likely
+ /// might be detected as a shuffle of 1 or 2 input vectors. If this attempt
+ /// was successful, the matched scalars are replaced by poison values in \p VL
+ /// for future analysis.
+ std::optional<TargetTransformInfo::ShuffleKind>
+ tryToGatherSingleRegisterExtractElements(MutableArrayRef<Value *> VL,
+ SmallVectorImpl<int> &Mask) const;
+
+ /// Tries to find extractelement instructions with constant indices from fixed
+ /// vector type and gather such instructions into a bunch, which highly likely
+ /// might be detected as a shuffle of 1 or 2 input vectors. If this attempt
+ /// was successful, the matched scalars are replaced by poison values in \p VL
+ /// for future analysis.
+ SmallVector<std::optional<TargetTransformInfo::ShuffleKind>>
+ tryToGatherExtractElements(SmallVectorImpl<Value *> &VL,
+ SmallVectorImpl<int> &Mask,
+ unsigned NumParts) const;
+
+ /// Checks if the gathered \p VL can be represented as a single register
+ /// shuffle(s) of previous tree entries.
/// \param TE Tree entry checked for permutation.
/// \param VL List of scalars (a subset of the TE scalar), checked for
- /// permutations.
+ /// permutations. Must form single-register vector.
/// \returns ShuffleKind, if gathered values can be represented as shuffles of
- /// previous tree entries. \p Mask is filled with the shuffle mask.
+ /// previous tree entries. \p Part of \p Mask is filled with the shuffle mask.
std::optional<TargetTransformInfo::ShuffleKind>
- isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
- SmallVectorImpl<int> &Mask,
- SmallVectorImpl<const TreeEntry *> &Entries);
+ isGatherShuffledSingleRegisterEntry(
+ const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask,
+ SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part);
+
+ /// Checks if the gathered \p VL can be represented as multi-register
+ /// shuffle(s) of previous tree entries.
+ /// \param TE Tree entry checked for permutation.
+ /// \param VL List of scalars (a subset of the TE scalar), checked for
+ /// permutations.
+ /// \returns per-register series of ShuffleKind, if gathered values can be
+ /// represented as shuffles of previous tree entries. \p Mask is filled with
+ /// the shuffle mask (also on per-register base).
+ SmallVector<std::optional<TargetTransformInfo::ShuffleKind>>
+ isGatherShuffledEntry(
+ const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask,
+ SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries,
+ unsigned NumParts);
/// \returns the scalarization cost for this list of values. Assuming that
/// this subtree gets vectorized, we may need to extract the values from the
@@ -2517,14 +2460,14 @@ private:
/// Helper for `findExternalStoreUsersReorderIndices()`. It iterates over the
/// users of \p TE and collects the stores. It returns the map from the store
/// pointers to the collected stores.
- DenseMap<Value *, SmallVector<StoreInst *, 4>>
+ DenseMap<Value *, SmallVector<StoreInst *>>
collectUserStores(const BoUpSLP::TreeEntry *TE) const;
/// Helper for `findExternalStoreUsersReorderIndices()`. It checks if the
- /// stores in \p StoresVec can form a vector instruction. If so it returns true
- /// and populates \p ReorderIndices with the shuffle indices of the the stores
- /// when compared to the sorted vector.
- bool canFormVector(const SmallVector<StoreInst *, 4> &StoresVec,
+ /// stores in \p StoresVec can form a vector instruction. If so it returns
+ /// true and populates \p ReorderIndices with the shuffle indices of the
+ /// stores when compared to the sorted vector.
+ bool canFormVector(ArrayRef<StoreInst *> StoresVec,
OrdersType &ReorderIndices) const;
/// Iterates through the users of \p TE, looking for scalar stores that can be
@@ -2621,10 +2564,18 @@ private:
/// The Scalars are vectorized into this value. It is initialized to Null.
WeakTrackingVH VectorizedValue = nullptr;
+ /// New vector phi instructions emitted for the vectorized phi nodes.
+ PHINode *PHI = nullptr;
+
/// Do we need to gather this sequence or vectorize it
/// (either with vector instruction or with scatter/gather
/// intrinsics for store/load)?
- enum EntryState { Vectorize, ScatterVectorize, NeedToGather };
+ enum EntryState {
+ Vectorize,
+ ScatterVectorize,
+ PossibleStridedVectorize,
+ NeedToGather
+ };
EntryState State;
/// Does this sequence require some shuffling?
@@ -2772,6 +2723,14 @@ private:
return FoundLane;
}
+ /// Build a shuffle mask for graph entry which represents a merge of main
+ /// and alternate operations.
+ void
+ buildAltOpShuffleMask(const function_ref<bool(Instruction *)> IsAltOp,
+ SmallVectorImpl<int> &Mask,
+ SmallVectorImpl<Value *> *OpScalars = nullptr,
+ SmallVectorImpl<Value *> *AltScalars = nullptr) const;
+
#ifndef NDEBUG
/// Debug printer.
LLVM_DUMP_METHOD void dump() const {
@@ -2792,6 +2751,9 @@ private:
case ScatterVectorize:
dbgs() << "ScatterVectorize\n";
break;
+ case PossibleStridedVectorize:
+ dbgs() << "PossibleStridedVectorize\n";
+ break;
case NeedToGather:
dbgs() << "NeedToGather\n";
break;
@@ -2892,7 +2854,14 @@ private:
}
if (Last->State != TreeEntry::NeedToGather) {
for (Value *V : VL) {
- assert(!getTreeEntry(V) && "Scalar already in tree!");
+ const TreeEntry *TE = getTreeEntry(V);
+ assert((!TE || TE == Last || doesNotNeedToBeScheduled(V)) &&
+ "Scalar already in tree!");
+ if (TE) {
+ if (TE != Last)
+ MultiNodeScalars.try_emplace(V).first->getSecond().push_back(Last);
+ continue;
+ }
ScalarToTreeEntry[V] = Last;
}
// Update the scheduler bundle to point to this TreeEntry.
@@ -2905,7 +2874,8 @@ private:
for (Value *V : VL) {
if (doesNotNeedToBeScheduled(V))
continue;
- assert(BundleMember && "Unexpected end of bundle.");
+ if (!BundleMember)
+ continue;
BundleMember->TE = Last;
BundleMember = BundleMember->NextInBundle;
}
@@ -2913,6 +2883,10 @@ private:
assert(!BundleMember && "Bundle and VL out of sync");
} else {
MustGather.insert(VL.begin(), VL.end());
+ // Build a map for gathered scalars to the nodes where they are used.
+ for (Value *V : VL)
+ if (!isConstant(V))
+ ValueToGatherNodes.try_emplace(V).first->getSecond().insert(Last);
}
if (UserTreeIdx.UserTE)
@@ -2950,6 +2924,10 @@ private:
/// Maps a specific scalar to its tree entry.
SmallDenseMap<Value *, TreeEntry *> ScalarToTreeEntry;
+ /// List of scalars, used in several vectorize nodes, and the list of the
+ /// nodes.
+ SmallDenseMap<Value *, SmallVector<TreeEntry *>> MultiNodeScalars;
+
/// Maps a value to the proposed vectorizable size.
SmallDenseMap<Value *, unsigned> InstrElementSize;
@@ -2995,25 +2973,25 @@ private:
/// is invariant in the calling loop.
bool isAliased(const MemoryLocation &Loc1, Instruction *Inst1,
Instruction *Inst2) {
+ if (!Loc1.Ptr || !isSimple(Inst1) || !isSimple(Inst2))
+ return true;
// First check if the result is already in the cache.
- AliasCacheKey key = std::make_pair(Inst1, Inst2);
- std::optional<bool> &result = AliasCache[key];
- if (result) {
- return *result;
- }
- bool aliased = true;
- if (Loc1.Ptr && isSimple(Inst1))
- aliased = isModOrRefSet(BatchAA.getModRefInfo(Inst2, Loc1));
+ AliasCacheKey Key = std::make_pair(Inst1, Inst2);
+ auto It = AliasCache.find(Key);
+ if (It != AliasCache.end())
+ return It->second;
+ bool Aliased = isModOrRefSet(BatchAA.getModRefInfo(Inst2, Loc1));
// Store the result in the cache.
- result = aliased;
- return aliased;
+ AliasCache.try_emplace(Key, Aliased);
+ AliasCache.try_emplace(std::make_pair(Inst2, Inst1), Aliased);
+ return Aliased;
}
using AliasCacheKey = std::pair<Instruction *, Instruction *>;
/// Cache for alias results.
/// TODO: consider moving this to the AliasAnalysis itself.
- DenseMap<AliasCacheKey, std::optional<bool>> AliasCache;
+ DenseMap<AliasCacheKey, bool> AliasCache;
// Cache for pointerMayBeCaptured calls inside AA. This is preserved
// globally through SLP because we don't perform any action which
@@ -3047,7 +3025,7 @@ private:
SetVector<Instruction *> GatherShuffleExtractSeq;
/// A list of blocks that we are going to CSE.
- SetVector<BasicBlock *> CSEBlocks;
+ DenseSet<BasicBlock *> CSEBlocks;
/// Contains all scheduling relevant data for an instruction.
/// A ScheduleData either represents a single instruction or a member of an
@@ -3497,7 +3475,7 @@ private:
BasicBlock *BB;
/// Simple memory allocation for ScheduleData.
- std::vector<std::unique_ptr<ScheduleData[]>> ScheduleDataChunks;
+ SmallVector<std::unique_ptr<ScheduleData[]>> ScheduleDataChunks;
/// The size of a ScheduleData array in ScheduleDataChunks.
int ChunkSize;
@@ -3607,7 +3585,7 @@ private:
/// where "width" indicates the minimum bit width and "signed" is True if the
/// value must be signed-extended, rather than zero-extended, back to its
/// original width.
- MapVector<Value *, std::pair<uint64_t, bool>> MinBWs;
+ DenseMap<const TreeEntry *, std::pair<uint64_t, bool>> MinBWs;
};
} // end namespace slpvectorizer
@@ -3676,7 +3654,7 @@ template <> struct GraphTraits<BoUpSLP *> {
template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
using TreeEntry = BoUpSLP::TreeEntry;
- DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
+ DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) {
std::string Str;
@@ -3699,7 +3677,8 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
const BoUpSLP *) {
if (Entry->State == TreeEntry::NeedToGather)
return "color=red";
- if (Entry->State == TreeEntry::ScatterVectorize)
+ if (Entry->State == TreeEntry::ScatterVectorize ||
+ Entry->State == TreeEntry::PossibleStridedVectorize)
return "color=blue";
return "";
}
@@ -3761,7 +3740,7 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) {
inversePermutation(Order, MaskOrder);
}
reorderReuses(MaskOrder, Mask);
- if (ShuffleVectorInst::isIdentityMask(MaskOrder)) {
+ if (ShuffleVectorInst::isIdentityMask(MaskOrder, MaskOrder.size())) {
Order.clear();
return;
}
@@ -3779,7 +3758,40 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
OrdersType CurrentOrder(NumScalars, NumScalars);
SmallVector<int> Positions;
SmallBitVector UsedPositions(NumScalars);
- const TreeEntry *STE = nullptr;
+ DenseMap<const TreeEntry *, unsigned> UsedEntries;
+ DenseMap<Value *, std::pair<const TreeEntry *, unsigned>> ValueToEntryPos;
+ for (Value *V : TE.Scalars) {
+ if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V))
+ continue;
+ const auto *LocalSTE = getTreeEntry(V);
+ if (!LocalSTE)
+ continue;
+ unsigned Lane =
+ std::distance(LocalSTE->Scalars.begin(), find(LocalSTE->Scalars, V));
+ if (Lane >= NumScalars)
+ continue;
+ ++UsedEntries.try_emplace(LocalSTE, 0).first->getSecond();
+ ValueToEntryPos.try_emplace(V, LocalSTE, Lane);
+ }
+ if (UsedEntries.empty())
+ return std::nullopt;
+ const TreeEntry &BestSTE =
+ *std::max_element(UsedEntries.begin(), UsedEntries.end(),
+ [](const std::pair<const TreeEntry *, unsigned> &P1,
+ const std::pair<const TreeEntry *, unsigned> &P2) {
+ return P1.second < P2.second;
+ })
+ ->first;
+ UsedEntries.erase(&BestSTE);
+ const TreeEntry *SecondBestSTE = nullptr;
+ if (!UsedEntries.empty())
+ SecondBestSTE =
+ std::max_element(UsedEntries.begin(), UsedEntries.end(),
+ [](const std::pair<const TreeEntry *, unsigned> &P1,
+ const std::pair<const TreeEntry *, unsigned> &P2) {
+ return P1.second < P2.second;
+ })
+ ->first;
// Try to find all gathered scalars that are gets vectorized in other
// vectorize node. Here we can have only one single tree vector node to
// correctly identify order of the gathered scalars.
@@ -3787,58 +3799,56 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
Value *V = TE.Scalars[I];
if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V))
continue;
- if (const auto *LocalSTE = getTreeEntry(V)) {
- if (!STE)
- STE = LocalSTE;
- else if (STE != LocalSTE)
- // Take the order only from the single vector node.
- return std::nullopt;
- unsigned Lane =
- std::distance(STE->Scalars.begin(), find(STE->Scalars, V));
- if (Lane >= NumScalars)
- return std::nullopt;
- if (CurrentOrder[Lane] != NumScalars) {
- if (Lane != I)
- continue;
- UsedPositions.reset(CurrentOrder[Lane]);
- }
- // The partial identity (where only some elements of the gather node are
- // in the identity order) is good.
- CurrentOrder[Lane] = I;
- UsedPositions.set(I);
+ const auto [LocalSTE, Lane] = ValueToEntryPos.lookup(V);
+ if (!LocalSTE || (LocalSTE != &BestSTE && LocalSTE != SecondBestSTE))
+ continue;
+ if (CurrentOrder[Lane] != NumScalars) {
+ if ((CurrentOrder[Lane] >= BestSTE.Scalars.size() ||
+ BestSTE.Scalars[CurrentOrder[Lane]] == V) &&
+ (Lane != I || LocalSTE == SecondBestSTE))
+ continue;
+ UsedPositions.reset(CurrentOrder[Lane]);
}
+ // The partial identity (where only some elements of the gather node are
+ // in the identity order) is good.
+ CurrentOrder[Lane] = I;
+ UsedPositions.set(I);
}
// Need to keep the order if we have a vector entry and at least 2 scalars or
// the vectorized entry has just 2 scalars.
- if (STE && (UsedPositions.count() > 1 || STE->Scalars.size() == 2)) {
- auto &&IsIdentityOrder = [NumScalars](ArrayRef<unsigned> CurrentOrder) {
- for (unsigned I = 0; I < NumScalars; ++I)
- if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars)
- return false;
- return true;
- };
- if (IsIdentityOrder(CurrentOrder))
- return OrdersType();
- auto *It = CurrentOrder.begin();
- for (unsigned I = 0; I < NumScalars;) {
- if (UsedPositions.test(I)) {
- ++I;
- continue;
- }
- if (*It == NumScalars) {
- *It = I;
- ++I;
- }
- ++It;
+ if (BestSTE.Scalars.size() != 2 && UsedPositions.count() <= 1)
+ return std::nullopt;
+ auto IsIdentityOrder = [&](ArrayRef<unsigned> CurrentOrder) {
+ for (unsigned I = 0; I < NumScalars; ++I)
+ if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars)
+ return false;
+ return true;
+ };
+ if (IsIdentityOrder(CurrentOrder))
+ return OrdersType();
+ auto *It = CurrentOrder.begin();
+ for (unsigned I = 0; I < NumScalars;) {
+ if (UsedPositions.test(I)) {
+ ++I;
+ continue;
}
- return std::move(CurrentOrder);
+ if (*It == NumScalars) {
+ *It = I;
+ ++I;
+ }
+ ++It;
}
- return std::nullopt;
+ return std::move(CurrentOrder);
}
namespace {
/// Tracks the state we can represent the loads in the given sequence.
-enum class LoadsState { Gather, Vectorize, ScatterVectorize };
+enum class LoadsState {
+ Gather,
+ Vectorize,
+ ScatterVectorize,
+ PossibleStridedVectorize
+};
} // anonymous namespace
static bool arePointersCompatible(Value *Ptr1, Value *Ptr2,
@@ -3898,6 +3908,7 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
if (IsSorted || all_of(PointerOps, [&](Value *P) {
return arePointersCompatible(P, PointerOps.front(), TLI);
})) {
+ bool IsPossibleStrided = false;
if (IsSorted) {
Value *Ptr0;
Value *PtrN;
@@ -3913,6 +3924,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
// Check that the sorted loads are consecutive.
if (static_cast<unsigned>(*Diff) == VL.size() - 1)
return LoadsState::Vectorize;
+ // Simple check if not a strided access - clear order.
+ IsPossibleStrided = *Diff % (VL.size() - 1) == 0;
}
// TODO: need to improve analysis of the pointers, if not all of them are
// GEPs or have > 2 operands, we end up with a gather node, which just
@@ -3934,7 +3947,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
auto *VecTy = FixedVectorType::get(ScalarTy, VL.size());
if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment))
- return LoadsState::ScatterVectorize;
+ return IsPossibleStrided ? LoadsState::PossibleStridedVectorize
+ : LoadsState::ScatterVectorize;
}
}
@@ -4050,7 +4064,8 @@ static bool areTwoInsertFromSameBuildVector(
// Go through the vector operand of insertelement instructions trying to find
// either VU as the original vector for IE2 or V as the original vector for
// IE1.
- SmallSet<int, 8> ReusedIdx;
+ SmallBitVector ReusedIdx(
+ cast<VectorType>(VU->getType())->getElementCount().getKnownMinValue());
bool IsReusedIdx = false;
do {
if (IE2 == VU && !IE1)
@@ -4058,16 +4073,18 @@ static bool areTwoInsertFromSameBuildVector(
if (IE1 == V && !IE2)
return V->hasOneUse();
if (IE1 && IE1 != V) {
- IsReusedIdx |=
- !ReusedIdx.insert(getInsertIndex(IE1).value_or(*Idx2)).second;
+ unsigned Idx1 = getInsertIndex(IE1).value_or(*Idx2);
+ IsReusedIdx |= ReusedIdx.test(Idx1);
+ ReusedIdx.set(Idx1);
if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx)
IE1 = nullptr;
else
IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1));
}
if (IE2 && IE2 != VU) {
- IsReusedIdx |=
- !ReusedIdx.insert(getInsertIndex(IE2).value_or(*Idx1)).second;
+ unsigned Idx2 = getInsertIndex(IE2).value_or(*Idx1);
+ IsReusedIdx |= ReusedIdx.test(Idx2);
+ ReusedIdx.set(Idx2);
if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx)
IE2 = nullptr;
else
@@ -4135,13 +4152,16 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
return std::nullopt; // No need to reorder.
return std::move(ResOrder);
}
- if (TE.State == TreeEntry::Vectorize &&
+ if ((TE.State == TreeEntry::Vectorize ||
+ TE.State == TreeEntry::PossibleStridedVectorize) &&
(isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) ||
(TopToBottom && isa<StoreInst, InsertElementInst>(TE.getMainOp()))) &&
!TE.isAltShuffle())
return TE.ReorderIndices;
if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::PHI) {
- auto PHICompare = [](llvm::Value *V1, llvm::Value *V2) {
+ auto PHICompare = [&](unsigned I1, unsigned I2) {
+ Value *V1 = TE.Scalars[I1];
+ Value *V2 = TE.Scalars[I2];
if (V1 == V2)
return false;
if (!V1->hasOneUse() || !V2->hasOneUse())
@@ -4180,14 +4200,13 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
};
if (!TE.ReorderIndices.empty())
return TE.ReorderIndices;
- DenseMap<Value *, unsigned> PhiToId;
- SmallVector<Value *, 4> Phis;
+ DenseMap<unsigned, unsigned> PhiToId;
+ SmallVector<unsigned> Phis(TE.Scalars.size());
+ std::iota(Phis.begin(), Phis.end(), 0);
OrdersType ResOrder(TE.Scalars.size());
- for (unsigned Id = 0, Sz = TE.Scalars.size(); Id < Sz; ++Id) {
- PhiToId[TE.Scalars[Id]] = Id;
- Phis.push_back(TE.Scalars[Id]);
- }
- llvm::stable_sort(Phis, PHICompare);
+ for (unsigned Id = 0, Sz = TE.Scalars.size(); Id < Sz; ++Id)
+ PhiToId[Id] = Id;
+ stable_sort(Phis, PHICompare);
for (unsigned Id = 0, Sz = Phis.size(); Id < Sz; ++Id)
ResOrder[Id] = PhiToId[Phis[Id]];
if (IsIdentityOrder(ResOrder))
@@ -4214,7 +4233,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
// Check that gather of extractelements can be represented as
// just a shuffle of a single vector.
OrdersType CurrentOrder;
- bool Reuse = canReuseExtract(TE.Scalars, TE.getMainOp(), CurrentOrder);
+ bool Reuse = canReuseExtract(TE.Scalars, TE.getMainOp(), CurrentOrder,
+ /*ResizeAllowed=*/true);
if (Reuse || !CurrentOrder.empty()) {
if (!CurrentOrder.empty())
fixupOrderingIndices(CurrentOrder);
@@ -4270,7 +4290,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
static bool isRepeatedNonIdentityClusteredMask(ArrayRef<int> Mask,
unsigned Sz) {
ArrayRef<int> FirstCluster = Mask.slice(0, Sz);
- if (ShuffleVectorInst::isIdentityMask(FirstCluster))
+ if (ShuffleVectorInst::isIdentityMask(FirstCluster, Sz))
return false;
for (unsigned I = Sz, E = Mask.size(); I < E; I += Sz) {
ArrayRef<int> Cluster = Mask.slice(I, Sz);
@@ -4386,7 +4406,9 @@ void BoUpSLP::reorderTopToBottom() {
++Cnt;
}
VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
- if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty())
+ if (!(TE->State == TreeEntry::Vectorize ||
+ TE->State == TreeEntry::PossibleStridedVectorize) ||
+ !TE->ReuseShuffleIndices.empty())
GathersToOrders.try_emplace(TE.get(), *CurrentOrder);
if (TE->State == TreeEntry::Vectorize &&
TE->getOpcode() == Instruction::PHI)
@@ -4409,6 +4431,9 @@ void BoUpSLP::reorderTopToBottom() {
MapVector<OrdersType, unsigned,
DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>>
OrdersUses;
+ // Last chance orders - scatter vectorize. Try to use their orders if no
+ // other orders or the order is counted already.
+ SmallVector<OrdersType> StridedVectorizeOrders;
SmallPtrSet<const TreeEntry *, 4> VisitedOps;
for (const TreeEntry *OpTE : OrderedEntries) {
// No need to reorder this nodes, still need to extend and to use shuffle,
@@ -4455,6 +4480,11 @@ void BoUpSLP::reorderTopToBottom() {
if (Order.empty())
continue;
}
+ // Postpone scatter orders.
+ if (OpTE->State == TreeEntry::PossibleStridedVectorize) {
+ StridedVectorizeOrders.push_back(Order);
+ continue;
+ }
// Stores actually store the mask, not the order, need to invert.
if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() &&
OpTE->getOpcode() == Instruction::Store && !Order.empty()) {
@@ -4472,8 +4502,21 @@ void BoUpSLP::reorderTopToBottom() {
}
}
// Set order of the user node.
- if (OrdersUses.empty())
- continue;
+ if (OrdersUses.empty()) {
+ if (StridedVectorizeOrders.empty())
+ continue;
+ // Add (potentially!) strided vectorize orders.
+ for (OrdersType &Order : StridedVectorizeOrders)
+ ++OrdersUses.insert(std::make_pair(Order, 0)).first->second;
+ } else {
+ // Account (potentially!) strided vectorize orders only if it was used
+ // already.
+ for (OrdersType &Order : StridedVectorizeOrders) {
+ auto *It = OrdersUses.find(Order);
+ if (It != OrdersUses.end())
+ ++It->second;
+ }
+ }
// Choose the most used order.
ArrayRef<unsigned> BestOrder = OrdersUses.front().first;
unsigned Cnt = OrdersUses.front().second;
@@ -4514,7 +4557,8 @@ void BoUpSLP::reorderTopToBottom() {
}
continue;
}
- if (TE->State == TreeEntry::Vectorize &&
+ if ((TE->State == TreeEntry::Vectorize ||
+ TE->State == TreeEntry::PossibleStridedVectorize) &&
isa<ExtractElementInst, ExtractValueInst, LoadInst, StoreInst,
InsertElementInst>(TE->getMainOp()) &&
!TE->isAltShuffle()) {
@@ -4555,6 +4599,10 @@ bool BoUpSLP::canReorderOperands(
}))
continue;
if (TreeEntry *TE = getVectorizedOperand(UserTE, I)) {
+ // FIXME: Do not reorder (possible!) strided vectorized nodes, they
+ // require reordering of the operands, which is not implemented yet.
+ if (TE->State == TreeEntry::PossibleStridedVectorize)
+ return false;
// Do not reorder if operand node is used by many user nodes.
if (any_of(TE->UserTreeIndices,
[UserTE](const EdgeInfo &EI) { return EI.UserTE != UserTE; }))
@@ -4567,7 +4615,8 @@ bool BoUpSLP::canReorderOperands(
// simply add to the list of gathered ops.
// If there are reused scalars, process this node as a regular vectorize
// node, just reorder reuses mask.
- if (TE->State != TreeEntry::Vectorize && TE->ReuseShuffleIndices.empty())
+ if (TE->State != TreeEntry::Vectorize &&
+ TE->ReuseShuffleIndices.empty() && TE->ReorderIndices.empty())
GatherOps.push_back(TE);
continue;
}
@@ -4602,18 +4651,19 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
// Currently the are vectorized loads,extracts without alternate operands +
// some gathering of extracts.
SmallVector<TreeEntry *> NonVectorized;
- for_each(VectorizableTree, [this, &OrderedEntries, &GathersToOrders,
- &NonVectorized](
- const std::unique_ptr<TreeEntry> &TE) {
- if (TE->State != TreeEntry::Vectorize)
+ for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) {
+ if (TE->State != TreeEntry::Vectorize &&
+ TE->State != TreeEntry::PossibleStridedVectorize)
NonVectorized.push_back(TE.get());
if (std::optional<OrdersType> CurrentOrder =
getReorderingData(*TE, /*TopToBottom=*/false)) {
OrderedEntries.insert(TE.get());
- if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty())
+ if (!(TE->State == TreeEntry::Vectorize ||
+ TE->State == TreeEntry::PossibleStridedVectorize) ||
+ !TE->ReuseShuffleIndices.empty())
GathersToOrders.try_emplace(TE.get(), *CurrentOrder);
}
- });
+ }
// 1. Propagate order to the graph nodes, which use only reordered nodes.
// I.e., if the node has operands, that are reordered, try to make at least
@@ -4627,6 +4677,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
SmallVector<TreeEntry *> Filtered;
for (TreeEntry *TE : OrderedEntries) {
if (!(TE->State == TreeEntry::Vectorize ||
+ TE->State == TreeEntry::PossibleStridedVectorize ||
(TE->State == TreeEntry::NeedToGather &&
GathersToOrders.count(TE))) ||
TE->UserTreeIndices.empty() || !TE->ReuseShuffleIndices.empty() ||
@@ -4649,8 +4700,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
}
}
// Erase filtered entries.
- for_each(Filtered,
- [&OrderedEntries](TreeEntry *TE) { OrderedEntries.remove(TE); });
+ for (TreeEntry *TE : Filtered)
+ OrderedEntries.remove(TE);
SmallVector<
std::pair<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>>>
UsersVec(Users.begin(), Users.end());
@@ -4662,10 +4713,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
SmallVector<TreeEntry *> GatherOps;
if (!canReorderOperands(Data.first, Data.second, NonVectorized,
GatherOps)) {
- for_each(Data.second,
- [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) {
- OrderedEntries.remove(Op.second);
- });
+ for (const std::pair<unsigned, TreeEntry *> &Op : Data.second)
+ OrderedEntries.remove(Op.second);
continue;
}
// All operands are reordered and used only in this node - propagate the
@@ -4673,6 +4722,9 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
MapVector<OrdersType, unsigned,
DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>>
OrdersUses;
+ // Last chance orders - scatter vectorize. Try to use their orders if no
+ // other orders or the order is counted already.
+ SmallVector<std::pair<OrdersType, unsigned>> StridedVectorizeOrders;
// Do the analysis for each tree entry only once, otherwise the order of
// the same node my be considered several times, though might be not
// profitable.
@@ -4694,6 +4746,11 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
Data.second, [OpTE](const std::pair<unsigned, TreeEntry *> &P) {
return P.second == OpTE;
});
+ // Postpone scatter orders.
+ if (OpTE->State == TreeEntry::PossibleStridedVectorize) {
+ StridedVectorizeOrders.emplace_back(Order, NumOps);
+ continue;
+ }
// Stores actually store the mask, not the order, need to invert.
if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() &&
OpTE->getOpcode() == Instruction::Store && !Order.empty()) {
@@ -4754,11 +4811,27 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
}
// If no orders - skip current nodes and jump to the next one, if any.
if (OrdersUses.empty()) {
- for_each(Data.second,
- [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) {
- OrderedEntries.remove(Op.second);
- });
- continue;
+ if (StridedVectorizeOrders.empty() ||
+ (Data.first->ReorderIndices.empty() &&
+ Data.first->ReuseShuffleIndices.empty() &&
+ !(IgnoreReorder &&
+ Data.first == VectorizableTree.front().get()))) {
+ for (const std::pair<unsigned, TreeEntry *> &Op : Data.second)
+ OrderedEntries.remove(Op.second);
+ continue;
+ }
+ // Add (potentially!) strided vectorize orders.
+ for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders)
+ OrdersUses.insert(std::make_pair(Pair.first, 0)).first->second +=
+ Pair.second;
+ } else {
+ // Account (potentially!) strided vectorize orders only if it was used
+ // already.
+ for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders) {
+ auto *It = OrdersUses.find(Pair.first);
+ if (It != OrdersUses.end())
+ It->second += Pair.second;
+ }
}
// Choose the best order.
ArrayRef<unsigned> BestOrder = OrdersUses.front().first;
@@ -4771,10 +4844,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
}
// Set order of the user node (reordering of operands and user nodes).
if (BestOrder.empty()) {
- for_each(Data.second,
- [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) {
- OrderedEntries.remove(Op.second);
- });
+ for (const std::pair<unsigned, TreeEntry *> &Op : Data.second)
+ OrderedEntries.remove(Op.second);
continue;
}
// Erase operands from OrderedEntries list and adjust their orders.
@@ -4796,7 +4867,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
continue;
}
// Gathers are processed separately.
- if (TE->State != TreeEntry::Vectorize)
+ if (TE->State != TreeEntry::Vectorize &&
+ TE->State != TreeEntry::PossibleStridedVectorize &&
+ (TE->State != TreeEntry::ScatterVectorize ||
+ TE->ReorderIndices.empty()))
continue;
assert((BestOrder.size() == TE->ReorderIndices.size() ||
TE->ReorderIndices.empty()) &&
@@ -4825,7 +4899,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
Data.first->isAltShuffle())
Data.first->reorderOperands(Mask);
if (!isa<InsertElementInst, StoreInst>(Data.first->getMainOp()) ||
- Data.first->isAltShuffle()) {
+ Data.first->isAltShuffle() ||
+ Data.first->State == TreeEntry::PossibleStridedVectorize) {
reorderScalars(Data.first->Scalars, Mask);
reorderOrder(Data.first->ReorderIndices, MaskOrder);
if (Data.first->ReuseShuffleIndices.empty() &&
@@ -4859,10 +4934,12 @@ void BoUpSLP::buildExternalUses(
// For each lane:
for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) {
Value *Scalar = Entry->Scalars[Lane];
+ if (!isa<Instruction>(Scalar))
+ continue;
int FoundLane = Entry->findLaneForValue(Scalar);
// Check if the scalar is externally used as an extra arg.
- auto ExtI = ExternallyUsedValues.find(Scalar);
+ const auto *ExtI = ExternallyUsedValues.find(Scalar);
if (ExtI != ExternallyUsedValues.end()) {
LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane "
<< Lane << " from " << *Scalar << ".\n");
@@ -4886,7 +4963,8 @@ void BoUpSLP::buildExternalUses(
// be used.
if (UseScalar != U ||
UseEntry->State == TreeEntry::ScatterVectorize ||
- !InTreeUserNeedToExtract(Scalar, UserInst, TLI)) {
+ UseEntry->State == TreeEntry::PossibleStridedVectorize ||
+ !doesInTreeUserNeedToExtract(Scalar, UserInst, TLI)) {
LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
<< ".\n");
assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state");
@@ -4906,9 +4984,9 @@ void BoUpSLP::buildExternalUses(
}
}
-DenseMap<Value *, SmallVector<StoreInst *, 4>>
+DenseMap<Value *, SmallVector<StoreInst *>>
BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const {
- DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap;
+ DenseMap<Value *, SmallVector<StoreInst *>> PtrToStoresMap;
for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) {
Value *V = TE->Scalars[Lane];
// To save compilation time we don't visit if we have too many users.
@@ -4947,14 +5025,14 @@ BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const {
return PtrToStoresMap;
}
-bool BoUpSLP::canFormVector(const SmallVector<StoreInst *, 4> &StoresVec,
+bool BoUpSLP::canFormVector(ArrayRef<StoreInst *> StoresVec,
OrdersType &ReorderIndices) const {
// We check whether the stores in StoreVec can form a vector by sorting them
// and checking whether they are consecutive.
// To avoid calling getPointersDiff() while sorting we create a vector of
// pairs {store, offset from first} and sort this instead.
- SmallVector<std::pair<StoreInst *, int>, 4> StoreOffsetVec(StoresVec.size());
+ SmallVector<std::pair<StoreInst *, int>> StoreOffsetVec(StoresVec.size());
StoreInst *S0 = StoresVec[0];
StoreOffsetVec[0] = {S0, 0};
Type *S0Ty = S0->getValueOperand()->getType();
@@ -5023,7 +5101,7 @@ SmallVector<BoUpSLP::OrdersType, 1>
BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const {
unsigned NumLanes = TE->Scalars.size();
- DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap =
+ DenseMap<Value *, SmallVector<StoreInst *>> PtrToStoresMap =
collectUserStores(TE);
// Holds the reorder indices for each candidate store vector that is a user of
@@ -5244,6 +5322,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
return TreeEntry::Vectorize;
case LoadsState::ScatterVectorize:
return TreeEntry::ScatterVectorize;
+ case LoadsState::PossibleStridedVectorize:
+ return TreeEntry::PossibleStridedVectorize;
case LoadsState::Gather:
#ifndef NDEBUG
Type *ScalarTy = VL0->getType();
@@ -5416,7 +5496,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
VFShape Shape = VFShape::get(
- *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
+ CI->getFunctionType(),
+ ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
false /*HasGlobalPred*/);
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
@@ -5488,9 +5569,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
SmallVector<int> ReuseShuffleIndicies;
SmallVector<Value *> UniqueValues;
- auto &&TryToFindDuplicates = [&VL, &ReuseShuffleIndicies, &UniqueValues,
- &UserTreeIdx,
- this](const InstructionsState &S) {
+ SmallVector<Value *> NonUniqueValueVL;
+ auto TryToFindDuplicates = [&](const InstructionsState &S,
+ bool DoNotFail = false) {
// Check that every instruction appears once in this bundle.
DenseMap<Value *, unsigned> UniquePositions(VL.size());
for (Value *V : VL) {
@@ -5517,6 +5598,24 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
!isConstant(V);
})) ||
!llvm::has_single_bit<uint32_t>(NumUniqueScalarValues)) {
+ if (DoNotFail && UniquePositions.size() > 1 &&
+ NumUniqueScalarValues > 1 && S.MainOp->isSafeToRemove() &&
+ all_of(UniqueValues, [=](Value *V) {
+ return isa<ExtractElementInst>(V) ||
+ areAllUsersVectorized(cast<Instruction>(V),
+ UserIgnoreList);
+ })) {
+ unsigned PWSz = PowerOf2Ceil(UniqueValues.size());
+ if (PWSz == VL.size()) {
+ ReuseShuffleIndicies.clear();
+ } else {
+ NonUniqueValueVL.assign(UniqueValues.begin(), UniqueValues.end());
+ NonUniqueValueVL.append(PWSz - UniqueValues.size(),
+ UniqueValues.back());
+ VL = NonUniqueValueVL;
+ }
+ return true;
+ }
LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n");
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
return false;
@@ -5528,6 +5627,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
InstructionsState S = getSameOpcode(VL, *TLI);
+ // Don't vectorize ephemeral values.
+ if (!EphValues.empty()) {
+ for (Value *V : VL) {
+ if (EphValues.count(V)) {
+ LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
+ << ") is ephemeral.\n");
+ newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
+ return;
+ }
+ }
+ }
+
// Gather if we hit the RecursionMaxDepth, unless this is a load (or z/sext of
// a load), in which case peek through to include it in the tree, without
// ballooning over-budget.
@@ -5633,7 +5744,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
BasicBlock *BB = nullptr;
bool IsScatterVectorizeUserTE =
UserTreeIdx.UserTE &&
- UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
+ (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize ||
+ UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize);
bool AreAllSameInsts =
(S.getOpcode() && allSameBlock(VL)) ||
(S.OpValue->getType()->isPointerTy() && IsScatterVectorizeUserTE &&
@@ -5665,39 +5777,44 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// We now know that this is a vector of instructions of the same type from
// the same block.
- // Don't vectorize ephemeral values.
- if (!EphValues.empty()) {
- for (Value *V : VL) {
- if (EphValues.count(V)) {
- LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
- << ") is ephemeral.\n");
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
- return;
- }
- }
- }
-
// Check if this is a duplicate of another entry.
if (TreeEntry *E = getTreeEntry(S.OpValue)) {
LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n");
if (!E->isSame(VL)) {
- LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n");
- if (TryToFindDuplicates(S))
- newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
- ReuseShuffleIndicies);
+ auto It = MultiNodeScalars.find(S.OpValue);
+ if (It != MultiNodeScalars.end()) {
+ auto *TEIt = find_if(It->getSecond(),
+ [&](TreeEntry *ME) { return ME->isSame(VL); });
+ if (TEIt != It->getSecond().end())
+ E = *TEIt;
+ else
+ E = nullptr;
+ } else {
+ E = nullptr;
+ }
+ }
+ if (!E) {
+ if (!doesNotNeedToBeScheduled(S.OpValue)) {
+ LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n");
+ if (TryToFindDuplicates(S))
+ newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx,
+ ReuseShuffleIndicies);
+ return;
+ }
+ } else {
+ // Record the reuse of the tree node. FIXME, currently this is only used
+ // to properly draw the graph rather than for the actual vectorization.
+ E->UserTreeIndices.push_back(UserTreeIdx);
+ LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue
+ << ".\n");
return;
}
- // Record the reuse of the tree node. FIXME, currently this is only used to
- // properly draw the graph rather than for the actual vectorization.
- E->UserTreeIndices.push_back(UserTreeIdx);
- LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue
- << ".\n");
- return;
}
// Check that none of the instructions in the bundle are already in the tree.
for (Value *V : VL) {
- if (!IsScatterVectorizeUserTE && !isa<Instruction>(V))
+ if ((!IsScatterVectorizeUserTE && !isa<Instruction>(V)) ||
+ doesNotNeedToBeScheduled(V))
continue;
if (getTreeEntry(V)) {
LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V
@@ -5725,7 +5842,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// Special processing for sorted pointers for ScatterVectorize node with
// constant indeces only.
if (AreAllSameInsts && UserTreeIdx.UserTE &&
- UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize &&
+ (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize ||
+ UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize) &&
!(S.getOpcode() && allSameBlock(VL))) {
assert(S.OpValue->getType()->isPointerTy() &&
count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >=
@@ -5760,7 +5878,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
// Check that every instruction appears once in this bundle.
- if (!TryToFindDuplicates(S))
+ if (!TryToFindDuplicates(S, /*DoNotFail=*/true))
return;
// Perform specific checks for each particular instruction kind.
@@ -5780,7 +5898,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
BlockScheduling &BS = *BSRef;
- std::optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S);
+ std::optional<ScheduleData *> Bundle =
+ BS.tryScheduleBundle(UniqueValues, this, S);
#ifdef EXPENSIVE_CHECKS
// Make sure we didn't break any internal invariants
BS.verify();
@@ -5905,6 +6024,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// from such a struct, we read/write packed bits disagreeing with the
// unvectorized version.
TreeEntry *TE = nullptr;
+ fixupOrderingIndices(CurrentOrder);
switch (State) {
case TreeEntry::Vectorize:
if (CurrentOrder.empty()) {
@@ -5913,7 +6033,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
ReuseShuffleIndicies);
LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n");
} else {
- fixupOrderingIndices(CurrentOrder);
// Need to reorder.
TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies, CurrentOrder);
@@ -5921,6 +6040,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
TE->setOperandsInOrder();
break;
+ case TreeEntry::PossibleStridedVectorize:
+ // Vectorizing non-consecutive loads with `llvm.masked.gather`.
+ if (CurrentOrder.empty()) {
+ TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S,
+ UserTreeIdx, ReuseShuffleIndicies);
+ } else {
+ TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S,
+ UserTreeIdx, ReuseShuffleIndicies, CurrentOrder);
+ }
+ TE->setOperandsInOrder();
+ buildTree_rec(PointerOps, Depth + 1, {TE, 0});
+ LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n");
+ break;
case TreeEntry::ScatterVectorize:
// Vectorizing non-consecutive loads with `llvm.masked.gather`.
TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S,
@@ -5951,13 +6083,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n");
TE->setOperandsInOrder();
- for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
+ for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) {
ValueList Operands;
// Prepare the operand vector.
for (Value *V : VL)
- Operands.push_back(cast<Instruction>(V)->getOperand(i));
+ Operands.push_back(cast<Instruction>(V)->getOperand(I));
- buildTree_rec(Operands, Depth + 1, {TE, i});
+ buildTree_rec(Operands, Depth + 1, {TE, I});
}
return;
}
@@ -6031,13 +6163,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
TE->setOperandsInOrder();
- for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
+ for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) {
ValueList Operands;
// Prepare the operand vector.
for (Value *V : VL)
- Operands.push_back(cast<Instruction>(V)->getOperand(i));
+ Operands.push_back(cast<Instruction>(V)->getOperand(I));
- buildTree_rec(Operands, Depth + 1, {TE, i});
+ buildTree_rec(Operands, Depth + 1, {TE, I});
}
return;
}
@@ -6087,8 +6219,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (!CI)
Operands.back().push_back(Op);
else
- Operands.back().push_back(ConstantExpr::getIntegerCast(
- CI, Ty, CI->getValue().isSignBitSet()));
+ Operands.back().push_back(ConstantFoldIntegerCast(
+ CI, Ty, CI->getValue().isSignBitSet(), *DL));
}
TE->setOperand(IndexIdx, Operands.back());
@@ -6132,18 +6264,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndicies);
TE->setOperandsInOrder();
- for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) {
- // For scalar operands no need to to create an entry since no need to
+ for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
+ // For scalar operands no need to create an entry since no need to
// vectorize it.
- if (isVectorIntrinsicWithScalarOpAtArg(ID, i))
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
continue;
ValueList Operands;
// Prepare the operand vector.
for (Value *V : VL) {
auto *CI2 = cast<CallInst>(V);
- Operands.push_back(CI2->getArgOperand(i));
+ Operands.push_back(CI2->getArgOperand(I));
}
- buildTree_rec(Operands, Depth + 1, {TE, i});
+ buildTree_rec(Operands, Depth + 1, {TE, I});
}
return;
}
@@ -6194,13 +6326,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
TE->setOperandsInOrder();
- for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
+ for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) {
ValueList Operands;
// Prepare the operand vector.
for (Value *V : VL)
- Operands.push_back(cast<Instruction>(V)->getOperand(i));
+ Operands.push_back(cast<Instruction>(V)->getOperand(I));
- buildTree_rec(Operands, Depth + 1, {TE, i});
+ buildTree_rec(Operands, Depth + 1, {TE, I});
}
return;
}
@@ -6210,7 +6342,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
llvm_unreachable("Unexpected vectorization of the instructions.");
}
-unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
+unsigned BoUpSLP::canMapToVector(Type *T) const {
unsigned N = 1;
Type *EltTy = T;
@@ -6234,15 +6366,16 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
if (!isValidElementType(EltTy))
return 0;
- uint64_t VTSize = DL.getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N));
+ uint64_t VTSize = DL->getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N));
if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize ||
- VTSize != DL.getTypeStoreSizeInBits(T))
+ VTSize != DL->getTypeStoreSizeInBits(T))
return 0;
return N;
}
bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
- SmallVectorImpl<unsigned> &CurrentOrder) const {
+ SmallVectorImpl<unsigned> &CurrentOrder,
+ bool ResizeAllowed) const {
const auto *It = find_if(VL, [](Value *V) {
return isa<ExtractElementInst, ExtractValueInst>(V);
});
@@ -6263,8 +6396,7 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
// We have to extract from a vector/aggregate with the same number of elements.
unsigned NElts;
if (E0->getOpcode() == Instruction::ExtractValue) {
- const DataLayout &DL = E0->getModule()->getDataLayout();
- NElts = canMapToVector(Vec->getType(), DL);
+ NElts = canMapToVector(Vec->getType());
if (!NElts)
return false;
// Check if load can be rewritten as load of vector.
@@ -6275,46 +6407,55 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
NElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
}
- if (NElts != VL.size())
- return false;
-
- // Check that all of the indices extract from the correct offset.
- bool ShouldKeepOrder = true;
unsigned E = VL.size();
- // Assign to all items the initial value E + 1 so we can check if the extract
- // instruction index was used already.
- // Also, later we can check that all the indices are used and we have a
- // consecutive access in the extract instructions, by checking that no
- // element of CurrentOrder still has value E + 1.
- CurrentOrder.assign(E, E);
- unsigned I = 0;
- for (; I < E; ++I) {
- auto *Inst = dyn_cast<Instruction>(VL[I]);
+ if (!ResizeAllowed && NElts != E)
+ return false;
+ SmallVector<int> Indices(E, PoisonMaskElem);
+ unsigned MinIdx = NElts, MaxIdx = 0;
+ for (auto [I, V] : enumerate(VL)) {
+ auto *Inst = dyn_cast<Instruction>(V);
if (!Inst)
continue;
if (Inst->getOperand(0) != Vec)
- break;
+ return false;
if (auto *EE = dyn_cast<ExtractElementInst>(Inst))
if (isa<UndefValue>(EE->getIndexOperand()))
continue;
std::optional<unsigned> Idx = getExtractIndex(Inst);
if (!Idx)
- break;
+ return false;
const unsigned ExtIdx = *Idx;
- if (ExtIdx != I) {
- if (ExtIdx >= E || CurrentOrder[ExtIdx] != E)
- break;
- ShouldKeepOrder = false;
- CurrentOrder[ExtIdx] = I;
- } else {
- if (CurrentOrder[I] != E)
- break;
- CurrentOrder[I] = I;
- }
+ if (ExtIdx >= NElts)
+ continue;
+ Indices[I] = ExtIdx;
+ if (MinIdx > ExtIdx)
+ MinIdx = ExtIdx;
+ if (MaxIdx < ExtIdx)
+ MaxIdx = ExtIdx;
}
- if (I < E) {
- CurrentOrder.clear();
+ if (MaxIdx - MinIdx + 1 > E)
return false;
+ if (MaxIdx + 1 <= E)
+ MinIdx = 0;
+
+ // Check that all of the indices extract from the correct offset.
+ bool ShouldKeepOrder = true;
+ // Assign to all items the initial value E + 1 so we can check if the extract
+ // instruction index was used already.
+ // Also, later we can check that all the indices are used and we have a
+ // consecutive access in the extract instructions, by checking that no
+ // element of CurrentOrder still has value E + 1.
+ CurrentOrder.assign(E, E);
+ for (unsigned I = 0; I < E; ++I) {
+ if (Indices[I] == PoisonMaskElem)
+ continue;
+ const unsigned ExtIdx = Indices[I] - MinIdx;
+ if (CurrentOrder[ExtIdx] != E) {
+ CurrentOrder.clear();
+ return false;
+ }
+ ShouldKeepOrder &= ExtIdx == I;
+ CurrentOrder[ExtIdx] = I;
}
if (ShouldKeepOrder)
CurrentOrder.clear();
@@ -6322,9 +6463,9 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
return ShouldKeepOrder;
}
-bool BoUpSLP::areAllUsersVectorized(Instruction *I,
- ArrayRef<Value *> VectorizedVals) const {
- return (I->hasOneUse() && is_contained(VectorizedVals, I)) ||
+bool BoUpSLP::areAllUsersVectorized(
+ Instruction *I, const SmallDenseSet<Value *> *VectorizedVals) const {
+ return (I->hasOneUse() && (!VectorizedVals || VectorizedVals->contains(I))) ||
all_of(I->users(), [this](User *U) {
return ScalarToTreeEntry.count(U) > 0 ||
isVectorLikeInstWithConstOps(U) ||
@@ -6351,8 +6492,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
auto IntrinsicCost =
TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
- auto Shape = VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
- VecTy->getNumElements())),
+ auto Shape = VFShape::get(CI->getFunctionType(),
+ ElementCount::getFixed(VecTy->getNumElements()),
false /*HasGlobalPred*/);
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
auto LibCost = IntrinsicCost;
@@ -6365,16 +6506,11 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
return {IntrinsicCost, LibCost};
}
-/// Build shuffle mask for shuffle graph entries and lists of main and alternate
-/// operations operands.
-static void
-buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices,
- ArrayRef<int> ReusesIndices,
- const function_ref<bool(Instruction *)> IsAltOp,
- SmallVectorImpl<int> &Mask,
- SmallVectorImpl<Value *> *OpScalars = nullptr,
- SmallVectorImpl<Value *> *AltScalars = nullptr) {
- unsigned Sz = VL.size();
+void BoUpSLP::TreeEntry::buildAltOpShuffleMask(
+ const function_ref<bool(Instruction *)> IsAltOp, SmallVectorImpl<int> &Mask,
+ SmallVectorImpl<Value *> *OpScalars,
+ SmallVectorImpl<Value *> *AltScalars) const {
+ unsigned Sz = Scalars.size();
Mask.assign(Sz, PoisonMaskElem);
SmallVector<int> OrderMask;
if (!ReorderIndices.empty())
@@ -6383,7 +6519,7 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices,
unsigned Idx = I;
if (!ReorderIndices.empty())
Idx = OrderMask[I];
- auto *OpInst = cast<Instruction>(VL[Idx]);
+ auto *OpInst = cast<Instruction>(Scalars[Idx]);
if (IsAltOp(OpInst)) {
Mask[I] = Sz + Idx;
if (AltScalars)
@@ -6394,9 +6530,9 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices,
OpScalars->push_back(OpInst);
}
}
- if (!ReusesIndices.empty()) {
- SmallVector<int> NewMask(ReusesIndices.size(), PoisonMaskElem);
- transform(ReusesIndices, NewMask.begin(), [&Mask](int Idx) {
+ if (!ReuseShuffleIndices.empty()) {
+ SmallVector<int> NewMask(ReuseShuffleIndices.size(), PoisonMaskElem);
+ transform(ReuseShuffleIndices, NewMask.begin(), [&Mask](int Idx) {
return Idx != PoisonMaskElem ? Mask[Idx] : PoisonMaskElem;
});
Mask.swap(NewMask);
@@ -6429,52 +6565,27 @@ static bool isAlternateInstruction(const Instruction *I,
return I->getOpcode() == AltOp->getOpcode();
}
-TTI::OperandValueInfo BoUpSLP::getOperandInfo(ArrayRef<Value *> VL,
- unsigned OpIdx) {
- assert(!VL.empty());
- const auto *I0 = cast<Instruction>(*find_if(VL, Instruction::classof));
- const auto *Op0 = I0->getOperand(OpIdx);
+TTI::OperandValueInfo BoUpSLP::getOperandInfo(ArrayRef<Value *> Ops) {
+ assert(!Ops.empty());
+ const auto *Op0 = Ops.front();
- const bool IsConstant = all_of(VL, [&](Value *V) {
+ const bool IsConstant = all_of(Ops, [](Value *V) {
// TODO: We should allow undef elements here
- const auto *I = dyn_cast<Instruction>(V);
- if (!I)
- return true;
- auto *Op = I->getOperand(OpIdx);
- return isConstant(Op) && !isa<UndefValue>(Op);
+ return isConstant(V) && !isa<UndefValue>(V);
});
- const bool IsUniform = all_of(VL, [&](Value *V) {
+ const bool IsUniform = all_of(Ops, [=](Value *V) {
// TODO: We should allow undef elements here
- const auto *I = dyn_cast<Instruction>(V);
- if (!I)
- return false;
- return I->getOperand(OpIdx) == Op0;
+ return V == Op0;
});
- const bool IsPowerOfTwo = all_of(VL, [&](Value *V) {
+ const bool IsPowerOfTwo = all_of(Ops, [](Value *V) {
// TODO: We should allow undef elements here
- const auto *I = dyn_cast<Instruction>(V);
- if (!I) {
- assert((isa<UndefValue>(V) ||
- I0->getOpcode() == Instruction::GetElementPtr) &&
- "Expected undef or GEP.");
- return true;
- }
- auto *Op = I->getOperand(OpIdx);
- if (auto *CI = dyn_cast<ConstantInt>(Op))
+ if (auto *CI = dyn_cast<ConstantInt>(V))
return CI->getValue().isPowerOf2();
return false;
});
- const bool IsNegatedPowerOfTwo = all_of(VL, [&](Value *V) {
+ const bool IsNegatedPowerOfTwo = all_of(Ops, [](Value *V) {
// TODO: We should allow undef elements here
- const auto *I = dyn_cast<Instruction>(V);
- if (!I) {
- assert((isa<UndefValue>(V) ||
- I0->getOpcode() == Instruction::GetElementPtr) &&
- "Expected undef or GEP.");
- return true;
- }
- const auto *Op = I->getOperand(OpIdx);
- if (auto *CI = dyn_cast<ConstantInt>(Op))
+ if (auto *CI = dyn_cast<ConstantInt>(V))
return CI->getValue().isNegatedPowerOf2();
return false;
});
@@ -6505,9 +6616,24 @@ protected:
bool IsStrict) {
int Limit = Mask.size();
int VF = VecTy->getNumElements();
- return (VF == Limit || !IsStrict) &&
- all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) &&
- ShuffleVectorInst::isIdentityMask(Mask);
+ int Index = -1;
+ if (VF == Limit && ShuffleVectorInst::isIdentityMask(Mask, Limit))
+ return true;
+ if (!IsStrict) {
+ // Consider extract subvector starting from index 0.
+ if (ShuffleVectorInst::isExtractSubvectorMask(Mask, VF, Index) &&
+ Index == 0)
+ return true;
+ // All VF-size submasks are identity (e.g.
+ // <poison,poison,poison,poison,0,1,2,poison,poison,1,2,3> etc. for VF 4).
+ if (Limit % VF == 0 && all_of(seq<int>(0, Limit / VF), [=](int Idx) {
+ ArrayRef<int> Slice = Mask.slice(Idx * VF, VF);
+ return all_of(Slice, [](int I) { return I == PoisonMaskElem; }) ||
+ ShuffleVectorInst::isIdentityMask(Slice, VF);
+ }))
+ return true;
+ }
+ return false;
}
/// Tries to combine 2 different masks into single one.
@@ -6577,7 +6703,8 @@ protected:
if (isIdentityMask(Mask, SVTy, /*IsStrict=*/false)) {
if (!IdentityOp || !SinglePermute ||
(isIdentityMask(Mask, SVTy, /*IsStrict=*/true) &&
- !ShuffleVectorInst::isZeroEltSplatMask(IdentityMask))) {
+ !ShuffleVectorInst::isZeroEltSplatMask(IdentityMask,
+ IdentityMask.size()))) {
IdentityOp = SV;
// Store current mask in the IdentityMask so later we did not lost
// this info if IdentityOp is selected as the best candidate for the
@@ -6647,7 +6774,7 @@ protected:
}
if (auto *OpTy = dyn_cast<FixedVectorType>(Op->getType());
!OpTy || !isIdentityMask(Mask, OpTy, SinglePermute) ||
- ShuffleVectorInst::isZeroEltSplatMask(Mask)) {
+ ShuffleVectorInst::isZeroEltSplatMask(Mask, Mask.size())) {
if (IdentityOp) {
V = IdentityOp;
assert(Mask.size() == IdentityMask.size() &&
@@ -6663,7 +6790,7 @@ protected:
/*IsStrict=*/true) ||
(Shuffle && Mask.size() == Shuffle->getShuffleMask().size() &&
Shuffle->isZeroEltSplat() &&
- ShuffleVectorInst::isZeroEltSplatMask(Mask)));
+ ShuffleVectorInst::isZeroEltSplatMask(Mask, Mask.size())));
}
V = Op;
return false;
@@ -6768,11 +6895,9 @@ protected:
CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF);
}
}
- const int Limit = CombinedMask1.size() * 2;
- if (Op1 == Op2 && Limit == 2 * VF &&
- all_of(CombinedMask1, [=](int Idx) { return Idx < Limit; }) &&
- (ShuffleVectorInst::isIdentityMask(CombinedMask1) ||
- (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1) &&
+ if (Op1 == Op2 &&
+ (ShuffleVectorInst::isIdentityMask(CombinedMask1, VF) ||
+ (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1, VF) &&
isa<ShuffleVectorInst>(Op1) &&
cast<ShuffleVectorInst>(Op1)->getShuffleMask() ==
ArrayRef(CombinedMask1))))
@@ -6807,10 +6932,29 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
const TargetTransformInfo &TTI;
InstructionCost Cost = 0;
- ArrayRef<Value *> VectorizedVals;
+ SmallDenseSet<Value *> VectorizedVals;
BoUpSLP &R;
SmallPtrSetImpl<Value *> &CheckedExtracts;
constexpr static TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ /// While set, still trying to estimate the cost for the same nodes and we
+ /// can delay actual cost estimation (virtual shuffle instruction emission).
+ /// May help better estimate the cost if same nodes must be permuted + allows
+ /// to move most of the long shuffles cost estimation to TTI.
+ bool SameNodesEstimated = true;
+
+ static Constant *getAllOnesValue(const DataLayout &DL, Type *Ty) {
+ if (Ty->getScalarType()->isPointerTy()) {
+ Constant *Res = ConstantExpr::getIntToPtr(
+ ConstantInt::getAllOnesValue(
+ IntegerType::get(Ty->getContext(),
+ DL.getTypeStoreSizeInBits(Ty->getScalarType()))),
+ Ty->getScalarType());
+ if (auto *VTy = dyn_cast<VectorType>(Ty))
+ Res = ConstantVector::getSplat(VTy->getElementCount(), Res);
+ return Res;
+ }
+ return Constant::getAllOnesValue(Ty);
+ }
InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) {
if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof))
@@ -6821,20 +6965,35 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
// Improve gather cost for gather of loads, if we can group some of the
// loads into vector loads.
InstructionsState S = getSameOpcode(VL, *R.TLI);
- if (VL.size() > 2 && S.getOpcode() == Instruction::Load &&
- !S.isAltShuffle() &&
+ const unsigned Sz = R.DL->getTypeSizeInBits(VL.front()->getType());
+ unsigned MinVF = R.getMinVF(2 * Sz);
+ if (VL.size() > 2 &&
+ ((S.getOpcode() == Instruction::Load && !S.isAltShuffle()) ||
+ (InVectors.empty() &&
+ any_of(seq<unsigned>(0, VL.size() / MinVF),
+ [&](unsigned Idx) {
+ ArrayRef<Value *> SubVL = VL.slice(Idx * MinVF, MinVF);
+ InstructionsState S = getSameOpcode(SubVL, *R.TLI);
+ return S.getOpcode() == Instruction::Load &&
+ !S.isAltShuffle();
+ }))) &&
!all_of(Gathers, [&](Value *V) { return R.getTreeEntry(V); }) &&
!isSplat(Gathers)) {
- BoUpSLP::ValueSet VectorizedLoads;
+ SetVector<Value *> VectorizedLoads;
+ SmallVector<LoadInst *> VectorizedStarts;
+ SmallVector<std::pair<unsigned, unsigned>> ScatterVectorized;
unsigned StartIdx = 0;
unsigned VF = VL.size() / 2;
- unsigned VectorizedCnt = 0;
- unsigned ScatterVectorizeCnt = 0;
- const unsigned Sz = R.DL->getTypeSizeInBits(S.MainOp->getType());
- for (unsigned MinVF = R.getMinVF(2 * Sz); VF >= MinVF; VF /= 2) {
+ for (; VF >= MinVF; VF /= 2) {
for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End;
Cnt += VF) {
ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
+ if (S.getOpcode() != Instruction::Load || S.isAltShuffle()) {
+ InstructionsState SliceS = getSameOpcode(Slice, *R.TLI);
+ if (SliceS.getOpcode() != Instruction::Load ||
+ SliceS.isAltShuffle())
+ continue;
+ }
if (!VectorizedLoads.count(Slice.front()) &&
!VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
SmallVector<Value *> PointerOps;
@@ -6845,12 +7004,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
switch (LS) {
case LoadsState::Vectorize:
case LoadsState::ScatterVectorize:
+ case LoadsState::PossibleStridedVectorize:
// Mark the vectorized loads so that we don't vectorize them
// again.
- if (LS == LoadsState::Vectorize)
- ++VectorizedCnt;
+ // TODO: better handling of loads with reorders.
+ if (LS == LoadsState::Vectorize && CurrentOrder.empty())
+ VectorizedStarts.push_back(cast<LoadInst>(Slice.front()));
else
- ++ScatterVectorizeCnt;
+ ScatterVectorized.emplace_back(Cnt, VF);
VectorizedLoads.insert(Slice.begin(), Slice.end());
// If we vectorized initial block, no need to try to vectorize
// it again.
@@ -6881,8 +7042,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
// Exclude potentially vectorized loads from list of gathered
// scalars.
- auto *LI = cast<LoadInst>(S.MainOp);
- Gathers.assign(Gathers.size(), PoisonValue::get(LI->getType()));
+ Gathers.assign(Gathers.size(), PoisonValue::get(VL.front()->getType()));
// The cost for vectorized loads.
InstructionCost ScalarsCost = 0;
for (Value *V : VectorizedLoads) {
@@ -6892,17 +7052,24 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
LI->getAlign(), LI->getPointerAddressSpace(),
CostKind, TTI::OperandValueInfo(), LI);
}
- auto *LoadTy = FixedVectorType::get(LI->getType(), VF);
- Align Alignment = LI->getAlign();
- GatherCost +=
- VectorizedCnt *
- TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment,
- LI->getPointerAddressSpace(), CostKind,
- TTI::OperandValueInfo(), LI);
- GatherCost += ScatterVectorizeCnt *
- TTI.getGatherScatterOpCost(
- Instruction::Load, LoadTy, LI->getPointerOperand(),
- /*VariableMask=*/false, Alignment, CostKind, LI);
+ auto *LoadTy = FixedVectorType::get(VL.front()->getType(), VF);
+ for (LoadInst *LI : VectorizedStarts) {
+ Align Alignment = LI->getAlign();
+ GatherCost +=
+ TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment,
+ LI->getPointerAddressSpace(), CostKind,
+ TTI::OperandValueInfo(), LI);
+ }
+ for (std::pair<unsigned, unsigned> P : ScatterVectorized) {
+ auto *LI0 = cast<LoadInst>(VL[P.first]);
+ Align CommonAlignment = LI0->getAlign();
+ for (Value *V : VL.slice(P.first + 1, VF - 1))
+ CommonAlignment =
+ std::min(CommonAlignment, cast<LoadInst>(V)->getAlign());
+ GatherCost += TTI.getGatherScatterOpCost(
+ Instruction::Load, LoadTy, LI0->getPointerOperand(),
+ /*VariableMask=*/false, CommonAlignment, CostKind, LI0);
+ }
if (NeedInsertSubvectorAnalysis) {
// Add the cost for the subvectors insert.
for (int I = VF, E = VL.size(); I < E; I += VF)
@@ -6938,77 +7105,137 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
: R.getGatherCost(Gathers, !Root && VL.equals(Gathers)));
};
- /// Compute the cost of creating a vector of type \p VecTy containing the
- /// extracted values from \p VL.
- InstructionCost computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask,
- TTI::ShuffleKind ShuffleKind) {
- auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
- unsigned NumOfParts = TTI.getNumberOfParts(VecTy);
-
- if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc ||
- !NumOfParts || VecTy->getNumElements() < NumOfParts)
- return TTI.getShuffleCost(ShuffleKind, VecTy, Mask);
-
- bool AllConsecutive = true;
- unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts;
- unsigned Idx = -1;
+ /// Compute the cost of creating a vector containing the extracted values from
+ /// \p VL.
+ InstructionCost
+ computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask,
+ ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds,
+ unsigned NumParts) {
+ assert(VL.size() > NumParts && "Unexpected scalarized shuffle.");
+ unsigned NumElts =
+ std::accumulate(VL.begin(), VL.end(), 0, [](unsigned Sz, Value *V) {
+ auto *EE = dyn_cast<ExtractElementInst>(V);
+ if (!EE)
+ return Sz;
+ auto *VecTy = cast<FixedVectorType>(EE->getVectorOperandType());
+ return std::max(Sz, VecTy->getNumElements());
+ });
+ unsigned NumSrcRegs = TTI.getNumberOfParts(
+ FixedVectorType::get(VL.front()->getType(), NumElts));
+ if (NumSrcRegs == 0)
+ NumSrcRegs = 1;
+ // FIXME: this must be moved to TTI for better estimation.
+ unsigned EltsPerVector = PowerOf2Ceil(std::max(
+ divideCeil(VL.size(), NumParts), divideCeil(NumElts, NumSrcRegs)));
+ auto CheckPerRegistersShuffle =
+ [&](MutableArrayRef<int> Mask) -> std::optional<TTI::ShuffleKind> {
+ DenseSet<int> RegIndices;
+ // Check that if trying to permute same single/2 input vectors.
+ TTI::ShuffleKind ShuffleKind = TTI::SK_PermuteSingleSrc;
+ int FirstRegId = -1;
+ for (int &I : Mask) {
+ if (I == PoisonMaskElem)
+ continue;
+ int RegId = (I / NumElts) * NumParts + (I % NumElts) / EltsPerVector;
+ if (FirstRegId < 0)
+ FirstRegId = RegId;
+ RegIndices.insert(RegId);
+ if (RegIndices.size() > 2)
+ return std::nullopt;
+ if (RegIndices.size() == 2)
+ ShuffleKind = TTI::SK_PermuteTwoSrc;
+ I = (I % NumElts) % EltsPerVector +
+ (RegId == FirstRegId ? 0 : EltsPerVector);
+ }
+ return ShuffleKind;
+ };
InstructionCost Cost = 0;
// Process extracts in blocks of EltsPerVector to check if the source vector
// operand can be re-used directly. If not, add the cost of creating a
// shuffle to extract the values into a vector register.
- SmallVector<int> RegMask(EltsPerVector, PoisonMaskElem);
- for (auto *V : VL) {
- ++Idx;
-
- // Reached the start of a new vector registers.
- if (Idx % EltsPerVector == 0) {
- RegMask.assign(EltsPerVector, PoisonMaskElem);
- AllConsecutive = true;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ if (!ShuffleKinds[Part])
continue;
- }
-
- // Need to exclude undefs from analysis.
- if (isa<UndefValue>(V) || Mask[Idx] == PoisonMaskElem)
+ ArrayRef<int> MaskSlice =
+ Mask.slice(Part * EltsPerVector,
+ (Part == NumParts - 1 && Mask.size() % EltsPerVector != 0)
+ ? Mask.size() % EltsPerVector
+ : EltsPerVector);
+ SmallVector<int> SubMask(EltsPerVector, PoisonMaskElem);
+ copy(MaskSlice, SubMask.begin());
+ std::optional<TTI::ShuffleKind> RegShuffleKind =
+ CheckPerRegistersShuffle(SubMask);
+ if (!RegShuffleKind) {
+ Cost += TTI.getShuffleCost(
+ *ShuffleKinds[Part],
+ FixedVectorType::get(VL.front()->getType(), NumElts), MaskSlice);
continue;
-
- // Check all extracts for a vector register on the target directly
- // extract values in order.
- unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V));
- if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != PoisonMaskElem) {
- unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1]));
- AllConsecutive &= PrevIdx + 1 == CurrentIdx &&
- CurrentIdx % EltsPerVector == Idx % EltsPerVector;
- RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector;
}
-
- if (AllConsecutive)
- continue;
-
- // Skip all indices, except for the last index per vector block.
- if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size())
- continue;
-
- // If we have a series of extracts which are not consecutive and hence
- // cannot re-use the source vector register directly, compute the shuffle
- // cost to extract the vector with EltsPerVector elements.
- Cost += TTI.getShuffleCost(
- TargetTransformInfo::SK_PermuteSingleSrc,
- FixedVectorType::get(VecTy->getElementType(), EltsPerVector),
- RegMask);
+ if (*RegShuffleKind != TTI::SK_PermuteSingleSrc ||
+ !ShuffleVectorInst::isIdentityMask(SubMask, EltsPerVector)) {
+ Cost += TTI.getShuffleCost(
+ *RegShuffleKind,
+ FixedVectorType::get(VL.front()->getType(), EltsPerVector),
+ SubMask);
+ }
}
return Cost;
}
+ /// Transforms mask \p CommonMask per given \p Mask to make proper set after
+ /// shuffle emission.
+ static void transformMaskAfterShuffle(MutableArrayRef<int> CommonMask,
+ ArrayRef<int> Mask) {
+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
+ if (Mask[Idx] != PoisonMaskElem)
+ CommonMask[Idx] = Idx;
+ }
+ /// Adds the cost of reshuffling \p E1 and \p E2 (if present), using given
+ /// mask \p Mask, register number \p Part, that includes \p SliceSize
+ /// elements.
+ void estimateNodesPermuteCost(const TreeEntry &E1, const TreeEntry *E2,
+ ArrayRef<int> Mask, unsigned Part,
+ unsigned SliceSize) {
+ if (SameNodesEstimated) {
+ // Delay the cost estimation if the same nodes are reshuffling.
+ // If we already requested the cost of reshuffling of E1 and E2 before, no
+ // need to estimate another cost with the sub-Mask, instead include this
+ // sub-Mask into the CommonMask to estimate it later and avoid double cost
+ // estimation.
+ if ((InVectors.size() == 2 &&
+ InVectors.front().get<const TreeEntry *>() == &E1 &&
+ InVectors.back().get<const TreeEntry *>() == E2) ||
+ (!E2 && InVectors.front().get<const TreeEntry *>() == &E1)) {
+ assert(all_of(ArrayRef(CommonMask).slice(Part * SliceSize, SliceSize),
+ [](int Idx) { return Idx == PoisonMaskElem; }) &&
+ "Expected all poisoned elements.");
+ ArrayRef<int> SubMask =
+ ArrayRef(Mask).slice(Part * SliceSize, SliceSize);
+ copy(SubMask, std::next(CommonMask.begin(), SliceSize * Part));
+ return;
+ }
+ // Found non-matching nodes - need to estimate the cost for the matched
+ // and transform mask.
+ Cost += createShuffle(InVectors.front(),
+ InVectors.size() == 1 ? nullptr : InVectors.back(),
+ CommonMask);
+ transformMaskAfterShuffle(CommonMask, CommonMask);
+ }
+ SameNodesEstimated = false;
+ Cost += createShuffle(&E1, E2, Mask);
+ transformMaskAfterShuffle(CommonMask, Mask);
+ }
class ShuffleCostBuilder {
const TargetTransformInfo &TTI;
static bool isEmptyOrIdentity(ArrayRef<int> Mask, unsigned VF) {
- int Limit = 2 * VF;
+ int Index = -1;
return Mask.empty() ||
(VF == Mask.size() &&
- all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) &&
- ShuffleVectorInst::isIdentityMask(Mask));
+ ShuffleVectorInst::isIdentityMask(Mask, VF)) ||
+ (ShuffleVectorInst::isExtractSubvectorMask(Mask, VF, Index) &&
+ Index == 0);
}
public:
@@ -7021,21 +7248,17 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue();
if (isEmptyOrIdentity(Mask, VF))
return TTI::TCC_Free;
- return TTI.getShuffleCost(
- TTI::SK_PermuteTwoSrc,
- FixedVectorType::get(
- cast<VectorType>(V1->getType())->getElementType(), Mask.size()),
- Mask);
+ return TTI.getShuffleCost(TTI::SK_PermuteTwoSrc,
+ cast<VectorType>(V1->getType()), Mask);
}
InstructionCost createShuffleVector(Value *V1, ArrayRef<int> Mask) const {
// Empty mask or identity mask are free.
- if (isEmptyOrIdentity(Mask, Mask.size()))
+ unsigned VF =
+ cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue();
+ if (isEmptyOrIdentity(Mask, VF))
return TTI::TCC_Free;
- return TTI.getShuffleCost(
- TTI::SK_PermuteSingleSrc,
- FixedVectorType::get(
- cast<VectorType>(V1->getType())->getElementType(), Mask.size()),
- Mask);
+ return TTI.getShuffleCost(TTI::SK_PermuteSingleSrc,
+ cast<VectorType>(V1->getType()), Mask);
}
InstructionCost createIdentity(Value *) const { return TTI::TCC_Free; }
InstructionCost createPoison(Type *Ty, unsigned VF) const {
@@ -7052,139 +7275,226 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
const PointerUnion<Value *, const TreeEntry *> &P2,
ArrayRef<int> Mask) {
ShuffleCostBuilder Builder(TTI);
+ SmallVector<int> CommonMask(Mask.begin(), Mask.end());
Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>();
- unsigned CommonVF = 0;
- if (!V1) {
+ unsigned CommonVF = Mask.size();
+ if (!V1 && !V2 && !P2.isNull()) {
+ // Shuffle 2 entry nodes.
const TreeEntry *E = P1.get<const TreeEntry *>();
unsigned VF = E->getVectorFactor();
- if (V2) {
- unsigned V2VF = cast<FixedVectorType>(V2->getType())->getNumElements();
- if (V2VF != VF && V2VF == E->Scalars.size())
- VF = E->Scalars.size();
- } else if (!P2.isNull()) {
- const TreeEntry *E2 = P2.get<const TreeEntry *>();
- if (E->Scalars.size() == E2->Scalars.size())
- CommonVF = VF = E->Scalars.size();
- } else {
- // P2 is empty, check that we have same node + reshuffle (if any).
- if (E->Scalars.size() == Mask.size() && VF != Mask.size()) {
- VF = E->Scalars.size();
- SmallVector<int> CommonMask(Mask.begin(), Mask.end());
- ::addMask(CommonMask, E->getCommonMask());
- V1 = Constant::getNullValue(
- FixedVectorType::get(E->Scalars.front()->getType(), VF));
- return BaseShuffleAnalysis::createShuffle<InstructionCost>(
- V1, nullptr, CommonMask, Builder);
+ const TreeEntry *E2 = P2.get<const TreeEntry *>();
+ CommonVF = std::max(VF, E2->getVectorFactor());
+ assert(all_of(Mask,
+ [=](int Idx) {
+ return Idx < 2 * static_cast<int>(CommonVF);
+ }) &&
+ "All elements in mask must be less than 2 * CommonVF.");
+ if (E->Scalars.size() == E2->Scalars.size()) {
+ SmallVector<int> EMask = E->getCommonMask();
+ SmallVector<int> E2Mask = E2->getCommonMask();
+ if (!EMask.empty() || !E2Mask.empty()) {
+ for (int &Idx : CommonMask) {
+ if (Idx == PoisonMaskElem)
+ continue;
+ if (Idx < static_cast<int>(CommonVF) && !EMask.empty())
+ Idx = EMask[Idx];
+ else if (Idx >= static_cast<int>(CommonVF))
+ Idx = (E2Mask.empty() ? Idx - CommonVF : E2Mask[Idx - CommonVF]) +
+ E->Scalars.size();
+ }
}
+ CommonVF = E->Scalars.size();
}
V1 = Constant::getNullValue(
- FixedVectorType::get(E->Scalars.front()->getType(), VF));
- }
- if (!V2 && !P2.isNull()) {
- const TreeEntry *E = P2.get<const TreeEntry *>();
+ FixedVectorType::get(E->Scalars.front()->getType(), CommonVF));
+ V2 = getAllOnesValue(
+ *R.DL, FixedVectorType::get(E->Scalars.front()->getType(), CommonVF));
+ } else if (!V1 && P2.isNull()) {
+ // Shuffle single entry node.
+ const TreeEntry *E = P1.get<const TreeEntry *>();
unsigned VF = E->getVectorFactor();
- unsigned V1VF = cast<FixedVectorType>(V1->getType())->getNumElements();
- if (!CommonVF && V1VF == E->Scalars.size())
+ CommonVF = VF;
+ assert(
+ all_of(Mask,
+ [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
+ "All elements in mask must be less than CommonVF.");
+ if (E->Scalars.size() == Mask.size() && VF != Mask.size()) {
+ SmallVector<int> EMask = E->getCommonMask();
+ assert(!EMask.empty() && "Expected non-empty common mask.");
+ for (int &Idx : CommonMask) {
+ if (Idx != PoisonMaskElem)
+ Idx = EMask[Idx];
+ }
CommonVF = E->Scalars.size();
- if (CommonVF)
- VF = CommonVF;
- V2 = Constant::getNullValue(
- FixedVectorType::get(E->Scalars.front()->getType(), VF));
+ }
+ V1 = Constant::getNullValue(
+ FixedVectorType::get(E->Scalars.front()->getType(), CommonVF));
+ } else if (V1 && P2.isNull()) {
+ // Shuffle single vector.
+ CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ assert(
+ all_of(Mask,
+ [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
+ "All elements in mask must be less than CommonVF.");
+ } else if (V1 && !V2) {
+ // Shuffle vector and tree node.
+ unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ const TreeEntry *E2 = P2.get<const TreeEntry *>();
+ CommonVF = std::max(VF, E2->getVectorFactor());
+ assert(all_of(Mask,
+ [=](int Idx) {
+ return Idx < 2 * static_cast<int>(CommonVF);
+ }) &&
+ "All elements in mask must be less than 2 * CommonVF.");
+ if (E2->Scalars.size() == VF && VF != CommonVF) {
+ SmallVector<int> E2Mask = E2->getCommonMask();
+ assert(!E2Mask.empty() && "Expected non-empty common mask.");
+ for (int &Idx : CommonMask) {
+ if (Idx == PoisonMaskElem)
+ continue;
+ if (Idx >= static_cast<int>(CommonVF))
+ Idx = E2Mask[Idx - CommonVF] + VF;
+ }
+ CommonVF = VF;
+ }
+ V1 = Constant::getNullValue(
+ FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF));
+ V2 = getAllOnesValue(
+ *R.DL,
+ FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF));
+ } else if (!V1 && V2) {
+ // Shuffle vector and tree node.
+ unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+ const TreeEntry *E1 = P1.get<const TreeEntry *>();
+ CommonVF = std::max(VF, E1->getVectorFactor());
+ assert(all_of(Mask,
+ [=](int Idx) {
+ return Idx < 2 * static_cast<int>(CommonVF);
+ }) &&
+ "All elements in mask must be less than 2 * CommonVF.");
+ if (E1->Scalars.size() == VF && VF != CommonVF) {
+ SmallVector<int> E1Mask = E1->getCommonMask();
+ assert(!E1Mask.empty() && "Expected non-empty common mask.");
+ for (int &Idx : CommonMask) {
+ if (Idx == PoisonMaskElem)
+ continue;
+ if (Idx >= static_cast<int>(CommonVF))
+ Idx = E1Mask[Idx - CommonVF] + VF;
+ }
+ CommonVF = VF;
+ }
+ V1 = Constant::getNullValue(
+ FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
+ V2 = getAllOnesValue(
+ *R.DL,
+ FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
+ } else {
+ assert(V1 && V2 && "Expected both vectors.");
+ unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ CommonVF =
+ std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
+ assert(all_of(Mask,
+ [=](int Idx) {
+ return Idx < 2 * static_cast<int>(CommonVF);
+ }) &&
+ "All elements in mask must be less than 2 * CommonVF.");
+ if (V1->getType() != V2->getType()) {
+ V1 = Constant::getNullValue(FixedVectorType::get(
+ cast<FixedVectorType>(V1->getType())->getElementType(), CommonVF));
+ V2 = getAllOnesValue(
+ *R.DL, FixedVectorType::get(
+ cast<FixedVectorType>(V1->getType())->getElementType(),
+ CommonVF));
+ }
}
- return BaseShuffleAnalysis::createShuffle<InstructionCost>(V1, V2, Mask,
- Builder);
+ InVectors.front() = Constant::getNullValue(FixedVectorType::get(
+ cast<FixedVectorType>(V1->getType())->getElementType(),
+ CommonMask.size()));
+ if (InVectors.size() == 2)
+ InVectors.pop_back();
+ return BaseShuffleAnalysis::createShuffle<InstructionCost>(
+ V1, V2, CommonMask, Builder);
}
public:
ShuffleCostEstimator(TargetTransformInfo &TTI,
ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
SmallPtrSetImpl<Value *> &CheckedExtracts)
- : TTI(TTI), VectorizedVals(VectorizedVals), R(R),
- CheckedExtracts(CheckedExtracts) {}
- Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask,
- TTI::ShuffleKind ShuffleKind) {
+ : TTI(TTI), VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()),
+ R(R), CheckedExtracts(CheckedExtracts) {}
+ Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
+ ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds,
+ unsigned NumParts, bool &UseVecBaseAsInput) {
+ UseVecBaseAsInput = false;
if (Mask.empty())
return nullptr;
Value *VecBase = nullptr;
ArrayRef<Value *> VL = E->Scalars;
- auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
// If the resulting type is scalarized, do not adjust the cost.
- unsigned VecNumParts = TTI.getNumberOfParts(VecTy);
- if (VecNumParts == VecTy->getNumElements())
+ if (NumParts == VL.size())
return nullptr;
- DenseMap<Value *, int> ExtractVectorsTys;
- for (auto [I, V] : enumerate(VL)) {
- // Ignore non-extractelement scalars.
- if (isa<UndefValue>(V) || (!Mask.empty() && Mask[I] == PoisonMaskElem))
- continue;
- // If all users of instruction are going to be vectorized and this
- // instruction itself is not going to be vectorized, consider this
- // instruction as dead and remove its cost from the final cost of the
- // vectorized tree.
- // Also, avoid adjusting the cost for extractelements with multiple uses
- // in different graph entries.
- const TreeEntry *VE = R.getTreeEntry(V);
- if (!CheckedExtracts.insert(V).second ||
- !R.areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) ||
- (VE && VE != E))
- continue;
- auto *EE = cast<ExtractElementInst>(V);
- VecBase = EE->getVectorOperand();
- std::optional<unsigned> EEIdx = getExtractIndex(EE);
- if (!EEIdx)
- continue;
- unsigned Idx = *EEIdx;
- if (VecNumParts != TTI.getNumberOfParts(EE->getVectorOperandType())) {
- auto It =
- ExtractVectorsTys.try_emplace(EE->getVectorOperand(), Idx).first;
- It->getSecond() = std::min<int>(It->second, Idx);
- }
- // Take credit for instruction that will become dead.
- if (EE->hasOneUse()) {
- Instruction *Ext = EE->user_back();
- if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) {
- return isa<GetElementPtrInst>(U);
- })) {
- // Use getExtractWithExtendCost() to calculate the cost of
- // extractelement/ext pair.
- Cost -= TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(),
- EE->getVectorOperandType(), Idx);
- // Add back the cost of s|zext which is subtracted separately.
- Cost += TTI.getCastInstrCost(
- Ext->getOpcode(), Ext->getType(), EE->getType(),
- TTI::getCastContextHint(Ext), CostKind, Ext);
+ // Check if it can be considered reused if same extractelements were
+ // vectorized already.
+ bool PrevNodeFound = any_of(
+ ArrayRef(R.VectorizableTree).take_front(E->Idx),
+ [&](const std::unique_ptr<TreeEntry> &TE) {
+ return ((!TE->isAltShuffle() &&
+ TE->getOpcode() == Instruction::ExtractElement) ||
+ TE->State == TreeEntry::NeedToGather) &&
+ all_of(enumerate(TE->Scalars), [&](auto &&Data) {
+ return VL.size() > Data.index() &&
+ (Mask[Data.index()] == PoisonMaskElem ||
+ isa<UndefValue>(VL[Data.index()]) ||
+ Data.value() == VL[Data.index()]);
+ });
+ });
+ SmallPtrSet<Value *, 4> UniqueBases;
+ unsigned SliceSize = VL.size() / NumParts;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize);
+ for (auto [I, V] : enumerate(VL.slice(Part * SliceSize, SliceSize))) {
+ // Ignore non-extractelement scalars.
+ if (isa<UndefValue>(V) ||
+ (!SubMask.empty() && SubMask[I] == PoisonMaskElem))
continue;
- }
- }
- Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind,
- Idx);
- }
- // Add a cost for subvector extracts/inserts if required.
- for (const auto &Data : ExtractVectorsTys) {
- auto *EEVTy = cast<FixedVectorType>(Data.first->getType());
- unsigned NumElts = VecTy->getNumElements();
- if (Data.second % NumElts == 0)
- continue;
- if (TTI.getNumberOfParts(EEVTy) > VecNumParts) {
- unsigned Idx = (Data.second / NumElts) * NumElts;
- unsigned EENumElts = EEVTy->getNumElements();
- if (Idx % NumElts == 0)
+ // If all users of instruction are going to be vectorized and this
+ // instruction itself is not going to be vectorized, consider this
+ // instruction as dead and remove its cost from the final cost of the
+ // vectorized tree.
+ // Also, avoid adjusting the cost for extractelements with multiple uses
+ // in different graph entries.
+ auto *EE = cast<ExtractElementInst>(V);
+ VecBase = EE->getVectorOperand();
+ UniqueBases.insert(VecBase);
+ const TreeEntry *VE = R.getTreeEntry(V);
+ if (!CheckedExtracts.insert(V).second ||
+ !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) ||
+ (VE && VE != E))
continue;
- if (Idx + NumElts <= EENumElts) {
- Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
- EEVTy, std::nullopt, CostKind, Idx, VecTy);
- } else {
- // Need to round up the subvector type vectorization factor to avoid a
- // crash in cost model functions. Make SubVT so that Idx + VF of SubVT
- // <= EENumElts.
- auto *SubVT =
- FixedVectorType::get(VecTy->getElementType(), EENumElts - Idx);
- Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector,
- EEVTy, std::nullopt, CostKind, Idx, SubVT);
+ std::optional<unsigned> EEIdx = getExtractIndex(EE);
+ if (!EEIdx)
+ continue;
+ unsigned Idx = *EEIdx;
+ // Take credit for instruction that will become dead.
+ if (EE->hasOneUse() || !PrevNodeFound) {
+ Instruction *Ext = EE->user_back();
+ if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) {
+ return isa<GetElementPtrInst>(U);
+ })) {
+ // Use getExtractWithExtendCost() to calculate the cost of
+ // extractelement/ext pair.
+ Cost -=
+ TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(),
+ EE->getVectorOperandType(), Idx);
+ // Add back the cost of s|zext which is subtracted separately.
+ Cost += TTI.getCastInstrCost(
+ Ext->getOpcode(), Ext->getType(), EE->getType(),
+ TTI::getCastContextHint(Ext), CostKind, Ext);
+ continue;
+ }
}
- } else {
- Cost += TTI.getShuffleCost(TargetTransformInfo::SK_InsertSubvector,
- VecTy, std::nullopt, CostKind, 0, EEVTy);
+ Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(),
+ CostKind, Idx);
}
}
// Check that gather of extractelements can be represented as just a
@@ -7192,31 +7502,152 @@ public:
// Found the bunch of extractelement instructions that must be gathered
// into a vector and can be represented as a permutation elements in a
// single input vector or of 2 input vectors.
- Cost += computeExtractCost(VL, Mask, ShuffleKind);
+ // Done for reused if same extractelements were vectorized already.
+ if (!PrevNodeFound)
+ Cost += computeExtractCost(VL, Mask, ShuffleKinds, NumParts);
+ InVectors.assign(1, E);
+ CommonMask.assign(Mask.begin(), Mask.end());
+ transformMaskAfterShuffle(CommonMask, CommonMask);
+ SameNodesEstimated = false;
+ if (NumParts != 1 && UniqueBases.size() != 1) {
+ UseVecBaseAsInput = true;
+ VecBase = Constant::getNullValue(
+ FixedVectorType::get(VL.front()->getType(), CommonMask.size()));
+ }
return VecBase;
}
- void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef<int> Mask) {
- CommonMask.assign(Mask.begin(), Mask.end());
- InVectors.assign({E1, E2});
+ /// Checks if the specified entry \p E needs to be delayed because of its
+ /// dependency nodes.
+ std::optional<InstructionCost>
+ needToDelay(const TreeEntry *,
+ ArrayRef<SmallVector<const TreeEntry *>>) const {
+ // No need to delay the cost estimation during analysis.
+ return std::nullopt;
}
- void add(const TreeEntry *E1, ArrayRef<int> Mask) {
- CommonMask.assign(Mask.begin(), Mask.end());
- InVectors.assign(1, E1);
+ void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
+ if (&E1 == &E2) {
+ assert(all_of(Mask,
+ [&](int Idx) {
+ return Idx < static_cast<int>(E1.getVectorFactor());
+ }) &&
+ "Expected single vector shuffle mask.");
+ add(E1, Mask);
+ return;
+ }
+ if (InVectors.empty()) {
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign({&E1, &E2});
+ return;
+ }
+ assert(!CommonMask.empty() && "Expected non-empty common mask.");
+ auto *MaskVecTy =
+ FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size());
+ unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
+ if (NumParts == 0 || NumParts >= Mask.size())
+ NumParts = 1;
+ unsigned SliceSize = Mask.size() / NumParts;
+ const auto *It =
+ find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
+ unsigned Part = std::distance(Mask.begin(), It) / SliceSize;
+ estimateNodesPermuteCost(E1, &E2, Mask, Part, SliceSize);
+ }
+ void add(const TreeEntry &E1, ArrayRef<int> Mask) {
+ if (InVectors.empty()) {
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign(1, &E1);
+ return;
+ }
+ assert(!CommonMask.empty() && "Expected non-empty common mask.");
+ auto *MaskVecTy =
+ FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size());
+ unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
+ if (NumParts == 0 || NumParts >= Mask.size())
+ NumParts = 1;
+ unsigned SliceSize = Mask.size() / NumParts;
+ const auto *It =
+ find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
+ unsigned Part = std::distance(Mask.begin(), It) / SliceSize;
+ estimateNodesPermuteCost(E1, nullptr, Mask, Part, SliceSize);
+ if (!SameNodesEstimated && InVectors.size() == 1)
+ InVectors.emplace_back(&E1);
+ }
+ /// Adds 2 input vectors and the mask for their shuffling.
+ void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
+ // May come only for shuffling of 2 vectors with extractelements, already
+ // handled in adjustExtracts.
+ assert(InVectors.size() == 1 &&
+ all_of(enumerate(CommonMask),
+ [&](auto P) {
+ if (P.value() == PoisonMaskElem)
+ return Mask[P.index()] == PoisonMaskElem;
+ auto *EI =
+ cast<ExtractElementInst>(InVectors.front()
+ .get<const TreeEntry *>()
+ ->Scalars[P.index()]);
+ return EI->getVectorOperand() == V1 ||
+ EI->getVectorOperand() == V2;
+ }) &&
+ "Expected extractelement vectors.");
}
/// Adds another one input vector and the mask for the shuffling.
- void add(Value *V1, ArrayRef<int> Mask) {
- assert(CommonMask.empty() && InVectors.empty() &&
- "Expected empty input mask/vectors.");
- CommonMask.assign(Mask.begin(), Mask.end());
- InVectors.assign(1, V1);
+ void add(Value *V1, ArrayRef<int> Mask, bool ForExtracts = false) {
+ if (InVectors.empty()) {
+ assert(CommonMask.empty() && !ForExtracts &&
+ "Expected empty input mask/vectors.");
+ CommonMask.assign(Mask.begin(), Mask.end());
+ InVectors.assign(1, V1);
+ return;
+ }
+ if (ForExtracts) {
+ // No need to add vectors here, already handled them in adjustExtracts.
+ assert(InVectors.size() == 1 &&
+ InVectors.front().is<const TreeEntry *>() && !CommonMask.empty() &&
+ all_of(enumerate(CommonMask),
+ [&](auto P) {
+ Value *Scalar = InVectors.front()
+ .get<const TreeEntry *>()
+ ->Scalars[P.index()];
+ if (P.value() == PoisonMaskElem)
+ return P.value() == Mask[P.index()] ||
+ isa<UndefValue>(Scalar);
+ if (isa<Constant>(V1))
+ return true;
+ auto *EI = cast<ExtractElementInst>(Scalar);
+ return EI->getVectorOperand() == V1;
+ }) &&
+ "Expected only tree entry for extractelement vectors.");
+ return;
+ }
+ assert(!InVectors.empty() && !CommonMask.empty() &&
+ "Expected only tree entries from extracts/reused buildvectors.");
+ unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ if (InVectors.size() == 2) {
+ Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
+ transformMaskAfterShuffle(CommonMask, CommonMask);
+ VF = std::max<unsigned>(VF, CommonMask.size());
+ } else if (const auto *InTE =
+ InVectors.front().dyn_cast<const TreeEntry *>()) {
+ VF = std::max(VF, InTE->getVectorFactor());
+ } else {
+ VF = std::max(
+ VF, cast<FixedVectorType>(InVectors.front().get<Value *>()->getType())
+ ->getNumElements());
+ }
+ InVectors.push_back(V1);
+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
+ if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
+ CommonMask[Idx] = Mask[Idx] + VF;
}
- Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
+ Value *gather(ArrayRef<Value *> VL, unsigned MaskVF = 0,
+ Value *Root = nullptr) {
Cost += getBuildVectorCost(VL, Root);
if (!Root) {
- assert(InVectors.empty() && "Unexpected input vectors for buildvector.");
// FIXME: Need to find a way to avoid use of getNullValue here.
SmallVector<Constant *> Vals;
- for (Value *V : VL) {
+ unsigned VF = VL.size();
+ if (MaskVF != 0)
+ VF = std::min(VF, MaskVF);
+ for (Value *V : VL.take_front(VF)) {
if (isa<UndefValue>(V)) {
Vals.push_back(cast<Constant>(V));
continue;
@@ -7226,9 +7657,11 @@ public:
return ConstantVector::get(Vals);
}
return ConstantVector::getSplat(
- ElementCount::getFixed(VL.size()),
- Constant::getNullValue(VL.front()->getType()));
+ ElementCount::getFixed(
+ cast<FixedVectorType>(Root->getType())->getNumElements()),
+ getAllOnesValue(*R.DL, VL.front()->getType()));
}
+ InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
/// Finalize emission of the shuffles.
InstructionCost
finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
@@ -7236,31 +7669,24 @@ public:
IsFinalized = true;
if (Action) {
const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
- if (InVectors.size() == 2) {
+ if (InVectors.size() == 2)
Cost += createShuffle(Vec, InVectors.back(), CommonMask);
- InVectors.pop_back();
- } else {
+ else
Cost += createShuffle(Vec, nullptr, CommonMask);
- }
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
if (CommonMask[Idx] != PoisonMaskElem)
CommonMask[Idx] = Idx;
assert(VF > 0 &&
"Expected vector length for the final value before action.");
- Value *V = Vec.dyn_cast<Value *>();
- if (!Vec.isNull() && !V)
- V = Constant::getNullValue(FixedVectorType::get(
- Vec.get<const TreeEntry *>()->Scalars.front()->getType(),
- CommonMask.size()));
+ Value *V = Vec.get<Value *>();
Action(V, CommonMask);
+ InVectors.front() = V;
}
::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true);
- if (CommonMask.empty())
- return Cost;
- int Limit = CommonMask.size() * 2;
- if (all_of(CommonMask, [=](int Idx) { return Idx < Limit; }) &&
- ShuffleVectorInst::isIdentityMask(CommonMask))
+ if (CommonMask.empty()) {
+ assert(InVectors.size() == 1 && "Expected only one vector with no mask");
return Cost;
+ }
return Cost +
createShuffle(InVectors.front(),
InVectors.size() == 2 ? InVectors.back() : nullptr,
@@ -7273,28 +7699,63 @@ public:
}
};
+const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E,
+ unsigned Idx) const {
+ Value *Op = E->getOperand(Idx).front();
+ if (const TreeEntry *TE = getTreeEntry(Op)) {
+ if (find_if(E->UserTreeIndices, [&](const EdgeInfo &EI) {
+ return EI.EdgeIdx == Idx && EI.UserTE == E;
+ }) != TE->UserTreeIndices.end())
+ return TE;
+ auto MIt = MultiNodeScalars.find(Op);
+ if (MIt != MultiNodeScalars.end()) {
+ for (const TreeEntry *TE : MIt->second) {
+ if (find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) {
+ return EI.EdgeIdx == Idx && EI.UserTE == E;
+ }) != TE->UserTreeIndices.end())
+ return TE;
+ }
+ }
+ }
+ const auto *It =
+ find_if(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
+ return TE->State == TreeEntry::NeedToGather &&
+ find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) {
+ return EI.EdgeIdx == Idx && EI.UserTE == E;
+ }) != TE->UserTreeIndices.end();
+ });
+ assert(It != VectorizableTree.end() && "Expected vectorizable entry.");
+ return It->get();
+}
+
InstructionCost
BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
SmallPtrSetImpl<Value *> &CheckedExtracts) {
ArrayRef<Value *> VL = E->Scalars;
Type *ScalarTy = VL[0]->getType();
- if (auto *SI = dyn_cast<StoreInst>(VL[0]))
- ScalarTy = SI->getValueOperand()->getType();
- else if (auto *CI = dyn_cast<CmpInst>(VL[0]))
- ScalarTy = CI->getOperand(0)->getType();
- else if (auto *IE = dyn_cast<InsertElementInst>(VL[0]))
- ScalarTy = IE->getOperand(1)->getType();
+ if (E->State != TreeEntry::NeedToGather) {
+ if (auto *SI = dyn_cast<StoreInst>(VL[0]))
+ ScalarTy = SI->getValueOperand()->getType();
+ else if (auto *CI = dyn_cast<CmpInst>(VL[0]))
+ ScalarTy = CI->getOperand(0)->getType();
+ else if (auto *IE = dyn_cast<InsertElementInst>(VL[0]))
+ ScalarTy = IE->getOperand(1)->getType();
+ }
+ if (!FixedVectorType::isValidElementType(ScalarTy))
+ return InstructionCost::getInvalid();
auto *VecTy = FixedVectorType::get(ScalarTy, VL.size());
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
// If we have computed a smaller type for the expression, update VecTy so
// that the costs will be accurate.
- if (MinBWs.count(VL[0]))
- VecTy = FixedVectorType::get(
- IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size());
+ auto It = MinBWs.find(E);
+ if (It != MinBWs.end()) {
+ ScalarTy = IntegerType::get(F->getContext(), It->second.first);
+ VecTy = FixedVectorType::get(ScalarTy, VL.size());
+ }
unsigned EntryVF = E->getVectorFactor();
- auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF);
+ auto *FinalVecTy = FixedVectorType::get(ScalarTy, EntryVF);
bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
if (E->State == TreeEntry::NeedToGather) {
@@ -7302,121 +7763,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return 0;
if (isa<InsertElementInst>(VL[0]))
return InstructionCost::getInvalid();
- ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this,
- CheckedExtracts);
- unsigned VF = E->getVectorFactor();
- SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(),
- E->ReuseShuffleIndices.end());
- SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end());
- // Build a mask out of the reorder indices and reorder scalars per this
- // mask.
- SmallVector<int> ReorderMask;
- inversePermutation(E->ReorderIndices, ReorderMask);
- if (!ReorderMask.empty())
- reorderScalars(GatheredScalars, ReorderMask);
- SmallVector<int> Mask;
- SmallVector<int> ExtractMask;
- std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle;
- std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle;
- SmallVector<const TreeEntry *> Entries;
- Type *ScalarTy = GatheredScalars.front()->getType();
- // Check for gathered extracts.
- ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask);
- SmallVector<Value *> IgnoredVals;
- if (UserIgnoreList)
- IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end());
-
- bool Resized = false;
- if (Value *VecBase = Estimator.adjustExtracts(
- E, ExtractMask, ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc)))
- if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
- if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) {
- Resized = true;
- GatheredScalars.append(VF - GatheredScalars.size(),
- PoisonValue::get(ScalarTy));
- }
-
- // Do not try to look for reshuffled loads for gathered loads (they will be
- // handled later), for vectorized scalars, and cases, which are definitely
- // not profitable (splats and small gather nodes.)
- if (ExtractShuffle || E->getOpcode() != Instruction::Load ||
- E->isAltShuffle() ||
- all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) ||
- isSplat(E->Scalars) ||
- (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2))
- GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries);
- if (GatherShuffle) {
- assert((Entries.size() == 1 || Entries.size() == 2) &&
- "Expected shuffle of 1 or 2 entries.");
- if (*GatherShuffle == TTI::SK_PermuteSingleSrc &&
- Entries.front()->isSame(E->Scalars)) {
- // Perfect match in the graph, will reuse the previously vectorized
- // node. Cost is 0.
- LLVM_DEBUG(
- dbgs()
- << "SLP: perfect diamond match for gather bundle that starts with "
- << *VL.front() << ".\n");
- // Restore the mask for previous partially matched values.
- for (auto [I, V] : enumerate(E->Scalars)) {
- if (isa<PoisonValue>(V)) {
- Mask[I] = PoisonMaskElem;
- continue;
- }
- if (Mask[I] == PoisonMaskElem)
- Mask[I] = Entries.front()->findLaneForValue(V);
- }
- Estimator.add(Entries.front(), Mask);
- return Estimator.finalize(E->ReuseShuffleIndices);
- }
- if (!Resized) {
- unsigned VF1 = Entries.front()->getVectorFactor();
- unsigned VF2 = Entries.back()->getVectorFactor();
- if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF)
- GatheredScalars.append(VF - GatheredScalars.size(),
- PoisonValue::get(ScalarTy));
- }
- // Remove shuffled elements from list of gathers.
- for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
- if (Mask[I] != PoisonMaskElem)
- GatheredScalars[I] = PoisonValue::get(ScalarTy);
- }
- LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size()
- << " entries for bundle that starts with "
- << *VL.front() << ".\n";);
- if (Entries.size() == 1)
- Estimator.add(Entries.front(), Mask);
- else
- Estimator.add(Entries.front(), Entries.back(), Mask);
- if (all_of(GatheredScalars, PoisonValue ::classof))
- return Estimator.finalize(E->ReuseShuffleIndices);
- return Estimator.finalize(
- E->ReuseShuffleIndices, E->Scalars.size(),
- [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
- Vec = Estimator.gather(GatheredScalars,
- Constant::getNullValue(FixedVectorType::get(
- GatheredScalars.front()->getType(),
- GatheredScalars.size())));
- });
- }
- if (!all_of(GatheredScalars, PoisonValue::classof)) {
- auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size());
- bool SameGathers = VL.equals(Gathers);
- Value *BV = Estimator.gather(
- Gathers, SameGathers ? nullptr
- : Constant::getNullValue(FixedVectorType::get(
- GatheredScalars.front()->getType(),
- GatheredScalars.size())));
- SmallVector<int> ReuseMask(Gathers.size(), PoisonMaskElem);
- std::iota(ReuseMask.begin(), ReuseMask.end(), 0);
- Estimator.add(BV, ReuseMask);
- }
- if (ExtractShuffle)
- Estimator.add(E, std::nullopt);
- return Estimator.finalize(E->ReuseShuffleIndices);
+ return processBuildVector<ShuffleCostEstimator, InstructionCost>(
+ E, *TTI, VectorizedVals, *this, CheckedExtracts);
}
InstructionCost CommonCost = 0;
SmallVector<int> Mask;
- if (!E->ReorderIndices.empty()) {
+ if (!E->ReorderIndices.empty() &&
+ E->State != TreeEntry::PossibleStridedVectorize) {
SmallVector<int> NewMask;
if (E->getOpcode() == Instruction::Store) {
// For stores the order is actually a mask.
@@ -7429,11 +7782,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
if (NeedToShuffleReuses)
::addMask(Mask, E->ReuseShuffleIndices);
- if (!Mask.empty() && !ShuffleVectorInst::isIdentityMask(Mask))
+ if (!Mask.empty() && !ShuffleVectorInst::isIdentityMask(Mask, Mask.size()))
CommonCost =
TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FinalVecTy, Mask);
assert((E->State == TreeEntry::Vectorize ||
- E->State == TreeEntry::ScatterVectorize) &&
+ E->State == TreeEntry::ScatterVectorize ||
+ E->State == TreeEntry::PossibleStridedVectorize) &&
"Unhandled state");
assert(E->getOpcode() &&
((allSameType(VL) && allSameBlock(VL)) ||
@@ -7443,7 +7797,34 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
Instruction *VL0 = E->getMainOp();
unsigned ShuffleOrOp =
E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
- const unsigned Sz = VL.size();
+ SetVector<Value *> UniqueValues(VL.begin(), VL.end());
+ const unsigned Sz = UniqueValues.size();
+ SmallBitVector UsedScalars(Sz, false);
+ for (unsigned I = 0; I < Sz; ++I) {
+ if (getTreeEntry(UniqueValues[I]) == E)
+ continue;
+ UsedScalars.set(I);
+ }
+ auto GetCastContextHint = [&](Value *V) {
+ if (const TreeEntry *OpTE = getTreeEntry(V)) {
+ if (OpTE->State == TreeEntry::ScatterVectorize)
+ return TTI::CastContextHint::GatherScatter;
+ if (OpTE->State == TreeEntry::Vectorize &&
+ OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) {
+ if (OpTE->ReorderIndices.empty())
+ return TTI::CastContextHint::Normal;
+ SmallVector<int> Mask;
+ inversePermutation(OpTE->ReorderIndices, Mask);
+ if (ShuffleVectorInst::isReverseMask(Mask, Mask.size()))
+ return TTI::CastContextHint::Reversed;
+ }
+ } else {
+ InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
+ if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
+ return TTI::CastContextHint::GatherScatter;
+ }
+ return TTI::CastContextHint::None;
+ };
auto GetCostDiff =
[=](function_ref<InstructionCost(unsigned)> ScalarEltCost,
function_ref<InstructionCost(InstructionCost)> VectorCost) {
@@ -7453,13 +7834,49 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// For some of the instructions no need to calculate cost for each
// particular instruction, we can use the cost of the single
// instruction x total number of scalar instructions.
- ScalarCost = Sz * ScalarEltCost(0);
+ ScalarCost = (Sz - UsedScalars.count()) * ScalarEltCost(0);
} else {
- for (unsigned I = 0; I < Sz; ++I)
+ for (unsigned I = 0; I < Sz; ++I) {
+ if (UsedScalars.test(I))
+ continue;
ScalarCost += ScalarEltCost(I);
+ }
}
InstructionCost VecCost = VectorCost(CommonCost);
+ // Check if the current node must be resized, if the parent node is not
+ // resized.
+ if (!UnaryInstruction::isCast(E->getOpcode()) && E->Idx != 0) {
+ const EdgeInfo &EI = E->UserTreeIndices.front();
+ if ((EI.UserTE->getOpcode() != Instruction::Select ||
+ EI.EdgeIdx != 0) &&
+ It != MinBWs.end()) {
+ auto UserBWIt = MinBWs.find(EI.UserTE);
+ Type *UserScalarTy =
+ EI.UserTE->getOperand(EI.EdgeIdx).front()->getType();
+ if (UserBWIt != MinBWs.end())
+ UserScalarTy = IntegerType::get(ScalarTy->getContext(),
+ UserBWIt->second.first);
+ if (ScalarTy != UserScalarTy) {
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
+ unsigned VecOpcode;
+ auto *SrcVecTy =
+ FixedVectorType::get(UserScalarTy, E->getVectorFactor());
+ if (BWSz > SrcBWSz)
+ VecOpcode = Instruction::Trunc;
+ else
+ VecOpcode =
+ It->second.second ? Instruction::SExt : Instruction::ZExt;
+ TTI::CastContextHint CCH = GetCastContextHint(VL0);
+ VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
+ CostKind);
+ ScalarCost +=
+ Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
+ CCH, CostKind);
+ }
+ }
+ }
LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost,
ScalarCost, "Calculated costs for Tree"));
return VecCost - ScalarCost;
@@ -7550,7 +7967,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// Count reused scalars.
InstructionCost ScalarCost = 0;
SmallPtrSet<const TreeEntry *, 4> CountedOps;
- for (Value *V : VL) {
+ for (Value *V : UniqueValues) {
auto *PHI = dyn_cast<PHINode>(V);
if (!PHI)
continue;
@@ -7571,8 +7988,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
case Instruction::ExtractValue:
case Instruction::ExtractElement: {
- auto GetScalarCost = [=](unsigned Idx) {
- auto *I = cast<Instruction>(VL[Idx]);
+ auto GetScalarCost = [&](unsigned Idx) {
+ auto *I = cast<Instruction>(UniqueValues[Idx]);
VectorType *SrcVecTy;
if (ShuffleOrOp == Instruction::ExtractElement) {
auto *EE = cast<ExtractElementInst>(I);
@@ -7680,8 +8097,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// need to shift the vector.
// Do not calculate the cost if the actual size is the register size and
// we can merge this shuffle with the following SK_Select.
- auto *InsertVecTy =
- FixedVectorType::get(SrcVecTy->getElementType(), InsertVecSz);
+ auto *InsertVecTy = FixedVectorType::get(ScalarTy, InsertVecSz);
if (!IsIdentity)
Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
InsertVecTy, Mask);
@@ -7697,8 +8113,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask));
if (!InMask.all() && NumScalars != NumElts && !IsWholeSubvector) {
if (InsertVecSz != VecSz) {
- auto *ActualVecTy =
- FixedVectorType::get(SrcVecTy->getElementType(), VecSz);
+ auto *ActualVecTy = FixedVectorType::get(ScalarTy, VecSz);
Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy,
std::nullopt, CostKind, OffsetBeg - Offset,
InsertVecTy);
@@ -7729,22 +8144,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
- auto GetScalarCost = [=](unsigned Idx) {
- auto *VI = cast<Instruction>(VL[Idx]);
- return TTI->getCastInstrCost(E->getOpcode(), ScalarTy,
- VI->getOperand(0)->getType(),
+ auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
+ Type *SrcScalarTy = VL0->getOperand(0)->getType();
+ auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+ unsigned Opcode = ShuffleOrOp;
+ unsigned VecOpcode = Opcode;
+ if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
+ (SrcIt != MinBWs.end() || It != MinBWs.end())) {
+ // Check if the values are candidates to demote.
+ unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+ if (SrcIt != MinBWs.end()) {
+ SrcBWSz = SrcIt->second.first;
+ SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
+ SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+ }
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ if (BWSz == SrcBWSz) {
+ VecOpcode = Instruction::BitCast;
+ } else if (BWSz < SrcBWSz) {
+ VecOpcode = Instruction::Trunc;
+ } else if (It != MinBWs.end()) {
+ assert(BWSz > SrcBWSz && "Invalid cast!");
+ VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+ }
+ }
+ auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
+ // Do not count cost here if minimum bitwidth is in effect and it is just
+ // a bitcast (here it is just a noop).
+ if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
+ return TTI::TCC_Free;
+ auto *VI = VL0->getOpcode() == Opcode
+ ? cast<Instruction>(UniqueValues[Idx])
+ : nullptr;
+ return TTI->getCastInstrCost(Opcode, VL0->getType(),
+ VL0->getOperand(0)->getType(),
TTI::getCastContextHint(VI), CostKind, VI);
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
- Type *SrcTy = VL0->getOperand(0)->getType();
- auto *SrcVecTy = FixedVectorType::get(SrcTy, VL.size());
- InstructionCost VecCost = CommonCost;
- // Check if the values are candidates to demote.
- if (!MinBWs.count(VL0) || VecTy != SrcVecTy)
- VecCost +=
- TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
- TTI::getCastContextHint(VL0), CostKind, VL0);
- return VecCost;
+ // Do not count cost here if minimum bitwidth is in effect and it is just
+ // a bitcast (here it is just a noop).
+ if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
+ return CommonCost;
+ auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
+ TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
+ return CommonCost +
+ TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
+ VecOpcode == Opcode ? VI : nullptr);
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
@@ -7761,7 +8206,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
? CmpInst::BAD_FCMP_PREDICATE
: CmpInst::BAD_ICMP_PREDICATE;
auto GetScalarCost = [&](unsigned Idx) {
- auto *VI = cast<Instruction>(VL[Idx]);
+ auto *VI = cast<Instruction>(UniqueValues[Idx]);
CmpInst::Predicate CurrentPred = ScalarTy->isFloatingPointTy()
? CmpInst::BAD_FCMP_PREDICATE
: CmpInst::BAD_ICMP_PREDICATE;
@@ -7821,8 +8266,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- auto GetScalarCost = [=](unsigned Idx) {
- auto *VI = cast<Instruction>(VL[Idx]);
+ auto GetScalarCost = [&](unsigned Idx) {
+ auto *VI = cast<Instruction>(UniqueValues[Idx]);
unsigned OpIdx = isa<UnaryOperator>(VI) ? 0 : 1;
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(VI->getOperand(0));
TTI::OperandValueInfo Op2Info =
@@ -7833,8 +8278,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
- TTI::OperandValueInfo Op1Info = getOperandInfo(VL, 0);
- TTI::OperandValueInfo Op2Info = getOperandInfo(VL, OpIdx);
+ TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
+ TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
Op2Info) +
CommonCost;
@@ -7845,23 +8290,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return CommonCost + GetGEPCostDiff(VL, VL0);
}
case Instruction::Load: {
- auto GetScalarCost = [=](unsigned Idx) {
- auto *VI = cast<LoadInst>(VL[Idx]);
+ auto GetScalarCost = [&](unsigned Idx) {
+ auto *VI = cast<LoadInst>(UniqueValues[Idx]);
return TTI->getMemoryOpCost(Instruction::Load, ScalarTy, VI->getAlign(),
VI->getPointerAddressSpace(), CostKind,
TTI::OperandValueInfo(), VI);
};
auto *LI0 = cast<LoadInst>(VL0);
- auto GetVectorCost = [=](InstructionCost CommonCost) {
+ auto GetVectorCost = [&](InstructionCost CommonCost) {
InstructionCost VecLdCost;
if (E->State == TreeEntry::Vectorize) {
VecLdCost = TTI->getMemoryOpCost(
Instruction::Load, VecTy, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo());
} else {
- assert(E->State == TreeEntry::ScatterVectorize && "Unknown EntryState");
+ assert((E->State == TreeEntry::ScatterVectorize ||
+ E->State == TreeEntry::PossibleStridedVectorize) &&
+ "Unknown EntryState");
Align CommonAlignment = LI0->getAlign();
- for (Value *V : VL)
+ for (Value *V : UniqueValues)
CommonAlignment =
std::min(CommonAlignment, cast<LoadInst>(V)->getAlign());
VecLdCost = TTI->getGatherScatterOpCost(
@@ -7874,7 +8321,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
InstructionCost Cost = GetCostDiff(GetScalarCost, GetVectorCost);
// If this node generates masked gather load then it is not a terminal node.
// Hence address operand cost is estimated separately.
- if (E->State == TreeEntry::ScatterVectorize)
+ if (E->State == TreeEntry::ScatterVectorize ||
+ E->State == TreeEntry::PossibleStridedVectorize)
return Cost;
// Estimate cost of GEPs since this tree node is a terminator.
@@ -7887,7 +8335,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
bool IsReorder = !E->ReorderIndices.empty();
auto GetScalarCost = [=](unsigned Idx) {
auto *VI = cast<StoreInst>(VL[Idx]);
- TTI::OperandValueInfo OpInfo = getOperandInfo(VI, 0);
+ TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(VI->getValueOperand());
return TTI->getMemoryOpCost(Instruction::Store, ScalarTy, VI->getAlign(),
VI->getPointerAddressSpace(), CostKind,
OpInfo, VI);
@@ -7896,7 +8344,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0);
auto GetVectorCost = [=](InstructionCost CommonCost) {
// We know that we can merge the stores. Calculate the cost.
- TTI::OperandValueInfo OpInfo = getOperandInfo(VL, 0);
+ TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0));
return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(),
BaseSI->getPointerAddressSpace(), CostKind,
OpInfo) +
@@ -7912,8 +8360,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
GetGEPCostDiff(PointerOps, BaseSI->getPointerOperand());
}
case Instruction::Call: {
- auto GetScalarCost = [=](unsigned Idx) {
- auto *CI = cast<CallInst>(VL[Idx]);
+ auto GetScalarCost = [&](unsigned Idx) {
+ auto *CI = cast<CallInst>(UniqueValues[Idx]);
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
if (ID != Intrinsic::not_intrinsic) {
IntrinsicCostAttributes CostAttrs(ID, *CI, 1);
@@ -7954,8 +8402,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
return false;
};
- auto GetScalarCost = [=](unsigned Idx) {
- auto *VI = cast<Instruction>(VL[Idx]);
+ auto GetScalarCost = [&](unsigned Idx) {
+ auto *VI = cast<Instruction>(UniqueValues[Idx]);
assert(E->isOpcodeOrAlt(VI) && "Unexpected main/alternate opcode");
(void)E;
return TTI->getInstructionCost(VI, CostKind);
@@ -7995,21 +8443,15 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
TTI::CastContextHint::None, CostKind);
}
- if (E->ReuseShuffleIndices.empty()) {
- VecCost +=
- TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy);
- } else {
- SmallVector<int> Mask;
- buildShuffleEntryMask(
- E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
- [E](Instruction *I) {
- assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
- return I->getOpcode() == E->getAltOpcode();
- },
- Mask);
- VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
- FinalVecTy, Mask);
- }
+ SmallVector<int> Mask;
+ E->buildAltOpShuffleMask(
+ [E](Instruction *I) {
+ assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
+ return I->getOpcode() == E->getAltOpcode();
+ },
+ Mask);
+ VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
+ FinalVecTy, Mask);
return VecCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -8065,7 +8507,8 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
// Gathering cost would be too much for tiny trees.
if (VectorizableTree[0]->State == TreeEntry::NeedToGather ||
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
- VectorizableTree[0]->State != TreeEntry::ScatterVectorize))
+ VectorizableTree[0]->State != TreeEntry::ScatterVectorize &&
+ VectorizableTree[0]->State != TreeEntry::PossibleStridedVectorize))
return false;
return true;
@@ -8144,6 +8587,23 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
allConstant(VectorizableTree[1]->Scalars))))
return true;
+ // If the graph includes only PHI nodes and gathers, it is defnitely not
+ // profitable for the vectorization, we can skip it, if the cost threshold is
+ // default. The cost of vectorized PHI nodes is almost always 0 + the cost of
+ // gathers/buildvectors.
+ constexpr int Limit = 4;
+ if (!ForReduction && !SLPCostThreshold.getNumOccurrences() &&
+ !VectorizableTree.empty() &&
+ all_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
+ return (TE->State == TreeEntry::NeedToGather &&
+ TE->getOpcode() != Instruction::ExtractElement &&
+ count_if(TE->Scalars,
+ [](Value *V) { return isa<ExtractElementInst>(V); }) <=
+ Limit) ||
+ TE->getOpcode() == Instruction::PHI;
+ }))
+ return true;
+
// We can vectorize the tree if its size is greater than or equal to the
// minimum size specified by the MinTreeSize command line option.
if (VectorizableTree.size() >= MinTreeSize)
@@ -8435,16 +8895,6 @@ static T *performExtractsShuffleAction(
}
InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
- // Build a map for gathered scalars to the nodes where they are used.
- ValueToGatherNodes.clear();
- for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) {
- if (EntryPtr->State != TreeEntry::NeedToGather)
- continue;
- for (Value *V : EntryPtr->Scalars)
- if (!isConstant(V))
- ValueToGatherNodes.try_emplace(V).first->getSecond().insert(
- EntryPtr.get());
- }
InstructionCost Cost = 0;
LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size "
<< VectorizableTree.size() << ".\n");
@@ -8460,8 +8910,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
E->isSame(TE.Scalars)) {
// Some gather nodes might be absolutely the same as some vectorizable
// nodes after reordering, need to handle it.
- LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle that starts with "
- << *TE.Scalars[0] << ".\n"
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle "
+ << shortBundleName(TE.Scalars) << ".\n"
<< "SLP: Current total cost = " << Cost << "\n");
continue;
}
@@ -8469,9 +8919,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts);
Cost += C;
- LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
- << " for bundle that starts with " << *TE.Scalars[0]
- << ".\n"
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle "
+ << shortBundleName(TE.Scalars) << ".\n"
<< "SLP: Current total cost = " << Cost << "\n");
}
@@ -8480,6 +8929,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
SmallVector<MapVector<const TreeEntry *, SmallVector<int>>> ShuffleMasks;
SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
SmallVector<APInt> DemandedElts;
+ SmallDenseSet<Value *, 4> UsedInserts;
+ DenseSet<Value *> VectorCasts;
for (ExternalUser &EU : ExternalUses) {
// We only add extract cost once for the same scalar.
if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -8500,6 +8951,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
// to detect it as a final shuffled/identity match.
if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User)) {
if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) {
+ if (!UsedInserts.insert(VU).second)
+ continue;
std::optional<unsigned> InsertIdx = getInsertIndex(VU);
if (InsertIdx) {
const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar);
@@ -8546,6 +8999,28 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
FirstUsers.emplace_back(VU, ScalarTE);
DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
VecId = FirstUsers.size() - 1;
+ auto It = MinBWs.find(ScalarTE);
+ if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
+ unsigned BWSz = It->second.second;
+ unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
+ unsigned VecOpcode;
+ if (BWSz < SrcBWSz)
+ VecOpcode = Instruction::Trunc;
+ else
+ VecOpcode =
+ It->second.second ? Instruction::SExt : Instruction::ZExt;
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost C = TTI->getCastInstrCost(
+ VecOpcode, FTy,
+ FixedVectorType::get(
+ IntegerType::get(FTy->getContext(), It->second.first),
+ FTy->getNumElements()),
+ TTI::CastContextHint::None, CostKind);
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
+ << " for extending externally used vector with "
+ "non-equal minimum bitwidth.\n");
+ Cost += C;
+ }
} else {
if (isFirstInsertElement(VU, cast<InsertElementInst>(It->first)))
It->first = VU;
@@ -8567,11 +9042,11 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
// for the extract and the added cost of the sign extend if needed.
auto *VecTy = FixedVectorType::get(EU.Scalar->getType(), BundleWidth);
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- auto *ScalarRoot = VectorizableTree[0]->Scalars[0];
- if (MinBWs.count(ScalarRoot)) {
- auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first);
- auto Extend =
- MinBWs[ScalarRoot].second ? Instruction::SExt : Instruction::ZExt;
+ auto It = MinBWs.find(getTreeEntry(EU.Scalar));
+ if (It != MinBWs.end()) {
+ auto *MinTy = IntegerType::get(F->getContext(), It->second.first);
+ unsigned Extend =
+ It->second.second ? Instruction::SExt : Instruction::ZExt;
VecTy = FixedVectorType::get(MinTy, BundleWidth);
ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
VecTy, EU.Lane);
@@ -8580,6 +9055,21 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
CostKind, EU.Lane);
}
}
+ // Add reduced value cost, if resized.
+ if (!VectorizedVals.empty()) {
+ auto BWIt = MinBWs.find(VectorizableTree.front().get());
+ if (BWIt != MinBWs.end()) {
+ Type *DstTy = VectorizableTree.front()->Scalars.front()->getType();
+ unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
+ unsigned Opcode = Instruction::Trunc;
+ if (OriginalSz < BWIt->second.first)
+ Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
+ Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+ Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
+ TTI::CastContextHint::None,
+ TTI::TCK_RecipThroughput);
+ }
+ }
InstructionCost SpillCost = getSpillCost();
Cost += SpillCost + ExtractCost;
@@ -8590,9 +9080,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
unsigned VecVF = TE->getVectorFactor();
if (VF != VecVF &&
(any_of(Mask, [VF](int Idx) { return Idx >= static_cast<int>(VF); }) ||
- (all_of(Mask,
- [VF](int Idx) { return Idx < 2 * static_cast<int>(VF); }) &&
- !ShuffleVectorInst::isIdentityMask(Mask)))) {
+ !ShuffleVectorInst::isIdentityMask(Mask, VF))) {
SmallVector<int> OrigMask(VecVF, PoisonMaskElem);
std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)),
OrigMask.begin());
@@ -8611,19 +9099,23 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
// Calculate the cost of the reshuffled vectors, if any.
for (int I = 0, E = FirstUsers.size(); I < E; ++I) {
Value *Base = cast<Instruction>(FirstUsers[I].first)->getOperand(0);
- unsigned VF = ShuffleMasks[I].begin()->second.size();
- auto *FTy = FixedVectorType::get(
- cast<VectorType>(FirstUsers[I].first->getType())->getElementType(), VF);
auto Vector = ShuffleMasks[I].takeVector();
- auto &&EstimateShufflesCost = [this, FTy,
- &Cost](ArrayRef<int> Mask,
- ArrayRef<const TreeEntry *> TEs) {
+ unsigned VF = 0;
+ auto EstimateShufflesCost = [&](ArrayRef<int> Mask,
+ ArrayRef<const TreeEntry *> TEs) {
assert((TEs.size() == 1 || TEs.size() == 2) &&
"Expected exactly 1 or 2 tree entries.");
if (TEs.size() == 1) {
- int Limit = 2 * Mask.size();
- if (!all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) ||
- !ShuffleVectorInst::isIdentityMask(Mask)) {
+ if (VF == 0)
+ VF = TEs.front()->getVectorFactor();
+ auto *FTy =
+ FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF);
+ if (!ShuffleVectorInst::isIdentityMask(Mask, VF) &&
+ !all_of(enumerate(Mask), [=](const auto &Data) {
+ return Data.value() == PoisonMaskElem ||
+ (Data.index() < VF &&
+ static_cast<int>(Data.index()) == Data.value());
+ })) {
InstructionCost C =
TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FTy, Mask);
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
@@ -8634,6 +9126,15 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
Cost += C;
}
} else {
+ if (VF == 0) {
+ if (TEs.front() &&
+ TEs.front()->getVectorFactor() == TEs.back()->getVectorFactor())
+ VF = TEs.front()->getVectorFactor();
+ else
+ VF = Mask.size();
+ }
+ auto *FTy =
+ FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF);
InstructionCost C =
TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, FTy, Mask);
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C
@@ -8643,6 +9144,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
dbgs() << "SLP: Current total cost = " << Cost << "\n");
Cost += C;
}
+ VF = Mask.size();
return TEs.back();
};
(void)performExtractsShuffleAction<const TreeEntry>(
@@ -8671,54 +9173,198 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
return Cost;
}
-std::optional<TargetTransformInfo::ShuffleKind>
-BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
- SmallVectorImpl<int> &Mask,
- SmallVectorImpl<const TreeEntry *> &Entries) {
- Entries.clear();
- // No need to check for the topmost gather node.
- if (TE == VectorizableTree.front().get())
+/// Tries to find extractelement instructions with constant indices from fixed
+/// vector type and gather such instructions into a bunch, which highly likely
+/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was
+/// successful, the matched scalars are replaced by poison values in \p VL for
+/// future analysis.
+std::optional<TTI::ShuffleKind>
+BoUpSLP::tryToGatherSingleRegisterExtractElements(
+ MutableArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) const {
+ // Scan list of gathered scalars for extractelements that can be represented
+ // as shuffles.
+ MapVector<Value *, SmallVector<int>> VectorOpToIdx;
+ SmallVector<int> UndefVectorExtracts;
+ for (int I = 0, E = VL.size(); I < E; ++I) {
+ auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
+ if (!EI) {
+ if (isa<UndefValue>(VL[I]))
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
+ if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand()))
+ continue;
+ std::optional<unsigned> Idx = getExtractIndex(EI);
+ // Undefined index.
+ if (!Idx) {
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ SmallBitVector ExtractMask(VecTy->getNumElements(), true);
+ ExtractMask.reset(*Idx);
+ if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) {
+ UndefVectorExtracts.push_back(I);
+ continue;
+ }
+ VectorOpToIdx[EI->getVectorOperand()].push_back(I);
+ }
+ // Sort the vector operands by the maximum number of uses in extractelements.
+ MapVector<unsigned, SmallVector<Value *>> VFToVector;
+ for (const auto &Data : VectorOpToIdx)
+ VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
+ .push_back(Data.first);
+ for (auto &Data : VFToVector) {
+ stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
+ return VectorOpToIdx.find(V1)->second.size() >
+ VectorOpToIdx.find(V2)->second.size();
+ });
+ }
+ // Find the best pair of the vectors with the same number of elements or a
+ // single vector.
+ const int UndefSz = UndefVectorExtracts.size();
+ unsigned SingleMax = 0;
+ Value *SingleVec = nullptr;
+ unsigned PairMax = 0;
+ std::pair<Value *, Value *> PairVec(nullptr, nullptr);
+ for (auto &Data : VFToVector) {
+ Value *V1 = Data.second.front();
+ if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
+ SingleMax = VectorOpToIdx[V1].size() + UndefSz;
+ SingleVec = V1;
+ }
+ Value *V2 = nullptr;
+ if (Data.second.size() > 1)
+ V2 = *std::next(Data.second.begin());
+ if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
+ UndefSz) {
+ PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
+ PairVec = std::make_pair(V1, V2);
+ }
+ }
+ if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
return std::nullopt;
+ // Check if better to perform a shuffle of 2 vectors or just of a single
+ // vector.
+ SmallVector<Value *> SavedVL(VL.begin(), VL.end());
+ SmallVector<Value *> GatheredExtracts(
+ VL.size(), PoisonValue::get(VL.front()->getType()));
+ if (SingleMax >= PairMax && SingleMax) {
+ for (int Idx : VectorOpToIdx[SingleVec])
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ } else {
+ for (Value *V : {PairVec.first, PairVec.second})
+ for (int Idx : VectorOpToIdx[V])
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ }
+ // Add extracts from undefs too.
+ for (int Idx : UndefVectorExtracts)
+ std::swap(GatheredExtracts[Idx], VL[Idx]);
+ // Check that gather of extractelements can be represented as just a
+ // shuffle of a single/two vectors the scalars are extracted from.
+ std::optional<TTI::ShuffleKind> Res =
+ isFixedVectorShuffle(GatheredExtracts, Mask);
+ if (!Res) {
+ // TODO: try to check other subsets if possible.
+ // Restore the original VL if attempt was not successful.
+ copy(SavedVL, VL.begin());
+ return std::nullopt;
+ }
+ // Restore unused scalars from mask, if some of the extractelements were not
+ // selected for shuffle.
+ for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) {
+ if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I]) &&
+ isa<UndefValue>(GatheredExtracts[I])) {
+ std::swap(VL[I], GatheredExtracts[I]);
+ continue;
+ }
+ auto *EI = dyn_cast<ExtractElementInst>(VL[I]);
+ if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) ||
+ !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) ||
+ is_contained(UndefVectorExtracts, I))
+ continue;
+ }
+ return Res;
+}
+
+/// Tries to find extractelement instructions with constant indices from fixed
+/// vector type and gather such instructions into a bunch, which highly likely
+/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was
+/// successful, the matched scalars are replaced by poison values in \p VL for
+/// future analysis.
+SmallVector<std::optional<TTI::ShuffleKind>>
+BoUpSLP::tryToGatherExtractElements(SmallVectorImpl<Value *> &VL,
+ SmallVectorImpl<int> &Mask,
+ unsigned NumParts) const {
+ assert(NumParts > 0 && "NumParts expected be greater than or equal to 1.");
+ SmallVector<std::optional<TTI::ShuffleKind>> ShufflesRes(NumParts);
Mask.assign(VL.size(), PoisonMaskElem);
- assert(TE->UserTreeIndices.size() == 1 &&
- "Expected only single user of the gather node.");
+ unsigned SliceSize = VL.size() / NumParts;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ // Scan list of gathered scalars for extractelements that can be represented
+ // as shuffles.
+ MutableArrayRef<Value *> SubVL =
+ MutableArrayRef(VL).slice(Part * SliceSize, SliceSize);
+ SmallVector<int> SubMask;
+ std::optional<TTI::ShuffleKind> Res =
+ tryToGatherSingleRegisterExtractElements(SubVL, SubMask);
+ ShufflesRes[Part] = Res;
+ copy(SubMask, std::next(Mask.begin(), Part * SliceSize));
+ }
+ if (none_of(ShufflesRes, [](const std::optional<TTI::ShuffleKind> &Res) {
+ return Res.has_value();
+ }))
+ ShufflesRes.clear();
+ return ShufflesRes;
+}
+
+std::optional<TargetTransformInfo::ShuffleKind>
+BoUpSLP::isGatherShuffledSingleRegisterEntry(
+ const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask,
+ SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part) {
+ Entries.clear();
// TODO: currently checking only for Scalars in the tree entry, need to count
// reused elements too for better cost estimation.
- Instruction &UserInst =
- getLastInstructionInBundle(TE->UserTreeIndices.front().UserTE);
- BasicBlock *ParentBB = nullptr;
+ const EdgeInfo &TEUseEI = TE->UserTreeIndices.front();
+ const Instruction *TEInsertPt = &getLastInstructionInBundle(TEUseEI.UserTE);
+ const BasicBlock *TEInsertBlock = nullptr;
// Main node of PHI entries keeps the correct order of operands/incoming
// blocks.
- if (auto *PHI =
- dyn_cast<PHINode>(TE->UserTreeIndices.front().UserTE->getMainOp())) {
- ParentBB = PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx);
+ if (auto *PHI = dyn_cast<PHINode>(TEUseEI.UserTE->getMainOp())) {
+ TEInsertBlock = PHI->getIncomingBlock(TEUseEI.EdgeIdx);
+ TEInsertPt = TEInsertBlock->getTerminator();
} else {
- ParentBB = UserInst.getParent();
+ TEInsertBlock = TEInsertPt->getParent();
}
- auto *NodeUI = DT->getNode(ParentBB);
+ auto *NodeUI = DT->getNode(TEInsertBlock);
assert(NodeUI && "Should only process reachable instructions");
SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end());
- auto CheckOrdering = [&](Instruction *LastEI) {
- // Check if the user node of the TE comes after user node of EntryPtr,
- // otherwise EntryPtr depends on TE.
- // Gather nodes usually are not scheduled and inserted before their first
- // user node. So, instead of checking dependency between the gather nodes
- // themselves, we check the dependency between their user nodes.
- // If one user node comes before the second one, we cannot use the second
- // gather node as the source vector for the first gather node, because in
- // the list of instructions it will be emitted later.
- auto *EntryParent = LastEI->getParent();
- auto *NodeEUI = DT->getNode(EntryParent);
+ auto CheckOrdering = [&](const Instruction *InsertPt) {
+ // Argument InsertPt is an instruction where vector code for some other
+ // tree entry (one that shares one or more scalars with TE) is going to be
+ // generated. This lambda returns true if insertion point of vector code
+ // for the TE dominates that point (otherwise dependency is the other way
+ // around). The other node is not limited to be of a gather kind. Gather
+ // nodes are not scheduled and their vector code is inserted before their
+ // first user. If user is PHI, that is supposed to be at the end of a
+ // predecessor block. Otherwise it is the last instruction among scalars of
+ // the user node. So, instead of checking dependency between instructions
+ // themselves, we check dependency between their insertion points for vector
+ // code (since each scalar instruction ends up as a lane of a vector
+ // instruction).
+ const BasicBlock *InsertBlock = InsertPt->getParent();
+ auto *NodeEUI = DT->getNode(InsertBlock);
if (!NodeEUI)
return false;
assert((NodeUI == NodeEUI) ==
(NodeUI->getDFSNumIn() == NodeEUI->getDFSNumIn()) &&
"Different nodes should have different DFS numbers");
// Check the order of the gather nodes users.
- if (UserInst.getParent() != EntryParent &&
+ if (TEInsertPt->getParent() != InsertBlock &&
(DT->dominates(NodeUI, NodeEUI) || !DT->dominates(NodeEUI, NodeUI)))
return false;
- if (UserInst.getParent() == EntryParent && UserInst.comesBefore(LastEI))
+ if (TEInsertPt->getParent() == InsertBlock &&
+ TEInsertPt->comesBefore(InsertPt))
return false;
return true;
};
@@ -8743,43 +9389,42 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
[&](Value *V) { return GatheredScalars.contains(V); }) &&
"Must contain at least single gathered value.");
assert(TEPtr->UserTreeIndices.size() == 1 &&
- "Expected only single user of the gather node.");
- PHINode *EntryPHI =
- dyn_cast<PHINode>(TEPtr->UserTreeIndices.front().UserTE->getMainOp());
- Instruction *EntryUserInst =
- EntryPHI ? nullptr
- : &getLastInstructionInBundle(
- TEPtr->UserTreeIndices.front().UserTE);
- if (&UserInst == EntryUserInst) {
- assert(!EntryPHI && "Unexpected phi node entry.");
- // If 2 gathers are operands of the same entry, compare operands
- // indices, use the earlier one as the base.
- if (TE->UserTreeIndices.front().UserTE ==
- TEPtr->UserTreeIndices.front().UserTE &&
- TE->UserTreeIndices.front().EdgeIdx <
- TEPtr->UserTreeIndices.front().EdgeIdx)
+ "Expected only single user of a gather node.");
+ const EdgeInfo &UseEI = TEPtr->UserTreeIndices.front();
+
+ PHINode *UserPHI = dyn_cast<PHINode>(UseEI.UserTE->getMainOp());
+ const Instruction *InsertPt =
+ UserPHI ? UserPHI->getIncomingBlock(UseEI.EdgeIdx)->getTerminator()
+ : &getLastInstructionInBundle(UseEI.UserTE);
+ if (TEInsertPt == InsertPt) {
+ // If 2 gathers are operands of the same entry (regardless of whether
+ // user is PHI or else), compare operands indices, use the earlier one
+ // as the base.
+ if (TEUseEI.UserTE == UseEI.UserTE && TEUseEI.EdgeIdx < UseEI.EdgeIdx)
+ continue;
+ // If the user instruction is used for some reason in different
+ // vectorized nodes - make it depend on index.
+ if (TEUseEI.UserTE != UseEI.UserTE &&
+ TEUseEI.UserTE->Idx < UseEI.UserTE->Idx)
continue;
}
- // Check if the user node of the TE comes after user node of EntryPtr,
- // otherwise EntryPtr depends on TE.
- auto *EntryI =
- EntryPHI
- ? EntryPHI
- ->getIncomingBlock(TEPtr->UserTreeIndices.front().EdgeIdx)
- ->getTerminator()
- : EntryUserInst;
- if ((ParentBB != EntryI->getParent() ||
- TE->UserTreeIndices.front().EdgeIdx <
- TEPtr->UserTreeIndices.front().EdgeIdx ||
- TE->UserTreeIndices.front().UserTE !=
- TEPtr->UserTreeIndices.front().UserTE) &&
- !CheckOrdering(EntryI))
+
+ // Check if the user node of the TE comes after user node of TEPtr,
+ // otherwise TEPtr depends on TE.
+ if ((TEInsertBlock != InsertPt->getParent() ||
+ TEUseEI.EdgeIdx < UseEI.EdgeIdx || TEUseEI.UserTE != UseEI.UserTE) &&
+ !CheckOrdering(InsertPt))
continue;
VToTEs.insert(TEPtr);
}
if (const TreeEntry *VTE = getTreeEntry(V)) {
- Instruction &EntryUserInst = getLastInstructionInBundle(VTE);
- if (&EntryUserInst == &UserInst || !CheckOrdering(&EntryUserInst))
+ Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
+ if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
+ continue;
+ auto It = MinBWs.find(VTE);
+ // If vectorize node is demoted - do not match.
+ if (It != MinBWs.end() &&
+ It->second.first != DL->getTypeSizeInBits(V->getType()))
continue;
VToTEs.insert(VTE);
}
@@ -8823,8 +9468,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
}
}
- if (UsedTEs.empty())
+ if (UsedTEs.empty()) {
+ Entries.clear();
return std::nullopt;
+ }
unsigned VF = 0;
if (UsedTEs.size() == 1) {
@@ -8838,9 +9485,19 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) {
return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars);
});
- if (It != FirstEntries.end() && (*It)->getVectorFactor() == VL.size()) {
+ if (It != FirstEntries.end() &&
+ ((*It)->getVectorFactor() == VL.size() ||
+ ((*It)->getVectorFactor() == TE->Scalars.size() &&
+ TE->ReuseShuffleIndices.size() == VL.size() &&
+ (*It)->isSame(TE->Scalars)))) {
Entries.push_back(*It);
- std::iota(Mask.begin(), Mask.end(), 0);
+ if ((*It)->getVectorFactor() == VL.size()) {
+ std::iota(std::next(Mask.begin(), Part * VL.size()),
+ std::next(Mask.begin(), (Part + 1) * VL.size()), 0);
+ } else {
+ SmallVector<int> CommonMask = TE->getCommonMask();
+ copy(CommonMask, Mask.begin());
+ }
// Clear undef scalars.
for (int I = 0, Sz = VL.size(); I < Sz; ++I)
if (isa<PoisonValue>(VL[I]))
@@ -8923,12 +9580,9 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
// by extractelements processing) or may form vector node in future.
auto MightBeIgnored = [=](Value *V) {
auto *I = dyn_cast<Instruction>(V);
- SmallVector<Value *> IgnoredVals;
- if (UserIgnoreList)
- IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end());
return I && !IsSplatOrUndefs && !ScalarToTreeEntry.count(I) &&
!isVectorLikeInstWithConstOps(I) &&
- !areAllUsersVectorized(I, IgnoredVals) && isSimple(I);
+ !areAllUsersVectorized(I, UserIgnoreList) && isSimple(I);
};
// Check that the neighbor instruction may form a full vector node with the
// current instruction V. It is possible, if they have same/alternate opcode
@@ -8980,7 +9634,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
TempEntries.push_back(Entries[I]);
}
Entries.swap(TempEntries);
- if (EntryLanes.size() == Entries.size() && !VL.equals(TE->Scalars)) {
+ if (EntryLanes.size() == Entries.size() &&
+ !VL.equals(ArrayRef(TE->Scalars)
+ .slice(Part * VL.size(),
+ std::min<int>(VL.size(), TE->Scalars.size())))) {
// We may have here 1 or 2 entries only. If the number of scalars is equal
// to the number of entries, no need to do the analysis, it is not very
// profitable. Since VL is not the same as TE->Scalars, it means we already
@@ -8993,9 +9650,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
// Pair.first is the offset to the vector, while Pair.second is the index of
// scalar in the list.
for (const std::pair<unsigned, int> &Pair : EntryLanes) {
- Mask[Pair.second] = Pair.first * VF +
- Entries[Pair.first]->findLaneForValue(VL[Pair.second]);
- IsIdentity &= Mask[Pair.second] == Pair.second;
+ unsigned Idx = Part * VL.size() + Pair.second;
+ Mask[Idx] = Pair.first * VF +
+ Entries[Pair.first]->findLaneForValue(VL[Pair.second]);
+ IsIdentity &= Mask[Idx] == Pair.second;
}
switch (Entries.size()) {
case 1:
@@ -9010,9 +9668,64 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL,
break;
}
Entries.clear();
+ // Clear the corresponding mask elements.
+ std::fill(std::next(Mask.begin(), Part * VL.size()),
+ std::next(Mask.begin(), (Part + 1) * VL.size()), PoisonMaskElem);
return std::nullopt;
}
+SmallVector<std::optional<TargetTransformInfo::ShuffleKind>>
+BoUpSLP::isGatherShuffledEntry(
+ const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask,
+ SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries,
+ unsigned NumParts) {
+ assert(NumParts > 0 && NumParts < VL.size() &&
+ "Expected positive number of registers.");
+ Entries.clear();
+ // No need to check for the topmost gather node.
+ if (TE == VectorizableTree.front().get())
+ return {};
+ Mask.assign(VL.size(), PoisonMaskElem);
+ assert(TE->UserTreeIndices.size() == 1 &&
+ "Expected only single user of the gather node.");
+ assert(VL.size() % NumParts == 0 &&
+ "Number of scalars must be divisible by NumParts.");
+ unsigned SliceSize = VL.size() / NumParts;
+ SmallVector<std::optional<TTI::ShuffleKind>> Res;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ArrayRef<Value *> SubVL = VL.slice(Part * SliceSize, SliceSize);
+ SmallVectorImpl<const TreeEntry *> &SubEntries = Entries.emplace_back();
+ std::optional<TTI::ShuffleKind> SubRes =
+ isGatherShuffledSingleRegisterEntry(TE, SubVL, Mask, SubEntries, Part);
+ if (!SubRes)
+ SubEntries.clear();
+ Res.push_back(SubRes);
+ if (SubEntries.size() == 1 && *SubRes == TTI::SK_PermuteSingleSrc &&
+ SubEntries.front()->getVectorFactor() == VL.size() &&
+ (SubEntries.front()->isSame(TE->Scalars) ||
+ SubEntries.front()->isSame(VL))) {
+ SmallVector<const TreeEntry *> LocalSubEntries;
+ LocalSubEntries.swap(SubEntries);
+ Entries.clear();
+ Res.clear();
+ std::iota(Mask.begin(), Mask.end(), 0);
+ // Clear undef scalars.
+ for (int I = 0, Sz = VL.size(); I < Sz; ++I)
+ if (isa<PoisonValue>(VL[I]))
+ Mask[I] = PoisonMaskElem;
+ Entries.emplace_back(1, LocalSubEntries.front());
+ Res.push_back(TargetTransformInfo::SK_PermuteSingleSrc);
+ return Res;
+ }
+ }
+ if (all_of(Res,
+ [](const std::optional<TTI::ShuffleKind> &SK) { return !SK; })) {
+ Entries.clear();
+ return {};
+ }
+ return Res;
+}
+
InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL,
bool ForPoisonSrc) const {
// Find the type of the operands in VL.
@@ -9224,18 +9937,20 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) {
auto *Front = E->getMainOp();
Instruction *LastInst = &getLastInstructionInBundle(E);
assert(LastInst && "Failed to find last instruction in bundle");
+ BasicBlock::iterator LastInstIt = LastInst->getIterator();
// If the instruction is PHI, set the insert point after all the PHIs.
bool IsPHI = isa<PHINode>(LastInst);
if (IsPHI)
- LastInst = LastInst->getParent()->getFirstNonPHI();
+ LastInstIt = LastInst->getParent()->getFirstNonPHIIt();
if (IsPHI || (E->State != TreeEntry::NeedToGather &&
doesNotNeedToSchedule(E->Scalars))) {
- Builder.SetInsertPoint(LastInst);
+ Builder.SetInsertPoint(LastInst->getParent(), LastInstIt);
} else {
// Set the insertion point after the last instruction in the bundle. Set the
// debug location to Front.
- Builder.SetInsertPoint(LastInst->getParent(),
- std::next(LastInst->getIterator()));
+ Builder.SetInsertPoint(
+ LastInst->getParent(),
+ LastInst->getNextNonDebugInstruction()->getIterator());
}
Builder.SetCurrentDebugLocation(Front->getDebugLoc());
}
@@ -9271,10 +9986,12 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) {
GatherShuffleExtractSeq.insert(InsElt);
CSEBlocks.insert(InsElt->getParent());
// Add to our 'need-to-extract' list.
- if (TreeEntry *Entry = getTreeEntry(V)) {
- // Find which lane we need to extract.
- unsigned FoundLane = Entry->findLaneForValue(V);
- ExternalUses.emplace_back(V, InsElt, FoundLane);
+ if (isa<Instruction>(V)) {
+ if (TreeEntry *Entry = getTreeEntry(V)) {
+ // Find which lane we need to extract.
+ unsigned FoundLane = Entry->findLaneForValue(V);
+ ExternalUses.emplace_back(V, InsElt, FoundLane);
+ }
}
return Vec;
};
@@ -9367,12 +10084,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
/// Holds all of the instructions that we gathered.
SetVector<Instruction *> &GatherShuffleExtractSeq;
/// A list of blocks that we are going to CSE.
- SetVector<BasicBlock *> &CSEBlocks;
+ DenseSet<BasicBlock *> &CSEBlocks;
public:
ShuffleIRBuilder(IRBuilderBase &Builder,
SetVector<Instruction *> &GatherShuffleExtractSeq,
- SetVector<BasicBlock *> &CSEBlocks)
+ DenseSet<BasicBlock *> &CSEBlocks)
: Builder(Builder), GatherShuffleExtractSeq(GatherShuffleExtractSeq),
CSEBlocks(CSEBlocks) {}
~ShuffleIRBuilder() = default;
@@ -9392,7 +10109,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
return V1;
unsigned VF = Mask.size();
unsigned LocalVF = cast<FixedVectorType>(V1->getType())->getNumElements();
- if (VF == LocalVF && ShuffleVectorInst::isIdentityMask(Mask))
+ if (VF == LocalVF && ShuffleVectorInst::isIdentityMask(Mask, VF))
return V1;
Value *Vec = Builder.CreateShuffleVector(V1, Mask);
if (auto *I = dyn_cast<Instruction>(Vec)) {
@@ -9455,7 +10172,11 @@ public:
: Builder(Builder), R(R) {}
/// Adjusts extractelements after reusing them.
- Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask) {
+ Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
+ ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds,
+ unsigned NumParts, bool &UseVecBaseAsInput) {
+ UseVecBaseAsInput = false;
+ SmallPtrSet<Value *, 4> UniqueBases;
Value *VecBase = nullptr;
for (int I = 0, Sz = Mask.size(); I < Sz; ++I) {
int Idx = Mask[I];
@@ -9463,6 +10184,10 @@ public:
continue;
auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
VecBase = EI->getVectorOperand();
+ if (const TreeEntry *TE = R.getTreeEntry(VecBase))
+ VecBase = TE->VectorizedValue;
+ assert(VecBase && "Expected vectorized value.");
+ UniqueBases.insert(VecBase);
// If the only one use is vectorized - can delete the extractelement
// itself.
if (!EI->hasOneUse() || any_of(EI->users(), [&](User *U) {
@@ -9471,14 +10196,97 @@ public:
continue;
R.eraseInstruction(EI);
}
- return VecBase;
+ if (NumParts == 1 || UniqueBases.size() == 1)
+ return VecBase;
+ UseVecBaseAsInput = true;
+ auto TransformToIdentity = [](MutableArrayRef<int> Mask) {
+ for (auto [I, Idx] : enumerate(Mask))
+ if (Idx != PoisonMaskElem)
+ Idx = I;
+ };
+ // Perform multi-register vector shuffle, joining them into a single virtual
+ // long vector.
+ // Need to shuffle each part independently and then insert all this parts
+ // into a long virtual vector register, forming the original vector.
+ Value *Vec = nullptr;
+ SmallVector<int> VecMask(Mask.size(), PoisonMaskElem);
+ unsigned SliceSize = E->Scalars.size() / NumParts;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ArrayRef<Value *> VL =
+ ArrayRef(E->Scalars).slice(Part * SliceSize, SliceSize);
+ MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize);
+ constexpr int MaxBases = 2;
+ SmallVector<Value *, MaxBases> Bases(MaxBases);
+#ifndef NDEBUG
+ int PrevSize = 0;
+#endif // NDEBUG
+ for (const auto [I, V]: enumerate(VL)) {
+ if (SubMask[I] == PoisonMaskElem)
+ continue;
+ Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand();
+ if (const TreeEntry *TE = R.getTreeEntry(VecOp))
+ VecOp = TE->VectorizedValue;
+ assert(VecOp && "Expected vectorized value.");
+ const int Size =
+ cast<FixedVectorType>(VecOp->getType())->getNumElements();
+#ifndef NDEBUG
+ assert((PrevSize == Size || PrevSize == 0) &&
+ "Expected vectors of the same size.");
+ PrevSize = Size;
+#endif // NDEBUG
+ Bases[SubMask[I] < Size ? 0 : 1] = VecOp;
+ }
+ if (!Bases.front())
+ continue;
+ Value *SubVec;
+ if (Bases.back()) {
+ SubVec = createShuffle(Bases.front(), Bases.back(), SubMask);
+ TransformToIdentity(SubMask);
+ } else {
+ SubVec = Bases.front();
+ }
+ if (!Vec) {
+ Vec = SubVec;
+ assert((Part == 0 || all_of(seq<unsigned>(0, Part),
+ [&](unsigned P) {
+ ArrayRef<int> SubMask =
+ Mask.slice(P * SliceSize, SliceSize);
+ return all_of(SubMask, [](int Idx) {
+ return Idx == PoisonMaskElem;
+ });
+ })) &&
+ "Expected first part or all previous parts masked.");
+ copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
+ } else {
+ unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements();
+ if (Vec->getType() != SubVec->getType()) {
+ unsigned SubVecVF =
+ cast<FixedVectorType>(SubVec->getType())->getNumElements();
+ VF = std::max(VF, SubVecVF);
+ }
+ // Adjust SubMask.
+ for (auto [I, Idx] : enumerate(SubMask))
+ if (Idx != PoisonMaskElem)
+ Idx += VF;
+ copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
+ Vec = createShuffle(Vec, SubVec, VecMask);
+ TransformToIdentity(VecMask);
+ }
+ }
+ copy(VecMask, Mask.begin());
+ return Vec;
}
/// Checks if the specified entry \p E needs to be delayed because of its
/// dependency nodes.
- Value *needToDelay(const TreeEntry *E, ArrayRef<const TreeEntry *> Deps) {
+ std::optional<Value *>
+ needToDelay(const TreeEntry *E,
+ ArrayRef<SmallVector<const TreeEntry *>> Deps) const {
// No need to delay emission if all deps are ready.
- if (all_of(Deps, [](const TreeEntry *TE) { return TE->VectorizedValue; }))
- return nullptr;
+ if (all_of(Deps, [](ArrayRef<const TreeEntry *> TEs) {
+ return all_of(
+ TEs, [](const TreeEntry *TE) { return TE->VectorizedValue; });
+ }))
+ return std::nullopt;
// Postpone gather emission, will be emitted after the end of the
// process to keep correct order.
auto *VecTy = FixedVectorType::get(E->Scalars.front()->getType(),
@@ -9487,6 +10295,16 @@ public:
VecTy, PoisonValue::get(PointerType::getUnqual(VecTy->getContext())),
MaybeAlign());
}
+ /// Adds 2 input vectors (in form of tree entries) and the mask for their
+ /// shuffling.
+ void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
+ add(E1.VectorizedValue, E2.VectorizedValue, Mask);
+ }
+ /// Adds single input vector (in form of tree entry) and the mask for its
+ /// shuffling.
+ void add(const TreeEntry &E1, ArrayRef<int> Mask) {
+ add(E1.VectorizedValue, Mask);
+ }
/// Adds 2 input vectors and the mask for their shuffling.
void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors.");
@@ -9516,7 +10334,7 @@ public:
InVectors.push_back(V1);
}
/// Adds another one input vector and the mask for the shuffling.
- void add(Value *V1, ArrayRef<int> Mask) {
+ void add(Value *V1, ArrayRef<int> Mask, bool = false) {
if (InVectors.empty()) {
if (!isa<FixedVectorType>(V1->getType())) {
V1 = createShuffle(V1, nullptr, CommonMask);
@@ -9578,7 +10396,8 @@ public:
inversePermutation(Order, NewMask);
add(V1, NewMask);
}
- Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
+ Value *gather(ArrayRef<Value *> VL, unsigned MaskVF = 0,
+ Value *Root = nullptr) {
return R.gather(VL, Root);
}
Value *createFreeze(Value *V) { return Builder.CreateFreeze(V); }
@@ -9639,8 +10458,14 @@ public:
}
};
-Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
- ArrayRef<Value *> VL = E->getOperand(NodeIdx);
+Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx,
+ bool PostponedPHIs) {
+ ValueList &VL = E->getOperand(NodeIdx);
+ if (E->State == TreeEntry::PossibleStridedVectorize &&
+ !E->ReorderIndices.empty()) {
+ SmallVector<int> Mask(E->ReorderIndices.begin(), E->ReorderIndices.end());
+ reorderScalars(VL, Mask);
+ }
const unsigned VF = VL.size();
InstructionsState S = getSameOpcode(VL, *TLI);
// Special processing for GEPs bundle, which may include non-gep values.
@@ -9651,23 +10476,39 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
S = getSameOpcode(*It, *TLI);
}
if (S.getOpcode()) {
- if (TreeEntry *VE = getTreeEntry(S.OpValue);
- VE && VE->isSame(VL) &&
- (any_of(VE->UserTreeIndices,
- [E, NodeIdx](const EdgeInfo &EI) {
- return EI.UserTE == E && EI.EdgeIdx == NodeIdx;
- }) ||
- any_of(VectorizableTree,
- [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) {
- return TE->isOperandGatherNode({E, NodeIdx}) &&
- VE->isSame(TE->Scalars);
- }))) {
+ auto CheckSameVE = [&](const TreeEntry *VE) {
+ return VE->isSame(VL) &&
+ (any_of(VE->UserTreeIndices,
+ [E, NodeIdx](const EdgeInfo &EI) {
+ return EI.UserTE == E && EI.EdgeIdx == NodeIdx;
+ }) ||
+ any_of(VectorizableTree,
+ [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) {
+ return TE->isOperandGatherNode({E, NodeIdx}) &&
+ VE->isSame(TE->Scalars);
+ }));
+ };
+ TreeEntry *VE = getTreeEntry(S.OpValue);
+ bool IsSameVE = VE && CheckSameVE(VE);
+ if (!IsSameVE) {
+ auto It = MultiNodeScalars.find(S.OpValue);
+ if (It != MultiNodeScalars.end()) {
+ auto *I = find_if(It->getSecond(), [&](const TreeEntry *TE) {
+ return TE != VE && CheckSameVE(TE);
+ });
+ if (I != It->getSecond().end()) {
+ VE = *I;
+ IsSameVE = true;
+ }
+ }
+ }
+ if (IsSameVE) {
auto FinalShuffle = [&](Value *V, ArrayRef<int> Mask) {
ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
ShuffleBuilder.add(V, Mask);
return ShuffleBuilder.finalize(std::nullopt);
};
- Value *V = vectorizeTree(VE);
+ Value *V = vectorizeTree(VE, PostponedPHIs);
if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) {
if (!VE->ReuseShuffleIndices.empty()) {
// Reshuffle to get only unique values.
@@ -9740,14 +10581,7 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) {
assert(I->get()->UserTreeIndices.size() == 1 &&
"Expected only single user for the gather node.");
assert(I->get()->isSame(VL) && "Expected same list of scalars.");
- IRBuilder<>::InsertPointGuard Guard(Builder);
- if (E->getOpcode() != Instruction::InsertElement &&
- E->getOpcode() != Instruction::PHI) {
- Instruction *LastInst = &getLastInstructionInBundle(E);
- assert(LastInst && "Failed to find last instruction in bundle");
- Builder.SetInsertPoint(LastInst);
- }
- return vectorizeTree(I->get());
+ return vectorizeTree(I->get(), PostponedPHIs);
}
template <typename BVTy, typename ResTy, typename... Args>
@@ -9765,7 +10599,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
inversePermutation(E->ReorderIndices, ReorderMask);
if (!ReorderMask.empty())
reorderScalars(GatheredScalars, ReorderMask);
- auto FindReusedSplat = [&](SmallVectorImpl<int> &Mask) {
+ auto FindReusedSplat = [&](MutableArrayRef<int> Mask, unsigned InputVF) {
if (!isSplat(E->Scalars) || none_of(E->Scalars, [](Value *V) {
return isa<UndefValue>(V) && !isa<PoisonValue>(V);
}))
@@ -9782,70 +10616,102 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
});
if (It == VectorizableTree.end())
return false;
- unsigned I =
- *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; });
- int Sz = Mask.size();
- if (all_of(Mask, [Sz](int Idx) { return Idx < 2 * Sz; }) &&
- ShuffleVectorInst::isIdentityMask(Mask))
+ int Idx;
+ if ((Mask.size() < InputVF &&
+ ShuffleVectorInst::isExtractSubvectorMask(Mask, InputVF, Idx) &&
+ Idx == 0) ||
+ (Mask.size() == InputVF &&
+ ShuffleVectorInst::isIdentityMask(Mask, Mask.size()))) {
std::iota(Mask.begin(), Mask.end(), 0);
- else
+ } else {
+ unsigned I =
+ *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; });
std::fill(Mask.begin(), Mask.end(), I);
+ }
return true;
};
BVTy ShuffleBuilder(Params...);
ResTy Res = ResTy();
SmallVector<int> Mask;
- SmallVector<int> ExtractMask;
- std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle;
- std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle;
- SmallVector<const TreeEntry *> Entries;
+ SmallVector<int> ExtractMask(GatheredScalars.size(), PoisonMaskElem);
+ SmallVector<std::optional<TTI::ShuffleKind>> ExtractShuffles;
+ Value *ExtractVecBase = nullptr;
+ bool UseVecBaseAsInput = false;
+ SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles;
+ SmallVector<SmallVector<const TreeEntry *>> Entries;
Type *ScalarTy = GatheredScalars.front()->getType();
+ auto *VecTy = FixedVectorType::get(ScalarTy, GatheredScalars.size());
+ unsigned NumParts = TTI->getNumberOfParts(VecTy);
+ if (NumParts == 0 || NumParts >= GatheredScalars.size())
+ NumParts = 1;
if (!all_of(GatheredScalars, UndefValue::classof)) {
// Check for gathered extracts.
- ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask);
- SmallVector<Value *> IgnoredVals;
- if (UserIgnoreList)
- IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end());
bool Resized = false;
- if (Value *VecBase = ShuffleBuilder.adjustExtracts(E, ExtractMask))
- if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
- if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) {
- Resized = true;
- GatheredScalars.append(VF - GatheredScalars.size(),
- PoisonValue::get(ScalarTy));
- }
+ ExtractShuffles =
+ tryToGatherExtractElements(GatheredScalars, ExtractMask, NumParts);
+ if (!ExtractShuffles.empty()) {
+ SmallVector<const TreeEntry *> ExtractEntries;
+ for (auto [Idx, I] : enumerate(ExtractMask)) {
+ if (I == PoisonMaskElem)
+ continue;
+ if (const auto *TE = getTreeEntry(
+ cast<ExtractElementInst>(E->Scalars[Idx])->getVectorOperand()))
+ ExtractEntries.push_back(TE);
+ }
+ if (std::optional<ResTy> Delayed =
+ ShuffleBuilder.needToDelay(E, ExtractEntries)) {
+ // Delay emission of gathers which are not ready yet.
+ PostponedGathers.insert(E);
+ // Postpone gather emission, will be emitted after the end of the
+ // process to keep correct order.
+ return *Delayed;
+ }
+ if (Value *VecBase = ShuffleBuilder.adjustExtracts(
+ E, ExtractMask, ExtractShuffles, NumParts, UseVecBaseAsInput)) {
+ ExtractVecBase = VecBase;
+ if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
+ if (VF == VecBaseTy->getNumElements() &&
+ GatheredScalars.size() != VF) {
+ Resized = true;
+ GatheredScalars.append(VF - GatheredScalars.size(),
+ PoisonValue::get(ScalarTy));
+ }
+ }
+ }
// Gather extracts after we check for full matched gathers only.
- if (ExtractShuffle || E->getOpcode() != Instruction::Load ||
+ if (!ExtractShuffles.empty() || E->getOpcode() != Instruction::Load ||
E->isAltShuffle() ||
all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) ||
isSplat(E->Scalars) ||
(E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) {
- GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries);
+ GatherShuffles =
+ isGatherShuffledEntry(E, GatheredScalars, Mask, Entries, NumParts);
}
- if (GatherShuffle) {
- if (Value *Delayed = ShuffleBuilder.needToDelay(E, Entries)) {
+ if (!GatherShuffles.empty()) {
+ if (std::optional<ResTy> Delayed =
+ ShuffleBuilder.needToDelay(E, Entries)) {
// Delay emission of gathers which are not ready yet.
PostponedGathers.insert(E);
// Postpone gather emission, will be emitted after the end of the
// process to keep correct order.
- return Delayed;
+ return *Delayed;
}
- assert((Entries.size() == 1 || Entries.size() == 2) &&
- "Expected shuffle of 1 or 2 entries.");
- if (*GatherShuffle == TTI::SK_PermuteSingleSrc &&
- Entries.front()->isSame(E->Scalars)) {
+ if (GatherShuffles.size() == 1 &&
+ *GatherShuffles.front() == TTI::SK_PermuteSingleSrc &&
+ Entries.front().front()->isSame(E->Scalars)) {
// Perfect match in the graph, will reuse the previously vectorized
// node. Cost is 0.
LLVM_DEBUG(
dbgs()
- << "SLP: perfect diamond match for gather bundle that starts with "
- << *E->Scalars.front() << ".\n");
+ << "SLP: perfect diamond match for gather bundle "
+ << shortBundleName(E->Scalars) << ".\n");
// Restore the mask for previous partially matched values.
- if (Entries.front()->ReorderIndices.empty() &&
- ((Entries.front()->ReuseShuffleIndices.empty() &&
- E->Scalars.size() == Entries.front()->Scalars.size()) ||
- (E->Scalars.size() ==
- Entries.front()->ReuseShuffleIndices.size()))) {
+ Mask.resize(E->Scalars.size());
+ const TreeEntry *FrontTE = Entries.front().front();
+ if (FrontTE->ReorderIndices.empty() &&
+ ((FrontTE->ReuseShuffleIndices.empty() &&
+ E->Scalars.size() == FrontTE->Scalars.size()) ||
+ (E->Scalars.size() == FrontTE->ReuseShuffleIndices.size()))) {
std::iota(Mask.begin(), Mask.end(), 0);
} else {
for (auto [I, V] : enumerate(E->Scalars)) {
@@ -9853,17 +10719,20 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
Mask[I] = PoisonMaskElem;
continue;
}
- Mask[I] = Entries.front()->findLaneForValue(V);
+ Mask[I] = FrontTE->findLaneForValue(V);
}
}
- ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask);
+ ShuffleBuilder.add(*FrontTE, Mask);
Res = ShuffleBuilder.finalize(E->getCommonMask());
return Res;
}
if (!Resized) {
- unsigned VF1 = Entries.front()->getVectorFactor();
- unsigned VF2 = Entries.back()->getVectorFactor();
- if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF)
+ if (GatheredScalars.size() != VF &&
+ any_of(Entries, [&](ArrayRef<const TreeEntry *> TEs) {
+ return any_of(TEs, [&](const TreeEntry *TE) {
+ return TE->getVectorFactor() == VF;
+ });
+ }))
GatheredScalars.append(VF - GatheredScalars.size(),
PoisonValue::get(ScalarTy));
}
@@ -9943,78 +10812,108 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
if (It != Scalars.end()) {
// Replace undefs by the non-poisoned scalars and emit broadcast.
int Pos = std::distance(Scalars.begin(), It);
- for_each(UndefPos, [&](int I) {
+ for (int I : UndefPos) {
// Set the undef position to the non-poisoned scalar.
ReuseMask[I] = Pos;
// Replace the undef by the poison, in the mask it is replaced by
// non-poisoned scalar already.
if (I != Pos)
Scalars[I] = PoisonValue::get(ScalarTy);
- });
+ }
} else {
// Replace undefs by the poisons, emit broadcast and then emit
// freeze.
- for_each(UndefPos, [&](int I) {
+ for (int I : UndefPos) {
ReuseMask[I] = PoisonMaskElem;
if (isa<UndefValue>(Scalars[I]))
Scalars[I] = PoisonValue::get(ScalarTy);
- });
+ }
NeedFreeze = true;
}
}
};
- if (ExtractShuffle || GatherShuffle) {
+ if (!ExtractShuffles.empty() || !GatherShuffles.empty()) {
bool IsNonPoisoned = true;
- bool IsUsedInExpr = false;
+ bool IsUsedInExpr = true;
Value *Vec1 = nullptr;
- if (ExtractShuffle) {
+ if (!ExtractShuffles.empty()) {
// Gather of extractelements can be represented as just a shuffle of
// a single/two vectors the scalars are extracted from.
// Find input vectors.
Value *Vec2 = nullptr;
for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
- if (ExtractMask[I] == PoisonMaskElem ||
- (!Mask.empty() && Mask[I] != PoisonMaskElem)) {
+ if (!Mask.empty() && Mask[I] != PoisonMaskElem)
ExtractMask[I] = PoisonMaskElem;
- continue;
- }
- if (isa<UndefValue>(E->Scalars[I]))
- continue;
- auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
- if (!Vec1) {
- Vec1 = EI->getVectorOperand();
- } else if (Vec1 != EI->getVectorOperand()) {
- assert((!Vec2 || Vec2 == EI->getVectorOperand()) &&
- "Expected only 1 or 2 vectors shuffle.");
- Vec2 = EI->getVectorOperand();
+ }
+ if (UseVecBaseAsInput) {
+ Vec1 = ExtractVecBase;
+ } else {
+ for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
+ if (ExtractMask[I] == PoisonMaskElem)
+ continue;
+ if (isa<UndefValue>(E->Scalars[I]))
+ continue;
+ auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
+ Value *VecOp = EI->getVectorOperand();
+ if (const auto *TE = getTreeEntry(VecOp))
+ if (TE->VectorizedValue)
+ VecOp = TE->VectorizedValue;
+ if (!Vec1) {
+ Vec1 = VecOp;
+ } else if (Vec1 != EI->getVectorOperand()) {
+ assert((!Vec2 || Vec2 == EI->getVectorOperand()) &&
+ "Expected only 1 or 2 vectors shuffle.");
+ Vec2 = VecOp;
+ }
}
}
if (Vec2) {
+ IsUsedInExpr = false;
IsNonPoisoned &=
isGuaranteedNotToBePoison(Vec1) && isGuaranteedNotToBePoison(Vec2);
ShuffleBuilder.add(Vec1, Vec2, ExtractMask);
} else if (Vec1) {
- IsUsedInExpr = FindReusedSplat(ExtractMask);
- ShuffleBuilder.add(Vec1, ExtractMask);
+ IsUsedInExpr &= FindReusedSplat(
+ ExtractMask,
+ cast<FixedVectorType>(Vec1->getType())->getNumElements());
+ ShuffleBuilder.add(Vec1, ExtractMask, /*ForExtracts=*/true);
IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1);
} else {
+ IsUsedInExpr = false;
ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get(
ScalarTy, GatheredScalars.size())),
- ExtractMask);
+ ExtractMask, /*ForExtracts=*/true);
}
}
- if (GatherShuffle) {
- if (Entries.size() == 1) {
- IsUsedInExpr = FindReusedSplat(Mask);
- ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask);
- IsNonPoisoned &=
- isGuaranteedNotToBePoison(Entries.front()->VectorizedValue);
- } else {
- ShuffleBuilder.add(Entries.front()->VectorizedValue,
- Entries.back()->VectorizedValue, Mask);
- IsNonPoisoned &=
- isGuaranteedNotToBePoison(Entries.front()->VectorizedValue) &&
- isGuaranteedNotToBePoison(Entries.back()->VectorizedValue);
+ if (!GatherShuffles.empty()) {
+ unsigned SliceSize = E->Scalars.size() / NumParts;
+ SmallVector<int> VecMask(Mask.size(), PoisonMaskElem);
+ for (const auto [I, TEs] : enumerate(Entries)) {
+ if (TEs.empty()) {
+ assert(!GatherShuffles[I] &&
+ "No shuffles with empty entries list expected.");
+ continue;
+ }
+ assert((TEs.size() == 1 || TEs.size() == 2) &&
+ "Expected shuffle of 1 or 2 entries.");
+ auto SubMask = ArrayRef(Mask).slice(I * SliceSize, SliceSize);
+ VecMask.assign(VecMask.size(), PoisonMaskElem);
+ copy(SubMask, std::next(VecMask.begin(), I * SliceSize));
+ if (TEs.size() == 1) {
+ IsUsedInExpr &=
+ FindReusedSplat(VecMask, TEs.front()->getVectorFactor());
+ ShuffleBuilder.add(*TEs.front(), VecMask);
+ if (TEs.front()->VectorizedValue)
+ IsNonPoisoned &=
+ isGuaranteedNotToBePoison(TEs.front()->VectorizedValue);
+ } else {
+ IsUsedInExpr = false;
+ ShuffleBuilder.add(*TEs.front(), *TEs.back(), VecMask);
+ if (TEs.front()->VectorizedValue && TEs.back()->VectorizedValue)
+ IsNonPoisoned &=
+ isGuaranteedNotToBePoison(TEs.front()->VectorizedValue) &&
+ isGuaranteedNotToBePoison(TEs.back()->VectorizedValue);
+ }
}
}
// Try to figure out best way to combine values: build a shuffle and insert
@@ -10025,16 +10924,24 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
int MSz = Mask.size();
// Try to build constant vector and shuffle with it only if currently we
// have a single permutation and more than 1 scalar constants.
- bool IsSingleShuffle = !ExtractShuffle || !GatherShuffle;
+ bool IsSingleShuffle = ExtractShuffles.empty() || GatherShuffles.empty();
bool IsIdentityShuffle =
- (ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc) ==
- TTI::SK_PermuteSingleSrc &&
+ ((UseVecBaseAsInput ||
+ all_of(ExtractShuffles,
+ [](const std::optional<TTI::ShuffleKind> &SK) {
+ return SK.value_or(TTI::SK_PermuteTwoSrc) ==
+ TTI::SK_PermuteSingleSrc;
+ })) &&
none_of(ExtractMask, [&](int I) { return I >= EMSz; }) &&
- ShuffleVectorInst::isIdentityMask(ExtractMask)) ||
- (GatherShuffle.value_or(TTI::SK_PermuteTwoSrc) ==
- TTI::SK_PermuteSingleSrc &&
+ ShuffleVectorInst::isIdentityMask(ExtractMask, EMSz)) ||
+ (!GatherShuffles.empty() &&
+ all_of(GatherShuffles,
+ [](const std::optional<TTI::ShuffleKind> &SK) {
+ return SK.value_or(TTI::SK_PermuteTwoSrc) ==
+ TTI::SK_PermuteSingleSrc;
+ }) &&
none_of(Mask, [&](int I) { return I >= MSz; }) &&
- ShuffleVectorInst::isIdentityMask(Mask));
+ ShuffleVectorInst::isIdentityMask(Mask, MSz));
bool EnoughConstsForShuffle =
IsSingleShuffle &&
(none_of(GatheredScalars,
@@ -10064,7 +10971,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
if (!all_of(GatheredScalars, PoisonValue::classof)) {
SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem);
TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true);
- Value *BV = ShuffleBuilder.gather(GatheredScalars);
+ Value *BV = ShuffleBuilder.gather(GatheredScalars, BVMask.size());
ShuffleBuilder.add(BV, BVMask);
}
if (all_of(NonConstants, [=](Value *V) {
@@ -10078,13 +10985,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
E->ReuseShuffleIndices, E->Scalars.size(),
[&](Value *&Vec, SmallVectorImpl<int> &Mask) {
TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
- Vec = ShuffleBuilder.gather(NonConstants, Vec);
+ Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
});
} else if (!allConstant(GatheredScalars)) {
// Gather unique scalars and all constants.
SmallVector<int> ReuseMask(GatheredScalars.size(), PoisonMaskElem);
TryPackScalars(GatheredScalars, ReuseMask, /*IsRootPoison=*/true);
- Value *BV = ShuffleBuilder.gather(GatheredScalars);
+ Value *BV = ShuffleBuilder.gather(GatheredScalars, ReuseMask.size());
ShuffleBuilder.add(BV, ReuseMask);
Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
} else {
@@ -10109,29 +11016,37 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
*this);
}
-Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
+Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
IRBuilder<>::InsertPointGuard Guard(Builder);
- if (E->VectorizedValue) {
+ if (E->VectorizedValue &&
+ (E->State != TreeEntry::Vectorize || E->getOpcode() != Instruction::PHI ||
+ E->isAltShuffle())) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *E->Scalars[0] << ".\n");
return E->VectorizedValue;
}
if (E->State == TreeEntry::NeedToGather) {
- if (E->getMainOp() && E->Idx == 0)
+ // Set insert point for non-reduction initial nodes.
+ if (E->getMainOp() && E->Idx == 0 && !UserIgnoreList)
setInsertPointAfterBundle(E);
Value *Vec = createBuildVector(E);
E->VectorizedValue = Vec;
return Vec;
}
- auto FinalShuffle = [&](Value *V, const TreeEntry *E) {
+ auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy,
+ bool IsSigned) {
+ if (V->getType() != VecTy)
+ V = Builder.CreateIntCast(V, VecTy, IsSigned);
ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
if (E->getOpcode() == Instruction::Store) {
ArrayRef<int> Mask =
ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()),
E->ReorderIndices.size());
ShuffleBuilder.add(V, Mask);
+ } else if (E->State == TreeEntry::PossibleStridedVectorize) {
+ ShuffleBuilder.addOrdered(V, std::nullopt);
} else {
ShuffleBuilder.addOrdered(V, E->ReorderIndices);
}
@@ -10139,7 +11054,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
};
assert((E->State == TreeEntry::Vectorize ||
- E->State == TreeEntry::ScatterVectorize) &&
+ E->State == TreeEntry::ScatterVectorize ||
+ E->State == TreeEntry::PossibleStridedVectorize) &&
"Unhandled state");
unsigned ShuffleOrOp =
E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
@@ -10149,6 +11065,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
ScalarTy = Store->getValueOperand()->getType();
else if (auto *IE = dyn_cast<InsertElementInst>(VL0))
ScalarTy = IE->getOperand(1)->getType();
+ bool IsSigned = false;
+ auto It = MinBWs.find(E);
+ if (It != MinBWs.end()) {
+ ScalarTy = IntegerType::get(F->getContext(), It->second.first);
+ IsSigned = It->second.second;
+ }
auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size());
switch (ShuffleOrOp) {
case Instruction::PHI: {
@@ -10156,32 +11078,45 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
E != VectorizableTree.front().get() ||
!E->UserTreeIndices.empty()) &&
"PHI reordering is free.");
+ if (PostponedPHIs && E->VectorizedValue)
+ return E->VectorizedValue;
auto *PH = cast<PHINode>(VL0);
- Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI());
+ Builder.SetInsertPoint(PH->getParent(),
+ PH->getParent()->getFirstNonPHIIt());
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
- PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
- Value *V = NewPhi;
+ if (PostponedPHIs || !E->VectorizedValue) {
+ PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
+ E->PHI = NewPhi;
+ Value *V = NewPhi;
- // Adjust insertion point once all PHI's have been generated.
- Builder.SetInsertPoint(&*PH->getParent()->getFirstInsertionPt());
- Builder.SetCurrentDebugLocation(PH->getDebugLoc());
+ // Adjust insertion point once all PHI's have been generated.
+ Builder.SetInsertPoint(PH->getParent(),
+ PH->getParent()->getFirstInsertionPt());
+ Builder.SetCurrentDebugLocation(PH->getDebugLoc());
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
- E->VectorizedValue = V;
+ E->VectorizedValue = V;
+ if (PostponedPHIs)
+ return V;
+ }
+ PHINode *NewPhi = cast<PHINode>(E->PHI);
+ // If phi node is fully emitted - exit.
+ if (NewPhi->getNumIncomingValues() != 0)
+ return NewPhi;
// PHINodes may have multiple entries from the same block. We want to
// visit every block once.
SmallPtrSet<BasicBlock *, 4> VisitedBBs;
- for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
+ for (unsigned I : seq<unsigned>(0, PH->getNumIncomingValues())) {
ValueList Operands;
- BasicBlock *IBB = PH->getIncomingBlock(i);
+ BasicBlock *IBB = PH->getIncomingBlock(I);
// Stop emission if all incoming values are generated.
if (NewPhi->getNumIncomingValues() == PH->getNumIncomingValues()) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
- return V;
+ return NewPhi;
}
if (!VisitedBBs.insert(IBB).second) {
@@ -10191,37 +11126,54 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Builder.SetInsertPoint(IBB->getTerminator());
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
- Value *Vec = vectorizeOperand(E, i);
+ Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true);
+ if (VecTy != Vec->getType()) {
+ assert(MinBWs.contains(getOperandEntry(E, I)) &&
+ "Expected item in MinBWs.");
+ Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
+ }
NewPhi->addIncoming(Vec, IBB);
}
assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() &&
"Invalid number of incoming values");
- return V;
+ return NewPhi;
}
case Instruction::ExtractElement: {
Value *V = E->getSingleOperand(0);
setInsertPointAfterBundle(E);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
return V;
}
case Instruction::ExtractValue: {
auto *LI = cast<LoadInst>(E->getSingleOperand(0));
Builder.SetInsertPoint(LI);
- auto *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace());
- Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy);
+ Value *Ptr = LI->getPointerOperand();
LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign());
Value *NewV = propagateMetadata(V, E->Scalars);
- NewV = FinalShuffle(NewV, E);
+ NewV = FinalShuffle(NewV, E, VecTy, IsSigned);
E->VectorizedValue = NewV;
return NewV;
}
case Instruction::InsertElement: {
assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique");
Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back()));
- Value *V = vectorizeOperand(E, 1);
+ Value *V = vectorizeOperand(E, 1, PostponedPHIs);
+ ArrayRef<Value *> Op = E->getOperand(1);
+ Type *ScalarTy = Op.front()->getType();
+ if (cast<VectorType>(V->getType())->getElementType() != ScalarTy) {
+ assert(ScalarTy->isIntegerTy() && "Expected item in MinBWs.");
+ std::pair<unsigned, bool> Res = MinBWs.lookup(getOperandEntry(E, 1));
+ assert(Res.first > 0 && "Expected item in MinBWs.");
+ V = Builder.CreateIntCast(
+ V,
+ FixedVectorType::get(
+ ScalarTy,
+ cast<FixedVectorType>(V->getType())->getNumElements()),
+ Res.second);
+ }
// Create InsertVector shuffle if necessary
auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) {
@@ -10254,7 +11206,57 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Mask[InsertIdx - Offset] = I;
}
if (!IsIdentity || NumElts != NumScalars) {
- V = Builder.CreateShuffleVector(V, Mask);
+ Value *V2 = nullptr;
+ bool IsVNonPoisonous = isGuaranteedNotToBePoison(V) && !isConstant(V);
+ SmallVector<int> InsertMask(Mask);
+ if (NumElts != NumScalars && Offset == 0) {
+ // Follow all insert element instructions from the current buildvector
+ // sequence.
+ InsertElementInst *Ins = cast<InsertElementInst>(VL0);
+ do {
+ std::optional<unsigned> InsertIdx = getInsertIndex(Ins);
+ if (!InsertIdx)
+ break;
+ if (InsertMask[*InsertIdx] == PoisonMaskElem)
+ InsertMask[*InsertIdx] = *InsertIdx;
+ if (!Ins->hasOneUse())
+ break;
+ Ins = dyn_cast_or_null<InsertElementInst>(
+ Ins->getUniqueUndroppableUser());
+ } while (Ins);
+ SmallBitVector UseMask =
+ buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask);
+ SmallBitVector IsFirstPoison =
+ isUndefVector<true>(FirstInsert->getOperand(0), UseMask);
+ SmallBitVector IsFirstUndef =
+ isUndefVector(FirstInsert->getOperand(0), UseMask);
+ if (!IsFirstPoison.all()) {
+ unsigned Idx = 0;
+ for (unsigned I = 0; I < NumElts; I++) {
+ if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I) &&
+ IsFirstUndef.test(I)) {
+ if (IsVNonPoisonous) {
+ InsertMask[I] = I < NumScalars ? I : 0;
+ continue;
+ }
+ if (!V2)
+ V2 = UndefValue::get(V->getType());
+ if (Idx >= NumScalars)
+ Idx = NumScalars - 1;
+ InsertMask[I] = NumScalars + Idx;
+ ++Idx;
+ } else if (InsertMask[I] != PoisonMaskElem &&
+ Mask[I] == PoisonMaskElem) {
+ InsertMask[I] = PoisonMaskElem;
+ }
+ }
+ } else {
+ InsertMask = Mask;
+ }
+ }
+ if (!V2)
+ V2 = PoisonValue::get(V->getType());
+ V = Builder.CreateShuffleVector(V, V2, InsertMask);
if (auto *I = dyn_cast<Instruction>(V)) {
GatherShuffleExtractSeq.insert(I);
CSEBlocks.insert(I->getParent());
@@ -10273,15 +11275,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
if ((!IsIdentity || Offset != 0 || !IsFirstUndef.all()) &&
NumElts != NumScalars) {
if (IsFirstUndef.all()) {
- if (!ShuffleVectorInst::isIdentityMask(InsertMask)) {
- SmallBitVector IsFirstPoison =
- isUndefVector<true>(FirstInsert->getOperand(0), UseMask);
- if (!IsFirstPoison.all()) {
- for (unsigned I = 0; I < NumElts; I++) {
- if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I))
- InsertMask[I] = I + NumElts;
+ if (!ShuffleVectorInst::isIdentityMask(InsertMask, NumElts)) {
+ SmallBitVector IsFirstPoison =
+ isUndefVector<true>(FirstInsert->getOperand(0), UseMask);
+ if (!IsFirstPoison.all()) {
+ for (unsigned I = 0; I < NumElts; I++) {
+ if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I))
+ InsertMask[I] = I + NumElts;
+ }
}
- }
V = Builder.CreateShuffleVector(
V,
IsFirstPoison.all() ? PoisonValue::get(V->getType())
@@ -10329,15 +11331,36 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::BitCast: {
setInsertPointAfterBundle(E);
- Value *InVec = vectorizeOperand(E, 0);
+ Value *InVec = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
auto *CI = cast<CastInst>(VL0);
- Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
- V = FinalShuffle(V, E);
+ Instruction::CastOps VecOpcode = CI->getOpcode();
+ Type *SrcScalarTy = VL0->getOperand(0)->getType();
+ auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
+ if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
+ (SrcIt != MinBWs.end() || It != MinBWs.end())) {
+ // Check if the values are candidates to demote.
+ unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+ if (SrcIt != MinBWs.end())
+ SrcBWSz = SrcIt->second.first;
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ if (BWSz == SrcBWSz) {
+ VecOpcode = Instruction::BitCast;
+ } else if (BWSz < SrcBWSz) {
+ VecOpcode = Instruction::Trunc;
+ } else if (It != MinBWs.end()) {
+ assert(BWSz > SrcBWSz && "Invalid cast!");
+ VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+ }
+ }
+ Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
+ ? InVec
+ : Builder.CreateCast(VecOpcode, InVec, VecTy);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10347,21 +11370,30 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::ICmp: {
setInsertPointAfterBundle(E);
- Value *L = vectorizeOperand(E, 0);
+ Value *L = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- Value *R = vectorizeOperand(E, 1);
+ Value *R = vectorizeOperand(E, 1, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ if (L->getType() != R->getType()) {
+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
+ MinBWs.contains(getOperandEntry(E, 1))) &&
+ "Expected item in MinBWs.");
+ L = Builder.CreateIntCast(L, VecTy, IsSigned);
+ R = Builder.CreateIntCast(R, VecTy, IsSigned);
+ }
CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate();
Value *V = Builder.CreateCmp(P0, L, R);
propagateIRFlags(V, E->Scalars, VL0);
- V = FinalShuffle(V, E);
+ // Do not cast for cmps.
+ VecTy = cast<FixedVectorType>(V->getType());
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10370,24 +11402,31 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::Select: {
setInsertPointAfterBundle(E);
- Value *Cond = vectorizeOperand(E, 0);
+ Value *Cond = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- Value *True = vectorizeOperand(E, 1);
+ Value *True = vectorizeOperand(E, 1, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- Value *False = vectorizeOperand(E, 2);
+ Value *False = vectorizeOperand(E, 2, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ if (True->getType() != False->getType()) {
+ assert((MinBWs.contains(getOperandEntry(E, 1)) ||
+ MinBWs.contains(getOperandEntry(E, 2))) &&
+ "Expected item in MinBWs.");
+ True = Builder.CreateIntCast(True, VecTy, IsSigned);
+ False = Builder.CreateIntCast(False, VecTy, IsSigned);
+ }
Value *V = Builder.CreateSelect(Cond, True, False);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10396,7 +11435,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::FNeg: {
setInsertPointAfterBundle(E);
- Value *Op = vectorizeOperand(E, 0);
+ Value *Op = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
@@ -10409,7 +11448,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
if (auto *I = dyn_cast<Instruction>(V))
V = propagateMetadata(I, E->Scalars);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10436,16 +11475,23 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
case Instruction::Xor: {
setInsertPointAfterBundle(E);
- Value *LHS = vectorizeOperand(E, 0);
+ Value *LHS = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- Value *RHS = vectorizeOperand(E, 1);
+ Value *RHS = vectorizeOperand(E, 1, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ if (LHS->getType() != RHS->getType()) {
+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
+ MinBWs.contains(getOperandEntry(E, 1))) &&
+ "Expected item in MinBWs.");
+ LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+ RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
+ }
Value *V = Builder.CreateBinOp(
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
@@ -10454,7 +11500,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
if (auto *I = dyn_cast<Instruction>(V))
V = propagateMetadata(I, E->Scalars);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10475,14 +11521,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
// The pointer operand uses an in-tree scalar so we add the new
// LoadInst to ExternalUses list to make sure that an extract will
// be generated in the future.
- if (TreeEntry *Entry = getTreeEntry(PO)) {
- // Find which lane we need to extract.
- unsigned FoundLane = Entry->findLaneForValue(PO);
- ExternalUses.emplace_back(PO, NewLI, FoundLane);
+ if (isa<Instruction>(PO)) {
+ if (TreeEntry *Entry = getTreeEntry(PO)) {
+ // Find which lane we need to extract.
+ unsigned FoundLane = Entry->findLaneForValue(PO);
+ ExternalUses.emplace_back(PO, NewLI, FoundLane);
+ }
}
} else {
- assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state");
- Value *VecPtr = vectorizeOperand(E, 0);
+ assert((E->State == TreeEntry::ScatterVectorize ||
+ E->State == TreeEntry::PossibleStridedVectorize) &&
+ "Unhandled state");
+ Value *VecPtr = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
@@ -10496,35 +11546,32 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
Value *V = propagateMetadata(NewLI, E->Scalars);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
return V;
}
case Instruction::Store: {
auto *SI = cast<StoreInst>(VL0);
- unsigned AS = SI->getPointerAddressSpace();
setInsertPointAfterBundle(E);
- Value *VecValue = vectorizeOperand(E, 0);
- VecValue = FinalShuffle(VecValue, E);
+ Value *VecValue = vectorizeOperand(E, 0, PostponedPHIs);
+ VecValue = FinalShuffle(VecValue, E, VecTy, IsSigned);
- Value *ScalarPtr = SI->getPointerOperand();
- Value *VecPtr = Builder.CreateBitCast(
- ScalarPtr, VecValue->getType()->getPointerTo(AS));
+ Value *Ptr = SI->getPointerOperand();
StoreInst *ST =
- Builder.CreateAlignedStore(VecValue, VecPtr, SI->getAlign());
+ Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
- // The pointer operand uses an in-tree scalar, so add the new BitCast or
- // StoreInst to ExternalUses to make sure that an extract will be
- // generated in the future.
- if (TreeEntry *Entry = getTreeEntry(ScalarPtr)) {
- // Find which lane we need to extract.
- unsigned FoundLane = Entry->findLaneForValue(ScalarPtr);
- ExternalUses.push_back(ExternalUser(
- ScalarPtr, ScalarPtr != VecPtr ? cast<User>(VecPtr) : ST,
- FoundLane));
+ // The pointer operand uses an in-tree scalar, so add the new StoreInst to
+ // ExternalUses to make sure that an extract will be generated in the
+ // future.
+ if (isa<Instruction>(Ptr)) {
+ if (TreeEntry *Entry = getTreeEntry(Ptr)) {
+ // Find which lane we need to extract.
+ unsigned FoundLane = Entry->findLaneForValue(Ptr);
+ ExternalUses.push_back(ExternalUser(Ptr, ST, FoundLane));
+ }
}
Value *V = propagateMetadata(ST, E->Scalars);
@@ -10537,7 +11584,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
auto *GEP0 = cast<GetElementPtrInst>(VL0);
setInsertPointAfterBundle(E);
- Value *Op0 = vectorizeOperand(E, 0);
+ Value *Op0 = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
@@ -10545,7 +11592,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
SmallVector<Value *> OpVecs;
for (int J = 1, N = GEP0->getNumOperands(); J < N; ++J) {
- Value *OpVec = vectorizeOperand(E, J);
+ Value *OpVec = vectorizeOperand(E, J, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
@@ -10563,7 +11610,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
V = propagateMetadata(I, GEPs);
}
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10585,41 +11632,42 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
VecCallCosts.first <= VecCallCosts.second;
Value *ScalarArg = nullptr;
- std::vector<Value *> OpVecs;
+ SmallVector<Value *> OpVecs;
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
TysForDecl.push_back(
FixedVectorType::get(CI->getType(), E->Scalars.size()));
- for (int j = 0, e = CI->arg_size(); j < e; ++j) {
+ for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
ValueList OpVL;
// Some intrinsics have scalar arguments. This argument should not be
// vectorized.
- if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(IID, j)) {
+ if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(IID, I)) {
CallInst *CEI = cast<CallInst>(VL0);
- ScalarArg = CEI->getArgOperand(j);
- OpVecs.push_back(CEI->getArgOperand(j));
- if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j))
+ ScalarArg = CEI->getArgOperand(I);
+ OpVecs.push_back(CEI->getArgOperand(I));
+ if (isVectorIntrinsicWithOverloadTypeAtArg(IID, I))
TysForDecl.push_back(ScalarArg->getType());
continue;
}
- Value *OpVec = vectorizeOperand(E, j);
+ Value *OpVec = vectorizeOperand(E, I, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n");
+ LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
OpVecs.push_back(OpVec);
- if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j))
+ if (isVectorIntrinsicWithOverloadTypeAtArg(IID, I))
TysForDecl.push_back(OpVec->getType());
}
Function *CF;
if (!UseIntrinsic) {
VFShape Shape =
- VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
- VecTy->getNumElements())),
+ VFShape::get(CI->getFunctionType(),
+ ElementCount::getFixed(
+ static_cast<unsigned>(VecTy->getNumElements())),
false /*HasGlobalPred*/);
CF = VFDatabase(*CI).getVectorizedFunction(Shape);
} else {
@@ -10633,7 +11681,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
// The scalar argument uses an in-tree scalar so we add the new vectorized
// call to ExternalUses list to make sure that an extract will be
// generated in the future.
- if (ScalarArg) {
+ if (isa_and_present<Instruction>(ScalarArg)) {
if (TreeEntry *Entry = getTreeEntry(ScalarArg)) {
// Find which lane we need to extract.
unsigned FoundLane = Entry->findLaneForValue(ScalarArg);
@@ -10643,7 +11691,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
}
propagateIRFlags(V, E->Scalars, VL0);
- V = FinalShuffle(V, E);
+ V = FinalShuffle(V, E, VecTy, IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10661,20 +11709,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
Value *LHS = nullptr, *RHS = nullptr;
if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) {
setInsertPointAfterBundle(E);
- LHS = vectorizeOperand(E, 0);
+ LHS = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- RHS = vectorizeOperand(E, 1);
+ RHS = vectorizeOperand(E, 1, PostponedPHIs);
} else {
setInsertPointAfterBundle(E);
- LHS = vectorizeOperand(E, 0);
+ LHS = vectorizeOperand(E, 0, PostponedPHIs);
}
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ if (LHS && RHS && LHS->getType() != RHS->getType()) {
+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
+ MinBWs.contains(getOperandEntry(E, 1))) &&
+ "Expected item in MinBWs.");
+ LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
+ RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
+ }
Value *V0, *V1;
if (Instruction::isBinaryOp(E->getOpcode())) {
@@ -10707,8 +11762,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
// each vector operation.
ValueList OpScalars, AltScalars;
SmallVector<int> Mask;
- buildShuffleEntryMask(
- E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
+ E->buildAltOpShuffleMask(
[E, this](Instruction *I) {
assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
return isAlternateInstruction(I, E->getMainOp(), E->getAltOp(),
@@ -10726,6 +11780,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
CSEBlocks.insert(I->getParent());
}
+ if (V->getType() != VecTy && !isa<CmpInst>(VL0))
+ V = Builder.CreateIntCast(
+ V, FixedVectorType::get(ScalarTy, E->getVectorFactor()), IsSigned);
E->VectorizedValue = V;
++NumVectorInstructions;
@@ -10766,9 +11823,19 @@ Value *BoUpSLP::vectorizeTree(
// need to rebuild it.
EntryToLastInstruction.clear();
- Builder.SetInsertPoint(ReductionRoot ? ReductionRoot
- : &F->getEntryBlock().front());
- auto *VectorRoot = vectorizeTree(VectorizableTree[0].get());
+ if (ReductionRoot)
+ Builder.SetInsertPoint(ReductionRoot->getParent(),
+ ReductionRoot->getIterator());
+ else
+ Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin());
+
+ // Postpone emission of PHIs operands to avoid cyclic dependencies issues.
+ (void)vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true);
+ for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree)
+ if (TE->State == TreeEntry::Vectorize &&
+ TE->getOpcode() == Instruction::PHI && !TE->isAltShuffle() &&
+ TE->VectorizedValue)
+ (void)vectorizeTree(TE.get(), /*PostponedPHIs=*/false);
// Run through the list of postponed gathers and emit them, replacing the temp
// emitted allocas with actual vector instructions.
ArrayRef<const TreeEntry *> PostponedNodes = PostponedGathers.getArrayRef();
@@ -10785,9 +11852,32 @@ Value *BoUpSLP::vectorizeTree(
TE->VectorizedValue = nullptr;
auto *UserI =
cast<Instruction>(TE->UserTreeIndices.front().UserTE->VectorizedValue);
- Builder.SetInsertPoint(PrevVec);
+ // If user is a PHI node, its vector code have to be inserted right before
+ // block terminator. Since the node was delayed, there were some unresolved
+ // dependencies at the moment when stab instruction was emitted. In a case
+ // when any of these dependencies turn out an operand of another PHI, coming
+ // from this same block, position of a stab instruction will become invalid.
+ // The is because source vector that supposed to feed this gather node was
+ // inserted at the end of the block [after stab instruction]. So we need
+ // to adjust insertion point again to the end of block.
+ if (isa<PHINode>(UserI)) {
+ // Insert before all users.
+ Instruction *InsertPt = PrevVec->getParent()->getTerminator();
+ for (User *U : PrevVec->users()) {
+ if (U == UserI)
+ continue;
+ auto *UI = dyn_cast<Instruction>(U);
+ if (!UI || isa<PHINode>(UI) || UI->getParent() != InsertPt->getParent())
+ continue;
+ if (UI->comesBefore(InsertPt))
+ InsertPt = UI;
+ }
+ Builder.SetInsertPoint(InsertPt);
+ } else {
+ Builder.SetInsertPoint(PrevVec);
+ }
Builder.SetCurrentDebugLocation(UserI->getDebugLoc());
- Value *Vec = vectorizeTree(TE);
+ Value *Vec = vectorizeTree(TE, /*PostponedPHIs=*/false);
PrevVec->replaceAllUsesWith(Vec);
PostponedValues.try_emplace(Vec).first->second.push_back(TE);
// Replace the stub vector node, if it was used before for one of the
@@ -10800,26 +11890,6 @@ Value *BoUpSLP::vectorizeTree(
eraseInstruction(PrevVec);
}
- // If the vectorized tree can be rewritten in a smaller type, we truncate the
- // vectorized root. InstCombine will then rewrite the entire expression. We
- // sign extend the extracted values below.
- auto *ScalarRoot = VectorizableTree[0]->Scalars[0];
- if (MinBWs.count(ScalarRoot)) {
- if (auto *I = dyn_cast<Instruction>(VectorRoot)) {
- // If current instr is a phi and not the last phi, insert it after the
- // last phi node.
- if (isa<PHINode>(I))
- Builder.SetInsertPoint(&*I->getParent()->getFirstInsertionPt());
- else
- Builder.SetInsertPoint(&*++BasicBlock::iterator(I));
- }
- auto BundleWidth = VectorizableTree[0]->Scalars.size();
- auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first);
- auto *VecTy = FixedVectorType::get(MinTy, BundleWidth);
- auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy);
- VectorizableTree[0]->VectorizedValue = Trunc;
- }
-
LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size()
<< " values .\n");
@@ -10829,6 +11899,8 @@ Value *BoUpSLP::vectorizeTree(
// Maps extract Scalar to the corresponding extractelement instruction in the
// basic block. Only one extractelement per block should be emitted.
DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
+ SmallDenseSet<Value *, 4> UsedInserts;
+ DenseMap<Value *, Value *> VectorCasts;
// Extract all of the elements with the external uses.
for (const auto &ExternalUse : ExternalUses) {
Value *Scalar = ExternalUse.Scalar;
@@ -10863,7 +11935,8 @@ Value *BoUpSLP::vectorizeTree(
Instruction *I = EEIt->second;
if (Builder.GetInsertPoint() != Builder.GetInsertBlock()->end() &&
Builder.GetInsertPoint()->comesBefore(I))
- I->moveBefore(&*Builder.GetInsertPoint());
+ I->moveBefore(*Builder.GetInsertPoint()->getParent(),
+ Builder.GetInsertPoint());
Ex = I;
}
}
@@ -10886,11 +11959,10 @@ Value *BoUpSLP::vectorizeTree(
}
// If necessary, sign-extend or zero-extend ScalarRoot
// to the larger type.
- if (!MinBWs.count(ScalarRoot))
- return Ex;
- if (MinBWs[ScalarRoot].second)
- return Builder.CreateSExt(Ex, Scalar->getType());
- return Builder.CreateZExt(Ex, Scalar->getType());
+ if (Scalar->getType() != Ex->getType())
+ return Builder.CreateIntCast(Ex, Scalar->getType(),
+ MinBWs.find(E)->second.second);
+ return Ex;
}
assert(isa<FixedVectorType>(Scalar->getType()) &&
isa<InsertElementInst>(Scalar) &&
@@ -10908,12 +11980,13 @@ Value *BoUpSLP::vectorizeTree(
"ExternallyUsedValues map");
if (auto *VecI = dyn_cast<Instruction>(Vec)) {
if (auto *PHI = dyn_cast<PHINode>(VecI))
- Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI());
+ Builder.SetInsertPoint(PHI->getParent(),
+ PHI->getParent()->getFirstNonPHIIt());
else
Builder.SetInsertPoint(VecI->getParent(),
std::next(VecI->getIterator()));
} else {
- Builder.SetInsertPoint(&F->getEntryBlock().front());
+ Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin());
}
Value *NewInst = ExtractAndExtendIfNeeded(Vec);
// Required to update internally referenced instructions.
@@ -10926,12 +11999,26 @@ Value *BoUpSLP::vectorizeTree(
// Skip if the scalar is another vector op or Vec is not an instruction.
if (!Scalar->getType()->isVectorTy() && isa<Instruction>(Vec)) {
if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) {
+ if (!UsedInserts.insert(VU).second)
+ continue;
+ // Need to use original vector, if the root is truncated.
+ auto BWIt = MinBWs.find(E);
+ if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
+ auto VecIt = VectorCasts.find(Scalar);
+ if (VecIt == VectorCasts.end()) {
+ IRBuilder<>::InsertPointGuard Guard(Builder);
+ if (auto *IVec = dyn_cast<Instruction>(Vec))
+ Builder.SetInsertPoint(IVec->getNextNonDebugInstruction());
+ Vec = Builder.CreateIntCast(Vec, VU->getType(),
+ BWIt->second.second);
+ VectorCasts.try_emplace(Scalar, Vec);
+ } else {
+ Vec = VecIt->second;
+ }
+ }
+
std::optional<unsigned> InsertIdx = getInsertIndex(VU);
if (InsertIdx) {
- // Need to use original vector, if the root is truncated.
- if (MinBWs.count(Scalar) &&
- VectorizableTree[0]->VectorizedValue == Vec)
- Vec = VectorRoot;
auto *It =
find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) {
// Checks if 2 insertelements are from the same buildvector.
@@ -10991,18 +12078,18 @@ Value *BoUpSLP::vectorizeTree(
// Find the insertion point for the extractelement lane.
if (auto *VecI = dyn_cast<Instruction>(Vec)) {
if (PHINode *PH = dyn_cast<PHINode>(User)) {
- for (int i = 0, e = PH->getNumIncomingValues(); i != e; ++i) {
- if (PH->getIncomingValue(i) == Scalar) {
+ for (unsigned I : seq<unsigned>(0, PH->getNumIncomingValues())) {
+ if (PH->getIncomingValue(I) == Scalar) {
Instruction *IncomingTerminator =
- PH->getIncomingBlock(i)->getTerminator();
+ PH->getIncomingBlock(I)->getTerminator();
if (isa<CatchSwitchInst>(IncomingTerminator)) {
Builder.SetInsertPoint(VecI->getParent(),
std::next(VecI->getIterator()));
} else {
- Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator());
+ Builder.SetInsertPoint(PH->getIncomingBlock(I)->getTerminator());
}
Value *NewInst = ExtractAndExtendIfNeeded(Vec);
- PH->setOperand(i, NewInst);
+ PH->setOperand(I, NewInst);
}
}
} else {
@@ -11011,7 +12098,7 @@ Value *BoUpSLP::vectorizeTree(
User->replaceUsesOfWith(Scalar, NewInst);
}
} else {
- Builder.SetInsertPoint(&F->getEntryBlock().front());
+ Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin());
Value *NewInst = ExtractAndExtendIfNeeded(Vec);
User->replaceUsesOfWith(Scalar, NewInst);
}
@@ -11084,7 +12171,7 @@ Value *BoUpSLP::vectorizeTree(
// non-resizing mask.
if (Mask.size() != cast<FixedVectorType>(Vals.front()->getType())
->getNumElements() ||
- !ShuffleVectorInst::isIdentityMask(Mask))
+ !ShuffleVectorInst::isIdentityMask(Mask, Mask.size()))
return CreateShuffle(Vals.front(), nullptr, Mask);
return Vals.front();
}
@@ -11675,7 +12762,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
}
}
- auto makeControlDependent = [&](Instruction *I) {
+ auto MakeControlDependent = [&](Instruction *I) {
auto *DepDest = getScheduleData(I);
assert(DepDest && "must be in schedule window");
DepDest->ControlDependencies.push_back(BundleMember);
@@ -11697,7 +12784,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
continue;
// Add the dependency
- makeControlDependent(I);
+ MakeControlDependent(I);
if (!isGuaranteedToTransferExecutionToSuccessor(I))
// Everything past here must be control dependent on I.
@@ -11723,7 +12810,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
continue;
// Add the dependency
- makeControlDependent(I);
+ MakeControlDependent(I);
}
}
@@ -11741,7 +12828,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
continue;
// Add the dependency
- makeControlDependent(I);
+ MakeControlDependent(I);
break;
}
}
@@ -11756,7 +12843,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
"NextLoadStore list for non memory effecting bundle?");
MemoryLocation SrcLoc = getLocation(SrcInst);
bool SrcMayWrite = BundleMember->Inst->mayWriteToMemory();
- unsigned numAliased = 0;
+ unsigned NumAliased = 0;
unsigned DistToSrc = 1;
for (; DepDest; DepDest = DepDest->NextLoadStore) {
@@ -11771,13 +12858,13 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD,
// check this limit even between two read-only instructions.
if (DistToSrc >= MaxMemDepDistance ||
((SrcMayWrite || DepDest->Inst->mayWriteToMemory()) &&
- (numAliased >= AliasedCheckLimit ||
+ (NumAliased >= AliasedCheckLimit ||
SLP->isAliased(SrcLoc, SrcInst, DepDest->Inst)))) {
// We increment the counter only if the locations are aliased
// (instead of counting all alias checks). This gives a better
// balance between reduced runtime and accurate dependencies.
- numAliased++;
+ NumAliased++;
DepDest->MemoryDependencies.push_back(BundleMember);
BundleMember->Dependencies++;
@@ -11879,20 +12966,20 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) {
// Do the "real" scheduling.
while (!ReadyInsts.empty()) {
- ScheduleData *picked = *ReadyInsts.begin();
+ ScheduleData *Picked = *ReadyInsts.begin();
ReadyInsts.erase(ReadyInsts.begin());
// Move the scheduled instruction(s) to their dedicated places, if not
// there yet.
- for (ScheduleData *BundleMember = picked; BundleMember;
+ for (ScheduleData *BundleMember = Picked; BundleMember;
BundleMember = BundleMember->NextInBundle) {
- Instruction *pickedInst = BundleMember->Inst;
- if (pickedInst->getNextNode() != LastScheduledInst)
- pickedInst->moveBefore(LastScheduledInst);
- LastScheduledInst = pickedInst;
+ Instruction *PickedInst = BundleMember->Inst;
+ if (PickedInst->getNextNode() != LastScheduledInst)
+ PickedInst->moveBefore(LastScheduledInst);
+ LastScheduledInst = PickedInst;
}
- BS->schedule(picked, ReadyInsts);
+ BS->schedule(Picked, ReadyInsts);
}
// Check that we didn't break any of our invariants.
@@ -11993,21 +13080,22 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
// Determine if a value V in a vectorizable expression Expr can be demoted to a
// smaller type with a truncation. We collect the values that will be demoted
// in ToDemote and additional roots that require investigating in Roots.
-static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr,
- SmallVectorImpl<Value *> &ToDemote,
- SmallVectorImpl<Value *> &Roots) {
+bool BoUpSLP::collectValuesToDemote(
+ Value *V, SmallVectorImpl<Value *> &ToDemote,
+ DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
+ SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const {
// We can always demote constants.
- if (isa<Constant>(V)) {
- ToDemote.push_back(V);
+ if (isa<Constant>(V))
return true;
- }
- // If the value is not an instruction in the expression with only one use, it
- // cannot be demoted.
+ // If the value is not a vectorized instruction in the expression with only
+ // one use, it cannot be demoted.
auto *I = dyn_cast<Instruction>(V);
- if (!I || !I->hasOneUse() || !Expr.count(I))
+ if (!I || !I->hasOneUse() || !getTreeEntry(I) || !Visited.insert(I).second)
return false;
+ unsigned Start = 0;
+ unsigned End = I->getNumOperands();
switch (I->getOpcode()) {
// We can always demote truncations and extensions. Since truncations can
@@ -12029,16 +13117,21 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr,
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
- if (!collectValuesToDemote(I->getOperand(0), Expr, ToDemote, Roots) ||
- !collectValuesToDemote(I->getOperand(1), Expr, ToDemote, Roots))
+ if (!collectValuesToDemote(I->getOperand(0), ToDemote, DemotedConsts, Roots,
+ Visited) ||
+ !collectValuesToDemote(I->getOperand(1), ToDemote, DemotedConsts, Roots,
+ Visited))
return false;
break;
// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
+ Start = 1;
SelectInst *SI = cast<SelectInst>(I);
- if (!collectValuesToDemote(SI->getTrueValue(), Expr, ToDemote, Roots) ||
- !collectValuesToDemote(SI->getFalseValue(), Expr, ToDemote, Roots))
+ if (!collectValuesToDemote(SI->getTrueValue(), ToDemote, DemotedConsts,
+ Roots, Visited) ||
+ !collectValuesToDemote(SI->getFalseValue(), ToDemote, DemotedConsts,
+ Roots, Visited))
return false;
break;
}
@@ -12048,7 +13141,8 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr,
case Instruction::PHI: {
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!collectValuesToDemote(IncValue, Expr, ToDemote, Roots))
+ if (!collectValuesToDemote(IncValue, ToDemote, DemotedConsts, Roots,
+ Visited))
return false;
break;
}
@@ -12058,6 +13152,10 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr,
return false;
}
+ // Gather demoted constant operands.
+ for (unsigned Idx : seq<unsigned>(Start, End))
+ if (isa<Constant>(I->getOperand(Idx)))
+ DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx);
// Record the value that we can demote.
ToDemote.push_back(V);
return true;
@@ -12075,44 +13173,26 @@ void BoUpSLP::computeMinimumValueSizes() {
if (!TreeRootIT)
return;
- // If the expression is not rooted by a store, these roots should have
- // external uses. We will rely on InstCombine to rewrite the expression in
- // the narrower type. However, InstCombine only rewrites single-use values.
- // This means that if a tree entry other than a root is used externally, it
- // must have multiple uses and InstCombine will not rewrite it. The code
- // below ensures that only the roots are used externally.
- SmallPtrSet<Value *, 32> Expr(TreeRoot.begin(), TreeRoot.end());
- for (auto &EU : ExternalUses)
- if (!Expr.erase(EU.Scalar))
- return;
- if (!Expr.empty())
+ // Ensure the roots of the vectorizable tree don't form a cycle.
+ if (!VectorizableTree.front()->UserTreeIndices.empty())
return;
- // Collect the scalar values of the vectorizable expression. We will use this
- // context to determine which values can be demoted. If we see a truncation,
- // we mark it as seeding another demotion.
- for (auto &EntryPtr : VectorizableTree)
- Expr.insert(EntryPtr->Scalars.begin(), EntryPtr->Scalars.end());
-
- // Ensure the roots of the vectorizable tree don't form a cycle. They must
- // have a single external user that is not in the vectorizable tree.
- for (auto *Root : TreeRoot)
- if (!Root->hasOneUse() || Expr.count(*Root->user_begin()))
- return;
-
// Conservatively determine if we can actually truncate the roots of the
// expression. Collect the values that can be demoted in ToDemote and
// additional roots that require investigating in Roots.
SmallVector<Value *, 32> ToDemote;
+ DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts;
SmallVector<Value *, 4> Roots;
- for (auto *Root : TreeRoot)
- if (!collectValuesToDemote(Root, Expr, ToDemote, Roots))
+ for (auto *Root : TreeRoot) {
+ DenseSet<Value *> Visited;
+ if (!collectValuesToDemote(Root, ToDemote, DemotedConsts, Roots, Visited))
return;
+ }
// The maximum bit width required to represent all the values that can be
// demoted without loss of precision. It would be safe to truncate the roots
// of the expression to this width.
- auto MaxBitWidth = 8u;
+ auto MaxBitWidth = 1u;
// We first check if all the bits of the roots are demanded. If they're not,
// we can truncate the roots to this narrower type.
@@ -12137,9 +13217,9 @@ void BoUpSLP::computeMinimumValueSizes() {
// maximum bit width required to store the scalar by using ValueTracking to
// compute the number of high-order bits we can truncate.
if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType()) &&
- llvm::all_of(TreeRoot, [](Value *R) {
- assert(R->hasOneUse() && "Root should have only one use!");
- return isa<GetElementPtrInst>(R->user_back());
+ all_of(TreeRoot, [](Value *V) {
+ return all_of(V->users(),
+ [](User *U) { return isa<GetElementPtrInst>(U); });
})) {
MaxBitWidth = 8u;
@@ -12188,12 +13268,39 @@ void BoUpSLP::computeMinimumValueSizes() {
// If we can truncate the root, we must collect additional values that might
// be demoted as a result. That is, those seeded by truncations we will
// modify.
- while (!Roots.empty())
- collectValuesToDemote(Roots.pop_back_val(), Expr, ToDemote, Roots);
+ while (!Roots.empty()) {
+ DenseSet<Value *> Visited;
+ collectValuesToDemote(Roots.pop_back_val(), ToDemote, DemotedConsts, Roots,
+ Visited);
+ }
// Finally, map the values we can demote to the maximum bit with we computed.
- for (auto *Scalar : ToDemote)
- MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive);
+ for (auto *Scalar : ToDemote) {
+ auto *TE = getTreeEntry(Scalar);
+ assert(TE && "Expected vectorized scalar.");
+ if (MinBWs.contains(TE))
+ continue;
+ bool IsSigned = any_of(TE->Scalars, [&](Value *R) {
+ KnownBits Known = computeKnownBits(R, *DL);
+ return !Known.isNonNegative();
+ });
+ MinBWs.try_emplace(TE, MaxBitWidth, IsSigned);
+ const auto *I = cast<Instruction>(Scalar);
+ auto DCIt = DemotedConsts.find(I);
+ if (DCIt != DemotedConsts.end()) {
+ for (unsigned Idx : DCIt->getSecond()) {
+ // Check that all instructions operands are demoted.
+ if (all_of(TE->Scalars, [&](Value *V) {
+ auto SIt = DemotedConsts.find(cast<Instruction>(V));
+ return SIt != DemotedConsts.end() &&
+ is_contained(SIt->getSecond(), Idx);
+ })) {
+ const TreeEntry *CTE = getOperandEntry(TE, Idx);
+ MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned);
+ }
+ }
+ }
+ }
}
PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) {
@@ -12347,139 +13454,206 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores,
BoUpSLP::ValueSet VectorizedStores;
bool Changed = false;
- int E = Stores.size();
- SmallBitVector Tails(E, false);
- int MaxIter = MaxStoreLookup.getValue();
- SmallVector<std::pair<int, int>, 16> ConsecutiveChain(
- E, std::make_pair(E, INT_MAX));
- SmallVector<SmallBitVector, 4> CheckedPairs(E, SmallBitVector(E, false));
- int IterCnt;
- auto &&FindConsecutiveAccess = [this, &Stores, &Tails, &IterCnt, MaxIter,
- &CheckedPairs,
- &ConsecutiveChain](int K, int Idx) {
- if (IterCnt >= MaxIter)
- return true;
- if (CheckedPairs[Idx].test(K))
- return ConsecutiveChain[K].second == 1 &&
- ConsecutiveChain[K].first == Idx;
- ++IterCnt;
- CheckedPairs[Idx].set(K);
- CheckedPairs[K].set(Idx);
- std::optional<int> Diff = getPointersDiff(
- Stores[K]->getValueOperand()->getType(), Stores[K]->getPointerOperand(),
- Stores[Idx]->getValueOperand()->getType(),
- Stores[Idx]->getPointerOperand(), *DL, *SE, /*StrictCheck=*/true);
- if (!Diff || *Diff == 0)
- return false;
- int Val = *Diff;
- if (Val < 0) {
- if (ConsecutiveChain[Idx].second > -Val) {
- Tails.set(K);
- ConsecutiveChain[Idx] = std::make_pair(K, -Val);
- }
- return false;
+ // Stores the pair of stores (first_store, last_store) in a range, that were
+ // already tried to be vectorized. Allows to skip the store ranges that were
+ // already tried to be vectorized but the attempts were unsuccessful.
+ DenseSet<std::pair<Value *, Value *>> TriedSequences;
+ struct StoreDistCompare {
+ bool operator()(const std::pair<unsigned, int> &Op1,
+ const std::pair<unsigned, int> &Op2) const {
+ return Op1.second < Op2.second;
}
- if (ConsecutiveChain[K].second <= Val)
- return false;
-
- Tails.set(Idx);
- ConsecutiveChain[K] = std::make_pair(Idx, Val);
- return Val == 1;
};
- // Do a quadratic search on all of the given stores in reverse order and find
- // all of the pairs of stores that follow each other.
- for (int Idx = E - 1; Idx >= 0; --Idx) {
- // If a store has multiple consecutive store candidates, search according
- // to the sequence: Idx-1, Idx+1, Idx-2, Idx+2, ...
- // This is because usually pairing with immediate succeeding or preceding
- // candidate create the best chance to find slp vectorization opportunity.
- const int MaxLookDepth = std::max(E - Idx, Idx + 1);
- IterCnt = 0;
- for (int Offset = 1, F = MaxLookDepth; Offset < F; ++Offset)
- if ((Idx >= Offset && FindConsecutiveAccess(Idx - Offset, Idx)) ||
- (Idx + Offset < E && FindConsecutiveAccess(Idx + Offset, Idx)))
- break;
- }
-
- // Tracks if we tried to vectorize stores starting from the given tail
- // already.
- SmallBitVector TriedTails(E, false);
- // For stores that start but don't end a link in the chain:
- for (int Cnt = E; Cnt > 0; --Cnt) {
- int I = Cnt - 1;
- if (ConsecutiveChain[I].first == E || Tails.test(I))
- continue;
- // We found a store instr that starts a chain. Now follow the chain and try
- // to vectorize it.
+ // A set of pairs (index of store in Stores array ref, Distance of the store
+ // address relative to base store address in units).
+ using StoreIndexToDistSet =
+ std::set<std::pair<unsigned, int>, StoreDistCompare>;
+ auto TryToVectorize = [&](const StoreIndexToDistSet &Set) {
+ int PrevDist = -1;
BoUpSLP::ValueList Operands;
// Collect the chain into a list.
- while (I != E && !VectorizedStores.count(Stores[I])) {
- Operands.push_back(Stores[I]);
- Tails.set(I);
- if (ConsecutiveChain[I].second != 1) {
- // Mark the new end in the chain and go back, if required. It might be
- // required if the original stores come in reversed order, for example.
- if (ConsecutiveChain[I].first != E &&
- Tails.test(ConsecutiveChain[I].first) && !TriedTails.test(I) &&
- !VectorizedStores.count(Stores[ConsecutiveChain[I].first])) {
- TriedTails.set(I);
- Tails.reset(ConsecutiveChain[I].first);
- if (Cnt < ConsecutiveChain[I].first + 2)
- Cnt = ConsecutiveChain[I].first + 2;
- }
- break;
+ for (auto [Idx, Data] : enumerate(Set)) {
+ if (Operands.empty() || Data.second - PrevDist == 1) {
+ Operands.push_back(Stores[Data.first]);
+ PrevDist = Data.second;
+ if (Idx != Set.size() - 1)
+ continue;
+ }
+ if (Operands.size() <= 1) {
+ Operands.clear();
+ Operands.push_back(Stores[Data.first]);
+ PrevDist = Data.second;
+ continue;
}
- // Move to the next value in the chain.
- I = ConsecutiveChain[I].first;
- }
- assert(!Operands.empty() && "Expected non-empty list of stores.");
- unsigned MaxVecRegSize = R.getMaxVecRegSize();
- unsigned EltSize = R.getVectorElementSize(Operands[0]);
- unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize);
+ unsigned MaxVecRegSize = R.getMaxVecRegSize();
+ unsigned EltSize = R.getVectorElementSize(Operands[0]);
+ unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize);
- unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store),
- MaxElts);
- auto *Store = cast<StoreInst>(Operands[0]);
- Type *StoreTy = Store->getValueOperand()->getType();
- Type *ValueTy = StoreTy;
- if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand()))
- ValueTy = Trunc->getSrcTy();
- unsigned MinVF = TTI->getStoreMinimumVF(
- R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy);
+ unsigned MaxVF =
+ std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts);
+ auto *Store = cast<StoreInst>(Operands[0]);
+ Type *StoreTy = Store->getValueOperand()->getType();
+ Type *ValueTy = StoreTy;
+ if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand()))
+ ValueTy = Trunc->getSrcTy();
+ unsigned MinVF = TTI->getStoreMinimumVF(
+ R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy);
- if (MaxVF <= MinVF) {
- LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF << ") <= "
- << "MinVF (" << MinVF << ")\n");
- }
+ if (MaxVF <= MinVF) {
+ LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF
+ << ") <= "
+ << "MinVF (" << MinVF << ")\n");
+ }
- // FIXME: Is division-by-2 the correct step? Should we assert that the
- // register size is a power-of-2?
- unsigned StartIdx = 0;
- for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) {
- for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) {
- ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size);
- if (!VectorizedStores.count(Slice.front()) &&
- !VectorizedStores.count(Slice.back()) &&
- vectorizeStoreChain(Slice, R, Cnt, MinVF)) {
- // Mark the vectorized stores so that we don't vectorize them again.
- VectorizedStores.insert(Slice.begin(), Slice.end());
- Changed = true;
- // If we vectorized initial block, no need to try to vectorize it
- // again.
- if (Cnt == StartIdx)
- StartIdx += Size;
- Cnt += Size;
- continue;
+ // FIXME: Is division-by-2 the correct step? Should we assert that the
+ // register size is a power-of-2?
+ unsigned StartIdx = 0;
+ for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) {
+ for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) {
+ ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size);
+ assert(
+ all_of(
+ Slice,
+ [&](Value *V) {
+ return cast<StoreInst>(V)->getValueOperand()->getType() ==
+ cast<StoreInst>(Slice.front())
+ ->getValueOperand()
+ ->getType();
+ }) &&
+ "Expected all operands of same type.");
+ if (!VectorizedStores.count(Slice.front()) &&
+ !VectorizedStores.count(Slice.back()) &&
+ TriedSequences.insert(std::make_pair(Slice.front(), Slice.back()))
+ .second &&
+ vectorizeStoreChain(Slice, R, Cnt, MinVF)) {
+ // Mark the vectorized stores so that we don't vectorize them again.
+ VectorizedStores.insert(Slice.begin(), Slice.end());
+ Changed = true;
+ // If we vectorized initial block, no need to try to vectorize it
+ // again.
+ if (Cnt == StartIdx)
+ StartIdx += Size;
+ Cnt += Size;
+ continue;
+ }
+ ++Cnt;
}
- ++Cnt;
+ // Check if the whole array was vectorized already - exit.
+ if (StartIdx >= Operands.size())
+ break;
}
- // Check if the whole array was vectorized already - exit.
- if (StartIdx >= Operands.size())
- break;
+ Operands.clear();
+ Operands.push_back(Stores[Data.first]);
+ PrevDist = Data.second;
}
+ };
+
+ // Stores pair (first: index of the store into Stores array ref, address of
+ // which taken as base, second: sorted set of pairs {index, dist}, which are
+ // indices of stores in the set and their store location distances relative to
+ // the base address).
+
+ // Need to store the index of the very first store separately, since the set
+ // may be reordered after the insertion and the first store may be moved. This
+ // container allows to reduce number of calls of getPointersDiff() function.
+ SmallVector<std::pair<unsigned, StoreIndexToDistSet>> SortedStores;
+ // Inserts the specified store SI with the given index Idx to the set of the
+ // stores. If the store with the same distance is found already - stop
+ // insertion, try to vectorize already found stores. If some stores from this
+ // sequence were not vectorized - try to vectorize them with the new store
+ // later. But this logic is applied only to the stores, that come before the
+ // previous store with the same distance.
+ // Example:
+ // 1. store x, %p
+ // 2. store y, %p+1
+ // 3. store z, %p+2
+ // 4. store a, %p
+ // 5. store b, %p+3
+ // - Scan this from the last to first store. The very first bunch of stores is
+ // {5, {{4, -3}, {2, -2}, {3, -1}, {5, 0}}} (the element in SortedStores
+ // vector).
+ // - The next store in the list - #1 - has the same distance from store #5 as
+ // the store #4.
+ // - Try to vectorize sequence of stores 4,2,3,5.
+ // - If all these stores are vectorized - just drop them.
+ // - If some of them are not vectorized (say, #3 and #5), do extra analysis.
+ // - Start new stores sequence.
+ // The new bunch of stores is {1, {1, 0}}.
+ // - Add the stores from previous sequence, that were not vectorized.
+ // Here we consider the stores in the reversed order, rather they are used in
+ // the IR (Stores are reversed already, see vectorizeStoreChains() function).
+ // Store #3 can be added -> comes after store #4 with the same distance as
+ // store #1.
+ // Store #5 cannot be added - comes before store #4.
+ // This logic allows to improve the compile time, we assume that the stores
+ // after previous store with the same distance most likely have memory
+ // dependencies and no need to waste compile time to try to vectorize them.
+ // - Try to vectorize the sequence {1, {1, 0}, {3, 2}}.
+ auto FillStoresSet = [&](unsigned Idx, StoreInst *SI) {
+ for (std::pair<unsigned, StoreIndexToDistSet> &Set : SortedStores) {
+ std::optional<int> Diff = getPointersDiff(
+ Stores[Set.first]->getValueOperand()->getType(),
+ Stores[Set.first]->getPointerOperand(),
+ SI->getValueOperand()->getType(), SI->getPointerOperand(), *DL, *SE,
+ /*StrictCheck=*/true);
+ if (!Diff)
+ continue;
+ auto It = Set.second.find(std::make_pair(Idx, *Diff));
+ if (It == Set.second.end()) {
+ Set.second.emplace(Idx, *Diff);
+ return;
+ }
+ // Try to vectorize the first found set to avoid duplicate analysis.
+ TryToVectorize(Set.second);
+ StoreIndexToDistSet PrevSet;
+ PrevSet.swap(Set.second);
+ Set.first = Idx;
+ Set.second.emplace(Idx, 0);
+ // Insert stores that followed previous match to try to vectorize them
+ // with this store.
+ unsigned StartIdx = It->first + 1;
+ SmallBitVector UsedStores(Idx - StartIdx);
+ // Distances to previously found dup store (or this store, since they
+ // store to the same addresses).
+ SmallVector<int> Dists(Idx - StartIdx, 0);
+ for (const std::pair<unsigned, int> &Pair : reverse(PrevSet)) {
+ // Do not try to vectorize sequences, we already tried.
+ if (Pair.first <= It->first ||
+ VectorizedStores.contains(Stores[Pair.first]))
+ break;
+ unsigned BI = Pair.first - StartIdx;
+ UsedStores.set(BI);
+ Dists[BI] = Pair.second - It->second;
+ }
+ for (unsigned I = StartIdx; I < Idx; ++I) {
+ unsigned BI = I - StartIdx;
+ if (UsedStores.test(BI))
+ Set.second.emplace(I, Dists[BI]);
+ }
+ return;
+ }
+ auto &Res = SortedStores.emplace_back();
+ Res.first = Idx;
+ Res.second.emplace(Idx, 0);
+ };
+ StoreInst *PrevStore = Stores.front();
+ for (auto [I, SI] : enumerate(Stores)) {
+ // Check that we do not try to vectorize stores of different types.
+ if (PrevStore->getValueOperand()->getType() !=
+ SI->getValueOperand()->getType()) {
+ for (auto &Set : SortedStores)
+ TryToVectorize(Set.second);
+ SortedStores.clear();
+ PrevStore = SI;
+ }
+ FillStoresSet(I, SI);
}
+ // Final vectorization attempt.
+ for (auto &Set : SortedStores)
+ TryToVectorize(Set.second);
+
return Changed;
}
@@ -12506,7 +13680,7 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) {
// constant index, or a pointer operand that doesn't point to a scalar
// type.
else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
- auto Idx = GEP->idx_begin()->get();
+ Value *Idx = GEP->idx_begin()->get();
if (GEP->getNumIndices() > 1 || isa<Constant>(Idx))
continue;
if (!isValidElementType(Idx->getType()))
@@ -12541,8 +13715,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
// NOTE: the following will give user internal llvm type name, which may
// not be useful.
R.getORE()->emit([&]() {
- std::string type_str;
- llvm::raw_string_ostream rso(type_str);
+ std::string TypeStr;
+ llvm::raw_string_ostream rso(TypeStr);
Ty->print(rso);
return OptimizationRemarkMissed(SV_NAME, "UnsupportedType", I0)
<< "Cannot SLP vectorize list: type "
@@ -12877,10 +14051,12 @@ class HorizontalReduction {
static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS,
Value *RHS, const Twine &Name,
const ReductionOpsListType &ReductionOps) {
- bool UseSelect = ReductionOps.size() == 2 ||
- // Logical or/and.
- (ReductionOps.size() == 1 &&
- isa<SelectInst>(ReductionOps.front().front()));
+ bool UseSelect =
+ ReductionOps.size() == 2 ||
+ // Logical or/and.
+ (ReductionOps.size() == 1 && any_of(ReductionOps.front(), [](Value *V) {
+ return isa<SelectInst>(V);
+ }));
assert((!UseSelect || ReductionOps.size() != 2 ||
isa<SelectInst>(ReductionOps[1][0])) &&
"Expected cmp + select pairs for reduction");
@@ -13314,12 +14490,26 @@ public:
// Update the final value in the reduction.
Builder.SetCurrentDebugLocation(
cast<Instruction>(ReductionOps.front().front())->getDebugLoc());
+ if ((isa<PoisonValue>(VectorizedTree) && !isa<PoisonValue>(Res)) ||
+ (isGuaranteedNotToBePoison(Res) &&
+ !isGuaranteedNotToBePoison(VectorizedTree))) {
+ auto It = ReducedValsToOps.find(Res);
+ if (It != ReducedValsToOps.end() &&
+ any_of(It->getSecond(),
+ [](Instruction *I) { return isBoolLogicOp(I); }))
+ std::swap(VectorizedTree, Res);
+ }
+
return createOp(Builder, RdxKind, VectorizedTree, Res, "op.rdx",
ReductionOps);
}
// Initialize the final value in the reduction.
return Res;
};
+ bool AnyBoolLogicOp =
+ any_of(ReductionOps.back(), [](Value *V) {
+ return isBoolLogicOp(cast<Instruction>(V));
+ });
// The reduction root is used as the insertion point for new instructions,
// so set it as externally used to prevent it from being deleted.
ExternallyUsedValues[ReductionRoot];
@@ -13363,10 +14553,12 @@ public:
// Check if the reduction value was not overriden by the extractelement
// instruction because of the vectorization and exclude it, if it is not
// compatible with other values.
- if (auto *Inst = dyn_cast<Instruction>(RdxVal))
- if (isVectorLikeInstWithConstOps(Inst) &&
- (!S.getOpcode() || !S.isOpcodeOrAlt(Inst)))
- continue;
+ // Also check if the instruction was folded to constant/other value.
+ auto *Inst = dyn_cast<Instruction>(RdxVal);
+ if ((Inst && isVectorLikeInstWithConstOps(Inst) &&
+ (!S.getOpcode() || !S.isOpcodeOrAlt(Inst))) ||
+ (S.getOpcode() && !Inst))
+ continue;
Candidates.push_back(RdxVal);
TrackedToOrig.try_emplace(RdxVal, OrigReducedVals[Cnt]);
}
@@ -13542,11 +14734,9 @@ public:
for (unsigned Cnt = 0, Sz = ReducedVals.size(); Cnt < Sz; ++Cnt) {
if (Cnt == I || (ShuffledExtracts && Cnt == I - 1))
continue;
- for_each(ReducedVals[Cnt],
- [&LocalExternallyUsedValues, &TrackedVals](Value *V) {
- if (isa<Instruction>(V))
- LocalExternallyUsedValues[TrackedVals[V]];
- });
+ for (Value *V : ReducedVals[Cnt])
+ if (isa<Instruction>(V))
+ LocalExternallyUsedValues[TrackedVals[V]];
}
if (!IsSupportedHorRdxIdentityOp) {
// Number of uses of the candidates in the vector of values.
@@ -13590,7 +14780,7 @@ public:
// Update LocalExternallyUsedValues for the scalar, replaced by
// extractelement instructions.
for (const std::pair<Value *, Value *> &Pair : ReplacedExternals) {
- auto It = ExternallyUsedValues.find(Pair.first);
+ auto *It = ExternallyUsedValues.find(Pair.first);
if (It == ExternallyUsedValues.end())
continue;
LocalExternallyUsedValues[Pair.second].append(It->second);
@@ -13604,7 +14794,8 @@ public:
InstructionCost ReductionCost =
getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
InstructionCost Cost = TreeCost + ReductionCost;
- LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n");
+ LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
+ << " for reduction\n");
if (!Cost.isValid())
return nullptr;
if (Cost >= -SLPCostThreshold) {
@@ -13651,7 +14842,9 @@ public:
// To prevent poison from leaking across what used to be sequential,
// safe, scalar boolean logic operations, the reduction operand must be
// frozen.
- if (isBoolLogicOp(RdxRootInst))
+ if ((isBoolLogicOp(RdxRootInst) ||
+ (AnyBoolLogicOp && VL.size() != TrackedVals.size())) &&
+ !isGuaranteedNotToBePoison(VectorizedRoot))
VectorizedRoot = Builder.CreateFreeze(VectorizedRoot);
// Emit code to correctly handle reused reduced values, if required.
@@ -13663,6 +14856,16 @@ public:
Value *ReducedSubTree =
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
+ if (ReducedSubTree->getType() != VL.front()->getType()) {
+ ReducedSubTree = Builder.CreateIntCast(
+ ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) {
+ KnownBits Known = computeKnownBits(
+ R, cast<Instruction>(ReductionOps.front().front())
+ ->getModule()
+ ->getDataLayout());
+ return !Known.isNonNegative();
+ }));
+ }
// Improved analysis for add/fadd/xor reductions with same scale factor
// for all operands of reductions. We can emit scalar ops for them
@@ -13715,31 +14918,33 @@ public:
// RedOp2 = select i1 ?, i1 RHS, i1 false
// Then, we must freeze LHS in the new op.
- auto &&FixBoolLogicalOps =
- [&Builder, VectorizedTree](Value *&LHS, Value *&RHS,
- Instruction *RedOp1, Instruction *RedOp2) {
- if (!isBoolLogicOp(RedOp1))
- return;
- if (LHS == VectorizedTree || getRdxOperand(RedOp1, 0) == LHS ||
- isGuaranteedNotToBePoison(LHS))
- return;
- if (!isBoolLogicOp(RedOp2))
- return;
- if (RHS == VectorizedTree || getRdxOperand(RedOp2, 0) == RHS ||
- isGuaranteedNotToBePoison(RHS)) {
- std::swap(LHS, RHS);
- return;
- }
- LHS = Builder.CreateFreeze(LHS);
- };
+ auto FixBoolLogicalOps = [&, VectorizedTree](Value *&LHS, Value *&RHS,
+ Instruction *RedOp1,
+ Instruction *RedOp2,
+ bool InitStep) {
+ if (!AnyBoolLogicOp)
+ return;
+ if (isBoolLogicOp(RedOp1) &&
+ ((!InitStep && LHS == VectorizedTree) ||
+ getRdxOperand(RedOp1, 0) == LHS || isGuaranteedNotToBePoison(LHS)))
+ return;
+ if (isBoolLogicOp(RedOp2) && ((!InitStep && RHS == VectorizedTree) ||
+ getRdxOperand(RedOp2, 0) == RHS ||
+ isGuaranteedNotToBePoison(RHS))) {
+ std::swap(LHS, RHS);
+ return;
+ }
+ if (LHS != VectorizedTree)
+ LHS = Builder.CreateFreeze(LHS);
+ };
// Finish the reduction.
// Need to add extra arguments and not vectorized possible reduction
// values.
// Try to avoid dependencies between the scalar remainders after
// reductions.
- auto &&FinalGen =
- [this, &Builder, &TrackedVals, &FixBoolLogicalOps](
- ArrayRef<std::pair<Instruction *, Value *>> InstVals) {
+ auto FinalGen =
+ [&](ArrayRef<std::pair<Instruction *, Value *>> InstVals,
+ bool InitStep) {
unsigned Sz = InstVals.size();
SmallVector<std::pair<Instruction *, Value *>> ExtraReds(Sz / 2 +
Sz % 2);
@@ -13760,7 +14965,7 @@ public:
// sequential, safe, scalar boolean logic operations, the
// reduction operand must be frozen.
FixBoolLogicalOps(StableRdxVal1, StableRdxVal2, InstVals[I].first,
- RedOp);
+ RedOp, InitStep);
Value *ExtraRed = createOp(Builder, RdxKind, StableRdxVal1,
StableRdxVal2, "op.rdx", ReductionOps);
ExtraReds[I / 2] = std::make_pair(InstVals[I].first, ExtraRed);
@@ -13790,11 +14995,13 @@ public:
ExtraReductions.emplace_back(I, Pair.first);
}
// Iterate through all not-vectorized reduction values/extra arguments.
+ bool InitStep = true;
while (ExtraReductions.size() > 1) {
VectorizedTree = ExtraReductions.front().second;
SmallVector<std::pair<Instruction *, Value *>> NewReds =
- FinalGen(ExtraReductions);
+ FinalGen(ExtraReductions, InitStep);
ExtraReductions.swap(NewReds);
+ InitStep = false;
}
VectorizedTree = ExtraReductions.front().second;
@@ -13841,8 +15048,7 @@ private:
bool IsCmpSelMinMax, unsigned ReduxWidth,
FastMathFlags FMF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- Value *FirstReducedVal = ReducedVals.front();
- Type *ScalarTy = FirstReducedVal->getType();
+ Type *ScalarTy = ReducedVals.front()->getType();
FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
@@ -13916,7 +15122,7 @@ private:
}
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
- << " for reduction that starts with " << *FirstReducedVal
+ << " for reduction of " << shortBundleName(ReducedVals)
<< " (It is a splitting reduction)\n");
return VectorCost - ScalarCost;
}
@@ -13931,7 +15137,7 @@ private:
"A call to the llvm.fmuladd intrinsic is not handled yet");
++NumVectorInstructions;
- return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind);
+ return createSimpleTargetReduction(Builder, VectorizedValue, RdxKind);
}
/// Emits optimized code for unique scalar value reused \p Cnt times.
@@ -13978,8 +15184,8 @@ private:
case RecurKind::Mul:
case RecurKind::FMul:
case RecurKind::FMulAdd:
- case RecurKind::SelectICmp:
- case RecurKind::SelectFCmp:
+ case RecurKind::IAnyOf:
+ case RecurKind::FAnyOf:
case RecurKind::None:
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
}
@@ -14067,8 +15273,8 @@ private:
case RecurKind::Mul:
case RecurKind::FMul:
case RecurKind::FMulAdd:
- case RecurKind::SelectICmp:
- case RecurKind::SelectFCmp:
+ case RecurKind::IAnyOf:
+ case RecurKind::FAnyOf:
case RecurKind::None:
llvm_unreachable("Unexpected reduction kind for reused scalars.");
}
@@ -14163,8 +15369,8 @@ static bool findBuildAggregate(Instruction *LastInsertInst,
InsertElts.resize(*AggregateSize);
findBuildAggregate_rec(LastInsertInst, TTI, BuildVectorOpds, InsertElts, 0);
- llvm::erase_value(BuildVectorOpds, nullptr);
- llvm::erase_value(InsertElts, nullptr);
+ llvm::erase(BuildVectorOpds, nullptr);
+ llvm::erase(InsertElts, nullptr);
if (BuildVectorOpds.size() >= 2)
return true;
@@ -14400,8 +15606,7 @@ bool SLPVectorizerPass::tryToVectorize(ArrayRef<WeakTrackingVH> Insts,
bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI,
BasicBlock *BB, BoUpSLP &R) {
- const DataLayout &DL = BB->getModule()->getDataLayout();
- if (!R.canMapToVector(IVI->getType(), DL))
+ if (!R.canMapToVector(IVI->getType()))
return false;
SmallVector<Value *, 16> BuildVectorOpds;
@@ -14540,11 +15745,11 @@ static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI,
if (BasePred1 > BasePred2)
return false;
// Compare operands.
- bool LEPreds = Pred1 <= Pred2;
- bool GEPreds = Pred1 >= Pred2;
+ bool CI1Preds = Pred1 == BasePred1;
+ bool CI2Preds = Pred2 == BasePred1;
for (int I = 0, E = CI1->getNumOperands(); I < E; ++I) {
- auto *Op1 = CI1->getOperand(LEPreds ? I : E - I - 1);
- auto *Op2 = CI2->getOperand(GEPreds ? I : E - I - 1);
+ auto *Op1 = CI1->getOperand(CI1Preds ? I : E - I - 1);
+ auto *Op2 = CI2->getOperand(CI2Preds ? I : E - I - 1);
if (Op1->getValueID() < Op2->getValueID())
return !IsCompatibility;
if (Op1->getValueID() > Op2->getValueID())
@@ -14690,14 +15895,20 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
return true;
if (Opcodes1.size() > Opcodes2.size())
return false;
- std::optional<bool> ConstOrder;
for (int I = 0, E = Opcodes1.size(); I < E; ++I) {
// Undefs are compatible with any other value.
if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) {
- if (!ConstOrder)
- ConstOrder =
- !isa<UndefValue>(Opcodes1[I]) && isa<UndefValue>(Opcodes2[I]);
- continue;
+ if (isa<Instruction>(Opcodes1[I]))
+ return true;
+ if (isa<Instruction>(Opcodes2[I]))
+ return false;
+ if (isa<Constant>(Opcodes1[I]) && !isa<UndefValue>(Opcodes1[I]))
+ return true;
+ if (isa<Constant>(Opcodes2[I]) && !isa<UndefValue>(Opcodes2[I]))
+ return false;
+ if (isa<UndefValue>(Opcodes1[I]) && isa<UndefValue>(Opcodes2[I]))
+ continue;
+ return isa<UndefValue>(Opcodes2[I]);
}
if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I]))
if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) {
@@ -14713,21 +15924,26 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
if (NodeI1 != NodeI2)
return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn();
InstructionsState S = getSameOpcode({I1, I2}, *TLI);
- if (S.getOpcode())
+ if (S.getOpcode() && !S.isAltShuffle())
continue;
return I1->getOpcode() < I2->getOpcode();
}
- if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) {
- if (!ConstOrder)
- ConstOrder = Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID();
- continue;
- }
+ if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I]))
+ return Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID();
+ if (isa<Instruction>(Opcodes1[I]))
+ return true;
+ if (isa<Instruction>(Opcodes2[I]))
+ return false;
+ if (isa<Constant>(Opcodes1[I]))
+ return true;
+ if (isa<Constant>(Opcodes2[I]))
+ return false;
if (Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID())
return true;
if (Opcodes1[I]->getValueID() > Opcodes2[I]->getValueID())
return false;
}
- return ConstOrder && *ConstOrder;
+ return false;
};
auto AreCompatiblePHIs = [&PHIToOpcodes, this](Value *V1, Value *V2) {
if (V1 == V2)
@@ -14775,6 +15991,9 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
Incoming.push_back(P);
}
+ if (Incoming.size() <= 1)
+ break;
+
// Find the corresponding non-phi nodes for better matching when trying to
// build the tree.
for (Value *V : Incoming) {
@@ -14837,41 +16056,41 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
return I->use_empty() &&
(I->getType()->isVoidTy() || isa<CallInst, InvokeInst>(I));
};
- for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
+ for (BasicBlock::iterator It = BB->begin(), E = BB->end(); It != E; ++It) {
// Skip instructions with scalable type. The num of elements is unknown at
// compile-time for scalable type.
- if (isa<ScalableVectorType>(it->getType()))
+ if (isa<ScalableVectorType>(It->getType()))
continue;
// Skip instructions marked for the deletion.
- if (R.isDeleted(&*it))
+ if (R.isDeleted(&*It))
continue;
// We may go through BB multiple times so skip the one we have checked.
- if (!VisitedInstrs.insert(&*it).second) {
- if (HasNoUsers(&*it) &&
- VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) {
+ if (!VisitedInstrs.insert(&*It).second) {
+ if (HasNoUsers(&*It) &&
+ VectorizeInsertsAndCmps(/*VectorizeCmps=*/It->isTerminator())) {
// We would like to start over since some instructions are deleted
// and the iterator may become invalid value.
Changed = true;
- it = BB->begin();
- e = BB->end();
+ It = BB->begin();
+ E = BB->end();
}
continue;
}
- if (isa<DbgInfoIntrinsic>(it))
+ if (isa<DbgInfoIntrinsic>(It))
continue;
// Try to vectorize reductions that use PHINodes.
- if (PHINode *P = dyn_cast<PHINode>(it)) {
+ if (PHINode *P = dyn_cast<PHINode>(It)) {
// Check that the PHI is a reduction PHI.
if (P->getNumIncomingValues() == 2) {
// Try to match and vectorize a horizontal reduction.
Instruction *Root = getReductionInstr(DT, P, BB, LI);
if (Root && vectorizeRootInstruction(P, Root, BB, R, TTI)) {
Changed = true;
- it = BB->begin();
- e = BB->end();
+ It = BB->begin();
+ E = BB->end();
continue;
}
}
@@ -14896,23 +16115,23 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
}
- if (HasNoUsers(&*it)) {
+ if (HasNoUsers(&*It)) {
bool OpsChanged = false;
- auto *SI = dyn_cast<StoreInst>(it);
+ auto *SI = dyn_cast<StoreInst>(It);
bool TryToVectorizeRoot = ShouldStartVectorizeHorAtStore || !SI;
if (SI) {
- auto I = Stores.find(getUnderlyingObject(SI->getPointerOperand()));
+ auto *I = Stores.find(getUnderlyingObject(SI->getPointerOperand()));
// Try to vectorize chain in store, if this is the only store to the
// address in the block.
// TODO: This is just a temporarily solution to save compile time. Need
// to investigate if we can safely turn on slp-vectorize-hor-store
// instead to allow lookup for reduction chains in all non-vectorized
// stores (need to check side effects and compile time).
- TryToVectorizeRoot = (I == Stores.end() || I->second.size() == 1) &&
- SI->getValueOperand()->hasOneUse();
+ TryToVectorizeRoot |= (I == Stores.end() || I->second.size() == 1) &&
+ SI->getValueOperand()->hasOneUse();
}
if (TryToVectorizeRoot) {
- for (auto *V : it->operand_values()) {
+ for (auto *V : It->operand_values()) {
// Postponed instructions should not be vectorized here, delay their
// vectorization.
if (auto *VI = dyn_cast<Instruction>(V);
@@ -14925,21 +16144,21 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
// top-tree instructions to try to vectorize as many instructions as
// possible.
OpsChanged |=
- VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator());
+ VectorizeInsertsAndCmps(/*VectorizeCmps=*/It->isTerminator());
if (OpsChanged) {
// We would like to start over since some instructions are deleted
// and the iterator may become invalid value.
Changed = true;
- it = BB->begin();
- e = BB->end();
+ It = BB->begin();
+ E = BB->end();
continue;
}
}
- if (isa<InsertElementInst, InsertValueInst>(it))
- PostProcessInserts.insert(&*it);
- else if (isa<CmpInst>(it))
- PostProcessCmps.insert(cast<CmpInst>(&*it));
+ if (isa<InsertElementInst, InsertValueInst>(It))
+ PostProcessInserts.insert(&*It);
+ else if (isa<CmpInst>(It))
+ PostProcessCmps.insert(cast<CmpInst>(&*It));
}
return Changed;
@@ -15043,6 +16262,12 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) {
// compatible (have the same opcode, same parent), otherwise it is
// definitely not profitable to try to vectorize them.
auto &&StoreSorter = [this](StoreInst *V, StoreInst *V2) {
+ if (V->getValueOperand()->getType()->getTypeID() <
+ V2->getValueOperand()->getType()->getTypeID())
+ return true;
+ if (V->getValueOperand()->getType()->getTypeID() >
+ V2->getValueOperand()->getType()->getTypeID())
+ return false;
if (V->getPointerOperandType()->getTypeID() <
V2->getPointerOperandType()->getTypeID())
return true;
@@ -15081,6 +16306,8 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) {
auto &&AreCompatibleStores = [this](StoreInst *V1, StoreInst *V2) {
if (V1 == V2)
return true;
+ if (V1->getValueOperand()->getType() != V2->getValueOperand()->getType())
+ return false;
if (V1->getPointerOperandType() != V2->getPointerOperandType())
return false;
// Undefs are compatible with any other value.
@@ -15112,8 +16339,13 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) {
if (!isValidElementType(Pair.second.front()->getValueOperand()->getType()))
continue;
+ // Reverse stores to do bottom-to-top analysis. This is important if the
+ // values are stores to the same addresses several times, in this case need
+ // to follow the stores order (reversed to meet the memory dependecies).
+ SmallVector<StoreInst *> ReversedStores(Pair.second.rbegin(),
+ Pair.second.rend());
Changed |= tryToVectorizeSequence<StoreInst>(
- Pair.second, StoreSorter, AreCompatibleStores,
+ ReversedStores, StoreSorter, AreCompatibleStores,
[this, &R](ArrayRef<StoreInst *> Candidates, bool) {
return vectorizeStores(Candidates, R);
},
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 1271d1424c03..7ff6749a0908 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -133,9 +133,12 @@ public:
Ingredient2Recipe[I] = R;
}
+ /// Create the mask for the vector loop header block.
+ void createHeaderMask(VPlan &Plan);
+
/// A helper function that computes the predicate of the block BB, assuming
- /// that the header block of the loop is set to True. It returns the *entry*
- /// mask for the block BB.
+ /// that the header block of the loop is set to True or the loop mask when
+ /// tail folding. It returns the *entry* mask for the block BB.
VPValue *createBlockInMask(BasicBlock *BB, VPlan &Plan);
/// A helper function that computes the predicate of the edge between SRC
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index e81b88fd8099..263d9938d1f0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -19,7 +19,6 @@
#include "VPlan.h"
#include "VPlanCFG.h"
#include "VPlanDominatorTree.h"
-#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -234,6 +233,99 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) {
// set(Def, Extract, Instance);
return Extract;
}
+
+Value *VPTransformState::get(VPValue *Def, unsigned Part) {
+ // If Values have been set for this Def return the one relevant for \p Part.
+ if (hasVectorValue(Def, Part))
+ return Data.PerPartOutput[Def][Part];
+
+ auto GetBroadcastInstrs = [this, Def](Value *V) {
+ bool SafeToHoist = Def->isDefinedOutsideVectorRegions();
+ if (VF.isScalar())
+ return V;
+ // Place the code for broadcasting invariant variables in the new preheader.
+ IRBuilder<>::InsertPointGuard Guard(Builder);
+ if (SafeToHoist) {
+ BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>(
+ Plan->getVectorLoopRegion()->getSinglePredecessor())];
+ if (LoopVectorPreHeader)
+ Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator());
+ }
+
+ // Place the code for broadcasting invariant variables in the new preheader.
+ // Broadcast the scalar into all locations in the vector.
+ Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast");
+
+ return Shuf;
+ };
+
+ if (!hasScalarValue(Def, {Part, 0})) {
+ assert(Def->isLiveIn() && "expected a live-in");
+ if (Part != 0)
+ return get(Def, 0);
+ Value *IRV = Def->getLiveInIRValue();
+ Value *B = GetBroadcastInstrs(IRV);
+ set(Def, B, Part);
+ return B;
+ }
+
+ Value *ScalarValue = get(Def, {Part, 0});
+ // If we aren't vectorizing, we can just copy the scalar map values over
+ // to the vector map.
+ if (VF.isScalar()) {
+ set(Def, ScalarValue, Part);
+ return ScalarValue;
+ }
+
+ bool IsUniform = vputils::isUniformAfterVectorization(Def);
+
+ unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1;
+ // Check if there is a scalar value for the selected lane.
+ if (!hasScalarValue(Def, {Part, LastLane})) {
+ // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and
+ // VPExpandSCEVRecipes can also be uniform.
+ assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) ||
+ isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) ||
+ isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) &&
+ "unexpected recipe found to be invariant");
+ IsUniform = true;
+ LastLane = 0;
+ }
+
+ auto *LastInst = cast<Instruction>(get(Def, {Part, LastLane}));
+ // Set the insert point after the last scalarized instruction or after the
+ // last PHI, if LastInst is a PHI. This ensures the insertelement sequence
+ // will directly follow the scalar definitions.
+ auto OldIP = Builder.saveIP();
+ auto NewIP =
+ isa<PHINode>(LastInst)
+ ? BasicBlock::iterator(LastInst->getParent()->getFirstNonPHI())
+ : std::next(BasicBlock::iterator(LastInst));
+ Builder.SetInsertPoint(&*NewIP);
+
+ // However, if we are vectorizing, we need to construct the vector values.
+ // If the value is known to be uniform after vectorization, we can just
+ // broadcast the scalar value corresponding to lane zero for each unroll
+ // iteration. Otherwise, we construct the vector values using
+ // insertelement instructions. Since the resulting vectors are stored in
+ // State, we will only generate the insertelements once.
+ Value *VectorValue = nullptr;
+ if (IsUniform) {
+ VectorValue = GetBroadcastInstrs(ScalarValue);
+ set(Def, VectorValue, Part);
+ } else {
+ // Initialize packing with insertelements to start from undef.
+ assert(!VF.isScalable() && "VF is assumed to be non scalable.");
+ Value *Undef = PoisonValue::get(VectorType::get(LastInst->getType(), VF));
+ set(Def, Undef, Part);
+ for (unsigned Lane = 0; Lane < VF.getKnownMinValue(); ++Lane)
+ packScalarIntoVectorValue(Def, {Part, Lane});
+ VectorValue = get(Def, Part);
+ }
+ Builder.restoreIP(OldIP);
+ return VectorValue;
+}
+
BasicBlock *VPTransformState::CFGState::getPreheaderBBFor(VPRecipeBase *R) {
VPRegionBlock *LoopRegion = R->getParent()->getEnclosingLoopRegion();
return VPBB2IRBB[LoopRegion->getPreheaderVPBB()];
@@ -267,18 +359,15 @@ void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) {
}
}
-void VPTransformState::setDebugLocFromInst(const Value *V) {
- const Instruction *Inst = dyn_cast<Instruction>(V);
- if (!Inst) {
- Builder.SetCurrentDebugLocation(DebugLoc());
- return;
- }
-
- const DILocation *DIL = Inst->getDebugLoc();
+void VPTransformState::setDebugLocFrom(DebugLoc DL) {
+ const DILocation *DIL = DL;
// When a FSDiscriminator is enabled, we don't need to add the multiply
// factors to the discriminators.
- if (DIL && Inst->getFunction()->shouldEmitDebugInfoForProfiling() &&
- !Inst->isDebugOrPseudoInst() && !EnableFSDiscriminator) {
+ if (DIL &&
+ Builder.GetInsertBlock()
+ ->getParent()
+ ->shouldEmitDebugInfoForProfiling() &&
+ !EnableFSDiscriminator) {
// FIXME: For scalable vectors, assume vscale=1.
auto NewDIL =
DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue());
@@ -291,6 +380,15 @@ void VPTransformState::setDebugLocFromInst(const Value *V) {
Builder.SetCurrentDebugLocation(DIL);
}
+void VPTransformState::packScalarIntoVectorValue(VPValue *Def,
+ const VPIteration &Instance) {
+ Value *ScalarInst = get(Def, Instance);
+ Value *VectorValue = get(Def, Instance.Part);
+ VectorValue = Builder.CreateInsertElement(
+ VectorValue, ScalarInst, Instance.Lane.getAsRuntimeExpr(Builder, VF));
+ set(Def, VectorValue, Instance.Part);
+}
+
BasicBlock *
VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) {
// BB stands for IR BasicBlocks. VPBB stands for VPlan VPBasicBlocks.
@@ -616,22 +714,17 @@ VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) {
auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader);
Plan->TripCount =
vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE);
+ // Create empty VPRegionBlock, to be filled during processing later.
+ auto *TopRegion = new VPRegionBlock("vector loop", false /*isReplicator*/);
+ VPBlockUtils::insertBlockAfter(TopRegion, VecPreheader);
+ VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
+ VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
return Plan;
}
-VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() {
- VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock();
- for (VPRecipeBase &R : Header->phis()) {
- if (isa<VPActiveLaneMaskPHIRecipe>(&R))
- return cast<VPActiveLaneMaskPHIRecipe>(&R);
- }
- return nullptr;
-}
-
void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
Value *CanonicalIVStartValue,
- VPTransformState &State,
- bool IsEpilogueVectorization) {
+ VPTransformState &State) {
// Check if the backedge taken count is needed, and if so build it.
if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) {
IRBuilder<> Builder(State.CFG.PrevBB->getTerminator());
@@ -648,6 +741,12 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)
State.set(&VectorTripCount, VectorTripCountV, Part);
+ IRBuilder<> Builder(State.CFG.PrevBB->getTerminator());
+ // FIXME: Model VF * UF computation completely in VPlan.
+ State.set(&VFxUF,
+ createStepForVF(Builder, TripCountV->getType(), State.VF, State.UF),
+ 0);
+
// When vectorizing the epilogue loop, the canonical induction start value
// needs to be changed from zero to the value after the main vector loop.
// FIXME: Improve modeling for canonical IV start values in the epilogue loop.
@@ -656,16 +755,12 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
auto *IV = getCanonicalIV();
assert(all_of(IV->users(),
[](const VPUser *U) {
- if (isa<VPScalarIVStepsRecipe>(U) ||
- isa<VPDerivedIVRecipe>(U))
- return true;
- auto *VPI = cast<VPInstruction>(U);
- return VPI->getOpcode() ==
- VPInstruction::CanonicalIVIncrement ||
- VPI->getOpcode() ==
- VPInstruction::CanonicalIVIncrementNUW;
+ return isa<VPScalarIVStepsRecipe>(U) ||
+ isa<VPDerivedIVRecipe>(U) ||
+ cast<VPInstruction>(U)->getOpcode() ==
+ Instruction::Add;
}) &&
- "the canonical IV should only be used by its increments or "
+ "the canonical IV should only be used by its increment or "
"ScalarIVSteps when resetting the start value");
IV->setOperand(0, VPV);
}
@@ -754,11 +849,14 @@ void VPlan::execute(VPTransformState *State) {
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-LLVM_DUMP_METHOD
-void VPlan::print(raw_ostream &O) const {
+void VPlan::printLiveIns(raw_ostream &O) const {
VPSlotTracker SlotTracker(this);
- O << "VPlan '" << getName() << "' {";
+ if (VFxUF.getNumUsers() > 0) {
+ O << "\nLive-in ";
+ VFxUF.printAsOperand(O, SlotTracker);
+ O << " = VF * UF";
+ }
if (VectorTripCount.getNumUsers() > 0) {
O << "\nLive-in ";
@@ -778,6 +876,15 @@ void VPlan::print(raw_ostream &O) const {
TripCount->printAsOperand(O, SlotTracker);
O << " = original trip-count";
O << "\n";
+}
+
+LLVM_DUMP_METHOD
+void VPlan::print(raw_ostream &O) const {
+ VPSlotTracker SlotTracker(this);
+
+ O << "VPlan '" << getName() << "' {";
+
+ printLiveIns(O);
if (!getPreheader()->empty()) {
O << "\n";
@@ -895,11 +1002,18 @@ void VPlanPrinter::dump() {
OS << "graph [labelloc=t, fontsize=30; label=\"Vectorization Plan";
if (!Plan.getName().empty())
OS << "\\n" << DOT::EscapeString(Plan.getName());
- if (Plan.BackedgeTakenCount) {
- OS << ", where:\\n";
- Plan.BackedgeTakenCount->print(OS, SlotTracker);
- OS << " := BackedgeTakenCount";
+
+ {
+ // Print live-ins.
+ std::string Str;
+ raw_string_ostream SS(Str);
+ Plan.printLiveIns(SS);
+ SmallVector<StringRef, 0> Lines;
+ StringRef(Str).rtrim('\n').split(Lines, "\n");
+ for (auto Line : Lines)
+ OS << DOT::EscapeString(Line.str()) << "\\n";
}
+
OS << "\"]\n";
OS << "node [shape=rect, fontname=Courier, fontsize=30]\n";
OS << "edge [fontname=Courier, fontsize=30]\n";
@@ -1035,6 +1149,26 @@ void VPValue::replaceAllUsesWith(VPValue *New) {
}
}
+void VPValue::replaceUsesWithIf(
+ VPValue *New,
+ llvm::function_ref<bool(VPUser &U, unsigned Idx)> ShouldReplace) {
+ for (unsigned J = 0; J < getNumUsers();) {
+ VPUser *User = Users[J];
+ unsigned NumUsers = getNumUsers();
+ for (unsigned I = 0, E = User->getNumOperands(); I < E; ++I) {
+ if (User->getOperand(I) != this || !ShouldReplace(*User, I))
+ continue;
+
+ User->setOperand(I, New);
+ }
+ // If a user got removed after updating the current user, the next user to
+ // update will be moved to the current position, so we only need to
+ // increment the index if the number of users did not change.
+ if (NumUsers == getNumUsers())
+ J++;
+ }
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPValue::printAsOperand(raw_ostream &OS, VPSlotTracker &Tracker) const {
if (const Value *UV = getUnderlyingValue()) {
@@ -1116,6 +1250,8 @@ void VPSlotTracker::assignSlot(const VPValue *V) {
}
void VPSlotTracker::assignSlots(const VPlan &Plan) {
+ if (Plan.VFxUF.getNumUsers() > 0)
+ assignSlot(&Plan.VFxUF);
assignSlot(&Plan.VectorTripCount);
if (Plan.BackedgeTakenCount)
assignSlot(Plan.BackedgeTakenCount);
@@ -1139,6 +1275,11 @@ bool vputils::onlyFirstLaneUsed(VPValue *Def) {
[Def](VPUser *U) { return U->onlyFirstLaneUsed(Def); });
}
+bool vputils::onlyFirstPartUsed(VPValue *Def) {
+ return all_of(Def->users(),
+ [Def](VPUser *U) { return U->onlyFirstPartUsed(Def); });
+}
+
VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
ScalarEvolution &SE) {
if (auto *Expanded = Plan.getSCEVExpansion(Expr))
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 73313465adea..94cb76889813 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -23,6 +23,7 @@
#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
+#include "VPlanAnalysis.h"
#include "VPlanValue.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
@@ -233,9 +234,9 @@ struct VPIteration {
struct VPTransformState {
VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
DominatorTree *DT, IRBuilderBase &Builder,
- InnerLoopVectorizer *ILV, VPlan *Plan)
+ InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx)
: VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan),
- LVer(nullptr) {}
+ LVer(nullptr), TypeAnalysis(Ctx) {}
/// The chosen Vectorization and Unroll Factors of the loop being vectorized.
ElementCount VF;
@@ -274,10 +275,6 @@ struct VPTransformState {
I->second[Part];
}
- bool hasAnyVectorValue(VPValue *Def) const {
- return Data.PerPartOutput.contains(Def);
- }
-
bool hasScalarValue(VPValue *Def, VPIteration Instance) {
auto I = Data.PerPartScalars.find(Def);
if (I == Data.PerPartScalars.end())
@@ -349,8 +346,11 @@ struct VPTransformState {
/// vector of instructions.
void addMetadata(ArrayRef<Value *> To, Instruction *From);
- /// Set the debug location in the builder using the debug location in \p V.
- void setDebugLocFromInst(const Value *V);
+ /// Set the debug location in the builder using the debug location \p DL.
+ void setDebugLocFrom(DebugLoc DL);
+
+ /// Construct the vector value of a scalarized value \p V one lane at a time.
+ void packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance);
/// Hold state information used when constructing the CFG of the output IR,
/// traversing the VPBasicBlocks and generating corresponding IR BasicBlocks.
@@ -410,6 +410,9 @@ struct VPTransformState {
/// Map SCEVs to their expanded values. Populated when executing
/// VPExpandSCEVRecipes.
DenseMap<const SCEV *, Value *> ExpandedSCEVs;
+
+ /// VPlan-based type analysis.
+ VPTypeAnalysis TypeAnalysis;
};
/// VPBlockBase is the building block of the Hierarchical Control-Flow Graph.
@@ -582,6 +585,8 @@ public:
/// This VPBlockBase must have no successors.
void setOneSuccessor(VPBlockBase *Successor) {
assert(Successors.empty() && "Setting one successor when others exist.");
+ assert(Successor->getParent() == getParent() &&
+ "connected blocks must have the same parent");
appendSuccessor(Successor);
}
@@ -693,7 +698,7 @@ public:
};
/// VPRecipeBase is a base class modeling a sequence of one or more output IR
-/// instructions. VPRecipeBase owns the the VPValues it defines through VPDef
+/// instructions. VPRecipeBase owns the VPValues it defines through VPDef
/// and is responsible for deleting its defined values. Single-value
/// VPRecipeBases that also inherit from VPValue must make sure to inherit from
/// VPRecipeBase before VPValue.
@@ -706,13 +711,18 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
/// Each VPRecipe belongs to a single VPBasicBlock.
VPBasicBlock *Parent = nullptr;
+ /// The debug location for the recipe.
+ DebugLoc DL;
+
public:
- VPRecipeBase(const unsigned char SC, ArrayRef<VPValue *> Operands)
- : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe) {}
+ VPRecipeBase(const unsigned char SC, ArrayRef<VPValue *> Operands,
+ DebugLoc DL = {})
+ : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe), DL(DL) {}
template <typename IterT>
- VPRecipeBase(const unsigned char SC, iterator_range<IterT> Operands)
- : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe) {}
+ VPRecipeBase(const unsigned char SC, iterator_range<IterT> Operands,
+ DebugLoc DL = {})
+ : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe), DL(DL) {}
virtual ~VPRecipeBase() = default;
/// \return the VPBasicBlock which this VPRecipe belongs to.
@@ -789,6 +799,9 @@ public:
bool mayReadOrWriteMemory() const {
return mayReadFromMemory() || mayWriteToMemory();
}
+
+ /// Returns the debug location of the recipe.
+ DebugLoc getDebugLoc() const { return DL; }
};
// Helper macro to define common classof implementations for recipes.
@@ -808,153 +821,30 @@ public:
return R->getVPDefID() == VPDefID; \
}
-/// This is a concrete Recipe that models a single VPlan-level instruction.
-/// While as any Recipe it may generate a sequence of IR instructions when
-/// executed, these instructions would always form a single-def expression as
-/// the VPInstruction is also a single def-use vertex.
-class VPInstruction : public VPRecipeBase, public VPValue {
- friend class VPlanSlp;
-
-public:
- /// VPlan opcodes, extending LLVM IR with idiomatics instructions.
- enum {
- FirstOrderRecurrenceSplice =
- Instruction::OtherOpsEnd + 1, // Combines the incoming and previous
- // values of a first-order recurrence.
- Not,
- ICmpULE,
- SLPLoad,
- SLPStore,
- ActiveLaneMask,
- CalculateTripCountMinusVF,
- CanonicalIVIncrement,
- CanonicalIVIncrementNUW,
- // The next two are similar to the above, but instead increment the
- // canonical IV separately for each unrolled part.
- CanonicalIVIncrementForPart,
- CanonicalIVIncrementForPartNUW,
- BranchOnCount,
- BranchOnCond
- };
-
-private:
- typedef unsigned char OpcodeTy;
- OpcodeTy Opcode;
- FastMathFlags FMF;
- DebugLoc DL;
-
- /// An optional name that can be used for the generated IR instruction.
- const std::string Name;
-
- /// Utility method serving execute(): generates a single instance of the
- /// modeled instruction. \returns the generated value for \p Part.
- /// In some cases an existing value is returned rather than a generated
- /// one.
- Value *generateInstruction(VPTransformState &State, unsigned Part);
-
-protected:
- void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); }
-
-public:
- VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
- const Twine &Name = "")
- : VPRecipeBase(VPDef::VPInstructionSC, Operands), VPValue(this),
- Opcode(Opcode), DL(DL), Name(Name.str()) {}
-
- VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
- DebugLoc DL = {}, const Twine &Name = "")
- : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
-
- VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
-
- VPInstruction *clone() const {
- SmallVector<VPValue *, 2> Operands(operands());
- return new VPInstruction(Opcode, Operands, DL, Name);
- }
-
- unsigned getOpcode() const { return Opcode; }
-
- /// Generate the instruction.
- /// TODO: We currently execute only per-part unless a specific instance is
- /// provided.
- void execute(VPTransformState &State) override;
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- /// Print the VPInstruction to \p O.
- void print(raw_ostream &O, const Twine &Indent,
- VPSlotTracker &SlotTracker) const override;
-
- /// Print the VPInstruction to dbgs() (for debugging).
- LLVM_DUMP_METHOD void dump() const;
-#endif
-
- /// Return true if this instruction may modify memory.
- bool mayWriteToMemory() const {
- // TODO: we can use attributes of the called function to rule out memory
- // modifications.
- return Opcode == Instruction::Store || Opcode == Instruction::Call ||
- Opcode == Instruction::Invoke || Opcode == SLPStore;
- }
-
- bool hasResult() const {
- // CallInst may or may not have a result, depending on the called function.
- // Conservatively return calls have results for now.
- switch (getOpcode()) {
- case Instruction::Ret:
- case Instruction::Br:
- case Instruction::Store:
- case Instruction::Switch:
- case Instruction::IndirectBr:
- case Instruction::Resume:
- case Instruction::CatchRet:
- case Instruction::Unreachable:
- case Instruction::Fence:
- case Instruction::AtomicRMW:
- case VPInstruction::BranchOnCond:
- case VPInstruction::BranchOnCount:
- return false;
- default:
- return true;
- }
- }
-
- /// Set the fast-math flags.
- void setFastMathFlags(FastMathFlags FMFNew);
-
- /// Returns true if the recipe only uses the first lane of operand \p Op.
- bool onlyFirstLaneUsed(const VPValue *Op) const override {
- assert(is_contained(operands(), Op) &&
- "Op must be an operand of the recipe");
- if (getOperand(0) != Op)
- return false;
- switch (getOpcode()) {
- default:
- return false;
- case VPInstruction::ActiveLaneMask:
- case VPInstruction::CalculateTripCountMinusVF:
- case VPInstruction::CanonicalIVIncrement:
- case VPInstruction::CanonicalIVIncrementNUW:
- case VPInstruction::CanonicalIVIncrementForPart:
- case VPInstruction::CanonicalIVIncrementForPartNUW:
- case VPInstruction::BranchOnCount:
- return true;
- };
- llvm_unreachable("switch should return");
- }
-};
-
/// Class to record LLVM IR flag for a recipe along with it.
class VPRecipeWithIRFlags : public VPRecipeBase {
enum class OperationType : unsigned char {
+ Cmp,
OverflowingBinOp,
+ DisjointOp,
PossiblyExactOp,
GEPOp,
FPMathOp,
+ NonNegOp,
Other
};
+
+public:
struct WrapFlagsTy {
char HasNUW : 1;
char HasNSW : 1;
+
+ WrapFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {}
+ };
+
+private:
+ struct DisjointFlagsTy {
+ char IsDisjoint : 1;
};
struct ExactFlagsTy {
char IsExact : 1;
@@ -962,6 +852,9 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
struct GEPFlagsTy {
char IsInBounds : 1;
};
+ struct NonNegFlagsTy {
+ char NonNeg : 1;
+ };
struct FastMathFlagsTy {
char AllowReassoc : 1;
char NoNaNs : 1;
@@ -970,56 +863,81 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
char AllowReciprocal : 1;
char AllowContract : 1;
char ApproxFunc : 1;
+
+ FastMathFlagsTy(const FastMathFlags &FMF);
};
OperationType OpType;
union {
+ CmpInst::Predicate CmpPredicate;
WrapFlagsTy WrapFlags;
+ DisjointFlagsTy DisjointFlags;
ExactFlagsTy ExactFlags;
GEPFlagsTy GEPFlags;
+ NonNegFlagsTy NonNegFlags;
FastMathFlagsTy FMFs;
- unsigned char AllFlags;
+ unsigned AllFlags;
};
public:
template <typename IterT>
- VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands)
- : VPRecipeBase(SC, Operands) {
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, DebugLoc DL = {})
+ : VPRecipeBase(SC, Operands, DL) {
OpType = OperationType::Other;
AllFlags = 0;
}
template <typename IterT>
- VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands,
- Instruction &I)
- : VPRecipeWithIRFlags(SC, Operands) {
- if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I)
+ : VPRecipeWithIRFlags(SC, Operands, I.getDebugLoc()) {
+ if (auto *Op = dyn_cast<CmpInst>(&I)) {
+ OpType = OperationType::Cmp;
+ CmpPredicate = Op->getPredicate();
+ } else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
+ OpType = OperationType::DisjointOp;
+ DisjointFlags.IsDisjoint = Op->isDisjoint();
+ } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
OpType = OperationType::OverflowingBinOp;
- WrapFlags.HasNUW = Op->hasNoUnsignedWrap();
- WrapFlags.HasNSW = Op->hasNoSignedWrap();
+ WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
} else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
OpType = OperationType::PossiblyExactOp;
ExactFlags.IsExact = Op->isExact();
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
OpType = OperationType::GEPOp;
GEPFlags.IsInBounds = GEP->isInBounds();
+ } else if (auto *PNNI = dyn_cast<PossiblyNonNegInst>(&I)) {
+ OpType = OperationType::NonNegOp;
+ NonNegFlags.NonNeg = PNNI->hasNonNeg();
} else if (auto *Op = dyn_cast<FPMathOperator>(&I)) {
OpType = OperationType::FPMathOp;
- FastMathFlags FMF = Op->getFastMathFlags();
- FMFs.AllowReassoc = FMF.allowReassoc();
- FMFs.NoNaNs = FMF.noNaNs();
- FMFs.NoInfs = FMF.noInfs();
- FMFs.NoSignedZeros = FMF.noSignedZeros();
- FMFs.AllowReciprocal = FMF.allowReciprocal();
- FMFs.AllowContract = FMF.allowContract();
- FMFs.ApproxFunc = FMF.approxFunc();
+ FMFs = Op->getFastMathFlags();
}
}
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+ CmpInst::Predicate Pred, DebugLoc DL = {})
+ : VPRecipeBase(SC, Operands, DL), OpType(OperationType::Cmp),
+ CmpPredicate(Pred) {}
+
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+ WrapFlagsTy WrapFlags, DebugLoc DL = {})
+ : VPRecipeBase(SC, Operands, DL), OpType(OperationType::OverflowingBinOp),
+ WrapFlags(WrapFlags) {}
+
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+ FastMathFlags FMFs, DebugLoc DL = {})
+ : VPRecipeBase(SC, Operands, DL), OpType(OperationType::FPMathOp),
+ FMFs(FMFs) {}
+
static inline bool classof(const VPRecipeBase *R) {
- return R->getVPDefID() == VPRecipeBase::VPWidenSC ||
+ return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
+ R->getVPDefID() == VPRecipeBase::VPWidenSC ||
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
+ R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
R->getVPDefID() == VPRecipeBase::VPReplicateSC;
}
@@ -1032,6 +950,9 @@ public:
WrapFlags.HasNUW = false;
WrapFlags.HasNSW = false;
break;
+ case OperationType::DisjointOp:
+ DisjointFlags.IsDisjoint = false;
+ break;
case OperationType::PossiblyExactOp:
ExactFlags.IsExact = false;
break;
@@ -1042,6 +963,10 @@ public:
FMFs.NoNaNs = false;
FMFs.NoInfs = false;
break;
+ case OperationType::NonNegOp:
+ NonNegFlags.NonNeg = false;
+ break;
+ case OperationType::Cmp:
case OperationType::Other:
break;
}
@@ -1054,6 +979,9 @@ public:
I->setHasNoUnsignedWrap(WrapFlags.HasNUW);
I->setHasNoSignedWrap(WrapFlags.HasNSW);
break;
+ case OperationType::DisjointOp:
+ cast<PossiblyDisjointInst>(I)->setIsDisjoint(DisjointFlags.IsDisjoint);
+ break;
case OperationType::PossiblyExactOp:
I->setIsExact(ExactFlags.IsExact);
break;
@@ -1069,43 +997,209 @@ public:
I->setHasAllowContract(FMFs.AllowContract);
I->setHasApproxFunc(FMFs.ApproxFunc);
break;
+ case OperationType::NonNegOp:
+ I->setNonNeg(NonNegFlags.NonNeg);
+ break;
+ case OperationType::Cmp:
case OperationType::Other:
break;
}
}
+ CmpInst::Predicate getPredicate() const {
+ assert(OpType == OperationType::Cmp &&
+ "recipe doesn't have a compare predicate");
+ return CmpPredicate;
+ }
+
bool isInBounds() const {
assert(OpType == OperationType::GEPOp &&
"recipe doesn't have inbounds flag");
return GEPFlags.IsInBounds;
}
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- FastMathFlags getFastMathFlags() const {
- FastMathFlags Res;
- Res.setAllowReassoc(FMFs.AllowReassoc);
- Res.setNoNaNs(FMFs.NoNaNs);
- Res.setNoInfs(FMFs.NoInfs);
- Res.setNoSignedZeros(FMFs.NoSignedZeros);
- Res.setAllowReciprocal(FMFs.AllowReciprocal);
- Res.setAllowContract(FMFs.AllowContract);
- Res.setApproxFunc(FMFs.ApproxFunc);
- return Res;
+ /// Returns true if the recipe has fast-math flags.
+ bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; }
+
+ FastMathFlags getFastMathFlags() const;
+
+ bool hasNoUnsignedWrap() const {
+ assert(OpType == OperationType::OverflowingBinOp &&
+ "recipe doesn't have a NUW flag");
+ return WrapFlags.HasNUW;
}
+ bool hasNoSignedWrap() const {
+ assert(OpType == OperationType::OverflowingBinOp &&
+ "recipe doesn't have a NSW flag");
+ return WrapFlags.HasNSW;
+ }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void printFlags(raw_ostream &O) const;
#endif
};
+/// This is a concrete Recipe that models a single VPlan-level instruction.
+/// While as any Recipe it may generate a sequence of IR instructions when
+/// executed, these instructions would always form a single-def expression as
+/// the VPInstruction is also a single def-use vertex.
+class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
+ friend class VPlanSlp;
+
+public:
+ /// VPlan opcodes, extending LLVM IR with idiomatics instructions.
+ enum {
+ FirstOrderRecurrenceSplice =
+ Instruction::OtherOpsEnd + 1, // Combines the incoming and previous
+ // values of a first-order recurrence.
+ Not,
+ SLPLoad,
+ SLPStore,
+ ActiveLaneMask,
+ CalculateTripCountMinusVF,
+ // Increment the canonical IV separately for each unrolled part.
+ CanonicalIVIncrementForPart,
+ BranchOnCount,
+ BranchOnCond
+ };
+
+private:
+ typedef unsigned char OpcodeTy;
+ OpcodeTy Opcode;
+
+ /// An optional name that can be used for the generated IR instruction.
+ const std::string Name;
+
+ /// Utility method serving execute(): generates a single instance of the
+ /// modeled instruction. \returns the generated value for \p Part.
+ /// In some cases an existing value is returned rather than a generated
+ /// one.
+ Value *generateInstruction(VPTransformState &State, unsigned Part);
+
+#if !defined(NDEBUG)
+ /// Return true if the VPInstruction is a floating point math operation, i.e.
+ /// has fast-math flags.
+ bool isFPMathOp() const;
+#endif
+
+protected:
+ void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); }
+
+public:
+ VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
+ const Twine &Name = "")
+ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
+ VPValue(this), Opcode(Opcode), Name(Name.str()) {}
+
+ VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
+ DebugLoc DL = {}, const Twine &Name = "")
+ : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
+
+ VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
+ VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
+
+ VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
+ WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
+ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
+ VPValue(this), Opcode(Opcode), Name(Name.str()) {}
+
+ VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
+ FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = "");
+
+ VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
+
+ unsigned getOpcode() const { return Opcode; }
+
+ /// Generate the instruction.
+ /// TODO: We currently execute only per-part unless a specific instance is
+ /// provided.
+ void execute(VPTransformState &State) override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the VPInstruction to \p O.
+ void print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const override;
+
+ /// Print the VPInstruction to dbgs() (for debugging).
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+ /// Return true if this instruction may modify memory.
+ bool mayWriteToMemory() const {
+ // TODO: we can use attributes of the called function to rule out memory
+ // modifications.
+ return Opcode == Instruction::Store || Opcode == Instruction::Call ||
+ Opcode == Instruction::Invoke || Opcode == SLPStore;
+ }
+
+ bool hasResult() const {
+ // CallInst may or may not have a result, depending on the called function.
+ // Conservatively return calls have results for now.
+ switch (getOpcode()) {
+ case Instruction::Ret:
+ case Instruction::Br:
+ case Instruction::Store:
+ case Instruction::Switch:
+ case Instruction::IndirectBr:
+ case Instruction::Resume:
+ case Instruction::CatchRet:
+ case Instruction::Unreachable:
+ case Instruction::Fence:
+ case Instruction::AtomicRMW:
+ case VPInstruction::BranchOnCond:
+ case VPInstruction::BranchOnCount:
+ return false;
+ default:
+ return true;
+ }
+ }
+
+ /// Returns true if the recipe only uses the first lane of operand \p Op.
+ bool onlyFirstLaneUsed(const VPValue *Op) const override {
+ assert(is_contained(operands(), Op) &&
+ "Op must be an operand of the recipe");
+ if (getOperand(0) != Op)
+ return false;
+ switch (getOpcode()) {
+ default:
+ return false;
+ case VPInstruction::ActiveLaneMask:
+ case VPInstruction::CalculateTripCountMinusVF:
+ case VPInstruction::CanonicalIVIncrementForPart:
+ case VPInstruction::BranchOnCount:
+ return true;
+ };
+ llvm_unreachable("switch should return");
+ }
+
+ /// Returns true if the recipe only uses the first part of operand \p Op.
+ bool onlyFirstPartUsed(const VPValue *Op) const override {
+ assert(is_contained(operands(), Op) &&
+ "Op must be an operand of the recipe");
+ if (getOperand(0) != Op)
+ return false;
+ switch (getOpcode()) {
+ default:
+ return false;
+ case VPInstruction::BranchOnCount:
+ return true;
+ };
+ llvm_unreachable("switch should return");
+ }
+};
+
/// VPWidenRecipe is a recipe for producing a copy of vector type its
/// ingredient. This recipe covers most of the traditional vectorization cases
/// where each ingredient transforms into a vectorized version of itself.
class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue {
+ unsigned Opcode;
public:
template <typename IterT>
VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
- : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I) {}
+ : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I),
+ Opcode(I.getOpcode()) {}
~VPWidenRecipe() override = default;
@@ -1114,6 +1208,8 @@ public:
/// Produce widened copies of all Ingredients.
void execute(VPTransformState &State) override;
+ unsigned getOpcode() const { return Opcode; }
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
@@ -1122,7 +1218,7 @@ public:
};
/// VPWidenCastRecipe is a recipe to create vector cast instructions.
-class VPWidenCastRecipe : public VPRecipeBase, public VPValue {
+class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPValue {
/// Cast instruction opcode.
Instruction::CastOps Opcode;
@@ -1131,15 +1227,19 @@ class VPWidenCastRecipe : public VPRecipeBase, public VPValue {
public:
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
- CastInst *UI = nullptr)
- : VPRecipeBase(VPDef::VPWidenCastSC, Op), VPValue(this, UI),
+ CastInst &UI)
+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, UI), VPValue(this, &UI),
Opcode(Opcode), ResultTy(ResultTy) {
- assert((!UI || UI->getOpcode() == Opcode) &&
+ assert(UI.getOpcode() == Opcode &&
"opcode of underlying cast doesn't match");
- assert((!UI || UI->getType() == ResultTy) &&
+ assert(UI.getType() == ResultTy &&
"result type of underlying cast doesn't match");
}
+ VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPValue(this, nullptr),
+ Opcode(Opcode), ResultTy(ResultTy) {}
+
~VPWidenCastRecipe() override = default;
VP_CLASSOF_IMPL(VPDef::VPWidenCastSC)
@@ -1196,7 +1296,8 @@ public:
struct VPWidenSelectRecipe : public VPRecipeBase, public VPValue {
template <typename IterT>
VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands)
- : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I) {}
+ : VPRecipeBase(VPDef::VPWidenSelectSC, Operands, I.getDebugLoc()),
+ VPValue(this, &I) {}
~VPWidenSelectRecipe() override = default;
@@ -1282,8 +1383,8 @@ public:
class VPHeaderPHIRecipe : public VPRecipeBase, public VPValue {
protected:
VPHeaderPHIRecipe(unsigned char VPDefID, Instruction *UnderlyingInstr,
- VPValue *Start = nullptr)
- : VPRecipeBase(VPDefID, {}), VPValue(this, UnderlyingInstr) {
+ VPValue *Start = nullptr, DebugLoc DL = {})
+ : VPRecipeBase(VPDefID, {}, DL), VPValue(this, UnderlyingInstr) {
if (Start)
addOperand(Start);
}
@@ -1404,7 +1505,7 @@ public:
bool isCanonical() const;
/// Returns the scalar type of the induction.
- const Type *getScalarType() const {
+ Type *getScalarType() const {
return Trunc ? Trunc->getType() : IV->getType();
}
};
@@ -1565,14 +1666,13 @@ public:
/// A recipe for vectorizing a phi-node as a sequence of mask-based select
/// instructions.
class VPBlendRecipe : public VPRecipeBase, public VPValue {
- PHINode *Phi;
-
public:
/// The blend operation is a User of the incoming values and of their
/// respective masks, ordered [I0, M0, I1, M1, ...]. Note that a single value
/// might be incoming with a full mask for which there is no VPValue.
VPBlendRecipe(PHINode *Phi, ArrayRef<VPValue *> Operands)
- : VPRecipeBase(VPDef::VPBlendSC, Operands), VPValue(this, Phi), Phi(Phi) {
+ : VPRecipeBase(VPDef::VPBlendSC, Operands, Phi->getDebugLoc()),
+ VPValue(this, Phi) {
assert(Operands.size() > 0 &&
((Operands.size() == 1) || (Operands.size() % 2 == 0)) &&
"Expected either a single incoming value or a positive even number "
@@ -1701,16 +1801,13 @@ public:
/// The Operands are {ChainOp, VecOp, [Condition]}.
class VPReductionRecipe : public VPRecipeBase, public VPValue {
/// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor *RdxDesc;
- /// Pointer to the TTI, needed to create the target reduction
- const TargetTransformInfo *TTI;
+ const RecurrenceDescriptor &RdxDesc;
public:
- VPReductionRecipe(const RecurrenceDescriptor *R, Instruction *I,
- VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
- const TargetTransformInfo *TTI)
+ VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp)
: VPRecipeBase(VPDef::VPReductionSC, {ChainOp, VecOp}), VPValue(this, I),
- RdxDesc(R), TTI(TTI) {
+ RdxDesc(R) {
if (CondOp)
addOperand(CondOp);
}
@@ -2008,11 +2105,9 @@ public:
/// loop). VPWidenCanonicalIVRecipe represents the vector version of the
/// canonical induction variable.
class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe {
- DebugLoc DL;
-
public:
VPCanonicalIVPHIRecipe(VPValue *StartV, DebugLoc DL)
- : VPHeaderPHIRecipe(VPDef::VPCanonicalIVPHISC, nullptr, StartV), DL(DL) {}
+ : VPHeaderPHIRecipe(VPDef::VPCanonicalIVPHISC, nullptr, StartV, DL) {}
~VPCanonicalIVPHIRecipe() override = default;
@@ -2032,8 +2127,8 @@ public:
#endif
/// Returns the scalar type of the induction.
- const Type *getScalarType() const {
- return getOperand(0)->getLiveInIRValue()->getType();
+ Type *getScalarType() const {
+ return getStartValue()->getLiveInIRValue()->getType();
}
/// Returns true if the recipe only uses the first lane of operand \p Op.
@@ -2043,6 +2138,13 @@ public:
return true;
}
+ /// Returns true if the recipe only uses the first part of operand \p Op.
+ bool onlyFirstPartUsed(const VPValue *Op) const override {
+ assert(is_contained(operands(), Op) &&
+ "Op must be an operand of the recipe");
+ return true;
+ }
+
/// Check if the induction described by \p Kind, /p Start and \p Step is
/// canonical, i.e. has the same start, step (of 1), and type as the
/// canonical IV.
@@ -2055,12 +2157,10 @@ public:
/// TODO: It would be good to use the existing VPWidenPHIRecipe instead and
/// remove VPActiveLaneMaskPHIRecipe.
class VPActiveLaneMaskPHIRecipe : public VPHeaderPHIRecipe {
- DebugLoc DL;
-
public:
VPActiveLaneMaskPHIRecipe(VPValue *StartMask, DebugLoc DL)
- : VPHeaderPHIRecipe(VPDef::VPActiveLaneMaskPHISC, nullptr, StartMask),
- DL(DL) {}
+ : VPHeaderPHIRecipe(VPDef::VPActiveLaneMaskPHISC, nullptr, StartMask,
+ DL) {}
~VPActiveLaneMaskPHIRecipe() override = default;
@@ -2113,19 +2213,24 @@ public:
/// an IV with different start and step values, using Start + CanonicalIV *
/// Step.
class VPDerivedIVRecipe : public VPRecipeBase, public VPValue {
- /// The type of the result value. It may be smaller than the type of the
- /// induction and in this case it will get truncated to ResultTy.
- Type *ResultTy;
+ /// If not nullptr, the result of the induction will get truncated to
+ /// TruncResultTy.
+ Type *TruncResultTy;
- /// Induction descriptor for the induction the canonical IV is transformed to.
- const InductionDescriptor &IndDesc;
+ /// Kind of the induction.
+ const InductionDescriptor::InductionKind Kind;
+ /// If not nullptr, the floating point induction binary operator. Must be set
+ /// for floating point inductions.
+ const FPMathOperator *FPBinOp;
public:
VPDerivedIVRecipe(const InductionDescriptor &IndDesc, VPValue *Start,
VPCanonicalIVPHIRecipe *CanonicalIV, VPValue *Step,
- Type *ResultTy)
+ Type *TruncResultTy)
: VPRecipeBase(VPDef::VPDerivedIVSC, {Start, CanonicalIV, Step}),
- VPValue(this), ResultTy(ResultTy), IndDesc(IndDesc) {}
+ VPValue(this), TruncResultTy(TruncResultTy), Kind(IndDesc.getKind()),
+ FPBinOp(dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp())) {
+ }
~VPDerivedIVRecipe() override = default;
@@ -2141,6 +2246,11 @@ public:
VPSlotTracker &SlotTracker) const override;
#endif
+ Type *getScalarType() const {
+ return TruncResultTy ? TruncResultTy
+ : getStartValue()->getLiveInIRValue()->getType();
+ }
+
VPValue *getStartValue() const { return getOperand(0); }
VPValue *getCanonicalIV() const { return getOperand(1); }
VPValue *getStepValue() const { return getOperand(2); }
@@ -2155,14 +2265,23 @@ public:
/// A recipe for handling phi nodes of integer and floating-point inductions,
/// producing their scalar values.
-class VPScalarIVStepsRecipe : public VPRecipeBase, public VPValue {
- const InductionDescriptor &IndDesc;
+class VPScalarIVStepsRecipe : public VPRecipeWithIRFlags, public VPValue {
+ Instruction::BinaryOps InductionOpcode;
public:
+ VPScalarIVStepsRecipe(VPValue *IV, VPValue *Step,
+ Instruction::BinaryOps Opcode, FastMathFlags FMFs)
+ : VPRecipeWithIRFlags(VPDef::VPScalarIVStepsSC,
+ ArrayRef<VPValue *>({IV, Step}), FMFs),
+ VPValue(this), InductionOpcode(Opcode) {}
+
VPScalarIVStepsRecipe(const InductionDescriptor &IndDesc, VPValue *IV,
VPValue *Step)
- : VPRecipeBase(VPDef::VPScalarIVStepsSC, {IV, Step}), VPValue(this),
- IndDesc(IndDesc) {}
+ : VPScalarIVStepsRecipe(
+ IV, Step, IndDesc.getInductionOpcode(),
+ dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp())
+ ? IndDesc.getInductionBinOp()->getFastMathFlags()
+ : FastMathFlags()) {}
~VPScalarIVStepsRecipe() override = default;
@@ -2445,6 +2564,9 @@ class VPlan {
/// Represents the vector trip count.
VPValue VectorTripCount;
+ /// Represents the loop-invariant VF * UF of the vector loop region.
+ VPValue VFxUF;
+
/// Holds a mapping between Values and their corresponding VPValue inside
/// VPlan.
Value2VPValueTy Value2VPValue;
@@ -2490,15 +2612,17 @@ public:
~VPlan();
- /// Create an initial VPlan with preheader and entry blocks. Creates a
- /// VPExpandSCEVRecipe for \p TripCount and uses it as plan's trip count.
+ /// Create initial VPlan skeleton, having an "entry" VPBasicBlock (wrapping
+ /// original scalar pre-header) which contains SCEV expansions that need to
+ /// happen before the CFG is modified; a VPBasicBlock for the vector
+ /// pre-header, followed by a region for the vector loop, followed by the
+ /// middle VPBasicBlock.
static VPlanPtr createInitialVPlan(const SCEV *TripCount,
ScalarEvolution &PSE);
/// Prepare the plan for execution, setting up the required live-in values.
void prepareToExecute(Value *TripCount, Value *VectorTripCount,
- Value *CanonicalIVStartValue, VPTransformState &State,
- bool IsEpilogueVectorization);
+ Value *CanonicalIVStartValue, VPTransformState &State);
/// Generate the IR code for this VPlan.
void execute(VPTransformState *State);
@@ -2522,6 +2646,9 @@ public:
/// The vector trip count.
VPValue &getVectorTripCount() { return VectorTripCount; }
+ /// Returns VF * UF of the vector loop region.
+ VPValue &getVFxUF() { return VFxUF; }
+
/// Mark the plan to indicate that using Value2VPValue is not safe any
/// longer, because it may be stale.
void disableValue2VPValue() { Value2VPValueEnabled = false; }
@@ -2583,13 +2710,10 @@ public:
return getVPValue(V);
}
- void removeVPValueFor(Value *V) {
- assert(Value2VPValueEnabled &&
- "IR value to VPValue mapping may be out of date!");
- Value2VPValue.erase(V);
- }
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// Print the live-ins of this VPlan to \p O.
+ void printLiveIns(raw_ostream &O) const;
+
/// Print this VPlan to \p O.
void print(raw_ostream &O) const;
@@ -2628,10 +2752,6 @@ public:
return cast<VPCanonicalIVPHIRecipe>(&*EntryVPBB->begin());
}
- /// Find and return the VPActiveLaneMaskPHIRecipe from the header - there
- /// be only one at most. If there isn't one, then return nullptr.
- VPActiveLaneMaskPHIRecipe *getActiveLaneMaskPhi();
-
void addLiveOut(PHINode *PN, VPValue *V);
void removeLiveOut(PHINode *PN) {
@@ -2959,6 +3079,9 @@ namespace vputils {
/// Returns true if only the first lane of \p Def is used.
bool onlyFirstLaneUsed(VPValue *Def);
+/// Returns true if only the first part of \p Def is used.
+bool onlyFirstPartUsed(VPValue *Def);
+
/// Get or create a VPValue that corresponds to the expansion of \p Expr. If \p
/// Expr is a SCEVConstant or SCEVUnknown, return a VPValue wrapping the live-in
/// value. Otherwise return a VPExpandSCEVRecipe to expand \p Expr. If \p Plan's
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
new file mode 100644
index 000000000000..97a8a1803bbf
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -0,0 +1,237 @@
+//===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "VPlanAnalysis.h"
+#include "VPlan.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "vplan"
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) {
+ Type *ResTy = inferScalarType(R->getIncomingValue(0));
+ for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) {
+ VPValue *Inc = R->getIncomingValue(I);
+ assert(inferScalarType(Inc) == ResTy &&
+ "different types inferred for different incoming values");
+ CachedTypes[Inc] = ResTy;
+ }
+ return ResTy;
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
+ switch (R->getOpcode()) {
+ case Instruction::Select: {
+ Type *ResTy = inferScalarType(R->getOperand(1));
+ VPValue *OtherV = R->getOperand(2);
+ assert(inferScalarType(OtherV) == ResTy &&
+ "different types inferred for different operands");
+ CachedTypes[OtherV] = ResTy;
+ return ResTy;
+ }
+ case VPInstruction::FirstOrderRecurrenceSplice: {
+ Type *ResTy = inferScalarType(R->getOperand(0));
+ VPValue *OtherV = R->getOperand(1);
+ assert(inferScalarType(OtherV) == ResTy &&
+ "different types inferred for different operands");
+ CachedTypes[OtherV] = ResTy;
+ return ResTy;
+ }
+ default:
+ break;
+ }
+ // Type inference not implemented for opcode.
+ LLVM_DEBUG({
+ dbgs() << "LV: Found unhandled opcode for: ";
+ R->getVPSingleValue()->dump();
+ });
+ llvm_unreachable("Unhandled opcode!");
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
+ unsigned Opcode = R->getOpcode();
+ switch (Opcode) {
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return IntegerType::get(Ctx, 1);
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::SRem:
+ case Instruction::URem:
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::FDiv:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor: {
+ Type *ResTy = inferScalarType(R->getOperand(0));
+ assert(ResTy == inferScalarType(R->getOperand(1)) &&
+ "types for both operands must match for binary op");
+ CachedTypes[R->getOperand(1)] = ResTy;
+ return ResTy;
+ }
+ case Instruction::FNeg:
+ case Instruction::Freeze:
+ return inferScalarType(R->getOperand(0));
+ default:
+ break;
+ }
+
+ // Type inference not implemented for opcode.
+ LLVM_DEBUG({
+ dbgs() << "LV: Found unhandled opcode for: ";
+ R->getVPSingleValue()->dump();
+ });
+ llvm_unreachable("Unhandled opcode!");
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
+ auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
+ return CI.getType();
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(
+ const VPWidenMemoryInstructionRecipe *R) {
+ assert(!R->isStore() && "Store recipes should not define any values");
+ return cast<LoadInst>(&R->getIngredient())->getType();
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
+ Type *ResTy = inferScalarType(R->getOperand(1));
+ VPValue *OtherV = R->getOperand(2);
+ assert(inferScalarType(OtherV) == ResTy &&
+ "different types inferred for different operands");
+ CachedTypes[OtherV] = ResTy;
+ return ResTy;
+}
+
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
+ switch (R->getUnderlyingInstr()->getOpcode()) {
+ case Instruction::Call: {
+ unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
+ return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
+ ->getReturnType();
+ }
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::SRem:
+ case Instruction::URem:
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::FDiv:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor: {
+ Type *ResTy = inferScalarType(R->getOperand(0));
+ assert(ResTy == inferScalarType(R->getOperand(1)) &&
+ "inferred types for operands of binary op don't match");
+ CachedTypes[R->getOperand(1)] = ResTy;
+ return ResTy;
+ }
+ case Instruction::Select: {
+ Type *ResTy = inferScalarType(R->getOperand(1));
+ assert(ResTy == inferScalarType(R->getOperand(2)) &&
+ "inferred types for operands of select op don't match");
+ CachedTypes[R->getOperand(2)] = ResTy;
+ return ResTy;
+ }
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return IntegerType::get(Ctx, 1);
+ case Instruction::Alloca:
+ case Instruction::BitCast:
+ case Instruction::Trunc:
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ case Instruction::FPExt:
+ case Instruction::FPTrunc:
+ case Instruction::ExtractValue:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ case Instruction::FPToSI:
+ case Instruction::FPToUI:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ return R->getUnderlyingInstr()->getType();
+ case Instruction::Freeze:
+ case Instruction::FNeg:
+ case Instruction::GetElementPtr:
+ return inferScalarType(R->getOperand(0));
+ case Instruction::Load:
+ return cast<LoadInst>(R->getUnderlyingInstr())->getType();
+ case Instruction::Store:
+ // FIXME: VPReplicateRecipes with store opcodes still define a result
+ // VPValue, so we need to handle them here. Remove the code here once this
+ // is modeled accurately in VPlan.
+ return Type::getVoidTy(Ctx);
+ default:
+ break;
+ }
+ // Type inference not implemented for opcode.
+ LLVM_DEBUG({
+ dbgs() << "LV: Found unhandled opcode for: ";
+ R->getVPSingleValue()->dump();
+ });
+ llvm_unreachable("Unhandled opcode");
+}
+
+Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
+ if (Type *CachedTy = CachedTypes.lookup(V))
+ return CachedTy;
+
+ if (V->isLiveIn())
+ return V->getLiveInIRValue()->getType();
+
+ Type *ResultTy =
+ TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
+ .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe,
+ VPReductionPHIRecipe, VPWidenPointerInductionRecipe>(
+ [this](const auto *R) {
+ // Handle header phi recipes, except VPWienIntOrFpInduction
+ // which needs special handling due it being possibly truncated.
+ // TODO: consider inferring/caching type of siblings, e.g.,
+ // backedge value, here and in cases below.
+ return inferScalarType(R->getStartValue());
+ })
+ .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>(
+ [](const auto *R) { return R->getScalarType(); })
+ .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe,
+ VPWidenGEPRecipe>([this](const VPRecipeBase *R) {
+ return inferScalarType(R->getOperand(0));
+ })
+ .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
+ VPWidenCallRecipe, VPWidenMemoryInstructionRecipe,
+ VPWidenSelectRecipe>(
+ [this](const auto *R) { return inferScalarTypeForRecipe(R); })
+ .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
+ // TODO: Use info from interleave group.
+ return V->getUnderlyingValue()->getType();
+ })
+ .Case<VPWidenCastRecipe>(
+ [](const VPWidenCastRecipe *R) { return R->getResultType(); });
+ assert(ResultTy && "could not infer type for the given VPValue");
+ CachedTypes[V] = ResultTy;
+ return ResultTy;
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
new file mode 100644
index 000000000000..473a7c28e48a
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -0,0 +1,64 @@
+//===- VPlanAnalysis.h - Various Analyses working on VPlan ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+#define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace llvm {
+
+class LLVMContext;
+class VPValue;
+class VPBlendRecipe;
+class VPInterleaveRecipe;
+class VPInstruction;
+class VPReductionPHIRecipe;
+class VPWidenRecipe;
+class VPWidenCallRecipe;
+class VPWidenCastRecipe;
+class VPWidenIntOrFpInductionRecipe;
+class VPWidenMemoryInstructionRecipe;
+struct VPWidenSelectRecipe;
+class VPReplicateRecipe;
+class Type;
+
+/// An analysis for type-inference for VPValues.
+/// It infers the scalar type for a given VPValue by bottom-up traversing
+/// through defining recipes until root nodes with known types are reached (e.g.
+/// live-ins or load recipes). The types are then propagated top down through
+/// operations.
+/// Note that the analysis caches the inferred types. A new analysis object must
+/// be constructed once a VPlan has been modified in a way that invalidates any
+/// of the previously inferred types.
+class VPTypeAnalysis {
+ DenseMap<const VPValue *, Type *> CachedTypes;
+ LLVMContext &Ctx;
+
+ Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPInstruction *R);
+ Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPWidenMemoryInstructionRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
+ Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+
+public:
+ VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {}
+
+ /// Infer the type of \p V. Returns the scalar type of \p V.
+ Type *inferScalarType(const VPValue *V);
+
+ /// Return the LLVMContext used by the analysis.
+ LLVMContext &getContext() { return Ctx; }
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
index f6e3a2a16db8..f950d4740e41 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp
@@ -61,6 +61,7 @@ private:
// Utility functions.
void setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB);
+ void setRegionPredsFromBB(VPRegionBlock *VPBB, BasicBlock *BB);
void fixPhiNodes();
VPBasicBlock *getOrCreateVPBB(BasicBlock *BB);
#ifndef NDEBUG
@@ -81,14 +82,43 @@ public:
// Set predecessors of \p VPBB in the same order as they are in \p BB. \p VPBB
// must have no predecessors.
void PlainCFGBuilder::setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB) {
- SmallVector<VPBlockBase *, 8> VPBBPreds;
+ auto GetLatchOfExit = [this](BasicBlock *BB) -> BasicBlock * {
+ auto *SinglePred = BB->getSinglePredecessor();
+ Loop *LoopForBB = LI->getLoopFor(BB);
+ if (!SinglePred || LI->getLoopFor(SinglePred) == LoopForBB)
+ return nullptr;
+ // The input IR must be in loop-simplify form, ensuring a single predecessor
+ // for exit blocks.
+ assert(SinglePred == LI->getLoopFor(SinglePred)->getLoopLatch() &&
+ "SinglePred must be the only loop latch");
+ return SinglePred;
+ };
+ if (auto *LatchBB = GetLatchOfExit(BB)) {
+ auto *PredRegion = getOrCreateVPBB(LatchBB)->getParent();
+ assert(VPBB == cast<VPBasicBlock>(PredRegion->getSingleSuccessor()) &&
+ "successor must already be set for PredRegion; it must have VPBB "
+ "as single successor");
+ VPBB->setPredecessors({PredRegion});
+ return;
+ }
// Collect VPBB predecessors.
+ SmallVector<VPBlockBase *, 2> VPBBPreds;
for (BasicBlock *Pred : predecessors(BB))
VPBBPreds.push_back(getOrCreateVPBB(Pred));
-
VPBB->setPredecessors(VPBBPreds);
}
+static bool isHeaderBB(BasicBlock *BB, Loop *L) {
+ return L && BB == L->getHeader();
+}
+
+void PlainCFGBuilder::setRegionPredsFromBB(VPRegionBlock *Region,
+ BasicBlock *BB) {
+ // BB is a loop header block. Connect the region to the loop preheader.
+ Loop *LoopOfBB = LI->getLoopFor(BB);
+ Region->setPredecessors({getOrCreateVPBB(LoopOfBB->getLoopPredecessor())});
+}
+
// Add operands to VPInstructions representing phi nodes from the input IR.
void PlainCFGBuilder::fixPhiNodes() {
for (auto *Phi : PhisToFix) {
@@ -100,38 +130,85 @@ void PlainCFGBuilder::fixPhiNodes() {
assert(VPPhi->getNumOperands() == 0 &&
"Expected VPInstruction with no operands.");
+ Loop *L = LI->getLoopFor(Phi->getParent());
+ if (isHeaderBB(Phi->getParent(), L)) {
+ // For header phis, make sure the incoming value from the loop
+ // predecessor is the first operand of the recipe.
+ assert(Phi->getNumOperands() == 2);
+ BasicBlock *LoopPred = L->getLoopPredecessor();
+ VPPhi->addIncoming(
+ getOrCreateVPOperand(Phi->getIncomingValueForBlock(LoopPred)),
+ BB2VPBB[LoopPred]);
+ BasicBlock *LoopLatch = L->getLoopLatch();
+ VPPhi->addIncoming(
+ getOrCreateVPOperand(Phi->getIncomingValueForBlock(LoopLatch)),
+ BB2VPBB[LoopLatch]);
+ continue;
+ }
+
for (unsigned I = 0; I != Phi->getNumOperands(); ++I)
VPPhi->addIncoming(getOrCreateVPOperand(Phi->getIncomingValue(I)),
BB2VPBB[Phi->getIncomingBlock(I)]);
}
}
+static bool isHeaderVPBB(VPBasicBlock *VPBB) {
+ return VPBB->getParent() && VPBB->getParent()->getEntry() == VPBB;
+}
+
+/// Return true of \p L loop is contained within \p OuterLoop.
+static bool doesContainLoop(const Loop *L, const Loop *OuterLoop) {
+ if (L->getLoopDepth() < OuterLoop->getLoopDepth())
+ return false;
+ const Loop *P = L;
+ while (P) {
+ if (P == OuterLoop)
+ return true;
+ P = P->getParentLoop();
+ }
+ return false;
+}
+
// Create a new empty VPBasicBlock for an incoming BasicBlock in the region
// corresponding to the containing loop or retrieve an existing one if it was
// already created. If no region exists yet for the loop containing \p BB, a new
// one is created.
VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) {
- auto BlockIt = BB2VPBB.find(BB);
- if (BlockIt != BB2VPBB.end())
+ if (auto *VPBB = BB2VPBB.lookup(BB)) {
// Retrieve existing VPBB.
- return BlockIt->second;
-
- // Get or create a region for the loop containing BB.
- Loop *CurrentLoop = LI->getLoopFor(BB);
- VPRegionBlock *ParentR = nullptr;
- if (CurrentLoop) {
- auto Iter = Loop2Region.insert({CurrentLoop, nullptr});
- if (Iter.second)
- Iter.first->second = new VPRegionBlock(
- CurrentLoop->getHeader()->getName().str(), false /*isReplicator*/);
- ParentR = Iter.first->second;
+ return VPBB;
}
// Create new VPBB.
- LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << BB->getName() << "\n");
- VPBasicBlock *VPBB = new VPBasicBlock(BB->getName());
+ StringRef Name = isHeaderBB(BB, TheLoop) ? "vector.body" : BB->getName();
+ LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << Name << "\n");
+ VPBasicBlock *VPBB = new VPBasicBlock(Name);
BB2VPBB[BB] = VPBB;
- VPBB->setParent(ParentR);
+
+ // Get or create a region for the loop containing BB.
+ Loop *LoopOfBB = LI->getLoopFor(BB);
+ if (!LoopOfBB || !doesContainLoop(LoopOfBB, TheLoop))
+ return VPBB;
+
+ auto *RegionOfVPBB = Loop2Region.lookup(LoopOfBB);
+ if (!isHeaderBB(BB, LoopOfBB)) {
+ assert(RegionOfVPBB &&
+ "Region should have been created by visiting header earlier");
+ VPBB->setParent(RegionOfVPBB);
+ return VPBB;
+ }
+
+ assert(!RegionOfVPBB &&
+ "First visit of a header basic block expects to register its region.");
+ // Handle a header - take care of its Region.
+ if (LoopOfBB == TheLoop) {
+ RegionOfVPBB = Plan.getVectorLoopRegion();
+ } else {
+ RegionOfVPBB = new VPRegionBlock(Name.str(), false /*isReplicator*/);
+ RegionOfVPBB->setParent(Loop2Region[LoopOfBB->getParentLoop()]);
+ }
+ RegionOfVPBB->setEntry(VPBB);
+ Loop2Region[LoopOfBB] = RegionOfVPBB;
return VPBB;
}
@@ -254,6 +331,25 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
// Main interface to build the plain CFG.
void PlainCFGBuilder::buildPlainCFG() {
+ // 0. Reuse the top-level region, vector-preheader and exit VPBBs from the
+ // skeleton. These were created directly rather than via getOrCreateVPBB(),
+ // revisit them now to update BB2VPBB. Note that header/entry and
+ // latch/exiting VPBB's of top-level region have yet to be created.
+ VPRegionBlock *TheRegion = Plan.getVectorLoopRegion();
+ BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader();
+ assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) &&
+ "Unexpected loop preheader");
+ auto *VectorPreheaderVPBB =
+ cast<VPBasicBlock>(TheRegion->getSinglePredecessor());
+ // ThePreheaderBB conceptually corresponds to both Plan.getPreheader() (which
+ // wraps the original preheader BB) and Plan.getEntry() (which represents the
+ // new vector preheader); here we're interested in setting BB2VPBB to the
+ // latter.
+ BB2VPBB[ThePreheaderBB] = VectorPreheaderVPBB;
+ BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock();
+ assert(LoopExitBB && "Loops with multiple exits are not supported.");
+ BB2VPBB[LoopExitBB] = cast<VPBasicBlock>(TheRegion->getSingleSuccessor());
+
// 1. Scan the body of the loop in a topological order to visit each basic
// block after having visited its predecessor basic blocks. Create a VPBB for
// each BB and link it to its successor and predecessor VPBBs. Note that
@@ -263,21 +359,11 @@ void PlainCFGBuilder::buildPlainCFG() {
// Loop PH needs to be explicitly visited since it's not taken into account by
// LoopBlocksDFS.
- BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader();
- assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) &&
- "Unexpected loop preheader");
- VPBasicBlock *ThePreheaderVPBB = Plan.getEntry();
- BB2VPBB[ThePreheaderBB] = ThePreheaderVPBB;
- ThePreheaderVPBB->setName("vector.ph");
for (auto &I : *ThePreheaderBB) {
if (I.getType()->isVoidTy())
continue;
IRDef2VPValue[&I] = Plan.getVPValueOrAddLiveIn(&I);
}
- // Create empty VPBB for Loop H so that we can link PH->H.
- VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader());
- HeaderVPBB->setName("vector.body");
- ThePreheaderVPBB->setOneSuccessor(HeaderVPBB);
LoopBlocksRPO RPO(TheLoop);
RPO.perform(LI);
@@ -286,88 +372,55 @@ void PlainCFGBuilder::buildPlainCFG() {
// Create or retrieve the VPBasicBlock for this BB and create its
// VPInstructions.
VPBasicBlock *VPBB = getOrCreateVPBB(BB);
+ VPRegionBlock *Region = VPBB->getParent();
createVPInstructionsForVPBB(VPBB, BB);
+ Loop *LoopForBB = LI->getLoopFor(BB);
+ // Set VPBB predecessors in the same order as they are in the incoming BB.
+ if (!isHeaderBB(BB, LoopForBB)) {
+ setVPBBPredsFromBB(VPBB, BB);
+ } else {
+ // BB is a loop header, set the predecessor for the region, except for the
+ // top region, whose predecessor was set when creating VPlan's skeleton.
+ assert(isHeaderVPBB(VPBB) && "isHeaderBB and isHeaderVPBB disagree");
+ if (TheRegion != Region)
+ setRegionPredsFromBB(Region, BB);
+ }
// Set VPBB successors. We create empty VPBBs for successors if they don't
// exist already. Recipes will be created when the successor is visited
// during the RPO traversal.
- Instruction *TI = BB->getTerminator();
- assert(TI && "Terminator expected.");
- unsigned NumSuccs = TI->getNumSuccessors();
-
+ auto *BI = cast<BranchInst>(BB->getTerminator());
+ unsigned NumSuccs = succ_size(BB);
if (NumSuccs == 1) {
- VPBasicBlock *SuccVPBB = getOrCreateVPBB(TI->getSuccessor(0));
- assert(SuccVPBB && "VPBB Successor not found.");
- VPBB->setOneSuccessor(SuccVPBB);
- } else if (NumSuccs == 2) {
- VPBasicBlock *SuccVPBB0 = getOrCreateVPBB(TI->getSuccessor(0));
- assert(SuccVPBB0 && "Successor 0 not found.");
- VPBasicBlock *SuccVPBB1 = getOrCreateVPBB(TI->getSuccessor(1));
- assert(SuccVPBB1 && "Successor 1 not found.");
-
- // Get VPBB's condition bit.
- assert(isa<BranchInst>(TI) && "Unsupported terminator!");
- // Look up the branch condition to get the corresponding VPValue
- // representing the condition bit in VPlan (which may be in another VPBB).
- assert(IRDef2VPValue.count(cast<BranchInst>(TI)->getCondition()) &&
- "Missing condition bit in IRDef2VPValue!");
-
- // Link successors.
- VPBB->setTwoSuccessors(SuccVPBB0, SuccVPBB1);
- } else
- llvm_unreachable("Number of successors not supported.");
-
- // Set VPBB predecessors in the same order as they are in the incoming BB.
- setVPBBPredsFromBB(VPBB, BB);
+ auto *Successor = getOrCreateVPBB(BB->getSingleSuccessor());
+ VPBB->setOneSuccessor(isHeaderVPBB(Successor)
+ ? Successor->getParent()
+ : static_cast<VPBlockBase *>(Successor));
+ continue;
+ }
+ assert(BI->isConditional() && NumSuccs == 2 && BI->isConditional() &&
+ "block must have conditional branch with 2 successors");
+ // Look up the branch condition to get the corresponding VPValue
+ // representing the condition bit in VPlan (which may be in another VPBB).
+ assert(IRDef2VPValue.contains(BI->getCondition()) &&
+ "Missing condition bit in IRDef2VPValue!");
+ VPBasicBlock *Successor0 = getOrCreateVPBB(BI->getSuccessor(0));
+ VPBasicBlock *Successor1 = getOrCreateVPBB(BI->getSuccessor(1));
+ if (!LoopForBB || BB != LoopForBB->getLoopLatch()) {
+ VPBB->setTwoSuccessors(Successor0, Successor1);
+ continue;
+ }
+ // For a latch we need to set the successor of the region rather than that
+ // of VPBB and it should be set to the exit, i.e., non-header successor,
+ // except for the top region, whose successor was set when creating VPlan's
+ // skeleton.
+ if (TheRegion != Region)
+ Region->setOneSuccessor(isHeaderVPBB(Successor0) ? Successor1
+ : Successor0);
+ Region->setExiting(VPBB);
}
- // 2. Process outermost loop exit. We created an empty VPBB for the loop
- // single exit BB during the RPO traversal of the loop body but Instructions
- // weren't visited because it's not part of the the loop.
- BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock();
- assert(LoopExitBB && "Loops with multiple exits are not supported.");
- VPBasicBlock *LoopExitVPBB = BB2VPBB[LoopExitBB];
- // Loop exit was already set as successor of the loop exiting BB.
- // We only set its predecessor VPBB now.
- setVPBBPredsFromBB(LoopExitVPBB, LoopExitBB);
-
- // 3. Fix up region blocks for loops. For each loop,
- // * use the header block as entry to the corresponding region,
- // * use the latch block as exit of the corresponding region,
- // * set the region as successor of the loop pre-header, and
- // * set the exit block as successor to the region.
- SmallVector<Loop *> LoopWorkList;
- LoopWorkList.push_back(TheLoop);
- while (!LoopWorkList.empty()) {
- Loop *L = LoopWorkList.pop_back_val();
- BasicBlock *Header = L->getHeader();
- BasicBlock *Exiting = L->getLoopLatch();
- assert(Exiting == L->getExitingBlock() &&
- "Latch must be the only exiting block");
- VPRegionBlock *Region = Loop2Region[L];
- VPBasicBlock *HeaderVPBB = getOrCreateVPBB(Header);
- VPBasicBlock *ExitingVPBB = getOrCreateVPBB(Exiting);
-
- // Disconnect backedge and pre-header from header.
- VPBasicBlock *PreheaderVPBB = getOrCreateVPBB(L->getLoopPreheader());
- VPBlockUtils::disconnectBlocks(PreheaderVPBB, HeaderVPBB);
- VPBlockUtils::disconnectBlocks(ExitingVPBB, HeaderVPBB);
-
- Region->setParent(PreheaderVPBB->getParent());
- Region->setEntry(HeaderVPBB);
- VPBlockUtils::connectBlocks(PreheaderVPBB, Region);
-
- // Disconnect exit block from exiting (=latch) block, set exiting block and
- // connect region to exit block.
- VPBasicBlock *ExitVPBB = getOrCreateVPBB(L->getExitBlock());
- VPBlockUtils::disconnectBlocks(ExitingVPBB, ExitVPBB);
- Region->setExiting(ExitingVPBB);
- VPBlockUtils::connectBlocks(Region, ExitVPBB);
-
- // Queue sub-loops for processing.
- LoopWorkList.append(L->begin(), L->end());
- }
- // 4. The whole CFG has been built at this point so all the input Values must
+ // 2. The whole CFG has been built at this point so all the input Values must
// have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding
// VPlan operands.
fixPhiNodes();
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 26c309eed800..c23428e2ba34 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "VPlan.h"
+#include "VPlanAnalysis.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
@@ -114,6 +115,16 @@ bool VPRecipeBase::mayHaveSideEffects() const {
case VPDerivedIVSC:
case VPPredInstPHISC:
return false;
+ case VPInstructionSC:
+ switch (cast<VPInstruction>(this)->getOpcode()) {
+ case Instruction::ICmp:
+ case VPInstruction::Not:
+ case VPInstruction::CalculateTripCountMinusVF:
+ case VPInstruction::CanonicalIVIncrementForPart:
+ return false;
+ default:
+ return true;
+ }
case VPWidenCallSC:
return cast<Instruction>(getVPSingleValue()->getUnderlyingValue())
->mayHaveSideEffects();
@@ -156,8 +167,13 @@ void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) {
VPValue *ExitValue = getOperand(0);
if (vputils::isUniformAfterVectorization(ExitValue))
Lane = VPLane::getFirstLane();
+ VPBasicBlock *MiddleVPBB =
+ cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
+ assert(MiddleVPBB->getNumSuccessors() == 0 &&
+ "the middle block must not have any successors");
+ BasicBlock *MiddleBB = State.CFG.VPBB2IRBB[MiddleVPBB];
Phi->addIncoming(State.get(ExitValue, VPIteration(State.UF - 1, Lane)),
- State.Builder.GetInsertBlock());
+ MiddleBB);
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -216,15 +232,55 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
insertBefore(BB, I);
}
+FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
+ assert(OpType == OperationType::FPMathOp &&
+ "recipe doesn't have fast math flags");
+ FastMathFlags Res;
+ Res.setAllowReassoc(FMFs.AllowReassoc);
+ Res.setNoNaNs(FMFs.NoNaNs);
+ Res.setNoInfs(FMFs.NoInfs);
+ Res.setNoSignedZeros(FMFs.NoSignedZeros);
+ Res.setAllowReciprocal(FMFs.AllowReciprocal);
+ Res.setAllowContract(FMFs.AllowContract);
+ Res.setApproxFunc(FMFs.ApproxFunc);
+ return Res;
+}
+
+VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
+ VPValue *A, VPValue *B, DebugLoc DL,
+ const Twine &Name)
+ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
+ Pred, DL),
+ VPValue(this), Opcode(Opcode), Name(Name.str()) {
+ assert(Opcode == Instruction::ICmp &&
+ "only ICmp predicates supported at the moment");
+}
+
+VPInstruction::VPInstruction(unsigned Opcode,
+ std::initializer_list<VPValue *> Operands,
+ FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
+ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL),
+ VPValue(this), Opcode(Opcode), Name(Name.str()) {
+ // Make sure the VPInstruction is a floating-point operation.
+ assert(isFPMathOp() && "this op can't take fast-math flags");
+}
+
Value *VPInstruction::generateInstruction(VPTransformState &State,
unsigned Part) {
IRBuilderBase &Builder = State.Builder;
- Builder.SetCurrentDebugLocation(DL);
+ Builder.SetCurrentDebugLocation(getDebugLoc());
if (Instruction::isBinaryOp(getOpcode())) {
+ if (Part != 0 && vputils::onlyFirstPartUsed(this))
+ return State.get(this, 0);
+
Value *A = State.get(getOperand(0), Part);
Value *B = State.get(getOperand(1), Part);
- return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
+ auto *Res =
+ Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
+ if (auto *I = dyn_cast<Instruction>(Res))
+ setFlags(I);
+ return Res;
}
switch (getOpcode()) {
@@ -232,10 +288,10 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
Value *A = State.get(getOperand(0), Part);
return Builder.CreateNot(A, Name);
}
- case VPInstruction::ICmpULE: {
- Value *IV = State.get(getOperand(0), Part);
- Value *TC = State.get(getOperand(1), Part);
- return Builder.CreateICmpULE(IV, TC, Name);
+ case Instruction::ICmp: {
+ Value *A = State.get(getOperand(0), Part);
+ Value *B = State.get(getOperand(1), Part);
+ return Builder.CreateCmp(getPredicate(), A, B, Name);
}
case Instruction::Select: {
Value *Cond = State.get(getOperand(0), Part);
@@ -285,23 +341,7 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
return Builder.CreateSelect(Cmp, Sub, Zero);
}
- case VPInstruction::CanonicalIVIncrement:
- case VPInstruction::CanonicalIVIncrementNUW: {
- if (Part == 0) {
- bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW;
- auto *Phi = State.get(getOperand(0), 0);
- // The loop step is equal to the vectorization factor (num of SIMD
- // elements) times the unroll factor (num of SIMD instructions).
- Value *Step =
- createStepForVF(Builder, Phi->getType(), State.VF, State.UF);
- return Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
- }
- return State.get(this, 0);
- }
-
- case VPInstruction::CanonicalIVIncrementForPart:
- case VPInstruction::CanonicalIVIncrementForPartNUW: {
- bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW;
+ case VPInstruction::CanonicalIVIncrementForPart: {
auto *IV = State.get(getOperand(0), VPIteration(0, 0));
if (Part == 0)
return IV;
@@ -309,7 +349,8 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
// The canonical IV is incremented by the vectorization factor (num of SIMD
// elements) times the unroll part.
Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part);
- return Builder.CreateAdd(IV, Step, Name, IsNUW, false);
+ return Builder.CreateAdd(IV, Step, Name, hasNoUnsignedWrap(),
+ hasNoSignedWrap());
}
case VPInstruction::BranchOnCond: {
if (Part != 0)
@@ -361,10 +402,25 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
}
}
+#if !defined(NDEBUG)
+bool VPInstruction::isFPMathOp() const {
+ // Inspired by FPMathOperator::classof. Notable differences are that we don't
+ // support Call, PHI and Select opcodes here yet.
+ return Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
+ Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
+ Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
+ Opcode == Instruction::FCmp || Opcode == Instruction::Select;
+}
+#endif
+
void VPInstruction::execute(VPTransformState &State) {
assert(!State.Instance && "VPInstruction executing an Instance");
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
- State.Builder.setFastMathFlags(FMF);
+ assert((hasFastMathFlags() == isFPMathOp() ||
+ getOpcode() == Instruction::Select) &&
+ "Recipe not a FPMathOp but has fast-math flags?");
+ if (hasFastMathFlags())
+ State.Builder.setFastMathFlags(getFastMathFlags());
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *GeneratedValue = generateInstruction(State, Part);
if (!hasResult())
@@ -393,9 +449,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::Not:
O << "not";
break;
- case VPInstruction::ICmpULE:
- O << "icmp ule";
- break;
case VPInstruction::SLPLoad:
O << "combined load";
break;
@@ -408,12 +461,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::FirstOrderRecurrenceSplice:
O << "first-order splice";
break;
- case VPInstruction::CanonicalIVIncrement:
- O << "VF * UF + ";
- break;
- case VPInstruction::CanonicalIVIncrementNUW:
- O << "VF * UF +(nuw) ";
- break;
case VPInstruction::BranchOnCond:
O << "branch-on-cond";
break;
@@ -421,49 +468,35 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
O << "TC > VF ? TC - VF : 0";
break;
case VPInstruction::CanonicalIVIncrementForPart:
- O << "VF * Part + ";
- break;
- case VPInstruction::CanonicalIVIncrementForPartNUW:
- O << "VF * Part +(nuw) ";
+ O << "VF * Part +";
break;
case VPInstruction::BranchOnCount:
- O << "branch-on-count ";
+ O << "branch-on-count";
break;
default:
O << Instruction::getOpcodeName(getOpcode());
}
- O << FMF;
-
- for (const VPValue *Operand : operands()) {
- O << " ";
- Operand->printAsOperand(O, SlotTracker);
- }
+ printFlags(O);
+ printOperands(O, SlotTracker);
- if (DL) {
+ if (auto DL = getDebugLoc()) {
O << ", !dbg ";
DL.print(O);
}
}
#endif
-void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
- // Make sure the VPInstruction is a floating-point operation.
- assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
- Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
- Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
- Opcode == Instruction::FCmp) &&
- "this op can't take fast-math flags");
- FMF = FMFNew;
-}
-
void VPWidenCallRecipe::execute(VPTransformState &State) {
assert(State.VF.isVector() && "not widening");
auto &CI = *cast<CallInst>(getUnderlyingInstr());
assert(!isa<DbgInfoIntrinsic>(CI) &&
"DbgInfoIntrinsic should have been dropped during VPlan construction");
- State.setDebugLocFromInst(&CI);
+ State.setDebugLocFrom(CI.getDebugLoc());
+ FunctionType *VFTy = nullptr;
+ if (Variant)
+ VFTy = Variant->getFunctionType();
for (unsigned Part = 0; Part < State.UF; ++Part) {
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
@@ -475,12 +508,15 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
for (const auto &I : enumerate(operands())) {
// Some intrinsics have a scalar argument - don't replace it with a
// vector.
+ // Some vectorized function variants may also take a scalar argument,
+ // e.g. linear parameters for pointers.
Value *Arg;
- if (VectorIntrinsicID == Intrinsic::not_intrinsic ||
- !isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))
- Arg = State.get(I.value(), Part);
- else
+ if ((VFTy && !VFTy->getParamType(I.index())->isVectorTy()) ||
+ (VectorIntrinsicID != Intrinsic::not_intrinsic &&
+ isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())))
Arg = State.get(I.value(), VPIteration(0, 0));
+ else
+ Arg = State.get(I.value(), Part);
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
TysForDecl.push_back(Arg->getType());
Args.push_back(Arg);
@@ -553,8 +589,7 @@ void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent,
#endif
void VPWidenSelectRecipe::execute(VPTransformState &State) {
- auto &I = *cast<SelectInst>(getUnderlyingInstr());
- State.setDebugLocFromInst(&I);
+ State.setDebugLocFrom(getDebugLoc());
// The condition can be loop invariant but still defined inside the
// loop. This means that we can't just use the original 'cond' value.
@@ -569,13 +604,31 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
Value *Op1 = State.get(getOperand(2), Part);
Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1);
State.set(this, Sel, Part);
- State.addMetadata(Sel, &I);
+ State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
}
}
+VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
+ const FastMathFlags &FMF) {
+ AllowReassoc = FMF.allowReassoc();
+ NoNaNs = FMF.noNaNs();
+ NoInfs = FMF.noInfs();
+ NoSignedZeros = FMF.noSignedZeros();
+ AllowReciprocal = FMF.allowReciprocal();
+ AllowContract = FMF.allowContract();
+ ApproxFunc = FMF.approxFunc();
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
switch (OpType) {
+ case OperationType::Cmp:
+ O << " " << CmpInst::getPredicateName(getPredicate());
+ break;
+ case OperationType::DisjointOp:
+ if (DisjointFlags.IsDisjoint)
+ O << " disjoint";
+ break;
case OperationType::PossiblyExactOp:
if (ExactFlags.IsExact)
O << " exact";
@@ -593,17 +646,22 @@ void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
if (GEPFlags.IsInBounds)
O << " inbounds";
break;
+ case OperationType::NonNegOp:
+ if (NonNegFlags.NonNeg)
+ O << " nneg";
+ break;
case OperationType::Other:
break;
}
- O << " ";
+ if (getNumOperands() > 0)
+ O << " ";
}
#endif
void VPWidenRecipe::execute(VPTransformState &State) {
- auto &I = *cast<Instruction>(getUnderlyingValue());
+ State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
- switch (I.getOpcode()) {
+ switch (Opcode) {
case Instruction::Call:
case Instruction::Br:
case Instruction::PHI:
@@ -630,28 +688,24 @@ void VPWidenRecipe::execute(VPTransformState &State) {
case Instruction::Or:
case Instruction::Xor: {
// Just widen unops and binops.
- State.setDebugLocFromInst(&I);
-
for (unsigned Part = 0; Part < State.UF; ++Part) {
SmallVector<Value *, 2> Ops;
for (VPValue *VPOp : operands())
Ops.push_back(State.get(VPOp, Part));
- Value *V = Builder.CreateNAryOp(I.getOpcode(), Ops);
+ Value *V = Builder.CreateNAryOp(Opcode, Ops);
if (auto *VecOp = dyn_cast<Instruction>(V))
setFlags(VecOp);
// Use this vector value for all users of the original instruction.
State.set(this, V, Part);
- State.addMetadata(V, &I);
+ State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
}
break;
}
case Instruction::Freeze: {
- State.setDebugLocFromInst(&I);
-
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *Op = State.get(getOperand(0), Part);
@@ -663,9 +717,7 @@ void VPWidenRecipe::execute(VPTransformState &State) {
case Instruction::ICmp:
case Instruction::FCmp: {
// Widen compares. Generate vector compares.
- bool FCmp = (I.getOpcode() == Instruction::FCmp);
- auto *Cmp = cast<CmpInst>(&I);
- State.setDebugLocFromInst(Cmp);
+ bool FCmp = Opcode == Instruction::FCmp;
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *A = State.get(getOperand(0), Part);
Value *B = State.get(getOperand(1), Part);
@@ -673,51 +725,64 @@ void VPWidenRecipe::execute(VPTransformState &State) {
if (FCmp) {
// Propagate fast math flags.
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
- Builder.setFastMathFlags(Cmp->getFastMathFlags());
- C = Builder.CreateFCmp(Cmp->getPredicate(), A, B);
+ if (auto *I = dyn_cast_or_null<Instruction>(getUnderlyingValue()))
+ Builder.setFastMathFlags(I->getFastMathFlags());
+ C = Builder.CreateFCmp(getPredicate(), A, B);
} else {
- C = Builder.CreateICmp(Cmp->getPredicate(), A, B);
+ C = Builder.CreateICmp(getPredicate(), A, B);
}
State.set(this, C, Part);
- State.addMetadata(C, &I);
+ State.addMetadata(C, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
}
break;
}
default:
// This instruction is not vectorized by simple widening.
- LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I);
+ LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
+ << Instruction::getOpcodeName(Opcode));
llvm_unreachable("Unhandled instruction!");
} // end of switch.
+
+#if !defined(NDEBUG)
+ // Verify that VPlan type inference results agree with the type of the
+ // generated values.
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ assert(VectorType::get(State.TypeAnalysis.inferScalarType(this),
+ State.VF) == State.get(this, Part)->getType() &&
+ "inferred type and type from generated instructions do not match");
+ }
+#endif
}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "WIDEN ";
printAsOperand(O, SlotTracker);
- const Instruction *UI = getUnderlyingInstr();
- O << " = " << UI->getOpcodeName();
+ O << " = " << Instruction::getOpcodeName(Opcode);
printFlags(O);
- if (auto *Cmp = dyn_cast<CmpInst>(UI))
- O << Cmp->getPredicate() << " ";
printOperands(O, SlotTracker);
}
#endif
void VPWidenCastRecipe::execute(VPTransformState &State) {
- auto *I = cast_or_null<Instruction>(getUnderlyingValue());
- if (I)
- State.setDebugLocFromInst(I);
+ State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
/// Vectorize casts.
assert(State.VF.isVector() && "Not vectorizing?");
Type *DestTy = VectorType::get(getResultType(), State.VF);
-
+ VPValue *Op = getOperand(0);
for (unsigned Part = 0; Part < State.UF; ++Part) {
- Value *A = State.get(getOperand(0), Part);
+ if (Part > 0 && Op->isLiveIn()) {
+ // FIXME: Remove once explicit unrolling is implemented using VPlan.
+ State.set(this, State.get(this, 0), Part);
+ continue;
+ }
+ Value *A = State.get(Op, Part);
Value *Cast = Builder.CreateCast(Instruction::CastOps(Opcode), A, DestTy);
State.set(this, Cast, Part);
- State.addMetadata(Cast, I);
+ State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue()));
}
}
@@ -727,10 +792,182 @@ void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent,
O << Indent << "WIDEN-CAST ";
printAsOperand(O, SlotTracker);
O << " = " << Instruction::getOpcodeName(Opcode) << " ";
+ printFlags(O);
printOperands(O, SlotTracker);
O << " to " << *getResultType();
}
+#endif
+
+/// This function adds
+/// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...)
+/// to each vector element of Val. The sequence starts at StartIndex.
+/// \p Opcode is relevant for FP induction variable.
+static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step,
+ Instruction::BinaryOps BinOp, ElementCount VF,
+ IRBuilderBase &Builder) {
+ assert(VF.isVector() && "only vector VFs are supported");
+
+ // Create and check the types.
+ auto *ValVTy = cast<VectorType>(Val->getType());
+ ElementCount VLen = ValVTy->getElementCount();
+ Type *STy = Val->getType()->getScalarType();
+ assert((STy->isIntegerTy() || STy->isFloatingPointTy()) &&
+ "Induction Step must be an integer or FP");
+ assert(Step->getType() == STy && "Step has wrong type");
+
+ SmallVector<Constant *, 8> Indices;
+
+ // Create a vector of consecutive numbers from zero to VF.
+ VectorType *InitVecValVTy = ValVTy;
+ if (STy->isFloatingPointTy()) {
+ Type *InitVecValSTy =
+ IntegerType::get(STy->getContext(), STy->getScalarSizeInBits());
+ InitVecValVTy = VectorType::get(InitVecValSTy, VLen);
+ }
+ Value *InitVec = Builder.CreateStepVector(InitVecValVTy);
+
+ // Splat the StartIdx
+ Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx);
+
+ if (STy->isIntegerTy()) {
+ InitVec = Builder.CreateAdd(InitVec, StartIdxSplat);
+ Step = Builder.CreateVectorSplat(VLen, Step);
+ assert(Step->getType() == Val->getType() && "Invalid step vec");
+ // FIXME: The newly created binary instructions should contain nsw/nuw
+ // flags, which can be found from the original scalar operations.
+ Step = Builder.CreateMul(InitVec, Step);
+ return Builder.CreateAdd(Val, Step, "induction");
+ }
+
+ // Floating point induction.
+ assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) &&
+ "Binary Opcode should be specified for FP induction");
+ InitVec = Builder.CreateUIToFP(InitVec, ValVTy);
+ InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat);
+
+ Step = Builder.CreateVectorSplat(VLen, Step);
+ Value *MulOp = Builder.CreateFMul(InitVec, Step);
+ return Builder.CreateBinOp(BinOp, Val, MulOp, "induction");
+}
+
+/// A helper function that returns an integer or floating-point constant with
+/// value C.
+static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
+ return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C)
+ : ConstantFP::get(Ty, C);
+}
+
+static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy,
+ ElementCount VF) {
+ assert(FTy->isFloatingPointTy() && "Expected floating point type!");
+ Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits());
+ Value *RuntimeVF = getRuntimeVF(B, IntTy, VF);
+ return B.CreateUIToFP(RuntimeVF, FTy);
+}
+
+void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) {
+ assert(!State.Instance && "Int or FP induction being replicated.");
+
+ Value *Start = getStartValue()->getLiveInIRValue();
+ const InductionDescriptor &ID = getInductionDescriptor();
+ TruncInst *Trunc = getTruncInst();
+ IRBuilderBase &Builder = State.Builder;
+ assert(IV->getType() == ID.getStartValue()->getType() && "Types must match");
+ assert(State.VF.isVector() && "must have vector VF");
+
+ // The value from the original loop to which we are mapping the new induction
+ // variable.
+ Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV;
+
+ // Fast-math-flags propagate from the original induction instruction.
+ IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp()))
+ Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags());
+
+ // Now do the actual transformations, and start with fetching the step value.
+ Value *Step = State.get(getStepValue(), VPIteration(0, 0));
+
+ assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) &&
+ "Expected either an induction phi-node or a truncate of it!");
+
+ // Construct the initial value of the vector IV in the vector loop preheader
+ auto CurrIP = Builder.saveIP();
+ BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this);
+ Builder.SetInsertPoint(VectorPH->getTerminator());
+ if (isa<TruncInst>(EntryVal)) {
+ assert(Start->getType()->isIntegerTy() &&
+ "Truncation requires an integer type");
+ auto *TruncType = cast<IntegerType>(EntryVal->getType());
+ Step = Builder.CreateTrunc(Step, TruncType);
+ Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType);
+ }
+
+ Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0);
+ Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start);
+ Value *SteppedStart = getStepVector(
+ SplatStart, Zero, Step, ID.getInductionOpcode(), State.VF, State.Builder);
+
+ // We create vector phi nodes for both integer and floating-point induction
+ // variables. Here, we determine the kind of arithmetic we will perform.
+ Instruction::BinaryOps AddOp;
+ Instruction::BinaryOps MulOp;
+ if (Step->getType()->isIntegerTy()) {
+ AddOp = Instruction::Add;
+ MulOp = Instruction::Mul;
+ } else {
+ AddOp = ID.getInductionOpcode();
+ MulOp = Instruction::FMul;
+ }
+
+ // Multiply the vectorization factor by the step using integer or
+ // floating-point arithmetic as appropriate.
+ Type *StepType = Step->getType();
+ Value *RuntimeVF;
+ if (Step->getType()->isFloatingPointTy())
+ RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF);
+ else
+ RuntimeVF = getRuntimeVF(Builder, StepType, State.VF);
+ Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF);
+
+ // Create a vector splat to use in the induction update.
+ //
+ // FIXME: If the step is non-constant, we create the vector splat with
+ // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't
+ // handle a constant vector splat.
+ Value *SplatVF = isa<Constant>(Mul)
+ ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul))
+ : Builder.CreateVectorSplat(State.VF, Mul);
+ Builder.restoreIP(CurrIP);
+
+ // We may need to add the step a number of times, depending on the unroll
+ // factor. The last of those goes into the PHI.
+ PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind");
+ VecInd->insertBefore(State.CFG.PrevBB->getFirstInsertionPt());
+ VecInd->setDebugLoc(EntryVal->getDebugLoc());
+ Instruction *LastInduction = VecInd;
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ State.set(this, LastInduction, Part);
+
+ if (isa<TruncInst>(EntryVal))
+ State.addMetadata(LastInduction, EntryVal);
+
+ LastInduction = cast<Instruction>(
+ Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add"));
+ LastInduction->setDebugLoc(EntryVal->getDebugLoc());
+ }
+
+ LastInduction->setName("vec.ind.next");
+ VecInd->addIncoming(SteppedStart, VectorPH);
+ // Add induction update using an incorrect block temporarily. The phi node
+ // will be fixed after VPlan execution. Note that at this point the latch
+ // block cannot be used, as it does not exist yet.
+ // TODO: Model increment value in VPlan, by turning the recipe into a
+ // multi-def and a subclass of VPHeaderPHIRecipe.
+ VecInd->addIncoming(LastInduction, VectorPH);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "WIDEN-INDUCTION";
@@ -770,17 +1007,112 @@ void VPDerivedIVRecipe::print(raw_ostream &O, const Twine &Indent,
O << " * ";
getStepValue()->printAsOperand(O, SlotTracker);
- if (IndDesc.getStep()->getType() != ResultTy)
- O << " (truncated to " << *ResultTy << ")";
+ if (TruncResultTy)
+ O << " (truncated to " << *TruncResultTy << ")";
}
#endif
+void VPScalarIVStepsRecipe::execute(VPTransformState &State) {
+ // Fast-math-flags propagate from the original induction instruction.
+ IRBuilder<>::FastMathFlagGuard FMFG(State.Builder);
+ if (hasFastMathFlags())
+ State.Builder.setFastMathFlags(getFastMathFlags());
+
+ /// Compute scalar induction steps. \p ScalarIV is the scalar induction
+ /// variable on which to base the steps, \p Step is the size of the step.
+
+ Value *BaseIV = State.get(getOperand(0), VPIteration(0, 0));
+ Value *Step = State.get(getStepValue(), VPIteration(0, 0));
+ IRBuilderBase &Builder = State.Builder;
+
+ // Ensure step has the same type as that of scalar IV.
+ Type *BaseIVTy = BaseIV->getType()->getScalarType();
+ if (BaseIVTy != Step->getType()) {
+ // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to
+ // avoid separate truncate here.
+ assert(Step->getType()->isIntegerTy() &&
+ "Truncation requires an integer step");
+ Step = State.Builder.CreateTrunc(Step, BaseIVTy);
+ }
+
+ // We build scalar steps for both integer and floating-point induction
+ // variables. Here, we determine the kind of arithmetic we will perform.
+ Instruction::BinaryOps AddOp;
+ Instruction::BinaryOps MulOp;
+ if (BaseIVTy->isIntegerTy()) {
+ AddOp = Instruction::Add;
+ MulOp = Instruction::Mul;
+ } else {
+ AddOp = InductionOpcode;
+ MulOp = Instruction::FMul;
+ }
+
+ // Determine the number of scalars we need to generate for each unroll
+ // iteration.
+ bool FirstLaneOnly = vputils::onlyFirstLaneUsed(this);
+ // Compute the scalar steps and save the results in State.
+ Type *IntStepTy =
+ IntegerType::get(BaseIVTy->getContext(), BaseIVTy->getScalarSizeInBits());
+ Type *VecIVTy = nullptr;
+ Value *UnitStepVec = nullptr, *SplatStep = nullptr, *SplatIV = nullptr;
+ if (!FirstLaneOnly && State.VF.isScalable()) {
+ VecIVTy = VectorType::get(BaseIVTy, State.VF);
+ UnitStepVec =
+ Builder.CreateStepVector(VectorType::get(IntStepTy, State.VF));
+ SplatStep = Builder.CreateVectorSplat(State.VF, Step);
+ SplatIV = Builder.CreateVectorSplat(State.VF, BaseIV);
+ }
+
+ unsigned StartPart = 0;
+ unsigned EndPart = State.UF;
+ unsigned StartLane = 0;
+ unsigned EndLane = FirstLaneOnly ? 1 : State.VF.getKnownMinValue();
+ if (State.Instance) {
+ StartPart = State.Instance->Part;
+ EndPart = StartPart + 1;
+ StartLane = State.Instance->Lane.getKnownLane();
+ EndLane = StartLane + 1;
+ }
+ for (unsigned Part = StartPart; Part < EndPart; ++Part) {
+ Value *StartIdx0 = createStepForVF(Builder, IntStepTy, State.VF, Part);
+
+ if (!FirstLaneOnly && State.VF.isScalable()) {
+ auto *SplatStartIdx = Builder.CreateVectorSplat(State.VF, StartIdx0);
+ auto *InitVec = Builder.CreateAdd(SplatStartIdx, UnitStepVec);
+ if (BaseIVTy->isFloatingPointTy())
+ InitVec = Builder.CreateSIToFP(InitVec, VecIVTy);
+ auto *Mul = Builder.CreateBinOp(MulOp, InitVec, SplatStep);
+ auto *Add = Builder.CreateBinOp(AddOp, SplatIV, Mul);
+ State.set(this, Add, Part);
+ // It's useful to record the lane values too for the known minimum number
+ // of elements so we do those below. This improves the code quality when
+ // trying to extract the first element, for example.
+ }
+
+ if (BaseIVTy->isFloatingPointTy())
+ StartIdx0 = Builder.CreateSIToFP(StartIdx0, BaseIVTy);
+
+ for (unsigned Lane = StartLane; Lane < EndLane; ++Lane) {
+ Value *StartIdx = Builder.CreateBinOp(
+ AddOp, StartIdx0, getSignedIntOrFpConstant(BaseIVTy, Lane));
+ // The step returned by `createStepForVF` is a runtime-evaluated value
+ // when VF is scalable. Otherwise, it should be folded into a Constant.
+ assert((State.VF.isScalable() || isa<Constant>(StartIdx)) &&
+ "Expected StartIdx to be folded to a constant when VF is not "
+ "scalable");
+ auto *Mul = Builder.CreateBinOp(MulOp, StartIdx, Step);
+ auto *Add = Builder.CreateBinOp(AddOp, BaseIV, Mul);
+ State.set(this, Add, VPIteration(Part, Lane));
+ }
+ }
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent;
printAsOperand(O, SlotTracker);
- O << Indent << "= SCALAR-STEPS ";
+ O << " = SCALAR-STEPS ";
printOperands(O, SlotTracker);
}
#endif
@@ -874,7 +1206,7 @@ void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent,
#endif
void VPBlendRecipe::execute(VPTransformState &State) {
- State.setDebugLocFromInst(Phi);
+ State.setDebugLocFrom(getDebugLoc());
// We know that all PHIs in non-header blocks are converted into
// selects, so we don't have to worry about the insertion order and we
// can just use the builder.
@@ -916,7 +1248,7 @@ void VPBlendRecipe::execute(VPTransformState &State) {
void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "BLEND ";
- Phi->printAsOperand(O, false);
+ printAsOperand(O, SlotTracker);
O << " =";
if (getNumIncomingValues() == 1) {
// Not a User of any mask: not really blending, this is a
@@ -942,14 +1274,14 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
- O << " reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " (";
+ O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (getCondOp()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc->IntermediateStore)
+ if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
}
@@ -1093,12 +1425,12 @@ void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, const Twine &Indent,
void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) {
Value *Start = getStartValue()->getLiveInIRValue();
- PHINode *EntryPart = PHINode::Create(
- Start->getType(), 2, "index", &*State.CFG.PrevBB->getFirstInsertionPt());
+ PHINode *EntryPart = PHINode::Create(Start->getType(), 2, "index");
+ EntryPart->insertBefore(State.CFG.PrevBB->getFirstInsertionPt());
BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this);
EntryPart->addIncoming(Start, VectorPH);
- EntryPart->setDebugLoc(DL);
+ EntryPart->setDebugLoc(getDebugLoc());
for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)
State.set(this, EntryPart, Part);
}
@@ -1108,7 +1440,8 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "EMIT ";
printAsOperand(O, SlotTracker);
- O << " = CANONICAL-INDUCTION";
+ O << " = CANONICAL-INDUCTION ";
+ printOperands(O, SlotTracker);
}
#endif
@@ -1221,8 +1554,8 @@ void VPFirstOrderRecurrencePHIRecipe::execute(VPTransformState &State) {
}
// Create a phi node for the new recurrence.
- PHINode *EntryPart = PHINode::Create(
- VecTy, 2, "vector.recur", &*State.CFG.PrevBB->getFirstInsertionPt());
+ PHINode *EntryPart = PHINode::Create(VecTy, 2, "vector.recur");
+ EntryPart->insertBefore(State.CFG.PrevBB->getFirstInsertionPt());
EntryPart->addIncoming(VectorInit, VectorPH);
State.set(this, EntryPart, 0);
}
@@ -1254,8 +1587,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
"recipe must be in the vector loop header");
unsigned LastPartForNewPhi = isOrdered() ? 1 : State.UF;
for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) {
- Value *EntryPart =
- PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt());
+ Instruction *EntryPart = PHINode::Create(VecTy, 2, "vec.phi");
+ EntryPart->insertBefore(HeaderBB->getFirstInsertionPt());
State.set(this, EntryPart, Part);
}
@@ -1269,8 +1602,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
Value *Iden = nullptr;
RecurKind RK = RdxDesc.getRecurrenceKind();
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) ||
- RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) {
- // MinMax reduction have the start value as their identify.
+ RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
+ // MinMax and AnyOf reductions have the start value as their identity.
if (ScalarPHI) {
Iden = StartV;
} else {
@@ -1316,23 +1649,7 @@ void VPWidenPHIRecipe::execute(VPTransformState &State) {
assert(EnableVPlanNativePath &&
"Non-native vplans are not expected to have VPWidenPHIRecipes.");
- // Currently we enter here in the VPlan-native path for non-induction
- // PHIs where all control flow is uniform. We simply widen these PHIs.
- // Create a vector phi with no operands - the vector phi operands will be
- // set at the end of vector code generation.
- VPBasicBlock *Parent = getParent();
- VPRegionBlock *LoopRegion = Parent->getEnclosingLoopRegion();
- unsigned StartIdx = 0;
- // For phis in header blocks of loop regions, use the index of the value
- // coming from the preheader.
- if (LoopRegion->getEntryBasicBlock() == Parent) {
- for (unsigned I = 0; I < getNumOperands(); ++I) {
- if (getIncomingBlock(I) ==
- LoopRegion->getSinglePredecessor()->getExitingBasicBlock())
- StartIdx = I;
- }
- }
- Value *Op0 = State.get(getOperand(StartIdx), 0);
+ Value *Op0 = State.get(getOperand(0), 0);
Type *VecTy = Op0->getType();
Value *VecPhi = State.Builder.CreatePHI(VecTy, 2, "vec.phi");
State.set(this, VecPhi, 0);
@@ -1368,7 +1685,7 @@ void VPActiveLaneMaskPHIRecipe::execute(VPTransformState &State) {
PHINode *EntryPart =
State.Builder.CreatePHI(StartMask->getType(), 2, "active.lane.mask");
EntryPart->addIncoming(StartMask, VectorPH);
- EntryPart->setDebugLoc(DL);
+ EntryPart->setDebugLoc(getDebugLoc());
State.set(this, EntryPart, Part);
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 83bfdfd09d19..ea90ed4a21b1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -12,17 +12,22 @@
//===----------------------------------------------------------------------===//
#include "VPlanTransforms.h"
-#include "VPlanDominatorTree.h"
#include "VPRecipeBuilder.h"
+#include "VPlanAnalysis.h"
#include "VPlanCFG.h"
+#include "VPlanDominatorTree.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
using namespace llvm;
+using namespace llvm::PatternMatch;
+
void VPlanTransforms::VPInstructionsToVPRecipes(
VPlanPtr &Plan,
function_ref<const InductionDescriptor *(PHINode *)>
@@ -76,7 +81,7 @@ void VPlanTransforms::VPInstructionsToVPRecipes(
NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands());
} else if (auto *CI = dyn_cast<CastInst>(Inst)) {
NewRecipe = new VPWidenCastRecipe(
- CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), CI);
+ CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), *CI);
} else {
NewRecipe = new VPWidenRecipe(*Inst, Ingredient.operands());
}
@@ -158,17 +163,10 @@ static bool sinkScalarOperands(VPlan &Plan) {
// TODO: add ".cloned" suffix to name of Clone's VPValue.
Clone->insertBefore(SinkCandidate);
- for (auto *U : to_vector(SinkCandidate->getVPSingleValue()->users())) {
- auto *UI = cast<VPRecipeBase>(U);
- if (UI->getParent() == SinkTo)
- continue;
-
- for (unsigned Idx = 0; Idx != UI->getNumOperands(); Idx++) {
- if (UI->getOperand(Idx) != SinkCandidate->getVPSingleValue())
- continue;
- UI->setOperand(Idx, Clone);
- }
- }
+ SinkCandidate->getVPSingleValue()->replaceUsesWithIf(
+ Clone, [SinkTo](VPUser &U, unsigned) {
+ return cast<VPRecipeBase>(&U)->getParent() != SinkTo;
+ });
}
SinkCandidate->moveBefore(*SinkTo, SinkTo->getFirstNonPhi());
for (VPValue *Op : SinkCandidate->operands())
@@ -273,16 +271,10 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) {
VPValue *PredInst1 =
cast<VPPredInstPHIRecipe>(&Phi1ToMove)->getOperand(0);
VPValue *Phi1ToMoveV = Phi1ToMove.getVPSingleValue();
- for (VPUser *U : to_vector(Phi1ToMoveV->users())) {
- auto *UI = dyn_cast<VPRecipeBase>(U);
- if (!UI || UI->getParent() != Then2)
- continue;
- for (unsigned I = 0, E = U->getNumOperands(); I != E; ++I) {
- if (Phi1ToMoveV != U->getOperand(I))
- continue;
- U->setOperand(I, PredInst1);
- }
- }
+ Phi1ToMoveV->replaceUsesWithIf(PredInst1, [Then2](VPUser &U, unsigned) {
+ auto *UI = dyn_cast<VPRecipeBase>(&U);
+ return UI && UI->getParent() == Then2;
+ });
Phi1ToMove.moveBefore(*Merge2, Merge2->begin());
}
@@ -479,15 +471,45 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
// The recipes in the block are processed in reverse order, to catch chains
// of dead recipes.
for (VPRecipeBase &R : make_early_inc_range(reverse(*VPBB))) {
- if (R.mayHaveSideEffects() || any_of(R.definedValues(), [](VPValue *V) {
- return V->getNumUsers() > 0;
- }))
+ // A user keeps R alive:
+ if (any_of(R.definedValues(),
+ [](VPValue *V) { return V->getNumUsers(); }))
continue;
+
+ // Having side effects keeps R alive, but do remove conditional assume
+ // instructions as their conditions may be flattened.
+ auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
+ bool IsConditionalAssume =
+ RepR && RepR->isPredicated() &&
+ match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>());
+ if (R.mayHaveSideEffects() && !IsConditionalAssume)
+ continue;
+
R.eraseFromParent();
}
}
}
+static VPValue *createScalarIVSteps(VPlan &Plan, const InductionDescriptor &ID,
+ ScalarEvolution &SE, Instruction *TruncI,
+ Type *IVTy, VPValue *StartV,
+ VPValue *Step) {
+ VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
+ auto IP = HeaderVPBB->getFirstNonPhi();
+ VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV();
+ Type *TruncTy = TruncI ? TruncI->getType() : IVTy;
+ VPValue *BaseIV = CanonicalIV;
+ if (!CanonicalIV->isCanonical(ID.getKind(), StartV, Step, TruncTy)) {
+ BaseIV = new VPDerivedIVRecipe(ID, StartV, CanonicalIV, Step,
+ TruncI ? TruncI->getType() : nullptr);
+ HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP);
+ }
+
+ VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step);
+ HeaderVPBB->insert(Steps, IP);
+ return Steps;
+}
+
void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) {
SmallVector<VPRecipeBase *> ToRemove;
VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
@@ -501,36 +523,17 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) {
}))
continue;
- auto IP = HeaderVPBB->getFirstNonPhi();
- VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV();
- Type *ResultTy = WideIV->getPHINode()->getType();
- if (Instruction *TruncI = WideIV->getTruncInst())
- ResultTy = TruncI->getType();
const InductionDescriptor &ID = WideIV->getInductionDescriptor();
- VPValue *Step = WideIV->getStepValue();
- VPValue *BaseIV = CanonicalIV;
- if (!CanonicalIV->isCanonical(ID.getKind(), WideIV->getStartValue(), Step,
- ResultTy)) {
- BaseIV = new VPDerivedIVRecipe(ID, WideIV->getStartValue(), CanonicalIV,
- Step, ResultTy);
- HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP);
- }
-
- VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step);
- HeaderVPBB->insert(Steps, IP);
+ VPValue *Steps = createScalarIVSteps(
+ Plan, ID, SE, WideIV->getTruncInst(), WideIV->getPHINode()->getType(),
+ WideIV->getStartValue(), WideIV->getStepValue());
// Update scalar users of IV to use Step instead. Use SetVector to ensure
// the list of users doesn't contain duplicates.
- SetVector<VPUser *> Users(WideIV->user_begin(), WideIV->user_end());
- for (VPUser *U : Users) {
- if (HasOnlyVectorVFs && !U->usesScalars(WideIV))
- continue;
- for (unsigned I = 0, E = U->getNumOperands(); I != E; I++) {
- if (U->getOperand(I) != WideIV)
- continue;
- U->setOperand(I, Steps);
- }
- }
+ WideIV->replaceUsesWithIf(
+ Steps, [HasOnlyVectorVFs, WideIV](VPUser &U, unsigned) {
+ return !HasOnlyVectorVFs || U.usesScalars(WideIV);
+ });
}
}
@@ -778,3 +781,375 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
}
}
}
+
+/// Returns true is \p V is constant one.
+static bool isConstantOne(VPValue *V) {
+ if (!V->isLiveIn())
+ return false;
+ auto *C = dyn_cast<ConstantInt>(V->getLiveInIRValue());
+ return C && C->isOne();
+}
+
+/// Returns the llvm::Instruction opcode for \p R.
+static unsigned getOpcodeForRecipe(VPRecipeBase &R) {
+ if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R))
+ return WidenR->getUnderlyingInstr()->getOpcode();
+ if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R))
+ return WidenC->getOpcode();
+ if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R))
+ return RepR->getUnderlyingInstr()->getOpcode();
+ if (auto *VPI = dyn_cast<VPInstruction>(&R))
+ return VPI->getOpcode();
+ return 0;
+}
+
+/// Try to simplify recipe \p R.
+static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
+ switch (getOpcodeForRecipe(R)) {
+ case Instruction::Mul: {
+ VPValue *A = R.getOperand(0);
+ VPValue *B = R.getOperand(1);
+ if (isConstantOne(A))
+ return R.getVPSingleValue()->replaceAllUsesWith(B);
+ if (isConstantOne(B))
+ return R.getVPSingleValue()->replaceAllUsesWith(A);
+ break;
+ }
+ case Instruction::Trunc: {
+ VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe();
+ if (!Ext)
+ break;
+ unsigned ExtOpcode = getOpcodeForRecipe(*Ext);
+ if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt)
+ break;
+ VPValue *A = Ext->getOperand(0);
+ VPValue *Trunc = R.getVPSingleValue();
+ Type *TruncTy = TypeInfo.inferScalarType(Trunc);
+ Type *ATy = TypeInfo.inferScalarType(A);
+ if (TruncTy == ATy) {
+ Trunc->replaceAllUsesWith(A);
+ } else if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) {
+ auto *VPC =
+ new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy);
+ VPC->insertBefore(&R);
+ Trunc->replaceAllUsesWith(VPC);
+ } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) {
+ auto *VPC = new VPWidenCastRecipe(Instruction::Trunc, A, TruncTy);
+ VPC->insertBefore(&R);
+ Trunc->replaceAllUsesWith(VPC);
+ }
+#ifndef NDEBUG
+ // Verify that the cached type info is for both A and its users is still
+ // accurate by comparing it to freshly computed types.
+ VPTypeAnalysis TypeInfo2(TypeInfo.getContext());
+ assert(TypeInfo.inferScalarType(A) == TypeInfo2.inferScalarType(A));
+ for (VPUser *U : A->users()) {
+ auto *R = dyn_cast<VPRecipeBase>(U);
+ if (!R)
+ continue;
+ for (VPValue *VPV : R->definedValues())
+ assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV));
+ }
+#endif
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+/// Try to simplify the recipes in \p Plan.
+static void simplifyRecipes(VPlan &Plan, LLVMContext &Ctx) {
+ ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
+ Plan.getEntry());
+ VPTypeAnalysis TypeInfo(Ctx);
+ for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
+ for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+ simplifyRecipe(R, TypeInfo);
+ }
+ }
+}
+
+void VPlanTransforms::truncateToMinimalBitwidths(
+ VPlan &Plan, const MapVector<Instruction *, uint64_t> &MinBWs,
+ LLVMContext &Ctx) {
+#ifndef NDEBUG
+ // Count the processed recipes and cross check the count later with MinBWs
+ // size, to make sure all entries in MinBWs have been handled.
+ unsigned NumProcessedRecipes = 0;
+#endif
+ // Keep track of created truncates, so they can be re-used. Note that we
+ // cannot use RAUW after creating a new truncate, as this would could make
+ // other uses have different types for their operands, making them invalidly
+ // typed.
+ DenseMap<VPValue *, VPWidenCastRecipe *> ProcessedTruncs;
+ VPTypeAnalysis TypeInfo(Ctx);
+ VPBasicBlock *PH = Plan.getEntry();
+ for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+ vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
+ for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+ if (!isa<VPWidenRecipe, VPWidenCastRecipe, VPReplicateRecipe,
+ VPWidenSelectRecipe>(&R))
+ continue;
+
+ VPValue *ResultVPV = R.getVPSingleValue();
+ auto *UI = cast_or_null<Instruction>(ResultVPV->getUnderlyingValue());
+ unsigned NewResSizeInBits = MinBWs.lookup(UI);
+ if (!NewResSizeInBits)
+ continue;
+
+#ifndef NDEBUG
+ NumProcessedRecipes++;
+#endif
+ // If the value wasn't vectorized, we must maintain the original scalar
+ // type. Skip those here, after incrementing NumProcessedRecipes. Also
+ // skip casts which do not need to be handled explicitly here, as
+ // redundant casts will be removed during recipe simplification.
+ if (isa<VPReplicateRecipe, VPWidenCastRecipe>(&R)) {
+#ifndef NDEBUG
+ // If any of the operands is a live-in and not used by VPWidenRecipe or
+ // VPWidenSelectRecipe, but in MinBWs, make sure it is counted as
+ // processed as well. When MinBWs is currently constructed, there is no
+ // information about whether recipes are widened or replicated and in
+ // case they are reciplicated the operands are not truncated. Counting
+ // them them here ensures we do not miss any recipes in MinBWs.
+ // TODO: Remove once the analysis is done on VPlan.
+ for (VPValue *Op : R.operands()) {
+ if (!Op->isLiveIn())
+ continue;
+ auto *UV = dyn_cast_or_null<Instruction>(Op->getUnderlyingValue());
+ if (UV && MinBWs.contains(UV) && !ProcessedTruncs.contains(Op) &&
+ all_of(Op->users(), [](VPUser *U) {
+ return !isa<VPWidenRecipe, VPWidenSelectRecipe>(U);
+ })) {
+ // Add an entry to ProcessedTruncs to avoid counting the same
+ // operand multiple times.
+ ProcessedTruncs[Op] = nullptr;
+ NumProcessedRecipes += 1;
+ }
+ }
+#endif
+ continue;
+ }
+
+ Type *OldResTy = TypeInfo.inferScalarType(ResultVPV);
+ unsigned OldResSizeInBits = OldResTy->getScalarSizeInBits();
+ assert(OldResTy->isIntegerTy() && "only integer types supported");
+ if (OldResSizeInBits == NewResSizeInBits)
+ continue;
+ assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?");
+ (void)OldResSizeInBits;
+
+ auto *NewResTy = IntegerType::get(Ctx, NewResSizeInBits);
+
+ // Shrink operands by introducing truncates as needed.
+ unsigned StartIdx = isa<VPWidenSelectRecipe>(&R) ? 1 : 0;
+ for (unsigned Idx = StartIdx; Idx != R.getNumOperands(); ++Idx) {
+ auto *Op = R.getOperand(Idx);
+ unsigned OpSizeInBits =
+ TypeInfo.inferScalarType(Op)->getScalarSizeInBits();
+ if (OpSizeInBits == NewResSizeInBits)
+ continue;
+ assert(OpSizeInBits > NewResSizeInBits && "nothing to truncate");
+ auto [ProcessedIter, IterIsEmpty] =
+ ProcessedTruncs.insert({Op, nullptr});
+ VPWidenCastRecipe *NewOp =
+ IterIsEmpty
+ ? new VPWidenCastRecipe(Instruction::Trunc, Op, NewResTy)
+ : ProcessedIter->second;
+ R.setOperand(Idx, NewOp);
+ if (!IterIsEmpty)
+ continue;
+ ProcessedIter->second = NewOp;
+ if (!Op->isLiveIn()) {
+ NewOp->insertBefore(&R);
+ } else {
+ PH->appendRecipe(NewOp);
+#ifndef NDEBUG
+ auto *OpInst = dyn_cast<Instruction>(Op->getLiveInIRValue());
+ bool IsContained = MinBWs.contains(OpInst);
+ NumProcessedRecipes += IsContained;
+#endif
+ }
+ }
+
+ // Any wrapping introduced by shrinking this operation shouldn't be
+ // considered undefined behavior. So, we can't unconditionally copy
+ // arithmetic wrapping flags to VPW.
+ if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R))
+ VPW->dropPoisonGeneratingFlags();
+
+ // Extend result to original width.
+ auto *Ext = new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy);
+ Ext->insertAfter(&R);
+ ResultVPV->replaceAllUsesWith(Ext);
+ Ext->setOperand(0, ResultVPV);
+ }
+ }
+
+ assert(MinBWs.size() == NumProcessedRecipes &&
+ "some entries in MinBWs haven't been processed");
+}
+
+void VPlanTransforms::optimize(VPlan &Plan, ScalarEvolution &SE) {
+ removeRedundantCanonicalIVs(Plan);
+ removeRedundantInductionCasts(Plan);
+
+ optimizeInductions(Plan, SE);
+ simplifyRecipes(Plan, SE.getContext());
+ removeDeadRecipes(Plan);
+
+ createAndOptimizeReplicateRegions(Plan);
+
+ removeRedundantExpandSCEVRecipes(Plan);
+ mergeBlocksIntoPredecessors(Plan);
+}
+
+// Add a VPActiveLaneMaskPHIRecipe and related recipes to \p Plan and replace
+// the loop terminator with a branch-on-cond recipe with the negated
+// active-lane-mask as operand. Note that this turns the loop into an
+// uncountable one. Only the existing terminator is replaced, all other existing
+// recipes/users remain unchanged, except for poison-generating flags being
+// dropped from the canonical IV increment. Return the created
+// VPActiveLaneMaskPHIRecipe.
+//
+// The function uses the following definitions:
+//
+// %TripCount = DataWithControlFlowWithoutRuntimeCheck ?
+// calculate-trip-count-minus-VF (original TC) : original TC
+// %IncrementValue = DataWithControlFlowWithoutRuntimeCheck ?
+// CanonicalIVPhi : CanonicalIVIncrement
+// %StartV is the canonical induction start value.
+//
+// The function adds the following recipes:
+//
+// vector.ph:
+// %TripCount = calculate-trip-count-minus-VF (original TC)
+// [if DataWithControlFlowWithoutRuntimeCheck]
+// %EntryInc = canonical-iv-increment-for-part %StartV
+// %EntryALM = active-lane-mask %EntryInc, %TripCount
+//
+// vector.body:
+// ...
+// %P = active-lane-mask-phi [ %EntryALM, %vector.ph ], [ %ALM, %vector.body ]
+// ...
+// %InLoopInc = canonical-iv-increment-for-part %IncrementValue
+// %ALM = active-lane-mask %InLoopInc, TripCount
+// %Negated = Not %ALM
+// branch-on-cond %Negated
+//
+static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
+ VPlan &Plan, bool DataAndControlFlowWithoutRuntimeCheck) {
+ VPRegionBlock *TopRegion = Plan.getVectorLoopRegion();
+ VPBasicBlock *EB = TopRegion->getExitingBasicBlock();
+ auto *CanonicalIVPHI = Plan.getCanonicalIV();
+ VPValue *StartV = CanonicalIVPHI->getStartValue();
+
+ auto *CanonicalIVIncrement =
+ cast<VPInstruction>(CanonicalIVPHI->getBackedgeValue());
+ // TODO: Check if dropping the flags is needed if
+ // !DataAndControlFlowWithoutRuntimeCheck.
+ CanonicalIVIncrement->dropPoisonGeneratingFlags();
+ DebugLoc DL = CanonicalIVIncrement->getDebugLoc();
+ // We can't use StartV directly in the ActiveLaneMask VPInstruction, since
+ // we have to take unrolling into account. Each part needs to start at
+ // Part * VF
+ auto *VecPreheader = cast<VPBasicBlock>(TopRegion->getSinglePredecessor());
+ VPBuilder Builder(VecPreheader);
+
+ // Create the ActiveLaneMask instruction using the correct start values.
+ VPValue *TC = Plan.getTripCount();
+
+ VPValue *TripCount, *IncrementValue;
+ if (!DataAndControlFlowWithoutRuntimeCheck) {
+ // When the loop is guarded by a runtime overflow check for the loop
+ // induction variable increment by VF, we can increment the value before
+ // the get.active.lane mask and use the unmodified tripcount.
+ IncrementValue = CanonicalIVIncrement;
+ TripCount = TC;
+ } else {
+ // When avoiding a runtime check, the active.lane.mask inside the loop
+ // uses a modified trip count and the induction variable increment is
+ // done after the active.lane.mask intrinsic is called.
+ IncrementValue = CanonicalIVPHI;
+ TripCount = Builder.createNaryOp(VPInstruction::CalculateTripCountMinusVF,
+ {TC}, DL);
+ }
+ auto *EntryIncrement = Builder.createOverflowingOp(
+ VPInstruction::CanonicalIVIncrementForPart, {StartV}, {false, false}, DL,
+ "index.part.next");
+
+ // Create the active lane mask instruction in the VPlan preheader.
+ auto *EntryALM =
+ Builder.createNaryOp(VPInstruction::ActiveLaneMask, {EntryIncrement, TC},
+ DL, "active.lane.mask.entry");
+
+ // Now create the ActiveLaneMaskPhi recipe in the main loop using the
+ // preheader ActiveLaneMask instruction.
+ auto LaneMaskPhi = new VPActiveLaneMaskPHIRecipe(EntryALM, DebugLoc());
+ LaneMaskPhi->insertAfter(CanonicalIVPHI);
+
+ // Create the active lane mask for the next iteration of the loop before the
+ // original terminator.
+ VPRecipeBase *OriginalTerminator = EB->getTerminator();
+ Builder.setInsertPoint(OriginalTerminator);
+ auto *InLoopIncrement =
+ Builder.createOverflowingOp(VPInstruction::CanonicalIVIncrementForPart,
+ {IncrementValue}, {false, false}, DL);
+ auto *ALM = Builder.createNaryOp(VPInstruction::ActiveLaneMask,
+ {InLoopIncrement, TripCount}, DL,
+ "active.lane.mask.next");
+ LaneMaskPhi->addOperand(ALM);
+
+ // Replace the original terminator with BranchOnCond. We have to invert the
+ // mask here because a true condition means jumping to the exit block.
+ auto *NotMask = Builder.createNot(ALM, DL);
+ Builder.createNaryOp(VPInstruction::BranchOnCond, {NotMask}, DL);
+ OriginalTerminator->eraseFromParent();
+ return LaneMaskPhi;
+}
+
+void VPlanTransforms::addActiveLaneMask(
+ VPlan &Plan, bool UseActiveLaneMaskForControlFlow,
+ bool DataAndControlFlowWithoutRuntimeCheck) {
+ assert((!DataAndControlFlowWithoutRuntimeCheck ||
+ UseActiveLaneMaskForControlFlow) &&
+ "DataAndControlFlowWithoutRuntimeCheck implies "
+ "UseActiveLaneMaskForControlFlow");
+
+ auto FoundWidenCanonicalIVUser =
+ find_if(Plan.getCanonicalIV()->users(),
+ [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); });
+ assert(FoundWidenCanonicalIVUser &&
+ "Must have widened canonical IV when tail folding!");
+ auto *WideCanonicalIV =
+ cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
+ VPRecipeBase *LaneMask;
+ if (UseActiveLaneMaskForControlFlow) {
+ LaneMask = addVPLaneMaskPhiAndUpdateExitBranch(
+ Plan, DataAndControlFlowWithoutRuntimeCheck);
+ } else {
+ LaneMask = new VPInstruction(VPInstruction::ActiveLaneMask,
+ {WideCanonicalIV, Plan.getTripCount()},
+ nullptr, "active.lane.mask");
+ LaneMask->insertAfter(WideCanonicalIV);
+ }
+
+ // Walk users of WideCanonicalIV and replace all compares of the form
+ // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an
+ // active-lane-mask.
+ VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
+ for (VPUser *U : SmallVector<VPUser *>(WideCanonicalIV->users())) {
+ auto *CompareToReplace = dyn_cast<VPInstruction>(U);
+ if (!CompareToReplace ||
+ CompareToReplace->getOpcode() != Instruction::ICmp ||
+ CompareToReplace->getPredicate() != CmpInst::ICMP_ULE ||
+ CompareToReplace->getOperand(1) != BTC)
+ continue;
+
+ assert(CompareToReplace->getOperand(0) == WideCanonicalIV &&
+ "WidenCanonicalIV must be the first operand of the compare");
+ CompareToReplace->replaceAllUsesWith(LaneMask->getVPSingleValue());
+ CompareToReplace->eraseFromParent();
+ }
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 3eccf6e9600d..e8a6da8c3205 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -37,12 +37,56 @@ struct VPlanTransforms {
GetIntOrFpInductionDescriptor,
ScalarEvolution &SE, const TargetLibraryInfo &TLI);
+ /// Sink users of fixed-order recurrences after the recipe defining their
+ /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions
+ /// to combine the value from the recurrence phis and previous values. The
+ /// current implementation assumes all users can be sunk after the previous
+ /// value, which is enforced by earlier legality checks.
+ /// \returns true if all users of fixed-order recurrences could be re-arranged
+ /// as needed or false if it is not possible. In the latter case, \p Plan is
+ /// not valid.
+ static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder);
+
+ /// Clear NSW/NUW flags from reduction instructions if necessary.
+ static void clearReductionWrapFlags(VPlan &Plan);
+
+ /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the
+ /// resulting plan to \p BestVF and \p BestUF.
+ static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
+ unsigned BestUF,
+ PredicatedScalarEvolution &PSE);
+
+ /// Apply VPlan-to-VPlan optimizations to \p Plan, including induction recipe
+ /// optimizations, dead recipe removal, replicate region optimizations and
+ /// block merging.
+ static void optimize(VPlan &Plan, ScalarEvolution &SE);
+
/// Wrap predicated VPReplicateRecipes with a mask operand in an if-then
/// region block and remove the mask operand. Optimize the created regions by
/// iteratively sinking scalar operands into the region, followed by merging
/// regions until no improvements are remaining.
static void createAndOptimizeReplicateRegions(VPlan &Plan);
+ /// Replace (ICMP_ULE, wide canonical IV, backedge-taken-count) checks with an
+ /// (active-lane-mask recipe, wide canonical IV, trip-count). If \p
+ /// UseActiveLaneMaskForControlFlow is true, introduce an
+ /// VPActiveLaneMaskPHIRecipe. If \p DataAndControlFlowWithoutRuntimeCheck is
+ /// true, no minimum-iteration runtime check will be created (during skeleton
+ /// creation) and instead it is handled using active-lane-mask. \p
+ /// DataAndControlFlowWithoutRuntimeCheck implies \p
+ /// UseActiveLaneMaskForControlFlow.
+ static void addActiveLaneMask(VPlan &Plan,
+ bool UseActiveLaneMaskForControlFlow,
+ bool DataAndControlFlowWithoutRuntimeCheck);
+
+ /// Insert truncates and extends for any truncated recipe. Redundant casts
+ /// will be folded later.
+ static void
+ truncateToMinimalBitwidths(VPlan &Plan,
+ const MapVector<Instruction *, uint64_t> &MinBWs,
+ LLVMContext &Ctx);
+
+private:
/// Remove redundant VPBasicBlocks by merging them into their predecessor if
/// the predecessor has a single successor.
static bool mergeBlocksIntoPredecessors(VPlan &Plan);
@@ -71,24 +115,6 @@ struct VPlanTransforms {
/// them with already existing recipes expanding the same SCEV expression.
static void removeRedundantExpandSCEVRecipes(VPlan &Plan);
- /// Sink users of fixed-order recurrences after the recipe defining their
- /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions
- /// to combine the value from the recurrence phis and previous values. The
- /// current implementation assumes all users can be sunk after the previous
- /// value, which is enforced by earlier legality checks.
- /// \returns true if all users of fixed-order recurrences could be re-arranged
- /// as needed or false if it is not possible. In the latter case, \p Plan is
- /// not valid.
- static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder);
-
- /// Clear NSW/NUW flags from reduction instructions if necessary.
- static void clearReductionWrapFlags(VPlan &Plan);
-
- /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the
- /// resulting plan to \p BestVF and \p BestUF.
- static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
- unsigned BestUF,
- PredicatedScalarEvolution &PSE);
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index ac110bb3b0ef..e5ca52755dd2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -163,6 +163,13 @@ public:
void replaceAllUsesWith(VPValue *New);
+ /// Go through the uses list for this VPValue and make each use point to \p
+ /// New if the callback ShouldReplace returns true for the given use specified
+ /// by a pair of (VPUser, the use index).
+ void replaceUsesWithIf(
+ VPValue *New,
+ llvm::function_ref<bool(VPUser &U, unsigned Idx)> ShouldReplace);
+
/// Returns the recipe defining this VPValue or nullptr if it is not defined
/// by a recipe, i.e. is a live-in.
VPRecipeBase *getDefiningRecipe();
@@ -296,6 +303,14 @@ public:
"Op must be an operand of the recipe");
return false;
}
+
+ /// Returns true if the VPUser only uses the first part of operand \p Op.
+ /// Conservatively returns false.
+ virtual bool onlyFirstPartUsed(const VPValue *Op) const {
+ assert(is_contained(operands(), Op) &&
+ "Op must be an operand of the recipe");
+ return false;
+ }
};
/// This class augments a recipe with a set of VPValues defined by the recipe.
@@ -325,7 +340,7 @@ class VPDef {
assert(V->Def == this && "can only remove VPValue linked with this VPDef");
assert(is_contained(DefinedValues, V) &&
"VPValue to remove must be in DefinedValues");
- erase_value(DefinedValues, V);
+ llvm::erase(DefinedValues, V);
V->Def = nullptr;
}
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 13464c9d3496..f18711ba30b7 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,8 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -28,6 +30,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <numeric>
+#include <queue>
#define DEBUG_TYPE "vector-combine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
@@ -100,8 +103,9 @@ private:
Instruction &I);
bool foldExtractExtract(Instruction &I);
bool foldInsExtFNeg(Instruction &I);
- bool foldBitcastShuf(Instruction &I);
+ bool foldBitcastShuffle(Instruction &I);
bool scalarizeBinopOrCmp(Instruction &I);
+ bool scalarizeVPIntrinsic(Instruction &I);
bool foldExtractedCmps(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
@@ -258,8 +262,8 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
// It is safe and potentially profitable to load a vector directly:
// inselt undef, load Scalar, 0 --> load VecPtr
IRBuilder<> Builder(Load);
- Value *CastedPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
- SrcPtr, MinVecTy->getPointerTo(AS));
+ Value *CastedPtr =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
VecLd = Builder.CreateShuffleVector(VecLd, Mask);
@@ -321,7 +325,7 @@ bool VectorCombine::widenSubvectorLoad(Instruction &I) {
IRBuilder<> Builder(Load);
Value *CastedPtr =
- Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Ty->getPointerTo(AS));
+ Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
replaceValue(I, *VecLd);
++NumVecLoad;
@@ -677,7 +681,7 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
/// destination type followed by shuffle. This can enable further transforms by
/// moving bitcasts or shuffles together.
-bool VectorCombine::foldBitcastShuf(Instruction &I) {
+bool VectorCombine::foldBitcastShuffle(Instruction &I) {
Value *V;
ArrayRef<int> Mask;
if (!match(&I, m_BitCast(
@@ -687,35 +691,43 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
// 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
// scalable type is unknown; Second, we cannot reason if the narrowed shuffle
// mask for scalable type is a splat or not.
- // 2) Disallow non-vector casts and length-changing shuffles.
+ // 2) Disallow non-vector casts.
// TODO: We could allow any shuffle.
+ auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
- if (!SrcTy || I.getOperand(0)->getType() != SrcTy)
+ if (!DestTy || !SrcTy)
+ return false;
+
+ unsigned DestEltSize = DestTy->getScalarSizeInBits();
+ unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
+ if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
return false;
- auto *DestTy = cast<FixedVectorType>(I.getType());
- unsigned DestNumElts = DestTy->getNumElements();
- unsigned SrcNumElts = SrcTy->getNumElements();
SmallVector<int, 16> NewMask;
- if (SrcNumElts <= DestNumElts) {
+ if (DestEltSize <= SrcEltSize) {
// The bitcast is from wide to narrow/equal elements. The shuffle mask can
// always be expanded to the equivalent form choosing narrower elements.
- assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
- unsigned ScaleFactor = DestNumElts / SrcNumElts;
+ assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
+ unsigned ScaleFactor = SrcEltSize / DestEltSize;
narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
} else {
// The bitcast is from narrow elements to wide elements. The shuffle mask
// must choose consecutive elements to allow casting first.
- assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
- unsigned ScaleFactor = SrcNumElts / DestNumElts;
+ assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
+ unsigned ScaleFactor = DestEltSize / SrcEltSize;
if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
return false;
}
+ // Bitcast the shuffle src - keep its original width but using the destination
+ // scalar type.
+ unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
+ auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
+
// The new shuffle must not cost more than the old shuffle. The bitcast is
// moved ahead of the shuffle, so assume that it has the same cost as before.
InstructionCost DestCost = TTI.getShuffleCost(
- TargetTransformInfo::SK_PermuteSingleSrc, DestTy, NewMask);
+ TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
InstructionCost SrcCost =
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
if (DestCost > SrcCost || !DestCost.isValid())
@@ -723,12 +735,131 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
// bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
++NumShufOfBitcast;
- Value *CastV = Builder.CreateBitCast(V, DestTy);
+ Value *CastV = Builder.CreateBitCast(V, ShuffleTy);
Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
replaceValue(I, *Shuf);
return true;
}
+/// VP Intrinsics whose vector operands are both splat values may be simplified
+/// into the scalar version of the operation and the result splatted. This
+/// can lead to scalarization down the line.
+bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
+ if (!isa<VPIntrinsic>(I))
+ return false;
+ VPIntrinsic &VPI = cast<VPIntrinsic>(I);
+ Value *Op0 = VPI.getArgOperand(0);
+ Value *Op1 = VPI.getArgOperand(1);
+
+ if (!isSplatValue(Op0) || !isSplatValue(Op1))
+ return false;
+
+ // Check getSplatValue early in this function, to avoid doing unnecessary
+ // work.
+ Value *ScalarOp0 = getSplatValue(Op0);
+ Value *ScalarOp1 = getSplatValue(Op1);
+ if (!ScalarOp0 || !ScalarOp1)
+ return false;
+
+ // For the binary VP intrinsics supported here, the result on disabled lanes
+ // is a poison value. For now, only do this simplification if all lanes
+ // are active.
+ // TODO: Relax the condition that all lanes are active by using insertelement
+ // on inactive lanes.
+ auto IsAllTrueMask = [](Value *MaskVal) {
+ if (Value *SplattedVal = getSplatValue(MaskVal))
+ if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
+ return ConstValue->isAllOnesValue();
+ return false;
+ };
+ if (!IsAllTrueMask(VPI.getArgOperand(2)))
+ return false;
+
+ // Check to make sure we support scalarization of the intrinsic
+ Intrinsic::ID IntrID = VPI.getIntrinsicID();
+ if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
+ return false;
+
+ // Calculate cost of splatting both operands into vectors and the vector
+ // intrinsic
+ VectorType *VecTy = cast<VectorType>(VPI.getType());
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost SplatCost =
+ TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);
+
+ // Calculate the cost of the VP Intrinsic
+ SmallVector<Type *, 4> Args;
+ for (Value *V : VPI.args())
+ Args.push_back(V->getType());
+ IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
+ InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
+ InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
+
+ // Determine scalar opcode
+ std::optional<unsigned> FunctionalOpcode =
+ VPI.getFunctionalOpcode();
+ std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
+ if (!FunctionalOpcode) {
+ ScalarIntrID = VPI.getFunctionalIntrinsicID();
+ if (!ScalarIntrID)
+ return false;
+ }
+
+ // Calculate cost of scalarizing
+ InstructionCost ScalarOpCost = 0;
+ if (ScalarIntrID) {
+ IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
+ ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
+ } else {
+ ScalarOpCost =
+ TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
+ }
+
+ // The existing splats may be kept around if other instructions use them.
+ InstructionCost CostToKeepSplats =
+ (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
+ InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
+
+ LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
+ << "\n");
+ LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
+ << ", Cost of scalarizing:" << NewCost << "\n");
+
+ // We want to scalarize unless the vector variant actually has lower cost.
+ if (OldCost < NewCost || !NewCost.isValid())
+ return false;
+
+ // Scalarize the intrinsic
+ ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
+ Value *EVL = VPI.getArgOperand(3);
+ const DataLayout &DL = VPI.getModule()->getDataLayout();
+
+ // If the VP op might introduce UB or poison, we can scalarize it provided
+ // that we know the EVL > 0: If the EVL is zero, then the original VP op
+ // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
+ // scalarizing it.
+ bool SafeToSpeculate;
+ if (ScalarIntrID)
+ SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
+ .hasFnAttr(Attribute::AttrKind::Speculatable);
+ else
+ SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
+ *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
+ if (!SafeToSpeculate && !isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT))
+ return false;
+
+ Value *ScalarVal =
+ ScalarIntrID
+ ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
+ {ScalarOp0, ScalarOp1})
+ : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
+ ScalarOp0, ScalarOp1);
+
+ replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
+ return true;
+}
+
/// Match a vector binop or compare instruction with at least one inserted
/// scalar operand and convert to scalar binop/cmp followed by insertelement.
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
@@ -1013,19 +1144,24 @@ public:
/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
/// Idx. \p Idx must access a valid vector element.
-static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
- Value *Idx, Instruction *CtxI,
+static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
+ Instruction *CtxI,
AssumptionCache &AC,
const DominatorTree &DT) {
+ // We do checks for both fixed vector types and scalable vector types.
+ // This is the number of elements of fixed vector types,
+ // or the minimum number of elements of scalable vector types.
+ uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
+
if (auto *C = dyn_cast<ConstantInt>(Idx)) {
- if (C->getValue().ult(VecTy->getNumElements()))
+ if (C->getValue().ult(NumElements))
return ScalarizationResult::safe();
return ScalarizationResult::unsafe();
}
unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
APInt Zero(IntWidth, 0);
- APInt MaxElts(IntWidth, VecTy->getNumElements());
+ APInt MaxElts(IntWidth, NumElements);
ConstantRange ValidIndices(Zero, MaxElts);
ConstantRange IdxRange(IntWidth, true);
@@ -1074,8 +1210,7 @@ static Align computeAlignmentAfterScalarization(Align VectorAlignment,
// store i32 %b, i32* %1
bool VectorCombine::foldSingleElementStore(Instruction &I) {
auto *SI = cast<StoreInst>(&I);
- if (!SI->isSimple() ||
- !isa<FixedVectorType>(SI->getValueOperand()->getType()))
+ if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
return false;
// TODO: Combine more complicated patterns (multiple insert) by referencing
@@ -1089,13 +1224,13 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
return false;
if (auto *Load = dyn_cast<LoadInst>(Source)) {
- auto VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType());
+ auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
const DataLayout &DL = I.getModule()->getDataLayout();
Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
// Don't optimize for atomic/volatile load or store. Ensure memory is not
// modified between, vector type matches store size, and index is inbounds.
if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
- !DL.typeSizeEqualsStoreSize(Load->getType()) ||
+ !DL.typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
SrcAddr != SI->getPointerOperand()->stripPointerCasts())
return false;
@@ -1130,19 +1265,26 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
- auto *FixedVT = cast<FixedVectorType>(I.getType());
+ auto *VecTy = cast<VectorType>(I.getType());
auto *LI = cast<LoadInst>(&I);
const DataLayout &DL = I.getModule()->getDataLayout();
- if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+ if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy->getScalarType()))
return false;
InstructionCost OriginalCost =
- TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace());
InstructionCost ScalarizedCost = 0;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
+ auto FailureGuard = make_scope_exit([&]() {
+ // If the transform is aborted, discard the ScalarizationResults.
+ for (auto &Pair : NeedFreeze)
+ Pair.second.discard();
+ });
+
// Check if all users of the load are extracts with no memory modifications
// between the load and the extract. Compute the cost of both the original
// code and the scalarized version.
@@ -1151,9 +1293,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!UI || UI->getParent() != LI->getParent())
return false;
- if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
- return false;
-
// Check if any instruction between the load and the extract may modify
// memory.
if (LastCheckedInst->comesBefore(UI)) {
@@ -1168,22 +1307,23 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
LastCheckedInst = UI;
}
- auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
- if (!ScalarIdx.isSafe()) {
- // TODO: Freeze index if it is safe to do so.
- ScalarIdx.discard();
+ auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
+ if (ScalarIdx.isUnsafe())
return false;
+ if (ScalarIdx.isSafeWithFreeze()) {
+ NeedFreeze.try_emplace(UI, ScalarIdx);
+ ScalarIdx.discard();
}
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
OriginalCost +=
- TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+ TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
Index ? Index->getZExtValue() : -1);
ScalarizedCost +=
- TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+ TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
Align(1), LI->getPointerAddressSpace());
- ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+ ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
}
if (ScalarizedCost >= OriginalCost)
@@ -1192,21 +1332,27 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
// Replace extracts with narrow scalar loads.
for (User *U : LI->users()) {
auto *EI = cast<ExtractElementInst>(U);
- Builder.SetInsertPoint(EI);
-
Value *Idx = EI->getOperand(1);
+
+ // Insert 'freeze' for poison indexes.
+ auto It = NeedFreeze.find(EI);
+ if (It != NeedFreeze.end())
+ It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+ Builder.SetInsertPoint(EI);
Value *GEP =
- Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+ Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
- FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+ VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
Align ScalarOpAlignment = computeAlignmentAfterScalarization(
- LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+ LI->getAlign(), VecTy->getElementType(), Idx, DL);
NewLoad->setAlignment(ScalarOpAlignment);
replaceValue(*EI, *NewLoad);
}
+ FailureGuard.release();
return true;
}
@@ -1340,21 +1486,28 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
if (!ShuffleInputType)
return false;
- int NumInputElts = ShuffleInputType->getNumElements();
+ unsigned NumInputElts = ShuffleInputType->getNumElements();
// Find the mask from sorting the lanes into order. This is most likely to
// become a identity or concat mask. Undef elements are pushed to the end.
SmallVector<int> ConcatMask;
Shuffle->getShuffleMask(ConcatMask);
sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
+ // In the case of a truncating shuffle it's possible for the mask
+ // to have an index greater than the size of the resulting vector.
+ // This requires special handling.
+ bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
bool UsesSecondVec =
- any_of(ConcatMask, [&](int M) { return M >= NumInputElts; });
+ any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
+
+ FixedVectorType *VecTyForCost =
+ (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
InstructionCost OldCost = TTI.getShuffleCost(
- UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
- Shuffle->getShuffleMask());
+ UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
+ VecTyForCost, Shuffle->getShuffleMask());
InstructionCost NewCost = TTI.getShuffleCost(
- UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
- ConcatMask);
+ UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
+ VecTyForCost, ConcatMask);
LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
<< "\n");
@@ -1657,16 +1810,16 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
return SSV->getOperand(Op);
return SV->getOperand(Op);
};
- Builder.SetInsertPoint(SVI0A->getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
GetShuffleOperand(SVI0A, 1), V1A);
- Builder.SetInsertPoint(SVI0B->getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
GetShuffleOperand(SVI0B, 1), V1B);
- Builder.SetInsertPoint(SVI1A->getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
GetShuffleOperand(SVI1A, 1), V2A);
- Builder.SetInsertPoint(SVI1B->getInsertionPointAfterDef());
+ Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
GetShuffleOperand(SVI1B, 1), V2B);
Builder.SetInsertPoint(Op0);
@@ -1723,9 +1876,6 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= widenSubvectorLoad(I);
break;
- case Instruction::Load:
- MadeChange |= scalarizeLoadExtract(I);
- break;
default:
break;
}
@@ -1733,13 +1883,15 @@ bool VectorCombine::run() {
// This transform works with scalable and fixed vectors
// TODO: Identify and allow other scalable transforms
- if (isa<VectorType>(I.getType()))
+ if (isa<VectorType>(I.getType())) {
MadeChange |= scalarizeBinopOrCmp(I);
+ MadeChange |= scalarizeLoadExtract(I);
+ MadeChange |= scalarizeVPIntrinsic(I);
+ }
if (Opcode == Instruction::Store)
MadeChange |= foldSingleElementStore(I);
-
// If this is an early pipeline invocation of this pass, we are done.
if (TryEarlyFoldsOnly)
return;
@@ -1758,7 +1910,7 @@ bool VectorCombine::run() {
MadeChange |= foldSelectShuffle(I);
break;
case Instruction::BitCast:
- MadeChange |= foldBitcastShuf(I);
+ MadeChange |= foldBitcastShuffle(I);
break;
}
} else {