summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp')
-rw-r--r--llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp1116
1 files changed, 844 insertions, 272 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
index 6717d4706aefe..be75d6bef08c4 100644
--- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
+++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
@@ -35,6 +35,20 @@
/// are defined to be as large as this maximum sequence of replacement
/// instructions.
///
+/// A note on VPR.P0 (the lane mask):
+/// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
+/// "VPT Active" context (which includes low-overhead loops and vpt blocks).
+/// They will simply "and" the result of their calculation with the current
+/// value of VPR.P0. You can think of it like this:
+/// \verbatim
+/// if VPT active: ; Between a DLSTP/LETP, or for predicated instrs
+/// VPR.P0 &= Value
+/// else
+/// VPR.P0 = Value
+/// \endverbatim
+/// When we're inside the low-overhead loop (between DLSTP and LETP), we always
+/// fall in the "VPT active" case, so we can consider that all VPR writes by
+/// one of those instruction is actually a "and".
//===----------------------------------------------------------------------===//
#include "ARM.h"
@@ -45,6 +59,7 @@
#include "Thumb2InstrInfo.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineLoopUtils.h"
@@ -60,34 +75,93 @@ using namespace llvm;
namespace {
+ using InstSet = SmallPtrSetImpl<MachineInstr *>;
+
+ class PostOrderLoopTraversal {
+ MachineLoop &ML;
+ MachineLoopInfo &MLI;
+ SmallPtrSet<MachineBasicBlock*, 4> Visited;
+ SmallVector<MachineBasicBlock*, 4> Order;
+
+ public:
+ PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
+ : ML(ML), MLI(MLI) { }
+
+ const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
+ return Order;
+ }
+
+ // Visit all the blocks within the loop, as well as exit blocks and any
+ // blocks properly dominating the header.
+ void ProcessLoop() {
+ std::function<void(MachineBasicBlock*)> Search = [this, &Search]
+ (MachineBasicBlock *MBB) -> void {
+ if (Visited.count(MBB))
+ return;
+
+ Visited.insert(MBB);
+ for (auto *Succ : MBB->successors()) {
+ if (!ML.contains(Succ))
+ continue;
+ Search(Succ);
+ }
+ Order.push_back(MBB);
+ };
+
+ // Insert exit blocks.
+ SmallVector<MachineBasicBlock*, 2> ExitBlocks;
+ ML.getExitBlocks(ExitBlocks);
+ for (auto *MBB : ExitBlocks)
+ Order.push_back(MBB);
+
+ // Then add the loop body.
+ Search(ML.getHeader());
+
+ // Then try the preheader and its predecessors.
+ std::function<void(MachineBasicBlock*)> GetPredecessor =
+ [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
+ Order.push_back(MBB);
+ if (MBB->pred_size() == 1)
+ GetPredecessor(*MBB->pred_begin());
+ };
+
+ if (auto *Preheader = ML.getLoopPreheader())
+ GetPredecessor(Preheader);
+ else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
+ GetPredecessor(Preheader);
+ }
+ };
+
struct PredicatedMI {
MachineInstr *MI = nullptr;
SetVector<MachineInstr*> Predicates;
public:
- PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
- MI(I) {
+ PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
+ assert(I && "Instruction must not be null!");
Predicates.insert(Preds.begin(), Preds.end());
}
};
- // Represent a VPT block, a list of instructions that begins with a VPST and
- // has a maximum of four proceeding instructions. All instructions within the
- // block are predicated upon the vpr and we allow instructions to define the
- // vpr within in the block too.
+ // Represent a VPT block, a list of instructions that begins with a VPT/VPST
+ // and has a maximum of four proceeding instructions. All instructions within
+ // the block are predicated upon the vpr and we allow instructions to define
+ // the vpr within in the block too.
class VPTBlock {
- std::unique_ptr<PredicatedMI> VPST;
+ // The predicate then instruction, which is either a VPT, or a VPST
+ // instruction.
+ std::unique_ptr<PredicatedMI> PredicateThen;
PredicatedMI *Divergent = nullptr;
SmallVector<PredicatedMI, 4> Insts;
public:
VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
- VPST = std::make_unique<PredicatedMI>(MI, Preds);
+ PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
}
void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
- if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
+ if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
Divergent = &Insts.back();
LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
}
@@ -104,38 +178,73 @@ namespace {
// Is the given instruction part of the predicate set controlling the entry
// to the block.
bool IsPredicatedOn(MachineInstr *MI) const {
- return VPST->Predicates.count(MI);
+ return PredicateThen->Predicates.count(MI);
+ }
+
+ // Returns true if this is a VPT instruction.
+ bool isVPT() const { return !isVPST(); }
+
+ // Returns true if this is a VPST instruction.
+ bool isVPST() const {
+ return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
}
// Is the given instruction the only predicate which controls the entry to
// the block.
bool IsOnlyPredicatedOn(MachineInstr *MI) const {
- return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
+ return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
}
unsigned size() const { return Insts.size(); }
SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
- MachineInstr *getVPST() const { return VPST->MI; }
+ MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
PredicatedMI *getDivergent() const { return Divergent; }
};
+ struct Reduction {
+ MachineInstr *Init;
+ MachineInstr &Copy;
+ MachineInstr &Reduce;
+ MachineInstr &VPSEL;
+
+ Reduction(MachineInstr *Init, MachineInstr *Mov, MachineInstr *Add,
+ MachineInstr *Sel)
+ : Init(Init), Copy(*Mov), Reduce(*Add), VPSEL(*Sel) { }
+ };
+
struct LowOverheadLoop {
- MachineLoop *ML = nullptr;
+ MachineLoop &ML;
+ MachineBasicBlock *Preheader = nullptr;
+ MachineLoopInfo &MLI;
+ ReachingDefAnalysis &RDA;
+ const TargetRegisterInfo &TRI;
+ const ARMBaseInstrInfo &TII;
MachineFunction *MF = nullptr;
MachineInstr *InsertPt = nullptr;
MachineInstr *Start = nullptr;
MachineInstr *Dec = nullptr;
MachineInstr *End = nullptr;
MachineInstr *VCTP = nullptr;
+ SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
VPTBlock *CurrentBlock = nullptr;
SetVector<MachineInstr*> CurrentPredicate;
SmallVector<VPTBlock, 4> VPTBlocks;
+ SmallPtrSet<MachineInstr*, 4> ToRemove;
+ SmallVector<std::unique_ptr<Reduction>, 1> Reductions;
+ SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
bool Revert = false;
bool CannotTailPredicate = false;
- LowOverheadLoop(MachineLoop *ML) : ML(ML) {
- MF = ML->getHeader()->getParent();
+ LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
+ ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
+ const ARMBaseInstrInfo &TII)
+ : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII) {
+ MF = ML.getHeader()->getParent();
+ if (auto *MBB = ML.getLoopPreheader())
+ Preheader = MBB;
+ else if (auto *MBB = MLI.findLoopPreheader(&ML, true))
+ Preheader = MBB;
}
// If this is an MVE instruction, check that we know how to use tail
@@ -151,22 +260,30 @@ namespace {
// For now, let's keep things really simple and only support a single
// block for tail predication.
return !Revert && FoundAllComponents() && VCTP &&
- !CannotTailPredicate && ML->getNumBlocks() == 1;
+ !CannotTailPredicate && ML.getNumBlocks() == 1;
}
- bool ValidateTailPredicate(MachineInstr *StartInsertPt,
- ReachingDefAnalysis *RDA,
- MachineLoopInfo *MLI);
+ // Check that the predication in the loop will be equivalent once we
+ // perform the conversion. Also ensure that we can provide the number
+ // of elements to the loop start instruction.
+ bool ValidateTailPredicate(MachineInstr *StartInsertPt);
+
+ // See whether the live-out instructions are a reduction that we can fixup
+ // later.
+ bool FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers);
+
+ // Check that any values available outside of the loop will be the same
+ // after tail predication conversion.
+ bool ValidateLiveOuts();
// Is it safe to define LR with DLS/WLS?
// LR can be defined if it is the operand to start, because it's the same
// value, or if it's going to be equivalent to the operand to Start.
- MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
+ MachineInstr *isSafeToDefineLR();
// Check the branch targets are within range and we satisfy our
// restrictions.
- void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
- MachineLoopInfo *MLI);
+ void CheckLegality(ARMBasicBlockUtils *BBUtils);
bool FoundAllComponents() const {
return Start && Dec && End;
@@ -241,18 +358,19 @@ namespace {
void RevertWhile(MachineInstr *MI) const;
- bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
+ bool RevertLoopDec(MachineInstr *MI) const;
void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
- void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
-
void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
+ void FixupReductions(LowOverheadLoop &LoLoop) const;
+
MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
void Expand(LowOverheadLoop &LoLoop);
+ void IterationCountDCE(LowOverheadLoop &LoLoop);
};
}
@@ -261,7 +379,7 @@ char ARMLowOverheadLoops::ID = 0;
INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
false, false)
-MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
+MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
// We can define LR because LR already contains the same value.
if (Start->getOperand(0).getReg() == ARM::LR)
return Start;
@@ -279,52 +397,22 @@ MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
// Find an insertion point:
// - Is there a (mov lr, Count) before Start? If so, and nothing else writes
// to Count before Start, we can insert at that mov.
- if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR))
- if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
+ if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
+ if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
return LRDef;
// - Is there a (mov lr, Count) after Start? If so, and nothing else writes
// to Count after Start, we can insert at that mov.
- if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR))
- if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
+ if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
+ if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
return LRDef;
// We've found no suitable LR def and Start doesn't use LR directly. Can we
// just define LR anyway?
- if (!RDA->isRegUsedAfter(Start, ARM::LR))
- return Start;
-
- return nullptr;
-}
-
-// Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must
-// not define a register that is used by any instructions, after and including,
-// 'To'. These instructions also must not redefine any of Froms operands.
-template<typename Iterator>
-static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) {
- SmallSet<int, 2> Defs;
- // First check that From would compute the same value if moved.
- for (auto &MO : From->operands()) {
- if (!MO.isReg() || MO.isUndef() || !MO.getReg())
- continue;
- if (MO.isDef())
- Defs.insert(MO.getReg());
- else if (!RDA->hasSameReachingDef(From, To, MO.getReg()))
- return false;
- }
-
- // Now walk checking that the rest of the instructions will compute the same
- // value.
- for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) {
- for (auto &MO : I->operands())
- if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg()))
- return false;
- }
- return true;
+ return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
}
-bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
- ReachingDefAnalysis *RDA, MachineLoopInfo *MLI) {
+bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
assert(VCTP && "VCTP instruction expected but is not set");
// All predication within the loop should be based on vctp. If the block
// isn't predicated on entry, check whether the vctp is within the block
@@ -332,24 +420,35 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
for (auto &Block : VPTBlocks) {
if (Block.IsPredicatedOn(VCTP))
continue;
- if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
+ if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
- << *Block.getDivergent()->MI);
+ << *Block.getDivergent()->MI);
return false;
}
SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
for (auto &PredMI : Insts) {
- if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
+ // Check the instructions in the block and only allow:
+ // - VCTPs
+ // - Instructions predicated on the main VCTP
+ // - Any VCMP
+ // - VCMPs just "and" their result with VPR.P0. Whether they are
+ // located before/after the VCTP is irrelevant - the end result will
+ // be the same in both cases, so there's no point in requiring them
+ // to be located after the VCTP!
+ if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
+ VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
continue;
LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
- << " - which is predicated on:\n";
- for (auto *MI : PredMI.Predicates)
- dbgs() << " - " << *MI;
- );
+ << " - which is predicated on:\n";
+ for (auto *MI : PredMI.Predicates)
+ dbgs() << " - " << *MI);
return false;
}
}
+ if (!ValidateLiveOuts())
+ return false;
+
// For tail predication, we need to provide the number of elements, instead
// of the iteration count, to the loop start instruction. The number of
// elements is provided to the vctp instruction, so we need to check that
@@ -359,7 +458,7 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
// If the register is defined within loop, then we can't perform TP.
// TODO: Check whether this is just a mov of a register that would be
// available.
- if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
+ if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
return false;
}
@@ -367,17 +466,20 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
// The element count register maybe defined after InsertPt, in which case we
// need to try to move either InsertPt or the def so that the [w|d]lstp can
// use the value.
- MachineBasicBlock *InsertBB = InsertPt->getParent();
- if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) {
- if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) {
- if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) {
+ // TODO: On failing to move an instruction, check if the count is provided by
+ // a mov and whether we can use the mov operand directly.
+ MachineBasicBlock *InsertBB = StartInsertPt->getParent();
+ if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
+ if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
+ if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
ElemDef->removeFromParent();
- InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef);
+ InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
<< *ElemDef);
- } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) {
- InsertPt->removeFromParent();
- InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt);
+ } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
+ StartInsertPt->removeFromParent();
+ InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
+ StartInsertPt);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
} else {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
@@ -390,10 +492,10 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
// Especially in the case of while loops, InsertBB may not be the
// preheader, so we need to check that the register isn't redefined
// before entering the loop.
- auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
+ auto CannotProvideElements = [this](MachineBasicBlock *MBB,
Register NumElements) {
// NumElements is redefined in this block.
- if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
+ if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
return true;
// Don't continue searching up through multiple predecessors.
@@ -404,7 +506,7 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
};
// First, find the block that looks like the preheader.
- MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
+ MachineBasicBlock *MBB = Preheader;
if (!MBB) {
LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
return false;
@@ -419,13 +521,372 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
MBB = *MBB->pred_begin();
}
- LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
+ // Check that the value change of the element count is what we expect and
+ // that the predication will be equivalent. For this we need:
+ // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
+ // and we can also allow register copies within the chain too.
+ auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
+ return -getAddSubImmediate(*MI) == ExpectedVecWidth;
+ };
+
+ MBB = VCTP->getParent();
+ if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
+ SmallPtrSet<MachineInstr*, 2> ElementChain;
+ SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
+ unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
+
+ Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
+
+ if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
+ bool FoundSub = false;
+
+ for (auto *MI : ElementChain) {
+ if (isMovRegOpcode(MI->getOpcode()))
+ continue;
+
+ if (isSubImmOpcode(MI->getOpcode())) {
+ if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
+ return false;
+ FoundSub = true;
+ } else
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
+ for (auto *MI : ElementChain)
+ dbgs() << " - " << *MI);
+ ToRemove.insert(ElementChain.begin(), ElementChain.end());
+ }
+ }
+ return true;
+}
+
+static bool isVectorPredicated(MachineInstr *MI) {
+ int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
+ return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
+}
+
+static bool isRegInClass(const MachineOperand &MO,
+ const TargetRegisterClass *Class) {
+ return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
+}
+
+// MVE 'narrowing' operate on half a lane, reading from half and writing
+// to half, which are referred to has the top and bottom half. The other
+// half retains its previous value.
+static bool retainsPreviousHalfElement(const MachineInstr &MI) {
+ const MCInstrDesc &MCID = MI.getDesc();
+ uint64_t Flags = MCID.TSFlags;
+ return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
+}
+
+// Some MVE instructions read from the top/bottom halves of their operand(s)
+// and generate a vector result with result elements that are double the
+// width of the input.
+static bool producesDoubleWidthResult(const MachineInstr &MI) {
+ const MCInstrDesc &MCID = MI.getDesc();
+ uint64_t Flags = MCID.TSFlags;
+ return (Flags & ARMII::DoubleWidthResult) != 0;
+}
+
+static bool isHorizontalReduction(const MachineInstr &MI) {
+ const MCInstrDesc &MCID = MI.getDesc();
+ uint64_t Flags = MCID.TSFlags;
+ return (Flags & ARMII::HorizontalReduction) != 0;
+}
+
+// Can this instruction generate a non-zero result when given only zeroed
+// operands? This allows us to know that, given operands with false bytes
+// zeroed by masked loads, that the result will also contain zeros in those
+// bytes.
+static bool canGenerateNonZeros(const MachineInstr &MI) {
+
+ // Check for instructions which can write into a larger element size,
+ // possibly writing into a previous zero'd lane.
+ if (producesDoubleWidthResult(MI))
+ return true;
+
+ switch (MI.getOpcode()) {
+ default:
+ break;
+ // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
+ // fp16 -> fp32 vector conversions.
+ // Instructions that perform a NOT will generate 1s from 0s.
+ case ARM::MVE_VMVN:
+ case ARM::MVE_VORN:
+ // Count leading zeros will do just that!
+ case ARM::MVE_VCLZs8:
+ case ARM::MVE_VCLZs16:
+ case ARM::MVE_VCLZs32:
+ return true;
+ }
+ return false;
+}
+
+
+// Look at its register uses to see if it only can only receive zeros
+// into its false lanes which would then produce zeros. Also check that
+// the output register is also defined by an FalseLanesZero instruction
+// so that if tail-predication happens, the lanes that aren't updated will
+// still be zeros.
+static bool producesFalseLanesZero(MachineInstr &MI,
+ const TargetRegisterClass *QPRs,
+ const ReachingDefAnalysis &RDA,
+ InstSet &FalseLanesZero) {
+ if (canGenerateNonZeros(MI))
+ return false;
+
+ bool AllowScalars = isHorizontalReduction(MI);
+ for (auto &MO : MI.operands()) {
+ if (!MO.isReg() || !MO.getReg())
+ continue;
+ if (!isRegInClass(MO, QPRs) && AllowScalars)
+ continue;
+ if (auto *OpDef = RDA.getMIOperand(&MI, MO))
+ if (FalseLanesZero.count(OpDef))
+ continue;
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
+ return true;
+}
+
+bool
+LowOverheadLoop::FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers) {
+ // Also check for reductions where the operation needs to be merging values
+ // from the last and previous loop iterations. This means an instruction
+ // producing a value and a vmov storing the value calculated in the previous
+ // iteration. So we can have two live-out regs, one produced by a vmov and
+ // both being consumed by a vpsel.
+ LLVM_DEBUG(dbgs() << "ARM Loops: Looking for reduction live-outs:\n";
+ for (auto *MI : LiveMIs)
+ dbgs() << " - " << *MI);
+
+ if (!Preheader)
+ return false;
+
+ // Expect a vmov, a vadd and a single vpsel user.
+ // TODO: This means we can't currently support multiple reductions in the
+ // loop.
+ if (LiveMIs.size() != 2 || LiveOutUsers.size() != 1)
+ return false;
+
+ MachineInstr *VPSEL = *LiveOutUsers.begin();
+ if (VPSEL->getOpcode() != ARM::MVE_VPSEL)
+ return false;
+
+ unsigned VPRIdx = llvm::findFirstVPTPredOperandIdx(*VPSEL) + 1;
+ MachineInstr *Pred = RDA.getMIOperand(VPSEL, VPRIdx);
+ if (!Pred || Pred != VCTP) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Not using equivalent predicate.\n");
+ return false;
+ }
+
+ MachineInstr *Reduce = RDA.getMIOperand(VPSEL, 1);
+ if (!Reduce)
+ return false;
+
+ assert(LiveMIs.count(Reduce) && "Expected MI to be live-out");
+
+ // TODO: Support more operations than VADD.
+ switch (VCTP->getOpcode()) {
+ default:
+ return false;
+ case ARM::MVE_VCTP8:
+ if (Reduce->getOpcode() != ARM::MVE_VADDi8)
+ return false;
+ break;
+ case ARM::MVE_VCTP16:
+ if (Reduce->getOpcode() != ARM::MVE_VADDi16)
+ return false;
+ break;
+ case ARM::MVE_VCTP32:
+ if (Reduce->getOpcode() != ARM::MVE_VADDi32)
+ return false;
+ break;
+ }
+
+ // Test that the reduce op is overwriting ones of its operands.
+ if (Reduce->getOperand(0).getReg() != Reduce->getOperand(1).getReg() &&
+ Reduce->getOperand(0).getReg() != Reduce->getOperand(2).getReg()) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Reducing op isn't overwriting itself.\n");
+ return false;
+ }
+
+ // Check that the VORR is actually a VMOV.
+ MachineInstr *Copy = RDA.getMIOperand(VPSEL, 2);
+ if (!Copy || Copy->getOpcode() != ARM::MVE_VORR ||
+ !Copy->getOperand(1).isReg() || !Copy->getOperand(2).isReg() ||
+ Copy->getOperand(1).getReg() != Copy->getOperand(2).getReg())
+ return false;
+
+ assert(LiveMIs.count(Copy) && "Expected MI to be live-out");
+
+ // Check that the vadd and vmov are only used by each other and the vpsel.
+ SmallPtrSet<MachineInstr*, 2> CopyUsers;
+ RDA.getGlobalUses(Copy, Copy->getOperand(0).getReg(), CopyUsers);
+ if (CopyUsers.size() > 2 || !CopyUsers.count(Reduce)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Copy users unsupported.\n");
+ return false;
+ }
+
+ SmallPtrSet<MachineInstr*, 2> ReduceUsers;
+ RDA.getGlobalUses(Reduce, Reduce->getOperand(0).getReg(), ReduceUsers);
+ if (ReduceUsers.size() > 2 || !ReduceUsers.count(Copy)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Reduce users unsupported.\n");
+ return false;
+ }
+
+ // Then find whether there's an instruction initialising the register that
+ // is storing the reduction.
+ SmallPtrSet<MachineInstr*, 2> Incoming;
+ RDA.getLiveOuts(Preheader, Copy->getOperand(1).getReg(), Incoming);
+ if (Incoming.size() > 1)
+ return false;
+
+ MachineInstr *Init = Incoming.empty() ? nullptr : *Incoming.begin();
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found a reduction:\n"
+ << " - " << *Copy
+ << " - " << *Reduce
+ << " - " << *VPSEL);
+ Reductions.push_back(std::make_unique<Reduction>(Init, Copy, Reduce, VPSEL));
return true;
}
-void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
- ReachingDefAnalysis *RDA,
- MachineLoopInfo *MLI) {
+bool LowOverheadLoop::ValidateLiveOuts() {
+ // We want to find out if the tail-predicated version of this loop will
+ // produce the same values as the loop in its original form. For this to
+ // be true, the newly inserted implicit predication must not change the
+ // the (observable) results.
+ // We're doing this because many instructions in the loop will not be
+ // predicated and so the conversion from VPT predication to tail-predication
+ // can result in different values being produced; due to the tail-predication
+ // preventing many instructions from updating their falsely predicated
+ // lanes. This analysis assumes that all the instructions perform lane-wise
+ // operations and don't perform any exchanges.
+ // A masked load, whether through VPT or tail predication, will write zeros
+ // to any of the falsely predicated bytes. So, from the loads, we know that
+ // the false lanes are zeroed and here we're trying to track that those false
+ // lanes remain zero, or where they change, the differences are masked away
+ // by their user(s).
+ // All MVE loads and stores have to be predicated, so we know that any load
+ // operands, or stored results are equivalent already. Other explicitly
+ // predicated instructions will perform the same operation in the original
+ // loop and the tail-predicated form too. Because of this, we can insert
+ // loads, stores and other predicated instructions into our Predicated
+ // set and build from there.
+ const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
+ SetVector<MachineInstr *> FalseLanesUnknown;
+ SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
+ SmallPtrSet<MachineInstr *, 4> Predicated;
+ MachineBasicBlock *Header = ML.getHeader();
+
+ for (auto &MI : *Header) {
+ const MCInstrDesc &MCID = MI.getDesc();
+ uint64_t Flags = MCID.TSFlags;
+ if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
+ continue;
+
+ if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
+ continue;
+
+ // Predicated loads will write zeros to the falsely predicated bytes of the
+ // destination register.
+ if (isVectorPredicated(&MI)) {
+ if (MI.mayLoad())
+ FalseLanesZero.insert(&MI);
+ Predicated.insert(&MI);
+ continue;
+ }
+
+ if (MI.getNumDefs() == 0)
+ continue;
+
+ if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
+ // We require retaining and horizontal operations to operate upon zero'd
+ // false lanes to ensure the conversion doesn't change the output.
+ if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
+ return false;
+ // Otherwise we need to evaluate this instruction later to see whether
+ // unknown false lanes will get masked away by their user(s).
+ FalseLanesUnknown.insert(&MI);
+ } else if (!isHorizontalReduction(MI))
+ FalseLanesZero.insert(&MI);
+ }
+
+ auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
+ SmallPtrSetImpl<MachineInstr *> &Predicated) {
+ SmallPtrSet<MachineInstr *, 2> Uses;
+ RDA.getGlobalUses(MI, MO.getReg(), Uses);
+ for (auto *Use : Uses) {
+ if (Use != MI && !Predicated.count(Use))
+ return false;
+ }
+ return true;
+ };
+
+ // Visit the unknowns in reverse so that we can start at the values being
+ // stored and then we can work towards the leaves, hopefully adding more
+ // instructions to Predicated. Successfully terminating the loop means that
+ // all the unknown values have to found to be masked by predicated user(s).
+ // For any unpredicated values, we store them in NonPredicated so that we
+ // can later check whether these form a reduction.
+ SmallPtrSet<MachineInstr*, 2> NonPredicated;
+ for (auto *MI : reverse(FalseLanesUnknown)) {
+ for (auto &MO : MI->operands()) {
+ if (!isRegInClass(MO, QPRs) || !MO.isDef())
+ continue;
+ if (!HasPredicatedUsers(MI, MO, Predicated)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
+ << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
+ NonPredicated.insert(MI);
+ continue;
+ }
+ }
+ // Any unknown false lanes have been masked away by the user(s).
+ Predicated.insert(MI);
+ }
+
+ SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
+ SmallPtrSet<MachineInstr*, 2> LiveOutUsers;
+ SmallVector<MachineBasicBlock *, 2> ExitBlocks;
+ ML.getExitBlocks(ExitBlocks);
+ assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
+ assert(ExitBlocks.size() == 1 && "Expected a single exit block");
+ MachineBasicBlock *ExitBB = ExitBlocks.front();
+ for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
+ // Check Q-regs that are live in the exit blocks. We don't collect scalars
+ // because they won't be affected by lane predication.
+ if (QPRs->contains(RegMask.PhysReg)) {
+ if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
+ LiveOutMIs.insert(MI);
+ RDA.getLiveInUses(ExitBB, RegMask.PhysReg, LiveOutUsers);
+ }
+ }
+
+ // If we have any non-predicated live-outs, they need to be part of a
+ // reduction that we can fixup later. The reduction that the form of an
+ // operation that uses its previous values through a vmov and then a vpsel
+ // resides in the exit blocks to select the final bytes from n and n-1
+ // iterations.
+ if (!NonPredicated.empty() &&
+ !FindValidReduction(NonPredicated, LiveOutUsers))
+ return false;
+
+ // We've already validated that any VPT predication within the loop will be
+ // equivalent when we perform the predication transformation; so we know that
+ // any VPT predicated instruction is predicated upon VCTP. Any live-out
+ // instruction needs to be predicated, so check this here. The instructions
+ // in NonPredicated have been found to be a reduction that we can ensure its
+ // legality.
+ for (auto *MI : LiveOutMIs)
+ if (!isVectorPredicated(MI) && !NonPredicated.count(MI))
+ return false;
+
+ return true;
+}
+
+void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
if (Revert)
return;
@@ -434,7 +895,7 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
// TODO Maybe there's cases where the target doesn't have to be the header,
// but for now be safe and revert.
- if (End->getOperand(1).getMBB() != ML->getHeader()) {
+ if (End->getOperand(1).getMBB() != ML.getHeader()) {
LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
Revert = true;
return;
@@ -442,8 +903,8 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
// The WLS and LE instructions have 12-bits for the label offset. WLS
// requires a positive offset, while LE uses negative.
- if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
- !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
+ if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
+ !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
Revert = true;
return;
@@ -458,7 +919,7 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
return;
}
- InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
+ InsertPt = Revert ? nullptr : isSafeToDefineLR();
if (!InsertPt) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
Revert = true;
@@ -473,9 +934,9 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
return;
}
- assert(ML->getBlocks().size() == 1 &&
+ assert(ML.getBlocks().size() == 1 &&
"Shouldn't be processing a loop with more than one block");
- CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI);
+ CannotTailPredicate = !ValidateTailPredicate(InsertPt);
LLVM_DEBUG(if (CannotTailPredicate)
dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
}
@@ -484,29 +945,44 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
if (CannotTailPredicate)
return false;
- // Only support a single vctp.
- if (isVCTP(MI) && VCTP)
- return false;
+ if (isVCTP(MI)) {
+ // If we find another VCTP, check whether it uses the same value as the main VCTP.
+ // If it does, store it in the SecondaryVCTPs set, else refuse it.
+ if (VCTP) {
+ if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
+ !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
+ "definition from the main VCTP");
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
+ SecondaryVCTPs.insert(MI);
+ } else {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
+ VCTP = MI;
+ }
+ } else if (isVPTOpcode(MI->getOpcode())) {
+ if (MI->getOpcode() != ARM::MVE_VPST) {
+ assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
+ "VPT does not implicitly define VPR?!");
+ CurrentPredicate.insert(MI);
+ }
- // Start a new vpt block when we discover a vpt.
- if (MI->getOpcode() == ARM::MVE_VPST) {
VPTBlocks.emplace_back(MI, CurrentPredicate);
CurrentBlock = &VPTBlocks.back();
return true;
- } else if (isVCTP(MI))
- VCTP = MI;
- else if (MI->getOpcode() == ARM::MVE_VPSEL ||
- MI->getOpcode() == ARM::MVE_VPNOT)
+ } else if (MI->getOpcode() == ARM::MVE_VPSEL ||
+ MI->getOpcode() == ARM::MVE_VPNOT) {
+ // TODO: Allow VPSEL and VPNOT, we currently cannot because:
+ // 1) It will use the VPR as a predicate operand, but doesn't have to be
+ // instead a VPT block, which means we can assert while building up
+ // the VPT block because we don't find another VPT or VPST to being a new
+ // one.
+ // 2) VPSEL still requires a VPR operand even after tail predicating,
+ // which means we can't remove it unless there is another
+ // instruction, such as vcmp, that can provide the VPR def.
return false;
-
- // TODO: Allow VPSEL and VPNOT, we currently cannot because:
- // 1) It will use the VPR as a predicate operand, but doesn't have to be
- // instead a VPT block, which means we can assert while building up
- // the VPT block because we don't find another VPST to being a new
- // one.
- // 2) VPSEL still requires a VPR operand even after tail predicating,
- // which means we can't remove it unless there is another
- // instruction, such as vcmp, that can provide the VPR def.
+ }
bool IsUse = false;
bool IsDef = false;
@@ -548,7 +1024,9 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
return false;
}
- return true;
+ // If the instruction is already explicitly predicated, then the conversion
+ // will be fine, but ensure that all memory operations are predicated.
+ return !IsUse && MI->mayLoadOrStore() ? false : true;
}
bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
@@ -591,6 +1069,8 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
dbgs() << " - " << Preheader->getName() << "\n";
else if (auto *Preheader = MLI->findLoopPreheader(ML))
dbgs() << " - " << Preheader->getName() << "\n";
+ else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
+ dbgs() << " - " << Preheader->getName() << "\n";
for (auto *MBB : ML->getBlocks())
dbgs() << " - " << MBB->getName() << "\n";
);
@@ -608,14 +1088,12 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
return nullptr;
};
- LowOverheadLoop LoLoop(ML);
+ LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
// Search the preheader for the start intrinsic.
// FIXME: I don't see why we shouldn't be supporting multiple predecessors
// with potentially multiple set.loop.iterations, so we need to enable this.
- if (auto *Preheader = ML->getLoopPreheader())
- LoLoop.Start = SearchForStart(Preheader);
- else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
- LoLoop.Start = SearchForStart(Preheader);
+ if (LoLoop.Preheader)
+ LoLoop.Start = SearchForStart(LoLoop.Preheader);
else
return false;
@@ -624,7 +1102,9 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
// whether we can convert that predicate using tail predication.
for (auto *MBB : reverse(ML->getBlocks())) {
for (auto &MI : *MBB) {
- if (MI.getOpcode() == ARM::t2LoopDec)
+ if (MI.isDebugValue())
+ continue;
+ else if (MI.getOpcode() == ARM::t2LoopDec)
LoLoop.Dec = &MI;
else if (MI.getOpcode() == ARM::t2LoopEnd)
LoLoop.End = &MI;
@@ -641,28 +1121,6 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
// Check we know how to tail predicate any mve instructions.
LoLoop.AnalyseMVEInst(&MI);
}
-
- // We need to ensure that LR is not used or defined inbetween LoopDec and
- // LoopEnd.
- if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
- continue;
-
- // If we find that LR has been written or read between LoopDec and
- // LoopEnd, expect that the decremented value is being used else where.
- // Because this value isn't actually going to be produced until the
- // latch, by LE, we would need to generate a real sub. The value is also
- // likely to be copied/reloaded for use of LoopEnd - in which in case
- // we'd need to perform an add because it gets subtracted again by LE!
- // The other option is to then generate the other form of LE which doesn't
- // perform the sub.
- for (auto &MO : MI.operands()) {
- if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
- MO.getReg() == ARM::LR) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
- LoLoop.Revert = true;
- break;
- }
- }
}
}
@@ -672,7 +1130,15 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
return false;
}
- LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
+ // Check that the only instruction using LoopDec is LoopEnd.
+ // TODO: Check for copy chains that really have no effect.
+ SmallPtrSet<MachineInstr*, 2> Uses;
+ RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
+ if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
+ LoLoop.Revert = true;
+ }
+ LoLoop.CheckLegality(BBUtils.get());
Expand(LoLoop);
return true;
}
@@ -702,16 +1168,19 @@ void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
MI->eraseFromParent();
}
-bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
- bool SetFlags) const {
+bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
MachineBasicBlock *MBB = MI->getParent();
+ SmallPtrSet<MachineInstr*, 1> Ignore;
+ for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
+ if (I->getOpcode() == ARM::t2LoopEnd) {
+ Ignore.insert(&*I);
+ break;
+ }
+ }
// If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
- if (SetFlags &&
- (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
- !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
- SetFlags = false;
+ bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
TII->get(ARM::t2SUBri));
@@ -759,7 +1228,102 @@ void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
MI->eraseFromParent();
}
+// Perform dead code elimation on the loop iteration count setup expression.
+// If we are tail-predicating, the number of elements to be processed is the
+// operand of the VCTP instruction in the vector body, see getCount(), which is
+// register $r3 in this example:
+//
+// $lr = big-itercount-expression
+// ..
+// t2DoLoopStart renamable $lr
+// vector.body:
+// ..
+// $vpr = MVE_VCTP32 renamable $r3
+// renamable $lr = t2LoopDec killed renamable $lr, 1
+// t2LoopEnd renamable $lr, %vector.body
+// tB %end
+//
+// What we would like achieve here is to replace the do-loop start pseudo
+// instruction t2DoLoopStart with:
+//
+// $lr = MVE_DLSTP_32 killed renamable $r3
+//
+// Thus, $r3 which defines the number of elements, is written to $lr,
+// and then we want to delete the whole chain that used to define $lr,
+// see the comment below how this chain could look like.
+//
+void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
+ if (!LoLoop.IsTailPredicationLegal())
+ return;
+
+ LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
+
+ MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
+ if (!Def) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
+ return;
+ }
+
+ // Collect and remove the users of iteration count.
+ SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec,
+ LoLoop.End, LoLoop.InsertPt };
+ SmallPtrSet<MachineInstr*, 2> Remove;
+ if (RDA->isSafeToRemove(Def, Remove, Killed))
+ LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
+ else {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
+ return;
+ }
+
+ // Collect the dead code and the MBBs in which they reside.
+ RDA->collectKilledOperands(Def, Killed);
+ SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
+ for (auto *MI : Killed)
+ BasicBlocks.insert(MI->getParent());
+
+ // Collect IT blocks in all affected basic blocks.
+ std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
+ for (auto *MBB : BasicBlocks) {
+ for (auto &MI : *MBB) {
+ if (MI.getOpcode() != ARM::t2IT)
+ continue;
+ RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
+ }
+ }
+
+ // If we're removing all of the instructions within an IT block, then
+ // also remove the IT instruction.
+ SmallPtrSet<MachineInstr*, 2> ModifiedITs;
+ for (auto *MI : Killed) {
+ if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
+ MachineInstr *IT = RDA->getMIOperand(MI, *MO);
+ auto &CurrentBlock = ITBlocks[IT];
+ CurrentBlock.erase(MI);
+ if (CurrentBlock.empty())
+ ModifiedITs.erase(IT);
+ else
+ ModifiedITs.insert(IT);
+ }
+ }
+
+ // Delete the killed instructions only if we don't have any IT blocks that
+ // need to be modified because we need to fixup the mask.
+ // TODO: Handle cases where IT blocks are modified.
+ if (ModifiedITs.empty()) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
+ for (auto *MI : Killed)
+ dbgs() << " - " << *MI);
+ LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
+ } else
+ LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
+}
+
MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
+ // When using tail-predication, try to delete the dead code that was used to
+ // calculate the number of loop iterations.
+ IterationCountDCE(LoLoop);
+
MachineInstr *InsertPt = LoLoop.InsertPt;
MachineInstr *Start = LoLoop.Start;
MachineBasicBlock *MBB = InsertPt->getParent();
@@ -775,109 +1339,67 @@ MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
if (!IsDo)
MIB.add(Start->getOperand(1));
- // When using tail-predication, try to delete the dead code that was used to
- // calculate the number of loop iterations.
- if (LoLoop.IsTailPredicationLegal()) {
- SmallVector<MachineInstr*, 4> Killed;
- SmallVector<MachineInstr*, 4> Dead;
- if (auto *Def = RDA->getReachingMIDef(Start,
- Start->getOperand(0).getReg())) {
- Killed.push_back(Def);
-
- while (!Killed.empty()) {
- MachineInstr *Def = Killed.back();
- Killed.pop_back();
- Dead.push_back(Def);
- for (auto &MO : Def->operands()) {
- if (!MO.isReg() || !MO.isKill())
- continue;
-
- MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
- if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
- Killed.push_back(Kill);
- }
- }
- for (auto *MI : Dead)
- MI->eraseFromParent();
- }
- }
-
// If we're inserting at a mov lr, then remove it as it's redundant.
if (InsertPt != Start)
- InsertPt->eraseFromParent();
- Start->eraseFromParent();
+ LoLoop.ToRemove.insert(InsertPt);
+ LoLoop.ToRemove.insert(Start);
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
return &*MIB;
}
-// Goal is to optimise and clean-up these loops:
-//
-// vector.body:
-// renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
-// renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
-// ..
-// $lr = MVE_DLSTP_32 renamable $r3
-//
-// The SUB is the old update of the loop iteration count expression, which
-// is no longer needed. This sub is removed when the element count, which is in
-// r3 in this example, is defined by an instruction in the loop, and it has
-// no uses.
-//
-void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
- Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
- MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
-
- LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
-
- if (LoLoop.ML->getNumBlocks() != 1) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Single block loop expected\n");
- return;
- }
-
- LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing elemcount in operand: ";
- LoLoop.VCTP->getOperand(1).dump());
-
- // Find the definition we are interested in removing, if there is one.
- MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
- if (!Def) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Can't find a def, nothing to do.\n");
- return;
- }
-
- // Bail if we define CPSR and it is not dead
- if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
- LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
- return;
- }
-
- // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
- if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
- return;
- }
+void ARMLowOverheadLoops::FixupReductions(LowOverheadLoop &LoLoop) const {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Fixing up reduction(s).\n");
+ auto BuildMov = [this](MachineInstr &InsertPt, Register To, Register From) {
+ MachineBasicBlock *MBB = InsertPt.getParent();
+ MachineInstrBuilder MIB =
+ BuildMI(*MBB, &InsertPt, InsertPt.getDebugLoc(), TII->get(ARM::MVE_VORR));
+ MIB.addDef(To);
+ MIB.addReg(From);
+ MIB.addReg(From);
+ MIB.addImm(0);
+ MIB.addReg(0);
+ MIB.addReg(To);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Inserted VMOV: " << *MIB);
+ };
- // Bail if there are uses after this Def in the block.
- SmallVector<MachineInstr*, 4> Uses;
- RDA->getReachingLocalUses(Def, ElemCount, Uses);
- if (Uses.size()) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
- return;
- }
+ for (auto &Reduction : LoLoop.Reductions) {
+ MachineInstr &Copy = Reduction->Copy;
+ MachineInstr &Reduce = Reduction->Reduce;
+ Register DestReg = Copy.getOperand(0).getReg();
- Uses.clear();
- RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
+ // Change the initialiser if present
+ if (Reduction->Init) {
+ MachineInstr *Init = Reduction->Init;
- // Remove Def if there are no uses, or if the only use is the VCTP
- // instruction.
- if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
- Def->dump());
- Def->eraseFromParent();
- return;
+ for (unsigned i = 0; i < Init->getNumOperands(); ++i) {
+ MachineOperand &MO = Init->getOperand(i);
+ if (MO.isReg() && MO.isUse() && MO.isTied() &&
+ Init->findTiedOperandIdx(i) == 0)
+ Init->getOperand(i).setReg(DestReg);
+ }
+ Init->getOperand(0).setReg(DestReg);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Changed init regs: " << *Init);
+ } else
+ BuildMov(LoLoop.Preheader->instr_back(), DestReg, Copy.getOperand(1).getReg());
+
+ // Change the reducing op to write to the register that is used to copy
+ // its value on the next iteration. Also update the tied-def operand.
+ Reduce.getOperand(0).setReg(DestReg);
+ Reduce.getOperand(5).setReg(DestReg);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Changed reduction regs: " << Reduce);
+
+ // Instead of a vpsel, just copy the register into the necessary one.
+ MachineInstr &VPSEL = Reduction->VPSEL;
+ if (VPSEL.getOperand(0).getReg() != DestReg)
+ BuildMov(VPSEL, VPSEL.getOperand(0).getReg(), DestReg);
+
+ // Remove the unnecessary instructions.
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing:\n"
+ << " - " << Copy
+ << " - " << VPSEL << "\n");
+ Copy.eraseFromParent();
+ VPSEL.eraseFromParent();
}
-
- LLVM_DEBUG(dbgs() << "ARM Loops: Can't remove loop update, it's used by:\n";
- for (auto U : Uses) U->dump());
}
void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
@@ -893,28 +1415,24 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
};
// There are a few scenarios which we have to fix up:
- // 1) A VPT block with is only predicated by the vctp and has no internal vpr
- // defs.
- // 2) A VPT block which is only predicated by the vctp but has an internal
- // vpr def.
- // 3) A VPT block which is predicated upon the vctp as well as another vpr
- // def.
- // 4) A VPT block which is not predicated upon a vctp, but contains it and
- // all instructions within the block are predicated upon in.
-
+ // 1. VPT Blocks with non-uniform predicates:
+ // - a. When the divergent instruction is a vctp
+ // - b. When the block uses a vpst, and is only predicated on the vctp
+ // - c. When the block uses a vpt and (optionally) contains one or more
+ // vctp.
+ // 2. VPT Blocks with uniform predicates:
+ // - a. The block uses a vpst, and is only predicated on the vctp
for (auto &Block : LoLoop.getVPTBlocks()) {
SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
if (Block.HasNonUniformPredicate()) {
PredicatedMI *Divergent = Block.getDivergent();
if (isVCTP(Divergent->MI)) {
- // The vctp will be removed, so the size of the vpt block needs to be
- // modified.
- uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
- Block.getVPST()->getOperand(0).setImm(Size);
- LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
- } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
- // The VPT block has a non-uniform predicate but it's entry is guarded
- // only by a vctp, which means we:
+ // The vctp will be removed, so the block mask of the vp(s)t will need
+ // to be recomputed.
+ LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
+ } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
+ // The VPT block has a non-uniform predicate but it uses a vpst and its
+ // entry is guarded only by a vctp, which means we:
// - Need to remove the original vpst.
// - Then need to unpredicate any following instructions, until
// we come across the divergent vpr def.
@@ -922,7 +1440,7 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
// the divergent vpr def.
// TODO: We could be producing more VPT blocks than necessary and could
// fold the newly created one into a proceeding one.
- for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
+ for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
RemovePredicate(&*I);
@@ -935,28 +1453,58 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
++Size;
++I;
}
+ // Create a VPST (with a null mask for now, we'll recompute it later).
MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
InsertAt->getDebugLoc(),
TII->get(ARM::MVE_VPST));
- MIB.addImm(getARMVPTBlockMask(Size));
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
+ MIB.addImm(0);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
- Block.getVPST()->eraseFromParent();
+ LoLoop.ToRemove.insert(Block.getPredicateThen());
+ LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
+ }
+ // Else, if the block uses a vpt, iterate over the block, removing the
+ // extra VCTPs it may contain.
+ else if (Block.isVPT()) {
+ bool RemovedVCTP = false;
+ for (PredicatedMI &Elt : Block.getInsts()) {
+ MachineInstr *MI = Elt.MI;
+ if (isVCTP(MI)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
+ LoLoop.ToRemove.insert(MI);
+ RemovedVCTP = true;
+ continue;
+ }
+ }
+ if (RemovedVCTP)
+ LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
}
- } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
- // A vpt block which is only predicated upon vctp and has no internal vpr
- // defs:
+ } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
+ // A vpt block starting with VPST, is only predicated upon vctp and has no
+ // internal vpr defs:
// - Remove vpst.
// - Unpredicate the remaining instructions.
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
- Block.getVPST()->eraseFromParent();
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
+ LoLoop.ToRemove.insert(Block.getPredicateThen());
for (auto &PredMI : Insts)
RemovePredicate(PredMI.MI);
}
}
-
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
- LoLoop.VCTP->eraseFromParent();
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
+ // Remove the "main" VCTP
+ LoLoop.ToRemove.insert(LoLoop.VCTP);
+ LLVM_DEBUG(dbgs() << " " << *LoLoop.VCTP);
+ // Remove remaining secondary VCTPs
+ for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
+ // All VCTPs that aren't marked for removal yet should be unpredicated ones.
+ // The predicated ones should have already been marked for removal when
+ // visiting the VPT blocks.
+ if (LoLoop.ToRemove.insert(VCTP).second) {
+ assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
+ "Removing Predicated VCTP without updating the block mask!");
+ LLVM_DEBUG(dbgs() << " " << *VCTP);
+ }
+ }
}
void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
@@ -973,9 +1521,8 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
MIB.add(End->getOperand(0));
MIB.add(End->getOperand(1));
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
-
- LoLoop.End->eraseFromParent();
- LoLoop.Dec->eraseFromParent();
+ LoLoop.ToRemove.insert(LoLoop.Dec);
+ LoLoop.ToRemove.insert(End);
return &*MIB;
};
@@ -1001,7 +1548,7 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
RevertWhile(LoLoop.Start);
else
LoLoop.Start->eraseFromParent();
- bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
+ bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
} else {
LoLoop.Start = ExpandLoopStart(LoLoop);
@@ -1009,10 +1556,35 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
LoLoop.End = ExpandLoopEnd(LoLoop);
RemoveDeadBranch(LoLoop.End);
if (LoLoop.IsTailPredicationLegal()) {
- RemoveLoopUpdate(LoLoop);
ConvertVPTBlocks(LoLoop);
+ FixupReductions(LoLoop);
+ }
+ for (auto *I : LoLoop.ToRemove) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
+ I->eraseFromParent();
+ }
+ for (auto *I : LoLoop.BlockMasksToRecompute) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
+ recomputeVPTBlockMask(*I);
+ LLVM_DEBUG(dbgs() << " ... done: " << *I);
}
}
+
+ PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
+ DFS.ProcessLoop();
+ const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
+ for (auto *MBB : PostOrder) {
+ recomputeLiveIns(*MBB);
+ // FIXME: For some reason, the live-in print order is non-deterministic for
+ // our tests and I can't out why... So just sort them.
+ MBB->sortUniqueLiveIns();
+ }
+
+ for (auto *MBB : reverse(PostOrder))
+ recomputeLivenessFlags(*MBB);
+
+ // We've moved, removed and inserted new instructions, so update RDA.
+ RDA->reset();
}
bool ARMLowOverheadLoops::RevertNonLoops() {