diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp | 121 |
1 files changed, 91 insertions, 30 deletions
diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index bccf94fc217f..5c008585869c 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -15,7 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" #define AA_NAME "alignment-from-assumptions" #define DEBUG_TYPE AA_NAME @@ -204,33 +203,103 @@ static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, } bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, - unsigned Idx, Value *&AAPtr, const SCEV *&AlignSCEV, const SCEV *&OffSCEV) { - Type *Int64Ty = Type::getInt64Ty(I->getContext()); - OperandBundleUse AlignOB = I->getOperandBundleAt(Idx); - if (AlignOB.getTagName() != "align") + // An alignment assume must be a statement about the least-significant + // bits of the pointer being zero, possibly with some offset. + ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); + if (!ICI) return false; - assert(AlignOB.Inputs.size() >= 2); - AAPtr = AlignOB.Inputs[0].get(); - // TODO: Consider accumulating the offset to the base. - AAPtr = AAPtr->stripPointerCastsSameRepresentation(); - AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get()); - AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty); - if (AlignOB.Inputs.size() == 3) - OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get()); - else + + // This must be an expression of the form: x & m == 0. + if (ICI->getPredicate() != ICmpInst::ICMP_EQ) + return false; + + // Swap things around so that the RHS is 0. + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); + const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); + if (CmpLHSSCEV->isZero()) + std::swap(CmpLHS, CmpRHS); + else if (!CmpRHSSCEV->isZero()) + return false; + + BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS); + if (!CmpBO || CmpBO->getOpcode() != Instruction::And) + return false; + + // Swap things around so that the right operand of the and is a constant + // (the mask); we cannot deal with variable masks. + Value *AndLHS = CmpBO->getOperand(0); + Value *AndRHS = CmpBO->getOperand(1); + const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); + const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); + if (isa<SCEVConstant>(AndLHSSCEV)) { + std::swap(AndLHS, AndRHS); + std::swap(AndLHSSCEV, AndRHSSCEV); + } + + const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV); + if (!MaskSCEV) + return false; + + // The mask must have some trailing ones (otherwise the condition is + // trivial and tells us nothing about the alignment of the left operand). + unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes(); + if (!TrailingOnes) + return false; + + // Cap the alignment at the maximum with which LLVM can deal (and make sure + // we don't overflow the shift). + uint64_t Alignment; + TrailingOnes = std::min(TrailingOnes, + unsigned(sizeof(unsigned) * CHAR_BIT - 1)); + Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); + + Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); + AlignSCEV = SE->getConstant(Int64Ty, Alignment); + + // The LHS might be a ptrtoint instruction, or it might be the pointer + // with an offset. + AAPtr = nullptr; + OffSCEV = nullptr; + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) { + AAPtr = PToI->getPointerOperand(); OffSCEV = SE->getZero(Int64Ty); - OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty); + } else if (const SCEVAddExpr* AndLHSAddSCEV = + dyn_cast<SCEVAddExpr>(AndLHSSCEV)) { + // Try to find the ptrtoint; subtract it and the rest is the offset. + for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), + JE = AndLHSAddSCEV->op_end(); J != JE; ++J) + if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J)) + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); + break; + } + } + + if (!AAPtr) + return false; + + // Sign extend the offset to 64 bits (so that it is like all of the other + // expressions). + unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); + if (OffSCEVBits < 64) + OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); + else if (OffSCEVBits > 64) + return false; + + AAPtr = AAPtr->stripPointerCasts(); return true; } -bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, - unsigned Idx) { +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { Value *AAPtr; const SCEV *AlignSCEV, *OffSCEV; - if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV)) + if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) return false; // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't @@ -248,14 +317,13 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, continue; if (Instruction *K = dyn_cast<Instruction>(J)) + if (isValidAssumeForContext(ACall, K, DT)) WorkList.push_back(K); } while (!WorkList.empty()) { Instruction *J = WorkList.pop_back_val(); if (LoadInst *LI = dyn_cast<LoadInst>(J)) { - if (!isValidAssumeForContext(ACall, J, DT)) - continue; Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, LI->getPointerOperand(), SE); if (NewAlignment > LI->getAlign()) { @@ -263,8 +331,6 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, ++NumLoadAlignChanged; } } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { - if (!isValidAssumeForContext(ACall, J, DT)) - continue; Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, SI->getPointerOperand(), SE); if (NewAlignment > SI->getAlign()) { @@ -272,8 +338,6 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, ++NumStoreAlignChanged; } } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { - if (!isValidAssumeForContext(ACall, J, DT)) - continue; Align NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE); @@ -305,7 +369,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, Visited.insert(J); for (User *UJ : J->users()) { Instruction *K = cast<Instruction>(UJ); - if (!Visited.count(K)) + if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT)) WorkList.push_back(K); } } @@ -332,11 +396,8 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, bool Changed = false; for (auto &AssumeVH : AC.assumptions()) - if (AssumeVH) { - CallInst *Call = cast<CallInst>(AssumeVH); - for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) - Changed |= processAssumption(Call, Idx); - } + if (AssumeVH) + Changed |= processAssumption(cast<CallInst>(AssumeVH)); return Changed; } |
