summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/AggressiveInstCombine
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
commitc0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch)
treef42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/AggressiveInstCombine
parent344a3780b2e33f6ca763666c380202b18aab72a3 (diff)
Diffstat (limited to 'llvm/lib/Transforms/AggressiveInstCombine')
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp22
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h38
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp92
3 files changed, 128 insertions, 24 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 85abbf6d86e0..7243e39c9029 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -18,6 +18,7 @@
#include "llvm-c/Transforms/AggressiveInstCombine.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -205,8 +206,8 @@ struct MaskOps {
bool FoundAnd1;
MaskOps(unsigned BitWidth, bool MatchAnds)
- : Root(nullptr), Mask(APInt::getNullValue(BitWidth)),
- MatchAndChain(MatchAnds), FoundAnd1(false) {}
+ : Root(nullptr), Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds),
+ FoundAnd1(false) {}
};
/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
@@ -377,10 +378,10 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) {
// Also, we want to avoid matching partial patterns.
// TODO: It would be more efficient if we removed dead instructions
// iteratively in this loop rather than waiting until the end.
- for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
+ for (Instruction &I : llvm::reverse(BB)) {
MadeChange |= foldAnyOrAllBitsSet(I);
MadeChange |= foldGuardedFunnelShift(I, DT);
- MadeChange |= tryToRecognizePopCount(I);
+ MadeChange |= tryToRecognizePopCount(I);
}
}
@@ -394,10 +395,11 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) {
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
-static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) {
+static bool runImpl(Function &F, AssumptionCache &AC, TargetLibraryInfo &TLI,
+ DominatorTree &DT) {
bool MadeChange = false;
const DataLayout &DL = F.getParent()->getDataLayout();
- TruncInstCombine TIC(TLI, DL, DT);
+ TruncInstCombine TIC(AC, TLI, DL, DT);
MadeChange |= TIC.run(F);
MadeChange |= foldUnusualPatterns(F, DT);
return MadeChange;
@@ -406,6 +408,7 @@ static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) {
void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.setPreservesCFG();
+ AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
@@ -415,16 +418,18 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
}
bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) {
+ auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- return runImpl(F, TLI, DT);
+ return runImpl(F, AC, TLI, DT);
}
PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
FunctionAnalysisManager &AM) {
+ auto &AC = AM.getResult<AssumptionAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
- if (!runImpl(F, TLI, DT)) {
+ if (!runImpl(F, AC, TLI, DT)) {
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
}
@@ -438,6 +443,7 @@ char AggressiveInstCombinerLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass,
"aggressive-instcombine",
"Combine pattern based expressions", false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine",
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
index 42bcadfc7dcd..5d69e26d6ecc 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
@@ -17,6 +17,8 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Support/KnownBits.h"
using namespace llvm;
@@ -39,16 +41,18 @@ using namespace llvm;
//===----------------------------------------------------------------------===//
namespace llvm {
- class DataLayout;
- class DominatorTree;
- class Function;
- class Instruction;
- class TargetLibraryInfo;
- class TruncInst;
- class Type;
- class Value;
+class AssumptionCache;
+class DataLayout;
+class DominatorTree;
+class Function;
+class Instruction;
+class TargetLibraryInfo;
+class TruncInst;
+class Type;
+class Value;
class TruncInstCombine {
+ AssumptionCache &AC;
TargetLibraryInfo &TLI;
const DataLayout &DL;
const DominatorTree &DT;
@@ -75,9 +79,9 @@ class TruncInstCombine {
MapVector<Instruction *, Info> InstInfoMap;
public:
- TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL,
- const DominatorTree &DT)
- : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {}
+ TruncInstCombine(AssumptionCache &AC, TargetLibraryInfo &TLI,
+ const DataLayout &DL, const DominatorTree &DT)
+ : AC(AC), TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {}
/// Perform TruncInst pattern optimization on given function.
bool run(Function &F);
@@ -104,6 +108,18 @@ private:
/// to be reduced.
Type *getBestTruncatedType();
+ KnownBits computeKnownBits(const Value *V) const {
+ return llvm::computeKnownBits(V, DL, /*Depth=*/0, &AC,
+ /*CtxI=*/cast<Instruction>(CurrentTruncInst),
+ &DT);
+ }
+
+ unsigned ComputeNumSignBits(const Value *V) const {
+ return llvm::ComputeNumSignBits(
+ V, DL, /*Depth=*/0, &AC, /*CtxI=*/cast<Instruction>(CurrentTruncInst),
+ &DT);
+ }
+
/// Given a \p V value and a \p SclTy scalar type return the generated reduced
/// value of \p V based on the type \p SclTy.
///
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 16b82219e8ca..abac3f801a22 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/Support/KnownBits.h"
using namespace llvm;
@@ -61,9 +62,18 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ case Instruction::InsertElement:
Ops.push_back(I->getOperand(0));
Ops.push_back(I->getOperand(1));
break;
+ case Instruction::ExtractElement:
+ Ops.push_back(I->getOperand(0));
+ break;
case Instruction::Select:
Ops.push_back(I->getOperand(1));
Ops.push_back(I->getOperand(2));
@@ -127,6 +137,13 @@ bool TruncInstCombine::buildTruncExpressionDag() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ case Instruction::InsertElement:
+ case Instruction::ExtractElement:
case Instruction::Select: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
@@ -135,10 +152,9 @@ bool TruncInstCombine::buildTruncExpressionDag() {
}
default:
// TODO: Can handle more cases here:
- // 1. shufflevector, extractelement, insertelement
- // 2. udiv, urem
- // 3. shl, lshr, ashr
- // 4. phi node(and loop handling)
+ // 1. shufflevector
+ // 2. sdiv, srem
+ // 3. phi node(and loop handling)
// ...
return false;
}
@@ -270,6 +286,50 @@ Type *TruncInstCombine::getBestTruncatedType() {
unsigned OrigBitWidth =
CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
+ // Initialize MinBitWidth for shift instructions with the minimum number
+ // that is greater than shift amount (i.e. shift amount + 1).
+ // For `lshr` adjust MinBitWidth so that all potentially truncated
+ // bits of the value-to-be-shifted are zeros.
+ // For `ashr` adjust MinBitWidth so that all potentially truncated
+ // bits of the value-to-be-shifted are sign bits (all zeros or ones)
+ // and even one (first) untruncated bit is sign bit.
+ // Exit early if MinBitWidth is not less than original bitwidth.
+ for (auto &Itr : InstInfoMap) {
+ Instruction *I = Itr.first;
+ if (I->isShift()) {
+ KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
+ unsigned MinBitWidth = KnownRHS.getMaxValue()
+ .uadd_sat(APInt(OrigBitWidth, 1))
+ .getLimitedValue(OrigBitWidth);
+ if (MinBitWidth == OrigBitWidth)
+ return nullptr;
+ if (I->getOpcode() == Instruction::LShr) {
+ KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
+ MinBitWidth =
+ std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
+ }
+ if (I->getOpcode() == Instruction::AShr) {
+ unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
+ MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
+ }
+ if (MinBitWidth >= OrigBitWidth)
+ return nullptr;
+ Itr.second.MinBitWidth = MinBitWidth;
+ }
+ if (I->getOpcode() == Instruction::UDiv ||
+ I->getOpcode() == Instruction::URem) {
+ unsigned MinBitWidth = 0;
+ for (const auto &Op : I->operands()) {
+ KnownBits Known = computeKnownBits(Op);
+ MinBitWidth =
+ std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
+ if (MinBitWidth >= OrigBitWidth)
+ return nullptr;
+ }
+ Itr.second.MinBitWidth = MinBitWidth;
+ }
+ }
+
// Calculate minimum allowed bit-width allowed for shrinking the currently
// visited truncate's operand.
unsigned MinBitWidth = getMinBitWidth();
@@ -356,10 +416,32 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
case Instruction::Mul:
case Instruction::And:
case Instruction::Or:
- case Instruction::Xor: {
+ case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem: {
Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
+ // Preserve `exact` flag since truncation doesn't change exactness
+ if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
+ if (auto *ResI = dyn_cast<Instruction>(Res))
+ ResI->setIsExact(PEO->isExact());
+ break;
+ }
+ case Instruction::ExtractElement: {
+ Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
+ Value *Idx = I->getOperand(1);
+ Res = Builder.CreateExtractElement(Vec, Idx);
+ break;
+ }
+ case Instruction::InsertElement: {
+ Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
+ Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
+ Value *Idx = I->getOperand(2);
+ Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
break;
}
case Instruction::Select: {