diff options
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
| -rw-r--r-- | llvm/lib/IR/Constants.cpp | 403 |
1 files changed, 294 insertions, 109 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index cbbcca20ea51..6fd205c654a8 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -161,7 +161,7 @@ bool Constant::isNotOneValue() const { // Check that vectors don't contain 1 if (auto *VTy = dyn_cast<VectorType>(this->getType())) { - unsigned NumElts = VTy->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(VTy)->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = this->getAggregateElement(i); if (!Elt || !Elt->isNotOneValue()) @@ -211,7 +211,7 @@ bool Constant::isNotMinSignedValue() const { // Check that vectors don't contain INT_MIN if (auto *VTy = dyn_cast<VectorType>(this->getType())) { - unsigned NumElts = VTy->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(VTy)->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = this->getAggregateElement(i); if (!Elt || !Elt->isNotMinSignedValue()) @@ -227,7 +227,7 @@ bool Constant::isNotMinSignedValue() const { bool Constant::isFiniteNonZeroFP() const { if (auto *CFP = dyn_cast<ConstantFP>(this)) return CFP->getValueAPF().isFiniteNonZero(); - auto *VTy = dyn_cast<VectorType>(getType()); + auto *VTy = dyn_cast<FixedVectorType>(getType()); if (!VTy) return false; for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { @@ -304,23 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const { return isa<UndefValue>(CmpEq) || match(CmpEq, m_One()); } -bool Constant::containsUndefElement() const { - if (auto *VTy = dyn_cast<VectorType>(getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) - if (isa<UndefValue>(getAggregateElement(i))) +static bool +containsUndefinedElement(const Constant *C, + function_ref<bool(const Constant *)> HasFn) { + if (auto *VTy = dyn_cast<VectorType>(C->getType())) { + if (HasFn(C)) + return true; + if (isa<ConstantAggregateZero>(C)) + return false; + if (isa<ScalableVectorType>(C->getType())) + return false; + + for (unsigned i = 0, e = cast<FixedVectorType>(VTy)->getNumElements(); + i != e; ++i) + if (HasFn(C->getAggregateElement(i))) return true; } return false; } +bool Constant::containsUndefOrPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa<UndefValue>(C); }); +} + +bool Constant::containsPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa<PoisonValue>(C); }); +} + bool Constant::containsConstantExpression() const { - if (auto *VTy = dyn_cast<VectorType>(getType())) { + if (auto *VTy = dyn_cast<FixedVectorType>(getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) if (isa<ConstantExpr>(getAggregateElement(i))) return true; } - return false; } @@ -400,16 +419,23 @@ Constant *Constant::getAllOnesValue(Type *Ty) { } Constant *Constant::getAggregateElement(unsigned Elt) const { - if (const ConstantAggregate *CC = dyn_cast<ConstantAggregate>(this)) + if (const auto *CC = dyn_cast<ConstantAggregate>(this)) return Elt < CC->getNumOperands() ? CC->getOperand(Elt) : nullptr; - if (const ConstantAggregateZero *CAZ = dyn_cast<ConstantAggregateZero>(this)) + // FIXME: getNumElements() will fail for non-fixed vector types. + if (isa<ScalableVectorType>(getType())) + return nullptr; + + if (const auto *CAZ = dyn_cast<ConstantAggregateZero>(this)) return Elt < CAZ->getNumElements() ? CAZ->getElementValue(Elt) : nullptr; - if (const UndefValue *UV = dyn_cast<UndefValue>(this)) + if (const auto *PV = dyn_cast<PoisonValue>(this)) + return Elt < PV->getNumElements() ? PV->getElementValue(Elt) : nullptr; + + if (const auto *UV = dyn_cast<UndefValue>(this)) return Elt < UV->getNumElements() ? UV->getElementValue(Elt) : nullptr; - if (const ConstantDataSequential *CDS =dyn_cast<ConstantDataSequential>(this)) + if (const auto *CDS = dyn_cast<ConstantDataSequential>(this)) return Elt < CDS->getNumElements() ? CDS->getElementAsConstant(Elt) : nullptr; return nullptr; @@ -501,9 +527,15 @@ void llvm::deleteConstant(Constant *C) { case Constant::BlockAddressVal: delete static_cast<BlockAddress *>(C); break; + case Constant::DSOLocalEquivalentVal: + delete static_cast<DSOLocalEquivalent *>(C); + break; case Constant::UndefValueVal: delete static_cast<UndefValue *>(C); break; + case Constant::PoisonValueVal: + delete static_cast<PoisonValue *>(C); + break; case Constant::ConstantExprVal: if (isa<UnaryConstantExpr>(C)) delete static_cast<UnaryConstantExpr *>(C); @@ -646,10 +678,17 @@ bool Constant::needsRelocation() const { return false; // Relative pointers do not need to be dynamically relocated. - if (auto *LHSGV = dyn_cast<GlobalValue>(LHSOp0->stripPointerCasts())) - if (auto *RHSGV = dyn_cast<GlobalValue>(RHSOp0->stripPointerCasts())) + if (auto *RHSGV = + dyn_cast<GlobalValue>(RHSOp0->stripInBoundsConstantOffsets())) { + auto *LHS = LHSOp0->stripInBoundsConstantOffsets(); + if (auto *LHSGV = dyn_cast<GlobalValue>(LHS)) { if (LHSGV->isDSOLocal() && RHSGV->isDSOLocal()) return false; + } else if (isa<DSOLocalEquivalent>(LHS)) { + if (RHSGV->isDSOLocal()) + return false; + } + } } } } @@ -729,6 +768,40 @@ Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) { return ConstantVector::get(NewC); } +Constant *Constant::mergeUndefsWith(Constant *C, Constant *Other) { + assert(C && Other && "Expected non-nullptr constant arguments"); + if (match(C, m_Undef())) + return C; + + Type *Ty = C->getType(); + if (match(Other, m_Undef())) + return UndefValue::get(Ty); + + auto *VTy = dyn_cast<FixedVectorType>(Ty); + if (!VTy) + return C; + + Type *EltTy = VTy->getElementType(); + unsigned NumElts = VTy->getNumElements(); + assert(isa<FixedVectorType>(Other->getType()) && + cast<FixedVectorType>(Other->getType())->getNumElements() == NumElts && + "Type mismatch"); + + bool FoundExtraUndef = false; + SmallVector<Constant *, 32> NewC(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + NewC[I] = C->getAggregateElement(I); + Constant *OtherEltC = Other->getAggregateElement(I); + assert(NewC[I] && OtherEltC && "Unknown vector element"); + if (!match(NewC[I], m_Undef()) && match(OtherEltC, m_Undef())) { + NewC[I] = UndefValue::get(EltTy); + FoundExtraUndef = true; + } + } + if (FoundExtraUndef) + return ConstantVector::get(NewC); + return C; +} //===----------------------------------------------------------------------===// // ConstantInt @@ -753,6 +826,10 @@ ConstantInt *ConstantInt::getFalse(LLVMContext &Context) { return pImpl->TheFalseVal; } +ConstantInt *ConstantInt::getBool(LLVMContext &Context, bool V) { + return V ? getTrue(Context) : getFalse(Context); +} + Constant *ConstantInt::getTrue(Type *Ty) { assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext()); @@ -769,6 +846,10 @@ Constant *ConstantInt::getFalse(Type *Ty) { return FalseC; } +Constant *ConstantInt::getBool(Type *Ty, bool V) { + return V ? getTrue(Ty) : getFalse(Ty); +} + // Get a ConstantInt from an APInt. ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) { // get an existing value or the insertion position @@ -830,30 +911,12 @@ void ConstantInt::destroyConstantImpl() { // ConstantFP //===----------------------------------------------------------------------===// -static const fltSemantics *TypeToFloatSemantics(Type *Ty) { - if (Ty->isHalfTy()) - return &APFloat::IEEEhalf(); - if (Ty->isBFloatTy()) - return &APFloat::BFloat(); - if (Ty->isFloatTy()) - return &APFloat::IEEEsingle(); - if (Ty->isDoubleTy()) - return &APFloat::IEEEdouble(); - if (Ty->isX86_FP80Ty()) - return &APFloat::x87DoubleExtended(); - else if (Ty->isFP128Ty()) - return &APFloat::IEEEquad(); - - assert(Ty->isPPC_FP128Ty() && "Unknown FP format"); - return &APFloat::PPCDoubleDouble(); -} - Constant *ConstantFP::get(Type *Ty, double V) { LLVMContext &Context = Ty->getContext(); APFloat FV(V); bool ignored; - FV.convert(*TypeToFloatSemantics(Ty->getScalarType()), + FV.convert(Ty->getScalarType()->getFltSemantics(), APFloat::rmNearestTiesToEven, &ignored); Constant *C = get(Context, FV); @@ -879,7 +942,7 @@ Constant *ConstantFP::get(Type *Ty, const APFloat &V) { Constant *ConstantFP::get(Type *Ty, StringRef Str) { LLVMContext &Context = Ty->getContext(); - APFloat FV(*TypeToFloatSemantics(Ty->getScalarType()), Str); + APFloat FV(Ty->getScalarType()->getFltSemantics(), Str); Constant *C = get(Context, FV); // For vectors, broadcast the value. @@ -890,7 +953,7 @@ Constant *ConstantFP::get(Type *Ty, StringRef Str) { } Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) { - const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType()); + const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); APFloat NaN = APFloat::getNaN(Semantics, Negative, Payload); Constant *C = get(Ty->getContext(), NaN); @@ -901,7 +964,7 @@ Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) { } Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) { - const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType()); + const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); APFloat NaN = APFloat::getQNaN(Semantics, Negative, Payload); Constant *C = get(Ty->getContext(), NaN); @@ -912,7 +975,7 @@ Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) { } Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) { - const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType()); + const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); APFloat NaN = APFloat::getSNaN(Semantics, Negative, Payload); Constant *C = get(Ty->getContext(), NaN); @@ -923,7 +986,7 @@ Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) { } Constant *ConstantFP::getNegativeZero(Type *Ty) { - const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType()); + const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); APFloat NegZero = APFloat::getZero(Semantics, /*Negative=*/true); Constant *C = get(Ty->getContext(), NegZero); @@ -949,24 +1012,7 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) { std::unique_ptr<ConstantFP> &Slot = pImpl->FPConstants[V]; if (!Slot) { - Type *Ty; - if (&V.getSemantics() == &APFloat::IEEEhalf()) - Ty = Type::getHalfTy(Context); - else if (&V.getSemantics() == &APFloat::BFloat()) - Ty = Type::getBFloatTy(Context); - else if (&V.getSemantics() == &APFloat::IEEEsingle()) - Ty = Type::getFloatTy(Context); - else if (&V.getSemantics() == &APFloat::IEEEdouble()) - Ty = Type::getDoubleTy(Context); - else if (&V.getSemantics() == &APFloat::x87DoubleExtended()) - Ty = Type::getX86_FP80Ty(Context); - else if (&V.getSemantics() == &APFloat::IEEEquad()) - Ty = Type::getFP128Ty(Context); - else { - assert(&V.getSemantics() == &APFloat::PPCDoubleDouble() && - "Unknown FP format"); - Ty = Type::getPPC_FP128Ty(Context); - } + Type *Ty = Type::getFloatingPointTy(Context, V.getSemantics()); Slot.reset(new ConstantFP(Ty, V)); } @@ -974,7 +1020,7 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) { } Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) { - const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType()); + const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) @@ -985,7 +1031,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) { ConstantFP::ConstantFP(Type *Ty, const APFloat &V) : ConstantData(Ty, ConstantFPVal), Val(V) { - assert(&V.getSemantics() == TypeToFloatSemantics(Ty) && + assert(&V.getSemantics() == &Ty->getFltSemantics() && "FP type Mismatch"); } @@ -1029,7 +1075,7 @@ unsigned ConstantAggregateZero::getNumElements() const { if (auto *AT = dyn_cast<ArrayType>(Ty)) return AT->getNumElements(); if (auto *VT = dyn_cast<VectorType>(Ty)) - return VT->getNumElements(); + return cast<FixedVectorType>(VT)->getNumElements(); return Ty->getStructNumElements(); } @@ -1064,11 +1110,37 @@ unsigned UndefValue::getNumElements() const { if (auto *AT = dyn_cast<ArrayType>(Ty)) return AT->getNumElements(); if (auto *VT = dyn_cast<VectorType>(Ty)) - return VT->getNumElements(); + return cast<FixedVectorType>(VT)->getNumElements(); return Ty->getStructNumElements(); } //===----------------------------------------------------------------------===// +// PoisonValue Implementation +//===----------------------------------------------------------------------===// + +PoisonValue *PoisonValue::getSequentialElement() const { + if (ArrayType *ATy = dyn_cast<ArrayType>(getType())) + return PoisonValue::get(ATy->getElementType()); + return PoisonValue::get(cast<VectorType>(getType())->getElementType()); +} + +PoisonValue *PoisonValue::getStructElement(unsigned Elt) const { + return PoisonValue::get(getType()->getStructElementType(Elt)); +} + +PoisonValue *PoisonValue::getElementValue(Constant *C) const { + if (isa<ArrayType>(getType()) || isa<VectorType>(getType())) + return getSequentialElement(); + return getStructElement(cast<ConstantInt>(C)->getZExtValue()); +} + +PoisonValue *PoisonValue::getElementValue(unsigned Idx) const { + if (isa<ArrayType>(getType()) || isa<VectorType>(getType())) + return getSequentialElement(); + return getStructElement(Idx); +} + +//===----------------------------------------------------------------------===// // ConstantXXX Classes //===----------------------------------------------------------------------===// @@ -1246,7 +1318,7 @@ Constant *ConstantStruct::get(StructType *ST, ArrayRef<Constant*> V) { ConstantVector::ConstantVector(VectorType *T, ArrayRef<Constant *> V) : ConstantAggregate(T, ConstantVectorVal, V) { - assert(V.size() == T->getNumElements() && + assert(V.size() == cast<FixedVectorType>(T)->getNumElements() && "Invalid initializer for constant vector"); } @@ -1267,17 +1339,20 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) { Constant *C = V[0]; bool isZero = C->isNullValue(); bool isUndef = isa<UndefValue>(C); + bool isPoison = isa<PoisonValue>(C); if (isZero || isUndef) { for (unsigned i = 1, e = V.size(); i != e; ++i) if (V[i] != C) { - isZero = isUndef = false; + isZero = isUndef = isPoison = false; break; } } if (isZero) return ConstantAggregateZero::get(T); + if (isPoison) + return PoisonValue::get(T); if (isUndef) return UndefValue::get(T); @@ -1292,14 +1367,14 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) { } Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { - if (!EC.Scalable) { + if (!EC.isScalable()) { // If this splat is compatible with ConstantDataVector, use it instead of // ConstantVector. if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) && ConstantDataSequential::isElementTypeCompatible(V->getType())) - return ConstantDataVector::getSplat(EC.Min, V); + return ConstantDataVector::getSplat(EC.getKnownMinValue(), V); - SmallVector<Constant *, 32> Elts(EC.Min, V); + SmallVector<Constant *, 32> Elts(EC.getKnownMinValue(), V); return get(Elts); } @@ -1316,7 +1391,7 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { Constant *UndefV = UndefValue::get(VTy); V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0)); // Build shuffle mask to perform the splat. - SmallVector<int, 8> Zeros(EC.Min, 0); + SmallVector<int, 8> Zeros(EC.getKnownMinValue(), 0); // Splat. return ConstantExpr::getShuffleVector(V, UndefV, Zeros); } @@ -1441,6 +1516,8 @@ Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty, OnlyIfReducedTy); case Instruction::ExtractValue: return ConstantExpr::getExtractValue(Ops[0], getIndices(), OnlyIfReducedTy); + case Instruction::FNeg: + return ConstantExpr::getFNeg(Ops[0]); case Instruction::ShuffleVector: return ConstantExpr::getShuffleVector(Ops[0], Ops[1], getShuffleMask(), OnlyIfReducedTy); @@ -1601,7 +1678,7 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const { ConstantInt *Index = dyn_cast<ConstantInt>(IElt->getOperand(2)); if (Index && Index->getValue() == 0 && - std::all_of(Mask.begin(), Mask.end(), [](int I) { return I == 0; })) + llvm::all_of(Mask, [](int I) { return I == 0; })) return SplatVal; } } @@ -1673,7 +1750,26 @@ UndefValue *UndefValue::get(Type *Ty) { /// Remove the constant from the constant table. void UndefValue::destroyConstantImpl() { // Free the constant and any dangling references to it. - getContext().pImpl->UVConstants.erase(getType()); + if (getValueID() == UndefValueVal) { + getContext().pImpl->UVConstants.erase(getType()); + } else if (getValueID() == PoisonValueVal) { + getContext().pImpl->PVConstants.erase(getType()); + } + llvm_unreachable("Not a undef or a poison!"); +} + +PoisonValue *PoisonValue::get(Type *Ty) { + std::unique_ptr<PoisonValue> &Entry = Ty->getContext().pImpl->PVConstants[Ty]; + if (!Entry) + Entry.reset(new PoisonValue(Ty)); + + return Entry.get(); +} + +/// Remove the constant from the constant table. +void PoisonValue::destroyConstantImpl() { + // Free the constant and any dangling references to it. + getContext().pImpl->PVConstants.erase(getType()); } BlockAddress *BlockAddress::get(BasicBlock *BB) { @@ -1754,6 +1850,58 @@ Value *BlockAddress::handleOperandChangeImpl(Value *From, Value *To) { return nullptr; } +DSOLocalEquivalent *DSOLocalEquivalent::get(GlobalValue *GV) { + DSOLocalEquivalent *&Equiv = GV->getContext().pImpl->DSOLocalEquivalents[GV]; + if (!Equiv) + Equiv = new DSOLocalEquivalent(GV); + + assert(Equiv->getGlobalValue() == GV && + "DSOLocalFunction does not match the expected global value"); + return Equiv; +} + +DSOLocalEquivalent::DSOLocalEquivalent(GlobalValue *GV) + : Constant(GV->getType(), Value::DSOLocalEquivalentVal, &Op<0>(), 1) { + setOperand(0, GV); +} + +/// Remove the constant from the constant table. +void DSOLocalEquivalent::destroyConstantImpl() { + const GlobalValue *GV = getGlobalValue(); + GV->getContext().pImpl->DSOLocalEquivalents.erase(GV); +} + +Value *DSOLocalEquivalent::handleOperandChangeImpl(Value *From, Value *To) { + assert(From == getGlobalValue() && "Changing value does not match operand."); + assert(isa<Constant>(To) && "Can only replace the operands with a constant"); + + // The replacement is with another global value. + if (const auto *ToObj = dyn_cast<GlobalValue>(To)) { + DSOLocalEquivalent *&NewEquiv = + getContext().pImpl->DSOLocalEquivalents[ToObj]; + if (NewEquiv) + return llvm::ConstantExpr::getBitCast(NewEquiv, getType()); + } + + // If the argument is replaced with a null value, just replace this constant + // with a null value. + if (cast<Constant>(To)->isNullValue()) + return To; + + // The replacement could be a bitcast or an alias to another function. We can + // replace it with a bitcast to the dso_local_equivalent of that function. + auto *Func = cast<Function>(To->stripPointerCastsAndAliases()); + DSOLocalEquivalent *&NewEquiv = getContext().pImpl->DSOLocalEquivalents[Func]; + if (NewEquiv) + return llvm::ConstantExpr::getBitCast(NewEquiv, getType()); + + // Replace this with the new one. + getContext().pImpl->DSOLocalEquivalents.erase(getGlobalValue()); + NewEquiv = this; + setOperand(0, Func); + return nullptr; +} + //---- ConstantExpr::get() implementations. // @@ -2002,8 +2150,8 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy, "PtrToInt destination must be integer or integer vector"); assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy)); if (isa<VectorType>(C->getType())) - assert(cast<VectorType>(C->getType())->getNumElements() == - cast<VectorType>(DstTy)->getNumElements() && + assert(cast<FixedVectorType>(C->getType())->getNumElements() == + cast<FixedVectorType>(DstTy)->getNumElements() && "Invalid cast between a different number of vector elements"); return getFoldedCast(Instruction::PtrToInt, C, DstTy, OnlyIfReduced); } @@ -2016,8 +2164,8 @@ Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy, "IntToPtr destination must be a pointer or pointer vector"); assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy)); if (isa<VectorType>(C->getType())) - assert(cast<VectorType>(C->getType())->getNumElements() == - cast<VectorType>(DstTy)->getNumElements() && + assert(cast<VectorType>(C->getType())->getElementCount() == + cast<VectorType>(DstTy)->getElementCount() && "Invalid cast between a different number of vector elements"); return getFoldedCast(Instruction::IntToPtr, C, DstTy, OnlyIfReduced); } @@ -2048,7 +2196,8 @@ Constant *ConstantExpr::getAddrSpaceCast(Constant *C, Type *DstTy, Type *MidTy = PointerType::get(DstElemTy, SrcScalarTy->getAddressSpace()); if (VectorType *VT = dyn_cast<VectorType>(DstTy)) { // Handle vectors of pointers. - MidTy = FixedVectorType::get(MidTy, VT->getNumElements()); + MidTy = FixedVectorType::get(MidTy, + cast<FixedVectorType>(VT)->getNumElements()); } C = getBitCast(C, MidTy); } @@ -2245,7 +2394,7 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, unsigned AS = C->getType()->getPointerAddressSpace(); Type *ReqTy = DestTy->getPointerTo(AS); - ElementCount EltCount = {0, false}; + auto EltCount = ElementCount::getFixed(0); if (VectorType *VecTy = dyn_cast<VectorType>(C->getType())) EltCount = VecTy->getElementCount(); else @@ -2253,7 +2402,7 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, if (VectorType *VecTy = dyn_cast<VectorType>(Idx->getType())) EltCount = VecTy->getElementCount(); - if (EltCount.Min != 0) + if (EltCount.isNonZero()) ReqTy = VectorType::get(ReqTy, EltCount); if (OnlyIfReducedTy == ReqTy) @@ -2273,7 +2422,7 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, if (GTI.isStruct() && Idx->getType()->isVectorTy()) { Idx = Idx->getSplatValue(); - } else if (GTI.isSequential() && EltCount.Min != 0 && + } else if (GTI.isSequential() && EltCount.isNonZero() && !Idx->getType()->isVectorTy()) { Idx = ConstantVector::getSplat(EltCount, Idx); } @@ -2549,6 +2698,11 @@ Constant *ConstantExpr::getXor(Constant *C1, Constant *C2) { return get(Instruction::Xor, C1, C2); } +Constant *ConstantExpr::getUMin(Constant *C1, Constant *C2) { + Constant *Cmp = ConstantExpr::getICmp(CmpInst::ICMP_ULT, C1, C2); + return getSelect(Cmp, C1, C2); +} + Constant *ConstantExpr::getShl(Constant *C1, Constant *C2, bool HasNUW, bool HasNSW) { unsigned Flags = (HasNUW ? OverflowingBinaryOperator::NoUnsignedWrap : 0) | @@ -2566,6 +2720,35 @@ Constant *ConstantExpr::getAShr(Constant *C1, Constant *C2, bool isExact) { isExact ? PossiblyExactOperator::IsExact : 0); } +Constant *ConstantExpr::getExactLogBase2(Constant *C) { + Type *Ty = C->getType(); + const APInt *IVal; + if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) + return ConstantInt::get(Ty, IVal->logBase2()); + + // FIXME: We can extract pow of 2 of splat constant for scalable vectors. + auto *VecTy = dyn_cast<FixedVectorType>(Ty); + if (!VecTy) + return nullptr; + + SmallVector<Constant *, 4> Elts; + for (unsigned I = 0, E = VecTy->getNumElements(); I != E; ++I) { + Constant *Elt = C->getAggregateElement(I); + if (!Elt) + return nullptr; + // Note that log2(iN undef) is *NOT* iN undef, because log2(iN undef) u< N. + if (isa<UndefValue>(Elt)) { + Elts.push_back(Constant::getNullValue(Ty->getScalarType())); + continue; + } + if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2()) + return nullptr; + Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2())); + } + + return ConstantVector::get(Elts); +} + Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty, bool AllowRHSConstant) { assert(Instruction::isBinaryOp(Opcode) && "Only binops allowed"); @@ -2690,7 +2873,7 @@ bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) { unsigned ConstantDataSequential::getNumElements() const { if (ArrayType *AT = dyn_cast<ArrayType>(getType())) return AT->getNumElements(); - return cast<VectorType>(getType())->getNumElements(); + return cast<FixedVectorType>(getType())->getNumElements(); } @@ -2739,56 +2922,58 @@ Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) { // body but different types. For example, 0,0,0,1 could be a 4 element array // of i8, or a 1-element array of i32. They'll both end up in the same /// StringMap bucket, linked up by their Next pointers. Walk the list. - ConstantDataSequential **Entry = &Slot.second; - for (ConstantDataSequential *Node = *Entry; Node; - Entry = &Node->Next, Node = *Entry) - if (Node->getType() == Ty) - return Node; + std::unique_ptr<ConstantDataSequential> *Entry = &Slot.second; + for (; *Entry; Entry = &(*Entry)->Next) + if ((*Entry)->getType() == Ty) + return Entry->get(); // Okay, we didn't get a hit. Create a node of the right class, link it in, // and return it. - if (isa<ArrayType>(Ty)) - return *Entry = new ConstantDataArray(Ty, Slot.first().data()); + if (isa<ArrayType>(Ty)) { + // Use reset because std::make_unique can't access the constructor. + Entry->reset(new ConstantDataArray(Ty, Slot.first().data())); + return Entry->get(); + } assert(isa<VectorType>(Ty)); - return *Entry = new ConstantDataVector(Ty, Slot.first().data()); + // Use reset because std::make_unique can't access the constructor. + Entry->reset(new ConstantDataVector(Ty, Slot.first().data())); + return Entry->get(); } void ConstantDataSequential::destroyConstantImpl() { // Remove the constant from the StringMap. - StringMap<ConstantDataSequential*> &CDSConstants = - getType()->getContext().pImpl->CDSConstants; + StringMap<std::unique_ptr<ConstantDataSequential>> &CDSConstants = + getType()->getContext().pImpl->CDSConstants; - StringMap<ConstantDataSequential*>::iterator Slot = - CDSConstants.find(getRawDataValues()); + auto Slot = CDSConstants.find(getRawDataValues()); assert(Slot != CDSConstants.end() && "CDS not found in uniquing table"); - ConstantDataSequential **Entry = &Slot->getValue(); + std::unique_ptr<ConstantDataSequential> *Entry = &Slot->getValue(); // Remove the entry from the hash table. if (!(*Entry)->Next) { // If there is only one value in the bucket (common case) it must be this // entry, and removing the entry should remove the bucket completely. - assert((*Entry) == this && "Hash mismatch in ConstantDataSequential"); + assert(Entry->get() == this && "Hash mismatch in ConstantDataSequential"); getContext().pImpl->CDSConstants.erase(Slot); - } else { - // Otherwise, there are multiple entries linked off the bucket, unlink the - // node we care about but keep the bucket around. - for (ConstantDataSequential *Node = *Entry; ; - Entry = &Node->Next, Node = *Entry) { - assert(Node && "Didn't find entry in its uniquing hash table!"); - // If we found our entry, unlink it from the list and we're done. - if (Node == this) { - *Entry = Node->Next; - break; - } - } + return; } - // If we were part of a list, make sure that we don't delete the list that is - // still owned by the uniquing map. - Next = nullptr; + // Otherwise, there are multiple entries linked off the bucket, unlink the + // node we care about but keep the bucket around. + while (true) { + std::unique_ptr<ConstantDataSequential> &Node = *Entry; + assert(Node && "Didn't find entry in its uniquing hash table!"); + // If we found our entry, unlink it from the list and we're done. + if (Node.get() == this) { + Node = std::move(Node->Next); + return; + } + + Entry = &Node->Next; + } } /// getFP() constructors - Return a constant of array type with a float @@ -2938,7 +3123,7 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) { return getFP(V->getType(), Elts); } } - return ConstantVector::getSplat({NumElts, false}, V); + return ConstantVector::getSplat(ElementCount::getFixed(NumElts), V); } @@ -3248,7 +3433,7 @@ Value *ConstantExpr::handleOperandChangeImpl(Value *From, Value *ToV) { } Instruction *ConstantExpr::getAsInstruction() const { - SmallVector<Value *, 4> ValueOperands(op_begin(), op_end()); + SmallVector<Value *, 4> ValueOperands(operands()); ArrayRef<Value*> Ops(ValueOperands); switch (getOpcode()) { |
