diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
| commit | 344a3780b2e33f6ca763666c380202b18aab72a3 (patch) | |
| tree | f0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | |
| parent | b60736ec1405bb0a8dd40989f67ef4c93da068ab (diff) | |
vendor/llvm-project/llvmorg-13-init-16847-g88e66fa60ae5vendor/llvm-project/llvmorg-12.0.1-rc2-0-ge7dac564cd0evendor/llvm-project/llvmorg-12.0.1-0-gfed41342a82f
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 1269 |
1 files changed, 1114 insertions, 155 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 7fda6b8fb602..01236aa6b527 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <algorithm> using namespace llvm; using namespace llvm::PatternMatch; @@ -44,7 +45,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, /// Calculate the cost of materializing a 64-bit value. This helper /// method might only calculate a fraction of a larger immediate. Therefore it /// is valid to return a cost of ZERO. -int AArch64TTIImpl::getIntImmCost(int64_t Val) { +InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) { // Check if the immediate can be encoded within an instruction. if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64)) return 0; @@ -59,8 +60,8 @@ int AArch64TTIImpl::getIntImmCost(int64_t Val) { } /// Calculate the cost of materializing the given constant. -int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, - TTI::TargetCostKind CostKind) { +InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, + TTI::TargetCostKind CostKind) { assert(Ty->isIntegerTy()); unsigned BitSize = Ty->getPrimitiveSizeInBits(); @@ -74,20 +75,20 @@ int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, // Split the constant into 64-bit chunks and calculate the cost for each // chunk. - int Cost = 0; + InstructionCost Cost = 0; for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) { APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64); int64_t Val = Tmp.getSExtValue(); Cost += getIntImmCost(Val); } // We need at least one instruction to materialze the constant. - return std::max(1, Cost); + return std::max<InstructionCost>(1, Cost); } -int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, - const APInt &Imm, Type *Ty, - TTI::TargetCostKind CostKind, - Instruction *Inst) { +InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, + const APInt &Imm, Type *Ty, + TTI::TargetCostKind CostKind, + Instruction *Inst) { assert(Ty->isIntegerTy()); unsigned BitSize = Ty->getPrimitiveSizeInBits(); @@ -144,7 +145,7 @@ int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, if (Idx == ImmIdx) { int NumConstants = (BitSize + 63) / 64; - int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); + InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); return (Cost <= NumConstants * TTI::TCC_Basic) ? static_cast<int>(TTI::TCC_Free) : Cost; @@ -152,9 +153,10 @@ int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); } -int AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, - const APInt &Imm, Type *Ty, - TTI::TargetCostKind CostKind) { +InstructionCost +AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, + const APInt &Imm, Type *Ty, + TTI::TargetCostKind CostKind) { assert(Ty->isIntegerTy()); unsigned BitSize = Ty->getPrimitiveSizeInBits(); @@ -180,7 +182,7 @@ int AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, case Intrinsic::umul_with_overflow: if (Idx == 1) { int NumConstants = (BitSize + 63) / 64; - int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); + InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); return (Cost <= NumConstants * TTI::TCC_Basic) ? static_cast<int>(TTI::TCC_Free) : Cost; @@ -212,7 +214,7 @@ AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) { return TTI::PSK_Software; } -unsigned +InstructionCost AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) { auto *RetTy = ICA.getReturnType(); @@ -235,12 +237,605 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, return LT.first; break; } + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: { + static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, + MVT::v8i16, MVT::v2i32, MVT::v4i32, + MVT::v2i64}; + auto LT = TLI->getTypeLegalizationCost(DL, RetTy); + // This is a base cost of 1 for the vadd, plus 3 extract shifts if we + // need to extend the type, as it uses shr(qadd(shl, shl)). + unsigned Instrs = + LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4; + if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; })) + return LT.first * Instrs; + break; + } + case Intrinsic::abs: { + static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, + MVT::v8i16, MVT::v2i32, MVT::v4i32, + MVT::v2i64}; + auto LT = TLI->getTypeLegalizationCost(DL, RetTy); + if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; })) + return LT.first; + break; + } + case Intrinsic::experimental_stepvector: { + InstructionCost Cost = 1; // Cost of the `index' instruction + auto LT = TLI->getTypeLegalizationCost(DL, RetTy); + // Legalisation of illegal vectors involves an `index' instruction plus + // (LT.first - 1) vector adds. + if (LT.first > 1) { + Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext()); + InstructionCost AddCost = + getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind); + Cost += AddCost * (LT.first - 1); + } + return Cost; + } + case Intrinsic::bitreverse: { + static const CostTblEntry BitreverseTbl[] = { + {Intrinsic::bitreverse, MVT::i32, 1}, + {Intrinsic::bitreverse, MVT::i64, 1}, + {Intrinsic::bitreverse, MVT::v8i8, 1}, + {Intrinsic::bitreverse, MVT::v16i8, 1}, + {Intrinsic::bitreverse, MVT::v4i16, 2}, + {Intrinsic::bitreverse, MVT::v8i16, 2}, + {Intrinsic::bitreverse, MVT::v2i32, 2}, + {Intrinsic::bitreverse, MVT::v4i32, 2}, + {Intrinsic::bitreverse, MVT::v1i64, 2}, + {Intrinsic::bitreverse, MVT::v2i64, 2}, + }; + const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy); + const auto *Entry = + CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second); + // Cost Model is using the legal type(i32) that i8 and i16 will be converted + // to +1 so that we match the actual lowering cost + if (TLI->getValueType(DL, RetTy, true) == MVT::i8 || + TLI->getValueType(DL, RetTy, true) == MVT::i16) + return LegalisationCost.first * Entry->Cost + 1; + if (Entry) + return LegalisationCost.first * Entry->Cost; + break; + } + case Intrinsic::ctpop: { + static const CostTblEntry CtpopCostTbl[] = { + {ISD::CTPOP, MVT::v2i64, 4}, + {ISD::CTPOP, MVT::v4i32, 3}, + {ISD::CTPOP, MVT::v8i16, 2}, + {ISD::CTPOP, MVT::v16i8, 1}, + {ISD::CTPOP, MVT::i64, 4}, + {ISD::CTPOP, MVT::v2i32, 3}, + {ISD::CTPOP, MVT::v4i16, 2}, + {ISD::CTPOP, MVT::v8i8, 1}, + {ISD::CTPOP, MVT::i32, 5}, + }; + auto LT = TLI->getTypeLegalizationCost(DL, RetTy); + MVT MTy = LT.second; + if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) { + // Extra cost of +1 when illegal vector types are legalized by promoting + // the integer type. + int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() != + RetTy->getScalarSizeInBits() + ? 1 + : 0; + return LT.first * Entry->Cost + ExtraCost; + } + break; + } default: break; } return BaseT::getIntrinsicInstrCost(ICA, CostKind); } +/// The function will remove redundant reinterprets casting in the presence +/// of the control flow +static Optional<Instruction *> processPhiNode(InstCombiner &IC, + IntrinsicInst &II) { + SmallVector<Instruction *, 32> Worklist; + auto RequiredType = II.getType(); + + auto *PN = dyn_cast<PHINode>(II.getArgOperand(0)); + assert(PN && "Expected Phi Node!"); + + // Don't create a new Phi unless we can remove the old one. + if (!PN->hasOneUse()) + return None; + + for (Value *IncValPhi : PN->incoming_values()) { + auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi); + if (!Reinterpret || + Reinterpret->getIntrinsicID() != + Intrinsic::aarch64_sve_convert_to_svbool || + RequiredType != Reinterpret->getArgOperand(0)->getType()) + return None; + } + + // Create the new Phi + LLVMContext &Ctx = PN->getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(PN); + PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues()); + Worklist.push_back(PN); + + for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) { + auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I)); + NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I)); + Worklist.push_back(Reinterpret); + } + + // Cleanup Phi Node and reinterprets + return IC.replaceInstUsesWith(II, NPN); +} + +static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC, + IntrinsicInst &II) { + // If the reinterpret instruction operand is a PHI Node + if (isa<PHINode>(II.getArgOperand(0))) + return processPhiNode(IC, II); + + SmallVector<Instruction *, 32> CandidatesForRemoval; + Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr; + + const auto *IVTy = cast<VectorType>(II.getType()); + + // Walk the chain of conversions. + while (Cursor) { + // If the type of the cursor has fewer lanes than the final result, zeroing + // must take place, which breaks the equivalence chain. + const auto *CursorVTy = cast<VectorType>(Cursor->getType()); + if (CursorVTy->getElementCount().getKnownMinValue() < + IVTy->getElementCount().getKnownMinValue()) + break; + + // If the cursor has the same type as I, it is a viable replacement. + if (Cursor->getType() == IVTy) + EarliestReplacement = Cursor; + + auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor); + + // If this is not an SVE conversion intrinsic, this is the end of the chain. + if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() == + Intrinsic::aarch64_sve_convert_to_svbool || + IntrinsicCursor->getIntrinsicID() == + Intrinsic::aarch64_sve_convert_from_svbool)) + break; + + CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor); + Cursor = IntrinsicCursor->getOperand(0); + } + + // If no viable replacement in the conversion chain was found, there is + // nothing to do. + if (!EarliestReplacement) + return None; + + return IC.replaceInstUsesWith(II, EarliestReplacement); +} + +static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC, + IntrinsicInst &II) { + IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); + if (!Pg) + return None; + + if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) + return None; + + const auto PTruePattern = + cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); + if (PTruePattern != AArch64SVEPredPattern::vl1) + return None; + + // The intrinsic is inserting into lane zero so use an insert instead. + auto *IdxTy = Type::getInt64Ty(II.getContext()); + auto *Insert = InsertElementInst::Create( + II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0)); + Insert->insertBefore(&II); + Insert->takeName(&II); + + return IC.replaceInstUsesWith(II, Insert); +} + +static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, + IntrinsicInst &II) { + LLVMContext &Ctx = II.getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(&II); + + // Check that the predicate is all active + auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0)); + if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) + return None; + + const auto PTruePattern = + cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); + if (PTruePattern != AArch64SVEPredPattern::all) + return None; + + // Check that we have a compare of zero.. + auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2)); + if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) + return None; + + auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0)); + if (!DupXArg || !DupXArg->isZero()) + return None; + + // ..against a dupq + auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); + if (!DupQLane || + DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane) + return None; + + // Where the dupq is a lane 0 replicate of a vector insert + if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero()) + return None; + + auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0)); + if (!VecIns || + VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert) + return None; + + // Where the vector insert is a fixed constant vector insert into undef at + // index zero + if (!isa<UndefValue>(VecIns->getArgOperand(0))) + return None; + + if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero()) + return None; + + auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1)); + if (!ConstVec) + return None; + + auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType()); + auto *OutTy = dyn_cast<ScalableVectorType>(II.getType()); + if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements()) + return None; + + unsigned NumElts = VecTy->getNumElements(); + unsigned PredicateBits = 0; + + // Expand intrinsic operands to a 16-bit byte level predicate + for (unsigned I = 0; I < NumElts; ++I) { + auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I)); + if (!Arg) + return None; + if (!Arg->isZero()) + PredicateBits |= 1 << (I * (16 / NumElts)); + } + + // If all bits are zero bail early with an empty predicate + if (PredicateBits == 0) { + auto *PFalse = Constant::getNullValue(II.getType()); + PFalse->takeName(&II); + return IC.replaceInstUsesWith(II, PFalse); + } + + // Calculate largest predicate type used (where byte predicate is largest) + unsigned Mask = 8; + for (unsigned I = 0; I < 16; ++I) + if ((PredicateBits & (1 << I)) != 0) + Mask |= (I % 8); + + unsigned PredSize = Mask & -Mask; + auto *PredType = ScalableVectorType::get( + Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8)); + + // Ensure all relevant bits are set + for (unsigned I = 0; I < 16; I += PredSize) + if ((PredicateBits & (1 << I)) == 0) + return None; + + auto *PTruePat = + ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); + auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, + {PredType}, {PTruePat}); + auto *ConvertToSVBool = Builder.CreateIntrinsic( + Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue}); + auto *ConvertFromSVBool = + Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, + {II.getType()}, {ConvertToSVBool}); + + ConvertFromSVBool->takeName(&II); + return IC.replaceInstUsesWith(II, ConvertFromSVBool); +} + +static Optional<Instruction *> instCombineSVELast(InstCombiner &IC, + IntrinsicInst &II) { + Value *Pg = II.getArgOperand(0); + Value *Vec = II.getArgOperand(1); + bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta; + + // lastX(splat(X)) --> X + if (auto *SplatVal = getSplatValue(Vec)) + return IC.replaceInstUsesWith(II, SplatVal); + + auto *C = dyn_cast<Constant>(Pg); + if (IsAfter && C && C->isNullValue()) { + // The intrinsic is extracting lane 0 so use an extract instead. + auto *IdxTy = Type::getInt64Ty(II.getContext()); + auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0)); + Extract->insertBefore(&II); + Extract->takeName(&II); + return IC.replaceInstUsesWith(II, Extract); + } + + auto *IntrPG = dyn_cast<IntrinsicInst>(Pg); + if (!IntrPG) + return None; + + if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) + return None; + + const auto PTruePattern = + cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue(); + + // Can the intrinsic's predicate be converted to a known constant index? + unsigned Idx; + switch (PTruePattern) { + default: + return None; + case AArch64SVEPredPattern::vl1: + Idx = 0; + break; + case AArch64SVEPredPattern::vl2: + Idx = 1; + break; + case AArch64SVEPredPattern::vl3: + Idx = 2; + break; + case AArch64SVEPredPattern::vl4: + Idx = 3; + break; + case AArch64SVEPredPattern::vl5: + Idx = 4; + break; + case AArch64SVEPredPattern::vl6: + Idx = 5; + break; + case AArch64SVEPredPattern::vl7: + Idx = 6; + break; + case AArch64SVEPredPattern::vl8: + Idx = 7; + break; + case AArch64SVEPredPattern::vl16: + Idx = 15; + break; + } + + // Increment the index if extracting the element after the last active + // predicate element. + if (IsAfter) + ++Idx; + + // Ignore extracts whose index is larger than the known minimum vector + // length. NOTE: This is an artificial constraint where we prefer to + // maintain what the user asked for until an alternative is proven faster. + auto *PgVTy = cast<ScalableVectorType>(Pg->getType()); + if (Idx >= PgVTy->getMinNumElements()) + return None; + + // The intrinsic is extracting a fixed lane so use an extract instead. + auto *IdxTy = Type::getInt64Ty(II.getContext()); + auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx)); + Extract->insertBefore(&II); + Extract->takeName(&II); + return IC.replaceInstUsesWith(II, Extract); +} + +static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC, + IntrinsicInst &II) { + LLVMContext &Ctx = II.getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(&II); + // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr + // can work with RDFFR_PP for ptest elimination. + auto *AllPat = + ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); + auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, + {II.getType()}, {AllPat}); + auto *RDFFR = + Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue}); + RDFFR->takeName(&II); + return IC.replaceInstUsesWith(II, RDFFR); +} + +static Optional<Instruction *> +instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { + const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); + + if (Pattern == AArch64SVEPredPattern::all) { + LLVMContext &Ctx = II.getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(&II); + + Constant *StepVal = ConstantInt::get(II.getType(), NumElts); + auto *VScale = Builder.CreateVScale(StepVal); + VScale->takeName(&II); + return IC.replaceInstUsesWith(II, VScale); + } + + unsigned MinNumElts = 0; + switch (Pattern) { + default: + return None; + case AArch64SVEPredPattern::vl1: + case AArch64SVEPredPattern::vl2: + case AArch64SVEPredPattern::vl3: + case AArch64SVEPredPattern::vl4: + case AArch64SVEPredPattern::vl5: + case AArch64SVEPredPattern::vl6: + case AArch64SVEPredPattern::vl7: + case AArch64SVEPredPattern::vl8: + MinNumElts = Pattern; + break; + case AArch64SVEPredPattern::vl16: + MinNumElts = 16; + break; + } + + return NumElts >= MinNumElts + ? Optional<Instruction *>(IC.replaceInstUsesWith( + II, ConstantInt::get(II.getType(), MinNumElts))) + : None; +} + +static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, + IntrinsicInst &II) { + IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0)); + IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); + + if (Op1 && Op2 && + Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool && + Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool && + Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) { + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)}; + Type *Tys[] = {Op1->getArgOperand(0)->getType()}; + + auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops); + + PTest->takeName(&II); + return IC.replaceInstUsesWith(II, PTest); + } + + return None; +} + +static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, + IntrinsicInst &II) { + auto *OpPredicate = II.getOperand(0); + auto *OpMultiplicand = II.getOperand(1); + auto *OpMultiplier = II.getOperand(2); + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call + // with a unit splat value, false otherwise. + auto IsUnitDupX = [](auto *I) { + auto *IntrI = dyn_cast<IntrinsicInst>(I); + if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) + return false; + + auto *SplatValue = IntrI->getOperand(0); + return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); + }; + + // Return true if a given instruction is an aarch64_sve_dup intrinsic call + // with a unit splat value, false otherwise. + auto IsUnitDup = [](auto *I) { + auto *IntrI = dyn_cast<IntrinsicInst>(I); + if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup) + return false; + + auto *SplatValue = IntrI->getOperand(2); + return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); + }; + + // The OpMultiplier variable should always point to the dup (if any), so + // swap if necessary. + if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand)) + std::swap(OpMultiplier, OpMultiplicand); + + if (IsUnitDupX(OpMultiplier)) { + // [f]mul pg (dupx 1) %n => %n + OpMultiplicand->takeName(&II); + return IC.replaceInstUsesWith(II, OpMultiplicand); + } else if (IsUnitDup(OpMultiplier)) { + // [f]mul pg (dup pg 1) %n => %n + auto *DupInst = cast<IntrinsicInst>(OpMultiplier); + auto *DupPg = DupInst->getOperand(1); + // TODO: this is naive. The optimization is still valid if DupPg + // 'encompasses' OpPredicate, not only if they're the same predicate. + if (OpPredicate == DupPg) { + OpMultiplicand->takeName(&II); + return IC.replaceInstUsesWith(II, OpMultiplicand); + } + } + + return None; +} + +static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC, + IntrinsicInst &II) { + auto *OpVal = II.getOperand(0); + auto *OpIndices = II.getOperand(1); + VectorType *VTy = cast<VectorType>(II.getType()); + + // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with + // constant splat value < minimal element count of result. + auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices); + if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) + return None; + + auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0)); + if (!SplatValue || + SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue())) + return None; + + // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to + // splat_vector(extractelement(OpVal, SplatValue)) for further optimization. + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue); + auto *VectorSplat = + Builder.CreateVectorSplat(VTy->getElementCount(), Extract); + + VectorSplat->takeName(&II); + return IC.replaceInstUsesWith(II, VectorSplat); +} + +Optional<Instruction *> +AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, + IntrinsicInst &II) const { + Intrinsic::ID IID = II.getIntrinsicID(); + switch (IID) { + default: + break; + case Intrinsic::aarch64_sve_convert_from_svbool: + return instCombineConvertFromSVBool(IC, II); + case Intrinsic::aarch64_sve_dup: + return instCombineSVEDup(IC, II); + case Intrinsic::aarch64_sve_cmpne: + case Intrinsic::aarch64_sve_cmpne_wide: + return instCombineSVECmpNE(IC, II); + case Intrinsic::aarch64_sve_rdffr: + return instCombineRDFFR(IC, II); + case Intrinsic::aarch64_sve_lasta: + case Intrinsic::aarch64_sve_lastb: + return instCombineSVELast(IC, II); + case Intrinsic::aarch64_sve_cntd: + return instCombineSVECntElts(IC, II, 2); + case Intrinsic::aarch64_sve_cntw: + return instCombineSVECntElts(IC, II, 4); + case Intrinsic::aarch64_sve_cnth: + return instCombineSVECntElts(IC, II, 8); + case Intrinsic::aarch64_sve_cntb: + return instCombineSVECntElts(IC, II, 16); + case Intrinsic::aarch64_sve_ptest_any: + case Intrinsic::aarch64_sve_ptest_first: + case Intrinsic::aarch64_sve_ptest_last: + return instCombineSVEPTest(IC, II); + case Intrinsic::aarch64_sve_mul: + case Intrinsic::aarch64_sve_fmul: + return instCombineSVEVectorMul(IC, II); + case Intrinsic::aarch64_sve_tbl: + return instCombineSVETBL(IC, II); + } + + return None; +} + bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, ArrayRef<const Value *> Args) { @@ -297,18 +892,21 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, return false; // Get the total number of vector elements in the legalized types. - unsigned NumDstEls = DstTyL.first * DstTyL.second.getVectorMinNumElements(); - unsigned NumSrcEls = SrcTyL.first * SrcTyL.second.getVectorMinNumElements(); + InstructionCost NumDstEls = + DstTyL.first * DstTyL.second.getVectorMinNumElements(); + InstructionCost NumSrcEls = + SrcTyL.first * SrcTyL.second.getVectorMinNumElements(); // Return true if the legalized types have the same number of vector elements // and the destination element type size is twice that of the source type. return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize; } -int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - TTI::CastContextHint CCH, - TTI::TargetCostKind CostKind, - const Instruction *I) { +InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, + Type *Src, + TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -333,7 +931,7 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, } // TODO: Allow non-throughput costs that aren't binary. - auto AdjustCost = [&CostKind](int Cost) { + auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost { if (CostKind != TTI::TCK_RecipThroughput) return Cost == 0 ? 0 : 1; return Cost; @@ -353,6 +951,24 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 3 }, { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 }, + // Truncations on nxvmiN + { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 }, + { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 }, + { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 }, + { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 }, + { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 }, + { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 }, + { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 }, + { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 }, + { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 }, + { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 }, + { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 }, + { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 }, + { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 }, + { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 }, + { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 }, + { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 }, + // The number of shll instructions for the extension. { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, @@ -434,6 +1050,16 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 }, { ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 }, + // Complex, from nxv2f32. + { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, + // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2. { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 }, { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 }, @@ -441,6 +1067,107 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 }, { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 }, { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 }, + + // Complex, from nxv2f64. + { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, + + // Complex, from nxv4f32. + { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, + { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, + { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, + { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, + { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, + + // Complex, from nxv8f64. Illegal -> illegal conversions not required. + { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, + { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, + { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, + { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, + + // Complex, from nxv4f64. Illegal -> illegal conversions not required. + { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, + { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, + { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, + { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, + { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, + { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, + + // Complex, from nxv8f32. Illegal -> illegal conversions not required. + { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, + { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, + { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, + { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, + + // Complex, from nxv8f16. + { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, + { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, + { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, + { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, + { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, + + // Complex, from nxv4f16. + { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, + { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, + { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, + + // Complex, from nxv2f16. + { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, + { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, + { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, + + // Truncate from nxvmf32 to nxvmf16. + { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 }, + { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 }, + { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 }, + + // Truncate from nxvmf64 to nxvmf16. + { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 }, + { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 }, + { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 }, + + // Truncate from nxvmf64 to nxvmf32. + { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 }, + { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 }, + { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 }, + + // Extend from nxvmf16 to nxvmf32. + { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1}, + { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1}, + { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2}, + + // Extend from nxvmf16 to nxvmf64. + { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1}, + { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2}, + { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4}, + + // Extend from nxvmf32 to nxvmf64. + { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1}, + { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2}, + { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6}, + }; if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, @@ -452,9 +1179,10 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } -int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) { +InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, + Type *Dst, + VectorType *VecTy, + unsigned Index) { // Make sure we were given a valid extend opcode. assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && @@ -469,7 +1197,8 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, // Get the cost for the extract. We compute the cost (if any) for the extend // below. - auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index); + InstructionCost Cost = + getVectorInstrCost(Instruction::ExtractElement, VecTy, Index); // Legalize the types. auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy); @@ -511,8 +1240,9 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, CostKind); } -unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode, - TTI::TargetCostKind CostKind) { +InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode, + TTI::TargetCostKind CostKind, + const Instruction *I) { if (CostKind != TTI::TCK_RecipThroughput) return Opcode == Instruction::PHI ? 0 : 1; assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind"); @@ -520,13 +1250,13 @@ unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode, return 0; } -int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, - unsigned Index) { +InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, + unsigned Index) { assert(Val->isVectorTy() && "This must be a vector type"); if (Index != -1U) { // Legalize the type. - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Val); + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val); // This type is legalized to a scalar type. if (!LT.second.isVector()) @@ -545,10 +1275,10 @@ int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, return ST->getVectorInsertExtractBaseCost(); } -int AArch64TTIImpl::getArithmeticInstrCost( +InstructionCost AArch64TTIImpl::getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, - TTI::OperandValueKind Opd1Info, - TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo, + TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info, + TTI::OperandValueProperties Opd1PropInfo, TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args, const Instruction *CxtI) { // TODO: Handle more cost kinds. @@ -558,7 +1288,7 @@ int AArch64TTIImpl::getArithmeticInstrCost( Opd2PropInfo, Args, CxtI); // Legalize the type. - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.), // add in the widening overhead specified by the sub-target. Since the @@ -566,7 +1296,7 @@ int AArch64TTIImpl::getArithmeticInstrCost( // aren't present in the generated code and have a zero cost. By adding a // widening overhead here, we attach the total cost of the combined operation // to the widening instruction. - int Cost = 0; + InstructionCost Cost = 0; if (isWideningInstruction(Ty, Opcode, Args)) Cost += ST->getWideningBaseCost(); @@ -610,18 +1340,15 @@ int AArch64TTIImpl::getArithmeticInstrCost( // Vector signed division by constant are expanded to the // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division // to MULHS + SUB + SRL + ADD + SRL. - int MulCost = getArithmeticInstrCost(Instruction::Mul, Ty, CostKind, - Opd1Info, Opd2Info, - TargetTransformInfo::OP_None, - TargetTransformInfo::OP_None); - int AddCost = getArithmeticInstrCost(Instruction::Add, Ty, CostKind, - Opd1Info, Opd2Info, - TargetTransformInfo::OP_None, - TargetTransformInfo::OP_None); - int ShrCost = getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, - Opd1Info, Opd2Info, - TargetTransformInfo::OP_None, - TargetTransformInfo::OP_None); + InstructionCost MulCost = getArithmeticInstrCost( + Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info, + TargetTransformInfo::OP_None, TargetTransformInfo::OP_None); + InstructionCost AddCost = getArithmeticInstrCost( + Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info, + TargetTransformInfo::OP_None, TargetTransformInfo::OP_None); + InstructionCost ShrCost = getArithmeticInstrCost( + Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info, + TargetTransformInfo::OP_None, TargetTransformInfo::OP_None); return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1; } } @@ -677,8 +1404,9 @@ int AArch64TTIImpl::getArithmeticInstrCost( } } -int AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE, - const SCEV *Ptr) { +InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty, + ScalarEvolution *SE, + const SCEV *Ptr) { // Address computations in vectorized code with non-consecutive addresses will // likely result in more instructions compared to scalar code where the // computation can more often be merged into the index mode. The resulting @@ -695,10 +1423,11 @@ int AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE, return 1; } -int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, - Type *CondTy, CmpInst::Predicate VecPred, - TTI::TargetCostKind CostKind, - const Instruction *I) { +InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, + Type *CondTy, + CmpInst::Predicate VecPred, + TTI::TargetCostKind CostKind, + const Instruction *I) { // TODO: Handle other cost kinds. if (CostKind != TTI::TCK_RecipThroughput) return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, @@ -772,7 +1501,28 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { return Options; } -unsigned AArch64TTIImpl::getGatherScatterOpCost( +InstructionCost +AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, + Align Alignment, unsigned AddressSpace, + TTI::TargetCostKind CostKind) { + if (!isa<ScalableVectorType>(Src)) + return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, + CostKind); + auto LT = TLI->getTypeLegalizationCost(DL, Src); + if (!LT.first.isValid()) + return InstructionCost::getInvalid(); + + // The code-generator is currently not able to handle scalable vectors + // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting + // it. This change will be removed when code-generation for these types is + // sufficiently reliable. + if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1)) + return InstructionCost::getInvalid(); + + return LT.first * 2; +} + +InstructionCost AArch64TTIImpl::getGatherScatterOpCost( unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) { @@ -781,35 +1531,56 @@ unsigned AArch64TTIImpl::getGatherScatterOpCost( Alignment, CostKind, I); auto *VT = cast<VectorType>(DataTy); auto LT = TLI->getTypeLegalizationCost(DL, DataTy); - ElementCount LegalVF = LT.second.getVectorElementCount(); - Optional<unsigned> MaxNumVScale = getMaxVScale(); - assert(MaxNumVScale && "Expected valid max vscale value"); + if (!LT.first.isValid()) + return InstructionCost::getInvalid(); - unsigned MemOpCost = + // The code-generator is currently not able to handle scalable vectors + // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting + // it. This change will be removed when code-generation for these types is + // sufficiently reliable. + if (cast<VectorType>(DataTy)->getElementCount() == + ElementCount::getScalable(1)) + return InstructionCost::getInvalid(); + + ElementCount LegalVF = LT.second.getVectorElementCount(); + InstructionCost MemOpCost = getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I); - unsigned MaxNumElementsPerGather = - MaxNumVScale.getValue() * LegalVF.getKnownMinValue(); - return LT.first * MaxNumElementsPerGather * MemOpCost; + return LT.first * MemOpCost * getMaxNumElements(LegalVF); } bool AArch64TTIImpl::useNeonVector(const Type *Ty) const { return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors(); } -int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, - MaybeAlign Alignment, unsigned AddressSpace, - TTI::TargetCostKind CostKind, - const Instruction *I) { - // TODO: Handle other cost kinds. - if (CostKind != TTI::TCK_RecipThroughput) - return 1; - +InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, + MaybeAlign Alignment, + unsigned AddressSpace, + TTI::TargetCostKind CostKind, + const Instruction *I) { + EVT VT = TLI->getValueType(DL, Ty, true); // Type legalization can't handle structs - if (TLI->getValueType(DL, Ty, true) == MVT::Other) + if (VT == MVT::Other) return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind); auto LT = TLI->getTypeLegalizationCost(DL, Ty); + if (!LT.first.isValid()) + return InstructionCost::getInvalid(); + + // The code-generator is currently not able to handle scalable vectors + // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting + // it. This change will be removed when code-generation for these types is + // sufficiently reliable. + if (auto *VTy = dyn_cast<ScalableVectorType>(Ty)) + if (VTy->getElementCount() == ElementCount::getScalable(1)) + return InstructionCost::getInvalid(); + + // TODO: consider latency as well for TCK_SizeAndLatency. + if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) + return LT.first; + + if (CostKind != TTI::TCK_RecipThroughput) + return 1; if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store && LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) { @@ -823,29 +1594,20 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, return LT.first * 2 * AmortizationCost; } + // Check truncating stores and extending loads. if (useNeonVector(Ty) && - cast<VectorType>(Ty)->getElementType()->isIntegerTy(8)) { - unsigned ProfitableNumElements; - if (Opcode == Instruction::Store) - // We use a custom trunc store lowering so v.4b should be profitable. - ProfitableNumElements = 4; - else - // We scalarize the loads because there is not v.4b register and we - // have to promote the elements to v.2. - ProfitableNumElements = 8; - - if (cast<FixedVectorType>(Ty)->getNumElements() < ProfitableNumElements) { - unsigned NumVecElts = cast<FixedVectorType>(Ty)->getNumElements(); - unsigned NumVectorizableInstsToAmortize = NumVecElts * 2; - // We generate 2 instructions per vector element. - return NumVectorizableInstsToAmortize * NumVecElts * 2; - } + Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) { + // v4i8 types are lowered to scalar a load/store and sshll/xtn. + if (VT == MVT::v4i8) + return 2; + // Otherwise we need to scalarize. + return cast<FixedVectorType>(Ty)->getNumElements() * 2; } return LT.first; } -int AArch64TTIImpl::getInterleavedMemoryOpCost( +InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, bool UseMaskForCond, bool UseMaskForGaps) { @@ -871,8 +1633,9 @@ int AArch64TTIImpl::getInterleavedMemoryOpCost( UseMaskForCond, UseMaskForGaps); } -int AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { - int Cost = 0; +InstructionCost +AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { + InstructionCost Cost = 0; TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; for (auto *I : Tys) { if (!I->isVectorTy()) @@ -958,6 +1721,41 @@ void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, if (ST->getProcFamily() == AArch64Subtarget::Falkor && EnableFalkorHWPFUnrollFix) getFalkorUnrollingPreferences(L, SE, UP); + + // Scan the loop: don't unroll loops with calls as this could prevent + // inlining. Don't unroll vector loops either, as they don't benefit much from + // unrolling. + for (auto *BB : L->getBlocks()) { + for (auto &I : *BB) { + // Don't unroll vectorised loop. + if (I.getType()->isVectorTy()) + return; + + if (isa<CallInst>(I) || isa<InvokeInst>(I)) { + if (const Function *F = cast<CallBase>(I).getCalledFunction()) { + if (!isLoweredToCall(F)) + continue; + } + return; + } + } + } + + // Enable runtime unrolling for in-order models + // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by + // checking for that case, we can ensure that the default behaviour is + // unchanged + if (ST->getProcFamily() != AArch64Subtarget::Others && + !ST->getSchedModel().isOutOfOrder()) { + UP.Runtime = true; + UP.Partial = true; + UP.UpperBound = true; + UP.UnrollRemainder = true; + UP.DefaultUnrollRuntimeCount = 4; + + UP.UnrollAndJam = true; + UP.UnrollAndJamInnerLoopThreshold = 60; + } } void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE, @@ -1073,42 +1871,44 @@ bool AArch64TTIImpl::shouldConsiderAddressTypePromotion( return Considerable; } -bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty, - TTI::ReductionFlags Flags) const { - auto *VTy = cast<VectorType>(Ty); - unsigned ScalarBits = Ty->getScalarSizeInBits(); - switch (Opcode) { - case Instruction::FAdd: - case Instruction::FMul: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Mul: +bool AArch64TTIImpl::isLegalToVectorizeReduction( + const RecurrenceDescriptor &RdxDesc, ElementCount VF) const { + if (!VF.isScalable()) + return true; + + Type *Ty = RdxDesc.getRecurrenceType(); + if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty)) return false; - case Instruction::Add: - return ScalarBits * cast<FixedVectorType>(VTy)->getNumElements() >= 128; - case Instruction::ICmp: - return (ScalarBits < 64) && - (ScalarBits * cast<FixedVectorType>(VTy)->getNumElements() >= 128); - case Instruction::FCmp: - return Flags.NoNaN; + + switch (RdxDesc.getRecurrenceKind()) { + case RecurKind::Add: + case RecurKind::FAdd: + case RecurKind::And: + case RecurKind::Or: + case RecurKind::Xor: + case RecurKind::SMin: + case RecurKind::SMax: + case RecurKind::UMin: + case RecurKind::UMax: + case RecurKind::FMin: + case RecurKind::FMax: + return true; default: - llvm_unreachable("Unhandled reduction opcode"); + return false; } - return false; } -int AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, - bool IsPairwise, bool IsUnsigned, - TTI::TargetCostKind CostKind) { +InstructionCost +AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, + bool IsUnsigned, + TTI::TargetCostKind CostKind) { if (!isa<ScalableVectorType>(Ty)) - return BaseT::getMinMaxReductionCost(Ty, CondTy, IsPairwise, IsUnsigned, - CostKind); + return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind); assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) && "Both vector needs to be scalable"); - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); - int LegalizationCost = 0; + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty); + InstructionCost LegalizationCost = 0; if (LT.first > 1) { Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext()); unsigned CmpOpcode = @@ -1124,13 +1924,10 @@ int AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, return LegalizationCost + /*Cost of horizontal reduction*/ 2; } -int AArch64TTIImpl::getArithmeticReductionCostSVE( - unsigned Opcode, VectorType *ValTy, bool IsPairwise, - TTI::TargetCostKind CostKind) { - assert(!IsPairwise && "Cannot be pair wise to continue"); - - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy); - int LegalizationCost = 0; +InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE( + unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) { + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy); + InstructionCost LegalizationCost = 0; if (LT.first > 1) { Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext()); LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind); @@ -1148,51 +1945,162 @@ int AArch64TTIImpl::getArithmeticReductionCostSVE( case ISD::FADD: return LegalizationCost + 2; default: - // TODO: Replace for invalid when InstructionCost is used - // cases not supported by SVE - return 16; + return InstructionCost::getInvalid(); } } -int AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, - VectorType *ValTy, - bool IsPairwiseForm, - TTI::TargetCostKind CostKind) { +InstructionCost +AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, + Optional<FastMathFlags> FMF, + TTI::TargetCostKind CostKind) { + if (TTI::requiresOrderedReduction(FMF)) { + if (!isa<ScalableVectorType>(ValTy)) + return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); + + if (Opcode != Instruction::FAdd) + return InstructionCost::getInvalid(); + + auto *VTy = cast<ScalableVectorType>(ValTy); + InstructionCost Cost = + getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind); + Cost *= getMaxNumElements(VTy->getElementCount()); + return Cost; + } if (isa<ScalableVectorType>(ValTy)) - return getArithmeticReductionCostSVE(Opcode, ValTy, IsPairwiseForm, - CostKind); - if (IsPairwiseForm) - return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm, - CostKind); + return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind); - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy); + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy); MVT MTy = LT.second; int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); // Horizontal adds can use the 'addv' instruction. We model the cost of these - // instructions as normal vector adds. This is the only arithmetic vector - // reduction operation for which we have an instruction. + // instructions as twice a normal vector add, plus 1 for each legalization + // step (LT.first). This is the only arithmetic vector reduction operation for + // which we have an instruction. + // OR, XOR and AND costs should match the codegen from: + // OR: llvm/test/CodeGen/AArch64/reduce-or.ll + // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll + // AND: llvm/test/CodeGen/AArch64/reduce-and.ll static const CostTblEntry CostTblNoPairwise[]{ - {ISD::ADD, MVT::v8i8, 1}, - {ISD::ADD, MVT::v16i8, 1}, - {ISD::ADD, MVT::v4i16, 1}, - {ISD::ADD, MVT::v8i16, 1}, - {ISD::ADD, MVT::v4i32, 1}, + {ISD::ADD, MVT::v8i8, 2}, + {ISD::ADD, MVT::v16i8, 2}, + {ISD::ADD, MVT::v4i16, 2}, + {ISD::ADD, MVT::v8i16, 2}, + {ISD::ADD, MVT::v4i32, 2}, + {ISD::OR, MVT::v8i8, 15}, + {ISD::OR, MVT::v16i8, 17}, + {ISD::OR, MVT::v4i16, 7}, + {ISD::OR, MVT::v8i16, 9}, + {ISD::OR, MVT::v2i32, 3}, + {ISD::OR, MVT::v4i32, 5}, + {ISD::OR, MVT::v2i64, 3}, + {ISD::XOR, MVT::v8i8, 15}, + {ISD::XOR, MVT::v16i8, 17}, + {ISD::XOR, MVT::v4i16, 7}, + {ISD::XOR, MVT::v8i16, 9}, + {ISD::XOR, MVT::v2i32, 3}, + {ISD::XOR, MVT::v4i32, 5}, + {ISD::XOR, MVT::v2i64, 3}, + {ISD::AND, MVT::v8i8, 15}, + {ISD::AND, MVT::v16i8, 17}, + {ISD::AND, MVT::v4i16, 7}, + {ISD::AND, MVT::v8i16, 9}, + {ISD::AND, MVT::v2i32, 3}, + {ISD::AND, MVT::v4i32, 5}, + {ISD::AND, MVT::v2i64, 3}, + }; + switch (ISD) { + default: + break; + case ISD::ADD: + if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) + return (LT.first - 1) + Entry->Cost; + break; + case ISD::XOR: + case ISD::AND: + case ISD::OR: + const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy); + if (!Entry) + break; + auto *ValVTy = cast<FixedVectorType>(ValTy); + if (!ValVTy->getElementType()->isIntegerTy(1) && + MTy.getVectorNumElements() <= ValVTy->getNumElements() && + isPowerOf2_32(ValVTy->getNumElements())) { + InstructionCost ExtraCost = 0; + if (LT.first != 1) { + // Type needs to be split, so there is an extra cost of LT.first - 1 + // arithmetic ops. + auto *Ty = FixedVectorType::get(ValTy->getElementType(), + MTy.getVectorNumElements()); + ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind); + ExtraCost *= LT.first - 1; + } + return Entry->Cost + ExtraCost; + } + break; + } + return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); +} + +InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) { + static const CostTblEntry ShuffleTbl[] = { + { TTI::SK_Splice, MVT::nxv16i8, 1 }, + { TTI::SK_Splice, MVT::nxv8i16, 1 }, + { TTI::SK_Splice, MVT::nxv4i32, 1 }, + { TTI::SK_Splice, MVT::nxv2i64, 1 }, + { TTI::SK_Splice, MVT::nxv2f16, 1 }, + { TTI::SK_Splice, MVT::nxv4f16, 1 }, + { TTI::SK_Splice, MVT::nxv8f16, 1 }, + { TTI::SK_Splice, MVT::nxv2bf16, 1 }, + { TTI::SK_Splice, MVT::nxv4bf16, 1 }, + { TTI::SK_Splice, MVT::nxv8bf16, 1 }, + { TTI::SK_Splice, MVT::nxv2f32, 1 }, + { TTI::SK_Splice, MVT::nxv4f32, 1 }, + { TTI::SK_Splice, MVT::nxv2f64, 1 }, }; - if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) - return LT.first * Entry->Cost; + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp); + Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext()); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + EVT PromotedVT = LT.second.getScalarType() == MVT::i1 + ? TLI->getPromotedVTForPredicate(EVT(LT.second)) + : LT.second; + Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext()); + InstructionCost LegalizationCost = 0; + if (Index < 0) { + LegalizationCost = + getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy, + CmpInst::BAD_ICMP_PREDICATE, CostKind) + + getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy, + CmpInst::BAD_ICMP_PREDICATE, CostKind); + } - return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm, - CostKind); + // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp + // Cost performed on a promoted type. + if (LT.second.getScalarType() == MVT::i1) { + LegalizationCost += + getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy, + TTI::CastContextHint::None, CostKind) + + getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy, + TTI::CastContextHint::None, CostKind); + } + const auto *Entry = + CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT()); + assert(Entry && "Illegal Type for Splice"); + LegalizationCost += Entry->Cost; + return LegalizationCost * LT.first; } -int AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, - int Index, VectorType *SubTp) { +InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, + VectorType *Tp, + ArrayRef<int> Mask, int Index, + VectorType *SubTp) { + Kind = improveShuffleKindFromMask(Kind, Mask); if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose || - Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc) { + Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc || + Kind == TTI::SK_Reverse) { static const CostTblEntry ShuffleTbl[] = { // Broadcast shuffle kinds can be performed with 'dup'. { TTI::SK_Broadcast, MVT::v8i8, 1 }, @@ -1226,18 +2134,69 @@ int AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar). { TTI::SK_Select, MVT::v2f64, 1 }, // mov. // PermuteSingleSrc shuffle kinds. - // TODO: handle vXi8/vXi16. { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov. { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case. { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov. { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov. { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case. { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov. + { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case. + { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case. + { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case. + { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl + { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl + { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl + { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl + { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl + // Reverse can be lowered with `rev`. + { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov. + { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT + { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov. + { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov. + { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT + { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov. + // Broadcast shuffle kinds for scalable vectors + { TTI::SK_Broadcast, MVT::nxv16i8, 1 }, + { TTI::SK_Broadcast, MVT::nxv8i16, 1 }, + { TTI::SK_Broadcast, MVT::nxv4i32, 1 }, + { TTI::SK_Broadcast, MVT::nxv2i64, 1 }, + { TTI::SK_Broadcast, MVT::nxv2f16, 1 }, + { TTI::SK_Broadcast, MVT::nxv4f16, 1 }, + { TTI::SK_Broadcast, MVT::nxv8f16, 1 }, + { TTI::SK_Broadcast, MVT::nxv2bf16, 1 }, + { TTI::SK_Broadcast, MVT::nxv4bf16, 1 }, + { TTI::SK_Broadcast, MVT::nxv8bf16, 1 }, + { TTI::SK_Broadcast, MVT::nxv2f32, 1 }, + { TTI::SK_Broadcast, MVT::nxv4f32, 1 }, + { TTI::SK_Broadcast, MVT::nxv2f64, 1 }, + { TTI::SK_Broadcast, MVT::nxv16i1, 1 }, + { TTI::SK_Broadcast, MVT::nxv8i1, 1 }, + { TTI::SK_Broadcast, MVT::nxv4i1, 1 }, + { TTI::SK_Broadcast, MVT::nxv2i1, 1 }, + // Handle the cases for vector.reverse with scalable vectors + { TTI::SK_Reverse, MVT::nxv16i8, 1 }, + { TTI::SK_Reverse, MVT::nxv8i16, 1 }, + { TTI::SK_Reverse, MVT::nxv4i32, 1 }, + { TTI::SK_Reverse, MVT::nxv2i64, 1 }, + { TTI::SK_Reverse, MVT::nxv2f16, 1 }, + { TTI::SK_Reverse, MVT::nxv4f16, 1 }, + { TTI::SK_Reverse, MVT::nxv8f16, 1 }, + { TTI::SK_Reverse, MVT::nxv2bf16, 1 }, + { TTI::SK_Reverse, MVT::nxv4bf16, 1 }, + { TTI::SK_Reverse, MVT::nxv8bf16, 1 }, + { TTI::SK_Reverse, MVT::nxv2f32, 1 }, + { TTI::SK_Reverse, MVT::nxv4f32, 1 }, + { TTI::SK_Reverse, MVT::nxv2f64, 1 }, + { TTI::SK_Reverse, MVT::nxv16i1, 1 }, + { TTI::SK_Reverse, MVT::nxv8i1, 1 }, + { TTI::SK_Reverse, MVT::nxv4i1, 1 }, + { TTI::SK_Reverse, MVT::nxv2i1, 1 }, }; - std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp); + std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp); if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second)) return LT.first * Entry->Cost; } - - return BaseT::getShuffleCost(Kind, Tp, Index, SubTp); + if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp)) + return getSpliceCost(Tp, Index); + return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp); } |
