diff options
Diffstat (limited to 'llvm/lib/Analysis/ConstantFolding.cpp')
| -rw-r--r-- | llvm/lib/Analysis/ConstantFolding.cpp | 164 |
1 files changed, 83 insertions, 81 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 922b38e92785..7cf69f613c66 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -106,11 +106,8 @@ Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { "Invalid constantexpr bitcast!"); // Catch the obvious splat cases. - if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) - return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() && - !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! - return Constant::getAllOnesValue(DestTy); + if (Constant *Res = ConstantFoldLoadFromUniformValue(C, DestTy)) + return Res; if (auto *VTy = dyn_cast<VectorType>(C->getType())) { // Handle a vector->scalar integer/fp cast. @@ -362,16 +359,8 @@ Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy, // Catch the obvious splat cases (since all-zeros can coerce non-integral // pointers legally). - if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) - return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && - (DestTy->isIntegerTy() || DestTy->isFloatingPointTy() || - DestTy->isVectorTy()) && - !DestTy->isX86_AMXTy() && !DestTy->isX86_MMXTy() && - !DestTy->isPtrOrPtrVectorTy()) - // Get ones when the input is trivial, but - // only for supported types inside getAllOnesValue. - return Constant::getAllOnesValue(DestTy); + if (Constant *Res = ConstantFoldLoadFromUniformValue(C, DestTy)) + return Res; // If the type sizes are the same and a cast is legal, just directly // cast the constant. @@ -410,6 +399,12 @@ Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy, } while (ElemC && DL.getTypeSizeInBits(ElemC->getType()).isZero()); C = ElemC; } else { + // For non-byte-sized vector elements, the first element is not + // necessarily located at the vector base address. + if (auto *VT = dyn_cast<VectorType>(SrcTy)) + if (!DL.typeSizeEqualsStoreSize(VT->getElementType())) + return nullptr; + C = C->getAggregateElement(0u); } } while (C); @@ -558,23 +553,16 @@ Constant *FoldReinterpretLoadFromConst(Constant *C, Type *LoadTy, // If this isn't an integer load we can't fold it directly. if (!IntType) { - // If this is a float/double load, we can try folding it as an int32/64 load - // and then bitcast the result. This can be useful for union cases. Note + // If this is a non-integer load, we can try folding it as an int load and + // then bitcast the result. This can be useful for union cases. Note // that address spaces don't matter here since we're not going to result in // an actual new load. - Type *MapTy; - if (LoadTy->isHalfTy()) - MapTy = Type::getInt16Ty(C->getContext()); - else if (LoadTy->isFloatTy()) - MapTy = Type::getInt32Ty(C->getContext()); - else if (LoadTy->isDoubleTy()) - MapTy = Type::getInt64Ty(C->getContext()); - else if (LoadTy->isVectorTy()) { - MapTy = PointerType::getIntNTy( - C->getContext(), DL.getTypeSizeInBits(LoadTy).getFixedSize()); - } else + if (!LoadTy->isFloatingPointTy() && !LoadTy->isPointerTy() && + !LoadTy->isVectorTy()) return nullptr; + Type *MapTy = Type::getIntNTy( + C->getContext(), DL.getTypeSizeInBits(LoadTy).getFixedSize()); if (Constant *Res = FoldReinterpretLoadFromConst(C, MapTy, Offset, DL)) { if (Res->isNullValue() && !LoadTy->isX86_MMXTy() && !LoadTy->isX86_AMXTy()) @@ -680,9 +668,21 @@ Constant *llvm::ConstantFoldLoadFromConst(Constant *C, Type *Ty, if (Constant *Result = ConstantFoldLoadThroughBitcast(AtOffset, Ty, DL)) return Result; + // Explicitly check for out-of-bounds access, so we return undef even if the + // constant is a uniform value. + TypeSize Size = DL.getTypeAllocSize(C->getType()); + if (!Size.isScalable() && Offset.sge(Size.getFixedSize())) + return UndefValue::get(Ty); + + // Try an offset-independent fold of a uniform value. + if (Constant *Result = ConstantFoldLoadFromUniformValue(C, Ty)) + return Result; + // Try hard to fold loads from bitcasted strange and non-type-safe things. if (Offset.getMinSignedBits() <= 64) - return FoldReinterpretLoadFromConst(C, Ty, Offset.getSExtValue(), DL); + if (Constant *Result = + FoldReinterpretLoadFromConst(C, Ty, Offset.getSExtValue(), DL)) + return Result; return nullptr; } @@ -704,15 +704,13 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, Offset, DL)) return Result; - // If this load comes from anywhere in a constant global, and if the global - // is all undef or zero, we know what it loads. + // If this load comes from anywhere in a uniform constant global, the value + // is always the same, regardless of the loaded offset. if (auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(C))) { if (GV->isConstant() && GV->hasDefinitiveInitializer()) { - if (GV->getInitializer()->isNullValue() && !Ty->isX86_MMXTy() && - !Ty->isX86_AMXTy()) - return Constant::getNullValue(Ty); - if (isa<UndefValue>(GV->getInitializer())) - return UndefValue::get(Ty); + if (Constant *Res = + ConstantFoldLoadFromUniformValue(GV->getInitializer(), Ty)) + return Res; } } @@ -725,6 +723,19 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, return ConstantFoldLoadFromConstPtr(C, Ty, Offset, DL); } +Constant *llvm::ConstantFoldLoadFromUniformValue(Constant *C, Type *Ty) { + if (isa<PoisonValue>(C)) + return PoisonValue::get(Ty); + if (isa<UndefValue>(C)) + return UndefValue::get(Ty); + if (C->isNullValue() && !Ty->isX86_MMXTy() && !Ty->isX86_AMXTy()) + return Constant::getNullValue(Ty); + if (C->isAllOnesValue() && + (Ty->isIntOrIntVectorTy() || Ty->isFPOrFPVectorTy())) + return Constant::getAllOnesValue(Ty); + return nullptr; +} + namespace { /// One of Op0/Op1 is a constant expression. @@ -930,7 +941,7 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, if (auto *GV = dyn_cast<GlobalValue>(Ptr)) SrcElemTy = GV->getValueType(); else if (!PTy->isOpaque()) - SrcElemTy = PTy->getElementType(); + SrcElemTy = PTy->getNonOpaquePointerElementType(); else SrcElemTy = Type::getInt8Ty(Ptr->getContext()); @@ -1171,10 +1182,11 @@ Constant *llvm::ConstantFoldInstOperands(Instruction *I, return ConstantFoldInstOperandsImpl(I, I->getOpcode(), Ops, DL, TLI); } -Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, +Constant *llvm::ConstantFoldCompareInstOperands(unsigned IntPredicate, Constant *Ops0, Constant *Ops1, const DataLayout &DL, const TargetLibraryInfo *TLI) { + CmpInst::Predicate Predicate = (CmpInst::Predicate)IntPredicate; // fold: icmp (inttoptr x), null -> icmp x, 0 // fold: icmp null, (inttoptr x) -> icmp 0, x // fold: icmp (ptrtoint x), 0 -> icmp x, null @@ -1248,10 +1260,30 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); } + + // Convert pointer comparison (base+offset1) pred (base+offset2) into + // offset1 pred offset2, for the case where the offset is inbounds. This + // only works for equality and unsigned comparison, as inbounds permits + // crossing the sign boundary. However, the offset comparison itself is + // signed. + if (Ops0->getType()->isPointerTy() && !ICmpInst::isSigned(Predicate)) { + unsigned IndexWidth = DL.getIndexTypeSizeInBits(Ops0->getType()); + APInt Offset0(IndexWidth, 0); + Value *Stripped0 = + Ops0->stripAndAccumulateInBoundsConstantOffsets(DL, Offset0); + APInt Offset1(IndexWidth, 0); + Value *Stripped1 = + Ops1->stripAndAccumulateInBoundsConstantOffsets(DL, Offset1); + if (Stripped0 == Stripped1) + return ConstantExpr::getCompare( + ICmpInst::getSignedPredicate(Predicate), + ConstantInt::get(CE0->getContext(), Offset0), + ConstantInt::get(CE0->getContext(), Offset1)); + } } else if (isa<ConstantExpr>(Ops1)) { // If RHS is a constant expression, but the left side isn't, swap the // operands and try again. - Predicate = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)Predicate); + Predicate = ICmpInst::getSwappedPredicate(Predicate); return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI); } @@ -1347,23 +1379,6 @@ Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, } } -Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, - ConstantExpr *CE, - Type *Ty, - const DataLayout &DL) { - if (!CE->getOperand(1)->isNullValue()) - return nullptr; // Do not allow stepping over the value! - - // Loop over all of the operands, tracking down which value we are - // addressing. - for (unsigned i = 2, e = CE->getNumOperands(); i != e; ++i) { - C = C->getAggregateElement(CE->getOperand(i)); - if (!C) - return nullptr; - } - return ConstantFoldLoadThroughBitcast(C, Ty, DL); -} - //===----------------------------------------------------------------------===// // Constant Folding for Calls // @@ -2463,36 +2478,21 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, !getConstIntOrUndef(Operands[1], C1)) return nullptr; - unsigned BitWidth = Ty->getScalarSizeInBits(); switch (IntrinsicID) { default: break; case Intrinsic::smax: - if (!C0 && !C1) - return UndefValue::get(Ty); - if (!C0 || !C1) - return ConstantInt::get(Ty, APInt::getSignedMaxValue(BitWidth)); - return ConstantInt::get(Ty, C0->sgt(*C1) ? *C0 : *C1); - case Intrinsic::smin: - if (!C0 && !C1) - return UndefValue::get(Ty); - if (!C0 || !C1) - return ConstantInt::get(Ty, APInt::getSignedMinValue(BitWidth)); - return ConstantInt::get(Ty, C0->slt(*C1) ? *C0 : *C1); - case Intrinsic::umax: - if (!C0 && !C1) - return UndefValue::get(Ty); - if (!C0 || !C1) - return ConstantInt::get(Ty, APInt::getMaxValue(BitWidth)); - return ConstantInt::get(Ty, C0->ugt(*C1) ? *C0 : *C1); - case Intrinsic::umin: if (!C0 && !C1) return UndefValue::get(Ty); if (!C0 || !C1) - return ConstantInt::get(Ty, APInt::getMinValue(BitWidth)); - return ConstantInt::get(Ty, C0->ult(*C1) ? *C0 : *C1); + return MinMaxIntrinsic::getSaturationPoint(IntrinsicID, Ty); + return ConstantInt::get( + Ty, ICmpInst::compare(*C0, *C1, + MinMaxIntrinsic::getPredicate(IntrinsicID)) + ? *C0 + : *C1); case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: @@ -2572,9 +2572,9 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, case Intrinsic::ctlz: assert(C1 && "Must be constant int"); - // cttz(0, 1) and ctlz(0, 1) are undef. + // cttz(0, 1) and ctlz(0, 1) are poison. if (C1->isOne() && (!C0 || C0->isZero())) - return UndefValue::get(Ty); + return PoisonValue::get(Ty); if (!C0) return Constant::getNullValue(Ty); if (IntrinsicID == Intrinsic::cttz) @@ -2583,13 +2583,15 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, return ConstantInt::get(Ty, C0->countLeadingZeros()); case Intrinsic::abs: - // Undef or minimum val operand with poison min --> undef assert(C1 && "Must be constant int"); + assert((C1->isOne() || C1->isZero()) && "Must be 0 or 1"); + + // Undef or minimum val operand with poison min --> undef if (C1->isOne() && (!C0 || C0->isMinSignedValue())) return UndefValue::get(Ty); // Undef operand with no poison min --> 0 (sign bit must be clear) - if (C1->isZero() && !C0) + if (!C0) return Constant::getNullValue(Ty); return ConstantInt::get(Ty, C0->abs()); |
