summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/StraightLineStrengthReduce.cpp')
-rw-r--r--lib/Transforms/Scalar/StraightLineStrengthReduce.cpp79
1 files changed, 30 insertions, 49 deletions
diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
index 1faa65eb34175..292d0400a516b 100644
--- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp
@@ -57,8 +57,6 @@
// SLSR.
#include <vector>
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/FoldingSet.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -76,6 +74,8 @@ using namespace PatternMatch;
namespace {
+static const unsigned UnknownAddressSpace = ~0u;
+
class StraightLineStrengthReduce : public FunctionPass {
public:
// SLSR candidate. Such a candidate must be in one of the forms described in
@@ -234,51 +234,22 @@ bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis,
Basis.CandidateKind == C.CandidateKind);
}
-// TODO: use TTI->getGEPCost.
static bool isGEPFoldable(GetElementPtrInst *GEP,
- const TargetTransformInfo *TTI,
- const DataLayout *DL) {
- GlobalVariable *BaseGV = nullptr;
- int64_t BaseOffset = 0;
- bool HasBaseReg = false;
- int64_t Scale = 0;
-
- if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getPointerOperand()))
- BaseGV = GV;
- else
- HasBaseReg = true;
-
- gep_type_iterator GTI = gep_type_begin(GEP);
- for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I, ++GTI) {
- if (isa<SequentialType>(*GTI)) {
- int64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType());
- if (ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I)) {
- BaseOffset += ConstIdx->getSExtValue() * ElementSize;
- } else {
- // Needs scale register.
- if (Scale != 0) {
- // No addressing mode takes two scale registers.
- return false;
- }
- Scale = ElementSize;
- }
- } else {
- StructType *STy = cast<StructType>(*GTI);
- uint64_t Field = cast<ConstantInt>(*I)->getZExtValue();
- BaseOffset += DL->getStructLayout(STy)->getElementOffset(Field);
- }
- }
-
- unsigned AddrSpace = GEP->getPointerAddressSpace();
- return TTI->isLegalAddressingMode(GEP->getType()->getElementType(), BaseGV,
- BaseOffset, HasBaseReg, Scale, AddrSpace);
+ const TargetTransformInfo *TTI) {
+ SmallVector<const Value*, 4> Indices;
+ for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I)
+ Indices.push_back(*I);
+ return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
+ Indices) == TargetTransformInfo::TCC_Free;
}
// Returns whether (Base + Index * Stride) can be folded to an addressing mode.
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride,
TargetTransformInfo *TTI) {
- return TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
- Index->getSExtValue());
+ // Index->getSExtValue() may crash if Index is wider than 64-bit.
+ return Index->getBitWidth() <= 64 &&
+ TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
+ Index->getSExtValue(), UnknownAddressSpace);
}
bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
@@ -287,7 +258,7 @@ bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
if (C.CandidateKind == Candidate::Add)
return isAddFoldable(C.Base, C.Index, C.Stride, TTI);
if (C.CandidateKind == Candidate::GEP)
- return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI, DL);
+ return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI);
return false;
}
@@ -533,13 +504,23 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
IndexExprs, GEP->isInBounds());
Value *ArrayIdx = GEP->getOperand(I);
uint64_t ElementSize = DL->getTypeAllocSize(*GTI);
- factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
+ if (ArrayIdx->getType()->getIntegerBitWidth() <=
+ DL->getPointerSizeInBits(GEP->getAddressSpace())) {
+ // Skip factoring if ArrayIdx is wider than the pointer size, because
+ // ArrayIdx is implicitly truncated to the pointer size.
+ factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
+ }
// When ArrayIdx is the sext of a value, we try to factor that value as
// well. Handling this case is important because array indices are
// typically sign-extended to the pointer size.
Value *TruncatedArrayIdx = nullptr;
- if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))))
+ if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) &&
+ TruncatedArrayIdx->getType()->getIntegerBitWidth() <=
+ DL->getPointerSizeInBits(GEP->getAddressSpace())) {
+ // Skip factoring if TruncatedArrayIdx is wider than the pointer size,
+ // because TruncatedArrayIdx is implicitly truncated to the pointer size.
factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
+ }
IndexExprs[I - 1] = OrigIndexExpr;
}
@@ -567,10 +548,10 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
APInt ElementSize(
IndexOffset.getBitWidth(),
DL->getTypeAllocSize(
- cast<GetElementPtrInst>(Basis.Ins)->getType()->getElementType()));
+ cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
APInt Q, R;
APInt::sdivrem(IndexOffset, ElementSize, Q, R);
- if (R.getSExtValue() == 0)
+ if (R == 0)
IndexOffset = Q;
else
BumpWithUglyGEP = true;
@@ -578,10 +559,10 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
// Compute Bump = C - Basis = (i' - i) * S.
// Common case 1: if (i' - i) is 1, Bump = S.
- if (IndexOffset.getSExtValue() == 1)
+ if (IndexOffset == 1)
return C.Stride;
// Common case 2: if (i' - i) is -1, Bump = -S.
- if (IndexOffset.getSExtValue() == -1)
+ if (IndexOffset.isAllOnesValue())
return Builder.CreateNeg(C.Stride);
// Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may
@@ -685,7 +666,7 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis(
}
bool StraightLineStrengthReduce::runOnFunction(Function &F) {
- if (skipOptnoneFunction(F))
+ if (skipFunction(F))
return false;
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);