summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp479
1 files changed, 367 insertions, 112 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 01236aa6b527..63d6fa5bbb26 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -9,11 +9,13 @@
#include "AArch64TargetTransformInfo.h"
#include "AArch64ExpandImm.h"
#include "MCTargetDesc/AArch64AddressingModes.h"
+#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
@@ -220,19 +222,15 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
auto *RetTy = ICA.getReturnType();
switch (ICA.getID()) {
case Intrinsic::umin:
- case Intrinsic::umax: {
- auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
- // umin(x,y) -> sub(x,usubsat(x,y))
- // umax(x,y) -> add(x,usubsat(y,x))
- if (LT.second == MVT::v2i64)
- return LT.first * 2;
- LLVM_FALLTHROUGH;
- }
+ case Intrinsic::umax:
case Intrinsic::smin:
case Intrinsic::smax: {
static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
MVT::v8i16, MVT::v2i32, MVT::v4i32};
auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
+ // v2i64 types get converted to cmp+bif hence the cost of 2
+ if (LT.second == MVT::v2i64)
+ return LT.first * 2;
if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
return LT.first;
break;
@@ -291,13 +289,15 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
const auto *Entry =
CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
- // Cost Model is using the legal type(i32) that i8 and i16 will be converted
- // to +1 so that we match the actual lowering cost
- if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
- TLI->getValueType(DL, RetTy, true) == MVT::i16)
- return LegalisationCost.first * Entry->Cost + 1;
- if (Entry)
+ if (Entry) {
+ // Cost Model is using the legal type(i32) that i8 and i16 will be
+ // converted to +1 so that we match the actual lowering cost
+ if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
+ TLI->getValueType(DL, RetTy, true) == MVT::i16)
+ return LegalisationCost.first * Entry->Cost + 1;
+
return LegalisationCost.first * Entry->Cost;
+ }
break;
}
case Intrinsic::ctpop: {
@@ -440,6 +440,18 @@ static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
return IC.replaceInstUsesWith(II, Insert);
}
+static Optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
+ IntrinsicInst &II) {
+ // Replace DupX with a regular IR splat.
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ auto *RetTy = cast<ScalableVectorType>(II.getType());
+ Value *Splat =
+ Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0));
+ Splat->takeName(&II);
+ return IC.replaceInstUsesWith(II, Splat);
+}
+
static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
IntrinsicInst &II) {
LLVMContext &Ctx = II.getContext();
@@ -457,12 +469,9 @@ static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
return None;
// Check that we have a compare of zero..
- auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2));
- if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
- return None;
-
- auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0));
- if (!DupXArg || !DupXArg->isZero())
+ auto *SplatValue =
+ dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
+ if (!SplatValue || !SplatValue->isZero())
return None;
// ..against a dupq
@@ -547,14 +556,34 @@ static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
IntrinsicInst &II) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
Value *Pg = II.getArgOperand(0);
Value *Vec = II.getArgOperand(1);
- bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta;
+ auto IntrinsicID = II.getIntrinsicID();
+ bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
// lastX(splat(X)) --> X
if (auto *SplatVal = getSplatValue(Vec))
return IC.replaceInstUsesWith(II, SplatVal);
+ // If x and/or y is a splat value then:
+ // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
+ Value *LHS, *RHS;
+ if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
+ if (isSplatValue(LHS) || isSplatValue(RHS)) {
+ auto *OldBinOp = cast<BinaryOperator>(Vec);
+ auto OpC = OldBinOp->getOpcode();
+ auto *NewLHS =
+ Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
+ auto *NewRHS =
+ Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
+ auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
+ OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
+ return IC.replaceInstUsesWith(II, NewBinOp);
+ }
+ }
+
auto *C = dyn_cast<Constant>(Pg);
if (IsAfter && C && C->isNullValue()) {
// The intrinsic is extracting lane 0 so use an extract instead.
@@ -576,39 +605,11 @@ static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
// Can the intrinsic's predicate be converted to a known constant index?
- unsigned Idx;
- switch (PTruePattern) {
- default:
+ unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
+ if (!MinNumElts)
return None;
- case AArch64SVEPredPattern::vl1:
- Idx = 0;
- break;
- case AArch64SVEPredPattern::vl2:
- Idx = 1;
- break;
- case AArch64SVEPredPattern::vl3:
- Idx = 2;
- break;
- case AArch64SVEPredPattern::vl4:
- Idx = 3;
- break;
- case AArch64SVEPredPattern::vl5:
- Idx = 4;
- break;
- case AArch64SVEPredPattern::vl6:
- Idx = 5;
- break;
- case AArch64SVEPredPattern::vl7:
- Idx = 6;
- break;
- case AArch64SVEPredPattern::vl8:
- Idx = 7;
- break;
- case AArch64SVEPredPattern::vl16:
- Idx = 15;
- break;
- }
+ unsigned Idx = MinNumElts - 1;
// Increment the index if extracting the element after the last active
// predicate element.
if (IsAfter)
@@ -661,26 +662,9 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
return IC.replaceInstUsesWith(II, VScale);
}
- unsigned MinNumElts = 0;
- switch (Pattern) {
- default:
- return None;
- case AArch64SVEPredPattern::vl1:
- case AArch64SVEPredPattern::vl2:
- case AArch64SVEPredPattern::vl3:
- case AArch64SVEPredPattern::vl4:
- case AArch64SVEPredPattern::vl5:
- case AArch64SVEPredPattern::vl6:
- case AArch64SVEPredPattern::vl7:
- case AArch64SVEPredPattern::vl8:
- MinNumElts = Pattern;
- break;
- case AArch64SVEPredPattern::vl16:
- MinNumElts = 16;
- break;
- }
+ unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
- return NumElts >= MinNumElts
+ return MinNumElts && NumElts >= MinNumElts
? Optional<Instruction *>(IC.replaceInstUsesWith(
II, ConstantInt::get(II.getType(), MinNumElts)))
: None;
@@ -711,6 +695,116 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
return None;
}
+static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
+ IntrinsicInst &II) {
+ // fold (fadd p a (fmul p b c)) -> (fma p a b c)
+ Value *P = II.getOperand(0);
+ Value *A = II.getOperand(1);
+ auto FMul = II.getOperand(2);
+ Value *B, *C;
+ if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
+ m_Specific(P), m_Value(B), m_Value(C))))
+ return None;
+
+ if (!FMul->hasOneUse())
+ return None;
+
+ llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
+ // Stop the combine when the flags on the inputs differ in case dropping flags
+ // would lead to us missing out on more beneficial optimizations.
+ if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
+ return None;
+ if (!FAddFlags.allowContract())
+ return None;
+
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
+ {II.getType()}, {P, A, B, C}, &II);
+ FMLA->setFastMathFlags(FAddFlags);
+ return IC.replaceInstUsesWith(II, FMLA);
+}
+
+static Optional<Instruction *>
+instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ Value *Pred = II.getOperand(0);
+ Value *PtrOp = II.getOperand(1);
+ Type *VecTy = II.getType();
+ Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo());
+
+ if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
+ m_ConstantInt<AArch64SVEPredPattern::all>()))) {
+ LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr);
+ return IC.replaceInstUsesWith(II, Load);
+ }
+
+ CallInst *MaskedLoad =
+ Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL),
+ Pred, ConstantAggregateZero::get(VecTy));
+ return IC.replaceInstUsesWith(II, MaskedLoad);
+}
+
+static Optional<Instruction *>
+instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ Value *VecOp = II.getOperand(0);
+ Value *Pred = II.getOperand(1);
+ Value *PtrOp = II.getOperand(2);
+ Value *VecPtr =
+ Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo());
+
+ if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
+ m_ConstantInt<AArch64SVEPredPattern::all>()))) {
+ Builder.CreateStore(VecOp, VecPtr);
+ return IC.eraseInstFromFunction(II);
+ }
+
+ Builder.CreateMaskedStore(VecOp, VecPtr, PtrOp->getPointerAlignment(DL),
+ Pred);
+ return IC.eraseInstFromFunction(II);
+}
+
+static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
+ switch (Intrinsic) {
+ case Intrinsic::aarch64_sve_fmul:
+ return Instruction::BinaryOps::FMul;
+ case Intrinsic::aarch64_sve_fadd:
+ return Instruction::BinaryOps::FAdd;
+ case Intrinsic::aarch64_sve_fsub:
+ return Instruction::BinaryOps::FSub;
+ default:
+ return Instruction::BinaryOpsEnd;
+ }
+}
+
+static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
+ IntrinsicInst &II) {
+ auto *OpPredicate = II.getOperand(0);
+ auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
+ if (BinOpCode == Instruction::BinaryOpsEnd ||
+ !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
+ m_ConstantInt<AArch64SVEPredPattern::all>())))
+ return None;
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ Builder.setFastMathFlags(II.getFastMathFlags());
+ auto BinOp =
+ Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
+ return IC.replaceInstUsesWith(II, BinOp);
+}
+
+static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
+ IntrinsicInst &II) {
+ if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
+ return FMLA;
+ return instCombineSVEVectorBinOp(IC, II);
+}
+
static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpPredicate = II.getOperand(0);
@@ -720,14 +814,11 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
IRBuilder<> Builder(II.getContext());
Builder.SetInsertPoint(&II);
- // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
- // with a unit splat value, false otherwise.
- auto IsUnitDupX = [](auto *I) {
- auto *IntrI = dyn_cast<IntrinsicInst>(I);
- if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
+ // Return true if a given instruction is a unit splat value, false otherwise.
+ auto IsUnitSplat = [](auto *I) {
+ auto *SplatValue = getSplatValue(I);
+ if (!SplatValue)
return false;
-
- auto *SplatValue = IntrI->getOperand(0);
return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
};
@@ -744,10 +835,10 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
// The OpMultiplier variable should always point to the dup (if any), so
// swap if necessary.
- if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
+ if (IsUnitDup(OpMultiplicand) || IsUnitSplat(OpMultiplicand))
std::swap(OpMultiplier, OpMultiplicand);
- if (IsUnitDupX(OpMultiplier)) {
+ if (IsUnitSplat(OpMultiplier)) {
// [f]mul pg (dupx 1) %n => %n
OpMultiplicand->takeName(&II);
return IC.replaceInstUsesWith(II, OpMultiplicand);
@@ -763,22 +854,40 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
}
}
- return None;
+ return instCombineSVEVectorBinOp(IC, II);
}
+static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
+ IntrinsicInst &II) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ Value *UnpackArg = II.getArgOperand(0);
+ auto *RetTy = cast<ScalableVectorType>(II.getType());
+ bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
+ II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
+
+ // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
+ // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
+ if (auto *ScalarArg = getSplatValue(UnpackArg)) {
+ ScalarArg =
+ Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
+ Value *NewVal =
+ Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
+ NewVal->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewVal);
+ }
+
+ return None;
+}
static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
IntrinsicInst &II) {
auto *OpVal = II.getOperand(0);
auto *OpIndices = II.getOperand(1);
VectorType *VTy = cast<VectorType>(II.getType());
- // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
- // constant splat value < minimal element count of result.
- auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
- if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
- return None;
-
- auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
+ // Check whether OpIndices is a constant splat value < minimal element count
+ // of result.
+ auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
if (!SplatValue ||
SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
return None;
@@ -795,6 +904,115 @@ static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
return IC.replaceInstUsesWith(II, VectorSplat);
}
+static Optional<Instruction *> instCombineSVETupleGet(InstCombiner &IC,
+ IntrinsicInst &II) {
+ // Try to remove sequences of tuple get/set.
+ Value *SetTuple, *SetIndex, *SetValue;
+ auto *GetTuple = II.getArgOperand(0);
+ auto *GetIndex = II.getArgOperand(1);
+ // Check that we have tuple_get(GetTuple, GetIndex) where GetTuple is a
+ // call to tuple_set i.e. tuple_set(SetTuple, SetIndex, SetValue).
+ // Make sure that the types of the current intrinsic and SetValue match
+ // in order to safely remove the sequence.
+ if (!match(GetTuple,
+ m_Intrinsic<Intrinsic::aarch64_sve_tuple_set>(
+ m_Value(SetTuple), m_Value(SetIndex), m_Value(SetValue))) ||
+ SetValue->getType() != II.getType())
+ return None;
+ // Case where we get the same index right after setting it.
+ // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) --> SetValue
+ if (GetIndex == SetIndex)
+ return IC.replaceInstUsesWith(II, SetValue);
+ // If we are getting a different index than what was set in the tuple_set
+ // intrinsic. We can just set the input tuple to the one up in the chain.
+ // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex)
+ // --> tuple_get(SetTuple, GetIndex)
+ return IC.replaceOperand(II, 0, SetTuple);
+}
+
+static Optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
+ IntrinsicInst &II) {
+ // zip1(uzp1(A, B), uzp2(A, B)) --> A
+ // zip2(uzp1(A, B), uzp2(A, B)) --> B
+ Value *A, *B;
+ if (match(II.getArgOperand(0),
+ m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
+ match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
+ m_Specific(A), m_Specific(B))))
+ return IC.replaceInstUsesWith(
+ II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
+
+ return None;
+}
+
+static Optional<Instruction *> instCombineLD1GatherIndex(InstCombiner &IC,
+ IntrinsicInst &II) {
+ Value *Mask = II.getOperand(0);
+ Value *BasePtr = II.getOperand(1);
+ Value *Index = II.getOperand(2);
+ Type *Ty = II.getType();
+ Type *BasePtrTy = BasePtr->getType();
+ Value *PassThru = ConstantAggregateZero::get(Ty);
+
+ // Contiguous gather => masked load.
+ // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
+ // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
+ Value *IndexBase;
+ if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
+ m_Value(IndexBase), m_SpecificInt(1)))) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ Align Alignment =
+ BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
+
+ Type *VecPtrTy = PointerType::getUnqual(Ty);
+ Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
+ IndexBase);
+ Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
+ CallInst *MaskedLoad =
+ Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
+ MaskedLoad->takeName(&II);
+ return IC.replaceInstUsesWith(II, MaskedLoad);
+ }
+
+ return None;
+}
+
+static Optional<Instruction *> instCombineST1ScatterIndex(InstCombiner &IC,
+ IntrinsicInst &II) {
+ Value *Val = II.getOperand(0);
+ Value *Mask = II.getOperand(1);
+ Value *BasePtr = II.getOperand(2);
+ Value *Index = II.getOperand(3);
+ Type *Ty = Val->getType();
+ Type *BasePtrTy = BasePtr->getType();
+
+ // Contiguous scatter => masked store.
+ // (sve.ld1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
+ // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
+ Value *IndexBase;
+ if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
+ m_Value(IndexBase), m_SpecificInt(1)))) {
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+
+ Align Alignment =
+ BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
+
+ Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
+ IndexBase);
+ Type *VecPtrTy = PointerType::getUnqual(Ty);
+ Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
+
+ (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
+
+ return IC.eraseInstFromFunction(II);
+ }
+
+ return None;
+}
+
Optional<Instruction *>
AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
IntrinsicInst &II) const {
@@ -806,6 +1024,8 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
return instCombineConvertFromSVBool(IC, II);
case Intrinsic::aarch64_sve_dup:
return instCombineSVEDup(IC, II);
+ case Intrinsic::aarch64_sve_dup_x:
+ return instCombineSVEDupX(IC, II);
case Intrinsic::aarch64_sve_cmpne:
case Intrinsic::aarch64_sve_cmpne_wide:
return instCombineSVECmpNE(IC, II);
@@ -829,8 +1049,30 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_mul:
case Intrinsic::aarch64_sve_fmul:
return instCombineSVEVectorMul(IC, II);
+ case Intrinsic::aarch64_sve_fadd:
+ return instCombineSVEVectorFAdd(IC, II);
+ case Intrinsic::aarch64_sve_fsub:
+ return instCombineSVEVectorBinOp(IC, II);
case Intrinsic::aarch64_sve_tbl:
return instCombineSVETBL(IC, II);
+ case Intrinsic::aarch64_sve_uunpkhi:
+ case Intrinsic::aarch64_sve_uunpklo:
+ case Intrinsic::aarch64_sve_sunpkhi:
+ case Intrinsic::aarch64_sve_sunpklo:
+ return instCombineSVEUnpack(IC, II);
+ case Intrinsic::aarch64_sve_tuple_get:
+ return instCombineSVETupleGet(IC, II);
+ case Intrinsic::aarch64_sve_zip1:
+ case Intrinsic::aarch64_sve_zip2:
+ return instCombineSVEZip(IC, II);
+ case Intrinsic::aarch64_sve_ld1_gather_index:
+ return instCombineLD1GatherIndex(IC, II);
+ case Intrinsic::aarch64_sve_st1_scatter_index:
+ return instCombineST1ScatterIndex(IC, II);
+ case Intrinsic::aarch64_sve_ld1:
+ return instCombineSVELD1(IC, II, DL);
+ case Intrinsic::aarch64_sve_st1:
+ return instCombineSVEST1(IC, II, DL);
}
return None;
@@ -1393,9 +1635,13 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
return (Cost + 1) * LT.first;
case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ case ISD::FDIV:
+ case ISD::FNEG:
// These nodes are marked as 'custom' just to lower them to SVE.
// We know said lowering will incur no additional cost.
- if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty())
+ if (!Ty->getScalarType()->isFP128Ty())
return (Cost + 2) * LT.first;
return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
@@ -1525,8 +1771,7 @@ AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
-
- if (!isa<ScalableVectorType>(DataTy))
+ if (useNeonVector(DataTy))
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);
auto *VT = cast<VectorType>(DataTy);
@@ -1623,9 +1868,10 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
// ldN/stN only support legal vector types of size 64 or 128 in bits.
// Accesses having vector types that are a multiple of 128 bits can be
// matched to more than one ldN/stN instruction.
+ bool UseScalable;
if (NumElts % Factor == 0 &&
- TLI->isLegalInterleavedAccessType(SubVecTy, DL))
- return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL);
+ TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
+ return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
}
return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
@@ -1705,9 +1951,12 @@ getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
}
void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
- TTI::UnrollingPreferences &UP) {
+ TTI::UnrollingPreferences &UP,
+ OptimizationRemarkEmitter *ORE) {
// Enable partial unrolling and runtime unrolling.
- BaseT::getUnrollingPreferences(L, SE, UP);
+ BaseT::getUnrollingPreferences(L, SE, UP, ORE);
+
+ UP.UpperBound = true;
// For inner loop, it is more likely to be a hot one, and the runtime check
// can be promoted out from LICM pass, so the overhead is less, let's try
@@ -1749,7 +1998,6 @@ void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
!ST->getSchedModel().isOutOfOrder()) {
UP.Runtime = true;
UP.Partial = true;
- UP.UpperBound = true;
UP.UnrollRemainder = true;
UP.DefaultUnrollRuntimeCount = 4;
@@ -1775,7 +2023,7 @@ Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
StructType *ST = dyn_cast<StructType>(ExpectedType);
if (!ST)
return nullptr;
- unsigned NumElts = Inst->getNumArgOperands() - 1;
+ unsigned NumElts = Inst->arg_size() - 1;
if (ST->getNumElements() != NumElts)
return nullptr;
for (unsigned i = 0, e = NumElts; i != e; ++i) {
@@ -1816,7 +2064,7 @@ bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
case Intrinsic::aarch64_neon_st4:
Info.ReadMem = false;
Info.WriteMem = true;
- Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1);
+ Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
break;
}
@@ -1892,6 +2140,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
case RecurKind::UMax:
case RecurKind::FMin:
case RecurKind::FMax:
+ case RecurKind::SelectICmp:
+ case RecurKind::SelectFCmp:
return true;
default:
return false;
@@ -1902,23 +2152,23 @@ InstructionCost
AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
bool IsUnsigned,
TTI::TargetCostKind CostKind) {
- if (!isa<ScalableVectorType>(Ty))
+ std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
+
+ if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
- assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) &&
- "Both vector needs to be scalable");
- std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
+ assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
+ "Both vector needs to be equally scalable");
+
InstructionCost LegalizationCost = 0;
if (LT.first > 1) {
Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
- unsigned CmpOpcode =
- Ty->isFPOrFPVectorTy() ? Instruction::FCmp : Instruction::ICmp;
- LegalizationCost =
- getCmpSelInstrCost(CmpOpcode, LegalVTy, LegalVTy,
- CmpInst::BAD_ICMP_PREDICATE, CostKind) +
- getCmpSelInstrCost(Instruction::Select, LegalVTy, LegalVTy,
- CmpInst::BAD_ICMP_PREDICATE, CostKind);
- LegalizationCost *= LT.first - 1;
+ unsigned MinMaxOpcode =
+ Ty->isFPOrFPVectorTy()
+ ? Intrinsic::maxnum
+ : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
+ IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
+ LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
}
return LegalizationCost + /*Cost of horizontal reduction*/ 2;
@@ -1954,8 +2204,13 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
Optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind) {
if (TTI::requiresOrderedReduction(FMF)) {
- if (!isa<ScalableVectorType>(ValTy))
- return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
+ if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
+ InstructionCost BaseCost =
+ BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
+ // Add on extra cost to reflect the extra overhead on some CPUs. We still
+ // end up vectorizing for more computationally intensive loops.
+ return BaseCost + FixedVTy->getNumElements();
+ }
if (Opcode != Instruction::FAdd)
return InstructionCost::getInvalid();