diff options
Diffstat (limited to 'llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp')
| -rw-r--r-- | llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp | 92 |
1 files changed, 87 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 16b82219e8ca..abac3f801a22 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -61,9 +62,18 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; + case Instruction::ExtractElement: + Ops.push_back(I->getOperand(0)); + break; case Instruction::Select: Ops.push_back(I->getOperand(1)); Ops.push_back(I->getOperand(2)); @@ -127,6 +137,13 @@ bool TruncInstCombine::buildTruncExpressionDag() { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: + case Instruction::ExtractElement: case Instruction::Select: { SmallVector<Value *, 2> Operands; getRelevantOperands(I, Operands); @@ -135,10 +152,9 @@ bool TruncInstCombine::buildTruncExpressionDag() { } default: // TODO: Can handle more cases here: - // 1. shufflevector, extractelement, insertelement - // 2. udiv, urem - // 3. shl, lshr, ashr - // 4. phi node(and loop handling) + // 1. shufflevector + // 2. sdiv, srem + // 3. phi node(and loop handling) // ... return false; } @@ -270,6 +286,50 @@ Type *TruncInstCombine::getBestTruncatedType() { unsigned OrigBitWidth = CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + // Initialize MinBitWidth for shift instructions with the minimum number + // that is greater than shift amount (i.e. shift amount + 1). + // For `lshr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are zeros. + // For `ashr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are sign bits (all zeros or ones) + // and even one (first) untruncated bit is sign bit. + // Exit early if MinBitWidth is not less than original bitwidth. + for (auto &Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->isShift()) { + KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); + unsigned MinBitWidth = KnownRHS.getMaxValue() + .uadd_sat(APInt(OrigBitWidth, 1)) + .getLimitedValue(OrigBitWidth); + if (MinBitWidth == OrigBitWidth) + return nullptr; + if (I->getOpcode() == Instruction::LShr) { + KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); + MinBitWidth = + std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); + } + if (I->getOpcode() == Instruction::AShr) { + unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); + MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); + } + if (MinBitWidth >= OrigBitWidth) + return nullptr; + Itr.second.MinBitWidth = MinBitWidth; + } + if (I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::URem) { + unsigned MinBitWidth = 0; + for (const auto &Op : I->operands()) { + KnownBits Known = computeKnownBits(Op); + MinBitWidth = + std::max(Known.getMaxValue().getActiveBits(), MinBitWidth); + if (MinBitWidth >= OrigBitWidth) + return nullptr; + } + Itr.second.MinBitWidth = MinBitWidth; + } + } + // Calculate minimum allowed bit-width allowed for shrinking the currently // visited truncate's operand. unsigned MinBitWidth = getMinBitWidth(); @@ -356,10 +416,32 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: { Value *LHS = getReducedOperand(I->getOperand(0), SclTy); Value *RHS = getReducedOperand(I->getOperand(1), SclTy); Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + // Preserve `exact` flag since truncation doesn't change exactness + if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->setIsExact(PEO->isExact()); + break; + } + case Instruction::ExtractElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *Idx = I->getOperand(1); + Res = Builder.CreateExtractElement(Vec, Idx); + break; + } + case Instruction::InsertElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *NewElt = getReducedOperand(I->getOperand(1), SclTy); + Value *Idx = I->getOperand(2); + Res = Builder.CreateInsertElement(Vec, NewElt, Idx); break; } case Instruction::Select: { |
