summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-11-19 20:06:13 +0000
commitc0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch)
treef42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
parent344a3780b2e33f6ca763666c380202b18aab72a3 (diff)
Diffstat (limited to 'llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp')
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp92
1 files changed, 87 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 16b82219e8ca..abac3f801a22 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/Support/KnownBits.h"
using namespace llvm;
@@ -61,9 +62,18 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ case Instruction::InsertElement:
Ops.push_back(I->getOperand(0));
Ops.push_back(I->getOperand(1));
break;
+ case Instruction::ExtractElement:
+ Ops.push_back(I->getOperand(0));
+ break;
case Instruction::Select:
Ops.push_back(I->getOperand(1));
Ops.push_back(I->getOperand(2));
@@ -127,6 +137,13 @@ bool TruncInstCombine::buildTruncExpressionDag() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ case Instruction::InsertElement:
+ case Instruction::ExtractElement:
case Instruction::Select: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
@@ -135,10 +152,9 @@ bool TruncInstCombine::buildTruncExpressionDag() {
}
default:
// TODO: Can handle more cases here:
- // 1. shufflevector, extractelement, insertelement
- // 2. udiv, urem
- // 3. shl, lshr, ashr
- // 4. phi node(and loop handling)
+ // 1. shufflevector
+ // 2. sdiv, srem
+ // 3. phi node(and loop handling)
// ...
return false;
}
@@ -270,6 +286,50 @@ Type *TruncInstCombine::getBestTruncatedType() {
unsigned OrigBitWidth =
CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
+ // Initialize MinBitWidth for shift instructions with the minimum number
+ // that is greater than shift amount (i.e. shift amount + 1).
+ // For `lshr` adjust MinBitWidth so that all potentially truncated
+ // bits of the value-to-be-shifted are zeros.
+ // For `ashr` adjust MinBitWidth so that all potentially truncated
+ // bits of the value-to-be-shifted are sign bits (all zeros or ones)
+ // and even one (first) untruncated bit is sign bit.
+ // Exit early if MinBitWidth is not less than original bitwidth.
+ for (auto &Itr : InstInfoMap) {
+ Instruction *I = Itr.first;
+ if (I->isShift()) {
+ KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
+ unsigned MinBitWidth = KnownRHS.getMaxValue()
+ .uadd_sat(APInt(OrigBitWidth, 1))
+ .getLimitedValue(OrigBitWidth);
+ if (MinBitWidth == OrigBitWidth)
+ return nullptr;
+ if (I->getOpcode() == Instruction::LShr) {
+ KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
+ MinBitWidth =
+ std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
+ }
+ if (I->getOpcode() == Instruction::AShr) {
+ unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
+ MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
+ }
+ if (MinBitWidth >= OrigBitWidth)
+ return nullptr;
+ Itr.second.MinBitWidth = MinBitWidth;
+ }
+ if (I->getOpcode() == Instruction::UDiv ||
+ I->getOpcode() == Instruction::URem) {
+ unsigned MinBitWidth = 0;
+ for (const auto &Op : I->operands()) {
+ KnownBits Known = computeKnownBits(Op);
+ MinBitWidth =
+ std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
+ if (MinBitWidth >= OrigBitWidth)
+ return nullptr;
+ }
+ Itr.second.MinBitWidth = MinBitWidth;
+ }
+ }
+
// Calculate minimum allowed bit-width allowed for shrinking the currently
// visited truncate's operand.
unsigned MinBitWidth = getMinBitWidth();
@@ -356,10 +416,32 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
case Instruction::Mul:
case Instruction::And:
case Instruction::Or:
- case Instruction::Xor: {
+ case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::UDiv:
+ case Instruction::URem: {
Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
+ // Preserve `exact` flag since truncation doesn't change exactness
+ if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
+ if (auto *ResI = dyn_cast<Instruction>(Res))
+ ResI->setIsExact(PEO->isExact());
+ break;
+ }
+ case Instruction::ExtractElement: {
+ Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
+ Value *Idx = I->getOperand(1);
+ Res = Builder.CreateExtractElement(Vec, Idx);
+ break;
+ }
+ case Instruction::InsertElement: {
+ Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
+ Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
+ Value *Idx = I->getOperand(2);
+ Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
break;
}
case Instruction::Select: {