diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 106 |
1 files changed, 90 insertions, 16 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index e79c4f75712b..7caf0fedb2ca 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -37,6 +37,7 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> { const RISCVSubtarget *getST() const { return ST; } const RISCVTargetLowering *getTLI() const { return TLI; } + unsigned getMaxVLFor(VectorType *Ty); public: explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F) : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)), @@ -57,10 +58,15 @@ public: bool shouldExpandReduction(const IntrinsicInst *II) const; bool supportsScalableVectors() const { return ST->hasVInstructions(); } Optional<unsigned> getMaxVScale() const; + Optional<unsigned> getVScaleForTuning() const; TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const; - InstructionCost getRegUsageForType(Type *Ty); + unsigned getRegUsageForType(Type *Ty); + + InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src, + Align Alignment, unsigned AddressSpace, + TTI::TargetCostKind CostKind); void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP, @@ -73,24 +79,50 @@ public: return ST->useRVVForFixedLengthVectors() ? 16 : 0; } + InstructionCost getSpliceCost(VectorType *Tp, int Index); + InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, + ArrayRef<int> Mask, int Index, + VectorType *SubTp, + ArrayRef<const Value *> Args = None); + + InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, + TTI::TargetCostKind CostKind); + InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I); + InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::CastContextHint CCH, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr); + + InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, + bool IsUnsigned, + TTI::TargetCostKind CostKind); + + InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, + Optional<FastMathFlags> FMF, + TTI::TargetCostKind CostKind); + + bool isElementTypeLegalForScalableVector(Type *Ty) const { + return TLI->isLegalElementTypeForRVV(Ty); + } + bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) { if (!ST->hasVInstructions()) return false; // Only support fixed vectors if we know the minimum vector size. - if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0) + if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors()) return false; // Don't allow elements larger than the ELEN. // FIXME: How to limit for scalable vectors? if (isa<FixedVectorType>(DataType) && - DataType->getScalarSizeInBits() > ST->getMaxELENForFixedLengthVectors()) + DataType->getScalarSizeInBits() > ST->getELEN()) return false; if (Alignment < @@ -112,13 +144,13 @@ public: return false; // Only support fixed vectors if we know the minimum vector size. - if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0) + if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors()) return false; // Don't allow elements larger than the ELEN. // FIXME: How to limit for scalable vectors? if (isa<FixedVectorType>(DataType) && - DataType->getScalarSizeInBits() > ST->getMaxELENForFixedLengthVectors()) + DataType->getScalarSizeInBits() > ST->getELEN()) return false; if (Alignment < @@ -135,6 +167,16 @@ public: return isLegalMaskedGatherScatter(DataType, Alignment); } + bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) { + // Scalarize masked gather for RV64 if EEW=64 indices aren't supported. + return ST->is64Bit() && !ST->hasVInstructionsI64(); + } + + bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) { + // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported. + return ST->is64Bit() && !ST->hasVInstructionsI64(); + } + /// \returns How the target needs this vector-predicated operation to be /// transformed. TargetTransformInfo::VPLegalization @@ -145,9 +187,6 @@ public: bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc, ElementCount VF) const { - if (!ST->hasVInstructions()) - return false; - if (!VF.isScalable()) return true; @@ -179,18 +218,53 @@ public: return VF == 1 ? 1 : ST->getMaxInterleaveFactor(); } - // TODO: We should define RISC-V's own register classes. - // e.g. register class for FPR. + enum RISCVRegisterClass { GPRRC, FPRRC, VRRC }; unsigned getNumberOfRegisters(unsigned ClassID) const { - bool Vector = (ClassID == 1); - if (Vector) { - if (ST->hasVInstructions()) + switch (ClassID) { + case RISCVRegisterClass::GPRRC: + // 31 = 32 GPR - x0 (zero register) + // FIXME: Should we exclude fixed registers like SP, TP or GP? + return 31; + case RISCVRegisterClass::FPRRC: + if (ST->hasStdExtF()) return 32; return 0; + case RISCVRegisterClass::VRRC: + // Although there are 32 vector registers, v0 is special in that it is the + // only register that can be used to hold a mask. + // FIXME: Should we conservatively return 31 as the number of usable + // vector registers? + return ST->hasVInstructions() ? 32 : 0; + } + llvm_unreachable("unknown register class"); + } + + unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const { + if (Vector) + return RISCVRegisterClass::VRRC; + if (!Ty) + return RISCVRegisterClass::GPRRC; + + Type *ScalarTy = Ty->getScalarType(); + if ((ScalarTy->isHalfTy() && ST->hasStdExtZfh()) || + (ScalarTy->isFloatTy() && ST->hasStdExtF()) || + (ScalarTy->isDoubleTy() && ST->hasStdExtD())) { + return RISCVRegisterClass::FPRRC; + } + + return RISCVRegisterClass::GPRRC; + } + + const char *getRegisterClassName(unsigned ClassID) const { + switch (ClassID) { + case RISCVRegisterClass::GPRRC: + return "RISCV::GPRRC"; + case RISCVRegisterClass::FPRRC: + return "RISCV::FPRRC"; + case RISCVRegisterClass::VRRC: + return "RISCV::VRRC"; } - // 31 = 32 GPR - x0 (zero register) - // FIXME: Should we exclude fixed registers like SP, TP or GP? - return 31; + llvm_unreachable("unknown register class"); } }; |
