diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/SCCPSolver.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/SCCPSolver.cpp | 511 |
1 files changed, 348 insertions, 163 deletions
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 8d03a0d8a2c4..de3626a24212 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -41,6 +42,14 @@ static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { MaxNumRangeExtensions); } +static ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty, + bool UndefAllowed = true) { + assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); + if (LV.isConstantRange(UndefAllowed)) + return LV.getConstantRange(); + return ConstantRange::getFull(Ty->getScalarSizeInBits()); +} + namespace llvm { bool SCCPSolver::isConstant(const ValueLatticeElement &LV) { @@ -65,30 +74,9 @@ static bool canRemoveInstruction(Instruction *I) { } bool SCCPSolver::tryToReplaceWithConstant(Value *V) { - Constant *Const = nullptr; - if (V->getType()->isStructTy()) { - std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V); - if (llvm::any_of(IVs, isOverdefined)) - return false; - std::vector<Constant *> ConstVals; - auto *ST = cast<StructType>(V->getType()); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - ValueLatticeElement V = IVs[i]; - ConstVals.push_back(SCCPSolver::isConstant(V) - ? getConstant(V) - : UndefValue::get(ST->getElementType(i))); - } - Const = ConstantStruct::get(ST, ConstVals); - } else { - const ValueLatticeElement &IV = getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - - Const = SCCPSolver::isConstant(IV) ? getConstant(IV) - : UndefValue::get(V->getType()); - } - assert(Const && "Constant is nullptr here!"); - + Constant *Const = getConstantOrNull(V); + if (!Const) + return false; // Replacing `musttail` instructions with constant breaks `musttail` invariant // unless the call itself can be removed. // Calls with "clang.arc.attachedcall" implicitly use the return value and @@ -115,6 +103,47 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) { return true; } +/// Try to use \p Inst's value range from \p Solver to infer the NUW flag. +static bool refineInstruction(SCCPSolver &Solver, + const SmallPtrSetImpl<Value *> &InsertedValues, + Instruction &Inst) { + if (!isa<OverflowingBinaryOperator>(Inst)) + return false; + + auto GetRange = [&Solver, &InsertedValues](Value *Op) { + if (auto *Const = dyn_cast<ConstantInt>(Op)) + return ConstantRange(Const->getValue()); + if (isa<Constant>(Op) || InsertedValues.contains(Op)) { + unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); + return ConstantRange::getFull(Bitwidth); + } + return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(), + /*UndefAllowed=*/false); + }; + auto RangeA = GetRange(Inst.getOperand(0)); + auto RangeB = GetRange(Inst.getOperand(1)); + bool Changed = false; + if (!Inst.hasNoUnsignedWrap()) { + auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, + OverflowingBinaryOperator::NoUnsignedWrap); + if (NUWRange.contains(RangeA)) { + Inst.setHasNoUnsignedWrap(); + Changed = true; + } + } + if (!Inst.hasNoSignedWrap()) { + auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, OverflowingBinaryOperator::NoSignedWrap); + if (NSWRange.contains(RangeA)) { + Inst.setHasNoSignedWrap(); + Changed = true; + } + } + + return Changed; +} + /// Try to replace signed instructions with their unsigned equivalent. static bool replaceSignedInst(SCCPSolver &Solver, SmallPtrSetImpl<Value *> &InsertedValues, @@ -195,6 +224,8 @@ bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, } else if (replaceSignedInst(*this, InsertedValues, Inst)) { MadeChanges = true; ++InstReplacedStat; + } else if (refineInstruction(*this, InsertedValues, Inst)) { + MadeChanges = true; } } return MadeChanges; @@ -322,6 +353,10 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> { MapVector<std::pair<Function *, unsigned>, ValueLatticeElement> TrackedMultipleRetVals; + /// The set of values whose lattice has been invalidated. + /// Populated by resetLatticeValueFor(), cleared after resolving undefs. + DenseSet<Value *> Invalidated; + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. SmallPtrSet<Function *, 16> MRVFunctionsTracked; @@ -352,14 +387,15 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> { using Edge = std::pair<BasicBlock *, BasicBlock *>; DenseSet<Edge> KnownFeasibleEdges; - DenseMap<Function *, AnalysisResultsForFn> AnalysisResults; + DenseMap<Function *, std::unique_ptr<PredicateInfo>> FnPredicateInfo; + DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers; LLVMContext &Ctx; private: - ConstantInt *getConstantInt(const ValueLatticeElement &IV) const { - return dyn_cast_or_null<ConstantInt>(getConstant(IV)); + ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const { + return dyn_cast_or_null<ConstantInt>(getConstant(IV, Ty)); } // pushToWorkList - Helper for markConstant/markOverdefined @@ -447,6 +483,64 @@ private: return LV; } + /// Traverse the use-def chain of \p Call, marking itself and its users as + /// "unknown" on the way. + void invalidate(CallBase *Call) { + SmallVector<Instruction *, 64> ToInvalidate; + ToInvalidate.push_back(Call); + + while (!ToInvalidate.empty()) { + Instruction *Inst = ToInvalidate.pop_back_val(); + + if (!Invalidated.insert(Inst).second) + continue; + + if (!BBExecutable.count(Inst->getParent())) + continue; + + Value *V = nullptr; + // For return instructions we need to invalidate the tracked returns map. + // Anything else has its lattice in the value map. + if (auto *RetInst = dyn_cast<ReturnInst>(Inst)) { + Function *F = RetInst->getParent()->getParent(); + if (auto It = TrackedRetVals.find(F); It != TrackedRetVals.end()) { + It->second = ValueLatticeElement(); + V = F; + } else if (MRVFunctionsTracked.count(F)) { + auto *STy = cast<StructType>(F->getReturnType()); + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) + TrackedMultipleRetVals[{F, I}] = ValueLatticeElement(); + V = F; + } + } else if (auto *STy = dyn_cast<StructType>(Inst->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + if (auto It = StructValueState.find({Inst, I}); + It != StructValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + } else if (auto It = ValueState.find(Inst); It != ValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + + if (V) { + LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n"); + + for (User *U : V->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + ToInvalidate.push_back(UI); + + auto It = AdditionalUsers.find(V); + if (It != AdditionalUsers.end()) + for (User *U : It->second) + if (auto *UI = dyn_cast<Instruction>(U)) + ToInvalidate.push_back(UI); + } + } + } + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -520,6 +614,7 @@ private: void visitCastInst(CastInst &I); void visitSelectInst(SelectInst &I); void visitUnaryOperator(Instruction &I); + void visitFreezeInst(FreezeInst &I); void visitBinaryOperator(Instruction &I); void visitCmpInst(CmpInst &I); void visitExtractValueInst(ExtractValueInst &EVI); @@ -557,8 +652,8 @@ private: void visitInstruction(Instruction &I); public: - void addAnalysis(Function &F, AnalysisResultsForFn A) { - AnalysisResults.insert({&F, std::move(A)}); + void addPredicateInfo(Function &F, DominatorTree &DT, AssumptionCache &AC) { + FnPredicateInfo.insert({&F, std::make_unique<PredicateInfo>(F, DT, AC)}); } void visitCallInst(CallInst &I) { visitCallBase(I); } @@ -566,23 +661,10 @@ public: bool markBlockExecutable(BasicBlock *BB); const PredicateBase *getPredicateInfoFor(Instruction *I) { - auto A = AnalysisResults.find(I->getParent()->getParent()); - if (A == AnalysisResults.end()) + auto It = FnPredicateInfo.find(I->getParent()->getParent()); + if (It == FnPredicateInfo.end()) return nullptr; - return A->second.PredInfo->getPredicateInfoFor(I); - } - - const LoopInfo &getLoopInfo(Function &F) { - auto A = AnalysisResults.find(&F); - assert(A != AnalysisResults.end() && A->second.LI && - "Need LoopInfo analysis results for function."); - return *A->second.LI; - } - - DomTreeUpdater getDTU(Function &F) { - auto A = AnalysisResults.find(&F); - assert(A != AnalysisResults.end() && "Need analysis results for function."); - return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy}; + return It->second->getPredicateInfoFor(I); } SCCPInstVisitor(const DataLayout &DL, @@ -627,6 +709,8 @@ public: void solve(); + bool resolvedUndef(Instruction &I); + bool resolvedUndefsIn(Function &F); bool isBlockExecutable(BasicBlock *BB) const { @@ -649,6 +733,19 @@ public: void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call) { + // Calls to void returning functions do not need invalidation. + Function *F = Call->getCalledFunction(); + (void)F; + assert(!F->getReturnType()->isVoidTy() && + (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) && + "All non void specializations should be tracked"); + invalidate(Call); + handleCallResult(*Call); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -681,15 +778,16 @@ public: bool isStructLatticeConstant(Function *F, StructType *STy); - Constant *getConstant(const ValueLatticeElement &LV) const; - ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty) const; + Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const; + + Constant *getConstantOrNull(Value *V) const; SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() { return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl<ArgInfo> &Args); + void setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -715,6 +813,18 @@ public: ResolvedUndefs |= resolvedUndefsIn(*F); } } + + void solveWhileResolvedUndefs() { + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + solve(); + ResolvedUndefs = false; + for (Value *V : Invalidated) + if (auto *I = dyn_cast<Instruction>(V)) + ResolvedUndefs |= resolvedUndef(*I); + } + Invalidated.clear(); + } }; } // namespace llvm @@ -728,9 +838,13 @@ bool SCCPInstVisitor::markBlockExecutable(BasicBlock *BB) { } void SCCPInstVisitor::pushToWorkList(ValueLatticeElement &IV, Value *V) { - if (IV.isOverdefined()) - return OverdefinedInstWorkList.push_back(V); - InstWorkList.push_back(V); + if (IV.isOverdefined()) { + if (OverdefinedInstWorkList.empty() || OverdefinedInstWorkList.back() != V) + OverdefinedInstWorkList.push_back(V); + return; + } + if (InstWorkList.empty() || InstWorkList.back() != V) + InstWorkList.push_back(V); } void SCCPInstVisitor::pushToWorkListMsg(ValueLatticeElement &IV, Value *V) { @@ -771,57 +885,84 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { return true; } -Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const { - if (LV.isConstant()) - return LV.getConstant(); +Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + if (LV.isConstant()) { + Constant *C = LV.getConstant(); + assert(C->getType() == Ty && "Type mismatch"); + return C; + } if (LV.isConstantRange()) { const auto &CR = LV.getConstantRange(); if (CR.getSingleElement()) - return ConstantInt::get(Ctx, *CR.getSingleElement()); + return ConstantInt::get(Ty, *CR.getSingleElement()); } return nullptr; } -ConstantRange -SCCPInstVisitor::getConstantRange(const ValueLatticeElement &LV, - Type *Ty) const { - assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); - if (LV.isConstantRange()) - return LV.getConstantRange(); - return ConstantRange::getFull(Ty->getScalarSizeInBits()); +Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector<ValueLatticeElement> LVs = getStructLatticeValueFor(V); + if (any_of(LVs, SCCPSolver::isOverdefined)) + return nullptr; + std::vector<Constant *> ConstVals; + auto *ST = cast<StructType>(V->getType()); + for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { + ValueLatticeElement LV = LVs[I]; + ConstVals.push_back(SCCPSolver::isConstant(LV) + ? getConstant(LV, ST->getElementType(I)) + : UndefValue::get(ST->getElementType(I))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &LV = getLatticeValueFor(V); + if (SCCPSolver::isOverdefined(LV)) + return nullptr; + Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType()) + : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + return Const; } -void SCCPInstVisitor::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl<ArgInfo> &Args) { +void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args) { assert(!Args.empty() && "Specialization without arguments"); assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && "Functions should have the same number of arguments"); auto Iter = Args.begin(); - Argument *NewArg = F->arg_begin(); - Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); + Function::arg_iterator NewArg = F->arg_begin(); + Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin(); for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { LLVM_DEBUG(dbgs() << "SCCP: Marking argument " << NewArg->getNameOrAsOperand() << "\n"); - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); + // Mark the argument constants in the new function + // or copy the lattice state over from the old function. + if (Iter != Args.end() && Iter->Formal == &*OldArg) { + if (auto *STy = dyn_cast<StructType>(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue.markConstant(Iter->Actual->getAggregateElement(I)); + } + } else { + ValueState[&*NewArg].markConstant(Iter->Actual); + } ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); + } else { + if (auto *STy = dyn_cast<StructType>(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue = StructValueState[{&*OldArg, I}]; + } + } else { + ValueLatticeElement &NewValue = ValueState[&*NewArg]; + NewValue = ValueState[&*OldArg]; + } } } } @@ -874,7 +1015,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, } ValueLatticeElement BCValue = getValueState(BI->getCondition()); - ConstantInt *CI = getConstantInt(BCValue); + ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType()); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant // conditions, mean the branch could go either way. @@ -900,7 +1041,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } const ValueLatticeElement &SCValue = getValueState(SI->getCondition()); - if (ConstantInt *CI = getConstantInt(SCValue)) { + if (ConstantInt *CI = + getConstantInt(SCValue, SI->getCondition()->getType())) { Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; return; } @@ -931,7 +1073,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { // Casts are folded by visitCastInst. ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); - BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue)); + BlockAddress *Addr = dyn_cast_or_null<BlockAddress>( + getConstant(IBRValue, IBR->getAddress()->getType())); if (!Addr) { // Overdefined or unknown condition? // All destinations are executable! if (!IBRValue.isUnknownOrUndef()) @@ -1086,7 +1229,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (OpSt.isUnknownOrUndef()) return; - if (Constant *OpC = getConstant(OpSt)) { + if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) { // Fold the constant as we build. Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); markConstant(&I, C); @@ -1221,7 +1364,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) { if (CondValue.isUnknownOrUndef()) return; - if (ConstantInt *CondCB = getConstantInt(CondValue)) { + if (ConstantInt *CondCB = + getConstantInt(CondValue, I.getCondition()->getType())) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); mergeInValue(&I, getValueState(OpVal)); return; @@ -1254,13 +1398,37 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) { return; if (SCCPSolver::isConstant(V0State)) - if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(), - getConstant(V0State), DL)) + if (Constant *C = ConstantFoldUnaryOpOperand( + I.getOpcode(), getConstant(V0State, I.getType()), DL)) return (void)markConstant(IV, &I, C); markOverdefined(&I); } +void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) { + // If this freeze returns a struct, just mark the result overdefined. + // TODO: We could do a lot better than this. + if (I.getType()->isStructTy()) + return (void)markOverdefined(&I); + + ValueLatticeElement V0State = getValueState(I.getOperand(0)); + ValueLatticeElement &IV = ValueState[&I]; + // resolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (SCCPSolver::isOverdefined(IV)) + return (void)markOverdefined(&I); + + // If something is unknown/undef, wait for it to resolve. + if (V0State.isUnknownOrUndef()) + return; + + if (SCCPSolver::isConstant(V0State) && + isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType()))) + return (void)markConstant(IV, &I, getConstant(V0State, I.getType())); + + markOverdefined(&I); +} + // Handle Binary Operators. void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { ValueLatticeElement V1State = getValueState(I.getOperand(0)); @@ -1280,10 +1448,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // If either of the operands is a constant, try to fold it to a constant. // TODO: Use information from notconstant better. if ((V1State.isConstant() || V2State.isConstant())) { - Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State) - : I.getOperand(0); - Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State) - : I.getOperand(1); + Value *V1 = SCCPSolver::isConstant(V1State) + ? getConstant(V1State, I.getOperand(0)->getType()) + : I.getOperand(0); + Value *V2 = SCCPSolver::isConstant(V2State) + ? getConstant(V2State, I.getOperand(1)->getType()) + : I.getOperand(1); Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); auto *C = dyn_cast_or_null<Constant>(R); if (C) { @@ -1361,7 +1531,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&I); - if (Constant *C = getConstant(State)) { + if (Constant *C = getConstant(State, I.getOperand(i)->getType())) { Operands.push_back(C); continue; } @@ -1427,7 +1597,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { ValueLatticeElement &IV = ValueState[&I]; if (SCCPSolver::isConstant(PtrVal)) { - Constant *Ptr = getConstant(PtrVal); + Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType()); // load null is undefined. if (isa<ConstantPointerNull>(Ptr)) { @@ -1490,7 +1660,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&CB); assert(SCCPSolver::isConstant(State) && "Unknown state!"); - Operands.push_back(getConstant(State)); + Operands.push_back(getConstant(State, A->getType())); } if (SCCPSolver::isOverdefined(getValueState(&CB))) @@ -1622,6 +1792,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { SmallVector<ConstantRange, 2> OpRanges; for (Value *Op : II->args()) { const ValueLatticeElement &State = getValueState(Op); + if (State.isUnknownOrUndef()) + return; OpRanges.push_back(getConstantRange(State, Op->getType())); } @@ -1666,6 +1838,7 @@ void SCCPInstVisitor::solve() { // things to overdefined more quickly. while (!OverdefinedInstWorkList.empty()) { Value *I = OverdefinedInstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); @@ -1682,6 +1855,7 @@ void SCCPInstVisitor::solve() { // Process the instruction work list. while (!InstWorkList.empty()) { Value *I = InstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); @@ -1709,6 +1883,61 @@ void SCCPInstVisitor::solve() { } } +bool SCCPInstVisitor::resolvedUndef(Instruction &I) { + // Look for instructions which produce undef values. + if (I.getType()->isVoidTy()) + return false; + + if (auto *STy = dyn_cast<StructType>(I.getType())) { + // Only a few things that can be structs matter for undef. + + // Tracked calls must never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) + if (MRVFunctionsTracked.count(F)) + return false; + + // extractvalue and insertvalue don't need to be marked; they are + // tracked as precisely as their operands. + if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) + return false; + // Send the results of everything else to overdefined. We could be + // more precise than this but it isn't worth bothering. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + ValueLatticeElement &LV = getStructValueState(&I, i); + if (LV.isUnknown()) { + markOverdefined(LV, &I); + return true; + } + } + return false; + } + + ValueLatticeElement &LV = getValueState(&I); + if (!LV.isUnknown()) + return false; + + // There are two reasons a call can have an undef result + // 1. It could be tracked. + // 2. It could be constant-foldable. + // Because of the way we solve return values, tracked calls must + // never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) + if (TrackedRetVals.count(F)) + return false; + + if (isa<LoadInst>(I)) { + // A load here means one of two things: a load of undef from a global, + // a load from an unknown pointer. Either way, having it return undef + // is okay. + return false; + } + + markOverdefined(&I); + return true; +} + /// While solving the dataflow for a function, we don't compute a result for /// operations with an undef operand, to allow undef to be lowered to a /// constant later. For example, constant folding of "zext i8 undef to i16" @@ -1728,60 +1957,8 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { if (!BBExecutable.count(&BB)) continue; - for (Instruction &I : BB) { - // Look for instructions which produce undef values. - if (I.getType()->isVoidTy()) - continue; - - if (auto *STy = dyn_cast<StructType>(I.getType())) { - // Only a few things that can be structs matter for undef. - - // Tracked calls must never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast<CallBase>(&I)) - if (Function *F = CB->getCalledFunction()) - if (MRVFunctionsTracked.count(F)) - continue; - - // extractvalue and insertvalue don't need to be marked; they are - // tracked as precisely as their operands. - if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) - continue; - // Send the results of everything else to overdefined. We could be - // more precise than this but it isn't worth bothering. - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - ValueLatticeElement &LV = getStructValueState(&I, i); - if (LV.isUnknown()) { - markOverdefined(LV, &I); - MadeChange = true; - } - } - continue; - } - - ValueLatticeElement &LV = getValueState(&I); - if (!LV.isUnknown()) - continue; - - // There are two reasons a call can have an undef result - // 1. It could be tracked. - // 2. It could be constant-foldable. - // Because of the way we solve return values, tracked calls must - // never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast<CallBase>(&I)) - if (Function *F = CB->getCalledFunction()) - if (TrackedRetVals.count(F)) - continue; - - if (isa<LoadInst>(I)) { - // A load here means one of two things: a load of undef from a global, - // a load from an unknown pointer. Either way, having it return undef - // is okay. - continue; - } - - markOverdefined(&I); - MadeChange = true; - } + for (Instruction &I : BB) + MadeChange |= resolvedUndef(I); } LLVM_DEBUG(if (MadeChange) dbgs() @@ -1802,8 +1979,9 @@ SCCPSolver::SCCPSolver( SCCPSolver::~SCCPSolver() = default; -void SCCPSolver::addAnalysis(Function &F, AnalysisResultsForFn A) { - return Visitor->addAnalysis(F, std::move(A)); +void SCCPSolver::addPredicateInfo(Function &F, DominatorTree &DT, + AssumptionCache &AC) { + Visitor->addPredicateInfo(F, DT, AC); } bool SCCPSolver::markBlockExecutable(BasicBlock *BB) { @@ -1814,12 +1992,6 @@ const PredicateBase *SCCPSolver::getPredicateInfoFor(Instruction *I) { return Visitor->getPredicateInfoFor(I); } -const LoopInfo &SCCPSolver::getLoopInfo(Function &F) { - return Visitor->getLoopInfo(F); -} - -DomTreeUpdater SCCPSolver::getDTU(Function &F) { return Visitor->getDTU(F); } - void SCCPSolver::trackValueOfGlobalVariable(GlobalVariable *GV) { Visitor->trackValueOfGlobalVariable(GV); } @@ -1859,6 +2031,10 @@ SCCPSolver::solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) { Visitor->solveWhileResolvedUndefsIn(WorkList); } +void SCCPSolver::solveWhileResolvedUndefs() { + Visitor->solveWhileResolvedUndefs(); +} + bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const { return Visitor->isBlockExecutable(BB); } @@ -1876,6 +2052,10 @@ void SCCPSolver::removeLatticeValueFor(Value *V) { return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::resetLatticeValueFor(CallBase *Call) { + Visitor->resetLatticeValueFor(Call); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } @@ -1900,17 +2080,22 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) { return Visitor->isStructLatticeConstant(F, STy); } -Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const { - return Visitor->getConstant(LV); +Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + return Visitor->getConstant(LV, Ty); +} + +Constant *SCCPSolver::getConstantOrNull(Value *V) const { + return Visitor->getConstantOrNull(V); } SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() { return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl<ArgInfo> &Args) { - Visitor->markArgInFuncSpecialization(F, Args); +void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args) { + Visitor->setLatticeValueForSpecializationArguments(F, Args); } void SCCPSolver::markFunctionUnreachable(Function *F) { |