diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
| commit | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch) | |
| tree | f42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/AggressiveInstCombine | |
| parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/AggressiveInstCombine')
3 files changed, 128 insertions, 24 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 85abbf6d86e0..7243e39c9029 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -18,6 +18,7 @@ #include "llvm-c/Transforms/AggressiveInstCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -205,8 +206,8 @@ struct MaskOps { bool FoundAnd1; MaskOps(unsigned BitWidth, bool MatchAnds) - : Root(nullptr), Mask(APInt::getNullValue(BitWidth)), - MatchAndChain(MatchAnds), FoundAnd1(false) {} + : Root(nullptr), Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds), + FoundAnd1(false) {} }; /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a @@ -377,10 +378,10 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { // Also, we want to avoid matching partial patterns. // TODO: It would be more efficient if we removed dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { + for (Instruction &I : llvm::reverse(BB)) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedFunnelShift(I, DT); - MadeChange |= tryToRecognizePopCount(I); + MadeChange |= tryToRecognizePopCount(I); } } @@ -394,10 +395,11 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. -static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { +static bool runImpl(Function &F, AssumptionCache &AC, TargetLibraryInfo &TLI, + DominatorTree &DT) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); - TruncInstCombine TIC(TLI, DL, DT); + TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); MadeChange |= foldUnusualPatterns(F, DT); return MadeChange; @@ -406,6 +408,7 @@ static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { void AggressiveInstCombinerLegacyPass::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); @@ -415,16 +418,18 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage( } bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return runImpl(F, TLI, DT); + return runImpl(F, AC, TLI, DT); } PreservedAnalyses AggressiveInstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - if (!runImpl(F, TLI, DT)) { + if (!runImpl(F, AC, TLI, DT)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } @@ -438,6 +443,7 @@ char AggressiveInstCombinerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", "Combine pattern based expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index 42bcadfc7dcd..5d69e26d6ecc 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -17,6 +17,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -39,16 +41,18 @@ using namespace llvm; //===----------------------------------------------------------------------===// namespace llvm { - class DataLayout; - class DominatorTree; - class Function; - class Instruction; - class TargetLibraryInfo; - class TruncInst; - class Type; - class Value; +class AssumptionCache; +class DataLayout; +class DominatorTree; +class Function; +class Instruction; +class TargetLibraryInfo; +class TruncInst; +class Type; +class Value; class TruncInstCombine { + AssumptionCache &AC; TargetLibraryInfo &TLI; const DataLayout &DL; const DominatorTree &DT; @@ -75,9 +79,9 @@ class TruncInstCombine { MapVector<Instruction *, Info> InstInfoMap; public: - TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL, - const DominatorTree &DT) - : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + TruncInstCombine(AssumptionCache &AC, TargetLibraryInfo &TLI, + const DataLayout &DL, const DominatorTree &DT) + : AC(AC), TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} /// Perform TruncInst pattern optimization on given function. bool run(Function &F); @@ -104,6 +108,18 @@ private: /// to be reduced. Type *getBestTruncatedType(); + KnownBits computeKnownBits(const Value *V) const { + return llvm::computeKnownBits(V, DL, /*Depth=*/0, &AC, + /*CtxI=*/cast<Instruction>(CurrentTruncInst), + &DT); + } + + unsigned ComputeNumSignBits(const Value *V) const { + return llvm::ComputeNumSignBits( + V, DL, /*Depth=*/0, &AC, /*CtxI=*/cast<Instruction>(CurrentTruncInst), + &DT); + } + /// Given a \p V value and a \p SclTy scalar type return the generated reduced /// value of \p V based on the type \p SclTy. /// diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 16b82219e8ca..abac3f801a22 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -61,9 +62,18 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; + case Instruction::ExtractElement: + Ops.push_back(I->getOperand(0)); + break; case Instruction::Select: Ops.push_back(I->getOperand(1)); Ops.push_back(I->getOperand(2)); @@ -127,6 +137,13 @@ bool TruncInstCombine::buildTruncExpressionDag() { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: + case Instruction::ExtractElement: case Instruction::Select: { SmallVector<Value *, 2> Operands; getRelevantOperands(I, Operands); @@ -135,10 +152,9 @@ bool TruncInstCombine::buildTruncExpressionDag() { } default: // TODO: Can handle more cases here: - // 1. shufflevector, extractelement, insertelement - // 2. udiv, urem - // 3. shl, lshr, ashr - // 4. phi node(and loop handling) + // 1. shufflevector + // 2. sdiv, srem + // 3. phi node(and loop handling) // ... return false; } @@ -270,6 +286,50 @@ Type *TruncInstCombine::getBestTruncatedType() { unsigned OrigBitWidth = CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + // Initialize MinBitWidth for shift instructions with the minimum number + // that is greater than shift amount (i.e. shift amount + 1). + // For `lshr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are zeros. + // For `ashr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are sign bits (all zeros or ones) + // and even one (first) untruncated bit is sign bit. + // Exit early if MinBitWidth is not less than original bitwidth. + for (auto &Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->isShift()) { + KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); + unsigned MinBitWidth = KnownRHS.getMaxValue() + .uadd_sat(APInt(OrigBitWidth, 1)) + .getLimitedValue(OrigBitWidth); + if (MinBitWidth == OrigBitWidth) + return nullptr; + if (I->getOpcode() == Instruction::LShr) { + KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); + MinBitWidth = + std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); + } + if (I->getOpcode() == Instruction::AShr) { + unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); + MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); + } + if (MinBitWidth >= OrigBitWidth) + return nullptr; + Itr.second.MinBitWidth = MinBitWidth; + } + if (I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::URem) { + unsigned MinBitWidth = 0; + for (const auto &Op : I->operands()) { + KnownBits Known = computeKnownBits(Op); + MinBitWidth = + std::max(Known.getMaxValue().getActiveBits(), MinBitWidth); + if (MinBitWidth >= OrigBitWidth) + return nullptr; + } + Itr.second.MinBitWidth = MinBitWidth; + } + } + // Calculate minimum allowed bit-width allowed for shrinking the currently // visited truncate's operand. unsigned MinBitWidth = getMinBitWidth(); @@ -356,10 +416,32 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: { Value *LHS = getReducedOperand(I->getOperand(0), SclTy); Value *RHS = getReducedOperand(I->getOperand(1), SclTy); Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + // Preserve `exact` flag since truncation doesn't change exactness + if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->setIsExact(PEO->isExact()); + break; + } + case Instruction::ExtractElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *Idx = I->getOperand(1); + Res = Builder.CreateExtractElement(Vec, Idx); + break; + } + case Instruction::InsertElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *NewElt = getReducedOperand(I->getOperand(1), SclTy); + Value *Idx = I->getOperand(2); + Res = Builder.CreateInsertElement(Vec, NewElt, Idx); break; } case Instruction::Select: { |
