diff options
Diffstat (limited to 'llvm/lib/IR/Type.cpp')
-rw-r--r-- | llvm/lib/IR/Type.cpp | 176 |
1 files changed, 99 insertions, 77 deletions
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index 3eab5042b542..d869a6e07cca 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -40,6 +40,7 @@ Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) { switch (IDNumber) { case VoidTyID : return getVoidTy(C); case HalfTyID : return getHalfTy(C); + case BFloatTyID : return getBFloatTy(C); case FloatTyID : return getFloatTy(C); case DoubleTyID : return getDoubleTy(C); case X86_FP80TyID : return getX86_FP80Ty(C); @@ -68,20 +69,17 @@ bool Type::canLosslesslyBitCastTo(Type *Ty) const { return false; // Vector -> Vector conversions are always lossless if the two vector types - // have the same size, otherwise not. Also, 64-bit vector types can be - // converted to x86mmx. - if (auto *thisPTy = dyn_cast<VectorType>(this)) { - if (auto *thatPTy = dyn_cast<VectorType>(Ty)) - return thisPTy->getBitWidth() == thatPTy->getBitWidth(); - if (Ty->getTypeID() == Type::X86_MMXTyID && - thisPTy->getBitWidth() == 64) - return true; - } + // have the same size, otherwise not. + if (isa<VectorType>(this) && isa<VectorType>(Ty)) + return getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits(); - if (this->getTypeID() == Type::X86_MMXTyID) - if (auto *thatPTy = dyn_cast<VectorType>(Ty)) - if (thatPTy->getBitWidth() == 64) - return true; + // 64-bit fixed width vector types can be losslessly converted to x86mmx. + if (((isa<FixedVectorType>(this)) && Ty->isX86_MMXTy()) && + getPrimitiveSizeInBits().getFixedSize() == 64) + return true; + if ((isX86_MMXTy() && isa<FixedVectorType>(Ty)) && + Ty->getPrimitiveSizeInBits().getFixedSize() == 64) + return true; // At this point we have only various mismatches of the first class types // remaining and ptr->ptr. Just select the lossless conversions. Everything @@ -115,6 +113,7 @@ bool Type::isEmptyTy() const { TypeSize Type::getPrimitiveSizeInBits() const { switch (getTypeID()) { case Type::HalfTyID: return TypeSize::Fixed(16); + case Type::BFloatTyID: return TypeSize::Fixed(16); case Type::FloatTyID: return TypeSize::Fixed(32); case Type::DoubleTyID: return TypeSize::Fixed(64); case Type::X86_FP80TyID: return TypeSize::Fixed(80); @@ -123,16 +122,21 @@ TypeSize Type::getPrimitiveSizeInBits() const { case Type::X86_MMXTyID: return TypeSize::Fixed(64); case Type::IntegerTyID: return TypeSize::Fixed(cast<IntegerType>(this)->getBitWidth()); - case Type::VectorTyID: { + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { const VectorType *VTy = cast<VectorType>(this); - return TypeSize(VTy->getBitWidth(), VTy->isScalable()); + ElementCount EC = VTy->getElementCount(); + TypeSize ETS = VTy->getElementType()->getPrimitiveSizeInBits(); + assert(!ETS.isScalable() && "Vector type should have fixed-width elements"); + return {ETS.getFixedSize() * EC.Min, EC.Scalable}; } default: return TypeSize::Fixed(0); } } unsigned Type::getScalarSizeInBits() const { - return getScalarType()->getPrimitiveSizeInBits(); + // It is safe to assume that the scalar types have a fixed size. + return getScalarType()->getPrimitiveSizeInBits().getFixedSize(); } int Type::getFPMantissaWidth() const { @@ -140,6 +144,7 @@ int Type::getFPMantissaWidth() const { return VTy->getElementType()->getFPMantissaWidth(); assert(isFloatingPointTy() && "Not a floating point type!"); if (getTypeID() == HalfTyID) return 11; + if (getTypeID() == BFloatTyID) return 8; if (getTypeID() == FloatTyID) return 24; if (getTypeID() == DoubleTyID) return 53; if (getTypeID() == X86_FP80TyID) return 64; @@ -165,6 +170,7 @@ bool Type::isSizedDerivedType(SmallPtrSetImpl<Type*> *Visited) const { Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; } Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; } Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; } +Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; } Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; } Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; } Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; } @@ -189,6 +195,10 @@ PointerType *Type::getHalfPtrTy(LLVMContext &C, unsigned AS) { return getHalfTy(C)->getPointerTo(AS); } +PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) { + return getBFloatTy(C)->getPointerTo(AS); +} + PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) { return getFloatTy(C)->getPointerTo(AS); } @@ -509,11 +519,9 @@ StringRef StructType::getName() const { } bool StructType::isValidElementType(Type *ElemTy) { - if (auto *VTy = dyn_cast<VectorType>(ElemTy)) - return !VTy->isScalable(); return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() && !ElemTy->isMetadataTy() && !ElemTy->isFunctionTy() && - !ElemTy->isTokenTy(); + !ElemTy->isTokenTy() && !isa<ScalableVectorType>(ElemTy); } bool StructType::isLayoutIdentical(StructType *Other) const { @@ -529,52 +537,24 @@ StructType *Module::getTypeByName(StringRef Name) const { return getContext().pImpl->NamedStructTypes.lookup(Name); } -//===----------------------------------------------------------------------===// -// CompositeType Implementation -//===----------------------------------------------------------------------===// - -Type *CompositeType::getTypeAtIndex(const Value *V) const { - if (auto *STy = dyn_cast<StructType>(this)) { - unsigned Idx = - (unsigned)cast<Constant>(V)->getUniqueInteger().getZExtValue(); - assert(indexValid(Idx) && "Invalid structure index!"); - return STy->getElementType(Idx); - } - - return cast<SequentialType>(this)->getElementType(); +Type *StructType::getTypeAtIndex(const Value *V) const { + unsigned Idx = (unsigned)cast<Constant>(V)->getUniqueInteger().getZExtValue(); + assert(indexValid(Idx) && "Invalid structure index!"); + return getElementType(Idx); } -Type *CompositeType::getTypeAtIndex(unsigned Idx) const{ - if (auto *STy = dyn_cast<StructType>(this)) { - assert(indexValid(Idx) && "Invalid structure index!"); - return STy->getElementType(Idx); - } - - return cast<SequentialType>(this)->getElementType(); -} - -bool CompositeType::indexValid(const Value *V) const { - if (auto *STy = dyn_cast<StructType>(this)) { - // Structure indexes require (vectors of) 32-bit integer constants. In the - // vector case all of the indices must be equal. - if (!V->getType()->isIntOrIntVectorTy(32)) - return false; - const Constant *C = dyn_cast<Constant>(V); - if (C && V->getType()->isVectorTy()) - C = C->getSplatValue(); - const ConstantInt *CU = dyn_cast_or_null<ConstantInt>(C); - return CU && CU->getZExtValue() < STy->getNumElements(); - } - - // Sequential types can be indexed by any integer. - return V->getType()->isIntOrIntVectorTy(); -} - -bool CompositeType::indexValid(unsigned Idx) const { - if (auto *STy = dyn_cast<StructType>(this)) - return Idx < STy->getNumElements(); - // Sequential types can be indexed by any integer. - return true; +bool StructType::indexValid(const Value *V) const { + // Structure indexes require (vectors of) 32-bit integer constants. In the + // vector case all of the indices must be equal. + if (!V->getType()->isIntOrIntVectorTy(32)) + return false; + if (isa<ScalableVectorType>(V->getType())) + return false; + const Constant *C = dyn_cast<Constant>(V); + if (C && V->getType()->isVectorTy()) + C = C->getSplatValue(); + const ConstantInt *CU = dyn_cast_or_null<ConstantInt>(C); + return CU && CU->getZExtValue() < getNumElements(); } //===----------------------------------------------------------------------===// @@ -582,7 +562,11 @@ bool CompositeType::indexValid(unsigned Idx) const { //===----------------------------------------------------------------------===// ArrayType::ArrayType(Type *ElType, uint64_t NumEl) - : SequentialType(ArrayTyID, ElType, NumEl) {} + : Type(ElType->getContext(), ArrayTyID), ContainedType(ElType), + NumElements(NumEl) { + ContainedTys = &ContainedType; + NumContainedTys = 1; +} ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) { assert(isValidElementType(ElementType) && "Invalid type for array element!"); @@ -597,37 +581,75 @@ ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) { } bool ArrayType::isValidElementType(Type *ElemTy) { - if (auto *VTy = dyn_cast<VectorType>(ElemTy)) - return !VTy->isScalable(); return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() && !ElemTy->isMetadataTy() && !ElemTy->isFunctionTy() && - !ElemTy->isTokenTy(); + !ElemTy->isTokenTy() && !isa<ScalableVectorType>(ElemTy); } //===----------------------------------------------------------------------===// // VectorType Implementation //===----------------------------------------------------------------------===// -VectorType::VectorType(Type *ElType, ElementCount EC) - : SequentialType(VectorTyID, ElType, EC.Min), Scalable(EC.Scalable) {} +VectorType::VectorType(Type *ElType, unsigned EQ, Type::TypeID TID) + : Type(ElType->getContext(), TID), ContainedType(ElType), + ElementQuantity(EQ) { + ContainedTys = &ContainedType; + NumContainedTys = 1; +} VectorType *VectorType::get(Type *ElementType, ElementCount EC) { - assert(EC.Min > 0 && "#Elements of a VectorType must be greater than 0"); + if (EC.Scalable) + return ScalableVectorType::get(ElementType, EC.Min); + else + return FixedVectorType::get(ElementType, EC.Min); +} + +bool VectorType::isValidElementType(Type *ElemTy) { + return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() || + ElemTy->isPointerTy(); +} + +//===----------------------------------------------------------------------===// +// FixedVectorType Implementation +//===----------------------------------------------------------------------===// + +FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) { + assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0"); assert(isValidElementType(ElementType) && "Element type of a VectorType must " "be an integer, floating point, or " "pointer type."); + ElementCount EC(NumElts, false); + LLVMContextImpl *pImpl = ElementType->getContext().pImpl; - VectorType *&Entry = ElementType->getContext().pImpl - ->VectorTypes[std::make_pair(ElementType, EC)]; + VectorType *&Entry = ElementType->getContext() + .pImpl->VectorTypes[std::make_pair(ElementType, EC)]; + if (!Entry) - Entry = new (pImpl->Alloc) VectorType(ElementType, EC); - return Entry; + Entry = new (pImpl->Alloc) FixedVectorType(ElementType, NumElts); + return cast<FixedVectorType>(Entry); } -bool VectorType::isValidElementType(Type *ElemTy) { - return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() || - ElemTy->isPointerTy(); +//===----------------------------------------------------------------------===// +// ScalableVectorType Implementation +//===----------------------------------------------------------------------===// + +ScalableVectorType *ScalableVectorType::get(Type *ElementType, + unsigned MinNumElts) { + assert(MinNumElts > 0 && "#Elements of a VectorType must be greater than 0"); + assert(isValidElementType(ElementType) && "Element type of a VectorType must " + "be an integer, floating point, or " + "pointer type."); + + ElementCount EC(MinNumElts, true); + + LLVMContextImpl *pImpl = ElementType->getContext().pImpl; + VectorType *&Entry = ElementType->getContext() + .pImpl->VectorTypes[std::make_pair(ElementType, EC)]; + + if (!Entry) + Entry = new (pImpl->Alloc) ScalableVectorType(ElementType, MinNumElts); + return cast<ScalableVectorType>(Entry); } //===----------------------------------------------------------------------===// |