diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 479 |
1 files changed, 367 insertions, 112 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 01236aa6b527..63d6fa5bbb26 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -9,11 +9,13 @@ #include "AArch64TargetTransformInfo.h" #include "AArch64ExpandImm.h" #include "MCTargetDesc/AArch64AddressingModes.h" +#include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/BasicTTIImpl.h" #include "llvm/CodeGen/CostTable.h" #include "llvm/CodeGen/TargetLowering.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/PatternMatch.h" @@ -220,19 +222,15 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, auto *RetTy = ICA.getReturnType(); switch (ICA.getID()) { case Intrinsic::umin: - case Intrinsic::umax: { - auto LT = TLI->getTypeLegalizationCost(DL, RetTy); - // umin(x,y) -> sub(x,usubsat(x,y)) - // umax(x,y) -> add(x,usubsat(y,x)) - if (LT.second == MVT::v2i64) - return LT.first * 2; - LLVM_FALLTHROUGH; - } + case Intrinsic::umax: case Intrinsic::smin: case Intrinsic::smax: { static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32, MVT::v4i32}; auto LT = TLI->getTypeLegalizationCost(DL, RetTy); + // v2i64 types get converted to cmp+bif hence the cost of 2 + if (LT.second == MVT::v2i64) + return LT.first * 2; if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; })) return LT.first; break; @@ -291,13 +289,15 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy); const auto *Entry = CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second); - // Cost Model is using the legal type(i32) that i8 and i16 will be converted - // to +1 so that we match the actual lowering cost - if (TLI->getValueType(DL, RetTy, true) == MVT::i8 || - TLI->getValueType(DL, RetTy, true) == MVT::i16) - return LegalisationCost.first * Entry->Cost + 1; - if (Entry) + if (Entry) { + // Cost Model is using the legal type(i32) that i8 and i16 will be + // converted to +1 so that we match the actual lowering cost + if (TLI->getValueType(DL, RetTy, true) == MVT::i8 || + TLI->getValueType(DL, RetTy, true) == MVT::i16) + return LegalisationCost.first * Entry->Cost + 1; + return LegalisationCost.first * Entry->Cost; + } break; } case Intrinsic::ctpop: { @@ -440,6 +440,18 @@ static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC, return IC.replaceInstUsesWith(II, Insert); } +static Optional<Instruction *> instCombineSVEDupX(InstCombiner &IC, + IntrinsicInst &II) { + // Replace DupX with a regular IR splat. + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto *RetTy = cast<ScalableVectorType>(II.getType()); + Value *Splat = + Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0)); + Splat->takeName(&II); + return IC.replaceInstUsesWith(II, Splat); +} + static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, IntrinsicInst &II) { LLVMContext &Ctx = II.getContext(); @@ -457,12 +469,9 @@ static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, return None; // Check that we have a compare of zero.. - auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2)); - if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) - return None; - - auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0)); - if (!DupXArg || !DupXArg->isZero()) + auto *SplatValue = + dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2))); + if (!SplatValue || !SplatValue->isZero()) return None; // ..against a dupq @@ -547,14 +556,34 @@ static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, static Optional<Instruction *> instCombineSVELast(InstCombiner &IC, IntrinsicInst &II) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); Value *Pg = II.getArgOperand(0); Value *Vec = II.getArgOperand(1); - bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta; + auto IntrinsicID = II.getIntrinsicID(); + bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta; // lastX(splat(X)) --> X if (auto *SplatVal = getSplatValue(Vec)) return IC.replaceInstUsesWith(II, SplatVal); + // If x and/or y is a splat value then: + // lastX (binop (x, y)) --> binop(lastX(x), lastX(y)) + Value *LHS, *RHS; + if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) { + if (isSplatValue(LHS) || isSplatValue(RHS)) { + auto *OldBinOp = cast<BinaryOperator>(Vec); + auto OpC = OldBinOp->getOpcode(); + auto *NewLHS = + Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS}); + auto *NewRHS = + Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS}); + auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags( + OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II); + return IC.replaceInstUsesWith(II, NewBinOp); + } + } + auto *C = dyn_cast<Constant>(Pg); if (IsAfter && C && C->isNullValue()) { // The intrinsic is extracting lane 0 so use an extract instead. @@ -576,39 +605,11 @@ static Optional<Instruction *> instCombineSVELast(InstCombiner &IC, cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue(); // Can the intrinsic's predicate be converted to a known constant index? - unsigned Idx; - switch (PTruePattern) { - default: + unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern); + if (!MinNumElts) return None; - case AArch64SVEPredPattern::vl1: - Idx = 0; - break; - case AArch64SVEPredPattern::vl2: - Idx = 1; - break; - case AArch64SVEPredPattern::vl3: - Idx = 2; - break; - case AArch64SVEPredPattern::vl4: - Idx = 3; - break; - case AArch64SVEPredPattern::vl5: - Idx = 4; - break; - case AArch64SVEPredPattern::vl6: - Idx = 5; - break; - case AArch64SVEPredPattern::vl7: - Idx = 6; - break; - case AArch64SVEPredPattern::vl8: - Idx = 7; - break; - case AArch64SVEPredPattern::vl16: - Idx = 15; - break; - } + unsigned Idx = MinNumElts - 1; // Increment the index if extracting the element after the last active // predicate element. if (IsAfter) @@ -661,26 +662,9 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { return IC.replaceInstUsesWith(II, VScale); } - unsigned MinNumElts = 0; - switch (Pattern) { - default: - return None; - case AArch64SVEPredPattern::vl1: - case AArch64SVEPredPattern::vl2: - case AArch64SVEPredPattern::vl3: - case AArch64SVEPredPattern::vl4: - case AArch64SVEPredPattern::vl5: - case AArch64SVEPredPattern::vl6: - case AArch64SVEPredPattern::vl7: - case AArch64SVEPredPattern::vl8: - MinNumElts = Pattern; - break; - case AArch64SVEPredPattern::vl16: - MinNumElts = 16; - break; - } + unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern); - return NumElts >= MinNumElts + return MinNumElts && NumElts >= MinNumElts ? Optional<Instruction *>(IC.replaceInstUsesWith( II, ConstantInt::get(II.getType(), MinNumElts))) : None; @@ -711,6 +695,116 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, return None; } +static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC, + IntrinsicInst &II) { + // fold (fadd p a (fmul p b c)) -> (fma p a b c) + Value *P = II.getOperand(0); + Value *A = II.getOperand(1); + auto FMul = II.getOperand(2); + Value *B, *C; + if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>( + m_Specific(P), m_Value(B), m_Value(C)))) + return None; + + if (!FMul->hasOneUse()) + return None; + + llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); + // Stop the combine when the flags on the inputs differ in case dropping flags + // would lead to us missing out on more beneficial optimizations. + if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags()) + return None; + if (!FAddFlags.allowContract()) + return None; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla, + {II.getType()}, {P, A, B, C}, &II); + FMLA->setFastMathFlags(FAddFlags); + return IC.replaceInstUsesWith(II, FMLA); +} + +static Optional<Instruction *> +instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + Value *Pred = II.getOperand(0); + Value *PtrOp = II.getOperand(1); + Type *VecTy = II.getType(); + Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo()); + + if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( + m_ConstantInt<AArch64SVEPredPattern::all>()))) { + LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr); + return IC.replaceInstUsesWith(II, Load); + } + + CallInst *MaskedLoad = + Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL), + Pred, ConstantAggregateZero::get(VecTy)); + return IC.replaceInstUsesWith(II, MaskedLoad); +} + +static Optional<Instruction *> +instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + Value *VecOp = II.getOperand(0); + Value *Pred = II.getOperand(1); + Value *PtrOp = II.getOperand(2); + Value *VecPtr = + Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo()); + + if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( + m_ConstantInt<AArch64SVEPredPattern::all>()))) { + Builder.CreateStore(VecOp, VecPtr); + return IC.eraseInstFromFunction(II); + } + + Builder.CreateMaskedStore(VecOp, VecPtr, PtrOp->getPointerAlignment(DL), + Pred); + return IC.eraseInstFromFunction(II); +} + +static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { + switch (Intrinsic) { + case Intrinsic::aarch64_sve_fmul: + return Instruction::BinaryOps::FMul; + case Intrinsic::aarch64_sve_fadd: + return Instruction::BinaryOps::FAdd; + case Intrinsic::aarch64_sve_fsub: + return Instruction::BinaryOps::FSub; + default: + return Instruction::BinaryOpsEnd; + } +} + +static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC, + IntrinsicInst &II) { + auto *OpPredicate = II.getOperand(0); + auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID()); + if (BinOpCode == Instruction::BinaryOpsEnd || + !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( + m_ConstantInt<AArch64SVEPredPattern::all>()))) + return None; + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + Builder.setFastMathFlags(II.getFastMathFlags()); + auto BinOp = + Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2)); + return IC.replaceInstUsesWith(II, BinOp); +} + +static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC, + IntrinsicInst &II) { + if (auto FMLA = instCombineSVEVectorFMLA(IC, II)) + return FMLA; + return instCombineSVEVectorBinOp(IC, II); +} + static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, IntrinsicInst &II) { auto *OpPredicate = II.getOperand(0); @@ -720,14 +814,11 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, IRBuilder<> Builder(II.getContext()); Builder.SetInsertPoint(&II); - // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call - // with a unit splat value, false otherwise. - auto IsUnitDupX = [](auto *I) { - auto *IntrI = dyn_cast<IntrinsicInst>(I); - if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) + // Return true if a given instruction is a unit splat value, false otherwise. + auto IsUnitSplat = [](auto *I) { + auto *SplatValue = getSplatValue(I); + if (!SplatValue) return false; - - auto *SplatValue = IntrI->getOperand(0); return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); }; @@ -744,10 +835,10 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, // The OpMultiplier variable should always point to the dup (if any), so // swap if necessary. - if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand)) + if (IsUnitDup(OpMultiplicand) || IsUnitSplat(OpMultiplicand)) std::swap(OpMultiplier, OpMultiplicand); - if (IsUnitDupX(OpMultiplier)) { + if (IsUnitSplat(OpMultiplier)) { // [f]mul pg (dupx 1) %n => %n OpMultiplicand->takeName(&II); return IC.replaceInstUsesWith(II, OpMultiplicand); @@ -763,22 +854,40 @@ static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, } } - return None; + return instCombineSVEVectorBinOp(IC, II); } +static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC, + IntrinsicInst &II) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + Value *UnpackArg = II.getArgOperand(0); + auto *RetTy = cast<ScalableVectorType>(II.getType()); + bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi || + II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo; + + // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X)) + // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X)) + if (auto *ScalarArg = getSplatValue(UnpackArg)) { + ScalarArg = + Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned); + Value *NewVal = + Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg); + NewVal->takeName(&II); + return IC.replaceInstUsesWith(II, NewVal); + } + + return None; +} static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC, IntrinsicInst &II) { auto *OpVal = II.getOperand(0); auto *OpIndices = II.getOperand(1); VectorType *VTy = cast<VectorType>(II.getType()); - // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with - // constant splat value < minimal element count of result. - auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices); - if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) - return None; - - auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0)); + // Check whether OpIndices is a constant splat value < minimal element count + // of result. + auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices)); if (!SplatValue || SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue())) return None; @@ -795,6 +904,115 @@ static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC, return IC.replaceInstUsesWith(II, VectorSplat); } +static Optional<Instruction *> instCombineSVETupleGet(InstCombiner &IC, + IntrinsicInst &II) { + // Try to remove sequences of tuple get/set. + Value *SetTuple, *SetIndex, *SetValue; + auto *GetTuple = II.getArgOperand(0); + auto *GetIndex = II.getArgOperand(1); + // Check that we have tuple_get(GetTuple, GetIndex) where GetTuple is a + // call to tuple_set i.e. tuple_set(SetTuple, SetIndex, SetValue). + // Make sure that the types of the current intrinsic and SetValue match + // in order to safely remove the sequence. + if (!match(GetTuple, + m_Intrinsic<Intrinsic::aarch64_sve_tuple_set>( + m_Value(SetTuple), m_Value(SetIndex), m_Value(SetValue))) || + SetValue->getType() != II.getType()) + return None; + // Case where we get the same index right after setting it. + // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) --> SetValue + if (GetIndex == SetIndex) + return IC.replaceInstUsesWith(II, SetValue); + // If we are getting a different index than what was set in the tuple_set + // intrinsic. We can just set the input tuple to the one up in the chain. + // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) + // --> tuple_get(SetTuple, GetIndex) + return IC.replaceOperand(II, 0, SetTuple); +} + +static Optional<Instruction *> instCombineSVEZip(InstCombiner &IC, + IntrinsicInst &II) { + // zip1(uzp1(A, B), uzp2(A, B)) --> A + // zip2(uzp1(A, B), uzp2(A, B)) --> B + Value *A, *B; + if (match(II.getArgOperand(0), + m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) && + match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>( + m_Specific(A), m_Specific(B)))) + return IC.replaceInstUsesWith( + II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B)); + + return None; +} + +static Optional<Instruction *> instCombineLD1GatherIndex(InstCombiner &IC, + IntrinsicInst &II) { + Value *Mask = II.getOperand(0); + Value *BasePtr = II.getOperand(1); + Value *Index = II.getOperand(2); + Type *Ty = II.getType(); + Type *BasePtrTy = BasePtr->getType(); + Value *PassThru = ConstantAggregateZero::get(Ty); + + // Contiguous gather => masked load. + // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1)) + // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer) + Value *IndexBase; + if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( + m_Value(IndexBase), m_SpecificInt(1)))) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + Align Alignment = + BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); + + Type *VecPtrTy = PointerType::getUnqual(Ty); + Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr, + IndexBase); + Ptr = Builder.CreateBitCast(Ptr, VecPtrTy); + CallInst *MaskedLoad = + Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru); + MaskedLoad->takeName(&II); + return IC.replaceInstUsesWith(II, MaskedLoad); + } + + return None; +} + +static Optional<Instruction *> instCombineST1ScatterIndex(InstCombiner &IC, + IntrinsicInst &II) { + Value *Val = II.getOperand(0); + Value *Mask = II.getOperand(1); + Value *BasePtr = II.getOperand(2); + Value *Index = II.getOperand(3); + Type *Ty = Val->getType(); + Type *BasePtrTy = BasePtr->getType(); + + // Contiguous scatter => masked store. + // (sve.ld1.scatter.index Value Mask BasePtr (sve.index IndexBase 1)) + // => (masked.store Value (gep BasePtr IndexBase) Align Mask) + Value *IndexBase; + if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( + m_Value(IndexBase), m_SpecificInt(1)))) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + Align Alignment = + BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); + + Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr, + IndexBase); + Type *VecPtrTy = PointerType::getUnqual(Ty); + Ptr = Builder.CreateBitCast(Ptr, VecPtrTy); + + (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask); + + return IC.eraseInstFromFunction(II); + } + + return None; +} + Optional<Instruction *> AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { @@ -806,6 +1024,8 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, return instCombineConvertFromSVBool(IC, II); case Intrinsic::aarch64_sve_dup: return instCombineSVEDup(IC, II); + case Intrinsic::aarch64_sve_dup_x: + return instCombineSVEDupX(IC, II); case Intrinsic::aarch64_sve_cmpne: case Intrinsic::aarch64_sve_cmpne_wide: return instCombineSVECmpNE(IC, II); @@ -829,8 +1049,30 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, case Intrinsic::aarch64_sve_mul: case Intrinsic::aarch64_sve_fmul: return instCombineSVEVectorMul(IC, II); + case Intrinsic::aarch64_sve_fadd: + return instCombineSVEVectorFAdd(IC, II); + case Intrinsic::aarch64_sve_fsub: + return instCombineSVEVectorBinOp(IC, II); case Intrinsic::aarch64_sve_tbl: return instCombineSVETBL(IC, II); + case Intrinsic::aarch64_sve_uunpkhi: + case Intrinsic::aarch64_sve_uunpklo: + case Intrinsic::aarch64_sve_sunpkhi: + case Intrinsic::aarch64_sve_sunpklo: + return instCombineSVEUnpack(IC, II); + case Intrinsic::aarch64_sve_tuple_get: + return instCombineSVETupleGet(IC, II); + case Intrinsic::aarch64_sve_zip1: + case Intrinsic::aarch64_sve_zip2: + return instCombineSVEZip(IC, II); + case Intrinsic::aarch64_sve_ld1_gather_index: + return instCombineLD1GatherIndex(IC, II); + case Intrinsic::aarch64_sve_st1_scatter_index: + return instCombineST1ScatterIndex(IC, II); + case Intrinsic::aarch64_sve_ld1: + return instCombineSVELD1(IC, II, DL); + case Intrinsic::aarch64_sve_st1: + return instCombineSVEST1(IC, II, DL); } return None; @@ -1393,9 +1635,13 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( return (Cost + 1) * LT.first; case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + case ISD::FDIV: + case ISD::FNEG: // These nodes are marked as 'custom' just to lower them to SVE. // We know said lowering will incur no additional cost. - if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty()) + if (!Ty->getScalarType()->isFP128Ty()) return (Cost + 2) * LT.first; return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, @@ -1525,8 +1771,7 @@ AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, InstructionCost AArch64TTIImpl::getGatherScatterOpCost( unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) { - - if (!isa<ScalableVectorType>(DataTy)) + if (useNeonVector(DataTy)) return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); auto *VT = cast<VectorType>(DataTy); @@ -1623,9 +1868,10 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost( // ldN/stN only support legal vector types of size 64 or 128 in bits. // Accesses having vector types that are a multiple of 128 bits can be // matched to more than one ldN/stN instruction. + bool UseScalable; if (NumElts % Factor == 0 && - TLI->isLegalInterleavedAccessType(SubVecTy, DL)) - return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL); + TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) + return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable); } return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, @@ -1705,9 +1951,12 @@ getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE, } void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, - TTI::UnrollingPreferences &UP) { + TTI::UnrollingPreferences &UP, + OptimizationRemarkEmitter *ORE) { // Enable partial unrolling and runtime unrolling. - BaseT::getUnrollingPreferences(L, SE, UP); + BaseT::getUnrollingPreferences(L, SE, UP, ORE); + + UP.UpperBound = true; // For inner loop, it is more likely to be a hot one, and the runtime check // can be promoted out from LICM pass, so the overhead is less, let's try @@ -1749,7 +1998,6 @@ void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, !ST->getSchedModel().isOutOfOrder()) { UP.Runtime = true; UP.Partial = true; - UP.UpperBound = true; UP.UnrollRemainder = true; UP.DefaultUnrollRuntimeCount = 4; @@ -1775,7 +2023,7 @@ Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, StructType *ST = dyn_cast<StructType>(ExpectedType); if (!ST) return nullptr; - unsigned NumElts = Inst->getNumArgOperands() - 1; + unsigned NumElts = Inst->arg_size() - 1; if (ST->getNumElements() != NumElts) return nullptr; for (unsigned i = 0, e = NumElts; i != e; ++i) { @@ -1816,7 +2064,7 @@ bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, case Intrinsic::aarch64_neon_st4: Info.ReadMem = false; Info.WriteMem = true; - Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1); + Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1); break; } @@ -1892,6 +2140,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction( case RecurKind::UMax: case RecurKind::FMin: case RecurKind::FMax: + case RecurKind::SelectICmp: + case RecurKind::SelectFCmp: return true; default: return false; @@ -1902,23 +2152,23 @@ InstructionCost AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned, TTI::TargetCostKind CostKind) { - if (!isa<ScalableVectorType>(Ty)) + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); + + if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16()) return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind); - assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) && - "Both vector needs to be scalable"); - std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); + assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) && + "Both vector needs to be equally scalable"); + InstructionCost LegalizationCost = 0; if (LT.first > 1) { Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext()); - unsigned CmpOpcode = - Ty->isFPOrFPVectorTy() ? Instruction::FCmp : Instruction::ICmp; - LegalizationCost = - getCmpSelInstrCost(CmpOpcode, LegalVTy, LegalVTy, - CmpInst::BAD_ICMP_PREDICATE, CostKind) + - getCmpSelInstrCost(Instruction::Select, LegalVTy, LegalVTy, - CmpInst::BAD_ICMP_PREDICATE, CostKind); - LegalizationCost *= LT.first - 1; + unsigned MinMaxOpcode = + Ty->isFPOrFPVectorTy() + ? Intrinsic::maxnum + : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin); + IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy}); + LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1); } return LegalizationCost + /*Cost of horizontal reduction*/ 2; @@ -1954,8 +2204,13 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, Optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) { if (TTI::requiresOrderedReduction(FMF)) { - if (!isa<ScalableVectorType>(ValTy)) - return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); + if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) { + InstructionCost BaseCost = + BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); + // Add on extra cost to reflect the extra overhead on some CPUs. We still + // end up vectorizing for more computationally intensive loops. + return BaseCost + FixedVTy->getNumElements(); + } if (Opcode != Instruction::FAdd) return InstructionCost::getInvalid(); |
