aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/Type.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/IR/Type.cpp')
-rw-r--r--llvm/lib/IR/Type.cpp176
1 files changed, 99 insertions, 77 deletions
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 3eab5042b542..d869a6e07cca 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -40,6 +40,7 @@ Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) {
switch (IDNumber) {
case VoidTyID : return getVoidTy(C);
case HalfTyID : return getHalfTy(C);
+ case BFloatTyID : return getBFloatTy(C);
case FloatTyID : return getFloatTy(C);
case DoubleTyID : return getDoubleTy(C);
case X86_FP80TyID : return getX86_FP80Ty(C);
@@ -68,20 +69,17 @@ bool Type::canLosslesslyBitCastTo(Type *Ty) const {
return false;
// Vector -> Vector conversions are always lossless if the two vector types
- // have the same size, otherwise not. Also, 64-bit vector types can be
- // converted to x86mmx.
- if (auto *thisPTy = dyn_cast<VectorType>(this)) {
- if (auto *thatPTy = dyn_cast<VectorType>(Ty))
- return thisPTy->getBitWidth() == thatPTy->getBitWidth();
- if (Ty->getTypeID() == Type::X86_MMXTyID &&
- thisPTy->getBitWidth() == 64)
- return true;
- }
+ // have the same size, otherwise not.
+ if (isa<VectorType>(this) && isa<VectorType>(Ty))
+ return getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits();
- if (this->getTypeID() == Type::X86_MMXTyID)
- if (auto *thatPTy = dyn_cast<VectorType>(Ty))
- if (thatPTy->getBitWidth() == 64)
- return true;
+ // 64-bit fixed width vector types can be losslessly converted to x86mmx.
+ if (((isa<FixedVectorType>(this)) && Ty->isX86_MMXTy()) &&
+ getPrimitiveSizeInBits().getFixedSize() == 64)
+ return true;
+ if ((isX86_MMXTy() && isa<FixedVectorType>(Ty)) &&
+ Ty->getPrimitiveSizeInBits().getFixedSize() == 64)
+ return true;
// At this point we have only various mismatches of the first class types
// remaining and ptr->ptr. Just select the lossless conversions. Everything
@@ -115,6 +113,7 @@ bool Type::isEmptyTy() const {
TypeSize Type::getPrimitiveSizeInBits() const {
switch (getTypeID()) {
case Type::HalfTyID: return TypeSize::Fixed(16);
+ case Type::BFloatTyID: return TypeSize::Fixed(16);
case Type::FloatTyID: return TypeSize::Fixed(32);
case Type::DoubleTyID: return TypeSize::Fixed(64);
case Type::X86_FP80TyID: return TypeSize::Fixed(80);
@@ -123,16 +122,21 @@ TypeSize Type::getPrimitiveSizeInBits() const {
case Type::X86_MMXTyID: return TypeSize::Fixed(64);
case Type::IntegerTyID:
return TypeSize::Fixed(cast<IntegerType>(this)->getBitWidth());
- case Type::VectorTyID: {
+ case Type::FixedVectorTyID:
+ case Type::ScalableVectorTyID: {
const VectorType *VTy = cast<VectorType>(this);
- return TypeSize(VTy->getBitWidth(), VTy->isScalable());
+ ElementCount EC = VTy->getElementCount();
+ TypeSize ETS = VTy->getElementType()->getPrimitiveSizeInBits();
+ assert(!ETS.isScalable() && "Vector type should have fixed-width elements");
+ return {ETS.getFixedSize() * EC.Min, EC.Scalable};
}
default: return TypeSize::Fixed(0);
}
}
unsigned Type::getScalarSizeInBits() const {
- return getScalarType()->getPrimitiveSizeInBits();
+ // It is safe to assume that the scalar types have a fixed size.
+ return getScalarType()->getPrimitiveSizeInBits().getFixedSize();
}
int Type::getFPMantissaWidth() const {
@@ -140,6 +144,7 @@ int Type::getFPMantissaWidth() const {
return VTy->getElementType()->getFPMantissaWidth();
assert(isFloatingPointTy() && "Not a floating point type!");
if (getTypeID() == HalfTyID) return 11;
+ if (getTypeID() == BFloatTyID) return 8;
if (getTypeID() == FloatTyID) return 24;
if (getTypeID() == DoubleTyID) return 53;
if (getTypeID() == X86_FP80TyID) return 64;
@@ -165,6 +170,7 @@ bool Type::isSizedDerivedType(SmallPtrSetImpl<Type*> *Visited) const {
Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
+Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; }
Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; }
@@ -189,6 +195,10 @@ PointerType *Type::getHalfPtrTy(LLVMContext &C, unsigned AS) {
return getHalfTy(C)->getPointerTo(AS);
}
+PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) {
+ return getBFloatTy(C)->getPointerTo(AS);
+}
+
PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) {
return getFloatTy(C)->getPointerTo(AS);
}
@@ -509,11 +519,9 @@ StringRef StructType::getName() const {
}
bool StructType::isValidElementType(Type *ElemTy) {
- if (auto *VTy = dyn_cast<VectorType>(ElemTy))
- return !VTy->isScalable();
return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() &&
!ElemTy->isMetadataTy() && !ElemTy->isFunctionTy() &&
- !ElemTy->isTokenTy();
+ !ElemTy->isTokenTy() && !isa<ScalableVectorType>(ElemTy);
}
bool StructType::isLayoutIdentical(StructType *Other) const {
@@ -529,52 +537,24 @@ StructType *Module::getTypeByName(StringRef Name) const {
return getContext().pImpl->NamedStructTypes.lookup(Name);
}
-//===----------------------------------------------------------------------===//
-// CompositeType Implementation
-//===----------------------------------------------------------------------===//
-
-Type *CompositeType::getTypeAtIndex(const Value *V) const {
- if (auto *STy = dyn_cast<StructType>(this)) {
- unsigned Idx =
- (unsigned)cast<Constant>(V)->getUniqueInteger().getZExtValue();
- assert(indexValid(Idx) && "Invalid structure index!");
- return STy->getElementType(Idx);
- }
-
- return cast<SequentialType>(this)->getElementType();
+Type *StructType::getTypeAtIndex(const Value *V) const {
+ unsigned Idx = (unsigned)cast<Constant>(V)->getUniqueInteger().getZExtValue();
+ assert(indexValid(Idx) && "Invalid structure index!");
+ return getElementType(Idx);
}
-Type *CompositeType::getTypeAtIndex(unsigned Idx) const{
- if (auto *STy = dyn_cast<StructType>(this)) {
- assert(indexValid(Idx) && "Invalid structure index!");
- return STy->getElementType(Idx);
- }
-
- return cast<SequentialType>(this)->getElementType();
-}
-
-bool CompositeType::indexValid(const Value *V) const {
- if (auto *STy = dyn_cast<StructType>(this)) {
- // Structure indexes require (vectors of) 32-bit integer constants. In the
- // vector case all of the indices must be equal.
- if (!V->getType()->isIntOrIntVectorTy(32))
- return false;
- const Constant *C = dyn_cast<Constant>(V);
- if (C && V->getType()->isVectorTy())
- C = C->getSplatValue();
- const ConstantInt *CU = dyn_cast_or_null<ConstantInt>(C);
- return CU && CU->getZExtValue() < STy->getNumElements();
- }
-
- // Sequential types can be indexed by any integer.
- return V->getType()->isIntOrIntVectorTy();
-}
-
-bool CompositeType::indexValid(unsigned Idx) const {
- if (auto *STy = dyn_cast<StructType>(this))
- return Idx < STy->getNumElements();
- // Sequential types can be indexed by any integer.
- return true;
+bool StructType::indexValid(const Value *V) const {
+ // Structure indexes require (vectors of) 32-bit integer constants. In the
+ // vector case all of the indices must be equal.
+ if (!V->getType()->isIntOrIntVectorTy(32))
+ return false;
+ if (isa<ScalableVectorType>(V->getType()))
+ return false;
+ const Constant *C = dyn_cast<Constant>(V);
+ if (C && V->getType()->isVectorTy())
+ C = C->getSplatValue();
+ const ConstantInt *CU = dyn_cast_or_null<ConstantInt>(C);
+ return CU && CU->getZExtValue() < getNumElements();
}
//===----------------------------------------------------------------------===//
@@ -582,7 +562,11 @@ bool CompositeType::indexValid(unsigned Idx) const {
//===----------------------------------------------------------------------===//
ArrayType::ArrayType(Type *ElType, uint64_t NumEl)
- : SequentialType(ArrayTyID, ElType, NumEl) {}
+ : Type(ElType->getContext(), ArrayTyID), ContainedType(ElType),
+ NumElements(NumEl) {
+ ContainedTys = &ContainedType;
+ NumContainedTys = 1;
+}
ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) {
assert(isValidElementType(ElementType) && "Invalid type for array element!");
@@ -597,37 +581,75 @@ ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) {
}
bool ArrayType::isValidElementType(Type *ElemTy) {
- if (auto *VTy = dyn_cast<VectorType>(ElemTy))
- return !VTy->isScalable();
return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() &&
!ElemTy->isMetadataTy() && !ElemTy->isFunctionTy() &&
- !ElemTy->isTokenTy();
+ !ElemTy->isTokenTy() && !isa<ScalableVectorType>(ElemTy);
}
//===----------------------------------------------------------------------===//
// VectorType Implementation
//===----------------------------------------------------------------------===//
-VectorType::VectorType(Type *ElType, ElementCount EC)
- : SequentialType(VectorTyID, ElType, EC.Min), Scalable(EC.Scalable) {}
+VectorType::VectorType(Type *ElType, unsigned EQ, Type::TypeID TID)
+ : Type(ElType->getContext(), TID), ContainedType(ElType),
+ ElementQuantity(EQ) {
+ ContainedTys = &ContainedType;
+ NumContainedTys = 1;
+}
VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
- assert(EC.Min > 0 && "#Elements of a VectorType must be greater than 0");
+ if (EC.Scalable)
+ return ScalableVectorType::get(ElementType, EC.Min);
+ else
+ return FixedVectorType::get(ElementType, EC.Min);
+}
+
+bool VectorType::isValidElementType(Type *ElemTy) {
+ return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
+ ElemTy->isPointerTy();
+}
+
+//===----------------------------------------------------------------------===//
+// FixedVectorType Implementation
+//===----------------------------------------------------------------------===//
+
+FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
+ assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0");
assert(isValidElementType(ElementType) && "Element type of a VectorType must "
"be an integer, floating point, or "
"pointer type.");
+ ElementCount EC(NumElts, false);
+
LLVMContextImpl *pImpl = ElementType->getContext().pImpl;
- VectorType *&Entry = ElementType->getContext().pImpl
- ->VectorTypes[std::make_pair(ElementType, EC)];
+ VectorType *&Entry = ElementType->getContext()
+ .pImpl->VectorTypes[std::make_pair(ElementType, EC)];
+
if (!Entry)
- Entry = new (pImpl->Alloc) VectorType(ElementType, EC);
- return Entry;
+ Entry = new (pImpl->Alloc) FixedVectorType(ElementType, NumElts);
+ return cast<FixedVectorType>(Entry);
}
-bool VectorType::isValidElementType(Type *ElemTy) {
- return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
- ElemTy->isPointerTy();
+//===----------------------------------------------------------------------===//
+// ScalableVectorType Implementation
+//===----------------------------------------------------------------------===//
+
+ScalableVectorType *ScalableVectorType::get(Type *ElementType,
+ unsigned MinNumElts) {
+ assert(MinNumElts > 0 && "#Elements of a VectorType must be greater than 0");
+ assert(isValidElementType(ElementType) && "Element type of a VectorType must "
+ "be an integer, floating point, or "
+ "pointer type.");
+
+ ElementCount EC(MinNumElts, true);
+
+ LLVMContextImpl *pImpl = ElementType->getContext().pImpl;
+ VectorType *&Entry = ElementType->getContext()
+ .pImpl->VectorTypes[std::make_pair(ElementType, EC)];
+
+ if (!Entry)
+ Entry = new (pImpl->Alloc) ScalableVectorType(ElementType, MinNumElts);
+ return cast<ScalableVectorType>(Entry);
}
//===----------------------------------------------------------------------===//