summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
commit344a3780b2e33f6ca763666c380202b18aab72a3 (patch)
treef0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
parentb60736ec1405bb0a8dd40989f67ef4c93da068ab (diff)
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp1269
1 files changed, 1114 insertions, 155 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7fda6b8fb602..01236aa6b527 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -18,6 +18,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <algorithm>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -44,7 +45,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
/// Calculate the cost of materializing a 64-bit value. This helper
/// method might only calculate a fraction of a larger immediate. Therefore it
/// is valid to return a cost of ZERO.
-int AArch64TTIImpl::getIntImmCost(int64_t Val) {
+InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
// Check if the immediate can be encoded within an instruction.
if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
return 0;
@@ -59,8 +60,8 @@ int AArch64TTIImpl::getIntImmCost(int64_t Val) {
}
/// Calculate the cost of materializing the given constant.
-int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
- TTI::TargetCostKind CostKind) {
+InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
+ TTI::TargetCostKind CostKind) {
assert(Ty->isIntegerTy());
unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -74,20 +75,20 @@ int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
// Split the constant into 64-bit chunks and calculate the cost for each
// chunk.
- int Cost = 0;
+ InstructionCost Cost = 0;
for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
int64_t Val = Tmp.getSExtValue();
Cost += getIntImmCost(Val);
}
// We need at least one instruction to materialze the constant.
- return std::max(1, Cost);
+ return std::max<InstructionCost>(1, Cost);
}
-int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
- const APInt &Imm, Type *Ty,
- TTI::TargetCostKind CostKind,
- Instruction *Inst) {
+InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
+ const APInt &Imm, Type *Ty,
+ TTI::TargetCostKind CostKind,
+ Instruction *Inst) {
assert(Ty->isIntegerTy());
unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -144,7 +145,7 @@ int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
if (Idx == ImmIdx) {
int NumConstants = (BitSize + 63) / 64;
- int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
+ InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
return (Cost <= NumConstants * TTI::TCC_Basic)
? static_cast<int>(TTI::TCC_Free)
: Cost;
@@ -152,9 +153,10 @@ int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
}
-int AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
- const APInt &Imm, Type *Ty,
- TTI::TargetCostKind CostKind) {
+InstructionCost
+AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
+ const APInt &Imm, Type *Ty,
+ TTI::TargetCostKind CostKind) {
assert(Ty->isIntegerTy());
unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -180,7 +182,7 @@ int AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
case Intrinsic::umul_with_overflow:
if (Idx == 1) {
int NumConstants = (BitSize + 63) / 64;
- int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
+ InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
return (Cost <= NumConstants * TTI::TCC_Basic)
? static_cast<int>(TTI::TCC_Free)
: Cost;
@@ -212,7 +214,7 @@ AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
return TTI::PSK_Software;
}
-unsigned
+InstructionCost
AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
TTI::TargetCostKind CostKind) {
auto *RetTy = ICA.getReturnType();
@@ -235,12 +237,605 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
return LT.first;
break;
}
+ case Intrinsic::sadd_sat:
+ case Intrinsic::ssub_sat:
+ case Intrinsic::uadd_sat:
+ case Intrinsic::usub_sat: {
+ static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
+ MVT::v8i16, MVT::v2i32, MVT::v4i32,
+ MVT::v2i64};
+ auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
+ // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
+ // need to extend the type, as it uses shr(qadd(shl, shl)).
+ unsigned Instrs =
+ LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
+ if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
+ return LT.first * Instrs;
+ break;
+ }
+ case Intrinsic::abs: {
+ static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
+ MVT::v8i16, MVT::v2i32, MVT::v4i32,
+ MVT::v2i64};
+ auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
+ if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
+ return LT.first;
+ break;
+ }
+ case Intrinsic::experimental_stepvector: {
+ InstructionCost Cost = 1; // Cost of the `index' instruction
+ auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
+ // Legalisation of illegal vectors involves an `index' instruction plus
+ // (LT.first - 1) vector adds.
+ if (LT.first > 1) {
+ Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
+ InstructionCost AddCost =
+ getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
+ Cost += AddCost * (LT.first - 1);
+ }
+ return Cost;
+ }
+ case Intrinsic::bitreverse: {
+ static const CostTblEntry BitreverseTbl[] = {
+ {Intrinsic::bitreverse, MVT::i32, 1},
+ {Intrinsic::bitreverse, MVT::i64, 1},
+ {Intrinsic::bitreverse, MVT::v8i8, 1},
+ {Intrinsic::bitreverse, MVT::v16i8, 1},
+ {Intrinsic::bitreverse, MVT::v4i16, 2},
+ {Intrinsic::bitreverse, MVT::v8i16, 2},
+ {Intrinsic::bitreverse, MVT::v2i32, 2},
+ {Intrinsic::bitreverse, MVT::v4i32, 2},
+ {Intrinsic::bitreverse, MVT::v1i64, 2},
+ {Intrinsic::bitreverse, MVT::v2i64, 2},
+ };
+ const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
+ const auto *Entry =
+ CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
+ // Cost Model is using the legal type(i32) that i8 and i16 will be converted
+ // to +1 so that we match the actual lowering cost
+ if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
+ TLI->getValueType(DL, RetTy, true) == MVT::i16)
+ return LegalisationCost.first * Entry->Cost + 1;
+ if (Entry)
+ return LegalisationCost.first * Entry->Cost;
+ break;
+ }
+ case Intrinsic::ctpop: {
+ static const CostTblEntry CtpopCostTbl[] = {
+ {ISD::CTPOP, MVT::v2i64, 4},
+ {ISD::CTPOP, MVT::v4i32, 3},
+ {ISD::CTPOP, MVT::v8i16, 2},
+ {ISD::CTPOP, MVT::v16i8, 1},
+ {ISD::CTPOP, MVT::i64, 4},
+ {ISD::CTPOP, MVT::v2i32, 3},
+ {ISD::CTPOP, MVT::v4i16, 2},
+ {ISD::CTPOP, MVT::v8i8, 1},
+ {ISD::CTPOP, MVT::i32, 5},
+ };
+ auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
+ MVT MTy = LT.second;
+ if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
+ // Extra cost of +1 when illegal vector types are legalized by promoting
+ // the integer type.
+ int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
+ RetTy->getScalarSizeInBits()
+ ? 1
+ : 0;
+ return LT.first * Entry->Cost + ExtraCost;
+ }
+ break;
+ }
default:
break;
}
return BaseT::getIntrinsicInstrCost(ICA, CostKind);
}
+/// The function will remove redundant reinterprets casting in the presence
+/// of the control flow
+static Optional<Instruction *> processPhiNode(InstCombiner &IC,
+ IntrinsicInst &II) {
+ SmallVector<Instruction *, 32> Worklist;
+ auto RequiredType = II.getType();
+
+ auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
+ assert(PN && "Expected Phi Node!");
+
+ // Don't create a new Phi unless we can remove the old one.
+ if (!PN->hasOneUse())
+ return None;
+
+ for (Value *IncValPhi : PN->incoming_values()) {
+ auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
+ if (!Reinterpret ||
+ Reinterpret->getIntrinsicID() !=
+ Intrinsic::aarch64_sve_convert_to_svbool ||
+ RequiredType != Reinterpret->getArgOperand(0)->getType())
+ return None;
+ }
+
+ // Create the new Phi
+ LLVMContext &Ctx = PN->getContext();
+ IRBuilder<> Builder(Ctx);
+ Builder.SetInsertPoint(PN);
+ PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
+ Worklist.push_back(PN);
+
+ for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
+ auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
+ NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
+ Worklist.push_back(Reinterpret);
+ }
+
+ // Cleanup Phi Node and reinterprets
+ return IC.replaceInstUsesWith(II, NPN);
+}
+
+static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
+ IntrinsicInst &II) {
+ // If the reinterpret instruction operand is a PHI Node
+ if (isa<PHINode>(II.getArgOperand(0)))
+ return processPhiNode(IC, II);
+
+ SmallVector<Instruction *, 32> CandidatesForRemoval;
+ Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
+
+ const auto *IVTy = cast<VectorType>(II.getType());
+
+ // Walk the chain of conversions.
+ while (Cursor) {
+ // If the type of the cursor has fewer lanes than the final result, zeroing
+ // must take place, which breaks the equivalence chain.
+ const auto *CursorVTy = cast<VectorType>(Cursor->getType());
+ if (CursorVTy->getElementCount().getKnownMinValue() <
+ IVTy->getElementCount().getKnownMinValue())
+ break;
+
+ // If the cursor has the same type as I, it is a viable replacement.
+ if (Cursor->getType() == IVTy)
+ EarliestReplacement = Cursor;
+
+ auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
+
+ // If this is not an SVE conversion intrinsic, this is the end of the chain.
+ if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
+ Intrinsic::aarch64_sve_convert_to_svbool ||
+ IntrinsicCursor->getIntrinsicID() ==
+ Intrinsic::aarch64_sve_convert_from_svbool))
+ break;
+
+ CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
+ Cursor = IntrinsicCursor->getOperand(0);
+ }
+
+ // If no viable replacement in the conversion chain was found, there is
+ // nothing to do.
+ if (!EarliestReplacement)
+ return None;
+
+ return IC.replaceInstUsesWith(II, EarliestReplacement);
+}
+
+static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
+ IntrinsicInst &II) {
+ IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
+ if (!Pg)
+ return None;
+
+ if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
+ return None;
+
+ const auto PTruePattern =
+ cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
+ if (PTruePattern != AArch64SVEPredPattern::vl1)
+ return None;
+
+ // The intrinsic is inserting into lane zero so use an insert instead.
+ auto *IdxTy = Type::getInt64Ty(II.getContext());
+ auto *Insert = InsertElementInst::Create(
+ II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
+ Insert->insertBefore(&II);
+ Insert->takeName(&II);
+
+ return IC.replaceInstUsesWith(II, Insert);
+}
+
+static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
+ IntrinsicInst &II) {
+ LLVMContext &Ctx = II.getContext();
+ IRBuilder<> Builder(Ctx);
+ Builder.SetInsertPoint(&II);
+
+ // Check that the predicate is all active
+ auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
+ if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
+ return None;
+
+ const auto PTruePattern =
+ cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
+ if (PTruePattern != AArch64SVEPredPattern::all)
+ return None;
+
+ // Check that we have a compare of zero..
+ auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2));
+ if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
+ return None;
+
+ auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0));
+ if (!DupXArg || !DupXArg->isZero())
+ return None;
+
+ // ..against a dupq
+ auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
+ if (!DupQLane ||
+ DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
+ return None;
+
+ // Where the dupq is a lane 0 replicate of a vector insert
+ if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
+ return None;
+
+ auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
+ if (!VecIns ||
+ VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
+ return None;
+
+ // Where the vector insert is a fixed constant vector insert into undef at
+ // index zero
+ if (!isa<UndefValue>(VecIns->getArgOperand(0)))
+ return None;
+
+ if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
+ return None;
+
+ auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
+ if (!ConstVec)
+ return None;
+
+ auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
+ auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
+ if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
+ return None;
+
+ unsigned NumElts = VecTy->getNumElements();
+ unsigned PredicateBits = 0;
+
+ // Expand intrinsic operands to a 16-bit byte level predicate
+ for (unsigned I = 0; I < NumElts; ++I) {
+ auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
+ if (!Arg)
+ return None;
+ if (!Arg->isZero())
+ PredicateBits |= 1 << (I * (16 / NumElts));
+ }
+
+ // If all bits are zero bail early with an empty predicate
+ if (PredicateBits == 0) {
+ auto *PFalse = Constant::getNullValue(II.getType());
+ PFalse->takeName(&II);
+ return IC.replaceInstUsesWith(II, PFalse);
+ }
+
+ // Calculate largest predicate type used (where byte predicate is largest)
+ unsigned Mask = 8;
+ for (unsigned I = 0; I < 16; ++I)
+ if ((PredicateBits & (1 << I)) != 0)
+ Mask |= (I % 8);
+
+ unsigned PredSize = Mask & -Mask;
+ auto *PredType = ScalableVectorType::get(
+ Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
+
+ // Ensure all relevant bits are set
+ for (unsigned I = 0; I < 16; I += PredSize)
+ if ((PredicateBits & (1 << I)) == 0)
+ return None;
+
+ auto *PTruePat =
+ ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
+ auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
+ {PredType}, {PTruePat});
+ auto *ConvertToSVBool = Builder.CreateIntrinsic(
+ Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
+ auto *ConvertFromSVBool =
+ Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
+ {II.getType()}, {ConvertToSVBool});
+
+ ConvertFromSVBool->takeName(&II);
+ return IC.replaceInstUsesWith(II, ConvertFromSVBool);
+}
+
+static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
+ IntrinsicInst &II) {
+ Value *Pg = II.getArgOperand(0);
+ Value *Vec = II.getArgOperand(1);
+ bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta;
+
+ // lastX(splat(X)) --> X
+ if (auto *SplatVal = getSplatValue(Vec))
+ return IC.replaceInstUsesWith(II, SplatVal);
+
+ auto *C = dyn_cast<Constant>(Pg);
+ if (IsAfter && C && C->isNullValue()) {
+ // The intrinsic is extracting lane 0 so use an extract instead.
+ auto *IdxTy = Type::getInt64Ty(II.getContext());
+ auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
+ Extract->insertBefore(&II);
+ Extract->takeName(&II);
+ return IC.replaceInstUsesWith(II, Extract);
+ }
+
+ auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
+ if (!IntrPG)
+ return None;
+
+ if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
+ return None;
+
+ const auto PTruePattern =
+ cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
+
+ // Can the intrinsic's predicate be converted to a known constant index?
+ unsigned Idx;
+ switch (PTruePattern) {
+ default:
+ return None;
+ case AArch64SVEPredPattern::vl1:
+ Idx = 0;
+ break;
+ case AArch64SVEPredPattern::vl2:
+ Idx = 1;
+ break;
+ case AArch64SVEPredPattern::vl3:
+ Idx = 2;
+ break;
+ case AArch64SVEPredPattern::vl4:
+ Idx = 3;
+ break;
+ case AArch64SVEPredPattern::vl5:
+ Idx = 4;
+ break;
+ case AArch64SVEPredPattern::vl6:
+ Idx = 5;
+ break;
+ case AArch64SVEPredPattern::vl7:
+ Idx = 6;
+ break;
+ case AArch64SVEPredPattern::vl8:
+ Idx = 7;
+ break;
+ case AArch64SVEPredPattern::vl16:
+ Idx = 15;
+ break;
+ }
+
+ // Increment the index if extracting the element after the last active
+ // predicate element.
+ if (IsAfter)
+ ++Idx;
+
+ // Ignore extracts whose index is larger than the known minimum vector
+ // length. NOTE: This is an artificial constraint where we prefer to
+ // maintain what the user asked for until an alternative is proven faster.
+ auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
+ if (Idx >= PgVTy->getMinNumElements())
+ return None;
+
+ // The intrinsic is extracting a fixed lane so use an extract instead.
+ auto *IdxTy = Type::getInt64Ty(II.getContext());
+ auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
+ Extract->insertBefore(&II);
+ Extract->takeName(&II);
+ return IC.replaceInstUsesWith(II, Extract);
+}
+
+static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
+ IntrinsicInst &II) {
+ LLVMContext &Ctx = II.getContext();
+ IRBuilder<> Builder(Ctx);
+ Builder.SetInsertPoint(&II);
+ // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
+ // can work with RDFFR_PP for ptest elimination.
+ auto *AllPat =
+ ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
+ auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
+ {II.getType()}, {AllPat});
+ auto *RDFFR =
+ Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
+ RDFFR->takeName(&II);
+ return IC.replaceInstUsesWith(II, RDFFR);
+}
+
+static Optional<Instruction *>
+instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
+ const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+
+ if (Pattern == AArch64SVEPredPattern::all) {
+ LLVMContext &Ctx = II.getContext();
+ IRBuilder<> Builder(Ctx);
+ Builder.SetInsertPoint(&II);
+
+ Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
+ auto *VScale = Builder.CreateVScale(StepVal);
+ VScale->takeName(&II);
+ return IC.replaceInstUsesWith(II, VScale);
+ }
+
+ unsigned MinNumElts = 0;
+ switch (Pattern) {
+ default:
+ return None;
+ case AArch64SVEPredPattern::vl1:
+ case AArch64SVEPredPattern::vl2:
+ case AArch64SVEPredPattern::vl3:
+ case AArch64SVEPredPattern::vl4:
+ case AArch64SVEPredPattern::vl5:
+ case AArch64SVEPredPattern::vl6:
+ case AArch64SVEPredPattern::vl7:
+ case AArch64SVEPredPattern::vl8:
+ MinNumElts = Pattern;
+ break;
+ case AArch64SVEPredPattern::vl16:
+ MinNumElts = 16;
+ break;
+ }
+
+ return NumElts >= MinNumElts
+ ? Optional<Instruction *>(IC.replaceInstUsesWith(
+ II, ConstantInt::get(II.getType(), MinNumElts)))
+ : None;
+}
+
+static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
+ IntrinsicInst &II) {
+ IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
+ IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
+
+ if (Op1 && Op2 &&
+ Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
+ Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
+ Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
+
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
+ Type *Tys[] = {Op1->getArgOperand(0)->getType()};
+
+ auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
+
+ PTest->takeName(&II);
+ return IC.replaceInstUsesWith(II, PTest);
+ }
+
+ return None;
+}
+
+static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
+ IntrinsicInst &II) {
+ auto *OpPredicate = II.getOperand(0);
+ auto *OpMultiplicand = II.getOperand(1);
+ auto *OpMultiplier = II.getOperand(2);
+
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
+ // with a unit splat value, false otherwise.
+ auto IsUnitDupX = [](auto *I) {
+ auto *IntrI = dyn_cast<IntrinsicInst>(I);
+ if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
+ return false;
+
+ auto *SplatValue = IntrI->getOperand(0);
+ return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
+ };
+
+ // Return true if a given instruction is an aarch64_sve_dup intrinsic call
+ // with a unit splat value, false otherwise.
+ auto IsUnitDup = [](auto *I) {
+ auto *IntrI = dyn_cast<IntrinsicInst>(I);
+ if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
+ return false;
+
+ auto *SplatValue = IntrI->getOperand(2);
+ return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
+ };
+
+ // The OpMultiplier variable should always point to the dup (if any), so
+ // swap if necessary.
+ if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
+ std::swap(OpMultiplier, OpMultiplicand);
+
+ if (IsUnitDupX(OpMultiplier)) {
+ // [f]mul pg (dupx 1) %n => %n
+ OpMultiplicand->takeName(&II);
+ return IC.replaceInstUsesWith(II, OpMultiplicand);
+ } else if (IsUnitDup(OpMultiplier)) {
+ // [f]mul pg (dup pg 1) %n => %n
+ auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
+ auto *DupPg = DupInst->getOperand(1);
+ // TODO: this is naive. The optimization is still valid if DupPg
+ // 'encompasses' OpPredicate, not only if they're the same predicate.
+ if (OpPredicate == DupPg) {
+ OpMultiplicand->takeName(&II);
+ return IC.replaceInstUsesWith(II, OpMultiplicand);
+ }
+ }
+
+ return None;
+}
+
+static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
+ IntrinsicInst &II) {
+ auto *OpVal = II.getOperand(0);
+ auto *OpIndices = II.getOperand(1);
+ VectorType *VTy = cast<VectorType>(II.getType());
+
+ // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
+ // constant splat value < minimal element count of result.
+ auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
+ if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
+ return None;
+
+ auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
+ if (!SplatValue ||
+ SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
+ return None;
+
+ // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
+ // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
+ auto *VectorSplat =
+ Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
+
+ VectorSplat->takeName(&II);
+ return IC.replaceInstUsesWith(II, VectorSplat);
+}
+
+Optional<Instruction *>
+AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
+ IntrinsicInst &II) const {
+ Intrinsic::ID IID = II.getIntrinsicID();
+ switch (IID) {
+ default:
+ break;
+ case Intrinsic::aarch64_sve_convert_from_svbool:
+ return instCombineConvertFromSVBool(IC, II);
+ case Intrinsic::aarch64_sve_dup:
+ return instCombineSVEDup(IC, II);
+ case Intrinsic::aarch64_sve_cmpne:
+ case Intrinsic::aarch64_sve_cmpne_wide:
+ return instCombineSVECmpNE(IC, II);
+ case Intrinsic::aarch64_sve_rdffr:
+ return instCombineRDFFR(IC, II);
+ case Intrinsic::aarch64_sve_lasta:
+ case Intrinsic::aarch64_sve_lastb:
+ return instCombineSVELast(IC, II);
+ case Intrinsic::aarch64_sve_cntd:
+ return instCombineSVECntElts(IC, II, 2);
+ case Intrinsic::aarch64_sve_cntw:
+ return instCombineSVECntElts(IC, II, 4);
+ case Intrinsic::aarch64_sve_cnth:
+ return instCombineSVECntElts(IC, II, 8);
+ case Intrinsic::aarch64_sve_cntb:
+ return instCombineSVECntElts(IC, II, 16);
+ case Intrinsic::aarch64_sve_ptest_any:
+ case Intrinsic::aarch64_sve_ptest_first:
+ case Intrinsic::aarch64_sve_ptest_last:
+ return instCombineSVEPTest(IC, II);
+ case Intrinsic::aarch64_sve_mul:
+ case Intrinsic::aarch64_sve_fmul:
+ return instCombineSVEVectorMul(IC, II);
+ case Intrinsic::aarch64_sve_tbl:
+ return instCombineSVETBL(IC, II);
+ }
+
+ return None;
+}
+
bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
ArrayRef<const Value *> Args) {
@@ -297,18 +892,21 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
return false;
// Get the total number of vector elements in the legalized types.
- unsigned NumDstEls = DstTyL.first * DstTyL.second.getVectorMinNumElements();
- unsigned NumSrcEls = SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
+ InstructionCost NumDstEls =
+ DstTyL.first * DstTyL.second.getVectorMinNumElements();
+ InstructionCost NumSrcEls =
+ SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
// Return true if the legalized types have the same number of vector elements
// and the destination element type size is twice that of the source type.
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
}
-int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
- TTI::CastContextHint CCH,
- TTI::TargetCostKind CostKind,
- const Instruction *I) {
+InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
+ Type *Src,
+ TTI::CastContextHint CCH,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I) {
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
@@ -333,7 +931,7 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
}
// TODO: Allow non-throughput costs that aren't binary.
- auto AdjustCost = [&CostKind](int Cost) {
+ auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
if (CostKind != TTI::TCK_RecipThroughput)
return Cost == 0 ? 0 : 1;
return Cost;
@@ -353,6 +951,24 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
{ ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 3 },
{ ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
+ // Truncations on nxvmiN
+ { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
+ { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
+ { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
+ { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
+ { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
+ { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
+ { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
+ { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
+ { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
+ { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
+ { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
+ { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
+ { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
+ { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
+ { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
+ { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
+
// The number of shll instructions for the extension.
{ ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
{ ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
@@ -434,6 +1050,16 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
{ ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
{ ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 },
+ // Complex, from nxv2f32.
+ { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
+
// Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
{ ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
{ ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
@@ -441,6 +1067,107 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
{ ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
{ ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
{ ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 },
+
+ // Complex, from nxv2f64.
+ { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
+
+ // Complex, from nxv4f32.
+ { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
+ { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
+ { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
+
+ // Complex, from nxv8f64. Illegal -> illegal conversions not required.
+ { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
+ { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
+ { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
+ { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
+
+ // Complex, from nxv4f64. Illegal -> illegal conversions not required.
+ { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
+ { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
+ { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
+ { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
+ { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
+ { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
+
+ // Complex, from nxv8f32. Illegal -> illegal conversions not required.
+ { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
+ { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
+ { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
+ { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
+
+ // Complex, from nxv8f16.
+ { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
+ { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
+ { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
+ { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
+ { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
+
+ // Complex, from nxv4f16.
+ { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
+ { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
+ { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
+
+ // Complex, from nxv2f16.
+ { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
+ { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
+
+ // Truncate from nxvmf32 to nxvmf16.
+ { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
+ { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
+ { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
+
+ // Truncate from nxvmf64 to nxvmf16.
+ { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
+ { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
+ { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
+
+ // Truncate from nxvmf64 to nxvmf32.
+ { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
+ { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
+ { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
+
+ // Extend from nxvmf16 to nxvmf32.
+ { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
+ { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
+ { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
+
+ // Extend from nxvmf16 to nxvmf64.
+ { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
+ { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
+ { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
+
+ // Extend from nxvmf32 to nxvmf64.
+ { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
+ { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
+ { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
+
};
if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
@@ -452,9 +1179,10 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
}
-int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
- VectorType *VecTy,
- unsigned Index) {
+InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
+ Type *Dst,
+ VectorType *VecTy,
+ unsigned Index) {
// Make sure we were given a valid extend opcode.
assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
@@ -469,7 +1197,8 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
// Get the cost for the extract. We compute the cost (if any) for the extend
// below.
- auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
+ InstructionCost Cost =
+ getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
// Legalize the types.
auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
@@ -511,8 +1240,9 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
CostKind);
}
-unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
- TTI::TargetCostKind CostKind) {
+InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I) {
if (CostKind != TTI::TCK_RecipThroughput)
return Opcode == Instruction::PHI ? 0 : 1;
assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
@@ -520,13 +1250,13 @@ unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
return 0;
}
-int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
- unsigned Index) {
+InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
+ unsigned Index) {
assert(Val->isVectorTy() && "This must be a vector type");
if (Index != -1U) {
// Legalize the type.
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
// This type is legalized to a scalar type.
if (!LT.second.isVector())
@@ -545,10 +1275,10 @@ int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return ST->getVectorInsertExtractBaseCost();
}
-int AArch64TTIImpl::getArithmeticInstrCost(
+InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
- TTI::OperandValueKind Opd1Info,
- TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
+ TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
+ TTI::OperandValueProperties Opd1PropInfo,
TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
const Instruction *CxtI) {
// TODO: Handle more cost kinds.
@@ -558,7 +1288,7 @@ int AArch64TTIImpl::getArithmeticInstrCost(
Opd2PropInfo, Args, CxtI);
// Legalize the type.
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
// If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
// add in the widening overhead specified by the sub-target. Since the
@@ -566,7 +1296,7 @@ int AArch64TTIImpl::getArithmeticInstrCost(
// aren't present in the generated code and have a zero cost. By adding a
// widening overhead here, we attach the total cost of the combined operation
// to the widening instruction.
- int Cost = 0;
+ InstructionCost Cost = 0;
if (isWideningInstruction(Ty, Opcode, Args))
Cost += ST->getWideningBaseCost();
@@ -610,18 +1340,15 @@ int AArch64TTIImpl::getArithmeticInstrCost(
// Vector signed division by constant are expanded to the
// sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
// to MULHS + SUB + SRL + ADD + SRL.
- int MulCost = getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
- Opd1Info, Opd2Info,
- TargetTransformInfo::OP_None,
- TargetTransformInfo::OP_None);
- int AddCost = getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
- Opd1Info, Opd2Info,
- TargetTransformInfo::OP_None,
- TargetTransformInfo::OP_None);
- int ShrCost = getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
- Opd1Info, Opd2Info,
- TargetTransformInfo::OP_None,
- TargetTransformInfo::OP_None);
+ InstructionCost MulCost = getArithmeticInstrCost(
+ Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
+ TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
+ InstructionCost AddCost = getArithmeticInstrCost(
+ Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
+ TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
+ InstructionCost ShrCost = getArithmeticInstrCost(
+ Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
+ TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
}
}
@@ -677,8 +1404,9 @@ int AArch64TTIImpl::getArithmeticInstrCost(
}
}
-int AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
- const SCEV *Ptr) {
+InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
+ ScalarEvolution *SE,
+ const SCEV *Ptr) {
// Address computations in vectorized code with non-consecutive addresses will
// likely result in more instructions compared to scalar code where the
// computation can more often be merged into the index mode. The resulting
@@ -695,10 +1423,11 @@ int AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
return 1;
}
-int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
- Type *CondTy, CmpInst::Predicate VecPred,
- TTI::TargetCostKind CostKind,
- const Instruction *I) {
+InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
+ Type *CondTy,
+ CmpInst::Predicate VecPred,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I) {
// TODO: Handle other cost kinds.
if (CostKind != TTI::TCK_RecipThroughput)
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
@@ -772,7 +1501,28 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
return Options;
}
-unsigned AArch64TTIImpl::getGatherScatterOpCost(
+InstructionCost
+AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
+ Align Alignment, unsigned AddressSpace,
+ TTI::TargetCostKind CostKind) {
+ if (!isa<ScalableVectorType>(Src))
+ return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
+ CostKind);
+ auto LT = TLI->getTypeLegalizationCost(DL, Src);
+ if (!LT.first.isValid())
+ return InstructionCost::getInvalid();
+
+ // The code-generator is currently not able to handle scalable vectors
+ // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
+ // it. This change will be removed when code-generation for these types is
+ // sufficiently reliable.
+ if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
+ return InstructionCost::getInvalid();
+
+ return LT.first * 2;
+}
+
+InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
@@ -781,35 +1531,56 @@ unsigned AArch64TTIImpl::getGatherScatterOpCost(
Alignment, CostKind, I);
auto *VT = cast<VectorType>(DataTy);
auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
- ElementCount LegalVF = LT.second.getVectorElementCount();
- Optional<unsigned> MaxNumVScale = getMaxVScale();
- assert(MaxNumVScale && "Expected valid max vscale value");
+ if (!LT.first.isValid())
+ return InstructionCost::getInvalid();
- unsigned MemOpCost =
+ // The code-generator is currently not able to handle scalable vectors
+ // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
+ // it. This change will be removed when code-generation for these types is
+ // sufficiently reliable.
+ if (cast<VectorType>(DataTy)->getElementCount() ==
+ ElementCount::getScalable(1))
+ return InstructionCost::getInvalid();
+
+ ElementCount LegalVF = LT.second.getVectorElementCount();
+ InstructionCost MemOpCost =
getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
- unsigned MaxNumElementsPerGather =
- MaxNumVScale.getValue() * LegalVF.getKnownMinValue();
- return LT.first * MaxNumElementsPerGather * MemOpCost;
+ return LT.first * MemOpCost * getMaxNumElements(LegalVF);
}
bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
}
-int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
- MaybeAlign Alignment, unsigned AddressSpace,
- TTI::TargetCostKind CostKind,
- const Instruction *I) {
- // TODO: Handle other cost kinds.
- if (CostKind != TTI::TCK_RecipThroughput)
- return 1;
-
+InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
+ MaybeAlign Alignment,
+ unsigned AddressSpace,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I) {
+ EVT VT = TLI->getValueType(DL, Ty, true);
// Type legalization can't handle structs
- if (TLI->getValueType(DL, Ty, true) == MVT::Other)
+ if (VT == MVT::Other)
return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
CostKind);
auto LT = TLI->getTypeLegalizationCost(DL, Ty);
+ if (!LT.first.isValid())
+ return InstructionCost::getInvalid();
+
+ // The code-generator is currently not able to handle scalable vectors
+ // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
+ // it. This change will be removed when code-generation for these types is
+ // sufficiently reliable.
+ if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
+ if (VTy->getElementCount() == ElementCount::getScalable(1))
+ return InstructionCost::getInvalid();
+
+ // TODO: consider latency as well for TCK_SizeAndLatency.
+ if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
+ return LT.first;
+
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return 1;
if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
@@ -823,29 +1594,20 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
return LT.first * 2 * AmortizationCost;
}
+ // Check truncating stores and extending loads.
if (useNeonVector(Ty) &&
- cast<VectorType>(Ty)->getElementType()->isIntegerTy(8)) {
- unsigned ProfitableNumElements;
- if (Opcode == Instruction::Store)
- // We use a custom trunc store lowering so v.4b should be profitable.
- ProfitableNumElements = 4;
- else
- // We scalarize the loads because there is not v.4b register and we
- // have to promote the elements to v.2.
- ProfitableNumElements = 8;
-
- if (cast<FixedVectorType>(Ty)->getNumElements() < ProfitableNumElements) {
- unsigned NumVecElts = cast<FixedVectorType>(Ty)->getNumElements();
- unsigned NumVectorizableInstsToAmortize = NumVecElts * 2;
- // We generate 2 instructions per vector element.
- return NumVectorizableInstsToAmortize * NumVecElts * 2;
- }
+ Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
+ // v4i8 types are lowered to scalar a load/store and sshll/xtn.
+ if (VT == MVT::v4i8)
+ return 2;
+ // Otherwise we need to scalarize.
+ return cast<FixedVectorType>(Ty)->getNumElements() * 2;
}
return LT.first;
}
-int AArch64TTIImpl::getInterleavedMemoryOpCost(
+InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
bool UseMaskForCond, bool UseMaskForGaps) {
@@ -871,8 +1633,9 @@ int AArch64TTIImpl::getInterleavedMemoryOpCost(
UseMaskForCond, UseMaskForGaps);
}
-int AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
- int Cost = 0;
+InstructionCost
+AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
+ InstructionCost Cost = 0;
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
for (auto *I : Tys) {
if (!I->isVectorTy())
@@ -958,6 +1721,41 @@ void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
EnableFalkorHWPFUnrollFix)
getFalkorUnrollingPreferences(L, SE, UP);
+
+ // Scan the loop: don't unroll loops with calls as this could prevent
+ // inlining. Don't unroll vector loops either, as they don't benefit much from
+ // unrolling.
+ for (auto *BB : L->getBlocks()) {
+ for (auto &I : *BB) {
+ // Don't unroll vectorised loop.
+ if (I.getType()->isVectorTy())
+ return;
+
+ if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
+ if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
+ if (!isLoweredToCall(F))
+ continue;
+ }
+ return;
+ }
+ }
+ }
+
+ // Enable runtime unrolling for in-order models
+ // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
+ // checking for that case, we can ensure that the default behaviour is
+ // unchanged
+ if (ST->getProcFamily() != AArch64Subtarget::Others &&
+ !ST->getSchedModel().isOutOfOrder()) {
+ UP.Runtime = true;
+ UP.Partial = true;
+ UP.UpperBound = true;
+ UP.UnrollRemainder = true;
+ UP.DefaultUnrollRuntimeCount = 4;
+
+ UP.UnrollAndJam = true;
+ UP.UnrollAndJamInnerLoopThreshold = 60;
+ }
}
void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
@@ -1073,42 +1871,44 @@ bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
return Considerable;
}
-bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty,
- TTI::ReductionFlags Flags) const {
- auto *VTy = cast<VectorType>(Ty);
- unsigned ScalarBits = Ty->getScalarSizeInBits();
- switch (Opcode) {
- case Instruction::FAdd:
- case Instruction::FMul:
- case Instruction::And:
- case Instruction::Or:
- case Instruction::Xor:
- case Instruction::Mul:
+bool AArch64TTIImpl::isLegalToVectorizeReduction(
+ const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
+ if (!VF.isScalable())
+ return true;
+
+ Type *Ty = RdxDesc.getRecurrenceType();
+ if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
return false;
- case Instruction::Add:
- return ScalarBits * cast<FixedVectorType>(VTy)->getNumElements() >= 128;
- case Instruction::ICmp:
- return (ScalarBits < 64) &&
- (ScalarBits * cast<FixedVectorType>(VTy)->getNumElements() >= 128);
- case Instruction::FCmp:
- return Flags.NoNaN;
+
+ switch (RdxDesc.getRecurrenceKind()) {
+ case RecurKind::Add:
+ case RecurKind::FAdd:
+ case RecurKind::And:
+ case RecurKind::Or:
+ case RecurKind::Xor:
+ case RecurKind::SMin:
+ case RecurKind::SMax:
+ case RecurKind::UMin:
+ case RecurKind::UMax:
+ case RecurKind::FMin:
+ case RecurKind::FMax:
+ return true;
default:
- llvm_unreachable("Unhandled reduction opcode");
+ return false;
}
- return false;
}
-int AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
- bool IsPairwise, bool IsUnsigned,
- TTI::TargetCostKind CostKind) {
+InstructionCost
+AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
+ bool IsUnsigned,
+ TTI::TargetCostKind CostKind) {
if (!isa<ScalableVectorType>(Ty))
- return BaseT::getMinMaxReductionCost(Ty, CondTy, IsPairwise, IsUnsigned,
- CostKind);
+ return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) &&
"Both vector needs to be scalable");
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
- int LegalizationCost = 0;
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
+ InstructionCost LegalizationCost = 0;
if (LT.first > 1) {
Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
unsigned CmpOpcode =
@@ -1124,13 +1924,10 @@ int AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
return LegalizationCost + /*Cost of horizontal reduction*/ 2;
}
-int AArch64TTIImpl::getArithmeticReductionCostSVE(
- unsigned Opcode, VectorType *ValTy, bool IsPairwise,
- TTI::TargetCostKind CostKind) {
- assert(!IsPairwise && "Cannot be pair wise to continue");
-
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
- int LegalizationCost = 0;
+InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
+ unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
+ InstructionCost LegalizationCost = 0;
if (LT.first > 1) {
Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
@@ -1148,51 +1945,162 @@ int AArch64TTIImpl::getArithmeticReductionCostSVE(
case ISD::FADD:
return LegalizationCost + 2;
default:
- // TODO: Replace for invalid when InstructionCost is used
- // cases not supported by SVE
- return 16;
+ return InstructionCost::getInvalid();
}
}
-int AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode,
- VectorType *ValTy,
- bool IsPairwiseForm,
- TTI::TargetCostKind CostKind) {
+InstructionCost
+AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
+ Optional<FastMathFlags> FMF,
+ TTI::TargetCostKind CostKind) {
+ if (TTI::requiresOrderedReduction(FMF)) {
+ if (!isa<ScalableVectorType>(ValTy))
+ return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
+
+ if (Opcode != Instruction::FAdd)
+ return InstructionCost::getInvalid();
+
+ auto *VTy = cast<ScalableVectorType>(ValTy);
+ InstructionCost Cost =
+ getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
+ Cost *= getMaxNumElements(VTy->getElementCount());
+ return Cost;
+ }
if (isa<ScalableVectorType>(ValTy))
- return getArithmeticReductionCostSVE(Opcode, ValTy, IsPairwiseForm,
- CostKind);
- if (IsPairwiseForm)
- return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm,
- CostKind);
+ return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
MVT MTy = LT.second;
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
// Horizontal adds can use the 'addv' instruction. We model the cost of these
- // instructions as normal vector adds. This is the only arithmetic vector
- // reduction operation for which we have an instruction.
+ // instructions as twice a normal vector add, plus 1 for each legalization
+ // step (LT.first). This is the only arithmetic vector reduction operation for
+ // which we have an instruction.
+ // OR, XOR and AND costs should match the codegen from:
+ // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
+ // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
+ // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
static const CostTblEntry CostTblNoPairwise[]{
- {ISD::ADD, MVT::v8i8, 1},
- {ISD::ADD, MVT::v16i8, 1},
- {ISD::ADD, MVT::v4i16, 1},
- {ISD::ADD, MVT::v8i16, 1},
- {ISD::ADD, MVT::v4i32, 1},
+ {ISD::ADD, MVT::v8i8, 2},
+ {ISD::ADD, MVT::v16i8, 2},
+ {ISD::ADD, MVT::v4i16, 2},
+ {ISD::ADD, MVT::v8i16, 2},
+ {ISD::ADD, MVT::v4i32, 2},
+ {ISD::OR, MVT::v8i8, 15},
+ {ISD::OR, MVT::v16i8, 17},
+ {ISD::OR, MVT::v4i16, 7},
+ {ISD::OR, MVT::v8i16, 9},
+ {ISD::OR, MVT::v2i32, 3},
+ {ISD::OR, MVT::v4i32, 5},
+ {ISD::OR, MVT::v2i64, 3},
+ {ISD::XOR, MVT::v8i8, 15},
+ {ISD::XOR, MVT::v16i8, 17},
+ {ISD::XOR, MVT::v4i16, 7},
+ {ISD::XOR, MVT::v8i16, 9},
+ {ISD::XOR, MVT::v2i32, 3},
+ {ISD::XOR, MVT::v4i32, 5},
+ {ISD::XOR, MVT::v2i64, 3},
+ {ISD::AND, MVT::v8i8, 15},
+ {ISD::AND, MVT::v16i8, 17},
+ {ISD::AND, MVT::v4i16, 7},
+ {ISD::AND, MVT::v8i16, 9},
+ {ISD::AND, MVT::v2i32, 3},
+ {ISD::AND, MVT::v4i32, 5},
+ {ISD::AND, MVT::v2i64, 3},
+ };
+ switch (ISD) {
+ default:
+ break;
+ case ISD::ADD:
+ if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
+ return (LT.first - 1) + Entry->Cost;
+ break;
+ case ISD::XOR:
+ case ISD::AND:
+ case ISD::OR:
+ const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
+ if (!Entry)
+ break;
+ auto *ValVTy = cast<FixedVectorType>(ValTy);
+ if (!ValVTy->getElementType()->isIntegerTy(1) &&
+ MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
+ isPowerOf2_32(ValVTy->getNumElements())) {
+ InstructionCost ExtraCost = 0;
+ if (LT.first != 1) {
+ // Type needs to be split, so there is an extra cost of LT.first - 1
+ // arithmetic ops.
+ auto *Ty = FixedVectorType::get(ValTy->getElementType(),
+ MTy.getVectorNumElements());
+ ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
+ ExtraCost *= LT.first - 1;
+ }
+ return Entry->Cost + ExtraCost;
+ }
+ break;
+ }
+ return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
+}
+
+InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
+ static const CostTblEntry ShuffleTbl[] = {
+ { TTI::SK_Splice, MVT::nxv16i8, 1 },
+ { TTI::SK_Splice, MVT::nxv8i16, 1 },
+ { TTI::SK_Splice, MVT::nxv4i32, 1 },
+ { TTI::SK_Splice, MVT::nxv2i64, 1 },
+ { TTI::SK_Splice, MVT::nxv2f16, 1 },
+ { TTI::SK_Splice, MVT::nxv4f16, 1 },
+ { TTI::SK_Splice, MVT::nxv8f16, 1 },
+ { TTI::SK_Splice, MVT::nxv2bf16, 1 },
+ { TTI::SK_Splice, MVT::nxv4bf16, 1 },
+ { TTI::SK_Splice, MVT::nxv8bf16, 1 },
+ { TTI::SK_Splice, MVT::nxv2f32, 1 },
+ { TTI::SK_Splice, MVT::nxv4f32, 1 },
+ { TTI::SK_Splice, MVT::nxv2f64, 1 },
};
- if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
- return LT.first * Entry->Cost;
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
+ Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ EVT PromotedVT = LT.second.getScalarType() == MVT::i1
+ ? TLI->getPromotedVTForPredicate(EVT(LT.second))
+ : LT.second;
+ Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
+ InstructionCost LegalizationCost = 0;
+ if (Index < 0) {
+ LegalizationCost =
+ getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
+ CmpInst::BAD_ICMP_PREDICATE, CostKind) +
+ getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
+ CmpInst::BAD_ICMP_PREDICATE, CostKind);
+ }
- return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm,
- CostKind);
+ // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
+ // Cost performed on a promoted type.
+ if (LT.second.getScalarType() == MVT::i1) {
+ LegalizationCost +=
+ getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
+ TTI::CastContextHint::None, CostKind) +
+ getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
+ TTI::CastContextHint::None, CostKind);
+ }
+ const auto *Entry =
+ CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
+ assert(Entry && "Illegal Type for Splice");
+ LegalizationCost += Entry->Cost;
+ return LegalizationCost * LT.first;
}
-int AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
- int Index, VectorType *SubTp) {
+InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
+ VectorType *Tp,
+ ArrayRef<int> Mask, int Index,
+ VectorType *SubTp) {
+ Kind = improveShuffleKindFromMask(Kind, Mask);
if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
- Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc) {
+ Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
+ Kind == TTI::SK_Reverse) {
static const CostTblEntry ShuffleTbl[] = {
// Broadcast shuffle kinds can be performed with 'dup'.
{ TTI::SK_Broadcast, MVT::v8i8, 1 },
@@ -1226,18 +2134,69 @@ int AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
{ TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
{ TTI::SK_Select, MVT::v2f64, 1 }, // mov.
// PermuteSingleSrc shuffle kinds.
- // TODO: handle vXi8/vXi16.
{ TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
{ TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
{ TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
{ TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
{ TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
{ TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
+ { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
+ { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
+ { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
+ { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
+ { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
+ { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
+ { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
+ { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
+ // Reverse can be lowered with `rev`.
+ { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
+ { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
+ { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
+ { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
+ { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
+ { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
+ // Broadcast shuffle kinds for scalable vectors
+ { TTI::SK_Broadcast, MVT::nxv16i8, 1 },
+ { TTI::SK_Broadcast, MVT::nxv8i16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv4i32, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2i64, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2f16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv4f16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv8f16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2f32, 1 },
+ { TTI::SK_Broadcast, MVT::nxv4f32, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2f64, 1 },
+ { TTI::SK_Broadcast, MVT::nxv16i1, 1 },
+ { TTI::SK_Broadcast, MVT::nxv8i1, 1 },
+ { TTI::SK_Broadcast, MVT::nxv4i1, 1 },
+ { TTI::SK_Broadcast, MVT::nxv2i1, 1 },
+ // Handle the cases for vector.reverse with scalable vectors
+ { TTI::SK_Reverse, MVT::nxv16i8, 1 },
+ { TTI::SK_Reverse, MVT::nxv8i16, 1 },
+ { TTI::SK_Reverse, MVT::nxv4i32, 1 },
+ { TTI::SK_Reverse, MVT::nxv2i64, 1 },
+ { TTI::SK_Reverse, MVT::nxv2f16, 1 },
+ { TTI::SK_Reverse, MVT::nxv4f16, 1 },
+ { TTI::SK_Reverse, MVT::nxv8f16, 1 },
+ { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
+ { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
+ { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
+ { TTI::SK_Reverse, MVT::nxv2f32, 1 },
+ { TTI::SK_Reverse, MVT::nxv4f32, 1 },
+ { TTI::SK_Reverse, MVT::nxv2f64, 1 },
+ { TTI::SK_Reverse, MVT::nxv16i1, 1 },
+ { TTI::SK_Reverse, MVT::nxv8i1, 1 },
+ { TTI::SK_Reverse, MVT::nxv4i1, 1 },
+ { TTI::SK_Reverse, MVT::nxv2i1, 1 },
};
- std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
return LT.first * Entry->Cost;
}
-
- return BaseT::getShuffleCost(Kind, Tp, Index, SubTp);
+ if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
+ return getSpliceCost(Tp, Index);
+ return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
}