summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp208
1 files changed, 175 insertions, 33 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index 616196ad5ba3..c4eeb81c5133 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -57,7 +57,7 @@ using namespace llvm;
static cl::opt<unsigned> UnrollThresholdPrivate(
"amdgpu-unroll-threshold-private",
cl::desc("Unroll threshold for AMDGPU if private memory used in a loop"),
- cl::init(2000), cl::Hidden);
+ cl::init(2700), cl::Hidden);
static cl::opt<unsigned> UnrollThresholdLocal(
"amdgpu-unroll-threshold-local",
@@ -90,7 +90,8 @@ static bool dependsOnLocalPhi(const Loop *L, const Value *Cond,
void AMDGPUTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
TTI::UnrollingPreferences &UP) {
- UP.Threshold = 300; // Twice the default.
+ const Function &F = *L->getHeader()->getParent();
+ UP.Threshold = AMDGPU::getIntegerAttribute(F, "amdgpu-unroll-threshold", 300);
UP.MaxCount = std::numeric_limits<unsigned>::max();
UP.Partial = true;
@@ -337,10 +338,13 @@ bool GCNTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
}
}
-int GCNTTIImpl::getArithmeticInstrCost(
- unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info,
- TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
- TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args ) {
+int GCNTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
+ TTI::OperandValueKind Opd1Info,
+ TTI::OperandValueKind Opd2Info,
+ TTI::OperandValueProperties Opd1PropInfo,
+ TTI::OperandValueProperties Opd2PropInfo,
+ ArrayRef<const Value *> Args,
+ const Instruction *CxtI) {
EVT OrigTy = TLI->getValueType(DL, Ty);
if (!OrigTy.isSimple()) {
return BaseT::getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info,
@@ -365,6 +369,9 @@ int GCNTTIImpl::getArithmeticInstrCost(
if (SLT == MVT::i64)
return get64BitInstrCost() * LT.first * NElts;
+ if (ST->has16BitInsts() && SLT == MVT::i16)
+ NElts = (NElts + 1) / 2;
+
// i32
return getFullRateInstrCost() * LT.first * NElts;
case ISD::ADD:
@@ -372,11 +379,14 @@ int GCNTTIImpl::getArithmeticInstrCost(
case ISD::AND:
case ISD::OR:
case ISD::XOR:
- if (SLT == MVT::i64){
+ if (SLT == MVT::i64) {
// and, or and xor are typically split into 2 VALU instructions.
return 2 * getFullRateInstrCost() * LT.first * NElts;
}
+ if (ST->has16BitInsts() && SLT == MVT::i16)
+ NElts = (NElts + 1) / 2;
+
return LT.first * NElts * getFullRateInstrCost();
case ISD::MUL: {
const int QuarterRateCost = getQuarterRateInstrCost();
@@ -385,6 +395,9 @@ int GCNTTIImpl::getArithmeticInstrCost(
return (4 * QuarterRateCost + (2 * 2) * FullRateCost) * LT.first * NElts;
}
+ if (ST->has16BitInsts() && SLT == MVT::i16)
+ NElts = (NElts + 1) / 2;
+
// i32
return QuarterRateCost * NElts * LT.first;
}
@@ -394,6 +407,9 @@ int GCNTTIImpl::getArithmeticInstrCost(
if (SLT == MVT::f64)
return LT.first * NElts * get64BitInstrCost();
+ if (ST->has16BitInsts() && SLT == MVT::f16)
+ NElts = (NElts + 1) / 2;
+
if (SLT == MVT::f32 || SLT == MVT::f16)
return LT.first * NElts * getFullRateInstrCost();
break;
@@ -412,7 +428,7 @@ int GCNTTIImpl::getArithmeticInstrCost(
if (!Args.empty() && match(Args[0], PatternMatch::m_FPOne())) {
// TODO: This is more complicated, unsafe flags etc.
- if ((SLT == MVT::f32 && !ST->hasFP32Denormals()) ||
+ if ((SLT == MVT::f32 && !HasFP32Denormals) ||
(SLT == MVT::f16 && ST->has16BitInsts())) {
return LT.first * getQuarterRateInstrCost() * NElts;
}
@@ -431,7 +447,7 @@ int GCNTTIImpl::getArithmeticInstrCost(
if (SLT == MVT::f32 || SLT == MVT::f16) {
int Cost = 7 * getFullRateInstrCost() + 1 * getQuarterRateInstrCost();
- if (!ST->hasFP32Denormals()) {
+ if (!HasFP32Denormals) {
// FP mode switches.
Cost += 2 * getFullRateInstrCost();
}
@@ -447,6 +463,49 @@ int GCNTTIImpl::getArithmeticInstrCost(
Opd1PropInfo, Opd2PropInfo);
}
+template <typename T>
+int GCNTTIImpl::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
+ ArrayRef<T *> Args,
+ FastMathFlags FMF, unsigned VF) {
+ if (ID != Intrinsic::fma)
+ return BaseT::getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF);
+
+ EVT OrigTy = TLI->getValueType(DL, RetTy);
+ if (!OrigTy.isSimple()) {
+ return BaseT::getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF);
+ }
+
+ // Legalize the type.
+ std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, RetTy);
+
+ unsigned NElts = LT.second.isVector() ?
+ LT.second.getVectorNumElements() : 1;
+
+ MVT::SimpleValueType SLT = LT.second.getScalarType().SimpleTy;
+
+ if (SLT == MVT::f64)
+ return LT.first * NElts * get64BitInstrCost();
+
+ if (ST->has16BitInsts() && SLT == MVT::f16)
+ NElts = (NElts + 1) / 2;
+
+ return LT.first * NElts * (ST->hasFastFMAF32() ? getHalfRateInstrCost()
+ : getQuarterRateInstrCost());
+}
+
+int GCNTTIImpl::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
+ ArrayRef<Value*> Args, FastMathFlags FMF,
+ unsigned VF) {
+ return getIntrinsicInstrCost<Value>(ID, RetTy, Args, FMF, VF);
+}
+
+int GCNTTIImpl::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
+ ArrayRef<Type *> Tys, FastMathFlags FMF,
+ unsigned ScalarizationCostPassed) {
+ return getIntrinsicInstrCost<Type>(ID, RetTy, Tys, FMF,
+ ScalarizationCostPassed);
+}
+
unsigned GCNTTIImpl::getCFInstrCost(unsigned Opcode) {
// XXX - For some reason this isn't called for switch.
switch (Opcode) {
@@ -671,10 +730,13 @@ unsigned GCNTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index,
bool GCNTTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
const TargetMachine &TM = getTLI()->getTargetMachine();
- const FeatureBitset &CallerBits =
- TM.getSubtargetImpl(*Caller)->getFeatureBits();
- const FeatureBitset &CalleeBits =
- TM.getSubtargetImpl(*Callee)->getFeatureBits();
+ const GCNSubtarget *CallerST
+ = static_cast<const GCNSubtarget *>(TM.getSubtargetImpl(*Caller));
+ const GCNSubtarget *CalleeST
+ = static_cast<const GCNSubtarget *>(TM.getSubtargetImpl(*Callee));
+
+ const FeatureBitset &CallerBits = CallerST->getFeatureBits();
+ const FeatureBitset &CalleeBits = CalleeST->getFeatureBits();
FeatureBitset RealCallerBits = CallerBits & ~InlineFeatureIgnoreList;
FeatureBitset RealCalleeBits = CalleeBits & ~InlineFeatureIgnoreList;
@@ -683,8 +745,8 @@ bool GCNTTIImpl::areInlineCompatible(const Function *Caller,
// FIXME: dx10_clamp can just take the caller setting, but there seems to be
// no way to support merge for backend defined attributes.
- AMDGPU::SIModeRegisterDefaults CallerMode(*Caller);
- AMDGPU::SIModeRegisterDefaults CalleeMode(*Callee);
+ AMDGPU::SIModeRegisterDefaults CallerMode(*Caller, *CallerST);
+ AMDGPU::SIModeRegisterDefaults CalleeMode(*Callee, *CalleeST);
return CallerMode.isInlineCompatible(CalleeMode);
}
@@ -695,34 +757,114 @@ void GCNTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
unsigned GCNTTIImpl::getUserCost(const User *U,
ArrayRef<const Value *> Operands) {
- // Estimate extractelement elimination
- if (const ExtractElementInst *EE = dyn_cast<ExtractElementInst>(U)) {
- ConstantInt *CI = dyn_cast<ConstantInt>(EE->getOperand(1));
+ const Instruction *I = dyn_cast<Instruction>(U);
+ if (!I)
+ return BaseT::getUserCost(U, Operands);
+
+ // Estimate different operations to be optimized out
+ switch (I->getOpcode()) {
+ case Instruction::ExtractElement: {
+ ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
unsigned Idx = -1;
if (CI)
Idx = CI->getZExtValue();
- return getVectorInstrCost(EE->getOpcode(), EE->getOperand(0)->getType(),
- Idx);
+ return getVectorInstrCost(I->getOpcode(), I->getOperand(0)->getType(), Idx);
}
-
- // Estimate insertelement elimination
- if (const InsertElementInst *IE = dyn_cast<InsertElementInst>(U)) {
- ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
+ case Instruction::InsertElement: {
+ ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(2));
unsigned Idx = -1;
if (CI)
Idx = CI->getZExtValue();
- return getVectorInstrCost(IE->getOpcode(), IE->getType(), Idx);
+ return getVectorInstrCost(I->getOpcode(), I->getType(), Idx);
}
+ case Instruction::Call: {
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) {
+ SmallVector<Value *, 4> Args(II->arg_operands());
+ FastMathFlags FMF;
+ if (auto *FPMO = dyn_cast<FPMathOperator>(II))
+ FMF = FPMO->getFastMathFlags();
+ return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), Args,
+ FMF);
+ } else {
+ return BaseT::getUserCost(U, Operands);
+ }
+ }
+ case Instruction::ShuffleVector: {
+ const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
+ Type *Ty = Shuffle->getType();
+ Type *SrcTy = Shuffle->getOperand(0)->getType();
+
+ // TODO: Identify and add costs for insert subvector, etc.
+ int SubIndex;
+ if (Shuffle->isExtractSubvectorMask(SubIndex))
+ return getShuffleCost(TTI::SK_ExtractSubvector, SrcTy, SubIndex, Ty);
+
+ if (Shuffle->changesLength())
+ return BaseT::getUserCost(U, Operands);
+
+ if (Shuffle->isIdentity())
+ return 0;
- // Estimate different intrinsics, e.g. llvm.fabs
- if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) {
- SmallVector<Value *, 4> Args(II->arg_operands());
- FastMathFlags FMF;
- if (auto *FPMO = dyn_cast<FPMathOperator>(II))
- FMF = FPMO->getFastMathFlags();
- return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), Args,
- FMF);
+ if (Shuffle->isReverse())
+ return getShuffleCost(TTI::SK_Reverse, Ty, 0, nullptr);
+
+ if (Shuffle->isSelect())
+ return getShuffleCost(TTI::SK_Select, Ty, 0, nullptr);
+
+ if (Shuffle->isTranspose())
+ return getShuffleCost(TTI::SK_Transpose, Ty, 0, nullptr);
+
+ if (Shuffle->isZeroEltSplat())
+ return getShuffleCost(TTI::SK_Broadcast, Ty, 0, nullptr);
+
+ if (Shuffle->isSingleSource())
+ return getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, 0, nullptr);
+
+ return getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, 0, nullptr);
+ }
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::FPExt:
+ case Instruction::PtrToInt:
+ case Instruction::IntToPtr:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ case Instruction::Trunc:
+ case Instruction::FPTrunc:
+ case Instruction::BitCast:
+ case Instruction::AddrSpaceCast: {
+ return getCastInstrCost(I->getOpcode(), I->getType(),
+ I->getOperand(0)->getType(), I);
}
+ case Instruction::Add:
+ case Instruction::FAdd:
+ case Instruction::Sub:
+ case Instruction::FSub:
+ case Instruction::Mul:
+ case Instruction::FMul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::FDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::FRem:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::FNeg: {
+ return getArithmeticInstrCost(I->getOpcode(), I->getType(),
+ TTI::OK_AnyValue, TTI::OK_AnyValue,
+ TTI::OP_None, TTI::OP_None, Operands, I);
+ }
+ default:
+ break;
+ }
+
return BaseT::getUserCost(U, Operands);
}