diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopFlatten.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopFlatten.cpp | 59 |
1 files changed, 43 insertions, 16 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index f36193fc468e..7d9ce8d35e0b 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -75,6 +75,7 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -99,6 +100,7 @@ static cl::opt<bool> cl::desc("Widen the loop induction variables, if possible, so " "overflow checks won't reject flattening")); +namespace { // We require all uses of both induction variables to match this pattern: // // (OuterPHI * InnerTripCount) + InnerPHI @@ -139,7 +141,7 @@ struct FlattenInfo { PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV - // has been apllied. Used to skip + // has been applied. Used to skip // checks on phi nodes. FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; @@ -191,7 +193,7 @@ struct FlattenInfo { bool matchLinearIVUser(User *U, Value *InnerTripCount, SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { - LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); + LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump()); Value *MatchedMul = nullptr; Value *MatchedItCount = nullptr; @@ -211,6 +213,18 @@ struct FlattenInfo { if (!MatchedItCount) return false; + LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump()); + LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump()); + + // The mul should not have any other uses. Widening may leave trivially dead + // uses, which can be ignored. + if (count_if(MatchedMul->users(), [](User *U) { + return !isInstructionTriviallyDead(cast<Instruction>(U)); + }) > 1) { + LLVM_DEBUG(dbgs() << "Multiply has more than one use\n"); + return false; + } + // Look through extends if the IV has been widened. Don't look through // extends if we already looked through a trunc. if (Widened && IsAdd && @@ -222,8 +236,11 @@ struct FlattenInfo { : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0); } + LLVM_DEBUG(dbgs() << "Looking for inner trip count: "; + InnerTripCount->dump()); + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); LinearIVUses.insert(U); return true; @@ -240,8 +257,11 @@ struct FlattenInfo { SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0); for (User *U : InnerInductionPHI->users()) { - if (isInnerLoopIncrement(U)) + LLVM_DEBUG(dbgs() << "Checking User: "; U->dump()); + if (isInnerLoopIncrement(U)) { + LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n"); continue; + } // After widening the IVs, a trunc instruction might have been introduced, // so look through truncs. @@ -255,15 +275,21 @@ struct FlattenInfo { // branch) then the compare has been altered by another transformation e.g // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is // a constant. Ignore this use as the compare gets removed later anyway. - if (isInnerLoopTest(U)) + if (isInnerLoopTest(U)) { + LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n"); continue; + } - if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) + if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) { + LLVM_DEBUG(dbgs() << "Not a linear IV user\n"); return false; + } + LLVM_DEBUG(dbgs() << "Linear IV users found!\n"); } return true; } }; +} // namespace static bool setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, @@ -413,7 +439,8 @@ static bool findLoopComponents( // increment variable. Increment = cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); - if (Increment->hasNUsesOrMore(3)) { + if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) && + !Increment->hasNUses(1)) { LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } @@ -540,7 +567,7 @@ checkOuterLoopInsts(FlattenInfo &FI, // they make a net difference of zero. if (IterationInstructions.count(&I)) continue; - // The uncoditional branch to the inner loop's header will turn into + // The unconditional branch to the inner loop's header will turn into // a fall-through, so adds no cost. BranchInst *Br = dyn_cast<BranchInst>(&I); if (Br && Br->isUnconditional() && @@ -552,7 +579,7 @@ checkOuterLoopInsts(FlattenInfo &FI, m_Specific(FI.InnerTripCount)))) continue; InstructionCost Cost = - TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); RepeatedInstrCost += Cost; } @@ -759,9 +786,9 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, } // Tell LoopInfo, SCEV and the pass manager that the inner loop has been - // deleted, and any information that have about the outer loop invalidated. + // deleted, and invalidate any outer loop information. SE->forgetLoop(FI.OuterLoop); - SE->forgetLoop(FI.InnerLoop); + SE->forgetBlockAndLoopDispositions(); if (U) U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName()); LI->erase(FI.InnerLoop); @@ -911,7 +938,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, bool Changed = false; - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); if (VerifyMemorySSA) @@ -923,7 +950,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU ? MSSAU.getPointer() : nullptr); + MSSAU ? &*MSSAU : nullptr); if (!Changed) return PreservedAnalyses::all(); @@ -981,15 +1008,15 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) { auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (MSSA) MSSAU = MemorySSAUpdater(&MSSA->getMSSA()); bool Changed = false; for (Loop *L : *LI) { auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, - MSSAU ? MSSAU.getPointer() : nullptr); + Changed |= + Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr); } return Changed; } |
