diff options
Diffstat (limited to 'llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | 1116 |
1 files changed, 844 insertions, 272 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp index 6717d4706aefe..be75d6bef08c4 100644 --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -35,6 +35,20 @@ /// are defined to be as large as this maximum sequence of replacement /// instructions. /// +/// A note on VPR.P0 (the lane mask): +/// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a +/// "VPT Active" context (which includes low-overhead loops and vpt blocks). +/// They will simply "and" the result of their calculation with the current +/// value of VPR.P0. You can think of it like this: +/// \verbatim +/// if VPT active: ; Between a DLSTP/LETP, or for predicated instrs +/// VPR.P0 &= Value +/// else +/// VPR.P0 = Value +/// \endverbatim +/// When we're inside the low-overhead loop (between DLSTP and LETP), we always +/// fall in the "VPT active" case, so we can consider that all VPR writes by +/// one of those instruction is actually a "and". //===----------------------------------------------------------------------===// #include "ARM.h" @@ -45,6 +59,7 @@ #include "Thumb2InstrInfo.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/CodeGen/LivePhysRegs.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineLoopUtils.h" @@ -60,34 +75,93 @@ using namespace llvm; namespace { + using InstSet = SmallPtrSetImpl<MachineInstr *>; + + class PostOrderLoopTraversal { + MachineLoop &ML; + MachineLoopInfo &MLI; + SmallPtrSet<MachineBasicBlock*, 4> Visited; + SmallVector<MachineBasicBlock*, 4> Order; + + public: + PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI) + : ML(ML), MLI(MLI) { } + + const SmallVectorImpl<MachineBasicBlock*> &getOrder() const { + return Order; + } + + // Visit all the blocks within the loop, as well as exit blocks and any + // blocks properly dominating the header. + void ProcessLoop() { + std::function<void(MachineBasicBlock*)> Search = [this, &Search] + (MachineBasicBlock *MBB) -> void { + if (Visited.count(MBB)) + return; + + Visited.insert(MBB); + for (auto *Succ : MBB->successors()) { + if (!ML.contains(Succ)) + continue; + Search(Succ); + } + Order.push_back(MBB); + }; + + // Insert exit blocks. + SmallVector<MachineBasicBlock*, 2> ExitBlocks; + ML.getExitBlocks(ExitBlocks); + for (auto *MBB : ExitBlocks) + Order.push_back(MBB); + + // Then add the loop body. + Search(ML.getHeader()); + + // Then try the preheader and its predecessors. + std::function<void(MachineBasicBlock*)> GetPredecessor = + [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void { + Order.push_back(MBB); + if (MBB->pred_size() == 1) + GetPredecessor(*MBB->pred_begin()); + }; + + if (auto *Preheader = ML.getLoopPreheader()) + GetPredecessor(Preheader); + else if (auto *Preheader = MLI.findLoopPreheader(&ML, true)) + GetPredecessor(Preheader); + } + }; + struct PredicatedMI { MachineInstr *MI = nullptr; SetVector<MachineInstr*> Predicates; public: - PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) : - MI(I) { + PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) { + assert(I && "Instruction must not be null!"); Predicates.insert(Preds.begin(), Preds.end()); } }; - // Represent a VPT block, a list of instructions that begins with a VPST and - // has a maximum of four proceeding instructions. All instructions within the - // block are predicated upon the vpr and we allow instructions to define the - // vpr within in the block too. + // Represent a VPT block, a list of instructions that begins with a VPT/VPST + // and has a maximum of four proceeding instructions. All instructions within + // the block are predicated upon the vpr and we allow instructions to define + // the vpr within in the block too. class VPTBlock { - std::unique_ptr<PredicatedMI> VPST; + // The predicate then instruction, which is either a VPT, or a VPST + // instruction. + std::unique_ptr<PredicatedMI> PredicateThen; PredicatedMI *Divergent = nullptr; SmallVector<PredicatedMI, 4> Insts; public: VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { - VPST = std::make_unique<PredicatedMI>(MI, Preds); + PredicateThen = std::make_unique<PredicatedMI>(MI, Preds); } void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); - if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) { + if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) { Divergent = &Insts.back(); LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); } @@ -104,38 +178,73 @@ namespace { // Is the given instruction part of the predicate set controlling the entry // to the block. bool IsPredicatedOn(MachineInstr *MI) const { - return VPST->Predicates.count(MI); + return PredicateThen->Predicates.count(MI); + } + + // Returns true if this is a VPT instruction. + bool isVPT() const { return !isVPST(); } + + // Returns true if this is a VPST instruction. + bool isVPST() const { + return PredicateThen->MI->getOpcode() == ARM::MVE_VPST; } // Is the given instruction the only predicate which controls the entry to // the block. bool IsOnlyPredicatedOn(MachineInstr *MI) const { - return IsPredicatedOn(MI) && VPST->Predicates.size() == 1; + return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1; } unsigned size() const { return Insts.size(); } SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } - MachineInstr *getVPST() const { return VPST->MI; } + MachineInstr *getPredicateThen() const { return PredicateThen->MI; } PredicatedMI *getDivergent() const { return Divergent; } }; + struct Reduction { + MachineInstr *Init; + MachineInstr &Copy; + MachineInstr &Reduce; + MachineInstr &VPSEL; + + Reduction(MachineInstr *Init, MachineInstr *Mov, MachineInstr *Add, + MachineInstr *Sel) + : Init(Init), Copy(*Mov), Reduce(*Add), VPSEL(*Sel) { } + }; + struct LowOverheadLoop { - MachineLoop *ML = nullptr; + MachineLoop &ML; + MachineBasicBlock *Preheader = nullptr; + MachineLoopInfo &MLI; + ReachingDefAnalysis &RDA; + const TargetRegisterInfo &TRI; + const ARMBaseInstrInfo &TII; MachineFunction *MF = nullptr; MachineInstr *InsertPt = nullptr; MachineInstr *Start = nullptr; MachineInstr *Dec = nullptr; MachineInstr *End = nullptr; MachineInstr *VCTP = nullptr; + SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs; VPTBlock *CurrentBlock = nullptr; SetVector<MachineInstr*> CurrentPredicate; SmallVector<VPTBlock, 4> VPTBlocks; + SmallPtrSet<MachineInstr*, 4> ToRemove; + SmallVector<std::unique_ptr<Reduction>, 1> Reductions; + SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute; bool Revert = false; bool CannotTailPredicate = false; - LowOverheadLoop(MachineLoop *ML) : ML(ML) { - MF = ML->getHeader()->getParent(); + LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI, + ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI, + const ARMBaseInstrInfo &TII) + : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII) { + MF = ML.getHeader()->getParent(); + if (auto *MBB = ML.getLoopPreheader()) + Preheader = MBB; + else if (auto *MBB = MLI.findLoopPreheader(&ML, true)) + Preheader = MBB; } // If this is an MVE instruction, check that we know how to use tail @@ -151,22 +260,30 @@ namespace { // For now, let's keep things really simple and only support a single // block for tail predication. return !Revert && FoundAllComponents() && VCTP && - !CannotTailPredicate && ML->getNumBlocks() == 1; + !CannotTailPredicate && ML.getNumBlocks() == 1; } - bool ValidateTailPredicate(MachineInstr *StartInsertPt, - ReachingDefAnalysis *RDA, - MachineLoopInfo *MLI); + // Check that the predication in the loop will be equivalent once we + // perform the conversion. Also ensure that we can provide the number + // of elements to the loop start instruction. + bool ValidateTailPredicate(MachineInstr *StartInsertPt); + + // See whether the live-out instructions are a reduction that we can fixup + // later. + bool FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers); + + // Check that any values available outside of the loop will be the same + // after tail predication conversion. + bool ValidateLiveOuts(); // Is it safe to define LR with DLS/WLS? // LR can be defined if it is the operand to start, because it's the same // value, or if it's going to be equivalent to the operand to Start. - MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA); + MachineInstr *isSafeToDefineLR(); // Check the branch targets are within range and we satisfy our // restrictions. - void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA, - MachineLoopInfo *MLI); + void CheckLegality(ARMBasicBlockUtils *BBUtils); bool FoundAllComponents() const { return Start && Dec && End; @@ -241,18 +358,19 @@ namespace { void RevertWhile(MachineInstr *MI) const; - bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const; + bool RevertLoopDec(MachineInstr *MI) const; void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; - void RemoveLoopUpdate(LowOverheadLoop &LoLoop); - void ConvertVPTBlocks(LowOverheadLoop &LoLoop); + void FixupReductions(LowOverheadLoop &LoLoop) const; + MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); void Expand(LowOverheadLoop &LoLoop); + void IterationCountDCE(LowOverheadLoop &LoLoop); }; } @@ -261,7 +379,7 @@ char ARMLowOverheadLoops::ID = 0; INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, false, false) -MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) { +MachineInstr *LowOverheadLoop::isSafeToDefineLR() { // We can define LR because LR already contains the same value. if (Start->getOperand(0).getReg() == ARM::LR) return Start; @@ -279,52 +397,22 @@ MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) { // Find an insertion point: // - Is there a (mov lr, Count) before Start? If so, and nothing else writes // to Count before Start, we can insert at that mov. - if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR)) - if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) + if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR)) + if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) return LRDef; // - Is there a (mov lr, Count) after Start? If so, and nothing else writes // to Count after Start, we can insert at that mov. - if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR)) - if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) + if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR)) + if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) return LRDef; // We've found no suitable LR def and Start doesn't use LR directly. Can we // just define LR anyway? - if (!RDA->isRegUsedAfter(Start, ARM::LR)) - return Start; - - return nullptr; -} - -// Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must -// not define a register that is used by any instructions, after and including, -// 'To'. These instructions also must not redefine any of Froms operands. -template<typename Iterator> -static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) { - SmallSet<int, 2> Defs; - // First check that From would compute the same value if moved. - for (auto &MO : From->operands()) { - if (!MO.isReg() || MO.isUndef() || !MO.getReg()) - continue; - if (MO.isDef()) - Defs.insert(MO.getReg()); - else if (!RDA->hasSameReachingDef(From, To, MO.getReg())) - return false; - } - - // Now walk checking that the rest of the instructions will compute the same - // value. - for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) { - for (auto &MO : I->operands()) - if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg())) - return false; - } - return true; + return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr; } -bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, - ReachingDefAnalysis *RDA, MachineLoopInfo *MLI) { +bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) { assert(VCTP && "VCTP instruction expected but is not set"); // All predication within the loop should be based on vctp. If the block // isn't predicated on entry, check whether the vctp is within the block @@ -332,24 +420,35 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, for (auto &Block : VPTBlocks) { if (Block.IsPredicatedOn(VCTP)) continue; - if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) { + if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) { LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: " - << *Block.getDivergent()->MI); + << *Block.getDivergent()->MI); return false; } SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); for (auto &PredMI : Insts) { - if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI)) + // Check the instructions in the block and only allow: + // - VCTPs + // - Instructions predicated on the main VCTP + // - Any VCMP + // - VCMPs just "and" their result with VPR.P0. Whether they are + // located before/after the VCTP is irrelevant - the end result will + // be the same in both cases, so there's no point in requiring them + // to be located after the VCTP! + if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) || + VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0) continue; LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI - << " - which is predicated on:\n"; - for (auto *MI : PredMI.Predicates) - dbgs() << " - " << *MI; - ); + << " - which is predicated on:\n"; + for (auto *MI : PredMI.Predicates) + dbgs() << " - " << *MI); return false; } } + if (!ValidateLiveOuts()) + return false; + // For tail predication, we need to provide the number of elements, instead // of the iteration count, to the loop start instruction. The number of // elements is provided to the vctp instruction, so we need to check that @@ -359,7 +458,7 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, // If the register is defined within loop, then we can't perform TP. // TODO: Check whether this is just a mov of a register that would be // available. - if (RDA->getReachingDef(VCTP, NumElements) >= 0) { + if (RDA.hasLocalDefBefore(VCTP, NumElements)) { LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n"); return false; } @@ -367,17 +466,20 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, // The element count register maybe defined after InsertPt, in which case we // need to try to move either InsertPt or the def so that the [w|d]lstp can // use the value. - MachineBasicBlock *InsertBB = InsertPt->getParent(); - if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) { - if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) { - if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) { + // TODO: On failing to move an instruction, check if the count is provided by + // a mov and whether we can use the mov operand directly. + MachineBasicBlock *InsertBB = StartInsertPt->getParent(); + if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) { + if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) { + if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) { ElemDef->removeFromParent(); - InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef); + InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef); LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: " << *ElemDef); - } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) { - InsertPt->removeFromParent(); - InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt); + } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) { + StartInsertPt->removeFromParent(); + InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), + StartInsertPt); LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); } else { LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop " @@ -390,10 +492,10 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, // Especially in the case of while loops, InsertBB may not be the // preheader, so we need to check that the register isn't redefined // before entering the loop. - auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB, + auto CannotProvideElements = [this](MachineBasicBlock *MBB, Register NumElements) { // NumElements is redefined in this block. - if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0) + if (RDA.hasLocalDefBefore(&MBB->back(), NumElements)) return true; // Don't continue searching up through multiple predecessors. @@ -404,7 +506,7 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, }; // First, find the block that looks like the preheader. - MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true); + MachineBasicBlock *MBB = Preheader; if (!MBB) { LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n"); return false; @@ -419,13 +521,372 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, MBB = *MBB->pred_begin(); } - LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n"); + // Check that the value change of the element count is what we expect and + // that the predication will be equivalent. For this we need: + // NumElements = NumElements - VectorWidth. The sub will be a sub immediate + // and we can also allow register copies within the chain too. + auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) { + return -getAddSubImmediate(*MI) == ExpectedVecWidth; + }; + + MBB = VCTP->getParent(); + if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) { + SmallPtrSet<MachineInstr*, 2> ElementChain; + SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP }; + unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode()); + + Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end()); + + if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) { + bool FoundSub = false; + + for (auto *MI : ElementChain) { + if (isMovRegOpcode(MI->getOpcode())) + continue; + + if (isSubImmOpcode(MI->getOpcode())) { + if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth)) + return false; + FoundSub = true; + } else + return false; + } + + LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n"; + for (auto *MI : ElementChain) + dbgs() << " - " << *MI); + ToRemove.insert(ElementChain.begin(), ElementChain.end()); + } + } + return true; +} + +static bool isVectorPredicated(MachineInstr *MI) { + int PIdx = llvm::findFirstVPTPredOperandIdx(*MI); + return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR; +} + +static bool isRegInClass(const MachineOperand &MO, + const TargetRegisterClass *Class) { + return MO.isReg() && MO.getReg() && Class->contains(MO.getReg()); +} + +// MVE 'narrowing' operate on half a lane, reading from half and writing +// to half, which are referred to has the top and bottom half. The other +// half retains its previous value. +static bool retainsPreviousHalfElement(const MachineInstr &MI) { + const MCInstrDesc &MCID = MI.getDesc(); + uint64_t Flags = MCID.TSFlags; + return (Flags & ARMII::RetainsPreviousHalfElement) != 0; +} + +// Some MVE instructions read from the top/bottom halves of their operand(s) +// and generate a vector result with result elements that are double the +// width of the input. +static bool producesDoubleWidthResult(const MachineInstr &MI) { + const MCInstrDesc &MCID = MI.getDesc(); + uint64_t Flags = MCID.TSFlags; + return (Flags & ARMII::DoubleWidthResult) != 0; +} + +static bool isHorizontalReduction(const MachineInstr &MI) { + const MCInstrDesc &MCID = MI.getDesc(); + uint64_t Flags = MCID.TSFlags; + return (Flags & ARMII::HorizontalReduction) != 0; +} + +// Can this instruction generate a non-zero result when given only zeroed +// operands? This allows us to know that, given operands with false bytes +// zeroed by masked loads, that the result will also contain zeros in those +// bytes. +static bool canGenerateNonZeros(const MachineInstr &MI) { + + // Check for instructions which can write into a larger element size, + // possibly writing into a previous zero'd lane. + if (producesDoubleWidthResult(MI)) + return true; + + switch (MI.getOpcode()) { + default: + break; + // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow + // fp16 -> fp32 vector conversions. + // Instructions that perform a NOT will generate 1s from 0s. + case ARM::MVE_VMVN: + case ARM::MVE_VORN: + // Count leading zeros will do just that! + case ARM::MVE_VCLZs8: + case ARM::MVE_VCLZs16: + case ARM::MVE_VCLZs32: + return true; + } + return false; +} + + +// Look at its register uses to see if it only can only receive zeros +// into its false lanes which would then produce zeros. Also check that +// the output register is also defined by an FalseLanesZero instruction +// so that if tail-predication happens, the lanes that aren't updated will +// still be zeros. +static bool producesFalseLanesZero(MachineInstr &MI, + const TargetRegisterClass *QPRs, + const ReachingDefAnalysis &RDA, + InstSet &FalseLanesZero) { + if (canGenerateNonZeros(MI)) + return false; + + bool AllowScalars = isHorizontalReduction(MI); + for (auto &MO : MI.operands()) { + if (!MO.isReg() || !MO.getReg()) + continue; + if (!isRegInClass(MO, QPRs) && AllowScalars) + continue; + if (auto *OpDef = RDA.getMIOperand(&MI, MO)) + if (FalseLanesZero.count(OpDef)) + continue; + return false; + } + LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI); + return true; +} + +bool +LowOverheadLoop::FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers) { + // Also check for reductions where the operation needs to be merging values + // from the last and previous loop iterations. This means an instruction + // producing a value and a vmov storing the value calculated in the previous + // iteration. So we can have two live-out regs, one produced by a vmov and + // both being consumed by a vpsel. + LLVM_DEBUG(dbgs() << "ARM Loops: Looking for reduction live-outs:\n"; + for (auto *MI : LiveMIs) + dbgs() << " - " << *MI); + + if (!Preheader) + return false; + + // Expect a vmov, a vadd and a single vpsel user. + // TODO: This means we can't currently support multiple reductions in the + // loop. + if (LiveMIs.size() != 2 || LiveOutUsers.size() != 1) + return false; + + MachineInstr *VPSEL = *LiveOutUsers.begin(); + if (VPSEL->getOpcode() != ARM::MVE_VPSEL) + return false; + + unsigned VPRIdx = llvm::findFirstVPTPredOperandIdx(*VPSEL) + 1; + MachineInstr *Pred = RDA.getMIOperand(VPSEL, VPRIdx); + if (!Pred || Pred != VCTP) { + LLVM_DEBUG(dbgs() << "ARM Loops: Not using equivalent predicate.\n"); + return false; + } + + MachineInstr *Reduce = RDA.getMIOperand(VPSEL, 1); + if (!Reduce) + return false; + + assert(LiveMIs.count(Reduce) && "Expected MI to be live-out"); + + // TODO: Support more operations than VADD. + switch (VCTP->getOpcode()) { + default: + return false; + case ARM::MVE_VCTP8: + if (Reduce->getOpcode() != ARM::MVE_VADDi8) + return false; + break; + case ARM::MVE_VCTP16: + if (Reduce->getOpcode() != ARM::MVE_VADDi16) + return false; + break; + case ARM::MVE_VCTP32: + if (Reduce->getOpcode() != ARM::MVE_VADDi32) + return false; + break; + } + + // Test that the reduce op is overwriting ones of its operands. + if (Reduce->getOperand(0).getReg() != Reduce->getOperand(1).getReg() && + Reduce->getOperand(0).getReg() != Reduce->getOperand(2).getReg()) { + LLVM_DEBUG(dbgs() << "ARM Loops: Reducing op isn't overwriting itself.\n"); + return false; + } + + // Check that the VORR is actually a VMOV. + MachineInstr *Copy = RDA.getMIOperand(VPSEL, 2); + if (!Copy || Copy->getOpcode() != ARM::MVE_VORR || + !Copy->getOperand(1).isReg() || !Copy->getOperand(2).isReg() || + Copy->getOperand(1).getReg() != Copy->getOperand(2).getReg()) + return false; + + assert(LiveMIs.count(Copy) && "Expected MI to be live-out"); + + // Check that the vadd and vmov are only used by each other and the vpsel. + SmallPtrSet<MachineInstr*, 2> CopyUsers; + RDA.getGlobalUses(Copy, Copy->getOperand(0).getReg(), CopyUsers); + if (CopyUsers.size() > 2 || !CopyUsers.count(Reduce)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Copy users unsupported.\n"); + return false; + } + + SmallPtrSet<MachineInstr*, 2> ReduceUsers; + RDA.getGlobalUses(Reduce, Reduce->getOperand(0).getReg(), ReduceUsers); + if (ReduceUsers.size() > 2 || !ReduceUsers.count(Copy)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Reduce users unsupported.\n"); + return false; + } + + // Then find whether there's an instruction initialising the register that + // is storing the reduction. + SmallPtrSet<MachineInstr*, 2> Incoming; + RDA.getLiveOuts(Preheader, Copy->getOperand(1).getReg(), Incoming); + if (Incoming.size() > 1) + return false; + + MachineInstr *Init = Incoming.empty() ? nullptr : *Incoming.begin(); + LLVM_DEBUG(dbgs() << "ARM Loops: Found a reduction:\n" + << " - " << *Copy + << " - " << *Reduce + << " - " << *VPSEL); + Reductions.push_back(std::make_unique<Reduction>(Init, Copy, Reduce, VPSEL)); return true; } -void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, - ReachingDefAnalysis *RDA, - MachineLoopInfo *MLI) { +bool LowOverheadLoop::ValidateLiveOuts() { + // We want to find out if the tail-predicated version of this loop will + // produce the same values as the loop in its original form. For this to + // be true, the newly inserted implicit predication must not change the + // the (observable) results. + // We're doing this because many instructions in the loop will not be + // predicated and so the conversion from VPT predication to tail-predication + // can result in different values being produced; due to the tail-predication + // preventing many instructions from updating their falsely predicated + // lanes. This analysis assumes that all the instructions perform lane-wise + // operations and don't perform any exchanges. + // A masked load, whether through VPT or tail predication, will write zeros + // to any of the falsely predicated bytes. So, from the loads, we know that + // the false lanes are zeroed and here we're trying to track that those false + // lanes remain zero, or where they change, the differences are masked away + // by their user(s). + // All MVE loads and stores have to be predicated, so we know that any load + // operands, or stored results are equivalent already. Other explicitly + // predicated instructions will perform the same operation in the original + // loop and the tail-predicated form too. Because of this, we can insert + // loads, stores and other predicated instructions into our Predicated + // set and build from there. + const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID); + SetVector<MachineInstr *> FalseLanesUnknown; + SmallPtrSet<MachineInstr *, 4> FalseLanesZero; + SmallPtrSet<MachineInstr *, 4> Predicated; + MachineBasicBlock *Header = ML.getHeader(); + + for (auto &MI : *Header) { + const MCInstrDesc &MCID = MI.getDesc(); + uint64_t Flags = MCID.TSFlags; + if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) + continue; + + if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode())) + continue; + + // Predicated loads will write zeros to the falsely predicated bytes of the + // destination register. + if (isVectorPredicated(&MI)) { + if (MI.mayLoad()) + FalseLanesZero.insert(&MI); + Predicated.insert(&MI); + continue; + } + + if (MI.getNumDefs() == 0) + continue; + + if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) { + // We require retaining and horizontal operations to operate upon zero'd + // false lanes to ensure the conversion doesn't change the output. + if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI)) + return false; + // Otherwise we need to evaluate this instruction later to see whether + // unknown false lanes will get masked away by their user(s). + FalseLanesUnknown.insert(&MI); + } else if (!isHorizontalReduction(MI)) + FalseLanesZero.insert(&MI); + } + + auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO, + SmallPtrSetImpl<MachineInstr *> &Predicated) { + SmallPtrSet<MachineInstr *, 2> Uses; + RDA.getGlobalUses(MI, MO.getReg(), Uses); + for (auto *Use : Uses) { + if (Use != MI && !Predicated.count(Use)) + return false; + } + return true; + }; + + // Visit the unknowns in reverse so that we can start at the values being + // stored and then we can work towards the leaves, hopefully adding more + // instructions to Predicated. Successfully terminating the loop means that + // all the unknown values have to found to be masked by predicated user(s). + // For any unpredicated values, we store them in NonPredicated so that we + // can later check whether these form a reduction. + SmallPtrSet<MachineInstr*, 2> NonPredicated; + for (auto *MI : reverse(FalseLanesUnknown)) { + for (auto &MO : MI->operands()) { + if (!isRegInClass(MO, QPRs) || !MO.isDef()) + continue; + if (!HasPredicatedUsers(MI, MO, Predicated)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : " + << TRI.getRegAsmName(MO.getReg()) << " at " << *MI); + NonPredicated.insert(MI); + continue; + } + } + // Any unknown false lanes have been masked away by the user(s). + Predicated.insert(MI); + } + + SmallPtrSet<MachineInstr *, 2> LiveOutMIs; + SmallPtrSet<MachineInstr*, 2> LiveOutUsers; + SmallVector<MachineBasicBlock *, 2> ExitBlocks; + ML.getExitBlocks(ExitBlocks); + assert(ML.getNumBlocks() == 1 && "Expected single block loop!"); + assert(ExitBlocks.size() == 1 && "Expected a single exit block"); + MachineBasicBlock *ExitBB = ExitBlocks.front(); + for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) { + // Check Q-regs that are live in the exit blocks. We don't collect scalars + // because they won't be affected by lane predication. + if (QPRs->contains(RegMask.PhysReg)) { + if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg)) + LiveOutMIs.insert(MI); + RDA.getLiveInUses(ExitBB, RegMask.PhysReg, LiveOutUsers); + } + } + + // If we have any non-predicated live-outs, they need to be part of a + // reduction that we can fixup later. The reduction that the form of an + // operation that uses its previous values through a vmov and then a vpsel + // resides in the exit blocks to select the final bytes from n and n-1 + // iterations. + if (!NonPredicated.empty() && + !FindValidReduction(NonPredicated, LiveOutUsers)) + return false; + + // We've already validated that any VPT predication within the loop will be + // equivalent when we perform the predication transformation; so we know that + // any VPT predicated instruction is predicated upon VCTP. Any live-out + // instruction needs to be predicated, so check this here. The instructions + // in NonPredicated have been found to be a reduction that we can ensure its + // legality. + for (auto *MI : LiveOutMIs) + if (!isVectorPredicated(MI) && !NonPredicated.count(MI)) + return false; + + return true; +} + +void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) { if (Revert) return; @@ -434,7 +895,7 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, // TODO Maybe there's cases where the target doesn't have to be the header, // but for now be safe and revert. - if (End->getOperand(1).getMBB() != ML->getHeader()) { + if (End->getOperand(1).getMBB() != ML.getHeader()) { LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); Revert = true; return; @@ -442,8 +903,8 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, // The WLS and LE instructions have 12-bits for the label offset. WLS // requires a positive offset, while LE uses negative. - if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) || - !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) { + if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) || + !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) { LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); Revert = true; return; @@ -458,7 +919,7 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, return; } - InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA); + InsertPt = Revert ? nullptr : isSafeToDefineLR(); if (!InsertPt) { LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); Revert = true; @@ -473,9 +934,9 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, return; } - assert(ML->getBlocks().size() == 1 && + assert(ML.getBlocks().size() == 1 && "Shouldn't be processing a loop with more than one block"); - CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI); + CannotTailPredicate = !ValidateTailPredicate(InsertPt); LLVM_DEBUG(if (CannotTailPredicate) dbgs() << "ARM Loops: Couldn't validate tail predicate.\n"); } @@ -484,29 +945,44 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) { if (CannotTailPredicate) return false; - // Only support a single vctp. - if (isVCTP(MI) && VCTP) - return false; + if (isVCTP(MI)) { + // If we find another VCTP, check whether it uses the same value as the main VCTP. + // If it does, store it in the SecondaryVCTPs set, else refuse it. + if (VCTP) { + if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) || + !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching " + "definition from the main VCTP"); + return false; + } + LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI); + SecondaryVCTPs.insert(MI); + } else { + LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI); + VCTP = MI; + } + } else if (isVPTOpcode(MI->getOpcode())) { + if (MI->getOpcode() != ARM::MVE_VPST) { + assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 && + "VPT does not implicitly define VPR?!"); + CurrentPredicate.insert(MI); + } - // Start a new vpt block when we discover a vpt. - if (MI->getOpcode() == ARM::MVE_VPST) { VPTBlocks.emplace_back(MI, CurrentPredicate); CurrentBlock = &VPTBlocks.back(); return true; - } else if (isVCTP(MI)) - VCTP = MI; - else if (MI->getOpcode() == ARM::MVE_VPSEL || - MI->getOpcode() == ARM::MVE_VPNOT) + } else if (MI->getOpcode() == ARM::MVE_VPSEL || + MI->getOpcode() == ARM::MVE_VPNOT) { + // TODO: Allow VPSEL and VPNOT, we currently cannot because: + // 1) It will use the VPR as a predicate operand, but doesn't have to be + // instead a VPT block, which means we can assert while building up + // the VPT block because we don't find another VPT or VPST to being a new + // one. + // 2) VPSEL still requires a VPR operand even after tail predicating, + // which means we can't remove it unless there is another + // instruction, such as vcmp, that can provide the VPR def. return false; - - // TODO: Allow VPSEL and VPNOT, we currently cannot because: - // 1) It will use the VPR as a predicate operand, but doesn't have to be - // instead a VPT block, which means we can assert while building up - // the VPT block because we don't find another VPST to being a new - // one. - // 2) VPSEL still requires a VPR operand even after tail predicating, - // which means we can't remove it unless there is another - // instruction, such as vcmp, that can provide the VPR def. + } bool IsUse = false; bool IsDef = false; @@ -548,7 +1024,9 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) { return false; } - return true; + // If the instruction is already explicitly predicated, then the conversion + // will be fine, but ensure that all memory operations are predicated. + return !IsUse && MI->mayLoadOrStore() ? false : true; } bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { @@ -591,6 +1069,8 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { dbgs() << " - " << Preheader->getName() << "\n"; else if (auto *Preheader = MLI->findLoopPreheader(ML)) dbgs() << " - " << Preheader->getName() << "\n"; + else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) + dbgs() << " - " << Preheader->getName() << "\n"; for (auto *MBB : ML->getBlocks()) dbgs() << " - " << MBB->getName() << "\n"; ); @@ -608,14 +1088,12 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { return nullptr; }; - LowOverheadLoop LoLoop(ML); + LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII); // Search the preheader for the start intrinsic. // FIXME: I don't see why we shouldn't be supporting multiple predecessors // with potentially multiple set.loop.iterations, so we need to enable this. - if (auto *Preheader = ML->getLoopPreheader()) - LoLoop.Start = SearchForStart(Preheader); - else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) - LoLoop.Start = SearchForStart(Preheader); + if (LoLoop.Preheader) + LoLoop.Start = SearchForStart(LoLoop.Preheader); else return false; @@ -624,7 +1102,9 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { // whether we can convert that predicate using tail predication. for (auto *MBB : reverse(ML->getBlocks())) { for (auto &MI : *MBB) { - if (MI.getOpcode() == ARM::t2LoopDec) + if (MI.isDebugValue()) + continue; + else if (MI.getOpcode() == ARM::t2LoopDec) LoLoop.Dec = &MI; else if (MI.getOpcode() == ARM::t2LoopEnd) LoLoop.End = &MI; @@ -641,28 +1121,6 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { // Check we know how to tail predicate any mve instructions. LoLoop.AnalyseMVEInst(&MI); } - - // We need to ensure that LR is not used or defined inbetween LoopDec and - // LoopEnd. - if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert) - continue; - - // If we find that LR has been written or read between LoopDec and - // LoopEnd, expect that the decremented value is being used else where. - // Because this value isn't actually going to be produced until the - // latch, by LE, we would need to generate a real sub. The value is also - // likely to be copied/reloaded for use of LoopEnd - in which in case - // we'd need to perform an add because it gets subtracted again by LE! - // The other option is to then generate the other form of LE which doesn't - // perform the sub. - for (auto &MO : MI.operands()) { - if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() && - MO.getReg() == ARM::LR) { - LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI); - LoLoop.Revert = true; - break; - } - } } } @@ -672,7 +1130,15 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { return false; } - LoLoop.CheckLegality(BBUtils.get(), RDA, MLI); + // Check that the only instruction using LoopDec is LoopEnd. + // TODO: Check for copy chains that really have no effect. + SmallPtrSet<MachineInstr*, 2> Uses; + RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses); + if (Uses.size() > 1 || !Uses.count(LoLoop.End)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n"); + LoLoop.Revert = true; + } + LoLoop.CheckLegality(BBUtils.get()); Expand(LoLoop); return true; } @@ -702,16 +1168,19 @@ void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const { MI->eraseFromParent(); } -bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI, - bool SetFlags) const { +bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const { LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI); MachineBasicBlock *MBB = MI->getParent(); + SmallPtrSet<MachineInstr*, 1> Ignore; + for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) { + if (I->getOpcode() == ARM::t2LoopEnd) { + Ignore.insert(&*I); + break; + } + } // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS. - if (SetFlags && - (RDA->isRegUsedAfter(MI, ARM::CPSR) || - !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR))) - SetFlags = false; + bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore); MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2SUBri)); @@ -759,7 +1228,102 @@ void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { MI->eraseFromParent(); } +// Perform dead code elimation on the loop iteration count setup expression. +// If we are tail-predicating, the number of elements to be processed is the +// operand of the VCTP instruction in the vector body, see getCount(), which is +// register $r3 in this example: +// +// $lr = big-itercount-expression +// .. +// t2DoLoopStart renamable $lr +// vector.body: +// .. +// $vpr = MVE_VCTP32 renamable $r3 +// renamable $lr = t2LoopDec killed renamable $lr, 1 +// t2LoopEnd renamable $lr, %vector.body +// tB %end +// +// What we would like achieve here is to replace the do-loop start pseudo +// instruction t2DoLoopStart with: +// +// $lr = MVE_DLSTP_32 killed renamable $r3 +// +// Thus, $r3 which defines the number of elements, is written to $lr, +// and then we want to delete the whole chain that used to define $lr, +// see the comment below how this chain could look like. +// +void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) { + if (!LoLoop.IsTailPredicationLegal()) + return; + + LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n"); + + MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0); + if (!Def) { + LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n"); + return; + } + + // Collect and remove the users of iteration count. + SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec, + LoLoop.End, LoLoop.InsertPt }; + SmallPtrSet<MachineInstr*, 2> Remove; + if (RDA->isSafeToRemove(Def, Remove, Killed)) + LoLoop.ToRemove.insert(Remove.begin(), Remove.end()); + else { + LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n"); + return; + } + + // Collect the dead code and the MBBs in which they reside. + RDA->collectKilledOperands(Def, Killed); + SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks; + for (auto *MI : Killed) + BasicBlocks.insert(MI->getParent()); + + // Collect IT blocks in all affected basic blocks. + std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks; + for (auto *MBB : BasicBlocks) { + for (auto &MI : *MBB) { + if (MI.getOpcode() != ARM::t2IT) + continue; + RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]); + } + } + + // If we're removing all of the instructions within an IT block, then + // also remove the IT instruction. + SmallPtrSet<MachineInstr*, 2> ModifiedITs; + for (auto *MI : Killed) { + if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) { + MachineInstr *IT = RDA->getMIOperand(MI, *MO); + auto &CurrentBlock = ITBlocks[IT]; + CurrentBlock.erase(MI); + if (CurrentBlock.empty()) + ModifiedITs.erase(IT); + else + ModifiedITs.insert(IT); + } + } + + // Delete the killed instructions only if we don't have any IT blocks that + // need to be modified because we need to fixup the mask. + // TODO: Handle cases where IT blocks are modified. + if (ModifiedITs.empty()) { + LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n"; + for (auto *MI : Killed) + dbgs() << " - " << *MI); + LoLoop.ToRemove.insert(Killed.begin(), Killed.end()); + } else + LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n"); +} + MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { + LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n"); + // When using tail-predication, try to delete the dead code that was used to + // calculate the number of loop iterations. + IterationCountDCE(LoLoop); + MachineInstr *InsertPt = LoLoop.InsertPt; MachineInstr *Start = LoLoop.Start; MachineBasicBlock *MBB = InsertPt->getParent(); @@ -775,109 +1339,67 @@ MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { if (!IsDo) MIB.add(Start->getOperand(1)); - // When using tail-predication, try to delete the dead code that was used to - // calculate the number of loop iterations. - if (LoLoop.IsTailPredicationLegal()) { - SmallVector<MachineInstr*, 4> Killed; - SmallVector<MachineInstr*, 4> Dead; - if (auto *Def = RDA->getReachingMIDef(Start, - Start->getOperand(0).getReg())) { - Killed.push_back(Def); - - while (!Killed.empty()) { - MachineInstr *Def = Killed.back(); - Killed.pop_back(); - Dead.push_back(Def); - for (auto &MO : Def->operands()) { - if (!MO.isReg() || !MO.isKill()) - continue; - - MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg()); - if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1) - Killed.push_back(Kill); - } - } - for (auto *MI : Dead) - MI->eraseFromParent(); - } - } - // If we're inserting at a mov lr, then remove it as it's redundant. if (InsertPt != Start) - InsertPt->eraseFromParent(); - Start->eraseFromParent(); + LoLoop.ToRemove.insert(InsertPt); + LoLoop.ToRemove.insert(Start); LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); return &*MIB; } -// Goal is to optimise and clean-up these loops: -// -// vector.body: -// renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg -// renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4 -// .. -// $lr = MVE_DLSTP_32 renamable $r3 -// -// The SUB is the old update of the loop iteration count expression, which -// is no longer needed. This sub is removed when the element count, which is in -// r3 in this example, is defined by an instruction in the loop, and it has -// no uses. -// -void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) { - Register ElemCount = LoLoop.VCTP->getOperand(1).getReg(); - MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back(); - - LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n"); - - if (LoLoop.ML->getNumBlocks() != 1) { - LLVM_DEBUG(dbgs() << "ARM Loops: Single block loop expected\n"); - return; - } - - LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing elemcount in operand: "; - LoLoop.VCTP->getOperand(1).dump()); - - // Find the definition we are interested in removing, if there is one. - MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount); - if (!Def) { - LLVM_DEBUG(dbgs() << "ARM Loops: Can't find a def, nothing to do.\n"); - return; - } - - // Bail if we define CPSR and it is not dead - if (!Def->registerDefIsDead(ARM::CPSR, TRI)) { - LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n"); - return; - } - - // Bail if elemcount is used in exit blocks, i.e. if it is live-in. - if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) { - LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n"); - return; - } +void ARMLowOverheadLoops::FixupReductions(LowOverheadLoop &LoLoop) const { + LLVM_DEBUG(dbgs() << "ARM Loops: Fixing up reduction(s).\n"); + auto BuildMov = [this](MachineInstr &InsertPt, Register To, Register From) { + MachineBasicBlock *MBB = InsertPt.getParent(); + MachineInstrBuilder MIB = + BuildMI(*MBB, &InsertPt, InsertPt.getDebugLoc(), TII->get(ARM::MVE_VORR)); + MIB.addDef(To); + MIB.addReg(From); + MIB.addReg(From); + MIB.addImm(0); + MIB.addReg(0); + MIB.addReg(To); + LLVM_DEBUG(dbgs() << "ARM Loops: Inserted VMOV: " << *MIB); + }; - // Bail if there are uses after this Def in the block. - SmallVector<MachineInstr*, 4> Uses; - RDA->getReachingLocalUses(Def, ElemCount, Uses); - if (Uses.size()) { - LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n"); - return; - } + for (auto &Reduction : LoLoop.Reductions) { + MachineInstr &Copy = Reduction->Copy; + MachineInstr &Reduce = Reduction->Reduce; + Register DestReg = Copy.getOperand(0).getReg(); - Uses.clear(); - RDA->getAllInstWithUseBefore(Def, ElemCount, Uses); + // Change the initialiser if present + if (Reduction->Init) { + MachineInstr *Init = Reduction->Init; - // Remove Def if there are no uses, or if the only use is the VCTP - // instruction. - if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) { - LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: "; - Def->dump()); - Def->eraseFromParent(); - return; + for (unsigned i = 0; i < Init->getNumOperands(); ++i) { + MachineOperand &MO = Init->getOperand(i); + if (MO.isReg() && MO.isUse() && MO.isTied() && + Init->findTiedOperandIdx(i) == 0) + Init->getOperand(i).setReg(DestReg); + } + Init->getOperand(0).setReg(DestReg); + LLVM_DEBUG(dbgs() << "ARM Loops: Changed init regs: " << *Init); + } else + BuildMov(LoLoop.Preheader->instr_back(), DestReg, Copy.getOperand(1).getReg()); + + // Change the reducing op to write to the register that is used to copy + // its value on the next iteration. Also update the tied-def operand. + Reduce.getOperand(0).setReg(DestReg); + Reduce.getOperand(5).setReg(DestReg); + LLVM_DEBUG(dbgs() << "ARM Loops: Changed reduction regs: " << Reduce); + + // Instead of a vpsel, just copy the register into the necessary one. + MachineInstr &VPSEL = Reduction->VPSEL; + if (VPSEL.getOperand(0).getReg() != DestReg) + BuildMov(VPSEL, VPSEL.getOperand(0).getReg(), DestReg); + + // Remove the unnecessary instructions. + LLVM_DEBUG(dbgs() << "ARM Loops: Removing:\n" + << " - " << Copy + << " - " << VPSEL << "\n"); + Copy.eraseFromParent(); + VPSEL.eraseFromParent(); } - - LLVM_DEBUG(dbgs() << "ARM Loops: Can't remove loop update, it's used by:\n"; - for (auto U : Uses) U->dump()); } void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { @@ -893,28 +1415,24 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { }; // There are a few scenarios which we have to fix up: - // 1) A VPT block with is only predicated by the vctp and has no internal vpr - // defs. - // 2) A VPT block which is only predicated by the vctp but has an internal - // vpr def. - // 3) A VPT block which is predicated upon the vctp as well as another vpr - // def. - // 4) A VPT block which is not predicated upon a vctp, but contains it and - // all instructions within the block are predicated upon in. - + // 1. VPT Blocks with non-uniform predicates: + // - a. When the divergent instruction is a vctp + // - b. When the block uses a vpst, and is only predicated on the vctp + // - c. When the block uses a vpt and (optionally) contains one or more + // vctp. + // 2. VPT Blocks with uniform predicates: + // - a. The block uses a vpst, and is only predicated on the vctp for (auto &Block : LoLoop.getVPTBlocks()) { SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); if (Block.HasNonUniformPredicate()) { PredicatedMI *Divergent = Block.getDivergent(); if (isVCTP(Divergent->MI)) { - // The vctp will be removed, so the size of the vpt block needs to be - // modified. - uint64_t Size = getARMVPTBlockMask(Block.size() - 1); - Block.getVPST()->getOperand(0).setImm(Size); - LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); - } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { - // The VPT block has a non-uniform predicate but it's entry is guarded - // only by a vctp, which means we: + // The vctp will be removed, so the block mask of the vp(s)t will need + // to be recomputed. + LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen()); + } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { + // The VPT block has a non-uniform predicate but it uses a vpst and its + // entry is guarded only by a vctp, which means we: // - Need to remove the original vpst. // - Then need to unpredicate any following instructions, until // we come across the divergent vpr def. @@ -922,7 +1440,7 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { // the divergent vpr def. // TODO: We could be producing more VPT blocks than necessary and could // fold the newly created one into a proceeding one. - for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()), + for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()), E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) RemovePredicate(&*I); @@ -935,28 +1453,58 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { ++Size; ++I; } + // Create a VPST (with a null mask for now, we'll recompute it later). MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, InsertAt->getDebugLoc(), TII->get(ARM::MVE_VPST)); - MIB.addImm(getARMVPTBlockMask(Size)); - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); + MIB.addImm(0); + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen()); LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); - Block.getVPST()->eraseFromParent(); + LoLoop.ToRemove.insert(Block.getPredicateThen()); + LoLoop.BlockMasksToRecompute.insert(MIB.getInstr()); + } + // Else, if the block uses a vpt, iterate over the block, removing the + // extra VCTPs it may contain. + else if (Block.isVPT()) { + bool RemovedVCTP = false; + for (PredicatedMI &Elt : Block.getInsts()) { + MachineInstr *MI = Elt.MI; + if (isVCTP(MI)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI); + LoLoop.ToRemove.insert(MI); + RemovedVCTP = true; + continue; + } + } + if (RemovedVCTP) + LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen()); } - } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { - // A vpt block which is only predicated upon vctp and has no internal vpr - // defs: + } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) { + // A vpt block starting with VPST, is only predicated upon vctp and has no + // internal vpr defs: // - Remove vpst. // - Unpredicate the remaining instructions. - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); - Block.getVPST()->eraseFromParent(); + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen()); + LoLoop.ToRemove.insert(Block.getPredicateThen()); for (auto &PredMI : Insts) RemovePredicate(PredMI.MI); } } - - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); - LoLoop.VCTP->eraseFromParent(); + LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n"); + // Remove the "main" VCTP + LoLoop.ToRemove.insert(LoLoop.VCTP); + LLVM_DEBUG(dbgs() << " " << *LoLoop.VCTP); + // Remove remaining secondary VCTPs + for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) { + // All VCTPs that aren't marked for removal yet should be unpredicated ones. + // The predicated ones should have already been marked for removal when + // visiting the VPT blocks. + if (LoLoop.ToRemove.insert(VCTP).second) { + assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None && + "Removing Predicated VCTP without updating the block mask!"); + LLVM_DEBUG(dbgs() << " " << *VCTP); + } + } } void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { @@ -973,9 +1521,8 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { MIB.add(End->getOperand(0)); MIB.add(End->getOperand(1)); LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); - - LoLoop.End->eraseFromParent(); - LoLoop.Dec->eraseFromParent(); + LoLoop.ToRemove.insert(LoLoop.Dec); + LoLoop.ToRemove.insert(End); return &*MIB; }; @@ -1001,7 +1548,7 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { RevertWhile(LoLoop.Start); else LoLoop.Start->eraseFromParent(); - bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true); + bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec); RevertLoopEnd(LoLoop.End, FlagsAlreadySet); } else { LoLoop.Start = ExpandLoopStart(LoLoop); @@ -1009,10 +1556,35 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { LoLoop.End = ExpandLoopEnd(LoLoop); RemoveDeadBranch(LoLoop.End); if (LoLoop.IsTailPredicationLegal()) { - RemoveLoopUpdate(LoLoop); ConvertVPTBlocks(LoLoop); + FixupReductions(LoLoop); + } + for (auto *I : LoLoop.ToRemove) { + LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I); + I->eraseFromParent(); + } + for (auto *I : LoLoop.BlockMasksToRecompute) { + LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I); + recomputeVPTBlockMask(*I); + LLVM_DEBUG(dbgs() << " ... done: " << *I); } } + + PostOrderLoopTraversal DFS(LoLoop.ML, *MLI); + DFS.ProcessLoop(); + const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder(); + for (auto *MBB : PostOrder) { + recomputeLiveIns(*MBB); + // FIXME: For some reason, the live-in print order is non-deterministic for + // our tests and I can't out why... So just sort them. + MBB->sortUniqueLiveIns(); + } + + for (auto *MBB : reverse(PostOrder)) + recomputeLivenessFlags(*MBB); + + // We've moved, removed and inserted new instructions, so update RDA. + RDA->reset(); } bool ARMLowOverheadLoops::RevertNonLoops() { |